英伟达又赚到了!FlashAttention3来了:H100利用率飙升至75%

机器之心报道

编辑:陈陈、小舟

740 TFLOPS!迄今最强 FlashAttention 来了。

随着大型语言模型(LLM)加速落地,扩展模型上下文窗口变得越来越重要。然而,Transformer 架构的核心 —— 注意力层的时间复杂度和空间复杂度与输入序列长度的平方成正比。这使得扩展模型上下文窗口存在挑战。


2022 年,一种快速、内存高效的注意力算法 ——FlashAttention 问世,该算法无需任何近似即可加速注意力并减少内存占用。


FlashAttention 对注意力计算进行重新排序的算法,并利用 tiling 和重计算来显著加快计算速度,将内存使用量从序列长度的二次减少到线性。



2023 年,研究团队宣布推出 FlashAttention-2,在算法、并行化和工作分区等方面有了显著改进。


现在,来自 Meta、英伟达、Together AI 等机构的研究者宣布推出 FlashAttention-3,它采用了加速 Hopper GPU 注意力的三种主要技术:


  • 通过 warp-specialization 重叠整体计算和数据移动;

  • 交错分块 matmul 和 softmax 运算;

  • 利用硬件支持 FP8 低精度的不连贯处理。


FlashAttention-3 的速度是 FlashAttention-2 的 1.5-2.0 倍,高达 740 TFLOPS,即 H100 理论最大 FLOPS 利用率为 75%。使用 FP8,FlashAttention-3 的速度更是接近 1.2 PFLOPS。


FlashAttention-3 的改进将带来:


  • 更高效的 GPU 利用率:H100 理论最大 FLOPS 利用率为 75%,而之前仅为 35%。这使得 LLM 的训练和运行速度比以前的版本快得多。

  • 较低精度下更好的性能:FlashAttention-3 可以在保持精度的同时使用较低精度的数字 (FP8)。这可以实现更快的处理速度并可能降低内存使用量,从而为运行大规模人工智能操作的客户节省成本并提高效率。

  • 能够在 LLM 中使用更长的上下文:通过加速注意力机制,FlashAttention-3 使 AI 模型能够更有效地处理更长的文本片段。这使得应用程序能够理解并生成更长、更复杂的内容而不会减慢速度。



论文标题:FlashAttention-3: Fast and Accurate Attention with Asynchrony and Low-precision

论文地址:https://tridao.me/publications/flash3/flash3.pdf


论文作者之一 、FlashAttention1-3 版本的参与者 Tri Dao 表示:FlashAttention 被广泛用于加速 Transformers,已经使注意力速度提高了 4-8 倍,但尚未利用现代 GPU。因而他们发布了 FlashAttention-3:在 FP16 上速度提高了 1.5-2 倍,在 H100 上高达 740 TFLOPS(75% 实用性),FP8 接近 1.2 PFLOPS!



Hopper GPU 硬件特性:WGMMA、TMA、FP8


虽然 FlashAttention-2 在 Ampere (A100) GPU 上可以实现 70% 的理论最大 FLOPS,但它尚未利用 Hopper GPU 上的新功能来最大限度地提高性能。接下来文章描述了一些新的 Hopper 特定功能,以及它们为何如此重要。


首先是 WGMMA(Warpgroup Matrix Multiply-Accumulate),该功能利用了 Hopper 架构上新的张量内核,比 Ampere 架构具有更高的吞吐量。



然后是 TMA(Tensor Memory Accelerator),这是一个特殊的硬件单元,可以加速全局内存和共享内存之间的数据传输,用于处理所有索引计算和边界外预测。这样一来寄存器就释放了,寄存器是增加 tile 大小和效率的宝贵资源。



低精度 FP8,让 Tensor Core 吞吐量翻了一倍。



FlashAttention-3 充分利用了 Hopper 架构的所有这些新功能。


异步:GEMM 和 Softmax 重叠


注意力机制主要有两个操作,GEMM 和 softmax。为什么要将它们重叠?


