## Introduction

Over the past year we have seen the rise of generative AI that has mainly come in two forms:

Text-to-image generation powered by diffusion models like Stable Diffusion and DALL-E 2

Language models like ChatGPT and LLaMA-2

It turns out one of the key ingredients for the mainstream success of language models is the use of Reinforcement Learning from Human Feedback (RLHF) where language models are trained with human feedback to produce outputs that users are more likely to prefer. This has enabled these language models to more easily follow instructions, making these models significantly more accessible. Therefore the question arises if RLHF can be applied to diffusion models. This is a natural question to ask, since text-to-image diffusion models also struggle to follow prompts and tend to need prompt engineering skills in order to get desired results. A paper in May 2023 by the Levine Lab at UC Berkeley explored how the RLHF paradigm can be applied to diffusion models, resulting in an algorithm called DDPO. Here we’ll walk through a simple implementation of this DDPO algorithm. Let’s get started!

First let’s start with some basic imports:

```
import os
import requests
from pathlib import Path
import numpy as np
import matplotlib.pyplot as plt
import clip # pip install git+https://github.com/openai/CLIP.git
import torch
import random
import math
import wandb
from torch import nn
from diffusers import StableDiffusionPipeline, DDIMScheduler
from PIL import Image
from fastprogress import progress_bar, master_bar
```

Let’s load our Stable Diffusion model. Let’s also enable some performance optimizations (TF32 support, attention slicing, memory-efficient xformers attention) that will make it faster to work with our Stable Diffusion model for training.

```
torch.backends.cuda.matmul.allow_tf32 = True
pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4").to("cuda")
pipe.enable_attention_slicing()
pipe.enable_xformers_memory_efficient_attention()
pipe.text_encoder.requires_grad_(False)
pipe.vae.requires_grad_(False)
```

We’re using the diffusers library, which provides a simple-to-use interface for sampling the Stable Diffusion model using their pipeline:

```
prompt = "a photograph of an astronaut riding a horse"
img = pipe(prompt).images[0]
```

`100%|██████████| 50/50 [00:03<00:00, 15.99it/s]`

Okay then, we want to improve the images coming out of our model. In order to do so we should have some sort of score for the image that we can later optimize for. This score could represent how aesthetic the image is. This is frankly something that is quite subjective, and there is no mathematical equation for the aestheticness of an image. Instead we will use LAION’s aesthetic predictor, which was trained on thousands of human aesthetic ratings of AI-generated images and is a linear model on top of CLIP features. Below is the standard inference code for the aesthetic predictor model:

```
class MLP(nn.Module):
def __init__(self, input_size, xcol='emb', ycol='avg_rating'):
super().__init__()
self.input_size = input_size
self.xcol = xcol
self.ycol = ycol
self.layers = nn.Sequential(
nn.Linear(self.input_size, 1024),
#nn.ReLU(),
nn.Dropout(0.2),
nn.Linear(1024, 128),
#nn.ReLU(),
nn.Dropout(0.2),
nn.Linear(128, 64),
#nn.ReLU(),
nn.Dropout(0.1),
nn.Linear(64, 16),
#nn.ReLU(),
nn.Linear(16, 1)
)
def forward(self, x):
return self.layers(x)
```

```
def load_aesthetic_model_weights(cache="."):
weights_fname = "sac+logos+ava1-l14-linearMSE.pth"
loadpath = os.path.join(cache, weights_fname)
if not os.path.exists(loadpath):
url = (
"https://github.com/christophschuhmann/"
f"improved-aesthetic-predictor/blob/main/{weights_fname}?raw=true"
)
r = requests.get(url)
with open(loadpath, "wb") as f:
f.write(r.content)
weights = torch.load(loadpath, map_location=torch.device("cpu"))
return weights
```

```
def aesthetic_model_normalize(a, axis=-1, order=2):
l2 = np.atleast_1d(np.linalg.norm(a, order, axis))
l2[l2 == 0] = 1
return a / np.expand_dims(l2, axis)
```

We need the CLIP model, whose features will be passed into our aesthetic predictor:

`clip_model, preprocess = clip.load("ViT-L/14", device="cuda")`

`aesthetic_model = MLP(768)`

```
aesthetic_model.load_state_dict(load_aesthetic_model_weights())
aesthetic_model.cuda()
```

```
MLP(
(layers): Sequential(
(0): Linear(in_features=768, out_features=1024, bias=True)
(1): Dropout(p=0.2, inplace=False)
(2): Linear(in_features=1024, out_features=128, bias=True)
(3): Dropout(p=0.2, inplace=False)
(4): Linear(in_features=128, out_features=64, bias=True)
(5): Dropout(p=0.1, inplace=False)
(6): Linear(in_features=64, out_features=16, bias=True)
(7): Linear(in_features=16, out_features=1, bias=True)
)
)
```

```
image = preprocess(img).unsqueeze(0).cuda()
with torch.no_grad(): image_features = clip_model.encode_image(image)
```

