Skip to content

Commit

Permalink
rename new function
Browse files Browse the repository at this point in the history
  • Loading branch information
kells1986 committed Jun 21, 2023
1 parent 59e1ec0 commit 81b634d
Showing 1 changed file with 7 additions and 7 deletions.
14 changes: 7 additions & 7 deletions minlora/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,18 +85,18 @@ def apply_lora(layer, register=True, merge=False, lora_config=default_lora_confi
parametrize.remove_parametrizations(layer, attr_name, leave_parametrized=merge)


def apply_lora_by_name(model, target_module_names, register=True, merge=False, lora_config=default_lora_config):
"""Add LoRA parameterization to specific layers in a model by names"""
for name, layer in model.named_modules():
if any([m in name for m in target_module_names]):
add_lora(layer, register=register, merge=merge, lora_config=lora_config)


def add_lora(model, lora_config=default_lora_config):
"""add lora parametrization to all layers in a model. Calling it twice will add lora twice"""
model.apply(partial(apply_lora, lora_config=lora_config))


def add_lora_by_name(model, target_module_names, lora_config=default_lora_config):
"""Add LoRA parameterization to specific layers in a model by names"""
for name, layer in model.named_modules():
if any([m in name for m in target_module_names]):
add_lora(layer, lora_config=lora_config)


def merge_lora(model):
"""merge lora parametrization to all layers in a model. This will remove all parametrization"""
model.apply(partial(apply_lora, register=False, merge=True))
Expand Down

0 comments on commit 81b634d

Please sign in to comment.