VK_EXT_shader_float8
This extension enables support for 8-bit floating point operations in shaders.
1. Problem Statement
With machine learning algorithms commonly being run on GPUs, it has become desirable to support smaller types in GPUs to allow increased throughput for large networks. This extension enables two 8-bit floating point types: E4M3 and E5M2 as defined by the "FP8 Formats For Deep Learning" whitepaper (https://arxiv.org/abs/2209.05433).
2. Solution Space
Machine learning algorithms frequently use SPV_KHR_cooperative_matrix.
Any proposal here has to support that functionality, as well as basic manipulation of data for these types.
3. Proposal
3.1. SPIR-V Changes
This extension adds two new Floating Point Encoding values, enabling the operand to be specified when creating a floating point type:
FP Encoding | Width(s) | Enabling Capabilities | |
---|---|---|---|
4214 |
Float8E4M3EXT |
8 |
Float8EXT |
4215 |
Float8E5M2EXT |
8 |
Float8EXT |
New capabilities enable both the declaration of the type and its use with cooperative matrix features:
Capability | Implicitly Declares | |
---|---|---|
4212 |
Float8EXT |
|
4213 |
Float8CooperativeMatrixEXT |
Float8EXT, CooperativeMatrixKHR |
The Float8EXT
capability is required to use 8-bit floating point types, and
Float8CooperativeMatrixEXT
is required to use cooperative matrix operations
with an 8-bit floating point component type.
3.2. API Changes
3.2.1. Features
This extension adds two features that map 1:1 to the capabilities exposed in that extension:
typedef struct VkPhysicalDeviceShaderFloat8FeaturesEXT {
VkStructureType sType;
void* pNext;
VkBool32 shaderFloat8;
VkBool32 shaderFloat8CooperativeMatrix;
} VkPhysicalDeviceShaderFloat16FeaturesEXT;
-
shaderFloat8Type
indicates support for theFloat8EXT
capability. -
shaderFloat8CooperativeMatrix
indicates support for theFloat8CooperativeMatrixEXT
capability.
shaderFloat8
must be supported for this extension.
3.2.2. Interactions with VK_KHR_cooperative_matrix
Two new VkComponentTypeKHR are added that can be reported as supported by vkGetPhysicalDeviceCooperativeMatrixPropertiesKHR:
typedef enum VkComponentTypeKHR {
...
VK_COMPONENT_TYPE_FLOAT8_E4M3_EXT = ...,
VK_COMPONENT_TYPE_FLOAT8_E5M2_EXT = ...,
} VkComponentTypeKHR;
If shaderFloat8CooperativeMatrix
is supported, at least one entry in vkGetPhysicalDeviceCooperativeMatrixPropertiesKHR::pProperties must include this type in all of its AType
, BType
, and CType
members.