View file src/colab/stable_diffusion_analysis.py - Download
# -*- coding: utf-8 -*-
"""stable_diffusion_analysis.ipynb
Automatically generated by Colaboratory.
Original file is located at
https://colab.research.google.com/drive/1jwMrZfiJqWxh8vthqy1e1vmX7c8-y4rN
Analysis of image generation with Stable Diffusion
Links :
https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/stable_diffusion.ipynb
https://colab.research.google.com/drive/1dlgggNa5Mz8sEAGU0wFCHhGLFooW_pf1?usp=sharing
https://colab.research.google.com/drive/1qXCgWuH8VWldCszISA58SKp7w0w7OlJZ?usp=sharing
Install and import models
"""
# Stable Diffusion
# Source : https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/stable_diffusion.ipynb
print("Install packages")
!pip install diffusers==0.11.1
!pip install transformers scipy ftfy accelerate
print("Imports")
import torch
torch_device = "cuda" if torch.cuda.is_available() else "cpu"
repo_id = "google/ddpm-church-256"
from transformers import CLIPTextModel, CLIPTokenizer
from diffusers import AutoencoderKL, UNet2DModel, UNet2DConditionModel, PNDMScheduler
# 1. Load the autoencoder model which will be used to decode the latents into image space.
vae = AutoencoderKL.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="vae")
# 2. Load the tokenizer and text encoder to tokenize and encode the text.
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
text_encoder = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14")
# 3. The UNet model for generating the latents.
unet = UNet2DConditionModel.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="unet")
# model = UNet2DModel.from_pretrained(repo_id)
from diffusers import LMSDiscreteScheduler
scheduler = LMSDiscreteScheduler.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="scheduler")
from diffusers import DDPMScheduler
# scheduler = DDPMScheduler.from_config(repo_id)
vae = vae.to(torch_device)
text_encoder = text_encoder.to(torch_device)
unet = unet.to(torch_device)
"""Stable Diffusion is based on a particular type of diffusion model called **Latent Diffusion**, proposed in [High-Resolution Image Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752).
General diffusion models are machine learning systems that are trained to *denoise* random gaussian noise step by step, to get to a sample of interest, such as an *image*. For a more detailed overview of how they work, check [this colab](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/diffusers_intro.ipynb).
Latent diffusion can reduce the memory and compute complexity by applying the diffusion process over a lower dimensional _latent_ space, instead of using the actual pixel space. This is the key difference between standard diffusion and latent diffusion models: **in latent diffusion the model is trained to generate latent (compressed) representations of the images.**
There are three main components in latent diffusion.
1. An autoencoder (VAE).
2. A [U-Net](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/diffusers_intro.ipynb#scrollTo=wW8o1Wp0zRkq).
3. A text-encoder, *e.g.* [CLIP's Text Encoder](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel).
**1. The autoencoder (VAE)**
The VAE model has two parts, an encoder and a decoder. The encoder is used to convert the image into a low dimensional latent representation, which will serve as the input to the *U-Net* model.
The decoder, conversely, transforms the latent representation back into an image.
During latent diffusion _training_, the encoder is used to get the latent representations (_latents_) of the images for the forward diffusion process, which applies more and more noise at each step. During _inference_, the denoised latents generated by the reverse diffusion process are converted back into images using the VAE decoder. As we will see during inference we **only need the VAE decoder**.
**2. The U-Net**
The U-Net has an encoder part and a decoder part both comprised of ResNet blocks.
The encoder compresses an image representation into a lower resolution image representation and the decoder decodes the lower resolution image representation back to the original higher resolution image representation that is supposedly less noisy.
More specifically, the U-Net output predicts the noise residual which can be used to compute the predicted denoised image representation.
**3. The Text-encoder**
The text-encoder is responsible for transforming the input prompt, *e.g.* "An astronout riding a horse" into an embedding space that can be understood by the U-Net. It is usually a simple *transformer-based* encoder that maps a sequence of input tokens to a sequence of latent text-embeddings.
Here is a typical chart of a Python script using Stable Diffusion :

First we encode the prompt : it is tokenized, i.e. converted into a vector of 77 integers, then embedded, i.e. converted into a matrix of 77 x 768 floats.
"""
prompt = ["An astronaut riding a horse"]
text_input = tokenizer(prompt, padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt")
print("text_input:")
print(text_input)
print(f"{len(text_input.input_ids[0])} tokens")
with torch.no_grad():
text_embeddings = text_encoder(text_input.input_ids.to(torch_device))[0]
print("text_embeddings:")
print(text_embeddings.shape)
print(text_embeddings)
"""Then we generate the initial random noise.
"""
batch_size = 1
height = 512
width = 512
generator = torch.manual_seed(13) # Seed generator to create the inital latent noise
latents = torch.randn(
(batch_size, unet.in_channels, height // 8, width // 8),
generator=generator,
)
latents = latents.to(torch_device)
print(latents.shape)
"""Now we run the denoising loop.
"""
from tqdm.auto import tqdm
from torch import autocast
num_inference_steps = 20
scheduler.set_timesteps(num_inference_steps)
latents = latents * scheduler.init_noise_sigma
for t in tqdm(scheduler.timesteps):
latent_model_input = torch.Tensor(latents)
latent_model_input = scheduler.scale_model_input(latent_model_input, t)
# predict the noise residual
with torch.no_grad():
noise_pred = unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
# compute the previous noisy sample x_t -> x_t-1
latents = scheduler.step(noise_pred, t, latents).prev_sample
"""Finally we decode and display the image."""
# scale and decode the image latents with vae
latents = 1 / 0.18215 * latents
with torch.no_grad():
image = vae.decode(latents).sample
from PIL import Image
image = (image / 2 + 0.5).clamp(0, 1)
image = image.detach().cpu().permute(0, 2, 3, 1).numpy()
images = (image * 255).round().astype("uint8")
pil_images = [Image.fromarray(image) for image in images]
display(pil_images[0])
"""Here is the full process in one cell :"""
# Prompt encoding
prompt = ["An astronaut riding a horse"]
text_input = tokenizer(prompt, padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt")
with torch.no_grad():
text_embeddings = text_encoder(text_input.input_ids.to(torch_device))[0]
# Initial random noise
batch_size = 1
height = 512
width = 512
generator = torch.manual_seed(13) # Seed generator to create the inital latent noise
latents = torch.randn(
(batch_size, unet.in_channels, height // 8, width // 8),
generator=generator,
)
latents = latents.to(torch_device)
# Denoising loop
from tqdm.auto import tqdm
from torch import autocast
num_inference_steps = 20
scheduler.set_timesteps(num_inference_steps)
latents = latents * scheduler.init_noise_sigma
for t in tqdm(scheduler.timesteps):
latent_model_input = torch.Tensor(latents)
latent_model_input = scheduler.scale_model_input(latent_model_input, t)
# predict the noise residual
with torch.no_grad():
noise_pred = unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
# compute the previous noisy sample x_t -> x_t-1
latents = scheduler.step(noise_pred, t, latents).prev_sample
# Decode and display the image
# scale and decode the image latents with vae
latents = 1 / 0.18215 * latents
with torch.no_grad():
image = vae.decode(latents).sample
from PIL import Image
image = (image / 2 + 0.5).clamp(0, 1)
image = image.detach().cpu().permute(0, 2, 3, 1).numpy()
images = (image * 255).round().astype("uint8")
pil_images = [Image.fromarray(image) for image in images]
display(pil_images[0])
"""**Prompt amplification** (or classifier-free guidance)
To get more relevant results, we can "amplify" the prompt by computing the result with an empty prompt, and extrapolating from this result the result with the given prompt by a given factor (variable guidance_scale).
"""
guidance_scale = 6
# Prompt encoding
prompt = ["An astronaut riding a horse"]
text_input = tokenizer(prompt, padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt")
with torch.no_grad():
text_embeddings = text_encoder(text_input.input_ids.to(torch_device))[0]
# Initial random noise
batch_size = 1
height = 512
width = 512
generator = torch.manual_seed(13) # Seed generator to create the inital latent noise
latents = torch.randn(
(batch_size, unet.in_channels, height // 8, width // 8),
generator=generator,
)
latents = latents.to(torch_device)
# Prompt amplification :
# The text embeddings are the concatenation of the text embeddings for an empty prompt and the text embeddings for the given prompt
max_length = text_input.input_ids.shape[-1]
uncond_input = tokenizer(
[""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt"
)
with torch.no_grad():
uncond_embeddings = text_encoder(uncond_input.input_ids.to(torch_device))[0]
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
# Denoising loop
from tqdm.auto import tqdm
from torch import autocast
num_inference_steps = 20
scheduler.set_timesteps(num_inference_steps)
latents = latents * scheduler.init_noise_sigma
for t in tqdm(scheduler.timesteps):
# latent_model_input = torch.Tensor(latents)
# Prompt amplification
# expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
latent_model_input = torch.cat([latents] * 2)
latent_model_input = scheduler.scale_model_input(latent_model_input, t)
# predict the noise residual
with torch.no_grad():
noise_pred = unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
# Prompt amplification
# perform guidance
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
# compute the previous noisy sample x_t -> x_t-1
latents = scheduler.step(noise_pred, t, latents).prev_sample
# Decode and display the image
# scale and decode the image latents with vae
latents = 1 / 0.18215 * latents
with torch.no_grad():
image = vae.decode(latents).sample
from PIL import Image
image = (image / 2 + 0.5).clamp(0, 1)
image = image.detach().cpu().permute(0, 2, 3, 1).numpy()
images = (image * 255).round().astype("uint8")
pil_images = [Image.fromarray(image) for image in images]
display(pil_images[0])
"""We can display the intermediate steps of denoising to see the image appearing progressively :
"""
guidance_scale = 6
# Prompt encoding
prompt = ["An astronaut riding a horse"]
text_input = tokenizer(prompt, padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt")
with torch.no_grad():
text_embeddings = text_encoder(text_input.input_ids.to(torch_device))[0]
# Initial random noise
batch_size = 1
height = 512
width = 512
generator = torch.manual_seed(13) # Seed generator to create the inital latent noise
latents = torch.randn(
(batch_size, unet.in_channels, height // 8, width // 8),
generator=generator,
)
latents = latents.to(torch_device)
# Prompt amplification :
# The text embeddings are the concatenation of the text embeddings for an empty prompt and the text embeddings for the given prompt
max_length = text_input.input_ids.shape[-1]
uncond_input = tokenizer(
[""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt"
)
with torch.no_grad():
uncond_embeddings = text_encoder(uncond_input.input_ids.to(torch_device))[0]
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
# Denoising loop
from tqdm.auto import tqdm
from torch import autocast
num_inference_steps = 20
scheduler.set_timesteps(num_inference_steps)
latents = latents * scheduler.init_noise_sigma
def display_sample(sample, i):
# scale and decode the image latents with vae
sample = 1 / 0.18215 * sample
with torch.no_grad():
image = vae.decode(sample).sample
image = (image / 2 + 0.5).clamp(0, 1)
image = image.detach().cpu().permute(0, 2, 3, 1).numpy()
images = (image * 255).round().astype("uint8")
pil_images = [Image.fromarray(image) for image in images]
# pil_images[0].save("astronaut_riding_horse.png")
display(f"Image at step {i}")
display(pil_images[0])
step = 0
for t in tqdm(scheduler.timesteps):
step += 1
# latent_model_input = torch.Tensor(latents)
# Prompt amplification
# expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
latent_model_input = torch.cat([latents] * 2)
latent_model_input = scheduler.scale_model_input(latent_model_input, t)
# predict the noise residual
with torch.no_grad():
noise_pred = unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
# Prompt amplification
# perform guidance
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
# compute the previous noisy sample x_t -> x_t-1
latents = scheduler.step(noise_pred, t, latents).prev_sample
display_sample(latents, step)
"""We define a function generating an image from the following parameters :
- The guidance scale (prompt amplification factor)
- The number of denoising steps
- The seed for generating the initial random noise
- The prompt
"""
from tqdm.auto import tqdm
from torch import autocast
from PIL import Image
def generate(guidance_scale, num_inference_steps, seed, prompt):
# Prompt encoding
text_input = tokenizer(prompt, padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt")
with torch.no_grad():
text_embeddings = text_encoder(text_input.input_ids.to(torch_device))[0]
# Initial random noise
batch_size = 1
height = 512
width = 512
generator = torch.manual_seed(seed) # Seed generator to create the inital latent noise
latents = torch.randn(
(batch_size, unet.in_channels, height // 8, width // 8),
generator=generator,
)
latents = latents.to(torch_device)
# Prompt amplification :
# The text embeddings are the concatenation of the text embeddings for an empty prompt and the text embeddings for the given prompt
max_length = text_input.input_ids.shape[-1]
uncond_input = tokenizer(
[""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt"
)
with torch.no_grad():
uncond_embeddings = text_encoder(uncond_input.input_ids.to(torch_device))[0]
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
# Denoising loop
scheduler.set_timesteps(num_inference_steps)
latents = latents * scheduler.init_noise_sigma
for t in tqdm(scheduler.timesteps):
# latent_model_input = torch.Tensor(latents)
# Prompt amplification
# expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
latent_model_input = torch.cat([latents] * 2)
latent_model_input = scheduler.scale_model_input(latent_model_input, t)
# predict the noise residual
with torch.no_grad():
noise_pred = unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
# Prompt amplification
# perform guidance
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
# compute the previous noisy sample x_t -> x_t-1
latents = scheduler.step(noise_pred, t, latents).prev_sample
# Decode and display the image
# scale and decode the image latents with vae
latents = 1 / 0.18215 * latents
with torch.no_grad():
image = vae.decode(latents).sample
image = (image / 2 + 0.5).clamp(0, 1)
image = image.detach().cpu().permute(0, 2, 3, 1).numpy()
images = (image * 255).round().astype("uint8")
pil_images = [Image.fromarray(image) for image in images]
return(pil_images[0])
image = generate(6, 20, 13, "An astronaut riding a horse")
display(image)
"""Here is the result with different prompt amplification values :
"""
for guidance_scale in range(1,16):
image = generate(guidance_scale, 20, 13, "An astronaut riding a horse")
print(f"Result with guidance_scale = {guidance_scale} :")
display(image)
"""Here are the results with different numbers of denoising steps : """
for l in range(6):
num_inference_steps = 2 ** l
image = generate(6, num_inference_steps, 13, "An astronaut riding a horse")
print(f"Result with num_inference_steps = {num_inference_steps} :")
display(image)
"""The image generated depends on the initial random noise, which depends on the seed.
"""
guidance_scale = 8
num_inference_steps = 20
prompt = "An astronaut riding a horse"
for seed in range(16):
image = generate(guidance_scale, num_inference_steps, seed, prompt)
print(f"Result with seed = {seed} :")
display(image)
"""Without classifier-free guidance, we can notice some similarities (shapes, colors, style, elements...) between images generated with the same seed and different prompts.
"""
guidance_scale = 1
seeds = [1, 3, 4, 5, 6, 10]
num_inference_steps = 20
prompts = ["An astronaut riding a horse", "A garden"]
for seed in seeds:
for prompt in prompts:
image = generate(guidance_scale, num_inference_steps, seed, prompt)
print(f"Result with seed = {seed} and prompt = {prompt} :")
display(image)
"""There are less similarities with classifier-free guidance.
"""
guidance_scale = 4
seeds = [1, 3, 4, 5, 6, 10]
num_inference_steps = 20
prompts = ["An astronaut riding a horse", "A garden"]
for seed in seeds:
for prompt in prompts:
image = generate(guidance_scale, num_inference_steps, seed, prompt)
print(f"Result with seed = {seed} and prompt = {prompt} :")
display(image)