解决Python 3.8环境中flash_attn_2_cuda.cpython-38-x86_64-linux-gnu.so导入错误

背景

调试网络时用到了FalshAttention,直接用的是flash_attn这个库,出现了以下异常

File "/usr/local/lib/python3.8/dist-packages/flash_attn/__init__.py", line 3, in <module>
    from flash_attn.flash_attn_interface import (
File "/usr/local/lib/python3.8/dist-packages/flash_attn/flash_attn_interface.py", line 8, in <module>
    import flash_attn_2_cuda as flash_attn_cuda
ImportError: /usr/local/lib/python3.8/dist-packages/flash_attn_2_cuda.cpython-38-x86_64-linux-gnu.so: undefined symbol: _ZN2at4_ops19empty_memory_format4callEN3c108ArrayRefIlEENS2_8optionalINS2_10ScalarTypeEEENS5_INS2_6LayoutEEENS5_INS2_6DeviceEEENS5_IbEENS5_INS2_12MemoryFormatEEE

原因分析

  1. 从异常上看,提示flash_attn_2_cuda.cpython-38-x86_64-linux-gnu.so这个库异常,这种未定义符号的异常,一般都是编译so时和当前环境不一致导致的
  2. 具体到flash_attn这个库,如果不是从源码编译,其对cuda版本和torch版本都是有要求的,所以在官方github的release上可以看到官方会提供很多不同cuda和torch版本的whl文件,如下所示

解决方法

方法一

  1. 从官方release种找到对应cuda版本和torch版本的whl文件,并下载
  2. 在本地使用pip3 install ${whl}的方式安装

方法二

从源码直接编译,详见官方github

作者:Garfield2005

物联沃分享整理
物联沃-IOTWORD物联网 » 解决Python 3.8环境中flash_attn_2_cuda.cpython-38-x86_64-linux-gnu.so导入错误

发表回复