diff options
author | Viet-Hoa Do <viet-hoa.do@arm.com> | 2023-07-25 14:00:46 +0100 |
---|---|---|
committer | Viet-Hoa Do <viet-hoa.do@arm.com> | 2023-07-27 14:34:04 +0000 |
commit | 0b23e0e6402cb18ddf621d36454cadbb73959518 (patch) | |
tree | 244c32e5a44a8c2a644cb6a1e965c114175d2515 /compute_kernel_writer/validation | |
parent | 9662ac062bafe454afb77a563648e5577c5a8360 (diff) | |
download | ComputeLibrary-0b23e0e6402cb18ddf621d36454cadbb73959518.tar.gz |
Add TensorOperand and declare tensor argument
Partially resolves: COMPMID-6391
Signed-off-by: Viet-Hoa Do <viet-hoa.do@arm.com>
Change-Id: I849d486401f99a93919015f2e173559dca5bffa2
Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/9972
Tested-by: Arm Jenkins <bsgcomp@arm.com>
Reviewed-by: Gunes Bayir <gunes.bayir@arm.com>
Reviewed-by: Gian Marco Iodice <gianmarco.iodice@arm.com>
Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
Benchmark: Arm Jenkins <bsgcomp@arm.com>
Diffstat (limited to 'compute_kernel_writer/validation')
-rw-r--r-- | compute_kernel_writer/validation/tests/CLTensorArgumentTest.h | 59 |
1 files changed, 33 insertions, 26 deletions
diff --git a/compute_kernel_writer/validation/tests/CLTensorArgumentTest.h b/compute_kernel_writer/validation/tests/CLTensorArgumentTest.h index 6db1384247..d3e455cb83 100644 --- a/compute_kernel_writer/validation/tests/CLTensorArgumentTest.h +++ b/compute_kernel_writer/validation/tests/CLTensorArgumentTest.h @@ -22,12 +22,13 @@ * SOFTWARE. */ -#ifndef CKW_TESTS_CLTENSORARGUMENTTEST_H -#define CKW_TESTS_CLTENSORARGUMENTTEST_H +#ifndef CKW_VALIDATION_TESTS_CLTENSORARGUMENTTEST_H +#define CKW_VALIDATION_TESTS_CLTENSORARGUMENTTEST_H #include "common/Common.h" #include "src/cl/CLHelpers.h" #include "src/cl/CLTensorArgument.h" +#include "src/cl/CLTensorComponent.h" #include <string> #include <vector> @@ -89,7 +90,7 @@ public: CLTensorArgument arg(tensor_name, info, false /* return_dims_by_value */); const std::string expected_var_name = _expected_vars[i]; - const std::string actual_var_name = arg.component(_components[i]).str; + const std::string actual_var_name = arg.component(_components[i]).name(); VALIDATE_TEST(actual_var_name.compare(expected_var_name) == 0, all_tests_passed, test_idx++); } @@ -200,8 +201,8 @@ public: { CLTensorArgument arg(tensor_name, info, true /* return_dims_by_value */); - const std::string expected_var_val = _expected_vals[i]; - const std::string actual_var_val = arg.component(_components[i]).str; + const std::string expected_var_val = std::string("((int)(") + _expected_vals[i] + "))"; + const std::string actual_var_val = arg.cl_component(_components[i]).scalar(0, 0).str; VALIDATE_TEST(actual_var_val.compare(expected_var_val) == 0, all_tests_passed, test_idx++); } @@ -276,18 +277,20 @@ public: { // Validate variable name const std::string expected_var_name = _expected_vars[i]; - const std::string actual_var_name = actual_vars[i].str; + const std::string actual_var_name = actual_vars[i]->tile().name(); VALIDATE_TEST(actual_var_name.compare(expected_var_name) == 0, all_tests_passed, test_idx++); // Validate data type const DataType expected_var_type = DataType::Int32; - const DataType actual_var_type = actual_vars[i].desc.dt; + const DataType actual_var_type = actual_vars[i]->tile().info().data_type(); VALIDATE_TEST(actual_var_type == expected_var_type, all_tests_passed, test_idx++); - // Validate data type length - const int32_t expected_var_len = 1; - const int32_t actual_var_len = actual_vars[i].desc.len; - VALIDATE_TEST(actual_var_len == expected_var_len, all_tests_passed, test_idx++); + // Validate tile shape + const int32_t actual_var_width = actual_vars[i]->tile().info().width(); + const int32_t actual_var_height = actual_vars[i]->tile().info().height(); + + VALIDATE_TEST(actual_var_height == 1, all_tests_passed, test_idx++); + VALIDATE_TEST(actual_var_width == 1, all_tests_passed, test_idx++); } return all_tests_passed; } @@ -356,18 +359,20 @@ public: { // Validate variable name const std::string expected_var_name = _expected_vars[i]; - const std::string actual_var_name = actual_vars[i].str; + const std::string actual_var_name = actual_vars[i]->tile().name(); VALIDATE_TEST(actual_var_name.compare(expected_var_name) == 0, all_tests_passed, test_idx++); // Validate data type const DataType expected_var_type = DataType::Int32; - const DataType actual_var_type = actual_vars[i].desc.dt; + const DataType actual_var_type = actual_vars[i]->tile().info().data_type(); VALIDATE_TEST(actual_var_type == expected_var_type, all_tests_passed, test_idx++); - // Validate data type length - const int32_t expected_var_len = 1; - const int32_t actual_var_len = actual_vars[i].desc.len; - VALIDATE_TEST(actual_var_len == expected_var_len, all_tests_passed, test_idx++); + // Validate tile shape + const int32_t actual_var_width = actual_vars[i]->tile().info().width(); + const int32_t actual_var_height = actual_vars[i]->tile().info().height(); + + VALIDATE_TEST(actual_var_height == 1, all_tests_passed, test_idx++); + VALIDATE_TEST(actual_var_width == 1, all_tests_passed, test_idx++); } return all_tests_passed; } @@ -430,8 +435,8 @@ public: VALIDATE_TEST(actual_var_name.compare(expected_var_name) == 0, all_tests_passed, test_idx++); // Validate storage type - const std::string expected_var_type = cl_get_variable_storagetype_as_string(_storages[i]); - const std::string actual_var_type = actual_vars[i].type; + const auto expected_var_type = _storages[i]; + const auto actual_var_type = actual_vars[i].type; VALIDATE_TEST(actual_var_type == expected_var_type, all_tests_passed, test_idx++); } return all_tests_passed; @@ -503,18 +508,20 @@ public: { // Validate variable name const std::string expected_var_name = _expected_vars[i]; - const std::string actual_var_name = actual_vars[i].str; + const std::string actual_var_name = actual_vars[i]->tile().name(); VALIDATE_TEST(actual_var_name.compare(expected_var_name) == 0, all_tests_passed, test_idx++); // Validate data type const DataType expected_var_type = DataType::Int32; - const DataType actual_var_type = actual_vars[i].desc.dt; + const DataType actual_var_type = actual_vars[i]->tile().info().data_type(); VALIDATE_TEST(actual_var_type == expected_var_type, all_tests_passed, test_idx++); - // Validate data type length - const int32_t expected_var_len = 1; - const int32_t actual_var_len = actual_vars[i].desc.len; - VALIDATE_TEST(actual_var_len == expected_var_len, all_tests_passed, test_idx++); + // Validate tile shape + const int32_t actual_var_width = actual_vars[i]->tile().info().width(); + const int32_t actual_var_height = actual_vars[i]->tile().info().height(); + + VALIDATE_TEST(actual_var_height == 1, all_tests_passed, test_idx++); + VALIDATE_TEST(actual_var_width == 1, all_tests_passed, test_idx++); } return all_tests_passed; } @@ -530,4 +537,4 @@ private: }; } // namespace ckw -#endif // CKW_TESTS_CLTENSORARGUMENTTEST_H +#endif // CKW_VALIDATION_TESTS_CLTENSORARGUMENTTEST_H |