VK_NV_cooperative_vector

This extension adds shading language support for matrix-vector multiplies which can be used to accelerate evaluation of small neural networks.

1. Problem Statement

Several recently developed graphics techniques involve having each shader invocation independently evaluate a small neural network, usually a multi-layer perceptron (MLP). These techniques can benefit from the same dedicated matrix multiply hardware that is used for VK_KHR_cooperative_matrix, but need to work in the usual SIMT shader programming model rather than the subgroup-wide cooperative programming model of cooperative matrix.

2. Solution Space

Three options have been considered:

  1. Matrix-vector multiply with matrix implicitly loaded from memory.

  2. Matrix-vector multiply reusing cooperative matrix types.

  3. Higher-level "MLP object".

The "MLP object" was deemed too high level and hard to express. There is a lot of variability in activation functions and types used that make it much more natural to write the MLP as shader code using a lower-level primitive like a matrix-vector multiply. And defining an MLP object would preclude other possible uses that have a different network structure or are not neural networks at all. A matrix-vector multiply is high level enough to target the dedicated matrix multiply hardware, but low level enough to be flexible.

Reusing cooperative matrix types seems desirable at first glance, but breaks down in the SIMT programming model because cooperative matrices are by nature shared across many threads. It does not naturally allow each thread to reference a distinct matrix.

Having the matrix implicitly loaded by the matrix multiply function allows the matrix address to act as a "handle" so threads can each reference a distinct matrix if needed. The matrices are small enough that they can be loaded on demand and the usual caches can make this efficient.

3. Proposal

This extension adds a new set of types to the shading language known as "cooperative vector" types. Unlike cooperative matrix types, a variable with a cooperative vector type is logically stored in the invocation it belongs to, but they can cooperate behind the scenes when performing matrix-vector multiplies. Cooperative vectors do not require a fully occupied subgroup or uniform control flow like cooperative matrices, although these do increase the likelihood of being on the fast path. And unlike normal vector types, they have arbitrary length and support a relatively limited set of operations. These types are intended to help accelerate the evaluation of small neural networks, where each invocation is performing its own independent evaluation of the network.

There are new matrix multiply functions (only the more general is shown). This function performs a matrix-vector multiplication using a matrix loaded from memory and a vector passed as a parameter. The input vector has K logical components and is left-multiplied by an MxK matrix to produce a result with M components that is stored in the output parameter 'result'. It also loads a 'bias' vector with M components from memory, which is added to the product before it is stored in 'result'.

    void coopVecMatMulAddNV(out coopvecNV<ResultTy, ResultComps> result,
                            coopvecNV<InputTy, InputComps> input,
                            int inputInterpretation,
                            const MatrixTy[] matrix,
                            uint matrixOffset,
                            int matrixInterpretation,
                            const BiasTy[] bias,
                            uint biasOffset,
                            int biasInterpretation,
                            uint M,
                            uint K,
                            int matrixLayout,
                            bool transpose,
                            uint matrixStride);

There are also functions to load/store vectors from memory:

    void coopVecLoadNV(out coopvecNV<VectorElemTy, NumComps> v, volatile coherent ArrayElemTy[] buf, uint offset);

    void coopVecStoreNV(coopvecNV<VectorElemTy, NumComps> v, volatile coherent ArrayElemTy[] buf, uint offset);

In the API, there are three new commands. The first two convert a matrix to optimal layout. One executes on the host and the other on the device, and they are meant to be used when loading network weights into memory.

VKAPI_ATTR VkResult VKAPI_CALL vkConvertCooperativeVectorMatrixNV(
    VkDevice                                    device,
    const VkConvertCooperativeVectorMatrixInfoNV* pInfo);

VKAPI_ATTR void VKAPI_CALL vkCmdConvertCooperativeVectorMatrixNV(
    VkCommandBuffer                             commandBuffer,
    uint32_t                                    infoCount,
    const VkConvertCooperativeVectorMatrixInfoNV* pInfos);

typedef struct VkConvertCooperativeVectorMatrixInfoNV {
    VkStructureType                      sType;
    const void*                          pNext;
    size_t                               srcSize;
    VkDeviceOrHostAddressConstKHR        srcData;
    size_t*                              pDstSize;
    VkDeviceOrHostAddressKHR             dstData;
    VkComponentTypeKHR                   srcComponentType;
    VkComponentTypeKHR                   dstComponentType;
    uint32_t                             numRows;
    uint32_t                             numColumns;
    VkCooperativeVectorMatrixLayoutNV    srcLayout;
    size_t                               srcStride;
    VkCooperativeVectorMatrixLayoutNV    dstLayout;
    size_t                               dstStride;
} VkConvertCooperativeVectorMatrixInfoNV;

The third new command advertises types supported by the matrix-vector multiply:

VKAPI_ATTR VkResult VKAPI_CALL vkGetPhysicalDeviceCooperativeVectorPropertiesNV(
    VkPhysicalDevice                            physicalDevice,
    uint32_t*                                   pPropertyCount,
    VkCooperativeVectorPropertiesNV*            pProperties);

typedef struct VkCooperativeVectorPropertiesNV {
    VkStructureType       sType;
    void*                 pNext;
    VkComponentTypeKHR    inputType;
    VkComponentTypeKHR    inputInterpretation;
    VkComponentTypeKHR    matrixInterpretation;
    VkComponentTypeKHR    biasInterpretation;
    VkComponentTypeKHR    resultType;
    VkBool32              transpose;
} VkCooperativeVectorPropertiesNV;

4. Examples

Example showing a 2x32 MLP evaluation in GLSL:

    restrict buffer {
        float16_t matrixData[];
    } matrixBuf;

    const int inputDim = 6;
    coopvecNV<float16_t, inputDim> inputVec = coopvecNV<float16_t, inputDim>(float16_t(materialstate), float16_t(shininess), ... );

    const int MLPDim = 32;
    coopvecNV<float16_t, MLPDim> mlpVec;
    coopVecMatMulNV(mlpVec, inputVec, gl_ComponentTypeFloat16NV, matrixBuf.matrixData, offset1, gl_ComponentTypeFloat16NV, MLPDim, inputDim, gl_CooperativeVectorMatrixLayoutRowMajorNV, false, MLPDim*sizeof(float16_t));

    // ReLU activation
    mlpVec = max(coopvecNV<float16_t, MLPDim>(float16_t(0)), mlpVec);

    coopVecMatMulNV(mlpVec, mlpVec, gl_ComponentTypeFloat16NV, matrixBuf.matrixData, offset2, gl_ComponentTypeFloat16NV, MLPDim, MLPDim, gl_CooperativeVectorMatrixLayoutRowMajorNV, false, MLPDim*sizeof(float16_t));

    // tanh activation
    mlpVec = tanh(mlpVec);

    const int resultDim = 8;
    coopvecNV<float16_t, resultDim> resultVec;

    coopVecMatMulNV(resultVec, mlpVec, gl_ComponentTypeFloat16NV, matrixBuf.matrixData, offset3, gl_ComponentTypeFloat16NV, resultDim, MLPDim, gl_CooperativeVectorMatrixLayoutRowMajorNV, false, resultDim*sizeof(float16_t));