Optimal Tensor Transport

Koki Tsuyuzaki

Laboratory for Bioinformatics Research, RIKEN Center for Biosystems Dynamics Research
k.t.the-answer@hotmail.co.jp

2026-05-08

Introduction

In this vignette, we consider optimal tensor transport (OTT), which is an extension of OT to be able to handle tensors of any order by learning possibly multiple transport plans.

Here, we reproduce the experiments in the original paper (Kerdoncuff 2022). For the details of the methodology, see the original paper.

library("otTensor")

.show_matrix <- function(mat, main = ""){
    mat_rev <- apply(mat, 2, rev)
    mat_rev <- t(mat_rev)

    row_index <- 1:ncol(mat_rev)
    col_index <- 1:nrow(mat_rev)

    # grayscale
    image(mat_rev, col = gray((0:255)/255), xaxt = "n", yaxt = "n",
        xlab = "", ylab = "", axes = FALSE, main = main)
}

OTT_1 (OT)

D <- 1 A <- 1 Is <- c(4) Ks <- c(7) f <- c(1) arrX <- array(rep(0, prod(Is)), Is) arrY <- array(rep(0, prod(Ks)), Ks)

for (i1 in 1:Is[1]) { arrX[i1] <- i1 } for (k1 in 1:Ks[1]) { arrY[k1] <- k1 }

ps <- list() for (a in 1:A) { ds <- which(f == a) d <- ds[1] length_of_p_a <- dim(arrX)[d] ps[[a]] <- rep(0.01, length_of_p_a); ps[[a]][c(1, 3)] <- 1 ps[[a]] <- ps[[a]] / sum(ps[[a]]) } qs <- list() for (a in 1:A) { ds <- which(f == a) d <- ds[1] length_of_q_a <- dim(arrY)[d] qs[[a]] <- rep(1, length_of_q_a); qs[[a]][c(2, 3)] <- 0 qs[[a]] <- qs[[a]] / sum(qs[[a]]) }

X <- as.tensor(arrX) Y <- as.tensor(arrY)

out <- OTT(X = X, Y = Y, D = D, A = A, Is = Is, Ks = Ks, f = f, ps=ps, qs=qs, num.sample=1000, loss = function (x, y) {abs(x - y)}, num.iter=200, epsilon=1e-10)

options(repr.plot.width=6, repr.plot.height=10) par(mfrow=c(3, 2)) plot(ps[[1]], type=“h”, col=“red”, ylim=c(0, 1), main=“ps[[1]]”) plot(qs[[1]], type=“h”, col=“red”, ylim=c(0, 1), main=“qs[[1]]”) .show_matrix(out$Ts[[1]], main=“Ts[[1]]”) plot(arrX, type=“h”, col=“black”, main=“arrX”) plot(arrY, type=“h”, col=“black”, main=“arrY”)

OTT_12 (Co-OT)

D <- 2 A <- 2 Is <- c(4, 5) Ks <- c(7, 8) f <- c(1, 2) arrX <- array(rep(0, prod(Is)), Is) arrY <- array(rep(0, prod(Ks)), Ks)

for (i1 in 1:Is[1]) { for (i2 in 1:Is[2]) { arrX[i1, i2] <- i1 + i2 } } for (k1 in 1:Ks[1]) { for (k2 in 1:Ks[2]) { arrY[k1, k2] <- k1 + k2 } }

ps <- list() for (a in 1:A) { ds <- which(f == a) d <- ds[1] length_of_p_a <- dim(arrX)[d] ps[[a]] <- rep(0.01, length_of_p_a); ps[[a]][c(1, 3)] <- 1 ps[[a]] <- ps[[a]] / sum(ps[[a]]) } qs <- list() for (a in 1:A) { ds <- which(f == a) d <- ds[1] length_of_q_a <- dim(arrY)[d] qs[[a]] <- rep(1, length_of_q_a); qs[[a]][c(2, 3)] <- 0 qs[[a]] <- qs[[a]] / sum(qs[[a]]) }

X <- as.tensor(arrX) Y <- as.tensor(arrY)

out <- OTT(X = X, Y = Y, D = D, A = A, Is = Is, Ks = Ks, f = f, ps=ps, qs=qs, num.sample=1000, loss = function (x, y) {abs(x - y)}, num.iter=200, epsilon=1e-10)

