Shortcuts

FSDP Scheduled Fine-Tuning

Overview

FinetuningScheduler (FTS) now supports flexible, multi-phase, scheduled fine-tuning with the Fully Sharded Data Parallel (FSDP) strategy ( FSDPStrategy). This tutorial assumes a basic understanding of FSDP training, please see this PyTorch tutorial for a good introduction to FSDP training.

As with standard FSDP usage, FSDP wrapping of a LightningModule can be performed either by providing an auto_wrap_policy or (for maximal control) by overriding the configure_sharded_model method of LightningModule and manually wrapping the module.

This tutorial walks through the configuration of an example multi-phase, scheduled FSDP fine-tuning training session and largely uses the same code as the basic scheduled fine-tuning for SuperGLUE examples.

Example: Multi-Phase Scheduled Fine-Tuning with FSDP

Demonstration FTS FSDP training/profiling configurations and a DDP baseline for comparison are available under ./fts_examples/stable/config/advanced/fsdp.

This FTS FSDP training example has the same dependencies as the basic scheduled fine-tuning for SuperGLUE examples except PyTorch >= 2.0 is required.

Note

This version of FSDPStrategyAdapter supports stable PyTorch releases >= 2.0.

Note

The examples below are not configured to execute a full training session but instead to generate the minimal meaningful profiling statistics for analysis and exposition (e.g. using only 2 batches, very limited epochs, etc.)

The demo schedule configurations are composed with the basic FTS example’s shared defaults (./config/fts_defaults.yaml) and can be executed as follows:

cd ./fts_examples/stable

# Profiled demo of FSDP scheduled fine-tuning using the ``awp_overrides`` option:
python fts_superglue.py fit --config config/advanced/fsdp/fts_fsdp_awp_overrides_profile.yaml

# Profiled demo of comparable DDP scheduled fine-tuning baseline:
python fts_superglue.py fit --config config/advanced/fsdp/fts_ddp_fsdp_baseline_profile.yaml

# Profiled demo of FSDP scheduled fine-tuning with CPU Offloading but full precision
# (for reference, not reviewed in this tutorial)
python fts_superglue.py fit --config config/advanced/fsdp/fts_fsdp_awp_overrides_offload_profile.yaml

FSDP Wrapping For Scheduled Fine-Tuning

As with standard FSDP module wrapping, one can use an auto_wrap_policy to wrap a model for FSDP scheduled fine-tuning. In the current FTS release, there is only one FTS-specific FSDP configuration enhancement to consider: the awp_overrides list.

awp_overrides is an optional list of module names that should be wrapped in separate FSDP instances, complementing the modules that would be individually wrapped by auto_wrap_policy provided in the FSDPStrategy strategy configuration.

Starting with a defined auto_wrap_policy and providing module name-based complements/overrides as needed using awp_overrides is often the most expedient approach to auto-wrapping models in alignment with a fine-tuning schedule.

We start by defining a simple fine-tuning schedule that we would like to ensure our module wrapping supports:

 10:
 2  params:
 3  - model.classifier.*
 4  max_transition_epoch: 1
 51:
 6  params:
 7  - model.pooler.dense.*
 8  - model.deberta.encoder.layer.11.(output|attention|intermediate).*
 9  max_transition_epoch: 2
102:
11  params:
12  - model.deberta.encoder.layer.([0-9]|10).(output|attention|intermediate).*
13  - model.deberta.encoder.LayerNorm.bias
14  - model.deberta.encoder.LayerNorm.weight
15  - model.deberta.encoder.rel_embeddings.weight
16  # excluding these parameters from the schedule to enhance the debugging demonstration
17  #- model.deberta.embeddings.LayerNorm.bias
18  #- model.deberta.embeddings.LayerNorm.weight
19  #- model.deberta.embeddings.word_embeddings.weight

