aboutsummaryrefslogtreecommitdiff
path: root/tests/InferenceTest.inl
diff options
context:
space:
mode:
authorAron Virginas-Tar <Aron.Virginas-Tar@arm.com>2019-01-24 17:05:36 +0000
committerAron Virginas-Tar <Aron.Virginas-Tar@arm.com>2019-01-30 11:25:56 +0000
commit7cf0eaa26c1fb29ca9df97e4734ec7c1e10f81c4 (patch)
tree407b519ede76b235c54907fe80411970741e8a00 /tests/InferenceTest.inl
parent28d3d63cc0a33f8396b32fa8347c03912c065911 (diff)
downloadarmnn-7cf0eaa26c1fb29ca9df97e4734ec7c1e10f81c4.tar.gz
IVGCVSW-2564 Add support for multiple input and output bindings in InferenceModel
Change-Id: I64d724367d42dca4b768b6c6e42acda714985950
Diffstat (limited to 'tests/InferenceTest.inl')
-rw-r--r--tests/InferenceTest.inl22
1 files changed, 16 insertions, 6 deletions
diff --git a/tests/InferenceTest.inl b/tests/InferenceTest.inl
index 7ce017c6cd..4dde35403d 100644
--- a/tests/InferenceTest.inl
+++ b/tests/InferenceTest.inl
@@ -1,4 +1,4 @@
-//
+//
// Copyright © 2017 Arm Ltd. All rights reserved.
// SPDX-License-Identifier: MIT
//
@@ -39,7 +39,7 @@ ClassifierTestCase<TTestCaseDatabase, TModel>::ClassifierTestCase(
unsigned int testCaseId,
unsigned int label,
std::vector<typename TModel::DataType> modelInput)
- : InferenceModelTestCase<TModel>(model, testCaseId, std::move(modelInput), model.GetOutputSize())
+ : InferenceModelTestCase<TModel>(model, testCaseId, { std::move(modelInput) }, { model.GetOutputSize() })
, m_Label(label)
, m_QuantizationParams(model.GetQuantizationParams())
, m_NumInferencesRef(numInferencesRef)
@@ -52,7 +52,7 @@ ClassifierTestCase<TTestCaseDatabase, TModel>::ClassifierTestCase(
template <typename TTestCaseDatabase, typename TModel>
TestCaseResult ClassifierTestCase<TTestCaseDatabase, TModel>::ProcessResult(const InferenceTestOptions& params)
{
- auto& output = this->GetOutput();
+ auto& output = this->GetOutputs()[0];
const auto testCaseId = this->GetTestCaseId();
std::map<float,int> resultMap;
@@ -309,7 +309,12 @@ int ClassifierInferenceTestMain(int argc,
const std::vector<unsigned int>& defaultTestCaseIds,
TConstructDatabaseCallable constructDatabase,
const armnn::TensorShape* inputTensorShape)
+
{
+ BOOST_ASSERT(modelFilename);
+ BOOST_ASSERT(inputBindingName);
+ BOOST_ASSERT(outputBindingName);
+
return InferenceTestMain(argc, argv, defaultTestCaseIds,
[=]
()
@@ -328,9 +333,14 @@ int ClassifierInferenceTestMain(int argc,
typename InferenceModel::Params modelParams;
modelParams.m_ModelPath = modelOptions.m_ModelDir + modelFilename;
- modelParams.m_InputBinding = inputBindingName;
- modelParams.m_OutputBinding = outputBindingName;
- modelParams.m_InputTensorShape = inputTensorShape;
+ modelParams.m_InputBindings = { inputBindingName };
+ modelParams.m_OutputBindings = { outputBindingName };
+
+ if (inputTensorShape)
+ {
+ modelParams.m_InputShapes.push_back(*inputTensorShape);
+ }
+
modelParams.m_IsModelBinary = isModelBinary;
modelParams.m_ComputeDevice = modelOptions.m_ComputeDevice;
modelParams.m_VisualizePostOptimizationModel = modelOptions.m_VisualizePostOptimizationModel;