From b0baff73b1574a198e57d46fcd704cedc43cea16 Mon Sep 17 00:00:00 2001 From: Cathal Corbett Date: Fri, 28 Jan 2022 12:17:19 +0000 Subject: IVGCVSW-6268 Add support of Unidirectional Sequence Lstm fp32/fp16 to Neon !ComputeLibrary:7150 Signed-off-by: Cathal Corbett Change-Id: I3de48ffc8d08c95a22705e2b68d069791bddae73 --- src/backends/cl/workloads/ClLstmFloatWorkload.cpp | 71 +++-------------------- 1 file changed, 9 insertions(+), 62 deletions(-) (limited to 'src/backends/cl') diff --git a/src/backends/cl/workloads/ClLstmFloatWorkload.cpp b/src/backends/cl/workloads/ClLstmFloatWorkload.cpp index 37dfab6a5f..e190f33bbc 100644 --- a/src/backends/cl/workloads/ClLstmFloatWorkload.cpp +++ b/src/backends/cl/workloads/ClLstmFloatWorkload.cpp @@ -7,6 +7,7 @@ #include #include #include +#include #include #include @@ -19,8 +20,8 @@ namespace armnn { using namespace armcomputetensorutils; -ClLstmFloatWorkload::ClLstmFloatWorkload(const LstmQueueDescriptor &descriptor, - const WorkloadInfo &info, +ClLstmFloatWorkload::ClLstmFloatWorkload(const LstmQueueDescriptor& descriptor, + const WorkloadInfo& info, const arm_compute::CLCompileContext& clCompileContext) : FloatWorkload(descriptor, info) { @@ -28,7 +29,7 @@ ClLstmFloatWorkload::ClLstmFloatWorkload(const LstmQueueDescriptor &descriptor, ARMNN_REPORT_PROFILING_WORKLOAD_DESC("ClLstmFloatWorkload_Construct", descriptor.m_Parameters, info, - this->GetGuid()); + GetGuid()); arm_compute::LSTMParams lstm_param; @@ -163,35 +164,8 @@ ClLstmFloatWorkload::ClLstmFloatWorkload(const LstmQueueDescriptor &descriptor, float projection_threshold = m_Data.m_Parameters.m_ClippingThresProj; // for preparing the object for the class ActivationLayerInfo, we need to consider 5 situations - arm_compute::ActivationLayerInfo activationLayerInfo; - if (m_Data.m_Parameters.m_ActivationFunc == 0) - { - // no activation, do nothing - } - else if (m_Data.m_Parameters.m_ActivationFunc == 1) - { - activationLayerInfo = arm_compute::ActivationLayerInfo( - arm_compute::ActivationLayerInfo::ActivationFunction::RELU); - } - else if (m_Data.m_Parameters.m_ActivationFunc == 3) - { - activationLayerInfo = arm_compute::ActivationLayerInfo( - arm_compute::ActivationLayerInfo::ActivationFunction::BOUNDED_RELU, 6.0); - } - else if (m_Data.m_Parameters.m_ActivationFunc == 4) - { - activationLayerInfo = arm_compute::ActivationLayerInfo( - arm_compute::ActivationLayerInfo::ActivationFunction::TANH, 1.0, 1.0); - } - else if (m_Data.m_Parameters.m_ActivationFunc == 6) - { - activationLayerInfo = arm_compute::ActivationLayerInfo( - arm_compute::ActivationLayerInfo::ActivationFunction::LOGISTIC); - } - else - { - throw armnn::Exception("Wrong Type of Activation Function!"); - } + arm_compute::ActivationLayerInfo activationLayerInfo = + ConvertLstmActivationFuncToAclLayerInfo(m_Data.m_Parameters.m_ActivationFunc); { ARMNN_SCOPED_PROFILING_EVENT(Compute::Undefined, "ClLstmFloatWorkload_configure"); @@ -263,7 +237,7 @@ ClLstmFloatWorkload::ClLstmFloatWorkload(const LstmQueueDescriptor &descriptor, void ClLstmFloatWorkload::Execute() const { - ARMNN_SCOPED_PROFILING_EVENT_CL_GUID("ClLstmFloatWorkload_Execute", this->GetGuid()); + ARMNN_SCOPED_PROFILING_EVENT_CL_GUID("ClLstmFloatWorkload_Execute", GetGuid()); RunClFunction(m_LstmLayer, CHECK_LOCATION()); } @@ -354,35 +328,8 @@ arm_compute::Status ClLstmFloatWorkloadValidate(const TensorInfo& input, const T float projection_threshold = descriptor.m_ClippingThresProj; // for preparing the object for the class ActivationLayerInfo, we need to consider 5 situations - arm_compute::ActivationLayerInfo activationLayerInfo; - if (descriptor.m_ActivationFunc == 0) - { - // no activation, do nothing - } - else if (descriptor.m_ActivationFunc == 1) - { - activationLayerInfo = arm_compute::ActivationLayerInfo( - arm_compute::ActivationLayerInfo::ActivationFunction::RELU); - } - else if (descriptor.m_ActivationFunc == 3) - { - activationLayerInfo = arm_compute::ActivationLayerInfo( - arm_compute::ActivationLayerInfo::ActivationFunction::BOUNDED_RELU, 6.0); - } - else if (descriptor.m_ActivationFunc == 4) - { - activationLayerInfo = arm_compute::ActivationLayerInfo( - arm_compute::ActivationLayerInfo::ActivationFunction::TANH, 1.0, 1.0); - } - else if (descriptor.m_ActivationFunc == 6) - { - activationLayerInfo = arm_compute::ActivationLayerInfo( - arm_compute::ActivationLayerInfo::ActivationFunction::LOGISTIC); - } - else - { - throw armnn::Exception("Wrong Type of Activation Function!"); - } + arm_compute::ActivationLayerInfo activationLayerInfo = + ConvertLstmActivationFuncToAclLayerInfo(descriptor.m_ActivationFunc); if (descriptor.m_LayerNormEnabled) { -- cgit v1.2.1