diff options
Diffstat (limited to 'reference_model/src/generate/generate_utils.cc')
-rw-r--r-- | reference_model/src/generate/generate_utils.cc | 16 |
1 files changed, 16 insertions, 0 deletions
diff --git a/reference_model/src/generate/generate_utils.cc b/reference_model/src/generate/generate_utils.cc index 1edc79d..58a3d33 100644 --- a/reference_model/src/generate/generate_utils.cc +++ b/reference_model/src/generate/generate_utils.cc @@ -116,6 +116,10 @@ void from_json(const nlohmann::json& j, PseudoRandomInfo& pseudoRandomInfo) { j.at("range").get_to(pseudoRandomInfo.range); } + if (j.contains("round")) + { + j.at("round").get_to(pseudoRandomInfo.round); + } } void from_json(const nlohmann::json& j, GenerateConfig& cfg) @@ -126,10 +130,22 @@ void from_json(const nlohmann::json& j, GenerateConfig& cfg) j.at("input_pos").get_to(cfg.inputPos); j.at("op").get_to(cfg.opType); j.at("generator").get_to(cfg.generatorType); + + // Set up defaults for dotProductInfo + cfg.dotProductInfo.s = -1; + cfg.dotProductInfo.ks = -1; + cfg.dotProductInfo.accType = DType_UNKNOWN; + cfg.dotProductInfo.kernel = std::vector<int32_t>(); + cfg.dotProductInfo.axis = -1; if (j.contains("dot_product_info")) { j.at("dot_product_info").get_to(cfg.dotProductInfo); } + + // Set up defaults for pseudoRandomInfo + cfg.pseudoRandomInfo.rngSeed = -1; + cfg.pseudoRandomInfo.range = std::vector<std::string>(); + cfg.pseudoRandomInfo.round = false; if (j.contains("pseudo_random_info")) { j.at("pseudo_random_info").get_to(cfg.pseudoRandomInfo); |