Spaces:
Running
on
Zero
Running
on
Zero
| # Copyright 2024-present the HuggingFace Inc. team. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| from __future__ import annotations | |
| from operator import attrgetter | |
| import torch | |
| from peft.config import PeftConfig | |
| from peft.mapping import PEFT_TYPE_TO_CONFIG_MAPPING | |
| from .constants import PEFT_TYPE_TO_PREFIX_MAPPING | |
| from .other import infer_device | |
| from .peft_types import PeftType | |
| from .save_and_load import _insert_adapter_name_into_state_dict, load_peft_weights | |
| # so far only LoRA is supported | |
| CONFIG_KEYS_TO_CHECK = {PeftType.LORA: ["lora_alpha", "use_rslora", "lora_dropout", "alpha_pattern", "use_dora"]} | |
| def hotswap_adapter_from_state_dict(model, state_dict, adapter_name, parameter_prefix="lora_"): | |
| """ | |
| Swap out the adapter weights from the model with the weights from state_dict. | |
| As of now, only LoRA is supported. | |
| This is a low-level function that assumes that the adapters have been checked for compatibility and that the | |
| state_dict has been correctly mapped to work with PEFT. For a high level function that performs this work for you, | |
| use `hotswap_adapter` instead. | |
| Args: | |
| model (`nn.Module`): | |
| The model with the loaded adapter. | |
| state_dict (`dict[str, torch.Tensor]`): | |
| The state dict of the new adapter, which needs to be compatible (targeting same modules etc.). | |
| adapter_name (`str`): | |
| The name of the adapter that should be hot-swapped, e.g. `"default"`. The name will remain the same after | |
| swapping. | |
| parameter_prefix (`str`, *optional*, defaults to `"lora_"`) | |
| The prefix used to identify the adapter's keys in the state dict. For LoRA, this would be `"lora_"` (the | |
| default). | |
| Raises: | |
| RuntimeError | |
| If the old and the new adapter are not compatible, a RuntimeError is raised. | |
| """ | |
| # Ensure that all the keys of the new adapter correspond exactly to the keys of the old adapter, otherwise | |
| # hot-swapping is not possible | |
| is_compiled = hasattr(model, "_orig_mod") | |
| # TODO: there is probably a more precise way to identify the adapter keys | |
| missing_keys = {k for k in model.state_dict() if (parameter_prefix in k) and (adapter_name in k)} | |
| unexpected_keys = set() | |
| # first: dry run, not swapping anything | |
| for key, new_val in state_dict.items(): | |
| try: | |
| old_val = attrgetter(key)(model) | |
| except AttributeError: | |
| unexpected_keys.add(key) | |
| continue | |
| if is_compiled: | |
| missing_keys.remove("_orig_mod." + key) | |
| else: | |
| missing_keys.remove(key) | |
| if missing_keys or unexpected_keys: | |
| msg = "Hot swapping the adapter did not succeed." | |
| if missing_keys: | |
| msg += f" Missing keys: {', '.join(sorted(missing_keys))}." | |
| if unexpected_keys: | |
| msg += f" Unexpected keys: {', '.join(sorted(unexpected_keys))}." | |
| raise RuntimeError(msg) | |
| # actual swapping | |
| for key, new_val in state_dict.items(): | |
| # no need to account for potential _orig_mod in key here, as torch handles that | |
| old_val = attrgetter(key)(model) | |
| if is_compiled: | |
| # Compiled models don't work with swap_tensors because there are weakrefs for the tensor. It is unclear if | |
| # this workaround could not cause trouble but the tests indicate that it works. | |
| old_val.data = new_val.data | |
| else: | |
| torch.utils.swap_tensors(old_val, new_val) | |
| def _check_hotswap_configs_compatible(config0: PeftConfig, config1: PeftConfig) -> None: | |
| """ | |
| Check if two configs are compatible for hot-swapping. | |
| Only LoRA parameters are checked for now. | |
| To hot-swap two adapters, their configs must be compatible. Otherwise, the results could be false. E.g. if they use | |
| different alpha values, after hot-swapping, the alphas from the first adapter would still be used with the weights | |
| from the 2nd adapter, which would result in incorrect behavior. There is probably a way to swap these values as | |
| well, but that's not implemented yet, and we need to be careful not to trigger re-compilation if the model is | |
| compiled (so no modification of the dict). | |
| """ | |
| if config0.peft_type != config1.peft_type: | |
| msg = f"Incompatible PEFT types found: {config0.peft_type.value} and {config1.peft_type.value}" | |
| raise ValueError(msg) | |
| if config0.peft_type not in CONFIG_KEYS_TO_CHECK: | |
| msg = ( | |
| f"Hotswapping only supports {', '.join(CONFIG_KEYS_TO_CHECK.keys())} but " | |
| f"{config0.peft_type.value} was passed." | |
| ) | |
| raise ValueError(msg) | |
| config_keys_to_check = CONFIG_KEYS_TO_CHECK[config0.peft_type] | |
| # TODO: This is a very rough check only for LoRA at the moment. Also, there might be some options that don't | |
| # necessarily require an error. | |
| config0 = config0.to_dict() | |
| config1 = config1.to_dict() | |
| sentinel = object() | |
| for key in config_keys_to_check: | |
| val0 = config0.get(key, sentinel) | |
| val1 = config1.get(key, sentinel) | |
| if val0 != val1: | |
| raise ValueError(f"Configs are incompatible: for {key}, {val0} != {val1}") | |
| def hotswap_adapter(model, model_name_or_path, adapter_name, torch_device=None, **kwargs): | |
| """Substitute old adapter data with new adapter data, keeping the rest the same. | |
| As of now, only LoRA is supported. | |
| This function is useful when you want to replace the loaded adapter with a new adapter. The adapter name will | |
| remain the same, but the weights and other parameters will be swapped out. | |
| If the adapters are incomptabile, e.g. targeting different layers or having different alpha values, an error will | |
| be raised. | |
| Example: | |
| ```py | |
| >>> import torch | |
| >>> from transformers import AutoModelForCausalLM | |
| >>> from peft import PeftModel | |
| >>> from peft.utils.hotswap import hotswap_adapter | |
| >>> model_id = ... | |
| >>> inputs = ... | |
| >>> device = ... | |
| >>> model = AutoModelForCausalLM.from_pretrained(model_id).to(device) | |
| >>> # load lora 0 | |
| >>> model = PeftModel.from_pretrained(model, "path-adapter-0") | |
| >>> model = torch.compile(model) # optionally compile the model | |
| >>> with torch.inference_mode(): | |
| ... output_adapter_0 = model(inputs) | |
| >>> # replace the "default" lora adapter with the new one | |
| >>> hotswap_adapter(model, "path-adapter-1", adapter_name="default", torch_device=device) | |
| >>> with torch.inference_mode(): | |
| ... output_adapter_1 = model(inputs).logits | |
| ``` | |
| Args: | |
| model ([`~PeftModel`]): | |
| The PEFT model with the loaded adapter. | |
| model_name_or_path (`str`): | |
| The name or path of the model to load the new adapter from. | |
| adapter_name (`str`): | |
| The name of the adapter to swap, e.g. `"default"`. The name will stay the same after swapping. | |
| torch_device: (`str`, *optional*, defaults to None): | |
| The device to load the new adapter onto. | |
| **kwargs (`optional`): | |
| Additional keyword arguments used for loading the config and weights. | |
| """ | |
| if torch_device is None: | |
| torch_device = infer_device() | |
| ############################ | |
| # LOAD CONFIG AND VALIDATE # | |
| ############################ | |
| config_cls = PEFT_TYPE_TO_CONFIG_MAPPING[ | |
| PeftConfig._get_peft_type( | |
| model_name_or_path, | |
| subfolder=kwargs.get("subfolder", None), | |
| revision=kwargs.get("revision", None), | |
| cache_dir=kwargs.get("cache_dir", None), | |
| use_auth_token=kwargs.get("use_auth_token", None), | |
| token=kwargs.get("token", None), | |
| ) | |
| ] | |
| config = config_cls.from_pretrained(model_name_or_path, **kwargs) | |
| # config keys that could affect the model output besides what is determined by the state_dict | |
| _check_hotswap_configs_compatible(model.active_peft_config, config) | |
| state_dict = load_peft_weights(model_name_or_path, device=torch_device, **kwargs) | |
| ########################### | |
| # LOAD & REMAP STATE_DICT # | |
| ########################### | |
| parameter_prefix = PEFT_TYPE_TO_PREFIX_MAPPING[config.peft_type] | |
| peft_model_state_dict = _insert_adapter_name_into_state_dict( | |
| state_dict, adapter_name=adapter_name, parameter_prefix=parameter_prefix | |
| ) | |
| hotswap_adapter_from_state_dict( | |
| model=model, | |
| state_dict=peft_model_state_dict, | |
| adapter_name=adapter_name, | |
| parameter_prefix=parameter_prefix, | |
| ) | |