diff options
author | Michele Di Giorgio <michele.digiorgio@arm.com> | 2020-03-09 19:32:33 +0000 |
---|---|---|
committer | Georgios Pinitas <georgios.pinitas@arm.com> | 2020-04-20 11:06:59 +0000 |
commit | 47a899017e67556ffffef78571c9be61dd7bc3f0 (patch) | |
tree | 9ec9c12eb912f042262fe596e225f7c7737c3a0f /src/runtime/CL | |
parent | d1d7722cfc5ee130115d8d195068a98b16102a21 (diff) | |
download | ComputeLibrary-47a899017e67556ffffef78571c9be61dd7bc3f0.tar.gz |
COMPMID-3237: Implement NEQLSTMLayer
COMPMID-3082: Extend NEQLSTMLayer with enhancements
Change-Id: I88175b7bf69494a4eae510b74176fe8a0d6cd770
Signed-off-by: Michele Di Giorgio <michele.digiorgio@arm.com>
Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/2969
Tested-by: Arm Jenkins <bsgcomp@arm.com>
Reviewed-by: Sang-Hoon Park <sang-hoon.park@arm.com>
Reviewed-by: Sheri Zhang <sheri.zhang@arm.com>
Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
Diffstat (limited to 'src/runtime/CL')
-rw-r--r-- | src/runtime/CL/functions/CLLSTMLayer.cpp | 31 |
1 files changed, 8 insertions, 23 deletions
diff --git a/src/runtime/CL/functions/CLLSTMLayer.cpp b/src/runtime/CL/functions/CLLSTMLayer.cpp index 793d5ca1a9..3a3917784b 100644 --- a/src/runtime/CL/functions/CLLSTMLayer.cpp +++ b/src/runtime/CL/functions/CLLSTMLayer.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018-2019 ARM Limited. + * Copyright (c) 2018-2020 ARM Limited. * * SPDX-License-Identifier: MIT * @@ -23,19 +23,17 @@ */ #include "arm_compute/runtime/CL/functions/CLLSTMLayer.h" -#include "arm_compute/core/PixelValue.h" #include "arm_compute/core/Utils.h" #include "arm_compute/core/Validate.h" +#include "arm_compute/core/utils/misc/InfoHelpers.h" #include "arm_compute/core/utils/misc/ShapeCalculator.h" #include "arm_compute/core/utils/quantization/AsymmHelpers.h" #include "arm_compute/runtime/CL/CLScheduler.h" -#include <cmath> -#include <memory> -#include <tuple> - -using namespace arm_compute; +namespace arm_compute +{ using namespace arm_compute::misc::shape_calculator; +using namespace arm_compute::utils::info_helpers; CLLSTMLayer::CLLSTMLayer(std::shared_ptr<IMemoryManager> memory_manager) : _memory_group(std::move(memory_manager)), _fully_connected_input_gate(), _accum_input_gate1(), _subtract_input_gate(), _pixelwise_mul_input_gate(), _activation_input_gate(), @@ -71,22 +69,8 @@ void CLLSTMLayer::configure(const ICLTensor *input, _is_layer_norm_lstm = lstm_params.use_layer_norm(); // Set lstm parameters - LSTMParams<ITensorInfo> lstm_params_info; - if(lstm_params.has_peephole_opt()) - { - lstm_params_info.set_peephole_params(lstm_params.cell_to_forget_weights()->info(), lstm_params.cell_to_output_weights()->info()); - } - if(lstm_params.has_projection()) - { - lstm_params_info.set_projection_params(lstm_params.projection_weights()->info(), - lstm_params.projection_bias() != nullptr ? lstm_params.projection_bias()->info() : nullptr); - } - if(!lstm_params.has_cifg_opt()) - { - const ITensorInfo *cell_to_input_weights_info = (lstm_params.has_peephole_opt()) ? lstm_params.cell_to_input_weights()->info() : nullptr; - lstm_params_info.set_cifg_params(lstm_params.input_to_input_weights()->info(), lstm_params.recurrent_to_input_weights()->info(), - cell_to_input_weights_info, lstm_params.input_gate_bias()->info()); - } + LSTMParams<ITensorInfo> lstm_params_info{}; + build_lstm_params_tensor_info(lstm_params, &lstm_params_info); // Validate ARM_COMPUTE_ERROR_THROW_ON(CLLSTMLayer::validate(input->info(), input_to_forget_weights->info(), @@ -729,3 +713,4 @@ void CLLSTMLayer::prepare() _is_prepared = true; } } +} // namespace arm_compute |