aboutsummaryrefslogtreecommitdiff
path: root/delegate/test/UnidirectionalSequenceLstmTestHelper.hpp
diff options
context:
space:
mode:
Diffstat (limited to 'delegate/test/UnidirectionalSequenceLstmTestHelper.hpp')
-rw-r--r--delegate/test/UnidirectionalSequenceLstmTestHelper.hpp88
1 files changed, 33 insertions, 55 deletions
diff --git a/delegate/test/UnidirectionalSequenceLstmTestHelper.hpp b/delegate/test/UnidirectionalSequenceLstmTestHelper.hpp
index 0ff04e7949..c058d83bc6 100644
--- a/delegate/test/UnidirectionalSequenceLstmTestHelper.hpp
+++ b/delegate/test/UnidirectionalSequenceLstmTestHelper.hpp
@@ -8,14 +8,13 @@
#include "TestUtils.hpp"
#include <armnn_delegate.hpp>
+#include <DelegateTestInterpreter.hpp>
#include <flatbuffers/flatbuffers.h>
-#include <tensorflow/lite/interpreter.h>
#include <tensorflow/lite/kernels/register.h>
-#include <tensorflow/lite/model.h>
-#include <schema_generated.h>
#include <tensorflow/lite/version.h>
-#include <tensorflow/lite/c/common.h>
+
+#include <schema_generated.h>
#include <doctest/doctest.h>
@@ -569,7 +568,7 @@ std::vector<char> CreateUnidirectionalSequenceLstmTfLiteModel(tflite::TensorType
modelDescription,
flatBufferBuilder.CreateVector(buffers));
- flatBufferBuilder.Finish(flatbufferModel);
+ flatBufferBuilder.Finish(flatbufferModel, armnnDelegate::FILE_IDENTIFIER);
return std::vector<char>(flatBufferBuilder.GetBufferPointer(),
flatBufferBuilder.GetBufferPointer() + flatBufferBuilder.GetSize());
@@ -624,7 +623,7 @@ void UnidirectionalSequenceLstmTestImpl(std::vector<armnn::BackendId>& backends,
bool isTimeMajor,
float quantScale = 0.1f)
{
- using namespace tflite;
+ using namespace delegateTestInterpreter;
std::vector<char> modelBuffer = CreateUnidirectionalSequenceLstmTfLiteModel(tensorType,
batchSize,
@@ -671,72 +670,51 @@ void UnidirectionalSequenceLstmTestImpl(std::vector<armnn::BackendId>& backends,
isTimeMajor,
quantScale);
- const Model* tfLiteModel = GetModel(modelBuffer.data());
- // Create TfLite Interpreters
- std::unique_ptr<Interpreter> armnnDelegateInterpreter;
- CHECK(InterpreterBuilder(tfLiteModel, ::tflite::ops::builtin::BuiltinOpResolver())
- (&armnnDelegateInterpreter) == kTfLiteOk);
- CHECK(armnnDelegateInterpreter != nullptr);
- CHECK(armnnDelegateInterpreter->AllocateTensors() == kTfLiteOk);
-
- std::unique_ptr<Interpreter> tfLiteInterpreter;
- CHECK(InterpreterBuilder(tfLiteModel, ::tflite::ops::builtin::BuiltinOpResolver())
- (&tfLiteInterpreter) == kTfLiteOk);
- CHECK(tfLiteInterpreter != nullptr);
- CHECK(tfLiteInterpreter->AllocateTensors() == kTfLiteOk);
-
- // Create the ArmNN Delegate
- armnnDelegate::DelegateOptions delegateOptions(backends);
- std::unique_ptr<TfLiteDelegate, decltype(&armnnDelegate::TfLiteArmnnDelegateDelete)>
- theArmnnDelegate(armnnDelegate::TfLiteArmnnDelegateCreate(delegateOptions),
- armnnDelegate::TfLiteArmnnDelegateDelete);
- CHECK(theArmnnDelegate != nullptr);
- // Modify armnnDelegateInterpreter to use armnnDelegate
- CHECK(armnnDelegateInterpreter->ModifyGraphWithDelegate(theArmnnDelegate.get()) == kTfLiteOk);
-
- // Set input data
- auto tfLiteDelegateInputId = tfLiteInterpreter->inputs()[0];
- auto tfLiteDelageInputData = tfLiteInterpreter->typed_tensor<float>(tfLiteDelegateInputId);
- for (unsigned int i = 0; i < inputValues.size(); ++i)
+ std::vector<int32_t> outputShape;
+ if (isTimeMajor)
{
- tfLiteDelageInputData[i] = inputValues[i];
+ outputShape = {timeSize, batchSize, outputSize};
}
-
- auto armnnDelegateInputId = armnnDelegateInterpreter->inputs()[0];
- auto armnnDelegateInputData = armnnDelegateInterpreter->typed_tensor<float>(armnnDelegateInputId);
- for (unsigned int i = 0; i < inputValues.size(); ++i)
+ else
{
- armnnDelegateInputData[i] = inputValues[i];
+ outputShape = {batchSize, timeSize, outputSize};
}
- // Run EnqueueWorkload
- CHECK(tfLiteInterpreter->Invoke() == kTfLiteOk);
- CHECK(armnnDelegateInterpreter->Invoke() == kTfLiteOk);
+ // Setup interpreter with just TFLite Runtime.
+ auto tfLiteInterpreter = DelegateTestInterpreter(modelBuffer);
+ CHECK(tfLiteInterpreter.AllocateTensors() == kTfLiteOk);
+ CHECK(tfLiteInterpreter.FillInputTensor<float>(inputValues, 0) == kTfLiteOk);
+ CHECK(tfLiteInterpreter.Invoke() == kTfLiteOk);
+ std::vector<float> tfLiteOutputValues = tfLiteInterpreter.GetOutputResult<float>(0);
+ std::vector<int32_t> tfLiteOutputShape = tfLiteInterpreter.GetOutputShape(0);
+
+ // Setup interpreter with Arm NN Delegate applied.
+ auto armnnInterpreter = DelegateTestInterpreter(modelBuffer, backends);
+ CHECK(armnnInterpreter.AllocateTensors() == kTfLiteOk);
+ CHECK(armnnInterpreter.FillInputTensor<float>(inputValues, 0) == kTfLiteOk);
+ CHECK(armnnInterpreter.Invoke() == kTfLiteOk);
+ std::vector<float> armnnOutputValues = armnnInterpreter.GetOutputResult<float>(0);
+ std::vector<int32_t> armnnOutputShape = armnnInterpreter.GetOutputShape(0);
- // Compare output data
- auto tfLiteDelegateOutputId = tfLiteInterpreter->outputs()[0];
- auto tfLiteDelagateOutputData = tfLiteInterpreter->typed_tensor<float>(tfLiteDelegateOutputId);
- auto armnnDelegateOutputId = armnnDelegateInterpreter->outputs()[0];
- auto armnnDelegateOutputData = armnnDelegateInterpreter->typed_tensor<float>(armnnDelegateOutputId);
+ armnnDelegate::CompareOutputShape(tfLiteOutputShape, armnnOutputShape, outputShape);
if (tensorType == ::tflite::TensorType_INT8)
{
// Allow 2% tolerance for Quantized weights
- armnnDelegate::CompareData(expectedOutputValues.data(), armnnDelegateOutputData,
+ armnnDelegate::CompareData(expectedOutputValues.data(), armnnOutputValues.data(),
expectedOutputValues.size(), 2);
- armnnDelegate::CompareData(expectedOutputValues.data(), tfLiteDelagateOutputData,
+ armnnDelegate::CompareData(expectedOutputValues.data(), tfLiteOutputValues.data(),
expectedOutputValues.size(), 2);
- armnnDelegate::CompareData(tfLiteDelagateOutputData, armnnDelegateOutputData,
+ armnnDelegate::CompareData(tfLiteOutputValues.data(), armnnOutputValues.data(),
expectedOutputValues.size(), 2);
}
else
{
- armnnDelegate::CompareData(expectedOutputValues.data(), armnnDelegateOutputData,
- expectedOutputValues.size());
- armnnDelegate::CompareData(expectedOutputValues.data(), tfLiteDelagateOutputData,
- expectedOutputValues.size());
- armnnDelegate::CompareData(tfLiteDelagateOutputData, armnnDelegateOutputData, expectedOutputValues.size());
+ armnnDelegate::CompareOutputData<float>(tfLiteOutputValues, armnnOutputValues, expectedOutputValues);
}
+
+ tfLiteInterpreter.Cleanup();
+ armnnInterpreter.Cleanup();
}
} // anonymous namespace \ No newline at end of file