diff options
Diffstat (limited to 'src/backends/neon/workloads/NeonLstmFloatWorkload.cpp')
-rw-r--r-- | src/backends/neon/workloads/NeonLstmFloatWorkload.cpp | 68 |
1 files changed, 9 insertions, 59 deletions
diff --git a/src/backends/neon/workloads/NeonLstmFloatWorkload.cpp b/src/backends/neon/workloads/NeonLstmFloatWorkload.cpp index 2f14ab9022..19c85f7f33 100644 --- a/src/backends/neon/workloads/NeonLstmFloatWorkload.cpp +++ b/src/backends/neon/workloads/NeonLstmFloatWorkload.cpp @@ -6,7 +6,8 @@ #include "NeonLstmFloatWorkload.hpp" #include "NeonWorkloadUtils.hpp" -#include "aclCommon/ArmComputeTensorUtils.hpp" +#include <aclCommon/ArmComputeTensorUtils.hpp> +#include <aclCommon/ArmComputeUtils.hpp> #include <armnn/utility/NumericCast.hpp> @@ -16,14 +17,14 @@ namespace armnn { using namespace armcomputetensorutils; -NeonLstmFloatWorkload::NeonLstmFloatWorkload(const LstmQueueDescriptor &descriptor, const WorkloadInfo &info) +NeonLstmFloatWorkload::NeonLstmFloatWorkload(const LstmQueueDescriptor& descriptor, const WorkloadInfo& info) : FloatWorkload<LstmQueueDescriptor>(descriptor, info) { // Report Profiling Details ARMNN_REPORT_PROFILING_WORKLOAD_DESC("NeonLstmFloatWorkload_Construct", descriptor.m_Parameters, info, - this->GetGuid()); + GetGuid()); arm_compute::LSTMParams<arm_compute::ITensor> lstm_param; @@ -160,36 +161,8 @@ NeonLstmFloatWorkload::NeonLstmFloatWorkload(const LstmQueueDescriptor &descript float projection_threshold = m_Data.m_Parameters.m_ClippingThresProj; // for preparing the object for the class ActivationLayerInfo, we need to consider 5 situations - arm_compute::ActivationLayerInfo activationLayerInfo; - if (m_Data.m_Parameters.m_ActivationFunc == 0) - { - // no activation, do nothing - } - else if (m_Data.m_Parameters.m_ActivationFunc == 1) - { - activationLayerInfo = arm_compute::ActivationLayerInfo( - arm_compute::ActivationLayerInfo::ActivationFunction::RELU); - } - else if (m_Data.m_Parameters.m_ActivationFunc == 3) - { - activationLayerInfo = arm_compute::ActivationLayerInfo( - arm_compute::ActivationLayerInfo::ActivationFunction::BOUNDED_RELU, 6.0); - } - else if (m_Data.m_Parameters.m_ActivationFunc == 4) - { - activationLayerInfo = arm_compute::ActivationLayerInfo( - arm_compute::ActivationLayerInfo::ActivationFunction::TANH, 1.0, 1.0); - } - else if (m_Data.m_Parameters.m_ActivationFunc == 6) - { - activationLayerInfo = arm_compute::ActivationLayerInfo( - arm_compute::ActivationLayerInfo::ActivationFunction::LOGISTIC); - } - else - { - throw armnn::Exception("Wrong Type of Activation Function!"); - } - + arm_compute::ActivationLayerInfo activationLayerInfo = + ConvertLstmActivationFuncToAclLayerInfo(m_Data.m_Parameters.m_ActivationFunc); m_LstmLayer.configure(&input, m_InputToForgetWeightsTensor.get(), m_InputToCellWeightsTensor.get(), m_InputToOutputWeightsTensor.get(), m_RecurrentToForgetWeightsTensor.get(), @@ -273,7 +246,7 @@ NeonLstmFloatWorkload::NeonLstmFloatWorkload(const LstmQueueDescriptor &descript void NeonLstmFloatWorkload::Execute() const { - ARMNN_SCOPED_PROFILING_EVENT_NEON_GUID("NeonLstmFloatWorkload_Execute", this->GetGuid()); + ARMNN_SCOPED_PROFILING_EVENT_NEON_GUID("NeonLstmFloatWorkload_Execute", GetGuid()); m_LstmLayer.run(); } @@ -390,31 +363,8 @@ arm_compute::Status NeonLstmFloatWorkloadValidate(const TensorInfo& input, float projection_threshold = descriptor.m_ClippingThresProj; // for preparing the object for the class ActivationLayerInfo, we need to consider 5 situations - arm_compute::ActivationLayerInfo activationLayerInfo; - switch (descriptor.m_ActivationFunc) - { - case 0: - // no activation, do nothing - break; - case 1: - activationLayerInfo = arm_compute::ActivationLayerInfo( - arm_compute::ActivationLayerInfo::ActivationFunction::RELU); - break; - case 3: - activationLayerInfo = arm_compute::ActivationLayerInfo( - arm_compute::ActivationLayerInfo::ActivationFunction::BOUNDED_RELU, 6.0); - break; - case 4: - activationLayerInfo = arm_compute::ActivationLayerInfo( - arm_compute::ActivationLayerInfo::ActivationFunction::TANH, 1.0, 1.0); - break; - case 6: - activationLayerInfo = arm_compute::ActivationLayerInfo( - arm_compute::ActivationLayerInfo::ActivationFunction::LOGISTIC); - break; - default: - throw armnn::Exception("Wrong Type of Activation Function!"); - } + arm_compute::ActivationLayerInfo activationLayerInfo = + ConvertLstmActivationFuncToAclLayerInfo(descriptor.m_ActivationFunc); return arm_compute::NELSTMLayer::validate(&aclInputInfo, &aclInputToForgetWeightsInfo, |