diff options
Diffstat (limited to 'tests/validation/fixtures')
-rw-r--r-- | tests/validation/fixtures/StackLayerFixture.h | 34 |
1 files changed, 29 insertions, 5 deletions
diff --git a/tests/validation/fixtures/StackLayerFixture.h b/tests/validation/fixtures/StackLayerFixture.h index 7320a032bd..7dd8fe47dc 100644 --- a/tests/validation/fixtures/StackLayerFixture.h +++ b/tests/validation/fixtures/StackLayerFixture.h @@ -21,8 +21,8 @@ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE * SOFTWARE. */ -#ifndef ARM_COMPUTE_TEST_STACK_LAYER_FIXTURE -#define ARM_COMPUTE_TEST_STACK_LAYER_FIXTURE +#ifndef ACL_TESTS_VALIDATION_FIXTURES_STACKLAYERFIXTURE_H +#define ACL_TESTS_VALIDATION_FIXTURES_STACKLAYERFIXTURE_H #include "arm_compute/core/Helpers.h" #include "arm_compute/core/TensorShape.h" @@ -54,7 +54,7 @@ class StackLayerValidationFixture : public framework::Fixture public: void setup(TensorShape shape_src, int axis, DataType data_type, int num_tensors) { - _target = compute_target(shape_src, axis, data_type, num_tensors); + _target = compute_target(shape_src, axis, data_type, num_tensors, false /* add_x_padding */); _reference = compute_reference(shape_src, axis, data_type, num_tensors); } @@ -65,7 +65,7 @@ protected: library->fill_tensor_uniform(tensor, i); } - TensorType compute_target(TensorShape shape_src, int axis, DataType data_type, int num_tensors) + TensorType compute_target(TensorShape shape_src, int axis, DataType data_type, int num_tensors, bool add_x_padding) { std::vector<TensorType> tensors(num_tensors); std::vector<AbstractTensorType *> src(num_tensors); @@ -90,6 +90,11 @@ protected: // Allocate and fill the input tensors for(int i = 0; i < num_tensors; ++i) { + if(add_x_padding) + { + add_padding_x({&tensors[i]}, DataLayout::NHWC); + } + ARM_COMPUTE_ASSERT(tensors[i].info()->is_resizable()); tensors[i].allocator()->allocate(); ARM_COMPUTE_ASSERT(!tensors[i].info()->is_resizable()); @@ -98,6 +103,11 @@ protected: fill(AccessorType(tensors[i]), i); } + if(add_x_padding) + { + add_padding_x({&dst}, DataLayout::NHWC); + } + // Allocate output tensor dst.allocator()->allocate(); @@ -131,7 +141,21 @@ protected: TensorType _target{}; SimpleTensor<T> _reference{}; }; + +template <typename TensorType, typename AbstractTensorType, typename AccessorType, typename FunctionType, typename T> +class StackLayerWithPaddingValidationFixture : + public StackLayerValidationFixture<TensorType, AbstractTensorType, AccessorType, FunctionType, T> +{ +public: + using Parent = StackLayerValidationFixture<TensorType, AbstractTensorType, AccessorType, FunctionType, T>; + + void setup(TensorShape shape_src, int axis, DataType data_type, int num_tensors) + { + Parent::_target = Parent::compute_target(shape_src, axis, data_type, num_tensors, true /* add_x_padding */); + Parent::_reference = Parent::compute_reference(shape_src, axis, data_type, num_tensors); + } +}; } // namespace validation } // namespace test } // namespace arm_compute -#endif /* ARM_COMPUTE_TEST_STACK_LAYER_FIXTURE */ +#endif // ACL_TESTS_VALIDATION_FIXTURES_STACKLAYERFIXTURE_H |