Skip to content

Commit

Permalink
[P0] Fixing LoReFT rotation layer hot loading problem
Browse files Browse the repository at this point in the history
  • Loading branch information
frankaging committed Jul 25, 2024
1 parent 499aea1 commit 7510aa5
Show file tree
Hide file tree
Showing 3 changed files with 81 additions and 14 deletions.
24 changes: 17 additions & 7 deletions pyreft/interventions.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,9 @@ class LoreftIntervention(
"""
def __init__(self, **kwargs):
super().__init__(**kwargs, keep_last_dim=True)
rotate_layer = LowRankRotateLayer(self.embed_dim, kwargs["low_rank_dimension"], init_orth=True)
self.rotate_layer = torch.nn.utils.parametrizations.orthogonal(rotate_layer, orthogonal_map='householder')
rotate_layer = LowRankRotateLayer(
self.embed_dim, kwargs["low_rank_dimension"], init_orth=True)
self.rotate_layer = torch.nn.utils.parametrizations.orthogonal(rotate_layer)
self.learned_source = torch.nn.Linear(
self.embed_dim, kwargs["low_rank_dimension"]).to(
kwargs["dtype"] if "dtype" in kwargs else torch.bfloat16)
Expand All @@ -59,16 +60,25 @@ def state_dict(self, *args, **kwargs):
for k, v in self.learned_source.state_dict().items():
state_dict[k] = v
state_dict["rotate_layer"] = self.rotate_layer.weight.data
print(self.rotate_layer.weight.data)
return state_dict

def load_state_dict(self, state_dict, *args, **kwargs):
"""
Overwrite for data-efficiency.
"""
self.learned_source.load_state_dict(state_dict, strict=False)
super().load_state_dict(state_dict, strict=False)

# Caveat: without creating a new layer, it might not work (still not sure why)
# We have to recreate a layer, and load back the columns.
overload_w = state_dict["rotate_layer"]
overload_w_width = overload_w.shape[-1]
self.rotate_layer.parametrizations.weight[0].base[:,:overload_w_width] = overload_w
rotate_layer = LowRankRotateLayer(
self.embed_dim, overload_w_width, init_orth=True).to(
self.learned_source.weight.device)
self.rotate_layer = torch.nn.utils.parametrizations.orthogonal(rotate_layer)
self.rotate_layer.parametrizations.weight[0].base[:,:overload_w_width] = overload_w.to("cuda")

return


Expand Down Expand Up @@ -112,7 +122,7 @@ class ConsreftIntervention(
def __init__(self, **kwargs):
super().__init__(**kwargs, keep_last_dim=True)
rotate_layer = LowRankRotateLayer(self.embed_dim, kwargs["low_rank_dimension"], init_orth=True)
self.rotate_layer = torch.nn.utils.parametrizations.orthogonal(rotate_layer, orthogonal_map='householder')
self.rotate_layer = torch.nn.utils.parametrizations.orthogonal(rotate_layer)
self.learned_source = torch.nn.Parameter(
torch.rand(kwargs["low_rank_dimension"]), requires_grad=True)

Expand All @@ -137,7 +147,7 @@ class LobireftIntervention(
def __init__(self, **kwargs):
super().__init__(**kwargs, keep_last_dim=True)
rotate_layer = LowRankRotateLayer(self.embed_dim, kwargs["low_rank_dimension"], init_orth=True)
self.rotate_layer = torch.nn.utils.parametrizations.orthogonal(rotate_layer, orthogonal_map='householder')
self.rotate_layer = torch.nn.utils.parametrizations.orthogonal(rotate_layer)
self.learned_source = torch.nn.Parameter(
torch.rand(kwargs["low_rank_dimension"]), requires_grad=True)
self.dropout = torch.nn.Dropout(kwargs["dropout"] if "dropout" in kwargs else 0.0)
Expand All @@ -162,7 +172,7 @@ class DireftIntervention(
def __init__(self, **kwargs):
super().__init__(**kwargs, keep_last_dim=True)
rotate_layer = LowRankRotateLayer(self.embed_dim, kwargs["low_rank_dimension"], init_orth=True)
self.rotate_layer = torch.nn.utils.parametrizations.orthogonal(rotate_layer, orthogonal_map='householder')
self.rotate_layer = torch.nn.utils.parametrizations.orthogonal(rotate_layer)
self.learned_source = torch.nn.Linear(
self.embed_dim, kwargs["low_rank_dimension"]).to(
kwargs["dtype"] if "dtype" in kwargs else torch.bfloat16)
Expand Down
69 changes: 63 additions & 6 deletions pyreft/reft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,28 +79,85 @@ def compute_loss(
return_outputs=False
):
# run intervened forward pass
unit_locations = None
if "intervention_locations" in inputs:
unit_locations={"sources->base": (
None,
inputs["intervention_locations"].permute(1, 0, 2).tolist()
)}

_, cf_outputs = intervenable(
{
"input_ids": inputs["input_ids"],
"attention_mask": inputs["attention_mask"]
},
unit_locations={"sources->base": (
None,
inputs["intervention_locations"].permute(1, 0, 2).tolist()
)},
unit_locations=unit_locations,
labels=inputs["labels"],
subspaces=inputs["subspaces"].permute(1, 0, 2).tolist() if "subspaces" in inputs else None
)
# return
return (cf_outputs.loss, cf_outputs) if return_outputs else cf_outputs.loss
return (cf_outputs, cf_outputs) if return_outputs else cf_outputs.loss


class ReftTrainerForCausalLM(ReftTrainer):
def get_train_dataloader(self) -> DataLoader:
return make_dataloader(self.train_dataset, self._train_batch_size, self.data_collator, shuffle=True)


class ReftTrainerForSequenceClassification(ReftTrainer):
def compute_loss(
self,
intervenable: pv.IntervenableModel,
inputs,
return_outputs=False
):
# run intervened forward pass
unit_locations = None
if "intervention_locations" in inputs:
unit_locations={"sources->base": (
None,
inputs["intervention_locations"].permute(1, 0, 2).tolist()
)}

_, cf_outputs = intervenable(
{
"input_ids": inputs["input_ids"],
"attention_mask": inputs["attention_mask"]
},
unit_locations=unit_locations,
labels=inputs["labels"],
subspaces=inputs["subspaces"].permute(1, 0, 2).tolist() if "subspaces" in inputs else None
)
# classification loss on counterfactual labels
logits = cf_outputs.logits
labels = inputs["labels"]

if self.model.model.config.problem_type is None:
if self.model.model.num_labels == 1:
problem_type = "regression"
elif self.model.model.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
problem_type = "single_label_classification"
else:
problem_type = "multi_label_classification"
else:
problem_type = self.model.model.config.problem_type

if problem_type == "regression":
loss_fct = MSELoss()
if self.model.model.num_labels == 1:
loss = loss_fct(logits.squeeze(), labels.squeeze().to(torch.bfloat16))
else:
loss = loss_fct(logits, labels.to(torch.bfloat16))
elif problem_type == "single_label_classification":
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.model.model.num_labels), labels.view(-1))
elif problem_type == "multi_label_classification":
loss_fct = BCEWithLogitsLoss()
loss = loss_fct(logits, labels)

# return
return (loss, cf_outputs) if return_outputs else loss

def evaluate(
self, ignore_keys,
):
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ protobuf>=3.20.0
matplotlib>=3.7.4
ipywidgets>=8.1.1
plotnine>=0.12.4
huggingface-hub==0.23.0
huggingface-hub
numpy>=1.26.4
accelerate>=0.29.1
sentencepiece>=0.1.96
Expand Down

0 comments on commit 7510aa5

Please sign in to comment.