Diff of /src/main.py [000000] .. [b798eb]

Switch to unified view

a b/src/main.py
1
""" Quantum machine learning on neural network embeddings
2
3
    Returns:
4
        Performance metrics on neural network, support vector classifier, and quantum support vector classifier 
5
"""
6
### Author: Aritra Bose <a.bose@ibm.com>
7
### MIT license
8
9
10
### --- base class imports --- ###
11
import pandas as pd
12
import numpy as np
13
import argparse
14
import os
15
import copy
16
from time import strftime, gmtime
17
#import numpy as np
18
import matplotlib
19
import matplotlib.pyplot as plt
20
import seaborn as sns 
21
sns.set_style('dark')
22
23
# ====== Torch imports ======
24
import torch
25
from torch.utils.data import DataLoader
26
from pytorch_lightning.callbacks import ModelCheckpoint
27
from pytorch_lightning.callbacks import EarlyStopping
28
from pytorch_lightning.loggers import TensorBoardLogger
29
import pytorch_lightning as pl 
30
from torchmetrics import ConfusionMatrix, F1Score
31
# ====== Scikit-learn imports ======
32
33
from sklearn.svm import SVC
34
from sklearn.metrics import (
35
    auc,
36
    roc_curve,
37
    ConfusionMatrixDisplay,
38
    f1_score,
39
    balanced_accuracy_score,
40
)
41
from sklearn.preprocessing import StandardScaler, LabelBinarizer
42
from sklearn.model_selection import train_test_split
43
from sklearn.model_selection import KFold
44
45
46
# ====== Qiskit imports ======
47
48
from qiskit.circuit.library import ZZFeatureMap, ZFeatureMap, PauliFeatureMap
49
from qiskit import QuantumCircuit
50
from qiskit_ibm_runtime import QiskitRuntimeService
51
from qiskit_algorithms.utils import algorithm_globals
52
from qiskit.primitives import Sampler
53
from qiskit_aer import AerSimulator
54
from qiskit_algorithms.state_fidelities import ComputeUncompute
55
from qiskit_machine_learning.kernels import FidelityQuantumKernel
56
from qiskit_machine_learning.algorithms import QSVC, PegasosQSVC
57
58
# ====== Local imports ======
59
from model import LModel
60
from dataset import OmicsData
61
62
63
def parse_args(): 
64
    """Parse the input command line args using argparse 
65
66
    Returns:
67
        Dictionary of parsed arguments.
68
    """
69
    parser = argparse.ArgumentParser(
70
        prog="quantum machine learning on multi-omics",
71
        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
72
    )
73
    parser.add_argument(
74
        "-f",
75
        "--file",
76
        type=str, 
77
        default=None, 
78
        help="Multi-omics data file"
79
    )
80
    parser.add_argument(
81
        "-cv",
82
        "--num_cv",
83
        type=int,
84
        default = 1, 
85
        help="Number of cross-validation folds"
86
    )
87
    parser.add_argument(
88
        "-e", "--epoch", 
89
        type=int, 
90
        default=100, 
91
        help="Number of training epochs"
92
    )
93
    parser.add_argument(
94
        "-b", 
95
        "--batch_size", 
96
        type=int, 
97
        default=20, 
98
        help="Train/test batch size"
99
    )
100
    parser.add_argument(
101
        "-lr",
102
        "--lr",
103
        type=float,
104
        default=1e-3,
105
        help="learning rate"
106
    )
107
    parser.add_argument(
108
        "-l2",
109
        "--weight_decay",
110
        type=float,
111
        default=1e-5,
112
        help="L2 regularization"
113
    )
114
    parser.add_argument(
115
        "-p",
116
        "--patience",
117
        type=int,
118
        default=3,
119
        help="Early stopping patience"
120
    )
121
    parser.add_argument(
122
        "-i",
123
        "--iter",
124
        type=int,
125
        default=1,
126
        help="Number of iterations"
127
    )
128
    parser.add_argument(
129
        "-d",
130
        "--dim",
131
        type=int,
132
        default=8,
133
        help="Number of dimensions for the neural network embedding"
134
    )
