Spaces:
Sleeping
Sleeping
| import spaces | |
| from huggingface_hub import hf_hub_download | |
| import subprocess | |
| import importlib | |
| import site | |
| import torch | |
| # Re-discover all .pth/.egg-link files | |
| for sitedir in site.getsitepackages(): | |
| site.addsitedir(sitedir) | |
| # Clear caches so importlib will pick up new modules | |
| importlib.invalidate_caches() | |
| def sh(cmd): | |
| subprocess.check_call(cmd, shell=True) | |
| flash_attention_installed = False | |
| try: | |
| print("Attempting to download and install FlashAttention wheel...") | |
| flash_attention_wheel = hf_hub_download( | |
| repo_id="alexnasa/flash-attn-3", | |
| repo_type="model", | |
| filename="128/flash_attn_3-3.0.0b1-cp39-abi3-linux_x86_64.whl", | |
| ) | |
| sh(f"pip install {flash_attention_wheel}") | |
| # tell Python to re-scan site-packages now that the egg-link exists | |
| import importlib, site | |
| site.addsitedir(site.getsitepackages()[0]) | |
| importlib.invalidate_caches() | |
| flash_attention_installed = True | |
| print("FlashAttention installed successfully.") | |
| except Exception as e: | |
| print(f"⚠️ Could not install FlashAttention: {e}") | |
| print("Continuing without FlashAttention...") | |
| attn_implementation = "flash_attention_2" if flash_attention_installed else "sdpa" | |
| dtype = torch.bfloat16 if flash_attention_installed else None | |