aboutsummaryrefslogtreecommitdiff
path: root/src/runtime/CL/functions/CLLSTMLayer.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/runtime/CL/functions/CLLSTMLayer.cpp')
-rw-r--r--src/runtime/CL/functions/CLLSTMLayer.cpp52
1 files changed, 26 insertions, 26 deletions
diff --git a/src/runtime/CL/functions/CLLSTMLayer.cpp b/src/runtime/CL/functions/CLLSTMLayer.cpp
index a1c412401b..058b6027c2 100644
--- a/src/runtime/CL/functions/CLLSTMLayer.cpp
+++ b/src/runtime/CL/functions/CLLSTMLayer.cpp
@@ -55,7 +55,7 @@ void CLLSTMLayer::configure(const ICLTensor *input,
const ICLTensor *input_to_forget_weights, const ICLTensor *input_to_cell_weights, const ICLTensor *input_to_output_weights,
const ICLTensor *recurrent_to_forget_weights, const ICLTensor *recurrent_to_cell_weights, const ICLTensor *recurrent_to_output_weights,
const ICLTensor *forget_gate_bias, const ICLTensor *cell_bias, const ICLTensor *output_gate_bias,
- const ICLTensor *output_state_in, const ICLTensor *cell_state_in,
+ const ICLTensor *output_state_in, ICLTensor *cell_state_in,
ICLTensor *scratch_buffer, ICLTensor *output_state_out, ICLTensor *cell_state_out, ICLTensor *output,
const LSTMParams<ICLTensor> &lstm_params, const ActivationLayerInfo &activation_info, float cell_threshold, float projection_threshold)
{
@@ -68,7 +68,7 @@ void CLLSTMLayer::configure(const CLCompileContext &compile_context, const ICLTe
const ICLTensor *input_to_forget_weights, const ICLTensor *input_to_cell_weights, const ICLTensor *input_to_output_weights,
const ICLTensor *recurrent_to_forget_weights, const ICLTensor *recurrent_to_cell_weights, const ICLTensor *recurrent_to_output_weights,
const ICLTensor *forget_gate_bias, const ICLTensor *cell_bias, const ICLTensor *output_gate_bias,
- const ICLTensor *output_state_in, const ICLTensor *cell_state_in,
+ const ICLTensor *output_state_in, ICLTensor *cell_state_in,
ICLTensor *scratch_buffer, ICLTensor *output_state_out, ICLTensor *cell_state_out, ICLTensor *output,
const LSTMParams<ICLTensor> &lstm_params, const ActivationLayerInfo &activation_info, float cell_threshold, float projection_threshold)
{
@@ -489,14 +489,14 @@ Status CLLSTMLayer::validate(const ITensorInfo *input,
if(lstm_params.has_peephole_opt())
{
- ARM_COMPUTE_RETURN_ON_ERROR(CLPixelWiseMultiplicationKernel::validate(cell_state_in, lstm_params.cell_to_forget_weights(), &forget_gate, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_NEAREST_EVEN));
+ ARM_COMPUTE_RETURN_ON_ERROR(CLPixelWiseMultiplication::validate(cell_state_in, lstm_params.cell_to_forget_weights(), &forget_gate, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_NEAREST_EVEN));
ARM_COMPUTE_RETURN_ON_ERROR(CLArithmeticAddition::validate(&forget_gate, &forget_gate, &forget_gate, ConvertPolicy::SATURATE));
}
if(lstm_params.use_layer_norm())
{
ARM_COMPUTE_RETURN_ON_ERROR(CLMeanStdDevNormalizationLayer::validate(&forget_gate));
- ARM_COMPUTE_RETURN_ON_ERROR(CLPixelWiseMultiplicationKernel::validate(&forget_gate, lstm_params.forget_layer_norm_weights(), &forget_gate, 1, ConvertPolicy::SATURATE,
- RoundingPolicy::TO_NEAREST_EVEN));
+ ARM_COMPUTE_RETURN_ON_ERROR(CLPixelWiseMultiplication::validate(&forget_gate, lstm_params.forget_layer_norm_weights(), &forget_gate, 1, ConvertPolicy::SATURATE,
+ RoundingPolicy::TO_NEAREST_EVEN));
ARM_COMPUTE_RETURN_ON_ERROR(CLArithmeticAddition::validate(&forget_gate, forget_gate_bias, &forget_gate, ConvertPolicy::SATURATE));
}
ARM_COMPUTE_RETURN_ON_ERROR(CLActivationLayer::validate(&forget_gate, &forget_gate, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LOGISTIC)));
@@ -524,14 +524,14 @@ Status CLLSTMLayer::validate(const ITensorInfo *input,
{
ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(lstm_params.cell_to_input_weights());
ARM_COMPUTE_RETURN_ERROR_ON(lstm_params.cell_to_input_weights()->num_dimensions() > 1);
- ARM_COMPUTE_RETURN_ON_ERROR(CLPixelWiseMultiplicationKernel::validate(cell_state_in, lstm_params.cell_to_input_weights(), &input_gate, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_NEAREST_EVEN));
+ ARM_COMPUTE_RETURN_ON_ERROR(CLPixelWiseMultiplication::validate(cell_state_in, lstm_params.cell_to_input_weights(), &input_gate, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_NEAREST_EVEN));
ARM_COMPUTE_RETURN_ON_ERROR(CLArithmeticAddition::validate(&input_gate, &input_gate, &input_gate, ConvertPolicy::SATURATE));
}
if(lstm_params.use_layer_norm())
{
ARM_COMPUTE_RETURN_ON_ERROR(CLMeanStdDevNormalizationLayer::validate(&input_gate));
- ARM_COMPUTE_RETURN_ON_ERROR(CLPixelWiseMultiplicationKernel::validate(&input_gate, lstm_params.input_layer_norm_weights(), &input_gate, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_NEAREST_EVEN));
+ ARM_COMPUTE_RETURN_ON_ERROR(CLPixelWiseMultiplication::validate(&input_gate, lstm_params.input_layer_norm_weights(), &input_gate, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_NEAREST_EVEN));
ARM_COMPUTE_RETURN_ON_ERROR(CLArithmeticAddition::validate(&input_gate, lstm_params.input_gate_bias(), &input_gate, ConvertPolicy::SATURATE));
}
ARM_COMPUTE_RETURN_ON_ERROR(CLActivationLayer::validate(&input_gate, nullptr, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LOGISTIC)));
@@ -548,13 +548,13 @@ Status CLLSTMLayer::validate(const ITensorInfo *input,
if(lstm_params.use_layer_norm())
{
ARM_COMPUTE_RETURN_ON_ERROR(CLMeanStdDevNormalizationLayer::validate(&cell_state_tmp));
- ARM_COMPUTE_RETURN_ON_ERROR(CLPixelWiseMultiplicationKernel::validate(&cell_state_tmp, lstm_params.cell_layer_norm_weights(), &cell_state_tmp, 1, ConvertPolicy::SATURATE,
- RoundingPolicy::TO_NEAREST_EVEN));
+ ARM_COMPUTE_RETURN_ON_ERROR(CLPixelWiseMultiplication::validate(&cell_state_tmp, lstm_params.cell_layer_norm_weights(), &cell_state_tmp, 1, ConvertPolicy::SATURATE,
+ RoundingPolicy::TO_NEAREST_EVEN));
ARM_COMPUTE_RETURN_ON_ERROR(CLArithmeticAddition::validate(&cell_state_tmp, cell_bias, &cell_state_tmp, ConvertPolicy::SATURATE));
}
ARM_COMPUTE_RETURN_ON_ERROR(CLActivationLayer::validate(&cell_state_tmp, nullptr, activation_info));
- ARM_COMPUTE_RETURN_ON_ERROR(CLPixelWiseMultiplicationKernel::validate(&cell_state_tmp, &input_gate, &cell_state_tmp, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_NEAREST_EVEN));
- ARM_COMPUTE_RETURN_ON_ERROR(CLPixelWiseMultiplicationKernel::validate(&cell_state_tmp, &forget_gate, &cell_state_tmp, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_NEAREST_EVEN));
+ ARM_COMPUTE_RETURN_ON_ERROR(CLPixelWiseMultiplication::validate(&cell_state_tmp, &input_gate, &cell_state_tmp, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_NEAREST_EVEN));
+ ARM_COMPUTE_RETURN_ON_ERROR(CLPixelWiseMultiplication::validate(&cell_state_tmp, &forget_gate, &cell_state_tmp, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_NEAREST_EVEN));
ARM_COMPUTE_RETURN_ON_ERROR(CLArithmeticAddition::validate(&cell_state_tmp, &cell_state_tmp, &cell_state_tmp, ConvertPolicy::SATURATE));
if(cell_threshold != 0.f)
{
@@ -573,22 +573,22 @@ Status CLLSTMLayer::validate(const ITensorInfo *input,
if(lstm_params.has_peephole_opt())
{
- ARM_COMPUTE_RETURN_ON_ERROR(CLPixelWiseMultiplicationKernel::validate(&cell_state_tmp, lstm_params.cell_to_output_weights(), &output_gate_tmp, 1, ConvertPolicy::SATURATE,
- RoundingPolicy::TO_NEAREST_EVEN));
+ ARM_COMPUTE_RETURN_ON_ERROR(CLPixelWiseMultiplication::validate(&cell_state_tmp, lstm_params.cell_to_output_weights(), &output_gate_tmp, 1, ConvertPolicy::SATURATE,
+ RoundingPolicy::TO_NEAREST_EVEN));
ARM_COMPUTE_RETURN_ON_ERROR(CLArithmeticAddition::validate(&output_gate_tmp, &output_gate_tmp, &output_gate_tmp, ConvertPolicy::SATURATE));
}
if(lstm_params.use_layer_norm())
{
ARM_COMPUTE_RETURN_ON_ERROR(CLMeanStdDevNormalizationLayer::validate(&output_gate_tmp));
- ARM_COMPUTE_RETURN_ON_ERROR(CLPixelWiseMultiplicationKernel::validate(&output_gate_tmp, lstm_params.output_layer_norm_weights(), &output_gate_tmp, 1, ConvertPolicy::SATURATE,
- RoundingPolicy::TO_NEAREST_EVEN));
+ ARM_COMPUTE_RETURN_ON_ERROR(CLPixelWiseMultiplication::validate(&output_gate_tmp, lstm_params.output_layer_norm_weights(), &output_gate_tmp, 1, ConvertPolicy::SATURATE,
+ RoundingPolicy::TO_NEAREST_EVEN));
ARM_COMPUTE_RETURN_ON_ERROR(CLArithmeticAddition::validate(&output_gate_tmp, output_gate_bias, &output_gate_tmp, ConvertPolicy::SATURATE));
}
ARM_COMPUTE_RETURN_ON_ERROR(CLActivationLayer::validate(&output_gate_tmp, nullptr, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LOGISTIC)));
// Validate output state
ARM_COMPUTE_RETURN_ON_ERROR(CLActivationLayer::validate(&cell_state_tmp, &cell_state_tmp, activation_info));
- ARM_COMPUTE_RETURN_ON_ERROR(CLPixelWiseMultiplicationKernel::validate(&cell_state_tmp, &output_gate_tmp, &output_gate_tmp, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_NEAREST_EVEN));
+ ARM_COMPUTE_RETURN_ON_ERROR(CLPixelWiseMultiplication::validate(&cell_state_tmp, &output_gate_tmp, &output_gate_tmp, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_NEAREST_EVEN));
if(lstm_params.has_projection())
{
ARM_COMPUTE_RETURN_ON_ERROR(CLFullyConnectedLayer::validate(&output_gate_tmp, lstm_params.projection_weights(), lstm_params.projection_bias(), output_state_out));
@@ -629,13 +629,13 @@ void CLLSTMLayer::run()
if(_run_peephole_opt)
{
- CLScheduler::get().enqueue(_pixelwise_mul_forget_gate);
+ _pixelwise_mul_forget_gate.run();
_accum_forget_gate1.run();
}
if(_is_layer_norm_lstm)
{
_mean_std_norm_forget_gate.run();
- CLScheduler::get().enqueue(_pixelwise_mul_forget_gate_coeff);
+ _pixelwise_mul_forget_gate_coeff.run();
_accum_forget_gate_bias.run();
}
_activation_forget_gate.run();
@@ -651,14 +651,14 @@ void CLLSTMLayer::run()
if(_run_peephole_opt)
{
- CLScheduler::get().enqueue(_pixelwise_mul_input_gate);
+ _pixelwise_mul_input_gate.run();
_accum_input_gate1.run();
}
if(_is_layer_norm_lstm)
{
_mean_std_norm_input_gate.run();
- CLScheduler::get().enqueue(_pixelwise_mul_input_gate_coeff);
+ _pixelwise_mul_input_gate_coeff.run();
_accum_input_gate_bias.run();
}
_activation_input_gate.run();
@@ -671,12 +671,12 @@ void CLLSTMLayer::run()
if(_is_layer_norm_lstm)
{
_mean_std_norm_cell_gate.run();
- CLScheduler::get().enqueue(_pixelwise_mul_cell_gate_coeff);
+ _pixelwise_mul_cell_gate_coeff.run();
_accum_cell_gate_bias.run();
}
_activation_cell_state.run();
- CLScheduler::get().enqueue(_pixelwise_mul_cell_state1);
- CLScheduler::get().enqueue(_pixelwise_mul_cell_state2);
+ _pixelwise_mul_cell_state1.run();
+ _pixelwise_mul_cell_state2.run();
_accum_cell_state2.run();
if(_perform_cell_clipping)
@@ -688,19 +688,19 @@ void CLLSTMLayer::run()
if(_run_peephole_opt)
{
- CLScheduler::get().enqueue(_pixelwise_mul_output_state1);
+ _pixelwise_mul_output_state1.run();
_accum_output1.run();
}
if(_is_layer_norm_lstm)
{
_mean_std_norm_output_gate.run();
- CLScheduler::get().enqueue(_pixelwise_mul_output_gate_coeff);
+ _pixelwise_mul_output_gate_coeff.run();
_accum_output_gate_bias.run();
}
_activation_output.run();
_activation_output_state.run();
- CLScheduler::get().enqueue(_pixelwise_mul_output_state2);
+ _pixelwise_mul_output_state2.run();
if(_has_projection_weights)
{