112 lines
2.9 KiB
Python
112 lines
2.9 KiB
Python
from tkinter import W
|
|
import torch
|
|
import wandb
|
|
import numpy as np
|
|
from PIL import Image
|
|
from einops import rearrange
|
|
from typing import Protocol, Tuple, Optional
|
|
|
|
|
|
class Logger(Protocol):
|
|
def metrics(self, metrics: dict, epoch: int):
|
|
"""loss etc."""
|
|
|
|
def hyperparameters(self, hyperparameters: dict):
|
|
"""model states"""
|
|
|
|
def predictions(self, predictions: dict):
|
|
"""inference time stuff"""
|
|
|
|
def images(self, images: np.ndarray):
|
|
"""log images"""
|
|
|
|
|
|
class WandbLogger:
|
|
def __init__(self, project: str, entity: str, name: Optional[str], notes: str):
|
|
self.project = project
|
|
self.entity = entity
|
|
self.notes = notes
|
|
self.experiment = wandb.init(project=project, entity=entity, notes=notes)
|
|
self.experiment.name = name
|
|
|
|
self.data_dict = {}
|
|
|
|
def metrics(self, metrics: dict):
|
|
"""loss etc."""
|
|
|
|
self.data_dict.update(metrics)
|
|
|
|
def hyperparameters(self, hyperparameters: dict):
|
|
"""model states"""
|
|
self.experiment.config.update(hyperparameters, allow_val_change=True)
|
|
|
|
def predictions(self, predictions: dict):
|
|
"""inference time stuff"""
|
|
|
|
def image(self, image: dict):
|
|
"""log images to wandb"""
|
|
self.data_dict.update({'Generate Image' : image})
|
|
|
|
def video(self, images: str, title: str):
|
|
"""log images to wandb"""
|
|
|
|
images = np.uint8(rearrange(images, 't b c h w -> b t c h w'))
|
|
self.data_dict.update({f"{title}": wandb.Video(images, fps=20)})
|
|
|
|
def flush(self):
|
|
self.experiment.log(self.data_dict)
|
|
self.data_dict = {}
|
|
|
|
|
|
class DebugLogger:
|
|
def __init__(self, project: str, entity: str, name: str, notes: str):
|
|
self.project = project
|
|
self.entity = entity
|
|
self.name = name
|
|
self.notes = notes
|
|
|
|
def metrics(self, metrics: dict, epoch: int = None):
|
|
"""
|
|
loss etc.
|
|
"""
|
|
print(f"metrics: {metrics}")
|
|
|
|
def hyperparameters(self, hyperparameters: dict):
|
|
"""
|
|
model states
|
|
"""
|
|
print(f"hyperparameters: {hyperparameters}")
|
|
|
|
def predictions(self, predictions: dict):
|
|
"""
|
|
inference time stuff
|
|
"""
|
|
|
|
|
|
class Checkpoint:
|
|
def __init__(self, checkpoint_path):
|
|
self.checkpoint_path = checkpoint_path
|
|
|
|
def load(self) -> Tuple:
|
|
checkpoint = torch.load(self.checkpoint_path)
|
|
model = checkpoint["model"]
|
|
optimizer = checkpoint["optimizer"]
|
|
epoch = checkpoint["epoch"]
|
|
loss = checkpoint["loss"]
|
|
return (model, optimizer, epoch, loss)
|
|
|
|
def save(self, model: torch.nn.Module, optimizer, epoch, loss):
|
|
checkpoint = {
|
|
"model": model,
|
|
"optimizer": optimizer,
|
|
"epoch": epoch,
|
|
"loss": loss,
|
|
}
|
|
import random
|
|
import string
|
|
|
|
name = "".join(random.choices(string.ascii_letters, k=10)) + ".tar"
|
|
torch.save(checkpoint, f"{name}")
|
|
|
|
|