aboutsummaryrefslogtreecommitdiff
path: root/reference_model/src/generate/generate_utils.cc
diff options
context:
space:
mode:
Diffstat (limited to 'reference_model/src/generate/generate_utils.cc')
-rw-r--r--reference_model/src/generate/generate_utils.cc17
1 files changed, 16 insertions, 1 deletions
diff --git a/reference_model/src/generate/generate_utils.cc b/reference_model/src/generate/generate_utils.cc
index f31b443..d62c247 100644
--- a/reference_model/src/generate/generate_utils.cc
+++ b/reference_model/src/generate/generate_utils.cc
@@ -107,7 +107,7 @@ NLOHMANN_JSON_SERIALIZE_ENUM(GeneratorType,
{ GeneratorType::DotProduct, "DOT_PRODUCT" },
{ GeneratorType::FullRange, "FULL_RANGE" },
{ GeneratorType::Boundary, "BOUNDARY" },
- { GeneratorType::Special, "SPECIAL" },
+ { GeneratorType::FpSpecial, "FP_SPECIAL" },
{ GeneratorType::FixedData, "FIXED_DATA" },
})
@@ -159,6 +159,14 @@ void from_json(const nlohmann::json& j, FullRangeInfo& fullRangeInfo)
}
}
+void from_json(const nlohmann::json& j, FpSpecialInfo& fpSpecialInfo)
+{
+ if (j.contains("start_idx"))
+ {
+ j.at("start_idx").get_to(fpSpecialInfo.startIndex);
+ }
+}
+
void from_json(const nlohmann::json& j, GenerateConfig& cfg)
{
j.at("data_type").get_to(cfg.dataType);
@@ -201,6 +209,13 @@ void from_json(const nlohmann::json& j, GenerateConfig& cfg)
{
j.at("full_range_info").get_to(cfg.fullRangeInfo);
}
+
+ //Set up defaults for fpSpecialInfo
+ cfg.fpSpecialInfo.startIndex = 0;
+ if (j.contains("fp_special_info"))
+ {
+ j.at("fp_special_info").get_to(cfg.fpSpecialInfo);
+ }
}
std::optional<GenerateConfig> parseGenerateConfig(const char* json, const char* tensorName)