aboutsummaryrefslogtreecommitdiff
path: root/reference_model/src/generate
diff options
context:
space:
mode:
Diffstat (limited to 'reference_model/src/generate')
-rw-r--r--reference_model/src/generate/generate_dot_product.cc7
-rw-r--r--reference_model/src/generate/generate_pseudo_random.cc5
-rw-r--r--reference_model/src/generate/generate_utils.cc16
-rw-r--r--reference_model/src/generate/generate_utils.h1
4 files changed, 28 insertions, 1 deletions
diff --git a/reference_model/src/generate/generate_dot_product.cc b/reference_model/src/generate/generate_dot_product.cc
index 4054472..c8a2b13 100644
--- a/reference_model/src/generate/generate_dot_product.cc
+++ b/reference_model/src/generate/generate_dot_product.cc
@@ -387,7 +387,12 @@ bool generateDotProduct(const GenerateConfig& cfg, void* data, size_t size)
if (!generator)
{
WARNING("[Generator][DP] Requested generator could not be created!");
- return 0;
+ return false;
+ }
+ if (cfg.dotProductInfo.ks <= 0)
+ {
+ WARNING("[Generator][DP] Invalid test set kernel size %d.", cfg.dotProductInfo.ks);
+ return false;
}
// Select which generator to use
diff --git a/reference_model/src/generate/generate_pseudo_random.cc b/reference_model/src/generate/generate_pseudo_random.cc
index d8d2288..b51424d 100644
--- a/reference_model/src/generate/generate_pseudo_random.cc
+++ b/reference_model/src/generate/generate_pseudo_random.cc
@@ -93,6 +93,7 @@ bool generateFP32(const TosaReference::GenerateConfig& cfg, void* data, size_t s
const TosaReference::PseudoRandomInfo& prinfo = cfg.pseudoRandomInfo;
PseudoRandomGeneratorFloat<float>* generator;
+ bool roundMode = prinfo.round;
if (prinfo.range.size() == 2)
{
@@ -117,6 +118,10 @@ bool generateFP32(const TosaReference::GenerateConfig& cfg, void* data, size_t s
// Set every 4th value to 0 to enable better comparison testing
a[t] = 0.f;
}
+ else if (roundMode)
+ {
+ a[t] = std::roundf(a[t]);
+ }
}
return true;
}
diff --git a/reference_model/src/generate/generate_utils.cc b/reference_model/src/generate/generate_utils.cc
index 1edc79d..58a3d33 100644
--- a/reference_model/src/generate/generate_utils.cc
+++ b/reference_model/src/generate/generate_utils.cc
@@ -116,6 +116,10 @@ void from_json(const nlohmann::json& j, PseudoRandomInfo& pseudoRandomInfo)
{
j.at("range").get_to(pseudoRandomInfo.range);
}
+ if (j.contains("round"))
+ {
+ j.at("round").get_to(pseudoRandomInfo.round);
+ }
}
void from_json(const nlohmann::json& j, GenerateConfig& cfg)
@@ -126,10 +130,22 @@ void from_json(const nlohmann::json& j, GenerateConfig& cfg)
j.at("input_pos").get_to(cfg.inputPos);
j.at("op").get_to(cfg.opType);
j.at("generator").get_to(cfg.generatorType);
+
+ // Set up defaults for dotProductInfo
+ cfg.dotProductInfo.s = -1;
+ cfg.dotProductInfo.ks = -1;
+ cfg.dotProductInfo.accType = DType_UNKNOWN;
+ cfg.dotProductInfo.kernel = std::vector<int32_t>();
+ cfg.dotProductInfo.axis = -1;
if (j.contains("dot_product_info"))
{
j.at("dot_product_info").get_to(cfg.dotProductInfo);
}
+
+ // Set up defaults for pseudoRandomInfo
+ cfg.pseudoRandomInfo.rngSeed = -1;
+ cfg.pseudoRandomInfo.range = std::vector<std::string>();
+ cfg.pseudoRandomInfo.round = false;
if (j.contains("pseudo_random_info"))
{
j.at("pseudo_random_info").get_to(cfg.pseudoRandomInfo);
diff --git a/reference_model/src/generate/generate_utils.h b/reference_model/src/generate/generate_utils.h
index 8d0f654..f9ec713 100644
--- a/reference_model/src/generate/generate_utils.h
+++ b/reference_model/src/generate/generate_utils.h
@@ -62,6 +62,7 @@ struct PseudoRandomInfo
int64_t rngSeed;
std::vector<std::string> range;
+ bool round;
};
/// \brief Generator configuration