a b/exseek/snakefiles/feature_selection.snakemake
1
include: 'common.snakemake'
2
3
import yaml
4
import re
5
compare_groups = config['compare_groups']
6
7
# Read best preprocess from output file of select_preprocess_method 
8
# key: count_method, value: preprocess_method
9
10
feature_selectors = list(config['machine_learning']['feature_selectors'])
11
classifiers = list(config['machine_learning']['classifiers'])
12
13
inputs = {
14
    'cross_validation': expand('{output_dir}/cross_validation/filter.{imputation_method}.Norm_{normalization_method}.Batch_{batch_removal_method}_{batch_index}.{count_method}/{compare_group}/{classifier}.{n_features_to_select}.{selector}',
15
        output_dir=output_dir, 
16
        imputation_method=config['imputation_method'],
17
        normalization_method=config['normalization_method'],
18
        batch_removal_method=config['batch_removal_method'],
19
        batch_index=config['batch_index'],
20
        count_method=config['count_method'],
21
        classifier=classifiers, 
22
        selector=feature_selectors,
23
        compare_group=list(compare_groups.keys()), 
24
        n_features_to_select=config['n_features_to_select']),
25
    'metrics_test': expand('{output_dir}/summary/{cross_validation}/metrics.test.txt', 
26
        output_dir=output_dir, cross_validation=['cross_validation']),
27
    'metrics_train': expand('{output_dir}/summary/{cross_validation}/metrics.train.txt', 
28
        output_dir=output_dir, cross_validation=['cross_validation']),
29
    'feature_stability': expand('{output_dir}/summary/{cross_validation}/feature_stability.txt',
30
        output_dir=output_dir, cross_validation=['cross_validation'])
31
}
32
33
def get_all_inputs(wildcards):
34
    return inputs
35
        
36
rule all:
37
    input:
38
        unpack(get_all_inputs)
39
40
rule cross_validation:
41
    input:
42
        matrix='{output_dir}/matrix_processing/{preprocess_method}.{count_method}.txt',
43
        sample_classes=data_dir + '/sample_classes.txt'
44
    output:
45
        dir=directory('{output_dir}/cross_validation/{preprocess_method}.{count_method}/{compare_group}/{classifier}.{n_features_to_select}.{selector}')
46
    run:
47
        from copy import deepcopy
48
49
        output_config = {}
50
        # number of features
51
        output_config['n_features_to_select'] = int(wildcards.n_features_to_select)
52
        # copy global config parameters
53
        for key in ('transpose', 'features', 'cv_params', 'sample_weight', 'preprocess_steps'):
54
            if key in config['machine_learning']:
55
                output_config[key] = config['machine_learning'][key]
56
        # copy selector config
57
        selector_config = deepcopy(config['machine_learning']['feature_selector_params'][wildcards.selector])
58
        selector_config['enabled'] = True
59
        selector_config['params'] = selector_config.get('params', {})
60
        # script path for differential expression
61
        if selector_config['name'] == 'DiffExpFilter':
62
            selector_config['params']['script'] = os.path.join(bin_dir, 'differential_expression.R')
63
        # copy selector grid search params
64
        if selector_config['params'].get('grid_search', False):
65
            grid_search_params = deepcopy(config['machine_learning']['selector_grid_search_params'])
66
            grid_search_params.update(selector_config['params']['grid_search_params'])
67
            selector_config['params']['grid_search_params'] = grid_search_params
68
        # append to preprocess_steps
69
        output_config['preprocess_steps'].append({'feature_selection': selector_config})
70
        # copy classifier config
71
        classifier_config = deepcopy(config['machine_learning']['classifier_params'][wildcards.classifier])
72
        classifier_config['params'] = classifier_config.get('params', {})
73
        output_config['classifier'] = classifier_config['classifier']
74
        output_config['classifier_params'] = classifier_config.get('classifier_params', {})
75
        # copy classifier grid search params
76
        if classifier_config.get('grid_search', False):
77
            grid_search_params = deepcopy(config['machine_learning']['classifier_grid_search_params'])
78
            grid_search_params.update(classifier_config['grid_search_params'])
79
            # add classifier grid search config
80
            output_config['grid_search'] = True
81
            output_config['grid_search_params'] = grid_search_params
82
        # write output config
83
        if not os.path.isdir(output.dir):
84
            os.makedirs(output.dir)
85
        output_config_file = os.path.join(output.dir, 'config.yaml')
86
        with open(output_config_file, 'w') as f:
87
            yaml.dump(output_config, f, default_flow_style=False)
88
        command = [
89
            'python',
90
            os.path.join(config['bin_dir'], 'machine_learning.py'), 'run_pipeline',
91
            '--matrix', input.matrix,
92
            '--sample-classes', input.sample_classes,
93
            '--output-dir', output.dir,
94
            '--positive-class', '"' + compare_groups[wildcards.compare_group][1] + '"',
95
            '--negative-class', '"' + compare_groups[wildcards.compare_group][0] + '"',
96
            '--config', output_config_file
97
        ]
98
        shell(' '.join(command))
99
100
101
rule summarize_cross_validation:
102
    input:
103
        input_dir=lambda wildcards: inputs[wildcards.cross_validation]
104
    output:
105
        metrics_test='{output_dir}/summary/{cross_validation}/metrics.test.txt',
106
        metrics_train='{output_dir}/summary/{cross_validation}/metrics.train.txt',
107
        feature_stability='{output_dir}/summary/{cross_validation}/feature_stability.txt'
108
    script:
109
        'scripts/summarize_cross_validation.py'