Shortcuts

fts

Classes

FinetuningScheduler

This callback enables flexible, multi-phase, scheduled finetuning of foundational models.

Finetuning Scheduler

Used to implement flexible finetuning training schedules

class finetuning_scheduler.fts.FinetuningScheduler(ft_schedule=None, max_depth=- 1, base_max_lr=1e-05, restore_best=True, gen_ft_sched_only=False, epoch_transitions_only=False)[source]

Bases: pytorch_lightning.callbacks.finetuning.BaseFinetuning, finetuning_scheduler.fts_supporters.SchedulingMixin, finetuning_scheduler.fts_supporters.CallbackDepMixin

This callback enables flexible, multi-phase, scheduled finetuning of foundational models. Gradual unfreezing/thawing can help maximize foundational model knowledge retention while allowing (typically upper layers of) the model to optimally adapt to new tasks during transfer learning. FinetuningScheduler orchestrates the gradual unfreezing of models via a finetuning schedule that is either implicitly generated (the default) or explicitly provided by the user (more computationally efficient).

Finetuning phase transitions are driven by FTSEarlyStopping criteria (a multi-phase extension of EarlyStopping), user-specified epoch transitions or a composition of the two (the default mode). A FinetuningScheduler training session completes when the final phase of the schedule has its stopping criteria met. See Early Stopping for more details on that callback’s configuration.

Schedule definition is facilitated via gen_ft_schedule() which dumps a default finetuning schedule (by default using a naive, 2-parameters per level heuristic) which can be adjusted as desired by the user and subsuquently passed to the callback. Implicit finetuning mode generates the default schedule and proceeds to finetune according to the generated schedule. Implicit finetuning will often be less computationally efficient than explicit finetuning but can often serve as a good baseline for subsquent explicit schedule refinement and can marginally outperform many explicit schedules.

Example:

from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import FinetuningScheduler
trainer = Trainer(callbacks=[FinetuningScheduler()])

Define and configure a scheduled finetuning training session.

Parameters
  • ft_schedule (Union[str, dict, None]) – The finetuning schedule to be executed. Usually will be a .yaml file path but can also be a properly structured Dict. See Specifying a Finetuning Schedule for the schedule format. If a schedule is not provided, will generate and execute a default finetuning schedule using the provided LightningModule. See the default schedule. Defaults to None.

  • max_depth (int) – Maximum schedule depth to which the defined finetuning schedule should be executed. Specifying -1 or an integer > (number of defined schedule layers) will result in the entire finetuning schedule being executed. Defaults to -1.

  • base_max_lr (float) – The default maximum learning rate to use for the parameter groups associated with each scheduled finetuning depth if not explicitly specified in the finetuning schedule. If overridden to None, will be set to the lr of the first scheduled finetuning depth scaled by 1e-1. Defaults to 1e-5.

  • restore_best (bool) – If True, restore the best available (defined by the FTSCheckpoint) checkpoint before finetuning depth transitions. Defaults to True.

  • gen_ft_sched_only (bool) – If True, generate the default finetuning schedule to Trainer.log_dir (it will be named after your LightningModule subclass with the suffix _ft_schedule.yaml) and exit without training. Typically used to generate a default schedule that will be adjusted by the user before training. Defaults to False.

  • epoch_transitions_only (bool) – If True, Use epoch-driven stopping criteria exclusively (rather than composing FTSEarlyStopping and epoch-driven criteria which is the default). If using this mode, an epoch-driven transition (max_transition_epoch >= 0) must be specified for each phase. If unspecified, max_transition_epoch defaults to -1 for each phase which signals the application of FTSEarlyStopping criteria only . epoch_transitions_only defaults to False.

_fts_state

The internal finetuning scheduler state.

Type

finetuning_scheduler.fts_supporters.FTSState

freeze_before_training(pl_module)[source]

Freezes all model parameters so that parameter subsets can be subsequently thawed according to the finetuning schedule.

Parameters

pl_module (LightningModule) – The target LightningModule to freeze parameters of

Return type

None

load_state_dict(state_dict)[source]

After loading a checkpoint, load the saved FinetuningScheduler callback state and update the current callback state accordingly.

Parameters

state_dict (Dict[str, Any]) – The FinetuningScheduler callback state dictionary that will be loaded from the checkpoint

Return type

None

on_before_zero_grad(trainer, pl_module, optimizer)[source]

Afer the latest optimizer step, update the _fts_state, incrementing the global finetuning steps taken

Parameters
Return type

None

on_fit_start(trainer, pl_module)[source]

Before beginning training, ensure an optimizer configuration supported by FinetuningScheduler is present.

Parameters
Raises

MisconfigurationException – If more than 1 optimizers are configured indicates a configuration error

Return type

None

on_train_end(trainer, pl_module)[source]

Synchronize internal _fts_state on end of training to ensure final training state is consistent with epoch semantics.

Parameters
Return type

None

on_train_epoch_start(trainer, pl_module)[source]

Before beginning a training epoch, configure the internal _fts_state, prepare the next scheduled finetuning level and store the updated optimizer configuration before continuing training

Parameters
Return type

None

restore_best_ckpt()[source]

Restore the current best model checkpoint, according to best_model_path

Return type

None

setup(trainer, pl_module, stage=None)[source]

Validate a compatible Strategy strategy is being used and ensure all FinetuningScheduler callback dependencies are met. If a valid configuration is present, then either dump the default finetuning schedule OR 1. configure the FTSEarlyStopping callback (if relevant) 2. initialize the _fts_state 3. freeze the target LightningModule parameters Finally, initialize the FinetuningScheduler training session in the training environment.

Parameters
Raises
  • SystemExit – Gracefully exit before training if only generating and not executing a finetuning schedule.

  • MisconfigurationException – If the Strategy strategy being used is not currently compatible with the FinetuningScheduler callback.

Return type

None

should_transition(trainer)[source]

Phase transition logic is contingent on whether we are composing FTSEarlyStopping criteria with epoch-driven transition constraints or exclusively using epoch-driven transition scheduling. (i.e., epoch_transitions_only is True)

Parameters

trainer (Trainer) – The Trainer object

Return type

bool

state_dict()[source]

Before saving a checkpoint, add the FinetuningScheduler callback state to be saved.

Returns

The FinetuningScheduler callback state dictionary

that will be added to the checkpoint

Return type

Dict[str, Any]

step()[source]

Prepare and execute the next scheduled finetuning level 1. Restore the current best model checkpoint if appropriate 2. Thaw model parameters according the the defined schedule 3. Synchronize the states of FitLoop and _fts_state

Note

The FinetuningScheduler callback initially only supports single-schedule/optimizer finetuning configurations

Return type

None

step_pg(optimizer, depth, depth_sync=True)[source]

Configure optimizer parameter groups for the next scheduled finetuning level, adding parameter groups beyond the restored optimizer state up to current_depth

Parameters
  • optimizer (Optimizer) – The Optimizer to which parameter groups will be configured and added.

  • depth (int) – The maximum index of the finetuning schedule for which to configure the optimizer parameter groups.

  • depth_sync (bool) – If True, configure optimizer parameter groups for all depth indices greater than the restored checkpoint. If False, configure groups only for the specified depth. Defaults to True.

Return type

None

property curr_depth: int

Index of the finetuning schedule depth currently being trained.

Returns

The index of the current finetuning training depth

Return type

int

property depth_remaining: int

Remaining number of finetuning training levels in the schedule.

Returns

The number of remaining finetuning training levels

Return type

int

Read the Docs v: v0.1.2
Versions
latest
stable
v0.1.2
Downloads
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.