```
im_emb_arr = aesthetic_model_normalize(image_features.cpu().detach().numpy())
prediction = aesthetic_model(torch.from_numpy(im_emb_arr).float().cuda())
```

`print(f'Aesthetic score: {prediction}')`

`Aesthetic score: tensor([[5.1999]], device='cuda:0', grad_fn=<AddmmBackward0>)`

Just like, that, we get the aesthetic score given with this predictor. Let’s package this code into a function:

```
def aesthetic_scoring(img, preprocess, clip_model, aesthetic_model_normalize, aesthetic_model):
image = preprocess(img).unsqueeze(0).cuda()
with torch.no_grad(): image_features = clip_model.encode_image(image)
im_emb_arr = aesthetic_model_normalize(image_features.cpu().detach().numpy())
prediction = aesthetic_model(torch.from_numpy(im_emb_arr).float().cuda())
return prediction
```

```
prompt = "a horse"
img = pipe(prompt).images[0]
print(f'Aesthetic score: {aesthetic_scoring(img, preprocess, clip_model, aesthetic_model_normalize, aesthetic_model)[0][0]}')
img
```

` 34%|███▍ | 17/50 [00:00<00:01, 19.25it/s]`

`100%|██████████| 50/50 [00:02<00:00, 19.03it/s]`

`Aesthetic score: 5.421084403991699`

```
prompt = "a beautiful, exquisite portrait of a horse, 4k, unreal engine"
img = pipe(prompt).images[0]
print(f'Aesthetic score: {aesthetic_scoring(img, preprocess, clip_model, aesthetic_model_normalize, aesthetic_model)[0][0]}')
img
```

` 62%|██████▏ | 31/50 [00:01<00:00, 19.53it/s]`

`100%|██████████| 50/50 [00:02<00:00, 19.17it/s]`

`Aesthetic score: 5.927961826324463`

```
prompt = "a very ugly photograph of a donkey"
img = pipe(prompt).images[0]
print(f'Aesthetic score: {aesthetic_scoring(img, preprocess, clip_model, aesthetic_model_normalize, aesthetic_model)[0][0]}')
img
```

` 58%|█████▊ | 29/50 [00:01<00:01, 19.56it/s]`

`100%|██████████| 50/50 [00:02<00:00, 19.16it/s]`

`Aesthetic score: 5.116610050201416`

You can see if you prompt for something “ugly” we do get an image with a lower score. But to be honest, it’s not that much lower. So I found another “ugly donkey” image online to score instead:

```
import requests
from io import BytesIO
img = Image.open(BytesIO(requests.get("https://i.redd.it/8wbqtdequzv41.jpg").content)).resize((512,512))
print(f'Aesthetic score: {aesthetic_scoring(img, preprocess, clip_model, aesthetic_model_normalize, aesthetic_model)[0][0]}')
img
```

`Aesthetic score: 4.846365928649902`

Okay that’s definitely much lower.

As you can see, for Stable Diffusion-generated images, variation of the score doesn’t tend to be much. There seems to be a sort of average aesthetic score that most images fluctuate around (~5.3). Can we increase the average aesthetic score of the images produced by Stable Diffusion? Reinforcement learning (RL) will help us do this.

## What if we could optimize the aesthetic score?

Now that we have some sort of measure of quality of our image, our aesthetic score, we can optimize for it. In the RL literature, this measure of quality that we are optimizing for is referred to as the **reward**. The goal of RL algorithms is to optimize the reward. We will see how DDPO does this for diffusion models.

Before we go down the RL route though, it is worth examining if there are alternative approaches. Diffusion models, after all, are an extremely versatile framework, and people have been incorporating different constraints and forms of guidance during sampling in order to achieve desired results. Let’s do a quick refresher about diffusion models and how guidance is applied.

### Diffusion model refresher

A diffusion model is described by a forward and reverse process. The forward process is when we start out with a clean image and repeatedly add Gaussian noise ( being the identity matrix)) to give use noisier and noisier images . This is described by the following:

where is a predefined monotonically increasing variance schedule. The forward process runs for a total of timesteps and finally ends with pure noise . The reverse process starts with pure noise and uses a neural network to repeatedly denoise the image giving us . The end of the reverse process gives us back our samples . This is described as follows:

where is the variance schedule for the reverse schedule and is the denoising neural network. Note that the denoising neural network can be reparameterized to predict the noise in the image. So instead of predicting the denoised image directly, we can predict the noise in the image and subtract it out to get . We train the reparameterized denoising neural network in the reverse diffusion process with a simple MSE loss:

In practice, training and sampling is quite simple. During each training step, a random image and timestep is selected, the forward process starts from till timestep to get using the noise , this is passed into our denoising model, and the MSE between the used to calculate and the predicted is optimized. During sampling, we start out with random Gaussian noise and the denoising neural network is repeatedly applied to give us until we reach our sample .

