# Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # This work is licensed under a Creative Commons # Attribution-NonCommercial-ShareAlike 4.0 International License. # You should have received a copy of the license along with this # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/ import numpy as np import torch #---------------------------------------------------------------------------- # Cached construction of constant tensors. Avoids CPU=>GPU copy when the # same constant is used multiple times. _constant_cache = dict() def constant(value, shape=None, dtype=None, device=None, memory_format=None): value = np.asarray(value) if shape is not None: shape = tuple(shape) if dtype is None: dtype = torch.get_default_dtype() if device is None: device = torch.device('cpu') if memory_format is None: memory_format = torch.contiguous_format key = (value.shape, value.dtype, value.tobytes(), shape, dtype, device, memory_format) tensor = _constant_cache.get(key, None) if tensor is None: tensor = torch.as_tensor(value.copy(), dtype=dtype, device=device) if shape is not None: tensor, _ = torch.broadcast_tensors(tensor, torch.empty(shape)) tensor = tensor.contiguous(memory_format=memory_format) _constant_cache[key] = tensor return tensor #---------------------------------------------------------------------------- # Variant of constant() that inherits dtype and device from the given # reference tensor by default. def const_like(ref, value, shape=None, dtype=None, device=None, memory_format=None): if dtype is None: dtype = ref.dtype if device is None: device = ref.device return constant(value, shape=shape, dtype=dtype, device=device, memory_format=memory_format)