multiplication is undoubtedly the commonest operation carried out by GPUs. It’s the elementary constructing block of linear algebra and reveals up throughout a large spectrum of various fields corresponding to graphics, physics simulations and scientific computing whereas being ubiquitous in machine studying.
In as we speak’s article, we’ll break down the conceptual implementation of common matrix-matrix multiplication (GEMM) whereas introducing a number of optimisation ideas corresponding to tiling and reminiscence coalescing. Lastly, we’ll implement GEMM in Triton!
This text is the second of a collection on Triton and GPU kernels, In case you are not acquainted with Triton or want a refresher on GPU fundamentals, take a look at the earlier article! All of the code showcased on this article is out there on GitHub.
Disclaimer: all the next figures and animations have been made by the writer until said in any other case.
Naive GEMM
Let’s begin easy: we need to multiply two matrices X
and Y
with shapes (M,N)
and (N,Okay)
respectively. The output matrix Z=X@Y
will due to this fact have form (M,Okay)
.
This operation includes computing the dot merchandise of all pairs of rows and columns in X
and Y
respectively. A simple NumPy implementation may look one thing like this:
Whereas straightforward to write down, learn and perceive, this implementation is extremely inefficient by way of reminiscence entry and caching. As talked about within the first article of this collection, a elementary side of GPU optimisation is minimising information transfers.
Nonetheless, our present implementation begins by loading a row from X
, iteratively hundreds all Okay
columns of Y
, computes their dot product and repeats the method for each row in X
. This leads to a complete of M(Okay+1)
loading operations.
As seen within the animation, the reminiscence entry sample is wasteful, as each column of Y
is loaded M
instances. As an analogy: that is like working to the grocery retailer (international reminiscence) each time you want a brand new ingredient for a dish as an alternative of getting ready all of the components in your kitchen counter (shared reminiscence). Ideally, we wish to minimise the variety of instances every chunk of information is loaded and maximise its reusability as soon as loaded. This leaves us with two most important axes of optimisation:
- How can we enhance the entry sample to minimise redundant hundreds?
- How a lot information can we load without delay, and the place ought to it’s saved on the GPU?
Tiled GEMM
As talked about beforehand, the naive strategy to GEMM leads to many redundant hundreds, which induces pointless overhead. Ideally, we’d wish to load every section of information solely as soon as and carry out all of the operations during which they’re used earlier than dropping them from reminiscence.
A chic strategy to this drawback is tiling, which includes dividing massive matrices in smaller “tiles” or sub-matrices. Think about two matrices X
and Y
with shapes (4,6)
and (6,4)
respectively, X@Y
leads to a matrix Z
with form (4,4)
.
With the intention to compute the primary factor of Z
, Z[0,0]
, we have to compute the dot product between the primary row of X
and the primary column of Y
: Z[0,0] = dot(X[0, :], Y[:, 0])
. We are able to additionally break down the dot product into smaller chunks, as an illustration in teams of three parts: Z[0,0] = dot(X[0,0:3], Y[0:3, 0]) + dot(X[0,3:6], Y[3:6, 0])
.
Alternatively, we will increase this strategy to 2 dimensions and compute a complete (2,2)
block of Z at a time: Z[0:2, 0:2] = dot(X[0:2, 0:2], Y[0:2, 0:2]) + dot(X[0:2, 2:4], Y[2:4, 0:2]) + dot(X[0:2, 4:6], Y[4:6, 0:2])
.
Right here’s a visible illustration of tiled matrix multiplication:

