Skip to content

Disable JAX GPU preallocation by default#1

Open
YonghaoZhao722 wants to merge 2 commits into
raphael-group:mainfrom
YonghaoZhao722:main
Open

Disable JAX GPU preallocation by default#1
YonghaoZhao722 wants to merge 2 commits into
raphael-group:mainfrom
YonghaoZhao722:main

Conversation

@YonghaoZhao722

@YonghaoZhao722 YonghaoZhao722 commented May 8, 2026

Copy link
Copy Markdown

Summary

Disable JAX GPU memory preallocation by default for MGW.

Motivation

MGW uses both PyTorch and JAX/OTT. By default, JAX preallocates a large fraction of GPU memory on first use, even when the actual working set is much smaller. This can make MGW appear to consume excessive GPU memory and may cause unnecessary conflicts with PyTorch, other processes, or shared-GPU environments.

Disabling preallocation makes GPU memory usage more proportional to the actual workload while preserving the existing MGW algorithm and benchmark behavior.

Changes

  • Set XLA_PYTHON_CLIENT_PREALLOCATE=false before JAX/OTT initialization.
  • Use os.environ.setdefault(...) so user-provided allocator settings are not overwritten.
  • Keep the change limited to runtime memory allocation behavior; no algorithmic logic is changed.
  • Allow users to override this behavior through their own environment variables if needed.

Test plan

  • Verified that JAX/OTT operations still run correctly without preallocation.
  • Verified MGW benchmark scripts run successfully with the new default.
  • Tested in limited/shared GPU memory settings to confirm reduced unnecessary GPU reservation.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant