aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorGeorgios Pinitas <georgios.pinitas@arm.com>2023-05-30 12:20:31 +0100
committerDominic Symes <dominic.symes@arm.com>2023-06-15 13:49:08 +0000
commit41df428ed5e3b07f0a497fc504f1eddb8e115188 (patch)
tree48dcbabf230d30c3bd1fdb2602164dadcaf3cb05
parent6168047ef0354927cb175ad295722924dfc3053c (diff)
downloadreference_model-41df428ed5e3b07f0a497fc504f1eddb8e115188.tar.gz
Add TOSA MI verification methods
Adds utility functions that enable compliance verification of TOSA MI operators; as per section 1.8.2 in the TOSA specification. Signed-off-by: Georgios Pinitas <georgios.pinitas@arm.com> Change-Id: I0cced0ff8875ac8d78b1943211438713d1c51b88
-rw-r--r--reference_model/CMakeLists.txt13
-rw-r--r--reference_model/include/verify.h71
-rw-r--r--reference_model/src/verify.cc130
-rw-r--r--reference_model/test/verify_tests.cpp56
4 files changed, 265 insertions, 5 deletions
diff --git a/reference_model/CMakeLists.txt b/reference_model/CMakeLists.txt
index a086f4b..5a0195c 100644
--- a/reference_model/CMakeLists.txt
+++ b/reference_model/CMakeLists.txt
@@ -60,13 +60,14 @@ include_directories(${CMAKE_CURRENT_SOURCE_DIR}/include)
# Common sources required for TOSA Reference Model library, executable and unit tests
set(CXX_SOURCE
- src/model_runner.cc
- src/model_runner_impl.cc
- src/tensor.cc
- src/graph_node.cc
- src/subgraph_traverser.cc
src/func_debug.cc
+ src/graph_node.cc
+ src/model_runner_impl.cc
+ src/model_runner.cc
src/operators.cc
+ src/subgraph_traverser.cc
+ src/tensor.cc
+ src/verify.cc
src/ops/op_factory.cc
src/ops/tensor_ops.cc
src/ops/activation_funcs.cc
@@ -115,6 +116,7 @@ list(APPEND PUBLIC_HEADERS
include/graph_status.h
include/model_common.h
include/model_runner.h
+ include/verify.h
include/version.h
)
@@ -162,6 +164,7 @@ if(BUILD_TOSA_REFERENCE_MODEL_TESTS)
# Sources only required for unit tests.
set(CXX_SOURCE_TESTS
test/model_runner_tests.cpp
+ test/verify_tests.cpp
${DOCTEST_DIR}/doctest.h
)
diff --git a/reference_model/include/verify.h b/reference_model/include/verify.h
new file mode 100644
index 0000000..0cf6b6c
--- /dev/null
+++ b/reference_model/include/verify.h
@@ -0,0 +1,71 @@
+// Copyright (c) 2023, ARM Limited.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+//===----------------------------------------------------------------------===//
+//
+// Verification functionality as per TOSA Specification
+// Output Verification : Section 1.8.2
+//
+//===----------------------------------------------------------------------===//
+
+#include <cstddef>
+
+#ifdef __cplusplus
+extern "C" {
+#endif /* __cplusplus */
+
+// Check result
+//
+// Error is valid only and only if is_valid is true
+struct CheckResult
+{
+ bool is_valid;
+ double error;
+};
+
+/// Validate and calculate tensor element error when using an fp32 accumulator
+///
+/// \param ref Tensor element calculated using fp64
+/// \param bnd Tensor element calculated using fp64 on abs(input, weights)
+/// \param imp Tensor element calculated through the implementation
+/// \param KS The kernel size
+///
+/// \return Output error
+CheckResult tosa_validate_element_accfp32(double ref, double bnd, float imp, size_t KS);
+
+/// Validate the accumulated output error
+///
+/// \param err_sum Sum of error of all the tensor elements within a tensor
+/// \param err_sum_sq Sum of error squares of all the tensor elements within a tensor
+/// \param T Number of output (dot-product) elements
+/// \param KS The kernel size
+/// \param S Test set used as a input/weight generator
+///
+/// \return True if the error is within margin else false
+bool tosa_validate_output_error(double err_sum, double err_sum_sq, size_t T, size_t KS, int S);
+
+/// Validate error of whole vector of output data
+///
+/// \param ref Output elements calculated using fp64
+/// \param bnd Output elements calculated using fp64 on abs(input, weights)
+/// \param imp Output elements calculated using the implementation
+/// \param T Number of elements in outputs (need to match)
+/// \param KS The kernel size
+/// \param S Test set used as a input/weight generator
+///
+/// \return True if the error is within margin else false
+bool tosa_validate_data_fp32(double* ref, double* bnd, float* imp, size_t T, size_t KS, int S);
+
+#ifdef __cplusplus
+}
+#endif /* __cplusplus */ \ No newline at end of file
diff --git a/reference_model/src/verify.cc b/reference_model/src/verify.cc
new file mode 100644
index 0000000..940275f
--- /dev/null
+++ b/reference_model/src/verify.cc
@@ -0,0 +1,130 @@
+// Copyright (c) 2023, ARM Limited.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+//===----------------------------------------------------------------------===//
+//
+// Verification functionality as per TOSA Specification
+// Output Verification : Section 1.8.2
+//
+//===----------------------------------------------------------------------===//
+
+#include "verify.h"
+
+#include <half.hpp>
+
+#include <cmath>
+#include <numeric>
+#include <optional>
+#include <type_traits>
+
+#define REQUIRE(COND) \
+ if (!(COND)) \
+ { \
+ return false; \
+ }
+
+namespace
+{
+// Accumulator precision
+template <typename T>
+struct AccPrecision;
+template <>
+struct AccPrecision<float>
+{
+ static constexpr double precision = (double)(1 << 24);
+};
+template <>
+struct AccPrecision<half_float::half>
+{
+ static constexpr double precision = (double)(1 << 11);
+};
+
+// Generic element validation function
+template <typename AccType, typename std::enable_if_t<std::is_floating_point_v<AccType>, int> = 0>
+std::optional<double> validate_element(double ref, double bnd, AccType imp, size_t KS)
+{
+ double err = 0.0;
+ bool is_valid = true;
+
+ if (bnd == 0.0)
+ {
+ is_valid = (ref == 0.0) && (imp == 0.0);
+ err = 0.0;
+ }
+ else
+ { // bnd > 0.0
+ const double imp_fp64 = static_cast<double>(imp);
+ const double acc_prec_fp64 = AccPrecision<AccType>::precision;
+ err = (imp_fp64 - ref) * acc_prec_fp64 / bnd;
+ is_valid = std::abs(err) <= KS;
+ }
+
+ return is_valid ? std::optional(err) : std::nullopt;
+}
+
+// Generic data validation function
+template <typename AccType, typename std::enable_if_t<std::is_floating_point_v<AccType>, int> = 0>
+bool validate_data(double* ref, double* bnd, AccType* imp, size_t T, size_t KS, int32_t S)
+{
+ double out_err_sum = 0.0;
+ double out_err_sumsq = 0.0;
+
+ for (size_t i = 0; i < T; ++i)
+ {
+ auto out_err = validate_element<AccType>(ref[i], bnd[i], imp[i], KS);
+ REQUIRE(out_err);
+ out_err_sum += out_err.value();
+ out_err_sumsq += out_err.value() * out_err.value();
+ }
+
+ return tosa_validate_output_error(out_err_sum, out_err_sumsq, T, KS, S);
+}
+
+// Convert std::optional to CheckResult
+CheckResult from_optional(const std::optional<double>& res)
+{
+ if (res)
+ return { true, *res };
+ else
+ return { false, 0.0 };
+}
+} // namespace
+
+extern "C" {
+
+CheckResult tosa_validate_element_accfp32(double ref, double bnd, float imp, size_t KS)
+{
+ auto err = validate_element<float>(ref, bnd, imp, KS);
+ return from_optional(err);
+}
+
+bool tosa_validate_output_error(double err_sum, double err_sum_sq, size_t T, size_t KS, int S)
+{
+ if (S != 1 && S != 2)
+ {
+ // Check error bias magnitude for data sets S which are not positive biased
+ REQUIRE(abs(err_sum) <= 2 * sqrt(KS * T));
+ }
+ // Check error variance magnitude
+ REQUIRE(err_sum_sq <= 0.4 * KS * T);
+
+ return true;
+}
+
+bool tosa_validate_data_fp32(double* ref, double* bnd, float* imp, size_t T, size_t KS, int S)
+{
+ return validate_data<float>(ref, bnd, imp, T, KS, S);
+}
+
+} // extern "C"
+#undef REQUIRE \ No newline at end of file
diff --git a/reference_model/test/verify_tests.cpp b/reference_model/test/verify_tests.cpp
new file mode 100644
index 0000000..f43804f
--- /dev/null
+++ b/reference_model/test/verify_tests.cpp
@@ -0,0 +1,56 @@
+// Copyright (c) 2023, ARM Limited.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+#include "verify.h"
+
+#include <doctest.h>
+
+#include <array>
+#include <numeric>
+#include <vector>
+
+TEST_SUITE("verify")
+{
+ TEST_CASE("check_element_accfp32")
+ {
+ const size_t KS = 27;
+
+ // Negative (bnd == 0.0)
+ REQUIRE_FALSE(tosa_validate_element_accfp32(0.0, 0.0, 1.0, KS).is_valid);
+ REQUIRE_FALSE(tosa_validate_element_accfp32(1.0, 0.0, 0.0, KS).is_valid);
+ // Negative (bnd > 0.0)
+ REQUIRE_FALSE(tosa_validate_element_accfp32(5.0, 5.0, 5.1, KS).is_valid);
+
+ // Positive (bnd == 0.0 && ref == 0.0 && imp == 0.0)
+ REQUIRE(tosa_validate_element_accfp32(0.0, 0.0, 0.0, KS).is_valid);
+ REQUIRE(tosa_validate_element_accfp32(0.0, 0.0, 0.0, KS).error == 0.0);
+
+ // Positive (bnd > 0.0)
+ REQUIRE(tosa_validate_element_accfp32(4.0, 4.0, 4.0, KS).error == 0.0);
+ REQUIRE(tosa_validate_element_accfp32(4.0, 4.0, 4.0, KS).error == 0.0);
+ REQUIRE(tosa_validate_element_accfp32(4.0, 4.0, 4.0, KS).error == 0.0);
+ }
+ TEST_CASE("check_output_error")
+ {
+ const size_t KS = 27;
+ const size_t T = 1024;
+
+ // Negative (S!=1 && S!=2 && (abs(err_sum) > 2*sqrt(KS*T)))
+ REQUIRE_FALSE(tosa_validate_output_error(1560, 112000, KS, T, 0));
+ // Negative (err_sum_sq > 0.4*KS*T))
+ REQUIRE_FALSE(tosa_validate_output_error(1560, 112000, KS, T, 1));
+ // Positive
+ REQUIRE(tosa_validate_output_error(10, 254, KS, T, 0));
+ REQUIRE(tosa_validate_output_error(10, 254, KS, T, 1));
+ }
+}