From 0b23e0e6402cb18ddf621d36454cadbb73959518 Mon Sep 17 00:00:00 2001 From: Viet-Hoa Do Date: Tue, 25 Jul 2023 14:00:46 +0100 Subject: Add TensorOperand and declare tensor argument Partially resolves: COMPMID-6391 Signed-off-by: Viet-Hoa Do Change-Id: I849d486401f99a93919015f2e173559dca5bffa2 Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/9972 Tested-by: Arm Jenkins Reviewed-by: Gunes Bayir Reviewed-by: Gian Marco Iodice Comments-Addressed: Arm Jenkins Benchmark: Arm Jenkins --- .../validation/tests/CLTensorArgumentTest.h | 59 ++++++++++++---------- 1 file changed, 33 insertions(+), 26 deletions(-) (limited to 'compute_kernel_writer/validation/tests') diff --git a/compute_kernel_writer/validation/tests/CLTensorArgumentTest.h b/compute_kernel_writer/validation/tests/CLTensorArgumentTest.h index 6db1384247..d3e455cb83 100644 --- a/compute_kernel_writer/validation/tests/CLTensorArgumentTest.h +++ b/compute_kernel_writer/validation/tests/CLTensorArgumentTest.h @@ -22,12 +22,13 @@ * SOFTWARE. */ -#ifndef CKW_TESTS_CLTENSORARGUMENTTEST_H -#define CKW_TESTS_CLTENSORARGUMENTTEST_H +#ifndef CKW_VALIDATION_TESTS_CLTENSORARGUMENTTEST_H +#define CKW_VALIDATION_TESTS_CLTENSORARGUMENTTEST_H #include "common/Common.h" #include "src/cl/CLHelpers.h" #include "src/cl/CLTensorArgument.h" +#include "src/cl/CLTensorComponent.h" #include #include @@ -89,7 +90,7 @@ public: CLTensorArgument arg(tensor_name, info, false /* return_dims_by_value */); const std::string expected_var_name = _expected_vars[i]; - const std::string actual_var_name = arg.component(_components[i]).str; + const std::string actual_var_name = arg.component(_components[i]).name(); VALIDATE_TEST(actual_var_name.compare(expected_var_name) == 0, all_tests_passed, test_idx++); } @@ -200,8 +201,8 @@ public: { CLTensorArgument arg(tensor_name, info, true /* return_dims_by_value */); - const std::string expected_var_val = _expected_vals[i]; - const std::string actual_var_val = arg.component(_components[i]).str; + const std::string expected_var_val = std::string("((int)(") + _expected_vals[i] + "))"; + const std::string actual_var_val = arg.cl_component(_components[i]).scalar(0, 0).str; VALIDATE_TEST(actual_var_val.compare(expected_var_val) == 0, all_tests_passed, test_idx++); } @@ -276,18 +277,20 @@ public: { // Validate variable name const std::string expected_var_name = _expected_vars[i]; - const std::string actual_var_name = actual_vars[i].str; + const std::string actual_var_name = actual_vars[i]->tile().name(); VALIDATE_TEST(actual_var_name.compare(expected_var_name) == 0, all_tests_passed, test_idx++); // Validate data type const DataType expected_var_type = DataType::Int32; - const DataType actual_var_type = actual_vars[i].desc.dt; + const DataType actual_var_type = actual_vars[i]->tile().info().data_type(); VALIDATE_TEST(actual_var_type == expected_var_type, all_tests_passed, test_idx++); - // Validate data type length - const int32_t expected_var_len = 1; - const int32_t actual_var_len = actual_vars[i].desc.len; - VALIDATE_TEST(actual_var_len == expected_var_len, all_tests_passed, test_idx++); + // Validate tile shape + const int32_t actual_var_width = actual_vars[i]->tile().info().width(); + const int32_t actual_var_height = actual_vars[i]->tile().info().height(); + + VALIDATE_TEST(actual_var_height == 1, all_tests_passed, test_idx++); + VALIDATE_TEST(actual_var_width == 1, all_tests_passed, test_idx++); } return all_tests_passed; } @@ -356,18 +359,20 @@ public: { // Validate variable name const std::string expected_var_name = _expected_vars[i]; - const std::string actual_var_name = actual_vars[i].str; + const std::string actual_var_name = actual_vars[i]->tile().name(); VALIDATE_TEST(actual_var_name.compare(expected_var_name) == 0, all_tests_passed, test_idx++); // Validate data type const DataType expected_var_type = DataType::Int32; - const DataType actual_var_type = actual_vars[i].desc.dt; + const DataType actual_var_type = actual_vars[i]->tile().info().data_type(); VALIDATE_TEST(actual_var_type == expected_var_type, all_tests_passed, test_idx++); - // Validate data type length - const int32_t expected_var_len = 1; - const int32_t actual_var_len = actual_vars[i].desc.len; - VALIDATE_TEST(actual_var_len == expected_var_len, all_tests_passed, test_idx++); + // Validate tile shape + const int32_t actual_var_width = actual_vars[i]->tile().info().width(); + const int32_t actual_var_height = actual_vars[i]->tile().info().height(); + + VALIDATE_TEST(actual_var_height == 1, all_tests_passed, test_idx++); + VALIDATE_TEST(actual_var_width == 1, all_tests_passed, test_idx++); } return all_tests_passed; } @@ -430,8 +435,8 @@ public: VALIDATE_TEST(actual_var_name.compare(expected_var_name) == 0, all_tests_passed, test_idx++); // Validate storage type - const std::string expected_var_type = cl_get_variable_storagetype_as_string(_storages[i]); - const std::string actual_var_type = actual_vars[i].type; + const auto expected_var_type = _storages[i]; + const auto actual_var_type = actual_vars[i].type; VALIDATE_TEST(actual_var_type == expected_var_type, all_tests_passed, test_idx++); } return all_tests_passed; @@ -503,18 +508,20 @@ public: { // Validate variable name const std::string expected_var_name = _expected_vars[i]; - const std::string actual_var_name = actual_vars[i].str; + const std::string actual_var_name = actual_vars[i]->tile().name(); VALIDATE_TEST(actual_var_name.compare(expected_var_name) == 0, all_tests_passed, test_idx++); // Validate data type const DataType expected_var_type = DataType::Int32; - const DataType actual_var_type = actual_vars[i].desc.dt; + const DataType actual_var_type = actual_vars[i]->tile().info().data_type(); VALIDATE_TEST(actual_var_type == expected_var_type, all_tests_passed, test_idx++); - // Validate data type length - const int32_t expected_var_len = 1; - const int32_t actual_var_len = actual_vars[i].desc.len; - VALIDATE_TEST(actual_var_len == expected_var_len, all_tests_passed, test_idx++); + // Validate tile shape + const int32_t actual_var_width = actual_vars[i]->tile().info().width(); + const int32_t actual_var_height = actual_vars[i]->tile().info().height(); + + VALIDATE_TEST(actual_var_height == 1, all_tests_passed, test_idx++); + VALIDATE_TEST(actual_var_width == 1, all_tests_passed, test_idx++); } return all_tests_passed; } @@ -530,4 +537,4 @@ private: }; } // namespace ckw -#endif // CKW_TESTS_CLTENSORARGUMENTTEST_H +#endif // CKW_VALIDATION_TESTS_CLTENSORARGUMENTTEST_H -- cgit v1.2.1