Skip to content

Conversation

Mogball
Copy link
Collaborator

@Mogball Mogball commented Jun 24, 2025

Rewrite the attention kernel to be persistent. This gives better performance at low-contexts. However, fp16 at large context has suffered a bit due to a ptxas instruction scheduling issue in the softmax partition. fp8 is ~100 tflops faster when the kernel name has "cutlass" in it.

Attention Z=4 H=32 D=64 causal=False:
     N_CTX  triton-fp16  triton-fp8
0   1024.0   359.574448  370.119987
1   2048.0   612.103928  641.204555
2   4096.0   653.868402  682.337948
3   8192.0   692.102228  721.555690
4  16384.0   696.972041  726.190035
5  32768.0   698.723685  727.983456
6  65536.0   699.865817  728.558321
Attention Z=4 H=32 D=64 causal=True:
     N_CTX  triton-fp16  triton-fp8
0   1024.0   181.879039  177.982453
1   2048.0   441.315463  454.310072
2   4096.0   532.170527  539.995252
3   8192.0   633.620646  638.544937
4  16384.0   667.687180  670.681255
5  32768.0   684.276329  688.571907
6  65536.0   692.953202  694.648353
Attention Z=4 H=32 D=128 causal=False:
     N_CTX  triton-fp16   triton-fp8
0   1024.0   718.580015   709.863720
1   2048.0  1133.490258  1222.548477
2   4096.0  1247.605551  1369.800195
3   8192.0  1243.482713  1406.799697
4  16384.0  1125.744367  1514.857403
5  32768.0  1124.116305  1521.267973
6  65536.0  1064.588719  1518.738037
Attention Z=4 H=32 D=128 causal=True:
     N_CTX  triton-fp16   triton-fp8
0   1024.0   355.642522   351.161232
1   2048.0   846.404095   854.547917
2   4096.0  1013.840017  1021.676435
3   8192.0  1176.258395  1152.844234
4  16384.0  1190.290681  1325.786204
5  32768.0  1063.658200  1394.413325
6  65536.0   970.531569  1413.282610

Copy link
Collaborator

@ThomasRaoux ThomasRaoux left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wow!

@Mogball
Copy link
Collaborator Author

Mogball commented Jul 9, 2025

For posterity, these are the best results prior to converting the kernel to persistent

Attention Z=4 H=32 D=64 causal=False:
     N_CTX  triton-fp16  triton-fp8  cudnn-fp16
0   1024.0   382.890516  412.882937  564.837281
1   2048.0   564.572331  613.796259  802.967790
2   4096.0   651.779895  711.258057  890.712337
3   8192.0   718.153704  786.906882  940.327292
4  16384.0   746.519007  815.458990  944.850564
5  32768.0   758.416978  830.055643  939.287109
6  65536.0   766.176045  837.979637  925.739296
Attention Z=4 H=32 D=64 causal=True:
     N_CTX  triton-fp16  triton-fp8  cudnn-fp16
0   1024.0   181.758722  190.690577  381.906099
1   2048.0   313.497481  326.500187  626.967949
2   4096.0   463.538482  483.472677  777.606926
3   8192.0   586.226812  618.900682  805.998776
4  16384.0   683.741305  708.737060  853.282336
5  32768.0   734.844555  762.845981  912.526865
6  65536.0   767.292419  793.280126  924.780010
Attention Z=4 H=32 D=128 causal=False:
     N_CTX  triton-fp16   triton-fp8   cudnn-fp16
0   1024.0   655.417393   730.561798   926.859512
1   2048.0   970.734621  1057.033298  1267.867719
2   4096.0  1118.226666  1210.507191  1428.037959
3   8192.0  1182.149430  1332.290127  1488.733746
4  16384.0  1227.000687  1372.951364  1358.394870
5  32768.0  1254.096611  1409.254506  1314.970965
6  65536.0  1231.680630  1426.040751  1313.822094
Attention Z=4 H=32 D=128 causal=True:
     N_CTX  triton-fp16   triton-fp8   cudnn-fp16
