From a7431aeef244c85f621b70b946d25229e42d1708 Mon Sep 17 00:00:00 2001 From: Sang-Hoon Park Date: Tue, 12 May 2020 11:13:30 +0100 Subject: COMPMID-3439: Fix peephole and projection in CLQLSTMLayer The followings are essential to make it work - QSYMM16 is added as supported data type in CLGEMMLowpOutputStage - Internal TensorCopyKernel is added similar to NEQLSTMLayer The followings are fix for related things. - Projection is modified to remove copy of projection_bias from NEQLSTMLayer. - Fix wrong argument for validate_mm() - validate_mm() now returns on error. Change-Id: Icbd04e9fdb8821eb41dd3e0a6a0980965b779714 Signed-off-by: Sang-Hoon Park Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/3177 Comments-Addressed: Arm Jenkins Tested-by: Arm Jenkins Reviewed-by: Michele Di Giorgio --- arm_compute/runtime/CL/functions/CLQLSTMLayer.h | 50 ++++- arm_compute/runtime/NEON/functions/NEQLSTMLayer.h | 1 - src/runtime/CL/functions/CLGEMMLowpOutputStage.cpp | 2 +- src/runtime/CL/functions/CLQLSTMLayer.cpp | 212 +++++++++++++++++---- src/runtime/NEON/functions/NEQLSTMLayer.cpp | 59 ++---- 5 files changed, 243 insertions(+), 81 deletions(-) diff --git a/arm_compute/runtime/CL/functions/CLQLSTMLayer.h b/arm_compute/runtime/CL/functions/CLQLSTMLayer.h index 219f46ee48..67e8bc7751 100644 --- a/arm_compute/runtime/CL/functions/CLQLSTMLayer.h +++ b/arm_compute/runtime/CL/functions/CLQLSTMLayer.h @@ -230,7 +230,8 @@ private: Output, Count }; - static constexpr uint8_t _layer_norm_count = static_cast(LayerNormGate::Count); + static constexpr uint8_t _layer_norm_count = static_cast(LayerNormGate::Count); + static constexpr uint32_t _out_state_output_size_dimension_idx = 0; /** Internal method to configure matrix multiplication plus output stage of each gate. * @@ -254,6 +255,35 @@ private: MemoryGroup _memory_group{}; + /** A small internel kernel do the copy between two tensors */ + class TensorCopyKernel + { + static constexpr uint32_t max_dimension_supported = 2; + + ICLTensor *_src{ nullptr }; + ICLTensor *_dst{ nullptr }; + size_t _row_size{}; + Window _window{}; + + public: + /** Static function to check if given info will lead to a valid configuration of @ref CLQLSTMLayer::TensorCopyKernel + * + * @param[in] src Source tensor info. + * @param[in] dst Destination tensor info + * + * @return a status + */ + static Status validate(const ITensorInfo &src, const ITensorInfo &dst); + /** Set the input and output tensors. + * + * @param[in] src Source tensor + * @param[out] dst Destination tensor + */ + void configure(ICLTensor &src, ICLTensor &dst); + /** run the kernel */ + void run(); + }; + // Functions used CLTranspose _transpose_input_to_forget_weights{}; CLTranspose _transpose_input_to_cell_weights{}; @@ -298,7 +328,7 @@ private: CLPixelWiseMultiplicationKernel _pixelwise_mul_cell_to_input{}; CLGEMMLowpOutputStage _cell_to_input_outstage{}; CLSaturatedArithmeticOperationKernel _accumulate_cell_input{}; - CLActivationLayer _input_gate_tanh{}; + CLActivationLayer _input_gate_sigmoid{}; CLPixelWiseMultiplicationKernel _pixelwise_mul_forget_cell{}; CLPixelWiseMultiplicationKernel _pixelwise_mul_input_cell{}; CLSaturatedArithmeticOperationKernel _add_forget_cell{}; @@ -309,6 +339,7 @@ private: CLGEMMLowpOutputStage _recurrent_to_output_outstage{}; CLSaturatedArithmeticOperationKernel _accumulate_input_recurrent_output{}; CLPixelWiseMultiplicationKernel _pixelwise_mul_cell_to_output{}; + CLGEMMLowpOutputStage _cell_to_output_outstage{}; CLSaturatedArithmeticOperationKernel _accumulate_cell_to_output{}; CLActivationLayer _output_gate_sigmoid{}; CLActivationLayer _hidden_tanh{}; @@ -321,11 +352,13 @@ private: std::array _layer_norms{ {} }; CLCopyKernel _copy_output{}; + TensorCopyKernel _projection_bias_copy{}; + TensorCopyKernel _projection_output_to_accumulate_copy{}; + TensorCopyKernel _projection_accumulate_to_output_copy{}; + TensorCopyKernel _hidden_to_output_copy{}; + // Tensor pointers - const ICLTensor *_input_to_input_weights - { - nullptr - }; + const ICLTensor *_input_to_input_weights{ nullptr }; const ICLTensor *_recurrent_to_input_weights{ nullptr }; const ICLTensor *_projection_bias{ nullptr }; const ICLTensor *_input_to_forget_weights{ nullptr }; @@ -435,11 +468,15 @@ private: CLTensor _input_to_output_outstage_res{ nullptr }; CLTensor _mm_recurrent_to_output_res{ nullptr }; CLTensor _mul_cell_to_output_res{ nullptr }; + CLTensor _cell_to_output_outstage_res{ nullptr }; CLTensor _recurrent_to_output_outstage_res{ nullptr }; CLTensor _output_gate{ nullptr }; CLTensor _hidden_mul_res{ nullptr }; + CLTensor _hidden_gate{ nullptr }; CLTensor _mm_projection_res{ nullptr }; CLTensor _projection_outstage_res{ nullptr }; + CLTensor _projection_out_res{ nullptr }; + CLTensor _projection_accumulate_res{ nullptr }; CLTensor _ones{ nullptr }; std::array _layer_norm_output{ {} }; @@ -455,6 +492,7 @@ private: bool _has_projection_clipping{ false }; bool _has_peephole{ false }; bool _has_layer_norm{ false }; + bool _projection_tensor_copy_required{ false }; }; } // namespace arm_compute #endif /* ARM_COMPUTE_CLQLSTMLAYER_H */ diff --git a/arm_compute/runtime/NEON/functions/NEQLSTMLayer.h b/arm_compute/runtime/NEON/functions/NEQLSTMLayer.h index 4dde85e895..d1cc962940 100644 --- a/arm_compute/runtime/NEON/functions/NEQLSTMLayer.h +++ b/arm_compute/runtime/NEON/functions/NEQLSTMLayer.h @@ -426,7 +426,6 @@ private: Tensor _mm_projection_res{ nullptr }; Tensor _projection_outstage_res{ nullptr }; Tensor _projection_out_res{ nullptr }; - Tensor _projection_eff_bias_adjusted{ nullptr }; Tensor _projection_accumulate_res{ nullptr }; Tensor _ones{ nullptr }; std::array _layer_norm_output{ {} }; diff --git a/src/runtime/CL/functions/CLGEMMLowpOutputStage.cpp b/src/runtime/CL/functions/CLGEMMLowpOutputStage.cpp index 18e002aa3d..9ae5d5121c 100644 --- a/src/runtime/CL/functions/CLGEMMLowpOutputStage.cpp +++ b/src/runtime/CL/functions/CLGEMMLowpOutputStage.cpp @@ -242,7 +242,7 @@ void CLGEMMLowpOutputStage::configure(const CLCompileContext &compile_context, c Status CLGEMMLowpOutputStage::validate(const ITensorInfo *input, const ITensorInfo *bias, const ITensorInfo *output, const GEMMLowpOutputStageInfo &info) { ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(output); - ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, DataType::QASYMM8, DataType::QASYMM8_SIGNED); + ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, DataType::QASYMM8, DataType::QASYMM8_SIGNED, DataType::QSYMM16); switch(info.type) { diff --git a/src/runtime/CL/functions/CLQLSTMLayer.cpp b/src/runtime/CL/functions/CLQLSTMLayer.cpp index a20ffc6f37..60e42a500d 100644 --- a/src/runtime/CL/functions/CLQLSTMLayer.cpp +++ b/src/runtime/CL/functions/CLQLSTMLayer.cpp @@ -46,6 +46,44 @@ Status validate_mm(GEMMLowpOutputStageInfo &gemmlowp_info, const ITensorInfo *mm } } // namespace +Status CLQLSTMLayer::TensorCopyKernel::validate(const ITensorInfo &src, const ITensorInfo &dst) +{ + ARM_COMPUTE_RETURN_ERROR_ON(src.tensor_shape().num_dimensions() > max_dimension_supported); + ARM_COMPUTE_RETURN_ERROR_ON(dst.tensor_shape().num_dimensions() > max_dimension_supported); + ARM_COMPUTE_ERROR_ON_MISMATCHING_DATA_TYPES(&src, &dst); + ARM_COMPUTE_RETURN_ERROR_ON(dst.tensor_shape().y() != src.tensor_shape().y()); + return Status{}; +} + +void CLQLSTMLayer::TensorCopyKernel::configure(ICLTensor &src, ICLTensor &dst) +{ + ARM_COMPUTE_ERROR_THROW_ON(CLQLSTMLayer::TensorCopyKernel::validate(*src.info(), *dst.info())); + _src = &src; + _dst = &dst; + _row_size = std::min(_src->info()->tensor_shape().x(), _dst->info()->tensor_shape().x()); + _window = calculate_max_window(*_src->info(), Steps()); +} + +void CLQLSTMLayer::TensorCopyKernel::run() +{ + auto &q = CLScheduler::get().queue(); + + _src->map(q, true); + _dst->map(q, true); + + Iterator input_iter{ _src, _window }; + Iterator output_iter{ _dst, _window }; + + execute_window_loop(_window, [&](const Coordinates &) + { + memcpy(output_iter.ptr(), input_iter.ptr(), _row_size); + }, + input_iter, output_iter); + + _src->unmap(q); + _dst->unmap(q); +} + CLQLSTMLayer::CLQLSTMLayer(std::shared_ptr memory_manager) { _memory_group = MemoryGroup(std::move(memory_manager)); @@ -108,8 +146,9 @@ void CLQLSTMLayer::configure(const CLCompileContext &compile_context, const ICLT cell_state_in->info(), output_state_in->info(), cell_state_out->info(), output_state_out->info(), output->info(), lstm_params_info)); - const int batch_size = input->info()->dimension(1); - const int num_units = input_to_output_weights->info()->dimension(1); + const int batch_size = input->info()->dimension(1); + const int num_units = input_to_output_weights->info()->dimension(1); + const int output_size = output_state_out->info()->dimension(_out_state_output_size_dimension_idx); const UniformQuantizationInfo qinput = input->info()->quantization_info().uniform(); const UniformQuantizationInfo qcell_state_in = cell_state_in->info()->quantization_info().uniform(); @@ -169,10 +208,9 @@ void CLQLSTMLayer::configure(const CLCompileContext &compile_context, const ICLT _recurrent_to_cell_reduction.configure(compile_context, recurrent_to_cell_weights, &_recurrent_to_cell_eff_bias, GEMMLowpReductionKernelInfo(num_units, false, -qoutput_state_in.offset, true)); _input_to_output_reduction.configure(compile_context, input_to_output_weights, &_input_to_output_eff_bias, GEMMLowpReductionKernelInfo(num_units, false, -qinput.offset, true)); _recurrent_to_output_reduction.configure(compile_context, recurrent_to_output_weights, &_recurrent_to_output_eff_bias, GEMMLowpReductionKernelInfo(num_units, false, -qoutput_state_in.offset, true)); - if(_projection_bias != nullptr) + if(_has_projection) { - _projection_reduction.configure(compile_context, _projection_weights, &_projection_reduction_res, GEMMLowpReductionKernelInfo(num_units, false, lstm_params.hidden_state_zero(), true)); - _projection_bias_add.configure(compile_context, ArithmeticOperation::ADD, _projection_bias, &_projection_reduction_res, &_projection_eff_bias, ConvertPolicy::SATURATE); + _projection_reduction.configure(compile_context, _projection_weights, &_projection_eff_bias, GEMMLowpReductionKernelInfo(output_size, false, lstm_params.hidden_state_zero(), true)); } // Pre-transpose weights to be used in GEMM. @@ -219,6 +257,7 @@ void CLQLSTMLayer::configure(const CLCompileContext &compile_context, const ICLT if(_has_peephole) { + _mul_cell_to_forget_res.allocator()->init(TensorInfo(cell_state_in->info()->tensor_shape(), 1, DataType::S32)); _memory_group.manage(&_mul_cell_to_forget_res); _pixelwise_mul_cell_to_forget.configure(compile_context, cell_state_in, lstm_params.cell_to_forget_weights(), &_mul_cell_to_forget_res, 1.f, ConvertPolicy::SATURATE, RoundingPolicy::TO_ZERO); _cell_to_forget_outstage_res.allocator()->init(TensorInfo(_mul_cell_to_forget_res.info()->tensor_shape(), 1, DataType::QSYMM16, QuantizationInfo(lstm_params.forget_intermediate_scale(), 0))); @@ -304,7 +343,7 @@ void CLQLSTMLayer::configure(const CLCompileContext &compile_context, const ICLT const float recurrent_to_input_scale = _recurrent_to_input_weights->info()->quantization_info().uniform().scale * qoutput_state_in.scale / lstm_params.input_intermediate_scale(); configure_mm(compile_context, _mm_recurrent_to_input, _recurrent_to_input_outstage, gemmlowp_info, - input, &_recurrent_to_input_weights_transposed, &_recurrent_to_input_eff_bias, + output_state_in, &_recurrent_to_input_weights_transposed, &_recurrent_to_input_eff_bias, &_mm_recurrent_to_input_res, &_recurrent_to_input_outstage_res, recurrent_to_input_scale, mm_out_info, input_outstage_info); _accumulate_input_recurrent_input.configure(compile_context, ArithmeticOperation::ADD, &_input_to_input_outstage_res, &_recurrent_to_input_outstage_res, &_recurrent_to_input_outstage_res, @@ -313,6 +352,7 @@ void CLQLSTMLayer::configure(const CLCompileContext &compile_context, const ICLT if(_has_peephole) { + _mul_cell_to_input_res.allocator()->init(TensorInfo(cell_state_in->info()->tensor_shape(), 1, DataType::S32)); _memory_group.manage(&_mul_cell_to_input_res); _pixelwise_mul_cell_to_input.configure(compile_context, cell_state_in, lstm_params.cell_to_input_weights(), &_mul_cell_to_input_res, 1.f, ConvertPolicy::SATURATE, RoundingPolicy::TO_ZERO); const float cell_to_input_scale = std::pow(2, cell_shift) * lstm_params.cell_to_input_weights()->info()->quantization_info().uniform().scale / lstm_params.input_intermediate_scale(); @@ -334,7 +374,7 @@ void CLQLSTMLayer::configure(const CLCompileContext &compile_context, const ICLT input_activation_input = &get_layer_norm_output(LayerNormGate::Input); } - _input_gate_tanh.configure(compile_context, input_activation_input, &_input_gate, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::TANH, 1.f, 1.f)); + _input_gate_sigmoid.configure(compile_context, input_activation_input, &_input_gate, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LOGISTIC)); input_activation_input->allocator()->allocate(); } // Cell. @@ -376,13 +416,20 @@ void CLQLSTMLayer::configure(const CLCompileContext &compile_context, const ICLT { // TODO(COMPMID-3396): Perform multiplication in the quantized domain in CLPixelWiseMultiplicationKernel // Here we are not using the output stage because all operations are done in float - // const float cell_to_output_scale = std::pow(2, cell_shift) * lstm_params.cell_to_output_weights()->info()->quantization_info().uniform().scale / lstm_params.output_intermediate_scale(); - // quantization::calculate_quantized_multiplier(cell_to_output_scale, &gemmlowp_info.gemmlowp_multiplier, &gemmlowp_info.gemmlowp_shift); + _mul_cell_to_output_res.allocator()->init(TensorInfo(cell_state_out->info()->tensor_shape(), 1, DataType::S32)); _memory_group.manage(&_mul_cell_to_output_res); _pixelwise_mul_cell_to_output.configure(compile_context, cell_state_out, lstm_params.cell_to_output_weights(), &_mul_cell_to_output_res, 1.f, ConvertPolicy::SATURATE, RoundingPolicy::TO_ZERO); - _accumulate_cell_to_output.configure(compile_context, ArithmeticOperation::ADD, &_recurrent_to_output_outstage_res, &_mul_cell_to_output_res, &_recurrent_to_output_outstage_res, - ConvertPolicy::SATURATE); + + const float cell_to_output_scale = std::pow(2, cell_shift) * lstm_params.cell_to_output_weights()->info()->quantization_info().uniform().scale / lstm_params.output_intermediate_scale(); + quantization::calculate_quantized_multiplier(cell_to_output_scale, &gemmlowp_info.gemmlowp_multiplier, &gemmlowp_info.gemmlowp_shift); + _cell_to_output_outstage_res.allocator()->init(TensorInfo(_mul_cell_to_output_res.info()->tensor_shape(), 1, DataType::QSYMM16, QuantizationInfo(lstm_params.output_intermediate_scale(), 0))); + _memory_group.manage(&_cell_to_output_outstage_res); + _cell_to_output_outstage.configure(compile_context, &_mul_cell_to_output_res, nullptr, &_cell_to_output_outstage_res, gemmlowp_info); _mul_cell_to_output_res.allocator()->allocate(); + + _accumulate_cell_to_output.configure(compile_context, ArithmeticOperation::ADD, &_recurrent_to_output_outstage_res, &_cell_to_output_outstage_res, &_recurrent_to_output_outstage_res, + ConvertPolicy::SATURATE); + _cell_to_output_outstage_res.allocator()->allocate(); } CLTensor *output_activation_input = &_recurrent_to_output_outstage_res; @@ -413,7 +460,20 @@ void CLQLSTMLayer::configure(const CLCompileContext &compile_context, const ICLT quantization::calculate_quantized_multiplier(hidden_state_scale, &gemmlowp_info.gemmlowp_multiplier, &gemmlowp_info.gemmlowp_shift, /* ignore_epsilon */ true); gemmlowp_info.gemmlowp_offset = lstm_params.hidden_state_zero(); gemmlowp_info.output_data_type = output_state_in->info()->data_type(); - _hidden_outstage.configure(compile_context, &_hidden_mul_res, nullptr, output_state_out, gemmlowp_info); + + _projection_tensor_copy_required = (num_units != output_size); + ICLTensor *hidden_gate_result = output_state_out; + + _memory_group.manage(&_hidden_gate); + + if(_projection_tensor_copy_required) + { + _hidden_gate.allocator()->init(*output_state_out->info()); + _hidden_gate.info()->set_tensor_shape(_hidden_mul_res.info()->tensor_shape()); + hidden_gate_result = &_hidden_gate; + } + + _hidden_outstage.configure(compile_context, &_hidden_mul_res, nullptr, hidden_gate_result, gemmlowp_info); _hidden_mul_res.allocator()->allocate(); // Projection. @@ -427,14 +487,34 @@ void CLQLSTMLayer::configure(const CLCompileContext &compile_context, const ICLT gemmlowp_info.gemmlowp_max_bound = std::numeric_limits::max(); gemmlowp_info.output_data_type = DataType::QASYMM8_SIGNED; + TensorInfo projection_mm_out_info{ mm_out_info }; + projection_mm_out_info.set_tensor_shape(TensorShape(output_size, batch_size)); + configure_mm(compile_context, _mm_projection, _projection_outstage, gemmlowp_info, - output_state_out, &_projection_weights_transposed, &_projection_eff_bias, + hidden_gate_result, &_projection_weights_transposed, &_projection_eff_bias, &_mm_projection_res, &_projection_outstage_res, projection_scale, - mm_out_info, projection_outstage_info); + projection_mm_out_info, projection_outstage_info); - _accumulate_projection.configure(compile_context, ArithmeticOperation::ADD, &_projection_outstage_res, output_state_out, output_state_out, ConvertPolicy::SATURATE); + ICLTensor *accumulate_destination = output_state_out; + + if(_projection_tensor_copy_required) + { + _hidden_gate.allocator()->allocate(); + _projection_accumulate_res.allocator()->init(*output_state_out->info()); + _projection_accumulate_res.info()->set_tensor_shape(_projection_outstage_res.info()->tensor_shape()); + _projection_output_to_accumulate_copy.configure(*output_state_out, _projection_accumulate_res); + accumulate_destination = &_projection_accumulate_res; + } + + _accumulate_projection.configure(compile_context, ArithmeticOperation::ADD, &_projection_outstage_res, accumulate_destination, accumulate_destination, ConvertPolicy::SATURATE); _projection_outstage_res.allocator()->allocate(); + if(_projection_tensor_copy_required) + { + _projection_accumulate_to_output_copy.configure(_projection_accumulate_res, *output_state_out); + _projection_accumulate_res.allocator()->allocate(); + } + int8_t quantized_projection_clip{ 0 }; if(lstm_params.projection_clip() > 0.0f) { @@ -448,6 +528,14 @@ void CLQLSTMLayer::configure(const CLCompileContext &compile_context, const ICLT _has_projection_clipping = true; } } + else + { + if(_projection_tensor_copy_required) + { + _hidden_to_output_copy.configure(_hidden_gate, *output_state_out); + _hidden_gate.allocator()->allocate(); + } + } // Copy output_state_out to output _copy_output.configure(compile_context, output_state_out, output); @@ -471,7 +559,7 @@ Status CLQLSTMLayer::validate(const ITensorInfo *input, const unsigned int input_size = input->dimension(0); const unsigned int batch_size = input->dimension(1); const unsigned int num_units = input_to_output_weights->dimension(1); - const unsigned int output_size = recurrent_to_output_weights->dimension(0); + const unsigned int output_size = output_state_out->dimension(_out_state_output_size_dimension_idx); ARM_COMPUTE_RETURN_ERROR_ON(input_to_output_weights->num_dimensions() != 2); ARM_COMPUTE_RETURN_ERROR_ON(input_to_output_weights->dimension(0) != input_size); @@ -534,6 +622,7 @@ Status CLQLSTMLayer::validate(const ITensorInfo *input, // Precompute effective bias for optimizing the matmul computations. const TensorInfo eff_bias_info(TensorShape(num_units), 1, DataType::S32); + const TensorInfo projection_eff_bias_info(TensorShape(output_size), 1, DataType::S32); if(!lstm_params.has_cifg_opt()) { ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMLowpMatrixAReductionKernel::validate(lstm_params.input_to_input_weights(), &eff_bias_info, GEMMLowpReductionKernelInfo(num_units, false, -qinput.offset, true))); @@ -546,11 +635,11 @@ Status CLQLSTMLayer::validate(const ITensorInfo *input, ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMLowpMatrixAReductionKernel::validate(recurrent_to_cell_weights, &eff_bias_info, GEMMLowpReductionKernelInfo(num_units, false, -qoutput_state_in.offset, true))); ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMLowpMatrixAReductionKernel::validate(input_to_output_weights, &eff_bias_info, GEMMLowpReductionKernelInfo(num_units, false, -qinput.offset, true))); ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMLowpMatrixAReductionKernel::validate(recurrent_to_output_weights, &eff_bias_info, GEMMLowpReductionKernelInfo(num_units, false, -qoutput_state_in.offset, true))); - if(lstm_params.projection_bias() != nullptr) + if(lstm_params.has_projection()) { - ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMLowpMatrixAReductionKernel::validate(lstm_params.projection_weights(), &eff_bias_info, GEMMLowpReductionKernelInfo(num_units, false, lstm_params.hidden_state_zero(), + ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMLowpMatrixAReductionKernel::validate(lstm_params.projection_weights(), &projection_eff_bias_info, GEMMLowpReductionKernelInfo(output_size, false, + lstm_params.hidden_state_zero(), true))); - ARM_COMPUTE_RETURN_ON_ERROR(CLSaturatedArithmeticOperationKernel::validate(ArithmeticOperation::ADD, lstm_params.projection_bias(), &eff_bias_info, &eff_bias_info, ConvertPolicy::SATURATE)); } const TensorInfo input_weights_transposed(TensorShape(num_units, input_size), 1, input_to_forget_weights->data_type(), input_to_forget_weights->quantization_info()); @@ -570,7 +659,8 @@ Status CLQLSTMLayer::validate(const ITensorInfo *input, } if(lstm_params.has_projection()) { - ARM_COMPUTE_RETURN_ON_ERROR(CLTranspose::validate(lstm_params.projection_weights(), &recurrent_weights_transposed)); + const TensorInfo projection_weights_transposed(TensorShape(output_size, num_units), 1, lstm_params.projection_weights()->data_type(), lstm_params.projection_weights()->quantization_info()); + ARM_COMPUTE_RETURN_ON_ERROR(CLTranspose::validate(lstm_params.projection_weights(), &projection_weights_transposed)); } GEMMLowpOutputStageInfo gemmlowp_info; @@ -585,10 +675,10 @@ Status CLQLSTMLayer::validate(const ITensorInfo *input, const TensorInfo forget_outstage_info(TensorShape(num_units, batch_size), 1, DataType::QSYMM16, QuantizationInfo(lstm_params.forget_intermediate_scale(), 0)); const TensorInfo mm_out_info(TensorShape(num_units, batch_size), 1, DataType::S32); const float input_to_forget_scale = input_to_forget_weights->quantization_info().uniform().scale * qinput.scale / lstm_params.forget_intermediate_scale(); - validate_mm(gemmlowp_info, input, &input_weights_transposed, &eff_bias_info, input_to_forget_scale, &mm_out_info, &forget_outstage_info); + ARM_COMPUTE_RETURN_ON_ERROR(validate_mm(gemmlowp_info, input, &input_weights_transposed, &eff_bias_info, input_to_forget_scale, &mm_out_info, &forget_outstage_info)); const float recurrent_to_forget_scale = recurrent_to_forget_weights->quantization_info().uniform().scale * qoutput_state_in.scale / lstm_params.forget_intermediate_scale(); - validate_mm(gemmlowp_info, input, &recurrent_weights_transposed, &eff_bias_info, recurrent_to_forget_scale, &mm_out_info, &forget_outstage_info); + ARM_COMPUTE_RETURN_ON_ERROR(validate_mm(gemmlowp_info, output_state_in, &recurrent_weights_transposed, &eff_bias_info, recurrent_to_forget_scale, &mm_out_info, &forget_outstage_info)); ARM_COMPUTE_RETURN_ON_ERROR(CLSaturatedArithmeticOperationKernel::validate(ArithmeticOperation::ADD, &forget_outstage_info, &forget_outstage_info, &forget_outstage_info, ConvertPolicy::SATURATE)); @@ -619,10 +709,10 @@ Status CLQLSTMLayer::validate(const ITensorInfo *input, // Modulation gate. const TensorInfo cell_outstage_info(TensorShape(num_units, batch_size), 1, DataType::QSYMM16, QuantizationInfo(lstm_params.cell_intermediate_scale(), 0)); const float input_to_cell_scale = input_to_cell_weights->quantization_info().uniform().scale * qinput.scale / lstm_params.cell_intermediate_scale(); - validate_mm(gemmlowp_info, input, &input_weights_transposed, &eff_bias_info, input_to_cell_scale, &mm_out_info, &cell_outstage_info); + ARM_COMPUTE_RETURN_ON_ERROR(validate_mm(gemmlowp_info, input, &input_weights_transposed, &eff_bias_info, input_to_cell_scale, &mm_out_info, &cell_outstage_info)); const float recurrent_to_cell_scale = recurrent_to_cell_weights->quantization_info().uniform().scale * qoutput_state_in.scale / lstm_params.cell_intermediate_scale(); - validate_mm(gemmlowp_info, input, &input_weights_transposed, &eff_bias_info, recurrent_to_cell_scale, &mm_out_info, &cell_outstage_info); + ARM_COMPUTE_RETURN_ON_ERROR(validate_mm(gemmlowp_info, output_state_in, &input_weights_transposed, &eff_bias_info, recurrent_to_cell_scale, &mm_out_info, &cell_outstage_info)); ARM_COMPUTE_RETURN_ON_ERROR(CLSaturatedArithmeticOperationKernel::validate(ArithmeticOperation::ADD, &cell_outstage_info, &cell_outstage_info, &cell_outstage_info, ConvertPolicy::SATURATE)); @@ -652,23 +742,22 @@ Status CLQLSTMLayer::validate(const ITensorInfo *input, ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(forget_gate_bias, lstm_params.input_gate_bias()); ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(forget_gate_bias, lstm_params.input_gate_bias()); - ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMLowpMatrixMultiplyCore::validate(input, lstm_params.input_to_input_weights(), nullptr, &mm_out_info)); const TensorInfo input_outstage_info(TensorShape(num_units, batch_size), 1, DataType::QSYMM16, QuantizationInfo(lstm_params.input_intermediate_scale(), 0)); const float input_to_input_scale = lstm_params.input_to_input_weights()->quantization_info().uniform().scale * qinput.scale / lstm_params.input_intermediate_scale(); - validate_mm(gemmlowp_info, input, lstm_params.input_to_input_weights(), &eff_bias_info, input_to_input_scale, &mm_out_info, &input_outstage_info); + ARM_COMPUTE_RETURN_ON_ERROR(validate_mm(gemmlowp_info, input, &input_weights_transposed, &eff_bias_info, input_to_input_scale, &mm_out_info, &input_outstage_info)); const float recurrent_to_input_scale = lstm_params.recurrent_to_input_weights()->quantization_info().uniform().scale * qoutput_state_in.scale / lstm_params.input_intermediate_scale(); - validate_mm(gemmlowp_info, input, lstm_params.recurrent_to_input_weights(), &eff_bias_info, recurrent_to_input_scale, &mm_out_info, &input_outstage_info); + ARM_COMPUTE_RETURN_ON_ERROR(validate_mm(gemmlowp_info, output_state_in, &recurrent_weights_transposed, &eff_bias_info, recurrent_to_input_scale, &mm_out_info, &input_outstage_info)); ARM_COMPUTE_RETURN_ON_ERROR(CLSaturatedArithmeticOperationKernel::validate(ArithmeticOperation::ADD, &input_outstage_info, &input_outstage_info, &input_outstage_info, ConvertPolicy::SATURATE)); if(lstm_params.has_peephole_opt()) { - ARM_COMPUTE_RETURN_ON_ERROR(CLPixelWiseMultiplicationKernel::validate(cell_state_in, lstm_params.cell_to_input_weights(), &input_outstage_info, 1.f, ConvertPolicy::SATURATE, + ARM_COMPUTE_RETURN_ON_ERROR(CLPixelWiseMultiplicationKernel::validate(cell_state_in, lstm_params.cell_to_input_weights(), &mm_out_info, 1.f, ConvertPolicy::SATURATE, RoundingPolicy::TO_ZERO)); const float cell_to_input_scale = std::pow(2, cell_shift) * lstm_params.cell_to_input_weights()->quantization_info().uniform().scale / lstm_params.input_intermediate_scale(); ARM_COMPUTE_RETURN_ON_ERROR(quantization::calculate_quantized_multiplier(cell_to_input_scale, &gemmlowp_info.gemmlowp_multiplier, &gemmlowp_info.gemmlowp_shift)); - ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMLowpOutputStage::validate(&input_outstage_info, &eff_bias_info, &input_outstage_info, gemmlowp_info)); + ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMLowpOutputStage::validate(&mm_out_info, &eff_bias_info, &input_outstage_info, gemmlowp_info)); ARM_COMPUTE_RETURN_ON_ERROR(CLSaturatedArithmeticOperationKernel::validate(ArithmeticOperation::ADD, &input_outstage_info, &input_outstage_info, &input_outstage_info, ConvertPolicy::SATURATE)); } @@ -679,7 +768,7 @@ Status CLQLSTMLayer::validate(const ITensorInfo *input, ARM_COMPUTE_RETURN_ON_ERROR(validate_layer_norm(cell_outstage_info, *w_info, *b_info)); } - ARM_COMPUTE_RETURN_ON_ERROR(CLActivationLayer::validate(&input_outstage_info, nullptr, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::TANH, 1.f, 1.f))); + ARM_COMPUTE_RETURN_ON_ERROR(CLActivationLayer::validate(&input_outstage_info, &input_gate_info, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LOGISTIC, 1.f, 1.f))); } // Cell. ARM_COMPUTE_RETURN_ON_ERROR(CLPixelWiseMultiplicationKernel::validate(&forget_gate_info, cell_state_in, &forget_gate_info, 1.f, ConvertPolicy::SATURATE, RoundingPolicy::TO_ZERO)); @@ -693,10 +782,10 @@ Status CLQLSTMLayer::validate(const ITensorInfo *input, // Output gate. const TensorInfo output_outstage_info(TensorShape(num_units, batch_size), 1, DataType::QSYMM16, QuantizationInfo(lstm_params.output_intermediate_scale(), 0)); const float input_to_output_scale = input_to_output_weights->quantization_info().uniform().scale * qinput.scale / lstm_params.output_intermediate_scale(); - validate_mm(gemmlowp_info, input, &input_weights_transposed, &eff_bias_info, input_to_output_scale, &mm_out_info, &output_outstage_info); + ARM_COMPUTE_RETURN_ON_ERROR(validate_mm(gemmlowp_info, input, &input_weights_transposed, &eff_bias_info, input_to_output_scale, &mm_out_info, &output_outstage_info)); const float recurrent_to_output_scale = recurrent_to_output_weights->quantization_info().uniform().scale * qoutput_state_in.scale / lstm_params.output_intermediate_scale(); - validate_mm(gemmlowp_info, output_state_in, &recurrent_weights_transposed, &eff_bias_info, recurrent_to_output_scale, &mm_out_info, &output_outstage_info); + ARM_COMPUTE_RETURN_ON_ERROR(validate_mm(gemmlowp_info, output_state_in, &recurrent_weights_transposed, &eff_bias_info, recurrent_to_output_scale, &mm_out_info, &output_outstage_info)); ARM_COMPUTE_RETURN_ON_ERROR(CLSaturatedArithmeticOperationKernel::validate(ArithmeticOperation::ADD, &output_outstage_info, &output_outstage_info, &output_outstage_info, ConvertPolicy::SATURATE)); if(lstm_params.has_peephole_opt()) @@ -724,11 +813,15 @@ Status CLQLSTMLayer::validate(const ITensorInfo *input, // Hidden. ARM_COMPUTE_RETURN_ON_ERROR(CLActivationLayer::validate(cell_state_out, &input_gate_info, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::TANH, 1.f, 1.f))); const TensorInfo hidden_mul_res(TensorShape(num_units, batch_size), 1, DataType::S32); + const TensorInfo hidden_out_info(TensorShape(num_units, batch_size), 1, DataType::QASYMM8_SIGNED); + ARM_COMPUTE_RETURN_ON_ERROR(CLPixelWiseMultiplicationKernel::validate(&output_gate_info, &input_gate_info, &hidden_mul_res, 1.f, ConvertPolicy::SATURATE, RoundingPolicy::TO_ZERO)); const float hidden_state_scale = std::pow(2, -15) / lstm_params.hidden_state_scale() * std::pow(2, -15); ARM_COMPUTE_RETURN_ON_ERROR(quantization::calculate_quantized_multiplier(hidden_state_scale, &gemmlowp_info.gemmlowp_multiplier, &gemmlowp_info.gemmlowp_shift, /* ignore_epsilon */ true)); gemmlowp_info.gemmlowp_offset = lstm_params.hidden_state_zero(); - ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMLowpOutputStage::validate(&hidden_mul_res, nullptr, output_state_out, gemmlowp_info)); + ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMLowpOutputStage::validate(&hidden_mul_res, nullptr, &hidden_out_info, gemmlowp_info)); + + const bool projection_tensor_copy_required = num_units != output_size; // Projection. if(lstm_params.has_projection()) @@ -745,10 +838,26 @@ Status CLQLSTMLayer::validate(const ITensorInfo *input, gemmlowp_info.output_data_type = DataType::QASYMM8_SIGNED; const TensorInfo projection_outstage_info(*output_state_out); - validate_mm(gemmlowp_info, output_state_out, &recurrent_weights_transposed, &eff_bias_info, input_to_output_scale, &mm_out_info, &projection_outstage_info); + const TensorInfo projection_weights_transposed(TensorShape(output_size, num_units), 1, lstm_params.projection_weights()->data_type(), lstm_params.projection_weights()->quantization_info()); + + TensorInfo projection_mm_out_info{ mm_out_info }; + projection_mm_out_info.set_tensor_shape(TensorShape(output_size, batch_size)); + + ARM_COMPUTE_RETURN_ON_ERROR(validate_mm(gemmlowp_info, &hidden_out_info, &projection_weights_transposed, &projection_eff_bias_info, projection_scale, &projection_mm_out_info, + &projection_outstage_info)); + + if(projection_tensor_copy_required) + { + ARM_COMPUTE_RETURN_ON_ERROR(CLQLSTMLayer::TensorCopyKernel::validate(*output_state_out, projection_outstage_info)); + } ARM_COMPUTE_RETURN_ON_ERROR(CLSaturatedArithmeticOperationKernel::validate(ArithmeticOperation::ADD, output_state_out, output_state_out, output_state_out, ConvertPolicy::SATURATE)); + if(projection_tensor_copy_required) + { + ARM_COMPUTE_RETURN_ON_ERROR(CLQLSTMLayer::TensorCopyKernel::validate(projection_outstage_info, *output_state_out)); + } + int8_t quantized_projection_clip{ 0 }; if(lstm_params.projection_clip() > 0.0f) { @@ -761,6 +870,13 @@ Status CLQLSTMLayer::validate(const ITensorInfo *input, quantized_projection_clip))); } } + else + { + if(projection_tensor_copy_required) + { + ARM_COMPUTE_RETURN_ON_ERROR(CLQLSTMLayer::TensorCopyKernel::validate(hidden_out_info, *output_state_out)); + } + } if(cell_state_out->total_size() > 0) { @@ -847,7 +963,7 @@ void CLQLSTMLayer::run() CLScheduler::get().enqueue(get_layer_norm(LayerNormGate::Input)); } - _input_gate_tanh.run(); + _input_gate_sigmoid.run(); } // Cell. @@ -868,6 +984,7 @@ void CLQLSTMLayer::run() if(_has_peephole) { CLScheduler::get().enqueue(_pixelwise_mul_cell_to_output); + _cell_to_output_outstage.run(); CLScheduler::get().enqueue(_accumulate_cell_to_output); } @@ -888,12 +1005,31 @@ void CLQLSTMLayer::run() { _mm_projection.run(); _projection_outstage.run(); + + if(_projection_tensor_copy_required) + { + _projection_output_to_accumulate_copy.run(); + } + CLScheduler::get().enqueue(_accumulate_projection); + + if(_projection_tensor_copy_required) + { + _projection_accumulate_to_output_copy.run(); + } + if(_has_projection_clipping) { _projection_clip.run(); } } + else + { + if(_projection_tensor_copy_required) + { + _hidden_to_output_copy.run(); + } + } // Copy output_state_out to output CLScheduler::get().enqueue(_copy_output); @@ -963,6 +1099,12 @@ void CLQLSTMLayer::prepare() _projection_weights_transposed.allocator()->allocate(); _transpose_projection_weights.run(); _projection_weights->mark_as_unused(); + + if(!_projection_tensor_copy_required) + { + _hidden_gate.mark_as_unused(); + _projection_accumulate_res.mark_as_unused(); + } } // Mark weights as unused diff --git a/src/runtime/NEON/functions/NEQLSTMLayer.cpp b/src/runtime/NEON/functions/NEQLSTMLayer.cpp index 466c41307b..beb180fda5 100644 --- a/src/runtime/NEON/functions/NEQLSTMLayer.cpp +++ b/src/runtime/NEON/functions/NEQLSTMLayer.cpp @@ -437,6 +437,8 @@ void NEQLSTMLayer::configure(const ITensor *input, _projection_tensor_copy_required = (num_units != output_size); ITensor *hidden_gate_result = output_state_out; + _memory_group.manage(&_hidden_gate); + if(_projection_tensor_copy_required) { _hidden_gate.allocator()->init(*output_state_out->info()); @@ -450,7 +452,7 @@ void NEQLSTMLayer::configure(const ITensor *input, // Projection. if(_has_projection) { - const TensorInfo projection_outstage_info(*hidden_gate_result->info()); + const TensorInfo projection_outstage_info(*output_state_out->info()); const UniformQuantizationInfo qprojection = _projection_weights->info()->quantization_info().uniform(); const float projection_scale = qprojection.scale * lstm_params.hidden_state_scale() / qoutput_state_in.scale; gemmlowp_info.gemmlowp_offset = qoutput_state_in.offset; @@ -458,23 +460,13 @@ void NEQLSTMLayer::configure(const ITensor *input, gemmlowp_info.gemmlowp_max_bound = std::numeric_limits::max(); gemmlowp_info.output_data_type = DataType::QASYMM8_SIGNED; - _memory_group.manage(&_projection_eff_bias_adjusted); - ITensor *bias_to_use = &_projection_eff_bias; - - if(_projection_tensor_copy_required) - { - _projection_eff_bias_adjusted.allocator()->init(*_projection_eff_bias.info()); - _projection_eff_bias_adjusted.info()->set_tensor_shape(TensorShape(num_units)); - _projection_bias_copy.configure(_projection_eff_bias, _projection_eff_bias_adjusted); - bias_to_use = &_projection_eff_bias_adjusted; - } + TensorInfo projection_mm_out_info{ mm_out_info }; + projection_mm_out_info.set_tensor_shape(TensorShape(output_size, batch_size)); configure_mm(_mm_projection, _projection_outstage, gemmlowp_info, - hidden_gate_result, &_projection_weights_transposed, bias_to_use, + hidden_gate_result, &_projection_weights_transposed, &_projection_eff_bias, &_mm_projection_res, &_projection_outstage_res, projection_scale, - mm_out_info, projection_outstage_info); - - _projection_eff_bias_adjusted.allocator()->allocate(); + projection_mm_out_info, projection_outstage_info); ITensor *accumulate_destination = output_state_out; @@ -655,10 +647,10 @@ Status NEQLSTMLayer::validate(const ITensorInfo *input, const TensorInfo forget_outstage_info(TensorShape(num_units, batch_size), 1, DataType::QSYMM16, QuantizationInfo(lstm_params.forget_intermediate_scale(), 0)); const TensorInfo mm_out_info(TensorShape(num_units, batch_size), 1, DataType::S32); const float input_to_forget_scale = input_to_forget_weights->quantization_info().uniform().scale * qinput.scale / lstm_params.forget_intermediate_scale(); - validate_mm(gemmlowp_info, input, &input_weights_transposed, &eff_bias_info, input_to_forget_scale, &mm_out_info, &forget_outstage_info); + ARM_COMPUTE_RETURN_ON_ERROR(validate_mm(gemmlowp_info, input, &input_weights_transposed, &eff_bias_info, input_to_forget_scale, &mm_out_info, &forget_outstage_info)); const float recurrent_to_forget_scale = recurrent_to_forget_weights->quantization_info().uniform().scale * qoutput_state_in.scale / lstm_params.forget_intermediate_scale(); - validate_mm(gemmlowp_info, input, &recurrent_weights_transposed, &eff_bias_info, recurrent_to_forget_scale, &mm_out_info, &forget_outstage_info); + ARM_COMPUTE_RETURN_ON_ERROR(validate_mm(gemmlowp_info, output_state_in, &recurrent_weights_transposed, &eff_bias_info, recurrent_to_forget_scale, &mm_out_info, &forget_outstage_info)); ARM_COMPUTE_RETURN_ON_ERROR(NEArithmeticAdditionKernel::validate(&forget_outstage_info, &forget_outstage_info, &forget_outstage_info, ConvertPolicy::SATURATE)); @@ -689,10 +681,10 @@ Status NEQLSTMLayer::validate(const ITensorInfo *input, // Modulation gate. const TensorInfo cell_outstage_info(TensorShape(num_units, batch_size), 1, DataType::QSYMM16, QuantizationInfo(lstm_params.cell_intermediate_scale(), 0)); const float input_to_cell_scale = input_to_cell_weights->quantization_info().uniform().scale * qinput.scale / lstm_params.cell_intermediate_scale(); - validate_mm(gemmlowp_info, input, &input_weights_transposed, &eff_bias_info, input_to_cell_scale, &mm_out_info, &cell_outstage_info); + ARM_COMPUTE_RETURN_ON_ERROR(validate_mm(gemmlowp_info, input, &input_weights_transposed, &eff_bias_info, input_to_cell_scale, &mm_out_info, &cell_outstage_info)); const float recurrent_to_cell_scale = recurrent_to_cell_weights->quantization_info().uniform().scale * qoutput_state_in.scale / lstm_params.cell_intermediate_scale(); - validate_mm(gemmlowp_info, input, &input_weights_transposed, &eff_bias_info, recurrent_to_cell_scale, &mm_out_info, &cell_outstage_info); + ARM_COMPUTE_RETURN_ON_ERROR(validate_mm(gemmlowp_info, output_state_in, &recurrent_weights_transposed, &eff_bias_info, recurrent_to_cell_scale, &mm_out_info, &cell_outstage_info)); ARM_COMPUTE_RETURN_ON_ERROR(NEArithmeticAdditionKernel::validate(&cell_outstage_info, &cell_outstage_info, &cell_outstage_info, ConvertPolicy::SATURATE)); @@ -724,10 +716,10 @@ Status NEQLSTMLayer::validate(const ITensorInfo *input, const TensorInfo input_outstage_info(TensorShape(num_units, batch_size), 1, DataType::QSYMM16, QuantizationInfo(lstm_params.input_intermediate_scale(), 0)); const float input_to_input_scale = lstm_params.input_to_input_weights()->quantization_info().uniform().scale * qinput.scale / lstm_params.input_intermediate_scale(); - validate_mm(gemmlowp_info, input, &input_weights_transposed, &eff_bias_info, input_to_input_scale, &mm_out_info, &input_outstage_info); + ARM_COMPUTE_RETURN_ON_ERROR(validate_mm(gemmlowp_info, input, &input_weights_transposed, &eff_bias_info, input_to_input_scale, &mm_out_info, &input_outstage_info)); const float recurrent_to_input_scale = lstm_params.recurrent_to_input_weights()->quantization_info().uniform().scale * qoutput_state_in.scale / lstm_params.input_intermediate_scale(); - validate_mm(gemmlowp_info, input, &recurrent_weights_transposed, &eff_bias_info, recurrent_to_input_scale, &mm_out_info, &input_outstage_info); + ARM_COMPUTE_RETURN_ON_ERROR(validate_mm(gemmlowp_info, output_state_in, &recurrent_weights_transposed, &eff_bias_info, recurrent_to_input_scale, &mm_out_info, &input_outstage_info)); ARM_COMPUTE_RETURN_ON_ERROR(NEArithmeticAdditionKernel::validate(&input_outstage_info, &input_outstage_info, &input_outstage_info, ConvertPolicy::SATURATE)); @@ -762,10 +754,10 @@ Status NEQLSTMLayer::validate(const ITensorInfo *input, // Output gate. const TensorInfo output_outstage_info(TensorShape(num_units, batch_size), 1, DataType::QSYMM16, QuantizationInfo(lstm_params.output_intermediate_scale(), 0)); const float input_to_output_scale = input_to_output_weights->quantization_info().uniform().scale * qinput.scale / lstm_params.output_intermediate_scale(); - validate_mm(gemmlowp_info, input, &input_weights_transposed, &eff_bias_info, input_to_output_scale, &mm_out_info, &output_outstage_info); + ARM_COMPUTE_RETURN_ON_ERROR(validate_mm(gemmlowp_info, input, &input_weights_transposed, &eff_bias_info, input_to_output_scale, &mm_out_info, &output_outstage_info)); const float recurrent_to_output_scale = recurrent_to_output_weights->quantization_info().uniform().scale * qoutput_state_in.scale / lstm_params.output_intermediate_scale(); - validate_mm(gemmlowp_info, output_state_in, &recurrent_weights_transposed, &eff_bias_info, recurrent_to_output_scale, &mm_out_info, &output_outstage_info); + ARM_COMPUTE_RETURN_ON_ERROR(validate_mm(gemmlowp_info, output_state_in, &recurrent_weights_transposed, &eff_bias_info, recurrent_to_output_scale, &mm_out_info, &output_outstage_info)); ARM_COMPUTE_RETURN_ON_ERROR(NEArithmeticAdditionKernel::validate(&output_outstage_info, &output_outstage_info, &output_outstage_info, ConvertPolicy::SATURATE)); if(lstm_params.has_peephole_opt()) @@ -816,17 +808,14 @@ Status NEQLSTMLayer::validate(const ITensorInfo *input, gemmlowp_info.gemmlowp_max_bound = std::numeric_limits::max(); gemmlowp_info.output_data_type = DataType::QASYMM8_SIGNED; - TensorInfo projection_outstage_info(hidden_out_info); - - if(projection_tensor_copy_required) - { - TensorInfo projection_eff_bias_adjusted_info{ projection_eff_bias_info }; - projection_eff_bias_adjusted_info.set_tensor_shape(TensorShape(num_units)); + const TensorInfo projection_outstage_info(*output_state_out); + const TensorInfo projection_weights_transposed(TensorShape(output_size, num_units), 1, lstm_params.projection_weights()->data_type(), lstm_params.projection_weights()->quantization_info()); - ARM_COMPUTE_RETURN_ON_ERROR(NEQLSTMLayer::TensorCopyKernel::validate(projection_eff_bias_info, projection_eff_bias_adjusted_info)); - } + TensorInfo projection_mm_out_info{ mm_out_info }; + projection_mm_out_info.set_tensor_shape(TensorShape(output_size, batch_size)); - validate_mm(gemmlowp_info, output_state_out, &recurrent_weights_transposed, &projection_eff_bias_info, input_to_output_scale, &mm_out_info, &projection_outstage_info); + ARM_COMPUTE_RETURN_ON_ERROR(validate_mm(gemmlowp_info, &hidden_out_info, &projection_weights_transposed, &projection_eff_bias_info, projection_scale, &projection_mm_out_info, + &projection_outstage_info)); if(projection_tensor_copy_required) { @@ -985,11 +974,6 @@ void NEQLSTMLayer::run() // Projection. if(_has_projection) { - if(_projection_tensor_copy_required) - { - _projection_bias_copy.run(); - } - _mm_projection.run(); _projection_outstage.run(); @@ -1088,7 +1072,6 @@ void NEQLSTMLayer::prepare() if(!_projection_tensor_copy_required) { _hidden_gate.mark_as_unused(); - _projection_eff_bias_adjusted.mark_as_unused(); _projection_accumulate_res.mark_as_unused(); } } -- cgit v1.2.1