aboutsummaryrefslogtreecommitdiff
path: root/reference_model/src/generate/generate_utils.cc
diff options
context:
space:
mode:
Diffstat (limited to 'reference_model/src/generate/generate_utils.cc')
-rw-r--r--reference_model/src/generate/generate_utils.cc15
1 files changed, 15 insertions, 0 deletions
diff --git a/reference_model/src/generate/generate_utils.cc b/reference_model/src/generate/generate_utils.cc
index c16d1c6..9eda0b6 100644
--- a/reference_model/src/generate/generate_utils.cc
+++ b/reference_model/src/generate/generate_utils.cc
@@ -33,6 +33,7 @@ NLOHMANN_JSON_SERIALIZE_ENUM(DType,
{ DType::DType_FP16, "FP16" },
{ DType::DType_BF16, "BF16" },
{ DType::DType_FP32, "FP32" },
+ { DType::DType_SHAPE, "SHAPE" },
})
NLOHMANN_JSON_SERIALIZE_ENUM(Op,
@@ -93,6 +94,7 @@ NLOHMANN_JSON_SERIALIZE_ENUM(GeneratorType,
{ GeneratorType::OpFullRange, "OP_FULL_RANGE" },
{ GeneratorType::OpBoundary, "OP_BOUNDARY" },
{ GeneratorType::OpSpecial, "OP_SPECIAL" },
+ { GeneratorType::FixedData, "FIXED_DATA" },
})
// NOTE: This assumes it's VARIABLE if the InputType is not recognized
@@ -130,6 +132,11 @@ void from_json(const nlohmann::json& j, PseudoRandomInfo& pseudoRandomInfo)
}
}
+void from_json(const nlohmann::json& j, FixedDataInfo& fixedDataInfo)
+{
+ j.at("data").get_to(fixedDataInfo.data);
+}
+
void from_json(const nlohmann::json& j, GenerateConfig& cfg)
{
j.at("data_type").get_to(cfg.dataType);
@@ -158,6 +165,13 @@ void from_json(const nlohmann::json& j, GenerateConfig& cfg)
{
j.at("pseudo_random_info").get_to(cfg.pseudoRandomInfo);
}
+
+ // Set up defaults for fixedDataInfo
+ cfg.fixedDataInfo.data = std::vector<int32_t>();
+ if (j.contains("fixed_data_info"))
+ {
+ j.at("fixed_data_info").get_to(cfg.fixedDataInfo);
+ }
}
std::optional<GenerateConfig> parseGenerateConfig(const char* json, const char* tensorName)
@@ -209,6 +223,7 @@ size_t elementSizeFromType(DType type)
return 2;
case DType::DType_INT32:
case DType::DType_FP32:
+ case DType::DType_SHAPE:
return 4;
default:
return 0;