Skip to content
Open
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,18 +1,47 @@
import { useRef } from 'react';

import { $api } from '@geti-inspect/api';
import { useProjectIdentifier } from '@geti-inspect/hooks';
import { Button, FileTrigger, toast } from '@geti/ui';
import { useQueryClient } from '@tanstack/react-query';

import { useUploadStatus } from '../footer/adapters';
import { TrainModelButton } from '../train-model/train-model-button.component';
import { REQUIRED_NUMBER_OF_NORMAL_IMAGES_TO_TRIGGER_TRAINING } from './utils';

export const UploadImages = () => {
const { projectId } = useProjectIdentifier();
const queryClient = useQueryClient();
const { startUpload, updateProgress, completeUpload } = useUploadStatus();

const captureImageMutation = $api.useMutation('post', '/api/projects/{project_id}/images');
// Track progress across parallel uploads
const progressRef = useRef({ completed: 0, failed: 0, total: 0 });
Comment on lines +20 to +21
Copy link

Copilot AI Dec 3, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using a mutable ref for tracking progress in parallel uploads creates a race condition. Multiple onSuccess/onError callbacks from captureImageMutation could update progressRef.current.completed and progressRef.current.failed simultaneously, leading to inaccurate counts. Consider using atomic state updates or a reducer pattern.

Copilot uses AI. Check for mistakes.

const captureImageMutation = $api.useMutation('post', '/api/projects/{project_id}/images', {
onSuccess: () => {
progressRef.current.completed++;
updateProgress({
completed: progressRef.current.completed + progressRef.current.failed,
total: progressRef.current.total,
failed: progressRef.current.failed,
});
},
onError: () => {
progressRef.current.failed++;
updateProgress({
completed: progressRef.current.completed + progressRef.current.failed,
total: progressRef.current.total,
failed: progressRef.current.failed,
});
},
});

const handleAddMediaItem = async (files: File[]) => {
const total = files.length;

progressRef.current = { completed: 0, failed: 0, total };
startUpload(total);

const uploadPromises = files.map((file) => {
const formData = new FormData();
formData.append('file', file);
Expand All @@ -24,10 +53,10 @@ export const UploadImages = () => {
});
});

const promises = await Promise.allSettled(uploadPromises);
await Promise.allSettled(uploadPromises);

const succeeded = promises.filter((result) => result.status === 'fulfilled').length;
const failed = promises.filter((result) => result.status === 'rejected').length;
const { failed } = progressRef.current;
completeUpload(failed === 0, failed);

const imagesOptions = $api.queryOptions('get', '/api/projects/{project_id}/images', {
params: { path: { project_id: projectId } },
Expand All @@ -44,18 +73,6 @@ export const UploadImages = () => {
actionButtons: [<TrainModelButton key='train' />],
position: 'bottom-left',
});
return;
}

if (failed === 0) {
toast({ type: 'success', message: `Uploaded ${succeeded} item(s)` });
} else if (succeeded === 0) {
toast({ type: 'error', message: `Failed to upload ${failed} item(s)` });
} else {
toast({
type: 'warning',
message: `Uploaded ${succeeded} item(s), ${failed} failed`,
});
}
};

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
// Copyright (C) 2025 Intel Corporation
// SPDX-License-Identifier: Apache-2.0

import { ReactNode } from 'react';

import { renderHook } from '@testing-library/react';

import { useWebRTCConnection } from '../../../../components/stream/web-rtc-connection-provider';
import { StatusBarProvider, useStatusBar } from '../status-bar';
import { ConnectionStatusAdapter } from './connection-status.adapter';

vi.mock('../../../../components/stream/web-rtc-connection-provider', () => ({
useWebRTCConnection: vi.fn(),
}));

const wrapper = ({ children }: { children: ReactNode }) => (
<StatusBarProvider>
<ConnectionStatusAdapter />
{children}
</StatusBarProvider>
);

describe('ConnectionStatusAdapter', () => {
beforeEach(() => {
vi.clearAllMocks();
});

it('maps connected status', () => {
vi.mocked(useWebRTCConnection).mockReturnValue({
status: 'connected',
start: vi.fn(),
stop: vi.fn(),
webRTCConnectionRef: { current: null },
});

const { result } = renderHook(() => useStatusBar(), { wrapper });

expect(result.current.connection).toBe('connected');
});

it('maps idle to disconnected', () => {
vi.mocked(useWebRTCConnection).mockReturnValue({
status: 'idle',
start: vi.fn(),
stop: vi.fn(),
webRTCConnectionRef: { current: null },
});

const { result } = renderHook(() => useStatusBar(), { wrapper });

expect(result.current.connection).toBe('disconnected');
});

it('maps failed status', () => {
vi.mocked(useWebRTCConnection).mockReturnValue({
status: 'failed',
start: vi.fn(),
stop: vi.fn(),
webRTCConnectionRef: { current: null },
});

const { result } = renderHook(() => useStatusBar(), { wrapper });

expect(result.current.connection).toBe('failed');
});
});
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
// Copyright (C) 2025 Intel Corporation
// SPDX-License-Identifier: Apache-2.0

import { useEffect } from 'react';

import { useWebRTCConnection } from '../../../../components/stream/web-rtc-connection-provider';
import { ConnectionStatus, useStatusBar } from '../status-bar';

const CONNECTION_STATUS_MAP: Record<string, ConnectionStatus> = {
connected: 'connected',
connecting: 'connecting',
failed: 'failed',
idle: 'disconnected',
disconnected: 'disconnected',
};