问题在于在现代加速器上,非矩阵乘法(matmul)运算比矩阵乘法运算慢。特殊函数如指数运算(如 softmax 函数)的吞吐量甚至低于浮点乘加操作;这些运算是由多功能单元处理的,这是一个与浮点乘加或矩阵乘加不同的单元。


理想情况下,研究者希望矩阵乘法和 softmax 能够并行操作。当 Tensor Cores 忙于矩阵乘法时,多功能单元应当在计算指数运算! 


Inter-warpgroup 重叠


重叠 GEMM 和 softmax 最简单的方法是什么都不做,warp 调度程序会免费完成部分重叠。下图说明了 pingpong 调度,其中相同的颜色表示相同的迭代。



Intra-warpgroup 重叠


即使在一个 warpgroup 中,研究者也可以在运行该 warpgroup 的 GEMM 时运行 softmax 的某些部分。如图所示,相同的颜色表示相同的迭代。



这种 pipeline 流程可以将 FP16 注意力前向传播的吞吐量从大约 620 TFLOPS 提高到 640-660 TFLOPS,但代价是更高的寄存器压力,因而需要更多的寄存器来同时保存 GEMM 的累加器以及 Softmax 的输入 / 输出。


低精度:使用非相干处理减少量化误差


激活 LLM 可能存在一些极端值,导致量化困难,从而产生较大的量化误差。本文采用非相干处理(incoherent processing),该技术通过将查询和键与一个随机正交矩阵相乘来「分散(spread out)」极端值,从而减少量化误差。特别地,该研究使用了 Hadamard 变换,它可以在每个注意力头中以 O (d log d) 的时间复杂度完成,而不是 O (d^2),其中 d 是头部维度。


研究者发现非相干处理可以将量化误差减少很多,具体的数值误差比较见下表。



实验


文中展示了 FlashAttention-3 的一些结果,并将其与 FlashAttention-2 以及 Triton 和 cuDNN 中的实现进行了比较(两者都已经使用了 Hopper GPU 的新硬件功能)。


在 FP16 精度下,FlashAttention-3 的速度是 FlashAttention-2 的 1.5-2.0 倍。



对于 FP8,FlashAttention-3 接近 1.2 PFLOPS。



扩展阅读:


斯坦福提出新型Attention算法!提速2-4倍,BERT单节点训练最快 

比标准Attention提速5-9倍,大模型都在用的FlashAttention v2来了


参考链接:

https://tridao.me/blog/2024/flash3/



© THE END 

转载请联系本公众号获得授权

投稿或寻求报道:content@jiqizhixin.com

相关推荐

  • 五年后的今天,训练GPT-2只需不到700刀、24小时,Karpathy又整新活
  • 【机器学习】XGBoost和LightGBM时间序列预测对比
  • 【仅限10名,留言领会议门票】ICCBD+AI 2024 群贤汇聚,期待您参会投稿!
  • 我是真的后悔从国家电网离职了。。
  • RAG落地环节的15个控制点及优化思路:兼看KG-RAG技术总结线上分享
  • 看完这篇,你的API服务设计能力将再次进化!
  • Python 中的 @wraps 到底是个啥东西?
  • 一分钟原画变3D角色,清华VAST成果入选图形学顶会SIGGRAPH
  • AI慢思考蒸馏进快思考,Llama2跃升至GPT-4水平,不写过程也能做对题
  • 苏妈掷48亿现金吞下AI模型公司,英伟达有的AMD也要有
  • 程序员如何用好“AI搭子”?实操演示来了,揭秘多元业务场景如何用AI工具提效降本
  • H100利用率飙升至75%!英伟达亲自下场FlashAttention三代升级,比标准注意力快16倍
  • 告警:MyBatis-Plus中慎用@Transactional注解,坑的差点被开了...
  • A800,它真免费送啊!
  • SOA 和微服务有何区别?
  • @Schedule定时任务+分布式环境,这些坑你一定得注意!!!
  • 不服不行,这才是后端API接口应该有的样子!
  • 20 个好看又酷炫的 404 页面【附源码】
  • 中国数据库前世今生:90年代的群雄争霸与技术革新
  • Chrome 浏览器权限升级:竟可访问系统资源!