Skip to content

Commit 1053d6b

Browse files
RSMNYSfreedomtananhappdevmohitmundhragithub
authored
refactor: make stable diffusion flow work with one sd model (#913)
* refactor: make stable diffusion flow work with one diffusion model * fix: make use of the dynamic models * chore: formatting and code cleanup --------- Co-authored-by: Koan-Sin Tan <[email protected]> Co-authored-by: Anh <[email protected]> Co-authored-by: Mohit Mundhra <[email protected]>
1 parent a659b30 commit 1053d6b

File tree

3 files changed

+30
-89
lines changed

3 files changed

+30
-89
lines changed

mobile_back_tflite/cpp/backend_tflite/stable_diffusion_invoker.cc

Lines changed: 16 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -44,79 +44,31 @@ std::vector<float> StableDiffusionInvoker::encode_prompt(
4444
std::vector<float> StableDiffusionInvoker::diffusion_step(
4545
const std::vector<float>& latent, const std::vector<float>& t_emb,
4646
const std::vector<float>& context) {
47-
// Prepare the first model's inputs
48-
49-
auto first_input_details =
50-
TfLiteInterpreterGetInputTensor(backend_data_->first_interpreter, 0);
51-
auto second_input_details =
52-
TfLiteInterpreterGetInputTensor(backend_data_->first_interpreter, 1);
53-
auto third_input_details =
54-
TfLiteInterpreterGetInputTensor(backend_data_->first_interpreter, 2);
47+
auto latent_input_details =
48+
TfLiteInterpreterGetInputTensor(backend_data_->sd_interpreter, 0);
49+
auto context_input_details =
50+
TfLiteInterpreterGetInputTensor(backend_data_->sd_interpreter, 1);
51+
auto time_stamp_embedding_input_details =
52+
TfLiteInterpreterGetInputTensor(backend_data_->sd_interpreter, 2);
5553

5654
std::copy(context.begin(), context.end(),
57-
reinterpret_cast<float*>(TfLiteTensorData(first_input_details)));
55+
reinterpret_cast<float*>(TfLiteTensorData(context_input_details)));
5856
std::copy(t_emb.begin(), t_emb.end(),
59-
reinterpret_cast<float*>(TfLiteTensorData(second_input_details)));
57+
reinterpret_cast<float*>(
58+
TfLiteTensorData(time_stamp_embedding_input_details)));
6059
std::copy(latent.begin(), latent.end(),
61-
reinterpret_cast<float*>(TfLiteTensorData(third_input_details)));
60+
reinterpret_cast<float*>(TfLiteTensorData(latent_input_details)));
6261