export const ConnectionStatusAdapter = () => {
const { setConnection } = useStatusBar();
const { status } = useWebRTCConnection();

useEffect(() => {
const connectionStatus = CONNECTION_STATUS_MAP[status] || 'disconnected';

setConnection(connectionStatus);
}, [status, setConnection]);

return null;
};
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
// Copyright (C) 2025 Intel Corporation
// SPDX-License-Identifier: Apache-2.0

export { ConnectionStatusAdapter } from './connection-status.adapter';
export { TrainingStatusAdapter } from './training-status.adapter';
export { useExportStatus } from './use-export-status';
export { useUploadStatus } from './use-upload-status';
Original file line number Diff line number Diff line change
@@ -0,0 +1,163 @@
// Copyright (C) 2025 Intel Corporation
// SPDX-License-Identifier: Apache-2.0

import { ReactNode } from 'react';

import { QueryClient, QueryClientProvider } from '@tanstack/react-query';
import { renderHook, waitFor } from '@testing-library/react';
import { HttpResponse } from 'msw';
import { MemoryRouter, Route, Routes } from 'react-router-dom';
import { describe, expect, it, vi } from 'vitest';

import { http } from '../../../../api/utils';
import { server } from '../../../../msw-node-setup';
import { StatusBarProvider, useStatusBar } from '../status-bar';
import { TrainingStatusAdapter } from './training-status.adapter';

const wrapper = ({ children }: { children: ReactNode }) => (
<QueryClientProvider client={new QueryClient()}>
<MemoryRouter initialEntries={['/projects/test-project/inspect']}>
<Routes>
<Route
path='/projects/:projectId/inspect'
element={
<StatusBarProvider>
<TrainingStatusAdapter />
{children}
</StatusBarProvider>
}
/>
</Routes>
</MemoryRouter>
</QueryClientProvider>
);

describe('TrainingStatusAdapter', () => {
it('does not set status when no training job exists', async () => {
server.use(http.get('/api/jobs', ({ response }) => response(200).json({ jobs: [] })));

const { result } = renderHook(() => useStatusBar(), { wrapper });

await waitFor(() => {
expect(result.current.activeStatus).toBeNull();
});
});

it('sets training status when training job is running', async () => {
const trainingJob = {
id: 'job-1',
project_id: 'test-project',
type: 'training' as const,
status: 'running' as const,
progress: 45,
message: 'Epoch 5/10',
payload: { model_name: 'EfficientAd' },
};
server.use(http.get('/api/jobs', ({ response }) => response(200).json({ jobs: [trainingJob] })));

const { result } = renderHook(() => useStatusBar(), { wrapper });

await waitFor(() => {
expect(result.current.activeStatus).toEqual(
expect.objectContaining({
id: 'training',
type: 'training',
message: 'Training EfficientAd...',
detail: 'Epoch 5/10',
progress: 45,
variant: 'info',
isCancellable: true,
})
);
});
});

it('sets training status when training job is pending', async () => {
const trainingJob = {
id: 'job-2',
project_id: 'test-project',
type: 'training' as const,
status: 'pending' as const,
progress: 0,
message: 'Waiting...',
payload: { model_name: 'Padim' },
};
server.use(http.get('/api/jobs', ({ response }) => response(200).json({ jobs: [trainingJob] })));

const { result } = renderHook(() => useStatusBar(), { wrapper });

await waitFor(() => {
expect(result.current.activeStatus?.message).toBe('Training Padim...');
});
});

it('ignores training jobs from other projects', async () => {
const trainingJob = {
id: 'job-3',
project_id: 'other-project',
type: 'training' as const,
status: 'running' as const,
progress: 50,
message: 'Training...',
payload: { model_name: 'Test' },
};
server.use(http.get('/api/jobs', ({ response }) => response(200).json({ jobs: [trainingJob] })));

const { result } = renderHook(() => useStatusBar(), { wrapper });

await waitFor(() => {
expect(result.current.activeStatus).toBeNull();
});
});

it('ignores completed or failed training jobs', async () => {
const completedJob = {
id: 'job-4',
project_id: 'test-project',
type: 'training' as const,
status: 'completed' as const,
progress: 100,
message: 'Done',
payload: { model_name: 'Test' },
};
server.use(http.get('/api/jobs', ({ response }) => response(200).json({ jobs: [completedJob] })));

const { result } = renderHook(() => useStatusBar(), { wrapper });

await waitFor(() => {
expect(result.current.activeStatus).toBeNull();
});
});

it('calls cancelJob when onCancel is triggered', async () => {
const cancelJobSpy = vi.fn();
const trainingJob = {
id: 'job-5',
project_id: 'test-project',
type: 'training' as const,
status: 'running' as const,
progress: 30,
message: 'Training...',
payload: { model_name: 'EfficientAd' },
};
server.use(
http.get('/api/jobs', ({ response }) => response(200).json({ jobs: [trainingJob] })),
http.post('/api/jobs/{job_id}:cancel', ({ params }) => {
cancelJobSpy(params.job_id);
return HttpResponse.json({}, { status: 204 });
})
);

const { result } = renderHook(() => useStatusBar(), { wrapper });

await waitFor(() => {
expect(result.current.activeStatus?.onCancel).toBeDefined();
});

result.current.activeStatus?.onCancel?.();

await waitFor(() => {
expect(cancelJobSpy).toHaveBeenCalledWith('job-5');
});
});
});
Loading
Loading