IDF / idf /models /loggers.py
dongjin-kim's picture
Upload 47 files
207cadb verified
from PIL import Image
import os
from lightning.pytorch.loggers.logger import Logger, rank_zero_experiment
from lightning.pytorch.utilities import rank_zero_only
from pytorch_lightning.loggers import TensorBoardLogger, WandbLogger
from pathlib import Path
from typing import Any, Dict, Optional, Union
from lightning.fabric.utilities.types import _PATH
import torch
import numpy as np
__all__ = [
"TensorBoardLogger",
"WandbLogger",
]
class LocalImageLogger(Logger):
def __init__(
self,
save_dir: _PATH,
name: Optional[str] = "lightning_logs",
version: Optional[Union[int, str]] = None,
):
super().__init__()
self._root_dir = save_dir
self._name = name
self._version = version
@property
def name(self) -> str:
"""Get the name of the experiment.
Returns:
The name of the experiment.
"""
return self._name
@property
def version(self) -> Union[int, str]:
"""Get the experiment version.
Returns:
The experiment version if specified else the next version.
"""
if self._version is None:
self._version = 'temp'
return self._version
@property
def root_dir(self) -> str:
"""Gets the save directory where the TensorBoard experiments are saved.
Returns:
The local path to the save directory where the TensorBoard experiments are saved.
"""
return self._root_dir
@property
def log_dir(self) -> str:
"""The directory for this run.
By default, it is named ``'version_${self.version}'`` but it can be overridden by passing a string value for the
constructor's version parameter instead of ``None`` or an int.
"""
version = self.version if isinstance(self.version, str) else f"version_{self.version}"
log_dir = os.path.join(self.root_dir, self.name, version)
log_dir = os.path.expandvars(log_dir)
log_dir = os.path.expanduser(log_dir)
return log_dir
@property
@rank_zero_experiment
def experiment(self) -> "self":
"""Actual object. To use features anywhere in your code, do the following.
Example::
logger.experiment.some_function()
"""
assert rank_zero_only.rank == 0, "tried to init log dirs in non global_rank=0"
if self.log_dir:
Path(self.log_dir).mkdir(parents=True, exist_ok=True)
return self
def log_image(self, name, image, step=None):
if type(image) == torch.Tensor:
# Convert tensor to PIL Image and save
image = Image.fromarray((image*255).permute(1, 2, 0).byte().cpu().numpy())
image.save(Path(self.log_dir) / fr"{name}_{step}.png")
elif type(image) == np.ndarray:
image = Image.fromarray(np.uint8(image * 255).transpose(1, 2, 0))
image.save(Path(self.log_dir) / fr"{name}_{step}.png")
elif type(image) == Image.Image:
image.save(Path(self.log_dir) / fr"{name}_{step}.png")
else:
raise NotImplementedError()
@rank_zero_only
def log_hyperparams(self, params):
pass
@rank_zero_only
def log_metrics(self, metrics, step):
pass