Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions model_api/cpp/adapters/include/adapters/inference_adapter.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ class InferenceAdapter
virtual ~InferenceAdapter() = default;

virtual InferenceOutput infer(const InferenceInput& input) = 0;
virtual void infer(const InferenceInput& input, InferenceOutput& output) = 0;
virtual void setCallback(std::function<void(ov::InferRequest, CallbackData)> callback) = 0;
virtual void inferAsync(const InferenceInput& input, CallbackData callback_args) = 0;
virtual bool isReady() = 0;
Expand All @@ -48,6 +49,9 @@ class InferenceAdapter
const std::string& device = "", const ov::AnyMap& compilationConfig = {},
size_t max_num_requests = 0) = 0;
virtual ov::PartialShape getInputShape(const std::string& inputName) const = 0;
virtual ov::PartialShape getOutputShape(const std::string& inputName) const = 0;
virtual ov::element::Type_t getInputDatatype(const std::string& inputName) const = 0;
virtual ov::element::Type_t getOutputDatatype(const std::string& outputName) const = 0;
virtual std::vector<std::string> getInputNames() const = 0;
virtual std::vector<std::string> getOutputNames() const = 0;
virtual const ov::AnyMap& getModelConfig() const = 0;
Expand Down
4 changes: 4 additions & 0 deletions model_api/cpp/adapters/include/adapters/openvino_adapter.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ class OpenVINOInferenceAdapter :public InferenceAdapter
OpenVINOInferenceAdapter() = default;

virtual InferenceOutput infer(const InferenceInput& input) override;
virtual void infer(const InferenceInput& input, InferenceOutput& output) override;
virtual void inferAsync(const InferenceInput& input, const CallbackData callback_args) override;
virtual void setCallback(std::function<void(ov::InferRequest, const CallbackData)> callback);
virtual bool isReady();
Expand All @@ -42,6 +43,9 @@ class OpenVINOInferenceAdapter :public InferenceAdapter
size_t max_num_requests = 1) override;
virtual size_t getNumAsyncExecutors() const;
virtual ov::PartialShape getInputShape(const std::string& inputName) const override;
virtual ov::PartialShape getOutputShape(const std::string& outputName) const override;
virtual ov::element::Type_t getInputDatatype(const std::string& inputName) const override;
virtual ov::element::Type_t getOutputDatatype(const std::string& outputName) const override;
virtual std::vector<std::string> getInputNames() const override;
virtual std::vector<std::string> getOutputNames() const override;
virtual const ov::AnyMap& getModelConfig() const override;
Expand Down
13 changes: 13 additions & 0 deletions model_api/cpp/adapters/src/openvino_adapter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,10 @@ void OpenVINOInferenceAdapter::loadModel(const std::shared_ptr<const ov::Model>&
}
}

void OpenVINOInferenceAdapter::infer(const InferenceInput&, InferenceOutput&) {
throw std::runtime_error("Not implemented");
}

InferenceOutput OpenVINOInferenceAdapter::infer(const InferenceInput& input) {
auto request = asyncQueue->operator[](asyncQueue->get_idle_request_id());
// Fill input blobs
Expand Down Expand Up @@ -95,6 +99,9 @@ size_t OpenVINOInferenceAdapter::getNumAsyncExecutors() const {
ov::PartialShape OpenVINOInferenceAdapter::getInputShape(const std::string& inputName) const {
return compiledModel.input(inputName).get_partial_shape();
}
ov::PartialShape OpenVINOInferenceAdapter::getOutputShape(const std::string& outputName) const {
return compiledModel.output(outputName).get_shape();
}

void OpenVINOInferenceAdapter::initInputsOutputs() {
for (const auto& input : compiledModel.inputs()) {
Expand All @@ -105,6 +112,12 @@ void OpenVINOInferenceAdapter::initInputsOutputs() {
outputNames.push_back(output.get_any_name());
}
}
ov::element::Type_t OpenVINOInferenceAdapter::getInputDatatype(const std::string&) const {
throw std::runtime_error("Not implemented");
}
ov::element::Type_t OpenVINOInferenceAdapter::getOutputDatatype(const std::string&) const {
throw std::runtime_error("Not implemented");
}

std::vector<std::string> OpenVINOInferenceAdapter::getInputNames() const {
return inputNames;
Expand Down
Loading