Skip to content

Commit

Permalink
Add files via upload
Browse files Browse the repository at this point in the history
  • Loading branch information
jshermeyer authored Oct 25, 2018
1 parent e66dd70 commit a3ff3f5
Show file tree
Hide file tree
Showing 8 changed files with 385 additions and 65 deletions.
70 changes: 70 additions & 0 deletions Create_SR.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
import data
import predict
import numpy as np
import tensorflow as tf
from scipy import misc
from skimage import color
import os
import sys
import gdal
import glob
from tqdm import tqdm

#python3 Create_SR.py "input/data/" "/output/data/" 2


def SR_it(input_dir,output_dir,scaling_factor):
base_dir=os.getcwd()
file_names = []
projs=[]
geos=[]
SF=scaling_factor
if input_dir.endswith("/"):
O=input_dir.split("/")[-2]
else:
O=input_dir.split("/")[-1]
with tf.Session() as session:
network = predict.load_model(session)

driver = gdal.GetDriverByName("GTiff")
os.chdir(input_dir)
images = glob.glob('*.tif')
for image in tqdm(images):
image=gdal.Open(image)
geo = image.GetGeoTransform()
pixW=float(geo[1])/SF
pixH=float(geo[5])/SF
geo=[geo[0],pixW,geo[2],geo[3],geo[4],pixH]
#print(geo)
proj = image.GetProjection()
projs.append(proj)
geos.append(geo)


os.chdir(base_dir)
if not os.path.exists(output_dir):
os.mkdir(output_dir)

for file_name in tqdm(os.listdir(input_dir)):
file_names.append(file_name)

for set_name in [O]:
for scaling_factor in [SF]:
dataset = data.SR_Run(set_name, scaling_factors=[scaling_factor])
for I, proj, geo, file_name in tqdm(zip(dataset.images,projs,geos,file_names)):
Im=[I]
prediction = predict.predict(Im, session, network, targets=None, border=scaling_factor)
prediction=prediction[0]
prediction=np.swapaxes(prediction,-1,0)
prediction=np.swapaxes(prediction,-1,1)
out=output_dir+str(file_name)
DataSet = driver.Create(out, prediction.shape[2], prediction.shape[1], prediction.shape[0], gdal.GDT_Byte)
for i, image in enumerate(prediction, 1):
DataSet.GetRasterBand(i).WriteArray( image )
DataSet.SetProjection(proj)
DataSet.SetGeoTransform(geo)
#DataSet.SetNoDataValue(0)
del DataSet

if __name__ == "__main__":
SR_it(sys.argv[1],sys.argv[2],int(sys.argv[3]))
58 changes: 58 additions & 0 deletions Create_SR_NoGEO.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
import data
import predict
import numpy as np
import tensorflow as tf
from scipy import misc
from skimage import color
import os
import sys
import gdal
import glob
from tqdm import tqdm


#python3 Create_SR_NoGEO.py "input/data/" "/output/data/" 2


def SR_it(input_dir,output_dir,scaling_factor):
base_dir=os.getcwd()
file_names = []
projs=[]
geos=[]
SF=scaling_factor
if input_dir.endswith("/"):
O=input_dir.split("/")[-2]
else:
O=input_dir.split("/")[-1]
with tf.Session() as session:
network = predict.load_model(session)

driver = gdal.GetDriverByName("GTiff")
os.chdir(input_dir)
images = glob.glob('*.tif')


os.chdir(base_dir)
if not os.path.exists(output_dir):
os.mkdir(output_dir)

for file_name in tqdm(os.listdir(input_dir)):
file_names.append(file_name)

for set_name in [O]:
for scaling_factor in [SF]:
dataset = data.SR_Run(set_name, scaling_factors=[scaling_factor])
for I, file_name in tqdm(zip(dataset.images,file_names)):
Im=[I]
prediction = predict.predict(Im, session, network, targets=None, border=scaling_factor)
prediction=prediction[0]
prediction=np.swapaxes(prediction,-1,0)
prediction=np.swapaxes(prediction,-1,1)
out=output_dir+str(file_name)
DataSet = driver.Create(out, prediction.shape[2], prediction.shape[1], prediction.shape[0], gdal.GDT_Byte)
for i, image in enumerate(prediction, 1):
DataSet.GetRasterBand(i).WriteArray( image )
del DataSet

if __name__ == "__main__":
SR_it(sys.argv[1],sys.argv[2],int(sys.argv[3]))
153 changes: 104 additions & 49 deletions data.py
Original file line number Diff line number Diff line change
@@ -1,49 +1,63 @@
import os
import zipfile
import numpy as np
import math
import cv2
from tqdm import tqdm

