aboutsummaryrefslogtreecommitdiff
path: root/reference_model/src/generate/generate_utils.cc
diff options
context:
space:
mode:
Diffstat (limited to 'reference_model/src/generate/generate_utils.cc')
-rw-r--r--reference_model/src/generate/generate_utils.cc16
1 files changed, 16 insertions, 0 deletions
diff --git a/reference_model/src/generate/generate_utils.cc b/reference_model/src/generate/generate_utils.cc
index 1edc79d..58a3d33 100644
--- a/reference_model/src/generate/generate_utils.cc
+++ b/reference_model/src/generate/generate_utils.cc
@@ -116,6 +116,10 @@ void from_json(const nlohmann::json& j, PseudoRandomInfo& pseudoRandomInfo)
{
j.at("range").get_to(pseudoRandomInfo.range);
}
+ if (j.contains("round"))
+ {
+ j.at("round").get_to(pseudoRandomInfo.round);
+ }
}
void from_json(const nlohmann::json& j, GenerateConfig& cfg)
@@ -126,10 +130,22 @@ void from_json(const nlohmann::json& j, GenerateConfig& cfg)
j.at("input_pos").get_to(cfg.inputPos);
j.at("op").get_to(cfg.opType);
j.at("generator").get_to(cfg.generatorType);
+
+ // Set up defaults for dotProductInfo
+ cfg.dotProductInfo.s = -1;
+ cfg.dotProductInfo.ks = -1;
+ cfg.dotProductInfo.accType = DType_UNKNOWN;
+ cfg.dotProductInfo.kernel = std::vector<int32_t>();
+ cfg.dotProductInfo.axis = -1;
if (j.contains("dot_product_info"))
{
j.at("dot_product_info").get_to(cfg.dotProductInfo);
}
+
+ // Set up defaults for pseudoRandomInfo
+ cfg.pseudoRandomInfo.rngSeed = -1;
+ cfg.pseudoRandomInfo.range = std::vector<std::string>();
+ cfg.pseudoRandomInfo.round = false;
if (j.contains("pseudo_random_info"))
{
j.at("pseudo_random_info").get_to(cfg.pseudoRandomInfo);