fts_supporters¶
Classes
Functionality for validating/managing callback dependencies. |
|
Give user-provided callbacks with the ability to connect to another user-provided callback. |
|
Extends/specializes |
|
Extends/specializes |
|
Dataclass to encapsulate the |
|
Functionality for generating and executing fine-tuning schedules. |
|
Functionality for parsing and validating fine-tuning schedules. |
|
Alters SafeLoader to enable duplicate key detection by the SafeConstructor. |
Fine-Tuning Scheduler Supporters¶
Classes composed to support scheduled fine-tuning
- class finetuning_scheduler.fts_supporters.CallbackDepMixin(callback_dep_parents={'EarlyStopping': <class 'pytorch_lightning.callbacks.early_stopping.EarlyStopping'>, 'ModelCheckpoint': <class 'pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint'>})[source]¶
Bases:
abc.ABC
Functionality for validating/managing callback dependencies.
Initialize the user-provided callback dependency validation in accordance with the user-provided module configuration.
- Parameters
callback_dep_parents¶ (Dict, optional) – The parent classes of all user-provided callbacks in the module that should be connected to the target user-provided callback. Defaults to CALLBACK_DEP_PARENTS in the user-provided module.
- class finetuning_scheduler.fts_supporters.CallbackResolverMixin(callback_attrs=('ft_schedule', 'max_depth'), callback_parents={'EarlyStopping': <class 'pytorch_lightning.callbacks.early_stopping.EarlyStopping'>, 'ModelCheckpoint': <class 'pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint'>}, target_callback_ref='FinetuningScheduler', support_multiple_targets=False)[source]¶
Bases:
abc.ABC
Give user-provided callbacks with the ability to connect to another user-provided callback.
This resolution logic is provided in order to avoid callback-dependent trainer attributes (e.g. trainer.finetuningscheduler_callback)
Initialize the user-provided callback depedency resolver in accordance with the user-provided module configuration.
- Parameters
callback_attrs¶ (Tuple, optional) – Attribute signature of user-provided callback to be structurally detected and connected. Defaults to CALLBACK_ATTRS defined in the user-provided module.
callback_parents¶ (Dict, optional) – The parent classes of all user-provided callbacks in the module that should be connected to the target user-provided callback. Defaults to CALLBACK_DEP_PARENTS in the user-provided module.
target_callback_ref¶ (str, optional) – The name of the target user-provided callback to connect to. For each subclass of CALLBACK_DEP_PARENTS, an attribute named
(target_callback_ref.lower())_callback
will be added. Defaults to TARGET_CALLBACK_REF in the user-provided module.support_multiple_targets¶ (bool, optional) – Whether multiple instances of the target user-provided callback (only the first of which will be connected to) are allowed. Defaults to False.
- connect_callback(trainer, reconnect=False)[source]¶
Connect each user-provided callback dependency that needs to be connected to the target user-provided callback.
- Parameters
- Raises
MisconfigurationException – If no target callback is detected
MisconfigurationException – if
support_multiple_targets
isFalse
and multiple target callbacks are detected.
- Return type
- class finetuning_scheduler.fts_supporters.FTSCheckpoint(*args, **kwargs)[source]¶
Bases:
pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint
,finetuning_scheduler.fts_supporters.CallbackResolverMixin
Extends/specializes
ModelCheckpoint
to facilitate multi-phase scheduled fine-tuning. Overrides thestate_dict
andload_state_dict
hooks to maintain additional state (current_ckpt_depth
,best_ckpt_depth
,finetuningscheduler_callback
). Usage ofFTSCheckpoint
is identical toModelCheckpoint
andFTSCheckpoint
will automatically be used if aFinetuningScheduler
callback is detected.Warning
FTSCheckpoint
is in beta and subject to change. For detailed usage information, seeModelCheckpoint
.Note
Currently,
FinetuningScheduler
supports the use of oneFTSCheckpoint
callback instance at a time.- best_ckpt_depth¶
Used to track the depth of the best known checkpoint (it may be from a different training depth)
- Type
- finetuningscheduler_callback¶
Reference to the
FinetuningScheduler
callback being used.
- save_on_train_epoch_end¶
Whether to run checkpointing at the end of the training epoch. If this is
False
, then the check runs at the end of the validation. Defaults toNone
similar toModelCheckpoint
but is set toFalse
during setup unless overridden.- Type
Optional[bool]
- load_state_dict(state_dict)[source]¶
Overrides
ModelCheckpoint
’sload_state_dict
method to load multi-phase training depth state.
- setup(trainer, pl_module, stage)[source]¶
Verify a valid callback configuration is present before beginning training.
- Parameters
- Raises
MisconfigurationException – If a
FinetuningScheduler
callback is not found on initialization (finetuningscheduler_callback
isNone
)MisconfigurationException – If
restore_best
isTrue
andModelCheckpoint.save_top_k
is eitherNone
or0
MisconfigurationException – If
restore_best
isTrue
andmonitor
isNone
- Return type
- state_dict()[source]¶
Overrides.
ModelCheckpoint
’sstate_dict
method to maintain multi-phase training depth state.- Returns
the callback state dictionary that will be saved.
- Return type
Dict[str, Any]
- class finetuning_scheduler.fts_supporters.FTSEarlyStopping(*args, **kwargs)[source]¶
Bases:
pytorch_lightning.callbacks.early_stopping.EarlyStopping
,finetuning_scheduler.fts_supporters.CallbackResolverMixin
Extends/specializes
EarlyStopping
to facilitate multi-phase scheduled fine-tuning.Adds
es_phase_complete
,final_phase
andfinetuningscheduler_callback
attributes and modifiesEarlyStopping._evaluate_stopping_criteria
to enable multi-phase behavior. Usage ofFTSEarlyStopping
is identical toEarlyStopping
except the former will evaluate the specified early stopping criteria at every scheduled phase.FTSEarlyStopping
will automatically be used if aFinetuningScheduler
callback is detected andepoch_transitions_only
isFalse
Warning
FTSEarlyStopping
is in beta and subject to change. For detailed usage information, seeEarlyStopping
.Note
Currently,
FinetuningScheduler
supports the use of oneFTSEarlyStopping
callback instance at a time.- es_phase_complete¶
Used to determine if the current phase’s early stopping criteria have been met.
- Type
- finetuningscheduler_callback¶
Reference to the
FinetuningScheduler
callback being used.
- reduce_transition_decisions¶
Used to indicate whether the callback is operating in a distributed context without the monitored metric being synchronized (via
sync_dist
being set toTrue
when logging).- Type
- check_on_train_epoch_end¶
Whether to run early stopping check at the end of the training epoch. If this is
False
, then the check runs at the end of the validation. Defaults toNone
similar toEarlyStopping
but is set toFalse
during setup unless overridden.- Type
- on_validation_end(trainer, pl_module)[source]¶
Ascertain whether the execution context of this callback requires that we reduce transition decisions over all distributed training processes.
- Parameters
pl_module¶ (
LightningModule
) – TheLightningModule
object
- Return type
- setup(trainer, pl_module, stage)[source]¶
Ensure a
FinetuningScheduler
is provided before beginning training.- Return type
- class finetuning_scheduler.fts_supporters.FTSState(_resume_fit_from_ckpt=False, _ft_epoch=0, _ft_global_steps=0, _curr_depth=0, _best_ckpt_depth=0, _ft_sync_props=(('epoch_progress.current.completed', '_ft_epoch'), ('epoch_loop.global_step', '_ft_global_steps')), _ft_sync_objects=None, _curr_thawed_params=<factory>, _fts_ckpt_metadata=<factory>)[source]¶
Bases:
object
Dataclass to encapsulate the
FinetuningScheduler
internal state.
- class finetuning_scheduler.fts_supporters.ScheduleImplMixin[source]¶
Bases:
abc.ABC
Functionality for generating and executing fine-tuning schedules.
- static add_optimizer_groups(module, optimizer, thawed_pl, no_decay=None, lr=None, apply_lambdas=False)[source]¶
Add optimizer parameter groups associated with the next scheduled fine-tuning depth/level and extend the relevent
lr_scheduler_configs
.- Parameters
module¶ (
Module
) – TheModule
from which the target optimizer parameters will be read.optimizer¶ (
Optimizer
) – TheOptimizer
to which parameter groups will be configured and added.thawed_pl¶ (
List
) – The list of thawed/unfrozen parameters that should be added to the new parameter group(s)no_decay¶ (
Optional
[list
]) – A list of parameters that should always have weight_decay set to 0. e.g.: [“bias”, “LayerNorm.weight”]. Defaults toNone
.lr¶ (
Optional
[float
]) – The initial learning rate for the new parameter group(s). If not specified, thelr
of the first scheduled fine-tuning depth will be used. Defaults toNone
.apply_lambdas¶ (
bool
) – Whether to apply lr lambdas to newly added groups. Defaults to False.
- Return type
- static exec_ft_phase(module, thaw_pl, init_thaw=False)[source]¶
Thaw/unfreeze the provided list of parameters in the provided
Module
- Parameters
- Returns
- A Tuple of two lists.
The list of newly thawed/unfrozen parameters thawed by this function
A list of all currently thawed/unfrozen parameters in the target
Module
- Return type
Tuple[List, List]
- static gen_ft_schedule(module, dump_loc)[source]¶
Generate the default fine-tuning schedule using a naive, 2-parameters per-level heuristic.
- Parameters
- Returns
The path to the generated schedule, by default
Trainer.log_dir
and named after theLightningModule
subclass in use with the suffix_ft_schedule.yaml
)- Return type
- gen_implicit_schedule(sched_dir)[source]¶
Generate the default schedule, save it to
sched_dir
and load it intoft_schedule
- gen_or_load_sched()[source]¶
Load an explicitly specified fine-tuning schedule if one provided, otherwise generate a default one.
- Return type
- init_ft_sched()[source]¶
Generate the default fine-tuning schedule and/or load it into
ft_schedule
. Broadcast the schedule to ensure it is available for use in a distributed context.- Return type
- init_fts()[source]¶
Initializes the fine-tuning schedule and prepares the first scheduled level 1. Generate the default fine-tuning schedule and/or load it into
ft_schedule
. 2. Prepare the first scheduled fine-tuning level, unfreezing the relevant parameters.- Return type
- static load_yaml_schedule(schedule_yaml_file)[source]¶
Load a schedule defined in a .yaml file and transform it into a dictionary.
- static save_schedule(schedule_name, layer_config, dump_loc)[source]¶
Save loaded or generated schedule to a directory to ensure reproducability.
- Parameters
layer_config¶ (Dict) – The saved schedule dictionary.
dump_loc¶ (os.PathLike) – The directory to which the generated schedule (.yaml) should be written
- Returns
The path to the generated schedule, by default
Trainer.log_dir
and named after theLightningModule
subclass in use with the suffix_ft_schedule.yaml
)- Return type
- class finetuning_scheduler.fts_supporters.ScheduleParsingMixin[source]¶
Bases:
abc.ABC
Functionality for parsing and validating fine-tuning schedules.
- class finetuning_scheduler.fts_supporters.UniqueKeyLoader(stream)[source]¶
Bases:
yaml.loader.SafeLoader
Alters SafeLoader to enable duplicate key detection by the SafeConstructor.
Initialize the scanner.
- construct_mapping(node, deep=False)[source]¶
Overrides the construct_mapping method of the SafeConstructor to raise a ValueError if duplicate keys are found.
Inspired by and adapated from https://stackoverflow.com/a/63215043
- Return type