Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable mypy static type checking and reslove errors #51

Draft
wants to merge 5 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .github/workflows/lint.yml
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,10 @@ jobs:
commands: |
echo "::add-matcher::.github/workflows/matchers/pylint.json"
tox -e lint
- name: "mypy"
commands: |
echo "::add-matcher::.github/workflows/matchers/mypy.json"
tox -e mypy

steps:
- name: "Harden Runner"
Expand Down
16 changes: 16 additions & 0 deletions .github/workflows/matchers/mypy.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
{
"problemMatcher": [
{
"owner": "mypy",
"pattern": [
{
"regexp": "^(.+):(\\d+):\\s(error|warning):\\s(.+)$",
"file": 1,
"line": 2,
"severity": 3,
"message": 4
}
]
}
]
}
4 changes: 2 additions & 2 deletions fms_mo/aiu_addons/i8i8/i8i8_aiu_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
"""Implement FMS adapter for INT8xINT8 checkpoints"""

# Standard
from typing import Mapping
from typing import Mapping, MutableMapping

# Third Party
from fms.utils import serialization
Expand Down Expand Up @@ -46,7 +46,7 @@ def _int8_qparams_aiu(


def _add_defaults_and_concat(
new_sd: Mapping[str, torch.Tensor],
new_sd: MutableMapping[str, torch.Tensor],
modules_seen: set,
) -> None:
"""
Expand Down
3 changes: 2 additions & 1 deletion fms_mo/calib.py
Original file line number Diff line number Diff line change
Expand Up @@ -482,7 +482,8 @@ def qmodel_calib(
return model

DPorDDPdevices = None
if "qmodel_prep" not in sys._getframe().f_back.f_code.co_name:
f_back = sys._getframe().f_back
if f_back and "qmodel_prep" not in f_back.f_code.co_name:
model.to(currDev)
qcfg["wasDPmodel"] = qcfg.get("wasDPmodel", isinstance(model, nn.DataParallel))
qcfg["wasDDPmodel"] = qcfg.get(
Expand Down
6 changes: 3 additions & 3 deletions fms_mo/custom_ext_kernels/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,10 +74,10 @@
# Third Party
import torch.library as lib

reg_op = partial(lib.custom_op, mutates_args=())
reg_op = partial(lib.custom_op, mutates_args=()) # type: ignore[attr-defined]
reg_op_func = lib.define # NOTE this is func, not decorator
kernel_impl = lib.register_kernel
reg_fake = lib.register_fake
kernel_impl = lib.register_kernel # type: ignore[attr-defined]
reg_fake = lib.register_fake # type: ignore[attr-defined]

else:
raise RuntimeError("Custom Op registration only works for >PT2.1")
Expand Down
8 changes: 5 additions & 3 deletions fms_mo/quant/ptq.py
Original file line number Diff line number Diff line change
Expand Up @@ -2623,8 +2623,10 @@ def reset_bn(module: nn.BatchNorm2d):
Function not currently used.
"""
if module.track_running_stats:
module.running_mean.zero_()
module.running_var.fill_(1 - module.eps)
if running_mean := module.running_mean:
running_mean.zero_()
if running_var := module.running_var:
running_var.fill_(1 - module.eps)
# we do not reset numer of tracked batches here
if module.affine:
nn.init.ones_(module.weight)
Expand All @@ -2643,7 +2645,7 @@ def reset_bn(module: nn.BatchNorm2d):
bn_affine = True # FrozenBN doesn't have .affine property
except:
BNofInteret = (nn.BatchNorm2d, nn.BatchNorm1d)
AbsorbLayers = (nn.Conv2d, nn.Linear)
AbsorbLayers = (nn.Conv2d, nn.Linear) # type: ignore[assignment]


def search_fold_and_remove_bn(model, mod_folded):
Expand Down
11 changes: 6 additions & 5 deletions fms_mo/quant/quantizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -3895,7 +3895,8 @@
self.delta = torch.nn.Parameter(delta)
else:
delta, zero_point = self.init_quantization_scale(x, self.channel_wise)
self.delta.fill_(delta)
if self_data := self.delta:
self_data.fill_(delta)
self.zero_point.fill_(zero_point)
self.inited = True

Expand Down Expand Up @@ -3930,25 +3931,25 @@
else:
if "max" in self.scale_method:
x_min = min(x.min().item(), 0)
x_max = max(x.max().item(), 0)

Check failure on line 3934 in fms_mo/quant/quantizers.py

View workflow job for this annotation

GitHub Actions / lint: mypy

Incompatible types in assignment (expression has type "int | float", variable has type "Tensor") [assignment]
if "scale" in self.scale_method:
x_min = x_min * (self.n_bits + 2) / 8
x_max = x_max * (self.n_bits + 2) / 8

x_absmax = max(abs(x_min), x_max)

Check failure on line 3939 in fms_mo/quant/quantizers.py

View workflow job for this annotation

GitHub Actions / lint: mypy

No overload variant of "max" matches argument types "float", "Tensor" [call-overload]
if self.sym:
x_min, x_max = -x_absmax if x_min < 0 else 0, x_absmax

delta = float(x_max - x_min) / (self.n_levels - 1)
if delta < 1e-8:

Check failure on line 3944 in fms_mo/quant/quantizers.py

View workflow job for this annotation

GitHub Actions / lint: mypy

Unsupported operand types for > ("float" and "None") [operator]
logger.info(f"Quantization range close to zero: [{x_min}, {x_max}]")
delta = 1e-8

Check failure on line 3946 in fms_mo/quant/quantizers.py

View workflow job for this annotation

GitHub Actions / lint: mypy

Incompatible types in assignment (expression has type "float", variable has type "Tensor | None") [assignment]

zero_point = round(-x_min / delta)

Check failure on line 3948 in fms_mo/quant/quantizers.py

View workflow job for this annotation

GitHub Actions / lint: mypy

Incompatible types in assignment (expression has type "Any | int", variable has type "Tensor | None") [assignment]

Check failure on line 3948 in fms_mo/quant/quantizers.py

View workflow job for this annotation

GitHub Actions / lint: mypy

Unsupported operand types for / ("int" and "None") [operator]

Check failure on line 3948 in fms_mo/quant/quantizers.py

View workflow job for this annotation

GitHub Actions / lint: mypy

Unsupported operand types for / ("float" and "None") [operator]

elif self.scale_method == "mse":
x_max = x.max()
x_min = x.min()

Check failure on line 3952 in fms_mo/quant/quantizers.py

View workflow job for this annotation

GitHub Actions / lint: mypy

Incompatible types in assignment (expression has type "Tensor", variable has type "int | float") [assignment]
best_score = 1e10
for i in range(80):
new_max = x_max * (1.0 - (i * 0.01))
Expand All @@ -3960,7 +3961,7 @@
if score < best_score:
best_score = score
delta = (new_max - new_min) / (2**self.n_bits - 1)
zero_point = (-new_min / delta).round()
zero_point = (-new_min / delta).round() # type: ignore[union-attr]

Check failure on line 3964 in fms_mo/quant/quantizers.py

View workflow job for this annotation

GitHub Actions / lint: mypy

Unsupported operand types for / ("float" and "None") [operator]
else:
raise NotImplementedError

Expand Down Expand Up @@ -4035,8 +4036,8 @@
self.reset_ReSig_param(multimodal)

self.beta = 2 / 3
self.Wshape = None
self.reshape2 = None
self.Wshape: list[int] = list()
self.reshape2: list[Any] = list()

def forward(self, x):
if self.useSAWB:
Expand Down Expand Up @@ -5389,7 +5390,7 @@
if "e4m3" in q_mode:
self.float8_dtype = torch.float8_e4m3fn
elif "e5m2" in q_mode:
self.float8_dtype = torch.float8_e5m2G
self.float8_dtype = torch.float8_e5m2
else:
raise ValueError("FP8 only supports e4m3 and e5m2")
self.emulate = emulate
Expand Down
4 changes: 2 additions & 2 deletions fms_mo/utils/qconfig_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

# Standard
from pathlib import Path
from typing import Any
from typing import Any, Dict
import json
import logging
import os
Expand Down Expand Up @@ -149,7 +149,7 @@ def qconfig_init(recipe: str = None, args: Any = None):
otherwise use constantLR as default
"""

qcfg = {}
qcfg: Dict[str, Any] = {}
# 1. create a dict with default values
qcfg["mapping"] = {
nn.Conv2d: {"from": nn.Conv2d, "to": QConv2d, "otherwise": QConv2d},
Expand Down
51 changes: 30 additions & 21 deletions fms_mo/utils/torchscript_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,8 @@ def parse_operation(op_str: str):
operands = op_str[
last_open_parenthesis_index + 1 : last_close_parenthesis_index
].split(",")
operands = [operand.strip() for operand in operands] if operands != [""] else None
# pylint: disable=line-too-long
operands = [operand.strip() for operand in operands] if operands != [""] else None # type: ignore[assignment]
return operator, operands


Expand Down Expand Up @@ -178,9 +179,14 @@ def __init__(self, node_input, dictionary_of_nodes: dict):
)
operator, operands = parse_operation(op_str)
if "aten::_conv" in op_str:
self.ch_in = list(native_torchscript_node.inputs())[0].type().sizes()
# NOTE: Needed for finding shortcut convolutions later
self.ch_out = list(native_torchscript_node.outputs())[0].type().sizes()
if native_torchscript_node:
self.ch_in = (
list(native_torchscript_node.inputs())[0].type().sizes()
)
# NOTE: Needed for finding shortcut convolutions later
self.ch_out = (
list(native_torchscript_node.outputs())[0].type().sizes()
)
else:
node_def = node_input_repr
op_str, operator, operands = None, None, None
Expand All @@ -200,31 +206,34 @@ def __init__(self, node_input, dictionary_of_nodes: dict):
working_str = node_input_repr[start_index:end_index]
start_index = end_index + 2

node_instance.name, node_instance.obj = working_str.split(" : ")
node_instance.name = node_instance.name.strip()
# pylint: disable=line-too-long
node_instance.name, node_instance.obj = working_str.split(" : ") # type: ignore[attr-defined]
node_instance.name = node_instance.name.strip() # type: ignore[attr-defined]
if native_torchscript_outputs:
if node_instance.name not in native_torchscript_outputs:
# pylint: disable=line-too-long
if node_instance.name not in native_torchscript_outputs: # type: ignore[attr-defined]
# pylint: disable=line-too-long
logger.error(
f"Node def {node_instance.name} not in nativeTSoutputs "
f"Node def {node_instance.name} not in nativeTSoutputs " # type: ignore[attr-defined]
f"{native_torchscript_outputs}"
)
node_instance.Op = op_str
node_instance.Op = op_str # type: ignore[attr-defined]
if node_def_in_one_line > 1:
node_instance.unpackIdx = node_index
node_instance.unpackIdx = node_index # type: ignore[attr-defined]
if line_number:
node_instance.lineno = line_number
node_instance.operator = operator
node_instance.lineno = line_number # type: ignore[attr-defined]
node_instance.operator = operator # type: ignore[attr-defined]
# This is the name of parents, not the pointer to the parent nodes
node_instance.parents = operands
node_instance.parents_ptr = []
node_instance.scope = scope_repr
node_instance.modname = module_name
node_instance.children = []
node_instance.children_ptr = []
node_instance.TSparents = native_torchscript_parents
node_instance.TSoutputs = native_torchscript_outputs
node_instance.parents = operands # type: ignore[attr-defined]
node_instance.parents_ptr = [] # type: ignore[attr-defined]
node_instance.scope = scope_repr # type: ignore[attr-defined]
node_instance.modname = module_name # type: ignore[attr-defined]
node_instance.children = [] # type: ignore[attr-defined]
node_instance.children_ptr = [] # type: ignore[attr-defined]
node_instance.TSparents = native_torchscript_parents # type: ignore[attr-defined]
node_instance.TSoutputs = native_torchscript_outputs # type: ignore[attr-defined]
# graph.dictionary_of_nodes will keep a record of all the nodes
dictionary_of_nodes[node_instance.name] = node_instance
dictionary_of_nodes[node_instance.name] = node_instance # type: ignore[attr-defined]

def __repr__(self):
return f"{self.name} "
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ known-local-folder=["fms_mo","tests"]
[tool.mypy]
mypy_path = [""]
packages = ["fms_mo", "tests"]
disable_error_code = []
disable_error_code = ["import-not-found", "import-untyped"]
# TODO: tighten MyPy checks by enabling these checks over time.
check_untyped_defs = false
disallow_incomplete_defs = false
Expand Down
Loading