System Info
Observed on main after DeepSeek-V4 support was added in #45643.
Relevant file:
src/transformers/models/deepseek_v4/modeling_deepseek_v4.py
DeepseekV4CSACompressor.forward
DeepseekV4Attention.forward
Who can help?
@ArthurZucker
Information
Tasks
Reproduction
This is a code-reading report rather than a failing runtime script. I may be missing an intended invariant in the eager path, but the current tensor shapes look non-equivalent to per-query CSA sparse attention when seq_len > 1.
In DeepseekV4CSACompressor.forward, CSA first obtains per-query top-k compressed KV indices:
topk = self.indexer(hidden_states, q_residual, position_ids, past_key_values, layer_idx) # [B, S, k]
expanded = compressed_kv.unsqueeze(2).expand(-1, -1, seq_len, -1, -1) # [B, 1, S, T, D]
idx = topk.unsqueeze(1).unsqueeze(-1).expand(-1, 1, -1, -1, self.head_dim) # [B, 1, S, k, D]
return torch.gather(expanded, 3, idx).reshape(batch, 1, -1, self.head_dim) # [B, 1, S*k, D]
Conceptually, the gathered tensor before reshape is per-query selected compressed KV:
selected[b, 0, t, :, :] = compressed entries selected for query t
shape: [B, 1, S, k, D]
After .reshape(batch, 1, -1, D), the selected entries for all query positions are flattened into one KV axis:
Then DeepseekV4Attention.forward concatenates those entries after the sliding-window KV branch:
compressed_kv = self.compressor(...)
kv = torch.cat([kv, compressed_kv], dim=2)
if isinstance(attention_mask, torch.Tensor) and kv.shape[2] > attention_mask.shape[-1]:
attention_mask = F.pad(attention_mask, (0, kv.shape[2] - attention_mask.shape[-1]), value=0.0)
The dense mask is right-padded with 0.0, which appears to make the entire flattened compressed segment visible to every query. For S > 1, that means query t0 can attend to compressed entries that were selected for query t1, t2, etc.
A minimal shape example:
B = 1, S = 3, k = 2
logical selected entries:
q0 -> [A0, A1]
q1 -> [B0, B1]
q2 -> [C0, C1]
after flatten:
[A0, A1, B0, B1, C0, C1]
if the compressed segment is mask-padded with 0.0:
q0 can attend to A*, B*, C*
q1 can attend to A*, B*, C*
q2 can attend to A*, B*, C*
That seems different from query-specific sparse attention unless an additional block mask maps each query t to only its own flattened segment [t*k : (t+1)*k].
There is a second related question around causal visibility. The paper describes the index score between a query token t and a preceding compressed block s (s < floor(t / m)). In the current eager code, I do not see an explicit query-dependent visible-range mask before:
return index_scores.topk(topk, dim=-1).indices
So for multi-token prefill/training-style forwards, it is not obvious where future-containing compressed blocks are excluded from top-k selection.
Expected behavior
For CSA with seq_len > 1, each query position should attend only to the compressed KV entries selected for that same query, plus the intended sliding-window KV entries.
Equivalently, one of these should hold:
- the selected compressed KV remains logically shaped as
[B, H_kv, S, k, D] and the attention kernel consumes it per query;
- the flattened
[B, H_kv, S*k, D] layout is paired with a query-dependent block mask so query t only sees its own segment; or
- the eager path is documented/guarded as decode-only for CSA (
S == 1) if that is the intended supported usage.
If the current implementation relies on some invariant that makes this equivalent, could you point me to where that masking/visibility constraint is enforced?
System Info
Observed on
mainafter DeepSeek-V4 support was added in #45643.Relevant file:
src/transformers/models/deepseek_v4/modeling_deepseek_v4.pyDeepseekV4CSACompressor.forwardDeepseekV4Attention.forwardWho can help?
@ArthurZucker
Information
Tasks
examplesfolderReproduction
This is a code-reading report rather than a failing runtime script. I may be missing an intended invariant in the eager path, but the current tensor shapes look non-equivalent to per-query CSA sparse attention when
seq_len > 1.In
DeepseekV4CSACompressor.forward, CSA first obtains per-query top-k compressed KV indices:Conceptually, the gathered tensor before reshape is per-query selected compressed KV:
After
.reshape(batch, 1, -1, D), the selected entries for all query positions are flattened into one KV axis:Then
DeepseekV4Attention.forwardconcatenates those entries after the sliding-window KV branch:The dense mask is right-padded with
0.0, which appears to make the entire flattened compressed segment visible to every query. ForS > 1, that means queryt0can attend to compressed entries that were selected for queryt1,t2, etc.A minimal shape example:
That seems different from query-specific sparse attention unless an additional block mask maps each query
tto only its own flattened segment[t*k : (t+1)*k].There is a second related question around causal visibility. The paper describes the index score between a query token
tand a preceding compressed blocks(s < floor(t / m)). In the current eager code, I do not see an explicit query-dependent visible-range mask before:So for multi-token prefill/training-style forwards, it is not obvious where future-containing compressed blocks are excluded from top-k selection.
Expected behavior
For CSA with
seq_len > 1, each query position should attend only to the compressed KV entries selected for that same query, plus the intended sliding-window KV entries.Equivalently, one of these should hold:
[B, H_kv, S, k, D]and the attention kernel consumes it per query;[B, H_kv, S*k, D]layout is paired with a query-dependent block mask so querytonly sees its own segment; orS == 1) if that is the intended supported usage.If the current implementation relies on some invariant that makes this equivalent, could you point me to where that masking/visibility constraint is enforced?