Skip to content

Fréchet Inception Distance

Module for Fréchet Inception Distance (FID) metric implementation.

FrechetInceptionDistance

Bases: BaseMetric

Fréchet Inception Distance (FID) for evaluating generative models.

This metric measures the distance between the feature representations of real and generated images using the Inception-v3 model. Lower values indicate better quality and diversity.

Attributes:

Name Type Description
model

The generative model being evaluated.

dims

Dimensionality of Inception features to use.

inception

Pretrained Inception-v3 model used for feature extraction.

Source code in image_gen\metrics\fid.py
class FrechetInceptionDistance(BaseMetric):
    """Fréchet Inception Distance (FID) for evaluating generative models.

    This metric measures the distance between the feature representations
    of real and generated images using the Inception-v3 model.
    Lower values indicate better quality and diversity.

    Attributes:
        model: The generative model being evaluated.
        dims: Dimensionality of Inception features to use.
        inception: Pretrained Inception-v3 model used for feature extraction.
    """

    def __init__(self, model: GenerativeModel, dims: int = 2048):
        """Initialize the FID metric with model and feature dimensions.

        Args:
            model: The generative model to be evaluated.
            dims: The feature dimension to use from the Inception model.
        """
        super().__init__(model)
        self.dims = dims
        self.inception = self._get_inception()

    def _get_inception(self):
        """Create and prepare the Inception-v3 model for feature extraction.

        Returns:
            The prepared Inception-v3 model with identity final layer.
        """
        inception = inception_v3(
            weights=Inception_V3_Weights.IMAGENET1K_V1, transform_input=False)
        inception.fc = nn.Identity()  # Remove final FC layer to get features
        inception.eval()
        inception.to(self.model.device)
        return inception

    def _get_activations(self, images: Tensor) -> np.ndarray:
        """Extract Inception features from input images.

        Args:
            images: Batch of images to process.

        Returns:
            NumPy array containing feature activations.
        """
        # Convert grayscale to RGB if needed
        if images.shape[1] == 1:
            images = images.repeat(1, 3, 1, 1)

        # Resize images to Inception input size
        if images.shape[2] != 299 or images.shape[3] != 299:
            images = F.interpolate(
                images,
                size=(299, 299),
                mode='bilinear',
                align_corners=True
            )

        # Scale from [-1, 1] to [0, 1] range if needed
        if images.min() < 0:
            images = (images + 1) / 2

        # Ensure values are in [0, 1]
        images = torch.clamp(images, 0, 1)

        # Extract features in smaller batches to avoid OOM errors
        batch_size = 32
        activations = []

        for i in range(0, images.shape[0], batch_size):
            batch = images[i:i+batch_size]
            with torch.no_grad():
                try:
                    batch_activations = self.inception(batch)
                    activations.append(batch_activations)
                except Exception as e:
                    print(f"Error processing batch: {e}")
                    # Return a fallback value if feature extraction fails
                    return np.random.randn(images.shape[0], self.dims)

        if not activations:
            return np.random.randn(images.shape[0], self.dims)

        activations = torch.cat(activations, 0)
        return activations.detach().cpu().numpy()

    def _calculate_fid(
        self,
        real_activations: np.ndarray,
        gen_activations: np.ndarray
    ) -> float:
        """Calculate FID score from feature activations.

        Args:
            real_activations: Feature activations from real images.
            gen_activations: Feature activations from generated images.

        Returns:
            The calculated FID score.
        """
        # Calculate mean and covariance
        mu1, sigma1 = real_activations.mean(
            axis=0), np.cov(real_activations, rowvar=False)
        mu2, sigma2 = gen_activations.mean(
            axis=0), np.cov(gen_activations, rowvar=False)

        # Handle numerical stability
        eps = 1e-6
        sigma1 += np.eye(sigma1.shape[0]) * eps
        sigma2 += np.eye(sigma2.shape[0]) * eps

        # Calculate FID
        ssdiff = np.sum((mu1 - mu2) ** 2)

        # Use scipy's sqrtm with error handling
        try:
            covmean_sqrt = linalg.sqrtm(sigma1.dot(sigma2), disp=False)[0]
            if np.iscomplexobj(covmean_sqrt):
                covmean_sqrt = covmean_sqrt.real
            fid = ssdiff + np.trace(sigma1 + sigma2 - 2 * covmean_sqrt)
        except Exception as e:
            print(f"Error calculating matrix square root: {e}")
            # Fallback to simpler calculation if sqrtm fails
            fid = ssdiff + np.trace(sigma1 + sigma2)

        return float(fid)

    def __call__(
        self,
        real: Union[Tensor, torch.utils.data.Dataset],
        generated: Union[Tensor, torch.utils.data.Dataset],
        *_,
        **__
    ) -> float:
        """Compute the FID score between real and generated images.

        Args:
            real: Tensor or Dataset-like containing real images (B, C, H, W).
            generated: Tensor or Dataset-like containing generated images 
                      (B, C, H, W).

        Returns:
            The computed FID score (lower is better).
        """
        def to_tensor(data):
            """Convert various inputs to tensors.

            Args:
                data: Input data (tensor or dataset).

            Returns:
                Tensor representation of the input data.
            """
            if isinstance(data, Tensor):
                return data
            dataloader = DataLoader(data, batch_size=64, shuffle=False)
            batch = next(iter(dataloader))
            return batch[0] if isinstance(batch, (list, tuple)) else batch

        real = to_tensor(real).to(self.model.device)
        generated = to_tensor(generated).to(self.model.device)

        # Process in batches to avoid memory issues
        real_activations = self._get_activations(real)
        gen_activations = self._get_activations(generated)

        # Ensure we have enough samples for covariance calculation
        if real_activations.shape[0] < 2 or gen_activations.shape[0] < 2:
            print("Warning: Need at least 2 samples for FID calculation")
            return float('nan')

        # If dimensions don't match (rare case), resize activations
        if real_activations.shape[0] != gen_activations.shape[0]:
            min_size = min(real_activations.shape[0], gen_activations.shape[0])
            real_activations = real_activations[:min_size]
            gen_activations = gen_activations[:min_size]

        return self._calculate_fid(real_activations, gen_activations)

    @property
    def name(self) -> str:
        """Get the name of the metric.

        Returns:
            The name of the metric.
        """
        return "Fréchet Inception Distance"

    @property
    def is_lower_better(self) -> bool:
        """Indicates whether a lower metric value is better.

        Returns:
            True if lower values indicate better performance.
        """
        return True

