Skip to content

Commit

Permalink
PSNR based only on the Y channel
Browse files Browse the repository at this point in the history
  • Loading branch information
michalkoziarski committed Sep 25, 2017
1 parent 0fa2cbf commit 692e0ed
Show file tree
Hide file tree
Showing 4 changed files with 27 additions and 15 deletions.
23 changes: 20 additions & 3 deletions predict.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import model
import utils
import os
import json
import argparse
Expand Down Expand Up @@ -26,7 +27,7 @@ def load_model(session):
return network


def predict(images, session=None, network=None):
def predict(images, session=None, network=None, targets=None, border=0):
session_passed = session is not None

if not session_passed:
Expand All @@ -37,7 +38,12 @@ def predict(images, session=None, network=None):

predictions = []

for image in images:
if targets is not None:
psnr = []

for i in range(len(images)):
image = images[i]

if len(image.shape) == 3:
image_ycbcr = color.rgb2ycbcr(image)
image_y = image_ycbcr[:, :, 0]
Expand All @@ -49,6 +55,14 @@ def predict(images, session=None, network=None):
prediction = network.output.eval(feed_dict={network.input: reshaped_image_y}, session=session)[0]
prediction *= 255

if targets is not None:
if len(targets[i].shape) == 3:
target_y = color.rgb2ycbcr(targets[i])[:, :, 0]
else:
target_y = targets[i].copy()

psnr.append(utils.psnr(prediction[:, :, 0], target_y, maximum=255.0))

if len(image.shape) == 3:
prediction = color.ycbcr2rgb(np.concatenate((prediction, image_ycbcr[:, :, 1:3]), axis=2)) * 255
else:
Expand All @@ -60,7 +74,10 @@ def predict(images, session=None, network=None):
if not session_passed:
session.close()

return predictions
if targets is not None:
return predictions, psnr
else:
return predictions


if __name__ == '__main__':
Expand Down
8 changes: 3 additions & 5 deletions test.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import data
import utils
import predict
import numpy as np
import tensorflow as tf
Expand All @@ -11,8 +10,7 @@
for set_name in ['Set5', 'Set14', 'B100', 'Urban100']:
for scaling_factor in [2, 3, 4]:
dataset = data.TestSet(set_name, scaling_factors=[scaling_factor])
predictions = predict.predict(dataset.images, session, network)
score = np.mean([utils.psnr(target.astype(np.float32), prediction.astype(np.float32), maximum=255).eval()
for target, prediction in zip(dataset.targets, predictions)])
predictions, psnr = predict.predict(dataset.images, session, network, targets=dataset.targets,
border=scaling_factor)

print('Dataset "%s", scaling factor = %d. Mean PSNR = %.2f.' % (set_name, scaling_factor, score))
print('Dataset "%s", scaling factor = %d. Mean PSNR = %.2f.' % (set_name, scaling_factor, np.mean(psnr)))
6 changes: 2 additions & 4 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,10 +89,8 @@
if batch * params['batch_size'] % train_set.length == 0:
print('Processing epoch #%d...' % (epoch + 1))

predictions = predict.predict(validation_set.images, session, network)
score = np.mean([utils.psnr(target.astype(np.float32), prediction.astype(np.float32), maximum=255).eval()
for target, prediction in zip(validation_set.targets, predictions)])
feed_dict[validation_score] = score
predictions, psnr = predict.predict(validation_set.images, session, network, targets=validation_set.targets)
feed_dict[validation_score] = np.mean(psnr)

_, summary = session.run([train_step, summary_step], feed_dict=feed_dict)
saver.save(session, model_path)
Expand Down
5 changes: 2 additions & 3 deletions utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import numpy as np
import tensorflow as tf


def psnr(x, y, maximum=1.0):
return 20 * tf.log(maximum / tf.sqrt(tf.reduce_mean(tf.pow(x - y, 2)))) / np.log(10)
def psnr(x, y, maximum=255.0):
return 20 * np.log10(maximum / np.sqrt(np.mean((x - y) ** 2)))

0 comments on commit 692e0ed

Please sign in to comment.