Skip to content

Commit

Permalink
add bigbigan env, wrapper and translation config
Browse files Browse the repository at this point in the history
  • Loading branch information
pesser committed Dec 9, 2020
1 parent 9ce9b31 commit 8d631d3
Show file tree
Hide file tree
Showing 3 changed files with 154 additions and 0 deletions.
39 changes: 39 additions & 0 deletions configs/translation/sbert-to-bigbigan.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
model:
base_learning_rate: 4.5e-6
target: net2net.models.flows.flow.Net2NetFlow
params:
first_stage_key: "image"
cond_stage_key: "caption"
flow_config:
target: net2net.modules.flow.flatflow.ConditionalFlatCouplingFlow
params:
conditioning_dim: 1024
embedding_dim: 256
conditioning_depth: 2
n_flows: 24
in_channels: 120
hidden_dim: 1024
hidden_depth: 2
activation: "none"
conditioner_use_bn: True

cond_stage_config:
target: net2net.modules.sbert.model.SentenceEmbedder
params:
version: "bert-large-nli-stsb-mean-tokens"

first_stage_config:
target: net2net.modules.gan.bigbigan.BigBiGAN

data:
target: translation.DataModuleFromConfig
params:
batch_size: 16
train:
target: net2net.data.coco.CocoImagesAndCaptionsTrain
params:
size: 256
validation:
target: net2net.data.coco.CocoImagesAndCaptionsValidation
params:
size: 256
27 changes: 27 additions & 0 deletions env_bigbigan.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
name: net2net_bigbigan
channels:
- pytorch
- defaults
dependencies:
- python=3.7
- pip=19.3
- cudatoolkit=10.1
- cudnn=7.6.5
- pytorch=1.6
- torchvision=0.7
- numpy=1.18
- pip:
- albumentations==0.4.3
- opencv-python==4.1.2.30
- pudb==2019.2
- imageio==2.9.0
- imageio-ffmpeg==0.4.2
- plotly==4.9.0
- pytorch-lightning==0.9.0
- omegaconf==2.0.0
- streamlit==0.71.0
- test-tube>=0.7.5
- sentence-transformers>=0.3.8
- tensorflow==2.3.1
- tensorflow-hub==0.10.0
- -e .
88 changes: 88 additions & 0 deletions net2net/modules/gan/bigbigan.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
import torch
import numpy as np

import tensorflow.compat.v1 as tf
tf.disable_v2_behavior()

import tensorflow_hub as hub


class BigBiGAN(object):
def __init__(self,
module_path='https://tfhub.dev/deepmind/bigbigan-resnet50/1',
allow_growth=True):
"""Initialize a BigBiGAN from the given TF Hub module."""
self._module = hub.Module(module_path)

# encode graph
self.enc_ph = self.make_encoder_ph()
self.z_sample = self.encode_graph(self.enc_ph)
self.z_mean = self.encode_graph(self.enc_ph, return_all_features=True)['z_mean']

# decode graph
self.gen_ph = self.make_generator_ph()
self.gen_samples = self.generate_graph(self.gen_ph, upsample=True)

# session
init = tf.global_variables_initializer()
gpu_options = tf.GPUOptions(allow_growth=allow_growth)
self.sess = tf.Session(config=tf.ConfigProto(gpu_options=gpu_options))
self.sess.run(init)

def generate_graph(self, z, upsample=False):
"""Run a batch of latents z through the generator to generate images.
Args:
z: A batch of 120D Gaussian latents, shape [N, 120].
Returns: a batch of generated RGB images, shape [N, 128, 128, 3], range
[-1, 1].
"""
outputs = self._module(z, signature='generate', as_dict=True)
return outputs['upsampled' if upsample else 'default']

def make_generator_ph(self):
"""Creates a tf.placeholder with the dtype & shape of generator inputs."""
info = self._module.get_input_info_dict('generate')['z']
return tf.placeholder(dtype=info.dtype, shape=info.get_shape())

def encode_graph(self, x, return_all_features=False):
"""Run a batch of images x through the encoder.
Args:
x: A batch of data (256x256 RGB images), shape [N, 256, 256, 3], range
[-1, 1].
return_all_features: If True, return all features computed by the encoder.
Otherwise (default) just return a sample z_hat.
Returns: the sample z_hat of shape [N, 120] (or a dict of all features if
return_all_features).
"""
outputs = self._module(x, signature='encode', as_dict=True)
return outputs if return_all_features else outputs['z_sample']

def make_encoder_ph(self):
"""Creates a tf.placeholder with the dtype & shape of encoder inputs."""
info = self._module.get_input_info_dict('encode')['x']
return tf.placeholder(dtype=info.dtype, shape=info.get_shape())

@torch.no_grad()
def encode(self, x_torch):
x_np = x_torch.detach().permute(0,2,3,1).cpu().numpy()
feed_dict = {self.enc_ph: x_np}
z = self.sess.run(self.z_sample, feed_dict=feed_dict)
z_torch = torch.tensor(z).to(device=x_torch.device)
return z_torch.unsqueeze(-1).unsqueeze(-1)

@torch.no_grad()
def decode(self, z_torch):
z_np = z_torch.detach().squeeze(-1).squeeze(-1).cpu().numpy()
feed_dict = {self.gen_ph: z_np}
x = self.sess.run(self.gen_samples, feed_dict=feed_dict)
x = x.transpose(0,3,1,2)
x_torch = torch.tensor(x).to(device=z_torch.device)
return x_torch

def eval(self):
# interface requirement
return self

0 comments on commit 8d631d3

Please sign in to comment.