This library implements some of the most common (Variational) Autoencoder models under a unified implementation. In particular, it provides the possibility to perform benchmark experiments and comparisons by training the models with the same autoencoding neural network architecture. The feature make your own autoencoder allows you to train any of these models with your own data and own Encoder and Decoder neural networks. It integrates experiment monitoring tools such wandb, mlflow or comet-ml 🧪 and allows model sharing and loading from the HuggingFace Hub 🤗 in a few lines of code.
- Installation
- Implemented models / Implemented samplers
- Reproducibility statement / Results flavor
- Model training / Data generation / Custom network architectures
- Model sharing with 🤗 Hub / Experiment tracking with
wandb
/ Experiment tracking withmlflow
/ Experiment tracking withcomet_ml
- Tutorials / Documentation
- Contributing 🚀 / Issues 🛠️
- Citing this repository
To install the latest stable release of this library run the following using pip
$ pip install pythae
To install the latest github version of this library run the following using pip
$ pip install git+https://github.com/clementchadebec/benchmark_VAE.git
or alternatively you can clone the github repo to access to tests, tutorials and scripts.
$ git clone https://github.com/clementchadebec/benchmark_VAE.git
and install the library
$ cd benchmark_VAE
$ pip install -e .
Below is the list of the models currently implemented in the library.
Models | Training example | Paper | Official Implementation |
---|---|---|---|
Autoencoder (AE) | |||
Variational Autoencoder (VAE) | link | ||
Beta Variational Autoencoder (BetaVAE) | link | ||
VAE with Linear Normalizing Flows (VAE_LinNF) | link | ||
VAE with Inverse Autoregressive Flows (VAE_IAF) | link | link | |
Disentangled Beta Variational Autoencoder (DisentangledBetaVAE) | link | ||
Disentangling by Factorising (FactorVAE) | link | ||
Beta-TC-VAE (BetaTCVAE) | link | link | |
Importance Weighted Autoencoder (IWAE) | link | link | |
Multiply Importance Weighted Autoencoder (MIWAE) | link | ||
Partially Importance Weighted Autoencoder (PIWAE) | link | ||
Combination Importance Weighted Autoencoder (CIWAE) | link | ||
VAE with perceptual metric similarity (MSSSIM_VAE) | link | ||
Wasserstein Autoencoder (WAE) | link | link | |
Info Variational Autoencoder (INFOVAE_MMD) | link | ||
VAMP Autoencoder (VAMP) | link | link | |
Hyperspherical VAE (SVAE) | link | link | |
Poincaré Disk VAE (PoincareVAE) | link | link | |
Adversarial Autoencoder (Adversarial_AE) | link | ||
Variational Autoencoder GAN (VAEGAN) 🥗 | link | link | |
Vector Quantized VAE (VQVAE) | link | link | |
Hamiltonian VAE (HVAE) | link | link | |
Regularized AE with L2 decoder param (RAE_L2) | link | link | |
Regularized AE with gradient penalty (RAE_GP) | link | link | |
Riemannian Hamiltonian VAE (RHVAE) | link | link |
See reconstruction and generation results for all aforementionned models
Below is the list of the models currently implemented in the library.
Samplers | Models | Paper | Official Implementation |
---|---|---|---|
Normal prior (NormalSampler) | all models | link | |
Gaussian mixture (GaussianMixtureSampler) | all models | link | link |
Two stage VAE sampler (TwoStageVAESampler) | all VAE based models | link | link |
Unit sphere uniform sampler (HypersphereUniformSampler) | SVAE | link | link |
Poincaré Disk sampler (PoincareDiskSampler) | PoincareVAE | link | link |
VAMP prior sampler (VAMPSampler) | VAMP | link | link |
Manifold sampler (RHVAESampler) | RHVAE | link | link |
Masked Autoregressive Flow Sampler (MAFSampler) | all models | link | link |
Inverse Autoregressive Flow Sampler (IAFSampler) | all models | link | link |
PixelCNN (PixelCNNSampler) | VQVAE | link |
We validate the implementations by reproducing some results presented in the original publications when the official code has been released or when enough details about the experimental section of the papers were available. See reproducibility for more details.
To launch a model training, you only need to call a TrainingPipeline
instance.
>>> from pythae.pipelines import TrainingPipeline
>>> from pythae.models import VAE, VAEConfig
>>> from pythae.trainers import BaseTrainerConfig
>>> # Set up the training configuration
>>> my_training_config = BaseTrainerConfig(
... output_dir='my_model',
... num_epochs=50,
... learning_rate=1e-3,
... batch_size=200,
... steps_saving=None
... )
>>> # Set up the model configuration
>>> my_vae_config = model_config = VAEConfig(
... input_dim=(1, 28, 28),
... latent_dim=10
... )
>>> # Build the model
>>> my_vae_model = VAE(
... model_config=my_vae_config
... )
>>> # Build the Pipeline
>>> pipeline = TrainingPipeline(
... training_config=my_training_config,
... model=my_vae_model
... )
>>> # Launch the Pipeline
>>> pipeline(
... train_data=your_train_data, # must be torch.Tensor, np.array or torch datasets
... eval_data=your_eval_data # must be torch.Tensor, np.array or torch datasets
... )
At the end of training, the best model weights, model configuration and training configuration are stored in a final_model
folder available in my_model/MODEL_NAME_training_YYYY-MM-DD_hh-mm-ss
(with my_model
being the output_dir
argument of the BaseTrainerConfig
). If you further set the steps_saving
argument to a certain value, folders named checkpoint_epoch_k
containing the best model weights, optimizer, scheduler, configuration and training configuration at epoch k will also appear in my_model/MODEL_NAME_training_YYYY-MM-DD_hh-mm-ss
.
We also provide a training script example here that can be used to train the models on benchmarks datasets (mnist, cifar10, celeba ...). The script can be launched with the following commandline
python training.py --dataset mnist --model_name ae --model_config 'configs/ae_config.json' --training_config 'configs/base_training_config.json'
See README.md for further details on this script
The easiest way to launch a data generation from a trained model consists in using the built-in GenerationPipeline
provided in Pythae. Say you want to generate 100 samples using a MAFSampler
all you have to do is 1) relaod the trained model, 2) define the sampler's configuration and 3) create and launch the GenerationPipeline
as follows
>>> from pythae.models import AutoModel
>>> from pythae.samplers import MAFSamplerConfig
>>> from pythae.pipelines import GenerationPipeline
>>> # Retrieve the trained model
>>> my_trained_vae = AutoModel.load_from_folder(
... 'path/to/your/trained/model'
... )
>>> my_sampler_config = MAFSamplerConfig(
... n_made_blocks=2,
... n_hidden_in_made=3,
... hidden_size=128
... )
>>> # Build the pipeline
>>> pipe = GenerationPipeline(
... model=my_trained_vae,
... sampler_config=my_sampler_config
... )
>>> # Launch data generation
>>> generated_samples = pipe(
... num_samples=args.num_samples,
... return_gen=True, # If false returns nothing
... train_data=train_data, # Needed to fit the sampler
... eval_data=eval_data, # Needed to fit the sampler
... training_config=BaseTrainerConfig(num_epochs=200) # TrainingConfig to use to fit the sampler
... )
Alternatively, you can launch the data generation process from a trained model directly with the sampler. For instance, to generate new data with your sampler, run the following.
>>> from pythae.models import AutoModel
>>> from pythae.samplers import NormalSampler
>>> # Retrieve the trained model
>>> my_trained_vae = AutoModel.load_from_folder(
... 'path/to/your/trained/model'
... )
>>> # Define your sampler
>>> my_samper = NormalSampler(
... model=my_trained_vae
... )
>>> # Generate samples
>>> gen_data = my_samper.sample(
... num_samples=50,
... batch_size=10,
... output_dir=None,
... return_gen=True
... )
If you set output_dir
to a specific path, the generated images will be saved as .png
files named 00000000.png
, 00000001.png
...
The samplers can be used with any model as long as it is suited. For instance, a GaussianMixtureSampler
instance can be used to generate from any model but a VAMPSampler
will only be usable with a VAMP
model. Check here to see which ones apply to your model. Be carefull that some samplers such as the GaussianMixtureSampler
for instance may need to be fitted by calling the fit
method before using. Below is an example for the GaussianMixtureSampler
.
>>> from pythae.models import AutoModel
>>> from pythae.samplers import GaussianMixtureSampler, GaussianMixtureSamplerConfig
>>> # Retrieve the trained model
>>> my_trained_vae = AutoModel.load_from_folder(
... 'path/to/your/trained/model'
... )
>>> # Define your sampler
... gmm_sampler_config = GaussianMixtureSamplerConfig(
... n_components=10
... )
>>> my_samper = GaussianMixtureSampler(
... sampler_config=gmm_sampler_config,
... model=my_trained_vae
... )
>>> # fit the sampler
>>> gmm_sampler.fit(train_dataset)
>>> # Generate samples
>>> gen_data = my_samper.sample(
... num_samples=50,
... batch_size=10,
... output_dir=None,
... return_gen=True
... )
Pythae provides you the possibility to define your own neural networks within the VAE models. For instance, say you want to train a Wassertstein AE with a specific encoder and decoder, you can do the following:
>>> from pythae.models.nn import BaseEncoder, BaseDecoder
>>> from pythae.models.base.base_utils import ModelOutput
>>> class My_Encoder(BaseEncoder):
... def __init__(self, args=None): # Args is a ModelConfig instance
... BaseEncoder.__init__(self)
... self.layers = my_nn_layers()
...
... def forward(self, x:torch.Tensor) -> ModelOutput:
... out = self.layers(x)
... output = ModelOutput(
... embedding=out # Set the output from the encoder in a ModelOutput instance
... )
... return output
...
... class My_Decoder(BaseDecoder):
... def __init__(self, args=None):
... BaseDecoder.__init__(self)
... self.layers = my_nn_layers()
...
... def forward(self, x:torch.Tensor) -> ModelOutput:
... out = self.layers(x)
... output = ModelOutput(
... reconstruction=out # Set the output from the decoder in a ModelOutput instance
... )
... return output
...
>>> my_encoder = My_Encoder()
>>> my_decoder = My_Decoder()
And now build the model
>>> from pythae.models import WAE_MMD, WAE_MMD_Config
>>> # Set up the model configuration
>>> my_wae_config = model_config = WAE_MMD_Config(
... input_dim=(1, 28, 28),
... latent_dim=10
... )
...
>>> # Build the model
>>> my_wae_model = WAE_MMD(
... model_config=my_wae_config,
... encoder=my_encoder, # pass your encoder as argument when building the model
... decoder=my_decoder # pass your decoder as argument when building the model
... )
important note 1: For all AE-based models (AE, WAE, RAE_L2, RAE_GP), both the encoder and decoder must return a ModelOutput
instance. For the encoder, the ModelOutput
instance must contain the embbeddings under the key embedding
. For the decoder, the ModelOutput
instance must contain the reconstructions under the key reconstruction
.
important note 2: For all VAE-based models (VAE, BetaVAE, IWAE, HVAE, VAMP, RHVAE), both the encoder and decoder must return a ModelOutput
instance. For the encoder, the ModelOutput
instance must contain the embbeddings and log-covariance matrices (of shape batch_size x latent_space_dim) respectively under the key embedding
and log_covariance
key. For the decoder, the ModelOutput
instance must contain the reconstructions under the key reconstruction
.
You can also find predefined neural network architectures for the most common data sets (i.e. MNIST, CIFAR, CELEBA ...) that can be loaded as follows
>>> from pythae.models.nn.benchmark.mnist import (
... Encoder_Conv_AE_MNIST, # For AE based model (only return embeddings)
... Encoder_Conv_VAE_MNIST, # For VAE based model (return embeddings and log_covariances)
... Decoder_Conv_AE_MNIST
... )
Replace mnist by cifar or celeba to access to other neural nets.
Pythae also allows you to share your models on the HuggingFace Hub. To do so you need:
- a valid HuggingFace account
- the package
huggingface_hub
installed in your virtual env. If not you can install it with
$ python -m pip install huggingface_hub
- to be logged in to your HuggingFace account using
$ huggingface-cli login
Any pythae model can be easily uploaded using the method push_to_hf_hub
>>> my_vae_model.push_to_hf_hub(hf_hub_path="your_hf_username/your_hf_hub_repo")
Note: If your_hf_hub_repo
already exists and is not empty, files will be overridden. In case,
the repo your_hf_hub_repo
does not exist, a folder having the same name will be created.
Equivalently, you can download or reload any Pythae's model directly from the Hub using the method load_from_hf_hub
>>> from pythae.models import AutoModel
>>> my_downloaded_vae = AutoModel.load_from_hf_hub(hf_hub_path="path_to_hf_repo")
Pythae also integrates the experiment tracking tool wandb allowing users to store their configs, monitor their trainings and compare runs through a graphic interface. To be able use this feature you will need:
- a valid wandb account
- the package
wandb
installed in your virtual env. If not you can install it with
$ pip install wandb
- to be logged in to your wandb account using
$ wandb login
Launching an experiment monitoring with wandb
in pythae is pretty simple. The only thing a user needs to do is create a WandbCallback
instance...
>>> # Create you callback
>>> from pythae.trainers.training_callbacks import WandbCallback
>>> callbacks = [] # the TrainingPipeline expects a list of callbacks
>>> wandb_cb = WandbCallback() # Build the callback
>>> # SetUp the callback
>>> wandb_cb.setup(
... training_config=your_training_config, # training config
... model_config=your_model_config, # model config
... project_name="your_wandb_project", # specify your wandb project
... entity_name="your_wandb_entity", # specify your wandb entity
... )
>>> callbacks.append(wandb_cb) # Add it to the callbacks list
...and then pass it to the TrainingPipeline
.
>>> pipeline = TrainingPipeline(
... training_config=config,
... model=model
... )
>>> pipeline(
... train_data=train_dataset,
... eval_data=eval_dataset,
... callbacks=callbacks # pass the callbacks to the TrainingPipeline and you are done!
... )
>>> # You can log to https://wandb.ai/your_wandb_entity/your_wandb_project to monitor your training
See the detailed tutorial
Pythae also integrates the experiment tracking tool mlflow allowing users to store their configs, monitor their trainings and compare runs through a graphic interface. To be able use this feature you will need:
- the package
mlfow
installed in your virtual env. If not you can install it with
$ pip install mlflow
Launching an experiment monitoring with mlfow
in pythae is pretty simple. The only thing a user needs to do is create a MLFlowCallback
instance...
>>> # Create you callback
>>> from pythae.trainers.training_callbacks import MLFlowCallback
>>> callbacks = [] # the TrainingPipeline expects a list of callbacks
>>> mlflow_cb = MLFlowCallback() # Build the callback
>>> # SetUp the callback
>>> mlflow_cb.setup(
... training_config=your_training_config, # training config
... model_config=your_model_config, # model config
... run_name="mlflow_cb_example", # specify your mlflow run
... )
>>> callbacks.append(mlflow_cb) # Add it to the callbacks list
...and then pass it to the TrainingPipeline
.
>>> pipeline = TrainingPipeline(
... training_config=config,
... model=model
... )
>>> pipeline(
... train_data=train_dataset,
... eval_data=eval_dataset,
... callbacks=callbacks # pass the callbacks to the TrainingPipeline and you are done!
... )
you can visualize your metric by running the following in the directory where the ./mlruns
$ mlflow ui
See the detailed tutorial
Pythae also integrates the experiment tracking tool comet_ml allowing users to store their configs, monitor their trainings and compare runs through a graphic interface. To be able use this feature you will need:
- the package
comet_ml
installed in your virtual env. If not you can install it with
$ pip install comet_ml
Launching an experiment monitoring with comet_ml
in pythae is pretty simple. The only thing a user needs to do is create a CometCallback
instance...
>>> # Create you callback
>>> from pythae.trainers.training_callbacks import CometCallback
>>> callbacks = [] # the TrainingPipeline expects a list of callbacks
>>> comet_cb = CometCallback() # Build the callback
>>> # SetUp the callback
>>> comet_cb.setup(
... training_config=training_config, # training config
... model_config=model_config, # model config
... api_key="your_comet_api_key", # specify your comet api-key
... project_name="your_comet_project", # specify your wandb project
... #offline_run=True, # run in offline mode
... #offline_directory='my_offline_runs' # set the directory to store the offline runs
... )
>>> callbacks.append(comet_cb) # Add it to the callbacks list
...and then pass it to the TrainingPipeline
.
>>> pipeline = TrainingPipeline(
... training_config=config,
... model=model
... )
>>> pipeline(
... train_data=train_dataset,
... eval_data=eval_dataset,
... callbacks=callbacks # pass the callbacks to the TrainingPipeline and you are done!
... )
>>> # You can log to https://comet.com/your_comet_username/your_comet_project to monitor your training
See the detailed tutorial
To help you to understand the way pythae works and how you can train your models with this library we also provide tutorials:
-
making_your_own_autoencoder.ipynb shows you how to pass your own networks to the models implemented in pythae
-
custom_dataset.ipynb shows you how to use custom datasets with any of the models implemented in pythae
-
hf_hub_models_sharing.ipynb shows you how to upload and download models for the HuggingFace Hub
-
wandb_experiment_monitoring.ipynb shows you how to monitor you experiments using
wandb
-
mlflow_experiment_monitoring.ipynb shows you how to monitor you experiments using
mlflow
-
comet_experiment_monitoring.ipynb shows you how to monitor you experiments using
comet_ml
-
models_training folder provides notebooks showing how to train each implemented model and how to sample from it using
pythae.samplers
. -
scripts folder provides in particular an example of a training script to train the models on benchmark data sets (mnist, cifar10, celeba ...)
If you are experiencing any issues while running the code or request new features/models to be implemented please open an issue on github.
You want to contribute to this library by adding a model, a sampler or simply fix a bug ? That's awesome! Thank you! Please see CONTRIBUTING.md to follow the main contributing guidelines.
First let's have a look at the reconstructed samples taken from the evaluation set.
Here, we show the generated samples using using each model implemented in the library and different samplers.
If you find this work useful or use it in your research, please consider citing us
@article{chadebec2022pythae,
title={Pythae: Unifying Generative Autoencoders in Python -- A Benchmarking Use Case},
author={Chadebec, Clément and Vincent, Louis J. and Allassonnière, Stéphanie},
journal={arXiv preprint arXiv:2206.08309},
url = {https://arxiv.org/abs/2206.08309},
year = {2022}
}