From 87b0acfa90b50371771549a016adbfa21af48af3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=AF=9B=E5=B7=B3=E7=85=9C?= Date: Wed, 14 May 2025 17:20:32 +0800 Subject: [PATCH] add convert2json method MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: 毛巳煜 --- .gitignore | 1 + .../pdf2markdown/configs/pdf2markdown.yaml | 1 + project/pdf2markdown/scripts/pdf2markdown.py | 136 +++++++++++++----- project/pdf2markdown/scripts/run_project.py | 9 +- 4 files changed, 108 insertions(+), 39 deletions(-) diff --git a/.gitignore b/.gitignore index ea20b51..32bb398 100644 --- a/.gitignore +++ b/.gitignore @@ -20,3 +20,4 @@ models/* # Sphinx documentation docs/*/_build/ +.idea \ No newline at end of file diff --git a/project/pdf2markdown/configs/pdf2markdown.yaml b/project/pdf2markdown/configs/pdf2markdown.yaml index 29014fc..365128e 100644 --- a/project/pdf2markdown/configs/pdf2markdown.yaml +++ b/project/pdf2markdown/configs/pdf2markdown.yaml @@ -2,6 +2,7 @@ inputs: assets/demo/formula_detection outputs: outputs/pdf2markdown visualize: True merge2markdown: True +output_json_list: True tasks: layout_detection: model: layout_detection_yolo diff --git a/project/pdf2markdown/scripts/pdf2markdown.py b/project/pdf2markdown/scripts/pdf2markdown.py index a8bdd98..8785ce4 100644 --- a/project/pdf2markdown/scripts/pdf2markdown.py +++ b/project/pdf2markdown/scripts/pdf2markdown.py @@ -1,3 +1,4 @@ +import json import os import re import gc @@ -38,6 +39,7 @@ def latex_rm_whitespace(s: str): break return s + def crop_img(input_res, input_pil_img, padding_x=0, padding_y=0): crop_xmin, crop_ymin = int(input_res['poly'][0]), int(input_res['poly'][1]) crop_xmax, crop_ymax = int(input_res['poly'][4]), int(input_res['poly'][5]) @@ -53,6 +55,7 @@ def crop_img(input_res, input_pil_img, padding_x=0, padding_y=0): return_list = [padding_x, padding_y, crop_xmin, crop_ymin, crop_xmax, crop_ymax, crop_new_width, crop_new_height] return return_image, return_list + @TASK_REGISTRY.register("pdf2markdown") class PDF2MARKDOWN(OCRTask): def __init__(self, layout_model, mfd_model, mfr_model, ocr_model): @@ -63,8 +66,8 @@ def __init__(self, layout_model, mfd_model, mfr_model, ocr_model): if self.mfr_model is not None: assert self.mfd_model is not None, "formula recognition based on formula detection, mfd_model can not be None." self.mfr_transform = transforms.Compose([self.mfr_model.vis_processor, ]) - - self.color_palette = { + + self.color_palette = { 'title': (255, 64, 255), 'plain text': (255, 255, 0), 'abandon': (0, 255, 255), @@ -94,17 +97,16 @@ def convert_format(self, yolo_res, id_to_names, ): } res_list.append(new_item) return res_list - - + def process_single_pdf(self, image_list): """predict on one image, reture text detection and recognition results. - + Args: image_list: List[PIL.Image.Image] - + Returns: List[dict]: list of PDF extract results - + Return example: [ { @@ -125,7 +127,7 @@ def process_single_pdf(self, image_list): "score": 0.97 }, ... - ], + ], "page_info": { "page_no": 0, "height": 2339, @@ -147,9 +149,9 @@ def process_single_pdf(self, image_list): layout_res = [] single_page_res = {'layout_dets': layout_res} single_page_res['page_info'] = dict( - page_no = idx, - height = img_H, - width = img_W + page_no=idx, + height=img_H, + width=img_W ) if self.mfd_model is not None: mfd_res = self.mfd_model.predict([image], "")[0] @@ -166,13 +168,13 @@ def process_single_pdf(self, image_list): latex_filling_list.append(new_item) bbox_img = image.crop((xmin, ymin, xmax, ymax)) mf_image_list.append(bbox_img) - + pdf_extract_res.append(single_page_res) - + del mfd_res torch.cuda.empty_cache() gc.collect() - + # Formula recognition, collect all formula images in whole pdf file, then batch infer them. if self.mfr_model is not None: a = time.time() @@ -187,10 +189,10 @@ def process_single_pdf(self, image_list): for res, latex in zip(latex_filling_list, mfr_res): res['latex'] = latex_rm_whitespace(latex) b = time.time() - print("formula nums:", len(mf_image_list), "mfr time:", round(b-a, 2)) - + print("formula nums:", len(mf_image_list), "mfr time:", round(b - a, 2)) + # ocr_res = self.ocr_model.predict(image) - + # ocr and table recognition for idx, image in enumerate(image_list): layout_res = pdf_extract_res[idx]['layout_dets'] @@ -258,13 +260,14 @@ def process_single_pdf(self, image_list): ocr_cost = round(time.time() - ocr_start, 2) print(f"ocr cost: {ocr_cost}") return pdf_extract_res - + def order_blocks(self, blocks): def calculate_oder(poly): xmin, ymin, _, _, xmax, ymax, _, _ = poly - return ymin*3000 + xmin + return ymin * 3000 + xmin + return sorted(blocks, key=lambda item: calculate_oder(item['poly'])) - + def convert2md(self, extract_res): blocks = [] spans = [] @@ -285,7 +288,7 @@ def convert2md(self, extract_res): blocks.append(item) else: blocks.append(item) - + blocks_types = ["title", "plain text", "figure_caption", "table_caption", "table_footnote", "isolate_formula", "formula_caption"] need_fix_bbox = [] @@ -296,9 +299,9 @@ def convert2md(self, extract_res): need_fix_bbox.append(block) else: final_block.append(block) - + block_with_spans, spans = fill_spans_in_blocks(need_fix_bbox, spans, 0.6) - + fix_blocks = fix_block_spans(block_with_spans) for para_block in fix_blocks: result = merge_para_with_text(para_block) @@ -307,23 +310,80 @@ def convert2md(self, extract_res): else: para_block['saved_info']['text'] = result final_block.append(para_block['saved_info']) - + final_block = self.order_blocks(final_block) md_text = "" for block in final_block: if block['category_type'] == "title": - md_text += "\n# "+block['text'] +"\n" + md_text += "\n# " + block['text'] + "\n" elif block['category_type'] in ["isolate_formula"]: - md_text += "\n"+block['latex']+"\n" + md_text += "\n" + block['latex'] + "\n" elif block['category_type'] in ["plain text", "figure_caption", "table_caption"]: - md_text += " "+block['text']+" " + md_text += " " + block['text'] + " " elif block['category_type'] in ["figure", "table"]: continue else: continue return md_text - - def process(self, input_path, save_dir=None, visualize=False, merge2markdown=False): + + def convert2json(self, extract_res): + blocks = [] + spans = [] + + for item in extract_res['layout_dets']: + if item['category_type'] in ['inline', 'text', 'isolated']: + text_key = 'text' if item['category_type'] == 'text' else 'latex' + xmin, ymin, _, _, xmax, ymax, _, _ = item['poly'] + spans.append( + { + "type": item['category_type'], + "bbox": [xmin, ymin, xmax, ymax], + "content": item[text_key] + } + ) + if item['category_type'] == "isolated": + item['category_type'] = "isolate_formula" + blocks.append(item) + else: + # 只有存在 text 且非空才加入 blocks,适用于某些文本类型 + if 'text' in item and item['text'].strip(): + blocks.append(item) + + blocks_types = [ + "title", "plain text", "figure_caption", "table_caption", + "table_footnote", "isolate_formula", "formula_caption" + ] + + need_fix_bbox = [] + final_block = [] + + for block in blocks: + block_type = block["category_type"] + if block_type in blocks_types: + need_fix_bbox.append(block) + else: + final_block.append(block) + + block_with_spans, spans = fill_spans_in_blocks(need_fix_bbox, spans, 0.6) + + fix_blocks = fix_block_spans(block_with_spans) + for para_block in fix_blocks: + result = merge_para_with_text(para_block) + if para_block['type'] == "isolate_formula": + para_block['saved_info']['latex'] = result + else: + para_block['saved_info']['text'] = result + final_block.append(para_block['saved_info']) + + final_block = self.order_blocks(final_block) + + # 返回 JSON 格式的结果 + return { + "blocks": final_block, + "spans": spans + } + + def process(self, input_path, save_dir=None, visualize=False, merge2markdown=False, output_json_list=False): file_list = self.prepare_input_files(input_path) res_list = [] for fpath in file_list: @@ -337,15 +397,15 @@ def process(self, input_path, save_dir=None, visualize=False, merge2markdown=Fal if save_dir: os.makedirs(save_dir, exist_ok=True) self.save_json_result(pdf_extract_res, os.path.join(save_dir, f"{basename}.json")) - + if merge2markdown: md_content = [] for extract_res in pdf_extract_res: md_text = self.convert2md(extract_res) md_content.append(md_text) - with open(os.path.join(save_dir, f"{basename}.md"), "w") as f: + with open(os.path.join(save_dir, f"{basename}.md"), "w", encoding="utf-8") as f: f.write("\n\n".join(md_content)) - + if visualize: for image, page_res in zip(images, pdf_extract_res): self.visualize_image(image, page_res['layout_dets'], cate2color=self.color_palette) @@ -355,8 +415,12 @@ def process(self, input_path, save_dir=None, visualize=False, merge2markdown=Fal else: images[0].save(os.path.join(save_dir, f"{basename}.png")) + if output_json_list: + json_list = [] + for extract_res in pdf_extract_res: + plain_json = self.convert2json(extract_res) + json_list.append(plain_json) + with open(os.path.join(save_dir, f"{basename}_list.json"), "w", encoding="utf-8") as f: + json.dump(json_list, f, ensure_ascii=False, indent=2) + return res_list - - - - \ No newline at end of file diff --git a/project/pdf2markdown/scripts/run_project.py b/project/pdf2markdown/scripts/run_project.py index 4c1605c..243e652 100644 --- a/project/pdf2markdown/scripts/run_project.py +++ b/project/pdf2markdown/scripts/run_project.py @@ -8,14 +8,15 @@ from pdf_extract_kit.utils.config_loader import load_config, initialize_tasks_and_models from pdf_extract_kit.registry.registry import TASK_REGISTRY - TASK_NAME = 'pdf2markdown' + def parse_args(): parser = argparse.ArgumentParser(description="Run a task with a given configuration file.") parser.add_argument('--config', type=str, required=True, help='Path to the configuration file.') return parser.parse_args() + def main(config_path): config = load_config(config_path) task_instances = initialize_tasks_and_models(config) @@ -25,17 +26,19 @@ def main(config_path): result_path = config.get('outputs', 'outputs/pdf_extract') visualize = config.get('visualize', False) merge2markdown = config.get('merge2markdown', False) + output_json_list = config.get('output_json_list', False) layout_model = task_instances['layout_detection'].model if 'layout_detection' in task_instances else None mfd_model = task_instances['formula_detection'].model if 'formula_detection' in task_instances else None mfr_model = task_instances['formula_recognition'].model if 'formula_recognition' in task_instances else None ocr_model = task_instances['ocr'].model if 'ocr' in task_instances else None - + pdf_extract_task = TASK_REGISTRY.get(TASK_NAME)(layout_model, mfd_model, mfr_model, ocr_model) - extract_results = pdf_extract_task.process(input_data, save_dir=result_path, visualize=visualize, merge2markdown=merge2markdown) + extract_results = pdf_extract_task.process(input_data, save_dir=result_path, visualize=visualize, merge2markdown=merge2markdown, output_json_list=output_json_list) print(f'Task done, results can be found at {result_path}') + if __name__ == "__main__": args = parse_args() main(args.config)