Skip to content

Commit ed8db98

Browse files
authored
Add cupy backend for GPU support (#1048)
1 parent 9fe6163 commit ed8db98

24 files changed

+4759
-5183
lines changed

.github/workflows/extended_tests.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ jobs:
4949
5050
- name: Compile Rust extension (no wheel)
5151
run: |
52-
pixi r maturin-develop
52+
pixi r -e dev maturin-develop
5353
5454
- name: Run long tests with coverage
5555
run: pixi run tests-extended

README.md

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,25 @@ python -m pip install pyfixest[plots]
8686

8787
Note that matplotlib is included by default, so you can always use the matplotlib backend for plotting even without installing the optional lets-plot dependency.
8888

89+
### GPU Acceleration (Optional)
90+
91+
PyFixest supports GPU-accelerated fixed effects demeaning via CuPy. To enable GPU acceleration, install CuPy matching your CUDA version:
92+
93+
```bash
94+
# For CUDA 11.x, 12.x, 13.x
95+
pip install cupy-cuda11x
96+
pip install cupy-cuda12x
97+
pip install cupy-cuda13x
98+
```
99+
100+
Once installed, you can use GPU-accelerated demeaning by setting the `demean_backend` parameter:
101+
102+
```python
103+
# Use GPU with float32 and float64 precision
104+
pf.feols("Y ~ X1 | f1 + f2", data=data, demean_backend="cupy32")
105+
pf.feols("Y ~ X1 | f1 + f2", data=data, demean_backend="cupy64")
106+
```
107+
89108
## Benchmarks
90109

91110
All benchmarks follow the [fixest

benchmarks/complex_benchmarks.png

114 KB
Loading

benchmarks/gpu_benchmarks.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
}
1313
},
1414
"source": [
15-
"# PyFixest on the GPU \n",
15+
"# PyFixest on the GPU via JAX\n",
1616
"\n",
1717
"Through its JAX integration, it is possible to run PyFixest on the GPU. In this notebook, we benchmark the performance of PyFixest on the GPU via its \n",
1818
"`jax` backend and compare it to the performance of PyFixest on the CPU (via the default `numba` backend). \n",

0 commit comments

Comments
 (0)