Shortcuts

fts_supporters

Classes

CallbackDepMixin

Functionality for validating/managing callback dependencies.

CallbackResolverMixin

Give user-provided callbacks with the ability to connect to another user-provided callback.

FTSCheckpoint

Extends/specializes ModelCheckpoint to facilitate multi-phase scheduled fine-tuning.

FTSEarlyStopping

Extends/specializes EarlyStopping to facilitate multi-phase scheduled fine-tuning.

FTSState

Dataclass to encapsulate the FinetuningScheduler internal state.

ScheduleImplMixin

Functionality for generating and executing fine-tuning schedules.

ScheduleParsingMixin

Functionality for parsing and validating fine-tuning schedules.

UniqueKeyLoader

Alters SafeLoader to enable duplicate key detection by the SafeConstructor.

Fine-Tuning Scheduler Supporters

Classes composed to support scheduled fine-tuning

class finetuning_scheduler.fts_supporters.CallbackDepMixin(callback_dep_parents={'EarlyStopping': <class 'lightning.pytorch.callbacks.early_stopping.EarlyStopping'>, 'ModelCheckpoint': <class 'lightning.pytorch.callbacks.model_checkpoint.ModelCheckpoint'>})[source]

Bases: ABC

Functionality for validating/managing callback dependencies.

Arguments used to initialize the user-provided callback dependency validation in accordance with the user-provided module configuration:

Parameters:

callback_dep_parents (Dict, optional) – The parent classes of all user-provided callbacks in the module that should be connected to the target user-provided callback. Defaults to CALLBACK_DEP_PARENTS in the user-provided module.

class finetuning_scheduler.fts_supporters.CallbackResolverMixin(callback_attrs=('ft_schedule', 'max_depth'), callback_parents={'EarlyStopping': <class 'lightning.pytorch.callbacks.early_stopping.EarlyStopping'>, 'ModelCheckpoint': <class 'lightning.pytorch.callbacks.model_checkpoint.ModelCheckpoint'>}, target_callback_ref='FinetuningScheduler', support_multiple_targets=False)[source]

Bases: ABC

Give user-provided callbacks with the ability to connect to another user-provided callback.

This resolution logic is provided in order to avoid callback-dependent trainer attributes (e.g. trainer.finetuningscheduler_callback)

Arguments used to initialize the user-provided callback depedency resolver in accordance with the user- provided module configuration:

Parameters:
  • callback_attrs (Tuple, optional) – Attribute signature of user-provided callback to be structurally detected and connected. Defaults to CALLBACK_ATTRS defined in the user-provided module.

  • callback_parents (Dict, optional) – The parent classes of all user-provided callbacks in the module that should be connected to the target user-provided callback. Defaults to CALLBACK_DEP_PARENTS in the user-provided module.

  • target_callback_ref (str, optional) – The name of the target user-provided callback to connect to. For each subclass of CALLBACK_DEP_PARENTS, an attribute named (target_callback_ref.lower())_callback will be added. Defaults to TARGET_CALLBACK_REF in the user-provided module.

  • support_multiple_targets (bool, optional) – Whether multiple instances of the target user-provided callback (only the first of which will be connected to) are allowed. Defaults to False.

connect_callback(trainer, reconnect=False)[source]

Connect each user-provided callback dependency that needs to be connected to the target user-provided callback.

Parameters:
  • trainer (pl.Trainer) – The Trainer object.

  • reconnect (bool, optional) – Whether to check for an updated target callback object even if one is already resolved. Predominantly useful in the context of testing. Defaults to False.

Raises:
  • MisconfigurationException – If no target callback is detected

  • MisconfigurationException – if support_multiple_targets is False and multiple target callbacks are detected.

Return type:

None

class finetuning_scheduler.fts_supporters.FTSCheckpoint(*args, **kwargs)[source]

Bases: ModelCheckpoint, CallbackResolverMixin

Extends/specializes ModelCheckpoint to facilitate multi-phase scheduled fine-tuning. Overrides the state_dict and load_state_dict hooks to maintain additional state (current_ckpt_depth, best_ckpt_depth, finetuningscheduler_callback). Usage of FTSCheckpoint is identical to ModelCheckpoint and FTSCheckpoint will automatically be used if a FinetuningScheduler callback is detected.

Note

For detailed usage information, see ModelCheckpoint.

Note

Currently, FinetuningScheduler supports the use of one FTSCheckpoint callback instance at a time.

current_ckpt_depth

Used to track the depth of most recently saved checkpoint

Type:

int

best_ckpt_depth

Used to track the depth of the best known checkpoint (it may be from a different training depth)

Type:

int

finetuningscheduler_callback

Reference to the FinetuningScheduler callback being used.

Type:

lightning.pytorch.callbacks.Callback

save_on_train_epoch_end

Whether to run checkpointing at the end of the training epoch. If this is False, then the check runs at the end of the validation. Defaults to None similar to ModelCheckpoint but is set to False during setup unless overridden.

