File size: 2,046 Bytes
0def483
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
"""
Image processing utilities.
"""

import torch
import numpy as np
from PIL import Image
import base64
from io import BytesIO

from config import REQUIRED_IMAGE_SIZE, WINDOW_SIZE, UPSCALE_FACTOR


def validate_image_size(image):
    """Validate that the image is exactly the required size."""
    if image is None:
        return False, "No image provided"

    width, height = image.size
    req_width, req_height = REQUIRED_IMAGE_SIZE
    if width != req_width or height != req_height:
        return False, f"Image must be exactly {req_width}x{req_height} pixels. Your image is {width}x{height} pixels."

    return True, "Valid image size"


def upscale_image(image, model, device):
    """Upscale an image using the HAT model."""
    # Convert PIL image to tensor
    img_np = np.array(image).astype(np.float32) / 255.0
    img_tensor = torch.from_numpy(img_np).permute(2, 0, 1).unsqueeze(0).to(device)

    # Ensure the image dimensions are multiples of window_size
    h, w = img_tensor.shape[2], img_tensor.shape[3]

    # Pad if necessary
    pad_h = (WINDOW_SIZE - h % WINDOW_SIZE) % WINDOW_SIZE
    pad_w = (WINDOW_SIZE - w % WINDOW_SIZE) % WINDOW_SIZE

    if pad_h > 0 or pad_w > 0:
        img_tensor = torch.nn.functional.pad(img_tensor, (0, pad_w, 0, pad_h), mode='reflect')

    with torch.no_grad():
        output = model(img_tensor)

    # Remove padding if it was added
    if pad_h > 0 or pad_w > 0:
        output = output[:, :, :h*UPSCALE_FACTOR, :w*UPSCALE_FACTOR]

    # Convert back to PIL image
    output_np = output.squeeze(0).permute(1, 2, 0).cpu().numpy()
    output_np = np.clip(output_np * 255.0, 0, 255).astype(np.uint8)

    return Image.fromarray(output_np)


def image_to_base64(image_path):
    """Convert image to base64 data URL for CSS background."""
    img = Image.open(image_path)
    img.thumbnail((120, 120), Image.Resampling.LANCZOS)
    buffer = BytesIO()
    img.save(buffer, format='PNG')
    img_str = base64.b64encode(buffer.getvalue()).decode()
    return f"data:image/png;base64,{img_str}"