aboutsummaryrefslogtreecommitdiff
path: root/reference_model/test
diff options
context:
space:
mode:
authorDmitrii Agibov <dmitrii.agibov@arm.com>2023-12-01 12:18:15 +0000
committerDmitrii Agibov <dmitrii.agibov@arm.com>2023-12-07 17:19:33 +0000
commit0de60f38e828e13359acbcfac51b6c179a34d042 (patch)
tree532cda49e0c9101440715b17a3e18ddaebc75858 /reference_model/test
parente9059775c0486de4a96d42b41104496f4aefe8e8 (diff)
downloadreference_model-0de60f38e828e13359acbcfac51b6c179a34d042.tar.gz
Add support for list of tensors as input parameter
Some operators (e.g. Concat) expect list of tensor as an input parameter. Currently operators API does not support passing such parameters from the client code. In order to enable it: - Add new type tensor_list_t - Update operators API generation script to support new type - Add unit test for operator Concat Signed-off-by: Dmitrii Agibov <dmitrii.agibov@arm.com> Change-Id: Ib2f61bcea5e5ecabf56ce031d905cb46a4cc68ea
Diffstat (limited to 'reference_model/test')
-rw-r--r--reference_model/test/model_runner_tests.cpp47
1 files changed, 47 insertions, 0 deletions
diff --git a/reference_model/test/model_runner_tests.cpp b/reference_model/test/model_runner_tests.cpp
index 35e3aa2..5292dd8 100644
--- a/reference_model/test/model_runner_tests.cpp
+++ b/reference_model/test/model_runner_tests.cpp
@@ -152,6 +152,53 @@ TEST_SUITE("model_runner")
compareOutput(dstData, expectedData, expectedData.size());
}
+ TEST_CASE("op_entry_concat")
+ {
+ // Concat parameters
+ const int32_t axis = 2;
+
+ // Inputs/Outputs
+ tosa_datatype_t dt = tosa_datatype_fp32_t;
+ std::vector<int32_t> input1_shape = { 1, 2, 3, 4 };
+ std::vector<int32_t> input2_shape = { 1, 2, 5, 4 };
+ std::vector<int32_t> output_shape = { 1, 2, 8, 4 };
+ std::vector<float> src1Data(24, 1.0f);
+ std::vector<float> src2Data(40, 1.0f);
+ std::vector<float> dstData(64, 0.f);
+
+ tosa_tensor_t input1;
+ input1.shape = input1_shape.data();
+ input1.num_dims = input1_shape.size();
+ input1.data_type = dt;
+ input1.data = reinterpret_cast<uint8_t*>(src1Data.data());
+ input1.size = src1Data.size() * sizeof(float);
+
+ tosa_tensor_t input2;
+ input2.shape = input2_shape.data();
+ input2.num_dims = input2_shape.size();
+ input2.data_type = dt;
+ input2.data = reinterpret_cast<uint8_t*>(src2Data.data());
+ input2.size = src2Data.size() * sizeof(float);
+
+ tosa_tensor_list_t input_list;
+ tosa_tensor_t inputs[]{ input1, input2 };
+ input_list.size = 2;
+ input_list.tensors = inputs;
+
+ tosa_tensor_t output;
+ output.shape = output_shape.data();
+ output.num_dims = output_shape.size();
+ output.data_type = dt;
+ output.data = reinterpret_cast<uint8_t*>(dstData.data());
+ output.size = dstData.size() * sizeof(float);
+
+ auto status = tosa_run_concat(input_list, axis, output, {});
+ CHECK((status == tosa_status_valid));
+
+ std::vector<float> expectedData(64, 1.0f);
+ compareOutput(dstData, expectedData, expectedData.size());
+ }
+
TEST_CASE("op_entry_conv2d")
{
// Conv parameters