{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "58b7d540",
   "metadata": {},
   "source": [
    "## GPU Usage\n",
    "\n",
    "author: Jacob Schreiber <br>\n",
    "contact: jmschreiber91@gmail.com\n",
    "\n",
    "Because pomegranate models are all instances of `torch.nn.Module`, you can do anything with them that you could do with other PyTorch models. This includes using GPUs or any other device that is supported by PyTorch using exactly the same method calls. Here, we will see how to use GPUs to speed up training and inference."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "94f8d0a5",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-04-16T03:34:38.820898Z",
     "start_time": "2023-04-16T03:34:32.232716Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Populating the interactive namespace from numpy and matplotlib\n",
      "numpy      : 1.23.4\n",
      "scipy      : 1.9.3\n",
      "torch      : 1.13.0\n",
      "pomegranate: 1.0.0\n",
      "\n",
      "Compiler    : GCC 11.2.0\n",
      "OS          : Linux\n",
      "Release     : 4.15.0-208-generic\n",
      "Machine     : x86_64\n",
      "Processor   : x86_64\n",
      "CPU cores   : 8\n",
      "Architecture: 64bit\n",
      "\n"
     ]
    }
   ],
   "source": [
    "import os\n",
    "os.environ['CUDA_VISIBLE_DEVICES'] = '0'\n",
    "\n",
    "%pylab inline\n",
    "import seaborn; seaborn.set_style('whitegrid')\n",
    "\n",
    "import torch\n",
    "\n",
    "numpy.random.seed(0)\n",
    "numpy.set_printoptions(suppress=True)\n",
    "\n",
    "%load_ext watermark\n",
    "%watermark -m -n -p numpy,scipy,torch,pomegranate"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a479bbc4",
   "metadata": {},
   "source": [
    "### Overview\n",
    "\n",
    "All models and methods in pomegranate can be sped up using GPUs using the exact same methods as other PyTorch models. Let's see that in action."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "124bc1ec",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-04-16T03:34:38.873756Z",
     "start_time": "2023-04-16T03:34:38.825577Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(Parameter containing:\n",
       " tensor([-0.0236, -0.0205, -0.0137,  0.0102,  0.0062]),\n",
       " Parameter containing:\n",
       " tensor([[ 1.0031e+00, -7.6397e-04,  1.2268e-03,  1.8175e-02, -1.2175e-02],\n",
       "         [-7.6397e-04,  1.0070e+00,  1.0800e-02,  1.4855e-02,  2.5852e-02],\n",
       "         [ 1.2268e-03,  1.0800e-02,  1.0058e+00,  7.3037e-03, -1.7321e-03],\n",
       "         [ 1.8175e-02,  1.4855e-02,  7.3037e-03,  1.0094e+00, -7.5648e-03],\n",
       "         [-1.2175e-02,  2.5852e-02, -1.7321e-03, -7.5648e-03,  9.8117e-01]]))"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from pomegranate.distributions import Normal\n",
    "\n",
    "X = torch.randn(5000, 5)\n",
    "dist = Normal().fit(X)\n",
    "\n",
    "dist.means, dist.covs"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b160ed65",
   "metadata": {},
   "source": [
    "All we need to do is use the `.cuda()` method or the `.to(device)` method on the data and the model. Similar to PyTorch, the model will not automatically move data to the GPU for you. You have to do this yourself."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "6e558641",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-04-16T03:34:39.496200Z",
     "start_time": "2023-04-16T03:34:38.876199Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(Parameter containing:\n",
       " tensor([-0.0236, -0.0205, -0.0137,  0.0102,  0.0062], device='cuda:0'),\n",
       " Parameter containing:\n",
       " tensor([[ 1.0031e+00, -7.6398e-04,  1.2268e-03,  1.8175e-02, -1.2175e-02],\n",
       "         [-7.6398e-04,  1.0070e+00,  1.0800e-02,  1.4855e-02,  2.5852e-02],\n",
       "         [ 1.2268e-03,  1.0800e-02,  1.0058e+00,  7.3037e-03, -1.7321e-03],\n",
       "         [ 1.8175e-02,  1.4855e-02,  7.3037e-03,  1.0094e+00, -7.5648e-03],\n",
       "         [-1.2175e-02,  2.5852e-02, -1.7321e-03, -7.5648e-03,  9.8117e-01]],\n",
       "        device='cuda:0'))"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "dist = Normal().cuda().fit(X.cuda())\n",
    "\n",
    "dist.means, dist.covs"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c9d0cc1c",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-04-16T01:17:06.819163Z",
     "start_time": "2023-04-16T01:17:06.810940Z"
    }
   },
   "source": [
    "All models operate in the same way."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "6b24e63f",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-04-16T03:34:39.626392Z",
     "start_time": "2023-04-16T03:34:39.499859Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[1] Improvement: 134.63671875, Time: 0.001251s\n",
      "[2] Improvement: 38.76953125, Time: 0.001229s\n",
      "[3] Improvement: 17.02734375, Time: 0.001173s\n",
      "[4] Improvement: 9.13671875, Time: 0.001179s\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "GeneralMixtureModel(\n",
       "  (distributions): ModuleList(\n",
       "    (0-1): 2 x Normal()\n",
       "  )\n",
       ")"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from pomegranate.gmm import GeneralMixtureModel\n",
    "\n",
    "model = GeneralMixtureModel([Normal(), Normal()], max_iter=5, verbose=True).cuda()\n",
    "model.fit(X.cuda())"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "20fbb008",
   "metadata": {},
   "source": [
    "### Timing Examples\n",
    "\n",
    "Using a GPU helps the most when the workload is complex. So, we will see only minimal gains on small, simple, workloads like basic probability distributions. For these evaluations we will do timings on the CPU, using the GPU but including the time to transfer everything there, and using the GPU once everything is already there. All evaluations are done on an A100."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "55c8bc9d",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-04-16T03:34:53.804660Z",
     "start_time": "2023-04-16T03:34:39.629276Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "28.4 µs ± 107 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)\n",
      "91.3 µs ± 361 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)\n",
      "54.5 µs ± 170 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)\n"
     ]
    }
   ],
   "source": [
    "n, d = 100, 5\n",
    "\n",
    "X = torch.randn(n, d)\n",
    "X_cuda = X.cuda()\n",
    "\n",
    "mu = torch.randn(d)\n",
    "cov = torch.exp(torch.randn(d))\n",
    "\n",
    "d = Normal(mu, cov, covariance_type='diag')\n",
    "d_cuda = Normal(mu, cov, covariance_type='diag').cuda()\n",
    "\n",
    "%timeit d.log_probability(X)\n",
    "%timeit d.cuda().log_probability(X.cuda())\n",
    "%timeit d_cuda.log_probability(X_cuda)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "08e206ab",
   "metadata": {},
   "source": [
    "For this extremely small workload, using a GPU is much slower than just doing everything on the CPU. Let's try a larger example."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "9fc26808",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-04-16T03:35:11.130177Z",
     "start_time": "2023-04-16T03:34:53.806166Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "131 µs ± 7.32 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)\n",
      "256 µs ± 108 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each)\n",
      "54.2 µs ± 68.3 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)\n"
     ]
    }
   ],
   "source": [
    "n, d = 10000, 50\n",
    "\n",
    "X = torch.randn(n, d)\n",
    "X_cuda = X.cuda()\n",
    "\n",
    "mu = torch.randn(d)\n",
    "cov = torch.exp(torch.randn(d))\n",
    "\n",
    "d = Normal(mu, cov, covariance_type='diag')\n",
    "d_cuda = Normal(mu, cov, covariance_type='diag').cuda()\n",
    "\n",
    "%timeit d.log_probability(X)\n",
    "%timeit d.cuda().log_probability(X.cuda())\n",
    "%timeit d_cuda.log_probability(X_cuda)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8c144010",
   "metadata": {},
   "source": [
    "Looks like the CPU time increased several times over but the GPU times didn't change as much. This is because using a GPU has a relatively large fixed cost but the variable cost associating with increasing the size of the data doesn't increase nearly as fast. Let's try a huge example."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "a21e44ac",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-04-16T03:35:31.914911Z",
     "start_time": "2023-04-16T03:35:11.136952Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "844 ms ± 90 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n",
      "213 ms ± 639 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)\n",
      "11.7 ms ± 2.97 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n"
     ]
    }
   ],
   "source": [
    "n, d = 100000, 5000\n",
    "\n",
    "X = torch.randn(n, d)\n",
    "X_cuda = X.cuda()\n",
    "\n",
    "mu = torch.randn(d)\n",
    "cov = torch.exp(torch.randn(d))\n",
    "\n",
    "d = Normal(mu, cov, covariance_type='diag')\n",
    "d_cuda = Normal(mu, cov, covariance_type='diag').cuda()\n",
    "\n",
    "%timeit d.log_probability(X)\n",
    "%timeit d.cuda().log_probability(X.cuda())\n",
    "%timeit d_cuda.log_probability(X_cuda)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0ed1eafc",
   "metadata": {},
   "source": [
    "We can see the expected results here: despite having to transfer data to and from the GPU it is faster to do it than use a CPU for large data sets, and if you're in a setting where everything is already on the GPU you can get huge speed boosts.\n",
    "\n",
    "Now, if you have a more complicated model, you can unlock even larger speed boosts."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "d50706e3",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-04-16T03:35:42.652225Z",
     "start_time": "2023-04-16T03:35:31.917040Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "10.1 s ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)\n",
      "643 ms ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)\n"
     ]
    }
   ],
   "source": [
    "from pomegranate.kmeans import KMeans\n",
    "\n",
    "model1 = KMeans(512)\n",
    "model2 = KMeans(512)\n",
    "\n",
    "%timeit -n 1 -r 1 model1.fit(X)\n",
    "%timeit -n 1 -r 1 model2.cuda().fit(X.cuda())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "7e617469",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-04-16T03:35:42.754909Z",
     "start_time": "2023-04-16T03:35:42.654315Z"
    }
   },
   "outputs": [],
   "source": [
    "del X, model1, model2"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7b1b51e4",
   "metadata": {},
   "source": [
    "Seems significantly faster.\n",
    "\n",
    "Now, let's try with an even more complex model: the dense hidden Markov model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "a7540a37",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-04-16T03:35:52.233950Z",
     "start_time": "2023-04-16T03:35:42.757731Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "8.23 s ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)\n",
      "1.08 s ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)\n"
     ]
    }
   ],
   "source": [
    "from pomegranate.hmm import DenseHMM\n",
    "\n",
    "n, l, d = 1000, 25, 15\n",
    "X = torch.randn(n, l, d)\n",
    "\n",
    "k = 256\n",
    "\n",
    "dists1, dists2 = [], []\n",
    "for i in range(k):\n",
    "    mu = torch.randn(d)\n",
    "    covs = torch.exp(torch.randn(d))\n",
    "    \n",
    "    dist1 = Normal(mu, covs, covariance_type='diag')\n",
    "    dist2 = Normal(mu, covs, covariance_type='diag').cuda()\n",
    "    \n",
    "    dists1.append(dist1)\n",
    "    dists2.append(dist2)\n",
    "    \n",
    "\n",
    "model1 = DenseHMM(dists1, max_iter=3)\n",
    "model2 = DenseHMM(dists2, max_iter=3).cuda()\n",
    "\n",
    "X_cuda = X.cuda()\n",
    "\n",
    "%timeit -n 1 -r 1 model1.fit(X)\n",
    "%timeit -n 1 -r 1 model2.fit(X_cuda)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
