KIOSHIROI's CS-learning Road

一、浮点数

在深度学习中,几乎所有的运算都基于浮点数。计算机中常见的浮点数格式通常为单精度浮点和双精度浮点。

1.1 FP32(单精度)

单精度浮点总共32位,由1位浮点数、8位指数位、23位尾数位构成。

这是Pytorch的默认数据类型 torch.float32, 其可表示的范围覆盖-3.4e+38~3.4e+38,精度高,动态范围大,几乎不会出现溢出或下溢。

1.2 FP64(双精度)

64位,更高精度,由1位浮点数、11位指数位,32位尾数位构成。

因为成本太高,很少在深度学习中使用。

1.3 FP16(半精度)

16位,由1位浮点数、5位指数位,10位尾数位构成。

动态范围比FP32小很多,但尾数精度仍然不错。

1.4 BF16(Brain Floating Point 16)

也是16位,不过是由1位符号数、8位指数位、7位尾数位构成。

与FP16不同的是,通过牺牲一些尾数精度,喊来与FP32几乎相同的动态范围(指数位同为8位)。

这是目前大模型训练中最常用的格式(尤其是混合精度训练)。

二、低精度格式

2.1 FP16 / BF16(2017-2019年流行)

混合精度训练(AMP)让训练速度提升2-3倍,同时精度几乎无损。

NVIDIA A100及以后的GPU原生支持。

2.2 INT8 量化(2018年起广泛用于推理)

将权重和激活值量化到8位整数。

推理速度提升2-4倍,内存减半。

缺点是:训练阶段难以直接使用(梯度问题),且对某些模型精度损失较大。

2.3 FP8(2022 年起)

FP8有两种常见变体:

  • E4M3:4位指数 + 3位尾数。动态范围大,适合权重。
  • E5M2:5位指数 + 2位尾数。动态范围更大,适合激活层。

NVIDIA H100/Hopper 架构开始原生支持 FP8 Tensor Core。

相比 BF16,内存再减半,理论吞吐量再翻倍。

三、我们想要更低的比特

3.1 四位比特

使用FP4 / INT4可以进一步减少内存,加速推理。但直接使用4位浮点(如E2M1)可表示的数值范围极小,很容易导致神经网络权重/激活层下溢或溢出。结果就是模型精度急剧下降,甚至完全无法收敛。

四、块浮点与MX格式

4.1 传统定点与浮点的权衡

  • 定点表示:
    • 优点:存储和计算极快,所有值共用一个隐含的缩放因子。
    • 缺点:动态范围小。
  • 浮点表示:
    • 优点:每个数都有自己的指数位,动态范围大,几乎不会溢出。
    • 缺点:每个数都要独处指数位的存储开销,导致平均位宽较高。

神经网络权重和激活值通常呈现长尾分布:大部分值很小,少数值很大。这使得纯定点容易出问题(溢出),而纯浮点又太浪费位宽。

4.2 块浮点——折中思想

块浮点是定点和浮点的混合体:

  • 将张量分成很多小块,每个块通常包含16~64个连续元素。
  • 块内所有元素共享同一个缩放因子。
  • 每个元素只存储低位宽的“尾数部分”,不在存储各自的指数(就像定点数那样)。
  • 实际数值 = 存储的低精度值 x 共享的 scale

这样做大幅拓展了动态范围,且极大地降低了存储开销,硬件实现也很友好。

4.3 Microscaling(MX)格式——标准化的块浮点

Microscaling(MX)格式转为AI工作负载设计。核心特点如下:

  • 固定块大小:标准推荐32个元素共享一个scale;
  • scale的专用格式:E8M0(8位纯指数,无尾数,无符号)。
    • 只能表示2的整数幂
    • 动态范围极大,且乘法时只需加减指数,硬件极简。
    • 无NaN,无Inf,无负数,转为缩放设计。
  • 元素数据格式:
    • MXFP8:标准FP8
    • MXFP6:6位浮点(E3M2 / E2M3)
    • MXFP4:E2M1

MXFP4实现了平均”约4.25位“的超低存储,同时动态范围接近甚至超过FP16.

MX 格式为什么能保持高精度?

  1. 共享 scale 缓解了 outliers 问题
  2. 块大小 32 是经验最优
  3. E8M0 scale 的 2 的幂特性:两个MX格式相乘只需将scale相加

五、PyTorch 实现

5.1 PyTorch中的低位宽基本dtype

PyTorch从2.3开始原生支持了一系列低位宽的浮点类型。这些不是完整的“MX格式”,而是MX格式的构件模块:

  • torch.float8_e4m3fn, 标准FP8,适合权重;
  • torch.float8_e5m2fn, 标准FP8,适合激活值;
  • torch.float8_e8m0fnu, 无符号,专门用于MX的scale,只能表示2的幂;
  • torch.float4.e2m1fn
  • torch.float4_e2m1fn_x2, 打包的4位浮点,每字节存储两个4位元素,用于高效存储和GPU加载

这些格式本身不是块浮点格式,只是普通的低精度浮点。

5.2 MXFP4的实现

MXFP4不是一个原生的单一 dtype,而是通过 Tensor Subclass(张量子类)实现的派生数据类型。在torchao库中实现如下:

  • 内部结构:
    • 一个MXFP4张量实际上持有两个部分:
      1. 数据部分:形状为原张量形状,dtype=torch.float4_e2m1fn_x2
      2. scale部分:形状为[…, ceil(N/32)], dtpe=torch.float8_e8m0fnu

量化过程:

def to_mxfp4(tensor):
	block_size = 32
	# 按块找最大绝对值
	abs_max = tensor.abs().reshape(-1, block_size).
amax
(dim=-1)
	# 计算 scale = abs_max / E8M0表示的最大值
	scale = compute_e8m0_scale(abs_max)
	# 归一化
	normalized = tensor / scale.
unsqueeze(-1)

	quantized_data = normalized.to(torch.float4_e2m1fn_x2)
	# 返回派生类型
	return MXFPTensor(quantized_data, scale)

反量化:

反量化的目的是降量化后的data(低精度FP4尾数)和scale(块缩放因子)恢复为近似原始值,代码逻辑完全对应量化公式的逆操作:

dequantized = mxfp4_tensor.data.
to(torch.float32)
 * mxfp4_tensor.scale.unsqueeze(-1)

5.3 MXFP8的类似实现

  • 数据部分:用 torch.float8_e4m3fntorch.float8_e5m2fn
  • scale部分:仍用 torch.float8_e8m0fnu
  • 平均位宽 ≈ 8.25位,比普通FP8多一点开销,但动态范围更大。

当前主流库支持情况(before 2025.12):

  • torchao:最完整的开源实现,支持MXFP4/6/8的量化;
  • NVDIA Transformer Engine:支持MX格式;
  • vLLM/TensorRT-LLM:推理时支持MX格式权重加载。

在 Notion 参与讨论

本文托管在 Notion,欢迎到原文评论区留言交流

在 Notion 打开
MXFP——为量化而生的浮点数们
https://kioshiroi.github.io/blog/mxfp
Author KIOSHIROI
Published at 2026年1月28日
Comment seems to stuck. Try to refresh?✨