From 283c1790da45ab562ecfb2aa7741297191886d85 Mon Sep 17 00:00:00 2001 From: Georgios Pinitas Date: Fri, 10 Nov 2017 18:14:06 +0000 Subject: COMPMID-676: Rework TensorInfo building Change-Id: Ic98f64ffe30739437a1fe31ef98d83ee900741e3 Reviewed-on: http://mpd-gerrit.cambridge.arm.com/95512 Reviewed-by: Michalis Spyrou Tested-by: Kaizen Reviewed-by: Anthony Barbier --- tests/validation/UNIT/TensorInfo.cpp | 48 +++++++++++++++++++++++++++++++++++- 1 file changed, 47 insertions(+), 1 deletion(-) (limited to 'tests/validation/UNIT') diff --git a/tests/validation/UNIT/TensorInfo.cpp b/tests/validation/UNIT/TensorInfo.cpp index 2a6c3365ea..dd7ae6d18c 100644 --- a/tests/validation/UNIT/TensorInfo.cpp +++ b/tests/validation/UNIT/TensorInfo.cpp @@ -36,10 +36,11 @@ namespace test namespace validation { TEST_SUITE(UNIT) -TEST_SUITE(TensorInfoValidation) +TEST_SUITE(TensorInfo) // *INDENT-OFF* // clang-format off +/** Validates TensorInfo Autopadding */ DATA_TEST_CASE(AutoPadding, framework::DatasetMode::ALL, zip(zip(zip( framework::dataset::make("TensorShape", { TensorShape{}, @@ -82,6 +83,51 @@ DATA_TEST_CASE(AutoPadding, framework::DatasetMode::ALL, zip(zip(zip( // clang-format on // *INDENT-ON* +/** Validates that TensorInfo is clonable */ +TEST_CASE(Clone, framework::DatasetMode::ALL) +{ + // Create tensor info + TensorInfo info(TensorShape(23U, 17U, 3U), // tensor shape + 1, // number of channels + DataType::F32); // data type + + // Get clone of current tensor info + std::unique_ptr info_clone = info.clone(); + ARM_COMPUTE_EXPECT(info_clone != nullptr, framework::LogLevel::ERRORS); + ARM_COMPUTE_EXPECT(info_clone->total_size() == info.total_size(), framework::LogLevel::ERRORS); + ARM_COMPUTE_EXPECT(info_clone->num_channels() == info.num_channels(), framework::LogLevel::ERRORS); + ARM_COMPUTE_EXPECT(info_clone->data_type() == info.data_type(), framework::LogLevel::ERRORS); +} + +/** Validates that TensorInfo can chain multiple set commands */ +TEST_CASE(TensorInfoBuild, framework::DatasetMode::ALL) +{ + // Create tensor info + TensorInfo info(TensorShape(23U, 17U, 3U), // tensor shape + 1, // number of channels + DataType::F32); // data type + + // Update data type and number of channels + info.set_data_type(DataType::S32).set_num_channels(3); + ARM_COMPUTE_EXPECT(info.data_type() == DataType::S32, framework::LogLevel::ERRORS); + ARM_COMPUTE_EXPECT(info.num_channels() == 3, framework::LogLevel::ERRORS); + + // Update data type channels and set fixed point position + info.set_data_type(DataType::QS8).set_num_channels(1).set_fixed_point_position(3); + ARM_COMPUTE_EXPECT(info.data_type() == DataType::QS8, framework::LogLevel::ERRORS); + ARM_COMPUTE_EXPECT(info.num_channels() == 1, framework::LogLevel::ERRORS); + ARM_COMPUTE_EXPECT(info.fixed_point_position() == 3, framework::LogLevel::ERRORS); + + // Update data type and set quantization info + info.set_data_type(DataType::QASYMM8).set_quantization_info(QuantizationInfo(0.5f, 15)); + ARM_COMPUTE_EXPECT(info.data_type() == DataType::QASYMM8, framework::LogLevel::ERRORS); + ARM_COMPUTE_EXPECT(info.quantization_info() == QuantizationInfo(0.5f, 15), framework::LogLevel::ERRORS); + + // Update tensor shape + info.set_tensor_shape(TensorShape(13U, 15U)); + ARM_COMPUTE_EXPECT(info.tensor_shape() == TensorShape(13U, 15U), framework::LogLevel::ERRORS); +} + TEST_SUITE_END() // TensorInfoValidation TEST_SUITE_END() } // namespace validation -- cgit v1.2.1