Skip to content

Commit 9fd3227

Browse files
refactoring code and retrained models
1 parent bf244eb commit 9fd3227

File tree

8 files changed

+803
-117
lines changed

8 files changed

+803
-117
lines changed

.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,3 +8,5 @@ ArtScanner/artifacts/**
88
ArtScanner/artifacts-all/**
99

1010
.vscode/**
11+
**/.ipynb_checkpoints/**
12+
ArtScanner/data/**

ArtScanner/Tools/datagen.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
import json
22
import numpy as np
33
from PIL import ImageFont, Image, ImageDraw
4-
from .. import ArtsInfo
4+
import sys
5+
sys.path.append("..")
6+
import ArtsInfo
57

68
MainAttrDatabase = json.load(open('ReliquaryLevelExcelConfigData.json'))
79
SubAttrDatabase = json.load(open('ReliquaryAffixExcelConfigData.json'))

ArtScanner/Tools/train.ipynb

Lines changed: 695 additions & 0 deletions
Large diffs are not rendered by default.

ArtScanner/build.cmd

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
pyinstaller --onefile --add-data "mn_model_weight.h5;." --add-data "mn_model_weight_artnames.h5;." --add-data "Tools/ReliquaryLevelExcelConfigData.json;./Tools" --add-data "Tools/ReliquaryAffixExcelConfigData.json;./Tools" --hidden-import=h5py --hidden-import=h5py.defs --hidden-import=h5py.utils --hidden-import=h5py.h5ac --hidden-import=h5py._proxy --uac-admin -n ArtScanner main.py
1+
pyinstaller --onefile --add-data "generic_model.h5;." --add-data "name_model.h5;." --add-data "Tools/ReliquaryLevelExcelConfigData.json;./Tools" --add-data "Tools/ReliquaryAffixExcelConfigData.json;./Tools" --hidden-import=h5py --hidden-import=h5py.defs --hidden-import=h5py.utils --hidden-import=h5py.h5ac --hidden-import=h5py._proxy --uac-admin -n ArtScanner main.py
Binary file not shown.

ArtScanner/main.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -90,8 +90,8 @@ def is_admin():
9090
# margin near level number, color=233,229,220
9191

9292
# initialization
93-
ocr_model = ocr.OCR(scale_ratio=game_info.scale_ratio, model_weight=os.path.join(bundle_dir, 'mn_model_weight.h5'),
94-
ocr_model_artnames=ocr.OCR_artnames(model_weight=os.path.join(bundle_dir, 'mn_model_weight_artnames.h5')))
93+
ocr_model = ocr.OCR(scale_ratio=game_info.scale_ratio, generic_model_weight=os.path.join(bundle_dir, 'generic_model.h5'),
94+
name_model_weight=os.path.join(bundle_dir, 'name_model.h5'))
9595
art_id = 0
9696
saved = 0
9797
skipped = 0
Binary file not shown.

ArtScanner/ocr.py

Lines changed: 100 additions & 113 deletions
Original file line numberDiff line numberDiff line change
@@ -32,22 +32,10 @@ class Config:
3232
subattr_3_coords = [67, 584, 560, 624]
3333
subattr_4_coords = [67, 636, 560, 676]
3434

35-
class OCR:
36-
def __init__(self, model_weight='mn_model_weight.h5', scale_ratio=1, ocr_model_artnames=None):
37-
self.scale_ratio = scale_ratio
38-
self.characters = sorted(
39-
[
40-
*set(
41-
"".join(
42-
sum(ArtsInfo.ArtNames[:-2], [])
43-
+ ArtsInfo.TypeNames
44-
+ list(ArtsInfo.MainAttrNames.values())
45-
+ list(ArtsInfo.SubAttrNames.values())
46-
+ list(".,+%0123456789")
47-
)
48-
)
49-
]
50-
)
35+
class OCRModel:
36+
def __init__(self, characters, model_weight, width, height, max_length):
37+
38+
self.characters = characters
5139
# Mapping characters to integers
5240
self.char_to_num = StringLookup(
5341
vocabulary=list(self.characters), num_oov_indices=0, mask_token=""
@@ -57,21 +45,94 @@ def __init__(self, model_weight='mn_model_weight.h5', scale_ratio=1, ocr_model_a
5745
self.num_to_char = StringLookup(
5846
vocabulary=self.char_to_num.get_vocabulary(), oov_token="", mask_token="", invert=True
5947
)
48+
49+
self.width = width
50+
self.height = height
51+
self.max_length = max_length
52+
53+
self.model = OCRModel.build_model(characters=self.characters, input_shape=(self.width, self.height))
54+
if model_weight:
55+
self.model.load_weights(model_weight)
56+
57+
def predict(self, x):
58+
return self.decode(self.model.predict(x))
59+
60+
def decode(self, pred):
61+
input_len = np.ones(pred.shape[0]) * pred.shape[1]
62+
# Use greedy search. For complex tasks, you can use beam search
63+
results = ctc_decode(pred, input_length=input_len, greedy=True)[0][0][
64+
:, :self.max_length
65+
]
66+
# Iterate over the results and get back the text
67+
output_text = []
68+
for res in results:
69+
res = self.num_to_char(res)
70+
res = reduce_join(res)
71+
res = res.numpy().decode("utf-8")
72+
output_text.append(res)
73+
return output_text
74+
75+
@staticmethod
76+
def build_model(characters, input_shape):
77+
input_img = Input(
78+
shape=(input_shape[0], input_shape[1], 1), name="image", dtype="float32"
79+
)
80+
mobilenet = MobileNetV3_Small(
81+
(input_shape[0], input_shape[1], 1), 0, alpha=1.0, include_top=False
82+
).build()
83+
x = mobilenet(input_img)
84+
new_shape = ((input_shape[0] // 8), (input_shape[1] // 8) * 576)
85+
x = Reshape(target_shape=new_shape, name="reshape")(x)
86+
x = Dense(64, activation="relu", name="dense1")(x)
87+
x = Dropout(0.2)(x)
88+
89+
# RNNs
90+
x = Bidirectional(LSTM(128, return_sequences=True, dropout=0.25))(x)
91+
x = Bidirectional(LSTM(64, return_sequences=True, dropout=0.25))(x)
92+
93+
# Output layer
94+
output = Dense(len(characters) + 2, activation="softmax", name="dense2")(x)
6095

96+
# Define the model
97+
return Model(inputs=[input_img], outputs=output, name="ocr_model_v1")
98+
99+
class OCR:
100+
def __init__(self, generic_model_weight='generic_model.h5', name_model_weight='name_model.h5', scale_ratio=1):
61101
self.width = 240
62102
self.height = 16
63103
self.max_length = 15
64-
self.build_model(input_shape=(self.width, self.height))
65-
self.model.load_weights(model_weight)
66-
self.ocr_model_artnames = ocr_model_artnames
104+
self.scale_ratio = scale_ratio
105+
self.generic_characters = sorted(
106+
[
107+
*set(
108+
"".join(
109+
ArtsInfo.TypeNames
110+
+ list(ArtsInfo.MainAttrNames.values())
111+
+ list(ArtsInfo.SubAttrNames.values())
112+
+ list(".,+%0123456789")
113+
)
114+
)
115+
]
116+
)
117+
118+
self.name_characters = sorted([*set("".join(sum(ArtsInfo.ArtNames, [])))])
119+
120+
self.name_model = OCRModel(characters=self.name_characters,
121+
model_weight=name_model_weight,
122+
width=self.width, height=self.height,
123+
max_length=self.max_length)
124+
self.generic_model = OCRModel(characters=self.generic_characters,
125+
model_weight=generic_model_weight,
126+
width=self.width, height=self.height,
127+
max_length=self.max_length)
67128

68129
def detect_info(self, art_img):
69130
info = self.extract_art_info(art_img)
70-
x = np.concatenate([self.preprocess(info[key]).T[None, :, :, None] for key in sorted(info.keys())], axis=0)
71-
y = self.model.predict(x)
72-
y = self.decode(y)
73-
y[3] = self.ocr_model_artnames.reg(x[3][None])
74-
return {**{key:v for key, v in zip(sorted(info.keys()), y)}, **{'star':self.detect_star(art_img)}}
131+
generic_keys = [key for key in sorted(info.keys()) if key!='name']
132+
x = np.concatenate([self.preprocess(info[key]).T[None, :, :, None] for key in generic_keys], axis=0)
133+
y_generic = self.generic_model.predict(x)
134+
y_name = self.name_model.predict(self.preprocess(info['name']).T[None,:,:,None])
135+
return {**{key:v for key, v in zip(generic_keys, y_generic)}, **{'star':self.detect_star(art_img)}, **{'name':y_name[0]}}
75136

76137
def extract_art_info(self, art_img):
77138
name = art_img.crop([i*self.scale_ratio for i in Config.name_coords])
@@ -112,12 +173,12 @@ def to_gray(self, text_img):
112173
text_img = (text_img[..., :3] @ [[[0.299], [0.587], [0.114]]])[:, :, 0]
113174
return np.array(text_img, np.float32)
114175

115-
def normalize(self, img, auto_inverse=True):
116-
img -= img.min()
176+
def normalize(self, img, auto_inverse=True, min_jitter=0):
177+
img -= img.min() + np.random.random() * min_jitter * img.max()
117178
img /= img.max()
118179
if auto_inverse and img[-1, -1] > 0.5:
119180
img = 1 - img
120-
return img
181+
return np.array(img, np.float32)
121182

122183

123184
def crop(self, img, tol=0.7):
@@ -154,7 +215,18 @@ def pad_to_width(self, img):
154215
)
155216

156217

157-
def preprocess(self, text_img):
218+
def preprocess(self, text_img, inference=True):
219+
result = self.to_gray(text_img)
220+
if inference:
221+
result = self.normalize(result, True, 0)
222+
result = self.crop(result)
223+
else:
224+
result = self.normalize(result, True, 0.2)
225+
result = self.crop(result, np.random.random() * 0.25 + 0.6)
226+
result = self.normalize(result, False, 0)
227+
result = self.resize_to_height(result)
228+
result = self.pad_to_width(result)
229+
return result
158230
result = self.to_gray(text_img)
159231
result = self.normalize(result, True)
160232
result = self.crop(result)
@@ -163,88 +235,3 @@ def preprocess(self, text_img):
163235
result = self.pad_to_width(result)
164236
return result
165237

166-
167-
def decode(self, pred):
168-
input_len = np.ones(pred.shape[0]) * pred.shape[1]
169-
# Use greedy search. For complex tasks, you can use beam search
170-
results = ctc_decode(pred, input_length=input_len, greedy=True)[0][0][
171-
:, :self.max_length
172-
]
173-
# Iterate over the results and get back the text
174-
output_text = []
175-
for res in results:
176-
res = self.num_to_char(res)
177-
res = reduce_join(res)
178-
res = res.numpy().decode("utf-8")
179-
output_text.append(res)
180-
return output_text
181-
182-
def build_model(self, input_shape):
183-
input_img = Input(
184-
shape=(input_shape[0], input_shape[1], 1), name="image", dtype="float32"
185-
)
186-
mobilenet = MobileNetV3_Small(
187-
(input_shape[0], input_shape[1], 1), 0, alpha=1.0, include_top=False
188-
).build()
189-
x = mobilenet(input_img)
190-
new_shape = ((input_shape[0] // 8), (input_shape[1] // 8) * 576)
191-
x = Reshape(target_shape=new_shape, name="reshape")(x)
192-
x = Dense(64, activation="relu", name="dense1")(x)
193-
x = Dropout(0.2)(x)
194-
195-
# RNNs
196-
x = Bidirectional(LSTM(128, return_sequences=True, dropout=0.25))(x)
197-
x = Bidirectional(LSTM(64, return_sequences=True, dropout=0.25))(x)
198-
199-
# Output layer
200-
output = Dense(len(self.characters) + 2, activation="softmax", name="dense2")(x)
201-
202-
# Define the model
203-
self.model = Model(inputs=[input_img], outputs=output, name="ocr_model_v1")
204-
205-
class OCR_artnames:
206-
def __init__(self, model_weight='mn_model_weight_artnames.h5'):
207-
self.artnames = sorted(set(sum(ArtsInfo.ArtNames, [])))
208-
209-
self.model = self.build_model(input_shape=(240, 16))
210-
self.model.load_weights(model_weight)
211-
212-
def build_model(self, input_shape):
213-
input_img = Input(
214-
shape=(input_shape[0], input_shape[1], 1), name="image", dtype="float32"
215-
)
216-
mobilenet = MobileNetV3_Small(
217-
(input_shape[0], input_shape[1], 1), 0, alpha=1.0, include_top=False
218-
).build()
219-
x = mobilenet(input_img)
220-
new_shape = ((input_shape[0] // 8), (input_shape[1] // 8) * 576)
221-
x = Reshape(target_shape=new_shape, name="reshape")(x)
222-
x = Dense(64, activation="relu", name="dense1")(x)
223-
x = Dropout(0.2)(x)
224-
225-
# RNNs
226-
x = Bidirectional(LSTM(128, return_sequences=True, dropout=0.25))(x)
227-
x = Bidirectional(LSTM(64, return_sequences=True, dropout=0.25))(x)
228-
229-
# Output layer
230-
x = Flatten(name="flatten")(x)
231-
x = Dense(
232-
len(self.artnames), activation="softmax", name="dense2"
233-
)(x)
234-
235-
output = x
236-
237-
# Define the model
238-
model = Model(inputs=[input_img], outputs=output, name="ocr_model_artnames")
239-
240-
return model
241-
242-
def decode_single(self, pred):
243-
i = pred[0].argmax()
244-
if pred[0][i] > 0.75:
245-
return self.artnames[i]
246-
else:
247-
return 'Unknown'
248-
249-
def reg(self, x):
250-
return self.decode_single(self.model.predict(x))

0 commit comments

Comments
 (0)