diff options
author | Aron Virginas-Tar <Aron.Virginas-Tar@arm.com> | 2019-01-24 17:05:36 +0000 |
---|---|---|
committer | Aron Virginas-Tar <Aron.Virginas-Tar@arm.com> | 2019-01-30 11:25:56 +0000 |
commit | 7cf0eaa26c1fb29ca9df97e4734ec7c1e10f81c4 (patch) | |
tree | 407b519ede76b235c54907fe80411970741e8a00 /tests/InferenceTest.inl | |
parent | 28d3d63cc0a33f8396b32fa8347c03912c065911 (diff) | |
download | armnn-7cf0eaa26c1fb29ca9df97e4734ec7c1e10f81c4.tar.gz |
IVGCVSW-2564 Add support for multiple input and output bindings in InferenceModel
Change-Id: I64d724367d42dca4b768b6c6e42acda714985950
Diffstat (limited to 'tests/InferenceTest.inl')
-rw-r--r-- | tests/InferenceTest.inl | 22 |
1 files changed, 16 insertions, 6 deletions
diff --git a/tests/InferenceTest.inl b/tests/InferenceTest.inl index 7ce017c6cd..4dde35403d 100644 --- a/tests/InferenceTest.inl +++ b/tests/InferenceTest.inl @@ -1,4 +1,4 @@ -// +// // Copyright © 2017 Arm Ltd. All rights reserved. // SPDX-License-Identifier: MIT // @@ -39,7 +39,7 @@ ClassifierTestCase<TTestCaseDatabase, TModel>::ClassifierTestCase( unsigned int testCaseId, unsigned int label, std::vector<typename TModel::DataType> modelInput) - : InferenceModelTestCase<TModel>(model, testCaseId, std::move(modelInput), model.GetOutputSize()) + : InferenceModelTestCase<TModel>(model, testCaseId, { std::move(modelInput) }, { model.GetOutputSize() }) , m_Label(label) , m_QuantizationParams(model.GetQuantizationParams()) , m_NumInferencesRef(numInferencesRef) @@ -52,7 +52,7 @@ ClassifierTestCase<TTestCaseDatabase, TModel>::ClassifierTestCase( template <typename TTestCaseDatabase, typename TModel> TestCaseResult ClassifierTestCase<TTestCaseDatabase, TModel>::ProcessResult(const InferenceTestOptions& params) { - auto& output = this->GetOutput(); + auto& output = this->GetOutputs()[0]; const auto testCaseId = this->GetTestCaseId(); std::map<float,int> resultMap; @@ -309,7 +309,12 @@ int ClassifierInferenceTestMain(int argc, const std::vector<unsigned int>& defaultTestCaseIds, TConstructDatabaseCallable constructDatabase, const armnn::TensorShape* inputTensorShape) + { + BOOST_ASSERT(modelFilename); + BOOST_ASSERT(inputBindingName); + BOOST_ASSERT(outputBindingName); + return InferenceTestMain(argc, argv, defaultTestCaseIds, [=] () @@ -328,9 +333,14 @@ int ClassifierInferenceTestMain(int argc, typename InferenceModel::Params modelParams; modelParams.m_ModelPath = modelOptions.m_ModelDir + modelFilename; - modelParams.m_InputBinding = inputBindingName; - modelParams.m_OutputBinding = outputBindingName; - modelParams.m_InputTensorShape = inputTensorShape; + modelParams.m_InputBindings = { inputBindingName }; + modelParams.m_OutputBindings = { outputBindingName }; + + if (inputTensorShape) + { + modelParams.m_InputShapes.push_back(*inputTensorShape); + } + modelParams.m_IsModelBinary = isModelBinary; modelParams.m_ComputeDevice = modelOptions.m_ComputeDevice; modelParams.m_VisualizePostOptimizationModel = modelOptions.m_VisualizePostOptimizationModel; |