We define the auto_wrap_policy for our DeBERTa-v3 module as follows:

 1strategy:
 2  class_path: lightning.pytorch.strategies.FSDPStrategy
 3  init_args:
 4    # other FSDP args as desired ...
 5    auto_wrap_policy:
 6      class_path: torch.distributed.fsdp.wrap.ModuleWrapPolicy
 7      init_args:
 8        module_classes: !!set
 9          ? transformers.models.deberta_v2.modeling_deberta_v2.DebertaV2Layer
10          ? transformers.models.deberta_v2.modeling_deberta_v2.DebertaV2Embeddings
11          ? transformers.models.deberta_v2.modeling_deberta_v2.DebertaV2Encoder

We’ll inspect the rationale for this policy below, but first, notice we have not referenced our classifier and pooler layers. Because we would like to thaw our classifier and pooler layers in separate phases from some other layers, we need to separately wrap these layers as well. If we specified separate wrapping of all Linear layers however in our auto_wrap_policy, we would end up unnecessarily (and in many cases problematically) separately wrapping the many Linear layers within our currently FSDP wrapped modules (DebertaV2Layer etc.).

To facilitate module wrapping in alignment with fine-tuning schedule phases, FTS provides the awp_overrides feature which allows users to provide module name-based complements to a given auto_wrap_policy.

In this case, simply listing the names of (or regex patterns matching) modules we would like to separately wrap allows us to achieve FSDP wrapping that aligns with our fine-tuning schedule. FTS support for FSDP training is provided via a StrategyAdapter (FSDPStrategyAdapter). Configuration for FTS-extensions of strategies like FSDP is passed to FTS via the strategy_adapter_cfg configuration dictionary.

So in our example, we can pass the awp_overrides configuration option to FTS like so:

1# in ./fts_examples/stable/config/advanced/fsdp/fts_fsdp_awp_overrides_profile.yaml
2...
3  - class_path: finetuning_scheduler.FinetuningScheduler
4  init_args:
5    ft_schedule: ./config/RteBoolqModule_ft_schedule_deberta_base_fsdp.yaml
6    max_depth: 2
7    strategy_adapter_cfg:
8      awp_overrides: ["model.pooler.dense", "model.classifier"]
9...

Finally, we configure the FSDP training strategy as desired per usual, for instance, specifying activation_checkpointing and cpu_offload configurations in addition the auto_wrap_policy we defined above:

 1# in ./fts_examples/stable/config/advanced/fsdp/fts_fsdp_awp_overrides_profile.yaml
 2  ...
 3  strategy:
 4    class_path: lightning.pytorch.strategies.FSDPStrategy
 5    init_args:
 6      cpu_offload: false
 7      activation_checkpointing:
 8      - transformers.models.deberta_v2.modeling_deberta_v2.DebertaV2Layer
 9      auto_wrap_policy:
10        class_path: torch.distributed.fsdp.wrap.ModuleWrapPolicy
11        init_args:
12          module_classes: !!set
13            ? transformers.models.deberta_v2.modeling_deberta_v2.DebertaV2Layer
14            ? transformers.models.deberta_v2.modeling_deberta_v2.DebertaV2Embeddings
15            ? transformers.models.deberta_v2.modeling_deberta_v2.DebertaV2Encoder

That’s all there is to it! We’ve successfully defined our fine-tuning schedule and FSDP wrapped our model in a manner that supports FSDP multi-phase scheduled fine-tuning.

Additional FSDP Wrapping and Debugging Guidance

In order to support multi-phase scheduled fine-tuning with FSDP, FTS’s key precondition is that the defined fine-tuning schedule phases have disjoint sets of FSDP-flattened parameters (a FlatParameter is created when wrapping a set of modules in a FSDP instance/unit). This constraint is derived from the fact that the requires_grad attribute currently (as of PyTorch 2.0.0) must be the same for all parameters flattened into the same FlatParameter. 1

FTS will attempt to validate that the module is wrapped in a manner that aligns with the defined fine-tuning schedule phases prior to the start of training and provide detailed feedback for the user if a misalignment is discovered.

