Rejection Sampling IMLE: Designing Priors for Better Few-Shot Image Synthesis

Simon Fraser University
European Conference on Computer Vision (ECCV) 2024


TLDR: We identity an issue with the current IMLE-based methods and propose a novel approach to address it. We achieve state-of-the-art performance on few-shot image generation tasks.

Abstract

An emerging area of research aims to learn deep generative models with limited training data. Implicit Maximum Likelihood Estimation (IMLE), a recent technique, successfully addresses the mode collapse issue of GANs and has been adapted to the few-shot setting, achieving state-of-the-art performance. However, current IMLE-based approaches encounter challenges due to inadequate correspondence between the latent codes selected for training and those drawn during inference. This results in suboptimal test-time performance. We theoretically show a way to address this issue and propose RS-IMLE, a novel approach that changes the prior distribution used for training. This leads to substantially higher quality image generation compared to existing GAN and IMLE-based methods, as validated by comprehensive experiments conducted on nine few-shot image datasets.

Why is Few-Shot Generation Challenging?

Generative models that perform well in the large-scale setting, do not perform well in the few-shot setting.

Diffusion models

In diffusion models, the marginal likelihood under the forward process is a mixture of isotropic Gaussians. This modeling assumption smooths out the learned manifold along all directions, including those that are orthogonal to the actual data manifold. This becomes particularly problematic when there are a limited number of training examples.
Suppose we have a dataset of 10K examples. We can train a diffusion model on this dataset and sample from it.

Dataset with 10K points

Denoising process for 10K data points

Consider what happens when we have the same shape but with just 20 data points. Now we train the same diffusion model on this dataset.
The performance of diffusion models degrades significantly in the few-shot setting. The model is not be able to learn the data manifold and generates samples that are far from the data manifold.

Dataset with 20 points

Denoising process for 20 data points

Generative Adversarial Network

GANs suffer from mode collapse and training instability. When mode collapse occurs, the modelled distribution only covers a subset of the modes.
Implicit Maximum Likelihood Estimation is an alternative to the GAN objective and has shown promising results in addressing mode collapse.

GAN training: mode collapse

IMLE training

The Misalignment Issue

In the existing IMLE-based methods, we observe that the distributions of the latent codes used for training the objective differs from the distribution of latent encountered at test time. Consider an illustrative example where the latent space is two dimensional. We train a simple generative model using IMLE on two dimensional toy dataset. The latent codes used for training over the course of training are illustrated below.

Latent codes selected by IMLE objective during training

Latent codes sampled at test time

We notice that for the latent codes belonging to the same data point (denoted by the same colour) form well-separated tight bands in the latent space. We also observe that there are large gaps between these bands, indicating that these segments of the latent space are consistently overlooked during training.
Since at test time we sample from the same standard normal distribution, these unsupervised segments in the latent space have arbitrary outputs, which result in bad samples.

Methodology

In our approach, we propose to change the prior distribution used for training the model. We design a prior such that all samples we obtain by using the prior are guaranteed to be some distance (ε) away from all data points.
Since we reject samples that are too close to the data points, we call our method Rejection Sampling IMLE (RS-IMLE).
Note: 🟩 denotes data points and 🟠 denote samples from the model.

IMLE

RS-IMLE

Our method ensures that the loss for each data point is always sufficiently high (indicated by the long arrows), resulting in meaningful updates to the model parameters. Here is a video of the training process of the models trained by IMLE and RS-IMLE.

IMLE

RS-IMLE

We can also analyze the latent space of model trained by the respective objectives. We observe that for our method over the course of training, samples latent codes that follow the distribution at test time more faithfully.


Latent codes selected by IMLE objective during training

Latent codes selected by RS-IMLE objective during training

Frechet Inception Distance

We present the FID scores computed for all the datasets across different methods. Lower FID scores indicates that the distribution of generated images is closer to the distribution of real images.
Our method performs significantly better compared to baselines. We report an average improvement of 45.9% over the best baseline.

Dataset FastGAN FakeCLR FreGAN ReGAN AdaIMLE RS-IMLE
Obama 41.1 29.9 33.4 45.7 25.0 14
Grumpy Cat 26.6 20.6 24.9 27.3 19.1 11.5
Panda 10.0 8.8 9.0 12.6 7.6 3.5
FFHQ-100 54.2 62.1 50.5 87.4 33.2 12.9
Cat 35.1 27.4 31.0 42.1 24.9 15.9
Dog 50.7 44.4 47.9 57.2 43.0 23.1
Anime 69.8 77.7 59.8 110.8 65.8 35.8
Skulls 109.6 106.5 163.3 130.7 81.9 51.1
Shells 120.9 148.4 169.3 236.1 108.5 55.4

Visual Recall Test

First column is the query image from the dataset. Subsequent columns are the samples produced by different methods that are closest to the query image in LPIPS feature space.
Note that the samples from our method are:

  1. Realistic (indicating high precision)
  2. Closer to the query (indicating high recall)
  3. Diverse (indicating that the model is not overfitting)

You can find more examples in the main paper and the supplementary material.

Video Presentation

BibTeX


        @inproceedings{vashist2024rejectionsamplingimledesigning,
          title = {Rejection Sampling IMLE: Designing Priors for Better Few-Shot Image Synthesis},
          author = {Chirag Vashist and Shichong Peng and Ke Li},
          booktitle = {Proceedings of the European Conference on Computer Vision (ECCV)},
          year = {2024}
        }