From ad9171701e6032b3ddf3573f85780bae30c512c6 Mon Sep 17 00:00:00 2001 From: Cathal Corbett Date: Wed, 9 Feb 2022 23:21:35 +0000 Subject: IVGCVSW-6267 Add support of Unidirectional Sequence Lstm fp32/fp16 to Cl !ComputeLibrary:7150 Signed-off-by: Cathal Corbett Change-Id: I01690e6555978d93c41d09bbe5378683bc925f61 --- src/backends/cl/ClLayerSupport.cpp | 188 ++++++++++++++++++++++--------------- 1 file changed, 111 insertions(+), 77 deletions(-) (limited to 'src/backends/cl/ClLayerSupport.cpp') diff --git a/src/backends/cl/ClLayerSupport.cpp b/src/backends/cl/ClLayerSupport.cpp index e5204e4d5b..e52f578bc0 100644 --- a/src/backends/cl/ClLayerSupport.cpp +++ b/src/backends/cl/ClLayerSupport.cpp @@ -78,6 +78,7 @@ #include "workloads/ClSubtractionWorkload.hpp" #include "workloads/ClTransposeConvolution2dWorkload.hpp" #include "workloads/ClTransposeWorkload.hpp" +#include "workloads/ClUnidirectionalSequenceLstmFloatWorkload.hpp" #endif @@ -212,6 +213,13 @@ bool ClLayerSupport::IsLayerSupported(const LayerType& type, infos[1], *(PolymorphicDowncast(&descriptor)), reasonIfUnsupported); + case LayerType::Cast: + return IsCastSupported(infos[0], infos[1], reasonIfUnsupported); + case LayerType::ChannelShuffle: + return IsChannelShuffleSupported(infos[0], + infos[1], + *(PolymorphicDowncast(&descriptor)), + reasonIfUnsupported); case LayerType::Comparison: return IsComparisonSupported(infos[0], infos[1], @@ -236,6 +244,14 @@ bool ClLayerSupport::IsLayerSupported(const LayerType& type, return IsConvertFp16ToFp32Supported(infos[0], infos[1], reasonIfUnsupported); case LayerType::ConvertFp32ToFp16: return IsConvertFp32ToFp16Supported(infos[0], infos[1], reasonIfUnsupported); + case LayerType::ConvertBf16ToFp32: + return LayerSupportBase::IsConvertBf16ToFp32Supported(infos[0], + infos[1], + reasonIfUnsupported); + case LayerType::ConvertFp32ToBf16: + return LayerSupportBase::IsConvertFp32ToBf16Supported(infos[0], + infos[1], + reasonIfUnsupported); case LayerType::Convolution2d: { if (infos.size() != 4) @@ -264,6 +280,34 @@ bool ClLayerSupport::IsLayerSupported(const LayerType& type, reasonIfUnsupported); } } + case LayerType::Convolution3d: + { + if (infos.size() != 4) + { + throw InvalidArgumentException("Invalid number of Convolution3d TensorInfos. " + "TensorInfos should be of format: {input, output, weights, biases}."); + } + + auto desc = *(PolymorphicDowncast(&descriptor)); + if (infos[3] == TensorInfo()) + { + return IsConvolution3dSupported(infos[0], + infos[1], + desc, + infos[2], + EmptyOptional(), + reasonIfUnsupported); + } + else + { + return IsConvolution3dSupported(infos[0], + infos[1], + desc, + infos[2], + infos[3], + reasonIfUnsupported); + } + } case LayerType::DepthToSpace: return IsDepthToSpaceSupported(infos[0], infos[1], @@ -361,16 +405,17 @@ bool ClLayerSupport::IsLayerSupported(const LayerType& type, *(PolymorphicDowncast(&descriptor)), lstmParamsInfo.value(), reasonIfUnsupported); - case LayerType::QLstm: - return IsQLstmSupported(infos[0], - infos[1], - infos[2], - infos[3], - infos[4], - infos[5], - *(PolymorphicDowncast(&descriptor)), - lstmParamsInfo.value(), - reasonIfUnsupported); + case LayerType::Map: + return true; + case LayerType::MemCopy: + return LayerSupportBase::IsMemCopySupported(infos[0], infos[1], reasonIfUnsupported); + case LayerType::MemImport: + return LayerSupportBase::IsMemImportSupported(infos[0], infos[1], reasonIfUnsupported); + case LayerType::Merge: + return LayerSupportBase::IsMergeSupported(infos[0], + infos[1], + infos[2], + reasonIfUnsupported); case LayerType::Maximum: return IsMaximumSupported(infos[0], infos[1], infos[2], reasonIfUnsupported); case LayerType::Mean: @@ -406,6 +451,16 @@ bool ClLayerSupport::IsLayerSupported(const LayerType& type, reasonIfUnsupported); case LayerType::Prelu: return IsPreluSupported(infos[0], infos[1], infos[2], reasonIfUnsupported); + case LayerType::QLstm: + return IsQLstmSupported(infos[0], + infos[1], + infos[2], + infos[3], + infos[4], + infos[5], + *(PolymorphicDowncast(&descriptor)), + lstmParamsInfo.value(), + reasonIfUnsupported); case LayerType::Quantize: return IsQuantizeSupported(infos[0], infos[1], reasonIfUnsupported); case LayerType::QuantizedLstm: @@ -416,6 +471,13 @@ bool ClLayerSupport::IsLayerSupported(const LayerType& type, infos[4], quantizedLstmParamsInfo.value(), reasonIfUnsupported); + case LayerType::Rank: + return true; + case LayerType::Reduce: + return IsReduceSupported(infos[0], + infos[1], + *(PolymorphicDowncast(&descriptor)), + reasonIfUnsupported); case LayerType::Reshape: return IsReshapeSupported(infos[0], infos[1], @@ -426,11 +488,10 @@ bool ClLayerSupport::IsLayerSupported(const LayerType& type, infos[1], *(PolymorphicDowncast(&descriptor)), reasonIfUnsupported); - case LayerType::Reduce: - return IsReduceSupported(infos[0], - infos[1], - *(PolymorphicDowncast(&descriptor)), - reasonIfUnsupported); + case LayerType::Shape: + return LayerSupportBase::IsShapeSupported(infos[0], + infos[1], + reasonIfUnsupported); case LayerType::Slice: return IsSliceSupported(infos[0], infos[1], @@ -515,72 +576,23 @@ bool ClLayerSupport::IsLayerSupported(const LayerType& type, reasonIfUnsupported); } } - case LayerType::Cast: - return IsCastSupported(infos[0], infos[1], reasonIfUnsupported); - case LayerType::ChannelShuffle: - return IsChannelShuffleSupported(infos[0], - infos[1], - *(PolymorphicDowncast(&descriptor)), - reasonIfUnsupported); - case LayerType::Convolution3d: - { - if (infos.size() != 4) - { - throw InvalidArgumentException("Invalid number of Convolution3d TensorInfos. " - "TensorInfos should be of format: {input, output, weights, biases}."); - } - - auto desc = *(PolymorphicDowncast(&descriptor)); - if (infos[3] == TensorInfo()) - { - return IsConvolution3dSupported(infos[0], - infos[1], - desc, - infos[2], - EmptyOptional(), - reasonIfUnsupported); - } - else - { - return IsConvolution3dSupported(infos[0], - infos[1], - desc, - infos[2], - infos[3], - reasonIfUnsupported); - } - } - case LayerType::MemCopy: - return LayerSupportBase::IsMemCopySupported(infos[0], infos[1], reasonIfUnsupported); - case LayerType::MemImport: - return LayerSupportBase::IsMemImportSupported(infos[0], infos[1], reasonIfUnsupported); - case LayerType::Map: - return true; + case LayerType::UnidirectionalSequenceLstm: + return IsUnidirectionalSequenceLstmSupported(infos[0], + infos[1], + infos[2], + infos[3], + infos[4], + infos[5], + *(PolymorphicDowncast(&descriptor)), + lstmParamsInfo.value(), + reasonIfUnsupported); case LayerType::Unmap: return true; - case LayerType::Merge: - return LayerSupportBase::IsMergeSupported(infos[0], - infos[1], - infos[2], - reasonIfUnsupported); - case LayerType::Rank: - return true; - case LayerType::Shape: - return LayerSupportBase::IsShapeSupported(infos[0], - infos[1], - reasonIfUnsupported); - case LayerType::ConvertBf16ToFp32: - return LayerSupportBase::IsConvertBf16ToFp32Supported(infos[0], - infos[1], - reasonIfUnsupported); - case LayerType::ConvertFp32ToBf16: - return LayerSupportBase::IsConvertFp32ToBf16Supported(infos[0], - infos[1], - reasonIfUnsupported); default: // layers not supported in cl by default: - // debug, detectionpostprocess, fakequantization, precompiled, - // standin, switch, unidirectionalsequencelstm, pooling3d + // debug, detectionpostprocess, fakequantization, + // precompiled, standin, switch, pooling3d return false; } } @@ -1415,4 +1427,26 @@ bool ClLayerSupport::IsTransposeSupported(const TensorInfo& input, FORWARD_WORKLOAD_VALIDATE_FUNC(ClTransposeWorkloadValidate, reasonIfUnsupported, input, output, descriptor); } +bool ClLayerSupport::IsUnidirectionalSequenceLstmSupported(const TensorInfo& input, + const TensorInfo& outputStateIn, + const TensorInfo& cellStateIn, + const TensorInfo& output, + const Optional& hiddenStateOutput, + const Optional& cellStateOutput, + const UnidirectionalSequenceLstmDescriptor& descriptor, + const LstmInputParamsInfo& paramsInfo, + Optional reasonIfUnsupported) const +{ + FORWARD_WORKLOAD_VALIDATE_FUNC(ClUnidirectionalSequenceLstmFloatWorkloadValidate, + reasonIfUnsupported, + input, + outputStateIn, + cellStateIn, + output, + hiddenStateOutput, + cellStateOutput, + descriptor, + paramsInfo); +} + } // namespace armnn -- cgit v1.2.1