For example, note that because we wanted to thaw some DebertaV2Layer s separately from others, we directed FSDP to wrap DebertaV2Layer s in their own FSDP instances rather than just the entire DebertaV2Encoder.

What happens if we just direct FSDP to wrap DebertaV2Layer s and not DebertaV2Encoder s and DebertaV2Embeddings as well?

FTS stops before beginning training and provides extensive context via this error message:

"Fine-tuning schedule phases do not have disjoint FSDP-flattened parameter sets. Because the `requires_grad` attribute of FSDP-flattened parameters currently must be the same for all flattened parameters, fine-tuning schedules must avoid thawing parameters in the same FSDP-flattened parameter in different phases. Please ensure parameters associated with each phase are wrapped in separate phase-aligned FSDP instances.

In this particular case, there are parameters not included in your fine-tuning schedule that span more than one fine-tuning phase. HINT: parameters associated with unwrapped modules will be included in the top-level (aka 'root') FSDP instance so ensuring all modules associated with fine-tuning scheduled parameters are wrapped separately from the top-level FSDP instance may avoid triggering this exception.

The following logical parameters are associated with an FSDP-flattened parameter that spans more than one fine-tuning phase. The mapping of each logical parameter with the module name wrapped by its associated FSDP instance is provided below:

{'model.deberta.embeddings.LayerNorm.bias': 'DebertaV2ForSequenceClassification',
 'model.deberta.embeddings.LayerNorm.weight': 'DebertaV2ForSequenceClassification',
 'model.deberta.embeddings.word_embeddings.weight': 'DebertaV2ForSequenceClassification',
 'model.deberta.encoder.LayerNorm.bias': 'DebertaV2ForSequenceClassification',
 'model.deberta.encoder.LayerNorm.weight': 'DebertaV2ForSequenceClassification',
 'model.deberta.encoder.rel_embeddings.weight': 'DebertaV2ForSequenceClassification'}"

This helps us understand that we have parameters that all belong to the same top-level FSDP instance (the instance that wraps DebertaV2ForSequenceClassification). By failing to specify separate wrapping of DebertaV2Encoder s, parameters associated with that module fell to the top-level/root FSDP instance to be managed. While DebertaV2Embeddings parameters were not included in our schedule, they still must be wrapped by FSDP and so also are included with DebertaV2Encoder parameters in the same top-level FlatParameter. If training had been permitted to proceed in this case, DebertaV2Embeddings parameters would have been thawed along with the DebertaV2Encoder parameters in phase 2, violating of our specified fine-tuning schedule.

To avoid violating the phase-wise disjointness constraint, we add DebertaV2Encoder to our auto_wrap_policy. While not technically required, we add DebertaV2Embeddings separately as well for future experimental flexibility.

As always, if needed, one can alternatively override configure_sharded_model and manually wrap a given LightningModule to align with a desired fine-tuning schedule.

Warning

FSDPStrategyAdapter is in BETA and subject to change. The interface can bring breaking changes and new features with the next release of PyTorch.

Note

The no_decay attribute that FTS supports on LightningModule with the base StrategyAdapter is not currently supported in the context of FSDP fine-tuning.

Tip

If you want to extend FTS to use a custom, currently unsupported strategy or override current FTS behavior with a given training strategy, subclassing StrategyAdapter is a way to do so.

Footnotes

1

Once this PyTorch FSDP feature is implemented, PyTorch should allow FlatParameter s constructed in use_orig_params mode to contain original params with non-uniform requires_grad. Depending upon the implementation of that feature, some of the aforementioned constraints on FSDP fine-tuning schedules may be relaxed.

Read the Docs v: v2.0.0
Versions
latest
stable
v2.0.0
v0.4.1
v0.4.0
v0.3.4
v0.3.3
v0.3.2
v0.3.1
v0.3.0
v0.2.3
v0.2.2
v0.2.1
v0.2.0
v0.1.8
v0.1.7
v0.1.6
v0.1.5
v0.1.4
v0.1.3
v0.1.2
Downloads
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.