Skip to content

Commit 139dd5e

Browse files
feat: add command-line arguments for sample episodes and save plot directory; update MSE histogram plotting
Signed-off-by: Dongyun Kim <[email protected]>
1 parent 925f1f7 commit 139dd5e

File tree

2 files changed

+25
-3
lines changed

2 files changed

+25
-3
lines changed

physical_ai_server/physical_ai_server/evaluation/evaluation_manager.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -274,6 +274,19 @@ def main():
274274
required=True,
275275
help='The path to the policy file to be evaluated.'
276276
)
277+
parser.add_argument(
278+
'--sample_episodes',
279+
type=int,
280+
nargs='+',
281+
default=None,
282+
help='A list of episode indices to evaluate. If not provided, all episodes are evaluated.'
283+
)
284+
parser.add_argument(
285+
'--save_plot_dir',
286+
type=str,
287+
default='./plots',
288+
help='The directory to save evaluation plots.'
289+
)
277290
args = parser.parse_args()
278291

279292
# Initialize managers
@@ -297,10 +310,10 @@ def main():
297310
results = evaluation_manager.evaluate_policy_on_dataset(
298311
inference_manager=inference_manager,
299312
dataset=dataset,
300-
sample_episodes=[1, 3, 5],
313+
sample_episodes=args.sample_episodes,
301314
plot_episodes=True,
302315
plot_summary=True,
303-
save_plot_dir='./plots'
316+
save_plot_dir=args.save_plot_dir
304317
)
305318
print(f'Evaluation results: {results}')
306319

physical_ai_server/physical_ai_server/evaluation/visualization_manager.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525

2626
class VisualizationManager:
2727

28-
def __init__(self, default_save_dir: str = '/root/.cache'):
28+
def __init__(self, default_save_dir: str = os.path.expanduser('~/.cache')):
2929
self.default_save_dir = default_save_dir
3030
os.makedirs(self.default_save_dir, exist_ok=True)
3131

@@ -212,6 +212,15 @@ def plot_episode_mse_distribution(
212212
# Create figure and axis
213213
_, ax = plt.subplots(figsize=(10, 6))
214214

215+
# Plot histogram of MSE values
216+
ax.hist(
217+
episode_mses,
218+
bins=20,
219+
color='blue',
220+
alpha=0.7,
221+
edgecolor='black',
222+
label='MSE Distribution')
223+
215224
# Add vertical lines for mean and median
216225
mean_mse = np.mean(episode_mses)
217226
median_mse = np.median(episode_mses)

0 commit comments

Comments
 (0)