aboutsummaryrefslogtreecommitdiff
path: root/compute_kernel_writer/validation
diff options
context:
space:
mode:
authorViet-Hoa Do <viet-hoa.do@arm.com>2023-07-25 14:00:46 +0100
committerViet-Hoa Do <viet-hoa.do@arm.com>2023-07-27 14:34:04 +0000
commit0b23e0e6402cb18ddf621d36454cadbb73959518 (patch)
tree244c32e5a44a8c2a644cb6a1e965c114175d2515 /compute_kernel_writer/validation
parent9662ac062bafe454afb77a563648e5577c5a8360 (diff)
downloadComputeLibrary-0b23e0e6402cb18ddf621d36454cadbb73959518.tar.gz
Add TensorOperand and declare tensor argument
Partially resolves: COMPMID-6391 Signed-off-by: Viet-Hoa Do <viet-hoa.do@arm.com> Change-Id: I849d486401f99a93919015f2e173559dca5bffa2 Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/9972 Tested-by: Arm Jenkins <bsgcomp@arm.com> Reviewed-by: Gunes Bayir <gunes.bayir@arm.com> Reviewed-by: Gian Marco Iodice <gianmarco.iodice@arm.com> Comments-Addressed: Arm Jenkins <bsgcomp@arm.com> Benchmark: Arm Jenkins <bsgcomp@arm.com>
Diffstat (limited to 'compute_kernel_writer/validation')
-rw-r--r--compute_kernel_writer/validation/tests/CLTensorArgumentTest.h59
1 files changed, 33 insertions, 26 deletions
diff --git a/compute_kernel_writer/validation/tests/CLTensorArgumentTest.h b/compute_kernel_writer/validation/tests/CLTensorArgumentTest.h
index 6db1384247..d3e455cb83 100644
--- a/compute_kernel_writer/validation/tests/CLTensorArgumentTest.h
+++ b/compute_kernel_writer/validation/tests/CLTensorArgumentTest.h
@@ -22,12 +22,13 @@
* SOFTWARE.
*/
-#ifndef CKW_TESTS_CLTENSORARGUMENTTEST_H
-#define CKW_TESTS_CLTENSORARGUMENTTEST_H
+#ifndef CKW_VALIDATION_TESTS_CLTENSORARGUMENTTEST_H
+#define CKW_VALIDATION_TESTS_CLTENSORARGUMENTTEST_H
#include "common/Common.h"
#include "src/cl/CLHelpers.h"
#include "src/cl/CLTensorArgument.h"
+#include "src/cl/CLTensorComponent.h"
#include <string>
#include <vector>
@@ -89,7 +90,7 @@ public:
CLTensorArgument arg(tensor_name, info, false /* return_dims_by_value */);
const std::string expected_var_name = _expected_vars[i];
- const std::string actual_var_name = arg.component(_components[i]).str;
+ const std::string actual_var_name = arg.component(_components[i]).name();
VALIDATE_TEST(actual_var_name.compare(expected_var_name) == 0, all_tests_passed, test_idx++);
}
@@ -200,8 +201,8 @@ public:
{
CLTensorArgument arg(tensor_name, info, true /* return_dims_by_value */);
- const std::string expected_var_val = _expected_vals[i];
- const std::string actual_var_val = arg.component(_components[i]).str;
+ const std::string expected_var_val = std::string("((int)(") + _expected_vals[i] + "))";
+ const std::string actual_var_val = arg.cl_component(_components[i]).scalar(0, 0).str;
VALIDATE_TEST(actual_var_val.compare(expected_var_val) == 0, all_tests_passed, test_idx++);
}
@@ -276,18 +277,20 @@ public:
{
// Validate variable name
const std::string expected_var_name = _expected_vars[i];
- const std::string actual_var_name = actual_vars[i].str;
+ const std::string actual_var_name = actual_vars[i]->tile().name();
VALIDATE_TEST(actual_var_name.compare(expected_var_name) == 0, all_tests_passed, test_idx++);
// Validate data type
const DataType expected_var_type = DataType::Int32;
- const DataType actual_var_type = actual_vars[i].desc.dt;
+ const DataType actual_var_type = actual_vars[i]->tile().info().data_type();
VALIDATE_TEST(actual_var_type == expected_var_type, all_tests_passed, test_idx++);
- // Validate data type length
- const int32_t expected_var_len = 1;
- const int32_t actual_var_len = actual_vars[i].desc.len;
- VALIDATE_TEST(actual_var_len == expected_var_len, all_tests_passed, test_idx++);
+ // Validate tile shape
+ const int32_t actual_var_width = actual_vars[i]->tile().info().width();
+ const int32_t actual_var_height = actual_vars[i]->tile().info().height();
+
+ VALIDATE_TEST(actual_var_height == 1, all_tests_passed, test_idx++);
+ VALIDATE_TEST(actual_var_width == 1, all_tests_passed, test_idx++);
}
return all_tests_passed;
}
@@ -356,18 +359,20 @@ public:
{
// Validate variable name
const std::string expected_var_name = _expected_vars[i];
- const std::string actual_var_name = actual_vars[i].str;
+ const std::string actual_var_name = actual_vars[i]->tile().name();
VALIDATE_TEST(actual_var_name.compare(expected_var_name) == 0, all_tests_passed, test_idx++);
// Validate data type
const DataType expected_var_type = DataType::Int32;
- const DataType actual_var_type = actual_vars[i].desc.dt;
+ const DataType actual_var_type = actual_vars[i]->tile().info().data_type();
VALIDATE_TEST(actual_var_type == expected_var_type, all_tests_passed, test_idx++);
- // Validate data type length
- const int32_t expected_var_len = 1;
- const int32_t actual_var_len = actual_vars[i].desc.len;
- VALIDATE_TEST(actual_var_len == expected_var_len, all_tests_passed, test_idx++);
+ // Validate tile shape
+ const int32_t actual_var_width = actual_vars[i]->tile().info().width();
+ const int32_t actual_var_height = actual_vars[i]->tile().info().height();
+
+ VALIDATE_TEST(actual_var_height == 1, all_tests_passed, test_idx++);
+ VALIDATE_TEST(actual_var_width == 1, all_tests_passed, test_idx++);
}
return all_tests_passed;
}
@@ -430,8 +435,8 @@ public:
VALIDATE_TEST(actual_var_name.compare(expected_var_name) == 0, all_tests_passed, test_idx++);
// Validate storage type
- const std::string expected_var_type = cl_get_variable_storagetype_as_string(_storages[i]);
- const std::string actual_var_type = actual_vars[i].type;
+ const auto expected_var_type = _storages[i];
+ const auto actual_var_type = actual_vars[i].type;
VALIDATE_TEST(actual_var_type == expected_var_type, all_tests_passed, test_idx++);
}
return all_tests_passed;
@@ -503,18 +508,20 @@ public:
{
// Validate variable name
const std::string expected_var_name = _expected_vars[i];
- const std::string actual_var_name = actual_vars[i].str;
+ const std::string actual_var_name = actual_vars[i]->tile().name();
VALIDATE_TEST(actual_var_name.compare(expected_var_name) == 0, all_tests_passed, test_idx++);
// Validate data type
const DataType expected_var_type = DataType::Int32;
- const DataType actual_var_type = actual_vars[i].desc.dt;
+ const DataType actual_var_type = actual_vars[i]->tile().info().data_type();
VALIDATE_TEST(actual_var_type == expected_var_type, all_tests_passed, test_idx++);
- // Validate data type length
- const int32_t expected_var_len = 1;
- const int32_t actual_var_len = actual_vars[i].desc.len;
- VALIDATE_TEST(actual_var_len == expected_var_len, all_tests_passed, test_idx++);
+ // Validate tile shape
+ const int32_t actual_var_width = actual_vars[i]->tile().info().width();
+ const int32_t actual_var_height = actual_vars[i]->tile().info().height();
+
+ VALIDATE_TEST(actual_var_height == 1, all_tests_passed, test_idx++);
+ VALIDATE_TEST(actual_var_width == 1, all_tests_passed, test_idx++);
}
return all_tests_passed;
}
@@ -530,4 +537,4 @@ private:
};
} // namespace ckw
-#endif // CKW_TESTS_CLTENSORARGUMENTTEST_H
+#endif // CKW_VALIDATION_TESTS_CLTENSORARGUMENTTEST_H