Flash Attention 安装记录

配环境天天被这个 flash attention 恶心,记录一下坑避免未来配环境的时候玉玉症。

不存在通用解决方案,一些原则和经验如下:

  • 直接 pip install flash-attn 大概率爆炸,通常从 GitHub releases 里面下载 wheel 安装。
  • 选 cxx11abiFALSE,否则可能会出现一个 ImportError 提示 undefined symbol。
  • 注意 torch 版本匹配(即 wheel 文件名里面那个 torch 版本和环境里面装的 torch 版本匹配)。
  • 注意 torch 的 CUDA 版本。理论上尽量和安装的 CUDA toolkit 版本(就是 nvcc -V 显示的那个)一致。
    • 如果无法更新驱动导致无法保证一致似乎只能多试一试,版本号尽量接近的有更大概率没问题。
      • 当时我实验室服务器的 docker 驱动只支持到 12.2,torch 2.5 之前的版本跑一些模型会提示有安全问题,2.6.0 之后要么 cu118 要么 cu124,结果 flash attn 2.8.3 又没有 cu11 支持,,,无语了,最后通过暴力尝试发现 2.6.0+cu124 可以和 CUDA 12.2 一起用。
  • 注意 CUDA 版本不能超过驱动所支持的版本(即 nvidia-smi 显示的那个)。
  • transformers、accelerate 这些东西大概率装最新的都没问题。