Switch to unified view

a b/src/biodiscml/AdaptDatasetToTesting.java
1
/*
2
 * Get clinical and genes expression
3
 * Make some feature extraction
4
 */
5
package biodiscml;
6
7
import java.io.FileWriter;
8
import java.io.PrintWriter;
9
import java.util.ArrayList;
10
import java.util.HashMap;
11
import java.util.TreeMap;
12
import utils.Weka_module;
13
import utils.utils.TableObject;
14
import static utils.utils.readTable;
15
import weka.classifiers.Classifier;
16
import weka.core.SerializationHelper;
17
18
/**
19
 *
20
 * @author Mickael
21
 */
22
public class AdaptDatasetToTesting {
23
24
    public static boolean debug = Main.debug;
25
    public static boolean missingClass = false;
26
27
    public AdaptDatasetToTesting() {
28
    }
29
30
    public boolean isMissingClass() {
31
        return missingClass;
32
    }
33
34
    /**
35
     * create dataset with a determined class
36
     *
37
     * @param theClass
38
     * @param infiles
39
     * @param outfile
40
     */
41
    public AdaptDatasetToTesting(String theClass, HashMap<String, String> infiles, String outfile, String separator, String model) {
42
        //get model features
43
        Weka_module weka = new Weka_module();
44
        ArrayList<String> alModelFeatures = weka.getFeaturesFromClassifier(model); //features in the right order
45
        HashMap<String, String> hmModelFeatures = new HashMap<>();//indexed hashed features
46
        System.out.println("# Model features: ");
47
48
        Boolean voteModel = false;
49
        try {
50
            voteModel = (((Classifier) SerializationHelper.read(model)).getClass().toString().contains("weka.classifiers.meta.Vote"));
51
        } catch (Exception e) {
52
            if (Main.debug) {
53
                e.printStackTrace();
54
            }
55
        }
56
        if (voteModel) {
57
            System.out.println("    Combined Vote model");
58
        }
59
        for (String f : alModelFeatures) {
60
            if (!f.startsWith("Model") && !f.startsWith("class")) {
61
                hmModelFeatures.put(f, f);
62
            }
63
        }
64
        hmModelFeatures.put("class", "class");
65
        for (String s : hmModelFeatures.keySet()) {
66
            System.out.println("\t" + s);
67
        }
68
69
        //convert hashmap to list
70
        String[] files = new String[infiles.size()];
71
        String[] prefixes = new String[infiles.size()];
72
        int cpt = 0;
73
        for (String f : infiles.keySet()) {
74
            files[cpt] = f;
75
            prefixes[cpt] = infiles.get(f);
76
            cpt++;
77
        }
78
79
        //load datasets of features
80
        if (debug) {
81
            System.out.println("loading files");
82
        }
83
84
        ArrayList<TableObject> al_tables = new ArrayList<>();
85
        int classIndex = -1;
86
87
        for (int i = 0; i < files.length; i++) {
88
            String file = files[i];
89
            TableObject tbo = new TableObject(readTable(file, separator));
90
            //locate class
91
            if (tbo.containsClass(theClass)) {
92
                classIndex = i;
93
            }
94
            al_tables.add(tbo);
95
        }
96
97
        //extract class 
98
        ArrayList<String> myClass = new ArrayList<>();
99
100
        try {
101
            myClass = al_tables.get(classIndex).getTheClass(theClass);
102
        } catch (Exception e) {
103
            String f = "";
104
            for (String s : infiles.keySet()) {
105
                f += " " + s;
106
            }
107
            System.out.println("Class " + theClass + " not found in " + f + ". Class values are filled by ?");
108
            missingClass = true;
109
        }
110
111
        // replace spaces by _ in class
112
        for (int i = 0; i < myClass.size(); i++) {
113
            String c = myClass.get(i).replace(" ", "_");
114
            myClass.set(i, c);
115
        }
116
117
        //remove feature that are not needed by the model
118
        //create outfile
119
        System.out.println("create outfile " + outfile);
120
121
        HashMap<String, ArrayList<String>> hmOutput = new HashMap<>();
122
        try {
123
            PrintWriter pw = new PrintWriter(new FileWriter(outfile));
124
            ///////// PRINT HEADER
125
            //pw.print(Main.mergingID);
126
            hmOutput.put(Main.mergingID, new ArrayList<>());
127
            for (int i = 0; i < al_tables.size(); i++) {
128
                TableObject tbo = al_tables.get(i);
129
                for (String s : tbo.hmData.keySet()) {
130
                    String head = null;
131
                    if (!prefixes[i].isEmpty()) {
132
                        head = prefixes[i] + "__" + s;
133
                    } else {
134
                        head = s;
135
                    }
136
                    if (hmModelFeatures.containsKey(head)) {
137
                        //pw.print("\t" + head);
138
                        hmOutput.put(head, new ArrayList<>());
139
                    }
140
                }
141
            }
142
143
            //pw.println("\tclass");
144
            hmOutput.put("class", new ArrayList<>());
145
            //pw.flush();
146
147
            //search for ids present in all datasets
148
            HashMap<String, String> hm_ids = getCommonIds(al_tables);
149
            if (al_tables.size() > 1) {
150
                System.out.println("Total number of common instances between files:" + hm_ids.size());
151
            } else {
152
                System.out.println("Total number of instances between files:" + hm_ids.size());
153
            }
154
155
            ///////PREPARE CONTENT FOR PRINTING
156
            TreeMap<String, Integer> tm = new TreeMap<>();
157
            tm.putAll(al_tables.get(0).hmIDsList);
158
            for (String id : tm.keySet()) {
159
                if (hm_ids.containsKey(id)) { //if sample exist in all files
160
                    //pw.print(id);
161
                    hmOutput.get(Main.mergingID).add(id);
162
                    for (int i = 0; i < al_tables.size(); i++) {
163
                        TableObject tbo = al_tables.get(i);
164
                        int idIndex = tbo.hmIDsList.get(id); //get index for sample
165
                        for (String s : tbo.hmData.keySet()) { //for all features
166
                            String feature = null;
167
                            if (!prefixes[i].isEmpty()) {
168
                                feature = prefixes[i] + "__" + s;
169
                            } else {
170
                                feature = s;
171
                            }
172
                            if (hmModelFeatures.containsKey(feature)) {
173
                                //pw.print("\t" + tbo.hmData.get(s).get(idIndex));
174
                                // print values and replace , per .
175
                                String out = tbo.hmData.get(s).get(idIndex).replace(",", ".").trim();
176
                                if (out.isEmpty()) {
177
                                    out = Main.missingValueToReplace;
178
                                }
179
                                hmOutput.get(feature).add(out);
180
                            }
181
                        }
182
                    }
183
184
                    if (missingClass) {
185
                        hmOutput.get("class").add("?");
186
                    } else {
187
                        //pw.print("\t" + myClass.get(al_tables.get(classIndex).hmIDsList.get(id)));
188
                        hmOutput.get("class").add(myClass.get(al_tables.get(classIndex).hmIDsList.get(id)));
189
                        //pw.print("\t" + myClass.get(idIndex).replace("1", "true").replace("0", "false"));
190
                    }
191
192
                    //pw.println();
193
                }
194
            }
195
            //PRINTING CONTENT IN THE RIGHT ORDER
196
            if (voteModel || (alModelFeatures.size() != hmModelFeatures.size())) {
197
                //ensure class is at the end of the list
198
                alModelFeatures = new ArrayList<>();
199
                for (String feature : hmModelFeatures.keySet()) {
200
                    if (!feature.equals("class")) {
201
                        alModelFeatures.add(feature);
202
                    }
203
                }
204
                alModelFeatures.add("class");
205
            }
206
            //header
207
            pw.print(Main.mergingID);
208
            for (String feature : alModelFeatures) {
209
                pw.print("\t" + feature);
210
            }
211
            pw.println();
212
            pw.flush();
213
            //content
214
            for (int i = 0; i < hm_ids.size(); i++) {//for every instance
215
                pw.print(hmOutput.get(Main.mergingID).get(i));
216
                for (String feature : alModelFeatures) {
217
                    try {
218
                        pw.print("\t" + hmOutput.get(feature).get(i));
219
                    } catch (Exception e) {
220
                        System.out.println("Feature "+feature+" is not present in the test set. Replacing with missing data");
221
                        pw.print("\t?");
222
                    }
223
                }
224
                pw.println();
225
            }
226
            pw.flush();
227
            pw.close();
228
            if (debug) {
229
                System.out.println("closing outfile " + outfile);
230
            }
231
            pw.close();
232
        } catch (Exception e) {
233
            e.printStackTrace();
234
        }
235
    }
236
237
    /**
238
     * get common ids between infiles
239
     *
240
     * @param al_tables
241
     * @return
242
     */
243
    public static HashMap<String, String> getCommonIds(ArrayList<TableObject> al_tables) {
244
        HashMap<String, String> hm_ids = new HashMap<>();
245
        //if many infiles
246
        if (al_tables.size() > 0) {
247
            HashMap<String, Integer> hm_counts = new HashMap<>();
248
            //get all ids, count how many times each one is seen
249
            for (TableObject table : al_tables) {
250
                for (String s : table.hmIDsList.keySet()) {
251
                    //s = s.toLowerCase();
252
                    if (hm_counts.containsKey(s)) {
253
                        int tmp = hm_counts.get(s);
254
                        tmp++;
255
                        hm_counts.put(s, tmp);
256
                    } else {
257
                        hm_counts.put(s, 1);
258
                    }
259
                }
260
            }
261
            //check number of times ids have been seen
262
            for (String s : hm_counts.keySet()) {
263
                if (hm_counts.get(s) == al_tables.size()) {
264
                    hm_ids.put(s, "");
265
                }
266
            }
267
        } else {//for one infile
268
            for (String s : al_tables.get(0).hmIDsList.keySet()) {
269
                hm_ids.put(s, "");
270
            }
271
        }
272
273
        return hm_ids;
274
    }
275
276
}