Skip to content

Commit 4c1b8ca

Browse files
committed
refactor: load_places.py 수정
1 parent 8e9923d commit 4c1b8ca

File tree

1 file changed

+200
-145
lines changed

1 file changed

+200
-145
lines changed

apps/places/management/commands/load_places.py

Lines changed: 200 additions & 145 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
import os
2+
import sys
23
import csv
34
import re
5+
import time
6+
from datetime import timedelta
47
from decimal import Decimal
58

69
from django.conf import settings
@@ -16,192 +19,244 @@
1619
except Exception:
1720
FAISS_AVAILABLE = False
1821

22+
1923
def norm(s: str) -> str:
2024
s = s or ""
2125
return re.sub(r"[\s\(\)\[\]\-_/·•~!@#$%^&*=+|:;\"'<>?,.]+", "", s).lower()
2226

27+
2328
class Command(BaseCommand):
24-
help = "CSV(필수) + (선택) FAISS index에서 Place 및 임베딩을 서버 DB에 적재합니다."
29+
help = "CSV(필수) + (선택) FAISS index에서 Place 및 임베딩을 DB에 적재 (진행률/ETA/막대 표시)."
2530

2631
def add_arguments(self, parser):
27-
parser.add_argument("--csv", required=False,
28-
help="CSV 경로 (기본: /app/data/triptailor_full_metadata.csv → 없으면 BASE_DIR/triptailor_full_metadata.csv)")
32+
parser.add_argument("--csv", required=False, help="CSV 경로 (기본: BASE_DIR/triptailor_full_metadata.csv)")
2933
parser.add_argument("--faiss", required=False, help="FAISS .index 경로 (벡터 저장 시)")
30-
parser.add_argument("--batch", type=int, default=500, help="태그/로그/부분커밋 간격")
31-
parser.add_argument("--no-embedding", action="store_true",
32-
help="임베딩 저장 건너뜀 (FAISS 없이 CSV만 적재)")
33-
parser.add_argument("--dim", type=int, default=None,
34-
help="벡터 차원(선택). 지정하면 vec 길이와 불일치 시 건너뜀")
34+
parser.add_argument("--dim", type=int, default=None, help="임베딩 차원(옵션, 검증용)")
35+
parser.add_argument("--batch", type=int, default=500, help="태그 배치 커밋 간격")
36+
parser.add_argument("--log-interval", type=int, default=200, help="몇 건마다 진행 로그를 강제 출력할지")
37+
parser.add_argument("--bar-width", type=int, default=40, help="진행 막대 너비(칸 수)")
38+
parser.add_argument("--dry-run", action="store_true", help="DB에 쓰지 않고 파싱/속도만 확인")
39+
parser.add_argument("--skip-embedding", action="store_true", help="FAISS 임베딩 저장 건너뛰기")
3540

36-
def handle(self, *args, **opts):
37-
# 1) CSV 기본 경로: 컨테이너 마운트 경로 → BASE_DIR 순
38-
default_csv_candidates = [
39-
"/app/data/triptailor_full_metadata.csv",
40-
os.path.join(settings.BASE_DIR, "triptailor_full_metadata.csv"),
41-
]
42-
csv_path = opts.get("csv")
43-
if not csv_path:
44-
for cand in default_csv_candidates:
45-
if os.path.exists(cand):
46-
csv_path = cand
47-
break
48-
49-
if not csv_path or not os.path.exists(csv_path):
50-
raise CommandError(f"CSV not found. Tried: {csv_path or default_csv_candidates}")
41+
# === 내부 유틸 ===
42+
def _fmt_hms(self, seconds: float) -> str:
43+
return str(timedelta(seconds=int(max(0, seconds))))
44+
45+
def _eta(self, done: int, total: int, start_ts: float) -> tuple[str, str]:
46+
elapsed = time.time() - start_ts
47+
rate = done / elapsed if elapsed > 0 and done > 0 else 0
48+
remain = (total - done) / rate if rate > 0 else 0
49+
return self._fmt_hms(remain), self._fmt_hms(elapsed)
50+
51+
def _bar(self, pct: float, width: int) -> str:
52+
filled = int(round(pct * width))
53+
return "[" + "#" * filled + "-" * (width - filled) + "]"
5154

