In this vignette, we consider optimal tensor transport (OTT), which is an extension of OT to be able to handle tensors of any order by learning possibly multiple transport plans.
Here, we reproduce the experiments in the original paper (Kerdoncuff 2022). For the details of the methodology, see the original paper.
library("otTensor")
.show_matrix <- function(mat, main = ""){
mat_rev <- apply(mat, 2, rev)
mat_rev <- t(mat_rev)
row_index <- 1:ncol(mat_rev)
col_index <- 1:nrow(mat_rev)
# grayscale
image(mat_rev, col = gray((0:255)/255), xaxt = "n", yaxt = "n",
xlab = "", ylab = "", axes = FALSE, main = main)
}D <- 1 A <- 1 Is <- c(4) Ks <- c(7) f <- c(1) arrX <- array(rep(0, prod(Is)), Is) arrY <- array(rep(0, prod(Ks)), Ks)
for (i1 in 1:Is[1]) { arrX[i1] <- i1 } for (k1 in 1:Ks[1]) { arrY[k1] <- k1 }
ps <- list() for (a in 1:A) { ds <- which(f == a) d <- ds[1] length_of_p_a <- dim(arrX)[d] ps[[a]] <- rep(0.01, length_of_p_a); ps[[a]][c(1, 3)] <- 1 ps[[a]] <- ps[[a]] / sum(ps[[a]]) } qs <- list() for (a in 1:A) { ds <- which(f == a) d <- ds[1] length_of_q_a <- dim(arrY)[d] qs[[a]] <- rep(1, length_of_q_a); qs[[a]][c(2, 3)] <- 0 qs[[a]] <- qs[[a]] / sum(qs[[a]]) }
X <- as.tensor(arrX) Y <- as.tensor(arrY)
out <- OTT(X = X, Y = Y, D = D, A = A, Is = Is, Ks = Ks, f = f, ps=ps, qs=qs, num.sample=1000, loss = function (x, y) {abs(x - y)}, num.iter=200, epsilon=1e-10)
options(repr.plot.width=6, repr.plot.height=10) par(mfrow=c(3, 2)) plot(ps[[1]], type=“h”, col=“red”, ylim=c(0, 1), main=“ps[[1]]”) plot(qs[[1]], type=“h”, col=“red”, ylim=c(0, 1), main=“qs[[1]]”) .show_matrix(out$Ts[[1]], main=“Ts[[1]]”) plot(arrX, type=“h”, col=“black”, main=“arrX”) plot(arrY, type=“h”, col=“black”, main=“arrY”)
D <- 2 A <- 2 Is <- c(4, 5) Ks <- c(7, 8) f <- c(1, 2) arrX <- array(rep(0, prod(Is)), Is) arrY <- array(rep(0, prod(Ks)), Ks)
for (i1 in 1:Is[1]) { for (i2 in 1:Is[2]) { arrX[i1, i2] <- i1 + i2 } } for (k1 in 1:Ks[1]) { for (k2 in 1:Ks[2]) { arrY[k1, k2] <- k1 + k2 } }
ps <- list() for (a in 1:A) { ds <- which(f == a) d <- ds[1] length_of_p_a <- dim(arrX)[d] ps[[a]] <- rep(0.01, length_of_p_a); ps[[a]][c(1, 3)] <- 1 ps[[a]] <- ps[[a]] / sum(ps[[a]]) } qs <- list() for (a in 1:A) { ds <- which(f == a) d <- ds[1] length_of_q_a <- dim(arrY)[d] qs[[a]] <- rep(1, length_of_q_a); qs[[a]][c(2, 3)] <- 0 qs[[a]] <- qs[[a]] / sum(qs[[a]]) }
X <- as.tensor(arrX) Y <- as.tensor(arrY)
out <- OTT(X = X, Y = Y, D = D, A = A, Is = Is, Ks = Ks, f = f, ps=ps, qs=qs, num.sample=1000, loss = function (x, y) {abs(x - y)}, num.iter=200, epsilon=1e-10)
options(repr.plot.width=6, repr.plot.height=10) par(mfrow=c(3, 2)) plot(ps[[1]], type=“h”, col=“red”, ylim=c(0, 1), main=“ps[[1]]”) plot(qs[[1]], type=“h”, col=“red”, ylim=c(0, 1), main=“qs[[1]]”) .show_matrix(out$Ts[[1]], main=“Ts[[1]]”) .show_matrix(arrX, main=“arrX”) .show_matrix(arrY, main=“arrY”)
par(mfrow=c(3, 2)) plot(ps[[2]], type=“h”, col=“red”, ylim=c(0, 1), main=“ps[[2]]”) plot(qs[[2]], type=“h”, col=“red”, ylim=c(0, 1), main=“qs[[2]]”) .show_matrix(out$Ts[[2]], main=“Ts[[2]]”) .show_matrix(arrX, main=“arrX”) .show_matrix(arrY, main=“arrY”)
D <- 2 A <- 1 Is <- c(4, 4) Ks <- c(6, 6) f <- c(1, 1) arrX <- array(rep(0, prod(Is)), Is) arrY <- array(rep(0, prod(Ks)), Ks)
for (i1 in 1:Is[1]) { for (i2 in 1:Is[2]) { arrX[i1, i2] <- i1 + i2 } } for (k1 in 1:Ks[1]) { for (k2 in 1:Ks[2]) { arrY[k1, k2] <- k1 + k2 } }
ps <- list() for (a in 1:A) { ds <- which(f == a) d <- ds[1] length_of_p_a <- dim(arrX)[d] ps[[a]] <- rep(0.01, length_of_p_a); ps[[a]][c(1, 3)] <- 1 ps[[a]] <- ps[[a]] / sum(ps[[a]]) } qs <- list() for (a in 1:A) { ds <- which(f == a) d <- ds[1] length_of_q_a <- dim(arrY)[d] qs[[a]] <- rep(1, length_of_q_a); qs[[a]][c(2, 3)] <- 0 qs[[a]] <- qs[[a]] / sum(qs[[a]]) }
X <- as.tensor(arrX) Y <- as.tensor(arrY)
out <- OTT(X = X, Y = Y, D = D, A = A, Is = Is, Ks = Ks, f = f, ps=ps, qs=qs, num.sample=1000, loss = function (x, y) {abs(x - y)}, num.iter=200, epsilon=1e-10)
options(repr.plot.width=6, repr.plot.height=10) par(mfrow=c(3, 2)) plot(ps[[1]], type=“h”, col=“red”, ylim=c(0, 1), main=“ps[[1]]”) plot(qs[[1]], type=“h”, col=“red”, ylim=c(0, 1), main=“qs[[1]]”) .show_matrix(out$Ts[[1]], main=“Ts[[1]]”) .show_matrix(arrX, main=“arrX”) .show_matrix(arrY, main=“arrY”)
D <- 3 A <- 1 Is <- c(4, 4, 4) Ks <- c(6, 6, 6) f <- c(1, 1, 1) arrX <- array(rep(0, prod(Is)), Is) arrY <- array(rep(0, prod(Ks)), Ks)
for (i1 in 1:Is[1]) { for (i2 in 1:Is[2]) { for (i3 in 1:Is[3]) { arrX[i1, i2, i3] <- i1 + i2 + i3 } } } for (k1 in 1:Ks[1]) { for (k2 in 1:Ks[2]) { for (k3 in 1:Ks[3]) { arrY[k1, k2, k3] <- k1 + k2 + k3 } } }
ps <- list() for (a in 1:A) { ds <- which(f == a) d <- ds[1] length_of_p_a <- dim(arrX)[d] ps[[a]] <- rep(0.01, length_of_p_a); ps[[a]][c(1, 3)] <- 1 ps[[a]] <- ps[[a]] / sum(ps[[a]]) } qs <- list() for (a in 1:A) { ds <- which(f == a) d <- ds[1] length_of_q_a <- dim(arrY)[d] qs[[a]] <- rep(1, length_of_q_a); qs[[a]][c(2, 3)] <- 0 qs[[a]] <- qs[[a]] / sum(qs[[a]]) }
X <- as.tensor(arrX) Y <- as.tensor(arrY)
out <- OTT(X = X, Y = Y, D = D, A = A, Is = Is, Ks = Ks, f = f, ps=ps, qs=qs, num.sample=1000, loss = function (x, y) {abs(x - y)}, num.iter=200, epsilon=1e-10)
options(repr.plot.width=6, repr.plot.height=10) par(mfrow=c(3, 2)) plot(ps[[1]], type=“h”, col=“red”, ylim=c(0, 1), main=“ps[[1]]”) plot(qs[[1]], type=“h”, col=“red”, ylim=c(0, 1), main=“qs[[1]]”) .show_matrix(out$Ts[[1]], main=“Ts[[1]]”) .show_matrix(arrX[,,1], main=“arrX[,,1]”) .show_matrix(arrY[,,1], main=“arrY[,,1]”)
D <- 3 A <- 3 Is <- c(4, 5, 6) Ks <- c(7, 8, 9) f <- c(1, 2, 3) arrX <- array(rep(0, prod(Is)), Is) arrY <- array(rep(0, prod(Ks)), Ks)
for (i1 in 1:Is[1]) { for (i2 in 1:Is[2]) { for (i3 in 1:Is[3]) { arrX[i1, i2, i3] <- i1 + i2 + i3 } } } for (k1 in 1:Ks[1]) { for (k2 in 1:Ks[2]) { for (k3 in 1:Ks[3]) { arrY[k1, k2, k3] <- k1 + k2 + k3 } } }
ps <- list() for (a in 1:A) { ds <- which(f == a) d <- ds[1] length_of_p_a <- dim(arrX)[d] ps[[a]] <- rep(0.01, length_of_p_a); ps[[a]][c(1, 3)] <- 1 ps[[a]] <- ps[[a]] / sum(ps[[a]]) } qs <- list() for (a in 1:A) { ds <- which(f == a) d <- ds[1] length_of_q_a <- dim(arrY)[d] qs[[a]] <- rep(1, length_of_q_a); qs[[a]][c(2, 3)] <- 0 qs[[a]] <- qs[[a]] / sum(qs[[a]]) }
X <- as.tensor(arrX) Y <- as.tensor(arrY)
out <- OTT(X = X, Y = Y, D = D, A = A, Is = Is, Ks = Ks, f = f, ps=ps, qs=qs, num.sample=1000, loss = function (x, y) {abs(x - y)}, num.iter=200, epsilon=1e-10)
options(repr.plot.width=6, repr.plot.height=10) par(mfrow=c(3, 2)) plot(ps[[1]], type=“h”, col=“red”, ylim=c(0, 1), main=“ps[[1]]”) plot(qs[[1]], type=“h”, col=“red”, ylim=c(0, 1), main=“qs[[1]]”) .show_matrix(out$Ts[[1]], main=“Ts[[1]]”) .show_matrix(arrX[,,1], main=“arrX[,,1]”) .show_matrix(arrY[,,1], main=“arrY[,,1]”)
par(mfrow=c(3, 2)) plot(ps[[2]], type=“h”, col=“red”, ylim=c(0, 1), main=“ps[[2]]”) plot(qs[[2]], type=“h”, col=“red”, ylim=c(0, 1), main=“qs[[2]]”) .show_matrix(out$Ts[[2]], main=“Ts[[2]]”) .show_matrix(arrX[,,2], main=“arrX[,,2]”) .show_matrix(arrY[,,2], main=“arrY[,,2]”)
par(mfrow=c(3, 2)) plot(ps[[3]], type=“h”, col=“red”, ylim=c(0, 1), main=“ps[[3]]”) plot(qs[[3]], type=“h”, col=“red”, ylim=c(0, 1), main=“qs[[3]]”) .show_matrix(out$Ts[[3]], main=“Ts[[3]]”) .show_matrix(arrX[,,3], main=“arrX[,,3]”) .show_matrix(arrY[,,3], main=“arrY[,,3]”)
D <- 3 A <- 2 Is <- c(4, 4, 5) Ks <- c(6, 6, 7) f <- c(1, 1, 2) arrX <- array(rep(0, prod(Is)), Is) arrY <- array(rep(0, prod(Ks)), Ks)
for (i1 in 1:Is[1]) { for (i2 in 1:Is[2]) { for (i3 in 1:Is[3]) { arrX[i1, i2, i3] <- i1 + i2 + i3 } } } for (k1 in 1:Ks[1]) { for (k2 in 1:Ks[2]) { for (k3 in 1:Ks[3]) { arrY[k1, k2, k3] <- k1 + k2 + k3 } } }
ps <- list() for (a in 1:A) { ds <- which(f == a) d <- ds[1] length_of_p_a <- dim(arrX)[d] ps[[a]] <- rep(0.01, length_of_p_a); ps[[a]][c(1, 3)] <- 1 ps[[a]] <- ps[[a]] / sum(ps[[a]]) } qs <- list() for (a in 1:A) { ds <- which(f == a) d <- ds[1] length_of_q_a <- dim(arrY)[d] qs[[a]] <- rep(1, length_of_q_a); qs[[a]][c(2, 3)] <- 0 qs[[a]] <- qs[[a]] / sum(qs[[a]]) }
X <- as.tensor(arrX) Y <- as.tensor(arrY)
out <- OTT(X = X, Y = Y, D = D, A = A, Is = Is, Ks = Ks, f = f, ps=ps, qs=qs, num.sample=1000, loss = function (x, y) {abs(x - y)}, num.iter=200, epsilon=1e-10)
options(repr.plot.width=6, repr.plot.height=10) par(mfrow=c(3, 2)) plot(ps[[1]], type=“h”, col=“red”, ylim=c(0, 1), main=“ps[[1]]”) plot(qs[[1]], type=“h”, col=“red”, ylim=c(0, 1), main=“qs[[1]]”) .show_matrix(out$Ts[[1]], main=“Ts[[1]]”) .show_matrix(arrX[,,1], main=“arrX[,,1]”) .show_matrix(arrY[,,1], main=“arrY[,,1]”)
par(mfrow=c(3, 2)) plot(ps[[2]], type=“h”, col=“red”, ylim=c(0, 1), main=“ps[[2]]”) plot(qs[[2]], type=“h”, col=“red”, ylim=c(0, 1), main=“qs[[2]]”) .show_matrix(out$Ts[[2]], main=“Ts[[2]]”) .show_matrix(arrX[,,2], main=“arrX[,,2]”) .show_matrix(arrY[,,2], main=“arrY[,,2]”)
## R version 3.6.3 (2020-02-29)
## Platform: x86_64-conda-linux-gnu (64-bit)
## Running under: Rocky Linux 9.5 (Blue Onyx)
##
## Matrix products: default
## BLAS: /home/koki/miniconda3/lib/libblas.so.3.9.0
## LAPACK: /home/koki/miniconda3/lib/liblapack.so.3.9.0
##
## locale:
## [1] LC_CTYPE=en_US.UTF-8 LC_NUMERIC=C
## [3] LC_TIME=en_US.UTF-8 LC_COLLATE=C
## [5] LC_MONETARY=en_US.UTF-8 LC_MESSAGES=en_US.UTF-8
## [7] LC_PAPER=en_US.UTF-8 LC_NAME=C
## [9] LC_ADDRESS=C LC_TELEPHONE=C
## [11] LC_MEASUREMENT=en_US.UTF-8 LC_IDENTIFICATION=C
##
## attached base packages:
## [1] stats graphics grDevices utils datasets methods base
##
## other attached packages:
## [1] rTensor_1.4.8 otTensor_0.99.0
##
## loaded via a namespace (and not attached):
## [1] digest_0.6.31 R6_2.5.1 jsonlite_1.8.4 evaluate_0.20
## [5] highr_0.10 rlang_0.4.11 jquerylib_0.1.4 bslib_0.3.1
## [9] rmarkdown_2.11 tools_3.6.3 xfun_0.38 yaml_2.3.7
## [13] fastmap_1.1.1 compiler_3.6.3 htmltools_0.5.5 knitr_1.42
## [17] sass_0.4.0