diff options
Diffstat (limited to 'delegate/test/UnidirectionalSequenceLstmTestHelper.hpp')
-rw-r--r-- | delegate/test/UnidirectionalSequenceLstmTestHelper.hpp | 88 |
1 files changed, 33 insertions, 55 deletions
diff --git a/delegate/test/UnidirectionalSequenceLstmTestHelper.hpp b/delegate/test/UnidirectionalSequenceLstmTestHelper.hpp index 0ff04e7949..c058d83bc6 100644 --- a/delegate/test/UnidirectionalSequenceLstmTestHelper.hpp +++ b/delegate/test/UnidirectionalSequenceLstmTestHelper.hpp @@ -8,14 +8,13 @@ #include "TestUtils.hpp" #include <armnn_delegate.hpp> +#include <DelegateTestInterpreter.hpp> #include <flatbuffers/flatbuffers.h> -#include <tensorflow/lite/interpreter.h> #include <tensorflow/lite/kernels/register.h> -#include <tensorflow/lite/model.h> -#include <schema_generated.h> #include <tensorflow/lite/version.h> -#include <tensorflow/lite/c/common.h> + +#include <schema_generated.h> #include <doctest/doctest.h> @@ -569,7 +568,7 @@ std::vector<char> CreateUnidirectionalSequenceLstmTfLiteModel(tflite::TensorType modelDescription, flatBufferBuilder.CreateVector(buffers)); - flatBufferBuilder.Finish(flatbufferModel); + flatBufferBuilder.Finish(flatbufferModel, armnnDelegate::FILE_IDENTIFIER); return std::vector<char>(flatBufferBuilder.GetBufferPointer(), flatBufferBuilder.GetBufferPointer() + flatBufferBuilder.GetSize()); @@ -624,7 +623,7 @@ void UnidirectionalSequenceLstmTestImpl(std::vector<armnn::BackendId>& backends, bool isTimeMajor, float quantScale = 0.1f) { - using namespace tflite; + using namespace delegateTestInterpreter; std::vector<char> modelBuffer = CreateUnidirectionalSequenceLstmTfLiteModel(tensorType, batchSize, @@ -671,72 +670,51 @@ void UnidirectionalSequenceLstmTestImpl(std::vector<armnn::BackendId>& backends, isTimeMajor, quantScale); - const Model* tfLiteModel = GetModel(modelBuffer.data()); - // Create TfLite Interpreters - std::unique_ptr<Interpreter> armnnDelegateInterpreter; - CHECK(InterpreterBuilder(tfLiteModel, ::tflite::ops::builtin::BuiltinOpResolver()) - (&armnnDelegateInterpreter) == kTfLiteOk); - CHECK(armnnDelegateInterpreter != nullptr); - CHECK(armnnDelegateInterpreter->AllocateTensors() == kTfLiteOk); - - std::unique_ptr<Interpreter> tfLiteInterpreter; - CHECK(InterpreterBuilder(tfLiteModel, ::tflite::ops::builtin::BuiltinOpResolver()) - (&tfLiteInterpreter) == kTfLiteOk); - CHECK(tfLiteInterpreter != nullptr); - CHECK(tfLiteInterpreter->AllocateTensors() == kTfLiteOk); - - // Create the ArmNN Delegate - armnnDelegate::DelegateOptions delegateOptions(backends); - std::unique_ptr<TfLiteDelegate, decltype(&armnnDelegate::TfLiteArmnnDelegateDelete)> - theArmnnDelegate(armnnDelegate::TfLiteArmnnDelegateCreate(delegateOptions), - armnnDelegate::TfLiteArmnnDelegateDelete); - CHECK(theArmnnDelegate != nullptr); - // Modify armnnDelegateInterpreter to use armnnDelegate - CHECK(armnnDelegateInterpreter->ModifyGraphWithDelegate(theArmnnDelegate.get()) == kTfLiteOk); - - // Set input data - auto tfLiteDelegateInputId = tfLiteInterpreter->inputs()[0]; - auto tfLiteDelageInputData = tfLiteInterpreter->typed_tensor<float>(tfLiteDelegateInputId); - for (unsigned int i = 0; i < inputValues.size(); ++i) + std::vector<int32_t> outputShape; + if (isTimeMajor) { - tfLiteDelageInputData[i] = inputValues[i]; + outputShape = {timeSize, batchSize, outputSize}; } - - auto armnnDelegateInputId = armnnDelegateInterpreter->inputs()[0]; - auto armnnDelegateInputData = armnnDelegateInterpreter->typed_tensor<float>(armnnDelegateInputId); - for (unsigned int i = 0; i < inputValues.size(); ++i) + else { - armnnDelegateInputData[i] = inputValues[i]; + outputShape = {batchSize, timeSize, outputSize}; } - // Run EnqueueWorkload - CHECK(tfLiteInterpreter->Invoke() == kTfLiteOk); - CHECK(armnnDelegateInterpreter->Invoke() == kTfLiteOk); + // Setup interpreter with just TFLite Runtime. + auto tfLiteInterpreter = DelegateTestInterpreter(modelBuffer); + CHECK(tfLiteInterpreter.AllocateTensors() == kTfLiteOk); + CHECK(tfLiteInterpreter.FillInputTensor<float>(inputValues, 0) == kTfLiteOk); + CHECK(tfLiteInterpreter.Invoke() == kTfLiteOk); + std::vector<float> tfLiteOutputValues = tfLiteInterpreter.GetOutputResult<float>(0); + std::vector<int32_t> tfLiteOutputShape = tfLiteInterpreter.GetOutputShape(0); + + // Setup interpreter with Arm NN Delegate applied. + auto armnnInterpreter = DelegateTestInterpreter(modelBuffer, backends); + CHECK(armnnInterpreter.AllocateTensors() == kTfLiteOk); + CHECK(armnnInterpreter.FillInputTensor<float>(inputValues, 0) == kTfLiteOk); + CHECK(armnnInterpreter.Invoke() == kTfLiteOk); + std::vector<float> armnnOutputValues = armnnInterpreter.GetOutputResult<float>(0); + std::vector<int32_t> armnnOutputShape = armnnInterpreter.GetOutputShape(0); - // Compare output data - auto tfLiteDelegateOutputId = tfLiteInterpreter->outputs()[0]; - auto tfLiteDelagateOutputData = tfLiteInterpreter->typed_tensor<float>(tfLiteDelegateOutputId); - auto armnnDelegateOutputId = armnnDelegateInterpreter->outputs()[0]; - auto armnnDelegateOutputData = armnnDelegateInterpreter->typed_tensor<float>(armnnDelegateOutputId); + armnnDelegate::CompareOutputShape(tfLiteOutputShape, armnnOutputShape, outputShape); if (tensorType == ::tflite::TensorType_INT8) { // Allow 2% tolerance for Quantized weights - armnnDelegate::CompareData(expectedOutputValues.data(), armnnDelegateOutputData, + armnnDelegate::CompareData(expectedOutputValues.data(), armnnOutputValues.data(), expectedOutputValues.size(), 2); - armnnDelegate::CompareData(expectedOutputValues.data(), tfLiteDelagateOutputData, + armnnDelegate::CompareData(expectedOutputValues.data(), tfLiteOutputValues.data(), expectedOutputValues.size(), 2); - armnnDelegate::CompareData(tfLiteDelagateOutputData, armnnDelegateOutputData, + armnnDelegate::CompareData(tfLiteOutputValues.data(), armnnOutputValues.data(), expectedOutputValues.size(), 2); } else { - armnnDelegate::CompareData(expectedOutputValues.data(), armnnDelegateOutputData, - expectedOutputValues.size()); - armnnDelegate::CompareData(expectedOutputValues.data(), tfLiteDelagateOutputData, - expectedOutputValues.size()); - armnnDelegate::CompareData(tfLiteDelagateOutputData, armnnDelegateOutputData, expectedOutputValues.size()); + armnnDelegate::CompareOutputData<float>(tfLiteOutputValues, armnnOutputValues, expectedOutputValues); } + + tfLiteInterpreter.Cleanup(); + armnnInterpreter.Cleanup(); } } // anonymous namespace
\ No newline at end of file |