diff options
Diffstat (limited to 'reference_model/src/generate/generate_utils.cc')
-rw-r--r-- | reference_model/src/generate/generate_utils.cc | 17 |
1 files changed, 16 insertions, 1 deletions
diff --git a/reference_model/src/generate/generate_utils.cc b/reference_model/src/generate/generate_utils.cc index f31b443..d62c247 100644 --- a/reference_model/src/generate/generate_utils.cc +++ b/reference_model/src/generate/generate_utils.cc @@ -107,7 +107,7 @@ NLOHMANN_JSON_SERIALIZE_ENUM(GeneratorType, { GeneratorType::DotProduct, "DOT_PRODUCT" }, { GeneratorType::FullRange, "FULL_RANGE" }, { GeneratorType::Boundary, "BOUNDARY" }, - { GeneratorType::Special, "SPECIAL" }, + { GeneratorType::FpSpecial, "FP_SPECIAL" }, { GeneratorType::FixedData, "FIXED_DATA" }, }) @@ -159,6 +159,14 @@ void from_json(const nlohmann::json& j, FullRangeInfo& fullRangeInfo) } } +void from_json(const nlohmann::json& j, FpSpecialInfo& fpSpecialInfo) +{ + if (j.contains("start_idx")) + { + j.at("start_idx").get_to(fpSpecialInfo.startIndex); + } +} + void from_json(const nlohmann::json& j, GenerateConfig& cfg) { j.at("data_type").get_to(cfg.dataType); @@ -201,6 +209,13 @@ void from_json(const nlohmann::json& j, GenerateConfig& cfg) { j.at("full_range_info").get_to(cfg.fullRangeInfo); } + + //Set up defaults for fpSpecialInfo + cfg.fpSpecialInfo.startIndex = 0; + if (j.contains("fp_special_info")) + { + j.at("fp_special_info").get_to(cfg.fpSpecialInfo); + } } std::optional<GenerateConfig> parseGenerateConfig(const char* json, const char* tensorName) |