|
3 | 3 |
|
4 | 4 | # adopted from https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/context
|
5 | 5 |
|
6 |
| -import inspect |
7 | 6 | import random
|
8 | 7 | import socket
|
9 |
| -import sys |
10 |
| -from importlib.machinery import SourceFileLoader |
11 |
| -from pathlib import Path |
12 | 8 | from typing import Union
|
13 | 9 |
|
14 | 10 | import numpy as np
|
15 | 11 | import torch
|
16 | 12 | import torch.distributed as dist
|
17 | 13 |
|
18 | 14 | from internlm.accelerator import get_accelerator
|
| 15 | +from internlm.core.context import Config |
19 | 16 | from internlm.utils.common import SingletonMeta
|
20 | 17 | from internlm.utils.logger import get_logger
|
21 | 18 | from internlm.utils.timeout import LLM_NCCL_TIMEOUT
|
|
46 | 43 | internlm_accelerator = get_accelerator()
|
47 | 44 |
|
48 | 45 |
|
49 |
| -class Config(dict): |
50 |
| - """This is a wrapper class for dict objects so that values of which can be |
51 |
| - accessed as attributes. |
52 |
| -
|
53 |
| - Args: |
54 |
| - config (dict): The dict object to be wrapped. |
55 |
| - """ |
56 |
| - |
57 |
| - def __init__(self, config: dict = None): # pylint: disable=W0231 |
58 |
| - if config is not None: |
59 |
| - for k, v in config.items(): |
60 |
| - self._add_item(k, v) |
61 |
| - |
62 |
| - def __missing__(self, key): |
63 |
| - raise KeyError(key) |
64 |
| - |
65 |
| - def __getattr__(self, key): |
66 |
| - try: |
67 |
| - value = super().__getitem__(key) |
68 |
| - return value |
69 |
| - except KeyError: |
70 |
| - raise AttributeError(key) |
71 |
| - |
72 |
| - def __setattr__(self, key, value): |
73 |
| - super().__setitem__(key, value) |
74 |
| - |
75 |
| - def _add_item(self, key, value): |
76 |
| - if isinstance(value, dict): |
77 |
| - self.__setattr__(key, Config(value)) |
78 |
| - else: |
79 |
| - self.__setattr__(key, value) |
80 |
| - |
81 |
| - def update(self, config): |
82 |
| - assert isinstance(config, (Config, dict)), "can only update dictionary or Config objects." |
83 |
| - for k, v in config.items(): |
84 |
| - self._add_item(k, v) |
85 |
| - return self |
86 |
| - |
87 |
| - @staticmethod |
88 |
| - def from_file(filename: str): |
89 |
| - """Reads a python file and constructs a corresponding :class:`Config` object. |
90 |
| -
|
91 |
| - Args: |
92 |
| - filename (str): Name of the file to construct the return object. |
93 |
| -
|
94 |
| - Returns: |
95 |
| - :class:`Config`: A :class:`Config` object constructed with information in the file. |
96 |
| -
|
97 |
| - Raises: |
98 |
| - AssertionError: Raises an AssertionError if the file does not exist, or the file is not .py file |
99 |
| - """ |
100 |
| - |
101 |
| - # check config path |
102 |
| - if isinstance(filename, str): |
103 |
| - filepath = Path(filename).absolute() |
104 |
| - elif isinstance(filename, Path): |
105 |
| - filepath = filename.absolute() |
106 |
| - |
107 |
| - assert filepath.exists(), f"{filename} is not found, please check your configuration path" |
108 |
| - |
109 |
| - # check extension |
110 |
| - extension = filepath.suffix |
111 |
| - assert extension == ".py", "only .py files are supported" |
112 |
| - |
113 |
| - # import the config as module |
114 |
| - remove_path = False |
115 |
| - if filepath.parent not in sys.path: |
116 |
| - sys.path.insert(0, (filepath)) |
117 |
| - remove_path = True |
118 |
| - |
119 |
| - module_name = filepath.stem |
120 |
| - source_file = SourceFileLoader(fullname=str(module_name), path=str(filepath)) |
121 |
| - module = source_file.load_module() # pylint: disable=W4902,E1120,W1505 |
122 |
| - |
123 |
| - # load into config |
124 |
| - config = Config() |
125 |
| - |
126 |
| - for k, v in module.__dict__.items(): |
127 |
| - if k.startswith("__") or inspect.ismodule(v) or inspect.isclass(v): |
128 |
| - continue |
129 |
| - else: |
130 |
| - config._add_item(k, v) |
131 |
| - |
132 |
| - # remove module |
133 |
| - del sys.modules[module_name] |
134 |
| - if remove_path: |
135 |
| - sys.path.pop(0) |
136 |
| - |
137 |
| - return config |
138 |
| - |
139 |
| - |
140 | 46 | class ParallelContext(metaclass=SingletonMeta):
|
141 | 47 | """This class provides interface functions for users to get the parallel context,
|
142 | 48 | such as the global rank, the local rank, the world size, etc. of each device.
|
|
0 commit comments