From 7bdcc2e3073accbd6e7884a4b7d0d016253e7412 Mon Sep 17 00:00:00 2001 From: Ryan OShea Date: Wed, 13 May 2020 16:36:19 +0100 Subject: IVGCVSW-4450 Add CL Enhanced Quantized LSTM Workload * Adds QLstm CL workload * Added Layer and CreateWorkload tests Signed-off-by: Ryan OShea Signed-off-by: James Conroy Change-Id: I32335e528467bfd619edb249d2971705ac2a6163 --- src/backends/cl/ClLayerSupport.cpp | 39 ++++++++++++++++++++++++++++++++++++-- 1 file changed, 37 insertions(+), 2 deletions(-) (limited to 'src/backends/cl/ClLayerSupport.cpp') diff --git a/src/backends/cl/ClLayerSupport.cpp b/src/backends/cl/ClLayerSupport.cpp index eb68a80765..7418dbd9e4 100644 --- a/src/backends/cl/ClLayerSupport.cpp +++ b/src/backends/cl/ClLayerSupport.cpp @@ -47,11 +47,12 @@ #include "workloads/ClPermuteWorkload.hpp" #include "workloads/ClPooling2dWorkload.hpp" #include "workloads/ClPreluWorkload.hpp" +#include "workloads/ClQLstmWorkload.hpp" +#include "workloads/ClQuantizedLstmWorkload.hpp" +#include "workloads/ClQuantizeWorkload.hpp" #include "workloads/ClReshapeWorkload.hpp" #include "workloads/ClResizeWorkload.hpp" #include "workloads/ClRsqrtWorkload.hpp" -#include "workloads/ClQuantizedLstmWorkload.hpp" -#include "workloads/ClQuantizeWorkload.hpp" #include "workloads/ClSliceWorkload.hpp" #include "workloads/ClSoftmaxWorkload.hpp" #include "workloads/ClSpaceToBatchNdWorkload.hpp" @@ -618,6 +619,40 @@ bool ClLayerSupport::IsPreluSupported(const armnn::TensorInfo &input, FORWARD_WORKLOAD_VALIDATE_FUNC(ClPreluWorkloadValidate, reasonIfUnsupported, input, alpha, output); } +bool ClLayerSupport::IsQLstmSupported(const TensorInfo& input, + const TensorInfo& previousOutputIn, + const TensorInfo& previousCellStateIn, + const TensorInfo& outputStateOut, + const TensorInfo& cellStateOut, + const TensorInfo& output, + const QLstmDescriptor& descriptor, + const LstmInputParamsInfo& paramsInfo, + Optional reasonIfUnsupported) const +{ + if (input.GetDataType() == armnn::DataType::QAsymmS8 && + previousOutputIn.GetDataType() == armnn::DataType::QAsymmS8 && + previousCellStateIn.GetDataType() == armnn::DataType::QSymmS16 && + outputStateOut.GetDataType() == armnn::DataType::QAsymmS8 && + cellStateOut.GetDataType() == armnn::DataType::QSymmS16 && + output.GetDataType() == armnn::DataType::QAsymmS8) + { + FORWARD_WORKLOAD_VALIDATE_FUNC(ClQLstmWorkloadValidate, + reasonIfUnsupported, + input, + previousCellStateIn, + previousOutputIn, + cellStateOut, + outputStateOut, + output, + descriptor, + paramsInfo); + } + else + { + return false; + } +} + bool ClLayerSupport::IsQuantizedLstmSupported(const TensorInfo& input, const TensorInfo& previousCellStateIn, const TensorInfo& previousOutputIn, -- cgit v1.2.1