aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--reference_model/CMakeLists.txt12
-rw-r--r--reference_model/include/operators.h43
-rw-r--r--reference_model/include/types.h75
-rw-r--r--reference_model/include/verify.h53
-rw-r--r--reference_model/src/verify.cc126
-rw-r--r--reference_model/src/verify/verifiers.h38
-rw-r--r--reference_model/src/verify/verify_dot_product.cc143
-rw-r--r--reference_model/src/verify/verify_entry.cc92
-rw-r--r--reference_model/src/verify/verify_utils.cc121
-rw-r--r--reference_model/src/verify/verify_utils.h81
-rw-r--r--reference_model/test/verify_tests.cpp122
11 files changed, 668 insertions, 238 deletions
diff --git a/reference_model/CMakeLists.txt b/reference_model/CMakeLists.txt
index 5a0195c..227d19f 100644
--- a/reference_model/CMakeLists.txt
+++ b/reference_model/CMakeLists.txt
@@ -56,6 +56,10 @@ if(NOT HALF_DIR)
set(HALF_DIR "../thirdparty/serialization_lib/third_party/half")
endif()
+if(NOT JSON_DIR)
+ set(JSON_DIR "../thirdparty/json")
+endif()
+
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/include)
# Common sources required for TOSA Reference Model library, executable and unit tests
@@ -67,7 +71,9 @@ set(CXX_SOURCE
src/operators.cc
src/subgraph_traverser.cc
src/tensor.cc
- src/verify.cc
+ src/verify/verify_dot_product.cc
+ src/verify/verify_entry.cc
+ src/verify/verify_utils.cc
src/ops/op_factory.cc
src/ops/tensor_ops.cc
src/ops/activation_funcs.cc
@@ -99,6 +105,7 @@ target_include_directories(tosa_reference_model_lib
${EIGEN_DIR}/unsupported/
${SERIALIZATION_DIR}/include
${HALF_DIR}/include
+ ${JSON_DIR}/single_include
)
target_link_libraries(tosa_reference_model_lib
@@ -116,6 +123,7 @@ list(APPEND PUBLIC_HEADERS
include/graph_status.h
include/model_common.h
include/model_runner.h
+ include/types.h
include/verify.h
include/version.h
)
@@ -140,6 +148,7 @@ if(BUILD_TOSA_REFERENCE_MODEL_EXECUTABLE)
${EIGEN_DIR}/unsupported/
${SERIALIZATION_DIR}/include
${HALF_DIR}/include
+ ${JSON_DIR}/single_include
)
target_link_libraries(tosa_reference_model
@@ -184,6 +193,7 @@ if(BUILD_TOSA_REFERENCE_MODEL_TESTS)
${SERIALIZATION_DIR}/include
${HALF_DIR}/include
${DOCTEST_DIR}
+ ${JSON_DIR}/single_include
)
target_link_libraries(unit_tests
diff --git a/reference_model/include/operators.h b/reference_model/include/operators.h
index 14ad236..1519d20 100644
--- a/reference_model/include/operators.h
+++ b/reference_model/include/operators.h
@@ -21,6 +21,7 @@
#include "func_config.h"
#include "func_debug.h"
+#include "types.h"
#include <stddef.h>
#include <stdint.h>
@@ -30,48 +31,6 @@ extern "C"
{
#endif /* __cplusplus */
- // Note status needs to be aligned with graph_status
- enum tosa_status_t
- {
- tosa_status_valid = 0,
- tosa_status_unpredictable = 1,
- tosa_status_error = 2
- };
-
- enum tosa_mode_t
- {
- tosa_mode_unknown = 0,
- tosa_mode_nearest = 1,
- tosa_mode_bilinear = 2,
- tosa_mode_min = 3,
- tosa_mode_max = 4
- };
-
- enum tosa_datatype_t
- {
- tosa_datatype_bf16_t = 0,
- tosa_datatype_bool_t = 1,
- tosa_datatype_fp16_t = 2,
- tosa_datatype_fp32_t = 3,
- tosa_datatype_int16_t = 4,
- tosa_datatype_int32_t = 5,
- tosa_datatype_int48_t = 6,
- tosa_datatype_int4_t = 7,
- tosa_datatype_int8_t = 8,
- tosa_datatype_uint16_t = 9,
- tosa_datatype_uint8_t = 10,
- tosa_datatype_shape_t = 11,
- };
-
- struct tosa_tensor_t
- {
- int32_t* shape;
- int32_t num_dims;
- tosa_datatype_t data_type;
- uint8_t* data;
- size_t size;
- };
-
tosa_status_t tosa_run_argmax(tosa_tensor_t client_input,
const int32_t client_axis,
tosa_tensor_t client_output,
diff --git a/reference_model/include/types.h b/reference_model/include/types.h
new file mode 100644
index 0000000..42040bf
--- /dev/null
+++ b/reference_model/include/types.h
@@ -0,0 +1,75 @@
+
+// Copyright (c) 2022-2023, ARM Limited.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef TYPES_H_
+#define TYPES_H_
+
+#include <stddef.h>
+#include <stdint.h>
+
+#ifdef __cplusplus
+extern "C"
+{
+#endif /* __cplusplus */
+
+ // Note status needs to be aligned with graph_status
+ enum tosa_status_t
+ {
+ tosa_status_valid = 0,
+ tosa_status_unpredictable = 1,
+ tosa_status_error = 2
+ };
+
+ enum tosa_mode_t
+ {
+ tosa_mode_unknown = 0,
+ tosa_mode_nearest = 1,
+ tosa_mode_bilinear = 2,
+ tosa_mode_min = 3,
+ tosa_mode_max = 4
+ };
+
+ enum tosa_datatype_t
+ {
+ tosa_datatype_bf16_t = 0,
+ tosa_datatype_bool_t = 1,
+ tosa_datatype_fp16_t = 2,
+ tosa_datatype_fp32_t = 3,
+ tosa_datatype_int16_t = 4,
+ tosa_datatype_int32_t = 5,
+ tosa_datatype_int48_t = 6,
+ tosa_datatype_int4_t = 7,
+ tosa_datatype_int8_t = 8,
+ tosa_datatype_uint16_t = 9,
+ tosa_datatype_uint8_t = 10,
+ tosa_datatype_shape_t = 11,
+ tosa_datatype_fp64_t = 99
+ };
+
+ struct tosa_tensor_t
+ {
+ const char* name;
+ int32_t* shape;
+ int32_t num_dims;
+ tosa_datatype_t data_type;
+ uint8_t* data;
+ size_t size;
+ };
+
+#ifdef __cplusplus
+}
+#endif /* __cplusplus */
+
+#endif // TYPES_H_ \ No newline at end of file
diff --git a/reference_model/include/verify.h b/reference_model/include/verify.h
index d294388..e449ff7 100644
--- a/reference_model/include/verify.h
+++ b/reference_model/include/verify.h
@@ -18,54 +18,29 @@
//
//===----------------------------------------------------------------------===//
-#include <cstddef>
+#include "types.h"
#ifdef __cplusplus
extern "C"
{
#endif /* __cplusplus */
- // Check result
- //
- // Error is valid only and only if is_valid is true
- struct CheckResult
- {
- bool is_valid;
- double error;
- };
-
- /// Validate and calculate tensor element error when using an fp32 accumulator
- ///
- /// \param ref Tensor element calculated using fp64
- /// \param bnd Tensor element calculated using fp64 on abs(input, weights)
- /// \param imp Tensor element calculated through the implementation
- /// \param KS The kernel size
- ///
- /// \return Output error
- CheckResult tosa_validate_element_accfp32(double ref, double bnd, float imp, size_t KS);
-
- /// Validate the accumulated output error
- ///
- /// \param err_sum Sum of error of all the tensor elements within a tensor
- /// \param err_sum_sq Sum of error squares of all the tensor elements within a tensor
- /// \param T Number of output (dot-product) elements
- /// \param KS The kernel size
- /// \param S Test set used as a input/weight generator
+ /// \brief Perform compliance validation between a reference and a target output
///
- /// \return True if the error is within margin else false
- bool tosa_validate_output_error(double err_sum, double err_sum_sq, size_t T, size_t KS, int S);
-
- /// Validate error of whole vector of output data
+ /// A compliance configuration is expected as it provides information about
+ /// the type of validation to be performed alongside with all the relevant
+ /// meta-data. Configuration is provided in JSON format.
///
- /// \param ref Output elements calculated using fp64
- /// \param bnd Output elements calculated using fp64 on abs(input, weights)
- /// \param imp Output elements calculated using the implementation
- /// \param T Number of elements in outputs (need to match)
- /// \param KS The kernel size
- /// \param S Test set used as a input/weight generator
+ /// \param ref Reference tensor to compare against
+ /// \param ref_bnd (Optional) Reference tensor when run on absolute inputs
+ /// \param imp Implementation resulting tensor
+ /// \param config_json Compliance configuration that indicates how and what compliance need to be performed
///
- /// \return True if the error is within margin else false
- bool tosa_validate_data_fp32(const double* ref, const double* bnd, const float* imp, size_t T, size_t KS, int S);
+ /// \return True in case of successful validation else false
+ bool tvf_verify_data(const tosa_tensor_t* ref,
+ const tosa_tensor_t* ref_bnd,
+ const tosa_tensor_t* imp,
+ const char* config_json);
#ifdef __cplusplus
}
diff --git a/reference_model/src/verify.cc b/reference_model/src/verify.cc
deleted file mode 100644
index 450dcbf..0000000
--- a/reference_model/src/verify.cc
+++ /dev/null
@@ -1,126 +0,0 @@
-// Copyright (c) 2023, ARM Limited.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-//===----------------------------------------------------------------------===//
-//
-// Verification functionality as per TOSA Specification
-// Output Verification : Section 1.8.2
-//
-//===----------------------------------------------------------------------===//
-
-#include "verify.h"
-
-#include <half.hpp>
-
-#include <cmath>
-#include <numeric>
-#include <optional>
-#include <type_traits>
-
-#define REQUIRE(COND) \
- if (!(COND)) \
- { \
- return false; \
- }
-
-namespace
-{
-// Accumulator precision
-template <typename T>
-struct AccPrecision;
-template <>
-struct AccPrecision<float>
-{
- static constexpr double precision = (double)(1 << 24);
-};
-
-// Generic element validation function
-template <typename AccType, typename std::enable_if_t<std::is_floating_point_v<AccType>, int> = 0>
-std::optional<double> validate_element(double ref, double bnd, AccType imp, size_t KS)
-{
- double err = 0.0;
- bool is_valid = true;
-
- if (bnd == 0.0)
- {
- is_valid = (ref == 0.0) && (imp == 0.0);
- err = 0.0;
- }
- else
- { // bnd > 0.0
- const double imp_fp64 = static_cast<double>(imp);
- const double acc_prec_fp64 = AccPrecision<AccType>::precision;
- err = (imp_fp64 - ref) * acc_prec_fp64 / bnd;
- is_valid = std::abs(err) <= KS;
- }
-
- return is_valid ? std::optional(err) : std::nullopt;
-}
-
-// Generic data validation function
-template <typename AccType, typename std::enable_if_t<std::is_floating_point_v<AccType>, int> = 0>
-bool validate_data(const double* ref, const double* bnd, const AccType* imp, size_t T, size_t KS, int32_t S)
-{
- double out_err_sum = 0.0;
- double out_err_sumsq = 0.0;
-
- for (size_t i = 0; i < T; ++i)
- {
- auto out_err = validate_element<AccType>(ref[i], bnd[i], imp[i], KS);
- REQUIRE(out_err);
- out_err_sum += out_err.value();
- out_err_sumsq += out_err.value() * out_err.value();
- }
-
- return tosa_validate_output_error(out_err_sum, out_err_sumsq, T, KS, S);
-}
-
-// Convert std::optional to CheckResult
-CheckResult from_optional(const std::optional<double>& res)
-{
- if (res)
- return { true, *res };
- else
- return { false, 0.0 };
-}
-} // namespace
-
-extern "C"
-{
-
- CheckResult tosa_validate_element_accfp32(double ref, double bnd, float imp, size_t KS)
- {
- auto err = validate_element<float>(ref, bnd, imp, KS);
- return from_optional(err);
- }
-
- bool tosa_validate_output_error(double err_sum, double err_sum_sq, size_t T, size_t KS, int S)
- {
- if (S != 1 && S != 2)
- {
- // Check error bias magnitude for data sets S which are not positive biased
- REQUIRE(std::abs(err_sum) <= 2 * sqrt(KS * T));
- }
- // Check error variance magnitude
- REQUIRE(err_sum_sq <= 0.4 * KS * T);
-
- return true;
- }
-
- bool tosa_validate_data_fp32(const double* ref, const double* bnd, const float* imp, size_t T, size_t KS, int S)
- {
- return validate_data<float>(ref, bnd, imp, T, KS, S);
- }
-
-} // extern "C"
-#undef REQUIRE \ No newline at end of file
diff --git a/reference_model/src/verify/verifiers.h b/reference_model/src/verify/verifiers.h
new file mode 100644
index 0000000..afd50bf
--- /dev/null
+++ b/reference_model/src/verify/verifiers.h
@@ -0,0 +1,38 @@
+
+// Copyright (c) 2023, ARM Limited.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef VERIFIERS_H_
+#define VERIFIERS_H_
+
+#include "verify_utils.h"
+
+namespace TosaReference
+{
+/// \brief Perform dot-product based verification
+///
+/// \param ref Reference tensor
+/// \param refBnd Reference tensor when ran on abs(input)
+/// \param imp Implementation resulting tensor
+/// \param dpInfo Dot-product verification meta-data
+///
+/// \return True if compliant else false
+bool verifyDotProduct(const CTensor* ref,
+ const CTensor* refBnd,
+ const CTensor* imp,
+ const DotProductVerifyInfo& dpInfo);
+
+}; // namespace TosaReference
+
+#endif // VERIFIERS_H_
diff --git a/reference_model/src/verify/verify_dot_product.cc b/reference_model/src/verify/verify_dot_product.cc
new file mode 100644
index 0000000..a24f83f
--- /dev/null
+++ b/reference_model/src/verify/verify_dot_product.cc
@@ -0,0 +1,143 @@
+// Copyright (c) 2023, ARM Limited.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "func_debug.h"
+#include "verifiers.h"
+
+#include <cmath>
+#include <numeric>
+#include <optional>
+#include <type_traits>
+
+#define TOSA_REF_REQUIRE(COND, MESSAGE) \
+ if (!(COND)) \
+ { \
+ WARNING(MESSAGE); \
+ return false; \
+ }
+
+namespace TosaReference
+{
+namespace
+{
+
+// Accumulator precision
+template <typename T>
+struct AccPrecision;
+#define two_m42 1.0 / (double)(((int64_t)1) << 42) // 2^-42
+template <>
+struct AccPrecision<float>
+{
+ static constexpr double precision = (double)(1 << 24);
+ static constexpr double min_normal = two_m42 * two_m42 * two_m42; // 2^-126
+};
+#undef two_m42
+
+// Generic element validation function
+template <typename AccType, typename std::enable_if_t<std::is_floating_point_v<AccType>, int> = 0>
+std::optional<double> validateElement(double ref, double bnd, AccType imp, size_t KS)
+{
+ double err = 0.0;
+ bool is_valid = true;
+
+ if (bnd == 0.0)
+ {
+ is_valid = (ref == 0.0) && (imp == 0.0);
+ err = 0.0;
+ }
+ else if (std::isinf(static_cast<AccType>(bnd)))
+ {
+ // dot product can overflow and there is no accuracy limit
+ is_valid = true;
+ err = 0.0;
+ }
+ else
+ {
+ // 0.0 < bnd < infinity
+ const double bnd_norm = std::max(bnd, AccPrecision<AccType>::min_normal);
+ const double imp_fp64 = static_cast<double>(imp);
+ const double acc_prec_fp64 = AccPrecision<AccType>::precision;
+ err = (imp_fp64 - ref) * acc_prec_fp64 / bnd_norm;
+ is_valid = std::abs(err) <= KS;
+ }
+
+ return is_valid ? std::optional(err) : std::nullopt;
+}
+
+// Generic data validation function
+template <typename AccType, typename std::enable_if_t<std::is_floating_point_v<AccType>, int> = 0>
+bool validateData(const double* ref, const double* bnd, const AccType* imp, size_t T, const DotProductVerifyInfo& cfg)
+{
+ const int32_t S = cfg.s;
+ // TODO - needed for other ops - (max_value(bias_abs) > 0) ? (KS + 1) : KS
+ const int32_t KS = cfg.ks;
+
+ double out_err_sum = 0.0;
+ double out_err_sumsq = 0.0;
+
+ for (size_t i = 0; i < T; ++i)
+ {
+ auto out_err = validateElement<AccType>(ref[i], bnd[i], imp[i], KS);
+ TOSA_REF_REQUIRE(out_err, "output error is 0");
+ out_err_sum += out_err.value();
+ out_err_sumsq += out_err.value() * out_err.value();
+ }
+
+ if (S >= 3 && S <= 5)
+ {
+ // Check error bias magnitude for data sets S which are not positive biased
+ TOSA_REF_REQUIRE(std::abs(out_err_sum) <= 2 * sqrt(KS * T), "bias magnitude is out of range");
+ }
+ // Check error variance magnitude
+ TOSA_REF_REQUIRE(out_err_sumsq <= 0.4 * KS * T, "error variance magnitude is out of range");
+ return true;
+}
+} // namespace
+
+bool verifyDotProduct(const CTensor* ref, const CTensor* refBnd, const CTensor* imp, const DotProductVerifyInfo& dpInfo)
+{
+ // Validate that tensors are provided
+ TOSA_REF_REQUIRE(ref != nullptr, "reference tensor is missing");
+ TOSA_REF_REQUIRE(refBnd != nullptr, "reference bounds tensor is missing");
+ TOSA_REF_REQUIRE(imp != nullptr, "implementation tensor is missing");
+
+ // Validate data-type
+ TOSA_REF_REQUIRE(dpInfo.dataType == mapToDType(imp->data_type), "invalid data type in config");
+
+ // Get number of dot-product elements
+ const int64_t T = numElements(std::vector<int32_t>(ref->shape, ref->shape + ref->num_dims));
+ TOSA_REF_REQUIRE(T > 0, "invalid shape for reference tensor");
+
+ const double* refData = reinterpret_cast<const double*>(ref->data);
+ const double* refBndData = reinterpret_cast<const double*>(refBnd->data);
+ TOSA_REF_REQUIRE(refData != nullptr && refBndData != nullptr, "missing data for reference or bounds tensors");
+
+ switch (imp->data_type)
+ {
+ case tosa_datatype_fp32_t: {
+ const float* impData = reinterpret_cast<const float*>(imp->data);
+ TOSA_REF_REQUIRE(impData != nullptr, "missing data for implementation");
+ return validateData(refData, refBndData, impData, static_cast<size_t>(T), dpInfo);
+ break;
+ }
+ default:
+ break;
+ }
+
+ return false;
+}
+
+} // namespace TosaReference
+
+#undef TOSA_REF_REQUIRE \ No newline at end of file
diff --git a/reference_model/src/verify/verify_entry.cc b/reference_model/src/verify/verify_entry.cc
new file mode 100644
index 0000000..80ca916
--- /dev/null
+++ b/reference_model/src/verify/verify_entry.cc
@@ -0,0 +1,92 @@
+// Copyright (c) 2023, ARM Limited.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "verify.h"
+
+#include "func_debug.h"
+#include "model_common.h"
+#include "verifiers.h"
+#include "verify_utils.h"
+
+#include <vector>
+
+namespace TosaReference
+{
+
+bool verify(const CTensor* ref, const CTensor* refBnd, const CTensor* imp, const VerifyConfig& cfg)
+{
+ switch (cfg.mode)
+ {
+ case VerifyMode::DotProduct: {
+ return verifyDotProduct(ref, refBnd, imp, cfg.dotProductInfo);
+ break;
+ }
+ default: {
+ WARNING("unsupported verification mode.");
+ break;
+ }
+ }
+ return false;
+}
+
+} // namespace TosaReference
+
+extern "C"
+{
+ bool tvf_verify_data(const tosa_tensor_t* ref,
+ const tosa_tensor_t* ref_bnd,
+ const tosa_tensor_t* imp,
+ const char* config_json)
+ {
+ // Check inputs for nullptr
+ if (!ref || !imp || !config_json)
+ {
+ WARNING("one of the inputs is missing.");
+ return false;
+ }
+
+ // Extract verification config
+ if (!ref->name)
+ {
+ WARNING("tensor name is not specified.");
+ return false;
+ }
+ auto cfg = TosaReference::parseVerifyConfig(ref->name, config_json);
+ if (!cfg)
+ {
+ WARNING("invalid json config.");
+ return false;
+ }
+
+ // Validate shape
+ if (ref->num_dims != imp->num_dims)
+ {
+ WARNING("tensors have different number of dimensions.");
+ return false;
+ }
+ if (!ref->shape || !imp->shape)
+ {
+ WARNING("one of tensors' shape is missing.");
+ return false;
+ }
+ if (std::vector(ref->shape, ref->shape + ref->num_dims) != std::vector(imp->shape, imp->shape + imp->num_dims))
+ {
+ WARNING("tensors have different shapes.");
+ return false;
+ }
+
+ // Run verification
+ return verify(ref, ref_bnd, imp, *cfg);
+ }
+} // extern "C"
diff --git a/reference_model/src/verify/verify_utils.cc b/reference_model/src/verify/verify_utils.cc
new file mode 100644
index 0000000..bb4feaa
--- /dev/null
+++ b/reference_model/src/verify/verify_utils.cc
@@ -0,0 +1,121 @@
+
+// Copyright (c) 2023, ARM Limited.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "verify_utils.h"
+
+#include <nlohmann/json.hpp>
+
+#include <algorithm>
+#include <map>
+
+namespace tosa
+{
+
+NLOHMANN_JSON_SERIALIZE_ENUM(DType,
+ {
+ { DType::DType_BOOL, "BOOL" },
+ { DType::DType_INT4, "INT4" },
+ { DType::DType_INT8, "INT8" },
+ { DType::DType_INT16, "INT16" },
+ { DType::DType_INT32, "INT32" },
+ { DType::DType_INT48, "INT48" },
+ { DType::DType_FP16, "FP16" },
+ { DType::DType_BF16, "BF16" },
+ { DType::DType_FP32, "FP32" },
+ })
+
+} // namespace tosa
+
+namespace TosaReference
+{
+
+NLOHMANN_JSON_SERIALIZE_ENUM(VerifyMode,
+ {
+ { VerifyMode::Exact, "EXACT" },
+ { VerifyMode::Ulp, "ULP" },
+ { VerifyMode::DotProduct, "DOT_PRODUCT" },
+ { VerifyMode::ReduceProduct, "REDUCE_PRODUCT" },
+ { VerifyMode::FpSpecial, "FP_SPECIAL" },
+ })
+
+void from_json(const nlohmann::json& j, UlpInfo& ulpInfo)
+{
+ j.at("ulp").get_to(ulpInfo.ulp);
+}
+
+void from_json(const nlohmann::json& j, DotProductVerifyInfo& dotProductInfo)
+{
+ j.at("data_type").get_to(dotProductInfo.dataType);
+ j.at("s").get_to(dotProductInfo.s);
+ j.at("ks").get_to(dotProductInfo.ks);
+}
+
+void from_json(const nlohmann::json& j, VerifyConfig& cfg)
+{
+ j.at("mode").get_to(cfg.mode);
+ if (j.contains("ulp_info"))
+ {
+ j.at("ulp_info").get_to(cfg.ulpInfo);
+ }
+ if (j.contains("dot_product_info"))
+ {
+ j.at("dot_product_info").get_to(cfg.dotProductInfo);
+ }
+}
+
+std::optional<VerifyConfig> parseVerifyConfig(const char* tensorName, const char* json)
+{
+ if (!tensorName)
+ return std::nullopt;
+
+ auto jsonCfg = nlohmann::json::parse(json, nullptr, /* allow exceptions */ false);
+
+ if (jsonCfg.is_discarded())
+ return std::nullopt;
+ if (!jsonCfg.contains("tensors"))
+ return std::nullopt;
+
+ const auto& tensors = jsonCfg["tensors"];
+ if (!tensors.contains(tensorName))
+ return std::nullopt;
+
+ const auto& namedTensor = tensors[tensorName];
+ return namedTensor.get<VerifyConfig>();
+}
+
+int64_t numElements(const std::vector<int32_t>& shape)
+{
+ return std::accumulate(std::begin(shape), std::end(shape), 1, std::multiplies<int64_t>());
+}
+
+DType mapToDType(tosa_datatype_t dataType)
+{
+ static std::map<tosa_datatype_t, DType> typeMap = {
+ { tosa_datatype_bool_t, DType_BOOL }, { tosa_datatype_int4_t, DType_INT4 },
+ { tosa_datatype_int8_t, DType_INT8 }, { tosa_datatype_uint16_t, DType_UINT16 },
+ { tosa_datatype_int16_t, DType_INT16 }, { tosa_datatype_int32_t, DType_INT32 },
+ { tosa_datatype_int48_t, DType_INT48 }, { tosa_datatype_fp16_t, DType_FP16 },
+ { tosa_datatype_bf16_t, DType_BF16 }, { tosa_datatype_fp32_t, DType_FP32 },
+ { tosa_datatype_shape_t, DType_SHAPE },
+ };
+
+ if (typeMap.count(dataType))
+ {
+ return typeMap[dataType];
+ }
+
+ return DType_UNKNOWN;
+}
+} // namespace TosaReference
diff --git a/reference_model/src/verify/verify_utils.h b/reference_model/src/verify/verify_utils.h
new file mode 100644
index 0000000..6e51e3e
--- /dev/null
+++ b/reference_model/src/verify/verify_utils.h
@@ -0,0 +1,81 @@
+
+// Copyright (c) 2023, ARM Limited.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef VERIFY_UTILS_H_
+#define VERIFY_UTILS_H_
+
+#include "dtype.h"
+#include "types.h"
+
+#include <cstdint>
+#include <optional>
+#include <vector>
+
+namespace TosaReference
+{
+
+// Name alias
+using CTensor = tosa_tensor_t;
+
+/// \brief Supported verification modes
+enum class VerifyMode
+{
+ Exact,
+ Ulp,
+ DotProduct,
+ ReduceProduct,
+ FpSpecial
+};
+
+/// \brief ULP verification meta-data
+struct UlpInfo
+{
+ UlpInfo() = default;
+
+ float ulp;
+};
+
+/// \brief Dot-product verification meta-data
+struct DotProductVerifyInfo
+{
+ DotProductVerifyInfo() = default;
+
+ DType dataType;
+ int32_t s;
+ int32_t ks;
+};
+
+/// \brief Verification meta-data
+struct VerifyConfig
+{
+ VerifyConfig() = default;
+
+ VerifyMode mode;
+ UlpInfo ulpInfo;
+ DotProductVerifyInfo dotProductInfo;
+};
+
+/// \brief Parse the verification config for a tensor when given in JSON form
+std::optional<VerifyConfig> parseVerifyConfig(const char* tensorName, const char* configJson);
+
+/// \brief Extract number of total elements
+int64_t numElements(const std::vector<int32_t>& shape);
+
+/// \brief Map API data-type to DType
+DType mapToDType(tosa_datatype_t dataType);
+
+}; // namespace TosaReference
+
+#endif // VERIFY_UTILS_H_
diff --git a/reference_model/test/verify_tests.cpp b/reference_model/test/verify_tests.cpp
index f43804f..81f3e8d 100644
--- a/reference_model/test/verify_tests.cpp
+++ b/reference_model/test/verify_tests.cpp
@@ -16,41 +16,103 @@
#include <doctest.h>
#include <array>
-#include <numeric>
+#include <string>
#include <vector>
-TEST_SUITE("verify")
+namespace
{
- TEST_CASE("check_element_accfp32")
+
+class TosaTensor
+{
+public:
+ TosaTensor(std::string name, tosa_datatype_t dataType, std::vector<int32_t> shape)
+ : _name(std::move(name))
+ , _shape(std::move(shape))
+ {
+ _tensor.name = _name.c_str();
+ _tensor.data_type = dataType;
+ _tensor.num_dims = _shape.size();
+ _tensor.shape = _shape.data();
+ };
+
+ const tosa_tensor_t* cTensor() const
{
- const size_t KS = 27;
-
- // Negative (bnd == 0.0)
- REQUIRE_FALSE(tosa_validate_element_accfp32(0.0, 0.0, 1.0, KS).is_valid);
- REQUIRE_FALSE(tosa_validate_element_accfp32(1.0, 0.0, 0.0, KS).is_valid);
- // Negative (bnd > 0.0)
- REQUIRE_FALSE(tosa_validate_element_accfp32(5.0, 5.0, 5.1, KS).is_valid);
-
- // Positive (bnd == 0.0 && ref == 0.0 && imp == 0.0)
- REQUIRE(tosa_validate_element_accfp32(0.0, 0.0, 0.0, KS).is_valid);
- REQUIRE(tosa_validate_element_accfp32(0.0, 0.0, 0.0, KS).error == 0.0);
-
- // Positive (bnd > 0.0)
- REQUIRE(tosa_validate_element_accfp32(4.0, 4.0, 4.0, KS).error == 0.0);
- REQUIRE(tosa_validate_element_accfp32(4.0, 4.0, 4.0, KS).error == 0.0);
- REQUIRE(tosa_validate_element_accfp32(4.0, 4.0, 4.0, KS).error == 0.0);
+ return &_tensor;
}
- TEST_CASE("check_output_error")
+
+private:
+ std::string _name;
+ std::vector<int32_t> _shape;
+ tosa_tensor_t _tensor;
+};
+
+} // namespace
+
+TEST_SUITE_BEGIN("verify");
+
+TEST_CASE("negative - api")
+{
+ std::string json_cfg = R"({
+ "tensors" : {
+ "out1" : {
+ "mode": "DOT_PRODUCT",
+ "dot_product_info" : {
+ "data_type": "FP32",
+ "s": 2,
+ "ks": 9
+ }
+ }
+ }
+ })";
+
+ SUBCASE("invalid json")
{
- const size_t KS = 27;
- const size_t T = 1024;
-
- // Negative (S!=1 && S!=2 && (abs(err_sum) > 2*sqrt(KS*T)))
- REQUIRE_FALSE(tosa_validate_output_error(1560, 112000, KS, T, 0));
- // Negative (err_sum_sq > 0.4*KS*T))
- REQUIRE_FALSE(tosa_validate_output_error(1560, 112000, KS, T, 1));
- // Positive
- REQUIRE(tosa_validate_output_error(10, 254, KS, T, 0));
- REQUIRE(tosa_validate_output_error(10, 254, KS, T, 1));
+ std::string invalid_json_cfg = R"({
+ "tensors" : {
+ "out1" : {
+ "mode": DOT_PRODUCT,
+ },
+ }
+ })";
+
+ const TosaTensor ref("out1", tosa_datatype_fp64_t, { 8, 8, 8 });
+ const TosaTensor refAbs("out1", tosa_datatype_fp64_t, { 8, 8, 8 });
+ const TosaTensor imp("out1", tosa_datatype_fp32_t, { 8, 8, 8 });
+
+ REQUIRE_FALSE(tvf_verify_data(ref.cTensor(), refAbs.cTensor(), imp.cTensor(), invalid_json_cfg.c_str()));
+ }
+ SUBCASE("mismatching dimensions")
+ {
+ const TosaTensor ref("out1", tosa_datatype_fp64_t, { 4, 4 });
+ const TosaTensor refAbs("out1", tosa_datatype_fp64_t, { 4, 4 });
+ const TosaTensor imp("out1", tosa_datatype_fp32_t, { 8, 8, 8 });
+
+ REQUIRE_FALSE(tvf_verify_data(ref.cTensor(), refAbs.cTensor(), imp.cTensor(), json_cfg.c_str()));
+ }
+ SUBCASE("mismatching shapes")
+ {
+ const TosaTensor ref("out1", tosa_datatype_fp64_t, { 8, 8, 8 });
+ const TosaTensor refAbs("out1", tosa_datatype_fp64_t, { 8, 8, 8 });
+ const TosaTensor imp("out1", tosa_datatype_fp32_t, { 4, 4, 4 });
+
+ REQUIRE_FALSE(tvf_verify_data(ref.cTensor(), refAbs.cTensor(), imp.cTensor(), json_cfg.c_str()));
+ }
+ SUBCASE("mismatching data types")
+ {
+ const TosaTensor ref("out1", tosa_datatype_fp64_t, { 8, 8, 8 });
+ const TosaTensor refAbs("out1", tosa_datatype_fp64_t, { 8, 8, 8 });
+ const TosaTensor imp("out1", tosa_datatype_fp16_t, { 8, 8, 8 });
+
+ REQUIRE_FALSE(tvf_verify_data(ref.cTensor(), refAbs.cTensor(), imp.cTensor(), json_cfg.c_str()));
+ }
+ SUBCASE("missing tensor data")
+ {
+ const TosaTensor ref("out1", tosa_datatype_fp64_t, { 8, 8, 8 });
+ const TosaTensor refAbs("out1", tosa_datatype_fp64_t, { 8, 8, 8 });
+ const TosaTensor imp("out1", tosa_datatype_fp32_t, { 8, 8, 8 });
+
+ REQUIRE_FALSE(tvf_verify_data(ref.cTensor(), refAbs.cTensor(), imp.cTensor(), json_cfg.c_str()));
}
}
+
+TEST_SUITE_END(); // verify \ No newline at end of file