From 326a64ce01ecb06880e9bd2508874128c043fd8f Mon Sep 17 00:00:00 2001 From: Luc Busquin <133058544+Cybis320@users.noreply.github.com> Date: Tue, 27 Aug 2024 14:34:53 -0700 Subject: [PATCH 1/2] Optimize FlatStruct and applyFlat for memory efficiency and bit depth flexibility --- RMS/Routines/Image.py | 59 ++++++++++++++++++++++++++----------------- 1 file changed, 36 insertions(+), 23 deletions(-) diff --git a/RMS/Routines/Image.py b/RMS/Routines/Image.py index 6a659b3d0..11ddd4646 100644 --- a/RMS/Routines/Image.py +++ b/RMS/Routines/Image.py @@ -507,15 +507,18 @@ def __init__(self, flat_img, dark=None): """ Structure containing the flat field. Arguments: - flat_img: [ndarray] Flat field. - + flat_img: [ndarray] Flat field (8-bit or 16-bit single channel image). """ + # Determine the appropriate dtype based on input + if flat_img.dtype == np.uint8: + self.dtype = np.uint8 + self.max_value = 255 + else: + self.dtype = np.uint16 + self.max_value = 65535 - # Convert the flat to float64 - self.flat_img = flat_img.astype(np.float64) - - # Store the original flat - self.flat_img_raw = np.copy(self.flat_img) + # Store the flat as the appropriate integer type + self.flat_img_raw = flat_img.astype(self.dtype) # Apply the dark, if given self.applyDark(dark) @@ -532,11 +535,12 @@ def applyDark(self, dark): # Apply a dark frame to the flat, if given if dark is not None: - self.flat_img = applyDark(self.flat_img_raw, dark) + # Ensure dark is the same dtype as flat and apply subtraction + dark = dark.astype(self.dtype) + self.flat_img_raw = np.maximum(self.flat_img_raw.astype(np.int32) - dark, 0).astype(self.dtype) self.dark_applied = True else: - self.flat_img = np.copy(self.flat_img_raw) self.dark_applied = False # Compute flat median @@ -551,7 +555,7 @@ def computeAverage(self): # Bin the flat by a factor of 4 using the average method - flat_binned = binImage(self.flat_img, 4, method='avg') + flat_binned = binImage(self.flat_img_raw, 4, method='avg') # Take the maximum average level of pixels that are in a square of 1/4*height from the centre radius = flat_binned.shape[0]//4 @@ -561,24 +565,26 @@ def computeAverage(self): img_w_half-radius:img_w_half+radius]) # Make sure the self.flat_avg value is relatively high - if self.flat_avg < 1: - self.flat_avg = 1 + self.flat_avg = max(self.flat_avg, 1) def fixValues(self): """ Handle values close to 0 on flats. """ # Make sure there are no values close to 0, as images are divided by flats - self.flat_img[(self.flat_img < self.flat_avg/10) | (self.flat_img < 10)] = self.flat_avg + mask = (self.flat_img_raw < self.flat_avg/10) | (self.flat_img_raw < 10) + self.flat_img_raw[mask] = self.flat_avg + def getFlatImage(self): + """Return the processed flat image as float32 for calculations.""" + return self.flat_img_raw.astype(np.float32) / self.max_value + + def binFlat(self, binning_factor, binning_method): """ Bin the flat. """ - # Bin the processed flat - self.flat_img = binImage(self.flat_img, binning_factor, binning_method) - - # Bin the raw flat image + # Bin the raw flat image in-place self.flat_img_raw = binImage(self.flat_img_raw, binning_factor, binning_method) @@ -636,18 +642,25 @@ def applyFlat(img, flat_struct): """ + # Get the processed flat image + flat_img = flat_struct.getFlatImage() + # Check that the input image and the flat have the same dimensions, otherwise do not apply it - if img.shape != flat_struct.flat_img.shape: + if img.shape != flat_img.shape: + print("Warning: Flat field dimensions do not match the image. Flat field not applied.") return img input_type = img.dtype - # Apply the flat - img = flat_struct.flat_avg*img.astype(np.float64)/flat_struct.flat_img + # Convert image to float32 for calculations + img = img.astype(np.float32) + + # Apply the flat field correction + np.divide(img, flat_img, out=img) + np.multiply(img, flat_struct.flat_avg / flat_struct.max_value, out=img) - # Limit the image values to image type range - dtype_info = np.iinfo(input_type) - img = np.clip(img, dtype_info.min, dtype_info.max) + # Clip the values to the input type's range + np.clip(img, np.iinfo(input_type).min, np.iinfo(input_type).max, out=img) # Make sure the output array is the same as the input type img = img.astype(input_type) From ebd020ad83768015396e9be9851e9847b5a587ef Mon Sep 17 00:00:00 2001 From: Luc Busquin <133058544+Cybis320@users.noreply.github.com> Date: Sat, 31 Aug 2024 12:33:47 -0700 Subject: [PATCH 2/2] Flatstruct only stores the inverse of the flat in float32 for 8-bit img or float64 for 16-bit img --- RMS/Routines/Image.py | 134 +++++++++++++++++------------------------- 1 file changed, 53 insertions(+), 81 deletions(-) diff --git a/RMS/Routines/Image.py b/RMS/Routines/Image.py index 11ddd4646..9756df95c 100644 --- a/RMS/Routines/Image.py +++ b/RMS/Routines/Image.py @@ -501,91 +501,81 @@ def adjustLevels(img_array, minv, gamma, maxv, nbits=None, scaleto8bits=False): - -class FlatStruct(object): +class FlatStruct: def __init__(self, flat_img, dark=None): - """ Structure containing the flat field. + """ + Structure containing the flat field information. Arguments: - flat_img: [ndarray] Flat field (8-bit or 16-bit single channel image). + flat_img: [ndarray] Flat field image. + dark: [ndarray] Dark frame to be subtracted from the flat (optional). """ - # Determine the appropriate dtype based on input + + # Determine precision based on input image bit depth if flat_img.dtype == np.uint8: - self.dtype = np.uint8 - self.max_value = 255 + self.float_dtype = np.float32 else: - self.dtype = np.uint16 - self.max_value = 65535 - - # Store the flat as the appropriate integer type - self.flat_img_raw = flat_img.astype(self.dtype) - - # Apply the dark, if given - self.applyDark(dark) - - # Compute the flat median - self.computeAverage() + self.float_dtype = np.float64 - # Fix values close to 0 - self.fixValues() + # Process the flat image and store only essential information + self._process_flat(flat_img, dark) + def _process_flat(self, flat_img, dark): + """ Process the flat image and extract essential information. """ - def applyDark(self, dark): - """ Apply a dark to the flat. """ + # Convert to appropriate float type for processing + flat = flat_img.astype(self.float_dtype) - # Apply a dark frame to the flat, if given + # Apply dark subtraction if provided if dark is not None: - # Ensure dark is the same dtype as flat and apply subtraction - dark = dark.astype(self.dtype) - self.flat_img_raw = np.maximum(self.flat_img_raw.astype(np.int32) - dark, 0).astype(self.dtype) - self.dark_applied = True + flat = np.maximum(flat - dark.astype(self.float_dtype), 0) - else: - self.dark_applied = False - - # Compute flat median - self.computeAverage() + # Compute and store the average + self.flat_avg = self._compute_average(flat) - # Fix values close to 0 - self.fixValues() + # Compute and store the inverse of the flat for faster division later + self.flat_inverse = np.where(flat > 0, 1.0 / flat, self.flat_avg) + # Store the shape for validation + self.shape = flat.shape - def computeAverage(self): + def _compute_average(self, flat): """ Compute the reference level. """ - # Bin the flat by a factor of 4 using the average method - flat_binned = binImage(self.flat_img_raw, 4, method='avg') + flat_binned = self._bin_image(flat, 4) - # Take the maximum average level of pixels that are in a square of 1/4*height from the centre - radius = flat_binned.shape[0]//4 - img_h_half = flat_binned.shape[0]//2 - img_w_half = flat_binned.shape[1]//2 - self.flat_avg = np.max(flat_binned[img_h_half-radius:img_h_half+radius, \ - img_w_half-radius:img_w_half+radius]) + # Take the maximum average level of pixels in a square of 1/4*height from the centre + radius = flat_binned.shape[0] // 4 + img_h_half, img_w_half = flat_binned.shape[0] // 2, flat_binned.shape[1] // 2 + avg = np.max(flat_binned[img_h_half-radius:img_h_half+radius, + img_w_half-radius:img_w_half+radius]) - # Make sure the self.flat_avg value is relatively high - self.flat_avg = max(self.flat_avg, 1) + # Make sure the average value is relatively high + return max(avg, 1.0) + def _bin_image(self, image, bin_factor): + """ Bin the image by the given factor. """ + if bin_factor == 1: + return image - def fixValues(self): - """ Handle values close to 0 on flats. """ + h, w = image.shape + return image.reshape(h//bin_factor, bin_factor, w//bin_factor, bin_factor).mean(axis=(1,3)) - # Make sure there are no values close to 0, as images are divided by flats - mask = (self.flat_img_raw < self.flat_avg/10) | (self.flat_img_raw < 10) - self.flat_img_raw[mask] = self.flat_avg + def apply_flat(self, img): + """ Apply the flat field correction to the image. """ + if img.shape != self.shape: + raise ValueError("Image shape does not match flat field shape") - def getFlatImage(self): - """Return the processed flat image as float32 for calculations.""" - return self.flat_img_raw.astype(np.float32) / self.max_value - + # Convert image to appropriate float type for calculations + img = img.astype(self.float_dtype) - def binFlat(self, binning_factor, binning_method): - """ Bin the flat. """ + # Apply the flat field correction + img *= self.flat_inverse + img *= self.flat_avg - # Bin the raw flat image in-place - self.flat_img_raw = binImage(self.flat_img_raw, binning_factor, binning_method) + return img @@ -626,12 +616,9 @@ def loadFlat(dir_path, file_name, dtype=None, byteswap=False, dark=None): return flat_struct - - - -def applyFlat(img, flat_struct): - """ Apply a flat field to the image. - +def apply_flat(img, flat_struct): + """ Wrapper function to apply flat field correction. + Arguments: img: [ndarray] Image to flat field. flat_struct: [Flat struct] Structure containing the flat field. @@ -639,33 +626,18 @@ def applyFlat(img, flat_struct): Return: [ndarray] Flat corrected image. - """ - # Get the processed flat image - flat_img = flat_struct.getFlatImage() - - # Check that the input image and the flat have the same dimensions, otherwise do not apply it - if img.shape != flat_img.shape: - print("Warning: Flat field dimensions do not match the image. Flat field not applied.") - return img - input_type = img.dtype - # Convert image to float32 for calculations - img = img.astype(np.float32) - # Apply the flat field correction - np.divide(img, flat_img, out=img) - np.multiply(img, flat_struct.flat_avg / flat_struct.max_value, out=img) + img = flat_struct.apply_flat(img) # Clip the values to the input type's range np.clip(img, np.iinfo(input_type).min, np.iinfo(input_type).max, out=img) # Make sure the output array is the same as the input type - img = img.astype(input_type) - - return img + return img.astype(input_type)