diff --git a/bitsandbytes/optim/optimizer.py b/bitsandbytes/optim/optimizer.py index 5f31a584e..db7a35231 100644 --- a/bitsandbytes/optim/optimizer.py +++ b/bitsandbytes/optim/optimizer.py @@ -111,6 +111,8 @@ def register_module_override(self, module, param_name, config): class Optimizer8bit(torch.optim.Optimizer): + _FSDP_WRAPPED_QUANT_STATE_KEY = "__bnb_optimizer_quant_state__" + def __init__(self, params, defaults, optim_bits=32, is_paged=False): """ Base 8-bit optimizer class. @@ -152,6 +154,34 @@ def fill_qmap(self): self.name2qmap["dynamic"] = F.create_dynamic_map(signed=True) self.name2qmap["udynamic"] = F.create_dynamic_map(signed=False) + def state_dict(self): + """Return optimizer state, wrapping quantization tensors for FSDP compatibility. + + FSDP's full_optim_state_dict gathers all tensor states across ranks. + Quantization states (state1, state2, absmax, etc.) have different shapes + than model parameters, causing gather operations to fail. By wrapping + these tensors in a nested dict, FSDP skips them during gathering. + """ + state_dict = super().state_dict() + + # Deep copy the state to avoid modifying the original optimizer state + # PyTorch's state_dict() only does a shallow copy + state_dict["state"] = { + k: {kk: vv for kk, vv in v.items()} if isinstance(v, dict) else v for k, v in state_dict["state"].items() + } + + # Wrap quantization-specific tensors in a nested dict to hide from FSDP + for param_state in state_dict["state"].values(): + if isinstance(param_state, dict): + quant_state = {} + keys_to_wrap = [k for k in param_state if k in self.non_castable_tensor_keys] + for key in keys_to_wrap: + quant_state[key] = param_state.pop(key) + if quant_state: + param_state[self._FSDP_WRAPPED_QUANT_STATE_KEY] = quant_state + + return state_dict + def __setstate__(self, state): super().__setstate__(state) @@ -166,6 +196,13 @@ def load_state_dict(self, state_dict, move_to_device=True): """ # deepcopy, to be consistent with module API state_dict = deepcopy(state_dict) + + # Unwrap quantization states that were wrapped for FSDP compatibility + for param_state in state_dict["state"].values(): + if isinstance(param_state, dict) and self._FSDP_WRAPPED_QUANT_STATE_KEY in param_state: + quant_state = param_state.pop(self._FSDP_WRAPPED_QUANT_STATE_KEY) + param_state.update(quant_state) + # Validate the state_dict groups = self.param_groups saved_groups = state_dict["param_groups"]