diff --git a/Create_SR.py b/Create_SR.py new file mode 100644 index 0000000..31dff66 --- /dev/null +++ b/Create_SR.py @@ -0,0 +1,70 @@ +import data +import predict +import numpy as np +import tensorflow as tf +from scipy import misc +from skimage import color +import os +import sys +import gdal +import glob +from tqdm import tqdm + +#python3 Create_SR.py "input/data/" "/output/data/" 2 + + +def SR_it(input_dir,output_dir,scaling_factor): + base_dir=os.getcwd() + file_names = [] + projs=[] + geos=[] + SF=scaling_factor + if input_dir.endswith("/"): + O=input_dir.split("/")[-2] + else: + O=input_dir.split("/")[-1] + with tf.Session() as session: + network = predict.load_model(session) + + driver = gdal.GetDriverByName("GTiff") + os.chdir(input_dir) + images = glob.glob('*.tif') + for image in tqdm(images): + image=gdal.Open(image) + geo = image.GetGeoTransform() + pixW=float(geo[1])/SF + pixH=float(geo[5])/SF + geo=[geo[0],pixW,geo[2],geo[3],geo[4],pixH] + #print(geo) + proj = image.GetProjection() + projs.append(proj) + geos.append(geo) + + + os.chdir(base_dir) + if not os.path.exists(output_dir): + os.mkdir(output_dir) + + for file_name in tqdm(os.listdir(input_dir)): + file_names.append(file_name) + + for set_name in [O]: + for scaling_factor in [SF]: + dataset = data.SR_Run(set_name, scaling_factors=[scaling_factor]) + for I, proj, geo, file_name in tqdm(zip(dataset.images,projs,geos,file_names)): + Im=[I] + prediction = predict.predict(Im, session, network, targets=None, border=scaling_factor) + prediction=prediction[0] + prediction=np.swapaxes(prediction,-1,0) + prediction=np.swapaxes(prediction,-1,1) + out=output_dir+str(file_name) + DataSet = driver.Create(out, prediction.shape[2], prediction.shape[1], prediction.shape[0], gdal.GDT_Byte) + for i, image in enumerate(prediction, 1): + DataSet.GetRasterBand(i).WriteArray( image ) + DataSet.SetProjection(proj) + DataSet.SetGeoTransform(geo) + #DataSet.SetNoDataValue(0) + del DataSet + +if __name__ == "__main__": + SR_it(sys.argv[1],sys.argv[2],int(sys.argv[3])) \ No newline at end of file diff --git a/Create_SR_NoGEO.py b/Create_SR_NoGEO.py new file mode 100644 index 0000000..2e5cdf9 --- /dev/null +++ b/Create_SR_NoGEO.py @@ -0,0 +1,58 @@ +import data +import predict +import numpy as np +import tensorflow as tf +from scipy import misc +from skimage import color +import os +import sys +import gdal +import glob +from tqdm import tqdm + + +#python3 Create_SR_NoGEO.py "input/data/" "/output/data/" 2 + + +def SR_it(input_dir,output_dir,scaling_factor): + base_dir=os.getcwd() + file_names = [] + projs=[] + geos=[] + SF=scaling_factor + if input_dir.endswith("/"): + O=input_dir.split("/")[-2] + else: + O=input_dir.split("/")[-1] + with tf.Session() as session: + network = predict.load_model(session) + + driver = gdal.GetDriverByName("GTiff") + os.chdir(input_dir) + images = glob.glob('*.tif') + + + os.chdir(base_dir) + if not os.path.exists(output_dir): + os.mkdir(output_dir) + + for file_name in tqdm(os.listdir(input_dir)): + file_names.append(file_name) + + for set_name in [O]: + for scaling_factor in [SF]: + dataset = data.SR_Run(set_name, scaling_factors=[scaling_factor]) + for I, file_name in tqdm(zip(dataset.images,file_names)): + Im=[I] + prediction = predict.predict(Im, session, network, targets=None, border=scaling_factor) + prediction=prediction[0] + prediction=np.swapaxes(prediction,-1,0) + prediction=np.swapaxes(prediction,-1,1) + out=output_dir+str(file_name) + DataSet = driver.Create(out, prediction.shape[2], prediction.shape[1], prediction.shape[0], gdal.GDT_Byte) + for i, image in enumerate(prediction, 1): + DataSet.GetRasterBand(i).WriteArray( image ) + del DataSet + +if __name__ == "__main__": + SR_it(sys.argv[1],sys.argv[2],int(sys.argv[3])) \ No newline at end of file diff --git a/data.py b/data.py index ff4b322..07e8a20 100644 --- a/data.py +++ b/data.py @@ -1,49 +1,63 @@ import os import zipfile import numpy as np +import math +import cv2 +from tqdm import tqdm from scipy import misc from skimage import color from urllib.request import urlretrieve -DATA_PATH = os.path.join(os.path.dirname(__file__), 'data') - +DATA_PATH = "/Set/To/Data/Path" class TrainSet: - def __init__(self, benchmark, batch_size=64, patch_size=41, scaling_factors=(2, 3, 4)): + def __init__(self, benchmark, batch_size=64, patch_size=41, scaling_factors=(2, 4, 8)): self.benchmark = benchmark self.batch_size = batch_size self.patch_size = patch_size self.scaling_factors = scaling_factors self.images_completed = 0 self.epochs_completed = 0 - self.root_path = os.path.join(DATA_PATH, 'train', benchmark) + self.root_path = os.path.join(DATA_PATH, 'TRAIN_SUBSET', self.benchmark) self.images = [] self.targets = [] - if not os.path.exists(self.root_path): - download() for file_name in os.listdir(self.root_path): + #Read in image image = misc.imread(os.path.join(self.root_path, file_name)) - - if len(image.shape) == 3: - image = color.rgb2ycbcr(image)[:, :, 0].astype(np.uint8) - - width, height = image.shape + #Crop to an area divisible by 12 + width, height = image.shape[0], image.shape[1] width = width - width % 12 height = height - height % 12 n_horizontal_patches = width // patch_size n_vertical_patches = height // patch_size - image = image[:width, :height] - + image= image[:width,:height] + + #For each level of enhacement for scaling_factor in scaling_factors: - downscaled = misc.imresize(image, 1 / scaling_factor, 'bicubic', mode='L') - rescaled = misc.imresize(downscaled, float(scaling_factor), 'bicubic', mode='L') + #Conditional blur + blur_level=scaling_factor/2 + blurred = cv2.GaussianBlur(image, (0, 0), blur_level, blur_level, 0) + #Pull out the luminance component of ycbcr for the HR and blurred images + if len(image.shape) == 3: + blurred = color.rgb2ycbcr(blurred)[:, :, 0].astype(np.uint8) + image = color.rgb2ycbcr(image)[:, :, 0].astype(np.uint8) + + + + #downscale the blurred component + downscaled=cv2.resize(blurred, (0,0), fx=float(1 / scaling_factor),fy=float(1 / scaling_factor), interpolation=cv2.INTER_AREA) + #rescale the blurred component + rescaled = misc.imresize(downscaled, (image.shape[0],image.shape[1]), 'bicubic', mode='L') + #Save the luminance component of the original image as an HR target high_res_image = image.astype(np.float32) / 255 + #Save the blurred, downscaled/rescaled as a LR target low_res_image = np.clip(rescaled.astype(np.float32) / 255, 0.0, 1.0) - + + #Create patches and data aug for training for horizontal_patch in range(n_horizontal_patches): for vertical_patch in range(n_vertical_patches): h_start = horizontal_patch * patch_size @@ -98,40 +112,51 @@ def shuffle(self): self.targets = self.targets[indices] + class TestSet: - def __init__(self, benchmark, scaling_factors=(2, 3, 4)): + def __init__(self, benchmark, scaling_factors=(2, 4, 8)): self.benchmark = benchmark self.scaling_factors = scaling_factors self.images_completed = 0 - self.root_path = os.path.join(DATA_PATH, 'test', self.benchmark) + self.root_path = os.path.join(DATA_PATH, 'TEST', self.benchmark) self.file_names = os.listdir(self.root_path) self.images = [] self.targets = [] - if not os.path.exists(self.root_path): - download() - - for file_name in os.listdir(self.root_path): + for file_name in tqdm(os.listdir(self.root_path)): image = misc.imread(os.path.join(self.root_path, file_name)) - width, height = image.shape[0], image.shape[1] - width = width - width % 12 - height = height - height % 12 - image = image[:width, :height] - - if len(image.shape) == 3: - ycbcr = color.rgb2ycbcr(image) - y = ycbcr[:, :, 0].astype(np.uint8) - else: - y = image - + #For each enhancement level... for scaling_factor in self.scaling_factors: - downscaled = misc.imresize(y, 1 / scaling_factor, 'bicubic', mode='L') - rescaled = misc.imresize(downscaled, float(scaling_factor), 'bicubic', mode='L') - + #Conditional Blur + blur_level=scaling_factor/2 + blurred = cv2.GaussianBlur(image, (0, 0), blur_level, blur_level, 0) + + if len(image.shape) == 3: + #Pull out all the original ycbcr components + ycbcr = color.rgb2ycbcr(blurred) + y = ycbcr[:, :, 0].astype(np.uint8) + b = ycbcr[:, :, 1].astype(np.uint8) + r = ycbcr[:, :, 2].astype(np.uint8) + else: + y = blurred + + #Downscale them + downscaled=cv2.resize(y, (0,0), fx=float(1 / scaling_factor),fy=float(1 / scaling_factor), interpolation=cv2.INTER_AREA) + d_b=cv2.resize(b, (0,0), fx=float(1 / scaling_factor),fy=float(1 / scaling_factor), interpolation=cv2.INTER_AREA) + d_r=cv2.resize(r, (0,0), fx=float(1 / scaling_factor),fy=float(1 / scaling_factor), interpolation=cv2.INTER_AREA) + + #rescale them + rescaled = misc.imresize(downscaled, (y.shape[0],y.shape[1]), 'bicubic', mode='L') + r_b = misc.imresize(d_b, (y.shape[0],y.shape[1]), 'bicubic', mode='L') + d_r = misc.imresize(d_r, (y.shape[0],y.shape[1]), 'bicubic', mode='L') + + #Create the LR image to convert to HR if len(image.shape) == 3: low_res_image = ycbcr low_res_image[:, :, 0] = rescaled + low_res_image[:, :, 1] = r_b + low_res_image[:, :, 2] = d_r low_res_image = color.ycbcr2rgb(low_res_image) low_res_image = (np.clip(low_res_image, 0.0, 1.0) * 255).astype(np.uint8) else: @@ -151,20 +176,50 @@ def fetch(self): return self.images[self.images_completed - 1], self.targets[self.images_completed - 1] -def download(): - if not os.path.exists(DATA_PATH): - os.mkdir(DATA_PATH) +class SR_Run: + def __init__(self, benchmark, scaling_factors=(2, 4, 8)): + self.benchmark = benchmark + self.scaling_factors = scaling_factors + self.images_completed = 0 + self.root_path = os.path.join(DATA_PATH, self.benchmark) + self.file_names = os.listdir(self.root_path) + self.images = [] + self.targets = [] + + for file_name in tqdm(os.listdir(self.root_path)): + image = misc.imread(os.path.join(self.root_path, file_name)) + + for scaling_factor in self.scaling_factors: + if len(image.shape) == 3: + ycbcr = color.rgb2ycbcr(image) + downscaled = ycbcr[:, :, 0].astype(np.uint8) + d_b = ycbcr[:, :, 1].astype(np.uint8) + d_r = ycbcr[:, :, 2].astype(np.uint8) + else: + y = image - for partition in ['train', 'test']: - partition_path = os.path.join(DATA_PATH, partition) - zip_path = os.path.join(partition_path, '%s_data.zip' % partition) - url = 'http://cv.snu.ac.kr/research/VDSR/%s_data.zip' % partition + rescaled = misc.imresize(downscaled, float(scaling_factor), 'bicubic', mode='L') + r_b = misc.imresize(d_b, float(scaling_factor), 'bicubic', mode='L') + d_r = misc.imresize(d_r, float(scaling_factor), 'bicubic', mode='L') - if not os.path.exists(partition_path): - os.mkdir(partition_path) - if not os.path.exists(zip_path): - urlretrieve(url, zip_path) + if len(image.shape) == 3: + low_res_image = np.stack([rescaled,r_b,d_r],axis=2) + low_res_image=low_res_image.astype(np.float64) + low_res_image = color.ycbcr2rgb(low_res_image) + low_res_image = (np.clip(low_res_image, 0.0, 1.0) * 255).astype(np.uint8) + else: + low_res_image = rescaled - with zipfile.ZipFile(zip_path) as f: - f.extractall(partition_path) + self.images.append(low_res_image) + self.targets.append(image) + + self.length = len(self.images) + + def fetch(self): + if self.images_completed >= self.length: + return None + else: + self.images_completed += 1 + + return self.images[self.images_completed - 1], self.targets[self.images_completed - 1] diff --git a/params.json b/params.json index 3524941..730b9d0 100644 --- a/params.json +++ b/params.json @@ -10,6 +10,6 @@ "learning_rate_decay": 0.1, "learning_rate_decay_step": 20, "gradient_clipping": 0.5, - "train_set": "291", - "validation_set": "Set5" + "train_set": "TRAIN_images", + "validation_set": "TEST_images" } \ No newline at end of file diff --git a/predict.py b/predict.py index 2544ee6..26069e0 100644 --- a/predict.py +++ b/predict.py @@ -8,6 +8,8 @@ from scipy import misc from skimage import color +from skimage.measure import compare_ssim +from tqdm import tqdm def load_model(session): @@ -40,8 +42,9 @@ def predict(images, session=None, network=None, targets=None, border=0): if targets is not None: psnr = [] + ssim=[] - for i in range(len(images)): + for i in (range(len(images))): image = images[i] if len(image.shape) == 3: @@ -60,14 +63,18 @@ def predict(images, session=None, network=None, targets=None, border=0): target_y = color.rgb2ycbcr(targets[i])[:, :, 0] else: target_y = targets[i].copy() + + psnr_calc=utils.psnr(prediction[border:-border, border:-border, 0], + target_y[border:-border, border:-border], maximum=255.0) + #print(psnr_calc) - psnr.append(utils.psnr(prediction[border:-border, border:-border, 0], - target_y[border:-border, border:-border], maximum=255.0)) + psnr.append(psnr_calc) + ssim.append(compare_ssim(target_y[border:-border, border:-border], prediction[border:-border, border:-border, 0], data_range=prediction.max() - prediction.min())) - if len(image.shape) == 3: - prediction = color.ycbcr2rgb(np.concatenate((prediction, image_ycbcr[:, :, 1:3]), axis=2)) * 255 - else: - prediction = prediction[:, :, 0] + #if len(image.shape) == 3: + #prediction = color.ycbcr2rgb(np.concatenate((prediction, image_ycbcr[:, :, 1:3]), axis=2)) * 255 + #else: + #prediction = prediction[:, :, 0] prediction = np.clip(prediction, 0, 255).astype(np.uint8) predictions.append(prediction) @@ -76,7 +83,7 @@ def predict(images, session=None, network=None, targets=None, border=0): session.close() if targets is not None: - return predictions, psnr + return predictions, psnr, ssim else: return predictions diff --git a/predict2.py b/predict2.py new file mode 100644 index 0000000..5c49181 --- /dev/null +++ b/predict2.py @@ -0,0 +1,114 @@ +import model +import utils +import os +import json +import argparse +import numpy as np +import tensorflow as tf + +from scipy import misc +from skimage import color +from skimage.measure import compare_ssim +from tqdm import tqdm + + +def load_model(session): + checkpoint_path = os.path.join(os.path.dirname(__file__), 'model') + + assert os.path.exists(checkpoint_path) + + with open(os.path.join(os.path.dirname(__file__), 'params.json')) as f: + params = json.load(f) + + input = tf.placeholder(tf.float32) + network = model.Model(input, params['n_layers'], params['kernel_size'], params['n_filters']) + checkpoint = tf.train.get_checkpoint_state(checkpoint_path) + saver = tf.train.Saver() + saver.restore(session, checkpoint.model_checkpoint_path) + + return network + + +def predict(images, session=None, network=None, targets=None, border=0): + session_passed = session is not None + + if not session_passed: + session = tf.Session() + + if network is None: + network = load_model(session) + + predictions = [] + + if targets is not None: + psnr = [] + ssim=[] + #print(len(images),"num images") + for i in tqdm(range(len(images))): + image = images[i] + + if len(image.shape) == 3: + image_ycbcr = color.rgb2ycbcr(image) + image_y = image_ycbcr[:, :, 0] + else: + image_y = image.copy() + + image_y = image_y.astype(np.float) / 255 + reshaped_image_y = np.array([np.expand_dims(image_y, axis=2)]) + prediction = network.output.eval(feed_dict={network.input: reshaped_image_y}, session=session)[0] + prediction *= 255 + + if targets is not None: + if len(targets[i].shape) == 3: + target_y = color.rgb2ycbcr(targets[i])[:, :, 0] + else: + target_y = targets[i].copy() + + psnr.append(utils.psnr(prediction[border:-border, border:-border, 0], + target_y[border:-border, border:-border], maximum=255.0)) + ssim.append(compare_ssim(target_y[border:-border, border:-border], prediction[border:-border, border:-border, 0], data_range=prediction.max() - prediction.min())) + + if len(image.shape) == 3: + prediction = color.ycbcr2rgb(np.concatenate((prediction, image_ycbcr[:, :, 1:3]), axis=2)) * 255 + else: + prediction = prediction[:, :, 0] + + prediction = np.clip(prediction, 0, 255).astype(np.uint8) + predictions.append(prediction) + + if not session_passed: + session.close() + + if targets is not None: + return predictions, psnr, ssim + else: + return predictions + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('-in', help='a path of the input image or a directory of the input images', required=True) + parser.add_argument('-out', help='a path for the output image or a directory for the output images', required=True) + args = vars(parser.parse_args()) + + if os.path.isfile(args['in']): + image = misc.imread(args['in']) + prediction = predict([image])[0] + misc.imsave(args['out'], prediction) + elif os.path.isdir(args['in']): + images = [] + file_names = [] + + for file_name in os.listdir(args['in']): + images.append(misc.imread(os.path.join(args['in'], file_name))) + file_names.append(file_name) + + predictions = predict(images) + + if not os.path.exists(args['out']): + os.mkdir(args['out']) + + for file_name, prediction in zip(file_names, predictions): + misc.imsave(os.path.join(args['out'], file_name), prediction) + else: + raise ValueError('Incorrect input path.') diff --git a/test.py b/test.py index e505b95..32132a9 100644 --- a/test.py +++ b/test.py @@ -1,16 +1,17 @@ import data -import predict +import predict2 import numpy as np import tensorflow as tf with tf.Session() as session: - network = predict.load_model(session) + network = predict2.load_model(session) - for set_name in ['Set5', 'Set14', 'B100', 'Urban100']: - for scaling_factor in [2, 3, 4]: + for set_name in ['val_images_544']: + for scaling_factor in [2,4,8]: dataset = data.TestSet(set_name, scaling_factors=[scaling_factor]) - predictions, psnr = predict.predict(dataset.images, session, network, targets=dataset.targets, - border=scaling_factor) + predictions, psnr, ssim = predict2.predict(dataset.images, session, network, targets=dataset.targets, + border=int(scaling_factor)) print('Dataset "%s", scaling factor = %d. Mean PSNR = %.2f.' % (set_name, scaling_factor, np.mean(psnr))) + print("SSIM:",np.mean(ssim)) diff --git a/trainingdataTest.py b/trainingdataTest.py new file mode 100644 index 0000000..ac24270 --- /dev/null +++ b/trainingdataTest.py @@ -0,0 +1,15 @@ +import data +#import model +#import utils +#import predict +import os +import json +import numpy as np +#import tensorflow as tf + + +with open(os.path.join(os.path.dirname(__file__), 'params.json')) as f: + params = json.load(f) + +train_set = data.TrainSet(params['train_set'], params['batch_size'], params['patch_size']) +validation_set = data.TestSet(params['validation_set'])