options(repr.plot.width=6, repr.plot.height=10) par(mfrow=c(3, 2)) plot(ps[[1]], type=“h”, col=“red”, ylim=c(0, 1), main=“ps[[1]]”) plot(qs[[1]], type=“h”, col=“red”, ylim=c(0, 1), main=“qs[[1]]”) .show_matrix(out$Ts[[1]], main=“Ts[[1]]”) .show_matrix(arrX, main=“arrX”) .show_matrix(arrY, main=“arrY”)

par(mfrow=c(3, 2)) plot(ps[[2]], type=“h”, col=“red”, ylim=c(0, 1), main=“ps[[2]]”) plot(qs[[2]], type=“h”, col=“red”, ylim=c(0, 1), main=“qs[[2]]”) .show_matrix(out$Ts[[2]], main=“Ts[[2]]”) .show_matrix(arrX, main=“arrX”) .show_matrix(arrY, main=“arrY”)

OTT_11 (GW)

D <- 2 A <- 1 Is <- c(4, 4) Ks <- c(6, 6) f <- c(1, 1) arrX <- array(rep(0, prod(Is)), Is) arrY <- array(rep(0, prod(Ks)), Ks)

for (i1 in 1:Is[1]) { for (i2 in 1:Is[2]) { arrX[i1, i2] <- i1 + i2 } } for (k1 in 1:Ks[1]) { for (k2 in 1:Ks[2]) { arrY[k1, k2] <- k1 + k2 } }

ps <- list() for (a in 1:A) { ds <- which(f == a) d <- ds[1] length_of_p_a <- dim(arrX)[d] ps[[a]] <- rep(0.01, length_of_p_a); ps[[a]][c(1, 3)] <- 1 ps[[a]] <- ps[[a]] / sum(ps[[a]]) } qs <- list() for (a in 1:A) { ds <- which(f == a) d <- ds[1] length_of_q_a <- dim(arrY)[d] qs[[a]] <- rep(1, length_of_q_a); qs[[a]][c(2, 3)] <- 0 qs[[a]] <- qs[[a]] / sum(qs[[a]]) }

X <- as.tensor(arrX) Y <- as.tensor(arrY)

out <- OTT(X = X, Y = Y, D = D, A = A, Is = Is, Ks = Ks, f = f, ps=ps, qs=qs, num.sample=1000, loss = function (x, y) {abs(x - y)}, num.iter=200, epsilon=1e-10)

options(repr.plot.width=6, repr.plot.height=10) par(mfrow=c(3, 2)) plot(ps[[1]], type=“h”, col=“red”, ylim=c(0, 1), main=“ps[[1]]”) plot(qs[[1]], type=“h”, col=“red”, ylim=c(0, 1), main=“qs[[1]]”) .show_matrix(out$Ts[[1]], main=“Ts[[1]]”) .show_matrix(arrX, main=“arrX”) .show_matrix(arrY, main=“arrY”)

OTT_111 (triplets)

D <- 3 A <- 1 Is <- c(4, 4, 4) Ks <- c(6, 6, 6) f <- c(1, 1, 1) arrX <- array(rep(0, prod(Is)), Is) arrY <- array(rep(0, prod(Ks)), Ks)

for (i1 in 1:Is[1]) { for (i2 in 1:Is[2]) { for (i3 in 1:Is[3]) { arrX[i1, i2, i3] <- i1 + i2 + i3 } } } for (k1 in 1:Ks[1]) { for (k2 in 1:Ks[2]) { for (k3 in 1:Ks[3]) { arrY[k1, k2, k3] <- k1 + k2 + k3 } } }

ps <- list() for (a in 1:A) { ds <- which(f == a) d <- ds[1] length_of_p_a <- dim(arrX)[d] ps[[a]] <- rep(0.01, length_of_p_a); ps[[a]][c(1, 3)] <- 1 ps[[a]] <- ps[[a]] / sum(ps[[a]]) } qs <- list() for (a in 1:A) { ds <- which(f == a) d <- ds[1] length_of_q_a <- dim(arrY)[d] qs[[a]] <- rep(1, length_of_q_a); qs[[a]][c(2, 3)] <- 0 qs[[a]] <- qs[[a]] / sum(qs[[a]]) }

