aboutsummaryrefslogtreecommitdiff
path: root/reference_model/src/generate/generate_utils.cc
diff options
context:
space:
mode:
authorevacha01 <evan.chandler@arm.com>2024-02-07 11:21:55 +0000
committerevacha01 <evan.chandler@arm.com>2024-03-07 12:06:38 +0000
commit9c96eefbaca6c85be79529bce7ff04fd7dfe3a0d (patch)
tree55647ee0216800b621bd0b27277c6f895929ef3d /reference_model/src/generate/generate_utils.cc
parent6e1e2bc06bff785e87577f24064bbc846300f8fd (diff)
downloadreference_model-9c96eefbaca6c85be79529bce7ff04fd7dfe3a0d.tar.gz
FULL data gen mode for FP16
Signed-off-by: evacha01 <evan.chandler@arm.com> Change-Id: I81bb322132daf25328a40342edc62d8e1db9edd6
Diffstat (limited to 'reference_model/src/generate/generate_utils.cc')
-rw-r--r--reference_model/src/generate/generate_utils.cc21
1 files changed, 18 insertions, 3 deletions
diff --git a/reference_model/src/generate/generate_utils.cc b/reference_model/src/generate/generate_utils.cc
index d0d0194..f31b443 100644
--- a/reference_model/src/generate/generate_utils.cc
+++ b/reference_model/src/generate/generate_utils.cc
@@ -105,9 +105,9 @@ NLOHMANN_JSON_SERIALIZE_ENUM(GeneratorType,
{ GeneratorType::Unknown, "UNKNOWN" },
{ GeneratorType::PseudoRandom, "PSEUDO_RANDOM" },
{ GeneratorType::DotProduct, "DOT_PRODUCT" },
- { GeneratorType::OpFullRange, "OP_FULL_RANGE" },
- { GeneratorType::OpBoundary, "OP_BOUNDARY" },
- { GeneratorType::OpSpecial, "OP_SPECIAL" },
+ { GeneratorType::FullRange, "FULL_RANGE" },
+ { GeneratorType::Boundary, "BOUNDARY" },
+ { GeneratorType::Special, "SPECIAL" },
{ GeneratorType::FixedData, "FIXED_DATA" },
})
@@ -151,6 +151,14 @@ void from_json(const nlohmann::json& j, FixedDataInfo& fixedDataInfo)
j.at("data").get_to(fixedDataInfo.data);
}
+void from_json(const nlohmann::json& j, FullRangeInfo& fullRangeInfo)
+{
+ if (j.contains("start_val"))
+ {
+ j.at("start_val").get_to(fullRangeInfo.startVal);
+ }
+}
+
void from_json(const nlohmann::json& j, GenerateConfig& cfg)
{
j.at("data_type").get_to(cfg.dataType);
@@ -186,6 +194,13 @@ void from_json(const nlohmann::json& j, GenerateConfig& cfg)
{
j.at("fixed_data_info").get_to(cfg.fixedDataInfo);
}
+
+ //Set up defaults for fullRangeInfo
+ cfg.fullRangeInfo.startVal = 0;
+ if (j.contains("full_range_info"))
+ {
+ j.at("full_range_info").get_to(cfg.fullRangeInfo);
+ }
}
std::optional<GenerateConfig> parseGenerateConfig(const char* json, const char* tensorName)