View file src/colab/image_caption_with_flamingo.py - Download
# -*- coding: utf-8 -*-
"""image_caption_with_flamingo.ipynb
Automatically generated by Colab.
Original file is located at
https://colab.research.google.com/drive/1LHNo5LhdE18qhj9lWffDX93LsgHlKM7F
https://colab.research.google.com/github/githubpradeep/notebooks/blob/main/meme_creator.ipynb
"""
import torch
# !pip install --upgrade transformers
!pip install transformers==4.26.0
# !pip install --upgrade flamingo-mini
!pip install --upgrade flamingo-pytorch
!pip install vit-pytorch
!git clone https://github.com/dhansmair/flamingo-mini.git
# Commented out IPython magic to ensure Python compatibility.
# %cd flamingo-mini
!pip install .
from flamingo_mini import FlamingoConfig, FlamingoModel, FlamingoProcessor
model = FlamingoModel.from_pretrained('dhansmair/flamingo-mini') # or flamingo-tiny
processor = FlamingoProcessor(model.config)
model=model.to('cuda')
model.eval()
from PIL import Image
import requests
url = 'https://raw.githubusercontent.com/rmokady/CLIP_prefix_caption/main/Images/CONCEPTUAL_02.jpg'
im = Image.open(requests.get(url, stream=True).raw)
im
import torch
from flamingo_mini import FlamingoModel, FlamingoProcessor
from flamingo_mini.utils import load_url
device='cuda'
# load and process an example image
print('loading image and generating caption...')
image = load_url(url)
caption = model.generate_captions(processor, images=[image], device=device)
print('generated caption:', caption[0])
from einops import repeat
def multimodal_prompt(model: FlamingoModel, processor: FlamingoProcessor, prompt: str, images: list, device: torch.device) -> str:
"""few-shot multimodal prompting
shape for visual_features parameter: [b=batch size, N=number of images, T=1 (number of frames), v=number of visual features, d=dimensionality of visual feature]
(I haven't checked if it works with videos)
"""
input_ids, media_locations, attention_mask = processor.encode_text(prompt, device=device)
pixels = processor(images, device=device)['pixel_values']
pixels = repeat(pixels, 'N c h w -> b N T c h w', b=1, T=1)
output = model.generate(
inputs=input_ids,
media_locations=media_locations,
attention_mask=attention_mask,
pixel_values=pixels,
max_length=150,
use_cache=True,
early_stopping=True,
bos_token_id=model.flamingo.lm.config.bos_token_id,
eos_token_id=model.flamingo.lm.config.eos_token_id,
pad_token_id=model.flamingo.lm.config.eos_token_id
)
response = processor.tokenizer.batch_decode(output, skip_special_tokens=True)
return response[0]
prompt = " A photo of"
response = multimodal_prompt(model, processor, prompt, images=[image], device=device)
print('prompt:', prompt)
print('output:', response)