diff options
Diffstat (limited to 'tests/InferenceTest.inl')
-rw-r--r-- | tests/InferenceTest.inl | 22 |
1 files changed, 16 insertions, 6 deletions
diff --git a/tests/InferenceTest.inl b/tests/InferenceTest.inl index 7ce017c6cd..4dde35403d 100644 --- a/tests/InferenceTest.inl +++ b/tests/InferenceTest.inl @@ -1,4 +1,4 @@ -// +// // Copyright © 2017 Arm Ltd. All rights reserved. // SPDX-License-Identifier: MIT // @@ -39,7 +39,7 @@ ClassifierTestCase<TTestCaseDatabase, TModel>::ClassifierTestCase( unsigned int testCaseId, unsigned int label, std::vector<typename TModel::DataType> modelInput) - : InferenceModelTestCase<TModel>(model, testCaseId, std::move(modelInput), model.GetOutputSize()) + : InferenceModelTestCase<TModel>(model, testCaseId, { std::move(modelInput) }, { model.GetOutputSize() }) , m_Label(label) , m_QuantizationParams(model.GetQuantizationParams()) , m_NumInferencesRef(numInferencesRef) @@ -52,7 +52,7 @@ ClassifierTestCase<TTestCaseDatabase, TModel>::ClassifierTestCase( template <typename TTestCaseDatabase, typename TModel> TestCaseResult ClassifierTestCase<TTestCaseDatabase, TModel>::ProcessResult(const InferenceTestOptions& params) { - auto& output = this->GetOutput(); + auto& output = this->GetOutputs()[0]; const auto testCaseId = this->GetTestCaseId(); std::map<float,int> resultMap; @@ -309,7 +309,12 @@ int ClassifierInferenceTestMain(int argc, const std::vector<unsigned int>& defaultTestCaseIds, TConstructDatabaseCallable constructDatabase, const armnn::TensorShape* inputTensorShape) + { + BOOST_ASSERT(modelFilename); + BOOST_ASSERT(inputBindingName); + BOOST_ASSERT(outputBindingName); + return InferenceTestMain(argc, argv, defaultTestCaseIds, [=] () @@ -328,9 +333,14 @@ int ClassifierInferenceTestMain(int argc, typename InferenceModel::Params modelParams; modelParams.m_ModelPath = modelOptions.m_ModelDir + modelFilename; - modelParams.m_InputBinding = inputBindingName; - modelParams.m_OutputBinding = outputBindingName; - modelParams.m_InputTensorShape = inputTensorShape; + modelParams.m_InputBindings = { inputBindingName }; + modelParams.m_OutputBindings = { outputBindingName }; + + if (inputTensorShape) + { + modelParams.m_InputShapes.push_back(*inputTensorShape); + } + modelParams.m_IsModelBinary = isModelBinary; modelParams.m_ComputeDevice = modelOptions.m_ComputeDevice; modelParams.m_VisualizePostOptimizationModel = modelOptions.m_VisualizePostOptimizationModel; |