It is also worth noting the score matching intuition behind diffusion models. Specifically, turns out to actually be an estimate (up to a factor) of the “score function” . Basically, what this means is that when we sample from a diffusion model, we are iteratively taking steps in the direction of this score function, which is this gradient of the likelihood. So the sampling is very much like an optimization problem.

If any of this is unfamiliar to you, I recommend checking out the fast.ai course on the subject (I am somewhat biased though given I co-taught the class!).

Let’s now discuss how additional constraints and guidance can be added during diffusion sampling. Basically, we want to model where is some sort of condition or constraint (for example a class condition). In diffusion models, we could instead try to estimate and use this in sampling. This can be expressed differently using Bayes’ Rule:

The second term is our score function that is already being estimated by our diffusion model . The first term, however, is the gradient of the log likelihood of a classifier with respect to . What this means is that during diffusion model sampling, if we use a modified with the classifier gradient added to it, we can get samples that adhere to the desired condition.

More broadly, losses can be applied to the noisy images and its gradient can be added to the score function/denoising network output to try to obtain images that better adhere to some desired constraints. This is the idea behind CLIP-guided diffusion, for example, where the similarity between the CLIP text embedding of a prompt and the CLIP image embedding of the images during sampling are maximized.

Can we use guidance to get diffusion models to generate more aesthetic images that better adhere to user prompts? Potentially yes, but there are many challenges that may make it undesirable.

Strictly speaking, the proper way to perform classifier guidance is to use a classifier trained on either the noisy images or the predicted denoised images (which tend to be blurry, especially early on in sampling), which is what the original classifier guidance paper demonstrates. Note, it is possible to use classifiers and other models trained on regular images and get reasonable results for guidance, even though noisy or blurry images will likely be out-of-distribution. For example CLIP isn’t trained on noisy images but is used in CLIP-guided diffusion. But often to get reasonable results, various hacks and tricks are required (in the case of CLIP-guided diffusion, the use of this technique fell out of popularity once diffusion models properly conditioned on CLIP features like Stable Diffusion were developed, and CLIP guidance applied on top of models like Stable Diffusion often showed minimal benefit).

Additionally, note that guidance requires the calculation of whatever guidance loss we have and autograd of that loss at each step in the sampling process. This can add a significant overhead to the sampling time compared to guidance-free sampling. The situation is made worse with latent diffusion models like Stable Diffusion, where the latents often need to be decoded to full images in order to apply the classifier/loss, resulting in additional computational overhead.

For these reasons, we will instead apply reinforcement learning to obtain a diffusion model after optimizing arbitrary constraints (the reward function), such as aesthetic scores. As we will see, the generated images from the starting diffusion model are passed into the reward function, so there is no concern of images being out-of-distribution. Additionally, we will obtain a diffusion model that provides higher scoring images directly, not through sampling changes like guidance.

Okay let’s proceed with trying to apply RL to diffusion models. First we’ll create a dataset generator - Stable Diffusion generated images given some prompts. We’ll use animal prompts:

`!wget https://raw.githubusercontent.com/formigone/tf-imagenet/master/LOC_synset_mapping.txt`

```
--2023-06-13 09:54:06-- https://raw.githubusercontent.com/formigone/tf-imagenet/master/LOC_synset_mapping.txt
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.109.133, 185.199.110.133, 185.199.108.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.109.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 31675 (31K) [text/plain]
Saving to: ‘LOC_synset_mapping.txt.14’
LOC_synset_mapping. 100%[===================>] 30.93K --.-KB/s in 0.006s
2023-06-13 09:54:06 (5.24 MB/s) - ‘LOC_synset_mapping.txt.14’ saved [31675/31675]
```

```
synsets = {k:v for k,v in [o.split(',')[0].split(' ', maxsplit=1) for o in Path('LOC_synset_mapping.txt').read_text().splitlines()]}
imagenet_classes = list(synsets.values())
```

```
def imagenet_animal_prompts():
animal = random.choice(imagenet_classes[:397])
prompts = f'{animal}'
return prompts
```

`imagenet_animal_prompts()`

Put into a dataset class:

```
class PromptDataset(torch.utils.data.Dataset):
def __init__(self, prompt_fn, num):
super().__init__()
self.prompt_fn = prompt_fn
self.num = num
def __len__(self): return self.num
def __getitem__(self, x): return self.prompt_fn()
```

Next let’s set up our sampling loop. For simplicity, we’ll just use the DDIM scheduler:

```
pipe.scheduler = DDIMScheduler(
num_train_timesteps=pipe.scheduler.num_train_timesteps,
beta_start=pipe.scheduler.beta_start,
beta_end=pipe.scheduler.beta_end,
beta_schedule=pipe.scheduler.beta_schedule,
trained_betas=pipe.scheduler.trained_betas,
clip_sample=pipe.scheduler.clip_sample,
set_alpha_to_one=pipe.scheduler.set_alpha_to_one,
steps_offset=pipe.scheduler.steps_offset,
prediction_type=pipe.scheduler.prediction_type
)
```

