aboutsummaryrefslogtreecommitdiff
path: root/tests/validation/fixtures/PadLayerFixture.h
diff options
context:
space:
mode:
Diffstat (limited to 'tests/validation/fixtures/PadLayerFixture.h')
-rw-r--r--tests/validation/fixtures/PadLayerFixture.h48
1 files changed, 37 insertions, 11 deletions
diff --git a/tests/validation/fixtures/PadLayerFixture.h b/tests/validation/fixtures/PadLayerFixture.h
index 839313a118..3538cabfeb 100644
--- a/tests/validation/fixtures/PadLayerFixture.h
+++ b/tests/validation/fixtures/PadLayerFixture.h
@@ -45,30 +45,54 @@ class PaddingFixture : public framework::Fixture
{
public:
template <typename...>
- void setup(TensorShape shape, DataType data_type, const PaddingList &padding)
+ void setup(TensorShape shape, DataType data_type, const PaddingList &padding, const PaddingMode mode)
{
- _target = compute_target(shape, data_type, padding);
- _reference = compute_reference(shape, data_type, padding);
+ PaddingList clamped_padding = padding;
+ if(mode != PaddingMode::CONSTANT)
+ {
+ // Clamp padding to prevent applying more than is possible.
+ for(uint32_t i = 0; i < padding.size(); ++i)
+ {
+ if(mode == PaddingMode::REFLECT)
+ {
+ clamped_padding[i].first = std::min(static_cast<uint64_t>(padding[i].first), static_cast<uint64_t>(shape[i] - 1));
+ clamped_padding[i].second = std::min(static_cast<uint64_t>(padding[i].second), static_cast<uint64_t>(shape[i] - 1));
+ }
+ else
+ {
+ clamped_padding[i].first = std::min(static_cast<uint64_t>(padding[i].first), static_cast<uint64_t>(shape[i]));
+ clamped_padding[i].second = std::min(static_cast<uint64_t>(padding[i].second), static_cast<uint64_t>(shape[i]));
+ }
+ }
+ }
+ _target = compute_target(shape, data_type, clamped_padding, mode);
+ _reference = compute_reference(shape, data_type, clamped_padding, mode);
}
protected:
template <typename U>
- void fill(U &&tensor)
+ void fill(U &&tensor, int i)
{
- library->fill_tensor_uniform(tensor, 0);
+ library->fill_tensor_uniform(tensor, i);
}
TensorType compute_target(const TensorShape &shape,
DataType data_type,
- const PaddingList &paddings)
+ const PaddingList &paddings,
+ const PaddingMode mode)
{
// Create tensors
TensorType src = create_tensor<TensorType>(shape, data_type);
TensorType dst;
+ TensorType const_val = create_tensor<TensorType>(TensorShape(1), data_type);
+ const_val.allocator()->allocate();
+ fill(AccessorType(const_val), 1);
+ T const_value = *static_cast<T *>(AccessorType(const_val)(Coordinates(0)));
+
// Create and configure function
FunctionType padding;
- padding.configure(&src, &dst, paddings);
+ padding.configure(&src, &dst, paddings, const_value, mode);
ARM_COMPUTE_EXPECT(src.info()->is_resizable(), framework::LogLevel::ERRORS);
ARM_COMPUTE_EXPECT(dst.info()->is_resizable(), framework::LogLevel::ERRORS);
@@ -81,7 +105,7 @@ protected:
ARM_COMPUTE_EXPECT(!dst.info()->is_resizable(), framework::LogLevel::ERRORS);
// Fill tensors
- fill(AccessorType(src));
+ fill(AccessorType(src), 0);
// Compute function
padding.run();
@@ -90,15 +114,17 @@ protected:
}
SimpleTensor<T> compute_reference(const TensorShape &shape, DataType data_type,
- const PaddingList &paddings)
+ const PaddingList &paddings, const PaddingMode mode)
{
// Create reference tensor
SimpleTensor<T> src{ shape, data_type };
+ SimpleTensor<T> const_val{ TensorShape(1), data_type };
// Fill reference tensor
- fill(src);
+ fill(src, 0);
+ fill(const_val, 1);
- return reference::pad_layer(src, paddings);
+ return reference::pad_layer(src, paddings, const_val[0], mode);
}
TensorType _target{};