aboutsummaryrefslogtreecommitdiff
path: root/src/runtime/CL/functions/CLLSTMLayer.cpp
diff options
context:
space:
mode:
authorMichele Di Giorgio <michele.digiorgio@arm.com>2020-03-09 19:32:33 +0000
committerGeorgios Pinitas <georgios.pinitas@arm.com>2020-04-20 11:06:59 +0000
commit47a899017e67556ffffef78571c9be61dd7bc3f0 (patch)
tree9ec9c12eb912f042262fe596e225f7c7737c3a0f /src/runtime/CL/functions/CLLSTMLayer.cpp
parentd1d7722cfc5ee130115d8d195068a98b16102a21 (diff)
downloadComputeLibrary-47a899017e67556ffffef78571c9be61dd7bc3f0.tar.gz
COMPMID-3237: Implement NEQLSTMLayer
COMPMID-3082: Extend NEQLSTMLayer with enhancements Change-Id: I88175b7bf69494a4eae510b74176fe8a0d6cd770 Signed-off-by: Michele Di Giorgio <michele.digiorgio@arm.com> Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/2969 Tested-by: Arm Jenkins <bsgcomp@arm.com> Reviewed-by: Sang-Hoon Park <sang-hoon.park@arm.com> Reviewed-by: Sheri Zhang <sheri.zhang@arm.com> Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
Diffstat (limited to 'src/runtime/CL/functions/CLLSTMLayer.cpp')
-rw-r--r--src/runtime/CL/functions/CLLSTMLayer.cpp31
1 files changed, 8 insertions, 23 deletions
diff --git a/src/runtime/CL/functions/CLLSTMLayer.cpp b/src/runtime/CL/functions/CLLSTMLayer.cpp
index 793d5ca1a9..3a3917784b 100644
--- a/src/runtime/CL/functions/CLLSTMLayer.cpp
+++ b/src/runtime/CL/functions/CLLSTMLayer.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2018-2019 ARM Limited.
+ * Copyright (c) 2018-2020 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -23,19 +23,17 @@
*/
#include "arm_compute/runtime/CL/functions/CLLSTMLayer.h"
-#include "arm_compute/core/PixelValue.h"
#include "arm_compute/core/Utils.h"
#include "arm_compute/core/Validate.h"
+#include "arm_compute/core/utils/misc/InfoHelpers.h"
#include "arm_compute/core/utils/misc/ShapeCalculator.h"
#include "arm_compute/core/utils/quantization/AsymmHelpers.h"
#include "arm_compute/runtime/CL/CLScheduler.h"
-#include <cmath>
-#include <memory>
-#include <tuple>
-
-using namespace arm_compute;
+namespace arm_compute
+{
using namespace arm_compute::misc::shape_calculator;
+using namespace arm_compute::utils::info_helpers;
CLLSTMLayer::CLLSTMLayer(std::shared_ptr<IMemoryManager> memory_manager)
: _memory_group(std::move(memory_manager)), _fully_connected_input_gate(), _accum_input_gate1(), _subtract_input_gate(), _pixelwise_mul_input_gate(), _activation_input_gate(),
@@ -71,22 +69,8 @@ void CLLSTMLayer::configure(const ICLTensor *input,
_is_layer_norm_lstm = lstm_params.use_layer_norm();
// Set lstm parameters
- LSTMParams<ITensorInfo> lstm_params_info;
- if(lstm_params.has_peephole_opt())
- {
- lstm_params_info.set_peephole_params(lstm_params.cell_to_forget_weights()->info(), lstm_params.cell_to_output_weights()->info());
- }
- if(lstm_params.has_projection())
- {
- lstm_params_info.set_projection_params(lstm_params.projection_weights()->info(),
- lstm_params.projection_bias() != nullptr ? lstm_params.projection_bias()->info() : nullptr);
- }
- if(!lstm_params.has_cifg_opt())
- {
- const ITensorInfo *cell_to_input_weights_info = (lstm_params.has_peephole_opt()) ? lstm_params.cell_to_input_weights()->info() : nullptr;
- lstm_params_info.set_cifg_params(lstm_params.input_to_input_weights()->info(), lstm_params.recurrent_to_input_weights()->info(),
- cell_to_input_weights_info, lstm_params.input_gate_bias()->info());
- }
+ LSTMParams<ITensorInfo> lstm_params_info{};
+ build_lstm_params_tensor_info(lstm_params, &lstm_params_info);
// Validate
ARM_COMPUTE_ERROR_THROW_ON(CLLSTMLayer::validate(input->info(), input_to_forget_weights->info(),
@@ -729,3 +713,4 @@ void CLLSTMLayer::prepare()
_is_prepared = true;
}
}
+} // namespace arm_compute