|
a |
|
b/exseek/snakefiles/feature_selection.snakemake |
|
|
1 |
include: 'common.snakemake' |
|
|
2 |
|
|
|
3 |
import yaml |
|
|
4 |
import re |
|
|
5 |
compare_groups = config['compare_groups'] |
|
|
6 |
|
|
|
7 |
# Read best preprocess from output file of select_preprocess_method |
|
|
8 |
# key: count_method, value: preprocess_method |
|
|
9 |
|
|
|
10 |
feature_selectors = list(config['machine_learning']['feature_selectors']) |
|
|
11 |
classifiers = list(config['machine_learning']['classifiers']) |
|
|
12 |
|
|
|
13 |
inputs = { |
|
|
14 |
'cross_validation': expand('{output_dir}/cross_validation/filter.{imputation_method}.Norm_{normalization_method}.Batch_{batch_removal_method}_{batch_index}.{count_method}/{compare_group}/{classifier}.{n_features_to_select}.{selector}', |
|
|
15 |
output_dir=output_dir, |
|
|
16 |
imputation_method=config['imputation_method'], |
|
|
17 |
normalization_method=config['normalization_method'], |
|
|
18 |
batch_removal_method=config['batch_removal_method'], |
|
|
19 |
batch_index=config['batch_index'], |
|
|
20 |
count_method=config['count_method'], |
|
|
21 |
classifier=classifiers, |
|
|
22 |
selector=feature_selectors, |
|
|
23 |
compare_group=list(compare_groups.keys()), |
|
|
24 |
n_features_to_select=config['n_features_to_select']), |
|
|
25 |
'metrics_test': expand('{output_dir}/summary/{cross_validation}/metrics.test.txt', |
|
|
26 |
output_dir=output_dir, cross_validation=['cross_validation']), |
|
|
27 |
'metrics_train': expand('{output_dir}/summary/{cross_validation}/metrics.train.txt', |
|
|
28 |
output_dir=output_dir, cross_validation=['cross_validation']), |
|
|
29 |
'feature_stability': expand('{output_dir}/summary/{cross_validation}/feature_stability.txt', |
|
|
30 |
output_dir=output_dir, cross_validation=['cross_validation']) |
|
|
31 |
} |
|
|
32 |
|
|
|
33 |
def get_all_inputs(wildcards): |
|
|
34 |
return inputs |
|
|
35 |
|
|
|
36 |
rule all: |
|
|
37 |
input: |
|
|
38 |
unpack(get_all_inputs) |
|
|
39 |
|
|
|
40 |
rule cross_validation: |
|
|
41 |
input: |
|
|
42 |
matrix='{output_dir}/matrix_processing/{preprocess_method}.{count_method}.txt', |
|
|
43 |
sample_classes=data_dir + '/sample_classes.txt' |
|
|
44 |
output: |
|
|
45 |
dir=directory('{output_dir}/cross_validation/{preprocess_method}.{count_method}/{compare_group}/{classifier}.{n_features_to_select}.{selector}') |
|
|
46 |
run: |
|
|
47 |
from copy import deepcopy |
|
|
48 |
|
|
|
49 |
output_config = {} |
|
|
50 |
# number of features |
|
|
51 |
output_config['n_features_to_select'] = int(wildcards.n_features_to_select) |
|
|
52 |
# copy global config parameters |
|
|
53 |
for key in ('transpose', 'features', 'cv_params', 'sample_weight', 'preprocess_steps'): |
|
|
54 |
if key in config['machine_learning']: |
|
|
55 |
output_config[key] = config['machine_learning'][key] |
|
|
56 |
# copy selector config |
|
|
57 |
selector_config = deepcopy(config['machine_learning']['feature_selector_params'][wildcards.selector]) |
|
|
58 |
selector_config['enabled'] = True |
|
|
59 |
selector_config['params'] = selector_config.get('params', {}) |
|
|
60 |
# script path for differential expression |
|
|
61 |
if selector_config['name'] == 'DiffExpFilter': |
|
|
62 |
selector_config['params']['script'] = os.path.join(bin_dir, 'differential_expression.R') |
|
|
63 |
# copy selector grid search params |
|
|
64 |
if selector_config['params'].get('grid_search', False): |
|
|
65 |
grid_search_params = deepcopy(config['machine_learning']['selector_grid_search_params']) |
|
|
66 |
grid_search_params.update(selector_config['params']['grid_search_params']) |
|
|
67 |
selector_config['params']['grid_search_params'] = grid_search_params |
|
|
68 |
# append to preprocess_steps |
|
|
69 |
output_config['preprocess_steps'].append({'feature_selection': selector_config}) |
|
|
70 |
# copy classifier config |
|
|
71 |
classifier_config = deepcopy(config['machine_learning']['classifier_params'][wildcards.classifier]) |
|
|
72 |
classifier_config['params'] = classifier_config.get('params', {}) |
|
|
73 |
output_config['classifier'] = classifier_config['classifier'] |
|
|
74 |
output_config['classifier_params'] = classifier_config.get('classifier_params', {}) |
|
|
75 |
# copy classifier grid search params |
|
|
76 |
if classifier_config.get('grid_search', False): |
|
|
77 |
grid_search_params = deepcopy(config['machine_learning']['classifier_grid_search_params']) |
|
|
78 |
grid_search_params.update(classifier_config['grid_search_params']) |
|
|
79 |
# add classifier grid search config |
|
|
80 |
output_config['grid_search'] = True |
|
|
81 |
output_config['grid_search_params'] = grid_search_params |
|
|
82 |
# write output config |
|
|
83 |
if not os.path.isdir(output.dir): |
|
|
84 |
os.makedirs(output.dir) |
|
|
85 |
output_config_file = os.path.join(output.dir, 'config.yaml') |
|
|
86 |
with open(output_config_file, 'w') as f: |
|
|
87 |
yaml.dump(output_config, f, default_flow_style=False) |
|
|
88 |
command = [ |
|
|
89 |
'python', |
|
|
90 |
os.path.join(config['bin_dir'], 'machine_learning.py'), 'run_pipeline', |
|
|
91 |
'--matrix', input.matrix, |
|
|
92 |
'--sample-classes', input.sample_classes, |
|
|
93 |
'--output-dir', output.dir, |
|
|
94 |
'--positive-class', '"' + compare_groups[wildcards.compare_group][1] + '"', |
|
|
95 |
'--negative-class', '"' + compare_groups[wildcards.compare_group][0] + '"', |
|
|
96 |
'--config', output_config_file |
|
|
97 |
] |
|
|
98 |
shell(' '.join(command)) |
|
|
99 |
|
|
|
100 |
|
|
|
101 |
rule summarize_cross_validation: |
|
|
102 |
input: |
|
|
103 |
input_dir=lambda wildcards: inputs[wildcards.cross_validation] |
|
|
104 |
output: |
|
|
105 |
metrics_test='{output_dir}/summary/{cross_validation}/metrics.test.txt', |
|
|
106 |
metrics_train='{output_dir}/summary/{cross_validation}/metrics.train.txt', |
|
|
107 |
feature_stability='{output_dir}/summary/{cross_validation}/feature_stability.txt' |
|
|
108 |
script: |
|
|
109 |
'scripts/summarize_cross_validation.py' |