CycleGAN model

Defines the CycleGAN model architecture.

We use the models that were introduced in the cycleGAN paper.

Generator


source

convT_norm_relu

 convT_norm_relu (ch_in:int, ch_out:int,
                  norm_layer:torch.nn.modules.module.Module, ks:int=3,
                  stride:int=2, bias:bool=True)

source

pad_conv_norm_relu

 pad_conv_norm_relu (ch_in:int, ch_out:int, pad_mode:str,
                     norm_layer:torch.nn.modules.module.Module, ks:int=3,
                     bias:bool=True, pad=1, stride:int=1, activ:bool=True,
                     init=<function kaiming_normal_>, init_gain:int=0.02)

source

ResnetBlock

 ResnetBlock (dim:int, pad_mode:str='reflection',
              norm_layer:torch.nn.modules.module.Module=None,
              dropout:float=0.0, bias:bool=True)

nn.Module for the ResNet Block


source

resnet_generator

 resnet_generator (ch_in:int, ch_out:int, n_ftrs:int=64,
                   norm_layer:torch.nn.modules.module.Module=None,
                   dropout:float=0.0, n_blocks:int=9,
                   pad_mode:str='reflection')

Test generator

Let’s test for a few things: 1. The generator can indeed be initialized correctly 2. A random image can be passed into the model successfully with the correct size output 3. The CycleGAN generator is equivalent to the original implementation

First let’s create a random batch:

img1 = torch.randn(4,3,256,256)
m = resnet_generator(3,3)
with torch.no_grad():
    out1 = m(img1)
out1.shape
torch.Size([4, 3, 256, 256])
m_junyanz = define_G(3,3,64,'resnet_9blocks', norm='instance')
with torch.no_grad():
    out2 = m_junyanz(img1)
out2.shape
initialize network with normal
torch.Size([4, 3, 256, 256])

source

compare_networks

 compare_networks (a, b)

A simple function to compare the printed model representations as a proxy for actually comparing two models

test_eq(out1.shape,img1.shape)
test_eq(out2.shape,img1.shape)
assert compare_networks(list(m_junyanz.children())[0],m)
Passed!

Discriminator


source

conv_norm_lr

 conv_norm_lr (ch_in:int, ch_out:int,
               norm_layer:torch.nn.modules.module.Module=None, ks:int=3,
               bias:bool=True, pad:int=1, stride:int=1, activ:bool=True,
               slope:float=0.2, init=<function normal_>,
               init_gain:int=0.02)

source

discriminator

 discriminator (ch_in:int, n_ftrs:int=64, n_layers:int=3,
                norm_layer:torch.nn.modules.module.Module=None,
                sigmoid:bool=False)

Test discriminator

Let’s test for similar things: 1. The discriminator can indeed be initialized correctly 2. A random image can be passed into the discriminator successfully with the correct size output 3. The CycleGAN discriminator is equivalent to the original implementation

d = discriminator(3)
with torch.no_grad():
    out1 = d(img1)
out1.shape
torch.Size([4, 1, 30, 30])
img1 = torch.randn(4,3,256,256)
d_junyanz = define_D(3,64,'basic',norm='instance')
with torch.no_grad():
    out2 = d_junyanz(img1)
out2.shape
initialize network with normal
torch.Size([4, 1, 30, 30])
test_eq(out1.shape,torch.Size([4, 1, 30, 30]))
test_eq(out2.shape,torch.Size([4, 1, 30, 30]))
assert compare_networks(list(d_junyanz.children())[0],d)
Passed!

Full model

We group two discriminators and two generators in a single model, then a Callback (defined in 02_cyclegan_training.ipynb) will take care of training them properly. We use the PyTorchModelHubMixin to provide support for pushing to and loading from the HuggingFace Hub.


source

CycleGAN

 CycleGAN (ch_in:int=3, ch_out:int=3, n_features:int=64,
           disc_layers:int=3, gen_blocks:int=9, lsgan:bool=True,
           drop:float=0.0, norm_layer:torch.nn.modules.module.Module=None)

CycleGAN model.

When called, takes in input batch of real images from both domains and outputs fake images for the opposite domains (with the generators). Also outputs identity images after passing the images into generators that outputs its domain type (needed for identity loss).

Attributes:

G_A (nn.Module): takes real input B and generates fake input A

G_B (nn.Module): takes real input A and generates fake input B

D_A (nn.Module): trained to make the difference between real input A and fake input A

D_B (nn.Module): trained to make the difference between real input B and fake input B


source

CycleGAN.__init__

 CycleGAN.__init__ (ch_in:int=3, ch_out:int=3, n_features:int=64,
                    disc_layers:int=3, gen_blocks:int=9, lsgan:bool=True,
                    drop:float=0.0,
                    norm_layer:torch.nn.modules.module.Module=None)

Constructor for CycleGAN model.

Arguments:

ch_in (int): Number of input channels (default=3)

ch_out (int): Number of output channels (default=3)

n_features (int): Number of input features (default=64)

disc_layers (int): Number of discriminator layers (default=3)

