Skip to content

Commit

Permalink
Run precommit checks on existing code
Browse files Browse the repository at this point in the history
  • Loading branch information
msschwartz21 committed Jun 11, 2024
1 parent 290c1bb commit 08ac6b2
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 33 deletions.
10 changes: 5 additions & 5 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,11 @@ repos:
hooks:
- id: validate-pyproject

- repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.10.0
hooks:
- id: mypy
files: "^src/"
# - repo: https://github.com/pre-commit/mirrors-mypy
# rev: v1.10.0
# hooks:
# - id: mypy
# files: "^src/"
# # you have to add the things you want to type check against here
# additional_dependencies:
# - numpy
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -72,15 +72,15 @@ sources = ["src"]

# https://github.com/charliermarsh/ruff
[tool.ruff]
line-length = 88
line-length = 100
target-version = "py38"
src = ["src"]
# https://beta.ruff.rs/docs/rules/
select = [
"E", # style errors
"W", # style warnings
"F", # flakes
"D", # pydocstyle
# "D", # pydocstyle
"I", # isort
"UP", # pyupgrade
"C4", # flake8-comprehensions
Expand Down
58 changes: 32 additions & 26 deletions src/dlmbl_unet/unet.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,20 @@
import math
import torch
import torch.nn as nn
import numpy as np


class ConvBlock(torch.nn.Module):
"""A convolution block for a U-Net. Contains two convolutions, each followed by a ReLU."""

def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: int,
padding: str = "same",
):
""" A convolution block for a U-Net. Contains two convolutions, each followed by a ReLU.
"""A convolution block for a U-Net. Contains two convolutions, each followed by a ReLU.
Args:
----
in_channels (int): The number of input channels for this conv block. Depends on
the layer and side of the U-Net and the hyperparameters.
out_channels (int): The number of output channels for this conv block. Depends on
Expand Down Expand Up @@ -60,13 +60,10 @@ def __init__(self, downsample_factor: int):

self.downsample_factor = downsample_factor

self.down = torch.nn.MaxPool2d(
downsample_factor
)
self.down = torch.nn.MaxPool2d(downsample_factor)

def check_valid(self, image_size: tuple[int, int]) -> bool:
"""Check if the downsample factor evenly divides each image dimension
"""
"""Check if the downsample factor evenly divides each image dimension."""
for dim in image_size:
if dim % self.downsample_factor != 0:
return False
Expand All @@ -75,8 +72,7 @@ def check_valid(self, image_size: tuple[int, int]) -> bool:
def forward(self, x):
if not self.check_valid(tuple(x.size()[-2:])):
raise RuntimeError(
"Can not downsample shape %s with factor %s"
% (x.size(), self.downsample_factor)
f"Can not downsample shape {x.size()} with factor {self.downsample_factor}"
)

return self.down(x)
Expand All @@ -85,7 +81,6 @@ def forward(self, x):
class CropAndConcat(torch.nn.Module):
def crop(self, x, y):
"""Center-crop x to match spatial dimensions given by y."""

x_target_size = x.size()[:-2] + y.size()[-2:]

offset = tuple((a - b) // 2 for a, b in zip(x.size(), x_target_size))
Expand All @@ -99,15 +94,26 @@ def forward(self, encoder_output, upsample_output):

return torch.cat([encoder_cropped, upsample_output], dim=1)


class OutputConv(torch.nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
activation: str | None = None, # Accepts the name of any torch activation function (e.g., ``ReLU`` for ``torch.nn.ReLU``).
activation: str | None = None, # .
):
"""A convolutional block that applies a torch activation function.
Args:
in_channels (int): Number of input channels
out_channels (int): Number of output channels
activation (str | None, optional): Accepts the name of any torch activation function
(e.g., ``ReLU`` for ``torch.nn.ReLU``). Defaults to None.
"""
super().__init__()
self.final_conv = torch.nn.Conv2d(in_channels, out_channels, 1, padding=0) # leave this out
self.final_conv = torch.nn.Conv2d(
in_channels, out_channels, 1, padding=0
) # leave this out
if activation is None:
self.activation = None
else:
Expand Down Expand Up @@ -136,7 +142,9 @@ def __init__(
):
"""A U-Net for 2D input that expects tensors shaped like::
``(batch, channels, height, width)``.
Args:
----
depth:
The number of levels in the U-Net. 2 is the smallest that really
makes sense for the U-Net architecture, as a one layer U-Net is
Expand Down Expand Up @@ -166,7 +174,6 @@ def __init__(
The upsampling mode to pass to torch.nn.Upsample. Usually "nearest"
or "bilinear." Defaults to "nearest."
"""

super().__init__()

self.depth = depth
Expand All @@ -185,12 +192,7 @@ def __init__(
for level in range(self.depth):
fmaps_in, fmaps_out = self.compute_fmaps_encoder(level)
self.left_convs.append(
ConvBlock(
fmaps_in,
fmaps_out,
self.kernel_size,
self.padding
)
ConvBlock(fmaps_in, fmaps_out, self.kernel_size, self.padding)
)

# right convolutional passes
Expand All @@ -208,9 +210,9 @@ def __init__(

self.downsample = Downsample(self.downsample_factor)
self.upsample = torch.nn.Upsample(
scale_factor=self.downsample_factor,
mode=self.upsample_mode,
)
scale_factor=self.downsample_factor,
mode=self.upsample_mode,
)
self.crop_and_concat = CropAndConcat()
self.final_conv = OutputConv(
self.compute_fmaps_decoder(0)[1], self.out_channels, self.final_activation
Expand All @@ -221,6 +223,7 @@ def compute_fmaps_encoder(self, level: int) -> tuple[int, int]:
a conv block at a given level of the UNet encoder (left side).
Args:
----
level (int): The level of the U-Net which we are computing
the feature maps for. Level 0 is the input level, level 1 is
the first downsampled layer, and level=depth - 1 is the bottom layer.
Expand All @@ -243,14 +246,17 @@ def compute_fmaps_decoder(self, level: int) -> tuple[int, int]:
so this function is only valid up to depth - 2.
Args:
----
level (int): The level of the U-Net which we are computing
the feature maps for. Level 0 is the input level, level 1 is
the first downsampled layer, and level=depth - 1 is the bottom layer.
Output (tuple[int, int]): The number of input and output feature maps
of the encoder convolutional pass in the given level.
"""
fmaps_out = self.num_fmaps * self.fmap_inc_factor ** (level) # Leave out function
fmaps_out = self.num_fmaps * self.fmap_inc_factor ** (
level
) # Leave out function
concat_fmaps = self.compute_fmaps_encoder(level)[
1
] # The channels that come from the skip connection
Expand All @@ -273,7 +279,7 @@ def forward(self, x):
layer_input = conv_out

# right
for i in range(0, self.depth-1)[::-1]: # leave out center of for loop
for i in range(0, self.depth - 1)[::-1]: # leave out center of for loop
upsampled = self.upsample(layer_input)
concat = self.crop_and_concat(convolution_outputs[i], upsampled)
conv_output = self.right_convs[i](concat)
Expand Down

0 comments on commit 08ac6b2

Please sign in to comment.