BaoNhan commited on
Commit
3a79903
·
verified ·
1 Parent(s): 99e6a1c

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +517 -0
app.py ADDED
@@ -0,0 +1,517 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --- RAG / Semantic Search imports ---
2
+ import numpy as np
3
+ import traceback
4
+ import torch
5
+ from langchain_text_splitters import MarkdownHeaderTextSplitter, RecursiveCharacterTextSplitter
6
+ from langchain_huggingface import HuggingFaceEmbeddings
7
+ from openai import OpenAI
8
+ import time
9
+
10
+ # --- Initialize OpenAI Gemini client ---
11
+ client = OpenAI(
12
+ api_key="AIzaSyCnImdGunjyiEW7CS_N-xRP5VGAe1MIIgg",
13
+ base_url="https://generativelanguage.googleapis.com/v1beta/openai/"
14
+ )
15
+
16
+ # --- Functions for RAG ---
17
+ def md_to_kb_safe(md_text, embedding_model_name="sentence-transformers/all-MiniLM-L6-v2"):
18
+ try:
19
+ headers_to_split_on = [("#", "Header 1"), ("##", "Header 2"), ("###", "Header 3")]
20
+ splitter = MarkdownHeaderTextSplitter(headers_to_split_on=headers_to_split_on)
21
+ md_chunks = splitter.split_text(md_text)
22
+ text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200, length_function=len)
23
+ final_chunks = text_splitter.split_documents(md_chunks)
24
+ texts = [doc.page_content for doc in final_chunks]
25
+ device = "cuda" if torch.cuda.is_available() and torch.cuda.memory_allocated() < 2_000_000_000 else "cpu"
26
+ embedding_model = HuggingFaceEmbeddings(model_name=embedding_model_name, model_kwargs={"device": device})
27
+ vectors = embedding_model.embed_documents(texts)
28
+ kb = [{"text": texts[i], "vector": vectors[i]} for i in range(len(texts))]
29
+ return {"success": True, "num_chunks": len(final_chunks), "kb": kb, "embed_model": embedding_model}
30
+ except Exception as e:
31
+ return {"success": False, "error": str(e), "traceback": traceback.format_exc()}
32
+
33
+ def cosine_similarity(v1, v2):
34
+ return np.dot(v1, v2) / (np.linalg.norm(v1) * np.linalg.norm(v2))
35
+
36
+ def semantic_search(query, embed_model, kb, top_k=3):
37
+ t0 = time.time()
38
+ q_vec = np.array(embed_model.embed_query(query))
39
+ scores = [(cosine_similarity(q_vec, item["vector"]), item["text"]) for item in kb]
40
+ scores.sort(reverse=True, key=lambda x: x[0])
41
+ return scores[:top_k], time.time() - t0
42
+
43
+ def build_context(results):
44
+ ctx = ""
45
+ for i, (score, chunk) in enumerate(results):
46
+ ctx += f"=== Context {i+1} ===\n{chunk}\n\n"
47
+ return ctx
48
+
49
+ def rag_answer(query, embed_model, kb):
50
+ t0 = time.time()
51
+ results, t_semantic = semantic_search(query, embed_model, kb, top_k=3)
52
+ context = build_context(results)
53
+ prompt = f"""Use ONLY the information in the following context.
54
+
55
+ {context}
56
+
57
+ Question: {query}
58
+
59
+ If the answer is not in the context, respond EXACTLY with:
60
+ "I do not have enough information to answer that."
61
+ """
62
+ response = client.chat.completions.create(
63
+ model="gemini-2.5-pro",
64
+ temperature=0,
65
+ messages=[
66
+ {"role": "system", "content": "Answer strictly using the context."},
67
+ {"role": "user", "content": prompt}
68
+ ]
69
+ )
70
+ answer = response.choices[0].message.content
71
+ return answer, t_semantic, time.time() - t0
72
+
73
+ def evaluate_ai(response, true_answer):
74
+ t0 = time.time()
75
+ eval_prompt = f"""
76
+ AI Response: {response}
77
+ Ground Truth: {true_answer}
78
+
79
+ Rules:
80
+ - 1 = very close to true answer
81
+ - 0.5 = partially correct
82
+ - 0 = incorrect
83
+ """
84
+ response = client.chat.completions.create(
85
+ model="gemini-2.5-pro",
86
+ temperature=0,
87
+ messages=[
88
+ {"role": "system", "content": "You are an evaluation system."},
89
+ {"role": "user", "content": eval_prompt}
90
+ ]
91
+ )
92
+ return response.choices[0].message.content, time.time() - t0
93
+
94
+ def run_rag_pipeline(md_text_input, query, true_answer):
95
+ kb_result = md_to_kb_safe(md_text_input)
96
+ if not kb_result["success"]:
97
+ return f"Error creating KB:\n{kb_result['error']}", None, None
98
+ kb = kb_result["kb"]
99
+ embed_model = kb_result["embed_model"]
100
+ answer, t_semantic, t_rag = rag_answer(query, embed_model, kb)
101
+ score, t_eval = evaluate_ai(answer, true_answer)
102
+ timings = f"Semantic Search: {t_semantic:.2f}s | LLM Answer: {t_rag:.2f}s | Evaluation: {t_eval:.2f}s"
103
+ return answer, score, timings
104
+ import base64
105
+ import os
106
+ import re
107
+ import time
108
+ import zipfile
109
+ from pathlib import Path
110
+
111
+ import click
112
+ import gradio as gr
113
+ from gradio_pdf import PDF
114
+ from loguru import logger
115
+
116
+ from mineru.cli.common import prepare_env, read_fn, aio_do_parse, pdf_suffixes, image_suffixes
117
+ from mineru.utils.cli_parser import arg_parse
118
+ from mineru.utils.hash_utils import str_sha256
119
+
120
+
121
+ async def parse_pdf(doc_path, output_dir, end_page_id, is_ocr, formula_enable, table_enable, language, backend, url):
122
+ os.makedirs(output_dir, exist_ok=True)
123
+
124
+ try:
125
+ file_name = f'{safe_stem(Path(doc_path).stem)}_{time.strftime("%y%m%d_%H%M%S")}'
126
+ pdf_data = read_fn(doc_path)
127
+ if is_ocr:
128
+ parse_method = 'ocr'
129
+ else:
130
+ parse_method = 'auto'
131
+
132
+ if backend.startswith("vlm"):
133
+ parse_method = "vlm"
134
+
135
+ local_image_dir, local_md_dir = prepare_env(output_dir, file_name, parse_method)
136
+ await aio_do_parse(
137
+ output_dir=output_dir,
138
+ pdf_file_names=[file_name],
139
+ pdf_bytes_list=[pdf_data],
140
+ p_lang_list=[language],
141
+ parse_method=parse_method,
142
+ end_page_id=end_page_id,
143
+ formula_enable=formula_enable,
144
+ table_enable=table_enable,
145
+ backend=backend,
146
+ server_url=url,
147
+ )
148
+ return local_md_dir, file_name
149
+ except Exception as e:
150
+ logger.exception(e)
151
+ return None
152
+
153
+
154
+ def compress_directory_to_zip(directory_path, output_zip_path):
155
+ """压缩指定目录到一个 ZIP 文件。
156
+
157
+ :param directory_path: 要压缩的目录路径
158
+ :param output_zip_path: 输出的 ZIP 文件路径
159
+ """
160
+ try:
161
+ with zipfile.ZipFile(output_zip_path, 'w', zipfile.ZIP_DEFLATED) as zipf:
162
+
163
+ # 遍历目录中的所有文件和子目录
164
+ for root, dirs, files in os.walk(directory_path):
165
+ for file in files:
166
+ # 构建完整的文件路径
167
+ file_path = os.path.join(root, file)
168
+ # 计算相对路径
169
+ arcname = os.path.relpath(file_path, directory_path)
170
+ # 添加文件到 ZIP 文件
171
+ zipf.write(file_path, arcname)
172
+ return 0
173
+ except Exception as e:
174
+ logger.exception(e)
175
+ return -1
176
+
177
+
178
+ def image_to_base64(image_path):
179
+ with open(image_path, 'rb') as image_file:
180
+ return base64.b64encode(image_file.read()).decode('utf-8')
181
+
182
+
183
+ def replace_image_with_base64(markdown_text, image_dir_path):
184
+ # 匹配Markdown中的图片标签
185
+ pattern = r'\!\[(?:[^\]]*)\]\(([^)]+)\)'
186
+
187
+ # 替换图片链接
188
+ def replace(match):
189
+ relative_path = match.group(1)
190
+ full_path = os.path.join(image_dir_path, relative_path)
191
+ base64_image = image_to_base64(full_path)
192
+ return f'![{relative_path}](data:image/jpeg;base64,{base64_image})'
193
+
194
+ # 应用替换
195
+ return re.sub(pattern, replace, markdown_text)
196
+
197
+
198
+ async def to_markdown(file_path, end_pages=10, is_ocr=False, formula_enable=True, table_enable=True, language="ch", backend="pipeline", url=None):
199
+ file_path = to_pdf(file_path)
200
+ # 获取识别的md文件以及压缩包文件路径
201
+ local_md_dir, file_name = await parse_pdf(file_path, './output', end_pages - 1, is_ocr, formula_enable, table_enable, language, backend, url)
202
+ archive_zip_path = os.path.join('./output', str_sha256(local_md_dir) + '.zip')
203
+ zip_archive_success = compress_directory_to_zip(local_md_dir, archive_zip_path)
204
+ if zip_archive_success == 0:
205
+ logger.info('Compression successful')
206
+ else:
207
+ logger.error('Compression failed')
208
+ md_path = os.path.join(local_md_dir, file_name + '.md')
209
+ with open(md_path, 'r', encoding='utf-8') as f:
210
+ txt_content = f.read()
211
+ md_content = replace_image_with_base64(txt_content, local_md_dir)
212
+ # 返回转换后的PDF路径
213
+ new_pdf_path = os.path.join(local_md_dir, file_name + '_layout.pdf')
214
+
215
+ return md_content, txt_content, archive_zip_path, new_pdf_path
216
+ import asyncio
217
+ import traceback
218
+
219
+ async def to_markdown_safe(file_path, end_pages=10, is_ocr=False,
220
+ formula_enable=True, table_enable=True,
221
+ language="ch", backend="pipeline", url=None):
222
+ try:
223
+ return await to_markdown(file_path, end_pages, is_ocr,
224
+ formula_enable, table_enable,
225
+ language, backend, url)
226
+ except Exception as e:
227
+ err_msg = traceback.format_exc()
228
+ logger.error(f"Error in to_markdown: {err_msg}")
229
+ # trả về giá trị mặc định để Gradio không crash
230
+ return f"Error: {str(e)}", err_msg, None, None
231
+
232
+
233
+ latex_delimiters_type_a = [
234
+ {'left': '$$', 'right': '$$', 'display': True},
235
+ {'left': '$', 'right': '$', 'display': False},
236
+ ]
237
+ latex_delimiters_type_b = [
238
+ {'left': '\\(', 'right': '\\)', 'display': False},
239
+ {'left': '\\[', 'right': '\\]', 'display': True},
240
+ ]
241
+ latex_delimiters_type_all = latex_delimiters_type_a + latex_delimiters_type_b
242
+
243
+
244
+ header = """
245
+ <html><head><link rel="stylesheet"href="https://use.fontawesome.com/releases/v5.15.4/css/all.css"><style>.link-block{border:1px solid transparent;border-radius:24px;background-color:rgba(54,54,54,1);cursor:pointer!important}.link-block:hover{background-color:rgba(54,54,54,0.75)!important;cursor:pointer!important}.external-link{display:inline-flex;align-items:center;height:36px;line-height:36px;padding:0 16px;cursor:pointer!important}.external-link,.external-link:hover{cursor:pointer!important}a{text-decoration:none}</style></head><body><div style="
246
+ display: flex;
247
+ flex-direction: column;
248
+ justify-content: center;
249
+ align-items: center;
250
+ text-align: center;
251
+ background: linear-gradient(45deg, #007bff 0%, #0056b3 100%);
252
+ padding: 24px;
253
+ gap: 24px;
254
+ border-radius: 8px;
255
+ "><div style="
256
+ display: flex;
257
+ flex-direction: column;
258
+ align-items: center;
259
+ gap: 16px;
260
+ "><div style="display: flex; flex-direction: column; gap: 8px"><h1 style="
261
+ font-size: 48px;
262
+ color: #fafafa;
263
+ margin: 0;
264
+ font-family: 'Trebuchet MS', 'Lucida Sans Unicode',
265
+ 'Lucida Grande', 'Lucida Sans', Arial, sans-serif;
266
+ ">MinerU 2.5:PDF Extraction Demo</h1></div></div><p style="
267
+ margin: 0;
268
+ line-height: 1.6rem;
269
+ font-size: 16px;
270
+ color: #fafafa;
271
+ opacity: 0.8;
272
+ ">A one-stop,open-source,high-quality data extraction tool that supports converting PDF to Markdown and JSON.<br></p><style>.link-block{display:inline-block}.link-block+.link-block{margin-left:20px}</style><div class="column has-text-centered"><div class="publication-links"><!--Code Link.--><span class="link-block"><a href="https://github.com/opendatalab/MinerU"class="external-link button is-normal is-rounded is-dark"style="text-decoration: none; cursor: pointer"><span class="icon"style="margin-right: 4px"><i class="fab fa-github"style="color: white; margin-right: 4px"></i></span><span style="color: white">Code</span></a></span><!--Code Link.--><span class="link-block"><a href="https://huggingface.co/opendatalab/MinerU2.5-2509-1.2B"class="external-link button is-normal is-rounded is-dark"style="text-decoration: none; cursor: pointer"><span class="icon"style="margin-right: 4px"><i class="fas fa-archive"style="color: white; margin-right: 4px"></i></span><span style="color: white">Model</span></a></span><!--arXiv Link.--><span class="link-block"><a href="https://arxiv.org/abs/2409.18839"class="external-link button is-normal is-rounded is-dark"style="text-decoration: none; cursor: pointer"><span class="icon"style="margin-right: 8px"><i class="fas fa-file"style="color: white"></i></span><span style="color: white">Paper</span></a></span><!--Homepage Link.--><span class="link-block"><a href="https://mineru.net/home?source=online"class="external-link button is-normal is-rounded is-dark"style="text-decoration: none; cursor: pointer"><span class="icon"style="margin-right: 8px"><i class="fas fa-home"style="color: white"></i></span><span style="color: white">Homepage</span></a></span><!--Client Link.--><span class="link-block"><a href="https://mineru.net/client?source=online"class="external-link button is-normal is-rounded is-dark"style="text-decoration: none; cursor: pointer"><span class="icon"style="margin-right: 8px"><i class="fas fa-download"style="color: white"></i></span><span style="color: white">Download</span></a></span></div></div><!--New Demo Links--></div></body></html>
273
+ """
274
+
275
+
276
+ latin_lang = [
277
+ 'af', 'az', 'bs', 'cs', 'cy', 'da', 'de', 'es', 'et', 'fr', 'ga', 'hr', # noqa: E126
278
+ 'hu', 'id', 'is', 'it', 'ku', 'la', 'lt', 'lv', 'mi', 'ms', 'mt', 'nl',
279
+ 'no', 'oc', 'pi', 'pl', 'pt', 'ro', 'rs_latin', 'sk', 'sl', 'sq', 'sv',
280
+ 'sw', 'tl', 'tr', 'uz', 'vi', 'french', 'german'
281
+ ]
282
+ arabic_lang = ['ar', 'fa', 'ug', 'ur']
283
+ cyrillic_lang = [
284
+ 'rs_cyrillic', 'bg', 'mn', 'abq', 'ady', 'kbd', 'ava', # noqa: E126
285
+ 'dar', 'inh', 'che', 'lbe', 'lez', 'tab'
286
+ ]
287
+ east_slavic_lang = ["ru", "be", "uk"]
288
+ devanagari_lang = [
289
+ 'hi', 'mr', 'ne', 'bh', 'mai', 'ang', 'bho', 'mah', 'sck', 'new', 'gom', # noqa: E126
290
+ 'sa', 'bgc'
291
+ ]
292
+ other_lang = ['ch', 'ch_lite', 'ch_server', 'en', 'korean', 'japan', 'chinese_cht', 'ta', 'te', 'ka', "el", "th"]
293
+ add_lang = ['latin', 'arabic', 'east_slavic', 'cyrillic', 'devanagari']
294
+
295
+ # all_lang = ['', 'auto']
296
+ all_lang = []
297
+ # all_lang.extend([*other_lang, *latin_lang, *arabic_lang, *cyrillic_lang, *devanagari_lang])
298
+ all_lang.extend([*other_lang, *add_lang])
299
+
300
+
301
+ def safe_stem(file_path):
302
+ stem = Path(file_path).stem
303
+ # 只保留字母、数字、下划线和点,其他字符替换为下划线
304
+ return re.sub(r'[^\w.]', '_', stem)
305
+
306
+
307
+ def to_pdf(file_path):
308
+
309
+ if file_path is None:
310
+ return None
311
+
312
+ pdf_bytes = read_fn(file_path)
313
+
314
+ # unique_filename = f'{uuid.uuid4()}.pdf'
315
+ unique_filename = f'{safe_stem(file_path)}.pdf'
316
+
317
+ # 构建完整的文件路径
318
+ tmp_file_path = os.path.join(os.path.dirname(file_path), unique_filename)
319
+
320
+ # 将字节数据写入文件
321
+ with open(tmp_file_path, 'wb') as tmp_pdf_file:
322
+ tmp_pdf_file.write(pdf_bytes)
323
+
324
+ return tmp_file_path
325
+
326
+
327
+ # 更新界面函数
328
+ def update_interface(backend_choice):
329
+ if backend_choice in ["vlm-transformers", "vlm-vllm-async-engine"]:
330
+ return gr.update(visible=False), gr.update(visible=False)
331
+ elif backend_choice in ["vlm-http-client"]:
332
+ return gr.update(visible=True), gr.update(visible=False)
333
+ elif backend_choice in ["pipeline"]:
334
+ return gr.update(visible=False), gr.update(visible=True)
335
+ else:
336
+ pass
337
+
338
+
339
+ @click.command(context_settings=dict(ignore_unknown_options=True, allow_extra_args=True))
340
+ @click.pass_context
341
+ @click.option(
342
+ '--enable-example',
343
+ 'example_enable',
344
+ type=bool,
345
+ help="Enable example files for input."
346
+ "The example files to be input need to be placed in the `example` folder within the directory where the command is currently executed.",
347
+ default=True,
348
+ )
349
+ @click.option(
350
+ '--enable-vllm-engine',
351
+ 'vllm_engine_enable',
352
+ type=bool,
353
+ help="Enable vLLM engine backend for faster processing.",
354
+ default=False,
355
+ )
356
+ @click.option(
357
+ '--enable-api',
358
+ 'api_enable',
359
+ type=bool,
360
+ help="Enable gradio API for serving the application.",
361
+ default=True,
362
+ )
363
+ @click.option(
364
+ '--max-convert-pages',
365
+ 'max_convert_pages',
366
+ type=int,
367
+ help="Set the maximum number of pages to convert from PDF to Markdown.",
368
+ default=1000,
369
+ )
370
+ @click.option(
371
+ '--server-name',
372
+ 'server_name',
373
+ type=str,
374
+ help="Set the server name for the Gradio app.",
375
+ default=None,
376
+ )
377
+ @click.option(
378
+ '--server-port',
379
+ 'server_port',
380
+ type=int,
381
+ help="Set the server port for the Gradio app.",
382
+ default=None,
383
+ )
384
+ @click.option(
385
+ '--latex-delimiters-type',
386
+ 'latex_delimiters_type',
387
+ type=click.Choice(['a', 'b', 'all']),
388
+ help="Set the type of LaTeX delimiters to use in Markdown rendering:"
389
+ "'a' for type '$', 'b' for type '()[]', 'all' for both types.",
390
+ default='all',
391
+ )
392
+ def main(ctx,
393
+ example_enable, vllm_engine_enable, api_enable, max_convert_pages,
394
+ server_name, server_port, latex_delimiters_type, **kwargs
395
+ ):
396
+
397
+ kwargs.update(arg_parse(ctx))
398
+
399
+ if latex_delimiters_type == 'a':
400
+ latex_delimiters = latex_delimiters_type_a
401
+ elif latex_delimiters_type == 'b':
402
+ latex_delimiters = latex_delimiters_type_b
403
+ elif latex_delimiters_type == 'all':
404
+ latex_delimiters = latex_delimiters_type_all
405
+ else:
406
+ raise ValueError(f"Invalid latex delimiters type: {latex_delimiters_type}.")
407
+
408
+ if vllm_engine_enable:
409
+ try:
410
+ print("Start init vLLM engine...")
411
+ from mineru.backend.vlm.vlm_analyze import ModelSingleton
412
+ model_singleton = ModelSingleton()
413
+ predictor = model_singleton.get_model(
414
+ "vllm-async-engine",
415
+ None,
416
+ None,
417
+ **kwargs
418
+ )
419
+ print("vLLM engine init successfully.")
420
+ except Exception as e:
421
+ logger.exception(e)
422
+ suffixes = [f".{suffix}" for suffix in pdf_suffixes + image_suffixes]
423
+ with gr.Blocks() as demo:
424
+ gr.HTML(header)
425
+ with gr.Row():
426
+ with gr.Column(variant='panel', scale=5):
427
+ with gr.Row():
428
+ input_file = gr.File(label='Please upload a PDF or image', file_types=suffixes)
429
+ with gr.Row():
430
+ max_pages = gr.Slider(1, max_convert_pages, int(max_convert_pages/2), step=1, label='Max convert pages')
431
+ with gr.Row():
432
+ if vllm_engine_enable:
433
+ drop_list = ["pipeline", "vlm-vllm-async-engine"]
434
+ preferred_option = "vlm-vllm-async-engine"
435
+ else:
436
+ drop_list = ["pipeline", "vlm-transformers", "vlm-http-client"]
437
+ preferred_option = "pipeline"
438
+ backend = gr.Dropdown(drop_list, label="Backend", value=preferred_option)
439
+ with gr.Row(visible=False) as client_options:
440
+ url = gr.Textbox(label='Server URL', value='http://localhost:30000', placeholder='http://localhost:30000')
441
+ with gr.Row(equal_height=True):
442
+ with gr.Column():
443
+ gr.Markdown("**Recognition Options:**")
444
+ formula_enable = gr.Checkbox(label='Enable formula recognition', value=True)
445
+ table_enable = gr.Checkbox(label='Enable table recognition', value=True)
446
+ with gr.Column(visible=False) as ocr_options:
447
+ language = gr.Dropdown(all_lang, label='Language', value='ch')
448
+ is_ocr = gr.Checkbox(label='Force enable OCR', value=False)
449
+ with gr.Row():
450
+ change_bu = gr.Button('Convert')
451
+ clear_bu = gr.ClearButton(value='Clear')
452
+ pdf_show = PDF(label='PDF preview', interactive=False, visible=True, height=800)
453
+ if example_enable:
454
+ example_root = os.path.join(os.getcwd(), 'examples')
455
+ if os.path.exists(example_root):
456
+ with gr.Accordion('Examples:'):
457
+ gr.Examples(
458
+ examples=[os.path.join(example_root, _) for _ in os.listdir(example_root) if
459
+ _.endswith(tuple(suffixes))],
460
+ inputs=input_file
461
+ )
462
+
463
+ with gr.Column(variant='panel', scale=5):
464
+ output_file = gr.File(label='convert result', interactive=False)
465
+ with gr.Tabs():
466
+ with gr.Tab('Markdown rendering'):
467
+ md = gr.Markdown(label='Markdown rendering', height=1100, show_copy_button=True,
468
+ latex_delimiters=latex_delimiters,
469
+ line_breaks=True)
470
+ with gr.Tab('Markdown text'):
471
+ md_text = gr.TextArea(lines=45, show_copy_button=True)
472
+ with gr.Tab("RAG QA"):
473
+ rag_md_text = gr.TextArea(label="Paste Markdown here", lines=15)
474
+ rag_query = gr.Textbox(label="Your Question")
475
+ rag_true = gr.Textbox(label="Ground Truth Answer (optional)")
476
+ rag_run = gr.Button("Run RAG")
477
+ rag_answer_out = gr.TextArea(label="RAG Answer", lines=15, interactive=False)
478
+ rag_score_out = gr.Textbox(label="Evaluation Score")
479
+ rag_timing_out = gr.Textbox(label="Timings")
480
+ rag_run.click(
481
+ fn=run_rag_pipeline,
482
+ inputs=[rag_md_text, rag_query, rag_true],
483
+ outputs=[rag_answer_out, rag_score_out, rag_timing_out]
484
+ )
485
+ # 添加事件处理
486
+ backend.change(
487
+ fn=update_interface,
488
+ inputs=[backend],
489
+ outputs=[client_options, ocr_options],
490
+ api_name=False
491
+ )
492
+ # 添加demo.load事件,在页面加载时触发一次界面更新
493
+ demo.load(
494
+ fn=update_interface,
495
+ inputs=[backend],
496
+ outputs=[client_options, ocr_options],
497
+ api_name=False
498
+ )
499
+ clear_bu.add([input_file, md, pdf_show, md_text, output_file, is_ocr])
500
+
501
+ if api_enable:
502
+ api_name = None
503
+ else:
504
+ api_name = False
505
+
506
+ input_file.change(fn=to_pdf, inputs=input_file, outputs=pdf_show, api_name=api_name)
507
+ change_bu.click(
508
+ fn=lambda *args: asyncio.run(to_markdown_safe(*args)),
509
+ inputs=[input_file, max_pages, is_ocr, formula_enable, table_enable, language, backend, url],
510
+ outputs=[md, md_text, output_file, pdf_show],
511
+ api_name=api_name
512
+ )
513
+
514
+
515
+ demo.launch(server_name=server_name, server_port=server_port, show_api=api_enable, height=1200)
516
+ if __name__ == "__main__":
517
+ main()