Spaces:
Runtime error
Runtime error
| import torch | |
| from einops import rearrange | |
| from torch import nn | |
| from .helpers import PerceiverResampler | |
| from torch.distributed.fsdp.wrap import ( | |
| enable_wrap, | |
| wrap, | |
| ) | |
| from transformers.modeling_outputs import CausalLMOutputWithPast | |
| from torch.distributed.fsdp import ( | |
| FullyShardedDataParallel as FSDP, | |
| ) | |
| from .utils import apply_with_stopping_condition | |
| class Flamingo(nn.Module): | |
| def __init__( | |
| self, | |
| vision_encoder: nn.Module, | |
| lang_encoder: nn.Module, | |
| eoc_token_id: int, | |
| media_token_id: int, | |
| vis_dim: int, | |
| cross_attn_every_n_layers: int = 1, | |
| gradient_checkpointing: bool = False, | |
| compute_all_grads: bool = False, | |
| ): | |
| """ | |
| Args: | |
| vision_encoder (nn.Module): HF CLIPModel | |
| lang_encoder (nn.Module): HF causal language model | |
| eoc_token_id (int): Token id for <|endofchunk|> | |
| media_token_id (int): Token id for <image> | |
| vis_dim (int): Dimension of the visual features. | |
| Visual features are projected to match this shape along the last dimension. | |
| cross_attn_every_n_layers (int, optional): How often to apply cross attention after transformer layer. Defaults to 1. | |
| """ | |
| super().__init__() | |
| self.eoc_token_id = eoc_token_id | |
| self.media_token_id = media_token_id | |
| self.vis_dim = vis_dim | |
| if hasattr(lang_encoder.config, "d_model"): | |
| self.lang_dim = lang_encoder.config.d_model # mpt uses d_model | |
| else: | |
| self.lang_dim = lang_encoder.config.hidden_size | |
| self.vision_encoder = vision_encoder.visual | |
| self.perceiver = PerceiverResampler(dim=self.vis_dim) | |
| self.lang_encoder = lang_encoder | |
| self.lang_encoder.init_flamingo( | |
| media_token_id=media_token_id, | |
| lang_hidden_size=self.lang_dim, | |
| vis_hidden_size=self.vis_dim, | |
| cross_attn_every_n_layers=cross_attn_every_n_layers, | |
| gradient_checkpointing=gradient_checkpointing, | |
| ) | |
| self._use_gradient_checkpointing = gradient_checkpointing | |
| self.perceiver._use_gradient_checkpointing = gradient_checkpointing | |
| self.compute_all_grads = compute_all_grads | |
| def forward( | |
| self, | |
| vision_x: torch.Tensor, | |
| lang_x: torch.Tensor, | |
| attention_mask: torch.Tensor = None, | |
| labels: torch.Tensor = None, | |
| clear_conditioned_layers: bool = True, | |
| past_key_values=None, | |
| use_cache: bool = False, | |
| ): | |
| """ | |
| Forward pass of Flamingo. | |
| Args: | |
| vision_x (torch.Tensor): Vision input | |
| shape (B, T_img, F, C, H, W) with F=1 | |
| lang_x (torch.Tensor): Language input ids | |
| shape (B, T_txt) | |
| attention_mask (torch.Tensor, optional): Attention mask. Defaults to None. | |
| labels (torch.Tensor, optional): Labels. Defaults to None. | |
| clear_conditioned_layers: if True, clear the conditioned layers | |
| once the foward pass is completed. Set this to false if the | |
| same set of images will be reused in another subsequent | |
| forward pass. | |
| past_key_values: pre-computed values to pass to language model. | |
| See past_key_values documentation in Hugging Face | |
| CausalLM models. | |
| use_cache: whether to use cached key values. See use_cache | |
| documentation in Hugging Face CausalLM models. | |
| """ | |
| assert ( | |
| self.lang_encoder.initialized_flamingo | |
| ), "Flamingo layers are not initialized. Please call `init_flamingo` first." | |
| assert ( | |
| self.lang_encoder._use_cached_vision_x or vision_x is not None | |
| ), "Must provide either vision_x or have precached media using cache_media()." | |
| if self.lang_encoder._use_cached_vision_x: | |
| # Case: use cached; vision_x should be cached and other | |
| # vision-related inputs should not be provided. | |
| assert ( | |
| vision_x is None | |
| ), "Expect vision_x to be None when media has been cached using cache_media(). Try uncache_media() first." | |
| assert self.lang_encoder.is_conditioned() | |
| else: | |
| # Case: do not use caching (i.e. this is a standard forward pass); | |
| self._encode_vision_x(vision_x=vision_x) | |
| self._condition_media_locations(input_ids=lang_x) | |
| output = self.lang_encoder( | |
| input_ids=lang_x, | |
| attention_mask=attention_mask, | |
| labels=labels, | |
| past_key_values=past_key_values, | |
| use_cache=use_cache, | |
| ) | |
| if clear_conditioned_layers: | |
| self.lang_encoder.clear_conditioned_layers() | |
| return output | |
| def generate( | |
| self, | |
| vision_x: torch.Tensor, | |
| lang_x: torch.Tensor, | |
| attention_mask: torch.Tensor = None, | |
| num_beams=1, | |
| min_new_tokens=None, | |
| max_new_tokens=None, | |
| temperature=1.0, | |
| top_k=0, | |
| top_p=1.0, | |
| no_repeat_ngram_size=0, | |
| repetition_penalty=1.0, | |
| prefix_allowed_tokens_fn=None, | |
| length_penalty=1.0, | |
| num_return_sequences=1, | |
| do_sample=False, | |
| early_stopping=False, | |
| ): | |
| """ | |
| Generate text conditioned on vision and language inputs. | |
| Args: | |
| vision_x (torch.Tensor): Vision input | |
| shape (B, T_img, F, C, H, W) | |
| images in the same chunk are collated along T_img, and frames are collated along F | |
| currently only F=1 is supported (single-frame videos) | |
| lang_x (torch.Tensor): Language input | |
| shape (B, T_txt) | |
| max_length (int, optional): Maximum length of the output. Defaults to None. | |
| attention_mask (torch.Tensor, optional): Attention mask. Defaults to None. | |
| num_beams (int, optional): Number of beams. Defaults to 1. | |
| max_new_tokens (int, optional): Maximum new tokens. Defaults to None. | |
| temperature (float, optional): Temperature. Defaults to 1.0. | |
| top_k (int, optional): Top k. Defaults to 0. | |
| top_p (float, optional): Top p. Defaults to 1.0. | |
| no_repeat_ngram_size (int, optional): No repeat ngram size. Defaults to 0. | |
| length_penalty (float, optional): Length penalty. Defaults to 1.0. | |
| num_return_sequences (int, optional): Number of return sequences. Defaults to 1. | |
| do_sample (bool, optional): Do sample. Defaults to False. | |
| early_stopping (bool, optional): Early stopping. Defaults to False. | |
| Returns: | |
| torch.Tensor: lang_x with generated tokens appended to it | |
| """ | |
| if num_beams > 1: | |
| vision_x = vision_x.repeat_interleave(num_beams, dim=0) | |
| self.lang_encoder._use_cached_vision_x = True | |
| self._encode_vision_x(vision_x=vision_x) | |
| output = self.lang_encoder.generate( | |
| input_ids=lang_x, | |
| attention_mask=attention_mask, | |
| eos_token_id=self.eoc_token_id, | |
| num_beams=num_beams, | |
| min_new_tokens=min_new_tokens, | |
| max_new_tokens=max_new_tokens, | |
| temperature=temperature, | |
| top_k=top_k, | |
| top_p=top_p, | |
| prefix_allowed_tokens_fn=prefix_allowed_tokens_fn, | |
| no_repeat_ngram_size=no_repeat_ngram_size, | |
| repetition_penalty=repetition_penalty, | |
| length_penalty=length_penalty, | |
| num_return_sequences=num_return_sequences, | |
| do_sample=do_sample, | |
| early_stopping=early_stopping, | |
| ) | |
| self.lang_encoder.clear_conditioned_layers() | |
| self.lang_encoder._use_cached_vision_x = False | |
| return output | |
| def _encode_vision_x(self, vision_x: torch.Tensor): | |
| """ | |
| Compute media tokens from vision input by passing it through vision encoder and conditioning language model. | |
| Args: | |
| vision_x (torch.Tensor): Vision input | |
| shape (B, T_img, F, C, H, W) | |
| Images in the same chunk are collated along T_img, and frames are collated along F | |
| Currently only F=1 is supported (single-frame videos) | |
| rearrange code based on https://github.com/dhansmair/flamingo-mini | |
| """ | |
| assert vision_x.ndim == 6, "vision_x should be of shape (b, T_img, F, C, H, W)" | |
| b, T, F = vision_x.shape[:3] | |
| assert F == 1, "Only single frame supported" | |
| vision_x = rearrange(vision_x, "b T F c h w -> (b T F) c h w") | |
| with torch.set_grad_enabled(self.compute_all_grads): | |
| vision_x = self.vision_encoder(vision_x)[1] | |
| vision_x = rearrange(vision_x, "(b T F) v d -> b T F v d", b=b, T=T, F=F) | |
| vision_x = self.perceiver(vision_x) | |
| for layer in self.lang_encoder._get_decoder_layers(): | |
| layer.condition_vis_x(vision_x) | |
| def _get_vision_embedding(self, vision_x: torch.Tensor): | |
| """Without perceiver, not yet checked with new version | |
| Compute media tokens from vision input by passing it through vision encoder and conditioning language model. | |
| Args: | |
| vision_x (torch.Tensor): Vision input | |
| shape (B, T_img, F, C, H, W) | |
| Images in the same chunk are collated along T_img, and frames are collated along F | |
| Currently only F=1 is supported (single-frame videos) | |
| rearrange code based on https://github.com/dhansmair/flamingo-mini | |
| """ | |
| assert vision_x.ndim == 6, "vision_x should be of shape (b, T_img, F, C, H, W)" | |
| b, T, F = vision_x.shape[:3] | |
| assert F == 1, "Only single frame supported" | |
| vision_x = rearrange(vision_x, "b T F c h w -> (b T F) c h w") | |
| with torch.set_grad_enabled(self.compute_all_grads): | |
| vision_x = self.vision_encoder(vision_x)[1] | |
| vision_x = rearrange(vision_x, "(b T F) v d -> b T F v d", b=b, T=T, F=F) | |
| return vision_x | |
| def _encode_vision_embedding(self, vision_x_embedding: torch.Tensor): | |
| # encode vision embedding, that has not gone through perceiver yet | |
| vision_x_embedding = self.perceiver(vision_x_embedding) # reshapes to (b, T, n, d) | |
| for layer in self.lang_encoder._get_decoder_layers(): | |
| layer.condition_vis_x(vision_x_embedding) | |
| def wrap_fsdp(self, wrapper_kwargs, device_id): | |
| """ | |
| Manually wraps submodules for FSDP and move other parameters to device_id. | |
| Why manually wrap? | |
| - all parameters within the FSDP wrapper must have the same requires_grad. | |
| We have a mix of frozen and unfrozen parameters. | |
| - model.vision_encoder.visual needs to be individually wrapped or encode_vision_x errors | |
| See: https://github.com/pytorch/pytorch/issues/82461#issuecomment-1269136344 | |
| The rough wrapping structure is: | |
| - FlamingoModel | |
| - FSDP(FSDP(vision_encoder)) | |
| - FSDP(FSDP(perceiver)) | |
| - lang_encoder | |
| - FSDP(FSDP(input_embeddings)) | |
| - FlamingoLayers | |
| - FSDP(FSDP(gated_cross_attn_layer)) | |
| - FSDP(FSDP(decoder_layer)) | |
| - FSDP(FSDP(output_embeddings)) | |
| - other parameters | |
| Known issues: | |
| - Our FSDP strategy is not compatible with tied embeddings. If the LM embeddings are tied, | |
| train with DDP or set the --freeze_lm_embeddings flag to true. | |
| - With FSDP + gradient ckpting, one can increase the batch size with seemingly no upper bound. | |
| Although the training curves look okay, we found that downstream performance dramatically | |
| degrades if the batch size is unreasonably large (e.g., 100 MMC4 batch size for OPT-125M). | |
| FAQs about our FSDP wrapping strategy: | |
| Why double wrap? | |
| As of torch==2.0.1, FSDP's _post_forward_hook and _post_backward_hook | |
| only free gathered parameters if the module is NOT FSDP root. | |
| Why unfreeze the decoder_layers? | |
| See https://github.com/pytorch/pytorch/issues/95805 | |
| As of torch==2.0.1, FSDP's _post_backward_hook is only registed if the flat param | |
| requires_grad=True. We need the postback to fire to avoid OOM. | |
| To effectively freeze the decoder layers, we exclude them from the optimizer. | |
| What is assumed to be frozen v. unfrozen? | |
| We assume that the model is being trained under normal Flamingo settings | |
| with these lines being called in factory.py: | |
| ``` | |
| # Freeze all parameters | |
| model.requires_grad_(False) | |
| assert sum(p.numel() for p in model.parameters() if p.requires_grad) == 0 | |
| # Unfreeze perceiver, gated_cross_attn_layers, and LM input embeddings | |
| model.perceiver.requires_grad_(True) | |
| model.lang_encoder.gated_cross_attn_layers.requires_grad_(True) | |
| [optional] model.lang_encoder.get_input_embeddings().requires_grad_(True) | |
| ``` | |
| """ | |
| # unfreeze the decoder layers | |
| for block in self.lang_encoder.old_decoder_blocks: | |
| block.requires_grad_(True) | |
| # wrap in FSDP | |
| with enable_wrap(wrapper_cls=FSDP, **wrapper_kwargs): | |
| self.perceiver = wrap(wrap(self.perceiver)) | |
| self.lang_encoder.old_decoder_blocks = nn.ModuleList( | |
| wrap(wrap(block)) for block in self.lang_encoder.old_decoder_blocks | |
| ) | |
| self.lang_encoder.gated_cross_attn_layers = nn.ModuleList( | |
| wrap(wrap(layer)) if layer is not None else None | |
| for layer in self.lang_encoder.gated_cross_attn_layers | |
| ) | |
| self.lang_encoder.init_flamingo_layers(self._use_gradient_checkpointing) | |
| self.lang_encoder.set_input_embeddings( | |
| wrap(wrap(self.lang_encoder.get_input_embeddings())) | |
| ) | |
| self.lang_encoder.set_output_embeddings( | |
| wrap(wrap(self.lang_encoder.get_output_embeddings())) | |
| ) | |
| self.vision_encoder = wrap(wrap(self.vision_encoder)) # frozen | |
| # manually move non-FSDP managed parameters to device_id | |
| # these are all in lang_encoder | |
| apply_with_stopping_condition( | |
| module=self.lang_encoder, | |
| apply_fn=lambda m: m.to(device_id), | |
| apply_condition=lambda m: len(list(m.children())) == 0, | |
| stopping_condition=lambda m: isinstance(m, FSDP), | |
| ) | |
| # exclude the original decoder layers from the optimizer | |
| for block in self.lang_encoder.old_decoder_blocks: | |
| for p in block.parameters(): | |
| p.exclude_from_optimizer = True | |
| # set up clip_grad_norm_ function | |
| def clip_grad_norm_(max_norm): | |
| self.perceiver.clip_grad_norm_(max_norm) | |
| for layer in self.lang_encoder.gated_cross_attn_layers: | |
| if layer is not None: | |
| layer.clip_grad_norm_(max_norm) | |
| self.lang_encoder.get_input_embeddings().clip_grad_norm_(max_norm) | |
| self.clip_grad_norm_ = clip_grad_norm_ | |
| def _condition_media_locations(self, input_ids: torch.Tensor): | |
| """ | |
| Compute the media token locations from lang_x and condition the language model on these. | |
| Args: | |
| input_ids (torch.Tensor): Language input | |
| shape (B, T_txt) | |
| """ | |
| media_locations = input_ids == self.media_token_id | |
| for layer in self.lang_encoder._get_decoder_layers(): | |
| layer.condition_media_locations(media_locations) | |
| def cache_media(self, input_ids: torch.Tensor, vision_x: torch.Tensor): | |
| """ | |
| Pre-cache a prompt/sequence of images / text for log-likelihood evaluations. | |
| All subsequent calls to forward() will generate attending to the LAST | |
| image in vision_x. | |
| This is not meant to be used to cache things for generate(). | |
| Args: | |
| input_ids (torch.Tensor): Language input | |
| shape (B, T_txt) | |
| vision_x (torch.Tensor): Vision input | |
| shape (B, T_img, F, C, H, W) | |
| Images in the same chunk are collated along T_img, and frames are collated along F | |
| Currently only F=1 is supported (single-frame videos) | |
| """ | |
| self._encode_vision_x(vision_x=vision_x) | |
| self._condition_media_locations(input_ids=input_ids) | |
| self.lang_encoder._use_cached_vision_x = True | |
| def uncache_media(self): | |
| """ | |
| Clear all conditioning. | |
| """ | |
| self.lang_encoder.clear_conditioned_layers() | |
| self.lang_encoder._use_cached_vision_x = False | |