Skip to content

Commit

Permalink
added logic for deleting adapters once loaded (#34650)
Browse files Browse the repository at this point in the history
* added logic for deleting adapters once loaded

* updated to the latest version of transformers, merged utility function into the source

* updated with missing check

* added peft version check

* Apply suggestions from code review

Co-authored-by: Anton Vlasjuk <[email protected]>

* changes according to reviewer

* added test for deleting adapter(s)

* styling changes

* styling changes in test

* removed redundant code

* formatted my contributions with ruff

* optimized error handling

* ruff formatted with correct config

* resolved formatting issues

---------

Co-authored-by: Anton Vlasjuk <[email protected]>
  • Loading branch information
itsskofficial and vasqu authored Jan 6, 2025
1 parent 1650e0e commit ca00950
Show file tree
Hide file tree
Showing 2 changed files with 126 additions and 1 deletion.
62 changes: 62 additions & 0 deletions src/transformers/integrations/peft.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import importlib
import inspect
import warnings
Expand Down Expand Up @@ -525,3 +526,64 @@ def _dispatch_accelerate_model(
offload_dir=offload_folder,
**dispatch_model_kwargs,
)

def delete_adapter(self, adapter_names: Union[List[str], str]) -> None:
"""
Delete an adapter's LoRA layers from the underlying model.
Args:
adapter_names (`Union[List[str], str]`):
The name(s) of the adapter(s) to delete.
Example:
```py
from diffusers import AutoPipelineForText2Image
import torch
pipeline = AutoPipelineForText2Image.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
).to("cuda")
pipeline.load_lora_weights(
"jbilcke-hf/sdxl-cinematic-1", weight_name="pytorch_lora_weights.safetensors", adapter_names="cinematic"
)
pipeline.delete_adapters("cinematic")
```
"""

check_peft_version(min_version=MIN_PEFT_VERSION)

if not self._hf_peft_config_loaded:
raise ValueError("No adapter loaded. Please load an adapter first.")

from peft.tuners.tuners_utils import BaseTunerLayer

if isinstance(adapter_names, str):
adapter_names = [adapter_names]

# Check that all adapter names are present in the config
missing_adapters = [name for name in adapter_names if name not in self.peft_config]
if missing_adapters:
raise ValueError(
f"The following adapter(s) are not present and cannot be deleted: {', '.join(missing_adapters)}"
)

for adapter_name in adapter_names:
for module in self.modules():
if isinstance(module, BaseTunerLayer):
if hasattr(module, "delete_adapter"):
module.delete_adapter(adapter_name)
else:
raise ValueError(
"The version of PEFT you are using is not compatible, please use a version that is greater than 0.6.1"
)

# For transformers integration - we need to pop the adapter from the config
if getattr(self, "_hf_peft_config_loaded", False) and hasattr(self, "peft_config"):
self.peft_config.pop(adapter_name, None)

# In case all adapters are deleted, we need to delete the config
# and make sure to set the flag to False
if len(self.peft_config) == 0:
del self.peft_config
self._hf_peft_config_loaded = False
65 changes: 64 additions & 1 deletion tests/peft_integration/test_peft_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,7 +350,6 @@ def test_peft_add_multi_adapter(self):
self.assertFalse(
torch.allclose(logits_adapter_1.logits, logits_adapter_mixed.logits, atol=1e-6, rtol=1e-6)
)

self.assertFalse(
torch.allclose(logits_adapter_2.logits, logits_adapter_mixed.logits, atol=1e-6, rtol=1e-6)
)
Expand All @@ -359,6 +358,70 @@ def test_peft_add_multi_adapter(self):
with self.assertRaises(ValueError), tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)

def test_delete_adapter(self):
"""
Enhanced test for `delete_adapter` to handle multiple adapters,
edge cases, and proper error handling.
"""
from peft import LoraConfig

for model_id in self.transformers_test_model_ids:
for transformers_class in self.transformers_test_model_classes:
model = transformers_class.from_pretrained(model_id).to(torch_device)

# Add multiple adapters
peft_config_1 = LoraConfig(init_lora_weights=False)
peft_config_2 = LoraConfig(init_lora_weights=False)
model.add_adapter(peft_config_1, adapter_name="adapter_1")
model.add_adapter(peft_config_2, adapter_name="adapter_2")

# Ensure adapters were added
self.assertIn("adapter_1", model.peft_config)
self.assertIn("adapter_2", model.peft_config)

# Delete a single adapter
model.delete_adapter("adapter_1")
self.assertNotIn("adapter_1", model.peft_config)
self.assertIn("adapter_2", model.peft_config)

# Delete remaining adapter
model.delete_adapter("adapter_2")
self.assertNotIn("adapter_2", model.peft_config)
self.assertFalse(model._hf_peft_config_loaded)

# Re-add adapters for edge case tests
model.add_adapter(peft_config_1, adapter_name="adapter_1")
model.add_adapter(peft_config_2, adapter_name="adapter_2")

# Attempt to delete multiple adapters at once
model.delete_adapter(["adapter_1", "adapter_2"])
self.assertNotIn("adapter_1", model.peft_config)
self.assertNotIn("adapter_2", model.peft_config)
self.assertFalse(model._hf_peft_config_loaded)

# Test edge cases
with self.assertRaisesRegex(ValueError, "The following adapter\\(s\\) are not present"):
model.delete_adapter("nonexistent_adapter")

with self.assertRaisesRegex(ValueError, "The following adapter\\(s\\) are not present"):
model.delete_adapter(["adapter_1", "nonexistent_adapter"])

# Deleting with an empty list or None should not raise errors
model.add_adapter(peft_config_1, adapter_name="adapter_1")
model.add_adapter(peft_config_2, adapter_name="adapter_2")
model.delete_adapter([]) # No-op
self.assertIn("adapter_1", model.peft_config)
self.assertIn("adapter_2", model.peft_config)

model.delete_adapter(None) # No-op
self.assertIn("adapter_1", model.peft_config)
self.assertIn("adapter_2", model.peft_config)

# Deleting duplicate adapter names in the list
model.delete_adapter(["adapter_1", "adapter_1"])
self.assertNotIn("adapter_1", model.peft_config)
self.assertIn("adapter_2", model.peft_config)

@require_torch_gpu
@require_bitsandbytes
def test_peft_from_pretrained_kwargs(self):
Expand Down

0 comments on commit ca00950

Please sign in to comment.