63-
// Invoke the first model
64-
if (TfLiteInterpreterInvoke(backend_data_->first_interpreter) != kTfLiteOk) {
62+
// Invoke the model
63+
if (TfLiteInterpreterInvoke(backend_data_->sd_interpreter) != kTfLiteOk) {
6564
std::cerr << "Failed to invoke the first diffusion model!" << std::endl;
6665
exit(-1);
6766
}
6867

69-
// Output names from the first model and corresponding input names for the
70-
// second model
71-
std::vector<std::string> output_names = {
72-
"Identity_6", "Identity_4", "Identity", "input_1", "Identity_12",
73-
"Identity_11", "Identity_3", "Identity_10", "Identity_9", "Identity_5",
74-
"Identity_8", "Identity_7", "Identity_2"};
75-
76-
std::vector<std::string> input_names = {
77-
"args_0", "args_0_1", "args_0_2", "args_0_4", "args_0_3",
78-
"args_0_5", "args_0_6", "args_0_7", "args_0_8", "args_0_9",
79-
"args_0_10", "args_0_11", "args_0_12"};
80-
81-
// Copy outputs of the first model to the inputs of the second model based on
82-
// names
83-
for (size_t i = 0; i < input_names.size(); ++i) {
84-
int input_index = get_tensor_index_by_name(
85-
backend_data_->second_interpreter, input_names[i], true);
86-
int output_index = get_tensor_index_by_name(
87-
backend_data_->first_interpreter, output_names[i], false);
88-
89-
if (input_index == -1 || output_index == -1) {
90-
std::cerr << "Failed to find matching input or output tensor by name!"
91-
<< std::endl;
92-
exit(-1);
93-
}
94-
95-
auto first_model_output_details = TfLiteInterpreterGetOutputTensor(
96-
backend_data_->first_interpreter, output_index);
97-
98-
float* output_data =
99-
reinterpret_cast<float*>(TfLiteTensorData(first_model_output_details));
100-
int output_size =
101-
TfLiteTensorByteSize(first_model_output_details) / sizeof(float);
102-
103-
float* input_data = reinterpret_cast<float*>(
104-
TfLiteTensorData(TfLiteInterpreterGetInputTensor(
105-
backend_data_->second_interpreter, input_index)));
106-
107-
std::copy(output_data, output_data + output_size, input_data);
108-
}
109-
110-
// Invoke the second model
111-
if (TfLiteInterpreterInvoke(backend_data_->second_interpreter) != kTfLiteOk) {
112-
std::cerr << "Failed to invoke the second diffusion model!" << std::endl;
113-
exit(-1);
114-
}
115-
11668
float* output = reinterpret_cast<float*>(TfLiteTensorData(
117-
TfLiteInterpreterGetOutputTensor(backend_data_->second_interpreter, 0)));
69+
TfLiteInterpreterGetOutputTensor(backend_data_->sd_interpreter, 0)));
11870
int output_size = TfLiteTensorByteSize(TfLiteInterpreterGetOutputTensor(
119-
backend_data_->second_interpreter, 0)) /
71+
backend_data_->sd_interpreter, 0)) /
12072
sizeof(float);
12173
return std::vector<float>(output, output + output_size);
12274
}
@@ -201,9 +153,9 @@ std::vector<float> StableDiffusionInvoker::run_inference(
201153

202154
// Access the input tensors
203155
void* pos_ids_input_data =
204-
TfLiteTensorData(TfLiteInterpreterGetInputTensor(interpreter, 0));
205-
void* encoded_input_data =
206156
TfLiteTensorData(TfLiteInterpreterGetInputTensor(interpreter, 1));
157+
void* encoded_input_data =
158+
TfLiteTensorData(TfLiteInterpreterGetInputTensor(interpreter, 0));
207159

208160
// Copy data to input tensors (type cast required for correct copy operation)
209161
std::memcpy(pos_ids_input_data, pos_ids.data(), pos_ids.size() * sizeof(int));

mobile_back_tflite/cpp/backend_tflite/stable_diffusion_pipeline.cc

Lines changed: 12 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -66,39 +66,31 @@ mlperf_backend_ptr_t StableDiffusionPipeline::backend_create(
6666

6767
// Load models from the provided directory path
6868
std::string text_encoder_path =
69-
std::string(model_path) + "/text_encoder.tflite";
70-
std::string first_model_path =
71-
std::string(model_path) + "/first_model.tflite";
72-
std::string second_model_path =
73-
std::string(model_path) + "/second_model.tflite";
74-
std::string decoder_path = std::string(model_path) + "/decoder.tflite";
69+
std::string(model_path) + "/sd_text_encoder_dynamic.tflite";
70+
std::string sd_model_path =
71+
std::string(model_path) + "/sd_diffusion_model_dynamic.tflite";
72+
std::string decoder_path =
73+
std::string(model_path) + "/sd_decoder_dynamic.tflite";
7574

7675
backend_data->text_encoder_model =
7776
TfLiteModelCreateFromFile(text_encoder_path.c_str());
78-
backend_data->first_model =
79-
TfLiteModelCreateFromFile(first_model_path.c_str());
80-
backend_data->second_model =
81-
TfLiteModelCreateFromFile(second_model_path.c_str());
77+
backend_data->sd_model = TfLiteModelCreateFromFile(sd_model_path.c_str());
8278
backend_data->decoder_model = TfLiteModelCreateFromFile(decoder_path.c_str());
8379

84-
if (!backend_data->text_encoder_model || !backend_data->first_model ||
85-
!backend_data->second_model || !backend_data->decoder_model) {
80+
if (!backend_data->text_encoder_model || !backend_data->sd_model ||
81+
!backend_data->decoder_model) {
8682
delete backend_data;
8783
return nullptr;
8884
}
8985

9086
backend_data->text_encoder_interpreter =
9187
create_interpreter(backend_data->text_encoder_model);
92-
backend_data->first_interpreter =
93-
create_interpreter(backend_data->first_model);
94-
backend_data->second_interpreter =
95-
create_interpreter(backend_data->second_model);
88+
backend_data->sd_interpreter = create_interpreter(backend_data->sd_model);
9689
backend_data->decoder_interpreter =
9790
create_interpreter(backend_data->decoder_model);
9891

9992
if (!backend_data->text_encoder_interpreter ||
100-
!backend_data->first_interpreter || !backend_data->second_interpreter ||
101-
!backend_data->decoder_interpreter) {
93+
!backend_data->sd_interpreter || !backend_data->decoder_interpreter) {
10294
backend_delete(backend_data);
10395
return nullptr;
10496
}
@@ -142,8 +134,7 @@ void StableDiffusionPipeline::backend_delete(mlperf_backend_ptr_t backend_ptr) {
142134
SDBackendData* backend_data = static_cast<SDBackendData*>(backend_ptr);
143135
if (backend_data) {
144136
TfLiteModelDelete(backend_data->text_encoder_model);
145-
TfLiteModelDelete(backend_data->first_model);
146-
TfLiteModelDelete(backend_data->second_model);
137+
TfLiteModelDelete(backend_data->sd_model);
147138
TfLiteModelDelete(backend_data->decoder_model);
148139
delete backend_data;
149140
}
@@ -214,7 +205,7 @@ mlperf_status_t StableDiffusionPipeline::backend_set_input(
214205
++token_count;
215206
}
216207

217-
std::vector<int> unconditioned_tokens(87, 49407);
208+
std::vector<int> unconditioned_tokens(77, 49407);
218209
unconditioned_tokens[0] = 49406;
219210

220211
backend_data->input_prompt_tokens.assign(tokens, tokens + token_count);

mobile_back_tflite/cpp/backend_tflite/stable_diffusion_pipeline.h

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,13 +27,11 @@ struct SDBackendData {
2727
const char *accelerator = "CPU";
2828

2929
TfLiteModel *text_encoder_model{nullptr};
30-
TfLiteModel *first_model{nullptr};
31-
TfLiteModel *second_model{nullptr};
30+
TfLiteModel *sd_model{nullptr};
3231
TfLiteModel *decoder_model{nullptr};
3332

3433
TfLiteInterpreter *text_encoder_interpreter{nullptr};
35-
TfLiteInterpreter *first_interpreter{nullptr};
36-
TfLiteInterpreter *second_interpreter{nullptr};
34+
TfLiteInterpreter *sd_interpreter{nullptr};
3735
TfLiteInterpreter *decoder_interpreter{nullptr};
3836

3937
std::vector<int> input_prompt_tokens;

0 commit comments

Comments
 (0)