aboutsummaryrefslogtreecommitdiff
path: root/compute_kernel_writer/validation/tests/CLTensorArgumentTest.h
diff options
context:
space:
mode:
Diffstat (limited to 'compute_kernel_writer/validation/tests/CLTensorArgumentTest.h')
-rw-r--r--compute_kernel_writer/validation/tests/CLTensorArgumentTest.h59
1 files changed, 33 insertions, 26 deletions
diff --git a/compute_kernel_writer/validation/tests/CLTensorArgumentTest.h b/compute_kernel_writer/validation/tests/CLTensorArgumentTest.h
index 6db1384247..d3e455cb83 100644
--- a/compute_kernel_writer/validation/tests/CLTensorArgumentTest.h
+++ b/compute_kernel_writer/validation/tests/CLTensorArgumentTest.h
@@ -22,12 +22,13 @@
* SOFTWARE.
*/
-#ifndef CKW_TESTS_CLTENSORARGUMENTTEST_H
-#define CKW_TESTS_CLTENSORARGUMENTTEST_H
+#ifndef CKW_VALIDATION_TESTS_CLTENSORARGUMENTTEST_H
+#define CKW_VALIDATION_TESTS_CLTENSORARGUMENTTEST_H
#include "common/Common.h"
#include "src/cl/CLHelpers.h"
#include "src/cl/CLTensorArgument.h"
+#include "src/cl/CLTensorComponent.h"
#include <string>
#include <vector>
@@ -89,7 +90,7 @@ public:
CLTensorArgument arg(tensor_name, info, false /* return_dims_by_value */);
const std::string expected_var_name = _expected_vars[i];
- const std::string actual_var_name = arg.component(_components[i]).str;
+ const std::string actual_var_name = arg.component(_components[i]).name();
VALIDATE_TEST(actual_var_name.compare(expected_var_name) == 0, all_tests_passed, test_idx++);
}
@@ -200,8 +201,8 @@ public:
{
CLTensorArgument arg(tensor_name, info, true /* return_dims_by_value */);
- const std::string expected_var_val = _expected_vals[i];
- const std::string actual_var_val = arg.component(_components[i]).str;
+ const std::string expected_var_val = std::string("((int)(") + _expected_vals[i] + "))";
+ const std::string actual_var_val = arg.cl_component(_components[i]).scalar(0, 0).str;
VALIDATE_TEST(actual_var_val.compare(expected_var_val) == 0, all_tests_passed, test_idx++);
}
@@ -276,18 +277,20 @@ public:
{
// Validate variable name
const std::string expected_var_name = _expected_vars[i];
- const std::string actual_var_name = actual_vars[i].str;
+ const std::string actual_var_name = actual_vars[i]->tile().name();
VALIDATE_TEST(actual_var_name.compare(expected_var_name) == 0, all_tests_passed, test_idx++);
// Validate data type
const DataType expected_var_type = DataType::Int32;
- const DataType actual_var_type = actual_vars[i].desc.dt;
+ const DataType actual_var_type = actual_vars[i]->tile().info().data_type();
VALIDATE_TEST(actual_var_type == expected_var_type, all_tests_passed, test_idx++);
- // Validate data type length
- const int32_t expected_var_len = 1;
- const int32_t actual_var_len = actual_vars[i].desc.len;
- VALIDATE_TEST(actual_var_len == expected_var_len, all_tests_passed, test_idx++);
+ // Validate tile shape
+ const int32_t actual_var_width = actual_vars[i]->tile().info().width();
+ const int32_t actual_var_height = actual_vars[i]->tile().info().height();
+
+ VALIDATE_TEST(actual_var_height == 1, all_tests_passed, test_idx++);
+ VALIDATE_TEST(actual_var_width == 1, all_tests_passed, test_idx++);
}
return all_tests_passed;
}
@@ -356,18 +359,20 @@ public:
{
// Validate variable name
const std::string expected_var_name = _expected_vars[i];
- const std::string actual_var_name = actual_vars[i].str;
+ const std::string actual_var_name = actual_vars[i]->tile().name();
VALIDATE_TEST(actual_var_name.compare(expected_var_name) == 0, all_tests_passed, test_idx++);
// Validate data type
const DataType expected_var_type = DataType::Int32;
- const DataType actual_var_type = actual_vars[i].desc.dt;
+ const DataType actual_var_type = actual_vars[i]->tile().info().data_type();
VALIDATE_TEST(actual_var_type == expected_var_type, all_tests_passed, test_idx++);
- // Validate data type length
- const int32_t expected_var_len = 1;
- const int32_t actual_var_len = actual_vars[i].desc.len;
- VALIDATE_TEST(actual_var_len == expected_var_len, all_tests_passed, test_idx++);
+ // Validate tile shape
+ const int32_t actual_var_width = actual_vars[i]->tile().info().width();
+ const int32_t actual_var_height = actual_vars[i]->tile().info().height();
+
+ VALIDATE_TEST(actual_var_height == 1, all_tests_passed, test_idx++);
+ VALIDATE_TEST(actual_var_width == 1, all_tests_passed, test_idx++);
}
return all_tests_passed;
}
@@ -430,8 +435,8 @@ public:
VALIDATE_TEST(actual_var_name.compare(expected_var_name) == 0, all_tests_passed, test_idx++);
// Validate storage type
- const std::string expected_var_type = cl_get_variable_storagetype_as_string(_storages[i]);
- const std::string actual_var_type = actual_vars[i].type;
+ const auto expected_var_type = _storages[i];
+ const auto actual_var_type = actual_vars[i].type;
VALIDATE_TEST(actual_var_type == expected_var_type, all_tests_passed, test_idx++);
}
return all_tests_passed;
@@ -503,18 +508,20 @@ public:
{
// Validate variable name
const std::string expected_var_name = _expected_vars[i];
- const std::string actual_var_name = actual_vars[i].str;
+ const std::string actual_var_name = actual_vars[i]->tile().name();
VALIDATE_TEST(actual_var_name.compare(expected_var_name) == 0, all_tests_passed, test_idx++);
// Validate data type
const DataType expected_var_type = DataType::Int32;
- const DataType actual_var_type = actual_vars[i].desc.dt;
+ const DataType actual_var_type = actual_vars[i]->tile().info().data_type();
VALIDATE_TEST(actual_var_type == expected_var_type, all_tests_passed, test_idx++);
- // Validate data type length
- const int32_t expected_var_len = 1;
- const int32_t actual_var_len = actual_vars[i].desc.len;
- VALIDATE_TEST(actual_var_len == expected_var_len, all_tests_passed, test_idx++);
+ // Validate tile shape
+ const int32_t actual_var_width = actual_vars[i]->tile().info().width();
+ const int32_t actual_var_height = actual_vars[i]->tile().info().height();
+
+ VALIDATE_TEST(actual_var_height == 1, all_tests_passed, test_idx++);
+ VALIDATE_TEST(actual_var_width == 1, all_tests_passed, test_idx++);
}
return all_tests_passed;
}
@@ -530,4 +537,4 @@ private:
};
} // namespace ckw
-#endif // CKW_TESTS_CLTENSORARGUMENTTEST_H
+#endif // CKW_VALIDATION_TESTS_CLTENSORARGUMENTTEST_H