diff --git a/src/operators/batch-matrix-multiply-nc.c b/src/operators/batch-matrix-multiply-nc.c index 370b986a06b..8f7599fefc8 100644 --- a/src/operators/batch-matrix-multiply-nc.c +++ b/src/operators/batch-matrix-multiply-nc.c @@ -16,6 +16,7 @@ #include "src/xnnpack/compute.h" #include "src/xnnpack/config-types.h" #include "src/xnnpack/config.h" +#include "src/xnnpack/internal.h" #include "src/xnnpack/log.h" #include "src/xnnpack/math.h" #include "src/xnnpack/microfnptr.h" @@ -1479,7 +1480,7 @@ enum xnn_status xnn_setup_batch_matrix_multiply_nc_qd8_f32_qc8w( enum xnn_status xnn_setup_batch_matrix_multiply_nc_qp8_f32_qc8w( xnn_operator_t batch_matrix_multiply_op, void* workspace, - const int8_t* input_a, const float* input_b, float* output) { + const int8_t* input_a, const int8_t* input_b, float* output) { return setup_batch_matrix_multiply_nc( batch_matrix_multiply_op, xnn_operator_type_batch_matrix_multiply_nc_qp8_f32_qc8w, input_a, diff --git a/src/operators/fully-connected-nc.c b/src/operators/fully-connected-nc.c index 87c666cfdfc..164866c9a16 100644 --- a/src/operators/fully-connected-nc.c +++ b/src/operators/fully-connected-nc.c @@ -3136,7 +3136,7 @@ enum xnn_status xnn_setup_fully_connected_nc_qdu8_f16_qc8w( } enum xnn_status xnn_setup_fully_connected_nc_qp8_f32_qc4w( - xnn_operator_t fully_connected_op, const float* input, float* output, + xnn_operator_t fully_connected_op, const int8_t* input, float* output, void* workspace) { return setup_fully_connected_nc( fully_connected_op, xnn_operator_type_fully_connected_nc_qp8_f32_qc4w, @@ -3144,7 +3144,7 @@ enum xnn_status xnn_setup_fully_connected_nc_qp8_f32_qc4w( } enum xnn_status xnn_setup_fully_connected_nc_qp8_f32_qc8w( - xnn_operator_t fully_connected_op, const float* input, float* output, + xnn_operator_t fully_connected_op, const int8_t* input, float* output, void* workspace) { return setup_fully_connected_nc( fully_connected_op, xnn_operator_type_fully_connected_nc_qp8_f32_qc8w, @@ -3152,7 +3152,7 @@ enum xnn_status xnn_setup_fully_connected_nc_qp8_f32_qc8w( } enum xnn_status xnn_setup_fully_connected_nc_qp8_f32_qb4w( - xnn_operator_t fully_connected_op, const float* input, float* output, + xnn_operator_t fully_connected_op, const int8_t* input, float* output, void* workspace) { return setup_fully_connected_nc( fully_connected_op, xnn_operator_type_fully_connected_nc_qp8_f32_qb4w, diff --git a/src/xnnpack/internal.h b/src/xnnpack/internal.h index a33e142ce6c..2ce7a8e18b8 100644 --- a/src/xnnpack/internal.h +++ b/src/xnnpack/internal.h @@ -258,7 +258,7 @@ enum xnn_status xnn_create_batch_matrix_multiply_nc_pf16( uint32_t flags, xnn_operator_t* batch_matrix_multiply_op_out); enum xnn_status xnn_create_batch_matrix_multiply_nc_pf16_const_weights( - size_t batch_size_b, size_t k, size_t n, const void* data_b, uint32_t flags, + size_t batch_size_b, size_t k, size_t n, const xnn_float16* data_b, uint32_t flags, xnn_operator_t* batch_matrix_multiply_op_out); enum xnn_status xnn_reshape_batch_matrix_multiply_nc_pf16( @@ -476,7 +476,7 @@ enum xnn_status xnn_reshape_batch_matrix_multiply_nc_qdu8_f32_qc8w( enum xnn_status xnn_setup_batch_matrix_multiply_nc_qdu8_f32_qc8w( xnn_operator_t batch_matrix_multiply_op, void* workspace, - const int8_t* input_a, const int8_t* input_b, + const int8_t* input_a, const float* input_b, const struct xnn_quantization_params* quantization_params, float* output); enum xnn_status xnn_create_fully_connected_nc_pf16(