diff --git a/RMS/Routines/Image.py b/RMS/Routines/Image.py index 6a659b3d0..9756df95c 100644 --- a/RMS/Routines/Image.py +++ b/RMS/Routines/Image.py @@ -501,85 +501,81 @@ def adjustLevels(img_array, minv, gamma, maxv, nbits=None, scaleto8bits=False): - -class FlatStruct(object): +class FlatStruct: def __init__(self, flat_img, dark=None): - """ Structure containing the flat field. + """ + Structure containing the flat field information. Arguments: - flat_img: [ndarray] Flat field. - + flat_img: [ndarray] Flat field image. + dark: [ndarray] Dark frame to be subtracted from the flat (optional). """ - # Convert the flat to float64 - self.flat_img = flat_img.astype(np.float64) - - # Store the original flat - self.flat_img_raw = np.copy(self.flat_img) - - # Apply the dark, if given - self.applyDark(dark) - - # Compute the flat median - self.computeAverage() + # Determine precision based on input image bit depth + if flat_img.dtype == np.uint8: + self.float_dtype = np.float32 + else: + self.float_dtype = np.float64 - # Fix values close to 0 - self.fixValues() + # Process the flat image and store only essential information + self._process_flat(flat_img, dark) + def _process_flat(self, flat_img, dark): + """ Process the flat image and extract essential information. """ - def applyDark(self, dark): - """ Apply a dark to the flat. """ + # Convert to appropriate float type for processing + flat = flat_img.astype(self.float_dtype) - # Apply a dark frame to the flat, if given + # Apply dark subtraction if provided if dark is not None: - self.flat_img = applyDark(self.flat_img_raw, dark) - self.dark_applied = True - - else: - self.flat_img = np.copy(self.flat_img_raw) - self.dark_applied = False + flat = np.maximum(flat - dark.astype(self.float_dtype), 0) - # Compute flat median - self.computeAverage() + # Compute and store the average + self.flat_avg = self._compute_average(flat) - # Fix values close to 0 - self.fixValues() + # Compute and store the inverse of the flat for faster division later + self.flat_inverse = np.where(flat > 0, 1.0 / flat, self.flat_avg) + # Store the shape for validation + self.shape = flat.shape - def computeAverage(self): + def _compute_average(self, flat): """ Compute the reference level. """ - # Bin the flat by a factor of 4 using the average method - flat_binned = binImage(self.flat_img, 4, method='avg') + flat_binned = self._bin_image(flat, 4) - # Take the maximum average level of pixels that are in a square of 1/4*height from the centre - radius = flat_binned.shape[0]//4 - img_h_half = flat_binned.shape[0]//2 - img_w_half = flat_binned.shape[1]//2 - self.flat_avg = np.max(flat_binned[img_h_half-radius:img_h_half+radius, \ - img_w_half-radius:img_w_half+radius]) + # Take the maximum average level of pixels in a square of 1/4*height from the centre + radius = flat_binned.shape[0] // 4 + img_h_half, img_w_half = flat_binned.shape[0] // 2, flat_binned.shape[1] // 2 + avg = np.max(flat_binned[img_h_half-radius:img_h_half+radius, + img_w_half-radius:img_w_half+radius]) - # Make sure the self.flat_avg value is relatively high - if self.flat_avg < 1: - self.flat_avg = 1 + # Make sure the average value is relatively high + return max(avg, 1.0) + def _bin_image(self, image, bin_factor): + """ Bin the image by the given factor. """ + if bin_factor == 1: + return image - def fixValues(self): - """ Handle values close to 0 on flats. """ + h, w = image.shape + return image.reshape(h//bin_factor, bin_factor, w//bin_factor, bin_factor).mean(axis=(1,3)) - # Make sure there are no values close to 0, as images are divided by flats - self.flat_img[(self.flat_img < self.flat_avg/10) | (self.flat_img < 10)] = self.flat_avg + def apply_flat(self, img): + """ Apply the flat field correction to the image. """ + if img.shape != self.shape: + raise ValueError("Image shape does not match flat field shape") - def binFlat(self, binning_factor, binning_method): - """ Bin the flat. """ + # Convert image to appropriate float type for calculations + img = img.astype(self.float_dtype) - # Bin the processed flat - self.flat_img = binImage(self.flat_img, binning_factor, binning_method) + # Apply the flat field correction + img *= self.flat_inverse + img *= self.flat_avg - # Bin the raw flat image - self.flat_img_raw = binImage(self.flat_img_raw, binning_factor, binning_method) + return img @@ -620,12 +616,9 @@ def loadFlat(dir_path, file_name, dtype=None, byteswap=False, dark=None): return flat_struct - - - -def applyFlat(img, flat_struct): - """ Apply a flat field to the image. - +def apply_flat(img, flat_struct): + """ Wrapper function to apply flat field correction. + Arguments: img: [ndarray] Image to flat field. flat_struct: [Flat struct] Structure containing the flat field. @@ -633,26 +626,18 @@ def applyFlat(img, flat_struct): Return: [ndarray] Flat corrected image. - """ - # Check that the input image and the flat have the same dimensions, otherwise do not apply it - if img.shape != flat_struct.flat_img.shape: - return img - input_type = img.dtype - # Apply the flat - img = flat_struct.flat_avg*img.astype(np.float64)/flat_struct.flat_img + # Apply the flat field correction + img = flat_struct.apply_flat(img) - # Limit the image values to image type range - dtype_info = np.iinfo(input_type) - img = np.clip(img, dtype_info.min, dtype_info.max) + # Clip the values to the input type's range + np.clip(img, np.iinfo(input_type).min, np.iinfo(input_type).max, out=img) # Make sure the output array is the same as the input type - img = img.astype(input_type) - - return img + return img.astype(input_type)