[博客翻译]PyTorch原生架构优化:Torchao


原文地址:https://pytorch.org/blog/pytorch-native-architecture-optimization/


我们很高兴正式推出torchao,这是一个基于PyTorch的库,通过利用低位宽数据类型、量化和稀疏性使模型更快、更小。torchao 是一个易于使用的工具集,其中包含(大部分)易于理解的PyTorch代码,涵盖推理和训练。本博客将帮助您选择最适合您工作负载的技术。

我们在流行的GenAI模型如LLama 3 和 Diffusion 模型上测试了我们的技术,并且发现了最小的精度损失。除非另有说明,基准是在A100 80GB GPU 上使用bf16运行的。

我们的LLama 3 顶级指标如下:

  • 使用autoquant进行int4权重量化和hqq,在Llama 3 8B推理中速度提升97%
  • 在128K上下文长度下,带有量化KV缓存的Llama 3.1 8B推理VRAM峰值降低了73%
  • 使用H100上的float8训练,在Llama 3 70B预训练中速度提升50%
  • 使用4位量化优化器,在Llama 3 8B推理中的VRAM峰值降低了30%

我们的Diffusion模型推理顶级指标如下:

  • 在H100上的flux1.dev上使用float8动态量化和float8按行缩放推理时速度提升53%
  • 在CogVideoX上使用int8动态量化时,模型VRAM减少50%

下面我们将介绍在torchao中可用的一些技术,您可以将其应用于模型的推理和训练。

推理

我们的推理量化算法 可以用于任意包含nn.Linear层的PyTorch模型。对于各种数据类型和稀疏布局,可以选择使用我们的顶层quantize_ API 进行仅权重或动态激活量化。

from torchao.quantization import (
quantize_,
int4_weight_only,
)
quantize_(model, int4_weight_only())

有时候分层量化可能会因开销导致其变慢,如果您希望我们为您选择合适的量化方式,则可以运行以下命令:

model = torchao.autoquant(torch.compile(model, mode='max-autotune'))

quantize_ API根据您的模型是计算受限还是内存受限有不同的选项。

from torchao.quantization import (  
    # Memory bound models  
    int4_weight_only,  
    int8_weight_only,

    # Compute bound models  
    int8_dynamic_activation_int8_semi_sparse_weight,  
    int8_dynamic_activation_int8_weight,  
      
    # Device capability 8.9+  
    float8_weight_only,  
    float8_dynamic_activation_float8_weight,  
)

我们还与HuggingFace diffusers团队合作,在diffusers-torchao上有大量的基准测试,并展示了Flux.1-Dev上53.88%的速度提升以及CogVideoX-5b的27.33%速度提升。

我们的API具有可组合性,例如我们可以将稀疏性和量化相结合,实现了ViT-H推理中5%的速度提升 参考

也可以做其他事情,比如将权重量化到int4,KV缓存量化到int8,以支持Llama 3.1 8B在128K上下文长度下,在不到18.9GB的VRAM中运行

量化感知训练(QAT)

后训练量化对于少于4位的量化可能会出现严重的精度下降。使用 量化感知训练 (QAT),我们成功恢复了高达96%的精度损失。我们在torchtune中集成了这一端到端的实例,并提供了一个简短的 教程

训练

低精度计算与通信

torchao 提供了易于使用的端到端工作流程,可用于在训练过程中减少计算和分布式通信的精度。从torch.nn.Linear 层的float8开始,以下是用于将您的训练计算gemms 转换为float8的一行代码:

from torchao.float8 import convert_to_float8_training
convert_to_float8_training(model)

有关如何使用float8加快LLaMa 3 70B 预训练达至最高1.5倍速度的完整示例,请参阅我们的README,以及torchtitan的博客float8 配方

float8 预训练LLaMa 3 70B与bfloat16 的性能与精度对比

(来源: https://dev-discuss.pytorch.org/t/enabling-float8-all-gather-in-fsdp2/2359)

我们正在扩展我们的训练工作流程,以适应更多数据类型和布局

  1. NF4 QLoRA in torchtune
  2. Prototype int8 训练支持
  3. 加速稀疏2:4 训练

低位宽优化器

受到Bits and Bytes的启发,我们还加入了8位和4位优化器的原型支持,作为AdamW的即插即用替代方案。

from torchao.prototype.low_bit_optim import AdamW8bit, AdamW4bit
optim = AdamW8bit(model.parameters())

集成

我们一直努力确保torchao能很好地适用于开源项目中最重要的一些项目。

  1. Huggingface transformers作为 推理后端
  2. 作为加速扩散模型的参考实现 在 diffusers-torchao 中
  3. 在HQQ中实现 快速4位推理
  4. torchtune 中进行PyTorch原生QLoRA和QAT配方
  5. torchchat 中进行后训练量化
  6. 在SGLang中进行 int4和int8 后训练量化

结论

如果您有兴趣提高模型在训练或推理中的效率,我们希望您会发现torchao既有用又易于集成。

pip install torchao

我们对未来有很多激动人心的计划,包括低于4位的探索,高吞吐量推理高效内核,扩展更多的层、尺度或粒度,支持MX硬件,并支持更多硬件后端。