fts¶
Classes
This callback enables flexible, multi-phase, scheduled finetuning of foundational models. |
Finetuning Scheduler¶
Used to implement flexible finetuning training schedules
- class finetuning_scheduler.fts.FinetuningScheduler(ft_schedule=None, max_depth=- 1, base_max_lr=1e-05, restore_best=True, gen_ft_sched_only=False, epoch_transitions_only=False)[source]¶
Bases:
pytorch_lightning.callbacks.finetuning.BaseFinetuning
,finetuning_scheduler.fts_supporters.SchedulingMixin
,finetuning_scheduler.fts_supporters.CallbackDepMixin
This callback enables flexible, multi-phase, scheduled finetuning of foundational models. Gradual unfreezing/thawing can help maximize foundational model knowledge retention while allowing (typically upper layers of) the model to optimally adapt to new tasks during transfer learning.
FinetuningScheduler
orchestrates the gradual unfreezing of models via a finetuning schedule that is either implicitly generated (the default) or explicitly provided by the user (more computationally efficient).Finetuning phase transitions are driven by
FTSEarlyStopping
criteria (a multi-phase extension ofEarlyStopping
), user-specified epoch transitions or a composition of the two (the default mode). AFinetuningScheduler
training session completes when the final phase of the schedule has its stopping criteria met. See Early Stopping for more details on that callback’s configuration.Schedule definition is facilitated via
gen_ft_schedule()
which dumps a default finetuning schedule (by default using a naive, 2-parameters per level heuristic) which can be adjusted as desired by the user and subsuquently passed to the callback. Implicit finetuning mode generates the default schedule and proceeds to finetune according to the generated schedule. Implicit finetuning will often be less computationally efficient than explicit finetuning but can often serve as a good baseline for subsquent explicit schedule refinement and can marginally outperform many explicit schedules.Example:
from pytorch_lightning import Trainer from pytorch_lightning.callbacks import FinetuningScheduler trainer = Trainer(callbacks=[FinetuningScheduler()])
Define and configure a scheduled finetuning training session.
- Parameters
ft_schedule¶ (
Union
[str
,dict
,None
]) – The finetuning schedule to be executed. Usually will be a .yaml file path but can also be a properly structured Dict. See Specifying a Finetuning Schedule for the schedule format. If a schedule is not provided, will generate and execute a default finetuning schedule using the providedLightningModule
. See the default schedule. Defaults toNone
.max_depth¶ (
int
) – Maximum schedule depth to which the defined finetuning schedule should be executed. Specifying -1 or an integer > (number of defined schedule layers) will result in the entire finetuning schedule being executed. Defaults to -1.base_max_lr¶ (
float
) – The default maximum learning rate to use for the parameter groups associated with each scheduled finetuning depth if not explicitly specified in the finetuning schedule. If overridden toNone
, will be set to thelr
of the first scheduled finetuning depth scaled by 1e-1. Defaults to 1e-5.restore_best¶ (
bool
) – IfTrue
, restore the best available (defined by theFTSCheckpoint
) checkpoint before finetuning depth transitions. Defaults toTrue
.gen_ft_sched_only¶ (
bool
) – IfTrue
, generate the default finetuning schedule toTrainer.log_dir
(it will be named after yourLightningModule
subclass with the suffix_ft_schedule.yaml
) and exit without training. Typically used to generate a default schedule that will be adjusted by the user before training. Defaults toFalse
.epoch_transitions_only¶ (
bool
) – IfTrue
, Use epoch-driven stopping criteria exclusively (rather than composingFTSEarlyStopping
and epoch-driven criteria which is the default). If using this mode, an epoch-driven transition (max_transition_epoch
>= 0) must be specified for each phase. If unspecified,max_transition_epoch
defaults to -1 for each phase which signals the application ofFTSEarlyStopping
criteria only . epoch_transitions_only defaults toFalse
.
- _fts_state¶
The internal finetuning scheduler state.
- freeze_before_training(pl_module)[source]¶
Freezes all model parameters so that parameter subsets can be subsequently thawed according to the finetuning schedule.
- Parameters
pl_module¶ (
LightningModule
) – The targetLightningModule
to freeze parameters of- Return type
- load_state_dict(state_dict)[source]¶
After loading a checkpoint, load the saved
FinetuningScheduler
callback state and update the current callback state accordingly.
- on_before_zero_grad(trainer, pl_module, optimizer)[source]¶
Afer the latest optimizer step, update the
_fts_state
, incrementing the global finetuning steps taken- Parameters
pl_module¶ (
LightningModule
) – TheLightningModule
objectoptimizer¶ (
Optimizer
) – TheOptimizer
to which parameter groups will be configured and added.
- Return type
- on_fit_start(trainer, pl_module)[source]¶
Before beginning training, ensure an optimizer configuration supported by
FinetuningScheduler
is present.- Parameters
pl_module¶ (
LightningModule
) – TheLightningModule
object
- Raises
MisconfigurationException – If more than 1 optimizers are configured indicates a configuration error
- Return type
- on_train_end(trainer, pl_module)[source]¶
Synchronize internal
_fts_state
on end of training to ensure final training state is consistent with epoch semantics.- Parameters
pl_module¶ (
LightningModule
) – _description_
- Return type
- on_train_epoch_start(trainer, pl_module)[source]¶
Before beginning a training epoch, configure the internal
_fts_state
, prepare the next scheduled finetuning level and store the updated optimizer configuration before continuing training- Parameters
pl_module¶ (
LightningModule
) – TheLightningModule
object
- Return type
- restore_best_ckpt()[source]¶
Restore the current best model checkpoint, according to
best_model_path
- Return type
- setup(trainer, pl_module, stage=None)[source]¶
Validate a compatible
Strategy
strategy is being used and ensure allFinetuningScheduler
callback dependencies are met. If a valid configuration is present, then either dump the default finetuning schedule OR 1. configure theFTSEarlyStopping
callback (if relevant) 2. initialize the_fts_state
3. freeze the targetLightningModule
parameters Finally, initialize theFinetuningScheduler
training session in the training environment.- Parameters
pl_module¶ (
LightningModule
) – TheLightningModule
objectstage¶ (
Optional
[str
]) – TheRunningStage.{SANITY_CHECKING,TRAINING,VALIDATING}
. Defaults to None.
- Raises
SystemExit – Gracefully exit before training if only generating and not executing a finetuning schedule.
MisconfigurationException – If the
Strategy
strategy being used is not currently compatible with theFinetuningScheduler
callback.
- Return type
- should_transition(trainer)[source]¶
Phase transition logic is contingent on whether we are composing
FTSEarlyStopping
criteria with epoch-driven transition constraints or exclusively using epoch-driven transition scheduling. (i.e.,epoch_transitions_only
isTrue
)
- state_dict()[source]¶
Before saving a checkpoint, add the
FinetuningScheduler
callback state to be saved.- Returns
- The
FinetuningScheduler
callback state dictionary that will be added to the checkpoint
- The
- Return type
Dict[str, Any]
- step()[source]¶
Prepare and execute the next scheduled finetuning level 1. Restore the current best model checkpoint if appropriate 2. Thaw model parameters according the the defined schedule 3. Synchronize the states of
FitLoop
and_fts_state
Note
The
FinetuningScheduler
callback initially only supports single-schedule/optimizer finetuning configurations- Return type
- step_pg(optimizer, depth, depth_sync=True)[source]¶
Configure optimizer parameter groups for the next scheduled finetuning level, adding parameter groups beyond the restored optimizer state up to
current_depth
- Parameters
optimizer¶ (
Optimizer
) – TheOptimizer
to which parameter groups will be configured and added.depth¶ (
int
) – The maximum index of the finetuning schedule for which to configure the optimizer parameter groups.depth_sync¶ (
bool
) – IfTrue
, configure optimizer parameter groups for all depth indices greater than the restored checkpoint. IfFalse
, configure groups only for the specified depth. Defaults toTrue
.
- Return type