aboutsummaryrefslogtreecommitdiff
path: root/reference_model/src/verify/verify_utils.cc
diff options
context:
space:
mode:
Diffstat (limited to 'reference_model/src/verify/verify_utils.cc')
-rw-r--r--reference_model/src/verify/verify_utils.cc11
1 files changed, 11 insertions, 0 deletions
diff --git a/reference_model/src/verify/verify_utils.cc b/reference_model/src/verify/verify_utils.cc
index abb55eb..14bc6f1 100644
--- a/reference_model/src/verify/verify_utils.cc
+++ b/reference_model/src/verify/verify_utils.cc
@@ -52,6 +52,7 @@ NLOHMANN_JSON_SERIALIZE_ENUM(VerifyMode,
{ VerifyMode::FpSpecial, "FP_SPECIAL" },
{ VerifyMode::ReduceProduct, "REDUCE_PRODUCT" },
{ VerifyMode::AbsError, "ABS_ERROR" },
+ { VerifyMode::Relative, "RELATIVE" },
})
void from_json(const nlohmann::json& j, UlpVerifyInfo& ulpInfo)
@@ -78,6 +79,12 @@ void from_json(const nlohmann::json& j, AbsErrorVerifyInfo& absErrorInfo)
}
}
+void from_json(const nlohmann::json& j, RelativeVerifyInfo& rInfo)
+{
+ j.at("max").get_to(rInfo.max);
+ j.at("scale").get_to(rInfo.scale);
+}
+
void from_json(const nlohmann::json& j, VerifyConfig& cfg)
{
j.at("mode").get_to(cfg.mode);
@@ -100,6 +107,10 @@ void from_json(const nlohmann::json& j, VerifyConfig& cfg)
{
j.at("abs_error_info").get_to(cfg.absErrorInfo);
}
+ if (j.contains("relative_info"))
+ {
+ j.at("relative_info").get_to(cfg.relativeInfo);
+ }
}
std::optional<VerifyConfig> parseVerifyConfig(const char* tensorName, const char* json)