File size: 1,273 Bytes
d9ce662
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
import spaces
from huggingface_hub import hf_hub_download
import subprocess
import importlib
import site

import torch

# Re-discover all .pth/.egg-link files
for sitedir in site.getsitepackages():
    site.addsitedir(sitedir)

# Clear caches so importlib will pick up new modules
importlib.invalidate_caches()


def sh(cmd):
    subprocess.check_call(cmd, shell=True)


flash_attention_installed = False

try:
    print("Attempting to download and install FlashAttention wheel...")
    flash_attention_wheel = hf_hub_download(
        repo_id="alexnasa/flash-attn-3",
        repo_type="model",
        filename="128/flash_attn_3-3.0.0b1-cp39-abi3-linux_x86_64.whl",
    )

    sh(f"pip install {flash_attention_wheel}")

    # tell Python to re-scan site-packages now that the egg-link exists
    import importlib, site

    site.addsitedir(site.getsitepackages()[0])
    importlib.invalidate_caches()

    flash_attention_installed = True
    print("FlashAttention installed successfully.")

except Exception as e:
    print(f"⚠️ Could not install FlashAttention: {e}")
    print("Continuing without FlashAttention...")

attn_implementation = "flash_attention_2" if flash_attention_installed else "sdpa"
dtype = torch.bfloat16 if flash_attention_installed else None