diff options
Diffstat (limited to 'src/backends/cl/ClLayerSupport.cpp')
-rw-r--r-- | src/backends/cl/ClLayerSupport.cpp | 33 |
1 files changed, 27 insertions, 6 deletions
diff --git a/src/backends/cl/ClLayerSupport.cpp b/src/backends/cl/ClLayerSupport.cpp index 21d191ab2c..dfac28989c 100644 --- a/src/backends/cl/ClLayerSupport.cpp +++ b/src/backends/cl/ClLayerSupport.cpp @@ -15,6 +15,7 @@ #include <boost/core/ignore_unused.hpp> #if defined(ARMCOMPUTECL_ENABLED) +#include <aclCommon/ArmComputeUtils.hpp> #include "workloads/ClAdditionWorkload.hpp" #include "workloads/ClActivationWorkload.hpp" #include "workloads/ClBatchNormalizationFloatWorkload.hpp" @@ -39,6 +40,7 @@ #include "workloads/ClPooling2dWorkload.hpp" #include "workloads/ClSoftmaxBaseWorkload.hpp" #include "workloads/ClSpaceToBatchNdWorkload.hpp" +#include "workloads/ClSplitterWorkload.hpp" #include "workloads/ClStridedSliceWorkload.hpp" #include "workloads/ClSubtractionWorkload.hpp" #endif @@ -612,12 +614,31 @@ bool ClLayerSupport::IsSplitterSupported(const TensorInfo& input, const ViewsDescriptor& descriptor, Optional<std::string&> reasonIfUnsupported) const { - ignore_unused(descriptor); - ignore_unused(outputs); - return IsSupportedForDataTypeCl(reasonIfUnsupported, - input.GetDataType(), - &TrueFunc<>, - &TrueFunc<>); +#if defined(ARMCOMPUTECL_ENABLED) + // Split along the last dimension, cannot use sub-tensors + // as width and height of the sub-tensors do not match + // the width and height of the parent tensor + // in case of input with more than 2D. + std::set<unsigned int> splitAxis = ComputeSplitAxis(descriptor, input.GetShape()); + if (descriptor.GetNumDimensions() > 2 && splitAxis.size() == 1 && + *splitAxis.begin() == descriptor.GetNumDimensions() - 1 ) + { + FORWARD_WORKLOAD_VALIDATE_FUNC(ClSplitterWorkloadValidate, + reasonIfUnsupported, + input, + outputs, + *splitAxis.begin()); + } +#endif + for (auto output : outputs) + { + if (!input.IsTypeSpaceMatch(output)) // Cannot use sub-tensors if the types are not same space + { + SetValueChecked(reasonIfUnsupported, "Cl Splitter: Types and quantization parameters must match."); + return false; + } + } + return true; } bool ClLayerSupport::IsStridedSliceSupported(const TensorInfo& input, |