diff options
Diffstat (limited to 'delegate/src/test/TestUtils.hpp')
-rw-r--r-- | delegate/src/test/TestUtils.hpp | 33 |
1 files changed, 33 insertions, 0 deletions
diff --git a/delegate/src/test/TestUtils.hpp b/delegate/src/test/TestUtils.hpp index 162d62f3bb..9bbab8f62b 100644 --- a/delegate/src/test/TestUtils.hpp +++ b/delegate/src/test/TestUtils.hpp @@ -7,6 +7,8 @@ #include <tensorflow/lite/interpreter.h> +#include <doctest/doctest.h> + namespace armnnDelegate { @@ -23,4 +25,35 @@ void FillInput(std::unique_ptr<tflite::Interpreter>& interpreter, int inputIndex } } +// Can be used to compare the output tensor shape and values +// from armnnDelegateInterpreter and tfLiteInterpreter. +// Example usage can be found in ControlTestHelper.hpp +template <typename T> +void CompareOutputData(std::unique_ptr<tflite::Interpreter>& tfLiteInterpreter, + std::unique_ptr<tflite::Interpreter>& armnnDelegateInterpreter, + std::vector<int32_t>& expectedOutputShape, + std::vector<T>& expectedOutputValues) +{ + auto tfLiteDelegateOutputId = tfLiteInterpreter->outputs()[0]; + auto tfLiteDelegateOutputTensor = tfLiteInterpreter->tensor(tfLiteDelegateOutputId); + auto tfLiteDelageOutputData = tfLiteInterpreter->typed_tensor<T>(tfLiteDelegateOutputId); + auto armnnDelegateOutputId = armnnDelegateInterpreter->outputs()[0]; + auto armnnDelegateOutputTensor = armnnDelegateInterpreter->tensor(armnnDelegateOutputId); + auto armnnDelegateOutputData = armnnDelegateInterpreter->typed_tensor<T>(armnnDelegateOutputId); + + for (size_t i = 0; i < expectedOutputShape.size(); i++) + { + CHECK(expectedOutputShape[i] == armnnDelegateOutputTensor->dims->data[i]); + CHECK(tfLiteDelegateOutputTensor->dims->data[i] == expectedOutputShape[i]); + CHECK(tfLiteDelegateOutputTensor->dims->data[i] == armnnDelegateOutputTensor->dims->data[i]); + } + + for (size_t i = 0; i < expectedOutputValues.size(); i++) + { + CHECK(expectedOutputValues[i] == armnnDelegateOutputData[i]); + CHECK(tfLiteDelageOutputData[i] == expectedOutputValues[i]); + CHECK(tfLiteDelageOutputData[i] == armnnDelegateOutputData[i]); + } +} + } // namespace armnnDelegate |