109 lines (99 with data), 5.5 kB
include: 'common.snakemake'
import yaml
import re
compare_groups = config['compare_groups']
# Read best preprocess from output file of select_preprocess_method
# key: count_method, value: preprocess_method
feature_selectors = list(config['machine_learning']['feature_selectors'])
classifiers = list(config['machine_learning']['classifiers'])
inputs = {
'cross_validation': expand('{output_dir}/cross_validation/filter.{imputation_method}.Norm_{normalization_method}.Batch_{batch_removal_method}_{batch_index}.{count_method}/{compare_group}/{classifier}.{n_features_to_select}.{selector}',
output_dir=output_dir,
imputation_method=config['imputation_method'],
normalization_method=config['normalization_method'],
batch_removal_method=config['batch_removal_method'],
batch_index=config['batch_index'],
count_method=config['count_method'],
classifier=classifiers,
selector=feature_selectors,
compare_group=list(compare_groups.keys()),
n_features_to_select=config['n_features_to_select']),
'metrics_test': expand('{output_dir}/summary/{cross_validation}/metrics.test.txt',
output_dir=output_dir, cross_validation=['cross_validation']),
'metrics_train': expand('{output_dir}/summary/{cross_validation}/metrics.train.txt',
output_dir=output_dir, cross_validation=['cross_validation']),
'feature_stability': expand('{output_dir}/summary/{cross_validation}/feature_stability.txt',
output_dir=output_dir, cross_validation=['cross_validation'])
}
def get_all_inputs(wildcards):
return inputs
rule all:
input:
unpack(get_all_inputs)
rule cross_validation:
input:
matrix='{output_dir}/matrix_processing/{preprocess_method}.{count_method}.txt',
sample_classes=data_dir + '/sample_classes.txt'
output:
dir=directory('{output_dir}/cross_validation/{preprocess_method}.{count_method}/{compare_group}/{classifier}.{n_features_to_select}.{selector}')
run:
from copy import deepcopy
output_config = {}
# number of features
output_config['n_features_to_select'] = int(wildcards.n_features_to_select)
# copy global config parameters
for key in ('transpose', 'features', 'cv_params', 'sample_weight', 'preprocess_steps'):
if key in config['machine_learning']:
output_config[key] = config['machine_learning'][key]
# copy selector config
selector_config = deepcopy(config['machine_learning']['feature_selector_params'][wildcards.selector])
selector_config['enabled'] = True
selector_config['params'] = selector_config.get('params', {})
# script path for differential expression
if selector_config['name'] == 'DiffExpFilter':
selector_config['params']['script'] = os.path.join(bin_dir, 'differential_expression.R')
# copy selector grid search params
if selector_config['params'].get('grid_search', False):
grid_search_params = deepcopy(config['machine_learning']['selector_grid_search_params'])
grid_search_params.update(selector_config['params']['grid_search_params'])
selector_config['params']['grid_search_params'] = grid_search_params
# append to preprocess_steps
output_config['preprocess_steps'].append({'feature_selection': selector_config})
# copy classifier config
classifier_config = deepcopy(config['machine_learning']['classifier_params'][wildcards.classifier])
classifier_config['params'] = classifier_config.get('params', {})
output_config['classifier'] = classifier_config['classifier']
output_config['classifier_params'] = classifier_config.get('classifier_params', {})
# copy classifier grid search params
if classifier_config.get('grid_search', False):
grid_search_params = deepcopy(config['machine_learning']['classifier_grid_search_params'])
grid_search_params.update(classifier_config['grid_search_params'])
# add classifier grid search config
output_config['grid_search'] = True
output_config['grid_search_params'] = grid_search_params
# write output config
if not os.path.isdir(output.dir):
os.makedirs(output.dir)
output_config_file = os.path.join(output.dir, 'config.yaml')
with open(output_config_file, 'w') as f:
yaml.dump(output_config, f, default_flow_style=False)
command = [
'python',
os.path.join(config['bin_dir'], 'machine_learning.py'), 'run_pipeline',
'--matrix', input.matrix,
'--sample-classes', input.sample_classes,
'--output-dir', output.dir,
'--positive-class', '"' + compare_groups[wildcards.compare_group][1] + '"',
'--negative-class', '"' + compare_groups[wildcards.compare_group][0] + '"',
'--config', output_config_file
]
shell(' '.join(command))
rule summarize_cross_validation:
input:
input_dir=lambda wildcards: inputs[wildcards.cross_validation]
output:
metrics_test='{output_dir}/summary/{cross_validation}/metrics.test.txt',
metrics_train='{output_dir}/summary/{cross_validation}/metrics.train.txt',
feature_stability='{output_dir}/summary/{cross_validation}/feature_stability.txt'
script:
'scripts/summarize_cross_validation.py'