-
Notifications
You must be signed in to change notification settings - Fork 25
/
Copy pathuint4_tensor.py
123 lines (106 loc) · 4.86 KB
/
uint4_tensor.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
import torch
import torch._prims_common as utils
def down_size(size):
assert size[-1] % 2 == 0, f"{size} last dim not divisible by two"
return (*size[:-1], size[-1] // 2)
def up_size(size):
return (*size[:-1], size[-1] * 2)
def fill_defaults(args, n, defaults_tail):
"""
__torch_dispatch__ doesn't guarantee the number of arguments you are
passed (e.g., defaulted arguments are not passed); but usually it is
convenient to pad out the arguments list with defaults. This function
helps you do that.
Args:
args: the list of positional arguments passed to __torch_dispatch__
n: the number of arguments you are expecting to get
defaults_tail: default values for the arguments, starting from the
end of the list
Example:
>>> fill_defaults([1, 2, 3], 5, [3, 4, 5])
[1, 2, 3, 4, 5]
>>> fill_defaults([1, 2, 3], 5, [None, None, None])
[1, 2, 3, None, None]]
"""
if n - len(defaults_tail) > len(args):
raise RuntimeError("not enough defaults to fill arguments")
r = list(args)
for i in range(len(args), n):
r.append(defaults_tail[i - n + len(defaults_tail)])
return r
# from
# https://github.com/drisspg/transformer_nuggets/blob/9ad3a7fc552a954eb702ade0e276b8d8e09c3db6/transformer_nuggets/quant/qlora.py#L233
def unpack_uint4(quantized_data) -> torch.Tensor:
"""Get the original weight from the normalized float weight format"""
# since we are using uint8 we will decode 2 entries per byte
# Shift elements down 4 and select out the bottom 4 bits
first_elements = (quantized_data >> 4).to(torch.uint8)
second_elements = (quantized_data & 0b1111).to(torch.uint8)
return torch.stack([first_elements, second_elements], dim=-1)
class UInt4Tensor(torch.Tensor):
@staticmethod
def __new__(cls, elem):
# TODO: uint64 here is wrong, need a real dtype. Don't try to(int64)
# weird shit will happen
assert elem.dtype is torch.uint8
return torch.Tensor._make_wrapper_subclass(cls, up_size(elem.shape), dtype=torch.int64)
def __init__(self, elem):
self.elem = elem
def tolist(self):
return self.to(torch.uint8).tolist()
@classmethod
def __torch_dispatch__(cls, func, types, args, kwargs=None):
if func is torch.ops.aten.view.default:
self, size = args
size = utils.infer_size(size, self.numel())
assert not kwargs
# WARNING: views not preserved
return UInt4Tensor(self.elem.reshape(down_size(size)))
elif func is torch.ops.aten._to_copy.default:
self, = args
if kwargs == {'dtype': torch.uint8}:
return unpack_uint4(self.elem).view(self.shape) # no wrap
else:
raise NotImplementedError(f"_to_copy {kwargs}")
elif func is torch.ops.aten.unbind.int:
# This is tricky. Given torch.tensor([0, 1, 2, 3]) we want to
# create four tensors containing one element each. But we can't
# do this with uint4 because such a tensor's size is not divisible
# by bytes. What I am going to do instead is promote to uint8
# when this happens
self, dim = fill_defaults(args, 2, [0])
if dim != self.dim() - 1:
raise NotImplementedError(f"unbind dim={dim}")
else:
# We're unbinding the last dimension, need to promote
return torch.ops.aten._to_copy.default(self, dtype=torch.uint8).unbind(dim)
elif func is torch.ops.aten.select.int:
self, dim, index = args
if dim != self.dim() - 1:
return UInt4Tensor(torch.ops.aten.select.int(self.elem, dim, index))
else:
raise NotImplementedError(f"select dim={dim}")
elif func is torch.ops.aten.slice.Tensor:
self, dim, start, end, step = fill_defaults(args, 5, [0, None, None, 1])
if dim == self.dim() - 1:
# hard case
if step != 1:
raise NotImplementedError(f"slice step={step}")
assert start % 2 == 0, start
assert end >= self.shape[dim] or end % 2 == 0, end
return UInt4Tensor(torch.ops.aten.slice.Tensor(self.elem, dim, start // 2, end // 2, 1))
else:
# easy case
return UInt4Tensor(torch.ops.aten.slice.Tensor(self.elem, dim, start, end, step))
raise NotImplementedError(f"{func}")
__torch_function__ = torch._C._disabled_torch_function_impl
x = UInt4Tensor(torch.tensor([
[0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF],
[0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF],
[0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF],
], dtype=torch.uint8))
print(x.shape) # (3, 8)
print(x.to(torch.uint8))
print(x)
print(x[0:1, :])
print(x[:, 2:6])