0   1024.0   312.399981   345.242273   553.117042
1   2048.0   534.902248   590.947330   877.822759
2   4096.0   782.786229   871.178240  1122.610667
3   8192.0   961.045037  1105.459197  1319.639575
4  16384.0  1114.273933  1256.257370  1317.439900
5  32768.0  1192.714280  1341.112079  1275.191200
6  65536.0  1195.453344  1386.801400  1310.388518

@Jokeren
Copy link
Contributor

Jokeren commented Jul 9, 2025

Rewrite the attention kernel to be persistent. This gives better performance at low-contexts. However, fp16 at large context has suffered a bit due to a ptxas instruction scheduling issue in the softmax partition. fp8 is ~100 tflops faster when the kernel name has "cutlass" in it.

Attention Z=4 H=32 D=64 causal=False:
     N_CTX  triton-fp16  triton-fp8
0   1024.0   359.574448  370.119987
1   2048.0   612.103928  641.204555
2   4096.0   653.868402  682.337948
3   8192.0   692.102228  721.555690
4  16384.0   696.972041  726.190035
5  32768.0   698.723685  727.983456
6  65536.0   699.865817  728.558321
Attention Z=4 H=32 D=64 causal=True:
     N_CTX  triton-fp16  triton-fp8
0   1024.0   181.879039  177.982453
1   2048.0   441.315463  454.310072
2   4096.0   532.170527  539.995252
3   8192.0   633.620646  638.544937
4  16384.0   667.687180  670.681255
5  32768.0   684.276329  688.571907
6  65536.0   692.953202  694.648353
Attention Z=4 H=32 D=128 causal=False:
     N_CTX  triton-fp16   triton-fp8
0   1024.0   718.580015   709.863720
1   2048.0  1133.490258  1222.548477
2   4096.0  1247.605551  1369.800195
3   8192.0  1243.482713  1406.799697
4  16384.0  1125.744367  1514.857403
5  32768.0  1124.116305  1521.267973
6  65536.0  1064.588719  1518.738037
Attention Z=4 H=32 D=128 causal=True:
     N_CTX  triton-fp16   triton-fp8
0   1024.0   355.642522   351.161232
1   2048.0   846.404095   854.547917
2   4096.0  1013.840017  1021.676435
3   8192.0  1176.258395  1152.844234
4  16384.0  1190.290681  1325.786204
5  32768.0  1063.658200  1394.413325
6  65536.0   970.531569  1413.282610

I don't see a "cutlass" in the kernel names?

@Mogball
Copy link
Collaborator Author

Mogball commented Jul 9, 2025

def attention_repr(specialization): 
     name = "gluon_attention" 
     # Up to 150 TFLOPS faster for fp8! 
     if specialization.constants["dtype"] == gl.float8e5: 
         name = "cutlass_" + name 
     return name

@Jokeren
Copy link
Contributor

Jokeren commented Jul 9, 2025

Before:

Attention Z=4 H=32 D=64 causal=False:
     N_CTX  triton-fp16  triton-fp8  cudnn-fp16
0   1024.0   382.890516  412.882937  564.837281
1   2048.0   564.572331  613.796259  802.967790
2   4096.0   651.779895  711.258057  890.712337
3   8192.0   718.153704  786.906882  940.327292
4  16384.0   746.519007  815.458990  944.850564
5  32768.0   758.416978  830.055643  939.287109

After

Attention Z=4 H=32 D=64 causal=False:
     N_CTX  triton-fp16  triton-fp8
0   1024.0   359.574448  370.119987
1   2048.0   612.103928  641.204555
2   4096.0   653.868402  682.337948
3   8192.0   692.102228  721.555690
4  16384.0   696.972041  726.190035
5  32768.0   698.723685  727.983456
6  65536.0   699.865817  728.558321

I'm not sure if I interpreted it incorrectly, but seems like perf is dropped based on the numbers?

Copy link
Contributor

@peterbell10 peterbell10 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great stuff. Couple small NITs though.

_, corr_bar, corr_producer = corr_producer.acquire()

p = gl.join(p0, p1).permute(0, 2, 1).reshape([config.SPLIT_M, config.BLOCK_N])
p = gl.convert_layout(p, config.qk_layout)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This shouldn't be needed any more after I introduced the slice layout for split, right?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The convert layout coming out of the split is no longer needed, but

