Colorization¶
This example demonstrates how to perform image colorization using the generative model.
Setup¶
Import the necessary modules and initialize the generative model.
import torch
from torchvision import datasets, transforms
from torchvision.transforms import ToTensor
from image_gen import GenerativeModel
# Initialize a generative model with Variance Exploding diffusion and Euler-Maruyama sampler
model = GenerativeModel(diffusion="ve", sampler="euler-maruyama")
Training¶
Load a dataset and train the model.
# Load CIFAR-10 dataset
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
data = datasets.CIFAR10(
root='data',
train=True,
download=True,
transform=transform
)
# Select a subset for faster training
targets = torch.tensor(data.targets)
idx = (targets == 1).nonzero().flatten()
data = torch.utils.data.Subset(data, idx)
# Train the model
model.train(data, epochs=500, batch_size=32)
Colorization¶
Colorize a grayscale image using the trained model.
# Generate a base image
generated_image = model.generate(num_samples=1)
# Convert to grayscale
gray_image = torch.mean(generated_image, dim=1, keepdim=True)
# Display original and grayscale images
display_images(generated_image)
display_images(gray_image)
# Colorize the grayscale image
colorized = model.colorize(gray_image)
display_images(colorized)
# Generate multiple color variations
gray_batch = gray_image.repeat(16, 1, 1, 1)
colorized_batch = model.colorize(gray_batch)
display_images(colorized_batch)