|
|
from einops import rearrange
|
|
|
from torch.nn import functional as F
|
|
|
from dotenv import load_dotenv
|
|
|
import os
|
|
|
import sys
|
|
|
|
|
|
|
|
|
|
|
|
from core.vision_encoder.pe import SelfAttention, AttentionPooling
|
|
|
import torch.nn as nn
|
|
|
from typing import Dict, List
|
|
|
from utils.task_config import Task
|
|
|
import torch
|
|
|
from typing import Optional, Union, Mapping,OrderedDict
|
|
|
from src.dlora import *
|
|
|
from peft import PeftModel, get_peft_model, LoraConfig
|
|
|
DROPOUT_P = 0.5
|
|
|
|
|
|
|
|
|
class MTLModel(nn.Module):
|
|
|
def __init__(self, backbone, tasks: List[Task], device,
|
|
|
rank: int = 64,
|
|
|
use_lora: bool = True,
|
|
|
truncate_idx: int = 22,
|
|
|
last_lora_layers: int = -99,
|
|
|
lora_dropout: float = 0.5,
|
|
|
use_mtl_lora :bool = False,
|
|
|
use_deep_head:bool = False,
|
|
|
use_batch_norm:bool = True,
|
|
|
use_mtl_attn_pool: bool = True,
|
|
|
use_dora:bool = True,
|
|
|
):
|
|
|
|
|
|
super().__init__()
|
|
|
self.use_mtl_attn_pool=use_mtl_attn_pool
|
|
|
self.tasks = tasks
|
|
|
self.use_mtl_lora = use_mtl_lora
|
|
|
self.use_deep_head= use_deep_head
|
|
|
self.use_lora = use_lora
|
|
|
self.use_mtlora = use_mtl_lora
|
|
|
output_dim = backbone.output_dim
|
|
|
|
|
|
self.log_vars = nn.Parameter(torch.zeros(len(tasks)))
|
|
|
task_names = [task.name for task in tasks]
|
|
|
self.backbone = backbone
|
|
|
width = backbone.width
|
|
|
heads = backbone.heads
|
|
|
rope = backbone.rope
|
|
|
|
|
|
if self.use_mtl_lora:
|
|
|
|
|
|
orig_last_block = backbone.transformer.resblocks[-1]
|
|
|
self.ln_post = backbone.ln_post
|
|
|
|
|
|
|
|
|
orig_attn_pool = backbone.attn_pool.to(device)
|
|
|
|
|
|
self.backbone.truncate(layer_idx=truncate_idx)
|
|
|
|
|
|
|
|
|
self.mtl_layer = MTLoRAResidualAttentionBlock(
|
|
|
d_model=width,
|
|
|
n_head=heads,
|
|
|
rope=rope,
|
|
|
r={'shared': rank, **{name: rank for name in task_names}},
|
|
|
tasks=task_names,
|
|
|
shared_mode='matrix' ,
|
|
|
lora_shared_scale=0.0
|
|
|
)
|
|
|
|
|
|
|
|
|
self.mtl_layer.load_from_original_block(orig_last_block)
|
|
|
print("MTL-LoRA final block created and initialized from pretrained weights.")
|
|
|
|
|
|
|
|
|
if self.use_mtl_attn_pool:
|
|
|
self.attn_pool = MTLoRAAttentionPooling(
|
|
|
embed_dim=width,
|
|
|
num_heads=8,
|
|
|
tasks=task_names,
|
|
|
r={'shared': rank, **{name: rank for name in task_names}},
|
|
|
lora_dropout=lora_dropout,
|
|
|
lora_task_scale=1.0,
|
|
|
lora_shared_scale=0.0
|
|
|
)
|
|
|
self.attn_pool.load_from_original(orig_attn_pool)
|
|
|
else:
|
|
|
self.task_specific_attn_pool = nn.ModuleDict({
|
|
|
task.name: AttentionPooling(embed_dim=width, num_heads=8)
|
|
|
for task in self.tasks
|
|
|
})
|
|
|
for task in self.tasks:
|
|
|
self.task_specific_attn_pool[task.name].load_state_dict(orig_attn_pool.state_dict())
|
|
|
print("Task-specific Attention Pooling layers created and initialized.")
|
|
|
|
|
|
|
|
|
del self.backbone.attn_pool
|
|
|
|
|
|
|
|
|
|
|
|
if use_lora:
|
|
|
|
|
|
target_layers = ["attn.in_proj", "attn.out_proj", "mlp.c_fc", "mlp.c_proj"]
|
|
|
target_modules = []
|
|
|
for name, param in self.backbone.named_modules():
|
|
|
if not isinstance(param, nn.Linear):
|
|
|
continue
|
|
|
is_target_layer = any(s in name for s in target_layers)
|
|
|
if is_target_layer:
|
|
|
if "attn_pool" in name:
|
|
|
target_modules.append(name)
|
|
|
elif "transformer.resblocks" in name:
|
|
|
layer_idx = int(name.split('.')[2])
|
|
|
if layer_idx >= last_lora_layers:
|
|
|
target_modules.append(name)
|
|
|
|
|
|
lora_config = LoraConfig(
|
|
|
r=rank,
|
|
|
lora_alpha=rank,
|
|
|
target_modules= target_modules,
|
|
|
use_dora=use_dora,
|
|
|
lora_dropout=lora_dropout,
|
|
|
bias = "none"
|
|
|
)
|
|
|
|
|
|
self.backbone = get_peft_model(self.backbone,lora_config)
|
|
|
print("PEFT LoRA module added")
|
|
|
|
|
|
|
|
|
if self.use_deep_head == False:
|
|
|
self.prediction_layers = nn.ModuleDict({
|
|
|
task.name: nn.Sequential(
|
|
|
nn.BatchNorm1d(backbone.output_dim) if use_batch_norm else nn.Identity(),
|
|
|
nn.Dropout(p=DROPOUT_P),
|
|
|
nn.Linear( backbone.output_dim, len(task.class_labels))
|
|
|
)
|
|
|
for task in self.tasks
|
|
|
})
|
|
|
print("Task-specific prediction heads created.")
|
|
|
else:
|
|
|
self.prediction_layers = nn.ModuleDict({
|
|
|
task.name: nn.Sequential(
|
|
|
nn.BatchNorm1d(backbone.output_dim) if use_batch_norm else nn.Identity(),
|
|
|
nn.Dropout(p=DROPOUT_P),
|
|
|
nn.Linear(backbone.output_dim, backbone.output_dim),
|
|
|
nn.GELU(),
|
|
|
nn.Linear(backbone.output_dim, len(task.class_labels)),
|
|
|
)
|
|
|
for task in self.tasks
|
|
|
})
|
|
|
print("Task-specific prediction deep-heads created.")
|
|
|
|
|
|
|
|
|
self.backbone.del_muda()
|
|
|
|
|
|
|
|
|
|
|
|
def enable_gradient_checkpointing(self):
|
|
|
"""Call this method after setting up parameter requires_grad"""
|
|
|
backbone_has_trainable = any(param.requires_grad for param in self.backbone.parameters())
|
|
|
if backbone_has_trainable:
|
|
|
self.backbone.set_grad_checkpointing()
|
|
|
print("Gradient checkpointing enabled for backbone (has trainable parameters)")
|
|
|
else:
|
|
|
print("Gradient checkpointing not enabled - backbone has no trainable parameters")
|
|
|
|
|
|
def forward(self, x: torch.Tensor):
|
|
|
if self.use_mtl_lora:
|
|
|
return self._forward_mtl_block(x)
|
|
|
else:
|
|
|
return self._forward_shared(x)
|
|
|
|
|
|
def _forward_shared(self, x: torch.Tensor):
|
|
|
logits = {}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
features = self.backbone(x)
|
|
|
|
|
|
for task in self.tasks:
|
|
|
logits[task.name] = self.prediction_layers[task.name](features)
|
|
|
|
|
|
|
|
|
return logits
|
|
|
|
|
|
def _forward_mtl_block(self, x: torch.Tensor, return_feat=False, feat_to_return="None"):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
features = self.backbone.forward_features(x, norm=False, strip_cls_token=False)
|
|
|
|
|
|
|
|
|
task_features_input = {task.name: features for task in self.tasks}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
_, task_features = self.mtl_layer(features, x_tasks=task_features_input)
|
|
|
|
|
|
normalized_task_features = {
|
|
|
task.name: self.ln_post(task_features[task.name])
|
|
|
for task in self.tasks
|
|
|
}
|
|
|
|
|
|
if self.use_mtl_attn_pool:
|
|
|
pooled_features = self.attn_pool(normalized_task_features)
|
|
|
else:
|
|
|
pooled_features = {}
|
|
|
for task in self.tasks:
|
|
|
feat = normalized_task_features[task.name]
|
|
|
pooled_features[task.name] = self.task_specific_attn_pool[task.name](feat)
|
|
|
|
|
|
|
|
|
if return_feat:
|
|
|
if feat_to_return == "Age":
|
|
|
return pooled_features['Age']
|
|
|
elif feat_to_return == "Emotion":
|
|
|
return pooled_features['Emotion']
|
|
|
elif feat_to_return == "Gender":
|
|
|
return pooled_features['Gender']
|
|
|
|
|
|
|
|
|
logits = {}
|
|
|
for task in self.tasks:
|
|
|
|
|
|
pooled_feat = pooled_features[task.name].squeeze(1)
|
|
|
logits[task.name] = self.prediction_layers[task.name](pooled_feat)
|
|
|
|
|
|
return logits
|
|
|
|
|
|
def save_whole_model(self, filepath: str):
|
|
|
print(f"Saving model state_dict to {filepath}")
|
|
|
torch.save(self.state_dict(), filepath)
|
|
|
|
|
|
def load_model(self, filepath:str,map_location='cuda'):
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
|
if self.use_lora or self.use_mtlora:
|
|
|
self.backbone.merge_and_unload()
|
|
|
self.to(device)
|
|
|
state_dict = torch.load(filepath, map_location=map_location)
|
|
|
self.load_state_dict(state_dict, strict=True)
|
|
|
|
|
|
def save_adapters_peft(self, save_directory: str):
|
|
|
|
|
|
print(f"Saving adapters to directory: {save_directory}")
|
|
|
os.makedirs(save_directory, exist_ok=True)
|
|
|
|
|
|
custom_layers_state_dict = {
|
|
|
'prediction_layers': self.prediction_layers.state_dict()
|
|
|
}
|
|
|
|
|
|
if self.use_lora:
|
|
|
self.backbone.save_pretrained(save_directory)
|
|
|
|
|
|
if self.use_mtlora:
|
|
|
custom_layers_state_dict['mtl_layer'] = self.mtl_layer.state_dict()
|
|
|
|
|
|
custom_layers_state_dict['mtl_attn_pool'] = self.attn_pool.state_dict()
|
|
|
|
|
|
|
|
|
torch.save(custom_layers_state_dict, os.path.join(save_directory, 'custom_layers.pt'))
|
|
|
print("Successfully saved PEFT backbone and custom task heads.")
|
|
|
|
|
|
def load_heads(self, filepaths: List[str],device='cuda'):
|
|
|
|
|
|
for ckpt in filepaths:
|
|
|
checkpoint = torch.load(ckpt, map_location=device)
|
|
|
model_state_dict = self.state_dict()
|
|
|
|
|
|
if "prediction_layers" in checkpoint:
|
|
|
for loaded_key, value in checkpoint["prediction_layers"].items():
|
|
|
new_key = loaded_key
|
|
|
|
|
|
|
|
|
if new_key.startswith('heads.emotion.'):
|
|
|
new_key = new_key.replace('heads.emotion.', 'prediction_layers.Emotion.')
|
|
|
|
|
|
if new_key.startswith('heads.age.'):
|
|
|
new_key = new_key.replace('heads.age.', 'prediction_layers.Age.')
|
|
|
|
|
|
if new_key.startswith('heads.gender.'):
|
|
|
new_key = new_key.replace('heads.gender.', 'prediction_layers.Gender.')
|
|
|
|
|
|
|
|
|
if '.5.' in new_key:
|
|
|
new_key = new_key.replace('.5.', '.4.')
|
|
|
|
|
|
if new_key in model_state_dict:
|
|
|
if model_state_dict[new_key].shape == value.shape:
|
|
|
model_state_dict[new_key].copy_(value)
|
|
|
|
|
|
def load_adapters_peft(self, load_directory: str, custom_head_name:str = 'custom_layers.pt'):
|
|
|
|
|
|
print(f"Loading adapters from directory: {load_directory}")
|
|
|
if self.use_lora:
|
|
|
self.backbone = self.backbone.merge_and_unload()
|
|
|
self.backbone = PeftModel.from_pretrained(self.backbone, load_directory)
|
|
|
|
|
|
custom_layers_path = os.path.join(load_directory, custom_head_name)
|
|
|
if not os.path.exists(custom_layers_path):
|
|
|
raise FileNotFoundError(f"Custom task heads file not found at {custom_layers_path}")
|
|
|
|
|
|
checkpoint = torch.load(custom_layers_path, map_location=("cuda" if torch.cuda.is_available() else "cpu"))
|
|
|
|
|
|
self.prediction_layers.load_state_dict(checkpoint['prediction_layers'])
|
|
|
|
|
|
if self.use_mtlora:
|
|
|
try:
|
|
|
self.mtl_layer.load_state_dict(checkpoint['mtl_layer'][0])
|
|
|
except KeyError:
|
|
|
self.mtl_layer.load_state_dict(checkpoint['mtl_layer'])
|
|
|
self.attn_pool.load_state_dict(checkpoint['mtl_attn_pool'])
|
|
|
|
|
|
print("Successfully loaded PEFT backbone and custom task heads.")
|
|
|
|
|
|
def save_trained(self, filepath: str):
|
|
|
|
|
|
trainable_param_names = {name for name, param in self.named_parameters() if param.requires_grad}
|
|
|
trainable_module_paths = {'.'.join(name.split('.')[:-1]) for name in trainable_param_names}
|
|
|
|
|
|
state_to_save = {}
|
|
|
full_state_dict = self.state_dict()
|
|
|
|
|
|
for key, value in full_state_dict.items():
|
|
|
if key in trainable_param_names:
|
|
|
state_to_save[key] = value
|
|
|
continue
|
|
|
|
|
|
|
|
|
current_module_path = '.'.join(key.split('.')[:-1])
|
|
|
if current_module_path in trainable_module_paths:
|
|
|
state_to_save[key] = value
|
|
|
|
|
|
print(f"Saving {len(state_to_save)} state entries (parameters and buffers) to {filepath}")
|
|
|
torch.save(state_to_save, filepath)
|
|
|
|
|
|
|
|
|
def load_trained_legacy(self, filepath: str, device='cuda'):
|
|
|
"""The training of some checkpoint where done with a different model class,
|
|
|
so there is the need of remapping the key names, so they match with this new model class"""
|
|
|
print(f"Loading trained states from structured checkpoint: {filepath}")
|
|
|
|
|
|
checkpoint = torch.load(filepath, map_location=device)
|
|
|
|
|
|
model_state_dict = self.state_dict()
|
|
|
|
|
|
loaded_keys_count = 0
|
|
|
skipped_keys = []
|
|
|
remapped_keys_examples = {}
|
|
|
|
|
|
if "backbone_state_dict" in checkpoint:
|
|
|
print("\n--- Processing Backbone Weights ---")
|
|
|
for loaded_key, value in checkpoint["backbone_state_dict"].items():
|
|
|
new_key = loaded_key
|
|
|
|
|
|
if new_key.startswith('strategy.backbone.'):
|
|
|
new_key = new_key.replace('strategy.backbone.', 'backbone.')
|
|
|
|
|
|
if 'attn.in_proj_weight' in new_key and 'attn.in_proj.weight' not in new_key:
|
|
|
new_key = new_key.replace('attn.in_proj_weight', 'attn.in_proj.weight')
|
|
|
if 'attn.in_proj_bias' in new_key and 'attn.in_proj.bias' not in new_key:
|
|
|
new_key = new_key.replace('attn.in_proj_bias', 'attn.in_proj.bias')
|
|
|
|
|
|
if new_key in model_state_dict:
|
|
|
if model_state_dict[new_key].shape == value.shape:
|
|
|
model_state_dict[new_key].copy_(value)
|
|
|
loaded_keys_count += 1
|
|
|
if loaded_key != new_key and len(remapped_keys_examples) < 5:
|
|
|
remapped_keys_examples[loaded_key] = new_key
|
|
|
else:
|
|
|
skipped_keys.append(f"{loaded_key} (Shape Mismatch: Model {model_state_dict[new_key].shape} vs Ckpt {value.shape})")
|
|
|
else:
|
|
|
skipped_keys.append(f"{loaded_key} (as {new_key}) -> Not found in model")
|
|
|
|
|
|
if "prediction_layers" in checkpoint:
|
|
|
print("\n--- Processing Prediction Head Weights ---")
|
|
|
for loaded_key, value in checkpoint["prediction_layers"].items():
|
|
|
new_key = loaded_key
|
|
|
|
|
|
if new_key.startswith('heads.emotion.'):
|
|
|
new_key = new_key.replace('heads.emotion.', 'prediction_layers.Emotion.')
|
|
|
|
|
|
if new_key.startswith('heads.age.'):
|
|
|
new_key = new_key.replace('heads.age.', 'prediction_layers.Age.')
|
|
|
|
|
|
if new_key.startswith('heads.gender.'):
|
|
|
new_key = new_key.replace('heads.gender.', 'prediction_layers.Gender.')
|
|
|
|
|
|
if '.5.' in new_key:
|
|
|
new_key = new_key.replace('.5.', '.4.')
|
|
|
|
|
|
|
|
|
if new_key in model_state_dict:
|
|
|
if model_state_dict[new_key].shape == value.shape:
|
|
|
model_state_dict[new_key].copy_(value)
|
|
|
loaded_keys_count += 1
|
|
|
if loaded_key != new_key and len(remapped_keys_examples) < 10:
|
|
|
remapped_keys_examples[loaded_key] = new_key
|
|
|
else:
|
|
|
skipped_keys.append(f"{loaded_key} (Shape Mismatch: Model {model_state_dict[new_key].shape} vs Ckpt {value.shape})")
|
|
|
else:
|
|
|
skipped_keys.append(f"{loaded_key} (as {new_key}) -> Not found in model")
|
|
|
|
|
|
if "attn_pool" in checkpoint:
|
|
|
print("\n--- Processing Attention Pool Weights ---")
|
|
|
for loaded_key, value in checkpoint["attn_pool"].items():
|
|
|
|
|
|
new_key = loaded_key.replace('strategy.backbone.attn_pool.', 'backbone.attn_pool.')
|
|
|
|
|
|
|
|
|
if new_key in model_state_dict:
|
|
|
if model_state_dict[new_key].shape == value.shape:
|
|
|
model_state_dict[new_key].copy_(value)
|
|
|
loaded_keys_count += 1
|
|
|
if loaded_key != new_key and len(remapped_keys_examples) < 15:
|
|
|
remapped_keys_examples[loaded_key] = new_key
|
|
|
else:
|
|
|
skipped_keys.append(f"{loaded_key} (Shape Mismatch: Model {model_state_dict[new_key].shape} vs Ckpt {value.shape})")
|
|
|
else:
|
|
|
skipped_keys.append(f"{loaded_key} (as {new_key}) -> Not found in model")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if loaded_keys_count == 0:
|
|
|
print('LAODED 0')
|
|
|
self.load_state_dict(torch.load(filepath, map_location=device), strict=False)
|
|
|
|
|
|
|
|
|
|
|
|
class MTLoRAResidualAttentionBlock(nn.Module):
|
|
|
"""Adaptation of Perception Encoder ResidualAttentionBlock with MTLora, to produce t-task specific feature-maps and a shared feature map"""
|
|
|
def __init__(
|
|
|
self,
|
|
|
d_model: int,
|
|
|
n_head: int,
|
|
|
mlp_ratio: float = 4.0,
|
|
|
ls_init_value: float = None,
|
|
|
act_layer = nn.GELU,
|
|
|
norm_layer = nn.LayerNorm,
|
|
|
drop_path: float = 0.0,
|
|
|
rope: Optional[nn.Module] = None,
|
|
|
r: Union[int, Mapping[str, int]] = 0,
|
|
|
lora_shared_scale: float = 1.0,
|
|
|
lora_task_scale: float = 1.0,
|
|
|
lora_dropout: float = DROPOUT_P,
|
|
|
tasks=None,
|
|
|
trainable_scale_shared=False,
|
|
|
trainable_scale_per_task=False,
|
|
|
shared_mode: str = 'matrix',
|
|
|
):
|
|
|
super().__init__()
|
|
|
self.tasks = tasks
|
|
|
self.num_heads = n_head
|
|
|
self.head_dim = d_model // n_head
|
|
|
self.scale = self.head_dim ** -0.5
|
|
|
self.rope = rope
|
|
|
|
|
|
task_scales = {t: lora_task_scale for t in tasks}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.attn = MTLoRAQKV(
|
|
|
in_features=d_model,
|
|
|
out_features=d_model,
|
|
|
r=r, lora_shared_scale=lora_shared_scale, lora_task_scale=task_scales,
|
|
|
lora_dropout=lora_dropout, tasks=tasks, trainable_scale_shared=trainable_scale_shared,
|
|
|
trainable_scale_per_task=trainable_scale_per_task, shared_mode=shared_mode
|
|
|
)
|
|
|
|
|
|
|
|
|
self.out_proj = MTLoRALinear(
|
|
|
in_features=d_model,
|
|
|
out_features=d_model,
|
|
|
r=r, lora_shared_scale=lora_shared_scale, lora_task_scale=task_scales,
|
|
|
lora_dropout=lora_dropout, tasks=tasks, trainable_scale_shared=trainable_scale_shared,
|
|
|
trainable_scale_per_task=trainable_scale_per_task, shared_mode=shared_mode
|
|
|
)
|
|
|
|
|
|
self.ls_1 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity()
|
|
|
self.ls_2 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity()
|
|
|
|
|
|
self.ln_1 = norm_layer(d_model)
|
|
|
self.ln_2 = norm_layer(d_model)
|
|
|
|
|
|
self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
|
|
self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
|
|
|
|
|
|
|
|
mlp_width = int(d_model * mlp_ratio)
|
|
|
self.mlp = nn.Sequential(
|
|
|
OrderedDict([
|
|
|
("c_fc", MTLoRALinear(
|
|
|
d_model, mlp_width, r=r, lora_shared_scale=lora_shared_scale,
|
|
|
lora_task_scale=task_scales, lora_dropout=lora_dropout, tasks=tasks,
|
|
|
trainable_scale_shared=trainable_scale_shared, trainable_scale_per_task=trainable_scale_per_task,
|
|
|
shared_mode=shared_mode
|
|
|
)),
|
|
|
("gelu", act_layer()),
|
|
|
("c_proj", MTLoRALinear(
|
|
|
mlp_width, d_model, r=r, lora_shared_scale=lora_shared_scale,
|
|
|
lora_task_scale=task_scales, lora_dropout=lora_dropout, tasks=tasks,
|
|
|
trainable_scale_shared=trainable_scale_shared, trainable_scale_per_task=trainable_scale_per_task,
|
|
|
shared_mode=shared_mode
|
|
|
)),
|
|
|
])
|
|
|
)
|
|
|
|
|
|
def _call_attn(
|
|
|
self,
|
|
|
x_shared: torch.Tensor,
|
|
|
attn_mask: Optional[torch.Tensor] = None,
|
|
|
x_tasks: Optional[Dict[str, torch.Tensor]] = None,
|
|
|
):
|
|
|
|
|
|
proj, proj_tasks = self.attn(x_shared, x_tasks)
|
|
|
|
|
|
def compute_attention(projection_tensor):
|
|
|
|
|
|
|
|
|
_, s, _ = projection_tensor.shape
|
|
|
|
|
|
split_size = self.attn.q.linear.out_features
|
|
|
|
|
|
|
|
|
q, k, v = projection_tensor.unflatten(-1, (3, split_size)).permute(2, 0, 1, 3).contiguous()
|
|
|
|
|
|
q = rearrange(q, "b s (h d) -> b h s d", h=self.num_heads)
|
|
|
k = rearrange(k, "b s (h d) -> b h s d", h=self.num_heads)
|
|
|
v = rearrange(v, "b s (h d) -> b h s d", h=self.num_heads)
|
|
|
|
|
|
if self.rope:
|
|
|
q, k = self.rope(q, k)
|
|
|
|
|
|
attn_output = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, scale=self.scale)
|
|
|
return rearrange(attn_output, "b h s d -> b s (h d)")
|
|
|
|
|
|
|
|
|
attn_result = compute_attention(proj)
|
|
|
|
|
|
|
|
|
attn_tasks_results = {}
|
|
|
if proj_tasks:
|
|
|
for task, task_proj in proj_tasks.items():
|
|
|
attn_tasks_results[task] = compute_attention(task_proj)
|
|
|
|
|
|
|
|
|
|
|
|
shared_out, tasks_out = self.out_proj(attn_result, x_tasks=attn_tasks_results if attn_tasks_results else None)
|
|
|
|
|
|
return shared_out, tasks_out
|
|
|
|
|
|
def forward(
|
|
|
self,
|
|
|
x: torch.Tensor,
|
|
|
attn_mask: Optional[torch.Tensor] = None,
|
|
|
x_tasks: Optional[Dict[str, torch.Tensor]] = None,
|
|
|
):
|
|
|
|
|
|
norm_x = self.ln_1(x)
|
|
|
norm_x_tasks = {task: self.ln_1(x_tasks[task]) for task in self.tasks} if x_tasks else None
|
|
|
|
|
|
attn_out, attn_tasks_out = self._call_attn(norm_x, attn_mask=attn_mask, x_tasks=norm_x_tasks)
|
|
|
|
|
|
x = x + self.drop_path1(self.ls_1(attn_out))
|
|
|
if attn_tasks_out and x_tasks:
|
|
|
for task in self.tasks:
|
|
|
x_tasks[task] = x_tasks[task] + self.drop_path1(self.ls_1(attn_tasks_out[task]))
|
|
|
|
|
|
|
|
|
norm_x = self.ln_2(x)
|
|
|
norm_x_tasks = {task: self.ln_2(x_tasks[task]) for task in self.tasks} if x_tasks else None
|
|
|
|
|
|
|
|
|
mlp_fc_out, mlp_fc_tasks_out = self.mlp.c_fc(norm_x, norm_x_tasks)
|
|
|
gelu_out = self.mlp.gelu(mlp_fc_out)
|
|
|
gelu_tasks_out = {task: self.mlp.gelu(mlp_fc_tasks_out[task]) for task in self.tasks} if mlp_fc_tasks_out else None
|
|
|
|
|
|
mlp_proj_out, mlp_proj_tasks_out = self.mlp.c_proj(gelu_out, gelu_tasks_out)
|
|
|
|
|
|
x = x + self.drop_path2(self.ls_2(mlp_proj_out))
|
|
|
if mlp_proj_tasks_out and x_tasks:
|
|
|
for task in self.tasks:
|
|
|
x_tasks[task] = x_tasks[task] + self.drop_path2(self.ls_2(mlp_proj_tasks_out[task]))
|
|
|
|
|
|
return x, x_tasks
|
|
|
|
|
|
def load_from_original_block(self, original_block):
|
|
|
"""
|
|
|
Initializes the weights of this block from a pre-trained ResidualAttentionBlock.
|
|
|
The LoRA-specific parameters are reset to their initial state.
|
|
|
"""
|
|
|
with torch.no_grad():
|
|
|
|
|
|
self.ln_1.load_state_dict(original_block.ln_1.state_dict())
|
|
|
self.ln_2.load_state_dict(original_block.ln_2.state_dict())
|
|
|
self.ls_1.load_state_dict(original_block.ls_1.state_dict())
|
|
|
self.ls_2.load_state_dict(original_block.ls_2.state_dict())
|
|
|
|
|
|
|
|
|
self.mlp.c_fc.linear.load_state_dict(original_block.mlp.c_fc.state_dict())
|
|
|
self.mlp.c_proj.linear.load_state_dict(original_block.mlp.c_proj.state_dict())
|
|
|
|
|
|
|
|
|
|
|
|
if isinstance(original_block.attn, SelfAttention):
|
|
|
|
|
|
|
|
|
original_block.attn.migrate_weights()
|
|
|
|
|
|
|
|
|
qkv_weight = original_block.attn.in_proj.weight
|
|
|
qkv_bias = original_block.attn.in_proj.bias
|
|
|
|
|
|
q_w, k_w, v_w = qkv_weight.chunk(3)
|
|
|
q_b, k_b, v_b = qkv_bias.chunk(3)
|
|
|
|
|
|
|
|
|
self.attn.q.linear.weight.copy_(q_w)
|
|
|
self.attn.q.linear.bias.copy_(q_b)
|
|
|
|
|
|
self.attn.k.linear.weight.copy_(k_w)
|
|
|
self.attn.k.linear.bias.copy_(k_b)
|
|
|
|
|
|
self.attn.v.linear.weight.copy_(v_w)
|
|
|
self.attn.v.linear.bias.copy_(v_b)
|
|
|
|
|
|
|
|
|
self.out_proj.linear.load_state_dict(original_block.attn.out_proj.state_dict())
|
|
|
elif isinstance(original_block.attn, nn.MultiheadAttention):
|
|
|
self.attn.q.linear.weight.copy_(original_block.attn.in_proj_weight[:self.attn.q.linear.out_features, :])
|
|
|
self.attn.q.linear.bias.copy_(original_block.attn.in_proj_bias[:self.attn.q.linear.out_features])
|
|
|
|
|
|
self.attn.k.linear.weight.copy_(original_block.attn.in_proj_weight[self.attn.q.linear.out_features:2*self.attn.q.linear.out_features, :])
|
|
|
self.attn.k.linear.bias.copy_(original_block.attn.in_proj_bias[self.attn.q.linear.out_features:2*self.attn.q.linear.out_features])
|
|
|
|
|
|
self.attn.v.linear.weight.copy_(original_block.attn.in_proj_weight[2*self.attn.q.linear.out_features:3*self.attn.q.linear.out_features, :])
|
|
|
self.attn.v.linear.bias.copy_(original_block.attn.in_proj_bias[2*self.attn.q.linear.out_features:3*self.attn.q.linear.out_features])
|
|
|
|
|
|
self.out_proj.linear.weight.copy_(original_block.attn.out_proj.weight)
|
|
|
self.out_proj.linear.bias.copy_(original_block.attn.out_proj.bias)
|
|
|
|
|
|
else:
|
|
|
raise TypeError(f"Unsupported attention module type in original_block: {type(original_block.attn)}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.attn.reset_parameters()
|
|
|
self.out_proj.reset_parameters()
|
|
|
self.mlp.c_fc.reset_parameters()
|
|
|
self.mlp.c_proj.reset_parameters()
|
|
|
|
|
|
print("Successfully loaded weights from original ResidualAttentionBlock and reset LoRA parameters.")
|
|
|
|
|
|
|
|
|
class MTLoRAAttentionPooling(nn.Module):
|
|
|
"""
|
|
|
A MT-LoRA equivalent of the AttentionPooling transformer block.
|
|
|
|
|
|
This module replicates the full original architecture:
|
|
|
1. Task-specific probes for attention pooling.
|
|
|
2. MT-LoRA enabled Q/K/V and Output projections.
|
|
|
3. A LayerNorm layer.
|
|
|
4. An MLP block with MT-LoRA enabled linear layers.
|
|
|
5. A final residual connection, matching the original's structure.
|
|
|
"""
|
|
|
def __init__(
|
|
|
self,
|
|
|
embed_dim: int,
|
|
|
num_heads: int,
|
|
|
tasks: List[str],
|
|
|
r: Union[int, Mapping[str, int]] = 0,
|
|
|
lora_shared_scale: float = 1.0,
|
|
|
lora_task_scale: float = 1.0,
|
|
|
lora_dropout: float = 0.0,
|
|
|
mlp_ratio: int = 4,
|
|
|
act_layer = nn.GELU,
|
|
|
norm_layer = nn.LayerNorm,
|
|
|
):
|
|
|
super().__init__()
|
|
|
self.tasks = tasks
|
|
|
self.num_heads = num_heads
|
|
|
|
|
|
self.probe = nn.ParameterDict({
|
|
|
task: nn.Parameter(torch.randn(1, 1, embed_dim))
|
|
|
for task in tasks
|
|
|
})
|
|
|
|
|
|
task_scales = {t: lora_task_scale for t in tasks}
|
|
|
|
|
|
self.q_proj = MTLoRALinear(
|
|
|
embed_dim, embed_dim, r=r, lora_shared_scale=lora_shared_scale, lora_task_scale=task_scales,
|
|
|
lora_dropout=lora_dropout, tasks=tasks
|
|
|
)
|
|
|
self.k_proj = MTLoRALinear(
|
|
|
embed_dim, embed_dim, r=r, lora_shared_scale=lora_shared_scale, lora_task_scale=task_scales,
|
|
|
lora_dropout=lora_dropout, tasks=tasks
|
|
|
)
|
|
|
self.v_proj = MTLoRALinear(
|
|
|
embed_dim, embed_dim, r=r, lora_shared_scale=lora_shared_scale, lora_task_scale=task_scales,
|
|
|
lora_dropout=lora_dropout, tasks=tasks
|
|
|
)
|
|
|
self.out_proj = MTLoRALinear(
|
|
|
embed_dim, embed_dim, r=r, lora_shared_scale=lora_shared_scale, lora_task_scale=task_scales,
|
|
|
lora_dropout=lora_dropout, tasks=tasks
|
|
|
)
|
|
|
|
|
|
self.layernorm = norm_layer(embed_dim)
|
|
|
mlp_width = int(embed_dim * mlp_ratio)
|
|
|
self.mlp = nn.Sequential(
|
|
|
OrderedDict([
|
|
|
("c_fc", MTLoRALinear(
|
|
|
embed_dim, mlp_width, r=r, lora_shared_scale=lora_shared_scale,
|
|
|
lora_task_scale=task_scales, lora_dropout=lora_dropout, tasks=tasks
|
|
|
)),
|
|
|
("gelu", nn.GELU()),
|
|
|
("c_proj", MTLoRALinear(
|
|
|
mlp_width, embed_dim, r=r, lora_shared_scale=lora_shared_scale,
|
|
|
lora_task_scale=task_scales, lora_dropout=lora_dropout, tasks=tasks
|
|
|
)),
|
|
|
])
|
|
|
)
|
|
|
|
|
|
def load_from_original(self, original_pool: AttentionPooling):
|
|
|
"""Initializes all weights from the pretrained AttentionPooling block."""
|
|
|
with torch.no_grad():
|
|
|
original_attn = original_pool.attn
|
|
|
|
|
|
for task in self.tasks:
|
|
|
self.probe[task].copy_(original_pool.probe)
|
|
|
|
|
|
q_w, k_w, v_w = original_attn.in_proj_weight.chunk(3)
|
|
|
q_b, k_b, v_b = original_attn.in_proj_bias.chunk(3)
|
|
|
|
|
|
self.q_proj.linear.weight.copy_(q_w)
|
|
|
self.q_proj.linear.bias.copy_(q_b)
|
|
|
self.k_proj.linear.weight.copy_(k_w)
|
|
|
self.k_proj.linear.bias.copy_(k_b)
|
|
|
self.v_proj.linear.weight.copy_(v_w)
|
|
|
self.v_proj.linear.bias.copy_(v_b)
|
|
|
|
|
|
self.out_proj.linear.load_state_dict(original_attn.out_proj.state_dict())
|
|
|
|
|
|
self.layernorm.load_state_dict(original_pool.layernorm.state_dict())
|
|
|
|
|
|
self.mlp.c_fc.linear.load_state_dict(original_pool.mlp.c_fc.state_dict())
|
|
|
self.mlp.c_proj.linear.load_state_dict(original_pool.mlp.c_proj.state_dict())
|
|
|
|
|
|
self.q_proj.reset_parameters()
|
|
|
self.k_proj.reset_parameters()
|
|
|
self.v_proj.reset_parameters()
|
|
|
self.out_proj.reset_parameters()
|
|
|
self.mlp.c_fc.reset_parameters()
|
|
|
self.mlp.c_proj.reset_parameters()
|
|
|
print("Full MT-LoRA Attention Pooling block created and initialized from pretrained weights.")
|
|
|
|
|
|
def forward(self, x_tasks: Dict[str, torch.Tensor]):
|
|
|
"""
|
|
|
Forward pass that correctly handles unique inputs for each task.
|
|
|
|
|
|
In this version, K and V are calculated inside the loop based on
|
|
|
the task-specific input 'x', and the each task has it's unique probe.
|
|
|
"""
|
|
|
|
|
|
|
|
|
final_outputs = {}
|
|
|
for task, x in x_tasks.items():
|
|
|
B, N, C = x.shape
|
|
|
probe = self.probe[task].repeat(B, 1, 1)
|
|
|
|
|
|
|
|
|
_, q_task_dict = self.q_proj(probe, x_tasks={task: probe})
|
|
|
q = q_task_dict[task]
|
|
|
|
|
|
_, k_task_dict = self.k_proj(x, x_tasks={task: x})
|
|
|
k = k_task_dict[task]
|
|
|
|
|
|
_, v_task_dict = self.v_proj(x, x_tasks={task: x})
|
|
|
v = v_task_dict[task]
|
|
|
|
|
|
q = rearrange(q, 'b n (h d) -> b h n d', h=self.num_heads)
|
|
|
k = rearrange(k, 'b n (h d) -> b h n d', h=self.num_heads)
|
|
|
v = rearrange(v, 'b n (h d) -> b h n d', h=self.num_heads)
|
|
|
|
|
|
attn_out = F.scaled_dot_product_attention(q, k, v)
|
|
|
attn_out_rearranged = rearrange(attn_out, 'b h n d -> b n (h d)')
|
|
|
|
|
|
_, out_proj_dict = self.out_proj(attn_out_rearranged, x_tasks={task: attn_out_rearranged})
|
|
|
x_attn = out_proj_dict[task]
|
|
|
|
|
|
norm_attn = self.layernorm(x_attn)
|
|
|
|
|
|
_, fc_task_dict = self.mlp.c_fc(norm_attn, x_tasks={task: norm_attn})
|
|
|
gelu_out = self.mlp.gelu(fc_task_dict[task])
|
|
|
_, proj_task_dict = self.mlp.c_proj(gelu_out, x_tasks={task: gelu_out})
|
|
|
mlp_out = proj_task_dict[task]
|
|
|
|
|
|
final_outputs[task] = x_attn + mlp_out
|
|
|
|
|
|
return final_outputs
|
|
|
|
|
|
|