Datarax GPU Testing Guide¤
This document describes how to run Datarax tests on GPU hardware.
Prerequisites¤
- NVIDIA GPU with CUDA support
- CUDA Toolkit 12.x installed
- Python 3.11 virtual environment with JAX GPU support
- Datarax development dependencies installed
Setting Up the Environment¤
The recommended way to set up the environment is to use the main setup script:
# Set up the development environment with automatic GPU detection
./setup.sh
# Activate the virtual environment (loads .env with CUDA configuration)
source activate.sh
This approach:
- Automatically detects NVIDIA GPUs and configures CUDA
- Creates
.envfile with properLD_LIBRARY_PATHfor CUDA libraries - Installs all dependencies including GPU support via
uv sync --extra all - Creates
activate.shthat loads environment configuration
Running GPU Tests¤
We provide a dedicated script for running tests on GPU:
This script will:
- Check for GPU availability
- Set up the required environment variables (
JAX_PLATFORMS=cuda) - Run selected tests with GPU support
Manual GPU Testing¤
If you want more control over which tests to run on GPU, you can:
# Set the environment to use CUDA
export JAX_PLATFORMS="cuda"
# Run all tests with GPU device selection
uv run pytest --device=gpu
# Run a specific test on GPU
uv run pytest --device=gpu tests/operators/
Troubleshooting¤
If you encounter issues with GPU tests:
- Verify GPU is detected:
- Check CUDA installation:
- Memory issues: Adjust memory fraction if tests fail due to OOM errors:
- GPU acceleration not used: Ensure JAX is using the GPU:
How GPU Testing Works¤
The GPU testing infrastructure consists of:
-
Pytest
--deviceOption: Theconftest.pyprovides a--deviceflag that acceptscpu,gpu,tpu, orall. When--device=gpuis specified, TPU-specific tests are skipped. -
Shell Script (
scripts/run_gpu_tests.sh): - Verifies GPU availability using
scripts/check_gpu.py - Sets required environment variables (
JAX_PLATFORMS=cuda) -
Runs pytest with
--device=gpuflag -
Test Markers: Tests can use
@pytest.mark.gpuor@pytest.mark.gpu_requiredto indicate GPU requirements. Currently, most tests run on any device, with only a few explicitly marked as GPU-specific.
Adding New GPU Tests¤
Most Datarax tests are device-agnostic and run on whatever JAX backend is available. Use GPU markers when a test:
- Requires GPU (would fail on CPU): Use
@pytest.mark.gpu_required - Benefits from GPU (runs faster): Use
@pytest.mark.gpu
Example: GPU-Required Test¤
import pytest
import jax
@pytest.mark.gpu_required
def test_multi_gpu_sharding():
"""Test that requires multiple GPU devices."""
devices = jax.devices("gpu")
if len(devices) < 2:
pytest.skip("Requires at least 2 GPUs")
# Test multi-GPU functionality
Example: GPU-Beneficial Test¤
import pytest
@pytest.mark.gpu
def test_large_batch_processing():
"""Test that benefits from GPU acceleration."""
# This test runs on any device but is faster on GPU
pass
Test File Location¤
Place GPU-intensive tests in appropriate directories:
tests/distributed/- Multi-device and sharding teststests/sharding/- Data sharding teststests/benchmarks/- Performance benchmarks
Testing Status¤
The GPU testing infrastructure supports:
- Automatic device detection: Tests adapt to available hardware
- Selective test execution: Use
--device=gputo focus on GPU-relevant tests - Memory management: Environment variables control GPU memory allocation
Running Full GPU Test Suite¤
# Run all tests on GPU
JAX_PLATFORMS=cuda uv run pytest --device=gpu tests/
# Run with memory limits (useful for shared GPUs)
XLA_PYTHON_CLIENT_MEM_FRACTION=0.5 JAX_PLATFORMS=cuda uv run pytest --device=gpu tests/
For more testing information, see the Testing Guide.