aboutsummaryrefslogtreecommitdiff
path: root/reference_model/src/verify
diff options
context:
space:
mode:
authorGeorgios Pinitas <georgios.pinitas@arm.com>2023-08-22 08:25:57 +0100
committerEric Kunze <eric.kunze@arm.com>2023-09-07 16:03:50 +0000
commit7021ef064f7daeca260bb1f1bd61b5bbc6473aa5 (patch)
tree24a488954ab0a7c6e29e811429ad194af67c3880 /reference_model/src/verify
parent391cc5e80559e46081b6aa12c344d820dc685c95 (diff)
downloadreference_model-7021ef064f7daeca260bb1f1bd61b5bbc6473aa5.tar.gz
Rework TOSA verification API
Change verifier API to consume verification configuration in a JSON format and enable appropriate validation to be performed within the verifier code in the reference model. Also update to latest spec changes for main compliance but not yet including bias support. Signed-off-by: Georgios Pinitas <georgios.pinitas@arm.com> Signed-off-by: Jeremy Johnson <jeremy.johnson@arm.com> Change-Id: I0ceaa1714dd041b00b5b29cd937c8f05e701bc4c
Diffstat (limited to 'reference_model/src/verify')
-rw-r--r--reference_model/src/verify/verifiers.h38
-rw-r--r--reference_model/src/verify/verify_dot_product.cc143
-rw-r--r--reference_model/src/verify/verify_entry.cc92
-rw-r--r--reference_model/src/verify/verify_utils.cc121
-rw-r--r--reference_model/src/verify/verify_utils.h81
5 files changed, 475 insertions, 0 deletions
diff --git a/reference_model/src/verify/verifiers.h b/reference_model/src/verify/verifiers.h
new file mode 100644
index 0000000..afd50bf
--- /dev/null
+++ b/reference_model/src/verify/verifiers.h
@@ -0,0 +1,38 @@
+
+// Copyright (c) 2023, ARM Limited.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef VERIFIERS_H_
+#define VERIFIERS_H_
+
+#include "verify_utils.h"
+
+namespace TosaReference
+{
+/// \brief Perform dot-product based verification
+///
+/// \param ref Reference tensor
+/// \param refBnd Reference tensor when ran on abs(input)
+/// \param imp Implementation resulting tensor
+/// \param dpInfo Dot-product verification meta-data
+///
+/// \return True if compliant else false
+bool verifyDotProduct(const CTensor* ref,
+ const CTensor* refBnd,
+ const CTensor* imp,
+ const DotProductVerifyInfo& dpInfo);
+
+}; // namespace TosaReference
+
+#endif // VERIFIERS_H_
diff --git a/reference_model/src/verify/verify_dot_product.cc b/reference_model/src/verify/verify_dot_product.cc
new file mode 100644
index 0000000..a24f83f
--- /dev/null
+++ b/reference_model/src/verify/verify_dot_product.cc
@@ -0,0 +1,143 @@
+// Copyright (c) 2023, ARM Limited.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "func_debug.h"
+#include "verifiers.h"
+
+#include <cmath>
+#include <numeric>
+#include <optional>
+#include <type_traits>
+
+#define TOSA_REF_REQUIRE(COND, MESSAGE) \
+ if (!(COND)) \
+ { \
+ WARNING(MESSAGE); \
+ return false; \
+ }
+
+namespace TosaReference
+{
+namespace
+{
+
+// Accumulator precision
+template <typename T>
+struct AccPrecision;
+#define two_m42 1.0 / (double)(((int64_t)1) << 42) // 2^-42
+template <>
+struct AccPrecision<float>
+{
+ static constexpr double precision = (double)(1 << 24);
+ static constexpr double min_normal = two_m42 * two_m42 * two_m42; // 2^-126
+};
+#undef two_m42
+
+// Generic element validation function
+template <typename AccType, typename std::enable_if_t<std::is_floating_point_v<AccType>, int> = 0>
+std::optional<double> validateElement(double ref, double bnd, AccType imp, size_t KS)
+{
+ double err = 0.0;
+ bool is_valid = true;
+
+ if (bnd == 0.0)
+ {
+ is_valid = (ref == 0.0) && (imp == 0.0);
+ err = 0.0;
+ }
+ else if (std::isinf(static_cast<AccType>(bnd)))
+ {
+ // dot product can overflow and there is no accuracy limit
+ is_valid = true;
+ err = 0.0;
+ }
+ else
+ {
+ // 0.0 < bnd < infinity
+ const double bnd_norm = std::max(bnd, AccPrecision<AccType>::min_normal);
+ const double imp_fp64 = static_cast<double>(imp);
+ const double acc_prec_fp64 = AccPrecision<AccType>::precision;
+ err = (imp_fp64 - ref) * acc_prec_fp64 / bnd_norm;
+ is_valid = std::abs(err) <= KS;
+ }
+
+ return is_valid ? std::optional(err) : std::nullopt;
+}
+
+// Generic data validation function
+template <typename AccType, typename std::enable_if_t<std::is_floating_point_v<AccType>, int> = 0>
+bool validateData(const double* ref, const double* bnd, const AccType* imp, size_t T, const DotProductVerifyInfo& cfg)
+{
+ const int32_t S = cfg.s;
+ // TODO - needed for other ops - (max_value(bias_abs) > 0) ? (KS + 1) : KS
+ const int32_t KS = cfg.ks;
+
+ double out_err_sum = 0.0;
+ double out_err_sumsq = 0.0;
+
+ for (size_t i = 0; i < T; ++i)
+ {
+ auto out_err = validateElement<AccType>(ref[i], bnd[i], imp[i], KS);
+ TOSA_REF_REQUIRE(out_err, "output error is 0");
+ out_err_sum += out_err.value();
+ out_err_sumsq += out_err.value() * out_err.value();
+ }
+
+ if (S >= 3 && S <= 5)
+ {
+ // Check error bias magnitude for data sets S which are not positive biased
+ TOSA_REF_REQUIRE(std::abs(out_err_sum) <= 2 * sqrt(KS * T), "bias magnitude is out of range");
+ }
+ // Check error variance magnitude
+ TOSA_REF_REQUIRE(out_err_sumsq <= 0.4 * KS * T, "error variance magnitude is out of range");
+ return true;
+}
+} // namespace
+
+bool verifyDotProduct(const CTensor* ref, const CTensor* refBnd, const CTensor* imp, const DotProductVerifyInfo& dpInfo)
+{
+ // Validate that tensors are provided
+ TOSA_REF_REQUIRE(ref != nullptr, "reference tensor is missing");
+ TOSA_REF_REQUIRE(refBnd != nullptr, "reference bounds tensor is missing");
+ TOSA_REF_REQUIRE(imp != nullptr, "implementation tensor is missing");
+
+ // Validate data-type
+ TOSA_REF_REQUIRE(dpInfo.dataType == mapToDType(imp->data_type), "invalid data type in config");
+
+ // Get number of dot-product elements
+ const int64_t T = numElements(std::vector<int32_t>(ref->shape, ref->shape + ref->num_dims));
+ TOSA_REF_REQUIRE(T > 0, "invalid shape for reference tensor");
+
+ const double* refData = reinterpret_cast<const double*>(ref->data);
+ const double* refBndData = reinterpret_cast<const double*>(refBnd->data);
+ TOSA_REF_REQUIRE(refData != nullptr && refBndData != nullptr, "missing data for reference or bounds tensors");
+
+ switch (imp->data_type)
+ {
+ case tosa_datatype_fp32_t: {
+ const float* impData = reinterpret_cast<const float*>(imp->data);
+ TOSA_REF_REQUIRE(impData != nullptr, "missing data for implementation");
+ return validateData(refData, refBndData, impData, static_cast<size_t>(T), dpInfo);
+ break;
+ }
+ default:
+ break;
+ }
+
+ return false;
+}
+
+} // namespace TosaReference
+
+#undef TOSA_REF_REQUIRE \ No newline at end of file
diff --git a/reference_model/src/verify/verify_entry.cc b/reference_model/src/verify/verify_entry.cc
new file mode 100644
index 0000000..80ca916
--- /dev/null
+++ b/reference_model/src/verify/verify_entry.cc
@@ -0,0 +1,92 @@
+// Copyright (c) 2023, ARM Limited.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "verify.h"
+
+#include "func_debug.h"
+#include "model_common.h"
+#include "verifiers.h"
+#include "verify_utils.h"
+
+#include <vector>
+
+namespace TosaReference
+{
+
+bool verify(const CTensor* ref, const CTensor* refBnd, const CTensor* imp, const VerifyConfig& cfg)
+{
+ switch (cfg.mode)
+ {
+ case VerifyMode::DotProduct: {
+ return verifyDotProduct(ref, refBnd, imp, cfg.dotProductInfo);
+ break;
+ }
+ default: {
+ WARNING("unsupported verification mode.");
+ break;
+ }
+ }
+ return false;
+}
+
+} // namespace TosaReference
+
+extern "C"
+{
+ bool tvf_verify_data(const tosa_tensor_t* ref,
+ const tosa_tensor_t* ref_bnd,
+ const tosa_tensor_t* imp,
+ const char* config_json)
+ {
+ // Check inputs for nullptr
+ if (!ref || !imp || !config_json)
+ {
+ WARNING("one of the inputs is missing.");
+ return false;
+ }
+
+ // Extract verification config
+ if (!ref->name)
+ {
+ WARNING("tensor name is not specified.");
+ return false;
+ }
+ auto cfg = TosaReference::parseVerifyConfig(ref->name, config_json);
+ if (!cfg)
+ {
+ WARNING("invalid json config.");
+ return false;
+ }
+
+ // Validate shape
+ if (ref->num_dims != imp->num_dims)
+ {
+ WARNING("tensors have different number of dimensions.");
+ return false;
+ }
+ if (!ref->shape || !imp->shape)
+ {
+ WARNING("one of tensors' shape is missing.");
+ return false;
+ }
+ if (std::vector(ref->shape, ref->shape + ref->num_dims) != std::vector(imp->shape, imp->shape + imp->num_dims))
+ {
+ WARNING("tensors have different shapes.");
+ return false;
+ }
+
+ // Run verification
+ return verify(ref, ref_bnd, imp, *cfg);
+ }
+} // extern "C"
diff --git a/reference_model/src/verify/verify_utils.cc b/reference_model/src/verify/verify_utils.cc
new file mode 100644
index 0000000..bb4feaa
--- /dev/null
+++ b/reference_model/src/verify/verify_utils.cc
@@ -0,0 +1,121 @@
+
+// Copyright (c) 2023, ARM Limited.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "verify_utils.h"
+
+#include <nlohmann/json.hpp>
+
+#include <algorithm>
+#include <map>
+
+namespace tosa
+{
+
+NLOHMANN_JSON_SERIALIZE_ENUM(DType,
+ {
+ { DType::DType_BOOL, "BOOL" },
+ { DType::DType_INT4, "INT4" },
+ { DType::DType_INT8, "INT8" },
+ { DType::DType_INT16, "INT16" },
+ { DType::DType_INT32, "INT32" },
+ { DType::DType_INT48, "INT48" },
+ { DType::DType_FP16, "FP16" },
+ { DType::DType_BF16, "BF16" },
+ { DType::DType_FP32, "FP32" },
+ })
+
+} // namespace tosa
+
+namespace TosaReference
+{
+
+NLOHMANN_JSON_SERIALIZE_ENUM(VerifyMode,
+ {
+ { VerifyMode::Exact, "EXACT" },
+ { VerifyMode::Ulp, "ULP" },
+ { VerifyMode::DotProduct, "DOT_PRODUCT" },
+ { VerifyMode::ReduceProduct, "REDUCE_PRODUCT" },
+ { VerifyMode::FpSpecial, "FP_SPECIAL" },
+ })
+
+void from_json(const nlohmann::json& j, UlpInfo& ulpInfo)
+{
+ j.at("ulp").get_to(ulpInfo.ulp);
+}
+
+void from_json(const nlohmann::json& j, DotProductVerifyInfo& dotProductInfo)
+{
+ j.at("data_type").get_to(dotProductInfo.dataType);
+ j.at("s").get_to(dotProductInfo.s);
+ j.at("ks").get_to(dotProductInfo.ks);
+}
+
+void from_json(const nlohmann::json& j, VerifyConfig& cfg)
+{
+ j.at("mode").get_to(cfg.mode);
+ if (j.contains("ulp_info"))
+ {
+ j.at("ulp_info").get_to(cfg.ulpInfo);
+ }
+ if (j.contains("dot_product_info"))
+ {
+ j.at("dot_product_info").get_to(cfg.dotProductInfo);
+ }
+}
+
+std::optional<VerifyConfig> parseVerifyConfig(const char* tensorName, const char* json)
+{
+ if (!tensorName)
+ return std::nullopt;
+
+ auto jsonCfg = nlohmann::json::parse(json, nullptr, /* allow exceptions */ false);
+
+ if (jsonCfg.is_discarded())
+ return std::nullopt;
+ if (!jsonCfg.contains("tensors"))
+ return std::nullopt;
+
+ const auto& tensors = jsonCfg["tensors"];
+ if (!tensors.contains(tensorName))
+ return std::nullopt;
+
+ const auto& namedTensor = tensors[tensorName];
+ return namedTensor.get<VerifyConfig>();
+}
+
+int64_t numElements(const std::vector<int32_t>& shape)
+{
+ return std::accumulate(std::begin(shape), std::end(shape), 1, std::multiplies<int64_t>());
+}
+
+DType mapToDType(tosa_datatype_t dataType)
+{
+ static std::map<tosa_datatype_t, DType> typeMap = {
+ { tosa_datatype_bool_t, DType_BOOL }, { tosa_datatype_int4_t, DType_INT4 },
+ { tosa_datatype_int8_t, DType_INT8 }, { tosa_datatype_uint16_t, DType_UINT16 },
+ { tosa_datatype_int16_t, DType_INT16 }, { tosa_datatype_int32_t, DType_INT32 },
+ { tosa_datatype_int48_t, DType_INT48 }, { tosa_datatype_fp16_t, DType_FP16 },
+ { tosa_datatype_bf16_t, DType_BF16 }, { tosa_datatype_fp32_t, DType_FP32 },
+ { tosa_datatype_shape_t, DType_SHAPE },
+ };
+
+ if (typeMap.count(dataType))
+ {
+ return typeMap[dataType];
+ }
+
+ return DType_UNKNOWN;
+}
+} // namespace TosaReference
diff --git a/reference_model/src/verify/verify_utils.h b/reference_model/src/verify/verify_utils.h
new file mode 100644
index 0000000..6e51e3e
--- /dev/null
+++ b/reference_model/src/verify/verify_utils.h
@@ -0,0 +1,81 @@
+
+// Copyright (c) 2023, ARM Limited.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef VERIFY_UTILS_H_
+#define VERIFY_UTILS_H_
+
+#include "dtype.h"
+#include "types.h"
+
+#include <cstdint>
+#include <optional>
+#include <vector>
+
+namespace TosaReference
+{
+
+// Name alias
+using CTensor = tosa_tensor_t;
+
+/// \brief Supported verification modes
+enum class VerifyMode
+{
+ Exact,
+ Ulp,
+ DotProduct,
+ ReduceProduct,
+ FpSpecial
+};
+
+/// \brief ULP verification meta-data
+struct UlpInfo
+{
+ UlpInfo() = default;
+
+ float ulp;
+};
+
+/// \brief Dot-product verification meta-data
+struct DotProductVerifyInfo
+{
+ DotProductVerifyInfo() = default;
+
+ DType dataType;
+ int32_t s;
+ int32_t ks;
+};
+
+/// \brief Verification meta-data
+struct VerifyConfig
+{
+ VerifyConfig() = default;
+
+ VerifyMode mode;
+ UlpInfo ulpInfo;
+ DotProductVerifyInfo dotProductInfo;
+};
+
+/// \brief Parse the verification config for a tensor when given in JSON form
+std::optional<VerifyConfig> parseVerifyConfig(const char* tensorName, const char* configJson);
+
+/// \brief Extract number of total elements
+int64_t numElements(const std::vector<int32_t>& shape);
+
+/// \brief Map API data-type to DType
+DType mapToDType(tosa_datatype_t dataType);
+
+}; // namespace TosaReference
+
+#endif // VERIFY_UTILS_H_