我们很高兴正式推出torchao,这是一个基于PyTorch的库,通过利用低位宽数据类型、量化和稀疏性使模型更快、更小。torchao 是一个易于使用的工具集,其中包含(大部分)易于理解的PyTorch代码,涵盖推理和训练。本博客将帮助您选择最适合您工作负载的技术。
我们在流行的GenAI模型如LLama 3 和 Diffusion 模型上测试了我们的技术,并且发现了最小的精度损失。除非另有说明,基准是在A100 80GB GPU 上使用bf16运行的。
我们的LLama 3 顶级指标如下:
- 使用autoquant进行int4权重量化和hqq,在Llama 3 8B推理中速度提升97%
- 在128K上下文长度下,带有量化KV缓存的Llama 3.1 8B推理VRAM峰值降低了73%
- 使用H100上的float8训练,在Llama 3 70B预训练中速度提升50%
- 使用4位量化优化器,在Llama 3 8B推理中的VRAM峰值降低了30%
我们的Diffusion模型推理顶级指标如下:
- 在H100上的flux1.dev上使用float8动态量化和float8按行缩放推理时速度提升53%
- 在CogVideoX上使用int8动态量化时,模型VRAM减少50%
下面我们将介绍在torchao中可用的一些技术,您可以将其应用于模型的推理和训练。
推理
我们的推理量化算法 可以用于任意包含nn.Linear层的PyTorch模型。对于各种数据类型和稀疏布局,可以选择使用我们的顶层quantize_
API 进行仅权重或动态激活量化。
from torchao.quantization import (
quantize_,
int4_weight_only,
)
quantize_(model, int4_weight_only())
有时候分层量化可能会因开销导致其变慢,如果您希望我们为您选择合适的量化方式,则可以运行以下命令:
model = torchao.autoquant(torch.compile(model, mode='max-autotune'))
quantize_
API根据您的模型是计算受限还是内存受限有不同的选项。
from torchao.quantization import (
# Memory bound models
int4_weight_only,
int8_weight_only,
# Compute bound models
int8_dynamic_activation_int8_semi_sparse_weight,
int8_dynamic_activation_int8_weight,
# Device capability 8.9+
float8_weight_only,
float8_dynamic_activation_float8_weight,
)
我们还与HuggingFace diffusers团队合作,在diffusers-torchao上有大量的基准测试,并展示了Flux.1-Dev上53.88%的速度提升以及CogVideoX-5b的27.33%速度提升。
我们的API具有可组合性,例如我们可以将稀疏性和量化相结合,实现了ViT-H推理中5%的速度提升 参考。
也可以做其他事情,比如将权重量化到int4,KV缓存量化到int8,以支持Llama 3.1 8B在128K上下文长度下,在不到18.9GB的VRAM中运行。
量化感知训练(QAT)
后训练量化对于少于4位的量化可能会出现严重的精度下降。使用 量化感知训练 (QAT),我们成功恢复了高达96%的精度损失。我们在torchtune中集成了这一端到端的实例,并提供了一个简短的 教程。
训练
低精度计算与通信
torchao 提供了易于使用的端到端工作流程,可用于在训练过程中减少计算和分布式通信的精度。从torch.nn.Linear
层的float8开始,以下是用于将您的训练计算gemms 转换为float8的一行代码:
from torchao.float8 import convert_to_float8_training
convert_to_float8_training(model)
有关如何使用float8加快LLaMa 3 70B 预训练达至最高1.5倍速度的完整示例,请参