aboutsummaryrefslogtreecommitdiff
path: root/tests/benchmark/fixtures/ScaleFixture.h
diff options
context:
space:
mode:
Diffstat (limited to 'tests/benchmark/fixtures/ScaleFixture.h')
-rw-r--r--tests/benchmark/fixtures/ScaleFixture.h15
1 files changed, 12 insertions, 3 deletions
diff --git a/tests/benchmark/fixtures/ScaleFixture.h b/tests/benchmark/fixtures/ScaleFixture.h
index cd51f5778f..b2fbd9c3b6 100644
--- a/tests/benchmark/fixtures/ScaleFixture.h
+++ b/tests/benchmark/fixtures/ScaleFixture.h
@@ -41,11 +41,17 @@ class ScaleFixture : public framework::Fixture
{
public:
template <typename...>
- void setup(TensorShape shape, DataType data_type, InterpolationPolicy policy, BorderMode border_mode, SamplingPolicy sampling_policy)
+ void setup(TensorShape shape, DataType data_type, DataLayout data_layout, InterpolationPolicy policy, BorderMode border_mode, SamplingPolicy sampling_policy)
{
constexpr float max_width = 8192.0f;
constexpr float max_height = 6384.0f;
+ // Change shape in case of NHWC.
+ if(data_layout == DataLayout::NHWC)
+ {
+ permute(shape, PermutationVector(2U, 0U, 1U));
+ }
+
std::mt19937 generator(library->seed());
std::uniform_real_distribution<float> distribution_float(0.25f, 3.0f);
float scale_x = distribution_float(generator);
@@ -57,9 +63,12 @@ public:
std::uniform_int_distribution<uint8_t> distribution_u8(0, 255);
uint8_t constant_border_value = static_cast<uint8_t>(distribution_u8(generator));
+ const int idx_width = get_data_layout_dimension_index(data_layout, DataLayoutDimension::WIDTH);
+ const int idx_height = get_data_layout_dimension_index(data_layout, DataLayoutDimension::HEIGHT);
+
TensorShape shape_scaled(shape);
- shape_scaled.set(0, shape[0] * scale_x);
- shape_scaled.set(1, shape[1] * scale_y);
+ shape_scaled.set(idx_width, shape[idx_width] * scale_x);
+ shape_scaled.set(idx_height, shape[idx_height] * scale_y);
// Create tensors
src = create_tensor<TensorType>(shape, data_type);