Below we have a sampling function, that also gives us intermediate timesteps. Again this is a pretty standard diffusion sampling loop, check out HuggingFace blog post for more information.

```
@torch.no_grad()
def sd_sample(prompts, pipe, height, width, guidance_scale, num_inference_steps, eta, device):
scheduler = pipe.scheduler
unet = pipe.unet
text_embeddings = pipe._encode_prompt(prompts,device, 1, do_classifier_free_guidance=guidance_scale > 1.0)
scheduler.set_timesteps(num_inference_steps, device=device)
latents = torch.randn((len(prompts), unet.in_channels, height//8, width//8)).to(device)
all_step_preds = []
for i, t in enumerate(progress_bar(scheduler.timesteps)):
input = torch.cat([latents] * 2)
input = scheduler.scale_model_input(input, t)
# predict the noise residual
pred = unet(input, t, encoder_hidden_states=text_embeddings).sample
# perform guidance
pred_uncond, pred_text = pred.chunk(2)
pred = pred_uncond + guidance_scale * (pred_text - pred_uncond)
# compute the "previous" noisy sample
scheduler_output = scheduler.step(pred, t, latents, eta)
all_step_preds.append(scheduler_output)
latents = scheduler_output.prev_sample
return latents, all_step_preds
```

`preds, all_step_preds = sd_sample([prompt]*2, pipe, 512, 512, 7.5, 50, 1, 'cuda')`

The sampling function only gives us latents, and they need to be decoded by a VAE to get the final images:

```
@torch.no_grad()
def decoding_fn(latents,pipe):
images = pipe.vae.decode(1 / 0.18215 * latents.cuda()).sample
images = (images / 2 + 0.5).clamp(0, 1)
images = images.detach().cpu().permute(0, 2, 3, 1).numpy()
images = (images * 255).round().astype("uint8")
return images
```

`Image.fromarray(decoding_fn(preds,pipe)[0])`

Let’s again calculate the aesthetic score. We have to make a slight modification to our `aesthetic_scoring`

function so it can handle batches.

```
def aesthetic_scoring(imgs, preprocess, clip_model, aesthetic_model_normalize, aesthetic_model):
imgs = torch.stack([preprocess(Image.fromarray(img)).cuda() for img in imgs])
with torch.no_grad(): image_features = clip_model.encode_image(imgs)
im_emb_arr = aesthetic_model_normalize(image_features.cpu().detach().numpy())
prediction = aesthetic_model(torch.from_numpy(im_emb_arr).float().cuda())
return prediction
```

```
imgs = decoding_fn(preds,pipe)
aesthetic_scoring(imgs, preprocess, clip_model, aesthetic_model_normalize, aesthetic_model)
```

```
tensor([[5.1635],
[5.4475]], device='cuda:0', grad_fn=<AddmmBackward0>)
```

Now we can initialize our dataset generator, which provides prompts that we sample with, generate images, and pass into the aesthetic predictor in just a few lines of code:

```
train_set = PromptDataset(imagenet_animal_prompts, 1000)
train_dl = torch.utils.data.DataLoader(train_set, batch_size=2, shuffle=True, num_workers=0)
```

```
prompts = next(iter(train_dl))
preds, all_step_preds = sd_sample(prompts, pipe, 512, 512, 7.5, 50, 1, 'cuda')
imgs = decoding_fn(preds,pipe)
rewards = aesthetic_scoring(imgs, preprocess, clip_model, aesthetic_model_normalize, aesthetic_model)
```

```
index = torch.where(rewards == rewards.min())[0][0]
print(prompts[index])
Image.fromarray(imgs[index])
```

```
index = torch.where(rewards == rewards.max())[0][0]
print(prompts[index])
Image.fromarray(imgs[index])
```

Once again, the aesthetic predictor provides *rewards* and these are used in optimization with RL. Our goal is to maximize the reward!

For stability during RL training, the rewards are usually normalized. There are different ways of doing this, but the DDPO paper utilizes a simple approach, which is to normalize based on unique prompts. Basically, a queue is set up with prompts and the corresponding rewards, which some given buffer size. When a new (prompt, reward) pair is obtained, it is added to the queue. If the queue for that unique prompt is not long enough, the reward is just normalized over the whole batch. Else, it is normalized based on the statistics for that specific prompt. This is implemented below:

```
from collections import deque
class PerPromptStatTracker:
def __init__(self, buffer_size, min_count):
self.buffer_size = buffer_size
self.min_count = min_count
self.stats = {}
def update(self, prompts, rewards):
unique = np.unique(prompts)
advantages = np.empty_like(rewards)
for prompt in unique:
prompt_rewards = rewards[prompts == prompt]
if prompt not in self.stats:
self.stats[prompt] = deque(maxlen=self.buffer_size)
self.stats[prompt].extend(prompt_rewards)
if len(self.stats[prompt]) < self.min_count:
mean = np.mean(rewards)
std = np.std(rewards) + 1e-6
else:
mean = np.mean(self.stats[prompt])
std = np.std(self.stats[prompt]) + 1e-6
advantages[prompts == prompt] = (prompt_rewards - mean) / std
return advantages
```

