diff --git a/.gitignore b/.gitignore index 6763329..e545efe 100644 --- a/.gitignore +++ b/.gitignore @@ -142,3 +142,5 @@ cython_debug/ # Mac .DS_Store + +.cursor \ No newline at end of file diff --git a/Accelerated_Python_User_Guide/notebooks/Chapter_12_Intro_to_NVIDIA_Warp.ipynb b/Accelerated_Python_User_Guide/notebooks/Chapter_12_Intro_to_NVIDIA_Warp.ipynb index 23a09a8..faee91f 100644 --- a/Accelerated_Python_User_Guide/notebooks/Chapter_12_Intro_to_NVIDIA_Warp.ipynb +++ b/Accelerated_Python_User_Guide/notebooks/Chapter_12_Intro_to_NVIDIA_Warp.ipynb @@ -471,6 +471,7 @@ "metadata": {}, "outputs": [], "source": [ + "\n", "%%writefile Chapter_12_finite_difference.py\n", "\n", "import warp as wp\n", @@ -1429,6 +1430,438 @@ "**How does Warp know how to evaluate derivatives exactly?** AD systems implement the known derivatives for a finite set of elementary operations. The chain rule is used to combine the elementary derivatives together to get the overall derivative." ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "___\n", + "\n", + "## Interoperating with PyTorch" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Let us install the PyTorch package. We will use PyTorch custom operators extensively in this section, so make sure you have PyTorch >= 2.4 installed for this support." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import warp as wp\n", + "\n", + "# Install the latest version of PyTorch from the link below\n", + "!pip install torch # For PyTorch installation, needs to be >=2.4 for PyTorch custom operators to work\n", + "# Choose the appropriate installation command for your system configuration from the link below\n", + "# https://pytorch.org/get-started/locally/\n", + "\n", + "import torch" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "if torch.__version__ < \"2.4\":\n", + " print(\"PyTorch version is less than 2.4, please install PyTorch >= 2.4\")\n", + "else:\n", + " print(\"PyTorch version is 2.4 or greater, all good!\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Preliminary examples \n", + "Warp provides `wp.to_torch()` and `wp.from_torch()` helper functions to convert arrays to/from PyTorch tensors without copying the underlying data (works both on CPU and GPU). If an associated gradient array exists, that is also converted simultaneously. Some small examples are provided below." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Warp --> PyTorch conversion\n", + "\n", + "# Construct a Warp array, including its corresponding gradient array\n", + "w = wp.array(\n", + " [1.0, 2.0, 3.0], dtype=wp.float32, requires_grad=True, device=wp.get_device()\n", + ")\n", + "\n", + "# Fill w.grad with 1.0 for now\n", + "w.grad.fill_(1.0)\n", + "\n", + "# Convert to Torch tensor\n", + "t = wp.to_torch(w)\n", + "\n", + "print(\"t = \", t)\n", + "print(\"t.grad = \", t.grad)\n", + "\n", + "# Set all t.grad to zero in PyTorch\n", + "t.grad.zero_()\n", + "\n", + "print(\n", + " \"After zeroing the grad from PyTorch interface, printing from Warp interface \",\n", + " w.grad,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# PyTorch --> Warp conversion\n", + "\n", + "# Construct a Torch tensor, and its corresponding gradient tensor\n", + "t = torch.tensor(\n", + " [1.0, 2.0, 3.0],\n", + " dtype=torch.float32,\n", + " requires_grad=True,\n", + " device=torch.device(\"cuda:0\"),\n", + ")\n", + "\n", + "# Convert Torch tensor to Warp array\n", + "\n", + "w = wp.from_torch(t)\n", + "\n", + "# Print array value and corresponding gradients\n", + "print(\"w = \", w)\n", + "print(\"w.grad = \", w.grad)\n", + "\n", + "# Set all w.grad to 1.0 and print from PyTorch interface\n", + "\n", + "w.grad.fill_(1.0)\n", + "print(\n", + " \"After setting grad values to 1 from Warp interface, printing from PyTorch interface \",\n", + " t.grad,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Writing custom operators in PyTorch using Warp\n", + "\n", + "In this example we will determine the $(x, y)$ values at which the [Rosenbrock function](https://www.sfu.ca/~ssurjano/rosen.html), a non-convex function often used for testing optimization algorithms, attains its minimum value. The function is defined as follows:\n", + "\n", + "$$\n", + "f(x,y) = (a-x)^2 + b(y-x^2)^2\n", + "$$\n", + "\n", + "where a = 1 and b = 100. Analytically, we can find that the minimum value of $f(x,y)$ occurs at $x=y=1$ with $f(1,1)=0$.\n", + "\n", + "We will make use of the PyTorch custom operators (available for PyTorch 2.4 or later), that will allow us to incorporate Warp kernel launches (in both forward and backward mode) in a PyTorch graph. PyTorch custom operators allow you to wrap Python functions (in this case, Warp kernel launches) so that they behave like PyTorch native operators. See the [PyTorch docs](https://pytorch.org/tutorials/advanced/python_custom_ops.html#adding-training-support-for-crop) for more information on PyTorch custom operators. This is particularly useful when you have a computational graph that is managed in PyTorch but you want to use Warp kernels in one or more nodes. In the following example, we use Warp to evaluate the Rosenbrock function in both forward as well as the backward pass, while using PyTorch's Adam optimizer to determine the function's minimum.\n", + "\n", + "*Note*: it is also possible to subclass `torch.autograd.function` to the same effect.\n", + "\n", + "Below, we define the Warp kernel `eval_rosenbrock` in the usual way, and wrap its forward implementation with the custom PyTorch operator `warp_rosenbrock` as well its adjoint with `warp_rosenbrock_backward`." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Let us define a `wp.func` for evaluating the Rosenbrock function at any given $(x, y)$ point. After that, we also define a `wp.kernel` for evaluating the Rosenbrock function on a collection of $(x,y)$ points in parallel." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Define the Rosenbrock function and forward kernel in Warp\n", + "@wp.func\n", + "def rosenbrock(x: float, y: float):\n", + " return (1.0 - x) ** 2.0 + 100.0 * (y - x**2.0) ** 2.0\n", + "\n", + "\n", + "@wp.kernel\n", + "def eval_rosenbrock(xy: wp.array(dtype=wp.vec2), z: wp.array(dtype=wp.float32)):\n", + " i = wp.tid()\n", + " v = xy[i]\n", + " z[i] = rosenbrock(v[0], v[1])" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Let us first write the custom operator for the forward pass in PyTorch. The `wp::warp_rosenbrock` custom operator launches the `eval_rosenbrock` kernel through the PyTorch interface. For any custom operator, we also need to register its `FakeTensor` implementation, which allows PyTorch to determine the shape and data type of the output from the custom operator." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# wp is our namespace that groups all our custom operators\n", + "# warp_rosenbrock is the custom operator we are defining for forward pass\n", + "# mutates_args is empty since we are not modifying any input arguments in-place\n", + "@torch.library.custom_op(\"wp::warp_rosenbrock\", mutates_args=())\n", + "def warp_rosenbrock(xy: torch.Tensor, num_particles: int) -> torch.Tensor:\n", + " wp_xy = wp.from_torch(xy, dtype=wp.vec2, requires_grad=False)\n", + " wp_z = wp.zeros(num_particles, dtype=wp.float32, requires_grad=False)\n", + "\n", + " wp.launch(eval_rosenbrock, dim=num_particles, inputs=[wp_xy], outputs=[wp_z])\n", + "\n", + " return wp.to_torch(wp_z)\n", + "\n", + "\n", + "# Registers a FakeTensor implementation of warp_rosenbrock operator\n", + "# Needed to reason out the shape and type of the output from the operator, at compile-time, without actually evaluating the operator\n", + "# Each custom operator must have a register_fake function\n", + "@warp_rosenbrock.register_fake\n", + "def _(xy, num_particles):\n", + " return torch.empty(num_particles, dtype=torch.float32)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Similar to the custom operator for the forward pass, we define the custom operator for the backward pass `wp::warp_rosenbrock_backward` in Warp below. Note the `adjoint=True` in the `wp.launch(...)` kernel call that invokes the backward version of the kernel `eval_rosenbrock`." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Similar to warp_rosenbrock operator, we define warp_rosenbrock_backward operator that also tracks the gradients w.r.t xy\n", + "# Notice that adjoint=True in the wp.launch(...) call for eval_rosenbrock\n", + "@torch.library.custom_op(\"wp::warp_rosenbrock_backward\", mutates_args=())\n", + "def warp_rosenbrock_backward(\n", + " xy: torch.Tensor, num_particles: int, z: torch.Tensor, adj_z: torch.Tensor\n", + ") -> torch.Tensor:\n", + " wp_xy = wp.from_torch(xy, dtype=wp.vec2)\n", + " wp_z = wp.from_torch(z, requires_grad=False)\n", + " wp_adj_z = wp.from_torch(adj_z, requires_grad=False)\n", + "\n", + " wp.launch(\n", + " eval_rosenbrock,\n", + " dim=num_particles,\n", + " inputs=[wp_xy],\n", + " outputs=[wp_z],\n", + " adj_inputs=[wp_xy.grad],\n", + " adj_outputs=[wp_adj_z],\n", + " adjoint=True,\n", + " )\n", + " return wp.to_torch(wp_xy.grad)\n", + "\n", + "\n", + "# Similar to the FakeTensor implementation for warp_rosenbrock\n", + "@warp_rosenbrock_backward.register_fake\n", + "def _(xy, num_particles, z, adj_z):\n", + " return torch.empty_like(xy)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Now that we have both the forward and backward custom operators defined using `torch.library.custom_op`, we need to register the backward custom operator so that PyTorch knows how to perform the backward pass. Please take a look at the detailed description on the PyTorch website [here](https://docs.pytorch.org/docs/stable/library.html#torch.library.register_autograd). The outputs of the `def backward(...)` function are the gradients of $z(x,y)$ with respect to $x$ and $y$. " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Register backward pass implementation for automatic differentiation\n", + "# The backward function calls the custom `warp_rosenbrock_backward` operator, defined above, to compute gradients w.r.t. the inputs\n", + "def backward(ctx, adj_z):\n", + " ctx.xy.grad = warp_rosenbrock_backward(ctx.xy, ctx.num_particles, ctx.z, adj_z)\n", + " return ctx.xy.grad, None\n", + "\n", + "# setup_context builds the context object ctx that stores information from the forward pass needed for the backward function\n", + "def setup_context(ctx, inputs, output):\n", + " ctx.xy, ctx.num_particles = inputs\n", + " ctx.z = output\n", + "\n", + "# For the warp_rosenbrock operator, we register the backward function as well as the setup_context function defined above\n", + "warp_rosenbrock.register_autograd(backward, setup_context=setup_context)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np\n", + "\n", + "# Number of points in the x-y plane\n", + "num_particles = 1500\n", + "\n", + "# Initial point positions randomly distributed in the x-y plane\n", + "# As the optimization happens, these points should converge to (1, 1)\n", + "rng = np.random.default_rng(42)\n", + "xy = torch.tensor(\n", + " rng.normal(size=(num_particles, 2)),\n", + " dtype=torch.float32,\n", + " requires_grad=True,\n", + " device=wp.device_to_torch(wp.get_device()),\n", + ")\n", + "\n", + "# PyTorch Adam optimizer at learning rate 5e-2, defined for the xy tensor\n", + "opt = torch.optim.Adam([xy], lr=5e-2)\n", + "\n", + "# Forward pass of the function\n", + "def forward():\n", + " global xy, num_particles\n", + " z = warp_rosenbrock(xy, num_particles)\n", + " return z\n", + "\n", + "\n", + "# Single step of the optimization\n", + "# This is your typical optimization step that you might have seen when training other ML models\n", + "# The key difference here is that we are updating the (x,y) points in the x-y plane to reach the minimum of the Rosenbrock function\n", + "def step():\n", + " opt.zero_grad() # Set the gradients to zero\n", + " z = forward() # Forward pass that calls def forward(...) and ultimately calls the warp_rosenbrock operator\n", + " z.backward(torch.ones_like(z)) # Backward pass that calls the warp_rosenbrock_backward operator\n", + " opt.step() # Update the (x,y) points in-place using the gradients obtained from the backward pass" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "%matplotlib widget\n", + "import matplotlib.pyplot as plt\n", + "\n", + "# Number of points in the x-y plane on which the function is evaluated (these points should converge to (1, 1) during optimization)\n", + "num_particles = 1500\n", + "\n", + "# Initial positions of the points in the x-y plane\n", + "rng = np.random.default_rng(42)\n", + "xy = torch.tensor(\n", + " rng.normal(size=(num_particles, 2)),\n", + " dtype=torch.float32,\n", + " requires_grad=True,\n", + " device=wp.device_to_torch(wp.get_device()),\n", + ")\n", + "\n", + "# PyTorch Adam optimizer\n", + "opt = torch.optim.Adam([xy], lr=5e-2)\n", + "\n", + "# Domain\n", + "min_x, max_x = -2.0, 2.0\n", + "min_y, max_y = -2.0, 2.0\n", + "\n", + "# Create a grid of points\n", + "x = np.linspace(min_x, max_x, 100)\n", + "y = np.linspace(min_y, max_y, 100)\n", + "X, Y = np.meshgrid(x, y)\n", + "XY = np.column_stack((X.flatten(), Y.flatten()))\n", + "N = len(XY)\n", + "\n", + "XY = wp.array(XY, dtype=wp.vec2)\n", + "Z = wp.empty(N, dtype=wp.float32)\n", + "\n", + "# Evaluate the function over the domain\n", + "wp.launch(eval_rosenbrock, dim=N, inputs=[XY], outputs=[Z])\n", + "Z = Z.numpy().reshape(X.shape)\n", + "\n", + "# Plot the function as a heatmap\n", + "fig = plt.figure(figsize=(6, 6))\n", + "ax = plt.gca()\n", + "\n", + "plt.imshow(\n", + " Z,\n", + " extent=[min_x, max_x, min_y, max_y],\n", + " origin=\"lower\",\n", + " interpolation=\"bicubic\",\n", + " cmap=\"coolwarm\",\n", + ")\n", + "\n", + "plt.contour(\n", + " X,\n", + " Y,\n", + " Z,\n", + " extent=[min_x, max_x, min_y, max_y],\n", + " levels=150,\n", + " colors=\"k\",\n", + " alpha=0.5,\n", + " linewidths=0.5,\n", + ")\n", + "\n", + "# Plot the optimum as a red star\n", + "plt.plot(1, 1, \"*\", color=\"r\", markersize=10)\n", + "\n", + "plt.title(\"Rosenbrock function\")\n", + "plt.xlabel(\"x\")\n", + "plt.ylabel(\"y\")\n", + "\n", + "(mean_marker,) = ax.plot([], [], \"o\", color=\"w\", markersize=5)\n", + "\n", + "# Create a scatter plot (initially empty)\n", + "scatter_plot = ax.scatter([], [], c=\"k\", s=2)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import matplotlib.animation\n", + "import IPython\n", + "\n", + "plt.rc(\"animation\", html=\"jshtml\")\n", + "\n", + "\n", + "# Function to update the scatter plot\n", + "def render():\n", + " # Compute mean\n", + " xy_np = xy.numpy(force=True)\n", + " mean_pos = np.mean(xy_np, axis=0)\n", + "\n", + " # Update the scatter plot\n", + " scatter_plot.set_offsets(np.c_[xy_np[:, 0], xy_np[:, 1]])\n", + " mean_marker.set_data([mean_pos[0]], [mean_pos[1]])\n", + "\n", + "\n", + "# Optimize then render\n", + "def step_and_render(frame):\n", + " for i in range(200):\n", + " step()\n", + " render()\n", + "\n", + "\n", + "# Create the animation and visualize in Matplotlib\n", + "plot_anim = matplotlib.animation.FuncAnimation(\n", + " fig, step_and_render, frames=30, interval=100\n", + ")\n", + "\n", + "# Display the result\n", + "IPython.display.display(plot_anim)\n", + "plt.close()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "In the animation above, the red star represents the optimum, the white dot shows the mean of $(x, y)$ coordinates across all 1500 points, and the black dots are the individual points. As the optimization progresses, both the white dot and the individual black dots converge toward the red star, validating our hybrid PyTorch-Warp optimization scheme for the Rosenbrock function." + ] + }, { "cell_type": "markdown", "metadata": {}, @@ -1436,7 +1869,7 @@ "___\n", "## Conclusion\n", "\n", - "This notebook provided an introduction to the core components of Warp. For more examples, see the [Warp example gallery](https://github.com/NVIDIA/warp?tab=readme-ov-file#running-examples) on GitHub.\n", + "This notebook provided an introduction to the core components of Warp. We also took a look at PyTorch-Warp interoperability towards the end of the notebook. For more examples, see the [Warp example gallery](https://github.com/NVIDIA/warp?tab=readme-ov-file#running-examples) on GitHub.\n", "\n", "The repository at https://github.com/shi-eric/warp-lanl-tutorial-2025-05 also contains a set of tutorials for Warp.\n", "\n", @@ -1462,7 +1895,9 @@ "\n", "- Atilim Gunes Baydin, Barak A. Pearlmutter, Alexey Andreyevich Radul, Jeffrey Mark Siskind, \"[Automatic differentiation in machine learning: a survey](https://arxiv.org/abs/1502.05767)\", The Journal of Machine Learning Research, 18(153), 1-43, 2018.\n", "- Andreas Griewank and Andrea Walther, \"[Evaluating Derivatives: Principles and Techniques of Algorithmic Differentiation](https://books.google.com/books?id=qMLUIsgCwvUC)\", 2nd Edition, SIAM, 2008.\n", - "- Stelian Coros, Miles Macklin, Bernhard Thomaszewski, Nils Thürey, \"[Differentiable simulation](https://dl.acm.org/doi/abs/10.1145/3476117.3483433)\", SA '21: SIGGRAPH Asia 2021 Courses, 1-142, 2021." + "- Stelian Coros, Miles Macklin, Bernhard Thomaszewski, Nils Thürey, \"[Differentiable simulation](https://dl.acm.org/doi/abs/10.1145/3476117.3483433)\", SA '21: SIGGRAPH Asia 2021 Courses, 1-142, 2021.\n", + "\n", + "For more information on custom Python operators in PyTorch, please take a look at this [link](https://docs.pytorch.org/tutorials/advanced/python_custom_ops.html)." ] } ], @@ -1474,7 +1909,7 @@ "toc_visible": true }, "kernelspec": { - "display_name": ".venv", + "display_name": "warp-cfd", "language": "python", "name": "python3" }, @@ -1488,7 +1923,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.12.3" + "version": "3.11.11" } }, "nbformat": 4,