import React, { useState, useEffect } from 'react';
import { Helmet } from 'react-helmet';

import CodeBlock from '../../components/CodeBlock'
import Title from '../../components/Title'
import Paragraph from '../../components/Paragraph'
import SubTitle from '../../components/SubTitle'
import ImageBlock from '../../components/ImageBlock'
import DownSpace from '../../components/DownSpace'
import ColabButton from '../../components/ColabButton'

const imports = `!pip install opendatasets
import torch, opendatasets, torchvision, PIL, os, numpy, tqdm, matplotlib.pyplot`

const datasetcomponent = `class CustomDataset(torch.utils.data.Dataset):
  def __init__(self, root_dir):
    self.root_dir = root_dir
    self.image_paths = [os.path.join(root_dir, f) for f in os.listdir(root_dir) if f.endswith(('.jpg', '.jpeg', '.png'))]
    self.transform = torchvision.transforms.Compose([
        torchvision.transforms.Resize((256, 256)),
        torchvision.transforms.RandomHorizontalFlip(),
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
    ])

  def __len__(self):
    return len(self.image_paths)

  def __getitem__(self, idx):
    img_path = self.image_paths[idx]
    image = PIL.Image.open(img_path).convert("RGB")
    image = self.transform(image)
    return image`

const vqgan1 = `class GroupNorm(torch.nn.Module):
  def __init__(self, channels):
    super(GroupNorm, self).__init__()

    self.gn = torch.nn.GroupNorm(num_groups=32, num_channels=channels, eps=1e-6, affine=True)

  def forward(self, x):
    return self.gn(x)`

const vqgan2 = `class Swish(torch.nn.Module):
  def forward(self, x):
    return x * torch.sigmoid(x)`

const vqgan3 = `class ResidualBlock(torch.nn.Module):
  def __init__(self, in_channels, out_channels):
    super(ResidualBlock, self).__init__()

    self.in_channels = in_channels
    self.out_channels = out_channels
    self.block = torch.nn.Sequential(
        GroupNorm(in_channels),
        Swish(),
        torch.nn.Conv2d(in_channels, out_channels, 3, 1, 1),
        GroupNorm(out_channels),
        Swish(),
        torch.nn.Conv2d(out_channels, out_channels, 3, 1, 1)
    )

    if in_channels != out_channels:
        self.channel_up = torch.nn.Conv2d(in_channels, out_channels, 1, 1, 0)

  def forward(self, x):
    if self.in_channels != self.out_channels:
        return self.channel_up(x) + self.block(x)
    else:
        return x + self.block(x)`

const vqgan4 = `class UpSampleBlock(torch.nn.Module):
  def __init__(self, channels):
    super(UpSampleBlock, self).__init__()

    self.conv = torch.nn.Conv2d(channels, channels, 3, 1, 1)

  def forward(self, x):
    x = torch.nn.functional.interpolate(x, scale_factor=2.0)
    return self.conv(x)`

const vqgan5 = `class DownSampleBlock(torch.nn.Module):
  def __init__(self, channels):
    super(DownSampleBlock, self).__init__()

    self.conv = torch.nn.Conv2d(channels, channels, 3, 2, 0)

  def forward(self, x):
    pad = (0, 1, 0, 1)
    x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
    return self.conv(x)`

const vqgan6 = `class NonLocalBlock(torch.nn.Module):
  def __init__(self, channels):
    super(NonLocalBlock, self).__init__()

    self.in_channels = channels

    self.gn = GroupNorm(channels)
    self.q = torch.nn.Conv2d(channels, channels, 1, 1, 0)
    self.k = torch.nn.Conv2d(channels, channels, 1, 1, 0)
    self.v = torch.nn.Conv2d(channels, channels, 1, 1, 0)
    self.proj_out = torch.nn.Conv2d(channels, channels, 1, 1, 0)

  def forward(self, x):
    h_ = self.gn(x)
    q  = self.q(h_)
    k  = self.k(h_)
    v  = self.v(h_)

    b, c, h, w = q.shape

    q = q.reshape(b, c, h*w)
    q = q.permute(0, 2, 1)
    k = k.reshape(b, c, h*w)
    v = v.reshape(b, c, h*w)

    attn = torch.bmm(q, k)
    attn = attn * (int(c)**(-0.5))
    attn = torch.nn.functional.softmax(attn, dim=2)
    attn = attn.permute(0, 2, 1)

    A = torch.bmm(v, attn)
    A = A.reshape(b, c, h, w)

    return x + A`