`per_prompt_stat_tracker = PerPromptStatTracker(buffer_size=32, min_count=16)`

`tensor([5.4786, 5.5764], grad_fn=<ToCopyBackward0>)`

`advantages = per_prompt_stat_tracker.update(np.array(prompts), rewards.squeeze().cpu().detach().numpy())`

`array([-0.99998444, 0.9999747 ], dtype=float32)`

`per_prompt_stat_tracker.stats`

```
{'garter snake': deque([5.576433], maxlen=32),
'polecat': deque([5.478563], maxlen=32)}
```

We have the reward normalization set up. Let’s look at some of the intermediate timesteps:

`Image.fromarray(decoding_fn(all_step_preds[0].prev_sample, pipe)[0])`

`Image.fromarray(decoding_fn(all_step_preds[30].prev_sample, pipe)[0])`

As expected, we get pure noise (in the latent space, it looks somewhat different when decoded by the Stable Diffusion VAE) at the beginning of sampling, and as sampling progresses, the images starts to take shape.

## The DDPO algorithm

Okay let’s now dig into how reinforcement learning works and derive the DDPO objective. A reminder that what we want to do is to maximize the reward signal. We can mathematically express this as follows:

where is the weights of our diffusion model, is some conditioning for the diffusion model, and is our reward function.

It would be nice to directly maximize for and if our model was a single evaluation of a neural network, we could simply backpropagate through the neural network and use an optimizer to update the weights. But that’s not what happens with a diffusion model! We have multiple timesteps for which we apply our denoising neural network. This constructs a **trajectory** as its known in the RL literature. In standard RL literature, our trajectory is composed of **states** and **actions**. A model that we are optimizing provides the next action given the current state, and this model is referred to as the **policy**. This framework is known as a **Markov Decision Process (MDP)**. Note that in the general MDP framework, a reward is usually given after each action, and we optimize over the sum of rewards over the whole trajectory.

We can easily describe diffusion models as an MDP, which will allow us to use standard results in RL for diffusion model optimization.

is the state, which is just the current noisy image (along with the timestep and condition info). is the action, which is the slightly less noisy image . is the policy that takes the current state and provides the next action, which is our diffusion model . is the distribution of the initial states, which in this case is the same distribution for , a standard isotropic normal distribution, with the timestep always being and the conditioning having whatever prior distribution as in the dataset. is giving given the current state and action, and is basically just saying it always goes to the state associated with . is the reward after each action/state (which is zero until the very last step when the generation is complete).The entire trajectory can be represented as and the probability density for trajectories as Note this is different from the diffusion model as it is a function of , a bit confusing! So going forward the diffusion model is referred to as .

Okay, with this framework in place it becomes trivial to apply standard **policy gradient** optimization methods like **REINFORCE** and **proximal policy optimization (PPO)**. Let’s go over these two algorithms now. I have also written a more principled from-scratch derivation over [here]

We’ll start by looking at the general case of taking the gradient of the expectation over of some function .

This is referred to as the *score function gradient estimator*.

Let’s think more about what this is doing. We want to calculate the gradient of with respect to . That is, we want to know how we can change such that we get samples from that on average give higher values. What this estimator says is that we can take the gradient (which tells you how to change to increase the likelihood of under your distribtuion ), and weight it with . So if this is being used in gradient ascent for example, we are placing more weight on updates that make high-scoring samples more likely under our model .

When we apply this to the MDP framework and simplify, we can get our **policy gradient**:

This gradient is referred to as the **REINFORCE** gradient and is only one type of policy gradient that could be used. Of course, this policy gradient is then used to update the weights of our model using gradient ascent:

where is some learning rate.

One implementation point is that the expectation is over the trajectories but of course we can’t take and sum over all possible trajectories. The expectation is estimated with just the sampled trajectories in the currenty batch. One other implementation point to mention: we could calculate our gradient and then pass that gradient to our optimizer, or we could let autograd handle the calculation of the gradient by constructing a loss function and treating it as a standard training loop. The latter is what is done in practice even though it is not explicitly mentioned often in the papers. So our loss function is:

Again, let’s reinforce the intuition behind REINFORCE gradient/loss function (pun definitely intended). You can see the loss looks very much like a negative log-likelihood loss, with the actions as our target. The diference here is that it is weighted by the reward. What this loss is doing is trying to make high-reward trajectories more likely and low-reward trajectories less likely.

Okay so we can simply plug in our diffusion model terms based on how it fits into the MDP framework, which we described earlier. We get:

This objective and gradient estimator is referred in the paper as **DDPO**_{SF}.

One challenge with this approach is that for each optimization step, the sampling from the current iteration of the model needs to be performed, we need to re-calculate as it comes from the current version of the model. This can be computationally expensive and wasteful, as the samples collected with previous iterations of the model cannot be used to learn.

