diff options
author | Georgios Pinitas <georgios.pinitas@arm.com> | 2018-05-08 15:54:53 +0100 |
---|---|---|
committer | Anthony Barbier <anthony.barbier@arm.com> | 2018-11-02 16:52:19 +0000 |
commit | 393fa4c87c84356132303170d1b9ce9a45b3c3bf (patch) | |
tree | b5d5a7ca835d625b5afd56155be8ad9de7ab6575 /tests/validation | |
parent | 1731d5133f1b081fc669d082ae8c3e744d36ab11 (diff) | |
download | ComputeLibrary-393fa4c87c84356132303170d1b9ce9a45b3c3bf.tar.gz |
COMPMID-814: NEScale NHWC support
Change-Id: Ibf5c624a5c5482faa42eb02bc8abe9ae0d65b0d1
Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/130608
Tested-by: Jenkins <bsgcomp@arm.com>
Reviewed-by: Anthony Barbier <anthony.barbier@arm.com>
Diffstat (limited to 'tests/validation')
-rw-r--r-- | tests/validation/CL/Scale.cpp | 30 | ||||
-rw-r--r-- | tests/validation/GLES_COMPUTE/Scale.cpp | 7 | ||||
-rw-r--r-- | tests/validation/NEON/Scale.cpp | 70 | ||||
-rw-r--r-- | tests/validation/fixtures/ScaleFixture.h | 33 | ||||
-rw-r--r-- | tests/validation/reference/Scale.cpp | 1 |
5 files changed, 105 insertions, 36 deletions
diff --git a/tests/validation/CL/Scale.cpp b/tests/validation/CL/Scale.cpp index cc4fdb0564..3d8750ad28 100644 --- a/tests/validation/CL/Scale.cpp +++ b/tests/validation/CL/Scale.cpp @@ -118,7 +118,9 @@ using CLScaleFixture = ScaleValidationFixture<CLTensor, CLAccessor, CLScale, T>; TEST_SUITE(Float) TEST_SUITE(FP32) -FIXTURE_DATA_TEST_CASE(RunSmall, CLScaleFixture<float>, framework::DatasetMode::ALL, combine(combine(combine(combine(datasets::SmallShapes(), framework::dataset::make("DataType", DataType::F32)), +FIXTURE_DATA_TEST_CASE(RunSmall, CLScaleFixture<float>, framework::DatasetMode::ALL, combine(combine(combine(combine(combine(datasets::SmallShapes(), framework::dataset::make("DataType", + DataType::F32)), + framework::dataset::make("DataLayout", { DataLayout::NCHW })), framework::dataset::make("InterpolationPolicy", { InterpolationPolicy::NEAREST_NEIGHBOR, InterpolationPolicy::BILINEAR })), datasets::BorderModes()), datasets::SamplingPolicies())) @@ -130,7 +132,9 @@ FIXTURE_DATA_TEST_CASE(RunSmall, CLScaleFixture<float>, framework::DatasetMode:: // Validate output validate(CLAccessor(_target), _reference, valid_region, tolerance_f32, tolerance_num_f32, tolerance_f32_absolute); } -FIXTURE_DATA_TEST_CASE(RunLarge, CLScaleFixture<float>, framework::DatasetMode::NIGHTLY, combine(combine(combine(combine(datasets::LargeShapes(), framework::dataset::make("DataType", DataType::F32)), +FIXTURE_DATA_TEST_CASE(RunLarge, CLScaleFixture<float>, framework::DatasetMode::NIGHTLY, combine(combine(combine(combine(combine(datasets::LargeShapes(), framework::dataset::make("DataType", + DataType::F32)), + framework::dataset::make("DataLayout", { DataLayout::NCHW })), framework::dataset::make("InterpolationPolicy", { InterpolationPolicy::NEAREST_NEIGHBOR, InterpolationPolicy::BILINEAR })), datasets::BorderModes()), datasets::SamplingPolicies())) @@ -144,7 +148,9 @@ FIXTURE_DATA_TEST_CASE(RunLarge, CLScaleFixture<float>, framework::DatasetMode:: } TEST_SUITE_END() TEST_SUITE(FP16) -FIXTURE_DATA_TEST_CASE(RunSmall, CLScaleFixture<half>, framework::DatasetMode::ALL, combine(combine(combine(combine(datasets::SmallShapes(), framework::dataset::make("DataType", DataType::F16)), +FIXTURE_DATA_TEST_CASE(RunSmall, CLScaleFixture<half>, framework::DatasetMode::ALL, combine(combine(combine(combine(combine(datasets::SmallShapes(), framework::dataset::make("DataType", + DataType::F16)), + framework::dataset::make("DataLayout", { DataLayout::NCHW })), framework::dataset::make("InterpolationPolicy", { InterpolationPolicy::NEAREST_NEIGHBOR, InterpolationPolicy::BILINEAR })), datasets::BorderModes()), datasets::SamplingPolicies())) @@ -156,8 +162,9 @@ FIXTURE_DATA_TEST_CASE(RunSmall, CLScaleFixture<half>, framework::DatasetMode::A // Validate output validate(CLAccessor(_target), _reference, valid_region, tolerance_f16); } -FIXTURE_DATA_TEST_CASE(RunLarge, CLScaleFixture<half>, framework::DatasetMode::NIGHTLY, combine(combine(combine(combine(datasets::LargeShapes(), framework::dataset::make("DataType", +FIXTURE_DATA_TEST_CASE(RunLarge, CLScaleFixture<half>, framework::DatasetMode::NIGHTLY, combine(combine(combine(combine(combine(datasets::LargeShapes(), framework::dataset::make("DataType", DataType::F16)), + framework::dataset::make("DataLayout", { DataLayout::NCHW })), framework::dataset::make("InterpolationPolicy", { InterpolationPolicy::NEAREST_NEIGHBOR, InterpolationPolicy::BILINEAR })), datasets::BorderModes()), datasets::SamplingPolicies())) @@ -174,7 +181,9 @@ TEST_SUITE_END() TEST_SUITE(Integer) TEST_SUITE(U8) -FIXTURE_DATA_TEST_CASE(RunSmall, CLScaleFixture<uint8_t>, framework::DatasetMode::ALL, combine(combine(combine(combine(datasets::SmallShapes(), framework::dataset::make("DataType", DataType::U8)), +FIXTURE_DATA_TEST_CASE(RunSmall, CLScaleFixture<uint8_t>, framework::DatasetMode::ALL, combine(combine(combine(combine(combine(datasets::SmallShapes(), framework::dataset::make("DataType", + DataType::U8)), + framework::dataset::make("DataLayout", { DataLayout::NCHW })), framework::dataset::make("InterpolationPolicy", { InterpolationPolicy::NEAREST_NEIGHBOR, InterpolationPolicy::BILINEAR })), datasets::BorderModes()), datasets::SamplingPolicies())) @@ -186,7 +195,9 @@ FIXTURE_DATA_TEST_CASE(RunSmall, CLScaleFixture<uint8_t>, framework::DatasetMode // Validate output validate(CLAccessor(_target), _reference, valid_region, tolerance_u8); } -FIXTURE_DATA_TEST_CASE(RunLarge, CLScaleFixture<uint8_t>, framework::DatasetMode::NIGHTLY, combine(combine(combine(combine(datasets::LargeShapes(), framework::dataset::make("DataType", DataType::U8)), +FIXTURE_DATA_TEST_CASE(RunLarge, CLScaleFixture<uint8_t>, framework::DatasetMode::NIGHTLY, combine(combine(combine(combine(combine(datasets::LargeShapes(), framework::dataset::make("DataType", + DataType::U8)), + framework::dataset::make("DataLayout", { DataLayout::NCHW })), framework::dataset::make("InterpolationPolicy", { InterpolationPolicy::NEAREST_NEIGHBOR, InterpolationPolicy::BILINEAR })), datasets::BorderModes()), datasets::SamplingPolicies())) @@ -200,7 +211,9 @@ FIXTURE_DATA_TEST_CASE(RunLarge, CLScaleFixture<uint8_t>, framework::DatasetMode } TEST_SUITE_END() TEST_SUITE(S16) -FIXTURE_DATA_TEST_CASE(RunSmall, CLScaleFixture<int16_t>, framework::DatasetMode::ALL, combine(combine(combine(combine(datasets::SmallShapes(), framework::dataset::make("DataType", DataType::S16)), +FIXTURE_DATA_TEST_CASE(RunSmall, CLScaleFixture<int16_t>, framework::DatasetMode::ALL, combine(combine(combine(combine(combine(datasets::SmallShapes(), framework::dataset::make("DataType", + DataType::S16)), + framework::dataset::make("DataLayout", { DataLayout::NCHW })), framework::dataset::make("InterpolationPolicy", { InterpolationPolicy::NEAREST_NEIGHBOR, InterpolationPolicy::BILINEAR })), datasets::BorderModes()), datasets::SamplingPolicies())) @@ -212,8 +225,9 @@ FIXTURE_DATA_TEST_CASE(RunSmall, CLScaleFixture<int16_t>, framework::DatasetMode // Validate output validate(CLAccessor(_target), _reference, valid_region, tolerance_s16); } -FIXTURE_DATA_TEST_CASE(RunLarge, CLScaleFixture<int16_t>, framework::DatasetMode::NIGHTLY, combine(combine(combine(combine(datasets::LargeShapes(), framework::dataset::make("DataType", +FIXTURE_DATA_TEST_CASE(RunLarge, CLScaleFixture<int16_t>, framework::DatasetMode::NIGHTLY, combine(combine(combine(combine(combine(datasets::LargeShapes(), framework::dataset::make("DataType", DataType::S16)), + framework::dataset::make("DataLayout", { DataLayout::NCHW })), framework::dataset::make("InterpolationPolicy", { InterpolationPolicy::NEAREST_NEIGHBOR, InterpolationPolicy::BILINEAR })), datasets::BorderModes()), datasets::SamplingPolicies())) diff --git a/tests/validation/GLES_COMPUTE/Scale.cpp b/tests/validation/GLES_COMPUTE/Scale.cpp index 9f670e4d4d..4bfa08f060 100644 --- a/tests/validation/GLES_COMPUTE/Scale.cpp +++ b/tests/validation/GLES_COMPUTE/Scale.cpp @@ -108,7 +108,9 @@ using GCScaleFixture = ScaleValidationFixture<GCTensor, GCAccessor, GCScale, T>; TEST_SUITE(Float) TEST_SUITE(FP16) -FIXTURE_DATA_TEST_CASE(RunSmall, GCScaleFixture<half>, framework::DatasetMode::ALL, combine(combine(combine(combine(datasets::SmallShapes(), framework::dataset::make("DataType", DataType::F16)), +FIXTURE_DATA_TEST_CASE(RunSmall, GCScaleFixture<half>, framework::DatasetMode::ALL, combine(combine(combine(combine(combine(datasets::SmallShapes(), framework::dataset::make("DataType", + DataType::F16)), + framework::dataset::make("DataLayout", { DataLayout::NCHW })), framework::dataset::make("InterpolationPolicy", { InterpolationPolicy::NEAREST_NEIGHBOR })), datasets::BorderModes()), datasets::SamplingPolicies())) @@ -120,8 +122,9 @@ FIXTURE_DATA_TEST_CASE(RunSmall, GCScaleFixture<half>, framework::DatasetMode::A // Validate output validate(GCAccessor(_target), _reference, valid_region, tolerance_f16); } -FIXTURE_DATA_TEST_CASE(RunLarge, GCScaleFixture<half>, framework::DatasetMode::NIGHTLY, combine(combine(combine(combine(datasets::LargeShapes(), framework::dataset::make("DataType", +FIXTURE_DATA_TEST_CASE(RunLarge, GCScaleFixture<half>, framework::DatasetMode::NIGHTLY, combine(combine(combine(combine(combine(datasets::LargeShapes(), framework::dataset::make("DataType", DataType::F16)), + framework::dataset::make("DataLayout", { DataLayout::NCHW })), framework::dataset::make("InterpolationPolicy", { InterpolationPolicy::NEAREST_NEIGHBOR })), datasets::BorderModes()), datasets::SamplingPolicies())) diff --git a/tests/validation/NEON/Scale.cpp b/tests/validation/NEON/Scale.cpp index 5f76a0ca06..b21affd9d3 100644 --- a/tests/validation/NEON/Scale.cpp +++ b/tests/validation/NEON/Scale.cpp @@ -55,6 +55,13 @@ const auto ScaleDataTypes = framework::dataset::make("DataType", DataType::F32, }); +/** Scale data types */ +const auto ScaleDataLayouts = framework::dataset::make("DataLayout", +{ + DataLayout::NCHW, + DataLayout::NHWC, +}); + /** Tolerance */ constexpr AbsoluteTolerance<uint8_t> tolerance_u8(1); constexpr AbsoluteTolerance<int16_t> tolerance_s16(1); @@ -67,29 +74,42 @@ constexpr float tolerance_num_f32 = 0.01f; TEST_SUITE(NEON) TEST_SUITE(Scale) -DATA_TEST_CASE(Configuration, framework::DatasetMode::ALL, combine(combine(combine(combine(concat(datasets::SmallShapes(), datasets::LargeShapes()), ScaleDataTypes), +DATA_TEST_CASE(Configuration, framework::DatasetMode::ALL, combine(combine(combine(combine(combine(concat(datasets::SmallShapes(), datasets::LargeShapes()), ScaleDataTypes), ScaleDataLayouts), framework::dataset::make("InterpolationPolicy", { InterpolationPolicy::NEAREST_NEIGHBOR, InterpolationPolicy::BILINEAR })), datasets::BorderModes()), framework::dataset::make("SamplingPolicy", { SamplingPolicy::CENTER })), - shape, data_type, policy, border_mode, sampling_policy) + shape, data_type, data_layout, policy, border_mode, sampling_policy) { std::mt19937 generator(library->seed()); std::uniform_real_distribution<float> distribution_float(0.25, 2); const float scale_x = distribution_float(generator); const float scale_y = distribution_float(generator); uint8_t constant_border_value = 0; + TensorShape src_shape = shape; if(border_mode == BorderMode::CONSTANT) { std::uniform_int_distribution<uint8_t> distribution_u8(0, 255); constant_border_value = distribution_u8(generator); } + // Get width/height indices depending on layout + const int idx_width = get_data_layout_dimension_index(data_layout, DataLayoutDimension::WIDTH); + const int idx_height = get_data_layout_dimension_index(data_layout, DataLayoutDimension::HEIGHT); + + // Change shape in case of NHWC. + if(data_layout == DataLayout::NHWC) + { + permute(src_shape, PermutationVector(2U, 0U, 1U)); + } + + // Calculate scaled shape + TensorShape shape_scaled(src_shape); + shape_scaled.set(idx_width, src_shape[idx_width] * scale_x); + shape_scaled.set(idx_height, src_shape[idx_height] * scale_y); + // Create tensors - Tensor src = create_tensor<Tensor>(shape, data_type); - TensorShape shape_scaled(shape); - shape_scaled.set(0, shape[0] * scale_x); - shape_scaled.set(1, shape[1] * scale_y); - Tensor dst = create_tensor<Tensor>(shape_scaled, data_type); + Tensor src = create_tensor<Tensor>(src_shape, data_type, 1, 0, QuantizationInfo(), data_layout); + Tensor dst = create_tensor<Tensor>(shape_scaled, data_type, 1, 0, QuantizationInfo(), data_layout); ARM_COMPUTE_EXPECT(src.info()->is_resizable(), framework::LogLevel::ERRORS); ARM_COMPUTE_EXPECT(dst.info()->is_resizable(), framework::LogLevel::ERRORS); @@ -100,14 +120,26 @@ DATA_TEST_CASE(Configuration, framework::DatasetMode::ALL, combine(combine(combi // Validate valid region const ValidRegion dst_valid_region = calculate_valid_region_scale(*(src.info()), shape_scaled, policy, sampling_policy, (border_mode == BorderMode::UNDEFINED)); - validate(dst.info()->valid_region(), dst_valid_region); // Validate padding - PaddingCalculator calculator(shape_scaled.x(), 16); + int num_elements_processed_x = 16; + if(data_layout == DataLayout::NHWC) + { + num_elements_processed_x = (policy == InterpolationPolicy::BILINEAR) ? 1 : 16 / src.info()->element_size(); + } + PaddingCalculator calculator(shape_scaled.x(), num_elements_processed_x); calculator.set_border_mode(border_mode); - const PaddingSize read_padding(1); + PaddingSize read_padding(1); + if(data_layout == DataLayout::NHWC) + { + read_padding = calculator.required_padding(PaddingCalculator::Option::EXCLUDE_BORDER); + if(border_mode == BorderMode::CONSTANT && policy == InterpolationPolicy::BILINEAR) + { + read_padding.top = 1; + } + } const PaddingSize write_padding = calculator.required_padding(PaddingCalculator::Option::EXCLUDE_BORDER); validate(src.info()->padding(), read_padding); validate(dst.info()->padding(), write_padding); @@ -118,8 +150,9 @@ using NEScaleFixture = ScaleValidationFixture<Tensor, Accessor, NEScale, T>; TEST_SUITE(Float) TEST_SUITE(FP32) -FIXTURE_DATA_TEST_CASE(RunSmall, NEScaleFixture<float>, framework::DatasetMode::ALL, combine(combine(combine(combine(datasets::SmallShapes(), framework::dataset::make("DataType", +FIXTURE_DATA_TEST_CASE(RunSmall, NEScaleFixture<float>, framework::DatasetMode::ALL, combine(combine(combine(combine(combine(datasets::SmallShapes(), framework::dataset::make("DataType", DataType::F32)), + framework::dataset::make("DataLayout", { DataLayout::NCHW, DataLayout::NHWC })), framework::dataset::make("InterpolationPolicy", { InterpolationPolicy::NEAREST_NEIGHBOR, InterpolationPolicy::BILINEAR })), datasets::BorderModes()), framework::dataset::make("SamplingPolicy", { SamplingPolicy::CENTER }))) @@ -131,8 +164,9 @@ FIXTURE_DATA_TEST_CASE(RunSmall, NEScaleFixture<float>, framework::DatasetMode:: // Validate output validate(Accessor(_target), _reference, valid_region, tolerance_f32, tolerance_num_f32); } -FIXTURE_DATA_TEST_CASE(RunLarge, NEScaleFixture<float>, framework::DatasetMode::NIGHTLY, combine(combine(combine(combine(datasets::LargeShapes(), framework::dataset::make("DataType", +FIXTURE_DATA_TEST_CASE(RunLarge, NEScaleFixture<float>, framework::DatasetMode::NIGHTLY, combine(combine(combine(combine(combine(datasets::LargeShapes(), framework::dataset::make("DataType", DataType::F32)), + framework::dataset::make("DataLayout", { DataLayout::NCHW, DataLayout::NHWC })), framework::dataset::make("InterpolationPolicy", { InterpolationPolicy::NEAREST_NEIGHBOR, InterpolationPolicy::BILINEAR })), datasets::BorderModes()), framework::dataset::make("SamplingPolicy", { SamplingPolicy::CENTER }))) @@ -149,8 +183,9 @@ TEST_SUITE_END() TEST_SUITE(Integer) TEST_SUITE(U8) -FIXTURE_DATA_TEST_CASE(RunSmall, NEScaleFixture<uint8_t>, framework::DatasetMode::ALL, combine(combine(combine(combine(datasets::SmallShapes(), framework::dataset::make("DataType", +FIXTURE_DATA_TEST_CASE(RunSmall, NEScaleFixture<uint8_t>, framework::DatasetMode::ALL, combine(combine(combine(combine(combine(datasets::SmallShapes(), framework::dataset::make("DataType", DataType::U8)), + framework::dataset::make("DataLayout", { DataLayout::NCHW, DataLayout::NHWC })), framework::dataset::make("InterpolationPolicy", { InterpolationPolicy::NEAREST_NEIGHBOR, InterpolationPolicy::BILINEAR })), datasets::BorderModes()), framework::dataset::make("SamplingPolicy", { SamplingPolicy::CENTER }))) @@ -162,8 +197,9 @@ FIXTURE_DATA_TEST_CASE(RunSmall, NEScaleFixture<uint8_t>, framework::DatasetMode // Validate output validate(Accessor(_target), _reference, valid_region, tolerance_u8); } -FIXTURE_DATA_TEST_CASE(RunLarge, NEScaleFixture<uint8_t>, framework::DatasetMode::NIGHTLY, combine(combine(combine(combine(datasets::LargeShapes(), framework::dataset::make("DataType", +FIXTURE_DATA_TEST_CASE(RunLarge, NEScaleFixture<uint8_t>, framework::DatasetMode::NIGHTLY, combine(combine(combine(combine(combine(datasets::LargeShapes(), framework::dataset::make("DataType", DataType::U8)), + framework::dataset::make("DataLayout", { DataLayout::NCHW, DataLayout::NHWC })), framework::dataset::make("InterpolationPolicy", { InterpolationPolicy::NEAREST_NEIGHBOR, InterpolationPolicy::BILINEAR })), datasets::BorderModes()), framework::dataset::make("SamplingPolicy", { SamplingPolicy::CENTER }))) @@ -177,8 +213,9 @@ FIXTURE_DATA_TEST_CASE(RunLarge, NEScaleFixture<uint8_t>, framework::DatasetMode } TEST_SUITE_END() TEST_SUITE(S16) -FIXTURE_DATA_TEST_CASE(RunSmall, NEScaleFixture<int16_t>, framework::DatasetMode::ALL, combine(combine(combine(combine(datasets::SmallShapes(), framework::dataset::make("DataType", +FIXTURE_DATA_TEST_CASE(RunSmall, NEScaleFixture<int16_t>, framework::DatasetMode::ALL, combine(combine(combine(combine(combine(datasets::SmallShapes(), framework::dataset::make("DataType", DataType::S16)), + framework::dataset::make("DataLayout", { DataLayout::NCHW, DataLayout::NHWC })), framework::dataset::make("InterpolationPolicy", { InterpolationPolicy::NEAREST_NEIGHBOR, InterpolationPolicy::BILINEAR })), datasets::BorderModes()), framework::dataset::make("SamplingPolicy", { SamplingPolicy::CENTER }))) @@ -190,8 +227,9 @@ FIXTURE_DATA_TEST_CASE(RunSmall, NEScaleFixture<int16_t>, framework::DatasetMode // Validate output validate(Accessor(_target), _reference, valid_region, tolerance_s16, tolerance_num_s16); } -FIXTURE_DATA_TEST_CASE(RunLarge, NEScaleFixture<int16_t>, framework::DatasetMode::NIGHTLY, combine(combine(combine(combine(datasets::LargeShapes(), framework::dataset::make("DataType", +FIXTURE_DATA_TEST_CASE(RunLarge, NEScaleFixture<int16_t>, framework::DatasetMode::NIGHTLY, combine(combine(combine(combine(combine(datasets::LargeShapes(), framework::dataset::make("DataType", DataType::S16)), + framework::dataset::make("DataLayout", { DataLayout::NCHW, DataLayout::NHWC })), framework::dataset::make("InterpolationPolicy", { InterpolationPolicy::NEAREST_NEIGHBOR, InterpolationPolicy::BILINEAR })), datasets::BorderModes()), framework::dataset::make("SamplingPolicy", { SamplingPolicy::CENTER }))) diff --git a/tests/validation/fixtures/ScaleFixture.h b/tests/validation/fixtures/ScaleFixture.h index 604bfb2622..ec102313c5 100644 --- a/tests/validation/fixtures/ScaleFixture.h +++ b/tests/validation/fixtures/ScaleFixture.h @@ -44,7 +44,7 @@ class ScaleValidationFixture : public framework::Fixture { public: template <typename...> - void setup(TensorShape shape, DataType data_type, InterpolationPolicy policy, BorderMode border_mode, SamplingPolicy sampling_policy) + void setup(TensorShape shape, DataType data_type, DataLayout data_layout, InterpolationPolicy policy, BorderMode border_mode, SamplingPolicy sampling_policy) { constexpr float max_width = 8192.0f; constexpr float max_height = 6384.0f; @@ -60,13 +60,16 @@ public: float scale_x = distribution_float(generator); float scale_y = distribution_float(generator); - scale_x = ((shape.x() * scale_x) > max_width) ? (max_width / shape.x()) : scale_x; - scale_y = ((shape.y() * scale_y) > max_height) ? (max_height / shape.y()) : scale_y; + const int idx_width = get_data_layout_dimension_index(data_layout, DataLayoutDimension::WIDTH); + const int idx_height = get_data_layout_dimension_index(data_layout, DataLayoutDimension::HEIGHT); + + scale_x = ((shape[idx_width] * scale_x) > max_width) ? (max_width / shape[idx_width]) : scale_x; + scale_y = ((shape[idx_height] * scale_y) > max_height) ? (max_height / shape[idx_height]) : scale_y; std::uniform_int_distribution<uint8_t> distribution_u8(0, 255); T constant_border_value = static_cast<T>(distribution_u8(generator)); - _target = compute_target(shape, scale_x, scale_y, policy, border_mode, constant_border_value, sampling_policy); + _target = compute_target(shape, data_layout, scale_x, scale_y, policy, border_mode, constant_border_value, sampling_policy); _reference = compute_reference(shape, scale_x, scale_y, policy, border_mode, constant_border_value, sampling_policy); } @@ -86,15 +89,25 @@ protected: } } - TensorType compute_target(const TensorShape &shape, const float scale_x, const float scale_y, + TensorType compute_target(TensorShape shape, DataLayout data_layout, const float scale_x, const float scale_y, InterpolationPolicy policy, BorderMode border_mode, T constant_border_value, SamplingPolicy sampling_policy) { + // Change shape in case of NHWC. + if(data_layout == DataLayout::NHWC) + { + permute(shape, PermutationVector(2U, 0U, 1U)); + } + // Create tensors - TensorType src = create_tensor<TensorType>(shape, _data_type); + TensorType src = create_tensor<TensorType>(shape, _data_type, 1, 0, QuantizationInfo(), data_layout); + + const int idx_width = get_data_layout_dimension_index(data_layout, DataLayoutDimension::WIDTH); + const int idx_height = get_data_layout_dimension_index(data_layout, DataLayoutDimension::HEIGHT); + TensorShape shape_scaled(shape); - shape_scaled.set(0, shape[0] * scale_x); - shape_scaled.set(1, shape[1] * scale_y); - TensorType dst = create_tensor<TensorType>(shape_scaled, _data_type); + shape_scaled.set(idx_width, shape[idx_width] * scale_x); + shape_scaled.set(idx_height, shape[idx_height] * scale_y); + TensorType dst = create_tensor<TensorType>(shape_scaled, _data_type, 1, 0, QuantizationInfo(), data_layout); // Create and configure function FunctionType scale; @@ -123,7 +136,7 @@ protected: InterpolationPolicy policy, BorderMode border_mode, T constant_border_value, SamplingPolicy sampling_policy) { // Create reference - SimpleTensor<T> src{ shape, _data_type }; + SimpleTensor<T> src{ shape, _data_type, 1, 0, QuantizationInfo() }; // Fill reference fill(src); diff --git a/tests/validation/reference/Scale.cpp b/tests/validation/reference/Scale.cpp index 5c9e95633c..f8a8b88cf9 100644 --- a/tests/validation/reference/Scale.cpp +++ b/tests/validation/reference/Scale.cpp @@ -23,6 +23,7 @@ */ #include "Scale.h" + #include "Utils.h" #include "arm_compute/core/utils/misc/Utility.h" #include "support/ToolchainSupport.h" |