const vqgan7 = `class Encoder(torch.nn.Module):
  def __init__(self, image_channels=3, latent_dim=256):
    super(Encoder, self).__init__()
    channels = [256, 256, 512, 512]
    attn_resolutions = [16]
    num_res_blocks = 3
    resolution = 256
    layers = [torch.nn.Conv2d(image_channels, channels[0], 3, 1, 1)]
    for i in range(len(channels)-1):
        in_channels = channels[i]
        out_channels = channels[i + 1]
        for j in range(num_res_blocks):
            layers.append(ResidualBlock(in_channels, out_channels))
            in_channels = out_channels
            if resolution in attn_resolutions:
                layers.append(NonLocalBlock(in_channels))
        if i != len(channels)-2:
            layers.append(DownSampleBlock(channels[i+1]))
            resolution //= 2
    layers.append(ResidualBlock(channels[-1], channels[-1]))
    layers.append(NonLocalBlock(channels[-1]))
    layers.append(ResidualBlock(channels[-1], channels[-1]))
    layers.append(GroupNorm(channels[-1]))
    layers.append(Swish())
    layers.append(torch.nn.Conv2d(channels[-1], latent_dim, 3, 1, 1))
    self.model = torch.nn.Sequential(*layers)

  def forward(self, x):
    return self.model(x)`

const vqgan8 = `class Codebook(torch.nn.Module):
  def __init__(self, num_codebook_vectors=1024, latent_dim=256, beta=0.25):
    super(Codebook, self).__init__()

    self.num_codebook_vectors = num_codebook_vectors
    self.latent_dim = latent_dim
    self.beta = beta

    self.embedding = torch.nn.Embedding(self.num_codebook_vectors, self.latent_dim)
    self.embedding.weight.data.uniform_(-1.0 / self.num_codebook_vectors, 1.0 / self.num_codebook_vectors)

  def forward(self, z):
    z = z.permute(0, 2, 3, 1).contiguous()
    z_flattened = z.view(-1, self.latent_dim)

    d = torch.sum(z_flattened**2, dim=1, keepdim=True) + torch.sum(self.embedding.weight**2, dim=1) - 2*(torch.matmul(z_flattened, self.embedding.weight.t()))

    min_encoding_indices = torch.argmin(d, dim=1)
    z_q = self.embedding(min_encoding_indices).view(z.shape)

    loss = torch.mean((z_q.detach() - z)**2) + self.beta * torch.mean((z_q - z.detach())**2)

    z_q = z + (z_q - z).detach()

    z_q = z_q.permute(0, 3, 1, 2)

    return z_q, min_encoding_indices, loss`

const vqgan9 = `class Decoder(torch.nn.Module):
  def __init__(self, image_channels=3, latent_dim=256):
    super(Decoder, self).__init__()
    channels = [512, 512, 256]
    attn_resolutions = [16]
    num_res_blocks = 4
    resolution = 16

    in_channels = channels[0]
    layers = [torch.nn.Conv2d(latent_dim, in_channels, 3, 1, 1),
              ResidualBlock(in_channels, in_channels),
              NonLocalBlock(in_channels),
              ResidualBlock(in_channels, in_channels)]

    for i in range(len(channels)):
        out_channels = channels[i]
        for j in range(num_res_blocks):
            layers.append(ResidualBlock(in_channels, out_channels))
            in_channels = out_channels
            if resolution in attn_resolutions:
                layers.append(NonLocalBlock(in_channels))
        if i != 0:
            layers.append(UpSampleBlock(in_channels))
            resolution *= 2

    layers.append(GroupNorm(in_channels))
    layers.append(Swish())
    layers.append(torch.nn.Conv2d(in_channels, image_channels, 3, 1, 1))
    self.model = torch.nn.Sequential(*layers)

  def forward(self, x):
    return self.model(x)`

const vqgan10 = `class VQGAN(torch.nn.Module):
  def __init__(self):
    super(VQGAN, self).__init__()

    self.encoder = Encoder(image_channels=3, latent_dim=1024).to(device=device)

    self.codebook = Codebook(num_codebook_vectors=4096, latent_dim=1024, beta=0.25).to(device=device)

    self.decoder = Decoder(image_channels=3, latent_dim=1024).to(device=device)

    self.quant_conv      = torch.nn.Conv2d(1024, 1024, 1).to(device=device)
    self.post_quant_conv = torch.nn.Conv2d(1024, 1024, 1).to(device=device)

  def forward(self, imgs):
    encoded_images = self.encoder(imgs)
    quant_conv_encoded_images = self.quant_conv(encoded_images)
    codebook_mapping, codebook_indices, q_loss = self.codebook(quant_conv_encoded_images)
    post_quant_conv_mapping = self.post_quant_conv(codebook_mapping)
    decoded_images = self.decoder(post_quant_conv_mapping)
    return decoded_images, codebook_indices, q_loss

  def encode(self, imgs):
    encoded_images = self.encoder(imgs)
    quant_conv_encoded_images = self.quant_conv(encoded_images)
    codebook_mapping, codebook_indices, q_loss = self.codebook(quant_conv_encoded_images)
    return codebook_mapping, codebook_indices, q_loss

  def decode(self, z):
    post_quant_conv_mapping = self.post_quant_conv(z)
    decoded_images = self.decoder(post_quant_conv_mapping)
    return decoded_images

  def calculate_lambda(self, perceptual_loss, gan_loss):
    last_layer = self.decoder.model[-1]
    last_layer_weight = last_layer.weight
    perceptual_loss_grads = torch.autograd.grad(perceptual_loss, last_layer_weight, retain_graph=True)[0]
    gan_loss_grads = torch.autograd.grad(gan_loss, last_layer_weight, retain_graph=True)[0]

    λ = torch.norm(perceptual_loss_grads) / (torch.norm(gan_loss_grads) + 1e-4)
    λ = torch.clamp(λ, 0, 1e4).detach()
    return 0.8 * λ`

