aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorGeorgios Pinitas <georgios.pinitas@arm.com>2023-08-22 08:25:57 +0100
committerEric Kunze <eric.kunze@arm.com>2023-11-28 16:35:26 -0800
commit9c3754715368d84567db883bdbafc31860850141 (patch)
treed4481466af42e1a19b193228e96253a7e4fbce4b
parent6e7b8b20c21a1d668b9401385d30525405e17125 (diff)
downloadreference_model-9c3754715368d84567db883bdbafc31860850141.tar.gz
Rework TOSA verification API
Change verifier API to consume verification configuration in a JSON format and enable appropriate validation to be performed within the verifier code in the reference model. Also update to latest spec changes for main compliance but not yet including bias support. Signed-off-by: Georgios Pinitas <georgios.pinitas@arm.com> Signed-off-by: Jeremy Johnson <jeremy.johnson@arm.com> Change-Id: I0ceaa1714dd041b00b5b29cd937c8f05e701bc4c
-rw-r--r--reference_model/CMakeLists.txt12
-rw-r--r--reference_model/include/operators.h43
-rw-r--r--reference_model/include/types.h75
-rw-r--r--reference_model/include/verify.h53
-rw-r--r--reference_model/src/verify.cc126
-rw-r--r--reference_model/src/verify/verifiers.h38
-rw-r--r--reference_model/src/verify/verify_dot_product.cc143
-rw-r--r--reference_model/src/verify/verify_entry.cc92
-rw-r--r--reference_model/src/verify/verify_utils.cc121
-rw-r--r--reference_model/src/verify/verify_utils.h81
-rw-r--r--reference_model/test/verify_tests.cpp122
11 files changed, 668 insertions, 238 deletions
diff --git a/reference_model/CMakeLists.txt b/reference_model/CMakeLists.txt
index 5a0195c..227d19f 100644
--- a/reference_model/CMakeLists.txt
+++ b/reference_model/CMakeLists.txt
@@ -56,6 +56,10 @@ if(NOT HALF_DIR)
set(HALF_DIR "../thirdparty/serialization_lib/third_party/half")
endif()
+if(NOT JSON_DIR)
+ set(JSON_DIR "../thirdparty/json")
+endif()
+
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/include)
# Common sources required for TOSA Reference Model library, executable and unit tests
@@ -67,7 +71,9 @@ set(CXX_SOURCE
src/operators.cc
src/subgraph_traverser.cc
src/tensor.cc
- src/verify.cc
+ src/verify/verify_dot_product.cc
+ src/verify/verify_entry.cc
+ src/verify/verify_utils.cc
src/ops/op_factory.cc
src/ops/tensor_ops.cc
src/ops/activation_funcs.cc
@@ -99,6 +105,7 @@ target_include_directories(tosa_reference_model_lib
${EIGEN_DIR}/unsupported/
${SERIALIZATION_DIR}/include
${HALF_DIR}/include
+ ${JSON_DIR}/single_include
)
target_link_libraries(tosa_reference_model_lib
@@ -116,6 +123,7 @@ list(APPEND PUBLIC_HEADERS
include/graph_status.h
include/model_common.h
include/model_runner.h
+ include/types.h
include/verify.h
include/version.h
)
@@ -140,6 +148,7 @@ if(BUILD_TOSA_REFERENCE_MODEL_EXECUTABLE)
${EIGEN_DIR}/unsupported/
${SERIALIZATION_DIR}/include
${HALF_DIR}/include
+ ${JSON_DIR}/single_include
)
target_link_libraries(tosa_reference_model
@@ -184,6 +193,7 @@ if(BUILD_TOSA_REFERENCE_MODEL_TESTS)
${SERIALIZATION_DIR}/include
${HALF_DIR}/include
${DOCTEST_DIR}
+ ${JSON_DIR}/single_include
)
target_link_libraries(unit_tests
diff --git a/reference_model/include/operators.h b/reference_model/include/operators.h
index 14ad236..1519d20 100644
--- a/reference_model/include/operators.h
+++ b/reference_model/include/operators.h
@@ -21,6 +21,7 @@
#include "func_config.h"
#include "func_debug.h"
+#include "types.h"
#include <stddef.h>
#include <stdint.h>
@@ -30,48 +31,6 @@ extern "C"
{
#endif /* __cplusplus */
- // Note status needs to be aligned with graph_status
- enum tosa_status_t
- {
- tosa_status_valid = 0,
- tosa_status_unpredictable = 1,
- tosa_status_error = 2
- };
-
- enum tosa_mode_t
- {
- tosa_mode_unknown = 0,
- tosa_mode_nearest = 1,
- tosa_mode_bilinear = 2,
- tosa_mode_min = 3,
- tosa_mode_max = 4
- };
-
- enum tosa_datatype_t
- {
- tosa_datatype_bf16_t = 0,
- tosa_datatype_bool_t = 1,
- tosa_datatype_fp16_t = 2,
- tosa_datatype_fp32_t = 3,
- tosa_datatype_int16_t = 4,
- tosa_datatype_int32_t = 5,
- tosa_datatype_int48_t = 6,
- tosa_datatype_int4_t = 7,
- tosa_datatype_int8_t = 8,
- tosa_datatype_uint16_t = 9,
- tosa_datatype_uint8_t = 10,
- tosa_datatype_shape_t = 11,
- };
-
- struct tosa_tensor_t
- {
- int32_t* shape;
- int32_t num_dims;
- tosa_datatype_t data_type;
- uint8_t* data;
- size_t size;
- };
-
tosa_status_t tosa_run_argmax(tosa_tensor_t client_input,
const int32_t client_axis,
tosa_tensor_t client_output,
diff --git a/reference_model/include/types.h b/reference_model/include/types.h
new file mode 100644
index 0000000..42040bf
--- /dev/null
+++ b/reference_model/include/types.h
@@ -0,0 +1,75 @@
+
+// Copyright (c) 2022-2023, ARM Limited.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef TYPES_H_
+#define TYPES_H_
+
+#include <stddef.h>
+#include <stdint.h>
+
+#ifdef __cplusplus
+extern "C"
+{
+#endif /* __cplusplus */
+
+ // Note status needs to be aligned with graph_status
+ enum tosa_status_t
+ {
+ tosa_status_valid = 0,
+ tosa_status_unpredictable = 1,
+ tosa_status_error = 2
+ };
+
+ enum tosa_mode_t
+ {
+ tosa_mode_unknown = 0,
+ tosa_mode_nearest = 1,
+ tosa_mode_bilinear = 2,
+ tosa_mode_min = 3,
+ tosa_mode_max = 4
+ };
+
+ enum tosa_datatype_t
+ {
+ tosa_datatype_bf16_t = 0,
+ tosa_datatype_bool_t = 1,
+ tosa_datatype_fp16_t = 2,
+ tosa_datatype_fp32_t = 3,
+ tosa_datatype_int16_t = 4,
+ tosa_datatype_int32_t = 5,
+ tosa_datatype_int48_t = 6,
+ tosa_datatype_int4_t = 7,
+ tosa_datatype_int8_t = 8,
+ tosa_datatype_uint16_t = 9,
+ tosa_datatype_uint8_t = 10,
+ tosa_datatype_shape_t = 11,
+ tosa_datatype_fp64_t = 99
+ };
+
+ struct tosa_tensor_t
+ {
+ const char* name;
+ int32_t* shape;
+ int32_t num_dims;
+ tosa_datatype_t data_type;
+ uint8_t* data;
+ size_t size;
+ };
+
+#ifdef __cplusplus
+}
+#endif /* __cplusplus */
+
+#endif // TYPES_H_ \ No newline at end of file
diff --git a/reference_model/include/verify.h b/reference_model/include/verify.h
index d294388..e449ff7 100644
--- a/reference_model/include/verify.h
+++ b/reference_model/include/verify.h
@@ -18,54 +18,29 @@
//
//===----------------------------------------------------------------------===//
-#include <cstddef>
+#include "types.h"
#ifdef __cplusplus
extern "C"
{
#endif /* __cplusplus */
- // Check result
- //
- // Error is valid only and only if is_valid is true
- struct CheckResult
- {
- bool is_valid;
- double error;
- };
-
- /// Validate and calculate tensor element error when using an fp32 accumulator
- ///
- /// \param ref Tensor element calculated using fp64
- /// \param bnd Tensor element calculated using fp64 on abs(input, weights)
- /// \param imp Tensor element calculated through the implementation
- /// \param KS The kernel size
- ///
- /// \return Output error
- CheckResult tosa_validate_element_accfp32(double ref, double bnd, float imp, size_t KS);
-
- /// Validate the accumulated output error
- ///
- /// \param err_sum Sum of error of all the tensor elements within a tensor
- /// \param err_sum_sq Sum of error squares of all the tensor elements within a tensor
- /// \param T Number of output (dot-product) elements
- /// \param KS The kernel size
- /// \param S Test set used as a input/weight generator
+ /// \brief Perform compliance validation between a reference and a target output
///
- /// \return True if the error is within margin else false
- bool tosa_validate_output_error(double err_sum, double err_sum_sq, size_t T, size_t KS, int S);
-
- /// Validate error of whole vector of output data
+ /// A compliance configuration is expected as it provides information about
+ /// the type of validation to be performed alongside with all the relevant
+ /// meta-data. Configuration is provided in JSON format.
///
- /// \param ref Output elements calculated using fp64
- /// \param bnd Output elements calculated using fp64 on abs(input, weights)
- /// \param imp Output elements calculated using the implementation
- /// \param T Number of elements in outputs (need to match)
- /// \param KS The kernel size
- /// \param S Test set used as a input/weight generator
+ /// \param ref Reference tensor to compare against
+ /// \param ref_bnd (Optional) Reference tensor when run on absolute inputs
+ /// \param imp Implementation resulting tensor
+ /// \param config_json Compliance configuration that indicates how and what compliance need to be performed
///
- /// \return True if the error is within margin else false
- bool tosa_validate_data_fp32(const double* ref, const double* bnd, const float* imp, size_t T, size_t KS, int S);
+ /// \return True in case of successful validation else false
+ bool tvf_verify_data(const tosa_tensor_t* ref,
+ const tosa_tensor_t* ref_bnd,
+ const tosa_tensor_t* imp,
+ const char* config_json);
#ifdef __cplusplus
}
diff --git a/reference_model/src/verify.cc b/reference_model/src/verify.cc
deleted file mode 100644
index 450dcbf..0000000
--- a/reference_model/src/verify.cc
+++ /dev/null
@@ -1,126 +0,0 @@
-// Copyright (c) 2023, ARM Limited.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-//===----------------------------------------------------------------------===//
-//
-// Verification functionality as per TOSA Specification
-// Output Verification : Section 1.8.2
-//
-//===----------------------------------------------------------------------===//
-
-#include "verify.h"
-
-#include <half.hpp>
-
-#include <cmath>
-#include <numeric>
-#include <optional>
-#include <type_traits>
-
-#define REQUIRE(COND) \
- if (!(COND)) \
- { \
- return false; \
- }
-
-namespace
-{
-// Accumulator precision
-template <typename T>
-struct AccPrecision;
-template <>
-struct AccPrecision<float>
-{
- static constexpr double precision = (double)(1 << 24);
-};
-
-// Generic element validation function
-template <typename AccType, typename std::enable_if_t<std::is_floating_point_v<AccType>, int> = 0>
-std::optional<double> validate_element(double ref, double bnd, AccType imp, size_t KS)
-{
- double err = 0.0;
- bool is_valid = true;
-
- if (bnd == 0.0)
- {
- is_valid = (ref == 0.0) && (imp == 0.0);
- err = 0.0;
- }
- else
- { // bnd > 0.0
- const double imp_fp64 = static_cast<double>(imp);
- const double acc_prec_fp64 = AccPrecision<AccType>::precision;
- err = (imp_fp64 - ref) * acc_prec_fp64 / bnd;
- is_valid = std::abs(err) <= KS;
- }
-
- return is_valid ? std::optional(err) : std::nullopt;
-}
-
-// Generic data validation function
-template <typename AccType, typename std::enable_if_t<std::is_floating_point_v<AccType>, int> = 0>
-bool validate_data(const double* ref, const double* bnd, const AccType* imp, size_t T, size_t KS, int32_t S)
-{
- double out_err_sum = 0.0;
- double out_err_sumsq = 0.0;
-
- for (size_t i = 0; i < T; ++i)
- {
- auto out_err = validate_element<AccType>(ref[i], bnd[i], imp[i], KS);
- REQUIRE(out_err);
- out_err_sum += out_err.value();
- out_err_sumsq += out_err.value() * out_err.value();
- }
-
- return tosa_validate_output_error(out_err_sum, out_err_sumsq, T, KS, S);
-}
-
-// Convert std::optional to CheckResult
-CheckResult from_optional(const std::optional<double>& res)
-{
- if (res)
- return { true, *res };
- else
- return { false, 0.0 };
-}
-} // namespace
-
-extern "C"
-{
-
- CheckResult tosa_validate_element_accfp32(double ref, double bnd, float imp, size_t KS)
- {
- auto err = validate_element<float>(ref, bnd, imp, KS);
- return from_optional(err);
- }
-
- bool tosa_validate_output_error(double err_sum, double err_sum_sq, size_t T, size_t KS, int S)
- {
- if (S != 1 && S != 2)
- {
- // Check error bias magnitude for data sets S which are not positive biased
- REQUIRE(std::abs(err_sum) <= 2 * sqrt(KS * T));
- }
- // Check error variance magnitude
- REQUIRE(err_sum_sq <= 0.4 * KS * T);
-
- return true;
- }
-
- bool tosa_validate_data_fp32(const double* ref, const double* bnd, const float* imp, size_t T, size_t KS, int S)
- {
- return validate_data<float>(ref, bnd, imp, T, KS, S);
- }
-
-} // extern "C"
-#undef REQUIRE \ No newline at end of file
diff --git a/reference_model/src/verify/verifiers.h b/reference_model/src/verify/verifiers.h
new file mode 100644
index 0000000..afd50bf
--- /dev/null
+++ b/reference_model/src/verify/verifiers.h
@@ -0,0 +1,38 @@
+
+// Copyright (c) 2023, ARM Limited.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef VERIFIERS_H_
+#define VERIFIERS_H_
+
+#include "verify_utils.h"
+
+namespace TosaReference
+{
+/// \brief Perform dot-product based verification
+///
+/// \param ref Reference tensor
+/// \param refBnd Reference tensor when ran on abs(input)
+/// \param imp Implementation resulting tensor
+/// \param dpInfo Dot-product verification meta-data
+///
+/// \return True if compliant else false
+bool verifyDotProduct(const CTensor* ref,
+ const CTensor* refBnd,
+ const CTensor* imp,
+ const DotProductVerifyInfo& dpInfo);
+
+}; // namespace TosaReference
+
+#endif // VERIFIERS_H_
diff --git a/reference_model/src/verify/verify_dot_product.cc b/reference_model/src/verify/verify_dot_product.cc
new file mode 100644
index 0000000..a24f83f
--- /dev/null
+++ b/reference_model/src/verify/verify_dot_product.cc
@@ -0,0 +1,143 @@
+// Copyright (c) 2023, ARM Limited.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "func_debug.h"
+#include "verifiers.h"
+
+#include <cmath>
+#include <numeric>
+#include <optional>
+#include <type_traits>
+
+#define TOSA_REF_REQUIRE(COND, MESSAGE) \
+ if (!(COND)) \
+ { \
+ WARNING(MESSAGE); \
+ return false; \
+ }
+
+namespace TosaReference
+{
+namespace
+{
+
+// Accumulator precision
+template <typename T>
+struct AccPrecision;
+#define two_m42 1.0 / (double)(((int64_t)1) << 42) // 2^-42
+template <>
+struct AccPrecision<float>
+{
+ static constexpr double precision = (double)(1 << 24);
+ static constexpr double min_normal = two_m42 * two_m42 * two_m42; // 2^-126
+};
+#undef two_m42
+
+// Generic element validation function
+template <typename AccType, typename std::enable_if_t<std::is_floating_point_v<AccType>, int> = 0>
+std::optional<double> validateElement(double ref, double bnd, AccType imp, size_t KS)
+{
+ double err = 0.0;
+ bool is_valid = true;
+
+ if (bnd == 0.0)
+ {
+ is_valid = (ref == 0.0) && (imp == 0.0);
+ err = 0.0;
+ }
+ else if (std::isinf(static_cast<AccType>(bnd)))
+ {
+ // dot product can overflow and there is no accuracy limit
+ is_valid = true;
+ err = 0.0;
+ }
+ else
+ {
+ // 0.0 < bnd < infinity
+ const double bnd_norm = std::max(bnd, AccPrecision<AccType>::min_normal);
+ const double imp_fp64 = static_cast<double>(imp);
+ const double acc_prec_fp64 = AccPrecision<AccType>::precision;
+ err = (imp_fp64 - ref) * acc_prec_fp64 / bnd_norm;
+ is_valid = std::abs(err) <= KS;
+ }
+
+ return is_valid ? std::optional(err) : std::nullopt;
+}
+
+// Generic data validation function
+template <typename AccType, typename std::enable_if_t<std::is_floating_point_v<AccType>, int> = 0>
+bool validateData(const double* ref, const double* bnd, const AccType* imp, size_t T, const DotProductVerifyInfo& cfg)
+{
+ const int32_t S = cfg.s;
+ // TODO - needed for other ops - (max_value(bias_abs) > 0) ? (KS + 1) : KS
+ const int32_t KS = cfg.ks;
+
+ double out_err_sum = 0.0;
+ double out_err_sumsq = 0.0;
+
+ for (size_t i = 0; i < T; ++i)
+ {
+ auto out_err = validateElement<AccType>(ref[i], bnd[i], imp[i], KS);
+ TOSA_REF_REQUIRE(out_err, "output error is 0");
+ out_err_sum += out_err.value();
+ out_err_sumsq += out_err.value() * out_err.value();
+ }
+
+ if (S >= 3 && S <= 5)
+ {
+ // Check error bias magnitude for data sets S which are not positive biased
+ TOSA_REF_REQUIRE(std::abs(out_err_sum) <= 2 * sqrt(KS * T), "bias magnitude is out of range");
+ }
+ // Check error variance magnitude
+ TOSA_REF_REQUIRE(out_err_sumsq <= 0.4 * KS * T, "error variance magnitude is out of range");
+ return true;
+}
+} // namespace
+
+bool verifyDotProduct(const CTensor* ref, const CTensor* refBnd, const CTensor* imp, const DotProductVerifyInfo& dpInfo)
+{
+ // Validate that tensors are provided
+ TOSA_REF_REQUIRE(ref != nullptr, "reference tensor is missing");
+ TOSA_REF_REQUIRE(refBnd != nullptr, "reference bounds tensor is missing");
+ TOSA_REF_REQUIRE(imp != nullptr, "implementation tensor is missing");
+
+ // Validate data-type
+ TOSA_REF_REQUIRE(dpInfo.dataType == mapToDType(imp->data_type), "invalid data type in config");
+
+ // Get number of dot-product elements
+ const int64_t T = numElements(std::vector<int32_t>(ref->shape, ref->shape + ref->num_dims));
+ TOSA_REF_REQUIRE(T > 0, "invalid shape for reference tensor");
+
+ const double* refData = reinterpret_cast<const double*>(ref->data);
+ const double* refBndData = reinterpret_cast<const double*>(refBnd->data);
+ TOSA_REF_REQUIRE(refData != nullptr && refBndData != nullptr, "missing data for reference or bounds tensors");
+
+ switch (imp->data_type)
+ {
+ case tosa_datatype_fp32_t: {
+ const float* impData = reinterpret_cast<const float*>(imp->data);
+ TOSA_REF_REQUIRE(impData != nullptr, "missing data for implementation");
+ return validateData(refData, refBndData, impData, static_cast<size_t>(T), dpInfo);
+ break;
+ }
+ default:
+ break;
+ }
+
+ return false;
+}
+
+} // namespace TosaReference
+
+#undef TOSA_REF_REQUIRE \ No newline at end of file
diff --git a/reference_model/src/verify/verify_entry.cc b/reference_model/src/verify/verify_entry.cc
new file mode 100644
index 0000000..80ca916
--- /dev/null
+++ b/reference_model/src/verify/verify_entry.cc
@@ -0,0 +1,92 @@
+// Copyright (c) 2023, ARM Limited.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "verify.h"
+
+#include "func_debug.h"
+#include "model_common.h"
+#include "verifiers.h"
+#include "verify_utils.h"
+
+#include <vector>
+
+namespace TosaReference
+{
+
+bool verify(const CTensor* ref, const CTensor* refBnd, const CTensor* imp, const VerifyConfig& cfg)
+{
+ switch (cfg.mode)
+ {
+ case VerifyMode::DotProduct: {
+ return verifyDotProduct(ref, refBnd, imp, cfg.dotProductInfo);
+ break;
+ }
+ default: {
+ WARNING("unsupported verification mode.");
+ break;
+ }
+ }
+ return false;
+}
+
+} // namespace TosaReference
+
+extern "C"
+{
+ bool tvf_verify_data(const tosa_tensor_t* ref,
+ const tosa_tensor_t* ref_bnd,
+ const tosa_tensor_t* imp,
+ const char* config_json)
+ {
+ // Check inputs for nullptr
+ if (!ref || !imp || !config_json)
+ {
+ WARNING("one of the inputs is missing.");
+ return false;
+ }
+
+ // Extract verification config
+ if (!ref->name)
+ {
+ WARNING("tensor name is not specified.");
+ return false;
+ }
+ auto cfg = TosaReference::parseVerifyConfig(ref->name, config_json);
+ if (!cfg)
+ {
+ WARNING("invalid json config.");
+ return false;
+ }
+
+ // Validate shape
+ if (ref->num_dims != imp->num_dims)
+ {
+ WARNING("tensors have different number of dimensions.");
+ return false;
+ }
+ if (!ref->shape || !imp->shape)
+ {
+ WARNING("one of tensors' shape is missing.");
+ return false;
+ }
+ if (std::vector(ref->shape, ref->shape + ref->num_dims) != std::vector(imp->shape, imp->shape + imp->num_dims))
+ {
+ WARNING("tensors have different shapes.");
+ return false;
+ }
+
+ // Run verification
+ return verify(ref, ref_bnd, imp, *cfg);
+ }
+} // extern "C"
diff --git a/reference_model/src/verify/verify_utils.cc b/reference_model/src/verify/verify_utils.cc
new file mode 100644
index 0000000..bb4feaa
--- /dev/null
+++ b/reference_model/src/verify/verify_utils.cc
@@ -0,0 +1,121 @@
+
+// Copyright (c) 2023, ARM Limited.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "verify_utils.h"
+
+#include <nlohmann/json.hpp>
+
+#include <algorithm>
+#include <map>
+
+namespace tosa
+{
+
+NLOHMANN_JSON_SERIALIZE_ENUM(DType,
+ {
+ { DType::DType_BOOL, "BOOL" },
+ { DType::DType_INT4, "INT4" },
+ { DType::DType_INT8, "INT8" },
+ { DType::DType_INT16, "INT16" },
+ { DType::DType_INT32, "INT32" },
+ { DType::DType_INT48, "INT48" },
+ { DType::DType_FP16, "FP16" },
+ { DType::DType_BF16, "BF16" },
+ { DType::DType_FP32, "FP32" },
+ })
+
+} // namespace tosa
+
+namespace TosaReference
+{
+
+NLOHMANN_JSON_SERIALIZE_ENUM(VerifyMode,
+ {
+ { VerifyMode::Exact, "EXACT" },
+ { VerifyMode::Ulp, "ULP" },
+ { VerifyMode::DotProduct, "DOT_PRODUCT" },
+ { VerifyMode::ReduceProduct, "REDUCE_PRODUCT" },
+ { VerifyMode::FpSpecial, "FP_SPECIAL" },
+ })
+
+void from_json(const nlohmann::json& j, UlpInfo& ulpInfo)
+{
+ j.at("ulp").get_to(ulpInfo.ulp);
+}
+
+void from_json(const nlohmann::json& j, DotProductVerifyInfo& dotProductInfo)
+{
+ j.at("data_type").get_to(dotProductInfo.dataType);
+ j.at("s").get_to(dotProductInfo.s);
+ j.at("ks").get_to(dotProductInfo.ks);
+}
+
+void from_json(const nlohmann::json& j, VerifyConfig& cfg)
+{
+ j.at("mode").get_to(cfg.mode);
+ if (j.contains("ulp_info"))
+ {
+ j.at("ulp_info").get_to(cfg.ulpInfo);
+ }
+ if (j.contains("dot_product_info"))
+ {
+ j.at("dot_product_info").get_to(cfg.dotProductInfo);
+ }
+}
+
+std::optional<VerifyConfig> parseVerifyConfig(const char* tensorName, const char* json)
+{
+ if (!tensorName)
+ return std::nullopt;
+
+ auto jsonCfg = nlohmann::json::parse(json, nullptr, /* allow exceptions */ false);
+
+ if (jsonCfg.is_discarded())
+ return std::nullopt;
+ if (!jsonCfg.contains("tensors"))
+ return std::nullopt;
+
+ const auto& tensors = jsonCfg["tensors"];
+ if (!tensors.contains(tensorName))
+ return std::nullopt;
+
+ const auto& namedTensor = tensors[tensorName];
+ return namedTensor.get<VerifyConfig>();
+}
+
+int64_t numElements(const std::vector<int32_t>& shape)
+{
+ return std::accumulate(std::begin(shape), std::end(shape), 1, std::multiplies<int64_t>());
+}
+
+DType mapToDType(tosa_datatype_t dataType)
+{
+ static std::map<tosa_datatype_t, DType> typeMap = {
+ { tosa_datatype_bool_t, DType_BOOL }, { tosa_datatype_int4_t, DType_INT4 },
+ { tosa_datatype_int8_t, DType_INT8 }, { tosa_datatype_uint16_t, DType_UINT16 },
+ { tosa_datatype_int16_t, DType_INT16 }, { tosa_datatype_int32_t, DType_INT32 },
+ { tosa_datatype_int48_t, DType_INT48 }, { tosa_datatype_fp16_t, DType_FP16 },
+ { tosa_datatype_bf16_t, DType_BF16 }, { tosa_datatype_fp32_t, DType_FP32 },
+ { tosa_datatype_shape_t, DType_SHAPE },
+ };
+
+ if (typeMap.count(dataType))
+ {
+ return typeMap[dataType];
+ }
+
+ return DType_UNKNOWN;
+}
+} // namespace TosaReference
diff --git a/reference_model/src/verify/verify_utils.h b/reference_model/src/verify/verify_utils.h
new file mode 100644
index 0000000..6e51e3e
--- /dev/null
+++ b/reference_model/src/verify/verify_utils.h
@@ -0,0 +1,81 @@
+
+// Copyright (c) 2023, ARM Limited.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef VERIFY_UTILS_H_
+#define VERIFY_UTILS_H_
+
+#include "dtype.h"
+#include "types.h"
+
+#include <cstdint>
+#include <optional>
+#include <vector>
+
+namespace TosaReference
+{
+
+// Name alias
+using CTensor = tosa_tensor_t;
+
+/// \brief Supported verification modes
+enum class VerifyMode
+{
+ Exact,
+ Ulp,
+ DotProduct,
+ ReduceProduct,
+ FpSpecial
+};
+
+/// \brief ULP verification meta-data
+struct UlpInfo
+{
+ UlpInfo() = default;
+
+ float ulp;
+};
+
+/// \brief Dot-product verification meta-data
+struct DotProductVerifyInfo
+{
+ DotProductVerifyInfo() = default;
+
+ DType dataType;
+ int32_t s;
+ int32_t ks;
+};
+
+/// \brief Verification meta-data
+struct VerifyConfig
+{
+ VerifyConfig() = default;
+
+ VerifyMode mode;
+ UlpInfo ulpInfo;
+ DotProductVerifyInfo dotProductInfo;
+};
+
+/// \brief Parse the verification config for a tensor when given in JSON form
+std::optional<VerifyConfig> parseVerifyConfig(const char* tensorName, const char* configJson);
+
+/// \brief Extract number of total elements
+int64_t numElements(const std::vector<int32_t>& shape);
+
+/// \brief Map API data-type to DType
+DType mapToDType(tosa_datatype_t dataType);
+
+}; // namespace TosaReference
+
+#endif // VERIFY_UTILS_H_
diff --git a/reference_model/test/verify_tests.cpp b/reference_model/test/verify_tests.cpp
index f43804f..81f3e8d 100644
--- a/reference_model/test/verify_tests.cpp
+++ b/reference_model/test/verify_tests.cpp
@@ -16,41 +16,103 @@
#include <doctest.h>
#include <array>
-#include <numeric>
+#include <string>
#include <vector>
-TEST_SUITE("verify")
+namespace
{
- TEST_CASE("check_element_accfp32")
+
+class TosaTensor
+{
+public:
+ TosaTensor(std::string name, tosa_datatype_t dataType, std::vector<int32_t> shape)
+ : _name(std::move(name))
+ , _shape(std::move(shape))
+ {
+ _tensor.name = _name.c_str();
+ _tensor.data_type = dataType;
+ _tensor.num_dims = _shape.size();
+ _tensor.shape = _shape.data();
+ };
+
+ const tosa_tensor_t* cTensor() const
{
- const size_t KS = 27;
-
- // Negative (bnd == 0.0)
- REQUIRE_FALSE(tosa_validate_element_accfp32(0.0, 0.0, 1.0, KS).is_valid);
- REQUIRE_FALSE(tosa_validate_element_accfp32(1.0, 0.0, 0.0, KS).is_valid);
- // Negative (bnd > 0.0)
- REQUIRE_FALSE(tosa_validate_element_accfp32(5.0, 5.0, 5.1, KS).is_valid);
-
- // Positive (bnd == 0.0 && ref == 0.0 && imp == 0.0)
- REQUIRE(tosa_validate_element_accfp32(0.0, 0.0, 0.0, KS).is_valid);
- REQUIRE(tosa_validate_element_accfp32(0.0, 0.0, 0.0, KS).error == 0.0);
-
- // Positive (bnd > 0.0)
- REQUIRE(tosa_validate_element_accfp32(4.0, 4.0, 4.0, KS).error == 0.0);
- REQUIRE(tosa_validate_element_accfp32(4.0, 4.0, 4.0, KS).error == 0.0);
- REQUIRE(tosa_validate_element_accfp32(4.0, 4.0, 4.0, KS).error == 0.0);
+ return &_tensor;
}
- TEST_CASE("check_output_error")
+
+private:
+ std::string _name;
+ std::vector<int32_t> _shape;
+ tosa_tensor_t _tensor;
+};
+
+} // namespace
+
+TEST_SUITE_BEGIN("verify");
+
+TEST_CASE("negative - api")
+{
+ std::string json_cfg = R"({
+ "tensors" : {
+ "out1" : {
+ "mode": "DOT_PRODUCT",
+ "dot_product_info" : {
+ "data_type": "FP32",
+ "s": 2,
+ "ks": 9
+ }
+ }
+ }
+ })";
+
+ SUBCASE("invalid json")
{
- const size_t KS = 27;
- const size_t T = 1024;
-
- // Negative (S!=1 && S!=2 && (abs(err_sum) > 2*sqrt(KS*T)))
- REQUIRE_FALSE(tosa_validate_output_error(1560, 112000, KS, T, 0));
- // Negative (err_sum_sq > 0.4*KS*T))
- REQUIRE_FALSE(tosa_validate_output_error(1560, 112000, KS, T, 1));
- // Positive
- REQUIRE(tosa_validate_output_error(10, 254, KS, T, 0));
- REQUIRE(tosa_validate_output_error(10, 254, KS, T, 1));
+ std::string invalid_json_cfg = R"({
+ "tensors" : {
+ "out1" : {
+ "mode": DOT_PRODUCT,
+ },
+ }
+ })";
+
+ const TosaTensor ref("out1", tosa_datatype_fp64_t, { 8, 8, 8 });
+ const TosaTensor refAbs("out1", tosa_datatype_fp64_t, { 8, 8, 8 });
+ const TosaTensor imp("out1", tosa_datatype_fp32_t, { 8, 8, 8 });
+
+ REQUIRE_FALSE(tvf_verify_data(ref.cTensor(), refAbs.cTensor(), imp.cTensor(), invalid_json_cfg.c_str()));
+ }
+ SUBCASE("mismatching dimensions")
+ {
+ const TosaTensor ref("out1", tosa_datatype_fp64_t, { 4, 4 });
+ const TosaTensor refAbs("out1", tosa_datatype_fp64_t, { 4, 4 });
+ const TosaTensor imp("out1", tosa_datatype_fp32_t, { 8, 8, 8 });
+
+ REQUIRE_FALSE(tvf_verify_data(ref.cTensor(), refAbs.cTensor(), imp.cTensor(), json_cfg.c_str()));
+ }
+ SUBCASE("mismatching shapes")
+ {
+ const TosaTensor ref("out1", tosa_datatype_fp64_t, { 8, 8, 8 });
+ const TosaTensor refAbs("out1", tosa_datatype_fp64_t, { 8, 8, 8 });
+ const TosaTensor imp("out1", tosa_datatype_fp32_t, { 4, 4, 4 });
+
+ REQUIRE_FALSE(tvf_verify_data(ref.cTensor(), refAbs.cTensor(), imp.cTensor(), json_cfg.c_str()));
+ }
+ SUBCASE("mismatching data types")
+ {
+ const TosaTensor ref("out1", tosa_datatype_fp64_t, { 8, 8, 8 });
+ const TosaTensor refAbs("out1", tosa_datatype_fp64_t, { 8, 8, 8 });
+ const TosaTensor imp("out1", tosa_datatype_fp16_t, { 8, 8, 8 });
+
+ REQUIRE_FALSE(tvf_verify_data(ref.cTensor(), refAbs.cTensor(), imp.cTensor(), json_cfg.c_str()));
+ }
+ SUBCASE("missing tensor data")
+ {
+ const TosaTensor ref("out1", tosa_datatype_fp64_t, { 8, 8, 8 });
+ const TosaTensor refAbs("out1", tosa_datatype_fp64_t, { 8, 8, 8 });
+ const TosaTensor imp("out1", tosa_datatype_fp32_t, { 8, 8, 8 });
+
+ REQUIRE_FALSE(tvf_verify_data(ref.cTensor(), refAbs.cTensor(), imp.cTensor(), json_cfg.c_str()));
}
}
+
+TEST_SUITE_END(); // verify \ No newline at end of file