def pretrain(
model_provider: Callable,
model_type: ModelType,
args,
data_loader_provider: Callable = None,
process_non_loss_data_func: Callable = None,
model_init_fn: Callable = None,
):
initialize_megatron(args)
model, optimizer, lr_scheduler = setup_model_and_optimizer(
model_provider, model_type, args, model_init_fn
)
if data_loader_provider is None:
data_loader_provider = build_pretraining_data_loader
dataloader = data_loader_provider(args)
if args.train_iters > 0:
iteration = train(
forward_step_func=model_provider.module_provider.forward_step,
model=model,
optimizer=optimizer,
lr_scheduler=lr_scheduler,
dataloader=dataloader,
args=args,
process_non_loss_data_func=process_non_loss_data_func,
)
return iteration
def train(
forward_step_func: Callable,
model: MegatronModule,
optimizer: Optimizer,
lr_scheduler: OptimizerParamScheduler,
dataloader: torch.utils.data.DataLoader,
args,
process_non_loss_data_func: Callable = None,
):
forward_backward_func = get_forward_backward_func(
args.virtual_pipeline_model_parallel_size,
args.pipeline_model_parallel_size,
args.num_layers_per_virtual_pipeline_stage,
)
for iteration in range(args.iteration, args.train_iters):
losses_reduced = forward_backward_func(
forward_step_func=forward_step_func,
data_iterator=data_iterator,
model=model,
num_microbatches=get_num_microbatches(),
seq_length=args.seq_length,
micro_batch_size=args.micro_batch_size,
decoder_seq_length=args.decoder_seq_length,
layernorm_epsilon=args.layernorm_epsilon,
hidden_size=args.hidden_size,
params_dtype=args.params_dtype,
fp16=args.fp16,
bf16=args.bf16,
fp32_residual_connection=args.fp32_residual_connection,
async_comm=args.async_comm,
)
if args.DDP_impl == 'local':
grads = [param.grad for param in model.parameters() if param.grad is not None]
_allreduce_gradients(grads, args)
optimizer.step()
lr_scheduler.step()
if args.save and args.save_interval and iteration % args.save_interval == 0:
save_checkpoint(iteration, model, optimizer, lr_scheduler, args)
training_log(losses_reduced, optimizer, lr_scheduler, iteration, args)
📋 主要依赖
• megatron.core.pipeline_parallel - 流水线并行调度
• megatron.core.optimizer - 分布式优化器
• megatron.training.checkpointing - 检查点管理