File size: 8,379 Bytes
62c4995
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
301b940
62c4995
301b940
62c4995
 
 
 
 
47112d4
 
 
 
 
 
 
 
 
 
 
 
 
62c4995
 
47112d4
62c4995
 
 
 
47112d4
62c4995
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47112d4
 
c32c169
 
 
 
 
 
62c4995
 
 
 
 
 
 
 
 
 
 
47112d4
c32c169
62c4995
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
301b940
62c4995
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
301b940
62c4995
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
874a882
 
 
 
 
 
 
62c4995
874a882
62c4995
 
 
 
 
 
 
 
 
 
 
 
 
f188628
7dad940
66b3113
62c4995
 
 
55280e0
62c4995
 
 
ea0bb69
62c4995
00cd4df
62c4995
 
 
2cb4eb3
d1c598e
62c4995
 
 
 
 
 
 
 
f188628
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
import gradio as gr
import torch
import os
import sys
from PIL import Image, ImageDraw
from transformers import AutoModel, AutoProcessor, AutoTokenizer, GenerationConfig
from huggingface_hub import snapshot_download
import spaces
from typing import Optional, Tuple, Dict, Any, Iterable
from gradio.themes import Soft
from gradio.themes.utils import colors, fonts, sizes

print("Downloading model snapshot to ensure all scripts are present...")
model_dir = snapshot_download(repo_id="nvidia/NVIDIA-Nemotron-Parse-v1.1")
print(f"Model downloaded to: {model_dir}")

sys.path.append(model_dir)

try:
    from postprocessing import extract_classes_bboxes, transform_bbox_to_original, postprocess_text
    print("Successfully imported postprocessing functions.")
except ImportError as e:
    print(f"Error importing postprocessing: {e}")
    raise e

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

colors.steel_blue = colors.Color(
    name="steel_blue",
    c50="#EBF3F8",
    c100="#D3E5F0",
    c200="#A8CCE1",
    c300="#7DB3D2",
    c400="#529AC3",
    c500="#4682B4",
    c600="#3E72A0",
    c700="#36638C",
    c800="#2E5378",
    c900="#264364",
    c950="#1E3450",
)

class SteelBlueTheme(Soft):
    def __init__(
        self,
        *,
        primary_hue: colors.Color | str = colors.gray,
        secondary_hue: colors.Color | str = colors.steel_blue,
        neutral_hue: colors.Color | str = colors.slate,
        text_size: sizes.Size | str = sizes.text_lg,
        font: fonts.Font | str | Iterable[fonts.Font | str] = (
            fonts.GoogleFont("Outfit"), "Arial", "sans-serif",
        ),
        font_mono: fonts.Font | str | Iterable[fonts.Font | str] = (
            fonts.GoogleFont("IBM Plex Mono"), "ui-monospace", "monospace",
        ),
    ):
        super().__init__(
            primary_hue=primary_hue,
            secondary_hue=secondary_hue,
            neutral_hue=neutral_hue,
            text_size=text_size,
            font=font,
            font_mono=font_mono,
        )
        super().set(
            background_fill_primary="*primary_50",
            background_fill_primary_dark="*primary_900",
            body_background_fill="linear-gradient(135deg, *primary_200, *primary_100)",
            body_background_fill_dark="linear-gradient(135deg, *primary_900, *primary_800)",
            button_primary_text_color="white",
            button_primary_text_color_hover="white",
            button_primary_background_fill="linear-gradient(90deg, *secondary_500, *secondary_600)",
            button_primary_background_fill_hover="linear-gradient(90deg, *secondary_600, *secondary_700)",
            button_primary_background_fill_dark="linear-gradient(90deg, *secondary_600, *secondary_800)",
            button_primary_background_fill_hover_dark="linear-gradient(90deg, *secondary_500, *secondary_500)",
            button_secondary_text_color="black",
            button_secondary_text_color_hover="white",
            button_secondary_background_fill="linear-gradient(90deg, *primary_300, *primary_300)",
            button_secondary_background_fill_hover="linear-gradient(90deg, *primary_400, *primary_400)",
            button_secondary_background_fill_dark="linear-gradient(90deg, *primary_500, *primary_600)",
            button_secondary_background_fill_hover_dark="linear-gradient(90deg, *primary_500, *primary_500)",
            slider_color="*secondary_500",
            slider_color_dark="*secondary_600",
            block_title_text_weight="600",
            block_border_width="3px",
            block_shadow="*shadow_drop_lg",
            button_primary_shadow="*shadow_drop_lg",
            button_large_padding="11px",
            color_accent_soft="*primary_100",
            block_label_background_fill="*primary_200",
        )

