aboutsummaryrefslogtreecommitdiff
path: root/src/runtime/NEON/functions/NEQLSTMLayer.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/runtime/NEON/functions/NEQLSTMLayer.cpp')
-rw-r--r--src/runtime/NEON/functions/NEQLSTMLayer.cpp59
1 files changed, 21 insertions, 38 deletions
diff --git a/src/runtime/NEON/functions/NEQLSTMLayer.cpp b/src/runtime/NEON/functions/NEQLSTMLayer.cpp
index 466c41307b..beb180fda5 100644
--- a/src/runtime/NEON/functions/NEQLSTMLayer.cpp
+++ b/src/runtime/NEON/functions/NEQLSTMLayer.cpp
@@ -437,6 +437,8 @@ void NEQLSTMLayer::configure(const ITensor *input,
_projection_tensor_copy_required = (num_units != output_size);
ITensor *hidden_gate_result = output_state_out;
+ _memory_group.manage(&_hidden_gate);
+
if(_projection_tensor_copy_required)
{
_hidden_gate.allocator()->init(*output_state_out->info());
@@ -450,7 +452,7 @@ void NEQLSTMLayer::configure(const ITensor *input,
// Projection.
if(_has_projection)
{
- const TensorInfo projection_outstage_info(*hidden_gate_result->info());
+ const TensorInfo projection_outstage_info(*output_state_out->info());
const UniformQuantizationInfo qprojection = _projection_weights->info()->quantization_info().uniform();
const float projection_scale = qprojection.scale * lstm_params.hidden_state_scale() / qoutput_state_in.scale;
gemmlowp_info.gemmlowp_offset = qoutput_state_in.offset;
@@ -458,23 +460,13 @@ void NEQLSTMLayer::configure(const ITensor *input,
gemmlowp_info.gemmlowp_max_bound = std::numeric_limits<int8_t>::max();
gemmlowp_info.output_data_type = DataType::QASYMM8_SIGNED;
- _memory_group.manage(&_projection_eff_bias_adjusted);
- ITensor *bias_to_use = &_projection_eff_bias;
-
- if(_projection_tensor_copy_required)
- {
- _projection_eff_bias_adjusted.allocator()->init(*_projection_eff_bias.info());
- _projection_eff_bias_adjusted.info()->set_tensor_shape(TensorShape(num_units));
- _projection_bias_copy.configure(_projection_eff_bias, _projection_eff_bias_adjusted);
- bias_to_use = &_projection_eff_bias_adjusted;
- }
+ TensorInfo projection_mm_out_info{ mm_out_info };
+ projection_mm_out_info.set_tensor_shape(TensorShape(output_size, batch_size));
configure_mm(_mm_projection, _projection_outstage, gemmlowp_info,
- hidden_gate_result, &_projection_weights_transposed, bias_to_use,
+ hidden_gate_result, &_projection_weights_transposed, &_projection_eff_bias,
&_mm_projection_res, &_projection_outstage_res, projection_scale,
- mm_out_info, projection_outstage_info);
-
- _projection_eff_bias_adjusted.allocator()->allocate();
+ projection_mm_out_info, projection_outstage_info);
ITensor *accumulate_destination = output_state_out;
@@ -655,10 +647,10 @@ Status NEQLSTMLayer::validate(const ITensorInfo *input,
const TensorInfo forget_outstage_info(TensorShape(num_units, batch_size), 1, DataType::QSYMM16, QuantizationInfo(lstm_params.forget_intermediate_scale(), 0));
const TensorInfo mm_out_info(TensorShape(num_units, batch_size), 1, DataType::S32);
const float input_to_forget_scale = input_to_forget_weights->quantization_info().uniform().scale * qinput.scale / lstm_params.forget_intermediate_scale();
- validate_mm(gemmlowp_info, input, &input_weights_transposed, &eff_bias_info, input_to_forget_scale, &mm_out_info, &forget_outstage_info);
+ ARM_COMPUTE_RETURN_ON_ERROR(validate_mm(gemmlowp_info, input, &input_weights_transposed, &eff_bias_info, input_to_forget_scale, &mm_out_info, &forget_outstage_info));
const float recurrent_to_forget_scale = recurrent_to_forget_weights->quantization_info().uniform().scale * qoutput_state_in.scale / lstm_params.forget_intermediate_scale();
- validate_mm(gemmlowp_info, input, &recurrent_weights_transposed, &eff_bias_info, recurrent_to_forget_scale, &mm_out_info, &forget_outstage_info);
+ ARM_COMPUTE_RETURN_ON_ERROR(validate_mm(gemmlowp_info, output_state_in, &recurrent_weights_transposed, &eff_bias_info, recurrent_to_forget_scale, &mm_out_info, &forget_outstage_info));
ARM_COMPUTE_RETURN_ON_ERROR(NEArithmeticAdditionKernel::validate(&forget_outstage_info, &forget_outstage_info, &forget_outstage_info, ConvertPolicy::SATURATE));
@@ -689,10 +681,10 @@ Status NEQLSTMLayer::validate(const ITensorInfo *input,
// Modulation gate.
const TensorInfo cell_outstage_info(TensorShape(num_units, batch_size), 1, DataType::QSYMM16, QuantizationInfo(lstm_params.cell_intermediate_scale(), 0));
const float input_to_cell_scale = input_to_cell_weights->quantization_info().uniform().scale * qinput.scale / lstm_params.cell_intermediate_scale();
- validate_mm(gemmlowp_info, input, &input_weights_transposed, &eff_bias_info, input_to_cell_scale, &mm_out_info, &cell_outstage_info);
+ ARM_COMPUTE_RETURN_ON_ERROR(validate_mm(gemmlowp_info, input, &input_weights_transposed, &eff_bias_info, input_to_cell_scale, &mm_out_info, &cell_outstage_info));
const float recurrent_to_cell_scale = recurrent_to_cell_weights->quantization_info().uniform().scale * qoutput_state_in.scale / lstm_params.cell_intermediate_scale();
- validate_mm(gemmlowp_info, input, &input_weights_transposed, &eff_bias_info, recurrent_to_cell_scale, &mm_out_info, &cell_outstage_info);
+ ARM_COMPUTE_RETURN_ON_ERROR(validate_mm(gemmlowp_info, output_state_in, &recurrent_weights_transposed, &eff_bias_info, recurrent_to_cell_scale, &mm_out_info, &cell_outstage_info));
ARM_COMPUTE_RETURN_ON_ERROR(NEArithmeticAdditionKernel::validate(&cell_outstage_info, &cell_outstage_info, &cell_outstage_info, ConvertPolicy::SATURATE));
@@ -724,10 +716,10 @@ Status NEQLSTMLayer::validate(const ITensorInfo *input,
const TensorInfo input_outstage_info(TensorShape(num_units, batch_size), 1, DataType::QSYMM16, QuantizationInfo(lstm_params.input_intermediate_scale(), 0));
const float input_to_input_scale = lstm_params.input_to_input_weights()->quantization_info().uniform().scale * qinput.scale / lstm_params.input_intermediate_scale();
- validate_mm(gemmlowp_info, input, &input_weights_transposed, &eff_bias_info, input_to_input_scale, &mm_out_info, &input_outstage_info);
+ ARM_COMPUTE_RETURN_ON_ERROR(validate_mm(gemmlowp_info, input, &input_weights_transposed, &eff_bias_info, input_to_input_scale, &mm_out_info, &input_outstage_info));
const float recurrent_to_input_scale = lstm_params.recurrent_to_input_weights()->quantization_info().uniform().scale * qoutput_state_in.scale / lstm_params.input_intermediate_scale();
- validate_mm(gemmlowp_info, input, &recurrent_weights_transposed, &eff_bias_info, recurrent_to_input_scale, &mm_out_info, &input_outstage_info);
+ ARM_COMPUTE_RETURN_ON_ERROR(validate_mm(gemmlowp_info, output_state_in, &recurrent_weights_transposed, &eff_bias_info, recurrent_to_input_scale, &mm_out_info, &input_outstage_info));
ARM_COMPUTE_RETURN_ON_ERROR(NEArithmeticAdditionKernel::validate(&input_outstage_info, &input_outstage_info, &input_outstage_info, ConvertPolicy::SATURATE));
@@ -762,10 +754,10 @@ Status NEQLSTMLayer::validate(const ITensorInfo *input,
// Output gate.
const TensorInfo output_outstage_info(TensorShape(num_units, batch_size), 1, DataType::QSYMM16, QuantizationInfo(lstm_params.output_intermediate_scale(), 0));
const float input_to_output_scale = input_to_output_weights->quantization_info().uniform().scale * qinput.scale / lstm_params.output_intermediate_scale();
- validate_mm(gemmlowp_info, input, &input_weights_transposed, &eff_bias_info, input_to_output_scale, &mm_out_info, &output_outstage_info);
+ ARM_COMPUTE_RETURN_ON_ERROR(validate_mm(gemmlowp_info, input, &input_weights_transposed, &eff_bias_info, input_to_output_scale, &mm_out_info, &output_outstage_info));
const float recurrent_to_output_scale = recurrent_to_output_weights->quantization_info().uniform().scale * qoutput_state_in.scale / lstm_params.output_intermediate_scale();
- validate_mm(gemmlowp_info, output_state_in, &recurrent_weights_transposed, &eff_bias_info, recurrent_to_output_scale, &mm_out_info, &output_outstage_info);
+ ARM_COMPUTE_RETURN_ON_ERROR(validate_mm(gemmlowp_info, output_state_in, &recurrent_weights_transposed, &eff_bias_info, recurrent_to_output_scale, &mm_out_info, &output_outstage_info));
ARM_COMPUTE_RETURN_ON_ERROR(NEArithmeticAdditionKernel::validate(&output_outstage_info, &output_outstage_info, &output_outstage_info, ConvertPolicy::SATURATE));
if(lstm_params.has_peephole_opt())
@@ -816,17 +808,14 @@ Status NEQLSTMLayer::validate(const ITensorInfo *input,
gemmlowp_info.gemmlowp_max_bound = std::numeric_limits<int8_t>::max();
gemmlowp_info.output_data_type = DataType::QASYMM8_SIGNED;
- TensorInfo projection_outstage_info(hidden_out_info);
-
- if(projection_tensor_copy_required)
- {
- TensorInfo projection_eff_bias_adjusted_info{ projection_eff_bias_info };
- projection_eff_bias_adjusted_info.set_tensor_shape(TensorShape(num_units));
+ const TensorInfo projection_outstage_info(*output_state_out);
+ const TensorInfo projection_weights_transposed(TensorShape(output_size, num_units), 1, lstm_params.projection_weights()->data_type(), lstm_params.projection_weights()->quantization_info());
- ARM_COMPUTE_RETURN_ON_ERROR(NEQLSTMLayer::TensorCopyKernel::validate(projection_eff_bias_info, projection_eff_bias_adjusted_info));
- }
+ TensorInfo projection_mm_out_info{ mm_out_info };
+ projection_mm_out_info.set_tensor_shape(TensorShape(output_size, batch_size));
- validate_mm(gemmlowp_info, output_state_out, &recurrent_weights_transposed, &projection_eff_bias_info, input_to_output_scale, &mm_out_info, &projection_outstage_info);
+ ARM_COMPUTE_RETURN_ON_ERROR(validate_mm(gemmlowp_info, &hidden_out_info, &projection_weights_transposed, &projection_eff_bias_info, projection_scale, &projection_mm_out_info,
+ &projection_outstage_info));
if(projection_tensor_copy_required)
{
@@ -985,11 +974,6 @@ void NEQLSTMLayer::run()
// Projection.
if(_has_projection)
{
- if(_projection_tensor_copy_required)
- {
- _projection_bias_copy.run();
- }
-
_mm_projection.run();
_projection_outstage.run();
@@ -1088,7 +1072,6 @@ void NEQLSTMLayer::prepare()
if(!_projection_tensor_copy_required)
{
_hidden_gate.mark_as_unused();
- _projection_eff_bias_adjusted.mark_as_unused();
_projection_accumulate_res.mark_as_unused();
}
}