aboutsummaryrefslogtreecommitdiff
path: root/reference_model/src/verify
diff options
context:
space:
mode:
Diffstat (limited to 'reference_model/src/verify')
-rw-r--r--reference_model/src/verify/verifiers.h38
-rw-r--r--reference_model/src/verify/verify_dot_product.cc143
-rw-r--r--reference_model/src/verify/verify_entry.cc92
-rw-r--r--reference_model/src/verify/verify_utils.cc121
-rw-r--r--reference_model/src/verify/verify_utils.h81
5 files changed, 475 insertions, 0 deletions
diff --git a/reference_model/src/verify/verifiers.h b/reference_model/src/verify/verifiers.h
new file mode 100644
index 0000000..afd50bf
--- /dev/null
+++ b/reference_model/src/verify/verifiers.h
@@ -0,0 +1,38 @@
+
+// Copyright (c) 2023, ARM Limited.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef VERIFIERS_H_
+#define VERIFIERS_H_
+
+#include "verify_utils.h"
+
+namespace TosaReference
+{
+/// \brief Perform dot-product based verification
+///
+/// \param ref Reference tensor
+/// \param refBnd Reference tensor when ran on abs(input)
+/// \param imp Implementation resulting tensor
+/// \param dpInfo Dot-product verification meta-data
+///
+/// \return True if compliant else false
+bool verifyDotProduct(const CTensor* ref,
+ const CTensor* refBnd,
+ const CTensor* imp,
+ const DotProductVerifyInfo& dpInfo);
+
+}; // namespace TosaReference
+
+#endif // VERIFIERS_H_
diff --git a/reference_model/src/verify/verify_dot_product.cc b/reference_model/src/verify/verify_dot_product.cc
new file mode 100644
index 0000000..a24f83f
--- /dev/null
+++ b/reference_model/src/verify/verify_dot_product.cc
@@ -0,0 +1,143 @@
+// Copyright (c) 2023, ARM Limited.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "func_debug.h"
+#include "verifiers.h"
+
+#include <cmath>
+#include <numeric>
+#include <optional>
+#include <type_traits>
+
+#define TOSA_REF_REQUIRE(COND, MESSAGE) \
+ if (!(COND)) \
+ { \
+ WARNING(MESSAGE); \
+ return false; \
+ }
+
+namespace TosaReference
+{
+namespace
+{
+
+// Accumulator precision
+template <typename T>
+struct AccPrecision;
+#define two_m42 1.0 / (double)(((int64_t)1) << 42) // 2^-42
+template <>
+struct AccPrecision<float>
+{
+ static constexpr double precision = (double)(1 << 24);
+ static constexpr double min_normal = two_m42 * two_m42 * two_m42; // 2^-126
+};
+#undef two_m42
+
+// Generic element validation function
+template <typename AccType, typename std::enable_if_t<std::is_floating_point_v<AccType>, int> = 0>
+std::optional<double> validateElement(double ref, double bnd, AccType imp, size_t KS)
+{
+ double err = 0.0;
+ bool is_valid = true;
+
+ if (bnd == 0.0)
+ {
+ is_valid = (ref == 0.0) && (imp == 0.0);
+ err = 0.0;
+ }
+ else if (std::isinf(static_cast<AccType>(bnd)))
+ {
+ // dot product can overflow and there is no accuracy limit
+ is_valid = true;
+ err = 0.0;
+ }
+ else
+ {
+ // 0.0 < bnd < infinity
+ const double bnd_norm = std::max(bnd, AccPrecision<AccType>::min_normal);
+ const double imp_fp64 = static_cast<double>(imp);
+ const double acc_prec_fp64 = AccPrecision<AccType>::precision;
+ err = (imp_fp64 - ref) * acc_prec_fp64 / bnd_norm;
+ is_valid = std::abs(err) <= KS;
+ }
+
+ return is_valid ? std::optional(err) : std::nullopt;
+}
+
+// Generic data validation function
+template <typename AccType, typename std::enable_if_t<std::is_floating_point_v<AccType>, int> = 0>
+bool validateData(const double* ref, const double* bnd, const AccType* imp, size_t T, const DotProductVerifyInfo& cfg)
+{
+ const int32_t S = cfg.s;
+ // TODO - needed for other ops - (max_value(bias_abs) > 0) ? (KS + 1) : KS
+ const int32_t KS = cfg.ks;
+
+ double out_err_sum = 0.0;
+ double out_err_sumsq = 0.0;
+
+ for (size_t i = 0; i < T; ++i)
+ {
+ auto out_err = validateElement<AccType>(ref[i], bnd[i], imp[i], KS);
+ TOSA_REF_REQUIRE(out_err, "output error is 0");
+ out_err_sum += out_err.value();
+ out_err_sumsq += out_err.value() * out_err.value();
+ }
+
+ if (S >= 3 && S <= 5)
+ {
+ // Check error bias magnitude for data sets S which are not positive biased
+ TOSA_REF_REQUIRE(std::abs(out_err_sum) <= 2 * sqrt(KS * T), "bias magnitude is out of range");
+ }
+ // Check error variance magnitude
+ TOSA_REF_REQUIRE(out_err_sumsq <= 0.4 * KS * T, "error variance magnitude is out of range");
+ return true;
+}
+} // namespace
+
+bool verifyDotProduct(const CTensor* ref, const CTensor* refBnd, const CTensor* imp, const DotProductVerifyInfo& dpInfo)
+{
+ // Validate that tensors are provided
+ TOSA_REF_REQUIRE(ref != nullptr, "reference tensor is missing");
+ TOSA_REF_REQUIRE(refBnd != nullptr, "reference bounds tensor is missing");
+ TOSA_REF_REQUIRE(imp != nullptr, "implementation tensor is missing");
+
+ // Validate data-type
+ TOSA_REF_REQUIRE(dpInfo.dataType == mapToDType(imp->data_type), "invalid data type in config");
+
+ // Get number of dot-product elements
+ const int64_t T = numElements(std::vector<int32_t>(ref->shape, ref->shape + ref->num_dims));
+ TOSA_REF_REQUIRE(T > 0, "invalid shape for reference tensor");
+
+ const double* refData = reinterpret_cast<const double*>(ref->data);
+ const double* refBndData = reinterpret_cast<const double*>(refBnd->data);
+ TOSA_REF_REQUIRE(refData != nullptr && refBndData != nullptr, "missing data for reference or bounds tensors");
+
+ switch (imp->data_type)
+ {
+ case tosa_datatype_fp32_t: {
+ const float* impData = reinterpret_cast<const float*>(imp->data);
+ TOSA_REF_REQUIRE(impData != nullptr, "missing data for implementation");
+ return validateData(refData, refBndData, impData, static_cast<size_t>(T), dpInfo);
+ break;
+ }
+ default:
+ break;
+ }
+
+ return false;
+}
+
+} // namespace TosaReference
+
+#undef TOSA_REF_REQUIRE \ No newline at end of file
diff --git a/reference_model/src/verify/verify_entry.cc b/reference_model/src/verify/verify_entry.cc
new file mode 100644
index 0000000..80ca916
--- /dev/null
+++ b/reference_model/src/verify/verify_entry.cc
@@ -0,0 +1,92 @@
+// Copyright (c) 2023, ARM Limited.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "verify.h"
+
+#include "func_debug.h"
+#include "model_common.h"
+#include "verifiers.h"
+#include "verify_utils.h"
+
+#include <vector>
+
+namespace TosaReference
+{
+
+bool verify(const CTensor* ref, const CTensor* refBnd, const CTensor* imp, const VerifyConfig& cfg)
+{
+ switch (cfg.mode)
+ {
+ case VerifyMode::DotProduct: {
+ return verifyDotProduct(ref, refBnd, imp, cfg.dotProductInfo);
+ break;
+ }
+ default: {
+ WARNING("unsupported verification mode.");
+ break;
+ }
+ }
+ return false;
+}
+
+} // namespace TosaReference
+
+extern "C"
+{
+ bool tvf_verify_data(const tosa_tensor_t* ref,
+ const tosa_tensor_t* ref_bnd,
+ const tosa_tensor_t* imp,
+ const char* config_json)
+ {
+ // Check inputs for nullptr
+ if (!ref || !imp || !config_json)
+ {
+ WARNING("one of the inputs is missing.");
+ return false;
+ }
+
+ // Extract verification config
+ if (!ref->name)
+ {
+ WARNING("tensor name is not specified.");
+ return false;
+ }
+ auto cfg = TosaReference::parseVerifyConfig(ref->name, config_json);
+ if (!cfg)
+ {
+ WARNING("invalid json config.");
+ return false;
+ }
+
+ // Validate shape
+ if (ref->num_dims != imp->num_dims)
+ {
+ WARNING("tensors have different number of dimensions.");
+ return false;
+ }
+ if (!ref->shape || !imp->shape)
+ {
+ WARNING("one of tensors' shape is missing.");
+ return false;
+ }
+ if (std::vector(ref->shape, ref->shape + ref->num_dims) != std::vector(imp->shape, imp->shape + imp->num_dims))
+ {
+ WARNING("tensors have different shapes.");
+ return false;
+ }
+
+ // Run verification
+ return verify(ref, ref_bnd, imp, *cfg);
+ }
+} // extern "C"
diff --git a/reference_model/src/verify/verify_utils.cc b/reference_model/src/verify/verify_utils.cc
new file mode 100644
index 0000000..bb4feaa
--- /dev/null
+++ b/reference_model/src/verify/verify_utils.cc
@@ -0,0 +1,121 @@
+
+// Copyright (c) 2023, ARM Limited.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "verify_utils.h"
+
+#include <nlohmann/json.hpp>
+
+#include <algorithm>
+#include <map>
+
+namespace tosa
+{
+
+NLOHMANN_JSON_SERIALIZE_ENUM(DType,
+ {
+ { DType::DType_BOOL, "BOOL" },
+ { DType::DType_INT4, "INT4" },
+ { DType::DType_INT8, "INT8" },
+ { DType::DType_INT16, "INT16" },
+ { DType::DType_INT32, "INT32" },
+ { DType::DType_INT48, "INT48" },
+ { DType::DType_FP16, "FP16" },
+ { DType::DType_BF16, "BF16" },
+ { DType::DType_FP32, "FP32" },
+ })
+
+} // namespace tosa
+
+namespace TosaReference
+{
+
+NLOHMANN_JSON_SERIALIZE_ENUM(VerifyMode,
+ {
+ { VerifyMode::Exact, "EXACT" },
+ { VerifyMode::Ulp, "ULP" },
+ { VerifyMode::DotProduct, "DOT_PRODUCT" },
+ { VerifyMode::ReduceProduct, "REDUCE_PRODUCT" },
+ { VerifyMode::FpSpecial, "FP_SPECIAL" },
+ })
+
+void from_json(const nlohmann::json& j, UlpInfo& ulpInfo)
+{
+ j.at("ulp").get_to(ulpInfo.ulp);
+}
+
+void from_json(const nlohmann::json& j, DotProductVerifyInfo& dotProductInfo)
+{
+ j.at("data_type").get_to(dotProductInfo.dataType);
+ j.at("s").get_to(dotProductInfo.s);
+ j.at("ks").get_to(dotProductInfo.ks);
+}
+
+void from_json(const nlohmann::json& j, VerifyConfig& cfg)
+{
+ j.at("mode").get_to(cfg.mode);
+ if (j.contains("ulp_info"))
+ {
+ j.at("ulp_info").get_to(cfg.ulpInfo);
+ }
+ if (j.contains("dot_product_info"))
+ {
+ j.at("dot_product_info").get_to(cfg.dotProductInfo);
+ }
+}
+
+std::optional<VerifyConfig> parseVerifyConfig(const char* tensorName, const char* json)
+{
+ if (!tensorName)
+ return std::nullopt;
+
+ auto jsonCfg = nlohmann::json::parse(json, nullptr, /* allow exceptions */ false);
+
+ if (jsonCfg.is_discarded())
+ return std::nullopt;
+ if (!jsonCfg.contains("tensors"))
+ return std::nullopt;
+
+ const auto& tensors = jsonCfg["tensors"];
+ if (!tensors.contains(tensorName))
+ return std::nullopt;
+
+ const auto& namedTensor = tensors[tensorName];
+ return namedTensor.get<VerifyConfig>();
+}
+
+int64_t numElements(const std::vector<int32_t>& shape)
+{
+ return std::accumulate(std::begin(shape), std::end(shape), 1, std::multiplies<int64_t>());
+}
+
+DType mapToDType(tosa_datatype_t dataType)
+{
+ static std::map<tosa_datatype_t, DType> typeMap = {
+ { tosa_datatype_bool_t, DType_BOOL }, { tosa_datatype_int4_t, DType_INT4 },
+ { tosa_datatype_int8_t, DType_INT8 }, { tosa_datatype_uint16_t, DType_UINT16 },
+ { tosa_datatype_int16_t, DType_INT16 }, { tosa_datatype_int32_t, DType_INT32 },
+ { tosa_datatype_int48_t, DType_INT48 }, { tosa_datatype_fp16_t, DType_FP16 },
+ { tosa_datatype_bf16_t, DType_BF16 }, { tosa_datatype_fp32_t, DType_FP32 },
+ { tosa_datatype_shape_t, DType_SHAPE },
+ };
+
+ if (typeMap.count(dataType))
+ {
+ return typeMap[dataType];
+ }
+
+ return DType_UNKNOWN;
+}
+} // namespace TosaReference
diff --git a/reference_model/src/verify/verify_utils.h b/reference_model/src/verify/verify_utils.h
new file mode 100644
index 0000000..6e51e3e
--- /dev/null
+++ b/reference_model/src/verify/verify_utils.h
@@ -0,0 +1,81 @@
+
+// Copyright (c) 2023, ARM Limited.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef VERIFY_UTILS_H_
+#define VERIFY_UTILS_H_
+
+#include "dtype.h"
+#include "types.h"
+
+#include <cstdint>
+#include <optional>
+#include <vector>
+
+namespace TosaReference
+{
+
+// Name alias
+using CTensor = tosa_tensor_t;
+
+/// \brief Supported verification modes
+enum class VerifyMode
+{
+ Exact,
+ Ulp,
+ DotProduct,
+ ReduceProduct,
+ FpSpecial
+};
+
+/// \brief ULP verification meta-data
+struct UlpInfo
+{
+ UlpInfo() = default;
+
+ float ulp;
+};
+
+/// \brief Dot-product verification meta-data
+struct DotProductVerifyInfo
+{
+ DotProductVerifyInfo() = default;
+
+ DType dataType;
+ int32_t s;
+ int32_t ks;
+};
+
+/// \brief Verification meta-data
+struct VerifyConfig
+{
+ VerifyConfig() = default;
+
+ VerifyMode mode;
+ UlpInfo ulpInfo;
+ DotProductVerifyInfo dotProductInfo;
+};
+
+/// \brief Parse the verification config for a tensor when given in JSON form
+std::optional<VerifyConfig> parseVerifyConfig(const char* tensorName, const char* configJson);
+
+/// \brief Extract number of total elements
+int64_t numElements(const std::vector<int32_t>& shape);
+
+/// \brief Map API data-type to DType
+DType mapToDType(tosa_datatype_t dataType);
+
+}; // namespace TosaReference
+
+#endif // VERIFY_UTILS_H_