steel_blue_theme = SteelBlueTheme()

css = """
#main-title h1 { font-size: 2.3em !important; }
#output-title h2 { font-size: 2.1em !important; }
"""

print("Loading Model components...")

processor = AutoProcessor.from_pretrained(model_dir, trust_remote_code=True)
model = AutoModel.from_pretrained(
    model_dir,
    trust_remote_code=True,
    torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32
).to(device).eval()

try:
    generation_config = GenerationConfig.from_pretrained(model_dir, trust_remote_code=True)
except Exception as e:
    print(f"Warning: Could not load GenerationConfig: {e}. Using default.")
    generation_config = GenerationConfig(max_new_tokens=4096)

print("Model loaded successfully.")

@spaces.GPU
def process_ocr_task(image):
    """
    Processes an image with NVIDIA-Nemotron-Parse-v1.1.
    """
    if image is None:
        return "Please upload an image first.", None
        
    task_prompt = "</s><s><predict_bbox><predict_classes><output_markdown>"
    
    inputs = processor(images=[image], text=task_prompt, return_tensors="pt").to(device)
    
    if device.type == 'cuda':
        inputs = {k: v.to(torch.bfloat16) if v.dtype == torch.float32 else v for k, v in inputs.items()}

    print("👊 Running inference...")
    with torch.no_grad():
        outputs = model.generate(
            **inputs, 
            generation_config=generation_config
        )

    generated_text = processor.batch_decode(outputs, skip_special_tokens=True)[0]
    
    try:
        classes, bboxes, texts = extract_classes_bboxes(generated_text)
    except Exception as e:
        print(f"Error extracting boxes: {e}")
        return generated_text, image

    bboxes = [transform_bbox_to_original(bbox, image.width, image.height) for bbox in bboxes]

    table_format = 'latex'  
    text_format = 'markdown'
    blank_text_in_figures = False 

    processed_texts = [
        postprocess_text(
            text, 
            cls=cls, 
            table_format=table_format, 
            text_format=text_format, 
            blank_text_in_figures=blank_text_in_figures
        ) 
        for text, cls in zip(texts, classes)
    ]

    result_image = image.copy()
    draw = ImageDraw.Draw(result_image)
    
    color_map = {
        "Table": "red",
        "Figure": "blue",
        "Text": "green",
        "Title": "purple"
    }

    final_output_text = ""
    
    for cls, bbox, txt in zip(classes, bboxes, processed_texts):
        # Normalize coordinates to prevent PIL ValueError (x1 >= x0)
        x1, y1, x2, y2 = bbox
        xmin = min(x1, x2)
        ymin = min(y1, y2)
        xmax = max(x1, x2)
        ymax = max(y1, y2)
        
        color = color_map.get(cls, "red")
        draw.rectangle([xmin, ymin, xmax, ymax], outline=color, width=3)
        
        if cls == "Table":
            final_output_text += f"\n\n--- [Table] ---\n{txt}\n-----------------\n"
        elif cls == "Figure":
            final_output_text += f"\n\n--- [Figure] ---\n(Figure Detected)\n-----------------\n"
        else:
            final_output_text += f"{txt}\n"

    if not final_output_text.strip() and generated_text:
        final_output_text = generated_text

    return final_output_text, result_image

with gr.Blocks() as demo:
    gr.Markdown("# **NVIDIA Nemotron Parse OCR**", elem_id="main-title")
    gr.Markdown("Upload a document image to extract text, tables, and layout structures using NVIDIA's [Nemotron Parse](https://huggingface.co/nvidia/NVIDIA-Nemotron-Parse-v1.1) model.")

    with gr.Row():
        with gr.Column(scale=1):
            image_input = gr.Image(type="pil", label="Upload Image", sources=["upload", "clipboard"], height=400)
            submit_btn = gr.Button("Process Document", variant="primary")

            examples = gr.Examples(
                examples=["examples/1.jpg", "examples/2.jpg", "examples/3.jpg", "examples/4.jpg", "examples/5.jpg"],
                inputs=image_input, 
                label="Examples"
            )

        with gr.Column(scale=2):
            output_text = gr.Textbox(label="Parsed Content (Markdown/LaTeX)", lines=12, interactive=True)
            output_image = gr.Image(label="Layout Detection", type="pil")
            
    submit_btn.click(
        fn=process_ocr_task, 
        inputs=[image_input], 
        outputs=[output_text, output_image]
    )

if __name__ == "__main__":
    demo.queue(max_size=30).launch(css=css, theme=steel_blue_theme, share=True, mcp_server=True, ssr_mode=False)