File size: 12,317 Bytes
2d3c552
321df30
 
 
 
 
 
 
4422fa9
4803558
 
 
4422fa9
 
 
 
 
 
4803558
4422fa9
 
 
 
 
 
 
 
 
 
4803558
4422fa9
 
 
 
321df30
4422fa9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
321df30
4422fa9
 
321df30
4422fa9
 
 
 
 
321df30
4422fa9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
321df30
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4422fa9
 
 
 
321df30
4422fa9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
321df30
4422fa9
321df30
 
 
 
 
 
 
4422fa9
 
321df30
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4422fa9
 
 
 
 
 
 
 
 
 
 
 
 
321df30
 
 
4422fa9
321df30
4422fa9
321df30
 
 
 
4422fa9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
321df30
 
 
 
4422fa9
321df30
4422fa9
 
 
 
 
321df30
 
 
 
 
 
 
 
 
4422fa9
 
 
321df30
 
 
 
 
 
4422fa9
321df30
 
 
4422fa9
 
 
 
 
 
 
 
 
 
 
 
 
 
321df30
4422fa9
 
 
 
 
 
 
 
321df30
 
 
4422fa9
321df30
4422fa9
321df30
 
4422fa9
321df30
 
4422fa9
321df30
 
 
4422fa9
321df30
 
4422fa9
321df30
 
 
4422fa9
 
 
321df30
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4422fa9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
321df30
 
4422fa9
321df30
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
import spaces
import torch
import gradio as gr
from docling_core.types.doc import DoclingDocument
from docling_core.types.doc.document import DocTagsDocument
from transformers import AutoProcessor, AutoModelForVision2Seq
from pathlib import Path
import tempfile
import os
import subprocess
import sys

# Try to install flash-attn at startup if not available
try:
    import flash_attn
    print("Flash attention already installed")
except ImportError:
    print("Flash attention not found, attempting to install...")
    try:
        subprocess.run(
            [sys.executable, "-m", "pip", "install", "flash-attn", "--no-build-isolation"],
            check=True,
            capture_output=True,
            text=True
        )
        print("Flash attention installed successfully")
    except subprocess.CalledProcessError as e:
        print(f"Could not install flash attention: {e}")
        print("Continuing without flash attention...")

# Global variables for model and processor
model = None
processor = None
model_loaded = False

def load_model():
    """Load the model and processor"""
    global model, processor, model_loaded
    
    if not model_loaded:
        try:
            # Load processor
            processor = AutoProcessor.from_pretrained("ibm-granite/granite-docling-258M")
            
            # Determine device
            device = "cuda" if torch.cuda.is_available() else "cpu"
            
            # Check if flash attention is available
            attn_implementation = "eager"  # default
            if device == "cuda":
                try:
                    import flash_attn
                    attn_implementation = "flash_attention_2"
                    print("Using Flash Attention 2")
                except ImportError:
                    print("Flash attention not available, using eager attention")
                    attn_implementation = "eager"
            
            # Load model with appropriate settings
            print(f"Loading model on {device} with {attn_implementation}...")
            
            if device == "cuda":
                # For GPU, use bfloat16 for better performance
                model = AutoModelForVision2Seq.from_pretrained(
                    "ibm-granite/granite-docling-258M",
                    torch_dtype=torch.bfloat16,
                    attn_implementation=attn_implementation,
                    device_map="auto",
                    trust_remote_code=True
                )
            else:
                # For CPU, use float32
                model = AutoModelForVision2Seq.from_pretrained(
                    "ibm-granite/granite-docling-258M",
                    torch_dtype=torch.float32,
                    attn_implementation="eager",
                    trust_remote_code=True
                )
                model = model.to(device)
            
            model_loaded = True
            print(f"Model loaded successfully on {device}")
            
        except Exception as e:
            print(f"Error loading model: {e}")
            # Fallback loading without special attention
            try:
                processor = AutoProcessor.from_pretrained("ibm-granite/granite-docling-258M")
                model = AutoModelForVision2Seq.from_pretrained(
                    "ibm-granite/granite-docling-258M",
                    torch_dtype=torch.float32,
                    trust_remote_code=True
                )
                device = "cpu"
                model = model.to(device)
                model_loaded = True
                print("Model loaded on CPU as fallback")
            except Exception as fallback_error:
                print(f"Fallback loading also failed: {fallback_error}")
                raise

# Load model at startup
load_model()