from scipy import misc
from skimage import color
from urllib.request import urlretrieve


DATA_PATH = os.path.join(os.path.dirname(__file__), 'data')

DATA_PATH = "/Set/To/Data/Path"

class TrainSet:
def __init__(self, benchmark, batch_size=64, patch_size=41, scaling_factors=(2, 3, 4)):
def __init__(self, benchmark, batch_size=64, patch_size=41, scaling_factors=(2, 4, 8)):
self.benchmark = benchmark
self.batch_size = batch_size
self.patch_size = patch_size
self.scaling_factors = scaling_factors
self.images_completed = 0
self.epochs_completed = 0
self.root_path = os.path.join(DATA_PATH, 'train', benchmark)
self.root_path = os.path.join(DATA_PATH, 'TRAIN_SUBSET', self.benchmark)
self.images = []
self.targets = []

if not os.path.exists(self.root_path):
download()

for file_name in os.listdir(self.root_path):
#Read in image
image = misc.imread(os.path.join(self.root_path, file_name))

if len(image.shape) == 3:
image = color.rgb2ycbcr(image)[:, :, 0].astype(np.uint8)

width, height = image.shape
#Crop to an area divisible by 12
width, height = image.shape[0], image.shape[1]
width = width - width % 12
height = height - height % 12
n_horizontal_patches = width // patch_size
n_vertical_patches = height // patch_size
image = image[:width, :height]

image= image[:width,:height]

#For each level of enhacement
for scaling_factor in scaling_factors:
downscaled = misc.imresize(image, 1 / scaling_factor, 'bicubic', mode='L')
rescaled = misc.imresize(downscaled, float(scaling_factor), 'bicubic', mode='L')
#Conditional blur
blur_level=scaling_factor/2
blurred = cv2.GaussianBlur(image, (0, 0), blur_level, blur_level, 0)
#Pull out the luminance component of ycbcr for the HR and blurred images
if len(image.shape) == 3:
blurred = color.rgb2ycbcr(blurred)[:, :, 0].astype(np.uint8)
image = color.rgb2ycbcr(image)[:, :, 0].astype(np.uint8)



#downscale the blurred component
downscaled=cv2.resize(blurred, (0,0), fx=float(1 / scaling_factor),fy=float(1 / scaling_factor), interpolation=cv2.INTER_AREA)
#rescale the blurred component
rescaled = misc.imresize(downscaled, (image.shape[0],image.shape[1]), 'bicubic', mode='L')
#Save the luminance component of the original image as an HR target
high_res_image = image.astype(np.float32) / 255
#Save the blurred, downscaled/rescaled as a LR target
low_res_image = np.clip(rescaled.astype(np.float32) / 255, 0.0, 1.0)


#Create patches and data aug for training
for horizontal_patch in range(n_horizontal_patches):
for vertical_patch in range(n_vertical_patches):
h_start = horizontal_patch * patch_size
Expand Down Expand Up @@ -98,40 +112,51 @@ def shuffle(self):
self.targets = self.targets[indices]



class TestSet:
def __init__(self, benchmark, scaling_factors=(2, 3, 4)):
def __init__(self, benchmark, scaling_factors=(2, 4, 8)):
self.benchmark = benchmark
self.scaling_factors = scaling_factors
self.images_completed = 0
self.root_path = os.path.join(DATA_PATH, 'test', self.benchmark)
self.root_path = os.path.join(DATA_PATH, 'TEST', self.benchmark)
self.file_names = os.listdir(self.root_path)
self.images = []
self.targets = []

if not os.path.exists(self.root_path):
download()

for file_name in os.listdir(self.root_path):
for file_name in tqdm(os.listdir(self.root_path)):
image = misc.imread(os.path.join(self.root_path, file_name))

width, height = image.shape[0], image.shape[1]
width = width - width % 12
height = height - height % 12
image = image[:width, :height]

if len(image.shape) == 3:
ycbcr = color.rgb2ycbcr(image)
y = ycbcr[:, :, 0].astype(np.uint8)
else:
y = image

#For each enhancement level...
for scaling_factor in self.scaling_factors:
downscaled = misc.imresize(y, 1 / scaling_factor, 'bicubic', mode='L')
rescaled = misc.imresize(downscaled, float(scaling_factor), 'bicubic', mode='L')

#Conditional Blur
blur_level=scaling_factor/2
blurred = cv2.GaussianBlur(image, (0, 0), blur_level, blur_level, 0)

