From 49895f4738e1969080b47b837a152cd6b5c8414c Mon Sep 17 00:00:00 2001 From: David Monahan Date: Tue, 21 Jul 2020 11:16:51 +0100 Subject: IVGCVSW-5085 Updates to CL and NEON TensorHandleFactory * Update the CL and Neon TensorHandleFactories to not use SubTensors if Axis is on x or y Signed-off-by: David Monahan Change-Id: I782b89f50a92b21fdcbe68dab0281ad265fb3b63 --- src/backends/cl/ClTensorHandleFactory.cpp | 12 ++++++++++++ src/backends/neon/NeonTensorHandleFactory.cpp | 12 ++++++++++++ 2 files changed, 24 insertions(+) (limited to 'src') diff --git a/src/backends/cl/ClTensorHandleFactory.cpp b/src/backends/cl/ClTensorHandleFactory.cpp index 8af97f41e2..e92913f196 100644 --- a/src/backends/cl/ClTensorHandleFactory.cpp +++ b/src/backends/cl/ClTensorHandleFactory.cpp @@ -36,6 +36,18 @@ std::unique_ptr ClTensorHandleFactory::CreateSubTensorHandle(ITen const arm_compute::TensorShape parentShape = armcomputetensorutils::BuildArmComputeTensorShape( parent.GetShape()); + + // In order for ACL to support subtensors the concat axis cannot be on x or y and the values of x and y + // must match the parent shapes + if (coords.x() != 0 || coords.y() != 0) + { + return nullptr; + } + if ((parentShape.x() != shape.x()) || (parentShape.y() != shape.y())) + { + return nullptr; + } + if (!::arm_compute::error_on_invalid_subtensor(__func__, __FILE__, __LINE__, parentShape, coords, shape)) { return nullptr; diff --git a/src/backends/neon/NeonTensorHandleFactory.cpp b/src/backends/neon/NeonTensorHandleFactory.cpp index ec9e0631fe..4e013a37a1 100644 --- a/src/backends/neon/NeonTensorHandleFactory.cpp +++ b/src/backends/neon/NeonTensorHandleFactory.cpp @@ -33,6 +33,18 @@ std::unique_ptr NeonTensorHandleFactory::CreateSubTensorHandle(IT } const arm_compute::TensorShape parentShape = armcomputetensorutils::BuildArmComputeTensorShape(parent.GetShape()); + + // In order for ACL to support subtensors the concat axis cannot be on x or y and the values of x and y + // must match the parent shapes + if (coords.x() != 0 || coords.y() != 0) + { + return nullptr; + } + if ((parentShape.x() != shape.x()) || (parentShape.y() != shape.y())) + { + return nullptr; + } + if (!::arm_compute::error_on_invalid_subtensor(__func__, __FILE__, __LINE__, parentShape, coords, shape)) { return nullptr; -- cgit v1.2.1