Metrics for unpaired image-to-image translation

Defines functionality for implementing metrics for unpaired image-to-image translation, including common metrics like FID, KID, etc.
set_seed(999, reproducible=True)

Fréchet Inception Distance

This code is based on this implementation and this implementation, adapted to fastai’s metric API.


source

InceptionV3

 InceptionV3 ()

Base class for all neural network modules.

Your models should also subclass this class.

Modules can also contain other Modules, allowing to nest them in a tree structure. You can assign the submodules as regular attributes::

import torch.nn as nn
import torch.nn.functional as F

class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 20, 5)
        self.conv2 = nn.Conv2d(20, 20, 5)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        return F.relu(self.conv2(x))

Submodules assigned in this way will be registered, and will have their parameters converted too when you call :meth:to, etc.

.. note:: As per the example above, an __init__() call to the parent class must be made before assignment on the child.

:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool


source

FrechetInceptionDistance

 FrechetInceptionDistance (model=None, device='cuda', yb_idx=0,
                           pred_idx=1)

Blueprint for defining a metric

The FrechetInceptionDistance metric works by initializing an Inception model, extracting Inception activation features for each batch of predictions and example images (target), and at the end calculate the statistics and the Frechet distance. Below are test for each of these components.

fid = FrechetInceptionDistance(device='cpu')
size = (224, 224, 3)
arrays = [np.zeros(size), np.ones(size) * 0.5, np.ones(size)]*2
img_like_tensor = torch.from_numpy(np.array(arrays)).float()
test_eq(fid.calc_activations_for_batch(img_like_tensor.permute(0,3,1,2),model=fid.model,device='cpu').shape, (img_like_tensor.shape[0],2048))
class fake_model(nn.Module):
    def __init__(self): super(fake_model, self).__init__()
    def forward(self,x): return x.mean(dim=(2,3))

size = (4, 4, 3)
arrays = [np.zeros(size), np.ones(size) * 0.5, np.ones(size)]
input_tensor = torch.from_numpy(np.array(arrays)).float()

stats = fid.calculate_activation_statistics(fid.calc_activations_for_batch(input_tensor.permute(0,3,1,2),model=fake_model()))
test_eq(stats[0], np.ones((3,)) * 0.5)
test_eq(stats[1], np.ones((3, 3)) * 0.25)
m1, m2 = np.zeros((2048,)), np.ones((2048,))
sigma = np.eye(2048)
# Given equal covariance, FID is just the squared norm of difference
test_eq(fid.calculate_frechet_distance(m1,sigma,m2,sigma), np.sum((m1 - m2)**2))
class FakeLearner():
    def __init__(self):
        self.yb = [img_like_tensor.permute(0,3,1,2)]
        self.pred = [None, img_like_tensor.permute(0,3,1,2)]
learn = FakeLearner()
for i in range(5):
    fid.accumulate(learn)
print(fid.value)
-0.0002026209404561996
CPU times: user 25.9 s, sys: 573 ms, total: 26.5 s
Wall time: 26.4 s

Quick Test

horse2zebra = untar_data('https://people.eecs.berkeley.edu/~taesung_park/CycleGAN/datasets/horse2zebra.zip')
folders = horse2zebra.ls().sorted()
trainA_path = folders[2]
trainB_path = folders[3]
testA_path = folders[0]
testB_path = folders[1]
dls = get_dls(trainA_path, trainB_path,num_A=100)
cycle_gan = CycleGAN(3,3,64)
learn = cycle_learner(dls, cycle_gan,metrics=[FrechetInceptionDistance()],show_img_interval=1)
learn.fit_flat_lin(5,5,2e-4)
epoch train_loss id_loss_A id_loss_B gen_loss_A gen_loss_B cyc_loss_A cyc_loss_B D_A_loss D_B_loss frechet_inception_distance time
0 10.250998 1.617909 1.524502 0.380783 0.408158 3.353981 3.306977 0.414352 0.414352 90.679044 00:48
1 8.985356 1.236593 1.236445 0.291166 0.296743 2.552362 2.653999 0.261017 0.261017 90.476945 00:48
2 8.218993 1.139003 1.041497 0.287646 0.302079 2.395697 2.288882 0.250520 0.250520 92.738160 00:49
3 7.720056 0.971353 1.124618 0.287192 0.313714 2.050789 2.414080 0.249346 0.249346 93.015558 00:49
4 7.342298 0.989288 0.983895 0.300589 0.326872 2.093399 2.157701 0.241617 0.241617 93.109652 00:50
5 7.034866 1.000397 0.922700 0.315030 0.325621 2.095371 1.945555 0.235398 0.235398 91.938541 00:50
6 6.892154 0.982663 0.920580 0.327344 0.346113 2.068103 2.069340 0.227647 0.227647 91.688272 00:50
7 6.704684 0.983012 0.865727 0.347819 0.364232 1.979038 1.902086 0.219287 0.219287 91.998393 00:50
8 6.627498 0.934832 0.936938 0.339186 0.349980 1.971006 2.006817 0.217617 0.217617 92.316292 00:51
9 6.321739 0.864914 0.824035 0.341803 0.348890 1.770928 1.729100 0.211070 0.211070 91.703895 00:52
/home/tmabraham/fastai/fastai/callback/core.py:50: UserWarning: You are shadowing an attribute (G_A) that exists in the learner. Use `self.learn.G_A` to avoid this
  warn(f"You are shadowing an attribute ({name}) that exists in the learner. Use `self.learn.{name}` to avoid this")
/home/tmabraham/fastai/fastai/callback/core.py:50: UserWarning: You are shadowing an attribute (G_B) that exists in the learner. Use `self.learn.G_B` to avoid this
  warn(f"You are shadowing an attribute ({name}) that exists in the learner. Use `self.learn.{name}` to avoid this")
/home/tmabraham/fastai/fastai/callback/core.py:50: UserWarning: You are shadowing an attribute (D_A) that exists in the learner. Use `self.learn.D_A` to avoid this
  warn(f"You are shadowing an attribute ({name}) that exists in the learner. Use `self.learn.{name}` to avoid this")
/home/tmabraham/fastai/fastai/callback/core.py:50: UserWarning: You are shadowing an attribute (D_B) that exists in the learner. Use `self.learn.D_B` to avoid this
  warn(f"You are shadowing an attribute ({name}) that exists in the learner. Use `self.learn.{name}` to avoid this")
/home/tmabraham/anaconda3/lib/python3.7/site-packages/fastprogress/fastprogress.py:74: UserWarning: Your generator is empty.
  warn("Your generator is empty.")