const vqgan11 = `def weights_init(m):
  classname = m.__class__.__name__
  if classname.find('Conv') != -1:
    torch.nn.init.normal_(m.weight.data, 0.0, 0.02)
  elif classname.find('BatchNorm') != -1:
    torch.nn.init.normal_(m.weight.data, 1.0, 0.02)
    torch.nn.init.constant_(m.bias.data, 0)`


const discriminator = `class Discriminator(torch.nn.Module):
  def __init__(self, image_channels=3, num_filters_last=64, n_layers=3):
    super(Discriminator, self).__init__()

    layers = [torch.nn.Conv2d(image_channels, num_filters_last, 4, 2, 1), torch.nn.LeakyReLU(0.2)]
    num_filters_mult = 1

    for i in range(1, n_layers + 1):
        num_filters_mult_last = num_filters_mult
        num_filters_mult = min(2 ** i, 8)
        layers += [
            torch.nn.Conv2d(num_filters_last * num_filters_mult_last, num_filters_last * num_filters_mult, 4,
                      2 if i < n_layers else 1, 1, bias=False),
            torch.nn.BatchNorm2d(num_filters_last * num_filters_mult),
            torch.nn.LeakyReLU(0.2, True)
        ]

    layers.append(torch.nn.Conv2d(num_filters_last * num_filters_mult, 1, 4, 1, 1))
    self.model = torch.nn.Sequential(*layers)

  def forward(self, x):
    return self.model(x)`

const perceptualLoss = `class VGGLoss(torch.nn.Module):
  def __init__(self):
    super().__init__()

    self.vgg  = torchvision.models.vgg19(pretrained=True).features[:36].eval().to(device)
    self.loss = torch.nn.MSELoss()

    for param in self.vgg.parameters():
        param.requires_grad = False

  def forward(self, input, target):
    vgg_input_features = self.vgg(input)
    vgg_target_features = self.vgg(target)
    return self.loss(vgg_input_features, vgg_target_features)`

const models = `vqgan = VQGAN().to(device)

discriminator = Discriminator().to(device)`

const perceptual = `perceptual_loss = VGGLoss()`

const learningrates = `lr = 0.0001

opt_vq = torch.optim.Adam(
    list(vqgan.encoder.parameters()) +
    list(vqgan.decoder.parameters()) +
    list(vqgan.codebook.parameters()) +
    list(vqgan.quant_conv.parameters()) +
    list(vqgan.post_quant_conv.parameters()),
    lr=lr, eps=1e-08, betas=(0, 0.999)
)

opt_disc = torch.optim.Adam(discriminator.parameters(),
                            lr=lr, eps=1e-08, betas=(0.5, 0.9))`


const epochs = `num_epochs = 100

disc_start = 5`

const factors = `disc_factor = 1

rec_loss_factor = 1

perceptual_loss_factor = 1`

const trainingloop = `for epoch in range(num_epochs):
  steps_per_epoch = len(dataloader)

  with tqdm.tqdm(enumerate(dataloader), total=len(dataloader)) as pbar:
    for i, imgs in pbar:
        imgs = imgs.to(device)
        decoded_images, _, q_loss = vqgan(imgs)

        disc_real = discriminator(imgs)
        disc_fake = discriminator(decoded_images)

        disc_factor = 1 if epoch + 1 > disc_start else 0

        perceptual_loss_val = perceptual_loss(imgs, decoded_images)
        rec_loss = torch.nn.functional.l1_loss(decoded_images, imgs)
        perceptual_rec_loss = perceptual_loss_factor * perceptual_loss_val + rec_loss_factor * rec_loss
        perceptual_rec_loss = perceptual_rec_loss.mean()
        g_loss = -torch.mean(disc_fake)

        λ = vqgan.calculate_lambda(perceptual_rec_loss, g_loss)
        vq_loss = perceptual_rec_loss + q_loss + disc_factor * λ * g_loss

        d_loss_real = torch.mean(torch.nn.functional.relu(1. - disc_real))
        d_loss_fake = torch.mean(torch.nn.functional.relu(1. + disc_fake))
        gan_loss = disc_factor * 0.5*(d_loss_real + d_loss_fake)

        opt_vq.zero_grad()
        vq_loss.backward(retain_graph=True)

        opt_disc.zero_grad()
        gan_loss.backward()

        opt_vq.step()
        opt_disc.step()

        pbar.set_postfix(
            VQ_Loss=numpy.round(vq_loss.item(), 5),
            GAN_Loss=numpy.round(gan_loss.item(), 3)
        )
`

