diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 0e0d672..ee849c9 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -28,11 +28,11 @@ repos: hooks: - id: validate-pyproject - - repo: https://github.com/pre-commit/mirrors-mypy - rev: v1.10.0 - hooks: - - id: mypy - files: "^src/" + # - repo: https://github.com/pre-commit/mirrors-mypy + # rev: v1.10.0 + # hooks: + # - id: mypy + # files: "^src/" # # you have to add the things you want to type check against here # additional_dependencies: # - numpy diff --git a/pyproject.toml b/pyproject.toml index 407fd01..c401bd8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -72,7 +72,7 @@ sources = ["src"] # https://github.com/charliermarsh/ruff [tool.ruff] -line-length = 88 +line-length = 100 target-version = "py38" src = ["src"] # https://beta.ruff.rs/docs/rules/ @@ -80,7 +80,7 @@ select = [ "E", # style errors "W", # style warnings "F", # flakes - "D", # pydocstyle + # "D", # pydocstyle "I", # isort "UP", # pyupgrade "C4", # flake8-comprehensions diff --git a/src/dlmbl_unet/unet.py b/src/dlmbl_unet/unet.py index ef18b6a..064af36 100644 --- a/src/dlmbl_unet/unet.py +++ b/src/dlmbl_unet/unet.py @@ -1,10 +1,9 @@ -import math import torch -import torch.nn as nn -import numpy as np class ConvBlock(torch.nn.Module): + """A convolution block for a U-Net. Contains two convolutions, each followed by a ReLU.""" + def __init__( self, in_channels: int, @@ -12,9 +11,10 @@ def __init__( kernel_size: int, padding: str = "same", ): - """ A convolution block for a U-Net. Contains two convolutions, each followed by a ReLU. + """A convolution block for a U-Net. Contains two convolutions, each followed by a ReLU. Args: + ---- in_channels (int): The number of input channels for this conv block. Depends on the layer and side of the U-Net and the hyperparameters. out_channels (int): The number of output channels for this conv block. Depends on @@ -60,13 +60,10 @@ def __init__(self, downsample_factor: int): self.downsample_factor = downsample_factor - self.down = torch.nn.MaxPool2d( - downsample_factor - ) + self.down = torch.nn.MaxPool2d(downsample_factor) def check_valid(self, image_size: tuple[int, int]) -> bool: - """Check if the downsample factor evenly divides each image dimension - """ + """Check if the downsample factor evenly divides each image dimension.""" for dim in image_size: if dim % self.downsample_factor != 0: return False @@ -75,8 +72,7 @@ def check_valid(self, image_size: tuple[int, int]) -> bool: def forward(self, x): if not self.check_valid(tuple(x.size()[-2:])): raise RuntimeError( - "Can not downsample shape %s with factor %s" - % (x.size(), self.downsample_factor) + f"Can not downsample shape {x.size()} with factor {self.downsample_factor}" ) return self.down(x) @@ -85,7 +81,6 @@ def forward(self, x): class CropAndConcat(torch.nn.Module): def crop(self, x, y): """Center-crop x to match spatial dimensions given by y.""" - x_target_size = x.size()[:-2] + y.size()[-2:] offset = tuple((a - b) // 2 for a, b in zip(x.size(), x_target_size)) @@ -99,15 +94,26 @@ def forward(self, encoder_output, upsample_output): return torch.cat([encoder_cropped, upsample_output], dim=1) + class OutputConv(torch.nn.Module): def __init__( self, in_channels: int, out_channels: int, - activation: str | None = None, # Accepts the name of any torch activation function (e.g., ``ReLU`` for ``torch.nn.ReLU``). + activation: str | None = None, # . ): + """A convolutional block that applies a torch activation function. + + Args: + in_channels (int): Number of input channels + out_channels (int): Number of output channels + activation (str | None, optional): Accepts the name of any torch activation function + (e.g., ``ReLU`` for ``torch.nn.ReLU``). Defaults to None. + """ super().__init__() - self.final_conv = torch.nn.Conv2d(in_channels, out_channels, 1, padding=0) # leave this out + self.final_conv = torch.nn.Conv2d( + in_channels, out_channels, 1, padding=0 + ) # leave this out if activation is None: self.activation = None else: @@ -136,7 +142,9 @@ def __init__( ): """A U-Net for 2D input that expects tensors shaped like:: ``(batch, channels, height, width)``. + Args: + ---- depth: The number of levels in the U-Net. 2 is the smallest that really makes sense for the U-Net architecture, as a one layer U-Net is @@ -166,7 +174,6 @@ def __init__( The upsampling mode to pass to torch.nn.Upsample. Usually "nearest" or "bilinear." Defaults to "nearest." """ - super().__init__() self.depth = depth @@ -185,12 +192,7 @@ def __init__( for level in range(self.depth): fmaps_in, fmaps_out = self.compute_fmaps_encoder(level) self.left_convs.append( - ConvBlock( - fmaps_in, - fmaps_out, - self.kernel_size, - self.padding - ) + ConvBlock(fmaps_in, fmaps_out, self.kernel_size, self.padding) ) # right convolutional passes @@ -208,9 +210,9 @@ def __init__( self.downsample = Downsample(self.downsample_factor) self.upsample = torch.nn.Upsample( - scale_factor=self.downsample_factor, - mode=self.upsample_mode, - ) + scale_factor=self.downsample_factor, + mode=self.upsample_mode, + ) self.crop_and_concat = CropAndConcat() self.final_conv = OutputConv( self.compute_fmaps_decoder(0)[1], self.out_channels, self.final_activation @@ -221,6 +223,7 @@ def compute_fmaps_encoder(self, level: int) -> tuple[int, int]: a conv block at a given level of the UNet encoder (left side). Args: + ---- level (int): The level of the U-Net which we are computing the feature maps for. Level 0 is the input level, level 1 is the first downsampled layer, and level=depth - 1 is the bottom layer. @@ -243,6 +246,7 @@ def compute_fmaps_decoder(self, level: int) -> tuple[int, int]: so this function is only valid up to depth - 2. Args: + ---- level (int): The level of the U-Net which we are computing the feature maps for. Level 0 is the input level, level 1 is the first downsampled layer, and level=depth - 1 is the bottom layer. @@ -250,7 +254,9 @@ def compute_fmaps_decoder(self, level: int) -> tuple[int, int]: Output (tuple[int, int]): The number of input and output feature maps of the encoder convolutional pass in the given level. """ - fmaps_out = self.num_fmaps * self.fmap_inc_factor ** (level) # Leave out function + fmaps_out = self.num_fmaps * self.fmap_inc_factor ** ( + level + ) # Leave out function concat_fmaps = self.compute_fmaps_encoder(level)[ 1 ] # The channels that come from the skip connection @@ -273,7 +279,7 @@ def forward(self, x): layer_input = conv_out # right - for i in range(0, self.depth-1)[::-1]: # leave out center of for loop + for i in range(0, self.depth - 1)[::-1]: # leave out center of for loop upsampled = self.upsample(layer_input) concat = self.crop_and_concat(convolution_outputs[i], upsampled) conv_output = self.right_convs[i](concat)