Spaces:
Running
on
Zero
Running
on
Zero
| # Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. | |
| # | |
| # This work is licensed under a Creative Commons | |
| # Attribution-NonCommercial-ShareAlike 4.0 International License. | |
| # You should have received a copy of the license along with this | |
| # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/ | |
| import numpy as np | |
| import torch | |
| #---------------------------------------------------------------------------- | |
| # Cached construction of constant tensors. Avoids CPU=>GPU copy when the | |
| # same constant is used multiple times. | |
| _constant_cache = dict() | |
| def constant(value, shape=None, dtype=None, device=None, memory_format=None): | |
| value = np.asarray(value) | |
| if shape is not None: | |
| shape = tuple(shape) | |
| if dtype is None: | |
| dtype = torch.get_default_dtype() | |
| if device is None: | |
| device = torch.device('cpu') | |
| if memory_format is None: | |
| memory_format = torch.contiguous_format | |
| key = (value.shape, value.dtype, value.tobytes(), shape, dtype, device, memory_format) | |
| tensor = _constant_cache.get(key, None) | |
| if tensor is None: | |
| tensor = torch.as_tensor(value.copy(), dtype=dtype, device=device) | |
| if shape is not None: | |
| tensor, _ = torch.broadcast_tensors(tensor, torch.empty(shape)) | |
| tensor = tensor.contiguous(memory_format=memory_format) | |
| _constant_cache[key] = tensor | |
| return tensor | |
| #---------------------------------------------------------------------------- | |
| # Variant of constant() that inherits dtype and device from the given | |
| # reference tensor by default. | |
| def const_like(ref, value, shape=None, dtype=None, device=None, memory_format=None): | |
| if dtype is None: | |
| dtype = ref.dtype | |
| if device is None: | |
| device = ref.device | |
| return constant(value, shape=shape, dtype=dtype, device=device, memory_format=memory_format) | |