Skip to content

DeepSeek-V4 CSA eager path may not preserve per-query top-k masking for S > 1 #45758

@kekmodel

Description

@kekmodel

System Info

Observed on main after DeepSeek-V4 support was added in #45643.

Relevant file:

  • src/transformers/models/deepseek_v4/modeling_deepseek_v4.py
  • DeepseekV4CSACompressor.forward
  • DeepseekV4Attention.forward

Who can help?

@ArthurZucker

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder
  • My own task or dataset

Reproduction

This is a code-reading report rather than a failing runtime script. I may be missing an intended invariant in the eager path, but the current tensor shapes look non-equivalent to per-query CSA sparse attention when seq_len > 1.

In DeepseekV4CSACompressor.forward, CSA first obtains per-query top-k compressed KV indices:

topk = self.indexer(hidden_states, q_residual, position_ids, past_key_values, layer_idx)  # [B, S, k]
expanded = compressed_kv.unsqueeze(2).expand(-1, -1, seq_len, -1, -1)                    # [B, 1, S, T, D]
idx = topk.unsqueeze(1).unsqueeze(-1).expand(-1, 1, -1, -1, self.head_dim)               # [B, 1, S, k, D]
return torch.gather(expanded, 3, idx).reshape(batch, 1, -1, self.head_dim)               # [B, 1, S*k, D]

Conceptually, the gathered tensor before reshape is per-query selected compressed KV:

selected[b, 0, t, :, :] = compressed entries selected for query t
shape: [B, 1, S, k, D]

After .reshape(batch, 1, -1, D), the selected entries for all query positions are flattened into one KV axis:

shape: [B, 1, S*k, D]

Then DeepseekV4Attention.forward concatenates those entries after the sliding-window KV branch:

compressed_kv = self.compressor(...)
kv = torch.cat([kv, compressed_kv], dim=2)

if isinstance(attention_mask, torch.Tensor) and kv.shape[2] > attention_mask.shape[-1]:
    attention_mask = F.pad(attention_mask, (0, kv.shape[2] - attention_mask.shape[-1]), value=0.0)

The dense mask is right-padded with 0.0, which appears to make the entire flattened compressed segment visible to every query. For S > 1, that means query t0 can attend to compressed entries that were selected for query t1, t2, etc.

A minimal shape example:

B = 1, S = 3, k = 2

logical selected entries:
  q0 -> [A0, A1]
  q1 -> [B0, B1]
  q2 -> [C0, C1]

after flatten:
  [A0, A1, B0, B1, C0, C1]

if the compressed segment is mask-padded with 0.0:
  q0 can attend to A*, B*, C*
  q1 can attend to A*, B*, C*
  q2 can attend to A*, B*, C*

That seems different from query-specific sparse attention unless an additional block mask maps each query t to only its own flattened segment [t*k : (t+1)*k].

There is a second related question around causal visibility. The paper describes the index score between a query token t and a preceding compressed block s (s < floor(t / m)). In the current eager code, I do not see an explicit query-dependent visible-range mask before:

return index_scores.topk(topk, dim=-1).indices

So for multi-token prefill/training-style forwards, it is not obvious where future-containing compressed blocks are excluded from top-k selection.

Expected behavior

For CSA with seq_len > 1, each query position should attend only to the compressed KV entries selected for that same query, plus the intended sliding-window KV entries.

Equivalently, one of these should hold:

  1. the selected compressed KV remains logically shaped as [B, H_kv, S, k, D] and the attention kernel consumes it per query;
  2. the flattened [B, H_kv, S*k, D] layout is paired with a query-dependent block mask so query t only sees its own segment; or
  3. the eager path is documented/guarded as decode-only for CSA (S == 1) if that is the intended supported usage.

If the current implementation relies on some invariant that makes this equivalent, could you point me to where that masking/visibility constraint is enforced?

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions