diff options
Diffstat (limited to 'src/backends/neon/NeonLayerSupport.cpp')
-rw-r--r-- | src/backends/neon/NeonLayerSupport.cpp | 36 |
1 files changed, 36 insertions, 0 deletions
diff --git a/src/backends/neon/NeonLayerSupport.cpp b/src/backends/neon/NeonLayerSupport.cpp index b095ed5629..53d0f0b633 100644 --- a/src/backends/neon/NeonLayerSupport.cpp +++ b/src/backends/neon/NeonLayerSupport.cpp @@ -48,6 +48,7 @@ #include "workloads/NeonPermuteWorkload.hpp" #include "workloads/NeonPooling2dWorkload.hpp" #include "workloads/NeonPreluWorkload.hpp" +#include "workloads/NeonQLstmWorkload.hpp" #include "workloads/NeonQuantizeWorkload.hpp" #include "workloads/NeonQuantizedLstmWorkload.hpp" #include "workloads/NeonReshapeWorkload.hpp" @@ -615,6 +616,41 @@ bool NeonLayerSupport::IsPreluSupported(const armnn::TensorInfo &input, FORWARD_WORKLOAD_VALIDATE_FUNC(NeonPreluWorkloadValidate, reasonIfUnsupported, input, alpha, output); } +bool NeonLayerSupport::IsQLstmSupported(const TensorInfo& input, + const TensorInfo& previousOutputIn, + const TensorInfo& previousCellStateIn, + const TensorInfo& outputStateOut, + const TensorInfo& cellStateOut, + const TensorInfo& output, + const QLstmDescriptor& descriptor, + const LstmInputParamsInfo& paramsInfo, + Optional<std::string&> reasonIfUnsupported) const +{ + // Check required here in order to pass IsLayerSupported for datatypes tests + if (input.GetDataType() == armnn::DataType::QAsymmS8 && + previousOutputIn.GetDataType() == armnn::DataType::QAsymmS8 && + previousCellStateIn.GetDataType() == armnn::DataType::QSymmS16 && + outputStateOut.GetDataType() == armnn::DataType::QAsymmS8 && + cellStateOut.GetDataType() == armnn::DataType::QSymmS16 && + output.GetDataType() == armnn::DataType::QAsymmS8) + { + FORWARD_WORKLOAD_VALIDATE_FUNC(NeonQLstmWorkloadValidate, + reasonIfUnsupported, + input, + previousCellStateIn, + previousOutputIn, + cellStateOut, + outputStateOut, + output, + descriptor, + paramsInfo); + } + else + { + return false; + } +} + bool NeonLayerSupport::IsQuantizeSupported(const TensorInfo& input, const TensorInfo& output, Optional<std::string&> reasonIfUnsupported) const |