piqa.lpips#

Learned Perceptual Image Patch Similarity (LPIPS)

This module implements the LPIPS in PyTorch.

Original

https://github.com/richzhang/PerceptualSimilarity

References

The Unreasonable Effectiveness of Deep Features as a Perceptual Metric (Zhang et al., 2018)

Functions#

get_weights

Returns the official LPIPS weights for network.

Classes#

LPIPS

Measures the LPIPS between an input and a target.

Perceptual

Perceptual network that intercepts and returns the output of target layers within its foward pass.

Descriptions#

piqa.lpips.get_weights(network='alex', version='v0.1')#

Returns the official LPIPS weights for network.

Parameters:
  • network (str) – Specifies the perception network that is used: 'alex', 'squeeze' or 'vgg'.

  • version (str) – Specifies the official version release: 'v0.0' or 'v0.1'.

class piqa.lpips.Perceptual(layers, targets)#

Perceptual network that intercepts and returns the output of target layers within its foward pass.

Parameters:
  • layers (List[Module]) – A list of layers.

  • targets (List[int]) – A list of target layer indices.

class piqa.lpips.LPIPS(network='alex', epsilon=1e-10, reduction='mean')#

Measures the LPIPS between an input and a target.

\[\text{LPIPS}(x, y) = \sum_{l \, \in \, \mathcal{F}} w_l \cdot \text{MSE}(\phi_l(x), \phi_l(y))\]

where \(\phi_l\) represents the normalized output of an intermediate layer \(l\) in a perceptual network \(\mathcal{F}\) and \(w_l\) are the official weights of Zhang et al. (2018).

Parameters:
  • network (str) – Specifies the perceptual network \(\mathcal{F}\) to use: 'alex', 'squeeze' or 'vgg'.

  • epsilon (float) – A numerical stability term.

  • reduction (str) – Specifies the reduction to apply to the output: 'none', 'mean' or 'sum'.

Example

>>> criterion = LPIPS()
>>> x = torch.rand(5, 3, 256, 256, requires_grad=True)
>>> y = torch.rand(5, 3, 256, 256)
>>> l = criterion(x, y)
>>> l.shape
torch.Size([])
>>> l.backward()
forward(x, y)#
Parameters:
  • x (Tensor) – An input tensor, \((N, 3, H, W)\).

  • y (Tensor) – A target tensor, \((N, 3, H, W)\).

Returns:

The LPIPS vector, \((N,)\) or \(()\) depending on reduction.

Return type:

Tensor