Switch to unified view

a b/exseek/snakefiles/common.snakemake
1
shell.prefix('set -x; set -e;')
2
import os
3
import yaml
4
import re
5
from glob import glob
6
7
def require_variable(variable, condition=None):
8
    value = config.get(variable)
9
    if value is None:
10
        raise ValueError('configuration variable "{}" is required'.format(variable))
11
    if (condition == 'input_dir') and (not os.path.isdir(value)):
12
        raise ValueError('cannot find input directory {}: {}'.format(variable, value))
13
    elif (condition == 'input_file') and (not os.path.isfile(value)):
14
        raise ValueError('cannot find input file {}: {}'.format(variable, value))
15
    return value
16
17
def get_config_file(filename):
18
    for config_dir in config_dirs:
19
        if os.path.isfile(os.path.join(config_dir, filename)):
20
            return os.path.join(config_dir, filename)
21
22
# setup global variables
23
package_dir = require_variable('package_dir')
24
root_dir = require_variable('root_dir', 'input_dir')
25
config_dirs = require_variable('config_dirs')
26
config_dirs = config_dirs.split(':')
27
# load default config
28
with open(get_config_file('default_config.yaml'), 'r') as f:
29
    default_config = yaml.load(f)
30
# read selector-classifier
31
with open(get_config_file('machine_learning.yaml'), 'r') as f:
32
    default_config['machine_learning'] = yaml.load(f)
33
34
default_config.update(config)
35
config = default_config
36
37
dataset = config['dataset']
38
data_dir = require_variable('data_dir', 'input_dir')
39
genome_dir = require_variable('genome_dir')
40
bin_dir = require_variable('bin_dir', 'input_dir')
41
output_dir = require_variable('output_dir')
42
rna_types = require_variable('rna_types')
43
tools_dir = require_variable('tools_dir')
44
temp_dir = require_variable('temp_dir')
45
# r_dir = require_variable('r_dir')
46
# create temp_dir
47
if not os.path.isdir(temp_dir):
48
    os.makedirs(temp_dir)
49
    
50
# read sample ids from file
51
with open(os.path.join(data_dir, 'sample_ids.txt'), 'r') as f:
52
    sample_ids = f.read().split()
53
for sample_id in sample_ids:
54
    if '.' in sample_id:
55
        raise ValueError('"." is not allowed in sample ID: {}'.format(sample_id))
56
57
# rna types with annotation gtf
58
rna_types_with_gtf = []
59
for rna_type in config['rna_types']:
60
    if rna_type not in ('univec', 'rRNA', 'spikein'):
61
        rna_types_with_gtf.append(rna_type)
62
63
# long RNA types
64
long_rna_types = list(filter(lambda x: x not in ('miRNA', 'piRNA', 'rRNA', 'spikein', 'univec'), config['rna_types']))
65
66
# read adapter sequences
67
for key in ('adaptor', 'adaptor_5p'):
68
    if isinstance(config[key], str):
69
        if config[key].startswith('file:'):
70
            filename = config[key][5:].strip()
71
            adapters = {}
72
            with open(filename, 'r') as f:
73
                for lineno, line in enumerate(f):
74
                    c = line.strip().split('\t')
75
                    if len(c) != 2:
76
                        raise ValueError('expect 2 columns in adapter file: {}, {} found'.format(filename, len(c)))
77
                    adapters[c[0]] = c[1]
78
            config[key] = adapters
79
    
80
def get_preprocess_methods():
81
    '''Get combinations of preprocess methods for feature selection and feature evaluation
82
    '''
83
    preprocess_methods = []
84
    for batch_removal_method in config['batch_removal_methods']:
85
        if batch_removal_method in ('ComBat', 'limma'):
86
            template = 'filter.{imputation_method}.Norm_{normalization_method}.Batch_{batch_removal_method}_{batch_index}'
87
            preprocess_methods += expand(template,
88
                output_dir=output_dir,
89
                imputation_method=config['imputation_methods'],
90
                normalization_method=config['normalization_methods'],
91
                batch_removal_method=batch_removal_method,
92
                batch_index=config['batch_index'])
93
        elif batch_removal_method in ('RUV', 'RUVn', 'null'):
94
            template = 'filter.{imputation_method}.Norm_{normalization_method}.Batch_{batch_removal_method}_1'
95
            preprocess_methods += expand(template,
96
                output_dir=output_dir,
97
                imputation_method=config['imputation_methods'],
98
                normalization_method=config['normalization_methods'],
99
                batch_removal_method=batch_removal_method)
100
    return preprocess_methods
101
102
def auto_gzip_input(template):
103
    '''Input function that automatically detect gzip files
104
    '''
105
    def get_filename(wildcards):
106
        gzip_names = expand(template + '.gz', **wildcards)
107
        if all(os.path.exists(f) for f in gzip_names):
108
            return gzip_names
109
        original_names = expand(template, **wildcards)
110
        #if all(os.path.exists(f) for f in original_names):
111
        return original_names
112
        
113
    return get_filename
114
115
def parse_fastqc_data(fp):
116
    '''
117
    '''
118
    section = None
119
    qc_status = OrderedDict()
120
    basic_statistics = OrderedDict()
121
    for line in fp:
122
        line = str(line, encoding='utf-8')
123
        line = line.strip()
124
        if line.startswith('>>'):
125
            if line == '>>END_MODULE':
126
                continue
127
            section, status = line[2:].split('\t')
128
            qc_status[section] = status
129
        else:
130
            if section == 'Basic Statistics':
131
                key, val = line.split('\t')
132
                basic_statistics[key] = val
133
    for key, val in qc_status.items():
134
        basic_statistics[key] = val
135
    return basic_statistics
136
    
137
def get_input_matrix(wildcards):
138
    # Use RPM for small RNA
139
    if config['small_rna']:
140
        return '{output_dir}/matrix_processing/{preprocess_method}.{count_method}.txt'.format(**wildcards)
141
    # Use RPKM for long RNA
142
    else:
143
        return '{output_dir}/rpkm/{preprocess_method}.{count_method}.txt'.format(**wildcards)
144
    
145
def get_known_biomarkers():
146
    biomarkers = []
147
    for filename in glob(os.path.join(data_dir, 'known_biomarkers', '*/*.txt')):
148
        c = filename.split('/')
149
        compare_group = c[-2]
150
        feature_set = os.path.splitext(c[-1])[0]
151
        biomarkers.append((compare_group, feature_set))
152
    return biomarkers
153
154
def to_list(obj):
155
    '''Wrap objects as a list
156
    '''
157
    if isinstance(obj, list):
158
        return obj
159
    if isinstance(obj, str) or isinstance(obj, bytes):
160
        return [obj]
161
    return list(obj)
162
163
def get_cutadapt_adapter_args(wildcards, adapter, option):
164
    '''Get adapter sequence for a sample
165
166
    ===========
167
    Parameters:
168
        adapter: adapter config value
169
        option: option for cutadapt
170
171
    ========
172
    Returns:
173
        arguments: command arguments for cutadapt
174
175
    '''
176
    if isinstance(adapter, str):
177
        if len(adapter) == 0:
178
            return ''
179
        else:
180
            return option + ' ' +  adapter
181
    elif isinstance(adapter, list):
182
        return ' '.join(expand(option + ' {a}', adapter))
183
    elif isinstance(adapter, dict):
184
        adapter_seq = adapter.get(wildcards.sample_id)
185
        if adapter_seq is None:
186
            raise ValueError('adapter sequence is not found for sample {}'.format(wildcards.sample_id))
187
        return option + ' ' + adapter_seq
188
    else:
189
        return ''
190
        
191
def get_library_size_small(summary_file, sample_id):
192
    '''Get library size for a sample from summary file
193
    '''
194
    columns = {}
195
    data = {}
196
    with open(summary_file, 'r') as f:
197
        for lineno, line in enumerate(f):
198
            c = line.strip().split('\t')
199
            if lineno == 0:
200
                columns = {c[i]:i for i in range(1, len(c))}
201
                column = columns[sample_id]
202
            else:
203
                data[c[0]] = int(c[column])
204
    library_size = data['clean.unmapped'] - data['other.unmapped']
205
    # remove these types from library size calculation
206
    for key in ['spikein.mapped', 'univec.mapped', 'rRNA.mapped']:
207
        if key in data:
208
            library_size -= data[key]
209
    return library_size
210
211
# template for nbconvert
212
nbconvert_command = '''cp {input.jupyter} {output.jupyter}
213
jupyter nbconvert --execute --to html \
214
    --HTMLExporter.exclude_code_cell=False \
215
    --HTMLExporter.exclude_input_prompt=True \
216
    --HTMLExporter.exclude_output_prompt=True \
217
    {output.jupyter}
218
'''
219
220
# export singularity wrappers
221
use_singularity = config.get('use_singularity')
222
if use_singularity:
223
    os.environ['PATH'] = config['container']['wrapper_dir'] + ':' + os.environ['PATH']
224
225
if not os.path.isdir(temp_dir):
226
    os.makedirs(temp_dir)
227
228
sub_matrix_regex = '(mirna_only)'
229
count_method_regex = '(featurecounts)|(htseq)|(transcript)|(mirna_and_long_fragments)|(featurecounts_lncrna)'
230
count_method_regex = r'[^\.]+'
231
imputation_method_regex = '(scimpute_count)|(viper_count)|(null)'
232
normalization_method_regex = '(SCnorm)|(TMM)|(RLE)|(CPM)|(CPM_top)|(CPM_rm)|(CPM_refer)|(UQ)|(null)'
233
batch_removal_method_with_batchinfo_regex = '(ComBat)|(limma)'
234
batch_removal_method_without_batchinfo_regex = '(RUV)|(RUVn)|(null)'
235
236
# batch information is provided
237
has_batch_info = os.path.isfile(os.path.join(data_dir, 'batch_info.txt'))
238
# clustering scores for normalization, feature_selection and evaluate_features
239
clustering_scores = ['uca_score']
240
if has_batch_info:
241
    #clustering_scores += ['kbet_score', 'combined_score', 'knn_score']
242
    clustering_scores += ['combined_score', 'knn_score']