The above animation illustrates how information is reused in tiled GEMM. For every 2×2 block in X
and Y
, we compute 4 dot merchandise, which leads to a (2,2)
output matrix in Z
. Since every tile accommodates 3 blocks, we have to accumulate 3 of those matrices to compute the ultimate (2,2)
output in Z
. This accumulation is represented by coloured cells in Z
.
Within the kitchen analogy, that is like fetching components from the shop and getting ready them on the kitchen counter (i.e. small shared reminiscence), reusing them a number of instances earlier than going again to the shop.
Importantly, reusing loaded information over a number of steps permits this strategy to drastically scale back the variety of load operations. For (2,2)
blocks, every X
row and Y
column is utilized in two dot merchandise. Due to this fact, we’re performing twice as many operations with every block of loaded information, roughly halving the variety of load operations! Word that this generalises to bigger blocks as nicely, utilizing a (32,32)
block would scale back the variety of hundreds by an element of round 32.
Now you’re most likely questioning “how massive can these blocks be”? To reply this query, let’s recall how reminiscence is managed in fashionable GPUs.
GPU Reminiscence Hierarchy
We distinguish 4 most important forms of reminiscence in Nvidia GPUs. Right here, we take the instance of an A100:
- Registers: The quickest and smallest sort of reminiscence on the GPU, residing straight inside every Streaming Multiprocessor (SM). On the A100, every SM offers 256 KB of register file area (65,536 × 32-bit registers), distributed amongst its threads. Every thread will get its personal non-public 32-bit registers for storing short-term variables and intermediate outcomes, avoiding reminiscence visitors altogether. Nonetheless, register utilization per thread straight impacts occupancy, as utilizing too many registers per thread limits what number of threads can run concurrently.
- L1/Shared Reminiscence: On an A100, every SM has 192KB of SRAM that may be flexibly configured as both a hardware-managed L1 cache or a programmer-managed shared reminiscence. For performance-critical kernels like matrix multiplication, we explicitly use this area as shared reminiscence to stage information tiles near the compute models, bypassing the L1 cache totally. This provides us fine-grained management over information reuse.
- L2 cache: This cache is slower than L1 however a lot bigger, with round 40 MB shared throughout all SMs on the A100. It serves as a world cache for each information and directions, decreasing the variety of accesses to high-latency HBM reminiscence. The L2 cache is coherent throughout SMs, that means that updates from one SM are seen to others, enabling synchronisation between thread blocks. Its bandwidth can attain a number of terabytes per second, performing as a buffer between the quick on-chip SRAM and the slower HBM.
- Excessive Bandwidth Reminiscence (HBM): That is the machine reminiscence, it has a capability of both 40GB or 80GB relying on the A100 mannequin. It offers extraordinarily excessive bandwidth (as much as 2 TB/s on the 80 GB variant) however with a lot greater latency than on-chip caches. HBM is the place massive tensors, mannequin weights, and datasets reside throughout execution. Since accessing HBM is dear, environment friendly kernels intention to minimise information motion and maximise on-chip information reuse by way of registers and shared reminiscence.
As you’ll be able to see, the reminiscence hierarchy typically trades off capability with latency. Due to this fact, maximising efficiency boils right down to loading information from HBM into shared reminiscence effectively and reusing it as a lot as doable.

Selecting our block measurement is essential. We wish blocks to be massive sufficient to create lots of parallel work, however sufficiently small that their information matches within the SM’s shared reminiscence and registers. A BLOCK_SIZE
of 64 is a standard start line as a result of it’s a a number of of the warp measurement (32 threads), making certain full {hardware} utilisation.
Parallel Tiled GEMM
With these concerns in thoughts, a pure follow-up to our tiled GEMM is to parallelise the computation of every pairs of tiles over a number of thread blocks, as depicted on the next animation.

Reminiscence Coalescing
Earlier than writing tiled GEMM in Triton, we have to contemplate one final element: reminiscence coalescing, a way that enables optimum use of worldwide reminiscence bandwidth. Reminiscence coalescing is achieved when subsequent threads in a warp entry subsequent reminiscence addresses. Think about a librarian needing to fetch books for a consumer, if all books are side-by-side on a shelf, they’ll seize them abruptly. In distinction, if all books are mendacity on totally different cabinets, they’ll should seize them one after the other, which takes considerably longer.
To know how this is applicable to our case, word that matrices are saved linearly in reminiscence, in different phrases a (2,2)
matrix is saved as a sequence of 4
consecutive parts. Frameworks like PyTorch undertake a row-major structure, that means that parts of a matrix are per-row contiguous in reminiscence. As an illustration, parts of our (2,2)
matrix could be saved as follows: [(0,0), (0,1), (1,0), (1,1)]
, discover that parts of the identical row are contiguous (touching) whereas parts of the identical column have a stride of 1 (separated by one factor).

