diff options
Diffstat (limited to 'reference_model/src/generate/generate_utils.cc')
-rw-r--r-- | reference_model/src/generate/generate_utils.cc | 15 |
1 files changed, 15 insertions, 0 deletions
diff --git a/reference_model/src/generate/generate_utils.cc b/reference_model/src/generate/generate_utils.cc index c16d1c6..9eda0b6 100644 --- a/reference_model/src/generate/generate_utils.cc +++ b/reference_model/src/generate/generate_utils.cc @@ -33,6 +33,7 @@ NLOHMANN_JSON_SERIALIZE_ENUM(DType, { DType::DType_FP16, "FP16" }, { DType::DType_BF16, "BF16" }, { DType::DType_FP32, "FP32" }, + { DType::DType_SHAPE, "SHAPE" }, }) NLOHMANN_JSON_SERIALIZE_ENUM(Op, @@ -93,6 +94,7 @@ NLOHMANN_JSON_SERIALIZE_ENUM(GeneratorType, { GeneratorType::OpFullRange, "OP_FULL_RANGE" }, { GeneratorType::OpBoundary, "OP_BOUNDARY" }, { GeneratorType::OpSpecial, "OP_SPECIAL" }, + { GeneratorType::FixedData, "FIXED_DATA" }, }) // NOTE: This assumes it's VARIABLE if the InputType is not recognized @@ -130,6 +132,11 @@ void from_json(const nlohmann::json& j, PseudoRandomInfo& pseudoRandomInfo) } } +void from_json(const nlohmann::json& j, FixedDataInfo& fixedDataInfo) +{ + j.at("data").get_to(fixedDataInfo.data); +} + void from_json(const nlohmann::json& j, GenerateConfig& cfg) { j.at("data_type").get_to(cfg.dataType); @@ -158,6 +165,13 @@ void from_json(const nlohmann::json& j, GenerateConfig& cfg) { j.at("pseudo_random_info").get_to(cfg.pseudoRandomInfo); } + + // Set up defaults for fixedDataInfo + cfg.fixedDataInfo.data = std::vector<int32_t>(); + if (j.contains("fixed_data_info")) + { + j.at("fixed_data_info").get_to(cfg.fixedDataInfo); + } } std::optional<GenerateConfig> parseGenerateConfig(const char* json, const char* tensorName) @@ -209,6 +223,7 @@ size_t elementSizeFromType(DType type) return 2; case DType::DType_INT32: case DType::DType_FP32: + case DType::DType_SHAPE: return 4; default: return 0; |