Skip to content

Commit 8f7b15b

Browse files
committed
fix eval
1 parent d702e75 commit 8f7b15b

File tree

4 files changed

+48
-20
lines changed

4 files changed

+48
-20
lines changed

mindocr/data/transforms/det_east_transforms.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
import numpy as np
55
from PIL import Image
66
from shapely.geometry import Polygon
7+
import json
8+
import ast
79

810
__all__ = ["EASTProcessTrain"]
911

@@ -414,8 +416,17 @@ def _extract_vertices(self, data_labels):
414416
"""
415417
vertices_list = []
416418
labels_list = []
417-
data_labels = eval(data_labels)
418-
for data_label in data_labels:
419+
try:
420+
parsed_data = json.loads(data_labels)
421+
except json.JSONDecodeError:
422+
try:
423+
parsed_data = ast.literal_eval(data_labels)
424+
except (ValueError, SyntaxError) as e:
425+
raise ValueError(f"Invalid data format: {str(e)}") from e
426+
427+
if not isinstance(parsed_data, list):
428+
raise ValueError("Data labels should be a list")
429+
for data_label in parsed_data:
419430
vertices = data_label["points"]
420431
vertices = [item for point in vertices for item in point]
421432
vertices_list.append(vertices)

mindocr/data/transforms/svtr_transform.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -546,10 +546,18 @@ def __init__(self, max_text_length, character_dict_path=None, use_space_char=Fal
546546

547547
self.ctc_encode = CTCLabelEncodeForSVTR(max_text_length, character_dict_path, use_space_char, **kwargs)
548548
self.gtc_encode_type = gtc_encode
549+
# Pls explicitly specify the supported gtc_encode classes and obtain the class objects through dictionaries.
550+
supported_gtc_encode = {}
549551
if gtc_encode is None:
550552
self.gtc_encode = SARLabelEncodeForSVTR(max_text_length, character_dict_path, use_space_char, **kwargs)
551553
else:
552-
self.gtc_encode = eval(gtc_encode)(max_text_length, character_dict_path, use_space_char, **kwargs)
554+
# Mindocr currently does not have a module that requires a custom `gtc_encode` input parameter, and will not
555+
# enter this branch at present. If it is supported later, please directly obtain the class reference through
556+
# a specific dict, and do not use the `eval` function.
557+
if gtc_encode not in supported_gtc_encode:
558+
raise ValueError(f"Get unsupported gtc_encode {gtc_encode}")
559+
self.gtc_encode = supported_gtc_encode[gtc_encode](max_text_length, character_dict_path,
560+
use_space_char, **kwargs)
553561

554562
def __call__(self, data):
555563
data_ctc = copy.deepcopy(data)

mindocr/postprocess/builder.py

Lines changed: 22 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -23,18 +23,24 @@
2323

2424
__all__ = ["build_postprocess"]
2525

26-
supported_postprocess = (
27-
det_db_postprocess.__all__
28-
+ det_pse_postprocess.__all__
29-
+ det_east_postprocess.__all__
30-
+ rec_postprocess.__all__
31-
+ cls_postprocess.__all__
32-
+ rec_abinet_postprocess.__all__
33-
+ kie_ser_postprocess.__all__
34-
+ kie_re_postprocess.__all__
35-
+ layout_postprocess.__all__
36-
+ table_postprocess.__all__
37-
)
26+
SUPPORTED_POSTPROCESS = {
27+
"DBPostprocess": DBPostprocess,
28+
"PSEPostprocess": PSEPostprocess,
29+
"EASTPostprocess": EASTPostprocess,
30+
"CTCLabelDecode": CTCLabelDecode,
31+
"RecCTCLabelDecode": RecCTCLabelDecode,
32+
"RecAttnLabelDecode": RecAttnLabelDecode,
33+
"RecMasterLabelDecode": RecMasterLabelDecode,
34+
"VisionLANPostProcess": VisionLANPostProcess,
35+
"SARLabelDecode": SARLabelDecode,
36+
"ClsPostprocess": ClsPostprocess,
37+
"ABINetLabelDecode": ABINetLabelDecode,
38+
"VQASerTokenLayoutLMPostProcess": VQASerTokenLayoutLMPostProcess,
39+
"VQAReTokenLayoutLMPostProcess": VQAReTokenLayoutLMPostProcess,
40+
"YOLOv8Postprocess": YOLOv8Postprocess,
41+
"Layoutlmv3Postprocess": Layoutlmv3Postprocess,
42+
"TableMasterLabelDecode": TableMasterLabelDecode,
43+
}
3844

3945

4046
def build_postprocess(config: dict):
@@ -57,11 +63,11 @@ def build_postprocess(config: dict):
5763
>>> postprocess
5864
"""
5965
proc = config.pop("name")
60-
if proc in supported_postprocess:
61-
postprocessor = eval(proc)(**config)
62-
elif proc is None:
66+
if proc is None:
6367
return None
68+
if proc in SUPPORTED_POSTPROCESS:
69+
postprocessor = SUPPORTED_POSTPROCESS[proc](**config)
6470
else:
65-
raise ValueError(f"Invalid postprocess name {proc}, support postprocess are {supported_postprocess}")
71+
raise ValueError(f"Invalid postprocess name {proc}, support postprocess are {SUPPORTED_POSTPROCESS.keys()}")
6672

6773
return postprocessor

tools/arg_parser.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,10 @@ def _parse_options(opts: list):
6262
"=" in opt_str
6363
), "Invalid option {}. A valid option must be in the format of {{key_name}}={{value}}".format(opt_str)
6464
k, v = opt_str.strip().split("=")
65-
options[k] = yaml.load(v, Loader=yaml.Loader)
65+
try:
66+
options[k] = yaml.load(v, Loader=yaml.SafeLoader)
67+
except yaml.YAMLError as e:
68+
raise ValueError(f"Failed to parse value for key '{k}': {str(e)}") from e
6669
# print('Parsed options: ', options)
6770

6871
return options

0 commit comments

Comments
 (0)