aboutsummaryrefslogtreecommitdiff
path: root/tests/validation/UNIT/TensorInfo.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'tests/validation/UNIT/TensorInfo.cpp')
-rw-r--r--tests/validation/UNIT/TensorInfo.cpp48
1 files changed, 47 insertions, 1 deletions
diff --git a/tests/validation/UNIT/TensorInfo.cpp b/tests/validation/UNIT/TensorInfo.cpp
index 2a6c3365ea..dd7ae6d18c 100644
--- a/tests/validation/UNIT/TensorInfo.cpp
+++ b/tests/validation/UNIT/TensorInfo.cpp
@@ -36,10 +36,11 @@ namespace test
namespace validation
{
TEST_SUITE(UNIT)
-TEST_SUITE(TensorInfoValidation)
+TEST_SUITE(TensorInfo)
// *INDENT-OFF*
// clang-format off
+/** Validates TensorInfo Autopadding */
DATA_TEST_CASE(AutoPadding, framework::DatasetMode::ALL, zip(zip(zip(
framework::dataset::make("TensorShape", {
TensorShape{},
@@ -82,6 +83,51 @@ DATA_TEST_CASE(AutoPadding, framework::DatasetMode::ALL, zip(zip(zip(
// clang-format on
// *INDENT-ON*
+/** Validates that TensorInfo is clonable */
+TEST_CASE(Clone, framework::DatasetMode::ALL)
+{
+ // Create tensor info
+ TensorInfo info(TensorShape(23U, 17U, 3U), // tensor shape
+ 1, // number of channels
+ DataType::F32); // data type
+
+ // Get clone of current tensor info
+ std::unique_ptr<ITensorInfo> info_clone = info.clone();
+ ARM_COMPUTE_EXPECT(info_clone != nullptr, framework::LogLevel::ERRORS);
+ ARM_COMPUTE_EXPECT(info_clone->total_size() == info.total_size(), framework::LogLevel::ERRORS);
+ ARM_COMPUTE_EXPECT(info_clone->num_channels() == info.num_channels(), framework::LogLevel::ERRORS);
+ ARM_COMPUTE_EXPECT(info_clone->data_type() == info.data_type(), framework::LogLevel::ERRORS);
+}
+
+/** Validates that TensorInfo can chain multiple set commands */
+TEST_CASE(TensorInfoBuild, framework::DatasetMode::ALL)
+{
+ // Create tensor info
+ TensorInfo info(TensorShape(23U, 17U, 3U), // tensor shape
+ 1, // number of channels
+ DataType::F32); // data type
+
+ // Update data type and number of channels
+ info.set_data_type(DataType::S32).set_num_channels(3);
+ ARM_COMPUTE_EXPECT(info.data_type() == DataType::S32, framework::LogLevel::ERRORS);
+ ARM_COMPUTE_EXPECT(info.num_channels() == 3, framework::LogLevel::ERRORS);
+
+ // Update data type channels and set fixed point position
+ info.set_data_type(DataType::QS8).set_num_channels(1).set_fixed_point_position(3);
+ ARM_COMPUTE_EXPECT(info.data_type() == DataType::QS8, framework::LogLevel::ERRORS);
+ ARM_COMPUTE_EXPECT(info.num_channels() == 1, framework::LogLevel::ERRORS);
+ ARM_COMPUTE_EXPECT(info.fixed_point_position() == 3, framework::LogLevel::ERRORS);
+
+ // Update data type and set quantization info
+ info.set_data_type(DataType::QASYMM8).set_quantization_info(QuantizationInfo(0.5f, 15));
+ ARM_COMPUTE_EXPECT(info.data_type() == DataType::QASYMM8, framework::LogLevel::ERRORS);
+ ARM_COMPUTE_EXPECT(info.quantization_info() == QuantizationInfo(0.5f, 15), framework::LogLevel::ERRORS);
+
+ // Update tensor shape
+ info.set_tensor_shape(TensorShape(13U, 15U));
+ ARM_COMPUTE_EXPECT(info.tensor_shape() == TensorShape(13U, 15U), framework::LogLevel::ERRORS);
+}
+
TEST_SUITE_END() // TensorInfoValidation
TEST_SUITE_END()
} // namespace validation