Skip to content

Gather tensor-parallel sharded parameters on read in a trace#677

Open
khaiwang wants to merge 2 commits into
devfrom
fix/vllm-tp-param-gather
Open

Gather tensor-parallel sharded parameters on read in a trace#677
khaiwang wants to merge 2 commits into
devfrom
fix/vllm-tp-param-gather

Conversation

@khaiwang

Copy link
Copy Markdown
Contributor

Problem

Under tensor parallelism vLLM shards parameters across ranks — lm_head/embeddings are vocab-sharded, attention/MLP projections are output- or input-sharded. nnsight handed intervention code the local shard, so reading a parameter inside a trace returned the wrong values on a rank that doesn't own them.

Concretely, a steering cell builds its direction from lm_head.weight[token_id]. Measured on Qwen2.5-0.5B under tp=2: on the rank that doesn't own the token, that index returned a different vocab row (lm_head.weight.shape[0] was 75968 = half-vocab in the trace), so the steered output diverged from single-GPU (maxabs 26, the wrong global token surfacing in the top-5) — silently.

Fix

The parameter analogue of the existing activation gather (VLLMBatcher already gathers RowParallelLinear/ColumnParallelLinear I/O):

  • Envoy.__getattr__ routes a tensor attribute read through interleaver.batcher.gather_param only while interleaving (so the collective all_gather fires on every rank; tp=1 / non-vLLM are untouched — the base Batcher.gather_param is identity).
  • VLLMBatcher.gather_param all-gathers the shard to its full logical shape. The sharded dim comes from the module classRowParallelLinear → input dim, ColumnParallelLinear/VocabParallelEmbedding → output/vocab dim — because vLLM sets both output_dim and input_dim on every linear weight (they label the dims, not which is sharded). Vocab padding is stripped to org_vocab_size.

Read-only; tp=1 is byte-identical.

Verification (Qwen2.5-0.5B, tp=2)

  • lm_head, qkv_proj, gate_up_proj, o_proj, down_proj all gather to the full tp=1 shape with matching Frobenius norms.
  • A steering write reading lm_head.weight[token_id] goes from divergent (maxabs 26) to equivalent.
  • New tests/test_tp_param_gather.py (full-vocab lm_head, upper-shard token row, row+column parallel weights); the pre-existing tests/test_tp_stream_fix.py still passes. Run with pytest tests/test_tp_param_gather.py --tp 2.

🤖 Generated with Claude Code

https://claude.ai/code/session_015y5Sy9vzzc9YJZXCtewSdQ

khaiwang and others added 2 commits June 5, 2026 01:00
…red across invokes)

barrier() is broken on the vLLM path (reproduces at tp1/pp1, non-PP): each invoke
is serialized into its own globals, so each gets a private copy of the Barrier with
its own participants set. The count never reaches n, both invokes take the no-op
send(BARRIER, None) branch, the workers block at the barrier and are abandoned, and
all post-barrier code is silently dropped. Diagnosed (instrumentation reverted);
fix is a design choice (interleaver-owned barrier registry keyed by a
serialization-stable id, preferred; or graft the Barrier into canonical globals).

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Under tensor parallelism vLLM shards parameters across ranks (lm_head/embeddings are
vocab-sharded; attention/MLP projections are output- or input-sharded). nnsight handed
intervention code the local shard, so reading a parameter inside a trace -- e.g.
`lm_head.weight[token_id]` to build a steering direction -- returned the wrong row on a
rank that does not own that token, silently diverging from single-GPU.

This is the parameter analogue of the existing activation gather (VLLMBatcher gathers
RowParallelLinear/ColumnParallelLinear I/O). Parameter reads now route through the batcher:
Envoy.__getattr__ delegates a tensor attribute to interleaver.batcher.gather_param while
interleaving (so the collective fires on every rank); the base Batcher returns it unchanged
(non-vLLM and tp=1 untouched), and VLLMBatcher all-gathers the shard to its full logical
shape. The sharded dim comes from the module class (RowParallelLinear -> input dim;
ColumnParallelLinear/VocabParallelEmbedding -> output/vocab dim), because vLLM sets BOTH
output_dim and input_dim on every linear weight (they label the dims, not which is sharded);
vocab padding is stripped to org_vocab_size.

Verified on Qwen2.5-0.5B tp=2: lm_head, qkv_proj, gate_up_proj, o_proj and down_proj all
gather to the full tp=1 shape with matching norms; a steering write reading lm_head[token_id]
goes from divergent (maxabs 26, wrong global token in the top-5) to equivalent. Regression
tests in tests/test_tp_param_gather.py.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Claude-Session: https://claude.ai/code/session_015y5Sy9vzzc9YJZXCtewSdQ
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant