diff options
Diffstat (limited to 'tests/NetworkExecutionUtils')
3 files changed, 57 insertions, 26 deletions
diff --git a/tests/NetworkExecutionUtils/NetworkExecutionUtils.cpp b/tests/NetworkExecutionUtils/NetworkExecutionUtils.cpp index e3c95d9312..6f9cdf87bc 100644 --- a/tests/NetworkExecutionUtils/NetworkExecutionUtils.cpp +++ b/tests/NetworkExecutionUtils/NetworkExecutionUtils.cpp @@ -78,3 +78,21 @@ void LogAndThrow(std::string eMsg) throw armnn::Exception(eMsg); } +/// Compute the root-mean-square error (RMSE) at a byte level between two tensors of the same size. +/// @param expected +/// @param actual +/// @param size size of the tensor in bytes. +/// @return float the RMSE +double ComputeByteLevelRMSE(const void* expected, const void* actual, const size_t size) +{ + const uint8_t* byteExpected = reinterpret_cast<const uint8_t*>(expected); + const uint8_t* byteActual = reinterpret_cast<const uint8_t*>(actual); + + double errorSum = 0; + for (unsigned int i = 0; i < size; i++) + { + int difference = byteExpected[i] - byteActual[i]; + errorSum += std::pow(difference, 2); + } + return std::sqrt(errorSum/armnn::numeric_cast<double>(size)); +} diff --git a/tests/NetworkExecutionUtils/NetworkExecutionUtils.hpp b/tests/NetworkExecutionUtils/NetworkExecutionUtils.hpp index 8c97238432..2136c446fb 100644 --- a/tests/NetworkExecutionUtils/NetworkExecutionUtils.hpp +++ b/tests/NetworkExecutionUtils/NetworkExecutionUtils.hpp @@ -73,6 +73,8 @@ std::vector<unsigned int> ParseArray(std::istream& stream); /// Splits a given string at every accurance of delimiter into a vector of string std::vector<std::string> ParseStringList(const std::string& inputString, const char* delimiter); +double ComputeByteLevelRMSE(const void* expected, const void* actual, const size_t size); + /// Dequantize an array of a given type /// @param array Type erased array to dequantize /// @param numElements Elements in the array @@ -285,29 +287,3 @@ std::vector<T> ParseArrayImpl(std::istream& stream, TParseElementFunc parseEleme return result; } - -/// Compute the root-mean-square error (RMSE) -/// @param expected -/// @param actual -/// @param size size of the tensor -/// @return float the RMSE -template<typename T> -float ComputeRMSE(const void* expected, const void* actual, const size_t size) -{ - auto typedExpected = reinterpret_cast<const T*>(expected); - auto typedActual = reinterpret_cast<const T*>(actual); - - T errorSum = 0; - - for (unsigned int i = 0; i < size; i++) - { - if (std::abs(typedExpected[i] - typedActual[i]) != 0) - { - std::cout << ""; - } - errorSum += std::pow(std::abs(typedExpected[i] - typedActual[i]), 2); - } - - float rmse = std::sqrt(armnn::numeric_cast<float>(errorSum) / armnn::numeric_cast<float>(size / sizeof(T))); - return rmse; -}
\ No newline at end of file diff --git a/tests/NetworkExecutionUtils/test/NetworkExecutionUtilsTests.cpp b/tests/NetworkExecutionUtils/test/NetworkExecutionUtilsTests.cpp new file mode 100644 index 0000000000..d11fe571b0 --- /dev/null +++ b/tests/NetworkExecutionUtils/test/NetworkExecutionUtilsTests.cpp @@ -0,0 +1,37 @@ +// +// Copyright © 2023 Arm Ltd and Contributors. All rights reserved. +// SPDX-License-Identifier: MIT +// + +#include "../NetworkExecutionUtils.hpp" + +#include <doctest/doctest.h> + +namespace +{ + +TEST_SUITE("NetworkExecutionUtilsTests") +{ + +TEST_CASE ("ComputeByteLevelRMSE") +{ + // Bytes. + const uint8_t expected[] = {1, 128, 255}; + const uint8_t actual[] = {0, 127, 254}; + + CHECK(ComputeByteLevelRMSE(expected, expected, 3) == 0); + CHECK(ComputeByteLevelRMSE(expected, actual, 3) == 1.0); + + // Floats. + const float expectedFloat[] = + {55.20419, 24.58061, 67.76520, 47.31617, 55.58102, 44.64565, 105.76307, 54.65538, 80.41088, 66.05208}; + const float actualFloat[] = + {13.87187, 14.16160, 49.28846, 25.89192, 97.70659, 91.30055, 15.88831, 4.79960, 102.99205, 51.28290}; + const double expectedResult = 74.059098023; // Calculated manually. + CHECK(ComputeByteLevelRMSE(expectedFloat, expectedFloat, sizeof(float) * 10) == 0); + CHECK(ComputeByteLevelRMSE(expectedFloat, actualFloat, sizeof(float) * 10) == doctest::Approx(expectedResult)); +} + +} // End of TEST_SUITE("NetworkExecutionUtilsTests") + +} // anonymous namespace
\ No newline at end of file |