解决Python 3.8环境中flash_attn_2_cuda.cpython-38-x86_64-linux-gnu.so导入错误
背景
调试网络时用到了FalshAttention,直接用的是flash_attn这个库,出现了以下异常
File "/usr/local/lib/python3.8/dist-packages/flash_attn/__init__.py", line 3, in <module>
from flash_attn.flash_attn_interface import (
File "/usr/local/lib/python3.8/dist-packages/flash_attn/flash_attn_interface.py", line 8, in <module>
import flash_attn_2_cuda as flash_attn_cuda
ImportError: /usr/local/lib/python3.8/dist-packages/flash_attn_2_cuda.cpython-38-x86_64-linux-gnu.so: undefined symbol: _ZN2at4_ops19empty_memory_format4callEN3c108ArrayRefIlEENS2_8optionalINS2_10ScalarTypeEEENS5_INS2_6LayoutEEENS5_INS2_6DeviceEEENS5_IbEENS5_INS2_12MemoryFormatEEE
原因分析
- 从异常上看,提示flash_attn_2_cuda.cpython-38-x86_64-linux-gnu.so这个库异常,这种未定义符号的异常,一般都是编译so时和当前环境不一致导致的
- 具体到flash_attn这个库,如果不是从源码编译,其对cuda版本和torch版本都是有要求的,所以在官方github的release上可以看到官方会提供很多不同cuda和torch版本的whl文件,如下所示
解决方法
方法一
- 从官方release种找到对应cuda版本和torch版本的whl文件,并下载
- 在本地使用pip3 install ${whl}的方式安装
方法二
从源码直接编译,详见官方github
作者:Garfield2005