135
    parser.add_argument(
136
        "-c",
137
        "--C",
138
        type=int,
139
        default=1,
140
        help="Regularization parameter for SVC"
141
    )
142
    parser.add_argument(
143
        "-pq",
144
        "--pegasos",
145
        type=bool,
146
        default=False,
147
        help="Flag to use PegasosQSVC"
148
    )
149
    parser.add_argument(
150
        "-en",
151
        "--encoding",
152
        type=str, 
153
        default="ZZ", 
154
        choices=['ZZ', 'Z', 'P'],
155
        help="Econding for QML"
156
    )
157
    args = parser.parse_args()
158
    return args 
159
160
def validate_args(args):
161
    """Validate the arguments
162
163
    Args:
164
        args (dictionary): The argument dictionary as returned by parse_args(). 
165
166
    Raises:
167
        ValueError: Input file path error if incorrect path provided.
168
    """
169
    if args.file is None or os.path.exists(args.file) is None: 
170
        raise ValueError("Input file path error!")
171
172
173
def process_data(file):
174
    """Process the data file 
175
176
    Args:
177
        file (path): Path of the .csv file with the following column structure: 
178
                    [Sample ID, Genes..., label]
179
                    label should contain the header of y in the .csv file 
180
181
    Returns:
182
        numpy ndarrays pertaining to the splits of the training and held out test data. 
183
    """
184
    
185
    df = pd.read_csv(file)
186
    y = df['y'].values.astype(float)
187
    X = df[df.columns[1:-1]].values
188
    
189
    # held-out master split
190
    X_working, X_held_out, y_working, y_held_out = train_test_split(X,
191
                                                    y,
192
                                                    train_size=0.8,
193
                                                    shuffle=True)
194
    
195
    return X_working, y_working, X_held_out, y_held_out
196
197
198
# def compute_metrics(y_hat, y):
199
#     _, preds = torch.max(y_hat, 1)
200
#     f1_score = F1Score(y, preds, average='micro')
201
#     cm = ConfusionMatrix(y, preds)
202
    
203
#     return f1_score, cm       
204
205
def kfold_cross_validation(args, model, fname, X, y, k, early_stopping_patience, iter, **trainer_kwargs):
206
    """K Fold cross validation method to train the neural network model
207
208
    Args:
209
        args (dict): arguments dictionary with all the variables
210
        model (LModel): The model object of LModel class
211
        X (numpy ndarray): Training data 
212
        y (numpy array): Training labels
213
        k (int): Number of cross validation to be conducted
214
        early_stopping_patience (int): Patience for early stopping checks
215
        iter (int): number of iterations of the whole pipeline
216
217
    Returns:
218
        best_model_weights (numpy ndarray): best model weights after training and validation 
219
        best_train_index (list): train indices which led to best model  
220
    """
221
    kfold = KFold(n_splits=k, shuffle=True)
222
    best_model_weights = None
223
    best_train_index = None
224
    best_val_metric = float("-inf")  
225
    
226
    for fold, (train_index, val_index) in enumerate(kfold.split(X)): 
227
        print(f"Fold {fold+1}")
228
        print(len(train_index))
229
        print(len(val_index))
230
        X_train, X_val = X[train_index], X[val_index]
231
        y_train, y_val = y[train_index], y[val_index]
232
        
233
        #create dataloaders 
234
        train_data = OmicsData(X_train,y_train)
235
        val_data = OmicsData(X_val, y_val)
236
        train_dataloader = DataLoader(train_data)
237
        val_dataloader = DataLoader(val_data)
238
        #rint(val_dataloader)
239
        
240
        checkpoint_callback = ModelCheckpoint(
241
                                        dirpath=f"checkpoints/{fname}/fold_{fold}",
242
                                        save_top_k=1, 
243
                                        monitor="val_loss",
244
                                        mode="min",
245
                                        )
246
        early_stopping = EarlyStopping(
247
                                    monitor="val_loss", 
248
                                    patience=early_stopping_patience,
249
                                    mode="min"
250
                                    )
251
        
