a b/tests/edgepy/test_glm.py
1
import unittest
2
3
import numpy as np
4
import pandas as pd
5
6
from inmoose.edgepy import DGEList, glmFit, glmLRT, glmQLFTest, topTags
7
from inmoose.utils import rnbinom
8
9
10
class test_DGEGLM(unittest.TestCase):
11
    def test_constructor(self):
12
        from inmoose.edgepy import DGEGLM
13
14
        d = DGEGLM((1, 2, 3, 4, 5))
15
16
        self.assertIsNotNone(d.coefficients)
17
        self.assertIsNotNone(d.fitted_values)
18
        self.assertIsNotNone(d.deviance)
19
        self.assertIsNotNone(d.iter)
20
        self.assertIsNotNone(d.failed)
21
22
        self.assertIsNone(d.counts)
23
        self.assertIsNone(d.design)
24
        self.assertIsNone(d.offset)
25
        self.assertIsNone(d.dispersion)
26
        self.assertIsNone(d.weights)
27
        self.assertIsNone(d.prior_count)
28
        self.assertIsNone(d.unshrunk_coefficients)
29
        self.assertIsNone(d.method)
30
        self.assertIsNone(d.AveLogCPM)
31
32
33
class test_glm(unittest.TestCase):
34
    def setUp(self):
35
        y = np.array(rnbinom(80, size=5, mu=20, seed=42)).reshape((20, 4))
36
        y = np.vstack(([0, 0, 0, 0], [0, 0, 2, 2], y))
37
        self.group = np.array([1, 1, 2, 2])
38
        self.d = DGEList(counts=y, group=self.group, lib_size=np.arange(1001, 1005))
39
40
    def test_glmFit(self):
41
        with self.assertRaisesRegex(
42
            ValueError, expected_regex="No dispersion values found in DGEList object"
43
        ):
44
            self.d.glmFit()
45
        # first estimate common dispersion
46
        self.d.estimateGLMCommonDisp()
47
48
        # test oneway method
49
        e = self.d.glmFit(prior_count=0)
50
        self.assertEqual(e.method, "oneway")
51
        coef_ref = np.array(
52
            [
53
                [-100000000, 0],
54
                [-100000000, 99999993.7818981],
55
                [-3.818158376, -0.2600061895],
56
                [-4.306653048, -0.001994836928],
57
                [-3.710751661, -0.5260476658],
58
                [-4.511242547, 0.9498904882],
59
                [-4.047000403, 0.5031039058],
60
                [-3.864988612, 0.2859441576],
61
                [-4.235093266, 0.1860767109],
62
                [-4.235334586, 0.5821202392],
63
                [-3.576947226, -0.4440123944],
64
                [-4.04714793, -0.03093068582],
65
                [-3.576961828, -0.4441419915],
66
                [-3.475280168, -0.7278715951],
67
                [-3.888726197, 0.1761259908],
68
                [-3.795782156, 0.5221193611],
69
                [-3.991513569, -0.2807027773],
70
                [-3.559630535, -0.2833225949],
71
                [-3.475132503, -0.05155169483],
72
                [-3.325699504, -0.1042127818],
73
                [-3.888819724, 0.2916754177],
74
                [-4.018902851, 0.1281254313],
75
            ]
76
        )
77
        self.assertTrue(np.allclose(e.coefficients, coef_ref, atol=1e-6, rtol=0))
78
        self.d.AveLogCPM = None
79
        e = self.d.glmFit(prior_count=0)
80
        self.assertTrue(np.allclose(e.coefficients, coef_ref, atol=1e-6, rtol=0))
81
82
        # test levenberg method
83
        design = np.array([[1, 0], [1, 0], [0, 1], [0, 2]])
84
        e = self.d.glmFit(design=design, prior_count=0)
85
        self.assertEqual(e.method, "levenberg")
86
        coef_ref = np.array(
87
            [
88
                [np.nan, np.nan],
89
                [-22.911238587272, -4.33912438169024],
90
                [-3.81815837585705, -2.1292021791932],
91
                [-4.30665304828008, -2.29993049236002],
92
                [-3.71075166096661, -3.34440560955535],
93
                [-4.51124254688311, -1.97663680068325],
94
                [-4.04700040309538, -2.04478756288351],
95
                [-3.86498861235895, -1.99475491866954],
96
                [-4.23509326552448, -2.31061055207476],
97
                [-4.23533458597804, -1.96831279885414],
98
                [-3.57694722597284, -2.1476132892472],
99
                [-4.04714792989704, -2.22576713012883],
100
                [-3.57696182768343, -2.00991843107384],
101
                [-3.47528016753611, -2.43228621567992],
102
                [-3.88872619677031, -2.04166124945388],
103
                [-3.79578215609008, -1.82484690967382],
104
                [-3.99151356863038, -2.36307176428644],
105
                [-3.55963053465756, -2.62529213670109],
106
                [-3.47513250262264, -2.22250430528328],
107
                [-3.32569950353846, -2.03272287830771],
108
                [-3.88881972448195, -1.91858334888063],
109
                [-4.0189028512534, -2.23681029760319],
110
            ]
111
        )
