Skip to content

Commit 02a1213

Browse files
authored
Decrease RAM Usage and index load and build times with Marisa Trie + Speed, Concurrency, and Unit Tests (#1)
1 parent ae0743e commit 02a1213

File tree

10 files changed

+250
-93
lines changed

10 files changed

+250
-93
lines changed

Pipfile

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,9 @@ PyYAML = "*"
3535
Shapely = "*"
3636
numpy = "*"
3737
scipy = "*"
38+
marisa_trie = "*"
3839

3940
[dev-packages]
4041

4142
[requires]
42-
python_version = "3.6"
43+
python_version = "3.6"

litecoder/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,9 @@
99

1010
DATA_DIR = os.path.join(os.path.dirname(__file__), 'data')
1111

12-
US_STATE_PATH = os.path.join(DATA_DIR, 'us-states.p')
12+
US_STATE_PATH = os.path.join(DATA_DIR, 'us-states.marisa')
1313

14-
US_CITY_PATH = os.path.join(DATA_DIR, 'us-cities.p')
14+
US_CITY_PATH = os.path.join(DATA_DIR, 'us-cities.marisa')
1515

1616

1717
logging.basicConfig(

litecoder/usa.py

Lines changed: 45 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11

22

33
import re
4-
import pickle
4+
import marisa_trie
5+
import ujson as json
56

67
from tqdm import tqdm
78
from collections import defaultdict
@@ -77,11 +78,9 @@ def large_p1_gap(self, row, name):
7778

7879
def __call__(self, row, name):
7980
"""Is a name unique enough that it should be indexed independently?
80-
8181
Args:
8282
row (models.WOFLocality)
8383
name (str)
84-
8584
Returns: bool
8685
"""
8786
return not self.blocked(name) and self.large_p1_gap(row, name)
@@ -94,10 +93,8 @@ def __init__(self, allow_bare):
9493

9594
def __call__(self, row):
9695
"""Enumerate index keys for a city.
97-
9896
Args:
9997
row (db.Locality)
100-
10198
Yields: str
10299
"""
103100
bare_names = [n for n in row.names if self.allow_bare(row, n)]
@@ -124,10 +121,8 @@ def __call__(self, row):
124121

125122
def state_key_iter(row):
126123
"""Enumerate index keys for a state.
127-
128124
Args:
129125
row (db.Region)
130-
131126
Yields: str
132127
"""
133128
names = (row.name,)
@@ -147,99 +142,53 @@ def state_key_iter(row):
147142
yield ' '.join((abbr, usa))
148143

149144

150-
class Match:
151-
152-
def __init__(self, row):
153-
"""Set model class, PK, metadata.
154-
"""
155-
state = inspect(row)
156-
157-
# Don't store the actual row, so we can serialize.
158-
self._model_cls = state.class_
159-
self._pk = state.identity
160-
161-
self.data = Box(dict(row))
162-
163-
@cached_property
164-
def db_row(self):
165-
"""Hydrate database row, lazily.
166-
"""
167-
return self._model_cls.query.get(self._pk)
168-
169-
170-
class CityMatch(Match):
171-
172-
def __repr__(self):
173-
return '%s<%s, %s, %s, wof:%d>' % (
174-
self.__class__.__name__,
175-
self.data.name,
176-
self.data.name_a1,
177-
self.data.name_a0,
178-
self.data.wof_id,
179-
)
180-
181-
182-
class StateMatch(Match):
183-
184-
def __repr__(self):
185-
return '%s<%s, %s, wof:%d>' % (
186-
self.__class__.__name__,
187-
self.data.name,
188-
self.data.name_a0,
189-
self.data.wof_id,
190-
)
191-
192-
193145
class Index:
194146

195-
@classmethod
196-
def load(cls, path):
197-
with open(path, 'rb') as fh:
198-
return pickle.load(fh)
147+
def load(self, path, mmap=False):
148+
if mmap:
149+
self._trie.mmap(path)
150+
else:
151+
self._trie.load(path)
199152

200153
def __init__(self):
201-
self._key_to_ids = defaultdict(set)
202-
self._id_to_loc = dict()
154+
self._trie = marisa_trie.BytesTrie()
155+
156+
# We use prefixes here to store the keys -> ids and ids -> loc "maps" as subtrees in one marisa trie.
157+
self._keys_prefix = "A"
158+
self._ids_prefix = "B"
203159

204160
def __len__(self):
205-
return len(self._key_to_ids)
161+
return len(self._trie.keys(self._keys_prefix))
206162

207163
def __repr__(self):
208164
return '%s<%d keys, %d entities>' % (
209165
self.__class__.__name__,
210-
len(self._key_to_ids),
211-
len(self._id_to_loc),
166+
len(self._trie.keys(self._keys_prefix)),
167+
len(self._trie.keys(self._ids_prefix)),
212168
)
213169

214170
def __getitem__(self, text):
215171
"""Get ids, map to records only if there is a match in the index
216172
"""
217-
if keyify(text) not in self._key_to_ids:
173+
normalized_key = self._keys_prefix + keyify(text)
174+
val = self._trie.get(normalized_key, None)
175+
if not val:
218176
return None
177+
ids = json.loads(val[0])
219178

220-
ids = self._key_to_ids[keyify(text)]
221-
222-
return [self._id_to_loc[id] for id in ids]
223-
224-
def add_key(self, key, id):
225-
self._key_to_ids[key].add(id)
226-
227-
def add_location(self, id, location):
228-
self._id_to_loc[id] = location
179+
return [json.loads(self._trie[self._ids_prefix + id][0]) for id in ids]
229180

230181
def locations(self):
231-
return list(self._id_to_loc.values())
182+
return self._trie.items(self._ids_prefix)
232183

233184
def save(self, path):
234-
with open(path, 'wb') as fh:
235-
pickle.dump(self, fh)
185+
self._trie.save(path)
236186

237187

238188
class USCityIndex(Index):
239189

240-
@classmethod
241-
def load(cls, path=US_CITY_PATH):
242-
return super().load(path)
190+
def load(self, path=US_CITY_PATH, mmap=False):
191+
return super().load(path, mmap)
243192

244193
def __init__(self, bare_name_blocklist=None):
245194
super().__init__()
@@ -248,6 +197,7 @@ def __init__(self, bare_name_blocklist=None):
248197
def build(self):
249198
"""Index all US cities.
250199
"""
200+
251201
allow_bare = AllowBareCityName(blocklist=self.bare_name_blocklist)
252202

253203
iter_keys = CityKeyIter(allow_bare)
@@ -257,21 +207,27 @@ def build(self):
257207

258208
logger.info('Indexing US cities.')
259209

210+
key_to_ids = defaultdict(set)
211+
id_to_loc_items = list()
212+
260213
for row in tqdm(cities):
261214

262215
# Key -> id(s)
263216
for key in map(keyify, iter_keys(row)):
264-
self.add_key(key, row.wof_id)
217+
key_to_ids[key].add(str(row.wof_id))
265218

266219
# ID -> city
267-
self.add_location(row.wof_id, CityMatch(row))
220+
id_to_loc_items.append((self._ids_prefix + str(row.wof_id), bytes(json.dumps(dict(row)), encoding="utf-8")))
221+
222+
key_to_ids_items = [(self._keys_prefix + key, json.dumps(list(key_to_ids[key])).encode("utf-8")) for key in key_to_ids]
223+
224+
self._trie = marisa_trie.BytesTrie(key_to_ids_items + id_to_loc_items)
268225

269226

270227
class USStateIndex(Index):
271228

272-
@classmethod
273-
def load(cls, path=US_STATE_PATH):
274-
return super().load(path)
229+
def load(self, path=US_STATE_PATH, mmap=False):
230+
return super().load(path, mmap)
275231

276232
def build(self):
277233
"""Index all US states.
@@ -280,11 +236,18 @@ def build(self):
280236

281237
logger.info('Indexing US states.')
282238

239+
key_to_ids = defaultdict(set)
240+
id_to_loc_items = list()
241+
283242
for row in tqdm(states):
284243

285244
# Key -> id(s)
286245
for key in map(keyify, state_key_iter(row)):
287-
self.add_key(key, row.wof_id)
246+
key_to_ids[key].add(str(row.wof_id))
288247

289248
# ID -> state
290-
self.add_location(row.wof_id, StateMatch(row))
249+
id_to_loc_items.append((self._ids_prefix + str(row.wof_id), bytes(json.dumps(dict(row)), encoding="utf-8")))
250+
251+
key_to_ids_items = [(self._keys_prefix + key, json.dumps(list(key_to_ids[key])).encode("utf-8")) for key in key_to_ids]
252+
253+
self._trie = marisa_trie.BytesTrie(key_to_ids_items + id_to_loc_items)

tests/prod_db/conftest.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,13 @@
77

88
@pytest.fixture(scope='session')
99
def city_idx():
10-
return USCityIndex.load()
10+
city_idx = USCityIndex()
11+
city_idx.load()
12+
return city_idx
1113

1214

1315
@pytest.fixture(scope='session')
1416
def state_idx():
15-
return USStateIndex.load()
17+
state_idx = USStateIndex()
18+
state_idx.load()
19+
return state_idx

tests/prod_db/test_us_city_index.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ def test_cases(city_idx, query, matches, xfail):
3333

3434
res = city_idx[query]
3535

36-
ids = [r.data.wof_id for r in res]
36+
ids = [r["wof_id"] for r in res]
3737

3838
# Exact id list match.
3939
assert sorted(ids) == sorted(matches)
@@ -49,6 +49,6 @@ def test_topn(city_idx, city):
4949
"""Smoke test N most populous cities.
5050
"""
5151
res = city_idx['%s, %s' % (city.name, city.name_a1)]
52-
res_ids = [r.data.wof_id for r in res]
52+
res_ids = [r["wof_id"] for r in res]
5353

54-
assert city.wof_id in res_ids
54+
assert city.wof_id in res_ids

tests/prod_db/test_us_state_index.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ def test_cases(state_idx, query, matches):
2828

2929
res = state_idx[query]
3030

31-
ids = [r.data.wof_id for r in res]
31+
ids = [r["wof_id"] for r in res]
3232

3333
assert sorted(ids) == sorted(matches)
3434

@@ -41,6 +41,6 @@ def test_all(state_idx, state):
4141
"""Smoke test N most populous cities.
4242
"""
4343
res = state_idx[state.name]
44-
res_ids = [r.data.wof_id for r in res]
44+
res_ids = [r["wof_id"] for r in res]
4545

46-
assert state.wof_id in res_ids
46+
assert state.wof_id in res_ids

tests/runtime/concurrency_test.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
from multiprocessing import Pool
2+
from litecoder.usa import USCityIndex, USStateIndex
3+
import time
4+
5+
NUM_PROCESSES = 4
6+
7+
# Load 50 test city lookups
8+
with open("tests/runtime/test_city_lookups.txt", "r") as lookups_file:
9+
city_tests = lookups_file.read().splitlines()
10+
11+
# Increase the number of lookups for the speed test if necessary
12+
for x in range (10):
13+
city_tests += city_tests
14+
num_tests_per_process = len(city_tests)
15+
num_tests = NUM_PROCESSES * num_tests_per_process
16+
17+
# Load USCityIndex
18+
city_idx = USCityIndex()
19+
city_idx.load()
20+
21+
22+
def lookup_cities(process_num):
23+
print ('Process {}: looking up {} cities'.format(process_num, num_tests_per_process))
24+
start_time = time.time()
25+
for city in city_tests:
26+
city_idx[city]
27+
ms = 1000*(time.time() - start_time)
28+
print("Process {}: finished, took {}ms @ {} ms/lookup!".format(process_num, ms, float(ms/num_tests_per_process)))
29+
30+
if __name__ == '__main__':
31+
print("Looking up {} cities on {} processes...".format(num_tests, NUM_PROCESSES))
32+
start_time = time.time()
33+
with Pool(5) as p:
34+
p.map(lookup_cities, range(1, NUM_PROCESSES+1))
35+
ms = 1000*(time.time() - start_time)
36+
print("Fully finished: took {}ms @ {} ms/lookup!".format(ms, float(ms/num_tests)))
37+
38+
print()
39+
print("Looking up all {} cities on one process...".format(num_tests), end="")
40+
start_time = time.time()
41+
for i in range(NUM_PROCESSES):
42+
for city in city_tests:
43+
city_idx[city]
44+
ms = 1000*(time.time() - start_time)
45+
print("finished: took {}ms @ {} ms/lookup!".format(ms, ms/num_tests))

tests/runtime/speed_test.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
from litecoder.usa import USCityIndex, USStateIndex
2+
import time
3+
4+
print("Loading USCityIndex... ", end="")
5+
start_time = time.time()
6+
city_idx = USCityIndex()
7+
city_idx.load()
8+
print("finished: {}s!".format(time.time() - start_time))
9+
10+
# Load 50 test city lookups
11+
with open("test_city_lookups.txt", "r") as lookups_file:
12+
city_tests = lookups_file.read().splitlines()
13+
14+
# Increase the number of lookups for the speed test if necessary
15+
for x in range (5):
16+
city_tests += city_tests
17+
num_tests = len(city_tests)
18+
print("measuring time for {} cities... ".format(num_tests), end="")
19+
start_time = time.time()
20+
for city in city_tests:
21+
city_idx[city]
22+
ms = 1000*(time.time() - start_time)
23+
print("finished: took {}ms at {} ms/lookup!".format(ms, float(ms/num_tests)))
24+
25+
print("Loading USStateIndex... ", end="")
26+
start_time = time.time()
27+
state_idx = USStateIndex()
28+
state_idx.load()
29+
print("finished: {}s!".format(time.time() - start_time))
30+
31+
# Load 50 test state lookups
32+
with open("test_state_lookups.txt", "r") as lookups_file:
33+
state_tests = lookups_file.read().splitlines()
34+
35+
# Increase the number of lookups for the speed test if necessary
36+
for x in range (5):
37+
state_tests += state_tests
38+
num_tests = len(state_tests)
39+
print("measuring time for {} states... ".format(num_tests), end="")
40+
start_time = time.time()
41+
for state in state_tests:
42+
state_idx[state]
43+
ms = 1000*(time.time() - start_time)
44+
print("finished: took {}ms at {} ms/lookup!".format(ms, float(ms/num_tests)))

0 commit comments

Comments
 (0)