From 24571fbf9ee447ccaf677c2f51c0b3d8a1d68cc3 Mon Sep 17 00:00:00 2001 From: Richard Shin Date: Thu, 29 Jun 2017 18:35:59 -0700 Subject: [PATCH] Add image augmentation for CIFAR-10 --- tensor2tensor/models/common_layers.py | 11 +++++++++++ tensor2tensor/utils/data_reader.py | 6 ++++++ 2 files changed, 17 insertions(+) diff --git a/tensor2tensor/models/common_layers.py b/tensor2tensor/models/common_layers.py index 3ef84f27c..36d9b0b51 100644 --- a/tensor2tensor/models/common_layers.py +++ b/tensor2tensor/models/common_layers.py @@ -132,6 +132,17 @@ def image_augmentation(images, do_colors=False): return images +def cifar_image_augmentation(images): + """Image augmentation suitable for CIFAR-10/100. + + As described in https://arxiv.org/pdf/1608.06993v3.pdf (page 5).""" + images = tf.image.resize_image_with_crop_or_pad( + images, 40, 40) + images = tf.random_crop(images, [32, 32, 3]) + images = tf.image.random_flip_left_right(images) + return images + + def flatten4d3d(x): """Flatten a 4d-tensor into a 3d-tensor by joining width and height.""" xshape = tf.shape(x) diff --git a/tensor2tensor/utils/data_reader.py b/tensor2tensor/utils/data_reader.py index 0ba62ec9f..88b45db9d 100644 --- a/tensor2tensor/utils/data_reader.py +++ b/tensor2tensor/utils/data_reader.py @@ -203,6 +203,12 @@ def preprocess(img): lambda img=inputs: resize(img)) else: examples["inputs"] = tf.to_int64(resize(inputs)) + + elif ("image_cifar10" in data_file_pattern + and mode == tf.contrib.learn.ModeKeys.TRAIN): + examples["inputs"] = common_layers.cifar_image_augmentation( + examples["inputs"]) + elif "audio" in data_file_pattern: # Reshape audio to proper shape sample_count = tf.to_int32(examples.pop("audio/sample_count"))