112
        self.assertTrue(
113
            np.allclose(e.coefficients, coef_ref, atol=1e-5, rtol=0, equal_nan=True)
114
        )
115
116
        with self.assertRaisesRegex(
117
            ValueError,
118
            expected_regex="design should have as many rows as y has columns",
119
        ):
120
            glmFit(self.d.counts, design=np.ones((5, 1)))
121
122
        with self.assertRaisesRegex(
123
            ValueError, expected_regex="No dispersion values provided"
124
        ):
125
            glmFit(self.d.counts, design=design)
126
127
        with self.assertRaisesRegex(
128
            ValueError,
129
            expected_regex="Dimensions of dispersion do not agree with dimensions of y",
130
        ):
131
            glmFit(self.d.counts, design=design, dispersion=np.ones((5, 1)))
132
133
        with self.assertRaisesRegex(
134
            ValueError,
135
            expected_regex="Dimensions of offset do not agree with dimensions of y",
136
        ):
137
            glmFit(
138
                self.d.counts, design=design, dispersion=0.05, offset=np.ones((5, 1))
139
            )
140
141
        with self.assertRaisesRegex(
142
            ValueError,
143
            expected_regex="lib_size has wrong length, should agree with ncol\(y\)",
144
        ):
145
            glmFit(
146
                self.d.counts, design=design, dispersion=0.05, lib_size=np.ones((2,))
147
            )
148
149
        e = glmFit(self.d.counts, dispersion=0.05, prior_count=0)
150
        coef_ref = np.array(
151
            [
152
                -1.000000e08,
153
                -6.099236e00,
154
                -3.120865e00,
155
                -3.495951e00,
156
                -3.138641e00,
157
                -3.114458e00,
158
                -2.953059e00,
159
                -2.908932e00,
160
                -3.325914e00,
161
                -3.096400e00,
162
                -2.953421e00,
163
                -3.252006e00,
164
                -2.951040e00,
165
                -2.960371e00,
166
                -2.980034e00,
167
                -2.691442e00,
168
                -3.309082e00,
169
                -2.895371e00,
170
                -2.686176e00,
171
                -2.563754e00,
172
                -2.917718e00,
173
                -3.142277e00,
174
            ]
175
        ).reshape(e.coefficients.shape)
176
        self.assertTrue(np.allclose(e.coefficients, coef_ref, atol=1e-6, rtol=0))
177
178
    def test_glmQLFit(self):
179
        with self.assertRaisesRegex(
180
            ValueError, expected_regex="No dispersion values found in DGEList object"
181
        ):
182
            self.d.glmQLFit()
183
        # first estimate common dispersion
184
        self.d.estimateGLMCommonDisp()
185
186
        e = self.d.glmQLFit()
187
        self.assertEqual(e.method, "oneway")
188
        coef_ref = np.array(
189
            [
190
                [-8.989943, 0.000000000],
191
                [-8.989943, 2.832275076],
192
                [-3.812746, -0.258338047],
193
                [-4.297698, -0.001976527],
194
                [-3.705921, -0.522525658],
195
                [-4.500199, 0.942978065],
196
                [-4.040138, 0.500298444],
197
                [-3.859316, 0.284480541],
198
                [-4.226767, 0.184626544],
199
                [-4.227016, 0.578352261],
200
                [-3.572745, -0.441539334],
201
                [-4.040290, -0.030706290],
202
                [-3.572760, -0.441672710],
203
                [-3.471510, -0.723582927],
204
                [-3.882900, 0.175143962],
205
                [-3.790498, 0.519873066],
206
                [-3.985036, -0.278532340],
207
                [-3.555513, -0.281880293],
208
                [-3.471359, -0.051338428],
209
                [-3.322486, -0.103831822],
210
                [-3.882996, 0.290140213],
211
                [-4.012239, 0.127298731],
212
            ]
213
        )
214
        self.assertTrue(np.allclose(e.coefficients, coef_ref, atol=1e-4, rtol=0))
215
        self.assertTrue(
216
            np.array_equal(
217
                [0, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2],
218
                e.df_residual_zeros,
219
            )
220
        )
221
        self.assertAlmostEqual(e.df_prior, 4.672744, places=6)
