aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJeremy Johnson <jeremy.johnson@arm.com>2024-01-03 10:54:12 +0000
committerEric Kunze <eric.kunze@arm.com>2024-01-08 21:40:41 +0000
commitd80ea5e11e5f92e0f7c08afeba74cb7d1719987b (patch)
tree25589c928c95de3de8bbad96dc07432bd9d289f9
parent2936f13d0e26c394333495ce909740eaf58a45cc (diff)
downloadreference_model-d80ea5e11e5f92e0f7c08afeba74cb7d1719987b.tar.gz
Main Conformance: Re-adjust TANH compliance check
Add lower bound to ABS ERROR checks to allow for cancellation of small values in error bounds checking. Re-adjust the error bounds multiplier to match the specification. Fix up naming of verify library info structs. Signed-off-by: Jeremy Johnson <jeremy.johnson@arm.com> Change-Id: I3e178c3d7d59fef9c3696178646b23ed2a3ffc61
-rw-r--r--reference_model/src/ops/activation_funcs.cc6
-rw-r--r--reference_model/src/verify/verifiers.h9
-rw-r--r--reference_model/src/verify/verify_abs_error.cc21
-rw-r--r--reference_model/src/verify/verify_entry.cc4
-rw-r--r--reference_model/src/verify/verify_ulp.cc4
-rw-r--r--reference_model/src/verify/verify_utils.cc18
-rw-r--r--reference_model/src/verify/verify_utils.h17
-rw-r--r--scripts/schemavalidation/compliance-config.schema.json14
-rw-r--r--verif/generator/tosa_test_gen.py9
9 files changed, 77 insertions, 25 deletions
diff --git a/reference_model/src/ops/activation_funcs.cc b/reference_model/src/ops/activation_funcs.cc
index 787055c..c8fdc74 100644
--- a/reference_model/src/ops/activation_funcs.cc
+++ b/reference_model/src/ops/activation_funcs.cc
@@ -1,5 +1,5 @@
-// Copyright (c) 2020-2023, ARM Limited.
+// Copyright (c) 2020-2024, ARM Limited.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@@ -123,8 +123,8 @@ int OpTanh<Rank, Dtype>::register_fcn()
case TOSA_REF_TYPE_FP64:
if (g_func_config.abs_mode)
{
- // ABS_ERROR bounds return 8*(1+abs(a))
- this->fcn = [](InEigenType a) -> OutEigenType { return 8.0 * (1.0 + (a > (InEigenType)0 ? a : (-a))); };
+ // ABS_ERROR bounds return 4*(1+abs(a))
+ this->fcn = [](InEigenType a) -> OutEigenType { return 4.0 * (1.0 + (a > (InEigenType)0 ? a : (-a))); };
}
else
{
diff --git a/reference_model/src/verify/verifiers.h b/reference_model/src/verify/verifiers.h
index f2590cb..152cd6a 100644
--- a/reference_model/src/verify/verifiers.h
+++ b/reference_model/src/verify/verifiers.h
@@ -1,5 +1,5 @@
-// Copyright (c) 2023, ARM Limited.
+// Copyright (c) 2023-2024, ARM Limited.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@@ -55,19 +55,20 @@ bool verifyReduceProduct(const CTensor* referenceTensor, const CTensor* implemen
///
/// \param referenceTensor Reference tensor
/// \param implementationTensor Implementation resulting tensor
-/// \param ulp The ULP tolerence for the comparison of the two tensors
+/// \param ulpInfo The ULP tolerence info for the comparison of the two tensors
///
/// \return True if compliant else false
-bool verifyULP(const CTensor* referenceTensor, const CTensor* implementationTensor, const UlpInfo& ulpInfo);
+bool verifyULP(const CTensor* referenceTensor, const CTensor* implementationTensor, const UlpVerifyInfo& ulpInfo);
/// \brief Perform abs-error based verification
///
/// \param ref Reference tensor
/// \param refBnd Reference bounds tensor (according to op)
/// \param imp Implementation resulting tensor
+/// \param aeInfo Abs-error verification meta-data
///
/// \return True if compliant else false
-bool verifyAbsError(const CTensor* ref, const CTensor* refBnd, const CTensor* imp);
+bool verifyAbsError(const CTensor* ref, const CTensor* refBnd, const CTensor* imp, const AbsErrorVerifyInfo& aeInfo);
}; // namespace TosaReference
diff --git a/reference_model/src/verify/verify_abs_error.cc b/reference_model/src/verify/verify_abs_error.cc
index 5aaa0ad..25ecae4 100644
--- a/reference_model/src/verify/verify_abs_error.cc
+++ b/reference_model/src/verify/verify_abs_error.cc
@@ -1,4 +1,4 @@
-// Copyright (c) 2023, ARM Limited.
+// Copyright (c) 2023-2024, ARM Limited.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@@ -27,14 +27,23 @@ namespace TosaReference
namespace
{
template <typename OutDtype>
-bool validateData(const double* ref, const double* bnd, const OutDtype* imp, const std::vector<int32_t>& shape)
+bool validateData(const double* ref,
+ const double* bnd,
+ const OutDtype* imp,
+ const std::vector<int32_t>& shape,
+ const AbsErrorVerifyInfo& cfg)
{
const size_t T = static_cast<size_t>(numElements(shape));
TOSA_REF_REQUIRE(T > 0, "[AE] Invalid shape for reference tensor");
for (size_t i = 0; i < T; ++i)
{
- double errBound = std::abs(ref[i]) * exp2(-AccPrecision<OutDtype>::normal_frac) * bnd[i];
+ double valBound = std::abs(ref[i]) * bnd[i];
+ if (cfg.lowerBound > 0)
+ {
+ valBound = std::max(cfg.lowerBound, valBound);
+ }
+ double errBound = exp2(-AccPrecision<OutDtype>::normal_frac) * valBound;
bool valid = tosaCheckFloatBound(imp[i], ref[i], errBound);
if (!valid)
{
@@ -46,7 +55,7 @@ bool validateData(const double* ref, const double* bnd, const OutDtype* imp, con
return true;
}
} // namespace
-bool verifyAbsError(const CTensor* ref, const CTensor* refBnd, const CTensor* imp)
+bool verifyAbsError(const CTensor* ref, const CTensor* refBnd, const CTensor* imp, const AbsErrorVerifyInfo& aeInfo)
{
// Validate that tensors are provided
TOSA_REF_REQUIRE(ref != nullptr, "[AE] Reference tensor is missing");
@@ -64,12 +73,12 @@ bool verifyAbsError(const CTensor* ref, const CTensor* refBnd, const CTensor* im
case tosa_datatype_fp32_t: {
const auto* impData = reinterpret_cast<const float*>(imp->data);
TOSA_REF_REQUIRE(impData != nullptr, "[AE] Missing data for implementation");
- return validateData(refData, refBndData, impData, refShape);
+ return validateData(refData, refBndData, impData, refShape, aeInfo);
}
case tosa_datatype_fp16_t: {
const auto* impData = reinterpret_cast<const half_float::half*>(imp->data);
TOSA_REF_REQUIRE(impData != nullptr, "[AE] Missing data for implementation");
- return validateData(refData, refBndData, impData, refShape);
+ return validateData(refData, refBndData, impData, refShape, aeInfo);
}
default:
WARNING("[Verifier][AE] Data-type not supported.");
diff --git a/reference_model/src/verify/verify_entry.cc b/reference_model/src/verify/verify_entry.cc
index d0b31c6..2b318d1 100644
--- a/reference_model/src/verify/verify_entry.cc
+++ b/reference_model/src/verify/verify_entry.cc
@@ -1,4 +1,4 @@
-// Copyright (c) 2023, ARM Limited.
+// Copyright (c) 2023-2024, ARM Limited.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@@ -41,7 +41,7 @@ bool verify(const CTensor* ref, const CTensor* refBnd, const CTensor* imp, const
return verifyULP(ref, imp, cfg.ulpInfo);
}
case VerifyMode::AbsError: {
- return verifyAbsError(ref, refBnd, imp);
+ return verifyAbsError(ref, refBnd, imp, cfg.absErrorInfo);
}
default: {
WARNING("[Verifier] Unsupported verification mode.");
diff --git a/reference_model/src/verify/verify_ulp.cc b/reference_model/src/verify/verify_ulp.cc
index 1b38fe6..13bf0a9 100644
--- a/reference_model/src/verify/verify_ulp.cc
+++ b/reference_model/src/verify/verify_ulp.cc
@@ -1,4 +1,4 @@
-// Copyright (c) 2023, ARM Limited.
+// Copyright (c) 2023-2024, ARM Limited.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@@ -48,7 +48,7 @@ bool tosaCheckULP(OutType testValue, double referenceValue, double ulpNum)
}
} // namespace
-bool verifyULP(const CTensor* referenceTensor, const CTensor* implementationTensor, const UlpInfo& ulpInfo)
+bool verifyULP(const CTensor* referenceTensor, const CTensor* implementationTensor, const UlpVerifyInfo& ulpInfo)
{
// Validate that tensors are provided
TOSA_REF_REQUIRE(referenceTensor != nullptr, "[ULP] Reference tensor is missing");
diff --git a/reference_model/src/verify/verify_utils.cc b/reference_model/src/verify/verify_utils.cc
index 3bdc99f..5ce646c 100644
--- a/reference_model/src/verify/verify_utils.cc
+++ b/reference_model/src/verify/verify_utils.cc
@@ -1,5 +1,5 @@
-// Copyright (c) 2023, ARM Limited.
+// Copyright (c) 2023-2024, ARM Limited.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@@ -53,7 +53,7 @@ NLOHMANN_JSON_SERIALIZE_ENUM(VerifyMode,
{ VerifyMode::AbsError, "ABS_ERROR" },
})
-void from_json(const nlohmann::json& j, UlpInfo& ulpInfo)
+void from_json(const nlohmann::json& j, UlpVerifyInfo& ulpInfo)
{
j.at("ulp").get_to(ulpInfo.ulp);
}
@@ -70,6 +70,14 @@ void from_json(const nlohmann::json& j, ReduceProductVerifyInfo& reduceProduceIn
j.at("n").get_to(reduceProduceInfo.n);
}
+void from_json(const nlohmann::json& j, AbsErrorVerifyInfo& absErrorInfo)
+{
+ if (j.contains("lower_bound"))
+ {
+ j.at("lower_bound").get_to(absErrorInfo.lowerBound);
+ }
+}
+
void from_json(const nlohmann::json& j, VerifyConfig& cfg)
{
j.at("mode").get_to(cfg.mode);
@@ -86,6 +94,12 @@ void from_json(const nlohmann::json& j, VerifyConfig& cfg)
{
j.at("reduce_product_info").get_to(cfg.reduceProductInfo);
}
+ // Set up defaults for optional AbsErrorVerifyInfo
+ cfg.absErrorInfo.lowerBound = 0;
+ if (j.contains("abs_error_info"))
+ {
+ j.at("abs_error_info").get_to(cfg.absErrorInfo);
+ }
}
std::optional<VerifyConfig> parseVerifyConfig(const char* tensorName, const char* json)
diff --git a/reference_model/src/verify/verify_utils.h b/reference_model/src/verify/verify_utils.h
index 45daeac..0fc68fb 100644
--- a/reference_model/src/verify/verify_utils.h
+++ b/reference_model/src/verify/verify_utils.h
@@ -1,5 +1,5 @@
-// Copyright (c) 2023, ARM Limited.
+// Copyright (c) 2023-2024, ARM Limited.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@@ -50,9 +50,9 @@ enum class VerifyMode
};
/// \brief ULP verification meta-data
-struct UlpInfo
+struct UlpVerifyInfo
{
- UlpInfo() = default;
+ UlpVerifyInfo() = default;
double ulp;
};
@@ -75,6 +75,14 @@ struct ReduceProductVerifyInfo
int64_t n;
};
+/// \brief abs-error verification meta-data
+struct AbsErrorVerifyInfo
+{
+ AbsErrorVerifyInfo() = default;
+
+ double lowerBound;
+};
+
/// \brief Verification meta-data
struct VerifyConfig
{
@@ -82,9 +90,10 @@ struct VerifyConfig
VerifyMode mode;
DType dataType;
- UlpInfo ulpInfo;
+ UlpVerifyInfo ulpInfo;
DotProductVerifyInfo dotProductInfo;
ReduceProductVerifyInfo reduceProductInfo;
+ AbsErrorVerifyInfo absErrorInfo;
};
/// \brief Parse the verification config for a tensor when given in JSON form
diff --git a/scripts/schemavalidation/compliance-config.schema.json b/scripts/schemavalidation/compliance-config.schema.json
index e78d385..dd62404 100644
--- a/scripts/schemavalidation/compliance-config.schema.json
+++ b/scripts/schemavalidation/compliance-config.schema.json
@@ -1,5 +1,5 @@
{
- "$comment": "Copyright (c) 2023, ARM Limited.",
+ "$comment": "Copyright (c) 2023-2024, ARM Limited.",
"$comment": "SPDX-License-Identifier: Apache-2.0",
"$id": "compliance-config.schema.json",
"$schema": "https://json-schema.org/draft/2020-12/schema",
@@ -61,6 +61,18 @@
"s",
"ks"
]
+ },
+ "abs_error_info": {
+ "description": "info required for the ABS_ERROR mode",
+ "type": "object",
+ "properties":
+ {
+ "lower_bound": {
+ "description": "lower bound multiplier for error bounds",
+ "type": "number"
+ }
+ },
+ "additionalProperties": false
}
},
"additionalProperties": false,
diff --git a/verif/generator/tosa_test_gen.py b/verif/generator/tosa_test_gen.py
index 2290c54..5129e24 100644
--- a/verif/generator/tosa_test_gen.py
+++ b/verif/generator/tosa_test_gen.py
@@ -1,4 +1,4 @@
-# Copyright (c) 2020-2023, ARM Limited.
+# Copyright (c) 2020-2024, ARM Limited.
# SPDX-License-Identifier: Apache-2.0
import json
import os
@@ -349,6 +349,10 @@ class TosaTestGen:
mode = gtu.ComplianceMode.REDUCE_PRODUCT
elif op["op"] in (Op.EXP, Op.POW, Op.TANH, Op.SIGMOID):
mode = gtu.ComplianceMode.ABS_ERROR
+ if "compliance" in op and "abs_error_lower_bound" in op["compliance"]:
+ compliance_tens["abs_error_info"] = {
+ "lower_bound": op["compliance"]["abs_error_lower_bound"]
+ }
else:
mode = gtu.ComplianceMode.EXACT
compliance_tens["mode"] = gtu.ComplianceMode(mode).name
@@ -3262,6 +3266,9 @@ class TosaTestGen:
"data_gen": {
"fp": (gtu.DataGenType.PSEUDO_RANDOM,),
},
+ "compliance": {
+ "abs_error_lower_bound": 0.5,
+ },
},
"erf": {
"op": Op.ERF,