From 0b2262c1cfaa8174acc165f9efc3eb1ff5895412 Mon Sep 17 00:00:00 2001 From: Jack Tang Date: Sat, 22 Jul 2023 23:20:14 +0800 Subject: [PATCH 1/2] #460 add gpu problem demo --- docs/source/problems/parallelization.ipynb | 56 +++++++++++++++++++++- 1 file changed, 55 insertions(+), 1 deletion(-) diff --git a/docs/source/problems/parallelization.ipynb b/docs/source/problems/parallelization.ipynb index 80fd66efc..bf0fc51ae 100644 --- a/docs/source/problems/parallelization.ipynb +++ b/docs/source/problems/parallelization.ipynb @@ -13,6 +13,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": { "pycharm": { @@ -24,6 +25,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": { "pycharm": { @@ -37,6 +39,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": { "pycharm": { @@ -76,6 +79,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": { "pycharm": { @@ -104,6 +108,40 @@ ] }, { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## GPU Acceleration\n", + "\n", + "If the problem evaluation takes a lot of time, we can optimize above vectorized matrix operation by adopting GPU acceleration. The modern GPU matrix manipulation framework such as pyTorch or JAX makes it easy." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np\n", + "import torch\n", + "from pymoo.core.problem import Problem\n", + "\n", + "class MyProblem(Problem):\n", + "\n", + " def __init__(self, **kwargs):\n", + " super().__init__(n_var=10, n_obj=1, n_ieq_constr=0, xl=-5, xu=5, **kwargs)\n", + "\n", + " def _evaluate(self, x, out, *args, **kwargs):\n", + " x = torch.from_numpy(x).cuda()\n", + " f = torch.sum(torch.pow(x, 2), dim=1)\n", + " out[\"F\"] = f.detach().cpu().clone().numpy()\n", + "\n", + "problem = MyProblem()" + ] + }, + { + "attachments": {}, "cell_type": "markdown", "metadata": { "pycharm": { @@ -143,6 +181,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": { "pycharm": { @@ -154,6 +193,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": { "pycharm": { @@ -198,6 +238,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": { "pycharm": { @@ -241,6 +282,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": { "pycharm": { @@ -252,6 +294,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": { "pycharm": { @@ -263,6 +306,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": { "pycharm": { @@ -308,6 +352,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": { "pycharm": { @@ -319,6 +364,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": { "pycharm": { @@ -330,6 +376,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": { "pycharm": { @@ -341,6 +388,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": { "pycharm": { @@ -352,6 +400,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": { "pycharm": { @@ -428,6 +477,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": { "pycharm": { @@ -489,7 +539,11 @@ ] } ], - "metadata": {}, + "metadata": { + "language_info": { + "name": "python" + } + }, "nbformat": 4, "nbformat_minor": 4 } From 5505cf95028b056bad3ccca3643bba363baa0375 Mon Sep 17 00:00:00 2001 From: Jack Tang Date: Mon, 24 Jul 2023 13:53:29 +0800 Subject: [PATCH 2/2] add jax accelerated problem --- docs/source/problems/parallelization.ipynb | 60 +++++++++++++++++++++- 1 file changed, 59 insertions(+), 1 deletion(-) diff --git a/docs/source/problems/parallelization.ipynb b/docs/source/problems/parallelization.ipynb index bf0fc51ae..e76892006 100644 --- a/docs/source/problems/parallelization.ipynb +++ b/docs/source/problems/parallelization.ipynb @@ -114,7 +114,20 @@ "source": [ "## GPU Acceleration\n", "\n", - "If the problem evaluation takes a lot of time, we can optimize above vectorized matrix operation by adopting GPU acceleration. The modern GPU matrix manipulation framework such as pyTorch or JAX makes it easy." + "If the problem evaluation takes a lot of time, we can optimize above vectorized matrix operation by adopting GPU acceleration. The modern GPU matrix manipulation framework such as PyTorch or JAX makes it easy." + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### PyTorch\n", + "\n", + "The problem is evaluated using PyTorch framework should follow below steps:\n", + "1. Converts numpy vectorized matrix to tensor and copy the data to cuda device\n", + "1. Calculates the problem using tensor\n", + "1. Returns the final results and copy to CPU so that pymoo will schedule it to next iteration." ] }, { @@ -140,6 +153,51 @@ "problem = MyProblem()" ] }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### JAX\n", + "\n", + "JAX as accelerated numpy and it provides a numpy-inspired interface for convenience. By default JAX executes operations one at a time, in sequence. Using a just-in-time (JIT) compilation decorator, sequences of operations can be optimized together and run at once. In order to apply JIT compilation decorator, some private helper functions `_eval_F` and `_eval_G` are wrapped.\n", + "\n", + "**IMPORTANT:** user should turn on float64 configuration if the problem's dtype is float64, otherwise some precision may lose and the result may be different." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import jax.numpy as jnp\n", + "import numpy as np\n", + "import jax\n", + "from jax.config import config\n", + "from functools import partial\n", + "from pymoo.core.problem import Problem\n", + "\n", + "config.update(\"jax_enable_x64\", True) # default is float32 \n", + "config.update('jax_disable_jit', False) # for debugging\n", + "\n", + "class MyProblem(Problem):\n", + "\n", + " def __init__(self, **kwargs):\n", + " super().__init__(n_var=10, n_obj=1, n_ieq_constr=0, xl=-50, xu=50, **kwargs)\n", + "\n", + " def _evaluate(self, x, out, *args, **kwargs):\n", + " _x = jnp.array(x)\n", + " f = self._eval_F(_x)\n", + " out[\"F\"] = np.asarray(f)\n", + "\n", + " @partial(jax.jit, static_argnums=0)\n", + " def _eval_F(self, x):\n", + " return jnp.sum(jnp.power(x, 2), axis=1)\n", + " \n", + "problem = MyProblem()" + ] + }, { "attachments": {}, "cell_type": "markdown",