222
        var_post_ref = np.array(
223
            [
224
                4.458200e-07,
225
                8.302426e-06,
226
                6.524738e-01,
227
                3.709944e-01,
228
                2.159878e00,
229
                6.208389e-01,
230
                5.878879e-01,
231
                1.325040e00,
232
                3.203313e-01,
233
                1.176470e00,
234
                6.941698e-01,
235
                4.743232e-01,
236
                1.205789e00,
237
                5.523578e-01,
238
                6.314750e-01,
239
                5.425240e-01,
240
                3.352875e-01,
241
                2.129557e00,
242
                9.924612e-01,
243
                5.895944e-01,
244
                5.991600e-01,
245
                4.650080e-01,
246
            ]
247
        )
248
        self.assertTrue(np.allclose(e.var_post, var_post_ref, atol=1e-6, rtol=0))
249
        var_prior_ref = np.array(
250
            [
251
                4.458200e-07,
252
                9.918589e-06,
253
                6.392936e-01,
254
                2.804475e-01,
255
                6.394279e-01,
256
                6.503286e-01,
257
                7.568130e-01,
258
                7.713810e-01,
259
                4.405790e-01,
260
                6.716242e-01,
261
                7.527023e-01,
262
                5.182746e-01,
263
                7.526737e-01,
264
                7.527380e-01,
265
                7.426247e-01,
266
                7.491848e-01,
267
                4.566134e-01,
268
                7.744945e-01,
269
                7.490555e-01,
270
                7.042350e-01,
271
                7.668183e-01,
272
                6.273684e-01,
273
            ]
274
        )
275
        self.assertTrue(np.allclose(e.var_prior, var_prior_ref, atol=1e-6, rtol=0))
276
277
    def test_glmQLFTest(self):
278
        # first estimate common dispersion
279
        self.d.estimateGLMCommonDisp()
280
        s = glmQLFTest(self.d.glmQLFit())
281
        table_ref = pd.DataFrame(
282
            {
283
                "log2FoldChange": [
284
                    0.00000,
285
                    4.086109e00,
286
                    -0.3727030,
287
                    -2.851525e-03,
288
                    -0.7538452,
289
                    1.36042978,
290
                    0.7217781,
291
                    0.4104187,
292
                    0.2663598,
293
                    0.8343859,
294
                    -0.6370066,
295
                    -0.04429981,
296
                    -0.6371990,
297
                    -1.04390950,
298
                    0.2526793,
299
                    0.7500183,
300
                    -0.4018372,
301
                    -0.4066673,
302
                    -0.07406570,
303
                    -0.14979766,
304
                    0.4185838,
305
                    0.1836532,
306
                ],
307
                "lfcSE": [
308
                    0.00000,
309
                    0.119306,
310
                    0.308323,
311
                    0.339974,
312
                    0.312279,
313
                    0.323978,
314
                    0.299188,
315
                    0.293219,
316
                    0.324064,
317
                    0.310603,
318
                    0.298936,
319
                    0.316756,
320
                    0.298942,
321
                    0.304455,
322
                    0.297726,
323
                    0.284142,
324
                    0.323617,
325
                    0.291959,
326
                    0.280812,
327
                    0.275328,
328
                    0.294518,
329
                    0.308360,
330
                ],
331
                "logCPM": [
332
                    10.95644,
333
                    1.154123e01,
334
                    14.3827970,
335
                    1.391051e01,
336
                    14.3829969,
337
                    14.39896457,
338
                    14.6144562,
339
                    14.6840429,
340
                    14.1263018,
341
                    14.4316968,
342
                    14.6005670,
343
                    14.22314458,
344
                    14.6004725,
345
                    14.60068510,
346
                    14.5711003,
347
                    14.9673456,
348
                    14.1463618,
349
                    14.7114142,
350
                    14.96790942,
351
                    15.13649877,
352
                    14.6566606,
353
                    14.3657820,
354
                ],
355
                "stat": [
356
                    0.000000e00,
357
                    5.998382e05,
358
                    4.845263e-01,
359
                    4.551851e-05,
360
                    5.906064e-01,
361
                    6.441791e00,
362
                    2.069610e00,
363
                    3.032625e-01,
364
                    4.811457e-01,
365
                    1.334095e00,
366
                    1.367280e00,
367
                    9.186695e-03,
368
                    7.875974e-01,
369
                    4.518866e00,
370
                    2.379445e-01,
371
                    2.538764e00,
372
                    1.047358e00,
373
                    1.860394e-01,
374
                    1.375613e-02,
375
                    9.649525e-02,
376
                    6.947198e-01,
377
                    1.651446e-01,
378
                ],
379
                "pvalue": [
380
                    1.00000000,
381
                    0.01861634,
382
                    0.50988854,
383
                    0.99481417,
384
                    0.46850481,
385
                    0.04035278,
386
                    0.19546361,
387
                    0.59978572,
388
                    0.51131874,
389
                    0.28775816,
390
                    0.28234748,
391
                    0.92645653,
392
                    0.40567730,
393
                    0.07303738,
394
                    0.64131120,
395
                    0.15720197,
396
                    0.34177834,
397
                    0.67982404,
398
                    0.91008476,
399
                    0.76555392,
400
                    0.43338141,
401
                    0.69717923,
402
                ],
403
            },
404
            index=[f"gene{i}" for i in range(22)],
405
        )
