Pipeline Parallelism¶
Model Specification¶
-
class
deepspeed.pipe.
PipelineModule
(layers, num_stages=None, topology=None, loss_fn=None, seed_layers=False, seed_fn=None, base_seed=1234, partition_method='parameters', activation_checkpoint_interval=0, activation_checkpoint_func=<function checkpoint>)[source]¶ Modules to be parallelized with pipeline parallelism.
The key constraint that enables pipeline parallelism is the representation of the forward pass as a sequence of layers and the enforcement of a simple interface between them. The forward pass is implicitly defined by the module
layers
. The key assumption is that the output of each layer can be directly fed as input to the next, like atorch.nn.Sequence
. The forward pass is implicitly:def forward(self, inputs): x = inputs for layer in self.layers: x = layer(x) return x
Parameters: - layers (Iterable) – A sequence of layers defining pipeline structure. Can be a
torch.nn.Sequential
module. - num_stages (int, optional) – The degree of pipeline parallelism. If not specified,
topology
must be provided. - topology (
deepseed.pipe.ProcessTopology
, optional) – Defines the axes of parallelism axes for training. Must be provided ifnum_stages
isNone
. - loss_fn (callable, optional) – Loss is computed
loss = loss_fn(outputs, label)
- base_seed (int, optional) – [description]. Defaults to 1234.
- partition_method (str, optional) – [description]. Defaults to ‘parameters’.
- activation_checkpoint_interval (int, optional) – The granularity activation checkpointing in terms of number of layers. 0 disables activation checkpointing.
- activation_checkpoint_func (callable, optional) – The function to use for activation checkpointing. Defaults to
deepspeed.checkpointing.checkpoint
.
-
allreduce_tied_weight_gradients
()[source]¶ All reduce the gradients of the tied weights between tied stages
- layers (Iterable) – A sequence of layers defining pipeline structure. Can be a
-
class
deepspeed.pipe.
LayerSpec
(typename, *module_args, **module_kwargs)[source]¶ Building block for specifying pipeline-parallel modules.
LayerSpec stores the type information and parameters for each stage in a PipelineModule. For example:
nn.Sequence( torch.nn.Linear(self.in_dim, self.hidden_dim, bias=False), torch.nn.Linear(self.hidden_hidden, self.out_dim) )
becomes
layer_specs = [ LayerSpec(torch.nn.Linear, self.in_dim, self.hidden_dim, bias=False), LayerSpec(torch.nn.Linear, self.hidden_hidden, self.out_dim)] ]
Training¶
-
class
deepspeed.runtime.pipe.engine.
PipelineEngine
(*super_args, **super_kwargs)[source]¶ A training engine hybrid pipeline, data, and model parallel training.
This engine is created by
deepspeed.initialize()
when aPipelineModule
is provided.-
train_batch
(data_iter=None)[source]¶ Progress the pipeline to train the next batch of data. The engine will ingest
self.train_batch_size()
total samples collectively across all workers.An iterator that over training data should be provided as an argument unless
deepspeed.initialize()
was provided a training set. In that event, the training data will automatically be read.Warning
A total of
self.gradient_accumulation_steps()
entries will be pulled fromdata_iter
by each pipeline. There must be sufficient data left indata_iter
or else aStopIteration
will halt training.DeepSpeed provides a convenience class
deepspeed.utils.RepeatingLoader
that wraps data loaders to automatically restart upon aStopIteration
.Parameters: data_iter (Iterator, optional) – Iterator of training data. Returns: The arithmetic mean of the losses computed this batch.
-
eval_batch
(data_iter)[source]¶ Evaluate the pipeline on a batch of data from
data_iter
. The engine will evaluateself.train_batch_size()
total samples collectively across all workers.This method is equivalent to:
module.eval() with torch.no_grad(): output = module(batch)
Warning
A total of
self.gradient_accumulation_steps()
entries will be pulled fromdata_iter
by each pipeline. There must be sufficient data left indata_iter
or else aStopIteration
will halt training.DeepSpeed provides a convenience class
deepspeed.utils.RepeatingLoader
that wraps data loaders to automatically restart upon aStopIteration
.Parameters: data_iter (Iterator) – Iterator of data to evaluate. Returns: The arithmetic mean of the losses computed this batch.
-
is_gradient_accumulation_boundary
()[source]¶ True if the engine is executing a gradient reduction or optimizer step instruction.
This is overridden from
DeepSpeedEngine
to force reductions and steps when the pipeline engine is instructed to do so.Returns: whether reductions and optimizer steps should occur. Return type: bool
-
module_state_dict
()[source]¶ Override hack to save a pipe model and return the directory path of the save.
This method should only be called by DeepSpeed’s
save_checkpoint()
. The recommended way of saving aPipelineModule
outside ofsave_checkpoint()
issave_state_dict()
.Returns: None
-
load_module_state_dict
(state_dict, strict=True)[source]¶ Override hack to instead use a directory path.
This is important because pipeline models checkpoint by layer instead of rank.
If
state_dict
is notNone
or astr
, we revert tosuper()
expecting adict
.Parameters: - state_dict (str, None) – unused
- strict (bool, optional) – Strict state loading. Defaults to True.
-
Extending Pipeline Parallelism¶
-
class
deepspeed.runtime.pipe.schedule.
PipeSchedule
(micro_batches, stages, stage_id)[source]¶ Directs the execution of a pipeline engine by generating sequences of
PipeInstruction
.Schedules are generators that yield sequences of
PipeInstruction
to process the micro-batches in one batch. Each yielded step is atomic in the sense that a barrier synchronization can be placed between successive steps without deadlock.Below is an example schedule that implements data parallelism with gradient accumulation:
class DataParallelSchedule(PipeSchedule): def steps(self): for step_id in range(self.micro_batches): cmds = [ LoadMicroBatch(buffer_id=0), ForwardPass(buffer_id=0), BackwardPass(buffer_id=0), ] if step_id == self.micro_batches - 1: cmds.extend([ ReduceGrads(), OptimizerStep(), ]) yield cmds def num_pipe_buffers(self): return 1
Parameters: - micro_batches (int) – The number of micro-batches that comprise a batch.
- stages (int) – The number of pipeline stages.
- stage_id (int) – The pipe stage that will execute the generated schedule.
-
steps
()[source]¶ Yield a list of
PipeInstruction
for each step in the schedule.Note
Schedules must implement
steps()
to define the schedule.Returns: Instructions to be executed as one step of the pipeline
-
num_pipe_buffers
()[source]¶ The number of pipeline buffers that will be used by this stage.
Note
Schedules should specialize
num_pipe_buffers()
for memory savings at scale.Returns: The number of buffers for the engine to allocate.
-
stage
¶ Stage index used to configure this schedule.
-
num_stages
¶ The number of total pipeline stages used to configure this schedule.
-
num_micro_batches
¶ The number of total micro_batches used to configure this schedule.
-
is_first_stage
¶ True if the configured
stage_id
is the first stage in the pipeline.
-
is_last_stage
¶ True if the configured
stage_id
is the last stage in the pipeline.
-
class
deepspeed.runtime.pipe.schedule.
InferenceSchedule
(micro_batches, stages, stage_id)[source]¶ A schedule for inferencing batches using pipeline parallelism.
-
class
deepspeed.runtime.pipe.schedule.
TrainSchedule
(micro_batches, stages, stage_id)[source]¶ A schedule for training a batch using hybrid parallelism.
Pipeline parallelism is extracted through gradient accumulation and thus convergence follows that of a data parallel approach with the same batch size.
-
class
deepspeed.runtime.pipe.schedule.
DataParallelSchedule
(micro_batches, stages, stage_id)[source]¶ An example schedule that trains using traditional data parallelism with gradient accumulation.
-
class
deepspeed.runtime.pipe.schedule.
PipeInstruction
(**kwargs)[source]¶ Base class for all instructions to be executed by the pipeline engine.
All keyword arguments are stored as members similar to a
namedtuple
. These are then accessible to thePipeEngine
during execution.Parameters: kwargs (optional) – keyword arguments to store as members
-
class
deepspeed.runtime.pipe.schedule.
OptimizerStep
(**kwargs)[source]¶ Performs one step with the optimizer and zeros gradients.
Note
Should be issued after
ReduceGrads
andReduceTiedGrads
.Note
Can be a synchronization point among data-parallel ranks.
-
class
deepspeed.runtime.pipe.schedule.
ReduceGrads
(**kwargs)[source]¶ Reduce the computed gradients among data-parallel processes within the stage.
-
class
deepspeed.runtime.pipe.schedule.
ReduceTiedGrads
(**kwargs)[source]¶ Reduce the computed gradients of tied modules within a pipeline-parallel group.
Warning
The stages included in this synchronization point are not known until the model is partitioned among pipeline stages. In the worst case, it includes all pipeline stages. This instruction should be scheduled carefully to avoid deadlocks.
-
class
deepspeed.runtime.pipe.schedule.
BufferOpInstruction
(buffer_id, **kwargs)[source]¶ A pipeline instruction that operates on pipeline buffer(s).
Parameters: buffer_id (int) – the index of the pipeline buffer() to modify.
-
class
deepspeed.runtime.pipe.schedule.
LoadMicroBatch
(buffer_id, **kwargs)[source]¶ Load a micro-batch into a buffer.
Roughly:
buffers['inputs'][buffer_id] = next(data_iter)
-
class
deepspeed.runtime.pipe.schedule.
ForwardPass
(buffer_id, **kwargs)[source]¶ Compute a forward pass.
Roughly:
buffers['ouputs'][buffer_id] = forward(buffers['inputs'][buffer_id])
-
class
deepspeed.runtime.pipe.schedule.
BackwardPass
(buffer_id, **kwargs)[source]¶ Compute a backward pass and accumulate gradients.
Roughly:
outputs = buffers['ouputs'][buffer_id] gradients = buffers['gradients'][buffer_id] torch.autograd.backward(tensors=outputs, grad_tensors=gradients)
-
class
deepspeed.runtime.pipe.schedule.
SendActivation
(buffer_id, **kwargs)[source]¶ Send activations to the next stage in the pipeline.
Roughly:
send(buffers['outputs'][buffer_id])
Note
The communication is blocking and must be paired with a
RecvActivation
on the next pipeline stage to avoid deadlock.
-
class
deepspeed.runtime.pipe.schedule.
RecvActivation
(buffer_id, **kwargs)[source]¶ Receive activations from the previous stage in the pipeline.
Roughly:
buffers['inputs'][buffer_id] = recv()
Note
The communication is blocking and must be paired with a
SendActivation
on the previous pipeline stage to avoid deadlock.
-
class
deepspeed.runtime.pipe.schedule.
SendGrad
(buffer_id, **kwargs)[source]¶ Send computed gradients to the previous pipeline stage. with respect to the received activations
Note
Only received tensors with
requires_grad==True
will produce gradients. Missing gradients will be replaced withNone
on the receiving stage.Note
The communication is blocking and must be paired with a
RecvGrad
on the previous pipeline stage to avoid deadlock.
-
class
deepspeed.runtime.pipe.schedule.
RecvGrad
(buffer_id, **kwargs)[source]¶ Receive computed gradients the next pipeline stage.
Note
Only activations with
requires_grad==True
will produce gradients. Missing gradients will be replaced withNone
.Note
The communication is blocking and must be paired with a
SendGrad
on the next pipeline stage to avoid deadlock.