@@ -32,22 +32,10 @@ class Config:
3232 subattr_3_coords = [67 , 584 , 560 , 624 ]
3333 subattr_4_coords = [67 , 636 , 560 , 676 ]
3434
35- class OCR :
36- def __init__ (self , model_weight = 'mn_model_weight.h5' , scale_ratio = 1 , ocr_model_artnames = None ):
37- self .scale_ratio = scale_ratio
38- self .characters = sorted (
39- [
40- * set (
41- "" .join (
42- sum (ArtsInfo .ArtNames [:- 2 ], [])
43- + ArtsInfo .TypeNames
44- + list (ArtsInfo .MainAttrNames .values ())
45- + list (ArtsInfo .SubAttrNames .values ())
46- + list (".,+%0123456789" )
47- )
48- )
49- ]
50- )
35+ class OCRModel :
36+ def __init__ (self , characters , model_weight , width , height , max_length ):
37+
38+ self .characters = characters
5139 # Mapping characters to integers
5240 self .char_to_num = StringLookup (
5341 vocabulary = list (self .characters ), num_oov_indices = 0 , mask_token = ""
@@ -57,21 +45,94 @@ def __init__(self, model_weight='mn_model_weight.h5', scale_ratio=1, ocr_model_a
5745 self .num_to_char = StringLookup (
5846 vocabulary = self .char_to_num .get_vocabulary (), oov_token = "" , mask_token = "" , invert = True
5947 )
48+
49+ self .width = width
50+ self .height = height
51+ self .max_length = max_length
52+
53+ self .model = OCRModel .build_model (characters = self .characters , input_shape = (self .width , self .height ))
54+ if model_weight :
55+ self .model .load_weights (model_weight )
56+
57+ def predict (self , x ):
58+ return self .decode (self .model .predict (x ))
59+
60+ def decode (self , pred ):
61+ input_len = np .ones (pred .shape [0 ]) * pred .shape [1 ]
62+ # Use greedy search. For complex tasks, you can use beam search
63+ results = ctc_decode (pred , input_length = input_len , greedy = True )[0 ][0 ][
64+ :, :self .max_length
65+ ]
66+ # Iterate over the results and get back the text
67+ output_text = []
68+ for res in results :
69+ res = self .num_to_char (res )
70+ res = reduce_join (res )
71+ res = res .numpy ().decode ("utf-8" )
72+ output_text .append (res )
73+ return output_text
74+
75+ @staticmethod
76+ def build_model (characters , input_shape ):
77+ input_img = Input (
78+ shape = (input_shape [0 ], input_shape [1 ], 1 ), name = "image" , dtype = "float32"
79+ )
80+ mobilenet = MobileNetV3_Small (
81+ (input_shape [0 ], input_shape [1 ], 1 ), 0 , alpha = 1.0 , include_top = False
82+ ).build ()
83+ x = mobilenet (input_img )
84+ new_shape = ((input_shape [0 ] // 8 ), (input_shape [1 ] // 8 ) * 576 )
85+ x = Reshape (target_shape = new_shape , name = "reshape" )(x )
86+ x = Dense (64 , activation = "relu" , name = "dense1" )(x )
87+ x = Dropout (0.2 )(x )
88+
89+ # RNNs
90+ x = Bidirectional (LSTM (128 , return_sequences = True , dropout = 0.25 ))(x )
91+ x = Bidirectional (LSTM (64 , return_sequences = True , dropout = 0.25 ))(x )
92+
93+ # Output layer
94+ output = Dense (len (characters ) + 2 , activation = "softmax" , name = "dense2" )(x )
6095
96+ # Define the model
97+ return Model (inputs = [input_img ], outputs = output , name = "ocr_model_v1" )
98+
99+ class OCR :
100+ def __init__ (self , generic_model_weight = 'generic_model.h5' , name_model_weight = 'name_model.h5' , scale_ratio = 1 ):
61101 self .width = 240
62102 self .height = 16
63103 self .max_length = 15
64- self .build_model (input_shape = (self .width , self .height ))
65- self .model .load_weights (model_weight )
66- self .ocr_model_artnames = ocr_model_artnames
104+ self .scale_ratio = scale_ratio
105+ self .generic_characters = sorted (
106+ [
107+ * set (
108+ "" .join (
109+ ArtsInfo .TypeNames
110+ + list (ArtsInfo .MainAttrNames .values ())
111+ + list (ArtsInfo .SubAttrNames .values ())
112+ + list (".,+%0123456789" )
113+ )
114+ )
115+ ]
116+ )
117+
118+ self .name_characters = sorted ([* set ("" .join (sum (ArtsInfo .ArtNames , [])))])
119+
120+ self .name_model = OCRModel (characters = self .name_characters ,
121+ model_weight = name_model_weight ,
122+ width = self .width , height = self .height ,
123+ max_length = self .max_length )
124+ self .generic_model = OCRModel (characters = self .generic_characters ,
125+ model_weight = generic_model_weight ,
126+ width = self .width , height = self .height ,
127+ max_length = self .max_length )
67128
68129 def detect_info (self , art_img ):
69130 info = self .extract_art_info (art_img )
70- x = np . concatenate ([ self . preprocess ( info [ key ]). T [ None , :, :, None ] for key in sorted (info .keys ())], axis = 0 )
71- y = self .model . predict ( x )
72- y = self .decode ( y )
73- y [ 3 ] = self .ocr_model_artnames . reg ( x [ 3 ][ None ])
74- return {** {key :v for key , v in zip (sorted ( info . keys ()), y )}, ** {'star' :self .detect_star (art_img )}}
131+ generic_keys = [ key for key in sorted (info .keys ()) if key != 'name' ]
132+ x = np . concatenate ([ self .preprocess ( info [ key ]). T [ None , :, :, None ] for key in generic_keys ], axis = 0 )
133+ y_generic = self .generic_model . predict ( x )
134+ y_name = self .name_model . predict ( self . preprocess ( info [ 'name' ]). T [ None ,:,:, None ])
135+ return {** {key :v for key , v in zip (generic_keys , y_generic )}, ** {'star' :self .detect_star (art_img )}, ** { 'name' : y_name [ 0 ] }}
75136
76137 def extract_art_info (self , art_img ):
77138 name = art_img .crop ([i * self .scale_ratio for i in Config .name_coords ])
@@ -112,12 +173,12 @@ def to_gray(self, text_img):
112173 text_img = (text_img [..., :3 ] @ [[[0.299 ], [0.587 ], [0.114 ]]])[:, :, 0 ]
113174 return np .array (text_img , np .float32 )
114175
115- def normalize (self , img , auto_inverse = True ):
116- img -= img .min ()
176+ def normalize (self , img , auto_inverse = True , min_jitter = 0 ):
177+ img -= img .min () + np . random . random () * min_jitter * img . max ()
117178 img /= img .max ()
118179 if auto_inverse and img [- 1 , - 1 ] > 0.5 :
119180 img = 1 - img
120- return img
181+ return np . array ( img , np . float32 )
121182
122183
123184 def crop (self , img , tol = 0.7 ):
@@ -154,7 +215,18 @@ def pad_to_width(self, img):
154215 )
155216
156217
157- def preprocess (self , text_img ):
218+ def preprocess (self , text_img , inference = True ):
219+ result = self .to_gray (text_img )
220+ if inference :
221+ result = self .normalize (result , True , 0 )
222+ result = self .crop (result )
223+ else :
224+ result = self .normalize (result , True , 0.2 )
225+ result = self .crop (result , np .random .random () * 0.25 + 0.6 )
226+ result = self .normalize (result , False , 0 )
227+ result = self .resize_to_height (result )
228+ result = self .pad_to_width (result )
229+ return result
158230 result = self .to_gray (text_img )
159231 result = self .normalize (result , True )
160232 result = self .crop (result )
@@ -163,88 +235,3 @@ def preprocess(self, text_img):
163235 result = self .pad_to_width (result )
164236 return result
165237
166-
167- def decode (self , pred ):
168- input_len = np .ones (pred .shape [0 ]) * pred .shape [1 ]
169- # Use greedy search. For complex tasks, you can use beam search
170- results = ctc_decode (pred , input_length = input_len , greedy = True )[0 ][0 ][
171- :, :self .max_length
172- ]
173- # Iterate over the results and get back the text
174- output_text = []
175- for res in results :
176- res = self .num_to_char (res )
177- res = reduce_join (res )
178- res = res .numpy ().decode ("utf-8" )
179- output_text .append (res )
180- return output_text
181-
182- def build_model (self , input_shape ):
183- input_img = Input (
184- shape = (input_shape [0 ], input_shape [1 ], 1 ), name = "image" , dtype = "float32"
185- )
186- mobilenet = MobileNetV3_Small (
187- (input_shape [0 ], input_shape [1 ], 1 ), 0 , alpha = 1.0 , include_top = False
188- ).build ()
189- x = mobilenet (input_img )
190- new_shape = ((input_shape [0 ] // 8 ), (input_shape [1 ] // 8 ) * 576 )
191- x = Reshape (target_shape = new_shape , name = "reshape" )(x )
192- x = Dense (64 , activation = "relu" , name = "dense1" )(x )
193- x = Dropout (0.2 )(x )
194-
195- # RNNs
196- x = Bidirectional (LSTM (128 , return_sequences = True , dropout = 0.25 ))(x )
197- x = Bidirectional (LSTM (64 , return_sequences = True , dropout = 0.25 ))(x )
198-
199- # Output layer
200- output = Dense (len (self .characters ) + 2 , activation = "softmax" , name = "dense2" )(x )
201-
202- # Define the model
203- self .model = Model (inputs = [input_img ], outputs = output , name = "ocr_model_v1" )
204-
205- class OCR_artnames :
206- def __init__ (self , model_weight = 'mn_model_weight_artnames.h5' ):
207- self .artnames = sorted (set (sum (ArtsInfo .ArtNames , [])))
208-
209- self .model = self .build_model (input_shape = (240 , 16 ))
210- self .model .load_weights (model_weight )
211-
212- def build_model (self , input_shape ):
213- input_img = Input (
214- shape = (input_shape [0 ], input_shape [1 ], 1 ), name = "image" , dtype = "float32"
215- )
216- mobilenet = MobileNetV3_Small (
217- (input_shape [0 ], input_shape [1 ], 1 ), 0 , alpha = 1.0 , include_top = False
218- ).build ()
219- x = mobilenet (input_img )
220- new_shape = ((input_shape [0 ] // 8 ), (input_shape [1 ] // 8 ) * 576 )
221- x = Reshape (target_shape = new_shape , name = "reshape" )(x )
222- x = Dense (64 , activation = "relu" , name = "dense1" )(x )
223- x = Dropout (0.2 )(x )
224-
225- # RNNs
226- x = Bidirectional (LSTM (128 , return_sequences = True , dropout = 0.25 ))(x )
227- x = Bidirectional (LSTM (64 , return_sequences = True , dropout = 0.25 ))(x )
228-
229- # Output layer
230- x = Flatten (name = "flatten" )(x )
231- x = Dense (
232- len (self .artnames ), activation = "softmax" , name = "dense2"
233- )(x )
234-
235- output = x
236-
237- # Define the model
238- model = Model (inputs = [input_img ], outputs = output , name = "ocr_model_artnames" )
239-
240- return model
241-
242- def decode_single (self , pred ):
243- i = pred [0 ].argmax ()
244- if pred [0 ][i ] > 0.75 :
245- return self .artnames [i ]
246- else :
247- return 'Unknown'
248-
249- def reg (self , x ):
250- return self .decode_single (self .model .predict (x ))
0 commit comments