Skip to content
This repository has been archived by the owner on Jul 7, 2023. It is now read-only.

Commit

Permalink
Merge pull request #75 from rshin/master
Browse files Browse the repository at this point in the history
Add image augmentation for CIFAR-10
  • Loading branch information
lukaszkaiser authored Jun 30, 2017
2 parents f95b7c9 + 24571fb commit aae9966
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 0 deletions.
11 changes: 11 additions & 0 deletions tensor2tensor/models/common_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,17 @@ def image_augmentation(images, do_colors=False):
return images


def cifar_image_augmentation(images):
"""Image augmentation suitable for CIFAR-10/100.
As described in https://arxiv.org/pdf/1608.06993v3.pdf (page 5)."""
images = tf.image.resize_image_with_crop_or_pad(
images, 40, 40)
images = tf.random_crop(images, [32, 32, 3])
images = tf.image.random_flip_left_right(images)
return images


def flatten4d3d(x):
"""Flatten a 4d-tensor into a 3d-tensor by joining width and height."""
xshape = tf.shape(x)
Expand Down
6 changes: 6 additions & 0 deletions tensor2tensor/utils/data_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,12 @@ def preprocess(img):
lambda img=inputs: resize(img))
else:
examples["inputs"] = tf.to_int64(resize(inputs))

elif ("image_cifar10" in data_file_pattern
and mode == tf.contrib.learn.ModeKeys.TRAIN):
examples["inputs"] = common_layers.cifar_image_augmentation(
examples["inputs"])

elif "audio" in data_file_pattern:
# Reshape audio to proper shape
sample_count = tf.to_int32(examples.pop("audio/sample_count"))
Expand Down

0 comments on commit aae9966

Please sign in to comment.