This document proposes extending support for cooperative matrices to make them more broadly useful.
1. Problem Statement
While VK_KHR_cooperative_matrix can be used to accelerate "simple" GEMMs, it has some serious limitations:
Subgroup scope matrices are small enough that using them the natural way in a GEMM does not lead to good performance. Shader authors can work around this by explicitly tiling and staging through shared memory, but it is cumbersome. Similarly, for networks that need slightly larger matrices that should still fit in the register file, they are again required to explicitly tile.
Some networks require operations like reductions or matrix Use conversions, and currently these require a round trip through shared memory.
Simple row/col-major ordering is fine for 2D tensors, but many networks have tensors with more dimensions, more complex slicing of tensors, and padding at tensor boundaries. VK_KHR_cooperative_matrix does not address these use cases.
Many large networks store weights in a quantized format, and there is not a great way to dequantize these as they are loaded into a cooperative matrix.
Fortunately, VK_KHR_cooperative_matrix provides a solid foundation that we can extend to address all of these issues.
We hope that this new functionality is useful both for handwritten shaders and as a target for graph compilers.
2. Solution Space
This proposal includes several new features, each with their own solution spaces:
Compilers can easily scale matrix sizes by internally tiling them, and more flexible matrix dimensions are clearly valuable for ease of use. The main question is whether the sizes should be completely arbitrary, or defined as a multiple of some minimum granularity. Completely arbitrary sizes might lead to unexpected complexity or performance issues, so for simplicity this extension adds minimum granularities, and any multiple is supported up to some relatively large upper bound. Additionally, supporting matrices at workgroup scope allows implementations to more efficiently support larger matrix sizes, and to optimize their loads by automatically staging them through shared memory.
Matrix Use conversions are important for writing fused network kernels, so that an accumulator can be used as an operand for another multiply. The most important conversions seem to be Accumulator→A, and Accumulator→B with optional transpose. To minimize complexity, this is all that is supported for now.
Reductions are common in networks, for example softmax activation requires a reduction over rows, and max-pooling requires a reduction over 2x2 neighborhoods. We support these in Accumulator matrices since that is the primary place they are needed. It is also possible to do sum-reduce of a row of an A matrix or a column of a B matrix by multiplying by a matrix of all ones, so we do not add an additional way to do this.
The existing row/col-major pitch-based addressing only covers the simplest use cases and misses a lot of real-world use cases, like slicing and reinterpreting views of tensors, and clamping/padding at tensor boundaries. We considered fully programmable addressing (e.g. function pointer to calculate addresses), but the generality makes it difficult for the implementation to do optimizations like coalescing loads or staging through shared memory. The solution we settled on, tensor layouts and views, are conceptually similar to how tensors are handled in APIs like PyTorch, so it handles common use cases. The regular structure imposed by the tensor layout aids the compiler in performing the aforementioned optimizations. And while the tensor layout and view calculations appear expensive in their full generality, applications "pay for what they use" and the simpler cases are still pretty cheap.
Quantized weights are often stored in sub-byte types (e.g. int4) or as part of a larger block, and need to be scaled and/or run through a lookup table to produce the dequantized value. Previously, this had to be done manually and copied into shared memory to be loaded into a matrix. We support block dimensions in the addressing calculations and a decode callback function that allows the values to be dequantized as they are loaded into the matrix.
3. Proposal
The solution space section describes all the separate features of the proposal and which solutions we have chosen. Then a final question is how to package these features. Most of these can be used independently as incremental additions to VK_KHR_cooperative_matrix. These are specified as independent feature enables (and independent SPIR-V capabilities) so implementations can adopt them incrementally (or they could be promoted incrementally) if needed. There are no specific required features.
4. Examples
A workgroup scope GEMM:
uvec2 tileID = uvec2(gl_WorkGroupID.xy);
// Initialize result to zero
result = coopmat<C_TYPE, gl_ScopeWorkgroup, TILE_M, TILE_N, gl_MatrixUseAccumulator>(0.0);
// Create tensor layouts for A, B, Accumulator
tensorLayoutNV<2> tensorLayoutA = createTensorLayoutNV(2);
tensorLayoutNV<2> tensorLayoutB = createTensorLayoutNV(2);
tensorLayoutNV<2> tensorLayoutC = createTensorLayoutNV(2);
tensorLayoutA = setTensorLayoutDimensionNV(tensorLayoutA, M, K);
tensorLayoutB = setTensorLayoutDimensionNV(tensorLayoutB, K, N);
tensorLayoutC = setTensorLayoutDimensionNV(tensorLayoutC, M, N);
// Loop over K dimension and accumulate partial results
for (uint chunkK = 0; chunkK < K; chunkK += TILE_K) {
coopmat<A_TYPE, gl_ScopeWorkgroup, TILE_M, TILE_K, gl_MatrixUseA> matrixA;
coopMatLoadTensorNV(matrixA, inputA.x, 0, sliceTensorLayoutNV(tensorLayoutA, TILE_M * tileID.y, TILE_M, chunkK, TILE_K));
coopmat<A_TYPE, gl_ScopeWorkgroup, TILE_K, TILE_N, gl_MatrixUseB> matrixB;
coopMatLoadTensorNV(matrixB, inputB.x, 0, sliceTensorLayoutNV(tensorLayoutB, chunkK, TILE_K, TILE_N * tileID.x, TILE_N));
result = coopMatMulAdd(matrixA, matrixB, result);
// Load C and compute alpha*(A*B)+beta*C
coopmat<C_TYPE, gl_ScopeWorkgroup, TILE_M, TILE_N, gl_MatrixUseAccumulator> matrixC;
coopMatLoadTensorNV(matrixC, inputC.x, 0, sliceTensorLayoutNV(tensorLayoutC, TILE_M * tileID.y, TILE_M, TILE_N * tileID.x, TILE_N));
result = C_TYPE(alpha) * result + C_TYPE(beta) * matrixC;
// Store result to memory
coopMatStoreTensorNV(result, outputD.x, 0, sliceTensorLayoutNV(tensorLayoutC, TILE_M * tileID.y, TILE_M, TILE_N * tileID.x, TILE_N));
A reduction computing max over each row of a matrix:
float16_t maxReduce(const in float16_t x, const in float16_t y) {
return max(x, y);
mat = ...;
coopmat<float16_t, gl_ScopeWorkgroup, TILE_M, TILE_N, gl_MatrixUseAccumulator> rowMax;
coopMatReduceNV(rowMax, mat, gl_CooperativeMatrixReduceRowNV, maxReduce);