I’ve spent the last few weeks deconstructing FlashAttention. While the original paper is brilliant, I found that just reading it didn't give me a "gut feeling" for why certain engineering choices were made (the transition from v1 to v2).
I decided to rebuild it from scratch using Triton. This post is a chronicle of that journey—moving beyond the high-level algorithm and into the "performance archaeology" of the GPU:
- Profiling with Nsight Compute to find the real bottlenecks.
- Looking at the generated PTX and SASS code.
- Debugging shared memory bank conflicts and MIO bottlenecks.
- Iterating through the logic to see why tiling and online softmax are hardware-necessitated, not just mathematical tricks.
I’ve tried to keep it in the spirit of Simon Boehm’s matmul deep dive. Would love to hear from any GPU engineers on whether my interpretations of the SASS/bank conflict behavior match what you've seen in production.
My question is partly rhetorical - I know the answer lies with the tight research and mathematical origins. But that makes it research code IMO, not what I would consider high quality software code.
I think it's a combination of multiple factors. I worked with GPU kernel codes before and the code that you write has a tendency of never being updated or modified. once it works it works perfectly and you do not change it. if you get new hardware you're going to fully rewrite it. so, typically readability is just not useful. also, you're never working with variables that make sense to humans. it's never something tangible. it's always tiles, offsets, indices. i do not think, at least when I was writing the code for GPUS to waste space visual space on better variable naming was worthwhile.
I did an experiment on FlashAttention in Triton to measure the impact of caching tiles in the Shared Memory. Surprisingly, it had a non-monotonic relationship with prefetching these tiles and it was kernel dependent. Attention kernel benefits from prefetching caches while MLP W1 doesn't.
I decided to rebuild it from scratch using Triton. This post is a chronicle of that journey—moving beyond the high-level algorithm and into the "performance archaeology" of the GPU:
- Profiling with Nsight Compute to find the real bottlenecks.
- Looking at the generated PTX and SASS code.
- Debugging shared memory bank conflicts and MIO bottlenecks.
- Iterating through the logic to see why tiling and online softmax are hardware-necessitated, not just mathematical tricks.
I’ve tried to keep it in the spirit of Simon Boehm’s matmul deep dive. Would love to hear from any GPU engineers on whether my interpretations of the SASS/bank conflict behavior match what you've seen in production.
It’s the equivalent of doing this for compound interest rate calculation:
# A = P * (1 + r/n)^(nt) P = 10000 r = 0.06 n = 12 t = 5 A = P (1 + r / n) * (n * t)
Compared to this:
principal = 10_000 annual_interest_rate = 0.06 compounds_per_year = 12 years = 5
future_value = principal * (1 + annual_interest_rate / compounds_per_year) * (compounds_per_year * years)
My question is partly rhetorical - I know the answer lies with the tight research and mathematical origins. But that makes it research code IMO, not what I would consider high quality software code.