X <- as.tensor(arrX) Y <- as.tensor(arrY)

out <- OTT(X = X, Y = Y, D = D, A = A, Is = Is, Ks = Ks, f = f, ps=ps, qs=qs, num.sample=1000, loss = function (x, y) {abs(x - y)}, num.iter=200, epsilon=1e-10)

options(repr.plot.width=6, repr.plot.height=10) par(mfrow=c(3, 2)) plot(ps[[1]], type=“h”, col=“red”, ylim=c(0, 1), main=“ps[[1]]”) plot(qs[[1]], type=“h”, col=“red”, ylim=c(0, 1), main=“qs[[1]]”) .show_matrix(out$Ts[[1]], main=“Ts[[1]]”) .show_matrix(arrX[,,1], main=“arrX[,,1]”) .show_matrix(arrY[,,1], main=“arrY[,,1]”)

OTT_123 (triCo-OT)

D <- 3 A <- 3 Is <- c(4, 5, 6) Ks <- c(7, 8, 9) f <- c(1, 2, 3) arrX <- array(rep(0, prod(Is)), Is) arrY <- array(rep(0, prod(Ks)), Ks)

for (i1 in 1:Is[1]) { for (i2 in 1:Is[2]) { for (i3 in 1:Is[3]) { arrX[i1, i2, i3] <- i1 + i2 + i3 } } } for (k1 in 1:Ks[1]) { for (k2 in 1:Ks[2]) { for (k3 in 1:Ks[3]) { arrY[k1, k2, k3] <- k1 + k2 + k3 } } }

ps <- list() for (a in 1:A) { ds <- which(f == a) d <- ds[1] length_of_p_a <- dim(arrX)[d] ps[[a]] <- rep(0.01, length_of_p_a); ps[[a]][c(1, 3)] <- 1 ps[[a]] <- ps[[a]] / sum(ps[[a]]) } qs <- list() for (a in 1:A) { ds <- which(f == a) d <- ds[1] length_of_q_a <- dim(arrY)[d] qs[[a]] <- rep(1, length_of_q_a); qs[[a]][c(2, 3)] <- 0 qs[[a]] <- qs[[a]] / sum(qs[[a]]) }

X <- as.tensor(arrX) Y <- as.tensor(arrY)

out <- OTT(X = X, Y = Y, D = D, A = A, Is = Is, Ks = Ks, f = f, ps=ps, qs=qs, num.sample=1000, loss = function (x, y) {abs(x - y)}, num.iter=200, epsilon=1e-10)

options(repr.plot.width=6, repr.plot.height=10) par(mfrow=c(3, 2)) plot(ps[[1]], type=“h”, col=“red”, ylim=c(0, 1), main=“ps[[1]]”) plot(qs[[1]], type=“h”, col=“red”, ylim=c(0, 1), main=“qs[[1]]”) .show_matrix(out$Ts[[1]], main=“Ts[[1]]”) .show_matrix(arrX[,,1], main=“arrX[,,1]”) .show_matrix(arrY[,,1], main=“arrY[,,1]”)

par(mfrow=c(3, 2)) plot(ps[[2]], type=“h”, col=“red”, ylim=c(0, 1), main=“ps[[2]]”) plot(qs[[2]], type=“h”, col=“red”, ylim=c(0, 1), main=“qs[[2]]”) .show_matrix(out$Ts[[2]], main=“Ts[[2]]”) .show_matrix(arrX[,,2], main=“arrX[,,2]”) .show_matrix(arrY[,,2], main=“arrY[,,2]”)

par(mfrow=c(3, 2)) plot(ps[[3]], type=“h”, col=“red”, ylim=c(0, 1), main=“ps[[3]]”) plot(qs[[3]], type=“h”, col=“red”, ylim=c(0, 1), main=“qs[[3]]”) .show_matrix(out$Ts[[3]], main=“Ts[[3]]”) .show_matrix(arrX[,,3], main=“arrX[,,3]”) .show_matrix(arrY[,,3], main=“arrY[,,3]”)

OTT_112 (GW Collection)

D <- 3 A <- 2 Is <- c(4, 4, 5) Ks <- c(6, 6, 7) f <- c(1, 1, 2) arrX <- array(rep(0, prod(Is)), Is) arrY <- array(rep(0, prod(Ks)), Ks)