const autoencode = `images_batch = next(iter(val_dataloader))
image1 = images_batch[0]

encoded_image1, _, _ = vqgan.encode(image1.unsqueeze(0).to(device))
decoded_image1 = vqgan.decode(encoded_image1).mul(0.5).add(0.5).squeeze().cpu().detach()

fig, axes = matplotlib.pyplot.subplots(1, 2, figsize=(7, 5))
axes[0].imshow(image1.mul(0.5).add(0.5).permute(1, 2, 0))
axes[0].set_title('Original Image')
axes[0].axis('off')

axes[1].imshow(decoded_image1.permute(1, 2, 0))
axes[1].set_title('Outputted Image')
axes[1].axis('off')

matplotlib.pyplot.tight_layout()
matplotlib.pyplot.show()`

const interpolate = `num_interpolates = 6
with torch.no_grad():
    images_batch = next(iter(val_dataloader))
    image1, image2 = images_batch[:2]

    z1, _, _ = vqgan.encode(image1.unsqueeze(0).to(device))
    z2, _, _ = vqgan.encode(image2.unsqueeze(0).to(device))

    interpolation_vectors = torch.stack([z1 + (i/num_interpolates) * (z2 - z1) for i in range(num_interpolates)])
    interpolated_images = vqgan.decode(interpolation_vectors.squeeze(1)).mul(0.5).add(0.5)
    
    fig, axes = matplotlib.pyplot.subplots(1, num_interpolates, figsize=(15, 5))
    for i in range(num_interpolates):
        axes[i].imshow(interpolated_images[i].cpu().detach().permute(1, 2, 0))
        axes[i].axis('off')
        axes[i].set_title(f'Interpolation {i+1}')
    matplotlib.pyplot.show()`

const download = `!pip install gdown

!gdown 1fu6YRv4bBtEFWsF5v0AtGMaGcrqYwGzG
!gdown 1FbpM5MET6vTaL6c_5UhgsCJN9R0HLHJx`

const load = `discriminator.load_state_dict(torch.load('discriminator_model.pth'))
vqgan.load_state_dict(torch.load('vqgan_model.pth'))`


const dataval = `val_dataset = CustomDataset(root_dir="./animal-faces/afhq/val/cat") 

val_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size=6, shuffle=True, num_workers=2, pin_memory=True)`

