Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimize FlatStruct and applyFlat for memory efficiency #393

Draft
wants to merge 2 commits into
base: prerelease
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
129 changes: 57 additions & 72 deletions RMS/Routines/Image.py
Original file line number Diff line number Diff line change
Expand Up @@ -501,85 +501,81 @@ def adjustLevels(img_array, minv, gamma, maxv, nbits=None, scaleto8bits=False):




class FlatStruct(object):
class FlatStruct:
def __init__(self, flat_img, dark=None):
""" Structure containing the flat field.
"""
Structure containing the flat field information.

Arguments:
flat_img: [ndarray] Flat field.

flat_img: [ndarray] Flat field image.
dark: [ndarray] Dark frame to be subtracted from the flat (optional).
"""

# Convert the flat to float64
self.flat_img = flat_img.astype(np.float64)

# Store the original flat
self.flat_img_raw = np.copy(self.flat_img)

# Apply the dark, if given
self.applyDark(dark)

# Compute the flat median
self.computeAverage()
# Determine precision based on input image bit depth
if flat_img.dtype == np.uint8:
self.float_dtype = np.float32
else:
self.float_dtype = np.float64

# Fix values close to 0
self.fixValues()
# Process the flat image and store only essential information
self._process_flat(flat_img, dark)

def _process_flat(self, flat_img, dark):
""" Process the flat image and extract essential information. """

def applyDark(self, dark):
""" Apply a dark to the flat. """
# Convert to appropriate float type for processing
flat = flat_img.astype(self.float_dtype)

# Apply a dark frame to the flat, if given
# Apply dark subtraction if provided
if dark is not None:
self.flat_img = applyDark(self.flat_img_raw, dark)
self.dark_applied = True

else:
self.flat_img = np.copy(self.flat_img_raw)
self.dark_applied = False
flat = np.maximum(flat - dark.astype(self.float_dtype), 0)

# Compute flat median
self.computeAverage()
# Compute and store the average
self.flat_avg = self._compute_average(flat)

# Fix values close to 0
self.fixValues()
# Compute and store the inverse of the flat for faster division later
self.flat_inverse = np.where(flat > 0, 1.0 / flat, self.flat_avg)

# Store the shape for validation
self.shape = flat.shape

def computeAverage(self):
def _compute_average(self, flat):
""" Compute the reference level. """


# Bin the flat by a factor of 4 using the average method
flat_binned = binImage(self.flat_img, 4, method='avg')
flat_binned = self._bin_image(flat, 4)

# Take the maximum average level of pixels that are in a square of 1/4*height from the centre
radius = flat_binned.shape[0]//4
img_h_half = flat_binned.shape[0]//2
img_w_half = flat_binned.shape[1]//2
self.flat_avg = np.max(flat_binned[img_h_half-radius:img_h_half+radius, \
img_w_half-radius:img_w_half+radius])
# Take the maximum average level of pixels in a square of 1/4*height from the centre
radius = flat_binned.shape[0] // 4
img_h_half, img_w_half = flat_binned.shape[0] // 2, flat_binned.shape[1] // 2
avg = np.max(flat_binned[img_h_half-radius:img_h_half+radius,
img_w_half-radius:img_w_half+radius])

# Make sure the self.flat_avg value is relatively high
if self.flat_avg < 1:
self.flat_avg = 1
# Make sure the average value is relatively high
return max(avg, 1.0)

def _bin_image(self, image, bin_factor):
""" Bin the image by the given factor. """
if bin_factor == 1:
return image

def fixValues(self):
""" Handle values close to 0 on flats. """
h, w = image.shape
return image.reshape(h//bin_factor, bin_factor, w//bin_factor, bin_factor).mean(axis=(1,3))

# Make sure there are no values close to 0, as images are divided by flats
self.flat_img[(self.flat_img < self.flat_avg/10) | (self.flat_img < 10)] = self.flat_avg
def apply_flat(self, img):
""" Apply the flat field correction to the image. """

if img.shape != self.shape:
raise ValueError("Image shape does not match flat field shape")

def binFlat(self, binning_factor, binning_method):
""" Bin the flat. """
# Convert image to appropriate float type for calculations
img = img.astype(self.float_dtype)

# Bin the processed flat
self.flat_img = binImage(self.flat_img, binning_factor, binning_method)
# Apply the flat field correction
img *= self.flat_inverse
img *= self.flat_avg

# Bin the raw flat image
self.flat_img_raw = binImage(self.flat_img_raw, binning_factor, binning_method)
return img



Expand Down Expand Up @@ -620,39 +616,28 @@ def loadFlat(dir_path, file_name, dtype=None, byteswap=False, dark=None):
return flat_struct





def applyFlat(img, flat_struct):
""" Apply a flat field to the image.

def apply_flat(img, flat_struct):
""" Wrapper function to apply flat field correction.

Arguments:
img: [ndarray] Image to flat field.
flat_struct: [Flat struct] Structure containing the flat field.


Return:
[ndarray] Flat corrected image.

"""

# Check that the input image and the flat have the same dimensions, otherwise do not apply it
if img.shape != flat_struct.flat_img.shape:
return img

input_type = img.dtype

# Apply the flat
img = flat_struct.flat_avg*img.astype(np.float64)/flat_struct.flat_img
# Apply the flat field correction
img = flat_struct.apply_flat(img)

# Limit the image values to image type range
dtype_info = np.iinfo(input_type)
img = np.clip(img, dtype_info.min, dtype_info.max)
# Clip the values to the input type's range
np.clip(img, np.iinfo(input_type).min, np.iinfo(input_type).max, out=img)

# Make sure the output array is the same as the input type
img = img.astype(input_type)

return img
return img.astype(input_type)



Expand Down