if len(image.shape) == 3:
#Pull out all the original ycbcr components
ycbcr = color.rgb2ycbcr(blurred)
y = ycbcr[:, :, 0].astype(np.uint8)
b = ycbcr[:, :, 1].astype(np.uint8)
r = ycbcr[:, :, 2].astype(np.uint8)
else:
y = blurred

#Downscale them
downscaled=cv2.resize(y, (0,0), fx=float(1 / scaling_factor),fy=float(1 / scaling_factor), interpolation=cv2.INTER_AREA)
d_b=cv2.resize(b, (0,0), fx=float(1 / scaling_factor),fy=float(1 / scaling_factor), interpolation=cv2.INTER_AREA)
d_r=cv2.resize(r, (0,0), fx=float(1 / scaling_factor),fy=float(1 / scaling_factor), interpolation=cv2.INTER_AREA)

#rescale them
rescaled = misc.imresize(downscaled, (y.shape[0],y.shape[1]), 'bicubic', mode='L')
r_b = misc.imresize(d_b, (y.shape[0],y.shape[1]), 'bicubic', mode='L')
d_r = misc.imresize(d_r, (y.shape[0],y.shape[1]), 'bicubic', mode='L')

#Create the LR image to convert to HR
if len(image.shape) == 3:
low_res_image = ycbcr
low_res_image[:, :, 0] = rescaled
low_res_image[:, :, 1] = r_b
low_res_image[:, :, 2] = d_r
low_res_image = color.ycbcr2rgb(low_res_image)
low_res_image = (np.clip(low_res_image, 0.0, 1.0) * 255).astype(np.uint8)
else:
Expand All @@ -151,20 +176,50 @@ def fetch(self):
return self.images[self.images_completed - 1], self.targets[self.images_completed - 1]


def download():
if not os.path.exists(DATA_PATH):
os.mkdir(DATA_PATH)
class SR_Run:
def __init__(self, benchmark, scaling_factors=(2, 4, 8)):
self.benchmark = benchmark
self.scaling_factors = scaling_factors
self.images_completed = 0
self.root_path = os.path.join(DATA_PATH, self.benchmark)
self.file_names = os.listdir(self.root_path)
self.images = []
self.targets = []

for file_name in tqdm(os.listdir(self.root_path)):
image = misc.imread(os.path.join(self.root_path, file_name))

for scaling_factor in self.scaling_factors:
if len(image.shape) == 3:
ycbcr = color.rgb2ycbcr(image)
downscaled = ycbcr[:, :, 0].astype(np.uint8)
d_b = ycbcr[:, :, 1].astype(np.uint8)
d_r = ycbcr[:, :, 2].astype(np.uint8)
else:
y = image

for partition in ['train', 'test']:
partition_path = os.path.join(DATA_PATH, partition)
zip_path = os.path.join(partition_path, '%s_data.zip' % partition)
url = 'http://cv.snu.ac.kr/research/VDSR/%s_data.zip' % partition
rescaled = misc.imresize(downscaled, float(scaling_factor), 'bicubic', mode='L')
r_b = misc.imresize(d_b, float(scaling_factor), 'bicubic', mode='L')
d_r = misc.imresize(d_r, float(scaling_factor), 'bicubic', mode='L')

if not os.path.exists(partition_path):
os.mkdir(partition_path)

if not os.path.exists(zip_path):
urlretrieve(url, zip_path)
if len(image.shape) == 3:
low_res_image = np.stack([rescaled,r_b,d_r],axis=2)
low_res_image=low_res_image.astype(np.float64)
low_res_image = color.ycbcr2rgb(low_res_image)
low_res_image = (np.clip(low_res_image, 0.0, 1.0) * 255).astype(np.uint8)
else:
low_res_image = rescaled

with zipfile.ZipFile(zip_path) as f:
f.extractall(partition_path)
self.images.append(low_res_image)
self.targets.append(image)

self.length = len(self.images)

def fetch(self):
if self.images_completed >= self.length:
return None
else:
self.images_completed += 1

return self.images[self.images_completed - 1], self.targets[self.images_completed - 1]
4 changes: 2 additions & 2 deletions params.json
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,6 @@
"learning_rate_decay": 0.1,
"learning_rate_decay_step": 20,
"gradient_clipping": 0.5,
"train_set": "291",
"validation_set": "Set5"
"train_set": "TRAIN_images",
"validation_set": "TEST_images"
}
Loading

0 comments on commit a3ff3f5

Please sign in to comment.