One trick to address this is known as importance sampling. This relies on the following identity (trivial to demonstrate based on the expectation definition):

Applying importance sampling gives us the following:

Again this can be written down as a loss function that we perform gradient descent on (sometimes referred to as the **surrogate loss**):

Again, we can plug in the diffusion model terms based on the MDP framework and get this loss function:

Minimization of this loss function is equivalent to gradient with the following gradient:

Note that the reward is usually normalized, and the normalized reward is referred to the advantage . So the advantage can be negative if it’s less than average.

Note that we don’t want current policy diverge too much from the previous policy, otherwise we may diverge and get a bad policy. To help address this, we can apply clipping to the importance sampling ratio to the loss function:

So if the policy diverges too much (the ratio is either much larger or much smaller than 1) the loss function is clipped to a certain value and the gradient will be zero and no updates will be made. The below diagram (taken from here) clarifies this further:

Note that DDPO also clips the advantages themselves $ A(_0,) $ in its implementation but this is not described in the paper, so I have not included it in the loss function.

The loss function can be written in a way that’s numerically easier to calculate/more stable (using logs, ignoring the clipping for now):

The objective and gradient estimator described here is referred in the paper as **DDPO**_{IS}. It is pretty much the same as **proximal policy optimization (PPO)**, applied to diffusion models.

For a more complete derivation of the DDPO objective from scratch, see here.

In order to start implementing this loss function, let’s calculate the log probs, which is easy for a normal distribution:

```
def calculate_log_probs(prev_sample, prev_sample_mean, std_dev_t):
std_dev_t = torch.clip(std_dev_t, 1e-6)
log_probs = -((prev_sample.detach() - prev_sample_mean) ** 2) / (2 * std_dev_t ** 2) - torch.log(std_dev_t) - math.log(math.sqrt(2 * math.pi))
return log_probs
```

We need to get those log probs of the original model so our sampling function should return that. Let’s update our sampling function to do that:

```
@torch.no_grad()
def sd_sample(prompts, pipe, height, width, guidance_scale, num_inference_steps, eta, device):
scheduler = pipe.scheduler
unet = pipe.unet
text_embeddings = pipe._encode_prompt(prompts,device, 1, do_classifier_free_guidance=guidance_scale > 1.0)
scheduler.set_timesteps(num_inference_steps, device=device)
latents = torch.randn((len(prompts), unet.in_channels, height//8, width//8)).to(device)
all_step_preds, log_probs = [latents], []
for i, t in enumerate(progress_bar(scheduler.timesteps)):
input = torch.cat([latents] * 2)
input = scheduler.scale_model_input(input, t)
# predict the noise residual
pred = unet(input, t, encoder_hidden_states=text_embeddings).sample
# perform guidance
pred_uncond, pred_text = pred.chunk(2)
pred = pred_uncond + guidance_scale * (pred_text - pred_uncond)
# compute the "previous" noisy sample mean and variance, and get log probs
scheduler_output = scheduler.step(pred, t, latents, eta, variance_noise=0)
t_1 = t - scheduler.config.num_train_timesteps // num_inference_steps
variance = scheduler._get_variance(t, t_1)
std_dev_t = eta * variance ** (0.5)
prev_sample_mean = scheduler_output.prev_sample # this is the mean and not full sample since variance is 0
prev_sample = prev_sample_mean + torch.randn_like(prev_sample_mean) * std_dev_t # get full sample by adding noise
log_probs.append(calculate_log_probs(prev_sample, prev_sample_mean, std_dev_t).mean(dim=tuple(range(1, prev_sample_mean.ndim))))
all_step_preds.append(prev_sample)
latents = prev_sample
return latents, torch.stack(all_step_preds), torch.stack(log_probs)
```

We can get everything we need for the loss function now (intermediate timesteps, log_probs, rewards):

```
per_prompt_stat_tracker = PerPromptStatTracker(buffer_size=32, min_count=16)
prompts = next(iter(train_dl))
pipe.text_encoder.to('cuda')
pipe.vae.to('cuda')
preds, all_step_preds, log_probs = sd_sample(prompts, pipe, 512, 512, 7.5, 50, 1, 'cuda')
imgs = decoding_fn(preds,pipe)
rewards = aesthetic_scoring(imgs, preprocess, clip_model, aesthetic_model_normalize, aesthetic_model)
advantages = torch.from_numpy(per_prompt_stat_tracker.update(np.array(prompts), rewards.squeeze().cpu().detach().numpy())).float().to('cuda')
```

Here’s a function to compute the loss function:

```
def compute_loss(x_t, original_log_probs, advantages, clip_advantages, clip_ratio, prompts, pipe, num_inference_steps, guidance_scale, eta, device):
scheduler = pipe.scheduler
unet = pipe.unet
text_embeddings = pipe._encode_prompt(prompts,device, 1, do_classifier_free_guidance=guidance_scale > 1.0).detach()
scheduler.set_timesteps(num_inference_steps, device=device)
loss_value = 0.
for i, t in enumerate(progress_bar(scheduler.timesteps)):
clipped_advantages = torch.clip(advantages, -clip_advantages, clip_advantages).detach()
input = torch.cat([x_t[i].detach()] * 2)
input = scheduler.scale_model_input(input, t)
# predict the noise residual
pred = unet(input, t, encoder_hidden_states=text_embeddings).sample
# perform guidance
pred_uncond, pred_text = pred.chunk(2)
pred = pred_uncond + guidance_scale * (pred_text - pred_uncond)
# compute the "previous" noisy sample mean and variance, and get log probs
scheduler_output = scheduler.step(pred, t, x_t[i].detach(), eta, variance_noise=0)
t_1 = t - scheduler.config.num_train_timesteps // num_inference_steps
variance = scheduler._get_variance(t, t_1)
std_dev_t = eta * variance ** (0.5)
prev_sample_mean = scheduler_output.prev_sample
current_log_probs = calculate_log_probs(x_t[i+1].detach(), prev_sample_mean, std_dev_t).mean(dim=tuple(range(1, prev_sample_mean.ndim)))
# calculate loss
ratio = torch.exp(current_log_probs - original_log_probs[i].detach()) # this is the importance ratio of the new policy to the old policy
unclipped_loss = -clipped_advantages * ratio # this is the surrogate loss
clipped_loss = -clipped_advantages * torch.clip(ratio, 1. - clip_ratio, 1. + clip_ratio) # this is the surrogate loss, but with artificially clipped ratios
loss = torch.max(unclipped_loss, clipped_loss).mean() # we take the max of the clipped and unclipped surrogate losses, and take the mean over the batch
loss.backward() # perform backward here, gets accumulated for all the timesteps
loss_value += loss.item()
return loss_value
```

```
loss = compute_loss(all_step_preds, log_probs, advantages, 10, 1e-4, prompts, pipe, 50, 7.5, 1, 'cuda')
print(loss)
```

100.00% [50/50 00:29<00:00]

## Complete training loop

Now that we can calculate the loss function, we can construct the full training loop. For a single epoch we:

Sample from the diffusion model `num_samples_per_epoch`

times, collecting the intermediate noisy images and log probs.

Pass the samples to the reward model and get reward, which we normalize to get advantage.

For `num_inner_epochs`

times, we go over each sample compute the loss, backpropagate, and update our diffusion model.

Let’s define all our hyperparameters. We will be training Stable Diffusion v1.4 on ImageNet animal prompts as defined earlier, using the LAION Aesthetic classifier as our reward model.

Note that if we set `num_inner_epochs`

to a high amount, this would be very data-efficient since we would be repeatedly using the previously generated trajectories, but we would probably significantly diverge from the original policy that we used to get those trajectories (or it would at least get clipped frequently in the loss). So we set `num_inner_epochs=1`

. This is still pretty efficient using DDPO_{IS} because otherwise with DDPO_{SF} you’d need to resample after every iteration (when model is updated) instead of after `num_samples_per_epoch=128`

that we have here.

```
num_samples_per_epoch = 128
num_epochs = 50
num_inner_epochs = 1
num_timesteps = 50
batch_size = 4
img_size = 512
lr = 5e-6
clip_advantages = 10.0
clip_ratio = 1e-4
cfg = 5.0
```

Okay let’s set everything up:

```
# group all reward function stuff
def reward_fn(imgs, device):
clip_model.to(device)
aesthetic_model.to(device)
rewards = aesthetic_scoring(imgs, preprocess, clip_model, aesthetic_model_normalize, aesthetic_model)
clip_model.to('cpu')
aesthetic_model.to('cpu')
return rewards
```

```
# a function to sample from the model and calculate rewards
def sample_and_calculate_rewards(prompts, pipe, image_size, cfg, num_timesteps, decoding_fn, reward_fn, device):
preds, all_step_preds, log_probs = sd_sample(prompts, pipe, image_size, image_size, cfg, num_timesteps, 1, device)
imgs = decoding_fn(preds,pipe)
rewards = reward_fn(imgs, device)
return imgs, rewards, all_step_preds, log_probs
```

Here we create our dataset, which is just randomly chosen prompts:

```
train_set = PromptDataset(imagenet_animal_prompts, num_samples_per_epoch)
train_dl = torch.utils.data.DataLoader(train_set, batch_size=32, shuffle=True, num_workers=0)
per_prompt_stat_tracker = PerPromptStatTracker(buffer_size=32, min_count=16)
sample_prompts = next(iter(train_dl)) # sample a batch of prompts to use for visualization
```

`pipe.unet.enable_gradient_checkpointing() # more performance optimization`

`optimizer = torch.optim.AdamW(pipe.unet.parameters(), lr=lr, weight_decay=1e-4) # optimizer`

Now that we are set up, let’s start training! You can see the training loop is quite simple!

