diff options
Diffstat (limited to 'delegate/test/LogicalTestHelper.hpp')
-rw-r--r-- | delegate/test/LogicalTestHelper.hpp | 91 |
1 files changed, 34 insertions, 57 deletions
diff --git a/delegate/test/LogicalTestHelper.hpp b/delegate/test/LogicalTestHelper.hpp index 2f2ae7bf40..7da8ad9bfc 100644 --- a/delegate/test/LogicalTestHelper.hpp +++ b/delegate/test/LogicalTestHelper.hpp @@ -8,14 +8,14 @@ #include "TestUtils.hpp" #include <armnn_delegate.hpp> +#include <DelegateTestInterpreter.hpp> #include <flatbuffers/flatbuffers.h> -#include <tensorflow/lite/interpreter.h> #include <tensorflow/lite/kernels/register.h> -#include <tensorflow/lite/model.h> -#include <schema_generated.h> #include <tensorflow/lite/version.h> +#include <schema_generated.h> + #include <doctest/doctest.h> namespace @@ -120,26 +120,25 @@ std::vector<char> CreateLogicalBinaryTfLiteModel(tflite::BuiltinOperator logical modelDescription, flatBufferBuilder.CreateVector(buffers.data(), buffers.size())); - flatBufferBuilder.Finish(flatbufferModel); + flatBufferBuilder.Finish(flatbufferModel, armnnDelegate::FILE_IDENTIFIER); return std::vector<char>(flatBufferBuilder.GetBufferPointer(), flatBufferBuilder.GetBufferPointer() + flatBufferBuilder.GetSize()); } -template <typename T> void LogicalBinaryTest(tflite::BuiltinOperator logicalOperatorCode, tflite::TensorType tensorType, std::vector<armnn::BackendId>& backends, std::vector<int32_t>& input0Shape, std::vector<int32_t>& input1Shape, std::vector<int32_t>& expectedOutputShape, - std::vector<T>& input0Values, - std::vector<T>& input1Values, - std::vector<T>& expectedOutputValues, + std::vector<bool>& input0Values, + std::vector<bool>& input1Values, + std::vector<bool>& expectedOutputValues, float quantScale = 1.0f, int quantOffset = 0) { - using namespace tflite; + using namespace delegateTestInterpreter; std::vector<char> modelBuffer = CreateLogicalBinaryTfLiteModel(logicalOperatorCode, tensorType, input0Shape, @@ -148,54 +147,32 @@ void LogicalBinaryTest(tflite::BuiltinOperator logicalOperatorCode, quantScale, quantOffset); - const Model* tfLiteModel = GetModel(modelBuffer.data()); - // Create TfLite Interpreters - std::unique_ptr<Interpreter> armnnDelegateInterpreter; - CHECK(InterpreterBuilder(tfLiteModel, ::tflite::ops::builtin::BuiltinOpResolver()) - (&armnnDelegateInterpreter) == kTfLiteOk); - CHECK(armnnDelegateInterpreter != nullptr); - CHECK(armnnDelegateInterpreter->AllocateTensors() == kTfLiteOk); - - std::unique_ptr<Interpreter> tfLiteInterpreter; - CHECK(InterpreterBuilder(tfLiteModel, ::tflite::ops::builtin::BuiltinOpResolver()) - (&tfLiteInterpreter) == kTfLiteOk); - CHECK(tfLiteInterpreter != nullptr); - CHECK(tfLiteInterpreter->AllocateTensors() == kTfLiteOk); - - // Create the ArmNN Delegate - armnnDelegate::DelegateOptions delegateOptions(backends); - std::unique_ptr<TfLiteDelegate, decltype(&armnnDelegate::TfLiteArmnnDelegateDelete)> - theArmnnDelegate(armnnDelegate::TfLiteArmnnDelegateCreate(delegateOptions), - armnnDelegate::TfLiteArmnnDelegateDelete); - CHECK(theArmnnDelegate != nullptr); - // Modify armnnDelegateInterpreter to use armnnDelegate - CHECK(armnnDelegateInterpreter->ModifyGraphWithDelegate(theArmnnDelegate.get()) == kTfLiteOk); - - // Set input data for the armnn interpreter - armnnDelegate::FillInput(armnnDelegateInterpreter, 0, input0Values); - armnnDelegate::FillInput(armnnDelegateInterpreter, 1, input1Values); - - // Set input data for the tflite interpreter - armnnDelegate::FillInput(tfLiteInterpreter, 0, input0Values); - armnnDelegate::FillInput(tfLiteInterpreter, 1, input1Values); - - // Run EnqueWorkload - CHECK(tfLiteInterpreter->Invoke() == kTfLiteOk); - CHECK(armnnDelegateInterpreter->Invoke() == kTfLiteOk); - - // Compare output data, comparing Boolean values is handled differently and needs to call the CompareData function - // directly. This is because Boolean types get converted to a bit representation in a vector. - auto tfLiteDelegateOutputId = tfLiteInterpreter->outputs()[0]; - auto tfLiteDelegateOutputData = tfLiteInterpreter->typed_tensor<T>(tfLiteDelegateOutputId); - auto armnnDelegateOutputId = armnnDelegateInterpreter->outputs()[0]; - auto armnnDelegateOutputData = armnnDelegateInterpreter->typed_tensor<T>(armnnDelegateOutputId); - - armnnDelegate::CompareData(expectedOutputValues, armnnDelegateOutputData, expectedOutputValues.size()); - armnnDelegate::CompareData(expectedOutputValues, tfLiteDelegateOutputData, expectedOutputValues.size()); - armnnDelegate::CompareData(tfLiteDelegateOutputData, armnnDelegateOutputData, expectedOutputValues.size()); - - armnnDelegateInterpreter.reset(nullptr); - tfLiteInterpreter.reset(nullptr); + // Setup interpreter with just TFLite Runtime. + auto tfLiteInterpreter = DelegateTestInterpreter(modelBuffer); + CHECK(tfLiteInterpreter.AllocateTensors() == kTfLiteOk); + CHECK(tfLiteInterpreter.FillInputTensor(input0Values, 0) == kTfLiteOk); + CHECK(tfLiteInterpreter.FillInputTensor(input1Values, 1) == kTfLiteOk); + CHECK(tfLiteInterpreter.Invoke() == kTfLiteOk); + std::vector<bool> tfLiteOutputValues = tfLiteInterpreter.GetOutputResult(0); + std::vector<int32_t> tfLiteOutputShape = tfLiteInterpreter.GetOutputShape(0); + + // Setup interpreter with Arm NN Delegate applied. + auto armnnInterpreter = DelegateTestInterpreter(modelBuffer, backends); + CHECK(armnnInterpreter.AllocateTensors() == kTfLiteOk); + CHECK(armnnInterpreter.FillInputTensor(input0Values, 0) == kTfLiteOk); + CHECK(armnnInterpreter.FillInputTensor(input1Values, 1) == kTfLiteOk); + CHECK(armnnInterpreter.Invoke() == kTfLiteOk); + std::vector<bool> armnnOutputValues = armnnInterpreter.GetOutputResult(0); + std::vector<int32_t> armnnOutputShape = armnnInterpreter.GetOutputShape(0); + + armnnDelegate::CompareOutputShape(tfLiteOutputShape, armnnOutputShape, expectedOutputShape); + + armnnDelegate::CompareData(expectedOutputValues, armnnOutputValues, expectedOutputValues.size()); + armnnDelegate::CompareData(expectedOutputValues, tfLiteOutputValues, expectedOutputValues.size()); + armnnDelegate::CompareData(tfLiteOutputValues, armnnOutputValues, expectedOutputValues.size()); + + tfLiteInterpreter.Cleanup(); + armnnInterpreter.Cleanup(); } } // anonymous namespace
\ No newline at end of file |