aboutsummaryrefslogtreecommitdiff
path: root/src/backends/neon/workloads/NeonLstmFloatWorkload.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/backends/neon/workloads/NeonLstmFloatWorkload.cpp')
-rw-r--r--src/backends/neon/workloads/NeonLstmFloatWorkload.cpp68
1 files changed, 9 insertions, 59 deletions
diff --git a/src/backends/neon/workloads/NeonLstmFloatWorkload.cpp b/src/backends/neon/workloads/NeonLstmFloatWorkload.cpp
index 2f14ab9022..19c85f7f33 100644
--- a/src/backends/neon/workloads/NeonLstmFloatWorkload.cpp
+++ b/src/backends/neon/workloads/NeonLstmFloatWorkload.cpp
@@ -6,7 +6,8 @@
#include "NeonLstmFloatWorkload.hpp"
#include "NeonWorkloadUtils.hpp"
-#include "aclCommon/ArmComputeTensorUtils.hpp"
+#include <aclCommon/ArmComputeTensorUtils.hpp>
+#include <aclCommon/ArmComputeUtils.hpp>
#include <armnn/utility/NumericCast.hpp>
@@ -16,14 +17,14 @@ namespace armnn
{
using namespace armcomputetensorutils;
-NeonLstmFloatWorkload::NeonLstmFloatWorkload(const LstmQueueDescriptor &descriptor, const WorkloadInfo &info)
+NeonLstmFloatWorkload::NeonLstmFloatWorkload(const LstmQueueDescriptor& descriptor, const WorkloadInfo& info)
: FloatWorkload<LstmQueueDescriptor>(descriptor, info)
{
// Report Profiling Details
ARMNN_REPORT_PROFILING_WORKLOAD_DESC("NeonLstmFloatWorkload_Construct",
descriptor.m_Parameters,
info,
- this->GetGuid());
+ GetGuid());
arm_compute::LSTMParams<arm_compute::ITensor> lstm_param;
@@ -160,36 +161,8 @@ NeonLstmFloatWorkload::NeonLstmFloatWorkload(const LstmQueueDescriptor &descript
float projection_threshold = m_Data.m_Parameters.m_ClippingThresProj;
// for preparing the object for the class ActivationLayerInfo, we need to consider 5 situations
- arm_compute::ActivationLayerInfo activationLayerInfo;
- if (m_Data.m_Parameters.m_ActivationFunc == 0)
- {
- // no activation, do nothing
- }
- else if (m_Data.m_Parameters.m_ActivationFunc == 1)
- {
- activationLayerInfo = arm_compute::ActivationLayerInfo(
- arm_compute::ActivationLayerInfo::ActivationFunction::RELU);
- }
- else if (m_Data.m_Parameters.m_ActivationFunc == 3)
- {
- activationLayerInfo = arm_compute::ActivationLayerInfo(
- arm_compute::ActivationLayerInfo::ActivationFunction::BOUNDED_RELU, 6.0);
- }
- else if (m_Data.m_Parameters.m_ActivationFunc == 4)
- {
- activationLayerInfo = arm_compute::ActivationLayerInfo(
- arm_compute::ActivationLayerInfo::ActivationFunction::TANH, 1.0, 1.0);
- }
- else if (m_Data.m_Parameters.m_ActivationFunc == 6)
- {
- activationLayerInfo = arm_compute::ActivationLayerInfo(
- arm_compute::ActivationLayerInfo::ActivationFunction::LOGISTIC);
- }
- else
- {
- throw armnn::Exception("Wrong Type of Activation Function!");
- }
-
+ arm_compute::ActivationLayerInfo activationLayerInfo =
+ ConvertLstmActivationFuncToAclLayerInfo(m_Data.m_Parameters.m_ActivationFunc);
m_LstmLayer.configure(&input, m_InputToForgetWeightsTensor.get(), m_InputToCellWeightsTensor.get(),
m_InputToOutputWeightsTensor.get(), m_RecurrentToForgetWeightsTensor.get(),
@@ -273,7 +246,7 @@ NeonLstmFloatWorkload::NeonLstmFloatWorkload(const LstmQueueDescriptor &descript
void NeonLstmFloatWorkload::Execute() const
{
- ARMNN_SCOPED_PROFILING_EVENT_NEON_GUID("NeonLstmFloatWorkload_Execute", this->GetGuid());
+ ARMNN_SCOPED_PROFILING_EVENT_NEON_GUID("NeonLstmFloatWorkload_Execute", GetGuid());
m_LstmLayer.run();
}
@@ -390,31 +363,8 @@ arm_compute::Status NeonLstmFloatWorkloadValidate(const TensorInfo& input,
float projection_threshold = descriptor.m_ClippingThresProj;
// for preparing the object for the class ActivationLayerInfo, we need to consider 5 situations
- arm_compute::ActivationLayerInfo activationLayerInfo;
- switch (descriptor.m_ActivationFunc)
- {
- case 0:
- // no activation, do nothing
- break;
- case 1:
- activationLayerInfo = arm_compute::ActivationLayerInfo(
- arm_compute::ActivationLayerInfo::ActivationFunction::RELU);
- break;
- case 3:
- activationLayerInfo = arm_compute::ActivationLayerInfo(
- arm_compute::ActivationLayerInfo::ActivationFunction::BOUNDED_RELU, 6.0);
- break;
- case 4:
- activationLayerInfo = arm_compute::ActivationLayerInfo(
- arm_compute::ActivationLayerInfo::ActivationFunction::TANH, 1.0, 1.0);
- break;
- case 6:
- activationLayerInfo = arm_compute::ActivationLayerInfo(
- arm_compute::ActivationLayerInfo::ActivationFunction::LOGISTIC);
- break;
- default:
- throw armnn::Exception("Wrong Type of Activation Function!");
- }
+ arm_compute::ActivationLayerInfo activationLayerInfo =
+ ConvertLstmActivationFuncToAclLayerInfo(descriptor.m_ActivationFunc);
return arm_compute::NELSTMLayer::validate(&aclInputInfo,
&aclInputToForgetWeightsInfo,