Shortcuts

Source code for finetuning_scheduler.fts

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
r"""
Fine-Tuning Scheduler
^^^^^^^^^^^^^^^^^^^^^

Used to implement flexible fine-tuning training schedules

"""
import logging
from copy import deepcopy
from pprint import pformat
from typing import Any, Dict, Optional, Sequence, Tuple, Union

import lightning.pytorch as pl
import torch
from lightning.fabric.utilities import rank_zero_info
from lightning.fabric.utilities.distributed import ReduceOp
from lightning.pytorch.callbacks import BaseFinetuning
from lightning.pytorch.strategies.strategy import Strategy
from lightning.pytorch.trainer.states import TrainerFn
from lightning.pytorch.utilities.exceptions import MisconfigurationException
from lightning.pytorch.utilities.rank_zero import rank_zero_debug, rank_zero_warn

from finetuning_scheduler.fts_supporters import (
    CallbackDepMixin,
    FTSEarlyStopping,
    FTSState,
    ScheduleImplMixin,
    ScheduleParsingMixin,
    STRATEGY_ADAPTERS,
)
from finetuning_scheduler.strategy_adapters.base import StrategyAdapter
from finetuning_scheduler.types import ParamGroupAddable

log = logging.getLogger(__name__)


[docs]class FinetuningScheduler(ScheduleImplMixin, ScheduleParsingMixin, CallbackDepMixin, BaseFinetuning): r""" This callback enables flexible, multi-phase, scheduled fine-tuning of foundation models. Gradual unfreezing/thawing can help maximize foundation model knowledge retention while allowing (typically upper layers of) the model to optimally adapt to new tasks during transfer learning. :class:`~finetuning_scheduler.fts.FinetuningScheduler` orchestrates the gradual unfreezing of models via a fine-tuning schedule that is either implicitly generated (the default) or explicitly provided by the user (more computationally efficient). Fine-tuning phase transitions are driven by :class:`~finetuning_scheduler.fts_supporters.FTSEarlyStopping` criteria (a multi-phase extension of :external+pl:class:`~lightning.pytorch.callbacks.early_stopping.EarlyStopping`), user-specified epoch transitions or a composition of the two (the default mode). A :class:`~finetuning_scheduler.fts.FinetuningScheduler` training session completes when the final phase of the schedule has its stopping criteria met. See :ref:`Early Stopping<common/early_stopping:Early stopping>` for more details on that callback's configuration. Schedule definition is facilitated via :meth:`~finetuning_scheduler.fts_supporters.ScheduleImplMixin.gen_ft_schedule` which dumps a default fine-tuning schedule (by default using a naive, 2-parameters per level heuristic) which can be adjusted as desired by the user and subsuquently passed to the callback. Implicit fine-tuning mode generates the default schedule and proceeds to fine-tune according to the generated schedule. Implicit fine-tuning will often be less computationally efficient than explicit fine-tuning but can often serve as a good baseline for subsquent explicit schedule refinement and can marginally outperform many explicit schedules. Example:: import lightning as L from lightning.pytorch.callbacks import FinetuningScheduler trainer = L.Trainer(callbacks=[FinetuningScheduler()]) .. note:: Currently, :class:`~finetuning_scheduler.fts.FinetuningScheduler` does not support the use of multiple :class:`~finetuning_scheduler.fts_supporters.FTSCheckpoint` or :class:`~finetuning_scheduler.fts_supporters.FTSEarlyStopping` callback instances. .. note:: While :class:`~finetuning_scheduler.fts.FinetuningScheduler` supports the use of :external+torch:class:`~torch.distributed.optim.ZeroRedundancyOptimizer`, setting ``overlap_with_ddp`` to ``True`` is not supported because that optimizer mode only supports a single parameter group. """ pl_module: pl.LightningModule strategy_adapter: StrategyAdapter def __init__( self, ft_schedule: Optional[Union[str, dict]] = None, max_depth: int = -1, base_max_lr: float = 1e-5, restore_best: bool = True, gen_ft_sched_only: bool = False, epoch_transitions_only: bool = False, reinit_optim_cfg: Optional[Dict] = None, reinit_lr_cfg: Optional[Dict] = None, strategy_adapter_cfg: Optional[Dict] = None, custom_strategy_adapter: Optional[Dict[str, str]] = None, allow_untested: bool = False, apply_lambdas_new_pgs: bool = False, logging_level: int = logging.INFO, enforce_phase0_params: bool = True, ): r""" Arguments used to define and configure a scheduled fine-tuning training session: Args: ft_schedule: The fine-tuning schedule to be executed. Usually will be a .yaml file path but can also be a properly structured Dict. See :ref:`Specifying a Fine-Tuning Schedule<index:Specifying a fine-tuning schedule>` for the basic schedule format. See :ref:`LR Scheduler Reinitialization<explicit-lr-reinitialization-schedule>` for more complex schedule configurations (including per-phase LR scheduler reinitialization). If a schedule is not provided, will generate and execute a default fine-tuning schedule using the provided :external+pl:class:`~lightning.pytorch.core.module.LightningModule`. See :ref:`the default schedule<index:The Default Fine-Tuning Schedule>`. Defaults to ``None``. max_depth: Maximum schedule depth to which the defined fine-tuning schedule should be executed. Specifying -1 or an integer > (number of defined schedule layers) will result in the entire fine-tuning schedule being executed. Defaults to -1. base_max_lr: The default maximum learning rate to use for the parameter groups associated with each scheduled fine-tuning depth if not explicitly specified in the fine-tuning schedule. If overridden to ``None``, will be set to the ``lr`` of the first scheduled fine-tuning depth. Defaults to 1e-5. restore_best: If ``True``, restore the best available (defined by the :class:`~finetuning_scheduler.fts_supporters.FTSCheckpoint`) checkpoint before fine-tuning depth transitions. Defaults to ``True``. gen_ft_sched_only: If ``True``, generate the default fine-tuning schedule to ``Trainer.log_dir`` (it will be named after your :external+pl:class:`~lightning.pytorch.core.module.LightningModule` subclass with the suffix ``_ft_schedule.yaml``) and exit without training. Typically used to generate a default schedule that will be adjusted by the user before training. Defaults to ``False``. epoch_transitions_only: If ``True``, use epoch-driven stopping criteria exclusively (rather than composing :class:`~finetuning_scheduler.fts_supporters.FTSEarlyStopping` and epoch-driven criteria which is the default). If using this mode, an epoch-driven transition (``max_transition_epoch`` >= 0) must be specified for each phase. If unspecified, ``max_transition_epoch`` defaults to -1 for each phase which signals the application of :class:`~finetuning_scheduler.fts_supporters.FTSEarlyStopping` criteria only. epoch_transitions_only defaults to ``False``. reinit_optim_cfg: An optimizer reinitialization configuration dictionary consisting of at minimum a nested ``optimizer_init`` dictionary with a ``class_path`` key specifying the class of the optimizer to be instantiated. Optionally, an ``init_args`` dictionary of arguments with which to initialize the optimizer may be included. A ``reinit_lr_cfg`` configuration can also be specified concurrently. By way of example, one could configure this dictionary via the :external+pl:class:`~lightning.pytorch.cli.LightningCLI` with the following: .. code-block:: yaml reinit_optim_cfg: optimizer_init: class_path: torch.optim.SGD init_args: lr: 1.0e-05 momentum: 0.9 weight_decay: 1.0e-06 reinit_lr_cfg: A lr scheduler reinitialization configuration dictionary consisting of at minimum a nested ``lr_scheduler_init`` dictionary with a ``class_path`` key specifying the class of the lr scheduler to be instantiated. Optionally, an ``init_args`` dictionary of arguments to initialize the lr scheduler with may be included. Additionally, one may optionally include arguments to pass to PyTorch Lightning's lr scheduler configuration :class:`~lightning.pytorch.utilities.types.LRSchedulerConfig` in the ``pl_lrs_cfg`` dictionary. A ``reinit_optim_cfg`` configuration can also be specified concurrently. By way of example, one could configure this dictionary via the :external+pl:class:`~lightning.pytorch.cli.LightningCLI` with the following: .. code-block:: yaml reinit_lr_cfg: lr_scheduler_init: class_path: torch.optim.lr_scheduler.StepLR init_args: step_size: 1 gamma: 0.7 pl_lrs_cfg: interval: epoch frequency: 1 name: Implicit_Reinit_LR_Scheduler use_current_optimizer_pg_lrs: true allow_untested: If ``True``, allows the use of custom or unsupported training strategies and lr schedulers (e.g. ``single_tpu``, ``MyCustomStrategy``, ``MyCustomLRScheduler``) . Defaults to ``False``. .. note:: Custom or officially unsupported strategies and lr schedulers can be used by setting :paramref:`~finetuning_scheduler.fts.FinetuningScheduler.allow_untested` to ``True``. Some officially unsupported strategies may work unaltered and are only unsupported due to the ``Fine-Tuning Scheduler`` project's lack of CI/testing resources for that strategy (e.g. ``single_tpu``). Most unsupported strategies and schedulers, however, are currently unsupported because they require varying degrees of modification to be compatible. For instance, with respect to strategies, ``deepspeed`` will require a :class:`~finetuning_scheduler.strategy_adapters.StrategyAdapter` similar to the one written for ``FSDP`` (:class:`~finetuning_scheduler.strategy_adapters.FSDPStrategyAdapter`) to be written before support can be added (PRs welcome!), while ``tpu_spawn`` would require an override of the current broadcast method to include python objects. Regarding lr schedulers, :external+torch:class:`~torch.optim.lr_scheduler.ChainedScheduler` and :external+torch:class:`~torch.optim.lr_scheduler.SequentialLR` are examples of schedulers not currently supported due to the configuration complexity and semantic conflicts supporting them would introduce. If a supported torch lr scheduler does not meet your requirements, one can always subclass a supported lr scheduler and modify it as required (e.g. :external+torch:class:`~torch.optim.lr_scheduler.LambdaLR` is especially useful for this). strategy_adapter_cfg: A configuration dictionary that will be applied to the :class:`~finetuning_scheduler.strategy_adapters.StrategyAdapter` associated with the current training :external+pl:class:`~lightning.pytorch.strategies.Strategy`. See the relevant :class:`~finetuning_scheduler.strategy_adapters.StrategyAdapter` documentation for strategy-specific configuration options. Defaults to None. custom_strategy_adapter: A dictionary associating the canonical ``strategy_flag`` associated with a :external+pl:class:`~lightning.pytorch.strategies.Strategy` (potentially a custom user-registered one) to the fully qualified path of a :class:`~finetuning_scheduler.strategy_adapters.StrategyAdapter` subclass. This is an experimental feature that is subject to change. Requires ``allow_untested`` to be set to ``True``. Defaults to None. apply_lambdas_new_pgs: If ``True``, applies most recent lambda in ``lr_lambdas`` list to newly added optimizer groups for lr schedulers that have a ``lr_lambdas`` attribute. Note this option only applies to phases without reinitialized lr schedulers. Phases with defined lr scheduler reinitialization configs will always apply the specified lambdas. Defaults to ``False``. logging_level: Sets the logging level for :class:`~finetuning_scheduler.fts.FinetuningScheduler`. Defaults to ``INFO``. enforce_phase0_params: Whether :class:`~finetuning_scheduler.fts.FinetuningScheduler` will reconfigure the user-configured optimizer (configured via `configure_optimizers`) to optimize the parameters (and only those parameters) scheduled to be optimized in phase 0 of the current fine-tuning schedule. Reconfiguration will only take place if FTS discovers the set of parameters to be initially thawed and present in the optimizer differs from the parameters specified in phase 0. Only the parameters included in the optimizer are affected; the choice of optimizer, lr_scheduler etc. remains unaltered. Defaults to ``True``. Attributes: _fts_state: The internal :class:`~finetuning_scheduler.fts.FinetuningScheduler` state. strategy_adapter_cfg: A configuration dictionary that will be applied to the :class:`~finetuning_scheduler.strategy_adapters.StrategyAdapter` associated with the current training :external+pl:class:`~lightning.pytorch.strategies.Strategy`. epoch_transitions_only: Whether to use epoch-driven stopping criteria exclusively. base_max_lr: The default maximum learning rate to use for the parameter groups associated with each scheduled fine-tuning depth if not explicitly specified in the fine-tuning schedule. If overridden to ``None``, will be set to the ``lr`` of the first scheduled fine-tuning depth. Defaults to 1e-5. """ super().__init__() self._fts_state = FTSState() self.max_depth = max_depth self.restore_best = restore_best self.ft_schedule = ft_schedule self.base_max_lr = base_max_lr self.gen_ft_sched_only = gen_ft_sched_only self.epoch_transitions_only = epoch_transitions_only self.reinit_optim_cfg = reinit_optim_cfg self.reinit_lr_cfg = reinit_lr_cfg self.strategy_adapter_cfg = strategy_adapter_cfg or {} self.custom_strategy_adapter = custom_strategy_adapter self.allow_untested = allow_untested self.apply_lambdas_new_pgs = apply_lambdas_new_pgs self.enforce_phase0_params = enforce_phase0_params self._has_reinit_schedule = False rz_logger = logging.getLogger("lightning.pytorch.utilities.rank_zero") rz_logger.setLevel(logging_level) @property def curr_depth(self) -> int: """Index of the fine-tuning schedule depth currently being trained. Returns: int: The index of the current fine-tuning training depth """ return self._fts_state._curr_depth @property def depth_remaining(self) -> int: """Remaining number of fine-tuning training levels in the schedule. Returns: int: The number of remaining fine-tuning training levels """ return max(self.max_depth - self._fts_state._curr_depth, 0) @staticmethod def _supported_strategy_flags() -> Sequence[str]: return ( "ddp", "ddp_find_unused_parameters_false", "ddp_find_unused_parameters_true", "ddp_spawn", "ddp_fork", "ddp_notebook", "single_device", "fsdp", "fsdp_cpu_offload", # "deepspeed", # relevant FTS strategy adapter not yet available, PRs welcome! )
[docs] def freeze_before_training(self, pl_module: "pl.LightningModule") -> None: """Freezes all model parameters so that parameter subsets can be subsequently thawed according to the fine- tuning schedule. Args: pl_module (:external+pl:class:`~lightning.pytorch.core.module.LightningModule`): The target :external+pl:class:`~lightning.pytorch.core.module.LightningModule` to freeze parameters of """ self.freeze(modules=pl_module, train_bn=False)
[docs] def step(self) -> None: """Prepare and execute the next scheduled fine-tuning level 1. Restore the current best model checkpoint if appropriate 2. Thaw model parameters according the the defined schedule 3. Synchronize the states of ``FitLoop`` and :attr:`~finetuning_scheduler.fts.FinetuningScheduler._fts_state` .. note:: The :class:`~finetuning_scheduler.fts.FinetuningScheduler` callback initially only supports single-schedule/optimizer fine-tuning configurations """ assert self.pl_module.trainer is not None pre_reinit_state = self._save_pre_reinit_lr_state(self.pl_module.trainer) if not self._fts_state._resume_fit_from_ckpt: if self.restore_best: self.restore_best_ckpt() self.step_pg( depth=self.curr_depth, optimizer=self.pl_module.trainer.optimizers[0], # type: ignore[arg-type] pre_reinit_state=pre_reinit_state, ) else: self.step_pg( depth=self.curr_depth, optimizer=self.pl_module.trainer.optimizers[0], # type: ignore[arg-type] depth_sync=False, pre_reinit_state=pre_reinit_state, ) else: self.thaw_to_depth() if self.depth_remaining == 0 and not self.epoch_transitions_only: assert self.pl_module.trainer.early_stopping_callback is not None self.pl_module.trainer.early_stopping_callback.final_phase = True # type: ignore[attr-defined] assert self._fts_state._ft_sync_objects is not None if self._fts_state._resume_fit_from_ckpt: # ensure multi-phase training session loops are synchronized for a fresh epoch restart self._maybe_sync_loops() FinetuningScheduler.sync(self._fts_state._ft_sync_objects, self._fts_state._ft_sync_props) if self.pl_module._compiler_ctx and self.pl_module._compiler_ctx.get("compiler", None) == "dynamo": # reset currently required as `AOTAutograd`` is getting confused by `requires_grad` alteration torch._dynamo.reset() rank_zero_info(f"Multi-phase fine-tuned training continuing at level {self.curr_depth}.") if self.depth_remaining == 0: max_epochs_msg = f"`max_epochs` ({self.pl_module.trainer.fit_loop.max_epochs}) is reached." composition_msg = "the early stopping conditions are met or " + max_epochs_msg rank_zero_info( f"Given the current configuration of `max_depth` ({self.max_depth}), this training session" f" will now end when {max_epochs_msg if self.epoch_transitions_only else composition_msg}" )
[docs] def step_pg( self, optimizer: ParamGroupAddable, depth: int, depth_sync: bool = True, pre_reinit_state: Optional[Tuple] = None, ) -> None: """Configure optimizer parameter groups for the next scheduled fine-tuning level, adding parameter groups beyond the restored optimizer state up to :paramref:`~finetuning_scheduler.fts.FinetuningScheduler.current_depth` and reinitializing the optimizer and/or learning rate scheduler as configured. Args: optimizer (:class:`~finetuning_scheduler.types.ParamGroupAddable`): The supported optimizer instance to which parameter groups will be configured and added. depth: The maximum index of the fine-tuning schedule for which to configure the optimizer parameter groups. depth_sync: If ``True``, configure optimizer parameter groups for all depth indices greater than the restored checkpoint. If ``False``, configure groups only for the specified depth. Defaults to ``True``. """ next_tl: Dict = {} assert isinstance(self.ft_schedule, dict) assert isinstance(self.pl_module, pl.LightningModule) assert isinstance(self.pl_module.trainer, pl.Trainer) # if the target depth is 0, implicit optimizer reinitialization should not be executed new_optimizer_cfg = ( (self.reinit_optim_cfg or self.ft_schedule[depth].get("new_optimizer", None)) if depth > 0 else None ) restored_depth = -1 if new_optimizer_cfg else self._fts_state._best_ckpt_depth if depth_sync or new_optimizer_cfg: thaw_layers = {d: tl for d, tl in self.ft_schedule.items() if d > restored_depth}.items() else: thaw_layers = {depth: self.ft_schedule[depth]}.items() for i, orig_next_tl in thaw_layers: next_tl = deepcopy(orig_next_tl) if i <= depth: next_tl["params"] = self.strategy_adapter.fts_optim_transform(next_tl["params"]) _, self._fts_state._curr_thawed_params = self.strategy_adapter.exec_ft_phase( self.pl_module, thaw_pl=next_tl["params"] ) if new_optimizer_cfg and i == 0: # If reinitializing the optimizer, we need to re-add the initial parameter groups (phase 0) optimizer = self.reinit_optimizer( new_optimizer=new_optimizer_cfg, trainer=self.pl_module.trainer, init_params=next_tl["params"] ) else: # Add pgs and configure the learning rate scheduler using the current optimizer/schedule self._add_pgs_config_lrs(optimizer, next_tl, depth, depth == i, pre_reinit_state)
def _add_pgs_config_lrs( self, optimizer: ParamGroupAddable, next_tl: Dict, depth: int, is_target_depth: bool, pre_reinit_state: Optional[Tuple], ) -> None: """Add optimizer parameter groups and potentially reinitialize/reconfigure the learning rate scheduler according to a given schedule phase configuration. Args: optimizer (:class:`~finetuning_scheduler.types.ParamGroupAddable`): The supported optimizer instance to which parameter groups will be configured and added. next_tl (Dict): A dictionary containing the schedule configuration associated with the current phase context. depth (int): The current target depth. is_target_depth (bool): Whether this restoration stage is the current target depth. pre_reinit_state (Tuple): Contingent on restoration context, lr scheduler and optimizer lr state to restore. """ # if the target depth is 0, lr scheduler reinitialization should not be executed new_scheduler_cfg = (self.reinit_lr_cfg or next_tl.get("new_lr_scheduler", None)) if depth > 0 else None if is_target_depth and depth > 0: # NB: The latest optimizer and lr scheduler lr state will be re-applied to all existing param groups before # implementing the next phase. assert pre_reinit_state self._restore_latest_lr_state(*pre_reinit_state) FinetuningScheduler.add_optimizer_groups( module=self.pl_module, optimizer=optimizer, thawed_pl=next_tl["params"], lr=next_tl.get("lr", optimizer.defaults["lr"]), no_decay=getattr(self.pl_module, "no_decay", None), apply_lambdas=self.apply_lambdas_new_pgs, ) if new_scheduler_cfg: self.reinit_lr_scheduler( new_lr_scheduler=new_scheduler_cfg, trainer=self.pl_module.trainer, optimizer=optimizer ) else: self._maybe_warn_lr_lambdas() def _maybe_warn_lr_lambdas(self) -> None: """If appropriate, warn the user that `lr_lambdas` will not be applied given the current configuration.""" if self.pl_module.trainer.lr_scheduler_configs: for config in self.pl_module.trainer.lr_scheduler_configs: show_warn_lambdas = ( hasattr(config.scheduler, "lr_lambdas") and config.scheduler.lr_lambdas[-1] is not None # type: ignore[union-attr] and not self.apply_lambdas_new_pgs ) if show_warn_lambdas: rank_zero_warn( "The lr scheduler used in this phase has lr_lambdas but will use a " "configured lr for new parameter groups because `apply_lambdas_new_pgs` is " "set to the default of `False`. If you would like new groups to have lr " "lambdas applied, set `apply_lambdas_new_pgs` to `True`." )
[docs] def restore_best_ckpt(self) -> None: """Restore the current best model checkpoint, according to :paramref:`~finetuning_scheduler.fts_supporters.FTSCheckpoint.best_model_path`""" assert self.pl_module.trainer is not None # wait for all processes to be ready to restore ckpt before restoring self.pl_module.trainer.strategy.barrier("setup_next_level") # if restarting across multiple depths, need to ensure we're restoring optimizer state appropriately # by resetting optimizer groups and allowing state dict to be reset commensurate w/ ckpt state for opt_idx, optimizer in enumerate(self.pl_module.trainer.optimizers): optimizer.param_groups = BaseFinetuning._apply_mapping_to_param_groups( self._fts_state._fts_ckpt_metadata["best_ckpt_pgs"][opt_idx], dict(self.pl_module.named_parameters()) ) if self.strategy_adapter.using_sharded_optimizer: ScheduleImplMixin._repartition_sharded_optim(optimizer) # type: ignore[arg-type] # we're restoring everything but callbacks and loops, otherwise, checkpoint_connector.restore() could be used assert self.pl_module.trainer.checkpoint_callback is not None checkpoint_path = self.pl_module.trainer.checkpoint_callback.best_model_path # type: ignore[attr-defined] try: self.pl_module.trainer._checkpoint_connector.resume_start(checkpoint_path=checkpoint_path) except KeyError as ke: # we may want to allow training to progress conditioned on context of restoration self._maybe_allow_incompatible_reinit_ckpt(ke) self.pl_module.trainer._checkpoint_connector.restore_datamodule() self.pl_module.trainer._checkpoint_connector.restore_model() # we need to override checkpoint_connector.restore_training_state() to bypass loop restoration # if additional customizations are required, may make sense to subclass _CheckpointConnector at some point self._restore_training_state() self.pl_module.trainer._checkpoint_connector.resume_end()
def _restore_training_state(self) -> None: """Restore training state without restoring loops from the pre-loaded checkpoint. This includes the precision settings, optimizer states and learning rate scheduler states. """ assert self.pl_module is not None and self.pl_module.trainer is not None checkpoint_connector = self.pl_module.trainer._checkpoint_connector # restore precision plugin (scaler etc.) checkpoint_connector.restore_precision_plugin_state() # checkpoint_connector.restore_training_state() would restore loops here # checkpoint_connector.restore_loops() assert self.pl_module.trainer.state.fn is not None if self.pl_module.trainer.state.fn == TrainerFn.FITTING: try: # enable strategy adapters to restore optimizer if `Strategy.lightning_restore_optimizer` is overridden self.strategy_adapter.on_before_restore_optimizers_and_lrs() # restore optimizers and schedulers state checkpoint_connector.restore_optimizers_and_schedulers() except KeyError as ke: self._maybe_allow_incompatible_reinit_ckpt(ke) def _maybe_allow_incompatible_reinit_ckpt(self, key_error: KeyError) -> None: """Inspect context for a given ``KeyError`` and permit continued training if using lr scheduler or optimizer reinitialization. Args: key_error (KeyError): The current key error encountered during checkpoint restoration. Raises: key_error: If not training in the context of lr scheduler or optimizer reinitialization, the provided ``KeyError`` will be raised. """ if self._has_reinit_schedule: rank_zero_warn( "Incompatible checkpoint detected when attempting to restore the optimizer and/or lr " "scheduler from a previous phase. Attempting to proceed with next phase of training since this " "schedule reinitializes the optimizer and/or lr scheduler.\n" "HINT: If subsequent errors are encountered, you can either set ``restore_best`` to ``False`` " "or alter your reinitialization schedule for the relevant training components (i.e. optimizer, " "lr scheduler)." ) else: raise key_error def _reduce_transition(self, strategy: Strategy, decision: bool) -> bool: """Reduce a transition decision across all world processes (effectively a global `any` collective) Args: strategy (Strategy): The PL :external+pl:class:`~lightning.pytorch.strategies.Strategy` context to use. decision (bool): The local process decision. Returns: bool: The reduced decision across all world processes. """ decision = torch.tensor(int(decision), device=strategy.root_device) decision = bool(strategy.reduce(decision, reduce_op=ReduceOp.SUM)) # type:ignore[arg-type] return decision def _sync_es_state(self, trainer: "pl.Trainer") -> None: """Synchronize the :class:`~finetuning_scheduler.fts_supporters.FTSEarlyStopping` callback transition state across distributed processes to avoid rare transition divergences when the user does not set ``sync_dist`` to ``True`` in logging the monitored metric. Args: trainer (:external+pl:class:`~lightning.pytorch.trainer.trainer.Trainer`): The :external+pl:class:`~lightning.pytorch.trainer.trainer.Trainer` object """ early_stopping_callback = trainer.early_stopping_callback assert early_stopping_callback is not None # reset the FTSEarlyStopping state in every distributed process if any world process is ready to transition to # the next fine-tuning phase should_reset_es = early_stopping_callback.es_phase_complete should_reset_es = self._reduce_transition(trainer.strategy, should_reset_es) if should_reset_es != early_stopping_callback.es_phase_complete: warn_msg = ( "The FTSEarlyStopping quantity you are monitoring is not being synchronized across processes and FTS" " has detected that two or more world processes are disagreeing on whether to continue the current" " training phase. Training is continuing by transitioning all training processes to the next" " fine-tuning phase when the early stopping conditions of any training process are met. To avoid this" " behavior in the future, you may want to either log the monitored metric with ``sync_dist`` set to" " ``True`` or increase the configured FTSEarlyStopping ``min_delta``." ) rank_zero_warn(warn_msg) early_stopping_callback._transition_es_phase() def _strategy_setup(self, trainer: "pl.Trainer") -> None: """Validate a compatible :external+pl:class:`~lightning.pytorch.strategies.Strategy` strategy is being used and connects the relevant :class:`~finetuning_scheduler.strategy_adapters.StrategyAdapter`. Args: trainer (:external+pl:class:`~lightning.pytorch.trainer.trainer.Trainer`): The :external+pl:class:`~lightning.pytorch.trainer.trainer.Trainer` object Raises: MisconfigurationException: If the :external+pl:class:`~lightning.pytorch.strategies.Strategy` strategy being used is not currently compatible with the :class:`~finetuning_scheduler.fts.FinetuningScheduler` callback. """ strategy = trainer.strategy connector_flag = getattr(trainer._accelerator_connector, "_strategy_flag", None) strategy_flag = connector_flag.strategy_name if isinstance(connector_flag, Strategy) else connector_flag supported = [t.lower() for t in self._supported_strategy_flags()] if strategy_flag and strategy_flag not in supported: # type: ignore[attr-defined] if not self.allow_untested: raise MisconfigurationException( "FTS is has not yet been adapted for or rigorously tested using the specified distributed strategy." f" Please select from currently compatible distributed strategies ({supported}) or if you would" " like to attempt to use the currently specified strategy, pass ``allow_untested=True`` to the" " FinetuningScheduler callback when adding it." ) else: warn_msg = ( "Allowing untested strategy" f" '{strategy}' because ``allow_untested`` is ``True``." # type: ignore[attr-defined] ) rank_zero_warn(warn_msg) if self.custom_strategy_adapter: strategy_cls = self._import_strategy_adapter(strategy_flag, self.custom_strategy_adapter) rank_zero_info( f"Imported custom strategy adapter class type `{strategy_cls}` associated with the current strategy" f" `{strategy_flag}`." ) else: strategy_cls = STRATEGY_ADAPTERS.get(strategy_flag, StrategyAdapter) self.strategy_adapter = strategy_cls(**self.strategy_adapter_cfg) self.strategy_adapter.connect(self)
[docs] def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: str) -> None: """Validate a compatible :external+pl:class:`~lightning.pytorch.strategies.Strategy` strategy is being used and ensure all :class:`~finetuning_scheduler.fts.FinetuningScheduler` callback dependencies are met. If a valid configuration is present, then either dump the default fine-tuning schedule OR 1. configure the :class:`~finetuning_scheduler.fts_supporters.FTSEarlyStopping` callback (if relevant) 2. initialize the :attr:`~finetuning_scheduler.fts.FinetuningScheduler._fts_state` 3. freeze the target :external+pl:class:`~lightning.pytorch.core.module.LightningModule` parameters Finally, initialize the :class:`~finetuning_scheduler.fts.FinetuningScheduler` training session in the training environment. Args: trainer (:external+pl:class:`~lightning.pytorch.trainer.trainer.Trainer`): The :external+pl:class:`~lightning.pytorch.trainer.trainer.Trainer` object pl_module (:external+pl:class:`~lightning.pytorch.core.module.LightningModule`): The :external+pl:class:`~lightning.pytorch.core.module.LightningModule` object stage: The ``RunningStage.{SANITY_CHECKING,TRAINING,VALIDATING}``. Defaults to None. Raises: SystemExit: Gracefully exit before training if only generating and not executing a fine-tuning schedule. """ self._callback_dep_setup(trainer, pl_module, stage) self._strategy_setup(trainer) if self.gen_ft_sched_only: if trainer.is_global_zero: assert trainer.log_dir is not None _ = ScheduleImplMixin.gen_ft_schedule(pl_module, trainer.log_dir) log.info("Bypassing training, generating fine-tuning schedule for review and subsequent fine-tuning") raise SystemExit(0) if not self.epoch_transitions_only: assert isinstance(trainer.early_stopping_callback, FTSEarlyStopping) trainer.early_stopping_callback.final_phase = False trainer.early_stopping_callback.es_phase_complete = False self._fts_state._ft_sync_objects = pl_module.trainer.fit_loop, self._fts_state if trainer.ckpt_path: self._fts_state._resume_fit_from_ckpt = True self.freeze_before_training(pl_module) self.pl_module = pl_module # save pl_module ref for downstream configuration convenience self.init_fts()
[docs] def on_fit_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: """Before beginning training, ensure an optimizer configuration supported by :class:`~finetuning_scheduler.fts.FinetuningScheduler` is present. Args: trainer (:external+pl:class:`~lightning.pytorch.trainer.trainer.Trainer`): The :external+pl:class:`~lightning.pytorch.trainer.trainer.Trainer` object pl_module (:external+pl:class:`~lightning.pytorch.core.module.LightningModule`): The :external+pl:class:`~lightning.pytorch.core.module.LightningModule` object Raises: MisconfigurationException: If more than 1 optimizers are configured indicates a configuration error """ self.strategy_adapter.on_before_fts_fit_start() if len(trainer.optimizers) > 1: raise MisconfigurationException("FTS currently only supports a single-optimizer configuration") if getattr(trainer.optimizers[0], "_overlap_with_ddp", False): raise MisconfigurationException( "Configuring an optimizer using `overlap_with_ddp=True` is not supported" " with FTS since that optimizer mode only supports a single parameter" " group." ) if trainer.lr_scheduler_configs: self._is_supported_lr(type(trainer.lr_scheduler_configs[0].scheduler)) if self.curr_depth == 0: assert isinstance(self.ft_schedule, Dict) self._validate_opt_init() super().on_fit_start(trainer, pl_module)
[docs] def state_dict(self) -> Dict[str, Any]: """Before saving a checkpoint, add the :class:`~finetuning_scheduler.fts.FinetuningScheduler` callback state to be saved. Returns: Dict[str, Any]: The :class:`~finetuning_scheduler.fts.FinetuningScheduler` callback state dictionary that will be added to the checkpoint """ # callback state such as `_ft_init_epoch` does not currently need to be persisted assert self.pl_module is not None and self.pl_module.trainer is not None trainer = self.pl_module.trainer checkpoint_callback = trainer.checkpoint_callback if checkpoint_callback.current_score == checkpoint_callback.best_model_score: # type: ignore[union-attr] self._fts_state._best_ckpt_depth = self._fts_state._curr_depth for opt_idx, _ in enumerate(trainer.optimizers): self._fts_state._fts_ckpt_metadata["best_ckpt_pgs"][opt_idx] = deepcopy( self._internal_optimizer_metadata[opt_idx] ) self._fts_state._fts_ckpt_metadata["current_ckpt_depth"] = self._fts_state._curr_depth self._fts_state._fts_ckpt_metadata["best_ckpt_depth"] = self._fts_state._best_ckpt_depth return { "internal_optimizer_metadata": self._internal_optimizer_metadata, "fts_metadata": self._fts_state._fts_ckpt_metadata, }
[docs] def load_state_dict(self, state_dict: Dict[str, Any]) -> None: """After loading a checkpoint, load the saved :class:`~finetuning_scheduler.fts.FinetuningScheduler` callback state and update the current callback state accordingly. Args: state_dict: The :class:`~finetuning_scheduler.fts.FinetuningScheduler` callback state dictionary that will be loaded from the checkpoint """ self._restarting = True self._internal_optimizer_metadata = state_dict["internal_optimizer_metadata"] self._fts_state._fts_ckpt_metadata = state_dict["fts_metadata"] if self._fts_state._resume_fit_from_ckpt: # if resuming training, on_fit_start will already be called # if resuming from a checkpoint, we need to update current fts depth from the used ckpt self._fts_state._curr_depth = self._fts_state._fts_ckpt_metadata["current_ckpt_depth"] # if we're restoring from a non-best ckpt depth, ensure it is the new training incarnation's initial best self._fts_state._best_ckpt_depth = self._fts_state._fts_ckpt_metadata["current_ckpt_depth"]
[docs] def should_transition(self, trainer: "pl.Trainer") -> bool: """Phase transition logic is contingent on whether we are composing :class:`~finetuning_scheduler.fts_supporters.FTSEarlyStopping` criteria with epoch-driven transition constraints or exclusively using epoch-driven transition scheduling. (i.e., :attr:`~finetuning_scheduler.fts.FinetuningScheduler.epoch_transitions_only` is ``True``) Args: trainer (:external+pl:class:`~lightning.pytorch.trainer.trainer.Trainer`): The :external+pl:class:`~lightning.pytorch.trainer.trainer.Trainer` object """ assert self.pl_module is not None assert isinstance(self.ft_schedule, Dict) early_stopping_callback = trainer.early_stopping_callback curr_max_epoch = ( self.ft_schedule[self.curr_depth]["max_transition_epoch"] if self.depth_remaining > 0 else trainer.fit_loop.max_epochs ) if not self.epoch_transitions_only: # if we're considering FTSEarlyStopping criteria assert early_stopping_callback is not None # in the edge case where transition decisions diverge among distributed processes because the user is # running in a distributed context without ``sync_dist`` set to ``True`` and ``min_delta`` is # sufficiently low, we should reduce the transition decision over all training processes to avoid deadlocks if early_stopping_callback.reduce_transition_decisions: self._sync_es_state(trainer) is_final_phase = early_stopping_callback.final_phase # type: ignore[attr-defined] epoch_driven_transition = ( True if not is_final_phase and (0 <= curr_max_epoch <= trainer.current_epoch) else False ) if early_stopping_callback.es_phase_complete or epoch_driven_transition: # type: ignore[attr-defined] phase_transition = True else: phase_transition = False else: # we're only considering epoch-driven transition constraints phase_transition = True if 0 <= curr_max_epoch <= trainer.current_epoch else False return phase_transition
[docs] def on_train_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: """Before beginning a training epoch, configure the internal :attr:`~finetuning_scheduler.fts.FinetuningScheduler._fts_state`, prepare the next scheduled fine-tuning level and store the updated optimizer configuration before continuing training Args: trainer (:external+pl:class:`~lightning.pytorch.trainer.trainer.Trainer`): The :external+pl:class:`~lightning.pytorch.trainer.trainer.Trainer` object pl_module (:external+pl:class:`~lightning.pytorch.core.module.LightningModule`): The :external+pl:class:`~lightning.pytorch.core.module.LightningModule` object """ # if resuming from a ckpt, we need to sync fts_state if self._fts_state._resume_fit_from_ckpt: self.step() self._fts_state._resume_fit_from_ckpt = False # increment ft_epoch on each train epoch assert isinstance(self._fts_state._ft_init_epoch, int) if trainer.current_epoch > self._fts_state._ft_init_epoch: assert self._fts_state._ft_sync_objects is not None self.sync(self._fts_state._ft_sync_objects, self._fts_state._ft_sync_props) if self.should_transition(trainer): self._fts_state._curr_depth += 1 # increment depth self.step() rank_zero_debug( f"Current depth is {self.curr_depth}." "\nCurrent logical parameters thawed by Fine-Tuning Scheduler:\n " f"{pformat(self.strategy_adapter.logical_param_translation(self._fts_state._curr_thawed_params))}." "\nCurrent actual parameters thawed by Fine-Tuning Scheduler:\n" f"{pformat(self._fts_state._curr_thawed_params)}. " ) if not self.epoch_transitions_only: assert isinstance(trainer.early_stopping_callback, FTSEarlyStopping) trainer.early_stopping_callback._reset_es_phase() if self.depth_remaining == 0: if not self.epoch_transitions_only: assert isinstance(trainer.early_stopping_callback, FTSEarlyStopping) trainer.early_stopping_callback.final_phase = True # capture optimizer config for all optimizers (though initially we'll only support a single optimizer) for opt_idx, optimizer in enumerate(trainer.optimizers): num_saved_groups = ( len(self._internal_optimizer_metadata[opt_idx]) if opt_idx in self._internal_optimizer_metadata else 0 ) current_param_groups = optimizer.param_groups self._store(pl_module, opt_idx, num_saved_groups, current_param_groups)
[docs] def on_before_zero_grad( self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", optimizer: ParamGroupAddable, # type: ignore[override] ) -> None: """Afer the latest optimizer step, update the :attr:`~finetuning_scheduler.fts.FinetuningScheduler._fts_state`, incrementing the global fine-tuning steps taken Args: trainer (:external+pl:class:`~lightning.pytorch.trainer.trainer.Trainer`): The :external+pl:class:`~lightning.pytorch.trainer.trainer.Trainer` object pl_module (:external+pl:class:`~lightning.pytorch.core.module.LightningModule`): The :external+pl:class:`~lightning.pytorch.core.module.LightningModule` object optimizer (:class:`~finetuning_scheduler.types.ParamGroupAddable`): The supported optimizer instance to which parameter groups will be configured and added. """ self._fts_state._ft_global_steps += 1
[docs] def on_train_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: """Synchronize internal :attr:`~finetuning_scheduler.fts.FinetuningScheduler._fts_state` on end of training to ensure final training state is consistent with epoch semantics. Args: trainer: The :external+pl:class:`~lightning.pytorch.trainer.trainer.Trainer` object pl_module (:external+pl:class:`~lightning.pytorch.core.module.LightningModule`): The :external+pl:class:`~lightning.pytorch.core.module.LightningModule` object """ assert self._fts_state._ft_sync_objects is not None self.sync(self._fts_state._ft_sync_objects, self._fts_state._ft_sync_props)

© Copyright Copyright (c) 2021-2024, Dan Dale. Revision bb5b478f.

Built with Sphinx using a theme provided by Read the Docs.
Read the Docs v: stable
Versions
latest
stable
v2.2.1
v2.2.0
v2.1.4
v2.1.3
v2.1.2
v2.1.1
v2.1.0
v2.0.9
v2.0.7
v2.0.6
v2.0.4
v2.0.2
v2.0.1
v2.0.0
v0.4.1
v0.4.0
v0.3.4
v0.3.3
v0.3.2
v0.3.1
v0.3.0
v0.2.3
v0.2.2
v0.2.1
v0.2.0
v0.1.8
v0.1.7
v0.1.6
v0.1.5
v0.1.4
v0.1.3
v0.1.2
Downloads
html
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.