Diff of /src/likelihoods.cpp [000000] .. [dfe06d]

Switch to unified view

a b/src/likelihoods.cpp
1
#include <Rmath.h>
2
#include <Rcpp.h>
3
#include "internals.h"
4
#include "likelihoods.h"
5
6
7
// IMPORTANT: ON INDEXING VECTORS AND ANCESTRIES
8
9
// Most of the functions implemented here are susceptible to be called from R
10
// via Rcpp, and are therefore treated as interfaces. This causes a number of
11
// headaches when using indices of cases defined in R (1:N) to refer to elements
12
// in Rcpp / Cpp vectors (0:N-1). By convention, we store all data on the
13
// original scale (1:N), and modify indices whenever accessing elements of
14
// vectors. In other words, in an expression like 'alpha[j]', 'j' should always
15
// be on the internal scale (0:N-1).
16
17
// In all these functions, 'SEXP i' is an optional vector of case indices, on
18
// the 1:N scale.
19
20
21
22
23
24
25
// ---------------------------
26
27
// This likelihood corresponds to the probability of observing a number of
28
// mutations between cases and their ancestors. See src/likelihoods.cpp for
29
// details of the Rcpp implmentation.
30
31
// The likelihood is based on the number of mutations between a case and its
32
// ancestor; these are extracted from a pairwise genetic distance matrix
33
// (data$D) the log-likelihood is computed as: sum(mu^nmut + (1-mu)^(L-nmut))
34
// with:
35
36
// 'mu' is the mutation probability
37
// 'L' the number of sites in the alignment
38
// 'n_mut' the number of mutations between an ancestor and its descendent
39
// 'n_non_mut' the number of sites that have not mutated
40
41
// For any given case at 'nmut' mutations from its ancestor, with kappa
42
// generations in between, the log-likelihood is defined as:
43
44
// log(mu) * n_mut + log(1 - mu) * {(L - n_mut) + (L * (kappa-1))}
45
46
47
// when summing over several individuals, it becomes:
48
49
// log(mu) * sum_i(n_mut_i) + log(1-mu) * sum_i((L - n_mut_i) + (L * (kappa_i - 1)))
50
51
double cpp_ll_genetic(Rcpp::List data, Rcpp::List param, SEXP i,
52
              Rcpp::RObject custom_function) {
53
  Rcpp::IntegerMatrix D = data["D"];
54
  if (D.ncol() < 1) return 0.0;
55
56
  size_t N = static_cast<size_t>(data["N"]);
57
  if (N < 2) return 0.0;
58
59
  Rcpp::List l;
60
  if (custom_function != R_NilValue) {
61
    l = Rcpp::as<Rcpp::List>(custom_function);
62
  }
63
  if (custom_function == R_NilValue || (l.size() > 0 && l[0] == R_NilValue)) {
64
65
    // Variables from the data & param
66
    Rcpp::NumericMatrix w_dens = data["log_w_dens"];
67
    size_t K = w_dens.nrow();
68
    double mu = Rcpp::as<double>(param["mu"]);
69
    long int L = Rcpp::as<int>(data["L"]);
70
    Rcpp::IntegerVector alpha = param["alpha"]; // values are on 1:N
71
    Rcpp::IntegerVector kappa = param["kappa"];
72
    Rcpp::LogicalVector has_dna = data["has_dna"];
73
74
75
    // Local variables used for computatoins
76
    size_t n_mut = 0;
77
    size_t n_non_mut = 0;
78
    double out = 0;
79
    bool found[1];
80
    size_t ances[1];
81
    size_t n_generations[1];
82
    found[0] = false;
83
    ances[0] = NA_INTEGER;
84
    n_generations[0] = NA_INTEGER;
85
86
  
87
    // Invalid values of mu
88
    if (mu < 0.0 || mu > 1.0) {
89
      return R_NegInf;
90
    }
91
92
    
93
    // NOTE ON MISSING SEQUENCES
94
95
    // Terms participating to the genetic likelihood correspond to pairs
96
    // of ancestor-descendent which have a genetic sequence. The
97
    // log-likelihood of other pairs is 0.0, and can therefore be
98
    // ommitted. Note the possible source of confusion in indices here:
99
100
    // 'has_dna' is a vector, thus indexed from 0:(N-1)
101
      
102
    // 'cpp_get_n_mutations' is a function, and thus takes indices on 1:N
103
104
105
    
106
    // all cases are retained
107
    
108
    if (i == R_NilValue) {
109
      for (size_t j = 0; j < N; j++) { // 'j' on 0:(N-1)
110
    if (alpha[j] != NA_INTEGER) {
111
112
      // kappa restriction
113
114
      if (kappa[j] < 1 || kappa[j] > K) {
115
        return R_NegInf;
116
      }
117
118
      // missing sequences handled here
119
      
120
      if (has_dna[j]) {
121
     
122
        lookup_sequenced_ancestor(alpha, kappa, has_dna, j + 1,
123
                      ances, n_generations, found);
124
125
        if (found[0]) {
126
127
          n_mut = cpp_get_n_mutations(data, j + 1, ances[0]); // remember the offset
128
          n_non_mut = L - n_mut;
129
130
          out += n_mut*log(n_generations[0]*mu) +
131
        n_non_mut*log(1 - n_generations[0]*mu);
132
        
133
        }
134
      }
135
    }
136
      }
137
138
    } else {
139
      // only the cases listed in 'i' are retained
140
      size_t length_i = static_cast<size_t>(LENGTH(i));
141
      Rcpp::IntegerVector vec_i(i);
142
      for (size_t k = 0; k < length_i; k++) {
143
    size_t j = vec_i[k] - 1; // offset
144
    if (alpha[j] != NA_INTEGER) {
145
      // kappa restriction
146
      if (kappa[j] < 1 || kappa[j] > K) {
147
        return R_NegInf;
148
      }
149
150
      // missing sequences handled here
151
          
152
      if (has_dna[j]) {
153
     
154
        lookup_sequenced_ancestor(alpha, kappa, has_dna, j + 1, 
155
                      ances, n_generations, found);
156
157
        if (found[0]) {
158
159
          n_mut = cpp_get_n_mutations(data, j + 1, ances[0]); // remember the offset
160
          n_non_mut = L - n_mut;
161
          
162
          out += n_mut*log(n_generations[0]*mu) +
163
        n_non_mut*log(1 - n_generations[0]*mu);
164
          
165
        }
166
      }
167
      
168
    }
169
170
      }
171
    }
172
173
    return(out);
174
    
175
  } else { // use of a customized likelihood function
176
    Rcpp::Function f = Rcpp::as<Rcpp::Function>(l[0]);
177
    int arity = l[1];
178
    if (arity == 3) return Rcpp::as<double>(f(data, param, i));
179
    return Rcpp::as<double>(f(data, param));
180
  }
181
}
182
183
184
double cpp_ll_genetic(Rcpp::List data, Rcpp::List param, size_t i,
185
              Rcpp::RObject custom_function) {
186
  SEXP si = PROTECT(Rcpp::wrap(i));
187
  double ret = cpp_ll_genetic(data, param, si, custom_function);
188
  UNPROTECT(1);
189
  return ret;
190
}
191
192
193
194
195
196
// ---------------------------
197
198
// This likelihood corresponds to the probability of observing infection dates
199
// of cases given the infection dates of their ancestors.
200
201
double cpp_ll_timing_infections(Rcpp::List data, Rcpp::List param, SEXP i,
202
                Rcpp::RObject custom_function) {
203
  size_t N = static_cast<size_t>(data["N"]);
204
  if(N < 2) return 0.0;
205
206
  Rcpp::List l;
207
  if (custom_function != R_NilValue) {
208
    l = Rcpp::as<Rcpp::List>(custom_function);
209
  }
210
  if (custom_function == R_NilValue || (l.size() > 0 && l[0] == R_NilValue)) {
211
212
    Rcpp::IntegerVector alpha = param["alpha"];
213
    Rcpp::IntegerVector t_inf = param["t_inf"];
214
    Rcpp::IntegerVector kappa = param["kappa"];
215
    Rcpp::NumericMatrix w_dens = data["log_w_dens"];
216
    size_t K = w_dens.nrow();
217
218
    double out = 0.0;
219
220
    // all cases are retained
221
    if (i == R_NilValue) {
222
      for (size_t j = 0; j < N; j++) {
223
    if (alpha[j] != NA_INTEGER) {
224
      size_t delay = t_inf[j] - t_inf[alpha[j] - 1]; // offset
225
      if (delay < 1 || delay > w_dens.ncol()) {
226
        return  R_NegInf;
227
      }
228
      if (kappa[j] < 1 || kappa[j] > K) {
229
        return  R_NegInf;
230
      }
231
232
      out += w_dens(kappa[j] - 1, delay - 1);
233
    }
234
      }
235
    } else {
236
      // only the cases listed in 'i' are retained
237
      size_t length_i = static_cast<size_t>(LENGTH(i));
238
      Rcpp::IntegerVector vec_i(i);
239
      for (size_t k = 0; k < length_i; k++) {
240
    size_t j = vec_i[k] - 1; // offset
241
    if (alpha[j] != NA_INTEGER) {
242
      size_t delay = t_inf[j] - t_inf[alpha[j] - 1]; // offset
243
      if (delay < 1 || delay > w_dens.ncol()) {
244
        return  R_NegInf;
245
      }
246
      if (kappa[j] < 1 || kappa[j] > K) {
247
        return  R_NegInf;
248
      }
249
250
      out += w_dens(kappa[j] - 1, delay - 1);
251
    }
252
253
      }
254
    }
255
256
    return out;
257
  } else { // use of a customized likelihood function
258
    Rcpp::Function f = Rcpp::as<Rcpp::Function>(l[0]);
259
    int arity = l[1];
260
    if (arity == 3) return Rcpp::as<double>(f(data, param, i));
261
    return Rcpp::as<double>(f(data, param));
262
  }
263
}
264
265
266
double cpp_ll_timing_infections(Rcpp::List data, Rcpp::List param, size_t i,
267
              Rcpp::RObject custom_function) {
268
  SEXP si = PROTECT(Rcpp::wrap(i));
269
  double ret = cpp_ll_timing_infections(data, param, si, custom_function);
270
  UNPROTECT(1);
271
  return ret;
272
}
273
274
275
276
277
278
// ---------------------------
279
280
// This likelihood corresponds to the probability of reporting dates of cases
281
// given their infection dates.
282
283
double cpp_ll_timing_sampling(Rcpp::List data, Rcpp::List param, SEXP i,
284
                  Rcpp::RObject custom_function) {
285
  size_t N = static_cast<size_t>(data["N"]);
286
  if(N < 2) return 0.0;
287
288
  Rcpp::List l;
289
  if (custom_function != R_NilValue) {
290
    l = Rcpp::as<Rcpp::List>(custom_function);
291
  }
292
  if (custom_function == R_NilValue || (l.size() > 0 && l[0] == R_NilValue)) {
293
294
    Rcpp::IntegerVector dates = data["dates"];
295
    Rcpp::IntegerVector t_inf = param["t_inf"];
296
    Rcpp::NumericVector f_dens = data["log_f_dens"];
297
298
    double out = 0.0;
299
300
    // all cases are retained
301
    if (i == R_NilValue) {
302
      for (size_t j = 0; j < N; j++) {
303
    size_t delay = dates[j] - t_inf[j];
304
    if (delay < 1 || delay > f_dens.size()) {
305
      return  R_NegInf;
306
    }
307
    out += f_dens[delay - 1];
308
      }
309
    } else {
310
      // only the cases listed in 'i' are retained
311
      size_t length_i = static_cast<size_t>(LENGTH(i));
312
      Rcpp::IntegerVector vec_i(i);
313
      for (size_t k = 0; k < length_i; k++) {
314
    size_t j = vec_i[k] - 1; // offset
315
    size_t delay = dates[j] - t_inf[j];
316
    if (delay < 1 || delay > f_dens.size()) {
317
      return  R_NegInf;
318
    }
319
    out += f_dens[delay - 1];
320
      }
321
    }
322
323
    return out;
324
  }  else { // use of a customized likelihood function
325
    Rcpp::Function f = Rcpp::as<Rcpp::Function>(l[0]);
326
    int arity = l[1];
327
    if (arity == 3) return Rcpp::as<double>(f(data, param, i));
328
    return Rcpp::as<double>(f(data, param));
329
  }
330
}
331
332
333
double cpp_ll_timing_sampling(Rcpp::List data, Rcpp::List param, size_t i,
334
              Rcpp::RObject custom_function) {
335
  SEXP si = PROTECT(Rcpp::wrap(i));
336
  double ret = cpp_ll_timing_sampling(data, param, si, custom_function);
337
  UNPROTECT(1);
338
  return ret;
339
}
340
341
342
343
344
345
// ---------------------------
346
347
// This likelihood corresponds to the probability of a given number of
348
// unreported cases on an ancestry.
349
350
// The likelihood is given by a geometric distribution with probability 'pi'
351
// to report a case
352
353
// - 'kappa' is the number of generation between two successive cases
354
// - 'kappa-1' is the number of unreported cases
355
356
double cpp_ll_reporting(Rcpp::List data, Rcpp::List param, SEXP i,
357
            Rcpp::RObject custom_function) {
358
  Rcpp::NumericMatrix w_dens = data["log_w_dens"];
359
  size_t K = w_dens.nrow();
360
361
  size_t N = static_cast<size_t>(data["N"]);
362
  if(N < 2) return 0.0;
363
364
  double pi = static_cast<double>(param["pi"]);
365
  Rcpp::IntegerVector kappa = param["kappa"];
366
367
  // p(pi < 0) = p(pi > 1) = 0
368
  if (pi < 0.0 || pi > 1.0) {
369
    return R_NegInf;
370
  }
371
372
  Rcpp::List l;
373
  if (custom_function != R_NilValue) {
374
    l = Rcpp::as<Rcpp::List>(custom_function);
375
  }
376
  if (custom_function == R_NilValue || (l.size() > 0 && l[0] == R_NilValue)) {
377
378
    double out = 0.0;
379
380
    // all cases are retained
381
    if (i == R_NilValue) {
382
      for (size_t j = 0; j < N; j++) {
383
    if (kappa[j] != NA_INTEGER) {
384
      if (kappa[j] < 1 || kappa[j] > K) {
385
        return  R_NegInf;
386
      }
387
      out += R::dgeom(kappa[j] - 1.0, pi, 1); // first arg must be cast to double
388
    }
389
      }
390
    } else {
391
      // only the cases listed in 'i' are retained
392
      size_t length_i = static_cast<size_t>(LENGTH(i));
393
      Rcpp::IntegerVector vec_i(i);
394
      for (size_t k = 0; k < length_i; k++) {
395
    size_t j = vec_i[k] - 1; // offset
396
    if (kappa[j] != NA_INTEGER) {
397
      if (kappa[j] < 1 || kappa[j] > K) {
398
        return  R_NegInf;
399
      }
400
      out += R::dgeom(kappa[j] - 1.0, pi, 1); // first arg must be cast to double
401
    }
402
      }
403
    }
404
405
    return out;
406
  } else { // use of a customized likelihood function
407
    Rcpp::Function f = Rcpp::as<Rcpp::Function>(l[0]);
408
    int arity = l[1];
409
    if (arity == 3) return Rcpp::as<double>(f(data, param, i));
410
    return Rcpp::as<double>(f(data, param));
411
  }
412
}
413
414
415
double cpp_ll_reporting(Rcpp::List data, Rcpp::List param, size_t i,
416
              Rcpp::RObject custom_function) {
417
  SEXP si = PROTECT(Rcpp::wrap(i));
418
  double ret = cpp_ll_reporting(data, param, si, custom_function);
419
  UNPROTECT(1);
420
  return ret;
421
}
422
423
424
425
426
427
// ---------------------------
428
429
// This likelihood corresponds to the probability of observing a a reported
430
// contact between cases and their ancestors. See
431
// src/likelihoods.cpp for details of the Rcpp implmentation.
432
433
// The likelihood is based on the contact status between a case and its
434
// ancestor; this is extracted from a pairwise contact matrix (data$C), the
435
// log-likelihood is computed as:
436
// true_pos*eps + false_pos*eps*xi +
437
// false_neg*(1- eps) + true_neg*(1 - eps*xi)
438
//
439
// with:
440
// 'eps' is the contact reporting coverage
441
// 'lambda' is the non-infectious contact rate
442
// 'true_pos' is the number of contacts between transmission pairs
443
// 'false_pos' is the number of contact between non-transmission pairs
444
// 'false_neg' is the number of transmission pairs without contact
445
// 'true_neg' is the number of non-transmission pairs without contact
446
447
double cpp_ll_contact(Rcpp::List data, Rcpp::List param, SEXP i,
448
               Rcpp::RObject custom_function) {
449
  Rcpp::NumericMatrix contacts = data["contacts"];
450
  if (contacts.ncol() < 1) return 0.0;
451
452
  size_t C_combn = static_cast<size_t>(data["C_combn"]);
453
  size_t C_nrow = static_cast<size_t>(data["C_nrow"]);
454
455
  size_t N = static_cast<size_t>(data["N"]);
456
  if (N < 2) return 0.0;
457
458
  Rcpp::List l;
459
  if (custom_function != R_NilValue) {
460
    l = Rcpp::as<Rcpp::List>(custom_function);
461
  }
462
  if (custom_function == R_NilValue || (l.size() > 0 && l[0] == R_NilValue)) {
463
464
    double out;
465
    double eps = Rcpp::as<double>(param["eps"]);
466
    double lambda = Rcpp::as<double>(param["lambda"]);
467
    Rcpp::IntegerVector alpha = param["alpha"];
468
    Rcpp::IntegerVector kappa = param["kappa"];
469
470
    size_t true_pos = 0;
471
    size_t false_pos = 0;
472
    size_t false_neg = 0;
473
    size_t true_neg = 0;
474
    size_t imports = 0;
475
    size_t unobsv_case = 0;
476
477
    // p(eps < 0 || lambda < 0) = 0
478
    if (eps < 0.0 || lambda < 0.0) {
479
      return R_NegInf;
480
    }
481
482
    // all cases are retained (currently no support for i subsetting)
483
    for (size_t j = 0; j < N; j++) {
484
      if (alpha[j] == NA_INTEGER) {
485
    imports += 1;
486
      } else if (kappa[j] > 1) {
487
    unobsv_case += 1;
488
      } else {
489
    true_pos += contacts(j, alpha[j] - 1); // offset
490
      }
491
    }
492
493
    false_pos = C_nrow - true_pos;
494
    false_neg = N - imports - unobsv_case - true_pos;
495
    true_neg = C_combn - true_pos - false_pos - false_neg;
496
497
    // deal with special case when lambda == 0 and eps == 1, to avoid log(0)
498
    if(lambda == 0.0) {
499
      if(false_pos > 0) {
500
    out = R_NegInf;
501
      } else {
502
    out = log(eps) * (double) true_pos +
503
      log(1 - eps) * (double) false_neg +
504
      log(1 - eps*lambda) * (double) true_neg;
505
      }
506
    } else if(eps == 1.0) {
507
      if(false_neg > 0) {
508
    out = R_NegInf;
509
      } else {
510
    out = log(eps) * (double) true_pos +
511
      log(eps*lambda) * (double) false_pos +
512
      log(1 - eps*lambda) * (double) true_neg;
513
      }
514
    } else {
515
      out = log(eps) * (double) true_pos +
516
    log(eps*lambda) * (double) false_pos +
517
    log(1 - eps) * (double) false_neg +
518
    log(1 - eps*lambda) * (double) true_neg;
519
    }
520
521
    return out;
522
    
523
  } else { //use of a customized likelihood function
524
    Rcpp::Function f = Rcpp::as<Rcpp::Function>(l[0]);
525
    int arity = l[1];
526
    if (arity == 3) return Rcpp::as<double>(f(data, param, i));
527
    return Rcpp::as<double>(f(data, param));
528
  }
529
}
530
531
532
double cpp_ll_contact(Rcpp::List data, Rcpp::List param, size_t i,
533
              Rcpp::RObject custom_function) {
534
  SEXP si = PROTECT(Rcpp::wrap(i));
535
  double ret = cpp_ll_contact(data, param, si, custom_function);
536
  UNPROTECT(1);
537
  return ret;
538
}
539
540
541
542
543
544
// ---------------------------
545
546
// This likelihood corresponds to the sums of the separate timing likelihoods,
547
// which include:
548
549
// - p(infection dates): see function cpp_ll_timing_infections
550
// - p(collection dates): see function cpp_ll_timing_sampling
551
552
double cpp_ll_timing(Rcpp::List data, Rcpp::List param, SEXP i,
553
             Rcpp::RObject custom_functions) {
554
555
  if (custom_functions == R_NilValue) {
556
    return cpp_ll_timing_infections(data, param, i) +
557
      cpp_ll_timing_sampling(data, param, i);
558
  } else { // use of a customized likelihood functions
559
    Rcpp::List list_functions = Rcpp::as<Rcpp::List>(custom_functions);
560
    return cpp_ll_timing_infections(data, param, i, list_functions["timing_infections"]) +
561
      cpp_ll_timing_sampling(data, param, i, list_functions["timing_sampling"]);
562
563
  }
564
}
565
566
567
double cpp_ll_timing(Rcpp::List data, Rcpp::List param, size_t i,
568
              Rcpp::RObject custom_function) {
569
  SEXP si = PROTECT(Rcpp::wrap(i));
570
  double ret = cpp_ll_timing(data, param, si, custom_function);
571
  UNPROTECT(1);
572
  return ret;
573
}
574
575
576
577
578
// ---------------------------
579
580
// This likelihood corresponds to the sums of the separate likelihoods, which
581
// include:
582
583
// - p(infection dates): see function cpp_ll_timing_infections
584
// - p(collection dates): see function cpp_ll_timing_sampling
585
// - p(genetic diversity): see function cpp_ll_genetic
586
// - p(missing cases): see function cpp_ll_reporting
587
// - p(contact): see function cpp_ll_contact 
588
589
double cpp_ll_all(Rcpp::List data, Rcpp::List param, SEXP i,
590
          Rcpp::RObject custom_functions) {
591
592
  if (custom_functions == R_NilValue) {
593
594
    return cpp_ll_timing_infections(data, param, i) +
595
      cpp_ll_timing_sampling(data, param, i) +
596
      cpp_ll_genetic(data, param, i) +
597
      cpp_ll_reporting(data, param, i) +
598
      cpp_ll_contact(data, param, i);
599
600
  }  else { // use of a customized likelihood functions
601
    Rcpp::List list_functions = Rcpp::as<Rcpp::List>(custom_functions);
602
603
    return cpp_ll_timing_infections(data, param, i, list_functions["timing_infections"]) +
604
      cpp_ll_timing_sampling(data, param, i, list_functions["timing_sampling"]) +
605
      cpp_ll_genetic(data, param, i, list_functions["genetic"]) +
606
      cpp_ll_reporting(data, param, i, list_functions["reporting"]) +
607
      cpp_ll_contact(data, param, i, list_functions["contact"]);
608
609
  }
610
}
611
612
613
double cpp_ll_all(Rcpp::List data, Rcpp::List param, size_t i,
614
              Rcpp::RObject custom_function) {
615
  SEXP si = PROTECT(Rcpp::wrap(i));
616
  double ret = cpp_ll_all(data, param, si, custom_function);
617
  UNPROTECT(1);
618
  return ret;
619
}