From cc34093a57ea486e48f9aa66fc8b98a7bbdefef1 Mon Sep 17 00:00:00 2001 From: James Conroy Date: Tue, 12 May 2020 18:08:52 +0100 Subject: IVGCVSW-4451 Add QLstm NEON workload * Adds QLstm workload. * Adds CreateWorkload and Layer tests. Signed-off-by: James Conroy Change-Id: I585eb2691395ee4ccd45b5a853660f90fc5cc821 --- src/backends/neon/NeonLayerSupport.cpp | 36 ++++++++++++++++++++++++++++++++++ 1 file changed, 36 insertions(+) (limited to 'src/backends/neon/NeonLayerSupport.cpp') diff --git a/src/backends/neon/NeonLayerSupport.cpp b/src/backends/neon/NeonLayerSupport.cpp index b095ed5629..53d0f0b633 100644 --- a/src/backends/neon/NeonLayerSupport.cpp +++ b/src/backends/neon/NeonLayerSupport.cpp @@ -48,6 +48,7 @@ #include "workloads/NeonPermuteWorkload.hpp" #include "workloads/NeonPooling2dWorkload.hpp" #include "workloads/NeonPreluWorkload.hpp" +#include "workloads/NeonQLstmWorkload.hpp" #include "workloads/NeonQuantizeWorkload.hpp" #include "workloads/NeonQuantizedLstmWorkload.hpp" #include "workloads/NeonReshapeWorkload.hpp" @@ -615,6 +616,41 @@ bool NeonLayerSupport::IsPreluSupported(const armnn::TensorInfo &input, FORWARD_WORKLOAD_VALIDATE_FUNC(NeonPreluWorkloadValidate, reasonIfUnsupported, input, alpha, output); } +bool NeonLayerSupport::IsQLstmSupported(const TensorInfo& input, + const TensorInfo& previousOutputIn, + const TensorInfo& previousCellStateIn, + const TensorInfo& outputStateOut, + const TensorInfo& cellStateOut, + const TensorInfo& output, + const QLstmDescriptor& descriptor, + const LstmInputParamsInfo& paramsInfo, + Optional reasonIfUnsupported) const +{ + // Check required here in order to pass IsLayerSupported for datatypes tests + if (input.GetDataType() == armnn::DataType::QAsymmS8 && + previousOutputIn.GetDataType() == armnn::DataType::QAsymmS8 && + previousCellStateIn.GetDataType() == armnn::DataType::QSymmS16 && + outputStateOut.GetDataType() == armnn::DataType::QAsymmS8 && + cellStateOut.GetDataType() == armnn::DataType::QSymmS16 && + output.GetDataType() == armnn::DataType::QAsymmS8) + { + FORWARD_WORKLOAD_VALIDATE_FUNC(NeonQLstmWorkloadValidate, + reasonIfUnsupported, + input, + previousCellStateIn, + previousOutputIn, + cellStateOut, + outputStateOut, + output, + descriptor, + paramsInfo); + } + else + { + return false; + } +} + bool NeonLayerSupport::IsQuantizeSupported(const TensorInfo& input, const TensorInfo& output, Optional reasonIfUnsupported) const -- cgit v1.2.1