From 196a57dbae22d399aabe9088c105e1849201d934 Mon Sep 17 00:00:00 2001 From: jsleep Date: Mon, 18 Oct 2021 17:21:13 -0700 Subject: [PATCH] ort trainer backend --- .../azureml/submit_ortds.py | 25 +++++---- .../config-ort.yaml | 33 ++++++++++++ .../config-ortds.yaml | 1 + .../deepspeed_methods/deepspeed_utils.py | 2 +- .../cnndailymail_text_summarization/train.py | 4 +- pymarlin/core/ort_trainer_backend.py | 52 +++++++++++++++++++ pymarlin/core/trainer_backend.py | 35 ++----------- pymarlin/core/trainer_backend_factory.py | 37 +++++++++++++ 8 files changed, 141 insertions(+), 48 deletions(-) create mode 100644 examples/cnndailymail_text_summarization/config-ort.yaml create mode 100644 pymarlin/core/ort_trainer_backend.py create mode 100644 pymarlin/core/trainer_backend_factory.py diff --git a/examples/cnndailymail_text_summarization/azureml/submit_ortds.py b/examples/cnndailymail_text_summarization/azureml/submit_ortds.py index 84358f1..e95bfd7 100644 --- a/examples/cnndailymail_text_summarization/azureml/submit_ortds.py +++ b/examples/cnndailymail_text_summarization/azureml/submit_ortds.py @@ -1,19 +1,18 @@ -from azureml.core import Experiment, Workspace, ScriptRunConfig +from azureml.core import Experiment, Workspace, ScriptRunConfig, Datastore from azureml.core.compute import AmlCompute from azureml.core.runconfig import MpiConfiguration # put your AML workspace config.json in this directory! ws = Workspace.from_config() ws_details = ws.get_details() -ds = ws.get_default_datastore() +ds = Datastore(ws, 'ws2_ds') -gpu_compute_target = AmlCompute(workspace=ws, name='sriovdedicated1') +gpu_compute_target = AmlCompute(workspace=ws, name='LoRA-ND') print(gpu_compute_target.status.serialize()) from azureml.core import Dataset from azureml.data import OutputFileDatasetConfig - # create input/output datasets def get_input_dataset(datastore, path_on_datastore, dataset_name): dataset = Dataset.File.from_files(path=[(datastore, path_on_datastore)]) @@ -25,7 +24,7 @@ def get_output_dataset(datastore, path_on_datastore, dataset_name): def get_args(outputSuffix="deepspeed_ort_amp_nopadding_v100_8"): all_params_default = [ '--data_path', get_input_dataset(ds, f'datasets/cnn_dm/preprocessed/bart/', "data_path"), - '--config_path', 'config-ortds.yaml', + '--config_path', 'config-ort.yaml', ] return all_params_default @@ -33,21 +32,22 @@ def get_args(outputSuffix="deepspeed_ort_amp_nopadding_v100_8"): from azureml.core import Environment # Creates the environment inside a Docker container. -pytorch_env = Environment(name='myEnv') +pytorch_env = Environment(name='pymarlin-ort-ds') pytorch_env.docker.enabled = True # docker file in this directory built for your convenience -pytorch_env.docker.base_image = "pymarlin/base-gpu:cuda11.1.cudnn8.ds.ort" + +pytorch_env.docker.base_image = "pymarlin/pymarlin.cuda11.1" pytorch_env.python.user_managed_dependencies = True pytorch_env.python.interpreter_path = '/opt/miniconda/bin/python' mpi = MpiConfiguration() -#NCv3_24rs - 4 16GB V100 GPU's per node -mpi.process_count_per_node = 4 -mpi.node_count = 2 +# NDv2, 8 GPU's per node +mpi.process_count_per_node = 8 +mpi.node_count = 1 # ds.upload_files(['local path to preprocessed data'], 'datasets/cnn_dm/preprocessed/bart') -script = "train_ortds.py" +script = "train.py" codepath = '..' config = ScriptRunConfig(source_directory=codepath, @@ -57,14 +57,13 @@ def get_args(outputSuffix="deepspeed_ort_amp_nopadding_v100_8"): environment=pytorch_env, distributed_job_config=mpi) -experiment_name = 'pymarlin_summarization_bart_ortds' +experiment_name = 'summarization_bart_ort_backend' experiment = Experiment(ws, name=experiment_name) run = experiment.submit(config) run.tag('nodes', f'{mpi.node_count}') run.tag('process_count_per_node', f'{mpi.process_count_per_node}') -run.tag('notes', '2 node with ort+ds') print("Submitted run") print(f"\n{run.get_portal_url()}") diff --git a/examples/cnndailymail_text_summarization/config-ort.yaml b/examples/cnndailymail_text_summarization/config-ort.yaml new file mode 100644 index 0000000..a7b516e --- /dev/null +++ b/examples/cnndailymail_text_summarization/config-ort.yaml @@ -0,0 +1,33 @@ +data_path: 'D:/data/cnn_cln' + +trainer: + max_train_steps_per_epoch : null # Maximum train steps per epoch. + max_val_steps_per_epoch : null # Maximum validation steps per epoch. + train_batch_size: 32 # Training global batch size. + val_batch_size: 32 # Validation batch size per GPU. + epochs: 3 # Total epochs to run. + gpu_batch_size_limit : 4 # Max limit for GPU batch size during training. + disable_tqdm : False + writers: ["stdout", "aml", "tensorboard"] + backend: 'ddp-amp-ort' +module: + max_length_encoder : 1024 + max_length_decoder : 128 +wrt: + tb_log_dir : 'logs' +stat: + log_steps : 50 +chkp: + checkpoint : True + delete_existing_checkpoints: False + save_dir: 'outputs' #aml output path. does not require mounting + load_dir: null + load_filename: null + +# add more from BartForConditionalGeneration.generate? +generate: + max_length: 128 + do_sample : False + num_beams : 5 +# support everything in a yaml. ignore (print warning) everything that's not present. +# Do not add the requirement to define anything in the parser other than yamls diff --git a/examples/cnndailymail_text_summarization/config-ortds.yaml b/examples/cnndailymail_text_summarization/config-ortds.yaml index 900fd82..9a8775d 100644 --- a/examples/cnndailymail_text_summarization/config-ortds.yaml +++ b/examples/cnndailymail_text_summarization/config-ortds.yaml @@ -9,6 +9,7 @@ trainer: train_batch_size: 32 # Training global batch size. val_batch_size: 32 # Validation batch size per GPU. epochs: 3 # Total epochs to run. + ort: True gpu_batch_size_limit : 4 # Max limit for GPU batch size during training. disable_tqdm : True writers: ["stdout", "aml", "tensorboard"] diff --git a/examples/cnndailymail_text_summarization/deepspeed_methods/deepspeed_utils.py b/examples/cnndailymail_text_summarization/deepspeed_methods/deepspeed_utils.py index 95f89eb..baf097b 100644 --- a/examples/cnndailymail_text_summarization/deepspeed_methods/deepspeed_utils.py +++ b/examples/cnndailymail_text_summarization/deepspeed_methods/deepspeed_utils.py @@ -46,6 +46,6 @@ def get_core_model(model, deepspeed_flag=False, ort_flag=False): if deepspeed_flag: module = module.module if ort_flag: - module = module._original_module + module = module._module_metadata.original_module return module diff --git a/examples/cnndailymail_text_summarization/train.py b/examples/cnndailymail_text_summarization/train.py index 484ec09..cddc9c9 100644 --- a/examples/cnndailymail_text_summarization/train.py +++ b/examples/cnndailymail_text_summarization/train.py @@ -40,8 +40,8 @@ def __init__( generate_kwargs = {} ): super().__init__() - self.model = BartForConditionalGeneration.from_pretrained("facebook/bart-base") - self.tokenizer = BartTokenizerFast.from_pretrained("facebook/bart-base") + self.model = BartForConditionalGeneration.from_pretrained("facebook/bart-large") + self.tokenizer = BartTokenizerFast.from_pretrained("facebook/bart-large") self.max_lr = max_lr self.max_length_encoder = max_length_encoder self.max_length_decoder = max_length_decoder diff --git a/pymarlin/core/ort_trainer_backend.py b/pymarlin/core/ort_trainer_backend.py new file mode 100644 index 0000000..51ae00b --- /dev/null +++ b/pymarlin/core/ort_trainer_backend.py @@ -0,0 +1,52 @@ +from .trainer_backend import * +import sys +from pymarlin.utils.logger import getlogger +import torch.nn as nn + +class ORTTrainerBackend(AbstractTrainerBackendDecorator): + def __init__(self, trainer_backend): + super().__init__(trainer_backend) + self.logger = getlogger(__file__,log_level='DEBUG') + + # TODO: add these under TrainerBackendDecoratorPassThrough, which ORT, Opacus can inherit from + # so that DDP backend can get/set from wrapped SingleProcess* + def __getattribute__(self, name): + # self.logger.debug(f'__getattribute__(name={name})') + if name in ('trainer_backend','init','__init__','logger', '_core_model', 'core_model') : + return super().__getattribute__(name) + else: + return self.trainer_backend.__getattribute__(name) + + def __setattr__(self, name, value): + # self.logger.debug(f'__setattr_(name={name},value={value})') + if name in ('trainer_backend','init','__init__','logger', '_core_model', 'core_model') : + super().__setattr__(name, value) + else: + self.trainer_backend.__setattr__(name, value) + + @property + def core_model(self): + return self._core_model + + @core_model.setter + def core_model(self, model): + self._core_model = model + + def init(self, args: TrainerBackendArguments): + super().init(args) + try: + from torch_ort import ORTModule + except: + self.logger.error("could not import ORTModule") + sys.exit(1) + + assert(hasattr(self.trainer_backend.model, 'model'), 'self.trainer_backend.model.model does not exist') + assert(isinstance(self.trainer_backend.model.model, nn.Module), "expected module_inteface.model of type torch.nn.Module") + + # get the reference and save it before ORTModule wrap + self.core_model = self.trainer_backend.model.model + module = self.trainer_backend.model # TODO: should we change trainer_backend.model to module? + module.get_core_model = lambda: self.core_model + + self.logger.info("Wrapping trainer_backend.model.model") + self.trainer_backend.model.model = ORTModule(self.trainer_backend.model.model) \ No newline at end of file diff --git a/pymarlin/core/trainer_backend.py b/pymarlin/core/trainer_backend.py index 2336185..e2f0563 100644 --- a/pymarlin/core/trainer_backend.py +++ b/pymarlin/core/trainer_backend.py @@ -13,7 +13,7 @@ Alternatively a user can provide a custom `TrainerBackend`. """ from tqdm.auto import tqdm -from abc import ABC, abstractmethod +from abc import ABC, abstractmethod, abstractproperty import dataclasses from typing import Iterable, Optional, Union import warnings @@ -31,32 +31,12 @@ SequentialDistributedSampler, ) - try: from apex import amp except ImportError: amp = None from functools import wraps -def build_trainer_backend(trainer_backend_name, *args, **kwargs): - """Factory for trainer_backends - - Args: - trainer_backend_name (str): TrainerBackend Name. Possible choices are currently: sp, sp-amp, sp-amp-apex, ddp, ddp-amp, ddp-amp-apex - args (sequence): TrainerBackend positional arguments - kwargs (dict): TrainerBackend keyword arguments - """ - factory_dict = { - "sp": SingleProcess, - "sp-amp": SingleProcessAmp, - "sp-amp-apex": SingleProcessApexAmp, - "ddp": DDPTrainerBackendFactory(SingleProcess), - "ddp-amp": DDPTrainerBackendFactory(SingleProcessAmp), - "ddp-amp-apex": DDPTrainerBackendFactory(SingleProcessApexAmp), - } - return factory_dict[trainer_backend_name](*args, **kwargs) - - @dataclasses.dataclass class TrainerBackendArguments: """ @@ -106,13 +86,11 @@ def get_batches_completed(self): def get_global_steps_completed(self): pass - @property - @abstractmethod + @abstractproperty def train_sampler(self): return RandomSampler - @property - @abstractmethod + @abstractproperty def val_sampler(self): return SequentialSampler @@ -712,10 +690,3 @@ def train_sampler(self): @property def val_sampler(self): return SequentialDistributedSampler - -def DDPTrainerBackendFactory(trainer_backend_cls): # pylint: disable=invalid-name - def create(*args, gather_frequency: Optional[int] = None, **kwargs): - # pull out args to DDPTrainerBackend if needed here. - return DDPTrainerBackend(trainer_backend_cls(*args, **kwargs), gather_frequency=gather_frequency) - - return create diff --git a/pymarlin/core/trainer_backend_factory.py b/pymarlin/core/trainer_backend_factory.py new file mode 100644 index 0000000..bcc46c7 --- /dev/null +++ b/pymarlin/core/trainer_backend_factory.py @@ -0,0 +1,37 @@ +from .trainer_backend import * +from .ort_trainer_backend import ORTTrainerBackend + +def build_trainer_backend(trainer_backend_name, *args, **kwargs): + """Factory for trainer_backends + + Args: + trainer_backend_name (str): TrainerBackend Name. Possible choices are currently: sp, sp-amp, sp-amp-apex, ddp, ddp-amp, ddp-amp-apex + args (sequence): TrainerBackend positional arguments + kwargs (dict): TrainerBackend keyword arguments + """ + factory_dict = { + "sp": SingleProcess, + "sp-amp": SingleProcessAmp, + "sp-amp-apex": SingleProcessApexAmp, + "ddp": DDPTrainerBackendFactory(SingleProcess), + "ddp-amp-ort": DDPORTTrainerBackendFactory(SingleProcessAmp), + "ddp-amp": DDPTrainerBackendFactory(SingleProcessAmp), + "ddp-amp-apex": DDPTrainerBackendFactory(SingleProcessApexAmp), + } + return factory_dict[trainer_backend_name](*args, **kwargs) + +def DDPTrainerBackendFactory(trainer_backend_cls): # pylint: disable=invalid-name + def create(*args, gather_frequency: Optional[int] = None, **kwargs): + # pull out args to DDPTrainerBackend if needed here. + return DDPTrainerBackend(trainer_backend_cls(*args, **kwargs), gather_frequency=gather_frequency) + + return create + +# testing TODO: refactor factory logic to do hierachael decoration (sp->ort->ddp/deepspeed) +def DDPORTTrainerBackendFactory(trainer_backend_cls): + def create(*args, gather_frequency: Optional[int] = None, **kwargs): + return DDPTrainerBackend( + ORTTrainerBackend(trainer_backend_cls(*args, **kwargs)), + gather_frequency=gather_frequency) + + return create