piqa.fid#
Fréchet Inception Distance (FID)
This module implements the FID in PyTorch.
Original
https://github.com/bioinf-jku/TTUR
Wikipedia
https://wikipedia.org/wiki/Frechet_inception_distance
References
Functions#
Returns the Fréchet distance between two multivariate Gaussian distributions. |
|
Returns the square root of a positive semi-definite matrix. |
Classes#
Measures the FID between two set of inception features. |
|
Pretrained Inception-v3 network. |
Descriptions#
- piqa.fid.sqrtm(sigma)#
Returns the square root of a positive semi-definite matrix.
\[\sqrt{\Sigma} = Q \sqrt{\Lambda} Q^T\]where \(Q \Lambda Q^T\) is the eigendecomposition of \(\Sigma\).
- Parameters:
sigma (Tensor) – A positive semi-definite matrix, \((*, D, D)\).
Example
>>> V = torch.randn(4, 4, dtype=torch.double) >>> A = V @ V.T >>> B = sqrtm(A @ A) >>> torch.allclose(A, B) True
- piqa.fid.frechet_distance(mu_x, sigma_x, mu_y, sigma_y)#
Returns the Fréchet distance between two multivariate Gaussian distributions.
\[d^2 = \left\| \mu_x - \mu_y \right\|_2^2 + \operatorname{tr} \left( \Sigma_x + \Sigma_y - 2 \sqrt{\Sigma_y^{\frac{1}{2}} \Sigma_x \Sigma_y^{\frac{1}{2}}} \right)\]Wikipedia
https://wikipedia.org/wiki/Frechet_distance
- Parameters:
mu_x (Tensor) – The mean \(\mu_x\) of the first distribution, \((*, D)\).
sigma_x (Tensor) – The covariance \(\Sigma_x\) of the first distribution, \((*, D, D)\).
mu_y (Tensor) – The mean \(\mu_y\) of the second distribution, \((*, D)\).
sigma_y (Tensor) – The covariance \(\Sigma_y\) of the second distribution, \((*, D, D)\).
Example
>>> mu_x = torch.arange(3).float() >>> sigma_x = torch.eye(3) >>> mu_y = 2 * mu_x + 1 >>> sigma_y = 2 * sigma_x + 1 >>> frechet_distance(mu_x, sigma_x, mu_y, sigma_y) tensor(15.8710)
- class piqa.fid.InceptionV3(logits=True)#
Pretrained Inception-v3 network.
References
Rethinking the Inception Architecture for Computer Vision (Szegedy et al., 2015)- Parameters:
logits (bool) – Whether to return the class logits or the last pooling features.
Example
>>> x = torch.randn(5, 3, 256, 256) >>> inception = InceptionV3() >>> logits = inception(x) >>> logits.shape torch.Size([5, 1000])
- class piqa.fid.FID#
Measures the FID between two set of inception features.
Note
See
FID.features
for how to get inception features.Example
>>> criterion = FID() >>> x = torch.randn(1024, 256) >>> y = torch.randn(2048, 256) >>> l = criterion(x, y) >>> l.shape torch.Size([])
- features(x, no_grad=True)#
Returns the inception features of an input.
Tip
If you cannot get the inception features of your input at once, for instance because of memory limitations, you can split it in smaller batches and concatenate the outputs afterwards.