Spaces:
Running
on
Zero
Running
on
Zero
| from PIL import Image | |
| import os | |
| from lightning.pytorch.loggers.logger import Logger, rank_zero_experiment | |
| from lightning.pytorch.utilities import rank_zero_only | |
| from pytorch_lightning.loggers import TensorBoardLogger, WandbLogger | |
| from pathlib import Path | |
| from typing import Any, Dict, Optional, Union | |
| from lightning.fabric.utilities.types import _PATH | |
| import torch | |
| import numpy as np | |
| __all__ = [ | |
| "TensorBoardLogger", | |
| "WandbLogger", | |
| ] | |
| class LocalImageLogger(Logger): | |
| def __init__( | |
| self, | |
| save_dir: _PATH, | |
| name: Optional[str] = "lightning_logs", | |
| version: Optional[Union[int, str]] = None, | |
| ): | |
| super().__init__() | |
| self._root_dir = save_dir | |
| self._name = name | |
| self._version = version | |
| def name(self) -> str: | |
| """Get the name of the experiment. | |
| Returns: | |
| The name of the experiment. | |
| """ | |
| return self._name | |
| def version(self) -> Union[int, str]: | |
| """Get the experiment version. | |
| Returns: | |
| The experiment version if specified else the next version. | |
| """ | |
| if self._version is None: | |
| self._version = 'temp' | |
| return self._version | |
| def root_dir(self) -> str: | |
| """Gets the save directory where the TensorBoard experiments are saved. | |
| Returns: | |
| The local path to the save directory where the TensorBoard experiments are saved. | |
| """ | |
| return self._root_dir | |
| def log_dir(self) -> str: | |
| """The directory for this run. | |
| By default, it is named ``'version_${self.version}'`` but it can be overridden by passing a string value for the | |
| constructor's version parameter instead of ``None`` or an int. | |
| """ | |
| version = self.version if isinstance(self.version, str) else f"version_{self.version}" | |
| log_dir = os.path.join(self.root_dir, self.name, version) | |
| log_dir = os.path.expandvars(log_dir) | |
| log_dir = os.path.expanduser(log_dir) | |
| return log_dir | |
| def experiment(self) -> "self": | |
| """Actual object. To use features anywhere in your code, do the following. | |
| Example:: | |
| logger.experiment.some_function() | |
| """ | |
| assert rank_zero_only.rank == 0, "tried to init log dirs in non global_rank=0" | |
| if self.log_dir: | |
| Path(self.log_dir).mkdir(parents=True, exist_ok=True) | |
| return self | |
| def log_image(self, name, image, step=None): | |
| if type(image) == torch.Tensor: | |
| # Convert tensor to PIL Image and save | |
| image = Image.fromarray((image*255).permute(1, 2, 0).byte().cpu().numpy()) | |
| image.save(Path(self.log_dir) / fr"{name}_{step}.png") | |
| elif type(image) == np.ndarray: | |
| image = Image.fromarray(np.uint8(image * 255).transpose(1, 2, 0)) | |
| image.save(Path(self.log_dir) / fr"{name}_{step}.png") | |
| elif type(image) == Image.Image: | |
| image.save(Path(self.log_dir) / fr"{name}_{step}.png") | |
| else: | |
| raise NotImplementedError() | |
| def log_hyperparams(self, params): | |
| pass | |
| def log_metrics(self, metrics, step): | |
| pass |