VK_NV_cooperative_vector
This extension adds shading language support for matrix-vector multiplies which can be used to accelerate evaluation of small neural networks.
1. Problem Statement
Several recently developed graphics techniques involve having each shader invocation independently evaluate a small neural network, usually a multi-layer perceptron (MLP). These techniques can benefit from the same dedicated matrix multiply hardware that is used for VK_KHR_cooperative_matrix, but need to work in the usual SIMT shader programming model rather than the subgroup-wide cooperative programming model of cooperative matrix.
2. Solution Space
Three options have been considered:
-
Matrix-vector multiply with matrix implicitly loaded from memory.
-
Matrix-vector multiply reusing cooperative matrix types.
-
Higher-level "MLP object".
The "MLP object" was deemed too high level and hard to express. There is a lot of variability in activation functions and types used that make it much more natural to write the MLP as shader code using a lower-level primitive like a matrix-vector multiply. And defining an MLP object would preclude other possible uses that have a different network structure or are not neural networks at all. A matrix-vector multiply is high level enough to target the dedicated matrix multiply hardware, but low level enough to be flexible.
Reusing cooperative matrix types seems desirable at first glance, but breaks down in the SIMT programming model because cooperative matrices are by nature shared across many threads. It does not naturally allow each thread to reference a distinct matrix.
Having the matrix implicitly loaded by the matrix multiply function allows the matrix address to act as a "handle" so threads can each reference a distinct matrix if needed. The matrices are small enough that they can be loaded on demand and the usual caches can make this efficient.
3. Proposal
This extension adds a new set of types to the shading language known as "cooperative vector" types. Unlike cooperative matrix types, a variable with a cooperative vector type is logically stored in the invocation it belongs to, but they can cooperate behind the scenes when performing matrix-vector multiplies. Cooperative vectors do not require a fully occupied subgroup or uniform control flow like cooperative matrices, although these do increase the likelihood of being on the fast path. And unlike normal vector types, they have arbitrary length and support a relatively limited set of operations. These types are intended to help accelerate the evaluation of small neural networks, where each invocation is performing its own independent evaluation of the network.
There are new matrix multiply functions (only the more general is shown). This function performs a matrix-vector multiplication using a matrix loaded from memory and a vector passed as a parameter. The input vector has K logical components and is left-multiplied by an MxK matrix to produce a result with M components that is stored in the output parameter 'result'. It also loads a 'bias' vector with M components from memory, which is added to the product before it is stored in 'result'.
void coopVecMatMulAddNV(out coopvecNV<ResultTy, ResultComps> result,
coopvecNV<InputTy, InputComps> input,
int inputInterpretation,
const MatrixTy[] matrix,
uint matrixOffset,
int matrixInterpretation,
const BiasTy[] bias,
uint biasOffset,
int biasInterpretation,
uint M,
uint K,
int matrixLayout,
bool transpose,
uint matrixStride);
There are also functions to load/store vectors from memory:
void coopVecLoadNV(out coopvecNV<VectorElemTy, NumComps> v, volatile coherent ArrayElemTy[] buf, uint offset);
void coopVecStoreNV(coopvecNV<VectorElemTy, NumComps> v, volatile coherent ArrayElemTy[] buf, uint offset);
In the API, there are three new commands. The first two convert a matrix to optimal layout. One executes on the host and the other on the device, and they are meant to be used when loading network weights into memory.
VKAPI_ATTR VkResult VKAPI_CALL vkConvertCooperativeVectorMatrixNV(
VkDevice device,
const VkConvertCooperativeVectorMatrixInfoNV* pInfo);
VKAPI_ATTR void VKAPI_CALL vkCmdConvertCooperativeVectorMatrixNV(
VkCommandBuffer commandBuffer,
uint32_t infoCount,
const VkConvertCooperativeVectorMatrixInfoNV* pInfos);
typedef struct VkConvertCooperativeVectorMatrixInfoNV {
VkStructureType sType;
const void* pNext;
size_t srcSize;
VkDeviceOrHostAddressConstKHR srcData;
size_t* pDstSize;
VkDeviceOrHostAddressKHR dstData;
VkComponentTypeKHR srcComponentType;
VkComponentTypeKHR dstComponentType;
uint32_t numRows;
uint32_t numColumns;
VkCooperativeVectorMatrixLayoutNV srcLayout;
size_t srcStride;
VkCooperativeVectorMatrixLayoutNV dstLayout;
size_t dstStride;
} VkConvertCooperativeVectorMatrixInfoNV;
The third new command advertises types supported by the matrix-vector multiply:
VKAPI_ATTR VkResult VKAPI_CALL vkGetPhysicalDeviceCooperativeVectorPropertiesNV(
VkPhysicalDevice physicalDevice,
uint32_t* pPropertyCount,
VkCooperativeVectorPropertiesNV* pProperties);
typedef struct VkCooperativeVectorPropertiesNV {
VkStructureType sType;
void* pNext;
VkComponentTypeKHR inputType;
VkComponentTypeKHR inputInterpretation;
VkComponentTypeKHR matrixInterpretation;
VkComponentTypeKHR biasInterpretation;
VkComponentTypeKHR resultType;
VkBool32 transpose;
} VkCooperativeVectorPropertiesNV;
4. Examples
Example showing a 2x32 MLP evaluation in GLSL:
restrict buffer {
float16_t matrixData[];
} matrixBuf;
const int inputDim = 6;
coopvecNV<float16_t, inputDim> inputVec = coopvecNV<float16_t, inputDim>(float16_t(materialstate), float16_t(shininess), ... );
const int MLPDim = 32;
coopvecNV<float16_t, MLPDim> mlpVec;
coopVecMatMulNV(mlpVec, inputVec, gl_ComponentTypeFloat16NV, matrixBuf.matrixData, offset1, gl_ComponentTypeFloat16NV, MLPDim, inputDim, gl_CooperativeVectorMatrixLayoutRowMajorNV, false, MLPDim*sizeof(float16_t));
// ReLU activation
mlpVec = max(coopvecNV<float16_t, MLPDim>(float16_t(0)), mlpVec);
coopVecMatMulNV(mlpVec, mlpVec, gl_ComponentTypeFloat16NV, matrixBuf.matrixData, offset2, gl_ComponentTypeFloat16NV, MLPDim, MLPDim, gl_CooperativeVectorMatrixLayoutRowMajorNV, false, MLPDim*sizeof(float16_t));
// tanh activation
mlpVec = tanh(mlpVec);
const int resultDim = 8;
coopvecNV<float16_t, resultDim> resultVec;
coopVecMatMulNV(resultVec, mlpVec, gl_ComponentTypeFloat16NV, matrixBuf.matrixData, offset3, gl_ComponentTypeFloat16NV, resultDim, MLPDim, gl_CooperativeVectorMatrixLayoutRowMajorNV, false, resultDim*sizeof(float16_t));