Diff of /src/velocity_umap.R [000000] .. [fb5156]

Switch to unified view

a b/src/velocity_umap.R
1
show.velocity.on.embedding.cor_umapEmbed<-function(emb, vel, n = 100, cell.colors = NULL, corr.sigma = 0.05, 
2
                                                   show.grid.flow = FALSE, grid.n = 20, grid.sd = NULL, min.grid.cell.mass = 1, 
3
                                                   min.arrow.size = NULL, arrow.scale = 1, max.grid.arrow.length = NULL, 
4
                                                   fixed.arrow.length = FALSE, plot.grid.points = FALSE, scale = "log", 
5
                                                   nPcs = NA, arrow.lwd = 1, xlab = "", ylab = "", n.cores = 1, 
6
                                                   do.par = T, show.cell = NULL, cell.border.alpha = 0.3, cc = NULL, 
7
                                                   return.details = FALSE, expression.scaling = FALSE){
8
  randomize <- FALSE
9
  if (do.par) 
10
    par(mfrow = c(1, 1), mar = c(3.5, 3.5, 2.5, 1.5), mgp = c(2, 
11
                                                              0.65, 0), cex = 0.85)
12
  celcol <- "white"
13
  if (is.null(show.cell)) {
14
    celcol <- cell.colors[rownames(emb)]
15
  }
16
  plot(emb, bg = celcol, pch = 21, col = ac(1, alpha = cell.border.alpha), 
17
       xlab = xlab, ylab = ylab)
18
  em <- as.matrix(vel$current)
19
  ccells <- intersect(rownames(emb), colnames(em))
20
  em <- em[, ccells]
21
  emb <- emb[ccells, ]
22
  nd <- as.matrix(vel$deltaE[, ccells])
23
  cgenes <- intersect(rownames(em), rownames(nd))
24
  nd <- nd[cgenes, ]
25
  em <- em[cgenes, ]
26
  if (randomize) {
27
    nd <- t(apply(nd, 1, function(x) (rbinom(length(x), 1, 
28
                                             0.5) * 2 - 1) * abs(sample(x))))
29
  }
30
  if (is.null(cc)) {
31
    cat("delta projections ... ")
32
    if (scale == "log") {
33
      cat("log ")
34
      cc <- velocyto.R:::colDeltaCorLog10(em, (log10(abs(nd) + 1) * 
35
                                                 sign(nd)), nthreads = n.cores)
36
    }
37
    else if (scale == "sqrt") {
38
      cat("sqrt ")
39
      cc <- velocyto.R:::colDeltaCorSqrt(em, (sqrt(abs(nd)) * sign(nd)), 
40
                                         nthreads = n.cores)
41
    }
42
    else if (scale == "rank") {
43
      cat("rank ")
44
      cc <- velocyto.R:::colDeltaCor((apply(em, 2, rank)), (apply(nd, 
45
                                                                  2, rank)), nthreads = n.cores)
46
    }
47
    else {
48
      cat("linear ")
49
      cc <- velocyto.R:::colDeltaCor(em, nd, nthreads = n.cores)
50
    }
51
    colnames(cc) <- rownames(cc) <- colnames(em)
52
    diag(cc) <- 0
53
  }
54
  cat("knn ... ")
55
  if (n > nrow(cc)) {
56
    n <- nrow(cc)
57
  }
58
  emb.knn <- velocyto.R:::balancedKNN(t(emb), k = n, maxl = nrow(emb), dist = "euclidean", 
59
                                      n.threads = n.cores)
60
  diag(emb.knn) <- 1
61
  cat("transition probs ... ")
62
  tp <- exp(cc/corr.sigma) * emb.knn
63
  tp <- t(t(tp)/Matrix::colSums(tp,na.rm=T))
64
  tp <- as(tp, "dgCMatrix")
65
  cat("done\n")
66
  if (!is.null(show.cell)) {
67
    i <- match(show.cell, rownames(emb))
68
    if (is.na(i)) 
69
      stop(paste("specified cell", i, "is not in the embedding"))
70
    points(emb, pch = 19, col = ac(val2col(tp[rownames(emb), 
71
                                              show.cell], gradient.range.quantile = 1), alpha = 0.5))
72
    points(emb[show.cell, 1], emb[show.cell, 2], pch = 3, 
73
           cex = 1, col = 1)
74
    di <- t(t(emb) - emb[i, ])
75
    di <- di/sqrt(Matrix::rowSums(di^2)) * arrow.scale
76
    di[i, ] <- 0
77
    dir <- Matrix::colSums(di * tp[, i])
78
    dic <- Matrix::colSums(di * (tp[, i] > 0)/sum(tp[, i] > 
79
                                                    0))
80
    dia <- dir - dic
81
    suppressWarnings(arrows(emb[colnames(em)[i], 1], emb[colnames(em)[i], 
82
                                                         2], emb[colnames(em)[i], 1] + dic[1], emb[colnames(em)[i], 
83
                                                                                                   2] + dic[2], length = 0.05, lwd = 1, col = "blue"))
84
    suppressWarnings(arrows(emb[colnames(em)[i], 1], emb[colnames(em)[i], 
85
                                                         2], emb[colnames(em)[i], 1] + dir[1], emb[colnames(em)[i], 
86
                                                                                                   2] + dir[2], length = 0.05, lwd = 1, col = "red"))
87
    suppressWarnings(arrows(emb[colnames(em)[i], 1] + dic[1], 
88
                            emb[colnames(em)[i], 2] + dic[2], emb[colnames(em)[i], 
89
                                                                  1] + dir[1], emb[colnames(em)[i], 2] + dir[2], 
90
                            length = 0.05, lwd = 1, lty = 1, col = "grey50"))
91
    suppressWarnings(arrows(emb[colnames(em)[i], 1], emb[colnames(em)[i], 
92
                                                         2], emb[colnames(em)[i], 1] + dia[1], emb[colnames(em)[i], 
93
                                                                                                   2] + dia[2], length = 0.05, lwd = 1, col = "black"))
94
  }
95
  else {
96
    cat("calculating arrows ... ")
97
    tp=tp
98
    tp[is.na(tp)] <- 0
99
    arsd <- data.frame(t(velocyto.R:::embArrows(as.matrix(emb), tmp, 1, 
100
                                                n.cores)))
101
    rownames(arsd) <- rownames(emb)
102
    if (expression.scaling) {
103
      tpb <- tp > 0
104
      tpb <- t(t(tpb)/colSums(tpb))
105
      es <- as.matrix(em %*% tp) - as.matrix(em %*% as.matrix(tpb))
106
      pl <- pmin(1, pmax(0, apply(as.matrix(vel$deltaE[, 
107
                                                       colnames(es)]) * es, 2, sum)/sqrt(colSums(es * 
108
                                                                                                   es))))
109
      arsd <- arsd * pl
110
    }
111
    ars <- data.frame(cbind(emb, emb + arsd))
112
    colnames(ars) <- c("x0", "y0", "x1", "y1")
113
    colnames(arsd) <- c("xd", "yd")
114
    rownames(ars) <- rownames(emb)
115
    cat("done\n")
116
    if (show.grid.flow) {
117
      cat("grid estimates ... ")
118
      rx <- range(c(range(ars$x0), range(ars$x1)))
119
      ry <- range(c(range(ars$y0), range(ars$y1)))
120
      gx <- seq(rx[1], rx[2], length.out = grid.n)
121
      gy <- seq(ry[1], ry[2], length.out = grid.n)
122
      if (is.null(grid.sd)) {
123
        grid.sd <- sqrt((gx[2] - gx[1])^2 + (gy[2] - 
124
                                               gy[1])^2)/2
125
        cat("grid.sd=", grid.sd, " ")
126
      }
127
      if (is.null(min.arrow.size)) {
128
        min.arrow.size <- sqrt((gx[2] - gx[1])^2 + (gy[2] - 
129
                                                      gy[1])^2) * 0.01
130
        cat("min.arrow.size=", min.arrow.size, " ")
131
      }
132
      if (is.null(max.grid.arrow.length)) {
133
        max.grid.arrow.length <- sqrt(sum((par("pin")/c(length(gx), 
134
                                                        length(gy)))^2)) * 0.25
135
        cat("max.grid.arrow.length=", max.grid.arrow.length, 
136
            " ")
137
      }
138
      garrows <- do.call(rbind, lapply(gx, function(x) {
139
        cd <- sqrt(outer(emb[, 2], -gy, "+")^2 + (x - 
140
                                                    emb[, 1])^2)
141
        cw <- dnorm(cd, sd = grid.sd)
142
        gw <- Matrix::colSums(cw)
143
        cws <- pmax(1, Matrix::colSums(cw))
144
        gxd <- Matrix::colSums(cw * arsd$xd)/cws
145
        gyd <- Matrix::colSums(cw * arsd$yd)/cws
146
        al <- sqrt(gxd^2 + gyd^2)
147
        vg <- gw >= min.grid.cell.mass & al >= min.arrow.size
148
        cbind(rep(x, sum(vg)), gy[vg], x + gxd[vg], gy[vg] + 
149
                gyd[vg])
150
      }))
151
      colnames(garrows) <- c("x0", "y0", "x1", "y1")
152
      if (fixed.arrow.length) {
153
        suppressWarnings(arrows(garrows[, 1], garrows[, 
154
                                                      2], garrows[, 3], garrows[, 4], length = 0.05, 
155
                                lwd = arrow.lwd))
156
      }
157
      else {
158
        alen <- pmin(max.grid.arrow.length, sqrt(((garrows[, 
159
                                                           3] - garrows[, 1]) * par("pin")[1]/diff(par("usr")[c(1, 
160
                                                                                                                2)]))^2 + ((garrows[, 4] - garrows[, 2]) * 
161
                                                                                                                             par("pin")[2]/diff(par("usr")[c(3, 4)]))^2))
162
        suppressWarnings(lapply(1:nrow(garrows), function(i) arrows(garrows[i, 
163
                                                                            1], garrows[i, 2], garrows[i, 3], garrows[i, 
164
                                                                                                                      4], length = alen[i], lwd = arrow.lwd)))
165
      }
166
      if (plot.grid.points) 
167
        points(rep(gx, each = length(gy)), rep(gy, length(gx)), 
168
               pch = ".", cex = 0.1, col = ac(1, alpha = 0.4))
169
      cat("done\n")
170
      if (return.details) {
171
        cat("expression shifts .")
172
        scale.int <- switch(scale, log = 2, sqrt = 3, 
173
                            1)
174
        if (!expression.scaling) {
175
          tpb <- tp > 0
176
          tpb <- t(t(tpb)/colSums(tpb))
177
          es <- as.matrix(em %*% tp) - as.matrix(em %*% 
178
                                                   as.matrix(tpb))
179
        }
180
        cat(".")
181
        gs <- do.call(cbind, parallel::mclapply(gx, function(x) {
182
          cd <- sqrt(outer(emb[, 2], -gy, "+")^2 + (x - 
183
                                                      emb[, 1])^2)
184
          cw <- dnorm(cd, sd = grid.sd)
185
          gw <- Matrix::colSums(cw)
186
          cws <- pmax(1, Matrix::colSums(cw))
187
          cw <- t(t(cw)/cws)
188
          gxd <- Matrix::colSums(cw * arsd$xd)
189
          gyd <- Matrix::colSums(cw * arsd$yd)
190
          al <- sqrt(gxd^2 + gyd^2)
191
          vg <- gw >= min.grid.cell.mass & al >= min.arrow.size
192
          if (any(vg)) {
193
            z <- es %*% cw[, vg]
194
          }
195
          else {
196
            NULL
197
          }
198
        }, mc.cores = n.cores, mc.preschedule = T))
199
        if (scale == "log") {
200
          nd <- (log10(abs(nd) + 1) * sign(nd))
201
        }
202
        else if (scale == "sqrt") {
203
          nd <- (sqrt(abs(nd)) * sign(nd))
204
        }
205
        cat(".")
206
        gv <- do.call(cbind, parallel::mclapply(gx, function(x) {
207
          cd <- sqrt(outer(emb[, 2], -gy, "+")^2 + (x - 
208
                                                      emb[, 1])^2)
209
          cw <- dnorm(cd, sd = grid.sd)
210
          gw <- Matrix::colSums(cw)
211
          cws <- pmax(1, Matrix::colSums(cw))
212
          cw <- t(t(cw)/cws)
213
          gxd <- Matrix::colSums(cw * arsd$xd)
214
          gyd <- Matrix::colSums(cw * arsd$yd)
215
          al <- sqrt(gxd^2 + gyd^2)
216
          vg <- gw >= min.grid.cell.mass & al >= min.arrow.size
217
          if (any(vg)) {
218
            z <- nd %*% cw[, vg]
219
          }
220
          else {
221
            NULL
222
          }
223
        }, mc.cores = n.cores, mc.preschedule = T))
224
        cat(". done\n")
225
        return(invisible(list(tp = tp, cc = cc, garrows = garrows, 
226
                              arrows = as.matrix(ars), vel = nd, eshifts = es, 
227
                              gvel = gv, geshifts = gs, scale = scale)))
228
      }
229
    }
230
    else {
231
      apply(ars, 1, function(x) {
232
        if (fixed.arrow.length) {
233
          suppressWarnings(arrows(x[1], x[2], x[3], x[4], 
234
                                  length = 0.05, lwd = arrow.lwd))
235
        }
236
        else {
237
          ali <- sqrt(((x[3] - x[1]) * par("pin")[1]/diff(par("usr")[c(1, 
238
                                                                       2)]))^2 + ((x[4] - x[2]) * par("pin")[2]/diff(par("usr")[c(3, 
239
                                                                                                                                  4)]))^2)
240
          suppressWarnings(arrows(x[1], x[2], x[3], x[4], 
241
                                  length = min(0.05, ali), lwd = arrow.lwd))
242
        }
243
      })
244
    }
245
  }
246
  return(invisible(list(tp = tp, cc = cc)))
247
}