This repository has been archived by the owner on Mar 1, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathspin.py
171 lines (145 loc) · 7.11 KB
/
spin.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import LoraConfig, LoraModel, get_peft_model
# Load model and tokenizer
model_checkpoint = 'EleutherAI/gpt-neo-125M'
device = "cuda" if torch.cuda.is_available() else "cpu"
model = AutoModelForCausalLM.from_pretrained(model_checkpoint).to(device)
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
tokenizer.pad_token = tokenizer.eos_token
# Define lambda regularization parameter as per paper details
lambda_reg = 0.1
# Placeholder for the dataset loading function
dataset = [{"prompt": "Example prompt", "response": "Example response"}]
# Define LoRA configuration
lora_config = LoraConfig(
r=128, # rank of LoRA
lora_alpha=256, # scaling factor for initialization
lora_dropout=0.05,
bias="none",
)
# Wrap the model with LoRA layers for parameter-efficient training
peft_model = get_peft_model(model, lora_config)
# Define the compute_spin_loss function (with expected tensor shapes)
def compute_spin_loss(model_logits_gt, opponent_logits_gt, model_logits_syn, opponent_logits_syn, ground_truth_ids, synthetic_response_ids, lambda_reg):
# Apply softmax to convert logits to probabilities
# Shapes after softmax: [batch_size, sequence_length, vocab_size]
model_probs_gt = torch.nn.functional.softmax(model_logits_gt, dim=-1)
opponent_probs_gt = torch.nn.functional.softmax(opponent_logits_gt, dim=-1)
model_probs_syn = torch.nn.functional.softmax(model_logits_syn, dim=-1)
opponent_probs_syn = torch.nn.functional.softmax(opponent_logits_syn, dim=-1)
# Gather log probabilities for the actual tokens in the ground truth sequence
# [batch_size, sequence_length, vocab_size] -> [batch_size, sequence_length]
log_model_probs_gt = torch.log(torch.gather(
model_probs_gt, dim=2, index=ground_truth_ids.unsqueeze(-1)
).squeeze(-1))
log_opponent_probs_gt = torch.log(torch.gather(
opponent_probs_gt, dim=2, index=ground_truth_ids.unsqueeze(-1)
).squeeze(-1))
# Gather log probabilities for the actual tokens in the synthetic sequence
# [batch_size, sequence_length, vocab_size] -> [batch_size, sequence_length]
log_model_probs_syn = torch.log(torch.gather(
model_probs_syn, dim=2, index=synthetic_response_ids.unsqueeze(-1)
).squeeze(-1))
log_opponent_probs_syn = torch.log(torch.gather(
opponent_probs_syn, dim=2, index=synthetic_response_ids.unsqueeze(-1)
).squeeze(-1))
# Calculate log probability ratios for the tokens in the sequence
# [batch_size, sequence_length]
log_prob_ratio_gt = log_model_probs_gt - log_opponent_probs_gt
log_prob_ratio_syn = log_model_probs_syn - log_opponent_probs_syn
# Sum the log probability ratios over the sequence
# [batch_size] -> scalar
sum_log_prob_ratio_gt = torch.sum(log_prob_ratio_gt, dim=1)
sum_log_prob_ratio_syn = torch.sum(log_prob_ratio_syn, dim=1)
# Calculate the combined loss term for each sequence in the batch, scaled by lambda_reg
# [batch_size] -> scalar
combined_loss = lambda_reg * (sum_log_prob_ratio_gt - sum_log_prob_ratio_syn)
# Apply the logistic loss to the combined term
# [batch_size] -> scalar
logistic_loss = torch.log(1 + torch.exp(-combined_loss))
# Compute the mean of the logistic loss across the batch
# scalar
spin_loss = logistic_loss.mean()
return spin_loss
# Training setup
optimizer = torch.optim.AdamW(filter(lambda p: p.requires_grad, peft_model.parameters()), lr=5e-5)
# Training loop for T iterations
T = 5 # Set the number of iterations
for iteration in range(T):
total_loss = 0
# Disable adapter layers for the opponent model
peft_model.disable_adapter_layers()
synthetic_data = []
opponent_logits_gt_list = []
for data in dataset:
prompt = data['prompt']
# Tokenize and generate synthetic data using the opponent model
prompt_encoding = tokenizer(prompt, return_tensors='pt', padding=True, truncation=True).to(device)
prompt_ids = prompt_encoding['input_ids']
prompt_attention_mask = prompt_encoding['attention_mask']
with torch.no_grad():
peft_model.eval() # Set model to evaluation mode
#Generate synthetic responses using the opponent model
synthetic_response_ids = peft_model.generate(
prompt_ids,
attention_mask=prompt_attention_mask,
max_length=50
)
synthetic_data.append(synthetic_response_ids)
# Calculate opponent's logits for ground truth responses
ground_truth = data['response']
ground_truth_encoding = tokenizer(
ground_truth, return_tensors='pt', padding=True, truncation=True
).to(device)
ground_truth_ids = ground_truth_encoding['input_ids']
ground_truth_attention_mask = ground_truth_encoding['attention_mask']
opponent_logits_gt = peft_model(
input_ids=ground_truth_ids,
attention_mask=ground_truth_attention_mask
).logits
opponent_logits_gt_list.append(opponent_logits_gt)
# Enable adapter layers for training the main player model
peft_model.enable_adapter_layers()
# Train the main player model using the synthetic data and real responses
peft_model.train() # Set model to training mode
for i, data in enumerate(dataset):
# Tokenize ground truth response for training
ground_truth_encoding = tokenizer(
data['response'], return_tensors='pt', padding=True, truncation=True
).to(device)
ground_truth_ids = ground_truth_encoding['input_ids']
ground_truth_attention_mask = ground_truth_encoding['attention_mask']
synthetic_response_ids = synthetic_data[i].to(device)
opponent_logits_gt = opponent_logits_gt_list[i]
# Calculate logits for ground truth and synthetic responses using the main player model
main_player_logits_gt = peft_model(
input_ids=ground_truth_ids,
attention_mask=ground_truth_attention_mask
).logits
main_player_logits_syn = peft_model(
input_ids=synthetic_response_ids
).logits
# Compute logits for synthetic responses using the opponent model (disabled adapter layers)
peft_model.disable_adapter_layers()
opponent_logits_syn = peft_model(
input_ids=synthetic_response_ids
).logits
peft_model.enable_adapter_layers()
# Compute the loss
loss = compute_spin_loss(
main_player_logits_gt, opponent_logits_gt,
main_player_logits_syn, opponent_logits_syn,
ground_truth_ids, synthetic_response_ids, lambda_reg
)
total_loss += loss.item()
# Backpropagation and optimization
optimizer.zero_grad()
loss.backward()
optimizer.step()
# Print average loss
average_loss = total_loss / len(dataset)
print(f"Iteration {iteration + 1}/{T}, Average Loss: {average_loss}")
# Save the final model parameters
final_model_params = peft_model.state_dict()
print("Training complete.")