aboutsummaryrefslogtreecommitdiff
path: root/tests/validation/fixtures/StackLayerFixture.h
diff options
context:
space:
mode:
authorGunes Bayir <gunes.bayir@arm.com>2023-10-07 23:52:48 +0100
committerGunes Bayir <gunes.bayir@arm.com>2023-10-10 09:48:53 +0000
commit0b72aa4b2abdba7ab48aaa8a45c624ba1e27a411 (patch)
treeea14c31a15c623cfa07db1dba722cd4ae61621b0 /tests/validation/fixtures/StackLayerFixture.h
parentc6137d2be4fb781b63831138970146a4eb8550a1 (diff)
downloadComputeLibrary-0b72aa4b2abdba7ab48aaa8a45c624ba1e27a411.tar.gz
Optimize NEStackLayer
Optimize the stack operation in Cpu by leveraging block memcpy. Resolves: COMPMID-6498 Change-Id: I49d79d179f0375a73d654edd59fb33072112569b Signed-off-by: Gunes Bayir <gunes.bayir@arm.com> Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/10451 Reviewed-by: SiCong Li <sicong.li@arm.com> Comments-Addressed: Arm Jenkins <bsgcomp@arm.com> Tested-by: Arm Jenkins <bsgcomp@arm.com> Benchmark: Arm Jenkins <bsgcomp@arm.com>
Diffstat (limited to 'tests/validation/fixtures/StackLayerFixture.h')
-rw-r--r--tests/validation/fixtures/StackLayerFixture.h34
1 files changed, 29 insertions, 5 deletions
diff --git a/tests/validation/fixtures/StackLayerFixture.h b/tests/validation/fixtures/StackLayerFixture.h
index 7320a032bd..7dd8fe47dc 100644
--- a/tests/validation/fixtures/StackLayerFixture.h
+++ b/tests/validation/fixtures/StackLayerFixture.h
@@ -21,8 +21,8 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*/
-#ifndef ARM_COMPUTE_TEST_STACK_LAYER_FIXTURE
-#define ARM_COMPUTE_TEST_STACK_LAYER_FIXTURE
+#ifndef ACL_TESTS_VALIDATION_FIXTURES_STACKLAYERFIXTURE_H
+#define ACL_TESTS_VALIDATION_FIXTURES_STACKLAYERFIXTURE_H
#include "arm_compute/core/Helpers.h"
#include "arm_compute/core/TensorShape.h"
@@ -54,7 +54,7 @@ class StackLayerValidationFixture : public framework::Fixture
public:
void setup(TensorShape shape_src, int axis, DataType data_type, int num_tensors)
{
- _target = compute_target(shape_src, axis, data_type, num_tensors);
+ _target = compute_target(shape_src, axis, data_type, num_tensors, false /* add_x_padding */);
_reference = compute_reference(shape_src, axis, data_type, num_tensors);
}
@@ -65,7 +65,7 @@ protected:
library->fill_tensor_uniform(tensor, i);
}
- TensorType compute_target(TensorShape shape_src, int axis, DataType data_type, int num_tensors)
+ TensorType compute_target(TensorShape shape_src, int axis, DataType data_type, int num_tensors, bool add_x_padding)
{
std::vector<TensorType> tensors(num_tensors);
std::vector<AbstractTensorType *> src(num_tensors);
@@ -90,6 +90,11 @@ protected:
// Allocate and fill the input tensors
for(int i = 0; i < num_tensors; ++i)
{
+ if(add_x_padding)
+ {
+ add_padding_x({&tensors[i]}, DataLayout::NHWC);
+ }
+
ARM_COMPUTE_ASSERT(tensors[i].info()->is_resizable());
tensors[i].allocator()->allocate();
ARM_COMPUTE_ASSERT(!tensors[i].info()->is_resizable());
@@ -98,6 +103,11 @@ protected:
fill(AccessorType(tensors[i]), i);
}
+ if(add_x_padding)
+ {
+ add_padding_x({&dst}, DataLayout::NHWC);
+ }
+
// Allocate output tensor
dst.allocator()->allocate();
@@ -131,7 +141,21 @@ protected:
TensorType _target{};
SimpleTensor<T> _reference{};
};
+
+template <typename TensorType, typename AbstractTensorType, typename AccessorType, typename FunctionType, typename T>
+class StackLayerWithPaddingValidationFixture :
+ public StackLayerValidationFixture<TensorType, AbstractTensorType, AccessorType, FunctionType, T>
+{
+public:
+ using Parent = StackLayerValidationFixture<TensorType, AbstractTensorType, AccessorType, FunctionType, T>;
+
+ void setup(TensorShape shape_src, int axis, DataType data_type, int num_tensors)
+ {
+ Parent::_target = Parent::compute_target(shape_src, axis, data_type, num_tensors, true /* add_x_padding */);
+ Parent::_reference = Parent::compute_reference(shape_src, axis, data_type, num_tensors);
+ }
+};
} // namespace validation
} // namespace test
} // namespace arm_compute
-#endif /* ARM_COMPUTE_TEST_STACK_LAYER_FIXTURE */
+#endif // ACL_TESTS_VALIDATION_FIXTURES_STACKLAYERFIXTURE_H