解决Python 3.10环境中flash_attn_2_cuda模块导入错误的问题

情况描述

环境:

linux
transformers 4.39.0
tokenizers 0.15.2
torch 2.1.2+cu121
flash-attn 2.3.3

在使用vllm运行xverse/XVERSE-13B-256K时(代码如下):

qwen_model = AutoModelForSequenceClassification.from_pretrained(
    args.pre_train,
    trust_remote_code=True,
    attn_implementation="flash_attention_2",
    torch_dtype=torch.bfloat16,
    device_map="auto",   # balanced_low_0
    num_labels=5
)

报错如下

Traceback (most recent call last):
  File "/usr/local/app/.local/lib/python3.10/site-packages/transformers/utils/import_utils.py", line 1364, in _get_module
    return importlib.import_module("." + module_name, self.__name__)
  File "/data/miniconda3/envs/xxx/lib/python3.10/importlib/__init__.py", line 126, in import_module
    return _bootstrap._gcd_import(name[level:], package, level)
  File "<frozen importlib._bootstrap>", line 1050, in _gcd_import
  File "<frozen importlib._bootstrap>", line 1027, in _find_and_load
  File "<frozen importlib._bootstrap>", line 1006, in _find_and_load_unlocked
  File "<frozen importlib._bootstrap>", line 688, in _load_unlocked
  File "<frozen importlib._bootstrap_external>", line 883, in exec_module
  File "<frozen importlib._bootstrap>", line 241, in _call_with_frames_removed
  File "/usr/local/app/.local/lib/python3.10/site-packages/transformers/models/qwen2/modeling_qwen2.py", line 49, in <module>
    from flash_attn import flash_attn_func, flash_attn_varlen_func
  File "/usr/local/app/.local/lib/python3.10/site-packages/flash_attn/__init__.py", line 3, in <module>
    from flash_attn.flash_attn_interface import (
  File "/usr/local/app/.local/lib/python3.10/site-packages/flash_attn/flash_attn_interface.py", line 10, in <module>
    import flash_attn_2_cuda as flash_attn_cuda
ImportError: /usr/local/app/.local/lib/python3.10/site-packages/flash_attn_2_cuda.cpython-310-x86_64-linux-gnu.so: undefined symbol: _ZN3c104cuda9SetDeviceEi

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/cfs/xxx/xxx/long-context/xxx/train.py", line 434, in <module>
    qwen_model = AutoModelForCausalLM.from_pretrained(
  File "/usr/local/app/.local/lib/python3.10/site-packages/transformers/models/auto/auto_factory.py", line 565, in from_pretrained
    model_class = _get_model_class(config, cls._model_mapping)
  File "/usr/local/app/.local/lib/python3.10/site-packages/transformers/models/auto/auto_factory.py", line 387, in _get_model_class
    supported_models = model_mapping[type(config)]
  File "/usr/local/app/.local/lib/python3.10/site-packages/transformers/models/auto/auto_factory.py", line 740, in __getitem__
    return self._load_attr_from_module(model_type, model_name)
  File "/usr/local/app/.local/lib/python3.10/site-packages/transformers/models/auto/auto_factory.py", line 754, in _load_attr_from_module
    return getattribute_from_module(self._modules[module_name], attr)
  File "/usr/local/app/.local/lib/python3.10/site-packages/transformers/models/auto/auto_factory.py", line 698, in getattribute_from_module
    if hasattr(module, attr):
  File "/usr/local/app/.local/lib/python3.10/site-packages/transformers/utils/import_utils.py", line 1354, in __getattr__
    module = self._get_module(self._class_to_module[name])
  File "/usr/local/app/.local/lib/python3.10/site-packages/transformers/utils/import_utils.py", line 1366, in _get_module
    raise RuntimeError(
RuntimeError: Failed to import transformers.models.qwen2.modeling_qwen2 because of the following error (look up to see its traceback):
/usr/local/app/.local/lib/python3.10/site-packages/flash_attn_2_cuda.cpython-310-x86_64-linux-gnu.so: undefined symbol: _ZN3c104cuda9SetDeviceEi

解决

pip install flash-attn==2.5.9.post1

作者:Cyril_KI

物联沃分享整理
物联沃-IOTWORD物联网 » 解决Python 3.10环境中flash_attn_2_cuda模块导入错误的问题

发表回复