From 7d66a8e3f603f2cd363f04a750847e3f9eabdfd4 Mon Sep 17 00:00:00 2001 From: Georgios Pinitas Date: Tue, 17 Jul 2018 12:28:42 +0100 Subject: COMPMID-1386: Add support for converting weights for CL. Change-Id: I62e3ead903366baeeb1488f233a9b8b0c388c9de Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/140403 Tested-by: Jenkins Reviewed-by: Giorgio Arena Reviewed-by: Anthony Barbier --- src/runtime/CL/functions/CLLSTMLayer.cpp | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) (limited to 'src/runtime/CL/functions/CLLSTMLayer.cpp') diff --git a/src/runtime/CL/functions/CLLSTMLayer.cpp b/src/runtime/CL/functions/CLLSTMLayer.cpp index 872325175d..d384400ed3 100644 --- a/src/runtime/CL/functions/CLLSTMLayer.cpp +++ b/src/runtime/CL/functions/CLLSTMLayer.cpp @@ -90,7 +90,7 @@ void CLLSTMLayer::configure(const ICLTensor *input, const ICLTensor *input_to_fo // Configure block that calculates the forget gate // forget_gate = Activation(input * input_to_forget_weights + output_state * recurrent_to_forget_weights + PixelWiseMul(cell_state, cell_to_forget_weights) + forget_gate_bias) _memory_group.manage(&_forget_gate_out1); - _fully_connected_forget_gate.configure(input, input_to_forget_weights, forget_gate_bias, &_forget_gate_out1, true, false); + _fully_connected_forget_gate.configure(input, input_to_forget_weights, forget_gate_bias, &_forget_gate_out1); _memory_group.manage(&_forget_gate_out2); _transpose_forget_gate.configure(recurrent_to_forget_weights, &_forget_gate_out2); _memory_group.manage(&_forget_gate_out3); @@ -142,7 +142,7 @@ void CLLSTMLayer::configure(const ICLTensor *input, const ICLTensor *input_to_fo _input_gate_out5.allocator()->init(TensorInfo(cell_state_shape, 1, input->info()->data_type())); _memory_group.manage(&_input_gate_out1); - _fully_connected_input_gate.configure(input, lstm_params.input_to_input_weights(), lstm_params.input_gate_bias(), &_input_gate_out1, true, false); + _fully_connected_input_gate.configure(input, lstm_params.input_to_input_weights(), lstm_params.input_gate_bias(), &_input_gate_out1); _memory_group.manage(&_input_gate_out2); _transpose_input_gate.configure(lstm_params.recurrent_to_input_weights(), &_input_gate_out2); _memory_group.manage(&_input_gate_out3); @@ -169,7 +169,7 @@ void CLLSTMLayer::configure(const ICLTensor *input, const ICLTensor *input_to_fo // Configure block that calculates the cell state // cell_state = Clip((PixelwiseMul(input_gate, Activation(input * input_to_cell_weights + output_state * recurrent_to_cell_weights + cell_bias)) + PixelwiseMul(forget_gate, cell_state)), cell_threshold) _memory_group.manage(&_cell_state_out1); - _fully_connected_cell_state.configure(input, input_to_cell_weights, cell_bias, &_cell_state_out1, true, false); + _fully_connected_cell_state.configure(input, input_to_cell_weights, cell_bias, &_cell_state_out1); _memory_group.manage(&_cell_state_out2); _transpose_cell_state.configure(recurrent_to_cell_weights, &_cell_state_out2); _memory_group.manage(&_cell_state_out3); @@ -204,7 +204,7 @@ void CLLSTMLayer::configure(const ICLTensor *input, const ICLTensor *input_to_fo // Configure block that calculates the output // output_state = Activation(input * input_to_output_weights + output_state * recurrent_to_output_weights + PixelWiseMul(cell_state, cell_to_output_weights) + output_gate_bias) _memory_group.manage(&_output1); - _fully_connected_output.configure(input, input_to_output_weights, output_gate_bias, &_output1, true, false); + _fully_connected_output.configure(input, input_to_output_weights, output_gate_bias, &_output1); _memory_group.manage(&_output2); _transpose_output.configure(recurrent_to_output_weights, &_output2); _memory_group.manage(&_output3); @@ -255,7 +255,7 @@ void CLLSTMLayer::configure(const ICLTensor *input, const ICLTensor *input_to_fo _has_projection_weights = true; _output_projection1.allocator()->init(TensorInfo(cell_state_shape, 1, input->info()->data_type())); _memory_group.manage(&_output_projection1); - _fully_connected_output_state.configure(output_state, lstm_params.projection_weights(), lstm_params.projection_bias(), &_output_projection1, true, false); + _fully_connected_output_state.configure(output_state, lstm_params.projection_weights(), lstm_params.projection_bias(), &_output_projection1); // Perform clipping if(projection_threshold != 0.f) { @@ -326,7 +326,7 @@ Status CLLSTMLayer::validate(const ITensorInfo *input, const ITensorInfo *input_ const TensorInfo num_units_transposed_info = TensorInfo(num_units_transposed_shape, 1, input->data_type()); // Validate forget gate - ARM_COMPUTE_RETURN_ON_ERROR(CLFullyConnectedLayer::validate(input, input_to_forget_weights, forget_gate_bias, cell_state, true, false)); + ARM_COMPUTE_RETURN_ON_ERROR(CLFullyConnectedLayer::validate(input, input_to_forget_weights, forget_gate_bias, cell_state)); ARM_COMPUTE_RETURN_ON_ERROR(CLGEMM::validate(output_state, &units_out_transposed_info, nullptr, cell_state, 1.f, 0.f, GEMMInfo())); ARM_COMPUTE_RETURN_ON_ERROR(CLArithmeticAdditionKernel::validate(cell_state, cell_state, cell_state, ConvertPolicy::SATURATE)); if(lstm_params.has_peephole_opt()) @@ -344,7 +344,7 @@ Status CLLSTMLayer::validate(const ITensorInfo *input, const ITensorInfo *input_ ARM_COMPUTE_RETURN_ERROR_ON(lstm_params.recurrent_to_input_weights()->num_dimensions() > 2); ARM_COMPUTE_RETURN_ERROR_ON(lstm_params.cell_to_input_weights()->num_dimensions() > 1); ARM_COMPUTE_RETURN_ERROR_ON(lstm_params.input_gate_bias()->num_dimensions() > 1); - ARM_COMPUTE_RETURN_ON_ERROR(CLFullyConnectedLayer::validate(input, lstm_params.input_to_input_weights(), lstm_params.input_gate_bias(), cell_state, true, false)); + ARM_COMPUTE_RETURN_ON_ERROR(CLFullyConnectedLayer::validate(input, lstm_params.input_to_input_weights(), lstm_params.input_gate_bias(), cell_state)); ARM_COMPUTE_RETURN_ON_ERROR(CLGEMM::validate(cell_state, &num_units_transposed_info, nullptr, &gemmv_shape_info, 1.f, 0.f, GEMMInfo())); ARM_COMPUTE_RETURN_ON_ERROR(CLArithmeticAddition::validate(cell_state, &gemmv_shape_info, cell_state, ConvertPolicy::SATURATE)); ARM_COMPUTE_RETURN_ON_ERROR(CLActivationLayerKernel::validate(cell_state, nullptr, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LOGISTIC))); @@ -355,7 +355,7 @@ Status CLLSTMLayer::validate(const ITensorInfo *input, const ITensorInfo *input_ } // Validate cell state - ARM_COMPUTE_RETURN_ON_ERROR(CLFullyConnectedLayer::validate(input, input_to_cell_weights, cell_bias, cell_state, true, false)); + ARM_COMPUTE_RETURN_ON_ERROR(CLFullyConnectedLayer::validate(input, input_to_cell_weights, cell_bias, cell_state)); ARM_COMPUTE_RETURN_ON_ERROR(CLActivationLayerKernel::validate(cell_state, nullptr, activation_info)); ARM_COMPUTE_RETURN_ON_ERROR(CLPixelWiseMultiplicationKernel::validate(cell_state, cell_state, cell_state, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_NEAREST_EVEN)); @@ -364,7 +364,7 @@ Status CLLSTMLayer::validate(const ITensorInfo *input, const ITensorInfo *input_ ARM_COMPUTE_RETURN_ON_ERROR(CLActivationLayerKernel::validate(cell_state, nullptr, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU, -cell_threshold, cell_threshold))); } - ARM_COMPUTE_RETURN_ON_ERROR(CLFullyConnectedLayer::validate(input, input_to_output_weights, output_gate_bias, cell_state, true, false)); + ARM_COMPUTE_RETURN_ON_ERROR(CLFullyConnectedLayer::validate(input, input_to_output_weights, output_gate_bias, cell_state)); if(lstm_params.has_peephole_opt()) { ARM_COMPUTE_RETURN_ON_ERROR(CLArithmeticAddition::validate(cell_state, cell_state, cell_state, ConvertPolicy::SATURATE)); @@ -376,7 +376,7 @@ Status CLLSTMLayer::validate(const ITensorInfo *input, const ITensorInfo *input_ ARM_COMPUTE_RETURN_ON_ERROR(CLPixelWiseMultiplicationKernel::validate(cell_state, output, output_state, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_NEAREST_EVEN)); if(lstm_params.has_projection()) { - ARM_COMPUTE_RETURN_ON_ERROR(CLFullyConnectedLayer::validate(output_state, lstm_params.projection_weights(), lstm_params.projection_bias(), cell_state, true, false)); + ARM_COMPUTE_RETURN_ON_ERROR(CLFullyConnectedLayer::validate(output_state, lstm_params.projection_weights(), lstm_params.projection_bias(), cell_state)); if(projection_threshold != 0.f) { ARM_COMPUTE_RETURN_ON_ERROR(CLActivationLayerKernel::validate(cell_state, output_state, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU, -projection_threshold, -- cgit v1.2.1