KC123hello's picture
Upload Files
fc0ff8f verified
raw
history blame
3.84 kB
from typing import List
from PIL import Image
import torch
from transformers import Blip2Processor, Blip2ForConditionalGeneration
from open_flamingo.eval.eval_model import BaseEvalModel
from open_flamingo.eval.models.utils import unwrap_model
class EvalModel(BaseEvalModel):
"""BLIP-2 model evaluation.
Attributes:
model (nn.Module): Underlying Torch model.
tokenizer (transformers.PreTrainedTokenizer): Tokenizer for model.
device: Index of GPU to use, or the string "cpu"
"""
def __init__(self, model_args):
assert (
"processor_path" in model_args
and "lm_path" in model_args
and "device" in model_args
), "BLIP-2 requires processor_path, lm_path, and device arguments to be specified"
self.device = (
int(model_args["device"])
if ("device" in model_args and model_args["device"] >= 0)
else "cpu"
)
self.processor = Blip2Processor.from_pretrained(model_args["processor_path"])
self.model = Blip2ForConditionalGeneration.from_pretrained(
model_args["lm_path"]
)
self.model.to(self.device)
self.model.eval()
self.processor.tokenizer.padding_side = "left"
def _prepare_images(self, batch: List[List[torch.Tensor]]) -> torch.Tensor:
"""Preprocess images and stack them.
Args:
batch: A list of lists of images.
Returns:
A Tensor of shape
(batch_size, channels, height, width).
"""
batch_images = None
assert all(
len(example) == 1 for example in batch
), "BLIP-2 only supports one image per example"
for example in batch:
assert len(example) == 1, "BLIP-2 only supports one image per example"
batch_images = torch.cat(
[
batch_images,
self.processor.image_processor(example, return_tensors="pt")[
"pixel_values"
],
]
if batch_images is not None
else [
self.processor.image_processor(example, return_tensors="pt")[
"pixel_values"
]
],
dim=0,
)
return batch_images
def get_outputs(
self,
batch_text: List[str],
batch_images: List[List[Image.Image]],
max_generation_length: int,
num_beams: int,
length_penalty: float,
) -> List[str]:
encodings = self.processor.tokenizer(
batch_text,
padding="longest",
truncation=True,
return_tensors="pt",
max_length=2000,
)
input_ids = encodings["input_ids"]
attention_mask = encodings["attention_mask"]
with torch.inference_mode():
outputs = unwrap_model(self.model).generate(
self._prepare_images(batch_images).to(self.device),
input_ids.to(self.device),
attention_mask=attention_mask.to(self.device),
max_new_tokens=max_generation_length,
min_new_tokens=8,
num_beams=num_beams,
length_penalty=length_penalty,
)
return self.processor.tokenizer.batch_decode(outputs, skip_special_tokens=True)
def get_vqa_prompt(self, question, answer=None) -> str:
return (
f"Question:{question} Short answer:{answer if answer is not None else ''}"
)
def get_caption_prompt(self, caption=None) -> str:
return f"A photo of {caption if caption is not None else ''}"
def get_classification_prompt(self, class_str=None) -> str:
raise NotImplementedError