Training Setup¶
Argument Parsing¶
DeepSpeed uses the argparse library to
supply commandline configuration to the DeepSpeed runtime. Use deepspeed.add_config_arguments()
to add DeepSpeed’s builtin arguments to your application’s parser.
parser = argparse.ArgumentParser(description='My training script.')
parser.add_argument('--local_rank', type=int, default=-1,
help='local rank passed from distributed launcher')
# Include DeepSpeed configuration arguments
parser = deepspeed.add_config_arguments(parser)
cmd_args = parser.parse_args()
-
deepspeed.
add_config_arguments
(parser)[source]¶ - Update the argument parser to enabling parsing of DeepSpeed command line arguments.
- The set of DeepSpeed arguments include the following: 1) –deepspeed: boolean flag to enable DeepSpeed 2) –deepspeed_config <json file path>: path of a json configuration file to configure DeepSpeed runtime.
Parameters: parser – argument parser Returns: Updated Parser Return type: parser
Training Initialization¶
The entrypoint for all training with DeepSpeed is deepspeed.initialize()
. Will initialize distributed backend if it is not intialized already.
Example usage:
model_engine, optimizer, _, _ = deepspeed.initialize(args=cmd_args,
model=net,
model_parameters=net.parameters())
-
deepspeed.
initialize
(args, model, optimizer=None, model_parameters=None, training_data=None, lr_scheduler=None, mpu=None, dist_init_required=None, collate_fn=None, config_params=None)[source]¶ Initialize the DeepSpeed Engine.
Parameters: - args – a dictionary containing local_rank and deepspeed_config file location
- model – Required: nn.module class before apply any wrappers
- optimizer – Optional: a user defined optimizer, this is typically used instead of defining an optimizer in the DeepSpeed json config.
- model_parameters – Optional: An iterable of torch.Tensors or dicts. Specifies what Tensors should be optimized.
- training_data – Optional: Dataset of type torch.utils.data.Dataset
- lr_scheduler – Optional: Learning Rate Scheduler Object. It should define a get_lr(), step(), state_dict(), and load_state_dict() methods
- mpu – Optional: A model parallelism unit object that implements get_{model,data}_parallel_{rank,group,world_size}()
- dist_init_required – Optional: None will auto-initialize torch.distributed if needed, otherwise the user can force it to be initialized or not via boolean.
- collate_fn – Optional: Merges a list of samples to form a mini-batch of Tensor(s). Used when using batched loading from a map-style dataset.
Returns: A tuple of
engine
,optimizer
,training_dataloader
,lr_scheduler
engine
: DeepSpeed runtime engine which wraps the client model for distributed training.optimizer
: Wrapped optimizer if a user definedoptimizer
is supplied, or if optimizer is specified in json config elseNone
.training_dataloader
: DeepSpeed dataloader iftraining_data
was supplied, otherwiseNone
.lr_scheduler
: Wrapped lr scheduler if userlr_scheduler
is passed, or iflr_scheduler
specified in JSON configuration. OtherwiseNone
.
Distributed Initialization¶
Optional distributed backend initializating separate from deepspeed.initialize()
. Useful in scenarios where the user wants to use torch distributed calls before calling deepspeed.initialize()
, such as when using model parallelism, pipeline parallelism, or certain data loader scenarios.
-
deepspeed.
init_distributed
(dist_backend='nccl', auto_mpi_discovery=True, distributed_port=29500, verbose=True, timeout=datetime.timedelta(seconds=1800), init_method=None)[source]¶ Initialize torch.distributed backend, potentially performing MPI discovery if needed
Parameters: - dist_backend – Optional (str). torch distributed backend, e.g., nccl, mpi, gloo
- Optional (auto_mpi_discovery) –
- distributed_port – Optional (int). torch distributed backend port
- verbose – Optional (bool). verbose logging
- timeout – Optional (timedelta). Timeout for operations executed against the process group. Default value equals 30 minutes.
- init_method – Optional (string). Torch distributed, URL specifying how to initialize the process group. Default is “env://” if no init_method or store is specified.