Shortcuts

fts_supporters

Classes

CallbackDepMixin

Functionality for validating/managing callback dependencies.

CallbackResolverMixin

Give user-provided callbacks with the ability to connect to another user-provided callback.

FTSCheckpoint

Extends/specializes ModelCheckpoint to facilitate multi-phase scheduled fine-tuning.

FTSEarlyStopping

Extends/specializes EarlyStopping to facilitate multi-phase scheduled fine-tuning.

FTSState

Dataclass to encapsulate the FinetuningScheduler internal state.

ScheduleImplMixin

Functionality for generating and executing fine-tuning schedules.

ScheduleParsingMixin

Functionality for parsing and validating fine-tuning schedules.

UniqueKeyLoader

Alters SafeLoader to enable duplicate key detection by the SafeConstructor.

Fine-Tuning Scheduler Supporters

Classes composed to support scheduled fine-tuning

class finetuning_scheduler.fts_supporters.CallbackDepMixin[source]

Bases: abc.ABC

Functionality for validating/managing callback dependencies.

class finetuning_scheduler.fts_supporters.CallbackResolverMixin(callback_attrs=('ft_schedule', 'max_depth'), callback_parents={'EarlyStopping': <class 'pytorch_lightning.callbacks.early_stopping.EarlyStopping'>, 'ModelCheckpoint': <class 'pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint'>}, target_callback_ref='FinetuningScheduler', support_multiple=False)[source]

Bases: abc.ABC

Give user-provided callbacks with the ability to connect to another user-provided callback.

This resolution logic is provided in order to avoid callback-dependent trainer attributes (e.g. trainer.finetuningscheduler_callback)

Initialize the user-provided callback depedency resolver in accordance with the user-provided module configuration.

Parameters
  • callback_attrs (Tuple, optional) – Attribute signature of user-provided callback to be structurally detected and connected. Defaults to CALLBACK_ATTRS defined in the user-provided module.

  • callback_parents (Dict, optional) – The parent classes of all user-provided callbacks in the module that should be connected to the target user-provided callback. Defaults to CALLBACK_DEP_PARENTS in the user-provided module.

  • target_callback_ref (str, optional) – The name of the target user-provided callback to connect to. For each subclass of CALLBACK_DEP_PARENTS, an attribute named (target_callback_ref.lower())_callback will be added. Defaults to TARGET_CALLBACK_REF in the user-provided module.

  • support_multiple (bool, optional) – Whether multiple instances of the target user-provided callback (only the first of which will be connected to) are allowed. Defaults to False.

connect_callback(trainer, reconnect=False)[source]

Connect each user-provided callback dependency that needs to be connected to the target user-provided callback.

Parameters
  • trainer (pl.Trainer) – The Trainer object.

  • reconnect (bool, optional) – Whether to check for an updated target callback object even if one is already resolved. Predominantly useful in the context of testing. Defaults to False.

Raises
  • MisconfigurationException – If no target callback is detected

  • MisconfigurationException – if support_multiple is False and multiple target callbacks are detected.

Return type

None

class finetuning_scheduler.fts_supporters.FTSCheckpoint(*args, **kwargs)[source]

Bases: pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint, finetuning_scheduler.fts_supporters.CallbackResolverMixin

Extends/specializes ModelCheckpoint to facilitate multi-phase scheduled fine-tuning. Overrides the state_dict and load_state_dict hooks to maintain additional state (current_ckpt_depth, best_ckpt_depth, finetuningscheduler_callback). Usage of FTSCheckpoint is identical to ModelCheckpoint and FTSCheckpoint will automatically be used if a FinetuningScheduler callback is detected.

Warning

FTSCheckpoint is in beta and subject to change. For detailed usage information, see ModelCheckpoint.

current_ckpt_depth

Used to track the depth of most recently saved checkpoint

Type

int

best_ckpt_depth

Used to track the depth of the best known checkpoint (it may be from a different training depth)

Type

int

finetuningscheduler_callback

Reference to the FinetuningScheduler callback being used.

Type

pytorch_lightning.callbacks.Callback

save_on_train_epoch_end

Whether to run checkpointing at the end of the training epoch. If this is False, then the check runs at the end of the validation. Defaults to None similar to ModelCheckpoint but is set to False during setup unless overridden.

Type

Optional[bool]

load_state_dict(state_dict)[source]

Overrides ModelCheckpoint’s load_state_dict method to load multi-phase training depth state.

Parameters

