aboutsummaryrefslogtreecommitdiff
path: root/reference_model
diff options
context:
space:
mode:
authorJeremy Johnson <jeremy.johnson@arm.com>2023-11-15 15:52:06 +0000
committerJeremy Johnson <jeremy.johnson@arm.com>2023-11-23 14:09:14 +0000
commita015001dfbd0ed48caf54fd66b0509ee344a229e (patch)
tree5f99a7d2d4aba2db2e672efb1168db961f99a544 /reference_model
parent0bbd8bcfb20ec834f18d0bb89fc69ba4e92b3019 (diff)
downloadreference_model-a015001dfbd0ed48caf54fd66b0509ee344a229e.tar.gz
Main Compliance testing support for COMPARISON ops
Signed-off-by: Jeremy Johnson <jeremy.johnson@arm.com> Change-Id: Id6229cfaccad866b110630119eb045dbf6453bf5
Diffstat (limited to 'reference_model')
-rw-r--r--reference_model/src/generate/generate_pseudo_random.cc7
-rw-r--r--reference_model/src/generate/generate_utils.cc3
2 files changed, 10 insertions, 0 deletions
diff --git a/reference_model/src/generate/generate_pseudo_random.cc b/reference_model/src/generate/generate_pseudo_random.cc
index 78013eb..d8d2288 100644
--- a/reference_model/src/generate/generate_pseudo_random.cc
+++ b/reference_model/src/generate/generate_pseudo_random.cc
@@ -107,9 +107,16 @@ bool generateFP32(const TosaReference::GenerateConfig& cfg, void* data, size_t s
float* a = reinterpret_cast<float*>(data);
const auto T = TosaReference::numElementsFromShape(cfg.shape);
+ const bool comparisonOp =
+ (cfg.opType == Op::Op_EQUAL) || (cfg.opType == Op::Op_GREATER_EQUAL) || (cfg.opType == Op::Op_GREATER);
for (auto t = 0; t < T; ++t)
{
a[t] = generator->getRandomFloat();
+ if (comparisonOp && (t % 4 == 0))
+ {
+ // Set every 4th value to 0 to enable better comparison testing
+ a[t] = 0.f;
+ }
}
return true;
}
diff --git a/reference_model/src/generate/generate_utils.cc b/reference_model/src/generate/generate_utils.cc
index d2168c9..1edc79d 100644
--- a/reference_model/src/generate/generate_utils.cc
+++ b/reference_model/src/generate/generate_utils.cc
@@ -45,10 +45,13 @@ NLOHMANN_JSON_SERIALIZE_ENUM(Op,
{ Op::Op_CEIL, "CEIL" },
{ Op::Op_CLAMP, "CLAMP" },
{ Op::Op_CONV2D, "CONV2D" },
+ { Op::Op_EQUAL, "EQUAL" },
{ Op::Op_ERF, "ERF" },
{ Op::Op_EXP, "EXP" },
{ Op::Op_FLOOR, "FLOOR" },
{ Op::Op_FULLY_CONNECTED, "FULLY_CONNECTED" },
+ { Op::Op_GREATER, "GREATER" },
+ { Op::Op_GREATER_EQUAL, "GREATER_EQUAL" },
{ Op::Op_IDENTITY, "IDENTITY" },
{ Op::Op_LOG, "LOG" },
{ Op::Op_MATMUL, "MATMUL" },