55+
def _print_progress(self, i: int, total: int, start_ts: float,
56+
created: int, updated: int, skipped: int, bar_width: int,
57+
final: bool = False):
58+
pct = (i / total) if total > 0 else 0.0
59+
eta, elapsed = self._eta(i, total, start_ts)
60+
line = (
61+
f"{self._bar(pct, bar_width)} "
62+
f"{pct*100:6.2f}% "
63+
f"{i}/{total} "
64+
f"elapsed={elapsed} eta={eta} "
65+
f"ok={created+updated} created={created} updated={updated} skipped={skipped}"
66+
)
67+
# 진행 중엔 같은 줄 덮어쓰기(\r), 종료 시 개행
68+
end = "\n" if final else "\r"
69+
# Django OutputWrapper는 carriage return도 전달 가능
70+
self.stdout.write(line, ending=end)
71+
self.stdout.flush()
72+
73+
def handle(self, *args, **opts):
74+
csv_path = opts["csv"] or os.path.join(settings.BASE_DIR, "triptailor_full_metadata.csv")
5275
faiss_path = opts.get("faiss")
53-
batch_size = int(opts["batch"])
54-
no_embedding = bool(opts["no_embedding"])
55-
exp_dim = opts.get("dim")
76+
dim_expect = opts.get("dim")
77+
batch_size = opts["batch"]
78+
log_interval = max(1, opts["log_interval"])
79+
bar_width = max(10, opts["bar_width"])
80+
dry_run = opts["dry_run"]
81+
skip_vec = opts["skip_embedding"]
5682

57-
# 2) (선택) FAISS에서 벡터 복원
83+
if not os.path.exists(csv_path):
84+
raise CommandError(f"CSV not found: {csv_path}")
85+
86+
# 0) CSV 총 행수(진행률 위해 미리 카운트)
87+
self.stdout.write(self.style.MIGRATE_HEADING("Count CSV rows"))
88+
with open(csv_path, newline="", encoding="utf-8-sig") as f:
89+
total_rows = sum(1 for _ in csv.DictReader(f))
90+
if total_rows == 0:
91+
raise CommandError("CSV에 데이터가 없습니다.")
92+
self.stdout.write(f"- total rows: {total_rows}")
93+
94+
# 1) (선택) FAISS에서 벡터 복원
5895
vecs = None
59-
if not no_embedding and faiss_path:
96+
if faiss_path and not skip_vec:
6097
if not FAISS_AVAILABLE:
6198
self.stderr.write(self.style.WARNING("faiss 모듈이 없어 임베딩은 건너뜁니다. (pip install faiss-cpu)"))
6299
elif not os.path.exists(faiss_path):
63-
self.stderr.write(self.style.WARNING(f"FAISS index가 없어서 임베딩은 건너뜁니다: {faiss_path}"))
100+
self.stderr.write(self.style.WARNING(f"FAISS index가 없어 임베딩은 건너뜁니다: {faiss_path}"))
64101
else:
65102
self.stdout.write(self.style.MIGRATE_HEADING("Load FAISS index"))
66103
index = faiss.read_index(faiss_path)
67104
n = index.ntotal
68105
self.stdout.write(f"- index.ntotal: {n}")
106+
if dim_expect:
107+
try:
108+
d = index.d
109+
if d != dim_expect:
110+
self.stderr.write(self.style.WARNING(f"임베딩 차원 불일치: index.d={d}, --dim={dim_expect}"))
111+
except Exception:
112+
pass
69113
try:
70-
vecs = index.reconstruct_n(0, n) # (n, dim) float32
71-
if exp_dim is not None and vecs.shape[1] != exp_dim:
72-
self.stderr.write(self.style.WARNING(
73-
f"벡터 차원 불일치: index dim={vecs.shape[1]} vs --dim={exp_dim} → 임베딩 저장 건너뜀"
74-
))
75-
vecs = None
114+
vecs = index.reconstruct_n(0, n) # (n, d) float32
76115
except Exception as e:
77116
self.stderr.write(self.style.WARNING(
78117
f"reconstruct_n 실패 → 임베딩 저장 건너뜀 (원본 임베딩 파일 필요). err={e}"
79118
))
80119
vecs = None
81120

82-
# 3) CSV 적재(업서트)
83-
self.stdout.write(self.style.MIGRATE_HEADING("Load CSV & upsert"))
84-
processed = 0
85-
created_cnt = 0
86-
updated_cnt = 0
121+
# 2) CSV 적재
122+
self.stdout.write(self.style.MIGRATE_HEADING("Load CSV & upsert (progress bar)"))
123+
124+
created = 0
125+
updated = 0
126+
skipped = 0
127+
processed_rows = 0
87128

88129
# 태그 캐시(성능)
89130
tag_cache = {t.name: t.id for t in Tag.objects.all().only("id", "name")}
90-
new_tags_buf = []
131+
new_tags: list[str] = []
132+
133+
start_ts = time.time()
134+
last_tick = 0.0 # 초당 1회 이상 과도 출력 방지
135+
136+
try:
137+
with open(csv_path, newline="", encoding="utf-8-sig") as f:
138+
reader = csv.DictReader(f)
139+
for i, row in enumerate(reader, start=1):
140+
name = row.get("명칭")
141+
address = row.get("주소")
142+
overview = row.get("개요")
143+
lat = row.get("위도") or row.get("lat")
144+
lng = row.get("경도") or row.get("lng")
145+
summary = row.get("summary", "")
146+
external_id = row.get("external_id", None)
147+
is_unique = str(row.get("is_unique", "0")).strip() in ["1", "True", "true"]
148+
raw_cls = row.get("class", "0")
149+
150+
# 필수 필드 검증
151+
if not (name and address and overview and lat and lng):
152+
skipped += 1
153+
processed_rows += 1
154+
# 진행 표시(스로틀링: 1초/혹은 간격)
155+
now = time.time()
156+
if (i % log_interval == 0) or (now - last_tick >= 1.0) or (i == total_rows):
157+
self._print_progress(i, total_rows, start_ts, created, updated, skipped, bar_width)
158+
last_tick = now
159+
continue
160+
161+
try:
162+
place_class = int(float(str(raw_cls).replace(",", ".").strip() or 0))
163+
except ValueError:
164+
place_class = 0
165+
166+
try:
167+
lat = Decimal(lat)
168+
lng = Decimal(lng)
169+
except Exception:
170+
skipped += 1
171+
processed_rows += 1
172+
now = time.time()
173+
if (i % log_interval == 0) or (now - last_tick >= 1.0) or (i == total_rows):
174+
self._print_progress(i, total_rows, start_ts, created, updated, skipped, bar_width)
175+
last_tick = now
176+
continue
91177

92-
def flush_new_tags():
93-
nonlocal new_tags_buf, tag_cache
94-
if not new_tags_buf:
95-
return
96-
uniq = list(dict.fromkeys(new_tags_buf)) # 중복 제거
178+
region = address.split()[0] if address else ""
179+
180+
# (선택) 임베딩
181+
embedding = None
182+
if vecs is not None and (i - 1) < len(vecs):
183+
embedding = vecs[i - 1].astype("float32").tolist()
184+
185+
if not dry_run:
186+
defaults = {
187+
"region": region,
188+
"lat": lat,
189+
"lng": lng,
190+
"overview": overview,
191+
"external_id": external_id,
192+
"is_unique": is_unique,
193+
"summary": summary,
194+
"place_class": place_class,
195+
}
196+
if embedding is not None:
197+
defaults["embedding"] = embedding
198+
199+
place, was_created = Place.objects.update_or_create(
200+
name=name,
201+
address=address,
202+
defaults=defaults,
203+
)
204+
created += int(was_created)
205+
updated += int(not was_created)
206+
207+
# 태그 처리
208+
tag_str = row.get("tags", "")
209+
tag_names = [t.strip().lstrip("#") for t in tag_str.split() if t.strip()]
210+
211+
for tname in tag_names:
212+
tid = tag_cache.get(tname)
213+
if tid is None:
214+
new_tags.append(tname)
215+
216+
# 새 태그 배치 생성
217+
if new_tags and len(new_tags) >= batch_size:
218+
with transaction.atomic():
219+
for nt in new_tags:
220+
obj, _ = Tag.objects.get_or_create(name=nt)
221+
tag_cache[obj.name] = obj.id
222+
new_tags.clear()
223+
224+
if tag_names:
225+
ids = [tag_cache[t] for t in tag_names if t in tag_cache]
226+
if ids:
227+
place.tags.add(*ids)
228+
229+
processed_rows += 1
230+
231+
# 진행 표시(스로틀링: 1초/혹은 간격)
232+
now = time.time()
233+
if (i % log_interval == 0) or (now - last_tick >= 1.0) or (i == total_rows):
234+
self._print_progress(i, total_rows, start_ts, created, updated, skipped, bar_width)
235+
last_tick = now
236+
237+
except KeyboardInterrupt:
238+
# 줄 깨끗이 정리
239+
sys.stdout.write("\n")
240+
sys.stdout.flush()
241+
self.stderr.write(self.style.WARNING("사용자에 의해 중단됨(KeyboardInterrupt). 진행 상황을 요약합니다."))
242+
243+
# 남은 새 태그 처리
244+
if not dry_run and new_tags:
97245
with transaction.atomic():
98-
for nt in uniq:
246+
for nt in new_tags:
99247
obj, _ = Tag.objects.get_or_create(name=nt)
100248
tag_cache[obj.name] = obj.id
101-
new_tags_buf.clear()
102-
103-
with open(csv_path, newline="", encoding="utf-8-sig") as f:
104-
reader = csv.DictReader(f)
105-
for i, row in enumerate(reader):
106-
name = row.get("명칭") or row.get("name")
107-
address = row.get("주소") or row.get("address")
108-
overview = row.get("개요") or row.get("overview")
109-
lat = row.get("위도") or row.get("lat")
110-
lng = row.get("경도") or row.get("lng")
111-
summary = row.get("summary", "")
112-
external_id = row.get("external_id") or None
113-
is_unique = str(row.get("is_unique", "0")).strip() in ["1", "True", "true"]
114-
raw_cls = row.get("class", "0")
115-
116-
# 필수값 체크
117-
if not (name and address and overview and lat and lng):
118-
continue
119-
120-
# 수치 변환
121-
try:
122-
lat = Decimal(str(lat))
123-
lng = Decimal(str(lng))
124-
except Exception:
125-
continue
126249

127-
try:
128-
place_class = int(float(str(raw_cls).replace(",", ".").strip() or 0))
129-
except ValueError:
130-
place_class = 0
131-
132-
region = address.split()[0] if address else ""
133-
134-
# (선택) 임베딩
135-
embedding = None
136-
if vecs is not None and i < len(vecs):
137-
cur = vecs[i]
138-
if exp_dim is not None and cur.shape[0] != exp_dim:
139-
# 차원 불일치 시 건너뛰기
140-
pass
141-
else:
142-
embedding = cur.astype("float32").tolist()
143-
144-
# upsert 기준: external_id가 있으면 그걸로, 없으면 (name, address)
145-
lookup = {}
146-
if external_id:
147-
lookup["external_id"] = external_id
148-
else:
149-
lookup["name"] = name
150-
lookup["address"] = address
151-
152-
defaults = {
153-
"region": region,
154-
"lat": lat,
155-
"lng": lng,
156-
"overview": overview,
157-
"is_unique": is_unique,
158-
"summary": summary,
159-
"place_class": place_class,
160-
}
161-
if external_id:
162-
defaults["name"] = name
163-
defaults["address"] = address
164-
if embedding is not None:
165-
defaults["embedding"] = embedding
166-
167-
place, created = Place.objects.update_or_create(
168-
**lookup,
169-
defaults=defaults,
170-
)
171-
if created:
172-
created_cnt += 1
173-
else:
174-
updated_cnt += 1
175-
176-
# 태그 처리(동기화)
177-
tag_str = row.get("tags", "")
178-
tag_names = [t.strip().lstrip("#") for t in tag_str.split() if t.strip()]
179-
# 미리 생성
180-
for tname in tag_names:
181-
if tname not in tag_cache:
182-
new_tags_buf.append(tname)
183-
184-
if len(new_tags_buf) >= batch_size:
185-
flush_new_tags()
186-
187-
if tag_names:
188-
ids = [tag_cache[t] for t in tag_names if t in tag_cache]
189-
if ids:
190-
place.tags.set(ids) # 중복 없이 현재 행 기준으로 동기화
191-
192-
processed += 1
193-
if processed % batch_size == 0:
194-
self.stdout.write(f"- processed={processed} (created={created_cnt}, updated={updated_cnt})")
195-
196-
# 남은 새 태그 생성
197-
flush_new_tags()
250+
# 최종 진행줄 한 줄 마무리 출력(개행)
251+
self._print_progress(processed_rows, total_rows, start_ts, created, updated, skipped, bar_width, final=True)
198252

253+
# 요약
199254
self.stdout.write(self.style.SUCCESS(
200-
f"완료: processed={processed}, created={created_cnt}, updated={updated_cnt}"
255+
f"완료 ✅ total={total_rows}, ok={created+updated}, created={created}, updated={updated}, skipped={skipped}"
201256
))
202257

203258
# 인덱스 안내
204-
if vecs is not None:
259+
if not dry_run and (faiss_path and not skip_vec and vecs is not None):
205260
self.stdout.write(self.style.HTTP_INFO(
206261
"임베딩 인덱스가 없다면 psql에서 생성:\n"
207262
"CREATE INDEX IF NOT EXISTS place_embedding_ivfflat "

0 commit comments

Comments
 (0)