55from networksecurity .utils .ml_metric .classification_metric import get_classification_score
66from networksecurity .entity .config_entity import ModelTrainerConfigEntity , DataTransformationConfigEntity
77from networksecurity .entity .artifact_entity import DataTransformationArtifactEntity , ModelTrainerArtifactEntity
8- from networksecurity .utils .main_utils import save_obj , load_obj , load_numpy_array_data
8+ from networksecurity .utils .main_utils import save_obj , load_obj , load_numpy_array_data , evaluate_models
99from networksecurity .utils .model_metric .estimator import NetworkModel
10+ from sklearn .metrics import accuracy_score , f1_score , precision_score , recall_score
11+ from sklearn .model_selection import train_test_split
12+ from sklearn .linear_model import LogisticRegression
13+ from sklearn .ensemble import RandomForestClassifier , GradientBoostingClassifier , AdaBoostClassifier
14+ from sklearn .tree import DecisionTreeClassifier
15+ from sklear .neighbors import KNeighborsClassifier
16+
1017
1118class ModelTrainer :
1219 def __init__ (self , model_trainer_config : ModelTrainerConfigEntity ,
13- data_transformation_config : DataTransformationConfigEntity ,
1420 data_transformation_artifact : DataTransformationArtifactEntity ):
1521 try :
1622 self .model_trainer_config = model_trainer_config
17- self .data_transformation_config = data_transformation_config
1823 self .data_transformation_artifact = data_transformation_artifact
1924 self .logger = Custom_Logger ().get_logger ()
2025 self .logger .info ("Model Trainer initialized with configuration and artifacts." )
2126 except Exception as e :
2227 raise CustomException (e , sys ) from e
2328
29+ def train_model (self ,X_train , y_train , X_test , y_test ) -> NetworkModel :
30+ try :
31+ self .logger .info ("Starting model training process." )
32+ models = {
33+ 'LogisticRegression' : LogisticRegression (),
34+ 'RandomForestClassifier' : RandomForestClassifier (),
35+ 'GradientBoostingClassifier' : GradientBoostingClassifier (),
36+ 'AdaBoostClassifier' : AdaBoostClassifier (),
37+ 'DecisionTreeClassifier' : DecisionTreeClassifier (),
38+ 'KNeighborsClassifier' : KNeighborsClassifier ()
39+ }
40+
41+ params = {
42+ "Decision Tree" : {
43+ 'criterion' :['gini' , 'entropy' , 'log_loss' ],
44+ # 'splitter':['best','random'],
45+ # 'max_features':['sqrt','log2'],
46+ },
47+ "Random Forest" :{
48+ # 'criterion':['gini', 'entropy', 'log_loss'],
49+
50+ # 'max_features':['sqrt','log2',None],
51+ 'n_estimators' : [8 ,16 ,32 ,128 ,256 ]
52+ },
53+ "Gradient Boosting" :{
54+ # 'loss':['log_loss', 'exponential'],
55+ 'learning_rate' :[.1 ,.01 ,.05 ,.001 ],
56+ 'subsample' :[0.6 ,0.7 ,0.75 ,0.85 ,0.9 ],
57+ # 'criterion':['squared_error', 'friedman_mse'],
58+ # 'max_features':['auto','sqrt','log2'],
59+ 'n_estimators' : [8 ,16 ,32 ,64 ,128 ,256 ]
60+ },
61+ "Logistic Regression" :{},
62+ "AdaBoost" :{
63+ 'learning_rate' :[.1 ,.01 ,.001 ],
64+ 'n_estimators' : [8 ,16 ,32 ,64 ,128 ,256 ]
65+ }
66+ }
67+
68+ model_report :dict = evaluate_models (X_train = X_train , y_train = y_train , X_test = X_test , y_test = y_test , models = models , params = params )
69+ except Exception as e :
70+ raise CustomException (e , sys ) from e
71+
2472 def initiate_model_trainer (self ) -> ModelTrainerArtifactEntity :
2573 try :
26- pass
74+ train_file_path = self .data_transformation_artifact .transformed_train_file_path
75+ test_file_path = self .data_transformation_artifact .transformed_test_file_path
76+
77+ # Load the transformed train and test data
78+ train_arr = load_numpy_array_data (file_path = train_file_path )
79+ test_arr = load_numpy_array_data (file_path = test_file_path )
80+ pass
81+ # Split the data into features and target variable
82+ X_train , y_train = train_arr [:, :- 1 ], train_arr [:, - 1 ]
83+ X_test , y_test = test_arr [:, :- 1 ], test_arr [:, - 1 ]
84+ self .logger .info ("Data loaded and split into features and target variable." )
85+
86+ model = self .train_model (X_train , y_train )
87+ self .logger .info ("Model trained successfully." )
2788 except Exception as e :
2889 raise CustomException (e , sys )
0 commit comments