252
        logger = TensorBoardLogger(save_dir="logs", name=f"{fname}_fold_{fold}")
253
        
254
        trainer = pl.Trainer(
255
        accelerator="gpu",
256
        devices=1,
257
        max_epochs=args.epoch,
258
        callbacks=[early_stopping, checkpoint_callback],
259
        accumulate_grad_batches=len(train_dataloader),
260
        check_val_every_n_epoch=10,
261
        logger=logger
262
        )
263
        
264
        trainer.fit(model=model, 
265
            train_dataloaders=train_dataloader, 
266
            val_dataloaders= val_dataloader)
267
        
268
        val_metric = trainer.callback_metrics.get("val_acc")
269
        print(val_metric)
270
        if val_metric > best_val_metric:
271
            best_val_metric = val_metric
272
            best_model_weights = model.state_dict()
273
            best_train_index = train_index.tolist()
274
            
275
    return best_model_weights, best_train_index
276
277
    
278
def training(args, fname, X, y, iter): 
279
    """Training method which calls the kfold cross validation code
280
281
    Args:
282
        args (dict): dictionary of arguments from input 
283
        fname (str): file name for storing checkpoints and embeddings
284
        X (numpy ndarray): Training data
285
        y (numpy array): Training labels
286
        iter (int): number of iterations to conduct
287
288
    Returns:
289
        embedded_train (numpy ndarray): Embedded training data of size samples x output dimension
290
        train_index (array): training indices 
291
        model (LModel): LModel object 
292
        model_weights (numpy ndarray): learned weights of the model
293
        
294
    """
295
    num_feats = X.shape[1]
296
    model = LModel(
297
        dim=num_feats, 
298
        output_dim = args.dim,
299
        batch_size=args.batch_size, 
300
        weight_decay=args.weight_decay,
301
        lr=args.lr
302
    )
303
    model_weights, train_index = kfold_cross_validation(args, 
304
                                                        model,
305
                                                        fname, 
306
                                                        X, 
307
                                                        y, 
308
                                                        args.num_cv, 
309
                                                        args.patience,
310
                                                        iter
311
                                                        )
312
    model.load_state_dict(model_weights)
313
    embedded_train = model.embedder(torch.tensor(X[train_index], dtype=torch.float32)).detach().numpy()
314
    #print(embedded_train.shape)
315
    
316
    return embedded_train, train_index, model, model_weights
317
318
def testing(X,y, model, model_weights):
319
    
320
    test_data = OmicsData(X, y)
321
    test_dataloader = DataLoader(test_data)
322
    model.load_state_dict(model_weights)
323
    X = torch.tensor(X, dtype=torch.float32) 
324
    embedded_test = model.embedder(torch.tensor(X, dtype=torch.float32)).detach().numpy()
325
    print(embedded_test.shape)
326
    trainer = pl.Trainer()
327
    results = trainer.test(model=model, dataloaders=test_dataloader)
328
    
329
    return results, embedded_test
330
331
def compute_svc(X_train, y_train, X_test, y_test, c = 1):
332
    svc = SVC(C=c)
333
    # y_train = torch.argmax(torch.tensor(y_train, dtype=torch.float32),dim=1)
334
    # y_test = torch.argmax(torch.tensor(y_test, dtype=torch.float32),dim=1)
335
    svc_vanilla = svc.fit(X_train, y_train)
336
    labels_vanilla = svc_vanilla.predict(X_test)
337
    f1_svc = f1_score(y_test, labels_vanilla, average='micro')
338
    
339
    return f1_svc
340
    
341
def compute_QSVC(X_train, y_train, X_test, y_test, encoding='ZZ', c = 1, pegasos=False):
342
    
343
    service = QiskitRuntimeService(instance="accelerated-disc/internal/default") 
344
    backend = service.least_busy(simulator=False, operational=True)    
345
    # service = QiskitRuntimeService()    
346
    # backend = AerSimulator(method='statevector')
347
    algorithm_globals.random_seed = 12345
348
349
    feature_map = None
350
351
    if encoding == 'ZZ' :
