HATSAT / utils /image_utils.py
BorisEm's picture
Broke down code base into smaller files for readibility
0def483
"""
Image processing utilities.
"""
import torch
import numpy as np
from PIL import Image
import base64
from io import BytesIO
from config import REQUIRED_IMAGE_SIZE, WINDOW_SIZE, UPSCALE_FACTOR
def validate_image_size(image):
"""Validate that the image is exactly the required size."""
if image is None:
return False, "No image provided"
width, height = image.size
req_width, req_height = REQUIRED_IMAGE_SIZE
if width != req_width or height != req_height:
return False, f"Image must be exactly {req_width}x{req_height} pixels. Your image is {width}x{height} pixels."
return True, "Valid image size"
def upscale_image(image, model, device):
"""Upscale an image using the HAT model."""
# Convert PIL image to tensor
img_np = np.array(image).astype(np.float32) / 255.0
img_tensor = torch.from_numpy(img_np).permute(2, 0, 1).unsqueeze(0).to(device)
# Ensure the image dimensions are multiples of window_size
h, w = img_tensor.shape[2], img_tensor.shape[3]
# Pad if necessary
pad_h = (WINDOW_SIZE - h % WINDOW_SIZE) % WINDOW_SIZE
pad_w = (WINDOW_SIZE - w % WINDOW_SIZE) % WINDOW_SIZE
if pad_h > 0 or pad_w > 0:
img_tensor = torch.nn.functional.pad(img_tensor, (0, pad_w, 0, pad_h), mode='reflect')
with torch.no_grad():
output = model(img_tensor)
# Remove padding if it was added
if pad_h > 0 or pad_w > 0:
output = output[:, :, :h*UPSCALE_FACTOR, :w*UPSCALE_FACTOR]
# Convert back to PIL image
output_np = output.squeeze(0).permute(1, 2, 0).cpu().numpy()
output_np = np.clip(output_np * 255.0, 0, 255).astype(np.uint8)
return Image.fromarray(output_np)
def image_to_base64(image_path):
"""Convert image to base64 data URL for CSS background."""
img = Image.open(image_path)
img.thumbnail((120, 120), Image.Resampling.LANCZOS)
buffer = BytesIO()
img.save(buffer, format='PNG')
img_str = base64.b64encode(buffer.getvalue()).decode()
return f"data:image/png;base64,{img_str}"