aboutsummaryrefslogtreecommitdiff
path: root/reference_model/src/generate/generate_utils.cc
diff options
context:
space:
mode:
Diffstat (limited to 'reference_model/src/generate/generate_utils.cc')
-rw-r--r--reference_model/src/generate/generate_utils.cc21
1 files changed, 18 insertions, 3 deletions
diff --git a/reference_model/src/generate/generate_utils.cc b/reference_model/src/generate/generate_utils.cc
index d0d0194..f31b443 100644
--- a/reference_model/src/generate/generate_utils.cc
+++ b/reference_model/src/generate/generate_utils.cc
@@ -105,9 +105,9 @@ NLOHMANN_JSON_SERIALIZE_ENUM(GeneratorType,
{ GeneratorType::Unknown, "UNKNOWN" },
{ GeneratorType::PseudoRandom, "PSEUDO_RANDOM" },
{ GeneratorType::DotProduct, "DOT_PRODUCT" },
- { GeneratorType::OpFullRange, "OP_FULL_RANGE" },
- { GeneratorType::OpBoundary, "OP_BOUNDARY" },
- { GeneratorType::OpSpecial, "OP_SPECIAL" },
+ { GeneratorType::FullRange, "FULL_RANGE" },
+ { GeneratorType::Boundary, "BOUNDARY" },
+ { GeneratorType::Special, "SPECIAL" },
{ GeneratorType::FixedData, "FIXED_DATA" },
})
@@ -151,6 +151,14 @@ void from_json(const nlohmann::json& j, FixedDataInfo& fixedDataInfo)
j.at("data").get_to(fixedDataInfo.data);
}
+void from_json(const nlohmann::json& j, FullRangeInfo& fullRangeInfo)
+{
+ if (j.contains("start_val"))
+ {
+ j.at("start_val").get_to(fullRangeInfo.startVal);
+ }
+}
+
void from_json(const nlohmann::json& j, GenerateConfig& cfg)
{
j.at("data_type").get_to(cfg.dataType);
@@ -186,6 +194,13 @@ void from_json(const nlohmann::json& j, GenerateConfig& cfg)
{
j.at("fixed_data_info").get_to(cfg.fixedDataInfo);
}
+
+ //Set up defaults for fullRangeInfo
+ cfg.fullRangeInfo.startVal = 0;
+ if (j.contains("full_range_info"))
+ {
+ j.at("full_range_info").get_to(cfg.fullRangeInfo);
+ }
}
std::optional<GenerateConfig> parseGenerateConfig(const char* json, const char* tensorName)