-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathapp.py
More file actions
119 lines (96 loc) · 3.91 KB
/
app.py
File metadata and controls
119 lines (96 loc) · 3.91 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
import torch
import joblib
import numpy as np
import pandas as pd
from model import TreeHealthModel
from features import Features
# Загружаем модель и LabelEncoder'ы
model = None
label_encoder = None
label_encoders = {}
cat_cardinalities = []
cat_features = []
num_features = []
def load_model_and_encoders():
global model, label_encoder, label_encoders, cat_cardinalities, cat_features, num_features
# Загружаем информацию о признаках
features = Features()
cat_features = features.get_cat_features() # Загружаем список категориальных признаков
num_features = features.get_num_features() # Загружаем список числовых признаков
# Загружаем модель
model = TreeHealthModel(
num_numeric=1, # Укажите правильное количество числовых признаков
cat_cardinalities=joblib.load('cat_cardinalities.pkl'), # Загружаем cat_cardinalities
embedding_dim=20,
hidden_dim=256
)
model.load_state_dict(torch.load('tree_health_model.pth', map_location=torch.device('cpu')))
model.eval()
# Загружаем LabelEncoder для целевой переменной
label_encoder = joblib.load('label_encoder_health.pkl')
# Загружаем LabelEncoder'ы для категориальных признаков
for col in cat_features:
label_encoders[col] = joblib.load(f'label_encoder_{col}.pkl')
# Определяем входные данные
class InputData(BaseModel):
tree_dbh: int
spc_common: object
spc_latin: object
postcode: int
borough: object
zip_city: object
steward: object
guards: object
sidewalk: object
user_type: object
root_stone: object
root_grate: object
root_other: object
trunk_wire: object
trnk_light: object
trnk_other: object
brch_light: object
brch_shoe: object
brch_other: object
curb_loc: object
# Создаем FastAPI приложение
app = FastAPI()
# Загружаем модель и LabelEncoder'ы при старте приложения
@app.on_event("startup")
async def startup_event():
load_model_and_encoders()
# Эндпоинт для предсказания
@app.post("/predict")
async def predict(data: InputData):
try:
# Преобразуем входные данные в DataFrame
input_dict = data.dict()
input_df = pd.DataFrame([input_dict])
# Преобразуем категориальные признаки с помощью LabelEncoder'ов
for col in cat_features:
input_df[col] = input_df[col].apply(lambda x: x if x in label_encoders[col].classes_ else 'unknown')
input_df[col] = label_encoders[col].transform(input_df[col])
print(input_df)
print(num_features)
# Преобразуем данные в тензоры
X_num = torch.tensor(input_df[num_features].values, dtype=torch.float32)
X_cat = torch.tensor(input_df[cat_features].values, dtype=torch.long)
print(X_num)
print(X_cat)
# Получаем предсказание
with torch.no_grad():
outputs = model(X_num, X_cat)
_, predicted = torch.max(outputs, 1)
prediction = predicted.cpu().numpy()[0]
print('kek')
# Декодируем предсказание
predicted_label = label_encoder.inverse_transform([prediction])[0]
return {"prediction": predicted_label}
except Exception as e:
raise HTTPException(status_code=400, detail=str(e))
# Запуск приложения
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)