Skip to content

Fix deepseek v4#45892

Merged
ArthurZucker merged 16 commits into
mainfrom
fix-deepseek-v4-csa-per-query-mask
May 12, 2026
Merged

Fix deepseek v4#45892
ArthurZucker merged 16 commits into
mainfrom
fix-deepseek-v4-csa-per-query-mask

Conversation

@ArthurZucker
Copy link
Copy Markdown
Collaborator

@ArthurZucker ArthurZucker commented May 11, 2026

What does this PR do?

Attention-mask layout per layer type

Tiny-config visualization (sliding_window=8, CSA m=4, HCA m'=8, index_topk=2, S=16) of the actual cat([sliding_mask, block_bias]) each DeepseekV4Attention layer feeds to eager_attention_forward. Green cells = attended-to, dim slate = masked, red = causally available but the indexer's top-k didn't pick. Wide green blocks in the compressor section bundle m source positions into one KV slot; dashed separators inside the block show which tokens were compressed together.

Sliding-only — plain sliding-window-causal [S, S], no compressor section:

sliding mask

CSA — sliding KV + entry-view of the Lightning Indexer's top-k picks; red cells = available but not picked:

CSA mask

HCA — sliding KV + every compressed entry (no indexer); each C_w summarises m' source positions:

HCA mask

Reproduces with python docs/source/en/imgs/deepseek_v4/visualize_attention_masks.py --svg docs/source/en/imgs/deepseek_v4 from the repo root.

@HuggingFaceDocBuilderDev
Copy link
Copy Markdown

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Sawyer117 added a commit to Sawyer117/transformers that referenced this pull request May 11, 2026
@Sawyer117

This comment was marked as off-topic.

@Sawyer117

This comment was marked as off-topic.

@ArthurZucker
Copy link
Copy Markdown
Collaborator Author

Hey! I don't get how you want to run sdpa when it does not have attention sinks backed in it? You won't ever get good results.

@ArthurZucker ArthurZucker marked this pull request as ready for review May 12, 2026 04:58
@I-hercules
Copy link
Copy Markdown

ArthurZucker Thanks for your hard work.

@ArthurZucker
Copy link
Copy Markdown
Collaborator Author

Sorry everyone I trusted Claude too much

@I-hercules
Copy link
Copy Markdown

Sorry everyone I trusted Claude too much

人之常情,Keep going!

Copy link
Copy Markdown
Member

@Cyrilvallez Cyrilvallez left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ALright, I don't know the specifics of deepseek v4, so just reviewed the general parts, not the exact mathematics.
I'm mostly a bit worried about all the dtype upcasting everywhere, it's quite expensive in general, both for speed and memory. And I believe some of them are actually useless. And when it's performed on weights directly, let's use keep_in_fp32_strict to always have the weights in fp32 instead of upcasting it every forward!

Comment thread src/transformers/models/deepseek_v4/modular_deepseek_v4.py Outdated
Comment thread src/transformers/models/deepseek_v4/modular_deepseek_v4.py Outdated
Comment thread src/transformers/models/deepseek_v4/modular_deepseek_v4.py Outdated
Comment thread src/transformers/models/deepseek_v4/modular_deepseek_v4.py
Comment thread src/transformers/models/deepseek_v4/modular_deepseek_v4.py Outdated
Comment thread src/transformers/models/deepseek_v4/modular_deepseek_v4.py Outdated
@github-actions
Copy link
Copy Markdown
Contributor

[For maintainers] Suggested jobs to run (before merge)

run-slow: deepseek_v4

@ArthurZucker
Copy link
Copy Markdown
Collaborator Author

Checked batch version!

@ArthurZucker ArthurZucker merged commit a1b77cc into main May 12, 2026
41 checks passed
@ArthurZucker ArthurZucker deleted the fix-deepseek-v4-csa-per-query-mask branch May 12, 2026 08:26
@Sawyer117
Copy link
Copy Markdown
Contributor

Hey! I don't get how you want to run sdpa when it does not have attention sinks backed in it? You won't ever get good results.

Hi sorry, the sdpa test was just comparing the two mask forms apples-to-apples on a non-sink-aware backend, but anyway please ignore sdpa, the main point I want to make, is that the [S, S*top_k] mask is unnecessary; [S, compressed_len] is math-identical with better performance(both speed and mem).

Current (DeepseekV4CSACompressor.forward ):

gathered = flat_kv.index_select(0, flat_indices).view(batch, 1, -1, self.head_dim)
block_bias = gathered.new_full((batch, 1, seq_len, seq_len, top_k), float("-inf"))
allowed = torch.where(valid, gathered.new_zeros(()), gathered.new_full((), float("-inf")))
block_bias[:, 0, query_indices, query_indices, :] = allowed
return gathered, block_bias.view(batch, 1, seq_len, seq_len * top_k)

Alt:

safe_indices = torch.where(valid, top_k_indices, torch.full_like(top_k_indices, compressed_len))
block_bias = compressed_kv.new_full((batch, 1, seq_len, compressed_len + 1), float("-inf"))
block_bias.scatter_(-1, safe_indices.unsqueeze(1), 0.0)
return compressed_kv, block_bias[..., :compressed_len]

Each query's softmax has the same top_k non--inf entries either way (same indices into the compressed_kv axis, same dot-product values). Only the row width differs:
dense S*top_k mostly filled with -inf vs compressed_len with top_k nonzeros.

ArthurZucker added a commit that referenced this pull request May 13, 2026
* up

* small updates

* a mistery how this got through

* updates

* update

* final fixes?

* update

* up

* yup

* last nit!

* up

* update

* revert shitty AI work

* up

* yyups
@ArjunSrivastava1
Copy link
Copy Markdown

Regarding the issue i mentioned, its confirmed, not opened by me, flagged by another, but ive been looking at it, the paper and current implementation

A brief summary is this for the current forward cell block this is occuring:
Block: [E, F, G, PAD]

Softmax over all 4 positions (including PAD)

PAD contributes to compressed entry

Polluted KV entry → incorrect attention

further details are in the issue itself, well v4 is after all a model unlike others, cant be helped that even after fixes issues happen

@ArjunSrivastava1
Copy link
Copy Markdown

Seeing as theres also a bit of lack of diagrams in the very PR focused to solve v4 issues and i imagine we may need them, again, im attaching a few for everyones reference here too from the paper itself

Fig 1: CSA architecture
Screenshot 2026-05-14 at 4 02 09 PM
Fig 2: HCA architecture
Screenshot 2026-05-14 at 4 02 25 PM

This provides a more complete picture for whoever opens and uses this pr

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

for patch Tag issues / labels that should be included in the next patch

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants