@@ -44,79 +44,31 @@ std::vector<float> StableDiffusionInvoker::encode_prompt(
4444std::vector<float > StableDiffusionInvoker::diffusion_step (
4545    const  std::vector<float >& latent, const  std::vector<float >& t_emb,
4646    const  std::vector<float >& context) {
47-   //  Prepare the first model's inputs
48- 
49-   auto  first_input_details =
50-       TfLiteInterpreterGetInputTensor (backend_data_->first_interpreter , 0 );
51-   auto  second_input_details =
52-       TfLiteInterpreterGetInputTensor (backend_data_->first_interpreter , 1 );
53-   auto  third_input_details =
54-       TfLiteInterpreterGetInputTensor (backend_data_->first_interpreter , 2 );
47+   auto  latent_input_details =
48+       TfLiteInterpreterGetInputTensor (backend_data_->sd_interpreter , 0 );
49+   auto  context_input_details =
50+       TfLiteInterpreterGetInputTensor (backend_data_->sd_interpreter , 1 );
51+   auto  time_stamp_embedding_input_details =
52+       TfLiteInterpreterGetInputTensor (backend_data_->sd_interpreter , 2 );
5553
5654  std::copy (context.begin (), context.end (),
57-             reinterpret_cast <float *>(TfLiteTensorData (first_input_details )));
55+             reinterpret_cast <float *>(TfLiteTensorData (context_input_details )));
5856  std::copy (t_emb.begin (), t_emb.end (),
59-             reinterpret_cast <float *>(TfLiteTensorData (second_input_details)));
57+             reinterpret_cast <float *>(
58+                 TfLiteTensorData (time_stamp_embedding_input_details)));
6059  std::copy (latent.begin (), latent.end (),
61-             reinterpret_cast <float *>(TfLiteTensorData (third_input_details )));
60+             reinterpret_cast <float *>(TfLiteTensorData (latent_input_details )));
6261
63-   //  Invoke the first  model
64-   if  (TfLiteInterpreterInvoke (backend_data_->first_interpreter ) != kTfLiteOk ) {
62+   //  Invoke the model
63+   if  (TfLiteInterpreterInvoke (backend_data_->sd_interpreter ) != kTfLiteOk ) {
6564    std::cerr << " Failed to invoke the first diffusion model!" 
6665    exit (-1 );
6766  }
6867
69-   //  Output names from the first model and corresponding input names for the
70-   //  second model
71-   std::vector<std::string> output_names = {
72-       " Identity_6" " Identity_4" " Identity" " input_1" " Identity_12" 
73-       " Identity_11" " Identity_3" " Identity_10" " Identity_9" " Identity_5" 
74-       " Identity_8" " Identity_7" " Identity_2" 
75- 
76-   std::vector<std::string> input_names = {
77-       " args_0" " args_0_1" " args_0_2" " args_0_4" " args_0_3" 
78-       " args_0_5" " args_0_6" " args_0_7" " args_0_8" " args_0_9" 
79-       " args_0_10" " args_0_11" " args_0_12" 
80- 
81-   //  Copy outputs of the first model to the inputs of the second model based on
82-   //  names
83-   for  (size_t  i = 0 ; i < input_names.size (); ++i) {
84-     int  input_index = get_tensor_index_by_name (
85-         backend_data_->second_interpreter , input_names[i], true );
86-     int  output_index = get_tensor_index_by_name (
87-         backend_data_->first_interpreter , output_names[i], false );
88- 
89-     if  (input_index == -1  || output_index == -1 ) {
90-       std::cerr << " Failed to find matching input or output tensor by name!" 
91-                 << std::endl;
92-       exit (-1 );
93-     }
94- 
95-     auto  first_model_output_details = TfLiteInterpreterGetOutputTensor (
96-         backend_data_->first_interpreter , output_index);
97- 
98-     float * output_data =
99-         reinterpret_cast <float *>(TfLiteTensorData (first_model_output_details));
100-     int  output_size =
101-         TfLiteTensorByteSize (first_model_output_details) / sizeof (float );
102- 
103-     float * input_data = reinterpret_cast <float *>(
104-         TfLiteTensorData (TfLiteInterpreterGetInputTensor (
105-             backend_data_->second_interpreter , input_index)));
106- 
107-     std::copy (output_data, output_data + output_size, input_data);
108-   }
109- 
110-   //  Invoke the second model
111-   if  (TfLiteInterpreterInvoke (backend_data_->second_interpreter ) != kTfLiteOk ) {
112-     std::cerr << " Failed to invoke the second diffusion model!" 
113-     exit (-1 );
114-   }
115- 
11668  float * output = reinterpret_cast <float *>(TfLiteTensorData (
117-       TfLiteInterpreterGetOutputTensor (backend_data_->second_interpreter , 0 )));
69+       TfLiteInterpreterGetOutputTensor (backend_data_->sd_interpreter , 0 )));
11870  int  output_size = TfLiteTensorByteSize (TfLiteInterpreterGetOutputTensor (
119-                         backend_data_->second_interpreter , 0 )) /
71+                         backend_data_->sd_interpreter , 0 )) /
12072                    sizeof (float );
12173  return  std::vector<float >(output, output + output_size);
12274}
@@ -201,9 +153,9 @@ std::vector<float> StableDiffusionInvoker::run_inference(
201153
202154  //  Access the input tensors
203155  void * pos_ids_input_data =
204-       TfLiteTensorData (TfLiteInterpreterGetInputTensor (interpreter, 0 ));
205-   void * encoded_input_data =
206156      TfLiteTensorData (TfLiteInterpreterGetInputTensor (interpreter, 1 ));
157+   void * encoded_input_data =
158+       TfLiteTensorData (TfLiteInterpreterGetInputTensor (interpreter, 0 ));
207159
208160  //  Copy data to input tensors (type cast required for correct copy operation)
209161  std::memcpy (pos_ids_input_data, pos_ids.data (), pos_ids.size () * sizeof (int ));
0 commit comments