View file src/colab/image_caption_with_flamingo.py - Download

# -*- coding: utf-8 -*-
"""image_caption_with_flamingo.ipynb

Automatically generated by Colab.

Original file is located at
    https://colab.research.google.com/drive/1LHNo5LhdE18qhj9lWffDX93LsgHlKM7F

https://colab.research.google.com/github/githubpradeep/notebooks/blob/main/meme_creator.ipynb
"""

import torch

# !pip install --upgrade transformers
!pip install transformers==4.26.0
# !pip install --upgrade flamingo-mini

!pip install --upgrade flamingo-pytorch

!pip install vit-pytorch

!git clone https://github.com/dhansmair/flamingo-mini.git

# Commented out IPython magic to ensure Python compatibility.
# %cd flamingo-mini
!pip install .

from flamingo_mini import FlamingoConfig, FlamingoModel, FlamingoProcessor

model = FlamingoModel.from_pretrained('dhansmair/flamingo-mini')           # or flamingo-tiny
processor = FlamingoProcessor(model.config)

model=model.to('cuda')

model.eval()

from PIL import Image
import requests
url = 'https://raw.githubusercontent.com/rmokady/CLIP_prefix_caption/main/Images/CONCEPTUAL_02.jpg'
im = Image.open(requests.get(url, stream=True).raw)
im

import torch
from flamingo_mini import FlamingoModel, FlamingoProcessor
from flamingo_mini.utils import load_url
device='cuda'

# load and process an example image
print('loading image and generating caption...')
image = load_url(url)
caption = model.generate_captions(processor, images=[image], device=device)
print('generated caption:', caption[0])

from einops import repeat

def multimodal_prompt(model: FlamingoModel, processor: FlamingoProcessor, prompt: str, images: list, device: torch.device) -> str:
    """few-shot multimodal prompting

    shape for visual_features parameter: [b=batch size, N=number of images, T=1 (number of frames), v=number of visual features, d=dimensionality of visual feature]
    (I haven't checked if it works with videos)
    """
    input_ids, media_locations, attention_mask = processor.encode_text(prompt, device=device)

    pixels = processor(images, device=device)['pixel_values']
    pixels = repeat(pixels, 'N c h w -> b N T c h w', b=1, T=1)

    output = model.generate(
        inputs=input_ids,
        media_locations=media_locations,
        attention_mask=attention_mask,
        pixel_values=pixels,
        max_length=150,
        use_cache=True,
        early_stopping=True,
        bos_token_id=model.flamingo.lm.config.bos_token_id,
        eos_token_id=model.flamingo.lm.config.eos_token_id,
        pad_token_id=model.flamingo.lm.config.eos_token_id
    )

    response = processor.tokenizer.batch_decode(output, skip_special_tokens=True)
    return response[0]

prompt = " A photo of"
response = multimodal_prompt(model, processor, prompt, images=[image], device=device)
print('prompt:', prompt)
print('output:', response)