aboutsummaryrefslogtreecommitdiff
path: root/reference_model/test
diff options
context:
space:
mode:
authorJeremy Johnson <jeremy.johnson@arm.com>2023-12-07 16:35:28 +0000
committerEric Kunze <eric.kunze@arm.com>2023-12-14 17:56:51 +0000
commita8420add949564053495ef78f3213f163c30fb9a (patch)
tree4c5e2783433e9443b2ed02e5e25c51cc5de2affd /reference_model/test
parent81db5d2f275f69cc0d3e8687af57bdba99971042 (diff)
downloadreference_model-a8420add949564053495ef78f3213f163c30fb9a.tar.gz
Main Compliance testing for SCATTER and GATHER
Added indices shuffling and random INT32 support to generate lib with testing of these new random generator modes Signed-off-by: Jeremy Johnson <jeremy.johnson@arm.com> Change-Id: I058d8b092470228075e8fe69c2ededa639163003
Diffstat (limited to 'reference_model/test')
-rw-r--r--reference_model/test/generate_tests.cpp75
1 files changed, 74 insertions, 1 deletions
diff --git a/reference_model/test/generate_tests.cpp b/reference_model/test/generate_tests.cpp
index 2c318e0..e4a6d20 100644
--- a/reference_model/test/generate_tests.cpp
+++ b/reference_model/test/generate_tests.cpp
@@ -448,7 +448,7 @@ TEST_CASE("positive - FP32 conv2d dot product (last 3 values)")
conv2d_test_FP32(tosaName, tosaElements, templateJsonCfg, "5", 2, lastExpected);
}
}
-TEST_CASE("positive - pseudo random")
+TEST_CASE("positive - FP32 pseudo random")
{
std::string templateJsonCfg = R"({
"tensors" : {
@@ -823,4 +823,77 @@ TEST_CASE("positive - FP32 avg_pool2d dot product (first 3 values)")
avg_pool2d_test_FP32(tosaName, tosaElements, templateJsonCfg, "5", expected);
}
}
+
+TEST_CASE("positive - INT32 pseudo random")
+{
+ std::string templateJsonCfg = R"({
+ "tensors" : {
+ "input0" : {
+ "generator": "PSEUDO_RANDOM",
+ "data_type": "INT32",
+ "input_type": "VARIABLE",
+ "shape" : [ 2, 12 ],
+ "input_pos": 0,
+ "op" : "SCATTER",
+ "pseudo_random_info": {
+ "rng_seed": 13,
+ "range": [ "-5", "5" ]
+ }
+ },
+ "input1" : {
+ "generator": "PSEUDO_RANDOM",
+ "data_type": "INT32",
+ "input_type": "VARIABLE",
+ "shape" : [ 2, 10 ],
+ "input_pos": 1,
+ "op" : "SCATTER",
+ "pseudo_random_info": {
+ "rng_seed": 14,
+ "range": [ "0", "9" ]
+ }
+ }
+
+ }
+ })";
+
+ const std::string tosaNameP0 = "input0";
+ const size_t tosaElementsP0 = 2 * 12;
+ const std::string tosaNameP1 = "input1";
+ const size_t tosaElementsP1 = 2 * 10;
+
+ SUBCASE("scatter - int32 random")
+ {
+ std::string jsonCfg = templateJsonCfg;
+
+ std::vector<int32_t> bufferP0(tosaElementsP0);
+ REQUIRE(tgd_generate_data(jsonCfg.c_str(), tosaNameP0.c_str(), (void*)bufferP0.data(), tosaElementsP0 * 4));
+ for (auto e = bufferP0.begin(); e < bufferP0.end(); ++e)
+ {
+ // Check the values are within range
+ bool withinRange = (*e >= -5 && *e <= 5);
+ REQUIRE(withinRange);
+ }
+ }
+
+ SUBCASE("scatter - int32 row shuffle")
+ {
+ std::string jsonCfg = templateJsonCfg;
+
+ std::vector<int32_t> bufferP1(tosaElementsP1);
+ REQUIRE(tgd_generate_data(jsonCfg.c_str(), tosaNameP1.c_str(), (void*)bufferP1.data(), tosaElementsP1 * 4));
+
+ std::vector<bool> set;
+ for (int32_t n = 0; n < 2; ++n)
+ {
+ set.assign(10, false);
+ for (int32_t i = 0; i < 10; ++i)
+ {
+ auto idx = bufferP1[i];
+ // Check that the values in the buffer only occur once
+ REQUIRE(!set[idx]);
+ set[idx] = true;
+ }
+ }
+ }
+}
TEST_SUITE_END(); // generate