aboutsummaryrefslogtreecommitdiff
path: root/src/runtime/CL/functions/CLLSTMLayer.cpp
diff options
context:
space:
mode:
authorGeorgios Pinitas <georgios.pinitas@arm.com>2018-07-17 12:28:42 +0100
committerAnthony Barbier <anthony.barbier@arm.com>2018-11-02 16:54:54 +0000
commit7d66a8e3f603f2cd363f04a750847e3f9eabdfd4 (patch)
tree0d7e1ad5bf0ecd32cd919074f756d27c351d7638 /src/runtime/CL/functions/CLLSTMLayer.cpp
parentae54e026c86aec7d6819ee3ef76372c1a3c92467 (diff)
downloadComputeLibrary-7d66a8e3f603f2cd363f04a750847e3f9eabdfd4.tar.gz
COMPMID-1386: Add support for converting weights for CL.
Change-Id: I62e3ead903366baeeb1488f233a9b8b0c388c9de Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/140403 Tested-by: Jenkins <bsgcomp@arm.com> Reviewed-by: Giorgio Arena <giorgio.arena@arm.com> Reviewed-by: Anthony Barbier <anthony.barbier@arm.com>
Diffstat (limited to 'src/runtime/CL/functions/CLLSTMLayer.cpp')
-rw-r--r--src/runtime/CL/functions/CLLSTMLayer.cpp20
1 files changed, 10 insertions, 10 deletions
diff --git a/src/runtime/CL/functions/CLLSTMLayer.cpp b/src/runtime/CL/functions/CLLSTMLayer.cpp
index 872325175d..d384400ed3 100644
--- a/src/runtime/CL/functions/CLLSTMLayer.cpp
+++ b/src/runtime/CL/functions/CLLSTMLayer.cpp
@@ -90,7 +90,7 @@ void CLLSTMLayer::configure(const ICLTensor *input, const ICLTensor *input_to_fo
// Configure block that calculates the forget gate
// forget_gate = Activation(input * input_to_forget_weights + output_state * recurrent_to_forget_weights + PixelWiseMul(cell_state, cell_to_forget_weights) + forget_gate_bias)
_memory_group.manage(&_forget_gate_out1);
- _fully_connected_forget_gate.configure(input, input_to_forget_weights, forget_gate_bias, &_forget_gate_out1, true, false);
+ _fully_connected_forget_gate.configure(input, input_to_forget_weights, forget_gate_bias, &_forget_gate_out1);
_memory_group.manage(&_forget_gate_out2);
_transpose_forget_gate.configure(recurrent_to_forget_weights, &_forget_gate_out2);
_memory_group.manage(&_forget_gate_out3);
@@ -142,7 +142,7 @@ void CLLSTMLayer::configure(const ICLTensor *input, const ICLTensor *input_to_fo
_input_gate_out5.allocator()->init(TensorInfo(cell_state_shape, 1, input->info()->data_type()));
_memory_group.manage(&_input_gate_out1);
- _fully_connected_input_gate.configure(input, lstm_params.input_to_input_weights(), lstm_params.input_gate_bias(), &_input_gate_out1, true, false);
+ _fully_connected_input_gate.configure(input, lstm_params.input_to_input_weights(), lstm_params.input_gate_bias(), &_input_gate_out1);
_memory_group.manage(&_input_gate_out2);
_transpose_input_gate.configure(lstm_params.recurrent_to_input_weights(), &_input_gate_out2);
_memory_group.manage(&_input_gate_out3);
@@ -169,7 +169,7 @@ void CLLSTMLayer::configure(const ICLTensor *input, const ICLTensor *input_to_fo
// Configure block that calculates the cell state
// cell_state = Clip((PixelwiseMul(input_gate, Activation(input * input_to_cell_weights + output_state * recurrent_to_cell_weights + cell_bias)) + PixelwiseMul(forget_gate, cell_state)), cell_threshold)
_memory_group.manage(&_cell_state_out1);
- _fully_connected_cell_state.configure(input, input_to_cell_weights, cell_bias, &_cell_state_out1, true, false);
+ _fully_connected_cell_state.configure(input, input_to_cell_weights, cell_bias, &_cell_state_out1);
_memory_group.manage(&_cell_state_out2);
_transpose_cell_state.configure(recurrent_to_cell_weights, &_cell_state_out2);
_memory_group.manage(&_cell_state_out3);
@@ -204,7 +204,7 @@ void CLLSTMLayer::configure(const ICLTensor *input, const ICLTensor *input_to_fo
// Configure block that calculates the output
// output_state = Activation(input * input_to_output_weights + output_state * recurrent_to_output_weights + PixelWiseMul(cell_state, cell_to_output_weights) + output_gate_bias)
_memory_group.manage(&_output1);
- _fully_connected_output.configure(input, input_to_output_weights, output_gate_bias, &_output1, true, false);
+ _fully_connected_output.configure(input, input_to_output_weights, output_gate_bias, &_output1);
_memory_group.manage(&_output2);
_transpose_output.configure(recurrent_to_output_weights, &_output2);
_memory_group.manage(&_output3);
@@ -255,7 +255,7 @@ void CLLSTMLayer::configure(const ICLTensor *input, const ICLTensor *input_to_fo
_has_projection_weights = true;
_output_projection1.allocator()->init(TensorInfo(cell_state_shape, 1, input->info()->data_type()));
_memory_group.manage(&_output_projection1);
- _fully_connected_output_state.configure(output_state, lstm_params.projection_weights(), lstm_params.projection_bias(), &_output_projection1, true, false);
+ _fully_connected_output_state.configure(output_state, lstm_params.projection_weights(), lstm_params.projection_bias(), &_output_projection1);
// Perform clipping
if(projection_threshold != 0.f)
{
@@ -326,7 +326,7 @@ Status CLLSTMLayer::validate(const ITensorInfo *input, const ITensorInfo *input_
const TensorInfo num_units_transposed_info = TensorInfo(num_units_transposed_shape, 1, input->data_type());
// Validate forget gate
- ARM_COMPUTE_RETURN_ON_ERROR(CLFullyConnectedLayer::validate(input, input_to_forget_weights, forget_gate_bias, cell_state, true, false));
+ ARM_COMPUTE_RETURN_ON_ERROR(CLFullyConnectedLayer::validate(input, input_to_forget_weights, forget_gate_bias, cell_state));
ARM_COMPUTE_RETURN_ON_ERROR(CLGEMM::validate(output_state, &units_out_transposed_info, nullptr, cell_state, 1.f, 0.f, GEMMInfo()));
ARM_COMPUTE_RETURN_ON_ERROR(CLArithmeticAdditionKernel::validate(cell_state, cell_state, cell_state, ConvertPolicy::SATURATE));
if(lstm_params.has_peephole_opt())
@@ -344,7 +344,7 @@ Status CLLSTMLayer::validate(const ITensorInfo *input, const ITensorInfo *input_
ARM_COMPUTE_RETURN_ERROR_ON(lstm_params.recurrent_to_input_weights()->num_dimensions() > 2);
ARM_COMPUTE_RETURN_ERROR_ON(lstm_params.cell_to_input_weights()->num_dimensions() > 1);
ARM_COMPUTE_RETURN_ERROR_ON(lstm_params.input_gate_bias()->num_dimensions() > 1);
- ARM_COMPUTE_RETURN_ON_ERROR(CLFullyConnectedLayer::validate(input, lstm_params.input_to_input_weights(), lstm_params.input_gate_bias(), cell_state, true, false));
+ ARM_COMPUTE_RETURN_ON_ERROR(CLFullyConnectedLayer::validate(input, lstm_params.input_to_input_weights(), lstm_params.input_gate_bias(), cell_state));
ARM_COMPUTE_RETURN_ON_ERROR(CLGEMM::validate(cell_state, &num_units_transposed_info, nullptr, &gemmv_shape_info, 1.f, 0.f, GEMMInfo()));
ARM_COMPUTE_RETURN_ON_ERROR(CLArithmeticAddition::validate(cell_state, &gemmv_shape_info, cell_state, ConvertPolicy::SATURATE));
ARM_COMPUTE_RETURN_ON_ERROR(CLActivationLayerKernel::validate(cell_state, nullptr, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LOGISTIC)));
@@ -355,7 +355,7 @@ Status CLLSTMLayer::validate(const ITensorInfo *input, const ITensorInfo *input_
}
// Validate cell state
- ARM_COMPUTE_RETURN_ON_ERROR(CLFullyConnectedLayer::validate(input, input_to_cell_weights, cell_bias, cell_state, true, false));
+ ARM_COMPUTE_RETURN_ON_ERROR(CLFullyConnectedLayer::validate(input, input_to_cell_weights, cell_bias, cell_state));
ARM_COMPUTE_RETURN_ON_ERROR(CLActivationLayerKernel::validate(cell_state, nullptr, activation_info));
ARM_COMPUTE_RETURN_ON_ERROR(CLPixelWiseMultiplicationKernel::validate(cell_state, cell_state, cell_state, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_NEAREST_EVEN));
@@ -364,7 +364,7 @@ Status CLLSTMLayer::validate(const ITensorInfo *input, const ITensorInfo *input_
ARM_COMPUTE_RETURN_ON_ERROR(CLActivationLayerKernel::validate(cell_state, nullptr, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU, -cell_threshold, cell_threshold)));
}
- ARM_COMPUTE_RETURN_ON_ERROR(CLFullyConnectedLayer::validate(input, input_to_output_weights, output_gate_bias, cell_state, true, false));
+ ARM_COMPUTE_RETURN_ON_ERROR(CLFullyConnectedLayer::validate(input, input_to_output_weights, output_gate_bias, cell_state));
if(lstm_params.has_peephole_opt())
{
ARM_COMPUTE_RETURN_ON_ERROR(CLArithmeticAddition::validate(cell_state, cell_state, cell_state, ConvertPolicy::SATURATE));
@@ -376,7 +376,7 @@ Status CLLSTMLayer::validate(const ITensorInfo *input, const ITensorInfo *input_
ARM_COMPUTE_RETURN_ON_ERROR(CLPixelWiseMultiplicationKernel::validate(cell_state, output, output_state, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_NEAREST_EVEN));
if(lstm_params.has_projection())
{
- ARM_COMPUTE_RETURN_ON_ERROR(CLFullyConnectedLayer::validate(output_state, lstm_params.projection_weights(), lstm_params.projection_bias(), cell_state, true, false));
+ ARM_COMPUTE_RETURN_ON_ERROR(CLFullyConnectedLayer::validate(output_state, lstm_params.projection_weights(), lstm_params.projection_bias(), cell_state));
if(projection_threshold != 0.f)
{
ARM_COMPUTE_RETURN_ON_ERROR(CLActivationLayerKernel::validate(cell_state, output_state, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU, -projection_threshold,