@spaces.GPU(duration=120)
def process_document_gpu(image, output_format="markdown"):
    """Process uploaded image to generate Docling document - GPU version"""
    global model, processor
    
    try:
        # Ensure model is loaded
        if not model_loaded:
            load_model()
        
        # Move model to GPU if available (for ZeroGPU)
        device = "cuda" if torch.cuda.is_available() else "cpu"
        
        # For ZeroGPU, the model might need to be moved to GPU
        if device == "cuda":
            # Only move if not already on cuda
            if hasattr(model, 'device') and model.device.type != 'cuda':
                model = model.to(device)
        
        print(f"Processing on {device}")
        
        # Prepare messages
        messages = [
            {
                "role": "user",
                "content": [
                    {"type": "image"},
                    {"type": "text", "text": "Convert this page to docling."}
                ]
            },
        ]
        
        # Prepare inputs
        prompt = processor.apply_chat_template(messages, add_generation_prompt=True)
        inputs = processor(text=prompt, images=[image], return_tensors="pt")
        
        # Move inputs to the same device as the model
        inputs = {k: v.to(device) if hasattr(v, 'to') else v for k, v in inputs.items()}
        
        # Generate outputs with memory-efficient settings
        with torch.no_grad():
            if device == "cuda":
                with torch.cuda.amp.autocast(dtype=torch.bfloat16):
                    generated_ids = model.generate(
                        **inputs, 
                        max_new_tokens=8192,
                        do_sample=False,
                        temperature=None,
                        top_p=None
                    )
            else:
                generated_ids = model.generate(
                    **inputs, 
                    max_new_tokens=8192,
                    do_sample=False
                )
        
        # Process the output
        prompt_length = inputs.input_ids.shape[1]
        trimmed_generated_ids = generated_ids[:, prompt_length:]
        doctags = processor.batch_decode(
            trimmed_generated_ids,
            skip_special_tokens=False,
        )[0].lstrip()
        
        print(f"Generated {len(doctags)} characters of DocTags")
        
        # Create Docling document
        doctags_doc = DocTagsDocument.from_doctags_and_image_pairs([doctags], [image])
        doc = DoclingDocument.load_from_doctags(doctags_doc, document_name="Document")
        
        # Generate output based on format
        if output_format == "markdown":
            content = doc.export_to_markdown()
            return content, None, None
        elif output_format == "html":
            # Create temporary file for HTML
            with tempfile.NamedTemporaryFile(mode='w', suffix='.html', delete=False) as tmp_file:
                doc.save_as_html(Path(tmp_file.name))
                html_file = tmp_file.name
            return None, html_file, None
        else:
            # Return both formats
            markdown_content = doc.export_to_markdown()
            with tempfile.NamedTemporaryFile(mode='w', suffix='.html', delete=False) as tmp_file:
                doc.save_as_html(Path(tmp_file.name))
                html_file = tmp_file.name
            return markdown_content, html_file, doctags
        
    except Exception as e:
        error_msg = f"Error processing document: {str(e)}"
        print(error_msg)
        import traceback
        print(traceback.format_exc())
        return error_msg, None, None

def process_document(image, output_format="markdown"):
    """Wrapper function to handle processing"""
    if image is None:
        return "Please upload an image first.", None, None
    
    # Call the GPU-decorated function
    return process_document_gpu(image, output_format)

def clear_results():
    """Clear all outputs"""
    return "", None, ""

# Create Gradio interface
with gr.Blocks(
    title="Docling Document Converter",
    theme=gr.themes.Soft(),
    css="""
    .header { 
        text-align: center; 
        margin-bottom: 2rem; 
    }
    .format-selector { 
        margin-top: 1rem; 
    }
    .markdown-output {
        max-height: 600px;
        overflow-y: auto;
        padding: 10px;
        border: 1px solid #ddd;
        border-radius: 5px;
        background-color: #f9f9f9;
    }
    """
) as demo:
    gr.Markdown(
        """
        # πŸ“„ Docling Document Converter
        
        Upload an image of a document page and convert it to structured markdown or HTML using the IBM Granite-Docling model.
        
        This space uses ZeroGPU for efficient processing. The model converts document images into structured formats while preserving layout and formatting.
        
        ---
        """,
        elem_classes="header"
    )
    
    with gr.Row():
        with gr.Column(scale=1):
            image_input = gr.Image(
                label="Upload Document Image",
                type="pil",
                height=400,
                sources=["upload", "clipboard"],
                show_label=True
            )
            
            format_choice = gr.Radio(
                choices=["markdown", "html", "both"],
                value="markdown",
                label="Output Format",
                info="Choose the output format for the converted document",
                elem_classes="format-selector"
            )
            
            with gr.Row():
                process_btn = gr.Button(
                    "πŸš€ Convert Document",
                    variant="primary",
                    size="lg",
                    scale=2
                )
                
                clear_btn = gr.Button(
                    "πŸ—‘οΈ Clear",
                    variant="secondary",
                    size="lg",
                    scale=1
                )
            
            # Status indicator
            gr.Markdown(
                """
                ### ℹ️ Tips:
                - Upload clear, high-resolution images for best results
                - The model works best with text documents, tables, and structured content
                - Processing may take a few moments depending on document complexity
                """
            )
        
        with gr.Column(scale=2):
            with gr.Tab("πŸ“ Markdown Output"):
                markdown_output = gr.Markdown(
                    value="",
                    label="Structured Markdown",
                    show_copy_button=True,
                    elem_classes="markdown-output"
                )
            
            with gr.Tab("🌐 HTML Output"):
                html_output = gr.File(
                    label="Download HTML File",
                    file_types=[".html"],
                    visible=True
                )
            
            with gr.Tab("🏷️ Raw DocTags"):
                doctags_output = gr.Textbox(
                    label="Raw DocTags Output",
                    lines=15,
                    max_lines=30,
                    show_copy_button=True,
                    placeholder="Raw DocTags will appear here after processing..."
                )
    
    # Event handlers
    process_btn.click(
        fn=process_document,
        inputs=[image_input, format_choice],
        outputs=[markdown_output, html_output, doctags_output],
        show_progress="full"
    )
    
    clear_btn.click(
        fn=clear_results,
        outputs=[markdown_output, html_output, doctags_output]
    )
    
    # Examples section
    with gr.Accordion("πŸ“š Example Documents", open=False):
        gr.Examples(
            examples=[
                ["https://huggingface.co/ibm-granite/granite-docling-258M/resolve/main/assets/new_arxiv.png"],
            ],
            inputs=[image_input],
            label="Click to load an example document",
            cache_examples=False
        )
    
    # Footer
    gr.Markdown(
        """
        ---
        <div style="text-align: center; margin-top: 2rem;">
            <p>Powered by <a href="https://huggingface.co/ibm-granite/granite-docling-258M" target="_blank">IBM Granite-Docling-258M</a></p>
            <p>Built with ❀️ using Gradio and Hugging Face Spaces</p>
        </div>
        """
    )

# Launch the app
if __name__ == "__main__":
    demo.launch()