diff options
Diffstat (limited to 'src/backends/cl/ClLayerSupport.cpp')
-rw-r--r-- | src/backends/cl/ClLayerSupport.cpp | 34 |
1 files changed, 34 insertions, 0 deletions
diff --git a/src/backends/cl/ClLayerSupport.cpp b/src/backends/cl/ClLayerSupport.cpp index 9a50f4aabd..afcaf566a9 100644 --- a/src/backends/cl/ClLayerSupport.cpp +++ b/src/backends/cl/ClLayerSupport.cpp @@ -32,6 +32,7 @@ #include "workloads/ClConvertFp16ToFp32Workload.hpp" #include "workloads/ClConvertFp32ToFp16Workload.hpp" #include "workloads/ClConvolution2dWorkload.hpp" +#include "workloads/ClConvolution3dWorkload.hpp" #include "workloads/ClDepthToSpaceWorkload.hpp" #include "workloads/ClDepthwiseConvolutionWorkload.hpp" #include "workloads/ClDequantizeWorkload.hpp" @@ -385,6 +386,39 @@ bool ClLayerSupport::IsConvolution2dSupported(const TensorInfo& input, nullptr); } +bool ClLayerSupport::IsConvolution3dSupported(const TensorInfo& input, + const TensorInfo& output, + const Convolution3dDescriptor& descriptor, + const TensorInfo& weights, + const Optional<TensorInfo>& biases, + Optional<std::string&> reasonIfUnsupported) const +{ + bool isFastMathEnabled = false; +#if defined(ARMCOMPUTECL_ENABLED) + if (m_ModelContextPtr) +{ + if (m_ModelContextPtr.get() != nullptr) + { + auto modelOptions = dynamic_cast<ClBackendModelContext*>(m_ModelContextPtr.get()); + if (modelOptions) + { + isFastMathEnabled = modelOptions->IsFastMathEnabled(); + } + } +} +#endif + + FORWARD_WORKLOAD_VALIDATE_FUNC(ClConvolution3dWorkloadValidate, + reasonIfUnsupported, + input, + output, + descriptor, + weights, + biases, + isFastMathEnabled, + nullptr); +} + bool ClLayerSupport::IsDequantizeSupported(const TensorInfo& input, const TensorInfo& output, Optional<std::string&> reasonIfUnsupported) const |