From 7510aa55cb303b66c25298d712b7c188b7a1f372 Mon Sep 17 00:00:00 2001 From: frankaging Date: Wed, 24 Jul 2024 18:40:30 -0700 Subject: [PATCH 1/2] [P0] Fixing LoReFT rotation layer hot loading problem --- pyreft/interventions.py | 24 +++++++++----- pyreft/reft_trainer.py | 69 +++++++++++++++++++++++++++++++++++++---- requirements.txt | 2 +- 3 files changed, 81 insertions(+), 14 deletions(-) diff --git a/pyreft/interventions.py b/pyreft/interventions.py index 3a25f7a..afe0446 100644 --- a/pyreft/interventions.py +++ b/pyreft/interventions.py @@ -34,8 +34,9 @@ class LoreftIntervention( """ def __init__(self, **kwargs): super().__init__(**kwargs, keep_last_dim=True) - rotate_layer = LowRankRotateLayer(self.embed_dim, kwargs["low_rank_dimension"], init_orth=True) - self.rotate_layer = torch.nn.utils.parametrizations.orthogonal(rotate_layer, orthogonal_map='householder') + rotate_layer = LowRankRotateLayer( + self.embed_dim, kwargs["low_rank_dimension"], init_orth=True) + self.rotate_layer = torch.nn.utils.parametrizations.orthogonal(rotate_layer) self.learned_source = torch.nn.Linear( self.embed_dim, kwargs["low_rank_dimension"]).to( kwargs["dtype"] if "dtype" in kwargs else torch.bfloat16) @@ -59,16 +60,25 @@ def state_dict(self, *args, **kwargs): for k, v in self.learned_source.state_dict().items(): state_dict[k] = v state_dict["rotate_layer"] = self.rotate_layer.weight.data + print(self.rotate_layer.weight.data) return state_dict def load_state_dict(self, state_dict, *args, **kwargs): """ Overwrite for data-efficiency. """ - self.learned_source.load_state_dict(state_dict, strict=False) + super().load_state_dict(state_dict, strict=False) + + # Caveat: without creating a new layer, it might not work (still not sure why) + # We have to recreate a layer, and load back the columns. overload_w = state_dict["rotate_layer"] overload_w_width = overload_w.shape[-1] - self.rotate_layer.parametrizations.weight[0].base[:,:overload_w_width] = overload_w + rotate_layer = LowRankRotateLayer( + self.embed_dim, overload_w_width, init_orth=True).to( + self.learned_source.weight.device) + self.rotate_layer = torch.nn.utils.parametrizations.orthogonal(rotate_layer) + self.rotate_layer.parametrizations.weight[0].base[:,:overload_w_width] = overload_w.to("cuda") + return @@ -112,7 +122,7 @@ class ConsreftIntervention( def __init__(self, **kwargs): super().__init__(**kwargs, keep_last_dim=True) rotate_layer = LowRankRotateLayer(self.embed_dim, kwargs["low_rank_dimension"], init_orth=True) - self.rotate_layer = torch.nn.utils.parametrizations.orthogonal(rotate_layer, orthogonal_map='householder') + self.rotate_layer = torch.nn.utils.parametrizations.orthogonal(rotate_layer) self.learned_source = torch.nn.Parameter( torch.rand(kwargs["low_rank_dimension"]), requires_grad=True) @@ -137,7 +147,7 @@ class LobireftIntervention( def __init__(self, **kwargs): super().__init__(**kwargs, keep_last_dim=True) rotate_layer = LowRankRotateLayer(self.embed_dim, kwargs["low_rank_dimension"], init_orth=True) - self.rotate_layer = torch.nn.utils.parametrizations.orthogonal(rotate_layer, orthogonal_map='householder') + self.rotate_layer = torch.nn.utils.parametrizations.orthogonal(rotate_layer) self.learned_source = torch.nn.Parameter( torch.rand(kwargs["low_rank_dimension"]), requires_grad=True) self.dropout = torch.nn.Dropout(kwargs["dropout"] if "dropout" in kwargs else 0.0) @@ -162,7 +172,7 @@ class DireftIntervention( def __init__(self, **kwargs): super().__init__(**kwargs, keep_last_dim=True) rotate_layer = LowRankRotateLayer(self.embed_dim, kwargs["low_rank_dimension"], init_orth=True) - self.rotate_layer = torch.nn.utils.parametrizations.orthogonal(rotate_layer, orthogonal_map='householder') + self.rotate_layer = torch.nn.utils.parametrizations.orthogonal(rotate_layer) self.learned_source = torch.nn.Linear( self.embed_dim, kwargs["low_rank_dimension"]).to( kwargs["dtype"] if "dtype" in kwargs else torch.bfloat16) diff --git a/pyreft/reft_trainer.py b/pyreft/reft_trainer.py index 01f4308..187335c 100644 --- a/pyreft/reft_trainer.py +++ b/pyreft/reft_trainer.py @@ -79,28 +79,85 @@ def compute_loss( return_outputs=False ): # run intervened forward pass + unit_locations = None + if "intervention_locations" in inputs: + unit_locations={"sources->base": ( + None, + inputs["intervention_locations"].permute(1, 0, 2).tolist() + )} + _, cf_outputs = intervenable( { "input_ids": inputs["input_ids"], "attention_mask": inputs["attention_mask"] }, - unit_locations={"sources->base": ( - None, - inputs["intervention_locations"].permute(1, 0, 2).tolist() - )}, + unit_locations=unit_locations, labels=inputs["labels"], subspaces=inputs["subspaces"].permute(1, 0, 2).tolist() if "subspaces" in inputs else None ) # return - return (cf_outputs.loss, cf_outputs) if return_outputs else cf_outputs.loss + return (cf_outputs, cf_outputs) if return_outputs else cf_outputs.loss class ReftTrainerForCausalLM(ReftTrainer): def get_train_dataloader(self) -> DataLoader: return make_dataloader(self.train_dataset, self._train_batch_size, self.data_collator, shuffle=True) - + class ReftTrainerForSequenceClassification(ReftTrainer): + def compute_loss( + self, + intervenable: pv.IntervenableModel, + inputs, + return_outputs=False + ): + # run intervened forward pass + unit_locations = None + if "intervention_locations" in inputs: + unit_locations={"sources->base": ( + None, + inputs["intervention_locations"].permute(1, 0, 2).tolist() + )} + + _, cf_outputs = intervenable( + { + "input_ids": inputs["input_ids"], + "attention_mask": inputs["attention_mask"] + }, + unit_locations=unit_locations, + labels=inputs["labels"], + subspaces=inputs["subspaces"].permute(1, 0, 2).tolist() if "subspaces" in inputs else None + ) + # classification loss on counterfactual labels + logits = cf_outputs.logits + labels = inputs["labels"] + + if self.model.model.config.problem_type is None: + if self.model.model.num_labels == 1: + problem_type = "regression" + elif self.model.model.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + problem_type = "single_label_classification" + else: + problem_type = "multi_label_classification" + else: + problem_type = self.model.model.config.problem_type + + if problem_type == "regression": + loss_fct = MSELoss() + if self.model.model.num_labels == 1: + loss = loss_fct(logits.squeeze(), labels.squeeze().to(torch.bfloat16)) + else: + loss = loss_fct(logits, labels.to(torch.bfloat16)) + elif problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.model.model.num_labels), labels.view(-1)) + elif problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(logits, labels) + + # return + return (loss, cf_outputs) if return_outputs else loss + def evaluate( self, ignore_keys, ): diff --git a/requirements.txt b/requirements.txt index 9fbd513..99328cd 100644 --- a/requirements.txt +++ b/requirements.txt @@ -7,7 +7,7 @@ protobuf>=3.20.0 matplotlib>=3.7.4 ipywidgets>=8.1.1 plotnine>=0.12.4 -huggingface-hub==0.23.0 +huggingface-hub numpy>=1.26.4 accelerate>=0.29.1 sentencepiece>=0.1.96 From 5bee6b882a40bc979888525c945d8cfbb2ac309e Mon Sep 17 00:00:00 2001 From: frankaging Date: Wed, 24 Jul 2024 21:05:02 -0700 Subject: [PATCH 2/2] remove logging and clean up --- pyreft/interventions.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pyreft/interventions.py b/pyreft/interventions.py index afe0446..53f7d05 100644 --- a/pyreft/interventions.py +++ b/pyreft/interventions.py @@ -60,14 +60,13 @@ def state_dict(self, *args, **kwargs): for k, v in self.learned_source.state_dict().items(): state_dict[k] = v state_dict["rotate_layer"] = self.rotate_layer.weight.data - print(self.rotate_layer.weight.data) return state_dict def load_state_dict(self, state_dict, *args, **kwargs): """ Overwrite for data-efficiency. """ - super().load_state_dict(state_dict, strict=False) + self.learned_source.load_state_dict(state_dict, strict=False) # Caveat: without creating a new layer, it might not work (still not sure why) # We have to recreate a layer, and load back the columns. @@ -77,7 +76,8 @@ def load_state_dict(self, state_dict, *args, **kwargs): self.embed_dim, overload_w_width, init_orth=True).to( self.learned_source.weight.device) self.rotate_layer = torch.nn.utils.parametrizations.orthogonal(rotate_layer) - self.rotate_layer.parametrizations.weight[0].base[:,:overload_w_width] = overload_w.to("cuda") + self.rotate_layer.parametrizations.weight[0].base[:,:overload_w_width] = overload_w + assert torch.allclose(self.rotate_layer.weight.data, overload_w.data) == True # we must match! return