Spaces:
Runtime error
Runtime error
Update policy.py
Browse files
policy.py
CHANGED
|
@@ -1,59 +1,53 @@
|
|
| 1 |
import torch
|
| 2 |
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 3 |
|
| 4 |
-
def load_policy_model():
|
| 5 |
-
model_name = "microsoft/phi-2"
|
| 6 |
-
|
| 7 |
-
print(">>> LOADING PHI-2...")
|
| 8 |
-
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
| 9 |
|
| 10 |
-
|
| 11 |
-
model_name,
|
| 12 |
-
device_map="auto",
|
| 13 |
-
torch_dtype=torch.float16
|
| 14 |
-
)
|
| 15 |
-
|
| 16 |
-
# -----------------------------------------------------------
|
| 17 |
-
# 1. Identify the REAL lm_head and embedding weights
|
| 18 |
-
# -----------------------------------------------------------
|
| 19 |
-
embed = model.model.embed_tokens
|
| 20 |
-
old_lm_head = model.lm_head # This is actually tied to embed
|
| 21 |
-
|
| 22 |
-
print(">>> UNTIEING LM HEAD...")
|
| 23 |
|
| 24 |
-
# -----------------------------------------------------------
|
| 25 |
-
# 2. Create a new untied lm_head
|
| 26 |
-
# -----------------------------------------------------------
|
| 27 |
-
vocab_size, hidden_size = old_lm_head.weight.shape
|
| 28 |
-
new_lm_head = torch.nn.Linear(hidden_size, vocab_size, bias=True)
|
| 29 |
|
| 30 |
-
|
| 31 |
-
if
|
| 32 |
-
new_lm_head.bias.data = old_lm_head.bias.data.clone()
|
| 33 |
|
| 34 |
-
|
| 35 |
-
|
|
|
|
| 36 |
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 40 |
for name, param in model.named_parameters():
|
| 41 |
-
param.requires_grad =
|
| 42 |
|
| 43 |
-
|
| 44 |
-
# 4. Unfreeze ONLY the new lm_head
|
| 45 |
-
# -----------------------------------------------------------
|
| 46 |
for name, param in model.named_parameters():
|
| 47 |
-
if
|
| 48 |
-
param.requires_grad = True
|
| 49 |
print("TRAINABLE:", name)
|
| 50 |
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 57 |
|
| 58 |
-
model.optimizer = torch.optim.Adam(trainable, lr=1e-4)
|
| 59 |
return model, tokenizer
|
|
|
|
| 1 |
import torch
|
| 2 |
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 3 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4 |
|
| 5 |
+
MODEL_NAME = "microsoft/phi-2"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 7 |
|
| 8 |
+
def load_policy_model():
|
| 9 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
|
|
| 10 |
|
| 11 |
+
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
|
| 12 |
+
if tokenizer.pad_token_id is None:
|
| 13 |
+
tokenizer.pad_token = tokenizer.eos_token
|
| 14 |
|
| 15 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 16 |
+
MODEL_NAME,
|
| 17 |
+
torch_dtype=torch.float16,
|
| 18 |
+
device_map=None,
|
| 19 |
+
).to(device)
|
| 20 |
+
|
| 21 |
+
# Untie lm_head and freeze everything except lm_head
|
| 22 |
+
model.lm_head = torch.nn.Linear(
|
| 23 |
+
model.lm_head.in_features,
|
| 24 |
+
model.lm_head.out_features,
|
| 25 |
+
bias=True,
|
| 26 |
+
device=device,
|
| 27 |
+
dtype=torch.float16,
|
| 28 |
+
)
|
| 29 |
for name, param in model.named_parameters():
|
| 30 |
+
param.requires_grad = name.startswith("lm_head")
|
| 31 |
|
| 32 |
+
print(">>> UNTIEING LM HEAD...")
|
|
|
|
|
|
|
| 33 |
for name, param in model.named_parameters():
|
| 34 |
+
if param.requires_grad:
|
|
|
|
| 35 |
print("TRAINABLE:", name)
|
| 36 |
|
| 37 |
+
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
| 38 |
+
print(">>> FINAL TRAINABLE PARAM COUNT:", trainable_params)
|
| 39 |
+
|
| 40 |
+
# Optimizer: only lm_head, small LR
|
| 41 |
+
optimizer = torch.optim.AdamW(
|
| 42 |
+
(p for p in model.parameters() if p.requires_grad),
|
| 43 |
+
lr=1e-5,
|
| 44 |
+
)
|
| 45 |
+
model.optimizer = optimizer
|
| 46 |
+
|
| 47 |
+
# Sanity check: no NaN / Inf in fresh weights
|
| 48 |
+
with torch.no_grad():
|
| 49 |
+
for p in model.parameters():
|
| 50 |
+
if torch.isnan(p).any() or torch.isinf(p).any():
|
| 51 |
+
raise RuntimeError("Loaded model checkpoint has NaN/Inf parameters.")
|
| 52 |
|
|
|
|
| 53 |
return model, tokenizer
|