406
        pd.testing.assert_frame_equal(table_ref, s, check_frame_type=False, rtol=1e-4)
407
408
    def test_glmLRT(self):
409
        # first estimate common dispersion
410
        self.d.estimateGLMCommonDisp()
411
        s = glmLRT(self.d.glmFit())
412
        table_ref = pd.DataFrame(
413
            {
414
                "log2FoldChange": [
415
                    0.00000,
416
                    4.08610921,
417
                    -0.3727030,
418
                    -2.851525e-03,
419
                    -0.7538452,
420
                    1.36042978,
421
                    0.7217781,
422
                    0.4104187,
423
                    0.2663598,
424
                    0.8343859,
425
                    -0.6370066,
426
                    -0.044299812,
427
                    -0.6371990,
428
                    -1.043910,
429
                    0.2526793,
430
                    0.7500183,
431
                    -0.4018372,
432
                    -0.4066673,
433
                    -0.07406570,
434
                    -0.1497977,
435
                    0.4185838,
436
                    0.18365325,
437
                ],
438
                "lfcSE": [
439
                    0.00000,
440
                    0.119306,
441
                    0.308323,
442
                    0.339974,
443
                    0.312279,
444
                    0.323978,
445
                    0.299188,
446
                    0.293219,
447
                    0.324064,
448
                    0.310603,
449
                    0.298936,
450
                    0.316756,
451
                    0.298942,
452
                    0.304455,
453
                    0.297726,
454
                    0.284142,
455
                    0.323617,
456
                    0.291959,
457
                    0.280812,
458
                    0.275328,
459
                    0.294518,
460
                    0.308360,
461
                ],
462
                "logCPM": [
463
                    10.95644,
464
                    11.54122852,
465
                    14.3827970,
466
                    1.391051e01,
467
                    14.3829969,
468
                    14.39896457,
469
                    14.6144562,
470
                    14.6840429,
471
                    14.1263018,
472
                    14.4316968,
473
                    14.6005670,
474
                    14.223144579,
475
                    14.6004725,
476
                    14.600685,
477
                    14.5711003,
478
                    14.9673456,
479
                    14.1463618,
480
                    14.7114142,
481
                    14.96790942,
482
                    15.1364988,
483
                    14.6566606,
484
                    14.36578202,
485
                ],
486
                "stat": [
487
                    0.000000e00,
488
                    4.980113e00,
489
                    3.161407e-01,
490
                    1.688711e-05,
491
                    1.275638e00,
492
                    3.999315e00,
493
                    1.216698e00,
494
                    4.018350e-01,
495
                    1.541260e-01,
496
                    1.569523e00,
497
                    9.491243e-01,
498
                    4.357462e-03,
499
                    9.496767e-01,
500
                    2.496031e00,
501
                    1.502560e-01,
502
                    1.377340e00,
503
                    3.511659e-01,
504
                    3.961817e-01,
505
                    1.365242e-02,
506
                    5.689306e-02,
507
                    4.162483e-01,
508
                    7.679356e-02,
509
                ],
510
                "pvalue": [
511
                    1.00000000,
512
                    0.02564032,
513
                    0.57393623,
514
                    0.99672119,
515
                    0.25871164,
516
                    0.04551876,
517
                    0.27000955,
518
                    0.52614310,
519
                    0.69462318,
520
                    0.21027629,
521
                    0.32994229,
522
                    0.94736901,
523
                    0.32980161,
524
                    0.11413364,
525
                    0.69829083,
526
                    0.24055469,
527
                    0.55345388,
528
                    0.52906782,
529
                    0.90698400,
530
                    0.81147574,
531
                    0.51881503,
532
                    0.78169064,
533
                ],
534
            },
535
            index=[f"gene{i}" for i in range(22)],
536
        )
537
        pd.testing.assert_frame_equal(table_ref, s, check_frame_type=False, rtol=1e-4)
538
539
    def test_topTags(self):
540
        self.d.estimateGLMCommonDisp()
541
        s = glmLRT(self.d.glmFit())
542
        t = topTags(s)
543
        self.assertTrue(
544
            np.array_equal(
545
                t.table.index,
546
                [
547
                    "gene1",
548
                    "gene5",
549
                    "gene13",
550
                    "gene9",
551
                    "gene15",
552
                    "gene4",
553
                    "gene6",
554
                    "gene12",
555
                    "gene10",
556
                    "gene20",
557
                ],
558
            )
559
        )