--- a +++ b/src/biodiscml/AdaptDatasetToTesting.java @@ -0,0 +1,276 @@ +/* + * Get clinical and genes expression + * Make some feature extraction + */ +package biodiscml; + +import java.io.FileWriter; +import java.io.PrintWriter; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.TreeMap; +import utils.Weka_module; +import utils.utils.TableObject; +import static utils.utils.readTable; +import weka.classifiers.Classifier; +import weka.core.SerializationHelper; + +/** + * + * @author Mickael + */ +public class AdaptDatasetToTesting { + + public static boolean debug = Main.debug; + public static boolean missingClass = false; + + public AdaptDatasetToTesting() { + } + + public boolean isMissingClass() { + return missingClass; + } + + /** + * create dataset with a determined class + * + * @param theClass + * @param infiles + * @param outfile + */ + public AdaptDatasetToTesting(String theClass, HashMap<String, String> infiles, String outfile, String separator, String model) { + //get model features + Weka_module weka = new Weka_module(); + ArrayList<String> alModelFeatures = weka.getFeaturesFromClassifier(model); //features in the right order + HashMap<String, String> hmModelFeatures = new HashMap<>();//indexed hashed features + System.out.println("# Model features: "); + + Boolean voteModel = false; + try { + voteModel = (((Classifier) SerializationHelper.read(model)).getClass().toString().contains("weka.classifiers.meta.Vote")); + } catch (Exception e) { + if (Main.debug) { + e.printStackTrace(); + } + } + if (voteModel) { + System.out.println(" Combined Vote model"); + } + for (String f : alModelFeatures) { + if (!f.startsWith("Model") && !f.startsWith("class")) { + hmModelFeatures.put(f, f); + } + } + hmModelFeatures.put("class", "class"); + for (String s : hmModelFeatures.keySet()) { + System.out.println("\t" + s); + } + + //convert hashmap to list + String[] files = new String[infiles.size()]; + String[] prefixes = new String[infiles.size()]; + int cpt = 0; + for (String f : infiles.keySet()) { + files[cpt] = f; + prefixes[cpt] = infiles.get(f); + cpt++; + } + + //load datasets of features + if (debug) { + System.out.println("loading files"); + } + + ArrayList<TableObject> al_tables = new ArrayList<>(); + int classIndex = -1; + + for (int i = 0; i < files.length; i++) { + String file = files[i]; + TableObject tbo = new TableObject(readTable(file, separator)); + //locate class + if (tbo.containsClass(theClass)) { + classIndex = i; + } + al_tables.add(tbo); + } + + //extract class + ArrayList<String> myClass = new ArrayList<>(); + + try { + myClass = al_tables.get(classIndex).getTheClass(theClass); + } catch (Exception e) { + String f = ""; + for (String s : infiles.keySet()) { + f += " " + s; + } + System.out.println("Class " + theClass + " not found in " + f + ". Class values are filled by ?"); + missingClass = true; + } + + // replace spaces by _ in class + for (int i = 0; i < myClass.size(); i++) { + String c = myClass.get(i).replace(" ", "_"); + myClass.set(i, c); + } + + //remove feature that are not needed by the model + //create outfile + System.out.println("create outfile " + outfile); + + HashMap<String, ArrayList<String>> hmOutput = new HashMap<>(); + try { + PrintWriter pw = new PrintWriter(new FileWriter(outfile)); + ///////// PRINT HEADER + //pw.print(Main.mergingID); + hmOutput.put(Main.mergingID, new ArrayList<>()); + for (int i = 0; i < al_tables.size(); i++) { + TableObject tbo = al_tables.get(i); + for (String s : tbo.hmData.keySet()) { + String head = null; + if (!prefixes[i].isEmpty()) { + head = prefixes[i] + "__" + s; + } else { + head = s; + } + if (hmModelFeatures.containsKey(head)) { + //pw.print("\t" + head); + hmOutput.put(head, new ArrayList<>()); + } + } + } + + //pw.println("\tclass"); + hmOutput.put("class", new ArrayList<>()); + //pw.flush(); + + //search for ids present in all datasets + HashMap<String, String> hm_ids = getCommonIds(al_tables); + if (al_tables.size() > 1) { + System.out.println("Total number of common instances between files:" + hm_ids.size()); + } else { + System.out.println("Total number of instances between files:" + hm_ids.size()); + } + + ///////PREPARE CONTENT FOR PRINTING + TreeMap<String, Integer> tm = new TreeMap<>(); + tm.putAll(al_tables.get(0).hmIDsList); + for (String id : tm.keySet()) { + if (hm_ids.containsKey(id)) { //if sample exist in all files + //pw.print(id); + hmOutput.get(Main.mergingID).add(id); + for (int i = 0; i < al_tables.size(); i++) { + TableObject tbo = al_tables.get(i); + int idIndex = tbo.hmIDsList.get(id); //get index for sample + for (String s : tbo.hmData.keySet()) { //for all features + String feature = null; + if (!prefixes[i].isEmpty()) { + feature = prefixes[i] + "__" + s; + } else { + feature = s; + } + if (hmModelFeatures.containsKey(feature)) { + //pw.print("\t" + tbo.hmData.get(s).get(idIndex)); + // print values and replace , per . + String out = tbo.hmData.get(s).get(idIndex).replace(",", ".").trim(); + if (out.isEmpty()) { + out = Main.missingValueToReplace; + } + hmOutput.get(feature).add(out); + } + } + } + + if (missingClass) { + hmOutput.get("class").add("?"); + } else { + //pw.print("\t" + myClass.get(al_tables.get(classIndex).hmIDsList.get(id))); + hmOutput.get("class").add(myClass.get(al_tables.get(classIndex).hmIDsList.get(id))); + //pw.print("\t" + myClass.get(idIndex).replace("1", "true").replace("0", "false")); + } + + //pw.println(); + } + } + //PRINTING CONTENT IN THE RIGHT ORDER + if (voteModel || (alModelFeatures.size() != hmModelFeatures.size())) { + //ensure class is at the end of the list + alModelFeatures = new ArrayList<>(); + for (String feature : hmModelFeatures.keySet()) { + if (!feature.equals("class")) { + alModelFeatures.add(feature); + } + } + alModelFeatures.add("class"); + } + //header + pw.print(Main.mergingID); + for (String feature : alModelFeatures) { + pw.print("\t" + feature); + } + pw.println(); + pw.flush(); + //content + for (int i = 0; i < hm_ids.size(); i++) {//for every instance + pw.print(hmOutput.get(Main.mergingID).get(i)); + for (String feature : alModelFeatures) { + try { + pw.print("\t" + hmOutput.get(feature).get(i)); + } catch (Exception e) { + System.out.println("Feature "+feature+" is not present in the test set. Replacing with missing data"); + pw.print("\t?"); + } + } + pw.println(); + } + pw.flush(); + pw.close(); + if (debug) { + System.out.println("closing outfile " + outfile); + } + pw.close(); + } catch (Exception e) { + e.printStackTrace(); + } + } + + /** + * get common ids between infiles + * + * @param al_tables + * @return + */ + public static HashMap<String, String> getCommonIds(ArrayList<TableObject> al_tables) { + HashMap<String, String> hm_ids = new HashMap<>(); + //if many infiles + if (al_tables.size() > 0) { + HashMap<String, Integer> hm_counts = new HashMap<>(); + //get all ids, count how many times each one is seen + for (TableObject table : al_tables) { + for (String s : table.hmIDsList.keySet()) { + //s = s.toLowerCase(); + if (hm_counts.containsKey(s)) { + int tmp = hm_counts.get(s); + tmp++; + hm_counts.put(s, tmp); + } else { + hm_counts.put(s, 1); + } + } + } + //check number of times ids have been seen + for (String s : hm_counts.keySet()) { + if (hm_counts.get(s) == al_tables.size()) { + hm_ids.put(s, ""); + } + } + } else {//for one infile + for (String s : al_tables.get(0).hmIDsList.keySet()) { + hm_ids.put(s, ""); + } + } + + return hm_ids; + } + +}