Skip to content

Commit

Permalink
ESM openfold_utils type hints (huggingface#20544)
Browse files Browse the repository at this point in the history
* add type annotations for esm chunk_utils

use isinstance builtin instead of 'type(x) is y'; add assertions to aid in type inferencing; use bools instead of ints in _get_minimal_slice_set for improved type clarity; refactor to avoid re-assigning to the same variable with a different type

* add type annotations for esm data_transforms

refactor to avoid re-assigning to the same variable with a different type

* add type annotations for esm feats utils

refactor to avoid re-assigning to the same variable with a different type

* add type annotations for esm loss utils

* add/fix type annotations for esm rigit_utils

refactor to avoid re-assigning to the same variable with a different type; fix Callable, Tuple type hints; match conditional structure to other methods; fix return type on Rotation.cat and Rotation.unsqueeze

* add type annotations for esm tensor_utils

overload for tree_map; use insinstance builtin instead of 'type(x) is y'; export dict_multimap, flatten_final_dims, permute_final_dims in openfold_utils

* add type annotations for esm protein utils

add FIXME for attempted string mutation; add missing None check in get_pdb_headers; fix potentially unbound variable 'chain_tag' in to_pdb; modify get_pdb_headers return type

* add type annotations for esm residue constants

hints on collection constants; remove magic trailing comma to reduce number of lines; change list -> tuple for rigid_group_atom_positions for improved hinting

* code style fixup

Co-authored-by: Matt <[email protected]>
  • Loading branch information
2 people authored and amyeroberts committed Dec 7, 2022
1 parent dcf3c21 commit c4653b0
Show file tree
Hide file tree
Showing 9 changed files with 491 additions and 760 deletions.
1 change: 1 addition & 0 deletions src/transformers/models/esm/openfold_utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,4 @@
from .protein import Protein as OFProtein
from .protein import to_pdb
from .rigid_utils import Rigid, Rotation
from .tensor_utils import dict_multimap, flatten_final_dims, permute_final_dims
110 changes: 54 additions & 56 deletions src/transformers/models/esm/openfold_utils/chunk_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,23 +14,22 @@
import logging
import math
from functools import partial
from typing import Any, Callable, Dict, Optional, Sequence, Tuple
from typing import Any, Callable, Dict, Iterable, List, Optional, Sequence, Tuple, Union

import torch

from .tensor_utils import tensor_tree_map, tree_map


def _fetch_dims(tree):
def _fetch_dims(tree: Union[dict, list, tuple, torch.Tensor]) -> List[Tuple[int, ...]]:
shapes = []
tree_type = type(tree)
if tree_type is dict:
if isinstance(tree, dict):
for v in tree.values():
shapes.extend(_fetch_dims(v))
elif tree_type is list or tree_type is tuple:
elif isinstance(tree, (list, tuple)):
for t in tree:
shapes.extend(_fetch_dims(t))
elif tree_type is torch.Tensor:
elif isinstance(tree, torch.Tensor):
shapes.append(tree.shape)
else:
raise ValueError("Not supported")
Expand All @@ -39,10 +38,7 @@ def _fetch_dims(tree):


@torch.jit.ignore
def _flat_idx_to_idx(
flat_idx: int,
dims: Tuple[int],
) -> Tuple[int]:
def _flat_idx_to_idx(flat_idx: int, dims: Tuple[int, ...]) -> Tuple[int, ...]:
idx = []
for d in reversed(dims):
idx.append(flat_idx % d)
Expand All @@ -55,10 +51,10 @@ def _flat_idx_to_idx(
def _get_minimal_slice_set(
start: Sequence[int],
end: Sequence[int],
dims: int,
dims: Sequence[int],
start_edges: Optional[Sequence[bool]] = None,
end_edges: Optional[Sequence[bool]] = None,
) -> Sequence[Tuple[int]]:
) -> List[Tuple[slice, ...]]:
"""
Produces an ordered sequence of tensor slices that, when used in sequence on a tensor with shape dims, yields
tensors that contain every leaf in the contiguous range [start, end]. Care is taken to yield a short sequence of
Expand All @@ -69,11 +65,11 @@ def _get_minimal_slice_set(
# start_edges and end_edges both indicate whether, starting from any given
# dimension, the start/end index is at the top/bottom edge of the
# corresponding tensor, modeled as a tree
def reduce_edge_list(l):
tally = 1
def reduce_edge_list(l: List[bool]) -> None:
tally = True
for i in range(len(l)):
reversed_idx = -1 * (i + 1)
l[reversed_idx] *= tally
l[reversed_idx] &= tally
tally = l[reversed_idx]

if start_edges is None:
Expand All @@ -90,48 +86,54 @@ def reduce_edge_list(l):
elif len(start) == 1:
return [(slice(start[0], end[0] + 1),)]

slices = []
path = []
slices: List[Tuple[slice, ...]] = []
path_list: List[slice] = []

# Dimensions common to start and end can be selected directly
for s, e in zip(start, end):
if s == e:
path.append(slice(s, s + 1))
path_list.append(slice(s, s + 1))
else:
break

path = tuple(path)
path: Tuple[slice, ...] = tuple(path_list)
divergence_idx = len(path)

# start == end, and we're done
if divergence_idx == len(dims):
return [tuple(path)]
return [path]

def upper() -> Tuple[Tuple[slice, ...], ...]:
assert start_edges is not None
assert end_edges is not None

def upper():
sdi = start[divergence_idx]
return [
return tuple(
path + (slice(sdi, sdi + 1),) + s
for s in _get_minimal_slice_set(
start[divergence_idx + 1 :],
[d - 1 for d in dims[divergence_idx + 1 :]],
dims[divergence_idx + 1 :],
start_edges=start_edges[divergence_idx + 1 :],
end_edges=[1 for _ in end_edges[divergence_idx + 1 :]],
end_edges=[True for _ in end_edges[divergence_idx + 1 :]],
)
]
)

def lower() -> Tuple[Tuple[slice, ...], ...]:
assert start_edges is not None
assert end_edges is not None

def lower():
edi = end[divergence_idx]
return [
return tuple(
path + (slice(edi, edi + 1),) + s
for s in _get_minimal_slice_set(
[0 for _ in start[divergence_idx + 1 :]],
end[divergence_idx + 1 :],
dims[divergence_idx + 1 :],
start_edges=[1 for _ in start_edges[divergence_idx + 1 :]],
start_edges=[True for _ in start_edges[divergence_idx + 1 :]],
end_edges=end_edges[divergence_idx + 1 :],
)
]
)

# If both start and end are at the edges of the subtree rooted at
# divergence_idx, we can just select the whole subtree at once
Expand All @@ -156,16 +158,11 @@ def lower():
slices.append(path + (slice(start[divergence_idx] + 1, end[divergence_idx]),))
slices.extend(lower())

return [tuple(s) for s in slices]
return slices


@torch.jit.ignore
def _chunk_slice(
t: torch.Tensor,
flat_start: int,
flat_end: int,
no_batch_dims: int,
) -> torch.Tensor:
def _chunk_slice(t: torch.Tensor, flat_start: int, flat_end: int, no_batch_dims: int) -> torch.Tensor:
"""
Equivalent to
Expand Down Expand Up @@ -232,7 +229,7 @@ def chunk_layer(
initial_dims = [shape[:no_batch_dims] for shape in _fetch_dims(inputs)]
orig_batch_dims = tuple([max(s) for s in zip(*initial_dims)])

def _prep_inputs(t):
def _prep_inputs(t: torch.Tensor) -> torch.Tensor:
if not low_mem:
if not sum(t.shape[:no_batch_dims]) == no_batch_dims:
t = t.expand(orig_batch_dims + t.shape[no_batch_dims:])
Expand All @@ -241,7 +238,7 @@ def _prep_inputs(t):
t = t.expand(orig_batch_dims + t.shape[no_batch_dims:])
return t

prepped_inputs = tensor_tree_map(_prep_inputs, inputs)
prepped_inputs: Dict[str, Any] = tensor_tree_map(_prep_inputs, inputs)
prepped_outputs = None
if _out is not None:
prepped_outputs = tensor_tree_map(lambda t: t.view([-1] + list(t.shape[no_batch_dims:])), _out)
Expand All @@ -252,7 +249,7 @@ def _prep_inputs(t):

no_chunks = flat_batch_dim // chunk_size + (flat_batch_dim % chunk_size != 0)

def _select_chunk(t):
def _select_chunk(t: torch.Tensor) -> torch.Tensor:
return t[i : i + chunk_size] if t.shape[0] != 1 else t

i = 0
Expand All @@ -269,7 +266,7 @@ def _select_chunk(t):
no_batch_dims=len(orig_batch_dims),
)

chunks = tensor_tree_map(select_chunk, prepped_inputs)
chunks: Dict[str, Any] = tensor_tree_map(select_chunk, prepped_inputs)

# Run the layer on the chunk
output_chunk = layer(**chunks)
Expand All @@ -279,12 +276,11 @@ def _select_chunk(t):
out = tensor_tree_map(lambda t: t.new_zeros((flat_batch_dim,) + t.shape[1:]), output_chunk)

# Put the chunk in its pre-allocated space
out_type = type(output_chunk)
if out_type is dict:
if isinstance(output_chunk, dict):

def assign(d1, d2):
def assign(d1: dict, d2: dict) -> None:
for k, v in d1.items():
if type(v) is dict:
if isinstance(v, dict):
assign(v, d2[k])
else:
if _add_into_out:
Expand All @@ -293,13 +289,13 @@ def assign(d1, d2):
v[i : i + chunk_size] = d2[k]

assign(out, output_chunk)
elif out_type is tuple:
elif isinstance(output_chunk, tuple):
for x1, x2 in zip(out, output_chunk):
if _add_into_out:
x1[i : i + chunk_size] += x2
else:
x1[i : i + chunk_size] = x2
elif out_type is torch.Tensor:
elif isinstance(output_chunk, torch.Tensor):
if _add_into_out:
out[i : i + chunk_size] += output_chunk
else:
Expand All @@ -319,24 +315,24 @@ def __init__(
self,
# Heuristically, runtimes for most of the modules in the network
# plateau earlier than this on all GPUs I've run the model on.
max_chunk_size=512,
max_chunk_size: int = 512,
):
self.max_chunk_size = max_chunk_size
self.cached_chunk_size = None
self.cached_arg_data = None
self.cached_chunk_size: Optional[int] = None
self.cached_arg_data: Optional[tuple] = None

def _determine_favorable_chunk_size(self, fn, args, min_chunk_size):
def _determine_favorable_chunk_size(self, fn: Callable, args: tuple, min_chunk_size: int) -> int:
logging.info("Tuning chunk size...")

if min_chunk_size >= self.max_chunk_size:
return min_chunk_size

candidates = [2**l for l in range(int(math.log(self.max_chunk_size, 2)) + 1)]
candidates: List[int] = [2**l for l in range(int(math.log(self.max_chunk_size, 2)) + 1)]
candidates = [c for c in candidates if c > min_chunk_size]
candidates = [min_chunk_size] + candidates
candidates[-1] += 4

def test_chunk_size(chunk_size):
def test_chunk_size(chunk_size: int) -> bool:
try:
with torch.no_grad():
fn(*args, chunk_size=chunk_size)
Expand All @@ -356,13 +352,13 @@ def test_chunk_size(chunk_size):

return candidates[min_viable_chunk_size_index]

def _compare_arg_caches(self, ac1, ac2):
def _compare_arg_caches(self, ac1: Iterable, ac2: Iterable) -> bool:
consistent = True
for a1, a2 in zip(ac1, ac2):
assert type(ac1) == type(ac2)
if type(ac1) is list or type(ac1) is tuple:
if isinstance(ac1, (list, tuple)):
consistent &= self._compare_arg_caches(a1, a2)
elif type(ac1) is dict:
elif isinstance(ac1, dict):
a1_items = [v for _, v in sorted(a1.items(), key=lambda x: x[0])]
a2_items = [v for _, v in sorted(a2.items(), key=lambda x: x[0])]
consistent &= self._compare_arg_caches(a1_items, a2_items)
Expand All @@ -374,11 +370,11 @@ def _compare_arg_caches(self, ac1, ac2):
def tune_chunk_size(
self,
representative_fn: Callable,
args: Tuple[Any],
args: tuple,
min_chunk_size: int,
) -> int:
consistent = True
arg_data = tree_map(lambda a: a.shape if type(a) is torch.Tensor else a, args, object)
arg_data: tuple = tree_map(lambda a: a.shape if isinstance(a, torch.Tensor) else a, args, object)
if self.cached_arg_data is not None:
# If args have changed shape/value, we need to re-tune
assert len(self.cached_arg_data) == len(arg_data)
Expand All @@ -395,4 +391,6 @@ def tune_chunk_size(
)
self.cached_arg_data = arg_data

assert self.cached_chunk_size is not None

return self.cached_chunk_size
33 changes: 17 additions & 16 deletions src/transformers/models/esm/openfold_utils/data_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,46 +13,48 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Dict

import numpy as np
import torch

from . import residue_constants as rc
from .tensor_utils import tensor_tree_map, tree_map


def make_atom14_masks(protein):
def make_atom14_masks(protein: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
"""Construct denser atom positions (14 dimensions instead of 37)."""
restype_atom14_to_atom37 = []
restype_atom37_to_atom14 = []
restype_atom14_mask = []
restype_atom14_to_atom37_list = []
restype_atom37_to_atom14_list = []
restype_atom14_mask_list = []

for rt in rc.restypes:
atom_names = rc.restype_name_to_atom14_names[rc.restype_1to3[rt]]
restype_atom14_to_atom37.append([(rc.atom_order[name] if name else 0) for name in atom_names])
restype_atom14_to_atom37_list.append([(rc.atom_order[name] if name else 0) for name in atom_names])
atom_name_to_idx14 = {name: i for i, name in enumerate(atom_names)}
restype_atom37_to_atom14.append(
restype_atom37_to_atom14_list.append(
[(atom_name_to_idx14[name] if name in atom_name_to_idx14 else 0) for name in rc.atom_types]
)

restype_atom14_mask.append([(1.0 if name else 0.0) for name in atom_names])
restype_atom14_mask_list.append([(1.0 if name else 0.0) for name in atom_names])

# Add dummy mapping for restype 'UNK'
restype_atom14_to_atom37.append([0] * 14)
restype_atom37_to_atom14.append([0] * 37)
restype_atom14_mask.append([0.0] * 14)
restype_atom14_to_atom37_list.append([0] * 14)
restype_atom37_to_atom14_list.append([0] * 37)
restype_atom14_mask_list.append([0.0] * 14)

restype_atom14_to_atom37 = torch.tensor(
restype_atom14_to_atom37,
restype_atom14_to_atom37_list,
dtype=torch.int32,
device=protein["aatype"].device,
)
restype_atom37_to_atom14 = torch.tensor(
restype_atom37_to_atom14,
restype_atom37_to_atom14_list,
dtype=torch.int32,
device=protein["aatype"].device,
)
restype_atom14_mask = torch.tensor(
restype_atom14_mask,
restype_atom14_mask_list,
dtype=torch.float32,
device=protein["aatype"].device,
)
Expand Down Expand Up @@ -85,8 +87,7 @@ def make_atom14_masks(protein):
return protein


def make_atom14_masks_np(batch):
def make_atom14_masks_np(batch: Dict[str, torch.Tensor]) -> Dict[str, np.ndarray]:
batch = tree_map(lambda n: torch.tensor(n, device=batch["aatype"].device), batch, np.ndarray)
out = make_atom14_masks(batch)
out = tensor_tree_map(lambda t: np.array(t), out)
out = tensor_tree_map(lambda t: np.array(t), make_atom14_masks(batch))
return out
Loading

0 comments on commit c4653b0

Please sign in to comment.