Switch to unified view

a b/src/biodiscml/AdaptDatasetToTraining.java
1
/*
2
 * Get clinical and genes expression
3
 * Make some feature extraction
4
 */
5
package biodiscml;
6
7
import java.io.BufferedReader;
8
import java.io.File;
9
import java.io.FileReader;
10
import java.io.FileWriter;
11
import java.io.PrintWriter;
12
import java.util.ArrayList;
13
import java.util.HashMap;
14
import java.util.TreeMap;
15
import utils.Weka_module;
16
import utils.utils.TableObject;
17
import static utils.utils.*;
18
19
/**
20
 *
21
 * @author Mickael
22
 */
23
public class AdaptDatasetToTraining {
24
25
    public static boolean debug = Main.debug;
26
    public static HashMap<String, String> removedFeatures = new HashMap<>();
27
28
    public AdaptDatasetToTraining() {
29
    }
30
31
    /**
32
     * create dataset with a determined class
33
     *
34
     * @param trainFile
35
     */
36
    public AdaptDatasetToTraining(String trainFile) {
37
        //remove files of previous run if exist
38
        if (!Main.restoreRun || !Main.resumeTraining) {
39
            System.out.println("Check if files exist and should be deleted...");
40
            String allDataFile = trainFile.replace("data_to_train", "all_data");
41
            if (new File(allDataFile).exists()) {
42
                System.out.println("\t" + allDataFile + " exist... deleting...");
43
                new File(allDataFile).delete();
44
            }
45
            if (new File(trainFile).exists()) {
46
                System.out.println("\t" + trainFile + " exist... deleting...");
47
                new File(trainFile).delete();
48
                new File(trainFile.replace(".csv", ".arff")).delete();
49
            }
50
        }
51
52
        //create the adapted training file
53
        System.out.println("# Training file(s)");
54
        if (Main.doClassification) {
55
            createFileCompatibleForWeka(Main.classificationClassName, Main.hmTrainFiles, trainFile, Main.separator, true);
56
        } else {
57
            createFileCompatibleForWeka(Main.regressionClassName, Main.hmTrainFiles, trainFile, Main.separator, true);
58
        }
59
        //create the adapted tested file
60
        if (Main.doSampling) {
61
            System.out.println("## Apply sampling configuration");
62
            //if a test file is provided
63
            String trainAndTestFile = trainFile.replace("data_to_train", "all_data");
64
            String testFile = trainFile.replace("data_to_train.csv", "data_to_test.csv");
65
            Weka_module weka = new Weka_module();
66
            String trainSetRange = "";
67
            if (!Main.hmNewDataFiles.isEmpty()) {
68
                System.out.println("# Testing file(s)");
69
                if (Main.doClassification) {
70
                    createFileCompatibleForWeka(Main.classificationClassName, Main.hmNewDataFiles, testFile, Main.separator, false);
71
                } else {
72
                    createFileCompatibleForWeka(Main.regressionClassName, Main.hmNewDataFiles, testFile, Main.separator, false);
73
                }
74
                //if a test file is provided, we need to merge it to the train file and
75
                // split it again to preserve a compatible arff format between train and test sets
76
                trainSetRange = mergeTrainAndTestFiles(trainFile, testFile, trainAndTestFile);
77
78
            } else {// else we split
79
                // just rename the train file to a file that contains all train and test data
80
                new File(trainFile).renameTo(new File(trainAndTestFile));
81
            }
82
83
            //perform sampling
84
            weka.sampling(trainAndTestFile, trainFile, testFile, Main.isClassification, trainSetRange);
85
        }
86
        System.out.println("");
87
88
    }
89
90
    /**
91
     * get common ids between infiles
92
     *
93
     * @param al_tables
94
     * @return
95
     */
96
    public HashMap<String, String> getCommonIds(ArrayList<TableObject> al_tables) {
97
        HashMap<String, String> hm_ids = new HashMap<>();
98
        //if many infiles
99
        if (al_tables.size() > 0) {
100
            HashMap<String, Integer> hm_counts = new HashMap<>();
101
            //get all ids, count how many times each one is seen
102
            for (TableObject table : al_tables) {
103
                for (String s : table.hmIDsList.keySet()) {
104
                    s = s.toLowerCase();
105
                    if (hm_counts.containsKey(s)) {
106
                        int tmp = hm_counts.get(s);
107
                        tmp++;
108
                        hm_counts.put(s, tmp);
109
                    } else {
110
                        hm_counts.put(s, 1);
111
                    }
112
                }
113
            }
114
            //check number of times ids have been seen
115
            for (String s : hm_counts.keySet()) {
116
                if (hm_counts.get(s) == al_tables.size()) {
117
                    hm_ids.put(s, "");
118
                }
119
            }
120
        } else {//for one infile
121
            for (String s : al_tables.get(0).hmIDsList.keySet()) {
122
                hm_ids.put(s, "");
123
            }
124
        }
125
126
        return hm_ids;
127
    }
128
129
    private void createFileCompatibleForWeka(String theClass, HashMap<String, String> infiles, String outfile, String separator, Boolean trainingFile) {
130
        //convert hashmap to list
131
        String[] files = new String[infiles.size()];
132
        String[] prefixes = new String[infiles.size()];
133
        int cpt = 0;
134
        for (String f : infiles.keySet()) {
135
            files[cpt] = f;
136
            prefixes[cpt] = infiles.get(f);
137
            cpt++;
138
        }
139
140
        //load datasets of features
141
        if (debug) {
142
            System.out.println("loading files");
143
        }
144
145
        ArrayList<TableObject> al_tables = new ArrayList<>();
146
        int classIndex = -1;
147
148
        for (int i = 0; i < files.length; i++) {
149
            String file = files[i];
150
            if (debug) {
151
                System.out.println(file);
152
            }
153
154
            TableObject tbo = new TableObject(readTable(file, separator));
155
            //locate class
156
            if (tbo.containsClass(theClass)) {
157
                classIndex = i;
158
            }
159
            al_tables.add(tbo);
160
        }
161
162
        //extract class
163
        ArrayList<String> myClass = new ArrayList<>();
164
        try {
165
            myClass = al_tables.get(classIndex).getTheClass(theClass);
166
        } catch (Exception e) {
167
            System.err.println("[error] Class " + theClass + " not found. Error in the input file.");
168
            if (Main.debug) {
169
                e.printStackTrace();
170
            }
171
            System.exit(0);
172
        }
173
        //remove useless features having 100% the same value
174
        if (trainingFile) {
175
            try {
176
                for (TableObject tbo : al_tables) {
177
                    for (String s : tbo.getSortedHmDataKeyset()) {
178
                        HashMap<String, String> hm = new HashMap<>();
179
                        for (String value : tbo.hmData.get(s)) {
180
                            hm.put(value, value);
181
                        }
182
                        if (hm.size() == 1) {
183
                            tbo.hmData.remove(s);
184
                            if (hm.keySet().toArray()[0].equals("?")) {
185
                                System.out.println("Removing feature " + s + " "
186
                                        + "because 100% of values are missing");
187
                            } else {
188
                                System.out.println("Removing feature " + s + " "
189
                                        + "because 100% of values are identical "
190
                                        + "{" + hm.keySet().toArray()[0] + "}");
191
                            }
192
                            removedFeatures.put(s, s);
193
                        }
194
                    }
195
                    cpt++;
196
                }
197
            } catch (Exception e) {
198
                e.printStackTrace();
199
            }
200
        } else {
201
            try {
202
                for (TableObject tbo : al_tables) {
203
                    for (String s : tbo.getSortedHmDataKeyset()) {
204
                        if (removedFeatures.containsKey(s)) {
205
                            tbo.hmData.remove(s);
206
                        }
207
                    }
208
                    cpt++;
209
                }
210
            } catch (Exception e) {
211
                e.printStackTrace();
212
            }
213
        }
214
215
        //create outfile
216
        if (debug) {
217
            System.out.println("create outfile " + outfile);
218
        }
219
        try {
220
            PrintWriter pw = new PrintWriter(new FileWriter(outfile));
221
            ///////// PRINT HEADER
222
            int featuresCpt = 0;
223
            pw.print(Main.mergingID);
224
            cpt = 0;
225
            for (TableObject tbo : al_tables) {
226
                for (String s : tbo.getSortedHmDataKeyset()) {
227
                    if (!Main.hmExcludedFeatures.containsKey(s)) {
228
                        if (!prefixes[cpt].isEmpty()) {
229
                            pw.print("\t" + prefixes[cpt] + "__" + s);
230
                        } else {
231
                            pw.print("\t" + s);
232
                        }
233
                        featuresCpt++;
234
                    }
235
                }
236
                cpt++;
237
            }
238
            pw.println("\tclass");
239
            pw.flush();
240
241
            //search for ids present in all datasets
242
            HashMap<String, String> hm_ids = getCommonIds(al_tables);
243
            if (debug && al_tables.size() > 1) {
244
                System.out.println("Total number of common instances between files: " + hm_ids.size());
245
            } else {
246
                System.out.println("Total number of instances between files: " + hm_ids.size());
247
            }
248
249
            System.out.println("Total number of features: " + featuresCpt);
250
251
            ///////PRINT CONTENT
252
            TreeMap<String, Integer> tm = new TreeMap<>();
253
            tm.putAll(al_tables.get(0).hmIDsList);
254
            int existing_spaces = 0;
255
            for (String id : tm.keySet()) {
256
                if (hm_ids.containsKey(id.toLowerCase()) && !id.equals(Main.mergingID.toLowerCase())) {
257
                    // if (hm_ids.containsKey(id) && !id.equals(Main.mergingID)) {
258
                    pw.print(id);
259
                    for (TableObject tbo : al_tables) {
260
                        int idIndex = tbo.hmIDsList.get(id);
261
                        for (String s : tbo.getSortedHmDataKeyset()) {
262
                            if (!Main.hmExcludedFeatures.containsKey(s)) { //if it is not a rejected feature
263
                                // print values and replace , by .
264
                                String out = tbo.hmData.get(s).get(idIndex).replace(",", ".").trim();
265
                                if (out.isEmpty() || out.equals("NA") || out.equals("na")
266
                                        || out.equals("N/A") || out.equals("n/a")) {
267
                                    out = "?";
268
                                }
269
                                pw.print("\t" + out);
270
                            }
271
                        }
272
                    }
273
                    String classe = myClass.get(al_tables.get(classIndex).hmIDsList.get(id));
274
                    if (classe.contains(" ")) {
275
                        existing_spaces++;
276
                    }
277
                    classe = classe.replace(" ", "_");
278
                    pw.print("\t" + classe);
279
                    //pw.print("\t" + myClass.get(idIndex).replace("1", "true").replace("0", "false"));
280
                    pw.println();
281
                }
282
            }
283
            if (existing_spaces > 0) {
284
                System.out.println("Spaces detected in class label. They were replaced by _");
285
            }
286
            pw.flush();
287
288
            if (debug) {
289
                System.out.println("closing outfile " + outfile);
290
            }
291
            pw.close();
292
293
        } catch (Exception e) {
294
            e.printStackTrace();
295
        }
296
    }
297
298
    /**
299
     *
300
     * @param outfile
301
     * @param replace
302
     * @return range of the train set (ex: 1-100)
303
     */
304
    private String mergeTrainAndTestFiles(String trainFile, String testFile, String trainAndTestFile) {
305
        int cpt = -1;
306
        try {
307
            //create train+set file
308
            PrintWriter pw = new PrintWriter(new FileWriter(trainAndTestFile));
309
310
            //read train
311
            BufferedReader br = new BufferedReader(new FileReader(trainFile));
312
313
            while (br.ready()) {
314
                pw.println(br.readLine());
315
                cpt++;
316
            }
317
            br.close();
318
            pw.flush();
319
320
            //read test
321
            br = new BufferedReader(new FileReader(testFile));
322
            br.readLine(); // skip header
323
            while (br.ready()) {
324
                pw.println(br.readLine());
325
            }
326
            br.close();
327
            pw.close();
328
        } catch (Exception e) {
329
            e.printStackTrace();
330
        }
331
        return "1-" + cpt;
332
    }
333
334
}