Skip to content

Commit e99dbea

Browse files
authored
feat: download resources per task (#960)
* Rename benchmarks to allBenchmarks * User can download resources per benchmark * Expand click area of download status to text * Refactor function signature for more clarity * Add checksum validation before running a benchmark. * Fix Flutter tests * Update SONAR_SCANNER_VERSION
1 parent 1be3f66 commit e99dbea

File tree

14 files changed

+155
-102
lines changed

14 files changed

+155
-102
lines changed

flutter/integration_test/utils.dart

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ Future<void> startApp(WidgetTester tester) async {
3333
Future<void> validateSettings(WidgetTester tester) async {
3434
final state = tester.state(find.byType(MaterialApp));
3535
final benchmarkState = state.context.read<BenchmarkState>();
36-
for (var benchmark in benchmarkState.benchmarks) {
36+
for (var benchmark in benchmarkState.allBenchmarks) {
3737
expect(benchmark.selectedDelegate.batchSize, greaterThanOrEqualTo(0),
3838
reason: 'batchSize must >= 0');
3939
for (var modelFile in benchmark.selectedDelegate.modelFile) {
@@ -67,7 +67,7 @@ Future<void> validateSettings(WidgetTester tester) async {
6767
Future<void> setBenchmarks(WidgetTester tester) async {
6868
final state = tester.state(find.byType(MaterialApp));
6969
final benchmarkState = state.context.read<BenchmarkState>();
70-
for (var benchmark in benchmarkState.benchmarks) {
70+
for (var benchmark in benchmarkState.allBenchmarks) {
7171
// Disable test for stable diffusion since it take too long to finish.
7272
if (benchmark.id == BenchmarkId.stableDiffusion) {
7373
benchmark.isActive = false;

flutter/lib/benchmark/benchmark.dart

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,11 @@ class Benchmark {
117117
}
118118

119119
class BenchmarkStore {
120-
final List<Benchmark> benchmarks = <Benchmark>[];
120+
final List<Benchmark> allBenchmarks = <Benchmark>[];
121+
122+
List<Benchmark> get activeBenchmarks {
123+
return allBenchmarks.where((e) => e.isActive).toList();
124+
}
121125

122126
BenchmarkStore({
123127
required pb.MLPerfConfig appConfig,
@@ -137,7 +141,7 @@ class BenchmarkStore {
137141
}
138142

139143
final enabled = taskSelection[task.id] ?? true;
140-
benchmarks.add(Benchmark(
144+
allBenchmarks.add(Benchmark(
141145
taskConfig: task,
142146
benchmarkSettings: backendSettings,
143147
isActive: enabled,
@@ -186,7 +190,7 @@ class BenchmarkStore {
186190

187191
Map<String, bool> get selection {
188192
Map<String, bool> result = {};
189-
for (var item in benchmarks) {
193+
for (var item in allBenchmarks) {
190194
result[item.id] = item.isActive;
191195
}
192196
return result;

flutter/lib/benchmark/state.dart

Lines changed: 32 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -54,28 +54,30 @@ class BenchmarkState extends ChangeNotifier {
5454
ExtendedResult? lastResult;
5555

5656
num get result {
57-
final benchmarksCount = benchmarks
57+
final benchmarksCount = allBenchmarks
5858
.where((benchmark) => benchmark.performanceModeResult != null)
5959
.length;
6060

6161
if (benchmarksCount == 0) return 0;
6262

6363
final summaryThroughput = pow(
64-
benchmarks.fold<double>(1, (prev, i) {
64+
allBenchmarks.fold<double>(1, (prev, i) {
6565
return prev * (i.performanceModeResult?.throughput?.value ?? 1.0);
6666
}),
6767
1.0 / benchmarksCount);
6868

6969
final maxSummaryThroughput = pow(
70-
benchmarks.fold<double>(1, (prev, i) {
70+
allBenchmarks.fold<double>(1, (prev, i) {
7171
return prev * (i.info.maxThroughput);
7272
}),
7373
1.0 / benchmarksCount);
7474

7575
return summaryThroughput / maxSummaryThroughput;
7676
}
7777

78-
List<Benchmark> get benchmarks => _benchmarkStore.benchmarks;
78+
List<Benchmark> get allBenchmarks => _benchmarkStore.allBenchmarks;
79+
80+
List<Benchmark> get activeBenchmarks => _benchmarkStore.activeBenchmarks;
7981

8082
late BenchmarkStore _benchmarkStore;
8183

@@ -131,25 +133,42 @@ class BenchmarkState extends ChangeNotifier {
131133
}
132134
}
133135

134-
Future<void> loadResources({required bool downloadMissing}) async {
136+
Future<void> loadResources(
137+
{required bool downloadMissing,
138+
List<Benchmark> benchmarks = const []}) async {
135139
final newAppVersion =
136140
'${BuildInfoHelper.info.version}+${BuildInfoHelper.info.buildNumber}';
137141
var needToPurgeCache = _store.previousAppVersion != newAppVersion;
138142
_store.previousAppVersion = newAppVersion;
139143

144+
final selectedBenchmarks = benchmarks.isEmpty ? allBenchmarks : benchmarks;
140145
await Wakelock.enable();
141-
print('Start loading resources with downloadMissing=$downloadMissing');
142-
final resources = _benchmarkStore.listResources(
146+
final selectedResources = _benchmarkStore.listResources(
147+
modes: [taskRunner.perfMode, taskRunner.accuracyMode],
148+
benchmarks: selectedBenchmarks,
149+
);
150+
final allResources = _benchmarkStore.listResources(
143151
modes: [taskRunner.perfMode, taskRunner.accuracyMode],
144-
benchmarks: benchmarks,
152+
benchmarks: allBenchmarks,
145153
);
146154
try {
155+
final selectedBenchmarkIds = selectedBenchmarks
156+
.map((e) => e.benchmarkSettings.benchmarkId)
157+
.join(', ');
158+
print('Start loading resources with downloadMissing=$downloadMissing '
159+
'for $selectedBenchmarkIds');
147160
await resourceManager.handleResources(
148-
resources,
149-
needToPurgeCache,
150-
downloadMissing,
161+
resources: selectedResources,
162+
purgeOldCache: needToPurgeCache,
163+
downloadMissing: downloadMissing,
151164
);
152165
print('Finished loading resources with downloadMissing=$downloadMissing');
166+
// We still need to load all resources after download selected resources.
167+
await resourceManager.handleResources(
168+
resources: allResources,
169+
purgeOldCache: false,
170+
downloadMissing: false,
171+
);
153172
error = null;
154173
stackTrace = null;
155174
taskConfigFailedToLoad = false;
@@ -289,7 +308,7 @@ class BenchmarkState extends ChangeNotifier {
289308
}
290309

291310
void resetCurrentResults() {
292-
for (var b in _benchmarkStore.benchmarks) {
311+
for (var b in _benchmarkStore.allBenchmarks) {
293312
b.accuracyModeResult = null;
294313
b.performanceModeResult = null;
295314
}
@@ -304,7 +323,7 @@ class BenchmarkState extends ChangeNotifier {
304323
lastResult = ExtendedResult.fromJson(
305324
jsonDecode(_store.previousExtendedResult) as Map<String, dynamic>);
306325
resourceManager.resultManager
307-
.restoreResults(lastResult!.results, benchmarks);
326+
.restoreResults(lastResult!.results, allBenchmarks);
308327
_doneRunning = true;
309328
return;
310329
} catch (e, trace) {

flutter/lib/l10n/app_en.arb

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,7 @@
9393
"dialogContentMissingFiles": "The following files don't exist:",
9494
"dialogContentMissingFilesHint": "Please go to the menu Resources to download the missing files.",
9595
"dialogContentChecksumError": "The following files failed checksum validation:",
96+
"dialogContentChecksumErrorHint": "Please go to the menu Resources to clear the cache and download the files again.",
9697
"dialogContentNoSelectedBenchmarkError": "Please select at least one benchmark.",
9798

9899
"benchModePerformanceOnly": "Performance Only",
@@ -122,7 +123,9 @@
122123
"benchInfoStableDiffusionDesc": "The Text to Image Gen AI benchmark adopts Stable Diffusion v1.5 for generating images from text prompts. It is a latent diffusion model. The benchmarked Stable Diffusion v1.5 refers to a specific configuration of the model architecture that uses a downsampling-factor 8 autoencoder with an 860M UNet,123M CLIP ViT-L/14 text encoder for the diffusion model, and VAE Decoder of 49.5M parameters. The model was trained on 595k steps at resolution of 512x512, which enables it to generate high quality images. We refer you to https://huggingface.co/benjamin-paine/stable-diffusion-v1-5 for more information. The benchmark runs 20 denoising steps for inference, and uses a precalculated time embedding of size 1x1280. Reference models can be found here https://github.com/mlcommons/mobile_open/releases.\n\nFor latency benchmarking, we benchmark end to end, excluding the time embedding calculation and the tokenizer. For accuracy calculations, the app adopts the CLIP metric for text-to-image consistency, and further evaluation of the generated images using this Image Quality Aesthetic Assessment metric https://github.com/idealo/image-quality-assessment/tree/master?tab=readme-ov-file",
123124

124125
"resourceDownload": "Download",
126+
"resourceDownloadAll": "Download all",
125127
"resourceClear": "Clear",
128+
"resourceClearAll": "Clear all",
126129
"resourceChecking": "Checking download status",
127130
"resourceDownloading": "Downloading",
128131
"resourceErrorMessage": "Some resources failed to load.\nIf you didn't change config from default you can try clearing the cache.\nIf you use a custom configuration file ensure that it has correct structure or switch back to default config.",

flutter/lib/resources/cache_manager.dart

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -91,11 +91,12 @@ class CacheManager {
9191
return deleteLoadedResources(currentResources, atLeastDaysOld);
9292
}
9393

94-
Future<void> cache(
95-
List<String> urls,
96-
void Function(double, String) reportProgress,
97-
bool purgeOldCache,
98-
bool downloadMissing) async {
94+
Future<void> cache({
95+
required List<String> urls,
96+
required void Function(double, String) onProgressUpdate,
97+
required bool purgeOldCache,
98+
required bool downloadMissing,
99+
}) async {
99100
final resourcesToDownload = <String>[];
100101
_resourcesMap = {};
101102

@@ -120,7 +121,7 @@ class CacheManager {
120121
continue;
121122
}
122123
if (downloadMissing) {
123-
await _download(resourcesToDownload, reportProgress);
124+
await _download(resourcesToDownload, onProgressUpdate);
124125
}
125126
if (purgeOldCache) {
126127
await purgeOutdatedCache(_oldFilesAgeInDays);
@@ -132,18 +133,20 @@ class CacheManager {
132133
}
133134

134135
Future<void> _download(
135-
List<String> urls, void Function(double, String) reportProgress) async {
136+
List<String> urls,
137+
void Function(double, String) onProgressUpdate,
138+
) async {
136139
var progress = 0.0;
137140
for (var url in urls) {
138141
progress += 0.1 / urls.length;
139-
reportProgress(progress, url);
142+
onProgressUpdate(progress, url);
140143
if (isResourceAnArchive(url)) {
141144
_resourcesMap[url] = await archiveCacheHelper.get(url, true);
142145
} else {
143146
_resourcesMap[url] = await fileCacheHelper.get(url, true);
144147
}
145148
progress += 0.9 / urls.length;
146-
reportProgress(progress, url);
149+
onProgressUpdate(progress, url);
147150
}
148151
}
149152
}

flutter/lib/resources/resource_manager.dart

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -94,8 +94,11 @@ class ResourceManager {
9494
return checksum == md5Checksum;
9595
}
9696

97-
Future<void> handleResources(List<Resource> resources, bool purgeOldCache,
98-
bool downloadMissing) async {
97+
Future<void> handleResources({
98+
required List<Resource> resources,
99+
required bool purgeOldCache,
100+
required bool downloadMissing,
101+
}) async {
99102
_loadingPath = '';
100103
_loadingProgress = 0.001;
101104
_done = false;
@@ -114,14 +117,14 @@ class ResourceManager {
114117
final internetPaths = internetResources.map((e) => e.path).toList();
115118
try {
116119
await cacheManager.cache(
117-
internetPaths,
118-
(double currentProgress, String currentPath) {
120+
urls: internetPaths,
121+
onProgressUpdate: (double currentProgress, String currentPath) {
119122
_loadingProgress = currentProgress;
120123
_loadingPath = currentPath;
121124
_onUpdate();
122125
},
123-
purgeOldCache,
124-
downloadMissing,
126+
purgeOldCache: purgeOldCache,
127+
downloadMissing: downloadMissing,
125128
);
126129
} on SocketException {
127130
throw 'A network error has occurred. Please make sure you are connected to the internet.';

flutter/lib/resources/validation_helper.dart

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,6 @@ class ValidationHelper {
1818
required this.selectedRunModes,
1919
});
2020

21-
List<Benchmark> get activeBenchmarks =>
22-
benchmarkStore.benchmarks.where((e) => e.isActive).toList();
23-
2421
Future<String> validateExternalResourcesDirectory(
2522
String errorDescription) async {
2623
final dataFolderPath = resourceManager.getDataFolder();
@@ -32,7 +29,7 @@ class ValidationHelper {
3229
}
3330
final resources = benchmarkStore.listResources(
3431
modes: selectedRunModes,
35-
benchmarks: activeBenchmarks,
32+
benchmarks: benchmarkStore.activeBenchmarks,
3633
);
3734
final result = await resourceManager.validateResourcesExist(resources);
3835
final missing = result[false] ?? [];
@@ -42,10 +39,22 @@ class ValidationHelper {
4239
missing.mapIndexed((i, element) => '\n${i + 1}) $element').join();
4340
}
4441

42+
Future<String> validateChecksum(String errorDescription) async {
43+
final resources = benchmarkStore.listResources(
44+
modes: selectedRunModes,
45+
benchmarks: benchmarkStore.activeBenchmarks,
46+
);
47+
final checksumFailed =
48+
await resourceManager.validateResourcesChecksum(resources);
49+
if (checksumFailed.isEmpty) return '';
50+
final mismatchedPaths = checksumFailed.map((e) => '\n${e.path}').join();
51+
return errorDescription + mismatchedPaths;
52+
}
53+
4554
Future<String> validateOfflineMode(String errorDescription) async {
4655
final resources = benchmarkStore.listResources(
4756
modes: selectedRunModes,
48-
benchmarks: activeBenchmarks,
57+
benchmarks: benchmarkStore.activeBenchmarks,
4958
);
5059
final internetResources = filterInternetResources(resources);
5160
if (internetResources.isEmpty) return '';

flutter/lib/state/task_runner.dart

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -80,8 +80,7 @@ class TaskRunner {
8080
final cooldown = store.cooldown;
8181
final cooldownDuration = Duration(seconds: store.cooldownDuration);
8282

83-
final activeBenchmarks =
84-
benchmarkStore.benchmarks.where((element) => element.isActive);
83+
final activeBenchmarks = benchmarkStore.activeBenchmarks;
8584

8685
final resultHelpers = <ResultHelper>[];
8786
for (final benchmark in activeBenchmarks) {

flutter/lib/ui/home/benchmark_config_section.dart

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ class _BenchmarkConfigSectionState extends State<BenchmarkConfigSection> {
3333
l10n = AppLocalizations.of(context)!;
3434
final childrenList = <Widget>[];
3535

36-
for (var benchmark in state.benchmarks) {
36+
for (var benchmark in state.allBenchmarks) {
3737
childrenList.add(_listTile(benchmark));
3838
childrenList.add(const Divider(height: 20));
3939
}

flutter/lib/ui/home/benchmark_result_screen.dart

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -219,7 +219,7 @@ class _BenchmarkResultScreenState extends State<BenchmarkResultScreen>
219219

220220
Widget _detailSection() {
221221
final children = <Widget>[];
222-
for (final benchmark in state.benchmarks) {
222+
for (final benchmark in state.allBenchmarks) {
223223
final row = _benchmarkResultRow(benchmark);
224224
children.add(row);
225225
children.add(const Divider());

0 commit comments

Comments
 (0)