[博客翻译]PyTorch原生架构优化:Torchao


原文地址:https://pytorch.org/blog/pytorch-native-architecture-optimization/


我们很高兴正式推出torchao,这是一个基于PyTorch的库,通过利用低位宽数据类型、量化和稀疏性使模型更快、更小。torchao 是一个易于使用的工具集,其中包含(大部分)易于理解的PyTorch代码,涵盖推理和训练。本博客将帮助您选择最适合您工作负载的技术。

我们在流行的GenAI模型如LLama 3 和 Diffusion 模型上测试了我们的技术,并且发现了最小的精度损失。除非另有说明,基准是在A100 80GB GPU 上使用bf16运行的。

我们的LLama 3 顶级指标如下:

  • 使用autoquant进行int4权重量化和hqq,在Llama 3 8B推理中速度提升97%
  • 在128K上下文长度下,带有量化KV缓存的Llama 3.1 8B推理VRAM峰值降低了73%
  • 使用H100上的float8训练,在Llama 3 70B预训练中速度提升50%
  • 使用4位量化优化器,在Llama 3 8B推理中的VRAM峰值降低了30%

我们的Diffusion模型推理顶级指标如下:

  • 在H100上的flux1.dev上使用float8动态量化和float8按行缩放推理时速度提升53%
  • 在CogVideoX上使用int8动态量化时,模型VRAM减少50%

下面我们将介绍在torchao中可用的一些技术,您可以将其应用于模型的推理和训练。

推理

我们的推理量化算法 可以用于任意包含nn.Linear层的PyTorch模型。对于各种数据类型和稀疏布局,可以选择使用我们的顶层quantize_ API 进行仅权重或动态激活量化。

from torchao.quantization import (
quantize_,
int4_weight_only,
)
quantize_(model, int4_weight_only())

有时候分层量化可能会因开销导致其变慢,如果您希望我们为您选择合适的量化方式,则可以运行以下命令:

model = torchao.autoquant(torch.compile(model, mode='max-autotune'))

quantize_ API根据您的模型是计算受限还是内存受限有不同的选项。

from torchao.quantization import (  
    # Memory bound models  
    int4_weight_only,  
    int8_weight_only,

    # Compute bound models  
    int8_dynamic_activation_int8_semi_sparse_weight,  
    int8_dynamic_activation_int8_weight,  
      
    # Device capability 8.9+  
    float8_weight_only,  
    float8_dynamic_activation_float8_weight,  
)

我们还与HuggingFace diffusers团队合作,在diffusers-torchao上有大量的基准测试,并展示了Flux.1-Dev上53.88%的速度提升以及CogVideoX-5b的27.33%速度提升。

我们的API具有可组合性,例如我们可以将稀疏性和量化相结合,实现了ViT-H推理中5%的速度提升 参考

也可以做其他事情,比如将权重量化到int4,KV缓存量化到int8,以支持Llama 3.1 8B在128K上下文长度下,在不到18.9GB的VRAM中运行

量化感知训练(QAT)

后训练量化对于少于4位的量化可能会出现严重的精度下降。使用 量化感知训练 (QAT),我们成功恢复了高达96%的精度损失。我们在torchtune中集成了这一端到端的实例,并提供了一个简短的 教程

训练

低精度计算与通信

torchao 提供了易于使用的端到端工作流程,可用于在训练过程中减少计算和分布式通信的精度。从torch.nn.Linear 层的float8开始,以下是用于将您的训练计算gemms 转换为float8的一行代码:

from torchao.float8 import convert_to_float8_training
convert_to_float8_training(model)

有关如何使用float8加快LLaMa 3 70B 预训练达至最高1.5倍速度的完整示例,请参