Type:

Optional[bool]

load_state_dict(state_dict)[source]

Overrides ModelCheckpoint’s load_state_dict method to load multi-phase training depth state.

Parameters:

state_dict (Dict[str, Any]) – the callback state dict of FTSCheckpoint.

Return type:

None

setup(trainer, pl_module, stage)[source]

Verify a valid callback configuration is present before beginning training.

Parameters:

trainer (Trainer) – The Trainer object

Raises:
  • MisconfigurationException – If a FinetuningScheduler callback is not found on initialization (finetuningscheduler_callback is None)

  • MisconfigurationException – If restore_best is True and ModelCheckpoint.save_top_k is either None or 0

  • MisconfigurationException – If restore_best is True and monitor is None

Return type:

None

state_dict()[source]

Overrides. ModelCheckpoint’s state_dict method to maintain multi-phase training depth state.

Returns:

the callback state dictionary that will be saved.

Return type:

Dict[str, Any]

class finetuning_scheduler.fts_supporters.FTSEarlyStopping(*args, **kwargs)[source]

Bases: EarlyStopping, CallbackResolverMixin

Extends/specializes EarlyStopping to facilitate multi-phase scheduled fine-tuning.

Adds es_phase_complete, final_phase and finetuningscheduler_callback attributes and modifies EarlyStopping._evaluate_stopping_criteria to enable multi-phase behavior. Usage of FTSEarlyStopping is identical to EarlyStopping except the former will evaluate the specified early stopping criteria at every scheduled phase. FTSEarlyStopping will automatically be used if a FinetuningScheduler callback is detected and epoch_transitions_only is False

Note

For detailed usage information, see EarlyStopping.

Note

Currently, FinetuningScheduler supports the use of one FTSEarlyStopping callback instance at a time.

es_phase_complete

Used to determine if the current phase’s early stopping criteria have been met.

Type:

bool

final_phase

Used to indicate whether the current phase is the final scheduled phase.

Type:

bool

finetuningscheduler_callback

Reference to the FinetuningScheduler callback being used.

Type:

lightning.pytorch.callbacks.Callback

reduce_transition_decisions

Used to indicate whether the callback is operating in a distributed context without the monitored metric being synchronized (via sync_dist being set to True when logging).

Type:

bool

check_on_train_epoch_end

Whether to run early stopping check at the end of the training epoch. If this is False, then the check runs at the end of the validation. Defaults to None similar to EarlyStopping but is set to False during setup unless overridden.

Type:

bool

on_validation_end(trainer, pl_module)[source]

Ascertain whether the execution context of this callback requires that we reduce transition decisions over all distributed training processes.

Parameters:
Return type:

None

setup(trainer, pl_module, stage)[source]

Ensure a FinetuningScheduler is provided before beginning training.

Return type:

None

class finetuning_scheduler.fts_supporters.FTSState(_resume_fit_from_ckpt=False, _ft_epoch=0, _ft_global_steps=0, _curr_depth=0, _best_ckpt_depth=0, _ft_sync_props=(('epoch_progress.current.completed', '_ft_epoch'), ('epoch_loop.global_step', '_ft_global_steps')), _ft_sync_objects=None, _ft_init_epoch=None, _curr_thawed_params=<factory>, _fts_ckpt_metadata=<factory>)[source]

Bases: object

Dataclass to encapsulate the FinetuningScheduler internal state.

class finetuning_scheduler.fts_supporters.ScheduleImplMixin[source]

Bases: ABC

Functionality for generating and executing fine-tuning schedules.

static add_optimizer_groups(module, optimizer, thawed_pl, no_decay=None, lr=None, apply_lambdas=False)[source]

Add optimizer parameter groups associated with the next scheduled fine-tuning depth/level and extend the relevent lr_scheduler_configs.

Parameters:
  • module (Module) – The Module from which the target optimizer parameters will be read.

  • optimizer (ParamGroupAddable) – The supported optimizer instance to which parameter groups will be configured and added.

  • thawed_pl (List) – The list of thawed/unfrozen parameters that should be added to the new parameter group(s)

  • no_decay (Optional[list]) – A list of parameters that should always have weight_decay set to 0. e.g.: [“bias”, “LayerNorm.weight”]. Defaults to None.

  • lr (Optional[float]) – The initial learning rate for the new parameter group(s). If not specified, the lr of the first scheduled fine-tuning depth will be used. Defaults to None.

  • apply_lambdas (bool) – Whether to apply lr lambdas to newly added groups. Defaults to False.

Return type:

None

Note

If one relies upon the default FTS schedule, the lr provided to this method will be base_max_lr which defaults to 1e-05.

static gen_ft_schedule(module, dump_loc)[source]

Generate the default fine-tuning schedule using a naive, 2-parameters per-level heuristic.

Parameters:
  • module (Module) – The Module for which a fine-tuning schedule will be generated

  • dump_loc (Union[str, PathLike]) – The directory to which the generated schedule (.yaml) should be written

