-
Notifications
You must be signed in to change notification settings - Fork 2.3k
[Gluon][Tutorial] Persistent attention #7298
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
wow!
For posterity, these are the best results prior to converting the kernel to persistent
|
I don't see a "cutlass" in the kernel names? |
|
Before:
After
I'm not sure if I interpreted it incorrectly, but seems like perf is dropped based on the numbers? |
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great stuff. Couple small NITs though.
_, corr_bar, corr_producer = corr_producer.acquire() | ||
|
||
p = gl.join(p0, p1).permute(0, 2, 1).reshape([config.SPLIT_M, config.BLOCK_N]) | ||
p = gl.convert_layout(p, config.qk_layout) |
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This shouldn't be needed any more after I introduced the slice layout for split, right?
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The convert layout coming out of the split
is no longer needed, but
ValueError('Layout mismatch in broadcast:
SliceLayout(dim=1, parent=BlockedLayout(size_per_thread=[1, 128], threads_per_warp=[32, 1], warps_per_cta=[4, 1], order=[0, 1], ctas_per_cga=[1, 1], cta_split_num=[1, 1], cta_order=[1, 0]))
vs
SliceLayout(dim=1, parent=DistributedLinearLayout(reg_bases=[[0, 64], [0, 1], [0, 2], [0, 4], [0, 8], [0, 16], [0, 32]], lane_bases=[[1, 0], [2, 0], [4, 0], [8, 0], [16, 0]], warp_bases=[[32, 0], [64, 0]], block_bases=[], shape=[128, 128]))')
It seems that p
ends up with a linear layout instead of a blocked layout. I am not sure why though -- I believe the layout inference should try a blocked layout first before falling back to linear layout.
name = "gluon_attention" | ||
# Up to 150 TFLOPS faster for fp8! | ||
if specialization.constants["dtype"] == gl.float8e5: | ||
name = "cutlass_" + name |
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
very cool... did you check if other names change the scheduling (e.g. because of non-determinism or code alignment) or if it's literally just special cased for cutlass.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it's literally just special cased for cutlass.
Yup
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
wow! You literally beat the nvcc team!
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@AlexMaclean Just a FYI, in case you can prod the right folks on your side. There must be a better way to enable this optimization. A PTX directive, perhaps, if ptxas can't figure out the right thing by itself?
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@Mogball have you checked the accuracy, is it the same? The Deepseek technical report mentioned that fp8 tensor cores use reduced mantissa for the accumulator, maybe this is what indirectly enabled/disabled by the name of the kernel.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The Deepseek technical report mentioned that fp8 tensor cores use reduced mantissa for the accumulator, maybe this is what indirectly enabled/disabled by the name of the kernel.
That's only on Hopper
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
By disassembly of
ptxas
, it is indeed hard-coded that they have logic likestrstr(kernel_name, "cutlass")
.
That's Interesting! I'm curious is it feasible to modifty asm code for ptxas
that make the al
return register always be true (maybe we could modify code in the address between 2165-216c
), did you have a try?
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Admittedly it is feasible. But it is more likely that, this is an unstable, experimental, aggressive optimization by NVIDIA, and blindly always enabling it may produce some elusive bugs.
For D64 it did drop quite a bit during the transition to persistent. This is due to a scheduling issue in ptxas that I couldn't find a workaround for. |
Rewrite the attention kernel to be persistent. This gives better performance at low-contexts. However, fp16 at large context has suffered a bit due to a ptxas instruction scheduling issue in the softmax partition. fp8 is ~100 tflops faster when the kernel name has "cutlass" in it.