aboutsummaryrefslogtreecommitdiff
path: root/delegate/test/SplitTestHelper.hpp
diff options
context:
space:
mode:
Diffstat (limited to 'delegate/test/SplitTestHelper.hpp')
-rw-r--r--delegate/test/SplitTestHelper.hpp138
1 files changed, 51 insertions, 87 deletions
diff --git a/delegate/test/SplitTestHelper.hpp b/delegate/test/SplitTestHelper.hpp
index 503fbc85ae..1d5f459148 100644
--- a/delegate/test/SplitTestHelper.hpp
+++ b/delegate/test/SplitTestHelper.hpp
@@ -8,17 +8,15 @@
#include "TestUtils.hpp"
#include <armnn_delegate.hpp>
+#include <DelegateTestInterpreter.hpp>
#include <flatbuffers/flatbuffers.h>
-#include <tensorflow/lite/interpreter.h>
#include <tensorflow/lite/kernels/register.h>
-#include <tensorflow/lite/model.h>
-#include <schema_generated.h>
#include <tensorflow/lite/version.h>
-#include <doctest/doctest.h>
+#include <schema_generated.h>
-#include <string>
+#include <doctest/doctest.h>
namespace
{
@@ -113,7 +111,7 @@ std::vector<char> CreateSplitTfLiteModel(tflite::TensorType tensorType,
modelDescription,
flatBufferBuilder.CreateVector(buffers));
- flatBufferBuilder.Finish(flatbufferModel);
+ flatBufferBuilder.Finish(flatbufferModel, armnnDelegate::FILE_IDENTIFIER);
return std::vector<char>(flatBufferBuilder.GetBufferPointer(),
flatBufferBuilder.GetBufferPointer() + flatBufferBuilder.GetSize());
@@ -132,7 +130,7 @@ void SplitTest(tflite::TensorType tensorType,
float quantScale = 1.0f,
int quantOffset = 0)
{
- using namespace tflite;
+ using namespace delegateTestInterpreter;
std::vector<char> modelBuffer = CreateSplitTfLiteModel(tensorType,
axisTensorShape,
inputTensorShape,
@@ -141,51 +139,34 @@ void SplitTest(tflite::TensorType tensorType,
numSplits,
quantScale,
quantOffset);
- const Model* tfLiteModel = GetModel(modelBuffer.data());
-
- // Create TfLite Interpreters
- std::unique_ptr<Interpreter> armnnDelegate;
- CHECK(InterpreterBuilder(tfLiteModel, ::tflite::ops::builtin::BuiltinOpResolver())
- (&armnnDelegate) == kTfLiteOk);
- CHECK(armnnDelegate != nullptr);
- CHECK(armnnDelegate->AllocateTensors() == kTfLiteOk);
-
- std::unique_ptr<Interpreter> tfLiteDelegate;
- CHECK(InterpreterBuilder(tfLiteModel, ::tflite::ops::builtin::BuiltinOpResolver())
- (&tfLiteDelegate) == kTfLiteOk);
- CHECK(tfLiteDelegate != nullptr);
- CHECK(tfLiteDelegate->AllocateTensors() == kTfLiteOk);
-
- // Create the ArmNN Delegate
- armnnDelegate::DelegateOptions delegateOptions(backends);
- std::unique_ptr<TfLiteDelegate, decltype(&armnnDelegate::TfLiteArmnnDelegateDelete)>
- theArmnnDelegate(armnnDelegate::TfLiteArmnnDelegateCreate(delegateOptions),
- armnnDelegate::TfLiteArmnnDelegateDelete);
- CHECK(theArmnnDelegate != nullptr);
-
- // Modify armnnDelegateInterpreter to use armnnDelegate
- CHECK(armnnDelegate->ModifyGraphWithDelegate(theArmnnDelegate.get()) == kTfLiteOk);
-
- // Set input data
- armnnDelegate::FillInput<T>(tfLiteDelegate, 1, inputValues);
- armnnDelegate::FillInput<T>(armnnDelegate, 1, inputValues);
-
- // Run EnqueWorkload
- CHECK(tfLiteDelegate->Invoke() == kTfLiteOk);
- CHECK(armnnDelegate->Invoke() == kTfLiteOk);
+ // Setup interpreter with just TFLite Runtime.
+ auto tfLiteInterpreter = DelegateTestInterpreter(modelBuffer);
+ CHECK(tfLiteInterpreter.AllocateTensors() == kTfLiteOk);
+ CHECK(tfLiteInterpreter.FillInputTensor<T>(inputValues, 1) == kTfLiteOk);
+ CHECK(tfLiteInterpreter.Invoke() == kTfLiteOk);
+
+ // Setup interpreter with Arm NN Delegate applied.
+ auto armnnInterpreter = DelegateTestInterpreter(modelBuffer, backends);
+ CHECK(armnnInterpreter.AllocateTensors() == kTfLiteOk);
+ CHECK(armnnInterpreter.FillInputTensor<T>(inputValues, 1) == kTfLiteOk);
+ CHECK(armnnInterpreter.Invoke() == kTfLiteOk);
// Compare output data
for (unsigned int i = 0; i < expectedOutputValues.size(); ++i)
{
- armnnDelegate::CompareOutputData<T>(tfLiteDelegate,
- armnnDelegate,
- outputTensorShapes[i],
- expectedOutputValues[i],
- i);
+ std::vector<T> tfLiteOutputValues = tfLiteInterpreter.GetOutputResult<T>(i);
+ std::vector<int32_t> tfLiteOutputShape = tfLiteInterpreter.GetOutputShape(i);
+
+ std::vector<T> armnnOutputValues = armnnInterpreter.GetOutputResult<T>(i);
+ std::vector<int32_t> armnnOutputShape = armnnInterpreter.GetOutputShape(i);
+
+ armnnDelegate::CompareOutputData<T>(tfLiteOutputValues, armnnOutputValues, expectedOutputValues[i]);
+ armnnDelegate::CompareOutputShape(tfLiteOutputShape, armnnOutputShape, outputTensorShapes[i]);
}
- tfLiteDelegate.reset(nullptr);
- armnnDelegate.reset(nullptr);
+ tfLiteInterpreter.Cleanup();
+ armnnInterpreter.Cleanup();
+
} // End of SPLIT Test
std::vector<char> CreateSplitVTfLiteModel(tflite::TensorType tensorType,
@@ -288,7 +269,7 @@ std::vector<char> CreateSplitVTfLiteModel(tflite::TensorType tensorType,
modelDescription,
flatBufferBuilder.CreateVector(buffers.data(), buffers.size()));
- flatBufferBuilder.Finish(flatbufferModel);
+ flatBufferBuilder.Finish(flatbufferModel, armnnDelegate::FILE_IDENTIFIER);
return std::vector<char>(flatBufferBuilder.GetBufferPointer(),
flatBufferBuilder.GetBufferPointer() + flatBufferBuilder.GetSize());
@@ -309,7 +290,7 @@ void SplitVTest(tflite::TensorType tensorType,
float quantScale = 1.0f,
int quantOffset = 0)
{
- using namespace tflite;
+ using namespace delegateTestInterpreter;
std::vector<char> modelBuffer = CreateSplitVTfLiteModel(tensorType,
inputTensorShape,
splitsTensorShape,
@@ -320,51 +301,34 @@ void SplitVTest(tflite::TensorType tensorType,
numSplits,
quantScale,
quantOffset);
- const Model* tfLiteModel = GetModel(modelBuffer.data());
-
- // Create TfLite Interpreters
- std::unique_ptr<Interpreter> armnnDelegate;
- CHECK(InterpreterBuilder(tfLiteModel, ::tflite::ops::builtin::BuiltinOpResolver())
- (&armnnDelegate) == kTfLiteOk);
- CHECK(armnnDelegate != nullptr);
- CHECK(armnnDelegate->AllocateTensors() == kTfLiteOk);
-
- std::unique_ptr<Interpreter> tfLiteDelegate;
- CHECK(InterpreterBuilder(tfLiteModel, ::tflite::ops::builtin::BuiltinOpResolver())
- (&tfLiteDelegate) == kTfLiteOk);
- CHECK(tfLiteDelegate != nullptr);
- CHECK(tfLiteDelegate->AllocateTensors() == kTfLiteOk);
-
- // Create the ArmNN Delegate
- armnnDelegate::DelegateOptions delegateOptions(backends);
- std::unique_ptr<TfLiteDelegate, decltype(&armnnDelegate::TfLiteArmnnDelegateDelete)>
- theArmnnDelegate(armnnDelegate::TfLiteArmnnDelegateCreate(delegateOptions),
- armnnDelegate::TfLiteArmnnDelegateDelete);
- CHECK(theArmnnDelegate != nullptr);
-
- // Modify armnnDelegateInterpreter to use armnnDelegate
- CHECK(armnnDelegate->ModifyGraphWithDelegate(theArmnnDelegate.get()) == kTfLiteOk);
-
- // Set input data
- armnnDelegate::FillInput<T>(tfLiteDelegate, 0, inputValues);
- armnnDelegate::FillInput<T>(armnnDelegate, 0, inputValues);
-
- // Run EnqueWorkload
- CHECK(tfLiteDelegate->Invoke() == kTfLiteOk);
- CHECK(armnnDelegate->Invoke() == kTfLiteOk);
+
+ // Setup interpreter with just TFLite Runtime.
+ auto tfLiteInterpreter = DelegateTestInterpreter(modelBuffer);
+ CHECK(tfLiteInterpreter.AllocateTensors() == kTfLiteOk);
+ CHECK(tfLiteInterpreter.FillInputTensor<T>(inputValues, 0) == kTfLiteOk);
+ CHECK(tfLiteInterpreter.Invoke() == kTfLiteOk);
+
+ // Setup interpreter with Arm NN Delegate applied.
+ auto armnnInterpreter = DelegateTestInterpreter(modelBuffer, backends);
+ CHECK(armnnInterpreter.AllocateTensors() == kTfLiteOk);
+ CHECK(armnnInterpreter.FillInputTensor<T>(inputValues, 0) == kTfLiteOk);
+ CHECK(armnnInterpreter.Invoke() == kTfLiteOk);
// Compare output data
for (unsigned int i = 0; i < expectedOutputValues.size(); ++i)
{
- armnnDelegate::CompareOutputData<T>(tfLiteDelegate,
- armnnDelegate,
- outputTensorShapes[i],
- expectedOutputValues[i],
- i);
+ std::vector<T> tfLiteOutputValues = tfLiteInterpreter.GetOutputResult<T>(i);
+ std::vector<int32_t> tfLiteOutputShape = tfLiteInterpreter.GetOutputShape(i);
+
+ std::vector<T> armnnOutputValues = armnnInterpreter.GetOutputResult<T>(i);
+ std::vector<int32_t> armnnOutputShape = armnnInterpreter.GetOutputShape(i);
+
+ armnnDelegate::CompareOutputData<T>(tfLiteOutputValues, armnnOutputValues, expectedOutputValues[i]);
+ armnnDelegate::CompareOutputShape(tfLiteOutputShape, armnnOutputShape, outputTensorShapes[i]);
}
- tfLiteDelegate.reset(nullptr);
- armnnDelegate.reset(nullptr);
+ tfLiteInterpreter.Cleanup();
+ armnnInterpreter.Cleanup();
} // End of SPLIT_V Test
} // anonymous namespace \ No newline at end of file