352
        feature_map = ZZFeatureMap(feature_dimension=X_train.shape[1], 
353
                            reps=2, 
354
                            entanglement='linear')
355
    else: 
356
        if encoding == 'Z': 
357
            feature_map = ZFeatureMap(feature_dimension=X_train.shape[1], 
358
                            reps=2)
359
        if encoding == 'P': 
360
            feature_map = PauliFeatureMap(feature_dimension=X_train.shape[1], 
361
                            reps=2, entanglement='linear')
362
363
    sampler = Sampler(backend=backend, 
364
                    options={"shots": 1024}) 
365
    fidelity = ComputeUncompute(sampler=sampler)
366
    Qkernel = FidelityQuantumKernel(fidelity=fidelity, feature_map=feature_map)
367
    if pegasos == False: 
368
        qsvc = QSVC(quantum_kernel=Qkernel, C=c)
369
    else: 
370
        qsvc = PegasosQSVC(quantum_kernel=Qkernel, C=c)
371
    qsvc_model = qsvc.fit(X_train, y_train)
372
    labels_qsvc = qsvc_model.predict(X_test)
373
    f1_qsvc = f1_score(y_test, labels_qsvc, average='micro')
374
375
    return f1_qsvc
376
377
if __name__ == "__main__":
378
    args = parse_args()
379
    validate_args(args)
380
    file_name = os.path.basename(args.file).split('.')[0]
381
    results_iter = {}
382
    for i in range(args.iter):
383
        print("===== Iteration " + str(i+1) + " =====")
384
        #process data to obtain master split
385
        X_working,y_working,X_held_out,y_held_out = process_data(args.file)
386
        print("Training size: ", X_working.shape[0])
387
        print("Held out size: ", X_held_out.shape[0])
388
        
389
        fname = file_name + "_iter" + str(i)
390
        #get embedded training data and the best performing model weights using cross validation
391
        embedded_train, train_idx, model, model_weights = training(args,
392
                                                                fname,
393
                                                                X_working, 
394
                                                                y_working, 
395
                                                                i)
396
        fname_train = fname + "_train_embedding"
397
        np.save(f"checkpoints/{fname}/{fname_train}", embedded_train)
398
        fname_train_y = fname + "_train_target"
399
        np.save(f"checkpoints/{fname}/{fname_train_y}", y_working[train_idx])
400
        
401
        results_dict, embedded_test = testing(X_held_out, y_held_out, model, model_weights)
402
        results_nn = results_dict[0]
403
        print("NN results on held-out data:", results_nn['test_acc'])
404
        
405
        fname_test = fname + "_test_embedding"
406
        np.save(f"checkpoints/{fname}/{fname_test}", embedded_test)
407
        fname_test_y = fname + "_test_target"
408
        np.save(f"checkpoints/{fname}/{fname_test_y}", y_held_out)
409
        
410
        results_svc = compute_svc(
411
                                embedded_train, 
412
                                y_working[train_idx], 
413
                                embedded_test, 
414
                                y_held_out,
415
                                args.C
416
                                )
417
418
        print("SVC results on held-out data: " + str(results_svc))
419
        
420
        
421
        results_qsvc = compute_QSVC(
422
                                embedded_train, 
423
                                y_working[train_idx],
424
                                embedded_test,
425
                                y_held_out, 
426
                                args.encoding,
427
                                args.C
428
                                )     
429
        print("QSVC results on held-out data: " + str(results_qsvc))
430
431
        results_iter[i] = [results_nn['test_acc'], results_svc, results_qsvc]
432
    
433
    results_df = pd.DataFrame.from_dict(results_iter, orient='index')
434
    print(results_df)
435
    
436
    str_time = strftime("%Y-%m-%d-%H-%M", gmtime())
437
    of_name = file_name + "_" + str_time + "_Results.csv" 
438
    results_df.to_csv(of_name, index=False, header=['NN', 'SVC', 'QSVC'])
439
    max_memory_allocated = torch.cuda.max_memory_allocated()
440
    print(f"{max_memory_allocated/1024**3:.2f} GB of GPU memory allocated")
441
    
442