Notebooks
H
Hugging Face
02 Class Conditioned Diffusion Model Example

02 Class Conditioned Diffusion Model Example

unit2hf-diffusion-models-class

Making a Class-Conditioned Diffusion Model

In this notebook we're going to illustrate one way to add conditioning information to a diffusion model. Specifically, we'll train a class-conditioned diffusion model on MNIST following on from the 'from-scratch' example in Unit 1, where we can specify which digit we'd like the model to generate at inference time.

As mentioned in the introduction to this unit, this is just one of many ways we could add additional conditioning information to a diffusion model, and has been chosen for its relative simplicity. Just like the 'from-scratch' notebook in Unit 1, this notebook is mostly for illustrative purposes and you can safely skip it if you'd like.

Setup and Data Prep

[ ]
     |████████████████████████████████| 503 kB 7.2 MB/s 
     |████████████████████████████████| 182 kB 51.3 MB/s 
[ ]
Using device: cuda
[ ]
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to mnist/MNIST/raw/train-images-idx3-ubyte.gz
  0%|          | 0/9912422 [00:00<?, ?it/s]
Extracting mnist/MNIST/raw/train-images-idx3-ubyte.gz to mnist/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to mnist/MNIST/raw/train-labels-idx1-ubyte.gz
  0%|          | 0/28881 [00:00<?, ?it/s]
Extracting mnist/MNIST/raw/train-labels-idx1-ubyte.gz to mnist/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to mnist/MNIST/raw/t10k-images-idx3-ubyte.gz
  0%|          | 0/1648877 [00:00<?, ?it/s]
Extracting mnist/MNIST/raw/t10k-images-idx3-ubyte.gz to mnist/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to mnist/MNIST/raw/t10k-labels-idx1-ubyte.gz
  0%|          | 0/4542 [00:00<?, ?it/s]
Extracting mnist/MNIST/raw/t10k-labels-idx1-ubyte.gz to mnist/MNIST/raw

Input shape: torch.Size([8, 1, 28, 28])
Labels: tensor([8, 1, 5, 9, 7, 6, 2, 2])
Output

Creating a Class-Conditioned UNet

The way we'll feed in the class conditioning is as follows:

  • Create a standard UNet2DModel with some additional input channels
  • Map the class label to a learned vector of shape (class_emb_size)via an embedding layer
  • Concatenate this information as extra channels for the internal UNet input with net_input = torch.cat((x, class_cond), 1)
  • Feed this net_input (which has (class_emb_size+1) channels in total) into the UNet to get the final prediction

In this example I've set the class_emb_size to 4, but this is completely arbitrary and you could explore having it size 1 (to see if it still works), size 10 (to match the number of classes), or replacing the learned nn.Embedding with a simple one-hot encoding of the class label directly.

This is what the implementation looks like:

[ ]

If any of the shapes or transforms are confusing, add in print statements to show the relevant shapes and check that they match your expectations. I've also annotated the shapes of some intermediate variables in the hopes of making things clearer.

Training and Sampling

Where previously we'd do something like prediction = unet(x, t) we'll now add the correct labels as a third argument (prediction = unet(x, t, y)) during training, and at inference we can pass whatever labels we want and if all goes well the model should generate images that match. y in this case is the labels of the MNIST digits, with values from 0 to 9.

The training loop is very similar to the example from Unit 1. We're now predicting the noise (rather than the denoised image as in Unit 1) to match the objective expected by the default DDPMScheduler which we're using to add noise during training and to generate samples at inference time. Training takes a while - speeding this up could be a fun mini-project, but most of you can probably just skim the code (and indeed this whole notebook) without running it since we're just illustrating an idea.

[ ]
[ ]
  0%|          | 0/469 [00:00<?, ?it/s]
Finished epoch 0. Average of the last 100 loss values: 0.052451
  0%|          | 0/469 [00:00<?, ?it/s]
Finished epoch 1. Average of the last 100 loss values: 0.045999
  0%|          | 0/469 [00:00<?, ?it/s]
Finished epoch 2. Average of the last 100 loss values: 0.043344
  0%|          | 0/469 [00:00<?, ?it/s]
Finished epoch 3. Average of the last 100 loss values: 0.042347
  0%|          | 0/469 [00:00<?, ?it/s]
Finished epoch 4. Average of the last 100 loss values: 0.041174
  0%|          | 0/469 [00:00<?, ?it/s]
Finished epoch 5. Average of the last 100 loss values: 0.040736
  0%|          | 0/469 [00:00<?, ?it/s]
Finished epoch 6. Average of the last 100 loss values: 0.040386
  0%|          | 0/469 [00:00<?, ?it/s]
Finished epoch 7. Average of the last 100 loss values: 0.039372
  0%|          | 0/469 [00:00<?, ?it/s]
Finished epoch 8. Average of the last 100 loss values: 0.039056
  0%|          | 0/469 [00:00<?, ?it/s]
Finished epoch 9. Average of the last 100 loss values: 0.039024
[<matplotlib.lines.Line2D>]
Output

Once training finishes, we can sample some images feeding in different labels as our conditioning:

[ ]
0it [00:00, ?it/s]
<matplotlib.image.AxesImage>
Output

There we go! We can now have some control over what images are produced.

I hope you've enjoyed this example. As always, feel free to ask questions in the Discord.

[ ]