--- title: "Optimal Tensor Transport" author: - name: Koki Tsuyuzaki affiliation: Laboratory for Bioinformatics Research, RIKEN Center for Biosystems Dynamics Research email: k.t.the-answer@hotmail.co.jp date: "`r Sys.Date()`" bibliography: bibliography.bib package: otTensor output: rmarkdown::html_vignette vignette: | %\VignetteIndexEntry{Optimal Tensor Transport} %\VignetteEngine{knitr::rmarkdown} %\VignetteEncoding{UTF-8} --- # Introduction In this vignette, we consider optimal tensor transport (OTT), which is an extension of OT to be able to handle tensors of any order by learning possibly multiple transport plans. Here, we reproduce the experiments in the original paper [@ott]. For the details of the methodology, see the original paper. ```{r setting, echo=TRUE} library("otTensor") .show_matrix <- function(mat, main = ""){ mat_rev <- apply(mat, 2, rev) mat_rev <- t(mat_rev) row_index <- 1:ncol(mat_rev) col_index <- 1:nrow(mat_rev) # grayscale image(mat_rev, col = gray((0:255)/255), xaxt = "n", yaxt = "n", xlab = "", ylab = "", axes = FALSE, main = main) } ``` # OTT_1 (OT) D <- 1 A <- 1 Is <- c(4) Ks <- c(7) f <- c(1) arrX <- array(rep(0, prod(Is)), Is) arrY <- array(rep(0, prod(Ks)), Ks) for (i1 in 1:Is[1]) { arrX[i1] <- i1 } for (k1 in 1:Ks[1]) { arrY[k1] <- k1 } ps <- list() for (a in 1:A) { ds <- which(f == a) d <- ds[1] length_of_p_a <- dim(arrX)[d] ps[[a]] <- rep(0.01, length_of_p_a); ps[[a]][c(1, 3)] <- 1 ps[[a]] <- ps[[a]] / sum(ps[[a]]) } qs <- list() for (a in 1:A) { ds <- which(f == a) d <- ds[1] length_of_q_a <- dim(arrY)[d] qs[[a]] <- rep(1, length_of_q_a); qs[[a]][c(2, 3)] <- 0 qs[[a]] <- qs[[a]] / sum(qs[[a]]) } X <- as.tensor(arrX) Y <- as.tensor(arrY) out <- OTT(X = X, Y = Y, D = D, A = A, Is = Is, Ks = Ks, f = f, ps=ps, qs=qs, num.sample=1000, loss = function (x, y) {abs(x - y)}, num.iter=200, epsilon=1e-10) options(repr.plot.width=6, repr.plot.height=10) par(mfrow=c(3, 2)) plot(ps[[1]], type="h", col="red", ylim=c(0, 1), main="ps[[1]]") plot(qs[[1]], type="h", col="red", ylim=c(0, 1), main="qs[[1]]") .show_matrix(out$Ts[[1]], main="Ts[[1]]") plot(arrX, type="h", col="black", main="arrX") plot(arrY, type="h", col="black", main="arrY") # OTT_12 (Co-OT) D <- 2 A <- 2 Is <- c(4, 5) Ks <- c(7, 8) f <- c(1, 2) arrX <- array(rep(0, prod(Is)), Is) arrY <- array(rep(0, prod(Ks)), Ks) for (i1 in 1:Is[1]) { for (i2 in 1:Is[2]) { arrX[i1, i2] <- i1 + i2 } } for (k1 in 1:Ks[1]) { for (k2 in 1:Ks[2]) { arrY[k1, k2] <- k1 + k2 } } ps <- list() for (a in 1:A) { ds <- which(f == a) d <- ds[1] length_of_p_a <- dim(arrX)[d] ps[[a]] <- rep(0.01, length_of_p_a); ps[[a]][c(1, 3)] <- 1 ps[[a]] <- ps[[a]] / sum(ps[[a]]) } qs <- list() for (a in 1:A) { ds <- which(f == a) d <- ds[1] length_of_q_a <- dim(arrY)[d] qs[[a]] <- rep(1, length_of_q_a); qs[[a]][c(2, 3)] <- 0 qs[[a]] <- qs[[a]] / sum(qs[[a]]) } X <- as.tensor(arrX) Y <- as.tensor(arrY) out <- OTT(X = X, Y = Y, D = D, A = A, Is = Is, Ks = Ks, f = f, ps=ps, qs=qs, num.sample=1000, loss = function (x, y) {abs(x - y)}, num.iter=200, epsilon=1e-10) options(repr.plot.width=6, repr.plot.height=10) par(mfrow=c(3, 2)) plot(ps[[1]], type="h", col="red", ylim=c(0, 1), main="ps[[1]]") plot(qs[[1]], type="h", col="red", ylim=c(0, 1), main="qs[[1]]") .show_matrix(out$Ts[[1]], main="Ts[[1]]") .show_matrix(arrX, main="arrX") .show_matrix(arrY, main="arrY") par(mfrow=c(3, 2)) plot(ps[[2]], type="h", col="red", ylim=c(0, 1), main="ps[[2]]") plot(qs[[2]], type="h", col="red", ylim=c(0, 1), main="qs[[2]]") .show_matrix(out$Ts[[2]], main="Ts[[2]]") .show_matrix(arrX, main="arrX") .show_matrix(arrY, main="arrY") # OTT_11 (GW) D <- 2 A <- 1 Is <- c(4, 4) Ks <- c(6, 6) f <- c(1, 1) arrX <- array(rep(0, prod(Is)), Is) arrY <- array(rep(0, prod(Ks)), Ks) for (i1 in 1:Is[1]) { for (i2 in 1:Is[2]) { arrX[i1, i2] <- i1 + i2 } } for (k1 in 1:Ks[1]) { for (k2 in 1:Ks[2]) { arrY[k1, k2] <- k1 + k2 } } ps <- list() for (a in 1:A) { ds <- which(f == a) d <- ds[1] length_of_p_a <- dim(arrX)[d] ps[[a]] <- rep(0.01, length_of_p_a); ps[[a]][c(1, 3)] <- 1 ps[[a]] <- ps[[a]] / sum(ps[[a]]) } qs <- list() for (a in 1:A) { ds <- which(f == a) d <- ds[1] length_of_q_a <- dim(arrY)[d] qs[[a]] <- rep(1, length_of_q_a); qs[[a]][c(2, 3)] <- 0 qs[[a]] <- qs[[a]] / sum(qs[[a]]) } X <- as.tensor(arrX) Y <- as.tensor(arrY) out <- OTT(X = X, Y = Y, D = D, A = A, Is = Is, Ks = Ks, f = f, ps=ps, qs=qs, num.sample=1000, loss = function (x, y) {abs(x - y)}, num.iter=200, epsilon=1e-10) options(repr.plot.width=6, repr.plot.height=10) par(mfrow=c(3, 2)) plot(ps[[1]], type="h", col="red", ylim=c(0, 1), main="ps[[1]]") plot(qs[[1]], type="h", col="red", ylim=c(0, 1), main="qs[[1]]") .show_matrix(out$Ts[[1]], main="Ts[[1]]") .show_matrix(arrX, main="arrX") .show_matrix(arrY, main="arrY") # OTT_111 (triplets) D <- 3 A <- 1 Is <- c(4, 4, 4) Ks <- c(6, 6, 6) f <- c(1, 1, 1) arrX <- array(rep(0, prod(Is)), Is) arrY <- array(rep(0, prod(Ks)), Ks) for (i1 in 1:Is[1]) { for (i2 in 1:Is[2]) { for (i3 in 1:Is[3]) { arrX[i1, i2, i3] <- i1 + i2 + i3 } } } for (k1 in 1:Ks[1]) { for (k2 in 1:Ks[2]) { for (k3 in 1:Ks[3]) { arrY[k1, k2, k3] <- k1 + k2 + k3 } } } ps <- list() for (a in 1:A) { ds <- which(f == a) d <- ds[1] length_of_p_a <- dim(arrX)[d] ps[[a]] <- rep(0.01, length_of_p_a); ps[[a]][c(1, 3)] <- 1 ps[[a]] <- ps[[a]] / sum(ps[[a]]) } qs <- list() for (a in 1:A) { ds <- which(f == a) d <- ds[1] length_of_q_a <- dim(arrY)[d] qs[[a]] <- rep(1, length_of_q_a); qs[[a]][c(2, 3)] <- 0 qs[[a]] <- qs[[a]] / sum(qs[[a]]) } X <- as.tensor(arrX) Y <- as.tensor(arrY) out <- OTT(X = X, Y = Y, D = D, A = A, Is = Is, Ks = Ks, f = f, ps=ps, qs=qs, num.sample=1000, loss = function (x, y) {abs(x - y)}, num.iter=200, epsilon=1e-10) options(repr.plot.width=6, repr.plot.height=10) par(mfrow=c(3, 2)) plot(ps[[1]], type="h", col="red", ylim=c(0, 1), main="ps[[1]]") plot(qs[[1]], type="h", col="red", ylim=c(0, 1), main="qs[[1]]") .show_matrix(out$Ts[[1]], main="Ts[[1]]") .show_matrix(arrX[,,1], main="arrX[,,1]") .show_matrix(arrY[,,1], main="arrY[,,1]") # OTT_123 (triCo-OT) D <- 3 A <- 3 Is <- c(4, 5, 6) Ks <- c(7, 8, 9) f <- c(1, 2, 3) arrX <- array(rep(0, prod(Is)), Is) arrY <- array(rep(0, prod(Ks)), Ks) for (i1 in 1:Is[1]) { for (i2 in 1:Is[2]) { for (i3 in 1:Is[3]) { arrX[i1, i2, i3] <- i1 + i2 + i3 } } } for (k1 in 1:Ks[1]) { for (k2 in 1:Ks[2]) { for (k3 in 1:Ks[3]) { arrY[k1, k2, k3] <- k1 + k2 + k3 } } } ps <- list() for (a in 1:A) { ds <- which(f == a) d <- ds[1] length_of_p_a <- dim(arrX)[d] ps[[a]] <- rep(0.01, length_of_p_a); ps[[a]][c(1, 3)] <- 1 ps[[a]] <- ps[[a]] / sum(ps[[a]]) } qs <- list() for (a in 1:A) { ds <- which(f == a) d <- ds[1] length_of_q_a <- dim(arrY)[d] qs[[a]] <- rep(1, length_of_q_a); qs[[a]][c(2, 3)] <- 0 qs[[a]] <- qs[[a]] / sum(qs[[a]]) } X <- as.tensor(arrX) Y <- as.tensor(arrY) out <- OTT(X = X, Y = Y, D = D, A = A, Is = Is, Ks = Ks, f = f, ps=ps, qs=qs, num.sample=1000, loss = function (x, y) {abs(x - y)}, num.iter=200, epsilon=1e-10) options(repr.plot.width=6, repr.plot.height=10) par(mfrow=c(3, 2)) plot(ps[[1]], type="h", col="red", ylim=c(0, 1), main="ps[[1]]") plot(qs[[1]], type="h", col="red", ylim=c(0, 1), main="qs[[1]]") .show_matrix(out$Ts[[1]], main="Ts[[1]]") .show_matrix(arrX[,,1], main="arrX[,,1]") .show_matrix(arrY[,,1], main="arrY[,,1]") par(mfrow=c(3, 2)) plot(ps[[2]], type="h", col="red", ylim=c(0, 1), main="ps[[2]]") plot(qs[[2]], type="h", col="red", ylim=c(0, 1), main="qs[[2]]") .show_matrix(out$Ts[[2]], main="Ts[[2]]") .show_matrix(arrX[,,2], main="arrX[,,2]") .show_matrix(arrY[,,2], main="arrY[,,2]") par(mfrow=c(3, 2)) plot(ps[[3]], type="h", col="red", ylim=c(0, 1), main="ps[[3]]") plot(qs[[3]], type="h", col="red", ylim=c(0, 1), main="qs[[3]]") .show_matrix(out$Ts[[3]], main="Ts[[3]]") .show_matrix(arrX[,,3], main="arrX[,,3]") .show_matrix(arrY[,,3], main="arrY[,,3]") # OTT_112 (GW Collection) D <- 3 A <- 2 Is <- c(4, 4, 5) Ks <- c(6, 6, 7) f <- c(1, 1, 2) arrX <- array(rep(0, prod(Is)), Is) arrY <- array(rep(0, prod(Ks)), Ks) for (i1 in 1:Is[1]) { for (i2 in 1:Is[2]) { for (i3 in 1:Is[3]) { arrX[i1, i2, i3] <- i1 + i2 + i3 } } } for (k1 in 1:Ks[1]) { for (k2 in 1:Ks[2]) { for (k3 in 1:Ks[3]) { arrY[k1, k2, k3] <- k1 + k2 + k3 } } } ps <- list() for (a in 1:A) { ds <- which(f == a) d <- ds[1] length_of_p_a <- dim(arrX)[d] ps[[a]] <- rep(0.01, length_of_p_a); ps[[a]][c(1, 3)] <- 1 ps[[a]] <- ps[[a]] / sum(ps[[a]]) } qs <- list() for (a in 1:A) { ds <- which(f == a) d <- ds[1] length_of_q_a <- dim(arrY)[d] qs[[a]] <- rep(1, length_of_q_a); qs[[a]][c(2, 3)] <- 0 qs[[a]] <- qs[[a]] / sum(qs[[a]]) } X <- as.tensor(arrX) Y <- as.tensor(arrY) out <- OTT(X = X, Y = Y, D = D, A = A, Is = Is, Ks = Ks, f = f, ps=ps, qs=qs, num.sample=1000, loss = function (x, y) {abs(x - y)}, num.iter=200, epsilon=1e-10) options(repr.plot.width=6, repr.plot.height=10) par(mfrow=c(3, 2)) plot(ps[[1]], type="h", col="red", ylim=c(0, 1), main="ps[[1]]") plot(qs[[1]], type="h", col="red", ylim=c(0, 1), main="qs[[1]]") .show_matrix(out$Ts[[1]], main="Ts[[1]]") .show_matrix(arrX[,,1], main="arrX[,,1]") .show_matrix(arrY[,,1], main="arrY[,,1]") par(mfrow=c(3, 2)) plot(ps[[2]], type="h", col="red", ylim=c(0, 1), main="ps[[2]]") plot(qs[[2]], type="h", col="red", ylim=c(0, 1), main="qs[[2]]") .show_matrix(out$Ts[[2]], main="Ts[[2]]") .show_matrix(arrX[,,2], main="arrX[,,2]") .show_matrix(arrY[,,2], main="arrY[,,2]") # Session Information {.unnumbered} ```{r sessionInfo, echo=FALSE} sessionInfo() ``` # References