Skip to content

Conditional Generation

This example demonstrates how to perform conditional generation using class labels.

Setup

Import the necessary modules and initialize the generative model.

import torch
from torchvision import datasets
from torchvision.transforms import ToTensor
from image_gen import GenerativeModel

# Initialize a generative model with Variance Preserving diffusion and Exponential Integrator sampler
model = GenerativeModel(
    diffusion="vp",
    sampler="exponential",
    noise_schedule="linear"
)

Training

Load a dataset and train the model.

# Load MNIST dataset
data = datasets.MNIST(
    root='data',
    train=True,
    download=True,
    transform=ToTensor()
)

# Train the model
model.train(data, epochs=50, batch_size=32)

Conditional Generation

Generate images conditioned on specific class labels.

# Generate 16 samples from class 7
samples = model.generate(num_samples=16, class_labels=7)
display_images(samples)

# Generate specific classes for each sample
labels = torch.repeat_interleave(torch.arange(0, model.num_classes), 2)
samples = model.generate(num_samples=len(labels), class_labels=labels)
display_images(samples)

Visualization

Visualize the effect of different guidance scales.

from image_gen.visualization import create_evolution_widget

# Create an animation showing the generation process for class 9
animation = create_evolution_widget(model, class_labels=9)
animation

# Compare different guidance scales
fig, axs = plt.subplots(2, 4, figsize=(15, 8))
for i, scale in enumerate([0, 0.5, 1, 2, 3, 5, 7.5, 10]):
    samples = model.generate(num_samples=1, class_labels=6, guidance_scale=scale)
    axs[i//4, i%4].imshow(samples[0].permute(1, 2, 0), cmap="gray")
    axs[i//4, i%4].set_title(f'Scale={scale}')
plt.tight_layout()
plt.show()