This means that we will load rows utilizing coalesced hundreds, however columns do not fulfill this situation. Nonetheless, we have to entry columns of Y
to compute dot merchandise. With the intention to maximise efficiency, a great follow is to transpose Y
in order that we iterate on its rows relatively than its columns.
Nonetheless, transposing Y
isn’t sufficient to change its structure in reminiscence. As talked about beforehand, PyTorch shops matrices in a flat array. Every matrix dimension is related to a stride
attribute, denoting the bounce essential to go from one factor to the following one alongside this dimension. As an illustration, a (10,10)
matrix would have strides=(10,1)
. Certainly, ranging from factor [0,0]
, factor [1,0]
is 10 reminiscence slots (i.e. one row) away, whereas factor [0,1]
is adjoining.
When transposing a tensor, PyTorch doesn’t modify the structure in reminiscence however merely recomputes the strides. With the intention to make the transpose efficient from a reminiscence standpoint we have to name Y.T.contiguous()
.
These are the required steps the load columns of Y
effectively, nonetheless we’ll must transpose the loaded blocks inside the kernel to carry out the dot product correctly: z_block = tl.dot(X_block, Y_block.T)
.

Triton Implementation
From right here on, we first describe the kernel with out reminiscence coalescing to simplify the logic and pointer arithmetic earlier than summarising the adjustments required to make the load operations coalesced on Y
columns.
Let’s begin by specializing in the PyTorch wrapper across the kernel. We have to learn M, N, Okay
from the enter matrices and compute their strides since these constants will probably be helpful later within the kernel. Then, we outline the BLOCK_SIZE
and declare the grid
.
Now let’s dive into the precise kernel code. We’re going to utilize Triton’s make_block_ptr
utility, which simplifies the pointer arithmetic. We create one block pointer per matrix and cross the matrix form, its strides, and the dimensions of the block as inputs. Moreover, we specify the offset, the coordinate of the top-left factor within the present block. For X
, this corresponds to (m_idx * BLOCK_SIZE, 0)
the place m_idx
is the index of the present block alongside the M
dimension.
From there, we outline z_acc
, a zero matrix that can obtain the partial dot-products as we iterate by means of tiles. We now iterate by means of the shared dimension N
, loading blocks of measurement (BLOCK_SIZE, BLOCK_SIZE)
, and accumulate their dot merchandise in z_acc
. We then transfer the block pointers alongside the shared dimension by utilizing .advance
.
You may need observed that when loading information, we use boundary_check
and padding_option
as an alternative of masks
and different
as within the earlier article. These arguments are particular to the usage of block pointers and specify which axes to test for out-of-bound operations (right here (0,1)
for x and y) and the right way to deal with these invalid values. Right here we set them to zero to be ignored within the dot product.
We are able to now check out the efficiency of this kernel by utilizing the next operate:
def bench(fn: callable, x: torch.Tensor, y: torch.Tensor, repeat: int):
flops = []
med_latency = []
for _ in tqdm(vary(repeat), desc=f"Benchmarking {fn.__name__}"):
latency_ms = triton.testing.do_bench(
lambda: fn(x, y),
quantiles=[0.5], # get the median latency
return_mode="all",
)
n_flops = 2 * M * N * Okay # matmul roughly requires 2*M*N*Okay operations
tflops = n_flops / (latency_ms / 1e3) / 1e12
med_latency.append(latency_ms)
flops.append(tflops)
flops = np.array(flops)
med_latency = np.array(med_latency)
print(f"Absolute Error: {torch.sum(torch.abs(X@Y - fn(x, y)))}")
print(f"Median Latency: {med_latency.imply():.4f} ± {med_latency.std():.3f} ms")
print(f"Throughput: {flops.imply():.4f} ± {flops.std():.3f} TeraFLOPS")
M = 8192
N = 6144
Okay = 4096
X = torch.randn((M, N), machine="cuda", dtype=torch.float32)
Y = torch.randn((N, Okay), machine="cuda", dtype=torch.float32)
bench(block_matmul, X, Y, repeat=10)
We get the next outputs (utilizing a T4 GPU on Colab):
Absolute Error: 0.0 # the kernel outputs the proper outcome!
Median Latency: 130.7831 ± 1.794 ms
Throughput: 3.1533 ± 0.043 TeraFLOPS
Now let’s evaluation the adjustments required for coalesced hundreds on Y
: we primarily must flip the form, strides and offsets when defining the block pointer for Y
. Moreover, we replace the block pointer to maneuver alongside the column dimension (beforehand row dimension). The total code for this implementation is out there on GitHub.
@triton.jit
def coalesced_block_matmul_kernel(
X_ptr, X_m_stride, X_n_stride,
Y_ptr, Y_k_stride, Y_n_stride,
Z_ptr, Z_m_stride, Z_k_stride,
M, N, Okay,
BLOCK_SIZE: tl.constexpr,
):
...
y_block_ptr = tl.make_block_ptr(
base=Y_ptr,
# flip the form, strides and offsets to match Y.T
form=(Okay, N),
strides=(Y_k_stride, Y_n_stride),
offsets=(k_idx * BLOCK_SIZE, 0),
block_shape=(BLOCK_SIZE, BLOCK_SIZE),
order=(0, 1),
)
...
for _ in vary(0, N, BLOCK_SIZE):
... # hundreds
z_acc += tl.dot(x, y.T) # transpose Y again for dot product
x_block_ptr = tl.advance(x_block_ptr, offsets=(0, BLOCK_SIZE))
# advance the block pointer alongside columns of Y.T (i.e rows of Y)
y_block_ptr = tl.advance(y_block_ptr, offsets=(0, BLOCK_SIZE))
tl.retailer(pointer=z_block_ptr, worth=z_acc, boundary_check=(0, 1))
def coalesced_block_matmul(X, Y):
Y = Y.T.contiguous() # Y is now (Okay,N)
M, N = X.form
Okay, _ = Y.form
Z = torch.empty((M, Okay), machine="cuda")
x_stride_m, x_stride_n = X.stride()
y_stride_k, y_stride_n = Y.stride()
z_stride_m, z_stride_k = Z.stride()
... # outline BLOCK_SIZE and grid
coalesced_block_matmul_kernel[grid](
X, x_stride_m, x_stride_n,
Y, y_stride_n, y_stride_k,
Z, z_stride_m, z_stride_k,
M, N, Okay,
BLOCK_SIZE,
)
return Z
Listed below are the outcomes of our benchmark for the kernel with coalesced hundreds for Y
:
Absolute Error: 0.0 # Once more, the kernel is right!
Median Latency: 261.9420 ± 0.858 ms
Throughput: 1.5741 ± 0.005 TeraFLOPS
Surprisingly, the throughput of this second kernel is simply half of what we obtained with the primary one, regardless of enhancing the effectivity of load operations 🤔
A fast inspection utilizing nsight
(Nvidia’s kernel profiler, extra on that in a future article) reveals that the transpose operation inside the kernel creates a “visitors jam”. Particularly, the transpose creates financial institution conflicts, inflicting threads to stay idle more often than not. Notably, the warp scheduler has no eligible warp to dispatch 87.6% of the time as they’re ready for the financial institution battle to resolve. Moreover, the report reads:
———————– ———– ————–
Metric Identify Metric Unit Metric Worth
———————– ———– ————–
…
DRAM Throughput % 8.20
Compute (SM) Throughput % 21.14
…
This means that the kernel is latency certain (i.e. neither reminiscence nor compute certain, check with the earlier article for extra particulars). In distinction, the primary kernel is compute certain (i.e. growing compute will enhance efficiency) for the reason that compute throughput is excessive in comparison with the DRAM throughput.
———————– ———– ————–
Metric Identify Metric Unit Metric Worth
———————– ———– ————–
…
DRAM Throughput % 29.35
Compute (SM) Throughput % 74.39
…
Conclusion
This experiment highlights the significance of profiling and empirical validation. Even well-intentioned optimisations like coalescing reminiscence accesses can introduce new bottlenecks if not evaluated rigorously. The primary kernel, although less complicated, was compute-bound and higher matched the {hardware} traits.
Within the subsequent articles of this collection, we’ll implement a softmax kernel, paying specific consideration to integrating Triton with PyTorch’s autograd
and profiling kernels utilizing Nsight.
Till subsequent time! 👋