Skip to content

Commit 322484d

Browse files
committed
add tuning result table
Signed-off-by: He, Xin3 <[email protected]>
1 parent 13aacc1 commit 322484d

File tree

1 file changed

+48
-1
lines changed

1 file changed

+48
-1
lines changed

neural_compressor/common/base_tuning.py

Lines changed: 48 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from typing import Any, Callable, Dict, Generator, Iterator, List, Optional, Sized, Tuple, Union
1919

2020
from neural_compressor.common.base_config import BaseConfig
21-
from neural_compressor.common.utils import TuningLogger, logger
21+
from neural_compressor.common.utils import Statistics, TuningLogger, logger
2222

2323
__all__ = [
2424
"Evaluator",
@@ -423,6 +423,47 @@ def add_trial_result(self, trial_index: int, trial_result: Union[int, float], qu
423423
trial_record = _TrialRecord(trial_index, trial_result, quant_config)
424424
self.tuning_history.append(trial_record)
425425

426+
# Print tuning results table
427+
self._print_trial_results_table(trial_index, trial_result)
428+
429+
def _print_trial_results_table(self, trial_index: int, trial_result: Union[int, float]) -> None:
430+
"""Print trial results in a formatted table using Statistics class."""
431+
baseline_val = self.baseline if self.baseline is not None else 0.0
432+
baseline_str = f"{baseline_val:.4f}" if self.baseline is not None else "N/A"
433+
target_threshold_str = (
434+
f"{baseline_val * (1 - self.tuning_config.tolerable_loss):.4f}" if self.baseline is not None else "N/A"
435+
)
436+
437+
# Calculate relative loss if baseline is available
438+
relative_loss_val = 0.0
439+
relative_loss_str = "N/A"
440+
if self.baseline is not None:
441+
relative_loss_val = (baseline_val - trial_result) / baseline_val
442+
relative_loss_str = f"{relative_loss_val*100:.2f}%"
443+
444+
# Get best result so far
445+
best_result = max(record.trial_result for record in self.tuning_history)
446+
447+
# Status indicator with emoji
448+
if self.baseline is not None and trial_result >= (baseline_val * (1 - self.tuning_config.tolerable_loss)):
449+
status = "✅ PASSED"
450+
else:
451+
status = "❌ FAILED"
452+
453+
# Prepare data for Statistics table with combined fields
454+
field_names = ["📊 Metric", "📈 Value"]
455+
output_data = [
456+
["Trial / Progress", f"{len(self.tuning_history)}/{self.tuning_config.max_trials}"],
457+
["Baseline / Target", f"{baseline_str} / {target_threshold_str}"],
458+
["Current / Status", f"{trial_result:.4f} | {status}"],
459+
["Best / Relative Loss", f"{best_result:.4f} / {relative_loss_str}"],
460+
]
461+
462+
# Use Statistics class to print the table
463+
Statistics(
464+
output_data, header=f"🎯 Auto-Tune Trial #{trial_index} Results", field_names=field_names
465+
).print_stat()
466+
426467
def set_baseline(self, baseline: float):
427468
"""Set the baseline value for auto-tune.
428469
@@ -488,4 +529,10 @@ def init_tuning(tuning_config: TuningConfig) -> Tuple[ConfigLoader, TuningLogger
488529
config_loader = ConfigLoader(config_set=tuning_config.config_set, sampler=tuning_config.sampler)
489530
tuning_logger = TuningLogger()
490531
tuning_monitor = TuningMonitor(tuning_config)
532+
533+
# Update max_trials based on actual number of available configurations
534+
actual_config_count = len(config_loader.config_set)
535+
if tuning_config.max_trials > actual_config_count:
536+
tuning_config.max_trials = actual_config_count
537+
491538
return config_loader, tuning_logger, tuning_monitor

0 commit comments

Comments
 (0)