aboutsummaryrefslogtreecommitdiff
path: root/tests/NetworkExecutionUtils
diff options
context:
space:
mode:
Diffstat (limited to 'tests/NetworkExecutionUtils')
-rw-r--r--tests/NetworkExecutionUtils/NetworkExecutionUtils.cpp18
-rw-r--r--tests/NetworkExecutionUtils/NetworkExecutionUtils.hpp28
-rw-r--r--tests/NetworkExecutionUtils/test/NetworkExecutionUtilsTests.cpp37
3 files changed, 57 insertions, 26 deletions
diff --git a/tests/NetworkExecutionUtils/NetworkExecutionUtils.cpp b/tests/NetworkExecutionUtils/NetworkExecutionUtils.cpp
index e3c95d9312..6f9cdf87bc 100644
--- a/tests/NetworkExecutionUtils/NetworkExecutionUtils.cpp
+++ b/tests/NetworkExecutionUtils/NetworkExecutionUtils.cpp
@@ -78,3 +78,21 @@ void LogAndThrow(std::string eMsg)
throw armnn::Exception(eMsg);
}
+/// Compute the root-mean-square error (RMSE) at a byte level between two tensors of the same size.
+/// @param expected
+/// @param actual
+/// @param size size of the tensor in bytes.
+/// @return float the RMSE
+double ComputeByteLevelRMSE(const void* expected, const void* actual, const size_t size)
+{
+ const uint8_t* byteExpected = reinterpret_cast<const uint8_t*>(expected);
+ const uint8_t* byteActual = reinterpret_cast<const uint8_t*>(actual);
+
+ double errorSum = 0;
+ for (unsigned int i = 0; i < size; i++)
+ {
+ int difference = byteExpected[i] - byteActual[i];
+ errorSum += std::pow(difference, 2);
+ }
+ return std::sqrt(errorSum/armnn::numeric_cast<double>(size));
+}
diff --git a/tests/NetworkExecutionUtils/NetworkExecutionUtils.hpp b/tests/NetworkExecutionUtils/NetworkExecutionUtils.hpp
index 8c97238432..2136c446fb 100644
--- a/tests/NetworkExecutionUtils/NetworkExecutionUtils.hpp
+++ b/tests/NetworkExecutionUtils/NetworkExecutionUtils.hpp
@@ -73,6 +73,8 @@ std::vector<unsigned int> ParseArray(std::istream& stream);
/// Splits a given string at every accurance of delimiter into a vector of string
std::vector<std::string> ParseStringList(const std::string& inputString, const char* delimiter);
+double ComputeByteLevelRMSE(const void* expected, const void* actual, const size_t size);
+
/// Dequantize an array of a given type
/// @param array Type erased array to dequantize
/// @param numElements Elements in the array
@@ -285,29 +287,3 @@ std::vector<T> ParseArrayImpl(std::istream& stream, TParseElementFunc parseEleme
return result;
}
-
-/// Compute the root-mean-square error (RMSE)
-/// @param expected
-/// @param actual
-/// @param size size of the tensor
-/// @return float the RMSE
-template<typename T>
-float ComputeRMSE(const void* expected, const void* actual, const size_t size)
-{
- auto typedExpected = reinterpret_cast<const T*>(expected);
- auto typedActual = reinterpret_cast<const T*>(actual);
-
- T errorSum = 0;
-
- for (unsigned int i = 0; i < size; i++)
- {
- if (std::abs(typedExpected[i] - typedActual[i]) != 0)
- {
- std::cout << "";
- }
- errorSum += std::pow(std::abs(typedExpected[i] - typedActual[i]), 2);
- }
-
- float rmse = std::sqrt(armnn::numeric_cast<float>(errorSum) / armnn::numeric_cast<float>(size / sizeof(T)));
- return rmse;
-} \ No newline at end of file
diff --git a/tests/NetworkExecutionUtils/test/NetworkExecutionUtilsTests.cpp b/tests/NetworkExecutionUtils/test/NetworkExecutionUtilsTests.cpp
new file mode 100644
index 0000000000..d11fe571b0
--- /dev/null
+++ b/tests/NetworkExecutionUtils/test/NetworkExecutionUtilsTests.cpp
@@ -0,0 +1,37 @@
+//
+// Copyright © 2023 Arm Ltd and Contributors. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+
+#include "../NetworkExecutionUtils.hpp"
+
+#include <doctest/doctest.h>
+
+namespace
+{
+
+TEST_SUITE("NetworkExecutionUtilsTests")
+{
+
+TEST_CASE ("ComputeByteLevelRMSE")
+{
+ // Bytes.
+ const uint8_t expected[] = {1, 128, 255};
+ const uint8_t actual[] = {0, 127, 254};
+
+ CHECK(ComputeByteLevelRMSE(expected, expected, 3) == 0);
+ CHECK(ComputeByteLevelRMSE(expected, actual, 3) == 1.0);
+
+ // Floats.
+ const float expectedFloat[] =
+ {55.20419, 24.58061, 67.76520, 47.31617, 55.58102, 44.64565, 105.76307, 54.65538, 80.41088, 66.05208};
+ const float actualFloat[] =
+ {13.87187, 14.16160, 49.28846, 25.89192, 97.70659, 91.30055, 15.88831, 4.79960, 102.99205, 51.28290};
+ const double expectedResult = 74.059098023; // Calculated manually.
+ CHECK(ComputeByteLevelRMSE(expectedFloat, expectedFloat, sizeof(float) * 10) == 0);
+ CHECK(ComputeByteLevelRMSE(expectedFloat, actualFloat, sizeof(float) * 10) == doctest::Approx(expectedResult));
+}
+
+} // End of TEST_SUITE("NetworkExecutionUtilsTests")
+
+} // anonymous namespace \ No newline at end of file