aboutsummaryrefslogtreecommitdiff
path: root/src/backends/cl/ClLayerSupport.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/backends/cl/ClLayerSupport.cpp')
-rw-r--r--src/backends/cl/ClLayerSupport.cpp33
1 files changed, 27 insertions, 6 deletions
diff --git a/src/backends/cl/ClLayerSupport.cpp b/src/backends/cl/ClLayerSupport.cpp
index 21d191ab2c..dfac28989c 100644
--- a/src/backends/cl/ClLayerSupport.cpp
+++ b/src/backends/cl/ClLayerSupport.cpp
@@ -15,6 +15,7 @@
#include <boost/core/ignore_unused.hpp>
#if defined(ARMCOMPUTECL_ENABLED)
+#include <aclCommon/ArmComputeUtils.hpp>
#include "workloads/ClAdditionWorkload.hpp"
#include "workloads/ClActivationWorkload.hpp"
#include "workloads/ClBatchNormalizationFloatWorkload.hpp"
@@ -39,6 +40,7 @@
#include "workloads/ClPooling2dWorkload.hpp"
#include "workloads/ClSoftmaxBaseWorkload.hpp"
#include "workloads/ClSpaceToBatchNdWorkload.hpp"
+#include "workloads/ClSplitterWorkload.hpp"
#include "workloads/ClStridedSliceWorkload.hpp"
#include "workloads/ClSubtractionWorkload.hpp"
#endif
@@ -612,12 +614,31 @@ bool ClLayerSupport::IsSplitterSupported(const TensorInfo& input,
const ViewsDescriptor& descriptor,
Optional<std::string&> reasonIfUnsupported) const
{
- ignore_unused(descriptor);
- ignore_unused(outputs);
- return IsSupportedForDataTypeCl(reasonIfUnsupported,
- input.GetDataType(),
- &TrueFunc<>,
- &TrueFunc<>);
+#if defined(ARMCOMPUTECL_ENABLED)
+ // Split along the last dimension, cannot use sub-tensors
+ // as width and height of the sub-tensors do not match
+ // the width and height of the parent tensor
+ // in case of input with more than 2D.
+ std::set<unsigned int> splitAxis = ComputeSplitAxis(descriptor, input.GetShape());
+ if (descriptor.GetNumDimensions() > 2 && splitAxis.size() == 1 &&
+ *splitAxis.begin() == descriptor.GetNumDimensions() - 1 )
+ {
+ FORWARD_WORKLOAD_VALIDATE_FUNC(ClSplitterWorkloadValidate,
+ reasonIfUnsupported,
+ input,
+ outputs,
+ *splitAxis.begin());
+ }
+#endif
+ for (auto output : outputs)
+ {
+ if (!input.IsTypeSpaceMatch(output)) // Cannot use sub-tensors if the types are not same space
+ {
+ SetValueChecked(reasonIfUnsupported, "Cl Splitter: Types and quantization parameters must match.");
+ return false;
+ }
+ }
+ return true;
}
bool ClLayerSupport::IsStridedSliceSupported(const TensorInfo& input,