diff --git a/pyreft/interventions.py b/pyreft/interventions.py index afe0446..53f7d05 100644 --- a/pyreft/interventions.py +++ b/pyreft/interventions.py @@ -60,14 +60,13 @@ def state_dict(self, *args, **kwargs): for k, v in self.learned_source.state_dict().items(): state_dict[k] = v state_dict["rotate_layer"] = self.rotate_layer.weight.data - print(self.rotate_layer.weight.data) return state_dict def load_state_dict(self, state_dict, *args, **kwargs): """ Overwrite for data-efficiency. """ - super().load_state_dict(state_dict, strict=False) + self.learned_source.load_state_dict(state_dict, strict=False) # Caveat: without creating a new layer, it might not work (still not sure why) # We have to recreate a layer, and load back the columns. @@ -77,7 +76,8 @@ def load_state_dict(self, state_dict, *args, **kwargs): self.embed_dim, overload_w_width, init_orth=True).to( self.learned_source.weight.device) self.rotate_layer = torch.nn.utils.parametrizations.orthogonal(rotate_layer) - self.rotate_layer.parametrizations.weight[0].base[:,:overload_w_width] = overload_w.to("cuda") + self.rotate_layer.parametrizations.weight[0].base[:,:overload_w_width] = overload_w + assert torch.allclose(self.rotate_layer.weight.data, overload_w.data) == True # we must match! return