aboutsummaryrefslogtreecommitdiff
path: root/tests/validation/fixtures/ScaleFixture.h
diff options
context:
space:
mode:
Diffstat (limited to 'tests/validation/fixtures/ScaleFixture.h')
-rw-r--r--tests/validation/fixtures/ScaleFixture.h33
1 files changed, 23 insertions, 10 deletions
diff --git a/tests/validation/fixtures/ScaleFixture.h b/tests/validation/fixtures/ScaleFixture.h
index 604bfb2622..ec102313c5 100644
--- a/tests/validation/fixtures/ScaleFixture.h
+++ b/tests/validation/fixtures/ScaleFixture.h
@@ -44,7 +44,7 @@ class ScaleValidationFixture : public framework::Fixture
{
public:
template <typename...>
- void setup(TensorShape shape, DataType data_type, InterpolationPolicy policy, BorderMode border_mode, SamplingPolicy sampling_policy)
+ void setup(TensorShape shape, DataType data_type, DataLayout data_layout, InterpolationPolicy policy, BorderMode border_mode, SamplingPolicy sampling_policy)
{
constexpr float max_width = 8192.0f;
constexpr float max_height = 6384.0f;
@@ -60,13 +60,16 @@ public:
float scale_x = distribution_float(generator);
float scale_y = distribution_float(generator);
- scale_x = ((shape.x() * scale_x) > max_width) ? (max_width / shape.x()) : scale_x;
- scale_y = ((shape.y() * scale_y) > max_height) ? (max_height / shape.y()) : scale_y;
+ const int idx_width = get_data_layout_dimension_index(data_layout, DataLayoutDimension::WIDTH);
+ const int idx_height = get_data_layout_dimension_index(data_layout, DataLayoutDimension::HEIGHT);
+
+ scale_x = ((shape[idx_width] * scale_x) > max_width) ? (max_width / shape[idx_width]) : scale_x;
+ scale_y = ((shape[idx_height] * scale_y) > max_height) ? (max_height / shape[idx_height]) : scale_y;
std::uniform_int_distribution<uint8_t> distribution_u8(0, 255);
T constant_border_value = static_cast<T>(distribution_u8(generator));
- _target = compute_target(shape, scale_x, scale_y, policy, border_mode, constant_border_value, sampling_policy);
+ _target = compute_target(shape, data_layout, scale_x, scale_y, policy, border_mode, constant_border_value, sampling_policy);
_reference = compute_reference(shape, scale_x, scale_y, policy, border_mode, constant_border_value, sampling_policy);
}
@@ -86,15 +89,25 @@ protected:
}
}
- TensorType compute_target(const TensorShape &shape, const float scale_x, const float scale_y,
+ TensorType compute_target(TensorShape shape, DataLayout data_layout, const float scale_x, const float scale_y,
InterpolationPolicy policy, BorderMode border_mode, T constant_border_value, SamplingPolicy sampling_policy)
{
+ // Change shape in case of NHWC.
+ if(data_layout == DataLayout::NHWC)
+ {
+ permute(shape, PermutationVector(2U, 0U, 1U));
+ }
+
// Create tensors
- TensorType src = create_tensor<TensorType>(shape, _data_type);
+ TensorType src = create_tensor<TensorType>(shape, _data_type, 1, 0, QuantizationInfo(), data_layout);
+
+ const int idx_width = get_data_layout_dimension_index(data_layout, DataLayoutDimension::WIDTH);
+ const int idx_height = get_data_layout_dimension_index(data_layout, DataLayoutDimension::HEIGHT);
+
TensorShape shape_scaled(shape);
- shape_scaled.set(0, shape[0] * scale_x);
- shape_scaled.set(1, shape[1] * scale_y);
- TensorType dst = create_tensor<TensorType>(shape_scaled, _data_type);
+ shape_scaled.set(idx_width, shape[idx_width] * scale_x);
+ shape_scaled.set(idx_height, shape[idx_height] * scale_y);
+ TensorType dst = create_tensor<TensorType>(shape_scaled, _data_type, 1, 0, QuantizationInfo(), data_layout);
// Create and configure function
FunctionType scale;
@@ -123,7 +136,7 @@ protected:
InterpolationPolicy policy, BorderMode border_mode, T constant_border_value, SamplingPolicy sampling_policy)
{
// Create reference
- SimpleTensor<T> src{ shape, _data_type };
+ SimpleTensor<T> src{ shape, _data_type, 1, 0, QuantizationInfo() };
// Fill reference
fill(src);