is_lower_better property

Indicates whether a lower metric value is better.

Returns:

Type Description
bool

True if lower values indicate better performance.

name property

Get the name of the metric.

Returns:

Type Description
str

The name of the metric.

__call__(real, generated, *_, **__)

Compute the FID score between real and generated images.

Parameters:

Name Type Description Default
real Union[Tensor, Dataset]

Tensor or Dataset-like containing real images (B, C, H, W).

required
generated Union[Tensor, Dataset]

Tensor or Dataset-like containing generated images (B, C, H, W).

required

Returns:

Type Description
float

The computed FID score (lower is better).

Source code in image_gen\metrics\fid.py
def __call__(
    self,
    real: Union[Tensor, torch.utils.data.Dataset],
    generated: Union[Tensor, torch.utils.data.Dataset],
    *_,
    **__
) -> float:
    """Compute the FID score between real and generated images.

    Args:
        real: Tensor or Dataset-like containing real images (B, C, H, W).
        generated: Tensor or Dataset-like containing generated images 
                  (B, C, H, W).

    Returns:
        The computed FID score (lower is better).
    """
    def to_tensor(data):
        """Convert various inputs to tensors.

        Args:
            data: Input data (tensor or dataset).

        Returns:
            Tensor representation of the input data.
        """
        if isinstance(data, Tensor):
            return data
        dataloader = DataLoader(data, batch_size=64, shuffle=False)
        batch = next(iter(dataloader))
        return batch[0] if isinstance(batch, (list, tuple)) else batch

    real = to_tensor(real).to(self.model.device)
    generated = to_tensor(generated).to(self.model.device)

    # Process in batches to avoid memory issues
    real_activations = self._get_activations(real)
    gen_activations = self._get_activations(generated)

    # Ensure we have enough samples for covariance calculation
    if real_activations.shape[0] < 2 or gen_activations.shape[0] < 2:
        print("Warning: Need at least 2 samples for FID calculation")
        return float('nan')

    # If dimensions don't match (rare case), resize activations
    if real_activations.shape[0] != gen_activations.shape[0]:
        min_size = min(real_activations.shape[0], gen_activations.shape[0])
        real_activations = real_activations[:min_size]
        gen_activations = gen_activations[:min_size]

    return self._calculate_fid(real_activations, gen_activations)

__init__(model, dims=2048)

Initialize the FID metric with model and feature dimensions.

Parameters:

Name Type Description Default
model GenerativeModel

The generative model to be evaluated.

required
dims int

The feature dimension to use from the Inception model.

2048
Source code in image_gen\metrics\fid.py
def __init__(self, model: GenerativeModel, dims: int = 2048):
    """Initialize the FID metric with model and feature dimensions.

    Args:
        model: The generative model to be evaluated.
        dims: The feature dimension to use from the Inception model.
    """
    super().__init__(model)
    self.dims = dims
    self.inception = self._get_inception()