@@ -44,79 +44,31 @@ std::vector<float> StableDiffusionInvoker::encode_prompt(
4444std::vector<float > StableDiffusionInvoker::diffusion_step (
4545 const std::vector<float >& latent, const std::vector<float >& t_emb,
4646 const std::vector<float >& context) {
47- // Prepare the first model's inputs
48-
49- auto first_input_details =
50- TfLiteInterpreterGetInputTensor (backend_data_->first_interpreter , 0 );
51- auto second_input_details =
52- TfLiteInterpreterGetInputTensor (backend_data_->first_interpreter , 1 );
53- auto third_input_details =
54- TfLiteInterpreterGetInputTensor (backend_data_->first_interpreter , 2 );
47+ auto latent_input_details =
48+ TfLiteInterpreterGetInputTensor (backend_data_->sd_interpreter , 0 );
49+ auto context_input_details =
50+ TfLiteInterpreterGetInputTensor (backend_data_->sd_interpreter , 1 );
51+ auto time_stamp_embedding_input_details =
52+ TfLiteInterpreterGetInputTensor (backend_data_->sd_interpreter , 2 );
5553
5654 std::copy (context.begin (), context.end (),
57- reinterpret_cast <float *>(TfLiteTensorData (first_input_details )));
55+ reinterpret_cast <float *>(TfLiteTensorData (context_input_details )));
5856 std::copy (t_emb.begin (), t_emb.end (),
59- reinterpret_cast <float *>(TfLiteTensorData (second_input_details)));
57+ reinterpret_cast <float *>(
58+ TfLiteTensorData (time_stamp_embedding_input_details)));
6059 std::copy (latent.begin (), latent.end (),
61- reinterpret_cast <float *>(TfLiteTensorData (third_input_details )));
60+ reinterpret_cast <float *>(TfLiteTensorData (latent_input_details )));
6261
63- // Invoke the first model
64- if (TfLiteInterpreterInvoke (backend_data_->first_interpreter ) != kTfLiteOk ) {
62+ // Invoke the model
63+ if (TfLiteInterpreterInvoke (backend_data_->sd_interpreter ) != kTfLiteOk ) {
6564 std::cerr << " Failed to invoke the first diffusion model!" << std::endl;
6665 exit (-1 );
6766 }
6867
69- // Output names from the first model and corresponding input names for the
70- // second model
71- std::vector<std::string> output_names = {
72- " Identity_6" , " Identity_4" , " Identity" , " input_1" , " Identity_12" ,
73- " Identity_11" , " Identity_3" , " Identity_10" , " Identity_9" , " Identity_5" ,
74- " Identity_8" , " Identity_7" , " Identity_2" };
75-
76- std::vector<std::string> input_names = {
77- " args_0" , " args_0_1" , " args_0_2" , " args_0_4" , " args_0_3" ,
78- " args_0_5" , " args_0_6" , " args_0_7" , " args_0_8" , " args_0_9" ,
79- " args_0_10" , " args_0_11" , " args_0_12" };
80-
81- // Copy outputs of the first model to the inputs of the second model based on
82- // names
83- for (size_t i = 0 ; i < input_names.size (); ++i) {
84- int input_index = get_tensor_index_by_name (
85- backend_data_->second_interpreter , input_names[i], true );
86- int output_index = get_tensor_index_by_name (
87- backend_data_->first_interpreter , output_names[i], false );
88-
89- if (input_index == -1 || output_index == -1 ) {
90- std::cerr << " Failed to find matching input or output tensor by name!"
91- << std::endl;
92- exit (-1 );
93- }
94-
95- auto first_model_output_details = TfLiteInterpreterGetOutputTensor (
96- backend_data_->first_interpreter , output_index);
97-
98- float * output_data =
99- reinterpret_cast <float *>(TfLiteTensorData (first_model_output_details));
100- int output_size =
101- TfLiteTensorByteSize (first_model_output_details) / sizeof (float );
102-
103- float * input_data = reinterpret_cast <float *>(
104- TfLiteTensorData (TfLiteInterpreterGetInputTensor (
105- backend_data_->second_interpreter , input_index)));
106-
107- std::copy (output_data, output_data + output_size, input_data);
108- }
109-
110- // Invoke the second model
111- if (TfLiteInterpreterInvoke (backend_data_->second_interpreter ) != kTfLiteOk ) {
112- std::cerr << " Failed to invoke the second diffusion model!" << std::endl;
113- exit (-1 );
114- }
115-
11668 float * output = reinterpret_cast <float *>(TfLiteTensorData (
117- TfLiteInterpreterGetOutputTensor (backend_data_->second_interpreter , 0 )));
69+ TfLiteInterpreterGetOutputTensor (backend_data_->sd_interpreter , 0 )));
11870 int output_size = TfLiteTensorByteSize (TfLiteInterpreterGetOutputTensor (
119- backend_data_->second_interpreter , 0 )) /
71+ backend_data_->sd_interpreter , 0 )) /
12072 sizeof (float );
12173 return std::vector<float >(output, output + output_size);
12274}
@@ -201,9 +153,9 @@ std::vector<float> StableDiffusionInvoker::run_inference(
201153
202154 // Access the input tensors
203155 void * pos_ids_input_data =
204- TfLiteTensorData (TfLiteInterpreterGetInputTensor (interpreter, 0 ));
205- void * encoded_input_data =
206156 TfLiteTensorData (TfLiteInterpreterGetInputTensor (interpreter, 1 ));
157+ void * encoded_input_data =
158+ TfLiteTensorData (TfLiteInterpreterGetInputTensor (interpreter, 0 ));
207159
208160 // Copy data to input tensors (type cast required for correct copy operation)
209161 std::memcpy (pos_ids_input_data, pos_ids.data (), pos_ids.size () * sizeof (int ));
0 commit comments