aboutsummaryrefslogtreecommitdiff
path: root/reference_model/src/generate/generate_utils.cc
diff options
context:
space:
mode:
authorevacha01 <evan.chandler@arm.com>2024-03-08 16:39:24 +0000
committerEric Kunze <eric.kunze@arm.com>2024-04-16 16:02:16 +0000
commit4a2051146f498cb9ec35d7213720540c5c3e81e2 (patch)
tree543000b3ef22bd587c3c7702100742e4b94eb5fb /reference_model/src/generate/generate_utils.cc
parent5d0e9c7f3748e80d6f14a3eeaef858eeb912e1fd (diff)
downloadreference_model-4a2051146f498cb9ec35d7213720540c5c3e81e2.tar.gz
SPECIAL data gen mode for FP16 and FP32
Signed-off-by: evacha01 <evan.chandler@arm.com> Change-Id: I5a9a1c63345bd83ca04bc6c2a99b0ef3612971ee
Diffstat (limited to 'reference_model/src/generate/generate_utils.cc')
-rw-r--r--reference_model/src/generate/generate_utils.cc17
1 files changed, 16 insertions, 1 deletions
diff --git a/reference_model/src/generate/generate_utils.cc b/reference_model/src/generate/generate_utils.cc
index f31b443..d62c247 100644
--- a/reference_model/src/generate/generate_utils.cc
+++ b/reference_model/src/generate/generate_utils.cc
@@ -107,7 +107,7 @@ NLOHMANN_JSON_SERIALIZE_ENUM(GeneratorType,
{ GeneratorType::DotProduct, "DOT_PRODUCT" },
{ GeneratorType::FullRange, "FULL_RANGE" },
{ GeneratorType::Boundary, "BOUNDARY" },
- { GeneratorType::Special, "SPECIAL" },
+ { GeneratorType::FpSpecial, "FP_SPECIAL" },
{ GeneratorType::FixedData, "FIXED_DATA" },
})
@@ -159,6 +159,14 @@ void from_json(const nlohmann::json& j, FullRangeInfo& fullRangeInfo)
}
}
+void from_json(const nlohmann::json& j, FpSpecialInfo& fpSpecialInfo)
+{
+ if (j.contains("start_idx"))
+ {
+ j.at("start_idx").get_to(fpSpecialInfo.startIndex);
+ }
+}
+
void from_json(const nlohmann::json& j, GenerateConfig& cfg)
{
j.at("data_type").get_to(cfg.dataType);
@@ -201,6 +209,13 @@ void from_json(const nlohmann::json& j, GenerateConfig& cfg)
{
j.at("full_range_info").get_to(cfg.fullRangeInfo);
}
+
+ //Set up defaults for fpSpecialInfo
+ cfg.fpSpecialInfo.startIndex = 0;
+ if (j.contains("fp_special_info"))
+ {
+ j.at("fp_special_info").get_to(cfg.fpSpecialInfo);
+ }
}
std::optional<GenerateConfig> parseGenerateConfig(const char* json, const char* tensorName)