JacobLinCool commited on
Commit
d9ce662
·
verified ·
1 Parent(s): 6510c49

Create install_flsh_attn.py

Browse files
Files changed (1) hide show
  1. install_flsh_attn.py +47 -0
install_flsh_attn.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import spaces
2
+ from huggingface_hub import hf_hub_download
3
+ import subprocess
4
+ import importlib
5
+ import site
6
+
7
+ import torch
8
+
9
+ # Re-discover all .pth/.egg-link files
10
+ for sitedir in site.getsitepackages():
11
+ site.addsitedir(sitedir)
12
+
13
+ # Clear caches so importlib will pick up new modules
14
+ importlib.invalidate_caches()
15
+
16
+
17
+ def sh(cmd):
18
+ subprocess.check_call(cmd, shell=True)
19
+
20
+
21
+ flash_attention_installed = False
22
+
23
+ try:
24
+ print("Attempting to download and install FlashAttention wheel...")
25
+ flash_attention_wheel = hf_hub_download(
26
+ repo_id="alexnasa/flash-attn-3",
27
+ repo_type="model",
28
+ filename="128/flash_attn_3-3.0.0b1-cp39-abi3-linux_x86_64.whl",
29
+ )
30
+
31
+ sh(f"pip install {flash_attention_wheel}")
32
+
33
+ # tell Python to re-scan site-packages now that the egg-link exists
34
+ import importlib, site
35
+
36
+ site.addsitedir(site.getsitepackages()[0])
37
+ importlib.invalidate_caches()
38
+
39
+ flash_attention_installed = True
40
+ print("FlashAttention installed successfully.")
41
+
42
+ except Exception as e:
43
+ print(f"⚠️ Could not install FlashAttention: {e}")
44
+ print("Continuing without FlashAttention...")
45
+
46
+ attn_implementation = "flash_attention_2" if flash_attention_installed else "sdpa"
47
+ dtype = torch.bfloat16 if flash_attention_installed else None