ValueError('Layout mismatch in broadcast: 

SliceLayout(dim=1, parent=BlockedLayout(size_per_thread=[1, 128], threads_per_warp=[32, 1], warps_per_cta=[4, 1], order=[0, 1], ctas_per_cga=[1, 1], cta_split_num=[1, 1], cta_order=[1, 0])) 
vs 
SliceLayout(dim=1, parent=DistributedLinearLayout(reg_bases=[[0, 64], [0, 1], [0, 2], [0, 4], [0, 8], [0, 16], [0, 32]], lane_bases=[[1, 0], [2, 0], [4, 0], [8, 0], [16, 0]], warp_bases=[[32, 0], [64, 0]], block_bases=[], shape=[128, 128]))')

It seems that p ends up with a linear layout instead of a blocked layout. I am not sure why though -- I believe the layout inference should try a blocked layout first before falling back to linear layout.

name = "gluon_attention"
# Up to 150 TFLOPS faster for fp8!
if specialization.constants["dtype"] == gl.float8e5:
name = "cutlass_" + name
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

very cool... did you check if other names change the scheduling (e.g. because of non-determinism or code alignment) or if it's literally just special cased for cutlass.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's literally just special cased for cutlass.

Yup

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wow! You literally beat the nvcc team!

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@AlexMaclean Just a FYI, in case you can prod the right folks on your side. There must be a better way to enable this optimization. A PTX directive, perhaps, if ptxas can't figure out the right thing by itself?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Mogball have you checked the accuracy, is it the same? The Deepseek technical report mentioned that fp8 tensor cores use reduced mantissa for the accumulator, maybe this is what indirectly enabled/disabled by the name of the kernel.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The Deepseek technical report mentioned that fp8 tensor cores use reduced mantissa for the accumulator, maybe this is what indirectly enabled/disabled by the name of the kernel.

That's only on Hopper

Copy link

@yhx-12243 yhx-12243 Jul 12, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

By disassembly of ptxas, it is indeed hard-coded that they have logic like strstr(kernel_name, "cutlass").

a

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

By disassembly of ptxas, it is indeed hard-coded that they have logic like strstr(kernel_name, "cutlass").

That's Interesting! I'm curious is it feasible to modifty asm code for ptxas that make the al return register always be true (maybe we could modify code in the address between 2165-216c), did you have a try?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Admittedly it is feasible. But it is more likely that, this is an unstable, experimental, aggressive optimization by NVIDIA, and blindly always enabling it may produce some elusive bugs.

@Mogball
Copy link
Collaborator Author

Mogball commented Jul 9, 2025

Before:

Attention Z=4 H=32 D=64 causal=False:
     N_CTX  triton-fp16  triton-fp8  cudnn-fp16
0   1024.0   382.890516  412.882937  564.837281
1   2048.0   564.572331  613.796259  802.967790
2   4096.0   651.779895  711.258057  890.712337
3   8192.0   718.153704  786.906882  940.327292
4  16384.0   746.519007  815.458990  944.850564
5  32768.0   758.416978  830.055643  939.287109

After

Attention Z=4 H=32 D=64 causal=False:
     N_CTX  triton-fp16  triton-fp8
0   1024.0   359.574448  370.119987
1   2048.0   612.103928  641.204555
2   4096.0   653.868402  682.337948
3   8192.0   692.102228  721.555690
4  16384.0   696.972041  726.190035
5  32768.0   698.723685  727.983456
6  65536.0   699.865817  728.558321

I'm not sure if I interpreted it incorrectly, but seems like perf is dropped based on the numbers?

For D64 it did drop quite a bit during the transition to persistent. This is due to a scheduling issue in ptxas that I couldn't find a workaround for.

Mogball added 2 commits July 9, 2025 09:58
@Mogball Mogball enabled auto-merge (squash) July 9, 2025 18:19
@Mogball Mogball merged commit ade3d49 into main Jul 9, 2025
9 checks passed
@Mogball Mogball deleted the mogball/persistent branch July 9, 2025 18:24
@hyz0906 hyz0906 mentioned this pull request Jul 11, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet