diff options
Diffstat (limited to 'reference_model')
-rw-r--r-- | reference_model/CMakeLists.txt | 13 | ||||
-rw-r--r-- | reference_model/include/verify.h | 71 | ||||
-rw-r--r-- | reference_model/src/verify.cc | 130 | ||||
-rw-r--r-- | reference_model/test/verify_tests.cpp | 56 |
4 files changed, 265 insertions, 5 deletions
diff --git a/reference_model/CMakeLists.txt b/reference_model/CMakeLists.txt index a086f4b..5a0195c 100644 --- a/reference_model/CMakeLists.txt +++ b/reference_model/CMakeLists.txt @@ -60,13 +60,14 @@ include_directories(${CMAKE_CURRENT_SOURCE_DIR}/include) # Common sources required for TOSA Reference Model library, executable and unit tests set(CXX_SOURCE - src/model_runner.cc - src/model_runner_impl.cc - src/tensor.cc - src/graph_node.cc - src/subgraph_traverser.cc src/func_debug.cc + src/graph_node.cc + src/model_runner_impl.cc + src/model_runner.cc src/operators.cc + src/subgraph_traverser.cc + src/tensor.cc + src/verify.cc src/ops/op_factory.cc src/ops/tensor_ops.cc src/ops/activation_funcs.cc @@ -115,6 +116,7 @@ list(APPEND PUBLIC_HEADERS include/graph_status.h include/model_common.h include/model_runner.h + include/verify.h include/version.h ) @@ -162,6 +164,7 @@ if(BUILD_TOSA_REFERENCE_MODEL_TESTS) # Sources only required for unit tests. set(CXX_SOURCE_TESTS test/model_runner_tests.cpp + test/verify_tests.cpp ${DOCTEST_DIR}/doctest.h ) diff --git a/reference_model/include/verify.h b/reference_model/include/verify.h new file mode 100644 index 0000000..0cf6b6c --- /dev/null +++ b/reference_model/include/verify.h @@ -0,0 +1,71 @@ +// Copyright (c) 2023, ARM Limited. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +//===----------------------------------------------------------------------===// +// +// Verification functionality as per TOSA Specification +// Output Verification : Section 1.8.2 +// +//===----------------------------------------------------------------------===// + +#include <cstddef> + +#ifdef __cplusplus +extern "C" { +#endif /* __cplusplus */ + +// Check result +// +// Error is valid only and only if is_valid is true +struct CheckResult +{ + bool is_valid; + double error; +}; + +/// Validate and calculate tensor element error when using an fp32 accumulator +/// +/// \param ref Tensor element calculated using fp64 +/// \param bnd Tensor element calculated using fp64 on abs(input, weights) +/// \param imp Tensor element calculated through the implementation +/// \param KS The kernel size +/// +/// \return Output error +CheckResult tosa_validate_element_accfp32(double ref, double bnd, float imp, size_t KS); + +/// Validate the accumulated output error +/// +/// \param err_sum Sum of error of all the tensor elements within a tensor +/// \param err_sum_sq Sum of error squares of all the tensor elements within a tensor +/// \param T Number of output (dot-product) elements +/// \param KS The kernel size +/// \param S Test set used as a input/weight generator +/// +/// \return True if the error is within margin else false +bool tosa_validate_output_error(double err_sum, double err_sum_sq, size_t T, size_t KS, int S); + +/// Validate error of whole vector of output data +/// +/// \param ref Output elements calculated using fp64 +/// \param bnd Output elements calculated using fp64 on abs(input, weights) +/// \param imp Output elements calculated using the implementation +/// \param T Number of elements in outputs (need to match) +/// \param KS The kernel size +/// \param S Test set used as a input/weight generator +/// +/// \return True if the error is within margin else false +bool tosa_validate_data_fp32(double* ref, double* bnd, float* imp, size_t T, size_t KS, int S); + +#ifdef __cplusplus +} +#endif /* __cplusplus */
\ No newline at end of file diff --git a/reference_model/src/verify.cc b/reference_model/src/verify.cc new file mode 100644 index 0000000..940275f --- /dev/null +++ b/reference_model/src/verify.cc @@ -0,0 +1,130 @@ +// Copyright (c) 2023, ARM Limited. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +//===----------------------------------------------------------------------===// +// +// Verification functionality as per TOSA Specification +// Output Verification : Section 1.8.2 +// +//===----------------------------------------------------------------------===// + +#include "verify.h" + +#include <half.hpp> + +#include <cmath> +#include <numeric> +#include <optional> +#include <type_traits> + +#define REQUIRE(COND) \ + if (!(COND)) \ + { \ + return false; \ + } + +namespace +{ +// Accumulator precision +template <typename T> +struct AccPrecision; +template <> +struct AccPrecision<float> +{ + static constexpr double precision = (double)(1 << 24); +}; +template <> +struct AccPrecision<half_float::half> +{ + static constexpr double precision = (double)(1 << 11); +}; + +// Generic element validation function +template <typename AccType, typename std::enable_if_t<std::is_floating_point_v<AccType>, int> = 0> +std::optional<double> validate_element(double ref, double bnd, AccType imp, size_t KS) +{ + double err = 0.0; + bool is_valid = true; + + if (bnd == 0.0) + { + is_valid = (ref == 0.0) && (imp == 0.0); + err = 0.0; + } + else + { // bnd > 0.0 + const double imp_fp64 = static_cast<double>(imp); + const double acc_prec_fp64 = AccPrecision<AccType>::precision; + err = (imp_fp64 - ref) * acc_prec_fp64 / bnd; + is_valid = std::abs(err) <= KS; + } + + return is_valid ? std::optional(err) : std::nullopt; +} + +// Generic data validation function +template <typename AccType, typename std::enable_if_t<std::is_floating_point_v<AccType>, int> = 0> +bool validate_data(double* ref, double* bnd, AccType* imp, size_t T, size_t KS, int32_t S) +{ + double out_err_sum = 0.0; + double out_err_sumsq = 0.0; + + for (size_t i = 0; i < T; ++i) + { + auto out_err = validate_element<AccType>(ref[i], bnd[i], imp[i], KS); + REQUIRE(out_err); + out_err_sum += out_err.value(); + out_err_sumsq += out_err.value() * out_err.value(); + } + + return tosa_validate_output_error(out_err_sum, out_err_sumsq, T, KS, S); +} + +// Convert std::optional to CheckResult +CheckResult from_optional(const std::optional<double>& res) +{ + if (res) + return { true, *res }; + else + return { false, 0.0 }; +} +} // namespace + +extern "C" { + +CheckResult tosa_validate_element_accfp32(double ref, double bnd, float imp, size_t KS) +{ + auto err = validate_element<float>(ref, bnd, imp, KS); + return from_optional(err); +} + +bool tosa_validate_output_error(double err_sum, double err_sum_sq, size_t T, size_t KS, int S) +{ + if (S != 1 && S != 2) + { + // Check error bias magnitude for data sets S which are not positive biased + REQUIRE(abs(err_sum) <= 2 * sqrt(KS * T)); + } + // Check error variance magnitude + REQUIRE(err_sum_sq <= 0.4 * KS * T); + + return true; +} + +bool tosa_validate_data_fp32(double* ref, double* bnd, float* imp, size_t T, size_t KS, int S) +{ + return validate_data<float>(ref, bnd, imp, T, KS, S); +} + +} // extern "C" +#undef REQUIRE
\ No newline at end of file diff --git a/reference_model/test/verify_tests.cpp b/reference_model/test/verify_tests.cpp new file mode 100644 index 0000000..f43804f --- /dev/null +++ b/reference_model/test/verify_tests.cpp @@ -0,0 +1,56 @@ +// Copyright (c) 2023, ARM Limited. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include "verify.h" + +#include <doctest.h> + +#include <array> +#include <numeric> +#include <vector> + +TEST_SUITE("verify") +{ + TEST_CASE("check_element_accfp32") + { + const size_t KS = 27; + + // Negative (bnd == 0.0) + REQUIRE_FALSE(tosa_validate_element_accfp32(0.0, 0.0, 1.0, KS).is_valid); + REQUIRE_FALSE(tosa_validate_element_accfp32(1.0, 0.0, 0.0, KS).is_valid); + // Negative (bnd > 0.0) + REQUIRE_FALSE(tosa_validate_element_accfp32(5.0, 5.0, 5.1, KS).is_valid); + + // Positive (bnd == 0.0 && ref == 0.0 && imp == 0.0) + REQUIRE(tosa_validate_element_accfp32(0.0, 0.0, 0.0, KS).is_valid); + REQUIRE(tosa_validate_element_accfp32(0.0, 0.0, 0.0, KS).error == 0.0); + + // Positive (bnd > 0.0) + REQUIRE(tosa_validate_element_accfp32(4.0, 4.0, 4.0, KS).error == 0.0); + REQUIRE(tosa_validate_element_accfp32(4.0, 4.0, 4.0, KS).error == 0.0); + REQUIRE(tosa_validate_element_accfp32(4.0, 4.0, 4.0, KS).error == 0.0); + } + TEST_CASE("check_output_error") + { + const size_t KS = 27; + const size_t T = 1024; + + // Negative (S!=1 && S!=2 && (abs(err_sum) > 2*sqrt(KS*T))) + REQUIRE_FALSE(tosa_validate_output_error(1560, 112000, KS, T, 0)); + // Negative (err_sum_sq > 0.4*KS*T)) + REQUIRE_FALSE(tosa_validate_output_error(1560, 112000, KS, T, 1)); + // Positive + REQUIRE(tosa_validate_output_error(10, 254, KS, T, 0)); + REQUIRE(tosa_validate_output_error(10, 254, KS, T, 1)); + } +} |