aboutsummaryrefslogtreecommitdiff
path: root/src/runtime/CL/functions/CLQLSTMLayer.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/runtime/CL/functions/CLQLSTMLayer.cpp')
-rw-r--r--src/runtime/CL/functions/CLQLSTMLayer.cpp68
1 files changed, 34 insertions, 34 deletions
diff --git a/src/runtime/CL/functions/CLQLSTMLayer.cpp b/src/runtime/CL/functions/CLQLSTMLayer.cpp
index 8c45a98935..c5c4aa3dfa 100644
--- a/src/runtime/CL/functions/CLQLSTMLayer.cpp
+++ b/src/runtime/CL/functions/CLQLSTMLayer.cpp
@@ -213,7 +213,7 @@ void CLQLSTMLayer::configure(const CLCompileContext &compile_context, const ICLT
_projection_reduction.configure(compile_context, _projection_weights, &_projection_eff_bias, GEMMLowpReductionKernelInfo(output_size, false, lstm_params.hidden_state_zero(), true));
if(_projection_bias != nullptr)
{
- _projection_bias_add.configure(compile_context, ArithmeticOperation::ADD, _projection_bias, &_projection_eff_bias, &_projection_eff_bias, ConvertPolicy::SATURATE);
+ _projection_bias_add.configure(compile_context, _projection_bias, &_projection_eff_bias, &_projection_eff_bias, ConvertPolicy::SATURATE);
}
}
@@ -255,7 +255,7 @@ void CLQLSTMLayer::configure(const CLCompileContext &compile_context, const ICLT
&_mm_recurrent_to_forget_res, &_recurrent_to_forget_outstage_res, recurrent_to_forget_scale,
mm_out_info, forget_gate_outstage_info);
- _accumulate_input_recurrent_forget.configure(compile_context, ArithmeticOperation::ADD, &_input_to_forget_outstage_res, &_recurrent_to_forget_outstage_res, &_recurrent_to_forget_outstage_res,
+ _accumulate_input_recurrent_forget.configure(compile_context, &_input_to_forget_outstage_res, &_recurrent_to_forget_outstage_res, &_recurrent_to_forget_outstage_res,
ConvertPolicy::SATURATE);
_input_to_forget_outstage_res.allocator()->allocate();
@@ -270,7 +270,7 @@ void CLQLSTMLayer::configure(const CLCompileContext &compile_context, const ICLT
quantization::calculate_quantized_multiplier(cell_to_forget_scale, &gemmlowp_info.gemmlowp_multiplier, &gemmlowp_info.gemmlowp_shift);
_cell_to_forget_outstage.configure(compile_context, &_mul_cell_to_forget_res, nullptr, &_cell_to_forget_outstage_res, gemmlowp_info);
_mul_cell_to_forget_res.allocator()->allocate();
- _accumulate_cell_forget.configure(compile_context, ArithmeticOperation::ADD, &_recurrent_to_forget_outstage_res, &_cell_to_forget_outstage_res, &_recurrent_to_forget_outstage_res,
+ _accumulate_cell_forget.configure(compile_context, &_recurrent_to_forget_outstage_res, &_cell_to_forget_outstage_res, &_recurrent_to_forget_outstage_res,
ConvertPolicy::SATURATE);
_cell_to_forget_outstage_res.allocator()->allocate();
}
@@ -307,7 +307,7 @@ void CLQLSTMLayer::configure(const CLCompileContext &compile_context, const ICLT
&_mm_recurrent_to_cell_res, &_recurrent_to_cell_outstage_res, recurrent_to_cell_scale,
mm_out_info, cell_outstage_info);
- _accumulate_input_recurrent_modulation.configure(compile_context, ArithmeticOperation::ADD, &_input_to_cell_outstage_res, &_recurrent_to_cell_outstage_res, &_recurrent_to_cell_outstage_res,
+ _accumulate_input_recurrent_modulation.configure(compile_context, &_input_to_cell_outstage_res, &_recurrent_to_cell_outstage_res, &_recurrent_to_cell_outstage_res,
ConvertPolicy::SATURATE);
_input_to_cell_outstage_res.allocator()->allocate();
@@ -333,7 +333,7 @@ void CLQLSTMLayer::configure(const CLCompileContext &compile_context, const ICLT
if(_has_cifg)
{
_ones.allocator()->init(*_forget_gate.info());
- _input_gate_sub.configure(compile_context, ArithmeticOperation::SUB, &_ones, &_forget_gate, &_input_gate, ConvertPolicy::SATURATE);
+ _input_gate_sub.configure(compile_context, &_ones, &_forget_gate, &_input_gate, ConvertPolicy::SATURATE);
_ones.allocator()->allocate();
}
else
@@ -350,7 +350,7 @@ void CLQLSTMLayer::configure(const CLCompileContext &compile_context, const ICLT
output_state_in, &_recurrent_to_input_weights_transposed, &_recurrent_to_input_eff_bias,
&_mm_recurrent_to_input_res, &_recurrent_to_input_outstage_res, recurrent_to_input_scale,
mm_out_info, input_outstage_info);
- _accumulate_input_recurrent_input.configure(compile_context, ArithmeticOperation::ADD, &_input_to_input_outstage_res, &_recurrent_to_input_outstage_res, &_recurrent_to_input_outstage_res,
+ _accumulate_input_recurrent_input.configure(compile_context, &_input_to_input_outstage_res, &_recurrent_to_input_outstage_res, &_recurrent_to_input_outstage_res,
ConvertPolicy::SATURATE);
_input_to_input_outstage_res.allocator()->allocate();
@@ -365,7 +365,7 @@ void CLQLSTMLayer::configure(const CLCompileContext &compile_context, const ICLT
_memory_group.manage(&_cell_to_input_outstage_res);
_cell_to_input_outstage.configure(compile_context, &_mul_cell_to_input_res, nullptr, &_cell_to_input_outstage_res, gemmlowp_info);
_mul_cell_to_input_res.allocator()->allocate();
- _accumulate_cell_input.configure(ArithmeticOperation::ADD, &_recurrent_to_input_outstage_res, &_cell_to_input_outstage_res, &_recurrent_to_input_outstage_res, ConvertPolicy::SATURATE);
+ _accumulate_cell_input.configure(&_recurrent_to_input_outstage_res, &_cell_to_input_outstage_res, &_recurrent_to_input_outstage_res, ConvertPolicy::SATURATE);
_cell_to_input_outstage_res.allocator()->allocate();
}
@@ -391,7 +391,7 @@ void CLQLSTMLayer::configure(const CLCompileContext &compile_context, const ICLT
_mul_input_cell_res.allocator()->init(mul_input_cell_info);
_pixelwise_mul_input_cell.configure(compile_context, &_input_gate, &_cell_gate, &_mul_input_cell_res, 1.f, ConvertPolicy::SATURATE, RoundingPolicy::TO_ZERO);
_cell_gate.allocator()->allocate();
- _add_forget_cell.configure(compile_context, ArithmeticOperation::ADD, &_forget_gate, &_mul_input_cell_res, cell_state_out, ConvertPolicy::SATURATE);
+ _add_forget_cell.configure(compile_context, &_forget_gate, &_mul_input_cell_res, cell_state_out, ConvertPolicy::SATURATE);
_mul_input_cell_res.allocator()->allocate();
_forget_gate.allocator()->allocate();
if(_has_cell_clipping)
@@ -412,7 +412,7 @@ void CLQLSTMLayer::configure(const CLCompileContext &compile_context, const ICLT
&_mm_recurrent_to_output_res, &_recurrent_to_output_outstage_res, recurrent_to_output_scale,
mm_out_info, output_outstage_info);
- _accumulate_input_recurrent_output.configure(compile_context, ArithmeticOperation::ADD, &_recurrent_to_output_outstage_res, &_input_to_output_outstage_res, &_recurrent_to_output_outstage_res,
+ _accumulate_input_recurrent_output.configure(compile_context, &_recurrent_to_output_outstage_res, &_input_to_output_outstage_res, &_recurrent_to_output_outstage_res,
ConvertPolicy::SATURATE);
_input_to_output_outstage_res.allocator()->allocate();
@@ -431,7 +431,7 @@ void CLQLSTMLayer::configure(const CLCompileContext &compile_context, const ICLT
_cell_to_output_outstage.configure(compile_context, &_mul_cell_to_output_res, nullptr, &_cell_to_output_outstage_res, gemmlowp_info);
_mul_cell_to_output_res.allocator()->allocate();
- _accumulate_cell_to_output.configure(compile_context, ArithmeticOperation::ADD, &_recurrent_to_output_outstage_res, &_cell_to_output_outstage_res, &_recurrent_to_output_outstage_res,
+ _accumulate_cell_to_output.configure(compile_context, &_recurrent_to_output_outstage_res, &_cell_to_output_outstage_res, &_recurrent_to_output_outstage_res,
ConvertPolicy::SATURATE);
_cell_to_output_outstage_res.allocator()->allocate();
}
@@ -510,7 +510,7 @@ void CLQLSTMLayer::configure(const CLCompileContext &compile_context, const ICLT
accumulate_destination = &_projection_accumulate_res;
}
- _accumulate_projection.configure(compile_context, ArithmeticOperation::ADD, &_projection_outstage_res, accumulate_destination, accumulate_destination, ConvertPolicy::SATURATE);
+ _accumulate_projection.configure(compile_context, &_projection_outstage_res, accumulate_destination, accumulate_destination, ConvertPolicy::SATURATE);
_projection_outstage_res.allocator()->allocate();
if(_projection_tensor_copy_required)
@@ -647,8 +647,8 @@ Status CLQLSTMLayer::validate(const ITensorInfo *input,
if(lstm_params.projection_bias() != nullptr)
{
ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(lstm_params.projection_bias(), 1, DataType::S32);
- ARM_COMPUTE_RETURN_ON_ERROR(CLSaturatedArithmeticOperationKernel::validate(ArithmeticOperation::ADD, lstm_params.projection_bias(), &projection_eff_bias_info,
- &projection_eff_bias_info, ConvertPolicy::SATURATE));
+ ARM_COMPUTE_RETURN_ON_ERROR(CLArithmeticAddition::validate(lstm_params.projection_bias(), &projection_eff_bias_info,
+ &projection_eff_bias_info, ConvertPolicy::SATURATE));
}
}
@@ -691,7 +691,7 @@ Status CLQLSTMLayer::validate(const ITensorInfo *input,
const float recurrent_to_forget_scale = recurrent_to_forget_weights->quantization_info().uniform().scale * qoutput_state_in.scale / lstm_params.forget_intermediate_scale();
ARM_COMPUTE_RETURN_ON_ERROR(validate_mm(gemmlowp_info, output_state_in, &recurrent_weights_transposed, &eff_bias_info, recurrent_to_forget_scale, &mm_out_info, &forget_outstage_info));
- ARM_COMPUTE_RETURN_ON_ERROR(CLSaturatedArithmeticOperationKernel::validate(ArithmeticOperation::ADD, &forget_outstage_info, &forget_outstage_info, &forget_outstage_info, ConvertPolicy::SATURATE));
+ ARM_COMPUTE_RETURN_ON_ERROR(CLArithmeticAddition::validate(&forget_outstage_info, &forget_outstage_info, &forget_outstage_info, ConvertPolicy::SATURATE));
if(lstm_params.has_peephole_opt())
{
@@ -701,7 +701,7 @@ Status CLQLSTMLayer::validate(const ITensorInfo *input,
const float cell_to_forget_scale = std::pow(2, cell_shift) * lstm_params.cell_to_forget_weights()->quantization_info().uniform().scale / lstm_params.forget_intermediate_scale();
ARM_COMPUTE_RETURN_ON_ERROR(quantization::calculate_quantized_multiplier(cell_to_forget_scale, &gemmlowp_info.gemmlowp_multiplier, &gemmlowp_info.gemmlowp_shift));
ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMLowpOutputStage::validate(&mm_out_info, nullptr, &forget_outstage_info, gemmlowp_info));
- ARM_COMPUTE_RETURN_ON_ERROR(CLSaturatedArithmeticOperationKernel::validate(ArithmeticOperation::ADD, &forget_outstage_info, &forget_outstage_info, &forget_outstage_info, ConvertPolicy::SATURATE));
+ ARM_COMPUTE_RETURN_ON_ERROR(CLArithmeticAddition::validate(&forget_outstage_info, &forget_outstage_info, &forget_outstage_info, ConvertPolicy::SATURATE));
}
if(has_layer_norm)
@@ -726,7 +726,7 @@ Status CLQLSTMLayer::validate(const ITensorInfo *input,
const float recurrent_to_cell_scale = recurrent_to_cell_weights->quantization_info().uniform().scale * qoutput_state_in.scale / lstm_params.cell_intermediate_scale();
ARM_COMPUTE_RETURN_ON_ERROR(validate_mm(gemmlowp_info, output_state_in, &input_weights_transposed, &eff_bias_info, recurrent_to_cell_scale, &mm_out_info, &cell_outstage_info));
- ARM_COMPUTE_RETURN_ON_ERROR(CLSaturatedArithmeticOperationKernel::validate(ArithmeticOperation::ADD, &cell_outstage_info, &cell_outstage_info, &cell_outstage_info, ConvertPolicy::SATURATE));
+ ARM_COMPUTE_RETURN_ON_ERROR(CLArithmeticAddition::validate(&cell_outstage_info, &cell_outstage_info, &cell_outstage_info, ConvertPolicy::SATURATE));
if(has_layer_norm)
{
@@ -743,7 +743,7 @@ Status CLQLSTMLayer::validate(const ITensorInfo *input,
if(lstm_params.has_cifg_opt())
{
ARM_COMPUTE_RETURN_ERROR_ON_MSG(lstm_params.input_gate_bias() != nullptr, "Input gate bias must not be present when CIFG is used");
- ARM_COMPUTE_RETURN_ON_ERROR(CLSaturatedArithmeticOperationKernel::validate(ArithmeticOperation::SUB, &input_gate_info, &forget_gate_info, &forget_gate_info, ConvertPolicy::SATURATE));
+ ARM_COMPUTE_RETURN_ON_ERROR(CLArithmeticSubtraction::validate(&input_gate_info, &forget_gate_info, &forget_gate_info, ConvertPolicy::SATURATE));
}
else
{
@@ -762,7 +762,7 @@ Status CLQLSTMLayer::validate(const ITensorInfo *input,
const float recurrent_to_input_scale = lstm_params.recurrent_to_input_weights()->quantization_info().uniform().scale * qoutput_state_in.scale / lstm_params.input_intermediate_scale();
ARM_COMPUTE_RETURN_ON_ERROR(validate_mm(gemmlowp_info, output_state_in, &recurrent_weights_transposed, &eff_bias_info, recurrent_to_input_scale, &mm_out_info, &input_outstage_info));
- ARM_COMPUTE_RETURN_ON_ERROR(CLSaturatedArithmeticOperationKernel::validate(ArithmeticOperation::ADD, &input_outstage_info, &input_outstage_info, &input_outstage_info, ConvertPolicy::SATURATE));
+ ARM_COMPUTE_RETURN_ON_ERROR(CLArithmeticAddition::validate(&input_outstage_info, &input_outstage_info, &input_outstage_info, ConvertPolicy::SATURATE));
if(lstm_params.has_peephole_opt())
{
@@ -771,7 +771,7 @@ Status CLQLSTMLayer::validate(const ITensorInfo *input,
const float cell_to_input_scale = std::pow(2, cell_shift) * lstm_params.cell_to_input_weights()->quantization_info().uniform().scale / lstm_params.input_intermediate_scale();
ARM_COMPUTE_RETURN_ON_ERROR(quantization::calculate_quantized_multiplier(cell_to_input_scale, &gemmlowp_info.gemmlowp_multiplier, &gemmlowp_info.gemmlowp_shift));
ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMLowpOutputStage::validate(&mm_out_info, &eff_bias_info, &input_outstage_info, gemmlowp_info));
- ARM_COMPUTE_RETURN_ON_ERROR(CLSaturatedArithmeticOperationKernel::validate(ArithmeticOperation::ADD, &input_outstage_info, &input_outstage_info, &input_outstage_info, ConvertPolicy::SATURATE));
+ ARM_COMPUTE_RETURN_ON_ERROR(CLArithmeticAddition::validate(&input_outstage_info, &input_outstage_info, &input_outstage_info, ConvertPolicy::SATURATE));
}
if(has_layer_norm)
@@ -786,7 +786,7 @@ Status CLQLSTMLayer::validate(const ITensorInfo *input,
// Cell.
ARM_COMPUTE_RETURN_ON_ERROR(CLPixelWiseMultiplicationKernel::validate(&forget_gate_info, cell_state_in, &forget_gate_info, 1.f, ConvertPolicy::SATURATE, RoundingPolicy::TO_ZERO));
ARM_COMPUTE_RETURN_ON_ERROR(CLPixelWiseMultiplicationKernel::validate(&input_gate_info, cell_state_in, &cell_gate_info, 1.f, ConvertPolicy::SATURATE, RoundingPolicy::TO_ZERO));
- ARM_COMPUTE_RETURN_ON_ERROR(CLSaturatedArithmeticOperationKernel::validate(ArithmeticOperation::ADD, &forget_gate_info, &cell_gate_info, cell_state_out, ConvertPolicy::SATURATE));
+ ARM_COMPUTE_RETURN_ON_ERROR(CLArithmeticAddition::validate(&forget_gate_info, &cell_gate_info, cell_state_out, ConvertPolicy::SATURATE));
if(quantized_cell_clip > 0)
{
ARM_COMPUTE_RETURN_ON_ERROR(CLActivationLayer::validate(cell_state_out, nullptr, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU, -quantized_cell_clip,
@@ -801,7 +801,7 @@ Status CLQLSTMLayer::validate(const ITensorInfo *input,
const float recurrent_to_output_scale = recurrent_to_output_weights->quantization_info().uniform().scale * qoutput_state_in.scale / lstm_params.output_intermediate_scale();
ARM_COMPUTE_RETURN_ON_ERROR(validate_mm(gemmlowp_info, output_state_in, &recurrent_weights_transposed, &eff_bias_info, recurrent_to_output_scale, &mm_out_info, &output_outstage_info));
- ARM_COMPUTE_RETURN_ON_ERROR(CLSaturatedArithmeticOperationKernel::validate(ArithmeticOperation::ADD, &output_outstage_info, &output_outstage_info, &output_outstage_info, ConvertPolicy::SATURATE));
+ ARM_COMPUTE_RETURN_ON_ERROR(CLArithmeticAddition::validate(&output_outstage_info, &output_outstage_info, &output_outstage_info, ConvertPolicy::SATURATE));
if(lstm_params.has_peephole_opt())
{
ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(lstm_params.cell_to_output_weights(), 1, DataType::QSYMM16);
@@ -811,7 +811,7 @@ Status CLQLSTMLayer::validate(const ITensorInfo *input,
// ARM_COMPUTE_RETURN_ON_ERROR(quantization::calculate_quantized_multiplier(cell_to_output_scale, &gemmlowp_info.gemmlowp_multiplier, &gemmlowp_info.gemmlowp_shift));
ARM_COMPUTE_RETURN_ON_ERROR(CLPixelWiseMultiplicationKernel::validate(cell_state_out, lstm_params.cell_to_output_weights(), &output_outstage_info, 1.f, ConvertPolicy::SATURATE,
RoundingPolicy::TO_ZERO));
- ARM_COMPUTE_RETURN_ON_ERROR(CLSaturatedArithmeticOperationKernel::validate(ArithmeticOperation::ADD, &output_outstage_info, &output_outstage_info, &output_outstage_info, ConvertPolicy::SATURATE));
+ ARM_COMPUTE_RETURN_ON_ERROR(CLArithmeticAddition::validate(&output_outstage_info, &output_outstage_info, &output_outstage_info, ConvertPolicy::SATURATE));
}
if(has_layer_norm)
@@ -866,7 +866,7 @@ Status CLQLSTMLayer::validate(const ITensorInfo *input,
ARM_COMPUTE_RETURN_ON_ERROR(CLQLSTMLayer::TensorCopyKernel::validate(*output_state_out, projection_outstage_info));
}
- ARM_COMPUTE_RETURN_ON_ERROR(CLSaturatedArithmeticOperationKernel::validate(ArithmeticOperation::ADD, output_state_out, output_state_out, output_state_out, ConvertPolicy::SATURATE));
+ ARM_COMPUTE_RETURN_ON_ERROR(CLArithmeticAddition::validate(output_state_out, output_state_out, output_state_out, ConvertPolicy::SATURATE));
if(projection_tensor_copy_required)
{
@@ -922,13 +922,13 @@ void CLQLSTMLayer::run()
_mm_recurrent_to_forget.run();
_recurrent_to_forget_outstage.run();
- CLScheduler::get().enqueue(_accumulate_input_recurrent_forget);
+ _accumulate_input_recurrent_forget.run();
if(_has_peephole)
{
CLScheduler::get().enqueue(_pixelwise_mul_cell_to_forget);
_cell_to_forget_outstage.run();
- CLScheduler::get().enqueue(_accumulate_cell_forget);
+ _accumulate_cell_forget.run();
}
if(_has_layer_norm)
@@ -944,7 +944,7 @@ void CLQLSTMLayer::run()
_mm_recurrent_to_cell.run();
_recurrent_to_cell_outstage.run();
- CLScheduler::get().enqueue(_accumulate_input_recurrent_modulation);
+ _accumulate_input_recurrent_modulation.run();
if(_has_layer_norm)
{
@@ -956,7 +956,7 @@ void CLQLSTMLayer::run()
// Input gate
if(_has_cifg)
{
- CLScheduler::get().enqueue(_input_gate_sub);
+ _input_gate_sub.run();
}
else
{
@@ -964,13 +964,13 @@ void CLQLSTMLayer::run()
_input_to_input_outstage.run();
_mm_recurrent_to_input.run();
_recurrent_to_input_outstage.run();
- CLScheduler::get().enqueue(_accumulate_input_recurrent_input);
+ _accumulate_input_recurrent_input.run();
if(_has_peephole)
{
CLScheduler::get().enqueue(_pixelwise_mul_cell_to_input);
_cell_to_input_outstage.run();
- CLScheduler::get().enqueue(_accumulate_cell_input);
+ _accumulate_cell_input.run();
}
if(_has_layer_norm)
@@ -984,7 +984,7 @@ void CLQLSTMLayer::run()
// Cell.
CLScheduler::get().enqueue(_pixelwise_mul_forget_cell);
CLScheduler::get().enqueue(_pixelwise_mul_input_cell);
- CLScheduler::get().enqueue(_add_forget_cell);
+ _add_forget_cell.run();
if(_has_cell_clipping)
{
_cell_clip.run();
@@ -995,12 +995,12 @@ void CLQLSTMLayer::run()
_input_to_output_outstage.run();
_mm_recurrent_to_output.run();
_recurrent_to_output_outstage.run();
- CLScheduler::get().enqueue(_accumulate_input_recurrent_output);
+ _accumulate_input_recurrent_output.run();
if(_has_peephole)
{
CLScheduler::get().enqueue(_pixelwise_mul_cell_to_output);
_cell_to_output_outstage.run();
- CLScheduler::get().enqueue(_accumulate_cell_to_output);
+ _accumulate_cell_to_output.run();
}
if(_has_layer_norm)
@@ -1026,7 +1026,7 @@ void CLQLSTMLayer::run()
_projection_output_to_accumulate_copy.run();
}
- CLScheduler::get().enqueue(_accumulate_projection);
+ _accumulate_projection.run();
if(_projection_tensor_copy_required)
{
@@ -1108,7 +1108,7 @@ void CLQLSTMLayer::prepare()
CLScheduler::get().enqueue(_projection_reduction);
if(_projection_bias != nullptr)
{
- CLScheduler::get().enqueue(_projection_bias_add);
+ _projection_bias_add.run();
_projection_bias->mark_as_unused();
}