aboutsummaryrefslogtreecommitdiff
path: root/delegate/src/test/TestUtils.hpp
diff options
context:
space:
mode:
Diffstat (limited to 'delegate/src/test/TestUtils.hpp')
-rw-r--r--delegate/src/test/TestUtils.hpp33
1 files changed, 33 insertions, 0 deletions
diff --git a/delegate/src/test/TestUtils.hpp b/delegate/src/test/TestUtils.hpp
index 162d62f3bb..9bbab8f62b 100644
--- a/delegate/src/test/TestUtils.hpp
+++ b/delegate/src/test/TestUtils.hpp
@@ -7,6 +7,8 @@
#include <tensorflow/lite/interpreter.h>
+#include <doctest/doctest.h>
+
namespace armnnDelegate
{
@@ -23,4 +25,35 @@ void FillInput(std::unique_ptr<tflite::Interpreter>& interpreter, int inputIndex
}
}
+// Can be used to compare the output tensor shape and values
+// from armnnDelegateInterpreter and tfLiteInterpreter.
+// Example usage can be found in ControlTestHelper.hpp
+template <typename T>
+void CompareOutputData(std::unique_ptr<tflite::Interpreter>& tfLiteInterpreter,
+ std::unique_ptr<tflite::Interpreter>& armnnDelegateInterpreter,
+ std::vector<int32_t>& expectedOutputShape,
+ std::vector<T>& expectedOutputValues)
+{
+ auto tfLiteDelegateOutputId = tfLiteInterpreter->outputs()[0];
+ auto tfLiteDelegateOutputTensor = tfLiteInterpreter->tensor(tfLiteDelegateOutputId);
+ auto tfLiteDelageOutputData = tfLiteInterpreter->typed_tensor<T>(tfLiteDelegateOutputId);
+ auto armnnDelegateOutputId = armnnDelegateInterpreter->outputs()[0];
+ auto armnnDelegateOutputTensor = armnnDelegateInterpreter->tensor(armnnDelegateOutputId);
+ auto armnnDelegateOutputData = armnnDelegateInterpreter->typed_tensor<T>(armnnDelegateOutputId);
+
+ for (size_t i = 0; i < expectedOutputShape.size(); i++)
+ {
+ CHECK(expectedOutputShape[i] == armnnDelegateOutputTensor->dims->data[i]);
+ CHECK(tfLiteDelegateOutputTensor->dims->data[i] == expectedOutputShape[i]);
+ CHECK(tfLiteDelegateOutputTensor->dims->data[i] == armnnDelegateOutputTensor->dims->data[i]);
+ }
+
+ for (size_t i = 0; i < expectedOutputValues.size(); i++)
+ {
+ CHECK(expectedOutputValues[i] == armnnDelegateOutputData[i]);
+ CHECK(tfLiteDelageOutputData[i] == expectedOutputValues[i]);
+ CHECK(tfLiteDelageOutputData[i] == armnnDelegateOutputData[i]);
+ }
+}
+
} // namespace armnnDelegate