Returns:

The path to the generated schedule, by default Trainer.log_dir and named after the LightningModule subclass in use with the suffix _ft_schedule.yaml)

Return type:

os.PathLike

gen_implicit_schedule(sched_dir)[source]

Generate the default schedule, save it to sched_dir and load it into ft_schedule

Parameters:

sched_dir (Union[str, PathLike]) – directory to which the generated schedule should be written. By default will be Trainer.log_dir.

Return type:

None

gen_or_load_sched()[source]

Load an explicitly specified fine-tuning schedule if one provided, otherwise generate a default one.

Return type:

None

init_ft_sched()[source]

Generate the default fine-tuning schedule and/or load it into ft_schedule. Broadcast the schedule to ensure it is available for use in a distributed context.

Return type:

None

init_fts()[source]

Initializes the fine-tuning schedule and prepares the first scheduled level.

Calls the relevant StrategyAdapter hooks before and after fine-tuning schedule initialization. 1. Generate the default fine-tuning schedule and/or load it into ft_schedule. 2. Prepare the first scheduled fine-tuning level, unfreezing the relevant parameters.

Return type:

None

static load_yaml_schedule(schedule_yaml_file)[source]

Load a schedule defined in a .yaml file and transform it into a dictionary.

Parameters:

schedule_yaml_file (str) – The .yaml fine-tuning schedule file

Raises:

MisconfigurationException – If the specified schedule file is not found

Returns:

the Dict representation of the fine-tuning schedule

Return type:

Dict

static save_schedule(schedule_name, layer_config, dump_loc)[source]

Save loaded or generated schedule to a directory to ensure reproducability.

Parameters:
  • schedule_name (str) – The name of the schedule.

  • layer_config (Dict) – The saved schedule dictionary.

  • dump_loc (os.PathLike) – The directory to which the generated schedule (.yaml) should be written

Returns:

The path to the generated schedule, by default Trainer.log_dir and named after the LightningModule subclass in use with the suffix _ft_schedule.yaml)

Return type:

os.PathLike

static sync(objs, asets, agg_func=<built-in function max>)[source]

Synchronize sets of object attributes using a given aggregation function.

Parameters:
  • objs (Tuple) – The target objects to synchronize

  • asets (Tuple) – The attribute sets to synchronize

  • agg_func (Callable) – The aggregation function use to synchronize the target object attribute sets. Defaults to max.

Return type:

None

thaw_to_depth(depth=None)[source]

Thaw/unfreeze the current pl_module to the specified fine-tuning depth (aka level)

Parameters:

depth (Optional[int]) – The depth/level to which the pl_module will be thawed. If no depth is is specified, curr_depth will be used. Defaults to None.

Return type:

None

class finetuning_scheduler.fts_supporters.ScheduleParsingMixin[source]

Bases: ABC

Functionality for parsing and validating fine-tuning schedules.

reinit_lr_scheduler(new_lr_scheduler, trainer, optimizer)[source]

Reinitialize the learning rate scheduler, using a validated learning rate scheduler configuration and wrapping the existing optimizer.

Parameters:
  • new_lr_scheduler (Dict) – A dictionary defining the new lr scheduler configuration to be initialized.

  • trainer (pl.Trainer) – The Trainer object.

  • optimizer (ParamGroupAddable) – A supported optimizer instance around which the new lr scheduler will be wrapped.

Return type:

None

reinit_optimizer(new_optimizer, trainer, init_params)[source]

Reinitialize the optimizer, using a validated optimizer reinitialization configuration.

Parameters:
  • new_optimizer (Dict) – A dictionary defining the new optimizer configuration to be initialized.

  • trainer (pl.Trainer) – The Trainer object.

  • init_params (List) – The list of parameter names with which to initialize the new optimizer.

Returns:

A handle for the newly reinitialized optimizer.

Return type:

ParamGroupAddable

class finetuning_scheduler.fts_supporters.UniqueKeyLoader(stream)[source]

Bases: SafeLoader

Alters SafeLoader to enable duplicate key detection by the SafeConstructor.

Initialize the scanner.

construct_mapping(node, deep=False)[source]

Overrides the construct_mapping method of the SafeConstructor to raise a ValueError if duplicate keys are found.

Inspired by and adapated from https://stackoverflow.com/a/63215043

Return type:

Dict

Read the Docs v: stable
Versions
latest
stable
v2.2.1
v2.2.0
v2.1.4
v2.1.3
v2.1.2
v2.1.1
v2.1.0
v2.0.9
v2.0.7
v2.0.6
v2.0.4
v2.0.2
v2.0.1
v2.0.0
v0.4.1
v0.4.0
v0.3.4
v0.3.3
v0.3.2
v0.3.1
v0.3.0
v0.2.3
v0.2.2
v0.2.1
v0.2.0
v0.1.8
v0.1.7
v0.1.6
v0.1.5
v0.1.4
v0.1.3
v0.1.2
Downloads
html
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.