state_dict (Dict[str, Any]) – the callback state dict of FTSCheckpoint.

Return type

None

setup(trainer, pl_module, stage=None)[source]

Verify a valid callback configuration is present before beginning training.

Parameters

trainer (Trainer) – The Trainer object

Raises
  • MisconfigurationException – If a FinetuningScheduler callback is not found on initialization (finetuningscheduler_callback is None)

  • MisconfigurationException – If restore_best is True and ModelCheckpoint.save_top_k is either None or 0

  • MisconfigurationException – If restore_best is True and monitor is None

Return type

None

state_dict()[source]

Overrides. ModelCheckpoint’s state_dict method to maintain multi-phase training depth state.

Returns

the callback state dictionary that will be saved.

Return type

Dict[str, Any]

class finetuning_scheduler.fts_supporters.FTSEarlyStopping(*args, **kwargs)[source]

Bases: pytorch_lightning.callbacks.early_stopping.EarlyStopping, finetuning_scheduler.fts_supporters.CallbackResolverMixin

Extends/specializes EarlyStopping to facilitate multi-phase scheduled fine-tuning.

Adds es_phase_complete, final_phase and finetuningscheduler_callback attributes and modifies EarlyStopping._evaluate_stopping_criteria to enable multi-phase behavior. Usage of FTSEarlyStopping is identical to EarlyStopping except the former will evaluate the specified early stopping criteria at every scheduled phase. FTSEarlyStopping will automatically be used if a FinetuningScheduler callback is detected and epoch_transitions_only is False

Warning

FTSEarlyStopping is in beta and subject to change. For detailed usage information, see EarlyStopping.

es_phase_complete

Used to determine if the current phase’s early stopping criteria have been met.

Type

bool

final_phase

Used to indicate whether the current phase is the final scheduled phase.

Type

bool

finetuningscheduler_callback

Reference to the FinetuningScheduler callback being used.

Type

pytorch_lightning.callbacks.Callback

check_on_train_epoch_end

Whether to run early stopping check at the end of the training epoch. If this is False, then the check runs at the end of the validation. Defaults to None similar to EarlyStopping but is set to False during setup unless overridden.

Type

bool

setup(trainer, pl_module, stage=None)[source]

Ensure a FinetuningScheduler is provided before beginning training.

Return type

None

class finetuning_scheduler.fts_supporters.FTSState(_resume_fit_from_ckpt=False, _ft_epoch=0, _ft_global_steps=0, _curr_depth=0, _best_ckpt_depth=0, _ft_sync_props=(('epoch_progress.current.completed', '_ft_epoch'), ('epoch_loop.global_step', '_ft_global_steps')), _ft_sync_objects=None, _curr_thawed_params=<factory>, _fts_ckpt_metadata=<factory>)[source]

Bases: object

Dataclass to encapsulate the FinetuningScheduler internal state.

class finetuning_scheduler.fts_supporters.ScheduleImplMixin[source]

Bases: abc.ABC

Functionality for generating and executing fine-tuning schedules.

static add_optimizer_groups(module, optimizer, thawed_pl, no_decay=None, lr=None, apply_lambdas=False)[source]

Add optimizer parameter groups associated with the next scheduled fine-tuning depth/level and extend the relevent lr_scheduler_configs.

Parameters
  • module (Module) – The Module from which the target optimizer parameters will be read.

  • optimizer (Optimizer) – The Optimizer to which parameter groups will be configured and added.

  • thawed_pl (List) – The list of thawed/unfrozen parameters that should be added to the new parameter group(s)

  • no_decay (Optional[list]) – A list of parameters that should always have weight_decay set to 0. e.g.: [“bias”, “LayerNorm.weight”]. Defaults to None.

  • lr (Optional[float]) – The initial learning rate for the new parameter group(s). If not specified, the lr of the first scheduled fine-tuning depth will be used. Defaults to None.

  • apply_lambdas (bool) – Whether to apply lr lambdas to newly added groups. Defaults to False.

Return type

None

static exec_ft_phase(module, thaw_pl, init_thaw=False)[source]

Thaw/unfreeze the provided list of parameters in the provided Module

Parameters
  • module (Module) – The Module that will have parameters selectively unfrozen/thawed.

  • thaw_pl (List) – The list of parameters that should be thawed/unfrozen in the Module

  • init_thaw (bool) – If True, modifies message to user accordingly. Defaults to False.

Returns

