piqa.fid#

Fréchet Inception Distance (FID)

This module implements the FID in PyTorch.

Original

https://github.com/bioinf-jku/TTUR

Wikipedia

https://wikipedia.org/wiki/Frechet_inception_distance

References

GANs Trained by a Two Time-Scale Update Rule Converge to a Local Nash Equilibrium (Heusel et al., 2017)

Functions#

frechet_distance

Returns the Fréchet distance between two multivariate Gaussian distributions.

sqrtm

Returns the square root of a positive semi-definite matrix.

Classes#

FID

Measures the FID between two set of inception features.

InceptionV3

Pretrained Inception-v3 network.

Descriptions#

piqa.fid.sqrtm(sigma)#

Returns the square root of a positive semi-definite matrix.

\[\sqrt{\Sigma} = Q \sqrt{\Lambda} Q^T\]

where \(Q \Lambda Q^T\) is the eigendecomposition of \(\Sigma\).

Parameters:

sigma (Tensor) – A positive semi-definite matrix, \((*, D, D)\).

Example

>>> V = torch.randn(4, 4, dtype=torch.double)
>>> A = V @ V.T
>>> B = sqrtm(A @ A)
>>> torch.allclose(A, B)
True
piqa.fid.frechet_distance(mu_x, sigma_x, mu_y, sigma_y)#

Returns the Fréchet distance between two multivariate Gaussian distributions.

\[d^2 = \left\| \mu_x - \mu_y \right\|_2^2 + \operatorname{tr} \left( \Sigma_x + \Sigma_y - 2 \sqrt{\Sigma_y^{\frac{1}{2}} \Sigma_x \Sigma_y^{\frac{1}{2}}} \right)\]

Wikipedia

https://wikipedia.org/wiki/Frechet_distance

Parameters:
  • mu_x (Tensor) – The mean \(\mu_x\) of the first distribution, \((*, D)\).

  • sigma_x (Tensor) – The covariance \(\Sigma_x\) of the first distribution, \((*, D, D)\).

  • mu_y (Tensor) – The mean \(\mu_y\) of the second distribution, \((*, D)\).

  • sigma_y (Tensor) – The covariance \(\Sigma_y\) of the second distribution, \((*, D, D)\).

Example

>>> mu_x = torch.arange(3).float()
>>> sigma_x = torch.eye(3)
>>> mu_y = 2 * mu_x + 1
>>> sigma_y = 2 * sigma_x + 1
>>> frechet_distance(mu_x, sigma_x, mu_y, sigma_y)
tensor(15.8710)
class piqa.fid.InceptionV3(logits=True)#

Pretrained Inception-v3 network.

References

Rethinking the Inception Architecture for Computer Vision (Szegedy et al., 2015)
Parameters:

logits (bool) – Whether to return the class logits or the last pooling features.

Example

>>> x = torch.randn(5, 3, 256, 256)
>>> inception = InceptionV3()
>>> logits = inception(x)
>>> logits.shape
torch.Size([5, 1000])
class piqa.fid.FID#

Measures the FID between two set of inception features.

Note

See FID.features for how to get inception features.

Example

>>> criterion = FID()
>>> x = torch.randn(1024, 256)
>>> y = torch.randn(2048, 256)
>>> l = criterion(x, y)
>>> l.shape
torch.Size([])
features(x, no_grad=True)#

Returns the inception features of an input.

Tip

If you cannot get the inception features of your input at once, for instance because of memory limitations, you can split it in smaller batches and concatenate the outputs afterwards.

Parameters:
  • x (Tensor) – An input tensor, \((N, 3, H, W)\).

  • no_grad (bool) – Whether to disable gradients or not.

Returns:

The features, \((N, 2048)\).

Return type:

Tensor

forward(x, y)#
Parameters:
  • x (Tensor) – An input tensor, \((M, D)\).

  • y (Tensor) – A target tensor, \((N, D)\).

Returns:

The FID, \(()\).

Return type:

Tensor