diff options
Diffstat (limited to 'reference_model/src/verify/verify_utils.cc')
-rw-r--r-- | reference_model/src/verify/verify_utils.cc | 11 |
1 files changed, 11 insertions, 0 deletions
diff --git a/reference_model/src/verify/verify_utils.cc b/reference_model/src/verify/verify_utils.cc index abb55eb..14bc6f1 100644 --- a/reference_model/src/verify/verify_utils.cc +++ b/reference_model/src/verify/verify_utils.cc @@ -52,6 +52,7 @@ NLOHMANN_JSON_SERIALIZE_ENUM(VerifyMode, { VerifyMode::FpSpecial, "FP_SPECIAL" }, { VerifyMode::ReduceProduct, "REDUCE_PRODUCT" }, { VerifyMode::AbsError, "ABS_ERROR" }, + { VerifyMode::Relative, "RELATIVE" }, }) void from_json(const nlohmann::json& j, UlpVerifyInfo& ulpInfo) @@ -78,6 +79,12 @@ void from_json(const nlohmann::json& j, AbsErrorVerifyInfo& absErrorInfo) } } +void from_json(const nlohmann::json& j, RelativeVerifyInfo& rInfo) +{ + j.at("max").get_to(rInfo.max); + j.at("scale").get_to(rInfo.scale); +} + void from_json(const nlohmann::json& j, VerifyConfig& cfg) { j.at("mode").get_to(cfg.mode); @@ -100,6 +107,10 @@ void from_json(const nlohmann::json& j, VerifyConfig& cfg) { j.at("abs_error_info").get_to(cfg.absErrorInfo); } + if (j.contains("relative_info")) + { + j.at("relative_info").get_to(cfg.relativeInfo); + } } std::optional<VerifyConfig> parseVerifyConfig(const char* tensorName, const char* json) |