## ----setup, include=FALSE----------------------------------------------------- knitr::opts_chunk$set(echo = TRUE, fig.width = 6, fig.height = 4) ## ----ot-intuition, fig.height=5----------------------------------------------- # Two simple 1D distributions source_dist <- c(0.4, 0.1, 0.4, 0.1) target_dist <- c(0.1, 0.3, 0.1, 0.3, 0.2) oldpar <- par(mfrow = c(1, 2), mar = c(4, 4, 3, 1)) barplot(source_dist, col = "steelblue", main = "Source distribution", names.arg = seq_along(source_dist), ylim = c(0, 0.5), xlab = "Bin", ylab = "Mass") barplot(target_dist, col = "tomato", main = "Target distribution", names.arg = seq_along(target_dist), ylim = c(0, 0.5), xlab = "Bin", ylab = "Mass") par(mfrow = oldpar) ## ----tensor-illustration, fig.height=4---------------------------------------- oldpar <- par(mfrow = c(1, 3), mar = c(2, 2, 3, 1)) # Vector (order 1) barplot(c(3, 1, 4, 1, 5), col = "steelblue", main = "Order 1: Vector") # Matrix (order 2) mat <- matrix(c(1, 2, 3, 4, 2, 3, 4, 5, 3, 4, 5, 6), nrow = 3) image(mat, col = gray((0:255) / 255), axes = FALSE, main = "Order 2: Matrix") # 3D tensor (show one slice) arr <- array(0, dim = c(3, 4, 2)) arr[,,1] <- matrix(c(1, 2, 3, 4, 2, 3, 4, 5, 3, 4, 5, 6), nrow = 3) arr[,,2] <- matrix(c(6, 5, 4, 3, 5, 4, 3, 2, 4, 3, 2, 1), nrow = 3) image(arr[,,1], col = gray((0:255) / 255), axes = FALSE, main = "Order 3: Tensor\n(slice 1)") par(mfrow = oldpar) ## ----quickstart, message=FALSE------------------------------------------------ library("otTensor") library("rTensor") ## ----create-tensors----------------------------------------------------------- # Source: a 4 x 5 matrix arrX <- matrix(0, nrow = 4, ncol = 5) for (i in 1:4) { for (j in 1:5) { arrX[i, j] <- i + j } } # Target: a 6 x 7 matrix (different size is OK) arrY <- matrix(0, nrow = 6, ncol = 7) for (i in 1:6) { for (j in 1:7) { arrY[i, j] <- i + j } } # Convert to Tensor objects X <- as.tensor(arrX) Y <- as.tensor(arrY) ## ----set-f-------------------------------------------------------------------- f <- c(1, 2) ## ----run-ott------------------------------------------------------------------ result <- OTT(X = X, Y = Y, f = f, num.sample = 500, num.iter = 100) ## ----inspect-results---------------------------------------------------------- # Transport plan dimensions cat("Transport plan 1:", dim(result$Ts[[1]]), "\n") cat("Transport plan 2:", dim(result$Ts[[2]]), "\n") ## ----visualize-results, fig.height=5, fig.width=6----------------------------- .show_matrix <- function(mat, main = "") { mat_rev <- t(apply(mat, 2, rev)) image(mat_rev, col = gray((0:255) / 255), xaxt = "n", yaxt = "n", xlab = "", ylab = "", axes = FALSE, main = main) } oldpar <- par(mfrow = c(2, 2), mar = c(2, 2, 3, 1)) .show_matrix(arrX, main = "Source (X)") .show_matrix(arrY, main = "Target (Y)") .show_matrix(result$Ts[[1]], main = "Transport Plan 1\n(rows)") .show_matrix(result$Ts[[2]], main = "Transport Plan 2\n(columns)") par(mfrow = oldpar) ## ----sessionInfo, echo=FALSE-------------------------------------------------- sessionInfo()