VK_EXT_shader_float8

This extension enables support for 8-bit floating point operations in shaders.

1. Problem Statement

With machine learning algorithms commonly being run on GPUs, it has become desirable to support smaller types in GPUs to allow increased throughput for large networks. This extension enables two 8-bit floating point types: E4M3 and E5M2 as defined by the "FP8 Formats For Deep Learning" whitepaper (https://arxiv.org/abs/2209.05433).

2. Solution Space

Machine learning algorithms frequently use SPV_KHR_cooperative_matrix.

Any proposal here has to support that functionality, as well as basic manipulation of data for these types.

3. Proposal

3.1. SPIR-V Changes

This extension adds two new Floating Point Encoding values, enabling the operand to be specified when creating a floating point type:

FP Encoding Width(s) Enabling Capabilities

4214

Float8E4M3EXT
The floating point type is encoded as an FP8 E4M3 type, as specified in the "FP8 Formats For Deep Learning" whitepaper (https://arxiv.org/abs/2209.05433).

8

Float8EXT

4215

Float8E5M2EXT
The floating point type is encoded as an FP8 E5M2 type, as specified in the "FP8 Formats For Deep Learning" whitepaper (https://arxiv.org/abs/2209.05433).

8

Float8EXT

New capabilities enable both the declaration of the type and its use with cooperative matrix features:

Capability Implicitly Declares

4212

Float8EXT
Uses OpTypeFloat to specify types with the Float8E4M3EXT or Float8E5M2EXT FP Encoding and values of this type with a few instructions.

4213

Float8CooperativeMatrixEXT
Uses cooperative matrix with a Component Type of OpTypeFloat with the Float8E4M3EXT or Float8E5M2EXT encoding.

Float8EXT, CooperativeMatrixKHR

The Float8EXT capability is required to use 8-bit floating point types, and Float8CooperativeMatrixEXT is required to use cooperative matrix operations with an 8-bit floating point component type.

3.2. API Changes

3.2.1. Features

This extension adds two features that map 1:1 to the capabilities exposed in that extension:

typedef struct VkPhysicalDeviceShaderFloat8FeaturesEXT {
    VkStructureType                 sType;
    void*                           pNext;
    VkBool32                        shaderFloat8;
    VkBool32                        shaderFloat8CooperativeMatrix;
} VkPhysicalDeviceShaderFloat16FeaturesEXT;
  • shaderFloat8Type indicates support for the Float8EXT capability.

  • shaderFloat8CooperativeMatrix indicates support for the Float8CooperativeMatrixEXT capability.

shaderFloat8 must be supported for this extension.

3.2.2. Interactions with VK_KHR_cooperative_matrix

Two new VkComponentTypeKHR are added that can be reported as supported by vkGetPhysicalDeviceCooperativeMatrixPropertiesKHR:

typedef enum VkComponentTypeKHR {
    ...
    VK_COMPONENT_TYPE_FLOAT8_E4M3_EXT = ...,
    VK_COMPONENT_TYPE_FLOAT8_E5M2_EXT = ...,
} VkComponentTypeKHR;

If shaderFloat8CooperativeMatrix is supported, at least one entry in vkGetPhysicalDeviceCooperativeMatrixPropertiesKHR::pProperties must include this type in all of its AType, BType, and CType members.

4. Issues

None.