View file src/colab/sddd_hybrid.py - Download

# -*- coding: utf-8 -*-
"""sddd_hybrid.ipynb

Automatically generated by Colaboratory.

Original file is located at
    https://colab.research.google.com/drive/17FzKUEG8sjrMiyeaX6Jlu0YP05naPUC8

Links :

- https://towardsdatascience.com/stable-diffusion-using-hugging-face-501d8dbdd8

- https://colab.research.google.com/github/fastai/diffusion-nbs/blob/master/Stable%20Diffusion%20Deep%20Dive.ipynb#scrollTo=650e1224-d757-45e9-a7cc-67d608bd73bd

- https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/stable_diffusion.ipynb

- https://colab.research.google.com/drive/1dlgggNa5Mz8sEAGU0wFCHhGLFooW_pf1?usp=sharing

- https://colab.research.google.com/drive/1qXCgWuH8VWldCszISA58SKp7w0w7OlJZ?usp=sharing

Installations and imports
"""

!pip install -q --upgrade transformers diffusers ftfy


from base64 import b64encode

import numpy
import torch
from diffusers import AutoencoderKL, LMSDiscreteScheduler, UNet2DConditionModel
from huggingface_hub import notebook_login

# For video display:
from IPython.display import HTML
from matplotlib import pyplot as plt
from pathlib import Path
from PIL import Image
from torch import autocast
from torchvision import transforms as tfms
from tqdm.auto import tqdm
from transformers import CLIPTextModel, CLIPTokenizer, logging

torch.manual_seed(1)
if not (Path.home()/'.huggingface'/'token').exists(): notebook_login()

# Supress some unnecessary warnings when loading the CLIPTextModel
logging.set_verbosity_error()

# Set device
torch_device = "cuda" if torch.cuda.is_available() else "cpu"


# Load the autoencoder model which will be used to decode the latents into image space.
vae = AutoencoderKL.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="vae")

# Load the tokenizer and text encoder to tokenize and encode the text.
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
text_encoder = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14")

# The UNet model for generating the latents.
unet = UNet2DConditionModel.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="unet")

# The noise scheduler
scheduler = LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000)

# To the GPU we go!
vae = vae.to(torch_device)
text_encoder = text_encoder.to(torch_device)
unet = unet.to(torch_device)

"""Functions"""

def pil_to_latent(input_im):
    # Single image -> single latent in a batch (so size 1, 4, 64, 64)
    with torch.no_grad():
        latent = vae.encode(tfms.ToTensor()(input_im).unsqueeze(0).to(torch_device)*2-1) # Note scaling
    return 0.18215 * latent.latent_dist.sample()

def latents_to_pil(latents):
    # bath of latents -> list of images
    latents = (1 / 0.18215) * latents
    with torch.no_grad():
        image = vae.decode(latents).sample
    image = (image / 2 + 0.5).clamp(0, 1)
    image = image.detach().cpu().permute(0, 2, 3, 1).numpy()
    images = (image * 255).round().astype("uint8")
    pil_images = [Image.fromarray(image) for image in images]
    return pil_images