for (i1 in 1:Is[1]) { for (i2 in 1:Is[2]) { for (i3 in 1:Is[3]) { arrX[i1, i2, i3] <- i1 + i2 + i3 } } } for (k1 in 1:Ks[1]) { for (k2 in 1:Ks[2]) { for (k3 in 1:Ks[3]) { arrY[k1, k2, k3] <- k1 + k2 + k3 } } }

ps <- list() for (a in 1:A) { ds <- which(f == a) d <- ds[1] length_of_p_a <- dim(arrX)[d] ps[[a]] <- rep(0.01, length_of_p_a); ps[[a]][c(1, 3)] <- 1 ps[[a]] <- ps[[a]] / sum(ps[[a]]) } qs <- list() for (a in 1:A) { ds <- which(f == a) d <- ds[1] length_of_q_a <- dim(arrY)[d] qs[[a]] <- rep(1, length_of_q_a); qs[[a]][c(2, 3)] <- 0 qs[[a]] <- qs[[a]] / sum(qs[[a]]) }

X <- as.tensor(arrX) Y <- as.tensor(arrY)

out <- OTT(X = X, Y = Y, D = D, A = A, Is = Is, Ks = Ks, f = f, ps=ps, qs=qs, num.sample=1000, loss = function (x, y) {abs(x - y)}, num.iter=200, epsilon=1e-10)

options(repr.plot.width=6, repr.plot.height=10) par(mfrow=c(3, 2)) plot(ps[[1]], type=“h”, col=“red”, ylim=c(0, 1), main=“ps[[1]]”) plot(qs[[1]], type=“h”, col=“red”, ylim=c(0, 1), main=“qs[[1]]”) .show_matrix(out$Ts[[1]], main=“Ts[[1]]”) .show_matrix(arrX[,,1], main=“arrX[,,1]”) .show_matrix(arrY[,,1], main=“arrY[,,1]”)

par(mfrow=c(3, 2)) plot(ps[[2]], type=“h”, col=“red”, ylim=c(0, 1), main=“ps[[2]]”) plot(qs[[2]], type=“h”, col=“red”, ylim=c(0, 1), main=“qs[[2]]”) .show_matrix(out$Ts[[2]], main=“Ts[[2]]”) .show_matrix(arrX[,,2], main=“arrX[,,2]”) .show_matrix(arrY[,,2], main=“arrY[,,2]”)

Session Information

## R version 3.6.3 (2020-02-29)
## Platform: x86_64-conda-linux-gnu (64-bit)
## Running under: Rocky Linux 9.5 (Blue Onyx)
## 
## Matrix products: default
## BLAS:   /home/koki/miniconda3/lib/libblas.so.3.9.0
## LAPACK: /home/koki/miniconda3/lib/liblapack.so.3.9.0
## 
## locale:
##  [1] LC_CTYPE=en_US.UTF-8       LC_NUMERIC=C              
##  [3] LC_TIME=en_US.UTF-8        LC_COLLATE=C              
##  [5] LC_MONETARY=en_US.UTF-8    LC_MESSAGES=en_US.UTF-8   
##  [7] LC_PAPER=en_US.UTF-8       LC_NAME=C                 
##  [9] LC_ADDRESS=C               LC_TELEPHONE=C            
## [11] LC_MEASUREMENT=en_US.UTF-8 LC_IDENTIFICATION=C       
## 
## attached base packages:
## [1] stats     graphics  grDevices utils     datasets  methods   base     
## 
## other attached packages:
## [1] rTensor_1.4.8   otTensor_0.99.0
## 
## loaded via a namespace (and not attached):
##  [1] digest_0.6.31   R6_2.5.1        jsonlite_1.8.4  evaluate_0.20  
##  [5] highr_0.10      rlang_0.4.11    jquerylib_0.1.4 bslib_0.3.1    
##  [9] rmarkdown_2.11  tools_3.6.3     xfun_0.38       yaml_2.3.7     
## [13] fastmap_1.1.1   compiler_3.6.3  htmltools_0.5.5 knitr_1.42     
## [17] sass_0.4.0

References

Kerdoncuff, T. et al. 2022. “Optimal Tensor Transport.” Proceedings of the AAAI Conference on Artificial Intelligence 36(7): 7124–32.