-
Notifications
You must be signed in to change notification settings - Fork 21
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
add bigbigan env, wrapper and translation config
- Loading branch information
Showing
3 changed files
with
154 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
model: | ||
base_learning_rate: 4.5e-6 | ||
target: net2net.models.flows.flow.Net2NetFlow | ||
params: | ||
first_stage_key: "image" | ||
cond_stage_key: "caption" | ||
flow_config: | ||
target: net2net.modules.flow.flatflow.ConditionalFlatCouplingFlow | ||
params: | ||
conditioning_dim: 1024 | ||
embedding_dim: 256 | ||
conditioning_depth: 2 | ||
n_flows: 24 | ||
in_channels: 120 | ||
hidden_dim: 1024 | ||
hidden_depth: 2 | ||
activation: "none" | ||
conditioner_use_bn: True | ||
|
||
cond_stage_config: | ||
target: net2net.modules.sbert.model.SentenceEmbedder | ||
params: | ||
version: "bert-large-nli-stsb-mean-tokens" | ||
|
||
first_stage_config: | ||
target: net2net.modules.gan.bigbigan.BigBiGAN | ||
|
||
data: | ||
target: translation.DataModuleFromConfig | ||
params: | ||
batch_size: 16 | ||
train: | ||
target: net2net.data.coco.CocoImagesAndCaptionsTrain | ||
params: | ||
size: 256 | ||
validation: | ||
target: net2net.data.coco.CocoImagesAndCaptionsValidation | ||
params: | ||
size: 256 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
name: net2net_bigbigan | ||
channels: | ||
- pytorch | ||
- defaults | ||
dependencies: | ||
- python=3.7 | ||
- pip=19.3 | ||
- cudatoolkit=10.1 | ||
- cudnn=7.6.5 | ||
- pytorch=1.6 | ||
- torchvision=0.7 | ||
- numpy=1.18 | ||
- pip: | ||
- albumentations==0.4.3 | ||
- opencv-python==4.1.2.30 | ||
- pudb==2019.2 | ||
- imageio==2.9.0 | ||
- imageio-ffmpeg==0.4.2 | ||
- plotly==4.9.0 | ||
- pytorch-lightning==0.9.0 | ||
- omegaconf==2.0.0 | ||
- streamlit==0.71.0 | ||
- test-tube>=0.7.5 | ||
- sentence-transformers>=0.3.8 | ||
- tensorflow==2.3.1 | ||
- tensorflow-hub==0.10.0 | ||
- -e . |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,88 @@ | ||
import torch | ||
import numpy as np | ||
|
||
import tensorflow.compat.v1 as tf | ||
tf.disable_v2_behavior() | ||
|
||
import tensorflow_hub as hub | ||
|
||
|
||
class BigBiGAN(object): | ||
def __init__(self, | ||
module_path='https://tfhub.dev/deepmind/bigbigan-resnet50/1', | ||
allow_growth=True): | ||
"""Initialize a BigBiGAN from the given TF Hub module.""" | ||
self._module = hub.Module(module_path) | ||
|
||
# encode graph | ||
self.enc_ph = self.make_encoder_ph() | ||
self.z_sample = self.encode_graph(self.enc_ph) | ||
self.z_mean = self.encode_graph(self.enc_ph, return_all_features=True)['z_mean'] | ||
|
||
# decode graph | ||
self.gen_ph = self.make_generator_ph() | ||
self.gen_samples = self.generate_graph(self.gen_ph, upsample=True) | ||
|
||
# session | ||
init = tf.global_variables_initializer() | ||
gpu_options = tf.GPUOptions(allow_growth=allow_growth) | ||
self.sess = tf.Session(config=tf.ConfigProto(gpu_options=gpu_options)) | ||
self.sess.run(init) | ||
|
||
def generate_graph(self, z, upsample=False): | ||
"""Run a batch of latents z through the generator to generate images. | ||
Args: | ||
z: A batch of 120D Gaussian latents, shape [N, 120]. | ||
Returns: a batch of generated RGB images, shape [N, 128, 128, 3], range | ||
[-1, 1]. | ||
""" | ||
outputs = self._module(z, signature='generate', as_dict=True) | ||
return outputs['upsampled' if upsample else 'default'] | ||
|
||
def make_generator_ph(self): | ||
"""Creates a tf.placeholder with the dtype & shape of generator inputs.""" | ||
info = self._module.get_input_info_dict('generate')['z'] | ||
return tf.placeholder(dtype=info.dtype, shape=info.get_shape()) | ||
|
||
def encode_graph(self, x, return_all_features=False): | ||
"""Run a batch of images x through the encoder. | ||
Args: | ||
x: A batch of data (256x256 RGB images), shape [N, 256, 256, 3], range | ||
[-1, 1]. | ||
return_all_features: If True, return all features computed by the encoder. | ||
Otherwise (default) just return a sample z_hat. | ||
Returns: the sample z_hat of shape [N, 120] (or a dict of all features if | ||
return_all_features). | ||
""" | ||
outputs = self._module(x, signature='encode', as_dict=True) | ||
return outputs if return_all_features else outputs['z_sample'] | ||
|
||
def make_encoder_ph(self): | ||
"""Creates a tf.placeholder with the dtype & shape of encoder inputs.""" | ||
info = self._module.get_input_info_dict('encode')['x'] | ||
return tf.placeholder(dtype=info.dtype, shape=info.get_shape()) | ||
|
||
@torch.no_grad() | ||
def encode(self, x_torch): | ||
x_np = x_torch.detach().permute(0,2,3,1).cpu().numpy() | ||
feed_dict = {self.enc_ph: x_np} | ||
z = self.sess.run(self.z_sample, feed_dict=feed_dict) | ||
z_torch = torch.tensor(z).to(device=x_torch.device) | ||
return z_torch.unsqueeze(-1).unsqueeze(-1) | ||
|
||
@torch.no_grad() | ||
def decode(self, z_torch): | ||
z_np = z_torch.detach().squeeze(-1).squeeze(-1).cpu().numpy() | ||
feed_dict = {self.gen_ph: z_np} | ||
x = self.sess.run(self.gen_samples, feed_dict=feed_dict) | ||
x = x.transpose(0,3,1,2) | ||
x_torch = torch.tensor(x).to(device=z_torch.device) | ||
return x_torch | ||
|
||
def eval(self): | ||
# interface requirement | ||
return self |