Spaces:
Build error
Build error
| import bisect | |
| import functools | |
| import logging | |
| import numbers | |
| import os | |
| import signal | |
| import sys | |
| import traceback | |
| import warnings | |
| import torch | |
| from pytorch_lightning import seed_everything | |
| LOGGER = logging.getLogger(__name__) | |
| import platform | |
| if platform.system() != 'Linux': | |
| signal.SIGUSR1 = 1 | |
| def check_and_warn_input_range(tensor, min_value, max_value, name): | |
| actual_min = tensor.min() | |
| actual_max = tensor.max() | |
| if actual_min < min_value or actual_max > max_value: | |
| warnings.warn(f"{name} must be in {min_value}..{max_value} range, but it ranges {actual_min}..{actual_max}") | |
| def sum_dict_with_prefix(target, cur_dict, prefix, default=0): | |
| for k, v in cur_dict.items(): | |
| target_key = prefix + k | |
| target[target_key] = target.get(target_key, default) + v | |
| def average_dicts(dict_list): | |
| result = {} | |
| norm = 1e-3 | |
| for dct in dict_list: | |
| sum_dict_with_prefix(result, dct, '') | |
| norm += 1 | |
| for k in list(result): | |
| result[k] /= norm | |
| return result | |
| def add_prefix_to_keys(dct, prefix): | |
| return {prefix + k: v for k, v in dct.items()} | |
| def set_requires_grad(module, value): | |
| for param in module.parameters(): | |
| param.requires_grad = value | |
| def flatten_dict(dct): | |
| result = {} | |
| for k, v in dct.items(): | |
| if isinstance(k, tuple): | |
| k = '_'.join(k) | |
| if isinstance(v, dict): | |
| for sub_k, sub_v in flatten_dict(v).items(): | |
| result[f'{k}_{sub_k}'] = sub_v | |
| else: | |
| result[k] = v | |
| return result | |
| class LinearRamp: | |
| def __init__(self, start_value=0, end_value=1, start_iter=-1, end_iter=0): | |
| self.start_value = start_value | |
| self.end_value = end_value | |
| self.start_iter = start_iter | |
| self.end_iter = end_iter | |
| def __call__(self, i): | |
| if i < self.start_iter: | |
| return self.start_value | |
| if i >= self.end_iter: | |
| return self.end_value | |
| part = (i - self.start_iter) / (self.end_iter - self.start_iter) | |
| return self.start_value * (1 - part) + self.end_value * part | |
| class LadderRamp: | |
| def __init__(self, start_iters, values): | |
| self.start_iters = start_iters | |
| self.values = values | |
| assert len(values) == len(start_iters) + 1, (len(values), len(start_iters)) | |
| def __call__(self, i): | |
| segment_i = bisect.bisect_right(self.start_iters, i) | |
| return self.values[segment_i] | |
| def get_ramp(kind='ladder', **kwargs): | |
| if kind == 'linear': | |
| return LinearRamp(**kwargs) | |
| if kind == 'ladder': | |
| return LadderRamp(**kwargs) | |
| raise ValueError(f'Unexpected ramp kind: {kind}') | |
| def print_traceback_handler(sig, frame): | |
| LOGGER.warning(f'Received signal {sig}') | |
| bt = ''.join(traceback.format_stack()) | |
| LOGGER.warning(f'Requested stack trace:\n{bt}') | |
| def register_debug_signal_handlers(sig=signal.SIGUSR1, handler=print_traceback_handler): | |
| LOGGER.warning(f'Setting signal {sig} handler {handler}') | |
| signal.signal(sig, handler) | |
| def handle_deterministic_config(config): | |
| seed = dict(config).get('seed', None) | |
| if seed is None: | |
| return False | |
| seed_everything(seed) | |
| return True | |
| def get_shape(t): | |
| if torch.is_tensor(t): | |
| return tuple(t.shape) | |
| elif isinstance(t, dict): | |
| return {n: get_shape(q) for n, q in t.items()} | |
| elif isinstance(t, (list, tuple)): | |
| return [get_shape(q) for q in t] | |
| elif isinstance(t, numbers.Number): | |
| return type(t) | |
| else: | |
| raise ValueError('unexpected type {}'.format(type(t))) | |
| def get_has_ddp_rank(): | |
| master_port = os.environ.get('MASTER_PORT', None) | |
| node_rank = os.environ.get('NODE_RANK', None) | |
| local_rank = os.environ.get('LOCAL_RANK', None) | |
| world_size = os.environ.get('WORLD_SIZE', None) | |
| has_rank = master_port is not None or node_rank is not None or local_rank is not None or world_size is not None | |
| return has_rank | |
| def handle_ddp_subprocess(): | |
| def main_decorator(main_func): | |
| def new_main(*args, **kwargs): | |
| # Trainer sets MASTER_PORT, NODE_RANK, LOCAL_RANK, WORLD_SIZE | |
| parent_cwd = os.environ.get('TRAINING_PARENT_WORK_DIR', None) | |
| has_parent = parent_cwd is not None | |
| has_rank = get_has_ddp_rank() | |
| assert has_parent == has_rank, f'Inconsistent state: has_parent={has_parent}, has_rank={has_rank}' | |
| if has_parent: | |
| # we are in the worker | |
| sys.argv.extend([ | |
| f'hydra.run.dir={parent_cwd}', | |
| # 'hydra/hydra_logging=disabled', | |
| # 'hydra/job_logging=disabled' | |
| ]) | |
| # do nothing if this is a top-level process | |
| # TRAINING_PARENT_WORK_DIR is set in handle_ddp_parent_process after hydra initialization | |
| main_func(*args, **kwargs) | |
| return new_main | |
| return main_decorator | |
| def handle_ddp_parent_process(): | |
| parent_cwd = os.environ.get('TRAINING_PARENT_WORK_DIR', None) | |
| has_parent = parent_cwd is not None | |
| has_rank = get_has_ddp_rank() | |
| assert has_parent == has_rank, f'Inconsistent state: has_parent={has_parent}, has_rank={has_rank}' | |
| if parent_cwd is None: | |
| os.environ['TRAINING_PARENT_WORK_DIR'] = os.getcwd() | |
| return has_parent | |