-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdataloader.py
57 lines (38 loc) · 1.62 KB
/
dataloader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
from torch.utils.data import Dataset
from collator import FinetuneDataCollatorWithPadding, EvalDataCollatorWithPadding
class RecformerTrainDataset(Dataset):
def __init__(self, user2train, collator: FinetuneDataCollatorWithPadding):
'''
user2train: dict of sequence data, user--> item sequence
'''
self.user2train = user2train
self.collator = collator
self.users = sorted(user2train.keys())
def __len__(self):
return len(self.users)
def __getitem__(self, index):
user = self.users[index]
seq = self.user2train[user]
return seq
def collate_fn(self, data):
return self.collator([{'items': line} for line in data])
class RecformerEvalDataset(Dataset):
def __init__(self, user2train, user2val, user2test, mode, collator: EvalDataCollatorWithPadding):
self.user2train = user2train
self.user2val = user2val
self.user2test = user2test
self.collator = collator
if mode == "val":
self.users = list(self.user2val.keys())
else:
self.users = list(self.user2test.keys())
self.mode = mode
def __len__(self):
return len(self.users)
def __getitem__(self, index):
user = self.users[index]
seq = self.user2train[user] if self.mode == "val" else self.user2train[user] + self.user2val[user]
label = self.user2val[user] if self.mode == "val" else self.user2test[user]
return seq, label
def collate_fn(self, data):
return self.collator([{'items': line[0], 'label': line[1]} for line in data])