Spaces:
Runtime error
Runtime error
| from typing import List | |
| from PIL import Image | |
| import torch | |
| from open_flamingo.eval.eval_model import BaseEvalModel | |
| from open_flamingo.src.factory import create_model_and_transforms | |
| from contextlib import suppress | |
| from open_flamingo.eval.models.utils import unwrap_model | |
| class EvalModel(BaseEvalModel): | |
| """OpenFlamingo model evaluation. | |
| Attributes: | |
| model (nn.Module): Underlying Torch model. | |
| tokenizer (transformers.PreTrainedTokenizer): Tokenizer for model. | |
| device: Index of GPU to use, or the string "CPU" | |
| """ | |
| def __init__(self, model_args): | |
| assert ( | |
| "vision_encoder_path" in model_args | |
| and "lm_path" in model_args | |
| and "checkpoint_path" in model_args | |
| and "lm_tokenizer_path" in model_args | |
| and "cross_attn_every_n_layers" in model_args | |
| and "vision_encoder_pretrained" in model_args | |
| and "precision" in model_args | |
| ), "OpenFlamingo requires vision_encoder_path, lm_path, device, checkpoint_path, lm_tokenizer_path, cross_attn_every_n_layers, vision_encoder_pretrained, and precision arguments to be specified" | |
| self.device = ( | |
| model_args["device"] | |
| if ("device" in model_args and model_args["device"] >= 0) | |
| else "cpu" | |
| ) | |
| ( | |
| self.model, | |
| self.image_processor, | |
| self.tokenizer, | |
| ) = create_model_and_transforms( | |
| model_args["vision_encoder_path"], | |
| model_args["vision_encoder_pretrained"], | |
| model_args["lm_path"], | |
| model_args["lm_tokenizer_path"], | |
| cross_attn_every_n_layers=int(model_args["cross_attn_every_n_layers"]), | |
| ) | |
| checkpoint = torch.load(model_args["checkpoint_path"], map_location=self.device) | |
| if "model_state_dict" in checkpoint: | |
| checkpoint = checkpoint["model_state_dict"] | |
| checkpoint = {k.replace("module.", ""): v for k, v in checkpoint.items()} | |
| self.model.load_state_dict(checkpoint, strict=False) | |
| self.model.to(self.device) | |
| self.model.eval() | |
| self.tokenizer.padding_side = "left" | |
| # autocast | |
| self.autocast = get_autocast(model_args["precision"]) | |
| self.cast_dtype = get_cast_dtype(model_args["precision"]) | |
| def _prepare_images(self, batch: List[List[torch.Tensor]]) -> torch.Tensor: | |
| """Preprocess images and stack them. | |
| Args: | |
| batch: A list of lists of images. | |
| Returns: | |
| A Tensor of shape | |
| (batch_size, images_per_example, frames, channels, height, width). | |
| """ | |
| images_per_example = max(len(x) for x in batch) | |
| batch_images = None | |
| for iexample, example in enumerate(batch): | |
| for iimage, image in enumerate(example): | |
| preprocessed = self.image_processor(image) | |
| if batch_images is None: | |
| batch_images = torch.zeros( | |
| (len(batch), images_per_example, 1) + preprocessed.shape, | |
| dtype=preprocessed.dtype, | |
| ) | |
| batch_images[iexample, iimage, 0] = preprocessed | |
| return batch_images | |
| def get_outputs( | |
| self, | |
| batch_text: List[str], | |
| batch_images: List[List[Image.Image]], | |
| min_generation_length: int, | |
| max_generation_length: int, | |
| num_beams: int, | |
| length_penalty: float, | |
| ) -> List[str]: | |
| encodings = self.tokenizer( | |
| batch_text, | |
| padding="longest", | |
| truncation=True, | |
| return_tensors="pt", | |
| max_length=2000, | |
| ) | |
| input_ids = encodings["input_ids"] | |
| attention_mask = encodings["attention_mask"] | |
| with torch.inference_mode(): | |
| with self.autocast(): | |
| outputs = unwrap_model(self.model).generate( | |
| self._prepare_images(batch_images).to( | |
| self.device, dtype=self.cast_dtype, non_blocking=True | |
| ), | |
| input_ids.to(self.device, dtype=self.cast_dtype, non_blocking=True), | |
| attention_mask=attention_mask.to( | |
| self.device, dtype=self.cast_dtype, non_blocking=True | |
| ), | |
| min_new_tokens=min_generation_length, | |
| max_new_tokens=max_generation_length, | |
| num_beams=num_beams, | |
| length_penalty=length_penalty, | |
| ) | |
| outputs = outputs[:, len(input_ids[0]) :] | |
| return self.tokenizer.batch_decode(outputs, skip_special_tokens=True) | |
| def get_logits( | |
| self, | |
| lang_x: torch.Tensor, | |
| vision_x: torch.Tensor = None, | |
| attention_mask: torch.Tensor = None, | |
| past_key_values: torch.Tensor = None, | |
| clear_conditioned_layers: bool = False, | |
| ): | |
| with torch.inference_mode(): | |
| with self.autocast(): | |
| outputs = self.model( | |
| vision_x=vision_x, | |
| lang_x=lang_x, | |
| attention_mask=attention_mask, | |
| clear_conditioned_layers=clear_conditioned_layers, | |
| past_key_values=past_key_values, | |
| use_cache=(past_key_values is not None), | |
| ) | |
| return outputs | |
| def encode_vision_x(self, image_tensor: torch.Tensor): | |
| unwrap_model(self.model)._encode_vision_x(image_tensor.to(self.device)) | |
| def uncache_media(self): | |
| unwrap_model(self.model).uncache_media() | |
| def cache_media(self, input_ids, vision_x): | |
| unwrap_model(self.model).cache_media(input_ids=input_ids, vision_x=vision_x) | |
| def get_vqa_prompt(self, question, answer=None) -> str: | |
| return f"<image>Question:{question} Short answer:{answer if answer is not None else ''}{'<|endofchunk|>' if answer is not None else ''}" | |
| def get_caption_prompt(self, caption=None) -> str: | |
| return f"<image>Output:{caption if caption is not None else ''}{'<|endofchunk|>' if caption is not None else ''}" | |
| def get_cast_dtype(precision: str): | |
| cast_dtype = None | |
| if precision == "bf16": | |
| cast_dtype = torch.bfloat16 | |
| elif precision == "fp16": | |
| cast_dtype = torch.float16 | |
| return cast_dtype | |
| def get_autocast(precision): | |
| if precision == "amp": | |
| return torch.cuda.amp.autocast | |
| elif precision == "amp_bfloat16" or precision == "amp_bf16": | |
| # amp_bfloat16 is more stable than amp float16 for clip training | |
| return lambda: torch.cuda.amp.autocast(dtype=torch.bfloat16) | |
| else: | |
| return suppress | |