Spaces:
Runtime error
Runtime error
File size: 16,839 Bytes
fc0ff8f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 |
import torch
from einops import rearrange
from torch import nn
from .helpers import PerceiverResampler
from torch.distributed.fsdp.wrap import (
enable_wrap,
wrap,
)
from transformers.modeling_outputs import CausalLMOutputWithPast
from torch.distributed.fsdp import (
FullyShardedDataParallel as FSDP,
)
from .utils import apply_with_stopping_condition
class Flamingo(nn.Module):
def __init__(
self,
vision_encoder: nn.Module,
lang_encoder: nn.Module,
eoc_token_id: int,
media_token_id: int,
vis_dim: int,
cross_attn_every_n_layers: int = 1,
gradient_checkpointing: bool = False,
compute_all_grads: bool = False,
):
"""
Args:
vision_encoder (nn.Module): HF CLIPModel
lang_encoder (nn.Module): HF causal language model
eoc_token_id (int): Token id for <|endofchunk|>
media_token_id (int): Token id for <image>
vis_dim (int): Dimension of the visual features.
Visual features are projected to match this shape along the last dimension.
cross_attn_every_n_layers (int, optional): How often to apply cross attention after transformer layer. Defaults to 1.
"""
super().__init__()
self.eoc_token_id = eoc_token_id
self.media_token_id = media_token_id
self.vis_dim = vis_dim
if hasattr(lang_encoder.config, "d_model"):
self.lang_dim = lang_encoder.config.d_model # mpt uses d_model
else:
self.lang_dim = lang_encoder.config.hidden_size
self.vision_encoder = vision_encoder.visual
self.perceiver = PerceiverResampler(dim=self.vis_dim)
self.lang_encoder = lang_encoder
self.lang_encoder.init_flamingo(
media_token_id=media_token_id,
lang_hidden_size=self.lang_dim,
vis_hidden_size=self.vis_dim,
cross_attn_every_n_layers=cross_attn_every_n_layers,
gradient_checkpointing=gradient_checkpointing,
)
self._use_gradient_checkpointing = gradient_checkpointing
self.perceiver._use_gradient_checkpointing = gradient_checkpointing
self.compute_all_grads = compute_all_grads
def forward(
self,
vision_x: torch.Tensor,
lang_x: torch.Tensor,
attention_mask: torch.Tensor = None,
labels: torch.Tensor = None,
clear_conditioned_layers: bool = True,
past_key_values=None,
use_cache: bool = False,
):
"""
Forward pass of Flamingo.
Args:
vision_x (torch.Tensor): Vision input
shape (B, T_img, F, C, H, W) with F=1
lang_x (torch.Tensor): Language input ids
shape (B, T_txt)
attention_mask (torch.Tensor, optional): Attention mask. Defaults to None.
labels (torch.Tensor, optional): Labels. Defaults to None.
clear_conditioned_layers: if True, clear the conditioned layers
once the foward pass is completed. Set this to false if the
same set of images will be reused in another subsequent
forward pass.
past_key_values: pre-computed values to pass to language model.
See past_key_values documentation in Hugging Face
CausalLM models.
use_cache: whether to use cached key values. See use_cache
documentation in Hugging Face CausalLM models.
"""
assert (
self.lang_encoder.initialized_flamingo
), "Flamingo layers are not initialized. Please call `init_flamingo` first."
assert (
self.lang_encoder._use_cached_vision_x or vision_x is not None
), "Must provide either vision_x or have precached media using cache_media()."
if self.lang_encoder._use_cached_vision_x:
# Case: use cached; vision_x should be cached and other
# vision-related inputs should not be provided.
assert (
vision_x is None
), "Expect vision_x to be None when media has been cached using cache_media(). Try uncache_media() first."
assert self.lang_encoder.is_conditioned()
else:
# Case: do not use caching (i.e. this is a standard forward pass);
self._encode_vision_x(vision_x=vision_x)
self._condition_media_locations(input_ids=lang_x)
output = self.lang_encoder(
input_ids=lang_x,
attention_mask=attention_mask,
labels=labels,
past_key_values=past_key_values,
use_cache=use_cache,
)
if clear_conditioned_layers:
self.lang_encoder.clear_conditioned_layers()
return output
def generate(
self,
vision_x: torch.Tensor,
lang_x: torch.Tensor,
attention_mask: torch.Tensor = None,
num_beams=1,
min_new_tokens=None,
max_new_tokens=None,
temperature=1.0,
top_k=0,
top_p=1.0,
no_repeat_ngram_size=0,
repetition_penalty=1.0,
prefix_allowed_tokens_fn=None,
length_penalty=1.0,
num_return_sequences=1,
do_sample=False,
early_stopping=False,
):
"""
Generate text conditioned on vision and language inputs.
Args:
vision_x (torch.Tensor): Vision input
shape (B, T_img, F, C, H, W)
images in the same chunk are collated along T_img, and frames are collated along F
currently only F=1 is supported (single-frame videos)
lang_x (torch.Tensor): Language input
shape (B, T_txt)
max_length (int, optional): Maximum length of the output. Defaults to None.
attention_mask (torch.Tensor, optional): Attention mask. Defaults to None.
num_beams (int, optional): Number of beams. Defaults to 1.
max_new_tokens (int, optional): Maximum new tokens. Defaults to None.
temperature (float, optional): Temperature. Defaults to 1.0.
top_k (int, optional): Top k. Defaults to 0.
top_p (float, optional): Top p. Defaults to 1.0.
no_repeat_ngram_size (int, optional): No repeat ngram size. Defaults to 0.
length_penalty (float, optional): Length penalty. Defaults to 1.0.
num_return_sequences (int, optional): Number of return sequences. Defaults to 1.
do_sample (bool, optional): Do sample. Defaults to False.
early_stopping (bool, optional): Early stopping. Defaults to False.
Returns:
torch.Tensor: lang_x with generated tokens appended to it
"""
if num_beams > 1:
vision_x = vision_x.repeat_interleave(num_beams, dim=0)
self.lang_encoder._use_cached_vision_x = True
self._encode_vision_x(vision_x=vision_x)
output = self.lang_encoder.generate(
input_ids=lang_x,
attention_mask=attention_mask,
eos_token_id=self.eoc_token_id,
num_beams=num_beams,
min_new_tokens=min_new_tokens,
max_new_tokens=max_new_tokens,
temperature=temperature,
top_k=top_k,
top_p=top_p,
prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
no_repeat_ngram_size=no_repeat_ngram_size,
repetition_penalty=repetition_penalty,
length_penalty=length_penalty,
num_return_sequences=num_return_sequences,
do_sample=do_sample,
early_stopping=early_stopping,
)
self.lang_encoder.clear_conditioned_layers()
self.lang_encoder._use_cached_vision_x = False
return output
def _encode_vision_x(self, vision_x: torch.Tensor):
"""
Compute media tokens from vision input by passing it through vision encoder and conditioning language model.
Args:
vision_x (torch.Tensor): Vision input
shape (B, T_img, F, C, H, W)
Images in the same chunk are collated along T_img, and frames are collated along F
Currently only F=1 is supported (single-frame videos)
rearrange code based on https://github.com/dhansmair/flamingo-mini
"""
assert vision_x.ndim == 6, "vision_x should be of shape (b, T_img, F, C, H, W)"
b, T, F = vision_x.shape[:3]
assert F == 1, "Only single frame supported"
vision_x = rearrange(vision_x, "b T F c h w -> (b T F) c h w")
with torch.set_grad_enabled(self.compute_all_grads):
vision_x = self.vision_encoder(vision_x)[1]
vision_x = rearrange(vision_x, "(b T F) v d -> b T F v d", b=b, T=T, F=F)
vision_x = self.perceiver(vision_x)
for layer in self.lang_encoder._get_decoder_layers():
layer.condition_vis_x(vision_x)
def _get_vision_embedding(self, vision_x: torch.Tensor):
"""Without perceiver, not yet checked with new version
Compute media tokens from vision input by passing it through vision encoder and conditioning language model.
Args:
vision_x (torch.Tensor): Vision input
shape (B, T_img, F, C, H, W)
Images in the same chunk are collated along T_img, and frames are collated along F
Currently only F=1 is supported (single-frame videos)
rearrange code based on https://github.com/dhansmair/flamingo-mini
"""
assert vision_x.ndim == 6, "vision_x should be of shape (b, T_img, F, C, H, W)"
b, T, F = vision_x.shape[:3]
assert F == 1, "Only single frame supported"
vision_x = rearrange(vision_x, "b T F c h w -> (b T F) c h w")
with torch.set_grad_enabled(self.compute_all_grads):
vision_x = self.vision_encoder(vision_x)[1]
vision_x = rearrange(vision_x, "(b T F) v d -> b T F v d", b=b, T=T, F=F)
return vision_x
def _encode_vision_embedding(self, vision_x_embedding: torch.Tensor):
# encode vision embedding, that has not gone through perceiver yet
vision_x_embedding = self.perceiver(vision_x_embedding) # reshapes to (b, T, n, d)
for layer in self.lang_encoder._get_decoder_layers():
layer.condition_vis_x(vision_x_embedding)
def wrap_fsdp(self, wrapper_kwargs, device_id):
"""
Manually wraps submodules for FSDP and move other parameters to device_id.
Why manually wrap?
- all parameters within the FSDP wrapper must have the same requires_grad.
We have a mix of frozen and unfrozen parameters.
- model.vision_encoder.visual needs to be individually wrapped or encode_vision_x errors
See: https://github.com/pytorch/pytorch/issues/82461#issuecomment-1269136344
The rough wrapping structure is:
- FlamingoModel
- FSDP(FSDP(vision_encoder))
- FSDP(FSDP(perceiver))
- lang_encoder
- FSDP(FSDP(input_embeddings))
- FlamingoLayers
- FSDP(FSDP(gated_cross_attn_layer))
- FSDP(FSDP(decoder_layer))
- FSDP(FSDP(output_embeddings))
- other parameters
Known issues:
- Our FSDP strategy is not compatible with tied embeddings. If the LM embeddings are tied,
train with DDP or set the --freeze_lm_embeddings flag to true.
- With FSDP + gradient ckpting, one can increase the batch size with seemingly no upper bound.
Although the training curves look okay, we found that downstream performance dramatically
degrades if the batch size is unreasonably large (e.g., 100 MMC4 batch size for OPT-125M).
FAQs about our FSDP wrapping strategy:
Why double wrap?
As of torch==2.0.1, FSDP's _post_forward_hook and _post_backward_hook
only free gathered parameters if the module is NOT FSDP root.
Why unfreeze the decoder_layers?
See https://github.com/pytorch/pytorch/issues/95805
As of torch==2.0.1, FSDP's _post_backward_hook is only registed if the flat param
requires_grad=True. We need the postback to fire to avoid OOM.
To effectively freeze the decoder layers, we exclude them from the optimizer.
What is assumed to be frozen v. unfrozen?
We assume that the model is being trained under normal Flamingo settings
with these lines being called in factory.py:
```
# Freeze all parameters
model.requires_grad_(False)
assert sum(p.numel() for p in model.parameters() if p.requires_grad) == 0
# Unfreeze perceiver, gated_cross_attn_layers, and LM input embeddings
model.perceiver.requires_grad_(True)
model.lang_encoder.gated_cross_attn_layers.requires_grad_(True)
[optional] model.lang_encoder.get_input_embeddings().requires_grad_(True)
```
"""
# unfreeze the decoder layers
for block in self.lang_encoder.old_decoder_blocks:
block.requires_grad_(True)
# wrap in FSDP
with enable_wrap(wrapper_cls=FSDP, **wrapper_kwargs):
self.perceiver = wrap(wrap(self.perceiver))
self.lang_encoder.old_decoder_blocks = nn.ModuleList(
wrap(wrap(block)) for block in self.lang_encoder.old_decoder_blocks
)
self.lang_encoder.gated_cross_attn_layers = nn.ModuleList(
wrap(wrap(layer)) if layer is not None else None
for layer in self.lang_encoder.gated_cross_attn_layers
)
self.lang_encoder.init_flamingo_layers(self._use_gradient_checkpointing)
self.lang_encoder.set_input_embeddings(
wrap(wrap(self.lang_encoder.get_input_embeddings()))
)
self.lang_encoder.set_output_embeddings(
wrap(wrap(self.lang_encoder.get_output_embeddings()))
)
self.vision_encoder = wrap(wrap(self.vision_encoder)) # frozen
# manually move non-FSDP managed parameters to device_id
# these are all in lang_encoder
apply_with_stopping_condition(
module=self.lang_encoder,
apply_fn=lambda m: m.to(device_id),
apply_condition=lambda m: len(list(m.children())) == 0,
stopping_condition=lambda m: isinstance(m, FSDP),
)
# exclude the original decoder layers from the optimizer
for block in self.lang_encoder.old_decoder_blocks:
for p in block.parameters():
p.exclude_from_optimizer = True
# set up clip_grad_norm_ function
def clip_grad_norm_(max_norm):
self.perceiver.clip_grad_norm_(max_norm)
for layer in self.lang_encoder.gated_cross_attn_layers:
if layer is not None:
layer.clip_grad_norm_(max_norm)
self.lang_encoder.get_input_embeddings().clip_grad_norm_(max_norm)
self.clip_grad_norm_ = clip_grad_norm_
def _condition_media_locations(self, input_ids: torch.Tensor):
"""
Compute the media token locations from lang_x and condition the language model on these.
Args:
input_ids (torch.Tensor): Language input
shape (B, T_txt)
"""
media_locations = input_ids == self.media_token_id
for layer in self.lang_encoder._get_decoder_layers():
layer.condition_media_locations(media_locations)
def cache_media(self, input_ids: torch.Tensor, vision_x: torch.Tensor):
"""
Pre-cache a prompt/sequence of images / text for log-likelihood evaluations.
All subsequent calls to forward() will generate attending to the LAST
image in vision_x.
This is not meant to be used to cache things for generate().
Args:
input_ids (torch.Tensor): Language input
shape (B, T_txt)
vision_x (torch.Tensor): Vision input
shape (B, T_img, F, C, H, W)
Images in the same chunk are collated along T_img, and frames are collated along F
Currently only F=1 is supported (single-frame videos)
"""
self._encode_vision_x(vision_x=vision_x)
self._condition_media_locations(input_ids=input_ids)
self.lang_encoder._use_cached_vision_x = True
def uncache_media(self):
"""
Clear all conditioning.
"""
self.lang_encoder.clear_conditioned_layers()
self.lang_encoder._use_cached_vision_x = False
|