View file src/colab/stable_diffusion_gen.py - Download

# -*- coding: utf-8 -*-
"""stable_diffusion_gen.ipynb

Automatically generated by Colaboratory.

Original file is located at
    https://colab.research.google.com/drive/1cKc_YoLbyLNP_2TausAHM_BDZIP1TwjS

Links :

- https://towardsdatascience.com/stable-diffusion-using-hugging-face-501d8dbdd8

- https://colab.research.google.com/github/fastai/diffusion-nbs/blob/master/Stable%20Diffusion%20Deep%20Dive.ipynb#scrollTo=650e1224-d757-45e9-a7cc-67d608bd73bd

- https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/stable_diffusion.ipynb

- https://colab.research.google.com/drive/1dlgggNa5Mz8sEAGU0wFCHhGLFooW_pf1?usp=sharing

- https://colab.research.google.com/drive/1qXCgWuH8VWldCszISA58SKp7w0w7OlJZ?usp=sharing

Stable Diffusion is based on a particular type of diffusion model called **Latent Diffusion**, proposed in [High-Resolution Image Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752).

General diffusion models are machine learning systems that are trained to *denoise* random gaussian noise step by step, to get to a sample of interest, such as an *image*. For a more detailed overview of how they work, check [this colab](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/diffusers_intro.ipynb).

Latent diffusion can reduce the memory and compute complexity by applying the diffusion process over a lower dimensional _latent_ space, instead of using the actual pixel space. This is the key difference between standard diffusion and latent diffusion models: **in latent diffusion the model is trained to generate latent (compressed) representations of the images.**

There are three main components in latent diffusion.

1. An autoencoder (VAE).
2. A [U-Net](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/diffusers_intro.ipynb#scrollTo=wW8o1Wp0zRkq).
3. A text-encoder, *e.g.* [CLIP's Text Encoder](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel).

**1. The autoencoder (VAE)**

The VAE model has two parts, an encoder and a decoder. The encoder is used to convert the image into a low dimensional latent representation, which will serve as the input to the *U-Net* model.
The decoder, conversely, transforms the latent representation back into an image.

 During latent diffusion _training_, the encoder is used to get the latent representations (_latents_) of the images for the forward diffusion process, which applies more and more noise at each step. During _inference_, the denoised latents generated by the reverse diffusion process are converted back into images using the VAE decoder. As we will see during inference we **only need the VAE decoder**.

**2. The U-Net**

The U-Net has an encoder part and a decoder part both comprised of ResNet blocks.
The encoder compresses an image representation into a lower resolution image representation and the decoder decodes the lower resolution image representation back to the original higher resolution image representation that is supposedly less noisy.
More specifically, the U-Net output predicts the noise residual which can be used to compute the predicted denoised image representation.

**3. The Text-encoder**

The text-encoder is responsible for transforming the input prompt, *e.g.* "An astronout riding a horse" into an embedding space that can be understood by the U-Net. It is usually a simple *transformer-based* encoder that maps a sequence of input tokens to a sequence of latent text-embeddings.

Here is a typical chart of a Python script using Stable Diffusion :

![stable_diffusion.png]()

Installations and imports
"""

diffusers_old_version = True

if diffusers_old_version:
  !pip install -q --upgrade transformers diffusers==0.2.4 ftfy
else:
  !pip install -q --upgrade transformers diffusers ftfy

from base64 import b64encode

import numpy
import torch
from diffusers import AutoencoderKL, LMSDiscreteScheduler, UNet2DConditionModel
from huggingface_hub import notebook_login

# For video display:
from IPython.display import HTML
from matplotlib import pyplot as plt
from pathlib import Path
from PIL import Image
from torch import autocast
from torchvision import transforms as tfms
from tqdm.auto import tqdm
from transformers import CLIPTextModel, CLIPTokenizer, logging

torch.manual_seed(1)
if not (Path.home()/'.huggingface'/'token').exists(): notebook_login()

# Supress some unnecessary warnings when loading the CLIPTextModel
logging.set_verbosity_error()

# Set device
torch_device = "cuda" if torch.cuda.is_available() else "cpu"


# Load the autoencoder model which will be used to decode the latents into image space.
vae = AutoencoderKL.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="vae")

# Load the tokenizer and text encoder to tokenize and encode the text.
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
text_encoder = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14")

# The UNet model for generating the latents.
unet = UNet2DConditionModel.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="unet")

# The noise scheduler
scheduler = LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000)

# To the GPU we go!
vae = vae.to(torch_device)
text_encoder = text_encoder.to(torch_device)
unet = unet.to(torch_device)

"""Functions"""

def pil_to_latent(input_im):
    # Single image -> single latent in a batch (so size 1, 4, 64, 64)
    with torch.no_grad():
        latent = vae.encode(tfms.ToTensor()(input_im).unsqueeze(0).to(torch_device)*2-1) # Note scaling
    if diffusers_old_version:
      return 0.18215 * latent.mode()
    else:
      return 0.18215 * latent.latent_dist.sample()