const VQGAN = () => {

  return (
    <div className='custom-font pt-9'>
        <Helmet>
          <title>State-of-the-Art Latent Space Quantization with VQ-GANs from Scratch Using Pytorch</title>
          <meta name="author" content="Dinis Martinho" />
          <meta name="description" content="In this article, I will guide you through understanding and implementing a Vector Quantized Generative Adversarial Network (VQGAN) within a 
        single Jupyter notebook. I won't delve deeply into the specifics of how and why VQGANs work; instead, I will provide a high-level overview of all the 
        components involved." />
        </Helmet>

        <Title Title="State-of-the-Art Latent Space Quantization with VQ-GANs from Scratch Using Pytorch" date='30 Apr 2024' />
        <br />
        <ColabButton notebookUrl="" paperURL="https://arxiv.org/abs/2012.09841" />

        <SubTitle Title="Introduction" noMarginTop={true} />
        <br />
        <Paragraph text="In this article, I will guide you through understanding and implementing a Vector Quantized Generative Adversarial Network (VQGAN) within a 
        single Jupyter notebook. I won't delve deeply into the specifics of how and why VQGANs work; instead, I will provide a high-level overview of all the 
        components involved." />
        <br />


        <div className='flex justify-center items-center'>
          <img className='max-w-full h-32 sm:h-36 md:h-40 lg:h-40 mx-5' src={process.env.PUBLIC_URL + "/State-of-the-Art-Latent-Space-Quantization-with-VQ-GANs-from-Scratch-In-Pytorch-I/interpolation_0.gif"} />
        </div>

        <br />
        <Paragraph text="VQGANs, short for Vector Quantized Generative Adversarial Networks, have surged in popularity due to their ability to generate high-quality 
        images with remarkable detail and diversity from a discrete set of latent variables. By integrating components of variational autoencoders and generative 
        adversarial networks, VQGANs have introduced a novel approach that has revolutionized the landscape of AI-generated content. Since generative AI is one of 
        my favorite topics and I have recently been researching continuous latent spaces, I knew I had to implement my own VQGAN to fully understand the concepts 
        behind it. This article is intended to both deepen my understanding of the topic and to help anyone trying to implement their own VQGAN. Much of the code 
        used in my implementation can be attributed to the 'VQGAN-pytorch' repository by Dominic Rampas on GitHub. Additionally, I would like to reference the YouTube 
        video by 'Aleksa Gordić - The AI Epiphany' titled 'VQ-GAN: Taming Transformers for High-Resolution Image Synthesis | Paper Explained.' References to both of 
        these resources can be found in the references section at the bottom of this article." />
        <br />
        <Paragraph text="Vector Quantized Generative Adversarial Networks (VQGANs) were first introduced in the paper 'Taming Transformers for High-Resolution Image Synthesis' by Patrick Esser, Robin Rombach, Björn Ommer, and others, published in 2021. VQGANs revolutionized the field of image generation by integrating vector quantization techniques within the GAN framework. These networks encode images into discrete latent representations, which help capture and preserve meaningful semantic information. VQGANs are not primarily used to generate images directly; instead, they facilitate the generation process by encoding and decoding images into and from latent spaces that are easier to handle by other models, such as diffusion models." />
        <br />
        <Paragraph text="Modern VQGANs have been modified to handle various types of data, including video and 3D data. This versatility has demonstrated that VQGANs can work effectively with different kinds of data, further establishing them as a significant advancement in generative modeling and a powerful tool in the realm of AI-generated content." />
        <br />

        <SubTitle Title="Understanding VQGANs from a practical point of view" />
        <br />
        <Paragraph text="From a practical point of view, a VQGAN discretizes the latent space by representing continuous-valued latent vectors with discrete codes. This is achieved through quantization, where continuous latent vectors are mapped to a finite set of codebook entries. Each vector is then assigned to the nearest codebook entry, effectively discretizing the latent space." />
        <br />
        <Paragraph text="This approach enables efficient training and generation of high-quality images. By quantizing the latent space, VQGANs can better capture complex data distributions while reducing computational complexity compared to traditional continuous latent representations." />
        <br />
        <Paragraph text="The concept of discretizing the latent space originated from the paper 'Neural Discrete Representation Learning' by van den Oord et al. (2017). This paper introduced the Vector Quantized Variational Autoencoder (VQ-VAE), which employed discrete latent variables for data representation. VQ-VAEs have since served as a foundation for various models, including VQGANs, which integrate discrete latent representations into Generative Adversarial Networks (GANs) for image generation tasks." />
        <br />
        <Paragraph text="Having discrete latent spaces allows for more structured and interpretable representations of the data. In the context of image generation, discrete latent spaces facilitate the encoding of images into a finite set of distinct elements, capturing essential features and patterns more effectively. This structured representation helps generative models learn and reproduce complex details, leading to higher quality and more diverse images." />
        <br />
        <Paragraph text="Discrete latent spaces enable models to separate and manipulate distinct attributes of an image, such as color, texture, and shape, in a more controlled manner. This separation allows for finer control over the generation process, as each element in the latent space corresponds to a specific characteristic of the image. As a result, models can generate images that are not only realistic but also exhibit a wide range of variations." />
        <br />
        <Paragraph text="Furthermore, working with discrete latent spaces can improve the stability and convergence of generative models. By reducing the complexity of the latent space, models can focus on learning the most critical features without being overwhelmed by noise and irrelevant details. This focused learning process leads to more efficient training and better overall performance in image generation tasks." />
        <br />

        <SubTitle Title="Implementation methodology" />
        <br />
        <Paragraph text="The implementation of this project will be divided into six distinct sections. The process will begin with importing all the necessary libraries, followed by data handling and processing. Subsequently, I will implement the individual building blocks for both the VQGAN and discriminator models. Finally, I will proceed with training and testing the models." />
        <br />
        <Paragraph text="When implementing the VQGAN model, I will prioritize smaller encoder and decoder models and a larger, higher-dimensional latent space. This approach contrasts with the conventional method that favors larger encoder and decoder models with smaller latent spaces. The decision to use a larger latent space is due to my limited computational resources." />
        <br />

        <SubTitle Title="Importing the necessary libraries" />
        <br />
        <CodeBlock code={imports} />
        <br />

        <SubTitle Title="Downloading the dataset from Kaggle" />
        <br />
        <Paragraph text="For this project, I have opted to train my VQGAN using cat faces sourced from the 'Animal Faces' dataset on Kaggle. The dataset will be 
        downloaded utilizing the `opendatasets` library, which requires an API key for access. A reference to this dataset can be located in the references section 
        of this article." />
        <br />
        <CodeBlock code='opendatasets.download("https://www.kaggle.com/datasets/andrewmvd/animal-faces")' />
        <br />

        <SubTitle Title="Creating our custom dataset and data loader components" />
        <br />
        <Paragraph text="I've chosen to develop a straightforward custom dataset component that retrieves images from a designated folder path. The dataset is 
        organized into separate validation and training sets, enabling us to utilize the component twice to handle both categories effectively." />
        <br />
        <CodeBlock code={datasetcomponent} />
        <br />
        <Paragraph text="For the initial phase of implementation, I'll invoke it solely on the training subset. During the testing phase, I'll utilize it again, this time focusing on the validation subset." />
        <br />
        <CodeBlock code='dataset = CustomDataset(root_dir="./animal-faces/afhq/train/cat")' />
        <br />
        <CodeBlock code='dataloader = torch.utils.data.DataLoader(dataset, batch_size=6, shuffle=True, num_workers=2, pin_memory=True)' />
        <br />

        <SubTitle Title="Implementing the VQGAN architecture" />
        <br />
        <Paragraph text="As mentioned earlier, a significant portion of the code utilized in this implementation is sourced from the 'VQGAN-pytorch' repository 
        authored by Dominic Rampas on GitHub. A reference to this repository is provided in the references section located at the bottom of this article. However, 
        I have made some modifications to the code to enhance its efficiency, tailored specifically to accommodate our dataset and my available computational 
        resources. While the specifics of each module are not thoroughly explained, I will attempt to outline their potential functions and roles within the VQGAN context." />
        <br />
        <CodeBlock code={vqgan1} />
        <br />
        <Paragraph text="Batch Normalization `(BatchNorm)` and Group Normalization `(GroupNorm)` are both techniques used to normalize the activations of neural 
        networks, but they differ in their approach. `BatchNorm` operates by normalizing activations across the batch dimension, computing the mean and standard 
        deviation for each feature map. In contrast, `GroupNorm` divides channels into groups and normalizes each group separately, making it less sensitive to 
        batch size variations. While `BatchNorm` is effective with larger batch sizes, `GroupNorm` is advantageous for scenarios with small batch sizes or data with 
        high variance across samples, like this one." />
        <br />
        <CodeBlock code={vqgan2} />
        <br />
        <Paragraph text="The Swish activation function `(Swish)`, is a smooth and non-monotonic function that enhances deep neural network performance by combining the 
        linearity of the input with the non-linearity of the sigmoid function. It encourages more complex feature representations, mitigates vanishing gradients, 
        and generally improves training convergence and model performance." />
        <br />
        <CodeBlock code={vqgan3} />
        <br />
        <Paragraph text="The `ResidualBlock` class defines a building block that is used to facilitate learning deeper representations while 
        avoiding the vanishing gradient problem. It consists of two main components: a sequence of convolutional layers and shortcut connections. " />
        <br />
        <CodeBlock code={vqgan4} />
        <br />
        <Paragraph text="This `UpSampleBlock` class uses interpolation to enlarge the spatial dimensions of the input tensor by a factor of 2. The block 
        includes a single convolutional layer that maintains the same number of input and output channels. During inference, the input tensor undergoes 
        interpolation, doubling its size, before passing through the convolutional layer." />
        <br />
        <CodeBlock code={vqgan5} />
        <br />
        <Paragraph text="The `DownSampleBlock` class serves to downscale the spatial dimensions of input tensors. It comprises a single convolutional layer configured to reduce the input's spatial dimensions by a factor of 2 while maintaining the same number of input and output channels. " />
        <br />
        <CodeBlock code={vqgan6} />
        <br />
        <Paragraph text="This `NonLocalBlock` class implements an attention mechanism within our VQGAN. This mechanism enhances the model's ability to capture long-range dependencies and contextual information across spatial dimensions. " />
        <br />
        <CodeBlock code={vqgan7} />
        <br />
        <Paragraph text="The `Encoder` class compresses input images to one-fourth of their original size, in this case, encoding from `256x256` to `64x64` resolution. It 
        employs a series of convolutional layers, residual blocks, and attention mechanisms to progressively downsample the input while extracting meaningful features. By incorporating residual learning and attention mechanisms, this architecture effectively captures hierarchical features, enhancing 
        feature representation and extraction in the encoded latent space. " />
        <br />
        <CodeBlock code={vqgan8} />
        <br />
        <Paragraph text="This `Codebook` module discretizes continuous latent representations into discrete codes. It creates a dictionary of codebook 
        vectors, each representing a cluster of latent feature space. During training, the embedding layer learns to map input latent vectors to the nearest codebook 
        vectors. This quantization process enables efficient representation and compression of information. During inference, the encoder encodes data into continuous 
        latent vectors, which are then quantized by finding the closest codebook vectors. This quantization allows for faithful reconstructions while 
        facilitating efficient manipulation and compression of latent representations, unlike conventional autoencoders." />
        <br />
        <CodeBlock code={vqgan9} />
        <br />
        <Paragraph text="The `Decoder` class reconstructs images from compressed latent representations, gradually upscaling them to the original size while preserving 
        essential features. It utilizes convolutional layers, residual blocks, and attention mechanisms to enhance feature representation. Starting with a convolutional 
        layer to expand the latent dimension, the decoder iterates over the specified channels, applying additional residual and upsampling blocks to increase image 
        resolution. " />
        <br />
        <CodeBlock code={vqgan10} />
        <br />
        <Paragraph text="Now, we assemble the VQGAN, comprised of an encoder, a codebook, and a decoder. The encoder compresses input images into latent 
        representations, which are then quantized using a codebook mechanism. This quantization process ensures efficient representation while balancing 
        perceptual quality and realism. The decoder reconstructs images from the quantized latent representations, employing dynamic adjustments through the 
        `λ` coefficient to maintain a delicate balance between perceptual loss and GAN loss during training. With these components in place, the VQGAN is ready 
        to produce visually appealing and semantically meaningful images." />
        <br />

        <SubTitle Title="Implementing the discriminators architecture" />
        <br />
        <Paragraph text="This `Discriminator` class was designed to introduce the adversarial component to our VQGAN. It features a simple architecture aimed at 
        distinguishing between real and synthetic data. Through a sequence of convolutional layers, batch normalization, and leaky ReLU activations, it endeavors 
        to discern the authenticity of input images within the GAN framework. " />
        <br />
        <CodeBlock code={discriminator} />
        <br />

        <SubTitle Title="Implementing the perceptual loss module" />
        <br />
        <Paragraph text="This is a simple `VGGLoss` module that calculates the Mean Squared Error `(MSE)` loss between the features of the input and target images 
        extracted from a pretrained `VGG19` network. The `VGG19` network is truncated to include only the first 36 layers, which are used for feature extraction. " />
        <br />
        <CodeBlock code={perceptualLoss} />
        <br />

        <SubTitle Title="Preparing to train and setting hyperparameters" />
        <br />
        <Paragraph text="Now, we can prepare to commence training our VQGAN. Setting the hyperparameters accurately is crucial. I've achieved satisfactory results 
        with the following choices, but there's a chance that different hyperparameters might yield better outcomes. The `disc_start` parameter indicates the epoch 
        from which the adversarial loss will begin to influence the training process. Through my experimentation, I've observed that pre-training the model solely 
        with a basic perceptual loss and Mean Absolute Error `(MAE)` without adversarial loss tends to yield quicker results." />
        <br />
        <CodeBlock code='device = "cuda"' />
        <br />
        <CodeBlock code={models} />
        <br />
        <CodeBlock code={perceptual} />
        <br />
        <CodeBlock code={learningrates} />
        <br />
        <CodeBlock code={factors} />
        <br />
        <CodeBlock code={epochs} />
        <br />

        <SubTitle Title="Implementing the VQGAN training loop" />
        <br />
        <Paragraph text="In this training loop, we first iterate over each epoch, processing batches of images from the dataset. For each batch, we decode the images 
        using the `VQGAN` model and compute the quantization loss. Subsequently, we evaluate the decoded and original images, initially focusing solely on perceptual 
        loss and Mean Absolute Error `(MAE)` loss. After reaching a designated epoch `(disc_start)`, the adversarial component activates, integrating adversarial loss 
        into the training process. This involves training both the `VQGAN` model and the `discriminator` simultaneously to enhance the quality of generated images. 
        Throughout training, we dynamically adjust the balance between perceptual, MAE, and adversarial losses to optimize the model's performance." />
        <br />
        <CodeBlock code={trainingloop} />
        <br />
        <Paragraph text="Training a VQGAN can be very computationally expensive. If you only intend to test it without training, you can download my pre-trained weights through these code blocks:" />
        <br />
        <CodeBlock code={download} />
        <br />
        <CodeBlock code={load} />
        <br />

        <SubTitle Title="Testing our results and exploring the discrete latent space" />
        <br />
        <Paragraph text="We will conduct our tests on the validation subset of our dataset. To do this, we will repurpose the `CustomDataset` class we created previously." />
        <br />
        <CodeBlock code={dataval} />
        <br />
        <Paragraph text="The simplest test we can perform on our VQGAN is to compare the inputs to the outputs." />
        <br />
        <CodeBlock code={autoencode} />
        <div className='flex justify-center'>
          <img className='max-w-full h-48 sm:h-52 md:h-56 lg:h-56 mx-5' src={process.env.PUBLIC_URL + "/State-of-the-Art-Latent-Space-Quantization-with-VQ-GANs-from-Scratch-In-Pytorch-I/comparison.jpeg"} />
        </div>
        <div className='flex justify-center'>
          <img className='max-w-full h-48 sm:h-52 md:h-56 lg:h-56 mx-5' src={process.env.PUBLIC_URL + "/State-of-the-Art-Latent-Space-Quantization-with-VQ-GANs-from-Scratch-In-Pytorch-I/comparison_2.jpeg"} />
        </div>

        <br />
        <Paragraph text="Upon comparing both images, we can observe a slight loss of detail in our VQGAN model's output. Some details such as colors and certain 
        characteristics of the images did not translate well. It's common for these types of models to exhibit some information loss. Further training or increasing 
        the network's size may potentially address such flaws." />
        <br />
        <Paragraph text="Exploring the discrete latent space of our model can also be further achieved by interpolating latents between images and assessing the sensibility of the 
        resulting outputs. This capability demonstrates the unique advantages of VQGANs 
        in capturing and manipulating latent representations effectively." />
        <br />

        <CodeBlock code={interpolate} />
        <div className='flex justify-center'>
          <img className='max-w-full h-auto' src={process.env.PUBLIC_URL + "/State-of-the-Art-Latent-Space-Quantization-with-VQ-GANs-from-Scratch-In-Pytorch-I/steps_interpolation.jpeg"} />
        </div>
        <div className='flex justify-center'>
          <img className='max-w-full h-auto' src={process.env.PUBLIC_URL + "/State-of-the-Art-Latent-Space-Quantization-with-VQ-GANs-from-Scratch-In-Pytorch-I/steps_interpolation_1.jpeg"} />
        </div>
        <div className='flex justify-center'>
          <img className='max-w-full h-auto' src={process.env.PUBLIC_URL + "/State-of-the-Art-Latent-Space-Quantization-with-VQ-GANs-from-Scratch-In-Pytorch-I/steps_interpolation_2.jpeg"} />
        </div>
        <div className='flex justify-center'>
          <img className='max-w-full h-auto' src={process.env.PUBLIC_URL + "/State-of-the-Art-Latent-Space-Quantization-with-VQ-GANs-from-Scratch-In-Pytorch-I/steps_interpolation_3.jpeg"} />
        </div>
        <div className='flex justify-center'>
          <img className='max-w-full h-auto' src={process.env.PUBLIC_URL + "/State-of-the-Art-Latent-Space-Quantization-with-VQ-GANs-from-Scratch-In-Pytorch-I/steps_interpolation_4.jpeg"} />
        </div>
        <br />
        <Paragraph text="If we regard each interpolation step as a frame, we can construct a compelling animation illustrating the interpolation process within the 
        latent space. We can notice that the interpolation isn't perfectly smooth, which could be due to either having a too large latent dimension or having very 
        little data available. VQGANs are typically trained with large and diverse datasets, allowing their latent space to become well-encoded and meaningful." />
        <br />
        <div className='flex justify-center items-center'>
          <img className='max-w-full h-32 sm:h-36 md:h-40 lg:h-40 mx-5' src={process.env.PUBLIC_URL + "/State-of-the-Art-Latent-Space-Quantization-with-VQ-GANs-from-Scratch-In-Pytorch-I/interpolation_0.gif"} />
          <img className='max-w-full h-32 sm:h-36 md:h-40 lg:h-40 mx-5' src={process.env.PUBLIC_URL + "/State-of-the-Art-Latent-Space-Quantization-with-VQ-GANs-from-Scratch-In-Pytorch-I/interpolation_1.gif"} />
        </div>
        <br />

        <SubTitle Title="Summative insights and future considerations" />
        <br />
        <Paragraph text="You're now equipped to implement a VQGAN, which discretizes the latent space, potentially assisting you with future generative tasks. You 
        have gained a high-level understanding of how these networks operate. However, it's worth noting that while VQGANs are considered state-of-the-art, this 
        specific implementation may not reach that level of quality, as it lacks several advanced techniques found in modern VQGANs. To delve deeper into VQGANs, 
        I recommend exploring them further in the Hugging Face repository." />
        <Paragraph text="You could use this VQGAN implementation to implement a latent-GAN or latent-diffuser. Additionally, you could explore image-to-image techniques within the latent space and compare them against techniques operating in the image-space." />
        <br />

        <SubTitle Title="Resources used" />
        <br />
        <Paragraph text="▸ [1] Esser, P., Rombach, R., & Ommer, B. (2020). Taming Transformers for High-Resolution Image Synthesis. arXiv preprint `arXiv:2012.09841`. Retrieved from `https://arxiv.org/abs/2012.09841`" />
        <br />
        <Paragraph text="▸ [2] Rampas, D. VQGAN-pytorch. GitHub. Retrieved from `https://github.com/dome272/VQGAN-pytorch`" />
        <br />
        <Paragraph text="▸ [3] van den Oord, A., Vinyals, O., & Kavukcuoglu, K. (2017). Neural Discrete Representation Learning. arXiv preprint `arXiv:1711.00937`. Retrieved from `https://arxiv.org/abs/1711.00937`" />
        <br />
        <Paragraph text="▸ [4] Kaggle. Animal Faces. Kaggle. Retrieved from `https://www.kaggle.com / datasets / andrewmvd / animal-faces`" />
        <br />
        <Paragraph text="▸ [5] Gordić, A. VQ-GAN: Taming transformers for high-resolution image synthesis | Paper explained. YouTube. The AI Epiphany. Retrieved from `https:// www.youtube.com / watch?v=FVxrHWuvWiA`" />
        <br />

        <DownSpace />
    </div>
  );
}

export default VQGAN;
