tmp-service / install_flsh_attn.py
JacobLinCool's picture
Create install_flsh_attn.py
d9ce662 verified
raw
history blame
1.27 kB
import spaces
from huggingface_hub import hf_hub_download
import subprocess
import importlib
import site
import torch
# Re-discover all .pth/.egg-link files
for sitedir in site.getsitepackages():
site.addsitedir(sitedir)
# Clear caches so importlib will pick up new modules
importlib.invalidate_caches()
def sh(cmd):
subprocess.check_call(cmd, shell=True)
flash_attention_installed = False
try:
print("Attempting to download and install FlashAttention wheel...")
flash_attention_wheel = hf_hub_download(
repo_id="alexnasa/flash-attn-3",
repo_type="model",
filename="128/flash_attn_3-3.0.0b1-cp39-abi3-linux_x86_64.whl",
)
sh(f"pip install {flash_attention_wheel}")
# tell Python to re-scan site-packages now that the egg-link exists
import importlib, site
site.addsitedir(site.getsitepackages()[0])
importlib.invalidate_caches()
flash_attention_installed = True
print("FlashAttention installed successfully.")
except Exception as e:
print(f"⚠️ Could not install FlashAttention: {e}")
print("Continuing without FlashAttention...")
attn_implementation = "flash_attention_2" if flash_attention_installed else "sdpa"
dtype = torch.bfloat16 if flash_attention_installed else None