def latents_to_pil(latents):
    # bath of latents -> list of images
    latents = (1 / 0.18215) * latents
    with torch.no_grad():
      image = vae.decode(latents)
      if not diffusers_old_version:
        image = image.sample
    image = (image / 2 + 0.5).clamp(0, 1)
    image = image.detach().cpu().permute(0, 2, 3, 1).numpy()
    images = (image * 255).round().astype("uint8")
    pil_images = [Image.fromarray(image) for image in images]
    return pil_images

def generate(guidance_scale, num_inference_steps, seed, prompts, nprompts=[("", 1)], input_image=None, start_step=0, start_loop=None):

  if not start_loop:
    start_loop = start_step

  height = 512                        # default height of Stable Diffusion
  width = 512                         # default width of Stable Diffusion
  batch_size = 1

  generator = torch.manual_seed(seed)   # Seed generator to create the inital latent noise

  # Prep text
  # text_input = tokenizer(prompt, padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt")
  # with torch.no_grad():
  #   text_embeddings = text_encoder(text_input.input_ids.to(torch_device))[0]
  # max_length = text_input.input_ids.shape[-1]
  # uncond_input = tokenizer(
  #   [nprompt] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt"
  # )
  # with torch.no_grad():
  #   uncond_embeddings = text_encoder(uncond_input.input_ids.to(torch_device))[0]
  # text_embeddings = torch.cat([uncond_embeddings, text_embeddings])

  i = 0
  for prompt in prompts:
    i += 1
    text_input = tokenizer([prompt[0]], padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt")
    with torch.no_grad():
      text_embeddings_i = text_encoder(text_input.input_ids.to(torch_device))[0]
    if (i == 1):
      text_embeddings = text_embeddings_i * prompt[1]
    else:
      text_embeddings += text_embeddings_i * prompt[1]

  # Prompt amplification :
  # The text embeddings are the concatenation of the text embeddings for an empty prompt and the text embeddings for the given prompt
  max_length = text_input.input_ids.shape[-1]

  i = 0
  for nprompt in nprompts:
    i += 1
    uncond_input = tokenizer(
      [nprompt[0]] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt"
    )
    with torch.no_grad():
      uncond_embeddings_i = text_encoder(uncond_input.input_ids.to(torch_device))[0]
    if i == 1:
      uncond_embeddings = uncond_embeddings_i * nprompt[1]
    else:
      uncond_embeddings += uncond_embeddings_i * nprompt[1]
    text_embeddings = torch.cat([uncond_embeddings, text_embeddings])


  # Prep Scheduler
  scheduler.set_timesteps(num_inference_steps)

  # Prep latents
  if not input_image:
    latents = torch.randn(
      (batch_size, unet.in_channels, height // 8, width // 8),
      generator=generator,
    )
    latents = latents.to(torch_device)
    # latents = latents * scheduler.init_noise_sigma # Scaling (previous versions did latents = latents * self.scheduler.sigmas[0]
    latents = latents * scheduler.sigmas[start_step]
  else:
    # Encode to the latent space
    encoded = pil_to_latent(input_image)
    # start_sigma = scheduler.sigmas[start_step]
    noise = torch.randn_like(encoded)
    if diffusers_old_version:
      latents = scheduler.add_noise(encoded, noise, timesteps=int(scheduler.timesteps[start_step]))
    else:
      latents = scheduler.add_noise(encoded, noise, timesteps=torch.tensor([scheduler.timesteps[start_step]]))
    latents = latents.to(torch_device) # .float()
    # latents = latents * start_sigma
    if diffusers_old_version:
      latents = latents * scheduler.sigmas[start_step]

  # Loop
  with autocast(torch_device):
    for i, t in tqdm(enumerate(scheduler.timesteps)):
      if i >= start_loop:
        # expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
        latent_model_input = torch.cat([latents] * 2)
        sigma = scheduler.sigmas[i]
        # Scale the latents (preconditioning):
        latent_model_input = latent_model_input / ((sigma**2 + 1) ** 0.5) # Diffusers 0.3 and below
        # latent_model_input = scheduler.scale_model_input(latent_model_input, t)

        # predict the noise residual
        with torch.no_grad():
            noise_pred = unet(latent_model_input, t, encoder_hidden_states=text_embeddings)["sample"]

        # perform guidance
        noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
        noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)

        # compute the previous noisy sample x_t -> x_t-1
        if diffusers_old_version:
          latents = scheduler.step(noise_pred, i, latents)["prev_sample"] # Diffusers 0.3 and below
        else:
          latents = scheduler.step(noise_pred, t, latents)["prev_sample"]

  # scale and decode the image latents with vae
  latents = 1 / 0.18215 * latents
  with torch.no_grad():
    image = vae.decode(latents)
    if not diffusers_old_version:
      image = image.sample

  # Display
  image = (image / 2 + 0.5).clamp(0, 1)
  image = image.detach().cpu().permute(0, 2, 3, 1).numpy()
  images = (image * 255).round().astype("uint8")
  pil_images = [Image.fromarray(image) for image in images]
  return pil_images[0]

  # return latents_to_pil(latents)[0]

"""Examples of image generation from a simple prompt"""

prompt = [("A watercolor painting of an otter", 1)]
num_inference_steps = 30            # Number of denoising steps
guidance_scale = 7.5                # Scale for classifier-free guidance
seed = 32
generate(guidance_scale, num_inference_steps, seed, prompt)

guidance_scale = 6
num_inference_steps = 20
seed = 13
prompt = [("An astronaut riding a horse", 1)]
image = generate(guidance_scale, num_inference_steps, seed, prompt)
display(image)

"""Prompt hybridation : a flower-bird"""

guidance_scale = 8
num_inference_steps = 12
seed = 32
prompt = [("a flower", 0.52), ("a bird", 0.48)]
image = generate(guidance_scale, num_inference_steps, seed, prompt)
display(image)

"""Prompt hybridation : a mouse-leopard"""

guidance_scale = 7.5
num_inference_steps = 30
seed = 32
prompt = [("A mouse", 0.5), ("A leopard", 0.5)]
image = generate(guidance_scale, num_inference_steps, seed, prompt)
display(image)

"""Image generation from a prompt and another image (image to image)"""

!curl --output house.jpg 'http://log.chez.com/images/house.jpg'

# Input image
img_name = "house.jpg"
input_image = Image.open(img_name).resize((512, 512))
display(input_image)

prompt = [("A drawing of a small house on a green hill with a tree beside", 1)]
nprompt = [("", 1)]
num_inference_steps = 50           # Number of denoising steps
guidance_scale = 7.5               # Scale for classifier-free guidance
seed = 32
start_step = 10
start_loop = 10

for seed in range(30, 40):
  # for guidance_scale in range(1,12):
    #for start_step in range(1,20):
      #for start_loop in range(5, 15):
        # start_loop = start_step + 1
        print(f"seed={seed} start_step={start_step} start_loop={start_loop} scale={guidance_scale}")
        image = generate(guidance_scale, num_inference_steps, seed, prompt, nprompt, input_image, start_step, start_loop)
        display(image)

!curl --output house.jpg 'http://log.chez.com/images/house.jpg'

# Input image
img_name = "house.jpg"
input_image = Image.open(img_name).resize((512, 512))
display(input_image)
# input_image = None

# prompt = [("A color drawing of a small house on a green hill with a tree beside", 1)]
prompt = [("A painting of a small house on a green hill with a tree beside", 1)]
nprompt = [("", 1)]
num_inference_steps = 50           # Number of denoising steps
guidance_scale = 2                # Scale for classifier-free guidance
seed = 32
start_step = 10
start_loop = 11

for seed in range(30, 40):
  # for guidance_scale in range(1,12):
    #for start_step in range(1,20):
      #for start_loop in range(5, 15):
        # start_loop = start_step + 1
        print(f"seed={seed} start_step={start_step} start_loop={start_loop} scale={guidance_scale}")
        image = generate(guidance_scale, num_inference_steps, seed, prompt, nprompt, input_image, start_step, start_loop)
        display(image)

# Download a demo Image
!curl --output macaw.jpg 'https://lafeber.com/pet-birds/wp-content/uploads/2018/06/Scarlet-Macaw-2.jpg'
# Load the image with PIL
input_image = Image.open('macaw.jpg').resize((512, 512))
display(input_image)
# Encode to the latent space
encoded = pil_to_latent(input_image)

# Settings (same as before except for the new prompt)
prompt = [("A colorful dancer, nat geo photo", 1)]
nprompt = [("", 1)]
num_inference_steps = 50            # Number of denoising steps
guidance_scale = 8                  # Scale for classifier-free guidance
seed = 32
start_step = 10

image = generate(guidance_scale, num_inference_steps, seed, prompt, nprompt, input_image, start_step)
display(image)

import urllib
guidance_scale = 8
num_inference_steps = 40
seed = 34
prompt = [("A colorful dancer, nat geo photo", 1)]
nprompt = [("", 1)]
input_image_url = "https://img.cutenesscdn.com/630x/clsd/getty/33e691dded7d4ddc87a83613469f897e?type=webp"
urllib.request.urlretrieve(input_image_url, "alter_input.jpg")
input_image = Image.open('alter_input.jpg').resize((512, 512))
display(input_image)
start_step = 4

image = generate(guidance_scale, num_inference_steps, seed, prompt, nprompt, input_image, start_step)
display(image)

"""Negative prompt : gardens with few flowers"""

guidance_scale = 7.5
seeds = range(0, 8)
num_inference_steps = 20
prompt = [("A garden", 1)]
nprompt = [("flower", 1)]

for seed in seeds:
  image = generate(guidance_scale, num_inference_steps, seed, prompt)
  print(f"seed={seed} without negative prompt")
  display(image)
  image = generate(guidance_scale, num_inference_steps, seed, prompt, nprompt)
  print(f"seed={seed} with negative prompt")
  display(image)