A Tuple of two lists.
  1. The list of newly thawed/unfrozen parameters thawed by this function

  2. A list of all currently thawed/unfrozen parameters in the target Module

Return type

Tuple[List, List]

static gen_ft_schedule(module, dump_loc)[source]

Generate the default fine-tuning schedule using a naive, 2-parameters per-level heuristic.

Parameters
  • module (Module) – The Module for which a fine-tuning schedule will be generated

  • dump_loc (Union[str, PathLike]) – The directory to which the generated schedule (.yaml) should be written

Returns

The path to the generated schedule, by default Trainer.log_dir and named after the LightningModule subclass in use with the suffix _ft_schedule.yaml)

Return type

os.PathLike

gen_implicit_schedule(sched_dir)[source]

Generate the default schedule, save it to sched_dir and load it into ft_schedule

Parameters

sched_dir (PathLike) – directory to which the generated schedule should be written. By default will be Trainer.log_dir.

Return type

None

gen_or_load_sched()[source]

Load an explicitly specified fine-tuning schedule if one provided, otherwise generate a default one.

Return type

None

init_ft_sched()[source]

Generate the default fine-tuning schedule and/or load it into ft_schedule. Broadcast the schedule to ensure it is available for use in a distributed context.

Return type

None

init_fts()[source]

Initializes the fine-tuning schedule and prepares the first scheduled level 1. Generate the default fine-tuning schedule and/or load it into ft_schedule. 2. Prepare the first scheduled fine-tuning level, unfreezing the relevant parameters.

Return type

None

static load_yaml_schedule(schedule_yaml_file)[source]

Load a schedule defined in a .yaml file and transform it into a dictionary.

Parameters

schedule_yaml_file (str) – The .yaml fine-tuning schedule file

Raises

MisconfigurationException – If the specified schedule file is not found

Returns

the Dict representation of the fine-tuning schedule

Return type

Dict

static save_schedule(schedule_name, layer_config, dump_loc)[source]

Save loaded or generated schedule to a directory to ensure reproducability.

Parameters
  • schedule_name (str) – The name of the schedule.

  • layer_config (Dict) – The saved schedule dictionary.

  • dump_loc (os.PathLike) – The directory to which the generated schedule (.yaml) should be written

Returns

The path to the generated schedule, by default Trainer.log_dir and named after the LightningModule subclass in use with the suffix _ft_schedule.yaml)

Return type

os.PathLike

static sync(objs, asets, agg_func=<built-in function max>)[source]

Synchronize sets of object attributes using a given aggregation function.

Parameters
  • objs (Tuple) – The target objects to synchronize

  • asets (Tuple) – The attribute sets to synchronize

  • agg_func (Callable) – The aggregation function use to synchronize the target object attribute sets. Defaults to max.

Return type

None

thaw_to_depth(depth=None)[source]

Thaw/unfreeze the current pl_module to the specified fine-tuning depth (aka level)

Parameters

depth (Optional[int]) – The depth/level to which the pl_module will be thawed. If no depth is is specified, curr_depth will be used. Defaults to None.

Return type

None

class finetuning_scheduler.fts_supporters.ScheduleParsingMixin[source]

Bases: abc.ABC

Functionality for parsing and validating fine-tuning schedules.

reinit_lr_scheduler(new_lr_scheduler, trainer, optimizer)[source]

Reinitialize the learning rate scheduler, using a validated learning rate scheduler configuration and wrapping the existing optimizer.

Parameters
  • new_lr_scheduler (Dict) – A dictionary defining the new lr scheduler configuration to be initialized.

  • trainer (pl.Trainer) – The Trainer object.

  • reinit_lr_scheduler.(class (optimizer) – ~torch.optim.Optimizer): The Optimizer around which the new lr scheduler will be wrapped.

Return type

None

class finetuning_scheduler.fts_supporters.UniqueKeyLoader(stream)[source]

Bases: yaml.loader.SafeLoader

Alters SafeLoader to enable duplicate key detection by the SafeConstructor.

Initialize the scanner.

construct_mapping(node, deep=False)[source]

Overrides the construct_mapping method of the SafeConstructor to raise a ValueError if duplicate keys are found.

Inspired by and adapated from https://stackoverflow.com/a/63215043

Return type

Dict

Read the Docs v: stable
Versions
latest
stable
v0.2.1
v0.2.0
v0.1.8
v0.1.7
v0.1.6
v0.1.5
v0.1.4
v0.1.3
v0.1.2
Downloads
html
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.