aboutsummaryrefslogtreecommitdiff
path: root/src/runtime/NEON/functions/NEQLSTMLayer.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/runtime/NEON/functions/NEQLSTMLayer.cpp')
-rw-r--r--src/runtime/NEON/functions/NEQLSTMLayer.cpp208
1 files changed, 179 insertions, 29 deletions
diff --git a/src/runtime/NEON/functions/NEQLSTMLayer.cpp b/src/runtime/NEON/functions/NEQLSTMLayer.cpp
index a279bba2ab..9c78ea8b75 100644
--- a/src/runtime/NEON/functions/NEQLSTMLayer.cpp
+++ b/src/runtime/NEON/functions/NEQLSTMLayer.cpp
@@ -46,6 +46,36 @@ Status validate_mm(GEMMLowpOutputStageInfo &gemmlowp_info, const ITensorInfo *mm
}
} // namespace
+Status NEQLSTMLayer::TensorCopyKernel::validate(const ITensorInfo &src, const ITensorInfo &dst)
+{
+ ARM_COMPUTE_RETURN_ERROR_ON(src.tensor_shape().num_dimensions() > max_dimension_supported);
+ ARM_COMPUTE_RETURN_ERROR_ON(dst.tensor_shape().num_dimensions() > max_dimension_supported);
+ ARM_COMPUTE_ERROR_ON_MISMATCHING_DATA_TYPES(&src, &dst);
+ ARM_COMPUTE_RETURN_ERROR_ON(dst.tensor_shape().y() != src.tensor_shape().y());
+ return Status{};
+}
+
+void NEQLSTMLayer::TensorCopyKernel::configure(ITensor &src, ITensor &dst)
+{
+ ARM_COMPUTE_ERROR_THROW_ON(NEQLSTMLayer::TensorCopyKernel::validate(*src.info(), *dst.info()));
+ _src = &src;
+ _dst = &dst;
+ _row_size = std::min(_src->info()->tensor_shape().x(), _dst->info()->tensor_shape().x());
+ _window = calculate_max_window(*_src->info(), Steps());
+}
+
+void NEQLSTMLayer::TensorCopyKernel::run()
+{
+ Iterator input_iter{ _src, _window };
+ Iterator output_iter{ _dst, _window };
+
+ execute_window_loop(_window, [&](const Coordinates &)
+ {
+ memcpy(output_iter.ptr(), input_iter.ptr(), _row_size);
+ },
+ input_iter, output_iter);
+}
+
NEQLSTMLayer::NEQLSTMLayer(std::shared_ptr<IMemoryManager> memory_manager)
{
_memory_group = MemoryGroup(std::move(memory_manager));
@@ -93,8 +123,9 @@ void NEQLSTMLayer::configure(const ITensor *input,
forget_gate_bias->info(), cell_bias->info(), output_gate_bias->info(),
cell_state_in->info(), output_state_in->info(), cell_state_out->info(), output_state_out->info(), lstm_params_info));
- const int batch_size = input->info()->dimension(1);
- const int num_units = input_to_output_weights->info()->dimension(1);
+ const int batch_size = input->info()->dimension(1);
+ const int num_units = input_to_output_weights->info()->dimension(1);
+ const int output_size = output_state_out->info()->dimension(_out_state_output_size_dimension_idx);
const UniformQuantizationInfo qinput = input->info()->quantization_info().uniform();
const UniformQuantizationInfo qcell_state_in = cell_state_in->info()->quantization_info().uniform();
@@ -154,10 +185,9 @@ void NEQLSTMLayer::configure(const ITensor *input,
_recurrent_to_cell_reduction.configure(recurrent_to_cell_weights, &_recurrent_to_cell_eff_bias, GEMMLowpReductionKernelInfo(num_units, false, -qoutput_state_in.offset, true));
_input_to_output_reduction.configure(input_to_output_weights, &_input_to_output_eff_bias, GEMMLowpReductionKernelInfo(num_units, false, -qinput.offset, true));
_recurrent_to_output_reduction.configure(recurrent_to_output_weights, &_recurrent_to_output_eff_bias, GEMMLowpReductionKernelInfo(num_units, false, -qoutput_state_in.offset, true));
- if(_projection_bias != nullptr)
+ if(_has_projection)
{
- _projection_reduction.configure(_projection_weights, &_projection_reduction_res, GEMMLowpReductionKernelInfo(num_units, false, lstm_params.hidden_state_zero(), true));
- _projection_bias_add.configure(_projection_bias, &_projection_reduction_res, &_projection_eff_bias, ConvertPolicy::SATURATE);
+ _projection_reduction.configure(_projection_weights, &_projection_eff_bias, GEMMLowpReductionKernelInfo(output_size, false, lstm_params.hidden_state_zero(), true));
}
// Pre-transpose weights to be used in GEMM.
@@ -203,6 +233,7 @@ void NEQLSTMLayer::configure(const ITensor *input,
if(_has_peephole)
{
+ _mul_cell_to_forget_res.allocator()->init(TensorInfo(cell_state_in->info()->tensor_shape(), 1, DataType::S32));
_memory_group.manage(&_mul_cell_to_forget_res);
_pixelwise_mul_cell_to_forget.configure(cell_state_in, lstm_params.cell_to_forget_weights(), &_mul_cell_to_forget_res, 1.f, ConvertPolicy::SATURATE, RoundingPolicy::TO_ZERO);
_cell_to_forget_outstage_res.allocator()->init(TensorInfo(_mul_cell_to_forget_res.info()->tensor_shape(), 1, DataType::QSYMM16, QuantizationInfo(lstm_params.forget_intermediate_scale(), 0)));
@@ -287,7 +318,7 @@ void NEQLSTMLayer::configure(const ITensor *input,
const float recurrent_to_input_scale = _recurrent_to_input_weights->info()->quantization_info().uniform().scale * qoutput_state_in.scale / lstm_params.input_intermediate_scale();
configure_mm(_mm_recurrent_to_input, _recurrent_to_input_outstage, gemmlowp_info,
- input, &_recurrent_to_input_weights_transposed, &_recurrent_to_input_eff_bias,
+ output_state_in, &_recurrent_to_input_weights_transposed, &_recurrent_to_input_eff_bias,
&_mm_recurrent_to_input_res, &_recurrent_to_input_outstage_res, recurrent_to_input_scale,
mm_out_info, input_outstage_info);
_accumulate_input_recurrent_input.configure(&_input_to_input_outstage_res, &_recurrent_to_input_outstage_res, &_recurrent_to_input_outstage_res, ConvertPolicy::SATURATE);
@@ -295,6 +326,7 @@ void NEQLSTMLayer::configure(const ITensor *input,
if(_has_peephole)
{
+ _mul_cell_to_input_res.allocator()->init(TensorInfo(cell_state_in->info()->tensor_shape(), 1, DataType::S32));
_memory_group.manage(&_mul_cell_to_input_res);
_pixelwise_mul_cell_to_input.configure(cell_state_in, lstm_params.cell_to_input_weights(), &_mul_cell_to_input_res, 1.f, ConvertPolicy::SATURATE, RoundingPolicy::TO_ZERO);
const float cell_to_input_scale = std::pow(2, cell_shift) * lstm_params.cell_to_input_weights()->info()->quantization_info().uniform().scale / lstm_params.input_intermediate_scale();
@@ -316,7 +348,7 @@ void NEQLSTMLayer::configure(const ITensor *input,
input_activation_input = &get_layer_norm_output(LayerNormGate::Input);
}
- _input_gate_tanh.configure(input_activation_input, &_input_gate, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::TANH, 1.f, 1.f));
+ _input_gate_sigmoid.configure(input_activation_input, &_input_gate, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LOGISTIC));
input_activation_input->allocator()->allocate();
}
// Cell.
@@ -357,12 +389,19 @@ void NEQLSTMLayer::configure(const ITensor *input,
{
// TODO(COMPMID-3395): Perform multiplication in the quantized domain in NEPixelWiseMultiplicationKernel
// Here we are not using the output stage because all operations are done in float
- // const float cell_to_output_scale = std::pow(2, cell_shift) * lstm_params.cell_to_output_weights()->info()->quantization_info().uniform().scale / lstm_params.output_intermediate_scale();
- // quantization::calculate_quantized_multiplier(cell_to_output_scale, &gemmlowp_info.gemmlowp_multiplier, &gemmlowp_info.gemmlowp_shift);
+ _mul_cell_to_output_res.allocator()->init(TensorInfo(cell_state_out->info()->tensor_shape(), 1, DataType::S32));
_memory_group.manage(&_mul_cell_to_output_res);
_pixelwise_mul_cell_to_output.configure(cell_state_out, lstm_params.cell_to_output_weights(), &_mul_cell_to_output_res, 1.f, ConvertPolicy::SATURATE, RoundingPolicy::TO_ZERO);
- _accumulate_cell_to_output.configure(&_recurrent_to_output_outstage_res, &_mul_cell_to_output_res, &_recurrent_to_output_outstage_res, ConvertPolicy::SATURATE);
+
+ const float cell_to_output_scale = std::pow(2, cell_shift) * lstm_params.cell_to_output_weights()->info()->quantization_info().uniform().scale / lstm_params.output_intermediate_scale();
+ quantization::calculate_quantized_multiplier(cell_to_output_scale, &gemmlowp_info.gemmlowp_multiplier, &gemmlowp_info.gemmlowp_shift);
+ _cell_to_output_outstage_res.allocator()->init(TensorInfo(_mul_cell_to_output_res.info()->tensor_shape(), 1, DataType::QSYMM16, QuantizationInfo(lstm_params.output_intermediate_scale(), 0)));
+ _memory_group.manage(&_cell_to_output_outstage_res);
+ _cell_to_output_outstage.configure(&_mul_cell_to_output_res, nullptr, &_cell_to_output_outstage_res, gemmlowp_info);
_mul_cell_to_output_res.allocator()->allocate();
+
+ _accumulate_cell_to_output.configure(&_recurrent_to_output_outstage_res, &_cell_to_output_outstage_res, &_recurrent_to_output_outstage_res, ConvertPolicy::SATURATE);
+ _cell_to_output_outstage_res.allocator()->allocate();
}
Tensor *output_activation_input = &_recurrent_to_output_outstage_res;
@@ -393,13 +432,24 @@ void NEQLSTMLayer::configure(const ITensor *input,
quantization::calculate_quantized_multiplier(hidden_state_scale, &gemmlowp_info.gemmlowp_multiplier, &gemmlowp_info.gemmlowp_shift, /* ignore_epsilon */ true);
gemmlowp_info.gemmlowp_offset = lstm_params.hidden_state_zero();
gemmlowp_info.output_data_type = output_state_in->info()->data_type();
- _hidden_outstage.configure(&_hidden_mul_res, nullptr, output_state_out, gemmlowp_info);
+
+ _projection_tensor_copy_required = (num_units != output_size);
+ ITensor *hidden_gate_result = output_state_out;
+
+ if(_projection_tensor_copy_required)
+ {
+ _hidden_gate.allocator()->init(*output_state_out->info());
+ _hidden_gate.info()->set_tensor_shape(_hidden_mul_res.info()->tensor_shape());
+ hidden_gate_result = &_hidden_gate;
+ }
+
+ _hidden_outstage.configure(&_hidden_mul_res, nullptr, hidden_gate_result, gemmlowp_info);
_hidden_mul_res.allocator()->allocate();
// Projection.
if(_has_projection)
{
- const TensorInfo projection_outstage_info(*output_state_out->info());
+ const TensorInfo projection_outstage_info(*hidden_gate_result->info());
const UniformQuantizationInfo qprojection = _projection_weights->info()->quantization_info().uniform();
const float projection_scale = qprojection.scale * lstm_params.hidden_state_scale() / qoutput_state_in.scale;
gemmlowp_info.gemmlowp_offset = qoutput_state_in.offset;
@@ -407,14 +457,44 @@ void NEQLSTMLayer::configure(const ITensor *input,
gemmlowp_info.gemmlowp_max_bound = std::numeric_limits<int8_t>::max();
gemmlowp_info.output_data_type = DataType::QASYMM8_SIGNED;
+ _memory_group.manage(&_projection_eff_bias_adjusted);
+ ITensor *bias_to_use = &_projection_eff_bias;
+
+ if(_projection_tensor_copy_required)
+ {
+ _projection_eff_bias_adjusted.allocator()->init(*_projection_eff_bias.info());
+ _projection_eff_bias_adjusted.info()->set_tensor_shape(TensorShape(num_units));
+ _projection_bias_copy.configure(_projection_eff_bias, _projection_eff_bias_adjusted);
+ bias_to_use = &_projection_eff_bias_adjusted;
+ }
+
configure_mm(_mm_projection, _projection_outstage, gemmlowp_info,
- output_state_out, &_projection_weights_transposed, &_projection_eff_bias,
+ hidden_gate_result, &_projection_weights_transposed, bias_to_use,
&_mm_projection_res, &_projection_outstage_res, projection_scale,
mm_out_info, projection_outstage_info);
- _accumulate_projection.configure(&_projection_outstage_res, output_state_out, output_state_out, ConvertPolicy::SATURATE);
+ _projection_eff_bias_adjusted.allocator()->allocate();
+
+ ITensor *accumulate_destination = output_state_out;
+
+ if(_projection_tensor_copy_required)
+ {
+ _hidden_gate.allocator()->allocate();
+ _projection_accumulate_res.allocator()->init(*output_state_out->info());
+ _projection_accumulate_res.info()->set_tensor_shape(_projection_outstage_res.info()->tensor_shape());
+ _projection_output_to_accumulate_copy.configure(*output_state_out, _projection_accumulate_res);
+ accumulate_destination = &_projection_accumulate_res;
+ }
+
+ _accumulate_projection.configure(&_projection_outstage_res, accumulate_destination, accumulate_destination, ConvertPolicy::SATURATE);
_projection_outstage_res.allocator()->allocate();
+ if(_projection_tensor_copy_required)
+ {
+ _projection_accumulate_to_output_copy.configure(_projection_accumulate_res, *output_state_out);
+ _projection_accumulate_res.allocator()->allocate();
+ }
+
int8_t quantized_projection_clip{ 0 };
if(lstm_params.projection_clip() > 0.0f)
{
@@ -427,6 +507,14 @@ void NEQLSTMLayer::configure(const ITensor *input,
_has_projection_clipping = true;
}
}
+ else
+ {
+ if(_projection_tensor_copy_required)
+ {
+ _hidden_to_output_copy.configure(_hidden_gate, *output_state_out);
+ _hidden_gate.allocator()->allocate();
+ }
+ }
}
Status NEQLSTMLayer::validate(const ITensorInfo *input,
@@ -446,7 +534,7 @@ Status NEQLSTMLayer::validate(const ITensorInfo *input,
const unsigned int input_size = input->dimension(0);
const unsigned int batch_size = input->dimension(1);
const unsigned int num_units = input_to_output_weights->dimension(1);
- const unsigned int output_size = recurrent_to_output_weights->dimension(0);
+ const unsigned int output_size = output_state_out->dimension(_out_state_output_size_dimension_idx);
ARM_COMPUTE_RETURN_ERROR_ON(input_to_output_weights->num_dimensions() != 2);
ARM_COMPUTE_RETURN_ERROR_ON(input_to_output_weights->dimension(0) != input_size);
@@ -509,6 +597,7 @@ Status NEQLSTMLayer::validate(const ITensorInfo *input,
// Precompute effective bias for optimizing the matmul computations.
const TensorInfo eff_bias_info(TensorShape(num_units), 1, DataType::S32);
+ const TensorInfo projection_eff_bias_info(TensorShape(output_size), 1, DataType::S32);
if(!lstm_params.has_cifg_opt())
{
ARM_COMPUTE_RETURN_ON_ERROR(NEGEMMLowpMatrixAReductionKernel::validate(lstm_params.input_to_input_weights(), &eff_bias_info, GEMMLowpReductionKernelInfo(num_units, false, -qinput.offset, true)));
@@ -521,11 +610,11 @@ Status NEQLSTMLayer::validate(const ITensorInfo *input,
ARM_COMPUTE_RETURN_ON_ERROR(NEGEMMLowpMatrixAReductionKernel::validate(recurrent_to_cell_weights, &eff_bias_info, GEMMLowpReductionKernelInfo(num_units, false, -qoutput_state_in.offset, true)));
ARM_COMPUTE_RETURN_ON_ERROR(NEGEMMLowpMatrixAReductionKernel::validate(input_to_output_weights, &eff_bias_info, GEMMLowpReductionKernelInfo(num_units, false, -qinput.offset, true)));
ARM_COMPUTE_RETURN_ON_ERROR(NEGEMMLowpMatrixAReductionKernel::validate(recurrent_to_output_weights, &eff_bias_info, GEMMLowpReductionKernelInfo(num_units, false, -qoutput_state_in.offset, true)));
- if(lstm_params.projection_bias() != nullptr)
+ if(lstm_params.has_projection())
{
- ARM_COMPUTE_RETURN_ON_ERROR(NEGEMMLowpMatrixAReductionKernel::validate(lstm_params.projection_weights(), &eff_bias_info, GEMMLowpReductionKernelInfo(num_units, false, lstm_params.hidden_state_zero(),
+ ARM_COMPUTE_RETURN_ON_ERROR(NEGEMMLowpMatrixAReductionKernel::validate(lstm_params.projection_weights(), &projection_eff_bias_info, GEMMLowpReductionKernelInfo(output_size, false,
+ lstm_params.hidden_state_zero(),
true)));
- ARM_COMPUTE_RETURN_ON_ERROR(NEArithmeticAdditionKernel::validate(lstm_params.projection_bias(), &eff_bias_info, &eff_bias_info, ConvertPolicy::SATURATE));
}
const TensorInfo input_weights_transposed(TensorShape(num_units, input_size), 1, input_to_forget_weights->data_type(), input_to_forget_weights->quantization_info());
@@ -545,7 +634,8 @@ Status NEQLSTMLayer::validate(const ITensorInfo *input,
}
if(lstm_params.has_projection())
{
- ARM_COMPUTE_RETURN_ON_ERROR(NETranspose::validate(lstm_params.projection_weights(), &recurrent_weights_transposed));
+ const TensorInfo projection_weights_transposed(TensorShape(output_size, num_units), 1, lstm_params.projection_weights()->data_type(), lstm_params.projection_weights()->quantization_info());
+ ARM_COMPUTE_RETURN_ON_ERROR(NETranspose::validate(lstm_params.projection_weights(), &projection_weights_transposed));
}
GEMMLowpOutputStageInfo gemmlowp_info;
@@ -627,23 +717,22 @@ Status NEQLSTMLayer::validate(const ITensorInfo *input,
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(forget_gate_bias, lstm_params.input_gate_bias());
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(forget_gate_bias, lstm_params.input_gate_bias());
- ARM_COMPUTE_RETURN_ON_ERROR(NEGEMMLowpMatrixMultiplyCore::validate(input, lstm_params.input_to_input_weights(), nullptr, &mm_out_info));
const TensorInfo input_outstage_info(TensorShape(num_units, batch_size), 1, DataType::QSYMM16, QuantizationInfo(lstm_params.input_intermediate_scale(), 0));
const float input_to_input_scale = lstm_params.input_to_input_weights()->quantization_info().uniform().scale * qinput.scale / lstm_params.input_intermediate_scale();
- validate_mm(gemmlowp_info, input, lstm_params.input_to_input_weights(), &eff_bias_info, input_to_input_scale, &mm_out_info, &input_outstage_info);
+ validate_mm(gemmlowp_info, input, &input_weights_transposed, &eff_bias_info, input_to_input_scale, &mm_out_info, &input_outstage_info);
const float recurrent_to_input_scale = lstm_params.recurrent_to_input_weights()->quantization_info().uniform().scale * qoutput_state_in.scale / lstm_params.input_intermediate_scale();
- validate_mm(gemmlowp_info, input, lstm_params.recurrent_to_input_weights(), &eff_bias_info, recurrent_to_input_scale, &mm_out_info, &input_outstage_info);
+ validate_mm(gemmlowp_info, input, &recurrent_weights_transposed, &eff_bias_info, recurrent_to_input_scale, &mm_out_info, &input_outstage_info);
ARM_COMPUTE_RETURN_ON_ERROR(NEArithmeticAdditionKernel::validate(&input_outstage_info, &input_outstage_info, &input_outstage_info, ConvertPolicy::SATURATE));
if(lstm_params.has_peephole_opt())
{
- ARM_COMPUTE_RETURN_ON_ERROR(NEPixelWiseMultiplicationKernel::validate(cell_state_in, lstm_params.cell_to_input_weights(), &input_outstage_info, 1.f, ConvertPolicy::SATURATE,
+ ARM_COMPUTE_RETURN_ON_ERROR(NEPixelWiseMultiplicationKernel::validate(cell_state_in, lstm_params.cell_to_input_weights(), &mm_out_info, 1.f, ConvertPolicy::SATURATE,
RoundingPolicy::TO_ZERO));
const float cell_to_input_scale = std::pow(2, cell_shift) * lstm_params.cell_to_input_weights()->quantization_info().uniform().scale / lstm_params.input_intermediate_scale();
ARM_COMPUTE_RETURN_ON_ERROR(quantization::calculate_quantized_multiplier(cell_to_input_scale, &gemmlowp_info.gemmlowp_multiplier, &gemmlowp_info.gemmlowp_shift));
- ARM_COMPUTE_RETURN_ON_ERROR(NEGEMMLowpOutputStage::validate(&input_outstage_info, &eff_bias_info, &input_outstage_info, gemmlowp_info));
+ ARM_COMPUTE_RETURN_ON_ERROR(NEGEMMLowpOutputStage::validate(&mm_out_info, &eff_bias_info, &input_outstage_info, gemmlowp_info));
ARM_COMPUTE_RETURN_ON_ERROR(NEArithmeticAdditionKernel::validate(&input_outstage_info, &input_outstage_info, &input_outstage_info, ConvertPolicy::SATURATE));
}
@@ -654,7 +743,7 @@ Status NEQLSTMLayer::validate(const ITensorInfo *input,
ARM_COMPUTE_RETURN_ON_ERROR(validate_layer_norm(input_outstage_info, *w_info, *b_info));
}
- ARM_COMPUTE_RETURN_ON_ERROR(NEActivationLayer::validate(&input_outstage_info, nullptr, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::TANH, 1.f, 1.f)));
+ ARM_COMPUTE_RETURN_ON_ERROR(NEActivationLayer::validate(&input_outstage_info, &input_gate_info, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::TANH, 1.f, 1.f)));
}
// Cell.
ARM_COMPUTE_RETURN_ON_ERROR(NEPixelWiseMultiplicationKernel::validate(&forget_gate_info, cell_state_in, &forget_gate_info, 1.f, ConvertPolicy::SATURATE, RoundingPolicy::TO_ZERO));
@@ -699,11 +788,14 @@ Status NEQLSTMLayer::validate(const ITensorInfo *input,
// Hidden.
ARM_COMPUTE_RETURN_ON_ERROR(NEActivationLayer::validate(cell_state_out, &input_gate_info, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::TANH, 1.f, 1.f)));
const TensorInfo hidden_mul_res(TensorShape(num_units, batch_size), 1, DataType::S32);
+ const TensorInfo hidden_out_info(TensorShape(num_units, batch_size), 1, DataType::QASYMM8_SIGNED);
ARM_COMPUTE_RETURN_ON_ERROR(NEPixelWiseMultiplicationKernel::validate(&output_gate_info, &input_gate_info, &hidden_mul_res, 1.f, ConvertPolicy::SATURATE, RoundingPolicy::TO_ZERO));
const float hidden_state_scale = std::pow(2, -15) / lstm_params.hidden_state_scale() * std::pow(2, -15);
ARM_COMPUTE_RETURN_ON_ERROR(quantization::calculate_quantized_multiplier(hidden_state_scale, &gemmlowp_info.gemmlowp_multiplier, &gemmlowp_info.gemmlowp_shift, /* ignore_epsilon */ true));
gemmlowp_info.gemmlowp_offset = lstm_params.hidden_state_zero();
- ARM_COMPUTE_RETURN_ON_ERROR(NEGEMMLowpOutputStage::validate(&hidden_mul_res, nullptr, output_state_out, gemmlowp_info));
+ ARM_COMPUTE_RETURN_ON_ERROR(NEGEMMLowpOutputStage::validate(&hidden_mul_res, nullptr, &hidden_out_info, gemmlowp_info));
+
+ const bool projection_tensor_copy_required = num_units != output_size;
// Projection.
if(lstm_params.has_projection())
@@ -719,11 +811,30 @@ Status NEQLSTMLayer::validate(const ITensorInfo *input,
gemmlowp_info.gemmlowp_max_bound = std::numeric_limits<int8_t>::max();
gemmlowp_info.output_data_type = DataType::QASYMM8_SIGNED;
- const TensorInfo projection_outstage_info(*output_state_out);
- validate_mm(gemmlowp_info, output_state_out, &recurrent_weights_transposed, &eff_bias_info, input_to_output_scale, &mm_out_info, &projection_outstage_info);
+ TensorInfo projection_outstage_info(hidden_out_info);
+
+ if(projection_tensor_copy_required)
+ {
+ TensorInfo projection_eff_bias_adjusted_info{ projection_eff_bias_info };
+ projection_eff_bias_adjusted_info.set_tensor_shape(TensorShape(num_units));
+
+ ARM_COMPUTE_RETURN_ON_ERROR(NEQLSTMLayer::TensorCopyKernel::validate(projection_eff_bias_info, projection_eff_bias_adjusted_info));
+ }
+
+ validate_mm(gemmlowp_info, output_state_out, &recurrent_weights_transposed, &projection_eff_bias_info, input_to_output_scale, &mm_out_info, &projection_outstage_info);
+
+ if(projection_tensor_copy_required)
+ {
+ ARM_COMPUTE_RETURN_ON_ERROR(NEQLSTMLayer::TensorCopyKernel::validate(*output_state_out, projection_outstage_info));
+ }
ARM_COMPUTE_RETURN_ON_ERROR(NEArithmeticAdditionKernel::validate(output_state_out, output_state_out, output_state_out, ConvertPolicy::SATURATE));
+ if(projection_tensor_copy_required)
+ {
+ ARM_COMPUTE_RETURN_ON_ERROR(NEQLSTMLayer::TensorCopyKernel::validate(projection_outstage_info, *output_state_out));
+ }
+
int8_t quantized_projection_clip{ 0 };
if(lstm_params.projection_clip() > 0.0f)
{
@@ -736,6 +847,13 @@ Status NEQLSTMLayer::validate(const ITensorInfo *input,
quantized_projection_clip)));
}
}
+ else
+ {
+ if(projection_tensor_copy_required)
+ {
+ ARM_COMPUTE_RETURN_ON_ERROR(NEQLSTMLayer::TensorCopyKernel::validate(hidden_out_info, *output_state_out));
+ }
+ }
if(cell_state_out->total_size() > 0)
{
@@ -821,7 +939,7 @@ void NEQLSTMLayer::run()
NEScheduler::get().schedule(&get_layer_norm(LayerNormGate::Input), Window::DimY);
}
- _input_gate_tanh.run();
+ _input_gate_sigmoid.run();
}
// Cell.
@@ -842,6 +960,7 @@ void NEQLSTMLayer::run()
if(_has_peephole)
{
NEScheduler::get().schedule(&_pixelwise_mul_cell_to_output, Window::DimY);
+ _cell_to_output_outstage.run();
NEScheduler::get().schedule(&_accumulate_cell_to_output, Window::DimY);
}
@@ -860,14 +979,38 @@ void NEQLSTMLayer::run()
// Projection.
if(_has_projection)
{
+ if(_projection_tensor_copy_required)
+ {
+ _projection_bias_copy.run();
+ }
+
_mm_projection.run();
_projection_outstage.run();
+
+ if(_projection_tensor_copy_required)
+ {
+ _projection_output_to_accumulate_copy.run();
+ }
+
NEScheduler::get().schedule(&_accumulate_projection, Window::DimY);
+
+ if(_projection_tensor_copy_required)
+ {
+ _projection_accumulate_to_output_copy.run();
+ }
+
if(_has_projection_clipping)
{
_projection_clip.run();
}
}
+ else
+ {
+ if(_projection_tensor_copy_required)
+ {
+ _hidden_to_output_copy.run();
+ }
+ }
}
void NEQLSTMLayer::prepare()
@@ -932,6 +1075,13 @@ void NEQLSTMLayer::prepare()
_projection_weights_transposed.allocator()->allocate();
_transpose_projection_weights.run();
_projection_weights->mark_as_unused();
+
+ if(!_projection_tensor_copy_required)
+ {
+ _hidden_gate.mark_as_unused();
+ _projection_eff_bias_adjusted.mark_as_unused();
+ _projection_accumulate_res.mark_as_unused();
+ }
}
// Mark weights as unused