jax-cuda12-plugin0.8.0
jax-cuda12-plugin0.8.0
Published
JAX Plugin for NVIDIA GPUs
pip install jax-cuda12-plugin
Package Downloads
Authors
Project URLs
Requires Python
>=3.11
Dependencies
- jax-cuda12-pjrt
==0.8.0 - nvidia-cublas-cu12
>=12.1.3.1; sys_platform == "linux" and extra == "with-cuda" - nvidia-cuda-cupti-cu12
>=12.1.105; sys_platform == "linux" and extra == "with-cuda" - nvidia-cuda-nvcc-cu12
>=12.6.85; sys_platform == "linux" and extra == "with-cuda" - nvidia-cuda-runtime-cu12
>=12.1.105; sys_platform == "linux" and extra == "with-cuda" - nvidia-cudnn-cu12
<10.0,>=9.8; sys_platform == "linux" and extra == "with-cuda" - nvidia-cufft-cu12
>=11.0.2.54; sys_platform == "linux" and extra == "with-cuda" - nvidia-cusolver-cu12
>=11.4.5.107; sys_platform == "linux" and extra == "with-cuda" - nvidia-cusparse-cu12
>=12.1.0.106; sys_platform == "linux" and extra == "with-cuda" - nvidia-nccl-cu12
>=2.18.1; sys_platform == "linux" and extra == "with-cuda" - nvidia-nvjitlink-cu12
>=12.1.105; sys_platform == "linux" and extra == "with-cuda" - nvidia-cuda-nvrtc-cu12
>=12.1.55; sys_platform == "linux" and extra == "with-cuda" - nvidia-nvshmem-cu12
>=3.2.5; sys_platform == "linux" and extra == "with-cuda"