aboutsummaryrefslogtreecommitdiff
path: root/delegate/test/FullyConnectedTestHelper.hpp
diff options
context:
space:
mode:
Diffstat (limited to 'delegate/test/FullyConnectedTestHelper.hpp')
-rw-r--r--delegate/test/FullyConnectedTestHelper.hpp12
1 files changed, 4 insertions, 8 deletions
diff --git a/delegate/test/FullyConnectedTestHelper.hpp b/delegate/test/FullyConnectedTestHelper.hpp
index 20c1102bd9..517d932f29 100644
--- a/delegate/test/FullyConnectedTestHelper.hpp
+++ b/delegate/test/FullyConnectedTestHelper.hpp
@@ -1,5 +1,5 @@
//
-// Copyright © 2020, 2023 Arm Ltd and Contributors. All rights reserved.
+// Copyright © 2020, 2023-2024 Arm Ltd and Contributors. All rights reserved.
// SPDX-License-Identifier: MIT
//
@@ -10,12 +10,8 @@
#include <armnn_delegate.hpp>
#include <DelegateTestInterpreter.hpp>
-#include <flatbuffers/flatbuffers.h>
-#include <tensorflow/lite/kernels/register.h>
#include <tensorflow/lite/version.h>
-#include <doctest/doctest.h>
-
namespace
{
@@ -164,8 +160,7 @@ std::vector<char> CreateFullyConnectedTfLiteModel(tflite::TensorType tensorType,
}
template <typename T>
-void FullyConnectedTest(std::vector<armnn::BackendId>& backends,
- tflite::TensorType tensorType,
+void FullyConnectedTest(tflite::TensorType tensorType,
tflite::ActivationFunctionType activationType,
const std::vector <int32_t>& inputTensorShape,
const std::vector <int32_t>& weightsTensorShape,
@@ -174,6 +169,7 @@ void FullyConnectedTest(std::vector<armnn::BackendId>& backends,
std::vector <T>& inputValues,
std::vector <T>& expectedOutputValues,
std::vector <T>& weightsData,
+ const std::vector<armnn::BackendId>& backends = {},
bool constantWeights = true,
float quantScale = 1.0f,
int quantOffset = 0)
@@ -196,7 +192,7 @@ void FullyConnectedTest(std::vector<armnn::BackendId>& backends,
CHECK(tfLiteInterpreter.AllocateTensors() == kTfLiteOk);
// Setup interpreter with Arm NN Delegate applied.
- auto armnnInterpreter = DelegateTestInterpreter(modelBuffer, backends);
+ auto armnnInterpreter = DelegateTestInterpreter(modelBuffer, CaptureAvailableBackends(backends));
CHECK(armnnInterpreter.AllocateTensors() == kTfLiteOk);
CHECK(tfLiteInterpreter.FillInputTensor<T>(inputValues, 0) == kTfLiteOk);