Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -20,3 +20,4 @@ models/*

# Sphinx documentation
docs/*/_build/
.idea
1 change: 1 addition & 0 deletions project/pdf2markdown/configs/pdf2markdown.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ inputs: assets/demo/formula_detection
outputs: outputs/pdf2markdown
visualize: True
merge2markdown: True
output_json_list: True
tasks:
layout_detection:
model: layout_detection_yolo
Expand Down
136 changes: 100 additions & 36 deletions project/pdf2markdown/scripts/pdf2markdown.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import json
import os
import re
import gc
Expand Down Expand Up @@ -38,6 +39,7 @@ def latex_rm_whitespace(s: str):
break
return s


def crop_img(input_res, input_pil_img, padding_x=0, padding_y=0):
crop_xmin, crop_ymin = int(input_res['poly'][0]), int(input_res['poly'][1])
crop_xmax, crop_ymax = int(input_res['poly'][4]), int(input_res['poly'][5])
Expand All @@ -53,6 +55,7 @@ def crop_img(input_res, input_pil_img, padding_x=0, padding_y=0):
return_list = [padding_x, padding_y, crop_xmin, crop_ymin, crop_xmax, crop_ymax, crop_new_width, crop_new_height]
return return_image, return_list


@TASK_REGISTRY.register("pdf2markdown")
class PDF2MARKDOWN(OCRTask):
def __init__(self, layout_model, mfd_model, mfr_model, ocr_model):
Expand All @@ -63,8 +66,8 @@ def __init__(self, layout_model, mfd_model, mfr_model, ocr_model):
if self.mfr_model is not None:
assert self.mfd_model is not None, "formula recognition based on formula detection, mfd_model can not be None."
self.mfr_transform = transforms.Compose([self.mfr_model.vis_processor, ])
self.color_palette = {

self.color_palette = {
'title': (255, 64, 255),
'plain text': (255, 255, 0),
'abandon': (0, 255, 255),
Expand Down Expand Up @@ -94,17 +97,16 @@ def convert_format(self, yolo_res, id_to_names, ):
}
res_list.append(new_item)
return res_list



def process_single_pdf(self, image_list):
"""predict on one image, reture text detection and recognition results.

Args:
image_list: List[PIL.Image.Image]

Returns:
List[dict]: list of PDF extract results

Return example:
[
{
Expand All @@ -125,7 +127,7 @@ def process_single_pdf(self, image_list):
"score": 0.97
},
...
],
],
"page_info": {
"page_no": 0,
"height": 2339,
Expand All @@ -147,9 +149,9 @@ def process_single_pdf(self, image_list):
layout_res = []
single_page_res = {'layout_dets': layout_res}
single_page_res['page_info'] = dict(
page_no = idx,
height = img_H,
width = img_W
page_no=idx,
height=img_H,
width=img_W
)
if self.mfd_model is not None:
mfd_res = self.mfd_model.predict([image], "")[0]
Expand All @@ -166,13 +168,13 @@ def process_single_pdf(self, image_list):
latex_filling_list.append(new_item)
bbox_img = image.crop((xmin, ymin, xmax, ymax))
mf_image_list.append(bbox_img)

pdf_extract_res.append(single_page_res)

del mfd_res
torch.cuda.empty_cache()
gc.collect()

# Formula recognition, collect all formula images in whole pdf file, then batch infer them.
if self.mfr_model is not None:
a = time.time()
Expand All @@ -187,10 +189,10 @@ def process_single_pdf(self, image_list):
for res, latex in zip(latex_filling_list, mfr_res):
res['latex'] = latex_rm_whitespace(latex)
b = time.time()
print("formula nums:", len(mf_image_list), "mfr time:", round(b-a, 2))
print("formula nums:", len(mf_image_list), "mfr time:", round(b - a, 2))

# ocr_res = self.ocr_model.predict(image)

# ocr and table recognition
for idx, image in enumerate(image_list):
layout_res = pdf_extract_res[idx]['layout_dets']
Expand Down Expand Up @@ -258,13 +260,14 @@ def process_single_pdf(self, image_list):
ocr_cost = round(time.time() - ocr_start, 2)
print(f"ocr cost: {ocr_cost}")
return pdf_extract_res

def order_blocks(self, blocks):
def calculate_oder(poly):
xmin, ymin, _, _, xmax, ymax, _, _ = poly
return ymin*3000 + xmin
return ymin * 3000 + xmin

return sorted(blocks, key=lambda item: calculate_oder(item['poly']))

def convert2md(self, extract_res):
blocks = []
spans = []
Expand All @@ -285,7 +288,7 @@ def convert2md(self, extract_res):
blocks.append(item)
else:
blocks.append(item)

blocks_types = ["title", "plain text", "figure_caption", "table_caption", "table_footnote", "isolate_formula", "formula_caption"]

need_fix_bbox = []
Expand All @@ -296,9 +299,9 @@ def convert2md(self, extract_res):
need_fix_bbox.append(block)
else:
final_block.append(block)

block_with_spans, spans = fill_spans_in_blocks(need_fix_bbox, spans, 0.6)

fix_blocks = fix_block_spans(block_with_spans)
for para_block in fix_blocks:
result = merge_para_with_text(para_block)
Expand All @@ -307,23 +310,80 @@ def convert2md(self, extract_res):
else:
para_block['saved_info']['text'] = result
final_block.append(para_block['saved_info'])

final_block = self.order_blocks(final_block)
md_text = ""
for block in final_block:
if block['category_type'] == "title":
md_text += "\n# "+block['text'] +"\n"
md_text += "\n# " + block['text'] + "\n"
elif block['category_type'] in ["isolate_formula"]:
md_text += "\n"+block['latex']+"\n"
md_text += "\n" + block['latex'] + "\n"
elif block['category_type'] in ["plain text", "figure_caption", "table_caption"]:
md_text += " "+block['text']+" "
md_text += " " + block['text'] + " "
elif block['category_type'] in ["figure", "table"]:
continue
else:
continue
return md_text

def process(self, input_path, save_dir=None, visualize=False, merge2markdown=False):

def convert2json(self, extract_res):
blocks = []
spans = []

for item in extract_res['layout_dets']:
if item['category_type'] in ['inline', 'text', 'isolated']:
text_key = 'text' if item['category_type'] == 'text' else 'latex'
xmin, ymin, _, _, xmax, ymax, _, _ = item['poly']
spans.append(
{
"type": item['category_type'],
"bbox": [xmin, ymin, xmax, ymax],
"content": item[text_key]
}
)
if item['category_type'] == "isolated":
item['category_type'] = "isolate_formula"
blocks.append(item)
else:
# 只有存在 text 且非空才加入 blocks,适用于某些文本类型
if 'text' in item and item['text'].strip():
blocks.append(item)

blocks_types = [
"title", "plain text", "figure_caption", "table_caption",
"table_footnote", "isolate_formula", "formula_caption"
]

need_fix_bbox = []
final_block = []

for block in blocks:
block_type = block["category_type"]
if block_type in blocks_types:
need_fix_bbox.append(block)
else:
final_block.append(block)

block_with_spans, spans = fill_spans_in_blocks(need_fix_bbox, spans, 0.6)

fix_blocks = fix_block_spans(block_with_spans)
for para_block in fix_blocks:
result = merge_para_with_text(para_block)
if para_block['type'] == "isolate_formula":
para_block['saved_info']['latex'] = result
else:
para_block['saved_info']['text'] = result
final_block.append(para_block['saved_info'])

final_block = self.order_blocks(final_block)

# 返回 JSON 格式的结果
return {
"blocks": final_block,
"spans": spans
}

def process(self, input_path, save_dir=None, visualize=False, merge2markdown=False, output_json_list=False):
file_list = self.prepare_input_files(input_path)
res_list = []
for fpath in file_list:
Expand All @@ -337,15 +397,15 @@ def process(self, input_path, save_dir=None, visualize=False, merge2markdown=Fal
if save_dir:
os.makedirs(save_dir, exist_ok=True)
self.save_json_result(pdf_extract_res, os.path.join(save_dir, f"{basename}.json"))

if merge2markdown:
md_content = []
for extract_res in pdf_extract_res:
md_text = self.convert2md(extract_res)
md_content.append(md_text)
with open(os.path.join(save_dir, f"{basename}.md"), "w") as f:
with open(os.path.join(save_dir, f"{basename}.md"), "w", encoding="utf-8") as f:
f.write("\n\n".join(md_content))

if visualize:
for image, page_res in zip(images, pdf_extract_res):
self.visualize_image(image, page_res['layout_dets'], cate2color=self.color_palette)
Expand All @@ -355,8 +415,12 @@ def process(self, input_path, save_dir=None, visualize=False, merge2markdown=Fal
else:
images[0].save(os.path.join(save_dir, f"{basename}.png"))

if output_json_list:
json_list = []
for extract_res in pdf_extract_res:
plain_json = self.convert2json(extract_res)
json_list.append(plain_json)
with open(os.path.join(save_dir, f"{basename}_list.json"), "w", encoding="utf-8") as f:
json.dump(json_list, f, ensure_ascii=False, indent=2)

return res_list




9 changes: 6 additions & 3 deletions project/pdf2markdown/scripts/run_project.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,15 @@
from pdf_extract_kit.utils.config_loader import load_config, initialize_tasks_and_models
from pdf_extract_kit.registry.registry import TASK_REGISTRY


TASK_NAME = 'pdf2markdown'


def parse_args():
parser = argparse.ArgumentParser(description="Run a task with a given configuration file.")
parser.add_argument('--config', type=str, required=True, help='Path to the configuration file.')
return parser.parse_args()


def main(config_path):
config = load_config(config_path)
task_instances = initialize_tasks_and_models(config)
Expand All @@ -25,17 +26,19 @@ def main(config_path):
result_path = config.get('outputs', 'outputs/pdf_extract')
visualize = config.get('visualize', False)
merge2markdown = config.get('merge2markdown', False)
output_json_list = config.get('output_json_list', False)

layout_model = task_instances['layout_detection'].model if 'layout_detection' in task_instances else None
mfd_model = task_instances['formula_detection'].model if 'formula_detection' in task_instances else None
mfr_model = task_instances['formula_recognition'].model if 'formula_recognition' in task_instances else None
ocr_model = task_instances['ocr'].model if 'ocr' in task_instances else None

pdf_extract_task = TASK_REGISTRY.get(TASK_NAME)(layout_model, mfd_model, mfr_model, ocr_model)
extract_results = pdf_extract_task.process(input_data, save_dir=result_path, visualize=visualize, merge2markdown=merge2markdown)
extract_results = pdf_extract_task.process(input_data, save_dir=result_path, visualize=visualize, merge2markdown=merge2markdown, output_json_list=output_json_list)

print(f'Task done, results can be found at {result_path}')


if __name__ == "__main__":
args = parse_args()
main(args.config)