MXFP——为量化而生的浮点数们
This note is hosted on Notion: MXFP——为量化而生的浮点数们
一、浮点数
在深度学习中,几乎所有的运算都基于浮点数。计算机中常见的浮点数格式通常为单精度浮点和双精度浮点。
1.1 FP32(单精度)
单精度浮点总共32位,由1位浮点数、8位指数位、23位尾数位构成。
这是Pytorch的默认数据类型 torch.float32, 其可表示的范围覆盖-3.4e+38~3.4e+38,精度高,动态范围大,几乎不会出现溢出或下溢。
1.2 FP64(双精度)
64位,更高精度,由1位浮点数、11位指数位,32位尾数位构成。
因为成本太高,很少在深度学习中使用。
1.3 FP16(半精度)
16位,由1位浮点数、5位指数位,10位尾数位构成。
动态范围比FP32小很多,但尾数精度仍然不错。
1.4 BF16(Brain Floating Point 16)
也是16位,不过是由1位符号数、8位指数位、7位尾数位构成。
与FP16不同的是,通过牺牲一些尾数精度,喊来与FP32几乎相同的动态范围(指数位同为8位)。
这是目前大模型训练中最常用的格式(尤其是混合精度训练)。
二、低精度格式
2.1 FP16 / BF16(2017-2019年流行)
混合精度训练(AMP)让训练速度提升2-3倍,同时精度几乎无损。
NVIDIA A100及以后的GPU原生支持。
2.2 INT8 量化(2018年起广泛用于推理)
将权重和激活值量化到8位整数。
推理速度提升2-4倍,内存减半。
缺点是:训练阶段难以直接使用(梯度问题),且对某些模型精度损失较大。
2.3 FP8(2022 年起)
FP8有两种常见变体:
- E4M3:4位指数 + 3位尾数。动态范围大,适合权重。
- E5M2:5位指数 + 2位尾数。动态范围更大,适合激活层。
NVIDIA H100/Hopper 架构开始原生支持 FP8 Tensor Core。
相比 BF16,内存再减半,理论吞吐量再翻倍。
三、我们想要更低的比特
3.1 四位比特
使用FP4 / INT4可以进一步减少内存,加速推理。但直接使用4位浮点(如E2M1)可表示的数值范围极小,很容易导致神经网络权重/激活层下溢或溢出。结果就是模型精度急剧下降,甚至完全无法收敛。
四、块浮点与MX格式
4.1 传统定点与浮点的权衡
- 定点表示:
- 优点:存储和计算极快,所有值共用一个隐含的缩放因子。
- 缺点:动态范围小。
- 浮点表示:
- 优点:每个数都有自己的指数位,动态范围大,几乎不会溢出。
- 缺点:每个数都要独处指数位的存储开销,导致平均位宽较高。
神经网络权重和激活值通常呈现长尾分布:大部分值很小,少数值很大。这使得纯定点容易出问题(溢出),而纯浮点又太浪费位宽。
4.2 块浮点——折中思想
块浮点是定点和浮点的混合体:
- 将张量分成很多小块,每个块通常包含16~64个连续元素。
- 块内所有元素共享同一个缩放因子。
- 每个元素只存储低位宽的“尾数部分”,不在存储各自的指数(就像定点数那样)。
- 实际数值 = 存储的低精度值 x 共享的 scale
这样做大幅拓展了动态范围,且极大地降低了存储开销,硬件实现也很友好。
4.3 Microscaling(MX)格式——标准化的块浮点
Microscaling(MX)格式转为AI工作负载设计。核心特点如下:
- 固定块大小:标准推荐32个元素共享一个scale;
- scale的专用格式:E8M0(8位纯指数,无尾数,无符号)。
- 只能表示2的整数幂
- 动态范围极大,且乘法时只需加减指数,硬件极简。
- 无NaN,无Inf,无负数,转为缩放设计。
- 元素数据格式:
- MXFP8:标准FP8
- MXFP6:6位浮点(E3M2 / E2M3)
- MXFP4:E2M1
MXFP4实现了平均”约4.25位“的超低存储,同时动态范围接近甚至超过FP16.
MX 格式为什么能保持高精度?
- 共享 scale 缓解了 outliers 问题
- 块大小 32 是经验最优
- E8M0 scale 的 2 的幂特性:两个MX格式相乘只需将scale相加
五、PyTorch 实现
5.1 PyTorch中的低位宽基本dtype
PyTorch从2.3开始原生支持了一系列低位宽的浮点类型。这些不是完整的“MX格式”,而是MX格式的构件模块:
- torch.float8_e4m3fn, 标准FP8,适合权重;
- torch.float8_e5m2fn, 标准FP8,适合激活值;
- torch.float8_e8m0fnu, 无符号,专门用于MX的scale,只能表示2的幂;
- torch.float4.e2m1fn
- torch.float4_e2m1fn_x2, 打包的4位浮点,每字节存储两个4位元素,用于高效存储和GPU加载
这些格式本身不是块浮点格式,只是普通的低精度浮点。
5.2 MXFP4的实现
MXFP4不是一个原生的单一 dtype,而是通过 Tensor Subclass(张量子类)实现的派生数据类型。在torchao库中实现如下:
- 内部结构:
- 一个MXFP4张量实际上持有两个部分:
- 数据部分:形状为原张量形状,dtype=torch.float4_e2m1fn_x2
- scale部分:形状为[…, ceil(N/32)], dtpe=torch.float8_e8m0fnu
- 一个MXFP4张量实际上持有两个部分:
量化过程:
def to_mxfp4(tensor):
block_size = 32
# 按块找最大绝对值
abs_max = tensor.abs().reshape(-1, block_size).
amax
(dim=-1)
# 计算 scale = abs_max / E8M0表示的最大值
scale = compute_e8m0_scale(abs_max)
# 归一化
normalized = tensor / scale.
unsqueeze(-1)
quantized_data = normalized.to(torch.float4_e2m1fn_x2)
# 返回派生类型
return MXFPTensor(quantized_data, scale)
反量化:
反量化的目的是降量化后的data(低精度FP4尾数)和scale(块缩放因子)恢复为近似原始值,代码逻辑完全对应量化公式的逆操作:
dequantized = mxfp4_tensor.data.
to(torch.float32)
* mxfp4_tensor.scale.unsqueeze(-1)
5.3 MXFP8的类似实现
- 数据部分:用 torch.float8_e4m3fn 或 torch.float8_e5m2fn
- scale部分:仍用 torch.float8_e8m0fnu
- 平均位宽 ≈ 8.25位,比普通FP8多一点开销,但动态范围更大。
当前主流库支持情况(before 2025.12):
- torchao:最完整的开源实现,支持MXFP4/6/8的量化;
- NVDIA Transformer Engine:支持MX格式;
- vLLM/TensorRT-LLM:推理时支持MX格式权重加载。
在 Notion 参与讨论
本文托管在 Notion,欢迎到原文评论区留言交流