11
22
33import re
4- import pickle
4+ import marisa_trie
5+ import ujson as json
56
67from tqdm import tqdm
78from collections import defaultdict
@@ -77,11 +78,9 @@ def large_p1_gap(self, row, name):
7778
7879 def __call__ (self , row , name ):
7980 """Is a name unique enough that it should be indexed independently?
80-
8181 Args:
8282 row (models.WOFLocality)
8383 name (str)
84-
8584 Returns: bool
8685 """
8786 return not self .blocked (name ) and self .large_p1_gap (row , name )
@@ -94,10 +93,8 @@ def __init__(self, allow_bare):
9493
9594 def __call__ (self , row ):
9695 """Enumerate index keys for a city.
97-
9896 Args:
9997 row (db.Locality)
100-
10198 Yields: str
10299 """
103100 bare_names = [n for n in row .names if self .allow_bare (row , n )]
@@ -124,10 +121,8 @@ def __call__(self, row):
124121
125122def state_key_iter (row ):
126123 """Enumerate index keys for a state.
127-
128124 Args:
129125 row (db.Region)
130-
131126 Yields: str
132127 """
133128 names = (row .name ,)
@@ -147,99 +142,53 @@ def state_key_iter(row):
147142 yield ' ' .join ((abbr , usa ))
148143
149144
150- class Match :
151-
152- def __init__ (self , row ):
153- """Set model class, PK, metadata.
154- """
155- state = inspect (row )
156-
157- # Don't store the actual row, so we can serialize.
158- self ._model_cls = state .class_
159- self ._pk = state .identity
160-
161- self .data = Box (dict (row ))
162-
163- @cached_property
164- def db_row (self ):
165- """Hydrate database row, lazily.
166- """
167- return self ._model_cls .query .get (self ._pk )
168-
169-
170- class CityMatch (Match ):
171-
172- def __repr__ (self ):
173- return '%s<%s, %s, %s, wof:%d>' % (
174- self .__class__ .__name__ ,
175- self .data .name ,
176- self .data .name_a1 ,
177- self .data .name_a0 ,
178- self .data .wof_id ,
179- )
180-
181-
182- class StateMatch (Match ):
183-
184- def __repr__ (self ):
185- return '%s<%s, %s, wof:%d>' % (
186- self .__class__ .__name__ ,
187- self .data .name ,
188- self .data .name_a0 ,
189- self .data .wof_id ,
190- )
191-
192-
193145class Index :
194146
195- @classmethod
196- def load (cls , path ):
197- with open (path , 'rb' ) as fh :
198- return pickle .load (fh )
147+ def load (self , path , mmap = False ):
148+ if mmap :
149+ self ._trie .mmap (path )
150+ else :
151+ self ._trie .load (path )
199152
200153 def __init__ (self ):
201- self ._key_to_ids = defaultdict (set )
202- self ._id_to_loc = dict ()
154+ self ._trie = marisa_trie .BytesTrie ()
155+
156+ # We use prefixes here to store the keys -> ids and ids -> loc "maps" as subtrees in one marisa trie.
157+ self ._keys_prefix = "A"
158+ self ._ids_prefix = "B"
203159
204160 def __len__ (self ):
205- return len (self ._key_to_ids )
161+ return len (self ._trie . keys ( self . _keys_prefix ) )
206162
207163 def __repr__ (self ):
208164 return '%s<%d keys, %d entities>' % (
209165 self .__class__ .__name__ ,
210- len (self ._key_to_ids ),
211- len (self ._id_to_loc ),
166+ len (self ._trie . keys ( self . _keys_prefix ) ),
167+ len (self ._trie . keys ( self . _ids_prefix ) ),
212168 )
213169
214170 def __getitem__ (self , text ):
215171 """Get ids, map to records only if there is a match in the index
216172 """
217- if keyify (text ) not in self ._key_to_ids :
173+ normalized_key = self ._keys_prefix + keyify (text )
174+ val = self ._trie .get (normalized_key , None )
175+ if not val :
218176 return None
177+ ids = json .loads (val [0 ])
219178
220- ids = self ._key_to_ids [keyify (text )]
221-
222- return [self ._id_to_loc [id ] for id in ids ]
223-
224- def add_key (self , key , id ):
225- self ._key_to_ids [key ].add (id )
226-
227- def add_location (self , id , location ):
228- self ._id_to_loc [id ] = location
179+ return [json .loads (self ._trie [self ._ids_prefix + id ][0 ]) for id in ids ]
229180
230181 def locations (self ):
231- return list ( self ._id_to_loc . values () )
182+ return self ._trie . items ( self . _ids_prefix )
232183
233184 def save (self , path ):
234- with open (path , 'wb' ) as fh :
235- pickle .dump (self , fh )
185+ self ._trie .save (path )
236186
237187
238188class USCityIndex (Index ):
239189
240- @classmethod
241- def load (cls , path = US_CITY_PATH ):
242- return super ().load (path )
190+ def load (self , path = US_CITY_PATH , mmap = False ):
191+ return super ().load (path , mmap )
243192
244193 def __init__ (self , bare_name_blocklist = None ):
245194 super ().__init__ ()
@@ -248,6 +197,7 @@ def __init__(self, bare_name_blocklist=None):
248197 def build (self ):
249198 """Index all US cities.
250199 """
200+
251201 allow_bare = AllowBareCityName (blocklist = self .bare_name_blocklist )
252202
253203 iter_keys = CityKeyIter (allow_bare )
@@ -257,21 +207,27 @@ def build(self):
257207
258208 logger .info ('Indexing US cities.' )
259209
210+ key_to_ids = defaultdict (set )
211+ id_to_loc_items = list ()
212+
260213 for row in tqdm (cities ):
261214
262215 # Key -> id(s)
263216 for key in map (keyify , iter_keys (row )):
264- self . add_key ( key , row .wof_id )
217+ key_to_ids [ key ]. add ( str ( row .wof_id ) )
265218
266219 # ID -> city
267- self .add_location (row .wof_id , CityMatch (row ))
220+ id_to_loc_items .append ((self ._ids_prefix + str (row .wof_id ), bytes (json .dumps (dict (row )), encoding = "utf-8" )))
221+
222+ key_to_ids_items = [(self ._keys_prefix + key , json .dumps (list (key_to_ids [key ])).encode ("utf-8" )) for key in key_to_ids ]
223+
224+ self ._trie = marisa_trie .BytesTrie (key_to_ids_items + id_to_loc_items )
268225
269226
270227class USStateIndex (Index ):
271228
272- @classmethod
273- def load (cls , path = US_STATE_PATH ):
274- return super ().load (path )
229+ def load (self , path = US_STATE_PATH , mmap = False ):
230+ return super ().load (path , mmap )
275231
276232 def build (self ):
277233 """Index all US states.
@@ -280,11 +236,18 @@ def build(self):
280236
281237 logger .info ('Indexing US states.' )
282238
239+ key_to_ids = defaultdict (set )
240+ id_to_loc_items = list ()
241+
283242 for row in tqdm (states ):
284243
285244 # Key -> id(s)
286245 for key in map (keyify , state_key_iter (row )):
287- self . add_key ( key , row .wof_id )
246+ key_to_ids [ key ]. add ( str ( row .wof_id ) )
288247
289248 # ID -> state
290- self .add_location (row .wof_id , StateMatch (row ))
249+ id_to_loc_items .append ((self ._ids_prefix + str (row .wof_id ), bytes (json .dumps (dict (row )), encoding = "utf-8" )))
250+
251+ key_to_ids_items = [(self ._keys_prefix + key , json .dumps (list (key_to_ids [key ])).encode ("utf-8" )) for key in key_to_ids ]
252+
253+ self ._trie = marisa_trie .BytesTrie (key_to_ids_items + id_to_loc_items )
0 commit comments