if you've read the og siboehm's gemm kernel optimization blog, my solution will feel right at home.
my solution can be broken down into three main parts: loading data from global memory to lds, reading from lds to registers and performing mfma, and storing data back to global memory.
to utilize memory bandwidth as much as possible, we want the tile size to be as large as possible. i went with a 128x128 tile for both A and B along with double buffering. this takes up the entire 65 kib of lds:
the goal here is to do coalesced reads from global and vectorized writes to shared. since mfma expects data packed along the k dimension, we need our data in lds to be packed accordingly. however, our inputs are mn-major, so we have to transpose the data when storing to lds.
i tried several ways of loading from global to shared, but the one that
gave me the best performance was reading a 4x4 tile per thread using 4
global_load_dword
per row, transposing
it in registers, and then storing it to lds.
naively doing this leads to a lot of bank conflicts since contiguous threads write to the same column, which lies in the same bank. we also need to load 8 bytes contiguously because that's what mfma expects. additionally, notice that contiguous threads write to rows that are 4 apart. so i came up with a two-level swizzling strategy where i first swizzle 8 bytes of data in groups of 32 rows to the next bank, and then swizzle 8 bytes of data to the right along each row. due to the constraint of 8-byte alignment, the swizzle isn't completely bank conflict-free. i believe there are still 2-way bank conflicts.
i also tried 4-byte alignment for lds, which allows for a bank
conflict-free swizzling strategy. however, while reading from lds, the
compiler emits ds_read2_b32
instead of
the more efficient ds_read_b64
, which
neutralizes the gains from eliminating bank conflicts. from my
experiments, both 4-byte and 8-byte alignment yield similar performance.
since i used double buffering, i overlap global memory reads with mfma and lds writes, which leads to significant runtime improvement. the compute stage is pretty straightforward: perform two 8-byte reads from lds to registers, execute 16x16x32 mfma instructions followed by scaling, and once the computation is done, each thread writes 4 output cells.
i used persistent blocks with a column-based scheduling strategy, where i launch 304 blocks (same as the number of CUs in mi300x). instead of exiting after computing a single tile, each block persists and processes another tile. this reduces some block launch overhead and gives a small performance boost. i used column scheduling because it gave the best cache hit rates.
each block loads and computes
K / BK
tiles from both A and B. take
the case where m = 1024, k = 7168, and n = 1536. we have
(1024 / 128) x (1536 / 128) = 96
output
tiles but 304 CUs, which leads to poor gpu utilization.
instead of considering the final result of an output tile as the unit of
work, we consider a partial result of an output tile as the unit of
work. let's continue with the example above. to compute the final result
of a single output tile, we need to load and compute
(7168 / 128) = 56
tiles of size
128x128
from A and B. so if we treat a
partial result of an output tile as a work unit there are
56 x 96 = 5376
work units in total. now
it becomes immediately obvious that with naive scheduling, only the
first 96 CUs are doing 56 units of work each, while the rest are idle.
instead, we spread this work across all CUs to better utilize the gpu.
with 304 CUs, we get
ceil_div(5376 / 304) = 18
work units
per CU. so each CU works on 18 partial output tiles, which gives
significantly better utilization.
precision challenge: although streamk allows for equal distribution of work and outperforms the non-streamk version for most shapes, there's a catch. since the entire accumulation in the reference implementation happens in fp32, any intermediate atomic add to the global bf16 buffer leads to precision loss, causing a mismatch between our output and the reference. to implement streamk correctly, we need an intermediate fp32 buffer instead of writing directly to the bf16 output. this negates most of the performance gains from streamk for many shapes, except a few. so, i only launch the specialized streamk kernel for those specific cases.
the complete implementation is available on github with two separate kernel versions:
both implementations feature the core optimizations discussed above, with the streamk version providing additional work distribution capabilities for improved gpu utilization on specific matrix shapes.
you can also check out the complete repository, which contains the development history of all kernels i wrote during the amd challenge.