aboutsummaryrefslogtreecommitdiff
path: root/tests
diff options
context:
space:
mode:
Diffstat (limited to 'tests')
-rw-r--r--tests/validation/CL/ReduceMean.cpp2
-rw-r--r--tests/validation/fixtures/ReduceMeanFixture.h3
2 files changed, 3 insertions, 2 deletions
diff --git a/tests/validation/CL/ReduceMean.cpp b/tests/validation/CL/ReduceMean.cpp
index 07e859f391..cfd4a2730c 100644
--- a/tests/validation/CL/ReduceMean.cpp
+++ b/tests/validation/CL/ReduceMean.cpp
@@ -48,7 +48,7 @@ constexpr AbsoluteTolerance<uint8_t> tolerance_qasymm8(1); /**< Tolerance value
const auto axis_keep = combine(framework::dataset::make("Axis", { Coordinates(0), Coordinates(1, 0), Coordinates(1, 2), Coordinates(0, 2), Coordinates(1, 3), Coordinates(0, 1, 2, 3) }),
framework::dataset::make("KeepDims", { true }));
-const auto axis_drop = combine(framework::dataset::make("Axis", { Coordinates(0), Coordinates(1), Coordinates(3) }), framework::dataset::make("KeepDims", { false }));
+const auto axis_drop = combine(framework::dataset::make("Axis", { Coordinates(0), Coordinates(1), Coordinates(3), Coordinates(1, 2), Coordinates(2, 1) }), framework::dataset::make("KeepDims", { false }));
} // namespace
TEST_SUITE(CL)
TEST_SUITE(ReduceMean)
diff --git a/tests/validation/fixtures/ReduceMeanFixture.h b/tests/validation/fixtures/ReduceMeanFixture.h
index 8692213641..769d7f674f 100644
--- a/tests/validation/fixtures/ReduceMeanFixture.h
+++ b/tests/validation/fixtures/ReduceMeanFixture.h
@@ -119,9 +119,10 @@ protected:
if(!keep_dims)
{
TensorShape output_shape = src_shape;
+ std::sort(axis.begin(), axis.begin() + axis.num_dimensions());
for(unsigned int i = 0; i < axis.num_dimensions(); ++i)
{
- output_shape.remove_dimension(axis[i]);
+ output_shape.remove_dimension(axis[i] - i);
}
out = reference::reshape_layer(out, output_shape);