jax-cuda12-plugin0.4.38
jax-cuda12-plugin0.4.38
Published
JAX Plugin for NVIDIA GPUs
pip install jax-cuda12-plugin
Package Downloads
Authors
Project URLs
Requires Python
>=3.10
Dependencies
- jax-cuda12-pjrt
==0.4.38
- nvidia-cublas-cu12
>=12.1.3.1; extra == "with-cuda"
- nvidia-cuda-cupti-cu12
>=12.1.105; extra == "with-cuda"
- nvidia-cuda-nvcc-cu12
>=12.6.85; extra == "with-cuda"
- nvidia-cuda-runtime-cu12
>=12.1.105; extra == "with-cuda"
- nvidia-cudnn-cu12
<10.0,>=9.1; extra == "with-cuda"
- nvidia-cufft-cu12
>=11.0.2.54; extra == "with-cuda"
- nvidia-cusolver-cu12
>=11.4.5.107; extra == "with-cuda"
- nvidia-cusparse-cu12
>=12.1.0.106; extra == "with-cuda"
- nvidia-nccl-cu12
>=2.18.1; extra == "with-cuda"
- nvidia-nvjitlink-cu12
>=12.1.105; extra == "with-cuda"