def generate(guidance_scale, num_inference_steps, seed, prompts, nprompts=[("", 1)], input_image=None, start_step=0, start_sigma=1):

  height = 512                        # default height of Stable Diffusion
  width = 512                         # default width of Stable Diffusion
  batch_size = 1

  generator = torch.manual_seed(seed)   # Seed generator to create the inital latent noise

  # Prep text
  # text_input = tokenizer(prompt, padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt")
  # with torch.no_grad():
  #   text_embeddings = text_encoder(text_input.input_ids.to(torch_device))[0]
  # max_length = text_input.input_ids.shape[-1]
  # uncond_input = tokenizer(
  #   [nprompt] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt"
  # )
  # with torch.no_grad():
  #   uncond_embeddings = text_encoder(uncond_input.input_ids.to(torch_device))[0]
  # text_embeddings = torch.cat([uncond_embeddings, text_embeddings])

  i = 0
  for prompt in prompts:
    i += 1
    text_input = tokenizer([prompt[0]], padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt")
    with torch.no_grad():
      text_embeddings_i = text_encoder(text_input.input_ids.to(torch_device))[0]
    if (i == 1):
      text_embeddings = text_embeddings_i * prompt[1]
    else:
      text_embeddings += text_embeddings_i * prompt[1]

  # Prompt amplification :
  # The text embeddings are the concatenation of the text embeddings for an empty prompt and the text embeddings for the given prompt
  max_length = text_input.input_ids.shape[-1]

  i = 0
  for nprompt in nprompts:
    i += 1
    uncond_input = tokenizer(
      [nprompt[0]] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt"
    )
    with torch.no_grad():
      uncond_embeddings_i = text_encoder(uncond_input.input_ids.to(torch_device))[0]
    if i == 1:
      uncond_embeddings = uncond_embeddings_i * nprompt[1]
    else:
      uncond_embeddings += uncond_embeddings_i * nprompt[1]
    text_embeddings = torch.cat([uncond_embeddings, text_embeddings])


  # Prep Scheduler
  scheduler.set_timesteps(num_inference_steps)

  # Prep latents
  if not input_image:
    latents = torch.randn(
      (batch_size, unet.in_channels, height // 8, width // 8),
      generator=generator,
    )
    latents = latents.to(torch_device)
    latents = latents * scheduler.init_noise_sigma # Scaling (previous versions did latents = latents * self.scheduler.sigmas[0]
  else:
    # Encode to the latent space
    encoded = pil_to_latent(input_image)
    # start_sigma = scheduler.sigmas[start_step]
    noise = torch.randn_like(encoded)
    latents = scheduler.add_noise(encoded, noise, timesteps=torch.tensor([scheduler.timesteps[start_step]]))
    latents = latents.to(torch_device) # .float()
    latents = latents * start_sigma

  # Loop
  with autocast("cuda"):
    for i, t in tqdm(enumerate(scheduler.timesteps)):
      if i >= start_step:
        # expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
        latent_model_input = torch.cat([latents] * 2)
        sigma = scheduler.sigmas[i]
        # Scale the latents (preconditioning):
        # latent_model_input = latent_model_input / ((sigma**2 + 1) ** 0.5) # Diffusers 0.3 and below
        latent_model_input = scheduler.scale_model_input(latent_model_input, t)

        # predict the noise residual
        with torch.no_grad():
            noise_pred = unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample

        # perform guidance
        noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
        noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)

        # compute the previous noisy sample x_t -> x_t-1
        # latents = scheduler.step(noise_pred, i, latents)["prev_sample"] # Diffusers 0.3 and below
        latents = scheduler.step(noise_pred, t, latents).prev_sample

  # scale and decode the image latents with vae
  latents = 1 / 0.18215 * latents
  with torch.no_grad():
    image = vae.decode(latents).sample

  # Display
  image = (image / 2 + 0.5).clamp(0, 1)
  image = image.detach().cpu().permute(0, 2, 3, 1).numpy()
  images = (image * 255).round().astype("uint8")
  pil_images = [Image.fromarray(image) for image in images]
  return pil_images[0]
  # return latents_to_pil(latents)[0]

"""Examples of image generation from a simple prompt"""

prompt = [("A watercolor painting of an otter", 1)]
num_inference_steps = 30            # Number of denoising steps
guidance_scale = 7.5                # Scale for classifier-free guidance
seed = 32
generate(guidance_scale, num_inference_steps, seed, prompt)

guidance_scale = 6
num_inference_steps = 20
seed = 13
prompt = [("An astronaut riding a horse", 1)]
image = generate(guidance_scale, num_inference_steps, seed, prompt)
display(image)

"""Prompt hybridation : a flower-bird"""

guidance_scale = 8
num_inference_steps = 12
seed = 32
prompt = [("a flower", 0.52), ("a bird", 0.48)]
image = generate(guidance_scale, num_inference_steps, seed, prompt)
display(image)

"""Prompt hybridation : a mouse-leopard"""

guidance_scale = 7.5
num_inference_steps = 30
seed = 32
prompt = [("A mouse", 0.5), ("A leopard", 0.5)]
image = generate(guidance_scale, num_inference_steps, seed, prompt)
display(image)

"""Image generation from a prompt and another image (image to image)"""

# Download a demo Image
!curl --output macaw.jpg 'https://lafeber.com/pet-birds/wp-content/uploads/2018/06/Scarlet-Macaw-2.jpg'
# Load the image with PIL
input_image = Image.open('macaw.jpg').resize((512, 512))
display(input_image)
# Encode to the latent space
encoded = pil_to_latent(input_image)

# Settings (same as before except for the new prompt)
prompt = [("A colorful dancer, nat geo photo", 1)]
nprompt = [("", 1)]
num_inference_steps = 50            # Number of denoising steps
guidance_scale = 8                  # Scale for classifier-free guidance
seed = 32
start_step = 10
start_sigma = 1

image = generate(guidance_scale, num_inference_steps, seed, prompt, nprompt, input_image, start_step, start_sigma)
display(image)

import urllib
guidance_scale = 8
num_inference_steps = 40
seed = 34
prompt = [("A colorful dancer, nat geo photo", 1)]
nprompt = [("", 1)]
input_image_url = "https://img.cutenesscdn.com/630x/clsd/getty/33e691dded7d4ddc87a83613469f897e?type=webp"
urllib.request.urlretrieve(input_image_url, "alter_input.jpg")
input_image = Image.open('alter_input.jpg').resize((512, 512))
display(input_image)
start_step = 4
start_sigma = 1
image = generate(guidance_scale, num_inference_steps, seed, prompt, nprompt, input_image, start_step, start_sigma)
display(image)

"""Negative prompt : gardens with few flowers"""

guidance_scale = 7.5
seeds = range(0, 8)
num_inference_steps = 20
prompt = [("A garden", 1)]
nprompt = [("flower", 1)]

for seed in seeds:
  image = generate(guidance_scale, num_inference_steps, seed, prompt)
  print(f"seed={seed} without negative prompt")
  display(image)
  image = generate(guidance_scale, num_inference_steps, seed, prompt, nprompt)
  print(f"seed={seed} with negative prompt")
  display(image)