gen_blocks (int): Number of residual blocks in the generator (default=9)

lsgan (bool): LSGAN training objective (output unnormalized float) or not? (default=True)

drop (float): Level of dropout (default=0)

norm_layer (nn.Module): Type of normalization layer to use in the models (default=None)


source

CycleGAN.forward

 CycleGAN.forward (input)

Forward function for CycleGAN model. The input is a tuple of a batch of real images from both domains A and B.


ModelHubMixin.push_to_hub

 ModelHubMixin.push_to_hub (repo_id:str, config:Optional[dict]=None,
                            commit_message:str='Push model using
                            huggingface_hub.', private:bool=False,
                            api_endpoint:Optional[str]=None,
                            token:Optional[str]=None,
                            branch:Optional[str]=None,
                            create_pr:Optional[bool]=None, allow_patterns:
                            Union[List[str],str,NoneType]=None, ignore_pat
                            terns:Union[List[str],str,NoneType]=None, dele
                            te_patterns:Union[List[str],str,NoneType]=None
                            )

Upload model checkpoint to the Hub.

Use allow_patterns and ignore_patterns to precisely filter which files should be pushed to the hub. Use delete_patterns to delete existing remote files in the same commit. See [upload_folder] reference for more details.

Args: repo_id (str): ID of the repository to push to (example: "username/my-model"). config (dict, optional): Configuration object to be saved alongside the model weights. commit_message (str, optional): Message to commit while pushing. private (bool, optional, defaults to False): Whether the repository created should be private. api_endpoint (str, optional): The API endpoint to use when pushing the model to the hub. token (str, optional): The token to use as HTTP bearer authorization for remote files. By default, it will use the token cached when running huggingface-cli login. branch (str, optional): The git branch on which to push the model. This defaults to "main". create_pr (boolean, optional): Whether or not to create a Pull Request from branch with that commit. Defaults to False. allow_patterns (List[str] or str, optional): If provided, only files matching at least one pattern are pushed. ignore_patterns (List[str] or str, optional): If provided, files matching any of the patterns are not pushed. delete_patterns (List[str] or str, optional): If provided, remote files matching any of the patterns will be deleted from the repo.

Returns: The url of the commit of your model in the given repository.


ModelHubMixin.from_pretrained

 ModelHubMixin.from_pretrained (cls:Type[~T],
                                pretrained_model_name_or_path:Union[str,pa
                                thlib.Path], force_download:bool=False,
                                resume_download:bool=False,
                                proxies:Optional[Dict]=None,
                                token:Union[bool,str,NoneType]=None, cache
                                _dir:Union[pathlib.Path,str,NoneType]=None
                                , local_files_only:bool=False,
                                revision:Optional[str]=None,
                                **model_kwargs)

Download a model from the Huggingface Hub and instantiate it.

Args: pretrained_model_name_or_path (str, Path): - Either the model_id (string) of a model hosted on the Hub, e.g. bigscience/bloom. - Or a path to a directory containing model weights saved using [~transformers.PreTrainedModel.save_pretrained], e.g., ../path/to/my_model_directory/. revision (str, optional): Revision of the model on the Hub. Can be a branch name, a git tag or any commit id. Defaults to the latest commit on main branch. force_download (bool, optional, defaults to False): Whether to force (re-)downloading the model weights and configuration files from the Hub, overriding the existing cache. resume_download (bool, optional, defaults to False): Whether to delete incompletely received files. Will attempt to resume the download if such a file exists. proxies (Dict[str, str], optional): A dictionary of proxy servers to use by protocol or endpoint, e.g., {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}. The proxies are used on every request. token (str or bool, optional): The token to use as HTTP bearer authorization for remote files. By default, it will use the token cached when running huggingface-cli login. cache_dir (str, Path, optional): Path to the folder where cached files are stored. local_files_only (bool, optional, defaults to False): If True, avoid downloading the file and return the path to the local cached file if it exists. model_kwargs (Dict, optional): Additional kwargs to pass to the model during initialization.

Quick model tests

Again, let’s check that the model can be called sucsessfully and outputs the correct shapes.

cyclegan_model = CycleGAN()
img1 = torch.randn(4,3,256,256)
img2 = torch.randn(4,3,256,256)
with torch.no_grad(): cyclegan_output = cyclegan_model((img1,img2))
CPU times: user 1min 15s, sys: 6.67 s, total: 1min 22s
Wall time: 2.25 s
test_eq(len(cyclegan_output),4)
for output_batch in cyclegan_output:
    test_eq(output_batch.shape,img1.shape)
cyclegan_model.push_to_hub('upit-cyclegan-test')
Cloning https://huggingface.co/tmabraham/upit-cyclegan-test into local empty directory.
To https://huggingface.co/tmabraham/upit-cyclegan-test
   a41e9e0..2331f7d  main -> main
'https://huggingface.co/tmabraham/upit-cyclegan-test/commit/2331f7d345d719ac1fdfb10b2cddf58abd7931bb'
cyclegan_model.from_pretrained('tmabraham/upit-cyclegan-test')