```
for epoch in master_bar(range(num_epochs)):
all_step_preds, log_probs, advantages, all_prompts, all_rewards = [], [], [], [], []
# sampling `num_samples_per_epoch` images and calculating rewards
for i, prompts in enumerate(progress_bar(train_dl)):
batch_imgs, rewards, batch_all_step_preds, batch_log_probs = sample_and_calculate_rewards(prompts, pipe, img_size, cfg, num_timesteps, decoding_fn, reward_fn, 'cuda')
batch_advantages = torch.from_numpy(per_prompt_stat_tracker.update(np.array(prompts), rewards.squeeze().cpu().detach().numpy())).float().to('cuda')
all_step_preds.append(batch_all_step_preds)
log_probs.append(batch_log_probs)
advantages.append(batch_advantages)
all_prompts += prompts
all_rewards.append(rewards)
all_step_preds = torch.cat(all_step_preds, dim=1)
log_probs = torch.cat(log_probs, dim=1)
advantages = torch.cat(advantages)
all_rewards = torch.cat(all_rewards)
# inner loop
for inner_epoch in progress_bar(range(num_inner_epochs)):
# chunk them into batches
all_step_preds_chunked = torch.chunk(all_step_preds, num_samples_per_epoch // batch_size, dim=1)
log_probs_chunked = torch.chunk(log_probs, num_samples_per_epoch // batch_size, dim=1)
advantages_chunked = torch.chunk(advantages, num_samples_per_epoch // batch_size, dim=0)
# chunk the prompts (list of strings) into batches
all_prompts_chunked = [all_prompts[i:i + batch_size] for i in range(0, len(all_prompts), batch_size)]
for i in progress_bar(range(len(all_step_preds_chunked))):
optimizer.zero_grad()
loss = compute_loss(all_step_preds_chunked[i], log_probs_chunked[i],
advantages_chunked[i], clip_advantages, clip_ratio, all_prompts_chunked[i], pipe, num_timesteps, cfg, 1, 'cuda'
) # loss.backward happens inside
torch.nn.utils.clip_grad_norm_(pipe.unet.parameters(), 1.0) # gradient clipping
optimizer.step()
```

That’s it! Let’s see what results we get with this code.

## Results

(These results were obtained with this script that included W&B tracking)

Here is the loss curve from training Stable Diffusion v1.4 with the LAION Aesthetic classifier reward model on ImageNet animal prompts:

As you can see, it’s quite noisy but clearly does decrease. That said, I’ve observed that typically the loss curve can be quite uninformative. Instead, as expected, the reward curve is a better indicator of the performance:

Here you can see a clear increase in average reward over training, which is what we want! So what do the samples look like? Let’s see:

I’d argue these images are definitely more aesthetic! It works! A few observations:

Sometimes the prompt isn’t being followed, as seen with the wolf spider example. This is since the reward model does not take into consideration the prompt and only looks at the image, so if generating something that is slightly unrelated to the prompt gives a better reward score then the model will do so. A reward model explicitly taking in the prompt as well and ensuring prompt alignment would be needed. One such model that could be used is PickScore.

Sometimes the RL-trained diffusion model generates pencil/charcoal drawings of the animals, which was observed in the original DDPO paper as well.

Another common pattern that the RL-trained diffusion model demonstrates is the generation of narrow depth-of-focus images, which of course clearly looks more artistic and aesthetic.

## Conclusion

Note that much of what we discuss here for how RL is applied to diffusion models also applies to language models. Specifically, the autoregressive generation of text from a language model can be viewed as a trajectory from an MDP. The state is the previous tokens, the action is the next token to be predicted, and the policy is of course the language model. The PPO algorithm that we described is the most common RL algorithm for RLHF training of language models. Overall, this hints to a deeper connection between diffusion models and autoregressive language models and how ideas can be transferred from one domain to another. For example, recently it was demonstrated how classifier-free guidance could be applied to language models. There may continue to be interesting ideas to apply from diffusion models to language models or vice versa.

This paper only is the start of applying RL to diffusion models. DDPO_{IS} is only a baseline, and there are changes that can easily be made to improve the performance, such as value function baselines, reward discounting, KL regularization, etc. Additionally, alternative RLHF algorithms like direct preference optimization could be explored. We plan to explore these directions further via DRLX.

Finally, I just want to say this was an interesting learning experience implementing my very first RL algorithm (and building a library based on that too). It took me a lot of reading of RL course material, blog posts, etc. looking at code implementations, wrestling with very subtle bugs, etc. but after about 2 weeks I managed to get it working. As many people in the RL field have experienced, it didn’t take much time for me to form a love-hate relationship with RL 😂 but I still think it’s a very interesting field and I’m excited to explore it further.

## Acknowledgements

Thank you to Costa Huang who helped me debug my code and providing feedback on the blog post. Thank you to Jonathan Whitaker and inox/hayley for providing feedback on my blog post. Thank you to Shahbuland Matiana who I worked closely with on DRLX.

]]>