diff options
author | Ryan OShea <Ryan.OShea2@arm.com> | 2020-05-13 16:36:19 +0100 |
---|---|---|
committer | Jan Eilers <jan.eilers@arm.com> | 2020-05-20 12:59:34 +0000 |
commit | 2323af4542db474d42e643051e38bbc65f5844e7 (patch) | |
tree | 0c081bc35b11d283d39f5a9e9edaf419e457ff83 /src/backends/cl/ClLayerSupport.cpp | |
parent | cc34093a57ea486e48f9aa66fc8b98a7bbdefef1 (diff) | |
download | armnn-2323af4542db474d42e643051e38bbc65f5844e7.tar.gz |
IVGCVSW-4450 Add CL Enhanced Quantized LSTM Workload
* Adds QLstm CL workload
* Added Layer and CreateWorkload tests
Signed-off-by: Ryan OShea <Ryan.OShea2@arm.com>
Signed-off-by: James Conroy <james.conroy@arm.com>
Change-Id: I32335e528467bfd619edb249d2971705ac2a6163
Diffstat (limited to 'src/backends/cl/ClLayerSupport.cpp')
-rw-r--r-- | src/backends/cl/ClLayerSupport.cpp | 39 |
1 files changed, 37 insertions, 2 deletions
diff --git a/src/backends/cl/ClLayerSupport.cpp b/src/backends/cl/ClLayerSupport.cpp index eb68a80765..7418dbd9e4 100644 --- a/src/backends/cl/ClLayerSupport.cpp +++ b/src/backends/cl/ClLayerSupport.cpp @@ -47,11 +47,12 @@ #include "workloads/ClPermuteWorkload.hpp" #include "workloads/ClPooling2dWorkload.hpp" #include "workloads/ClPreluWorkload.hpp" +#include "workloads/ClQLstmWorkload.hpp" +#include "workloads/ClQuantizedLstmWorkload.hpp" +#include "workloads/ClQuantizeWorkload.hpp" #include "workloads/ClReshapeWorkload.hpp" #include "workloads/ClResizeWorkload.hpp" #include "workloads/ClRsqrtWorkload.hpp" -#include "workloads/ClQuantizedLstmWorkload.hpp" -#include "workloads/ClQuantizeWorkload.hpp" #include "workloads/ClSliceWorkload.hpp" #include "workloads/ClSoftmaxWorkload.hpp" #include "workloads/ClSpaceToBatchNdWorkload.hpp" @@ -618,6 +619,40 @@ bool ClLayerSupport::IsPreluSupported(const armnn::TensorInfo &input, FORWARD_WORKLOAD_VALIDATE_FUNC(ClPreluWorkloadValidate, reasonIfUnsupported, input, alpha, output); } +bool ClLayerSupport::IsQLstmSupported(const TensorInfo& input, + const TensorInfo& previousOutputIn, + const TensorInfo& previousCellStateIn, + const TensorInfo& outputStateOut, + const TensorInfo& cellStateOut, + const TensorInfo& output, + const QLstmDescriptor& descriptor, + const LstmInputParamsInfo& paramsInfo, + Optional<std::string&> reasonIfUnsupported) const +{ + if (input.GetDataType() == armnn::DataType::QAsymmS8 && + previousOutputIn.GetDataType() == armnn::DataType::QAsymmS8 && + previousCellStateIn.GetDataType() == armnn::DataType::QSymmS16 && + outputStateOut.GetDataType() == armnn::DataType::QAsymmS8 && + cellStateOut.GetDataType() == armnn::DataType::QSymmS16 && + output.GetDataType() == armnn::DataType::QAsymmS8) + { + FORWARD_WORKLOAD_VALIDATE_FUNC(ClQLstmWorkloadValidate, + reasonIfUnsupported, + input, + previousCellStateIn, + previousOutputIn, + cellStateOut, + outputStateOut, + output, + descriptor, + paramsInfo); + } + else + { + return false; + } +} + bool ClLayerSupport::IsQuantizedLstmSupported(const TensorInfo& input, const TensorInfo& previousCellStateIn, const TensorInfo& previousOutputIn, |