fts_supporters¶
Classes
Functionality for validating/managing callback dependencies. |
|
Give user-provided callbacks with the ability to connect to another user-provided callback. |
|
Extends/specializes |
|
Extends/specializes |
|
Dataclass to encapsulate the |
|
Functionality for generating and executing finetuning schedules. |
|
Functionality for parsing and validating finetuning schedules. |
|
Alters SafeLoader to enable duplicate key detection by the SafeConstructor. |
Finetuning Scheduler Supporters¶
Classes composed to support scheduled finetuning
- class finetuning_scheduler.fts_supporters.CallbackDepMixin[source]¶
Bases:
abc.ABC
Functionality for validating/managing callback dependencies.
- class finetuning_scheduler.fts_supporters.CallbackResolverMixin(callback_attrs=('ft_schedule', 'max_depth'), callback_parents={'EarlyStopping': <class 'pytorch_lightning.callbacks.early_stopping.EarlyStopping'>, 'ModelCheckpoint': <class 'pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint'>}, target_callback_ref='FinetuningScheduler', support_multiple=False)[source]¶
Bases:
abc.ABC
Give user-provided callbacks with the ability to connect to another user-provided callback.
This resolution logic is provided in order to avoid callback-dependent trainer attributes (e.g. trainer.finetuningscheduler_callback)
Initialize the user-provided callback depedency resolver in accordance with the user-provided module configuration.
- Parameters
callback_attrs¶ (Tuple, optional) – Attribute signature of user-provided callback to be structurally detected and connected. Defaults to CALLBACK_ATTRS defined in the user-provided module.
callback_parents¶ (Dict, optional) – The parent classes of all user-provided callbacks in the module that should be connected to the target user-provided callback. Defaults to CALLBACK_DEP_PARENTS in the user-provided module.
target_callback_ref¶ (str, optional) – The name of the target user-provided callback to connect to. For each subclass of CALLBACK_DEP_PARENTS, an attribute named
(target_callback_ref.lower())_callback
will be added. Defaults to TARGET_CALLBACK_REF in the user-provided module.support_multiple¶ (bool, optional) – Whether multiple instances of the target user-provided callback (only the first of which will be connected to) are allowed. Defaults to False.
- connect_callback(trainer, reconnect=False)[source]¶
Connect each user-provided callback dependency that needs to be connected to the target user-provided callback.
- Parameters
- Raises
MisconfigurationException – If no target callback is detected
MisconfigurationException – if
support_multiple
isFalse
and multiple target callbacks are detected.
- Return type
- class finetuning_scheduler.fts_supporters.FTSCheckpoint(*args, **kwargs)[source]¶
Bases:
pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint
,finetuning_scheduler.fts_supporters.CallbackResolverMixin
Extends/specializes
ModelCheckpoint
to facilitate multi-phase scheduled finetuning. Overrides thestate_dict
andload_state_dict
hooks to maintain additional state (current_ckpt_depth
,best_ckpt_depth
,finetuningscheduler_callback
). Usage ofFTSCheckpoint
is identical toModelCheckpoint
andFTSCheckpoint
will automatically be used if aFinetuningScheduler
callback is detected.Warning
FTSCheckpoint
is in beta and subject to change. For detailed usage information, seeModelCheckpoint
.- best_ckpt_depth¶
Used to track the depth of the best known checkpoint (it may be from a different training depth)
- Type
- finetuningscheduler_callback¶
Reference to the
FinetuningScheduler
callback being used.
- save_on_train_epoch_end¶
Whether to run checkpointing at the end of the training epoch. If this is
False
, then the check runs at the end of the validation. Defaults toNone
similar toModelCheckpoint
but is set toFalse
during setup unless overridden.- Type
- load_state_dict(state_dict)[source]¶
Overrides
ModelCheckpoint
’sload_state_dict
method to load multi-phase training depth state.
- setup(trainer, pl_module, stage=None)[source]¶
Verify a valid callback configuration is present before beginning training.
- Parameters
- Raises
MisconfigurationException – If a
FinetuningScheduler
callback is not found on initialization (finetuningscheduler_callback
isNone
)MisconfigurationException – If
restore_best
isTrue
andModelCheckpoint.save_top_k
is eitherNone
or0
MisconfigurationException – If
restore_best
isTrue
andmonitor
isNone
- Return type
- state_dict()[source]¶
Overrides.
ModelCheckpoint
’sstate_dict
method to maintain multi-phase training depth state.- Returns
the callback state dictionary that will be saved.
- Return type
Dict[str, Any]
- class finetuning_scheduler.fts_supporters.FTSEarlyStopping(*args, **kwargs)[source]¶
Bases:
pytorch_lightning.callbacks.early_stopping.EarlyStopping
,finetuning_scheduler.fts_supporters.CallbackResolverMixin
Extends/specializes
EarlyStopping
to facilitate multi-phase scheduled finetuning.Adds
es_phase_complete
,final_phase
andfinetuningscheduler_callback
attributes and modifiesEarlyStopping._evaluate_stopping_criteria
to enable multi-phase behavior. Usage ofFTSEarlyStopping
is identical toEarlyStopping
except the former will evaluate the specified early stopping criteria at every scheduled phase.FTSEarlyStopping
will automatically be used if aFinetuningScheduler
callback is detected andepoch_transitions_only
isFalse
Warning
FTSEarlyStopping
is in beta and subject to change. For detailed usage information, seeEarlyStopping
.- es_phase_complete¶
Used to determine if the current phase’s early stopping criteria have been met.
- Type
- finetuningscheduler_callback¶
Reference to the
FinetuningScheduler
callback being used.
- check_on_train_epoch_end¶
Whether to run early stopping check at the end of the training epoch. If this is
False
, then the check runs at the end of the validation. Defaults toNone
similar toEarlyStopping
but is set toFalse
during setup unless overridden.- Type
- setup(trainer, pl_module, stage=None)[source]¶
Ensure a
FinetuningScheduler
is provided before beginning training.- Return type
- class finetuning_scheduler.fts_supporters.FTSState(_resume_fit_from_ckpt=False, _ft_epoch=0, _ft_global_steps=0, _curr_depth=0, _best_ckpt_depth=0, _ft_sync_props=(('epoch_progress.current.completed', '_ft_epoch'), ('epoch_loop.global_step', '_ft_global_steps')), _ft_sync_objects=None, _curr_thawed_params=<factory>, _fts_ckpt_metadata=<factory>)[source]¶
Bases:
object
Dataclass to encapsulate the
FinetuningScheduler
internal state.
- class finetuning_scheduler.fts_supporters.ScheduleImplMixin[source]¶
Bases:
abc.ABC
Functionality for generating and executing finetuning schedules.
- static add_optimizer_groups(module, optimizer, thawed_pl, no_decay=None, lr=None, initial_denom_lr=10.0)[source]¶
Add optimizer parameter groups associated with the next scheduled finetuning depth/level and extend the relevent
lr_scheduler_configs
.- Parameters
module¶ (
Module
) – TheModule
from which the target optimizer parameters will be read.optimizer¶ (
Optimizer
) – TheOptimizer
to which parameter groups will be configured and added.thawed_pl¶ (
List
) – The list of thawed/unfrozen parameters that should be added to the new parameter group(s)no_decay¶ (
Optional
[list
]) – A list of parameters that should always have weight_decay set to 0. e.g.: [“bias”, “LayerNorm.weight”]. Defaults toNone
.lr¶ (
Optional
[float
]) – The initial learning rate for the new parameter group(s). If not specified, thelr
of the first scheduled finetuning depth will be used. Defaults toNone
.initial_denom_lr¶ (
float
) – The scaling factor by which to scale the initial learning rate for new parameter groups when no initial learning rate is specified. Defaults to 10.0.
- Return type
- static exec_ft_phase(module, thaw_pl, init_thaw=False)[source]¶
Thaw/unfreeze the provided list of parameters in the provided
Module
- Parameters
- Returns
- A Tuple of two lists.
The list of newly thawed/unfrozen parameters thawed by this function
A list of all currently thawed/unfrozen parameters in the target
Module
- Return type
Tuple[List, List]
- static gen_ft_schedule(module, dump_loc)[source]¶
Generate the default finetuning schedule using a naive, 2-parameters per-level heuristic.
- Parameters
- Returns
The path to the generated schedule, by default
Trainer.log_dir
and named after theLightningModule
subclass in use with the suffix_ft_schedule.yaml
)- Return type
- gen_implicit_schedule(sched_dir)[source]¶
Generate the default schedule, save it to
sched_dir
and load it intoft_schedule
- gen_or_load_sched()[source]¶
Load an explicitly specified finetuning schedule if one provided, otherwise generate a default one.
- Return type
- init_ft_sched()[source]¶
Generate the default finetuning schedule and/or load it into
ft_schedule
. Broadcast the schedule to ensure it is available for use in a distributed context.- Return type
- init_fts()[source]¶
Initializes the finetuning schedule and prepares the first scheduled level 1. Generate the default finetuning schedule and/or load it into
ft_schedule
. 2. Prepare the first scheduled finetuning level, unfreezing the relevant parameters.- Return type
- static load_yaml_schedule(schedule_yaml_file)[source]¶
Load a schedule defined in a .yaml file and transform it into a dictionary.
- static save_schedule(schedule_name, layer_config, dump_loc)[source]¶
Save loaded or generated schedule to a directory to ensure reproducability.
- Parameters
layer_config¶ (Dict) – The saved schedule dictionary.
dump_loc¶ (os.PathLike) – The directory to which the generated schedule (.yaml) should be written
- Returns
The path to the generated schedule, by default
Trainer.log_dir
and named after theLightningModule
subclass in use with the suffix_ft_schedule.yaml
)- Return type
- class finetuning_scheduler.fts_supporters.ScheduleParsingMixin[source]¶
Bases:
abc.ABC
Functionality for parsing and validating finetuning schedules.
- class finetuning_scheduler.fts_supporters.UniqueKeyLoader(stream)[source]¶
Bases:
yaml.loader.SafeLoader
Alters SafeLoader to enable duplicate key detection by the SafeConstructor.
Initialize the scanner.
- construct_mapping(node, deep=False)[source]¶
Overrides the construct_mapping method of the SafeConstructor to raise a ValueError if duplicate keys are found.
Inspired by and adapated from https://stackoverflow.com/a/63215043
- Return type