Source code for deepspeed.config.base

class ConfigError(Exception):
    """Errors related to DeepSpeed configuration. """
    pass


class ConfigArg:
    def __init__(self, default=None, value=None):
        self.default = default
        if value is not None:
            self.value = value
        else:
            self.value = default

    def is_valid(self):
        return True

    def __repr__(self):
        return str(self.value)


class RequiredArg(ConfigArg):
    def __init__(self):
        super().__init__(default=None)

    def is_valid(self):
        # Ensure the required argument is provided.
        if self.value is None:
            return False
        return super().is_valid()


class SubConfig(ConfigArg):
    def __init__(self, config):
        if not isinstance(config, Config):
            raise TypeError(f'Expecting type Config, got {type(config)}')
        super().__init__(value=config)

    def is_valid(self):
        return self.value.is_valid()


class MetaConfig(type):
    """Metaclass (e.g, class factory) for :class:`Config`.

    This is used to extract the argument class attributes and stash them in
    `cls._class_args`.
    """
    def __new__(cls, name, bases, dct):
        config_args = dict()

        # Extract configs from the class dictionary and move them to _class_args
        for key, val in dct.items():
            if isinstance(val, ConfigArg):
                config_args[key] = val
        for key in config_args.keys():
            del dct[key]
        dct['_class_args'] = config_args

        return super().__new__(cls, name, bases, dct)


[docs]class Config(metaclass=MetaConfig): """Base class for DeepSpeed configurations. ``Config`` is a struct with subclassing. They are initialized from dictionaries and thus also keyword arguments: >>> c = Config(verbose=True) >>> c.verbose True >>> c['verbose'] True You can initialize them from dictionaries: >>> myconf = {'verbose' : True} >>> c = Config.from_dict(myconf) >>> c.verbose True Configurations should be subclassed to group arguments by topic. """ def __init__(self, **kwargs): super().__init__() # The config arguments we are tracking. Maps name -> ConfigArg self._args = dict() # Initialize config structure and defaults. _class_args is a dict of the args # in the class definition. for key, val in self._class_args.items(): self._set_arg(key, val) # First grab defaults # Overwrite any non-defaults specified for key, val in kwargs.items(): self._set_arg(key, ConfigArg(value=val)) def _set_arg(self, key, value): # Is this a fresh arg? if isinstance(value, ConfigArg): self._args[key] = value else: # Update the value self._args[key].value = value def __setattr__(self, name, value): # We may be at the start of __init__ before these are set args = self.__dict__.get('_args') # Updating an argument? if (args is not None) and (name in args): self._set_arg(name, value) return # base case super().__setattr__(name, value) def __getattr__(self, name): args = self.__dict__.get('_args') if (args is not None) and (name in args): return args[name].value raise AttributeError( f'{self.__class__.__name__} does not have attribute "{name}"') def __getitem__(self, name): return getattr(self, name)
[docs] def resolve(self): """Infer any missing arguments, if possible. This is useful for configs such as :class:`BatchConfig` in only a subset of arguments are required to complete a valid config. """ # Walk the tree of subconfigs and also resolve(). for arg in self._args: if isinstance(arg, SubConfig): arg.resolve()
@classmethod def from_json(cls, json_path): with open(json_path, 'r') as fin: config_dict = json.load(fin) return cls(**config_dict) @classmethod def from_dict(cls, config_dict): return cls(**config_dict)
[docs] def is_valid(self): """Resolve any missing configurations and determine in the configuration is valid. Returns: bool: Whether the config and all sub-configs are valid. """ self.resolve() return all(arg.is_valid() for arg in self._args.values())
def __str__(self): return self.dot_str() def dot_str(self, depth=0, dots_width=50): indent_width = 4 indent = ' ' * indent_width lines = [] lines.append(f'{indent * depth}{self.__class__.__name__} = {{') for key, val in self._args.items(): # Recursive configurations if isinstance(val, SubConfig): config = val.value lines.append(config.dot_str(depth=depth + 1)) continue dots = '.' * (dots_width - len(key) - (depth * indent_width)) lines.append(f'{indent * (depth+1)}{key} {dots} {val}') lines.append(f'{indent * depth}}}') return '\n'.join(lines)