Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement SK Module in Glasses #283

Open
wants to merge 6 commits into
base: develop
Choose a base branch
from

Conversation

rentainhe
Copy link
Contributor

@rentainhe rentainhe commented Oct 6, 2021

Paper

Reference

TODO

  • implement SelectiveKernelAttn Module
  • implement SelectiveKernel Module
  • Finish Doc String



def test_att():
x = torch.rand(1, 48, 8, 8)
x = torch.rand(2, 48, 8, 8)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cuz there is BatchNorm2d in SelectiveKernel, it will turn out some error with batch=1

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes! put the module in .eval() mode

@codecov-commenter
Copy link

codecov-commenter commented Oct 7, 2021

Codecov Report

Merging #283 (61df784) into develop (34300a7) will increase coverage by 0.03%.
The diff coverage is 93.18%.

Impacted file tree graph

@@             Coverage Diff             @@
##           develop     #283      +/-   ##
===========================================
+ Coverage    97.28%   97.31%   +0.03%     
===========================================
  Files           86       87       +1     
  Lines         3056     3203     +147     
===========================================
+ Hits          2973     3117     +144     
- Misses          83       86       +3     
Impacted Files Coverage Δ
glasses/nn/att/CBAM.py 100.00% <ø> (ø)
glasses/nn/att/ECA.py 100.00% <ø> (ø)
glasses/utils/Storage.py 95.40% <ø> (+0.16%) ⬆️
glasses/nn/att/utils.py 93.75% <83.33%> (-6.25%) ⬇️
glasses/nn/att/SK.py 93.10% <93.10%> (ø)
glasses/nn/att/__init__.py 100.00% <100.00%> (ø)
glasses/nn/att/se.py 100.00% <100.00%> (ø)
test/test_att.py 100.00% <100.00%> (ø)
test/test_auto.py 100.00% <0.00%> (ø)
... and 20 more

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update d0089dc...61df784. Read the comment docs.

@rentainhe
Copy link
Contributor Author

record the older code

import torch
import torch.nn as nn
from typing import Union, List

from glasses.nn.att.utils import make_divisible
from ..blocks import ConvBnAct
from einops.layers.torch import Rearrange, Reduce

def _kernel_valid(k):
    if isinstance(k, (list, tuple)):
        for ki in k:
            return _kernel_valid(ki)
    assert k >=3 and k % 2

class SelectiveKernelAtt(nn.Module):
    def __init__(
        self,
        features: int,
        num_paths: int = 2,
        mid_features: int = 32,
        act_layer: nn.Module = nn.ReLU,
        norm_layer: nn.Module = nn.BatchNorm2d,
    ):
        super().__init__()
        self.num_paths = num_paths
        self.att = nn.Sequential(
            Reduce("b n c h w -> b c h w", reduction="sum"),
            Reduce("b c h w -> b c 1 1", reduction="mean"),
            nn.Conv2d(features, mid_features, kernel_size=1, bias=False),
            norm_layer(mid_features),
            act_layer(inplace=True),
            nn.Conv2d(mid_features, features * num_paths, kernel_size=1, bias=False),
            Rearrange('b (n c) h w -> b n c h w', n=num_paths, c=features),
            nn.Softmax(dim=1),
        )

    def forward(self, x):
        assert x.shape[1] == self.num_paths
        x = self.att(x)
        return x


class SelectiveKernel(nn.Module):
    def __init__(
        self,
        in_features: int,
        out_features: int = None,
        kernel_size: Union[List, int] = None,
        stride: int = 1,
        dilation: int = 1,
        groups: int = 1,
        reduction: int = 16,
        reduction_divisor: int = 8,
        reduced_features: int = None,
        keep_3x3: bool = True,
        activation: nn.Module = nn.ReLU,
        normalization: nn.Module = nn.BatchNorm2d,
    ):
        super().__init__()
        out_features = out_features or in_features
        kernel_size = kernel_size or [3, 5]
        _kernel_valid(kernel_size)
        if not isinstance(kernel_size, list):
            kernel_size = [kernel_size] * 2
        if keep_3x3:
            dilation = [dilation * (k - 1) // 2 for k in kernel_size]
            kernel_size = [3] * len(kernel_size)
        else:
            dilation = [dilation] * len(kernel_size)
        self.num_paths = len(kernel_size)
        self.in_features = in_features
        self.out_features = out_features,
        groups = min(out_features, groups)
        
        self.paths = nn.ModuleList([
            ConvBnAct(in_features = in_features, 
                      out_features = out_features, 
                      activation = activation, 
                      normalization=normalization,
                      mode = "same",
                      stride=stride,
                      kernel_size=k, 
                      dilation=d)
            for k, d in zip(kernel_size, dilation)
        ])

        attn_features = reduced_features or make_divisible(out_features // reduction, divisor=reduction_divisor)
        self.attn = SelectiveKernelAtt(out_features, self.num_paths, attn_features)
    
    def forward(self, x):
        x_paths = [op(x) for op in self.paths]  # b, c, h, w
        x = torch.stack(x_paths, dim=1)  # b, n, c, h, w
        x_attn = self.attn(x)
        x = x * x_attn
        return torch.sum(x, dim=1)

@FrancescoSaverioZuppichini FrancescoSaverioZuppichini changed the base branch from master to develop October 8, 2021 08:51
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the PR. Let's

  • add typing
  • remove bad practices such as:
if not isinstance(kernel_size, list):
            kernel_size = [kernel_size] * 2
        if keep_3x3:
            dilation = [1 * (k - 1) // 2 for k in kernel_size]
            kernel_size = [3] * len(kernel_size)
        else:
            dilation = [1 * (k - 1) // 2 for k in kernel_size]
  • decuple each part of the module
  • let the user pass a black, to default ConvBnAct

@rentainhe
Copy link
Contributor Author

Thank you for the PR. Let's

  • add typing
  • remove bad practices such as:
if not isinstance(kernel_size, list):
            kernel_size = [kernel_size] * 2
        if keep_3x3:
            dilation = [1 * (k - 1) // 2 for k in kernel_size]
            kernel_size = [3] * len(kernel_size)
        else:
            dilation = [1 * (k - 1) // 2 for k in kernel_size]
  • decuple each part of the module
  • let the user pass a black, to default ConvBnAct

Sure, I will update my code tonight~, thanks for reviewing

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants