aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSang-Hoon Park <sang-hoon.park@arm.com>2020-05-12 11:13:30 +0100
committerSang-Hoon Park <sang-hoon.park@arm.com>2020-05-12 16:25:57 +0000
commita7431aeef244c85f621b70b946d25229e42d1708 (patch)
tree62f74403008cad9cb812202865d016addf711a18
parent1f567afcdfb2919fab417f0060155deda7132df8 (diff)
downloadComputeLibrary-a7431aeef244c85f621b70b946d25229e42d1708.tar.gz
COMPMID-3439: Fix peephole and projection in CLQLSTMLayer
The followings are essential to make it work - QSYMM16 is added as supported data type in CLGEMMLowpOutputStage - Internal TensorCopyKernel is added similar to NEQLSTMLayer The followings are fix for related things. - Projection is modified to remove copy of projection_bias from NEQLSTMLayer. - Fix wrong argument for validate_mm() - validate_mm() now returns on error. Change-Id: Icbd04e9fdb8821eb41dd3e0a6a0980965b779714 Signed-off-by: Sang-Hoon Park <sang-hoon.park@arm.com> Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/3177 Comments-Addressed: Arm Jenkins <bsgcomp@arm.com> Tested-by: Arm Jenkins <bsgcomp@arm.com> Reviewed-by: Michele Di Giorgio <michele.digiorgio@arm.com>
-rw-r--r--arm_compute/runtime/CL/functions/CLQLSTMLayer.h50
-rw-r--r--arm_compute/runtime/NEON/functions/NEQLSTMLayer.h1
-rw-r--r--src/runtime/CL/functions/CLGEMMLowpOutputStage.cpp2
-rw-r--r--src/runtime/CL/functions/CLQLSTMLayer.cpp212
-rw-r--r--src/runtime/NEON/functions/NEQLSTMLayer.cpp59
5 files changed, 243 insertions, 81 deletions
diff --git a/arm_compute/runtime/CL/functions/CLQLSTMLayer.h b/arm_compute/runtime/CL/functions/CLQLSTMLayer.h
index 219f46ee48..67e8bc7751 100644
--- a/arm_compute/runtime/CL/functions/CLQLSTMLayer.h
+++ b/arm_compute/runtime/CL/functions/CLQLSTMLayer.h
@@ -230,7 +230,8 @@ private:
Output,
Count
};
- static constexpr uint8_t _layer_norm_count = static_cast<uint8_t>(LayerNormGate::Count);
+ static constexpr uint8_t _layer_norm_count = static_cast<uint8_t>(LayerNormGate::Count);
+ static constexpr uint32_t _out_state_output_size_dimension_idx = 0;
/** Internal method to configure matrix multiplication plus output stage of each gate.
*
@@ -254,6 +255,35 @@ private:
MemoryGroup _memory_group{};
+ /** A small internel kernel do the copy between two tensors */
+ class TensorCopyKernel
+ {
+ static constexpr uint32_t max_dimension_supported = 2;
+
+ ICLTensor *_src{ nullptr };
+ ICLTensor *_dst{ nullptr };
+ size_t _row_size{};
+ Window _window{};
+
+ public:
+ /** Static function to check if given info will lead to a valid configuration of @ref CLQLSTMLayer::TensorCopyKernel
+ *
+ * @param[in] src Source tensor info.
+ * @param[in] dst Destination tensor info
+ *
+ * @return a status
+ */
+ static Status validate(const ITensorInfo &src, const ITensorInfo &dst);
+ /** Set the input and output tensors.
+ *
+ * @param[in] src Source tensor
+ * @param[out] dst Destination tensor
+ */
+ void configure(ICLTensor &src, ICLTensor &dst);
+ /** run the kernel */
+ void run();
+ };
+
// Functions used
CLTranspose _transpose_input_to_forget_weights{};
CLTranspose _transpose_input_to_cell_weights{};
@@ -298,7 +328,7 @@ private:
CLPixelWiseMultiplicationKernel _pixelwise_mul_cell_to_input{};
CLGEMMLowpOutputStage _cell_to_input_outstage{};
CLSaturatedArithmeticOperationKernel _accumulate_cell_input{};
- CLActivationLayer _input_gate_tanh{};
+ CLActivationLayer _input_gate_sigmoid{};
CLPixelWiseMultiplicationKernel _pixelwise_mul_forget_cell{};
CLPixelWiseMultiplicationKernel _pixelwise_mul_input_cell{};
CLSaturatedArithmeticOperationKernel _add_forget_cell{};
@@ -309,6 +339,7 @@ private:
CLGEMMLowpOutputStage _recurrent_to_output_outstage{};
CLSaturatedArithmeticOperationKernel _accumulate_input_recurrent_output{};
CLPixelWiseMultiplicationKernel _pixelwise_mul_cell_to_output{};
+ CLGEMMLowpOutputStage _cell_to_output_outstage{};
CLSaturatedArithmeticOperationKernel _accumulate_cell_to_output{};
CLActivationLayer _output_gate_sigmoid{};
CLActivationLayer _hidden_tanh{};
@@ -321,11 +352,13 @@ private:
std::array<CLQLSTMLayerNormalizationKernel, _layer_norm_count> _layer_norms{ {} };
CLCopyKernel _copy_output{};
+ TensorCopyKernel _projection_bias_copy{};
+ TensorCopyKernel _projection_output_to_accumulate_copy{};
+ TensorCopyKernel _projection_accumulate_to_output_copy{};
+ TensorCopyKernel _hidden_to_output_copy{};
+
// Tensor pointers
- const ICLTensor *_input_to_input_weights
- {
- nullptr
- };
+ const ICLTensor *_input_to_input_weights{ nullptr };
const ICLTensor *_recurrent_to_input_weights{ nullptr };
const ICLTensor *_projection_bias{ nullptr };
const ICLTensor *_input_to_forget_weights{ nullptr };
@@ -435,11 +468,15 @@ private:
CLTensor _input_to_output_outstage_res{ nullptr };
CLTensor _mm_recurrent_to_output_res{ nullptr };
CLTensor _mul_cell_to_output_res{ nullptr };
+ CLTensor _cell_to_output_outstage_res{ nullptr };
CLTensor _recurrent_to_output_outstage_res{ nullptr };
CLTensor _output_gate{ nullptr };
CLTensor _hidden_mul_res{ nullptr };
+ CLTensor _hidden_gate{ nullptr };
CLTensor _mm_projection_res{ nullptr };
CLTensor _projection_outstage_res{ nullptr };
+ CLTensor _projection_out_res{ nullptr };
+ CLTensor _projection_accumulate_res{ nullptr };
CLTensor _ones{ nullptr };
std::array<CLTensor, _layer_norm_count> _layer_norm_output{ {} };
@@ -455,6 +492,7 @@ private:
bool _has_projection_clipping{ false };
bool _has_peephole{ false };
bool _has_layer_norm{ false };
+ bool _projection_tensor_copy_required{ false };
};
} // namespace arm_compute
#endif /* ARM_COMPUTE_CLQLSTMLAYER_H */
diff --git a/arm_compute/runtime/NEON/functions/NEQLSTMLayer.h b/arm_compute/runtime/NEON/functions/NEQLSTMLayer.h
index 4dde85e895..d1cc962940 100644
--- a/arm_compute/runtime/NEON/functions/NEQLSTMLayer.h
+++ b/arm_compute/runtime/NEON/functions/NEQLSTMLayer.h
@@ -426,7 +426,6 @@ private:
Tensor _mm_projection_res{ nullptr };
Tensor _projection_outstage_res{ nullptr };
Tensor _projection_out_res{ nullptr };
- Tensor _projection_eff_bias_adjusted{ nullptr };
Tensor _projection_accumulate_res{ nullptr };
Tensor _ones{ nullptr };
std::array<Tensor, _layer_norm_count> _layer_norm_output{ {} };
diff --git a/src/runtime/CL/functions/CLGEMMLowpOutputStage.cpp b/src/runtime/CL/functions/CLGEMMLowpOutputStage.cpp
index 18e002aa3d..9ae5d5121c 100644
--- a/src/runtime/CL/functions/CLGEMMLowpOutputStage.cpp
+++ b/src/runtime/CL/functions/CLGEMMLowpOutputStage.cpp
@@ -242,7 +242,7 @@ void CLGEMMLowpOutputStage::configure(const CLCompileContext &compile_context, c
Status CLGEMMLowpOutputStage::validate(const ITensorInfo *input, const ITensorInfo *bias, const ITensorInfo *output, const GEMMLowpOutputStageInfo &info)
{
ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(output);
- ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, DataType::QASYMM8, DataType::QASYMM8_SIGNED);
+ ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, DataType::QASYMM8, DataType::QASYMM8_SIGNED, DataType::QSYMM16);
switch(info.type)
{
diff --git a/src/runtime/CL/functions/CLQLSTMLayer.cpp b/src/runtime/CL/functions/CLQLSTMLayer.cpp
index a20ffc6f37..60e42a500d 100644
--- a/src/runtime/CL/functions/CLQLSTMLayer.cpp
+++ b/src/runtime/CL/functions/CLQLSTMLayer.cpp
@@ -46,6 +46,44 @@ Status validate_mm(GEMMLowpOutputStageInfo &gemmlowp_info, const ITensorInfo *mm
}
} // namespace
+Status CLQLSTMLayer::TensorCopyKernel::validate(const ITensorInfo &src, const ITensorInfo &dst)
+{
+ ARM_COMPUTE_RETURN_ERROR_ON(src.tensor_shape().num_dimensions() > max_dimension_supported);
+ ARM_COMPUTE_RETURN_ERROR_ON(dst.tensor_shape().num_dimensions() > max_dimension_supported);
+ ARM_COMPUTE_ERROR_ON_MISMATCHING_DATA_TYPES(&src, &dst);
+ ARM_COMPUTE_RETURN_ERROR_ON(dst.tensor_shape().y() != src.tensor_shape().y());
+ return Status{};
+}
+
+void CLQLSTMLayer::TensorCopyKernel::configure(ICLTensor &src, ICLTensor &dst)
+{
+ ARM_COMPUTE_ERROR_THROW_ON(CLQLSTMLayer::TensorCopyKernel::validate(*src.info(), *dst.info()));
+ _src = &src;
+ _dst = &dst;
+ _row_size = std::min(_src->info()->tensor_shape().x(), _dst->info()->tensor_shape().x());
+ _window = calculate_max_window(*_src->info(), Steps());
+}
+
+void CLQLSTMLayer::TensorCopyKernel::run()
+{
+ auto &q = CLScheduler::get().queue();
+
+ _src->map(q, true);
+ _dst->map(q, true);
+
+ Iterator input_iter{ _src, _window };
+ Iterator output_iter{ _dst, _window };
+
+ execute_window_loop(_window, [&](const Coordinates &)
+ {
+ memcpy(output_iter.ptr(), input_iter.ptr(), _row_size);
+ },
+ input_iter, output_iter);
+
+ _src->unmap(q);
+ _dst->unmap(q);
+}
+
CLQLSTMLayer::CLQLSTMLayer(std::shared_ptr<IMemoryManager> memory_manager)
{
_memory_group = MemoryGroup(std::move(memory_manager));
@@ -108,8 +146,9 @@ void CLQLSTMLayer::configure(const CLCompileContext &compile_context, const ICLT
cell_state_in->info(), output_state_in->info(), cell_state_out->info(), output_state_out->info(), output->info(),
lstm_params_info));
- const int batch_size = input->info()->dimension(1);
- const int num_units = input_to_output_weights->info()->dimension(1);
+ const int batch_size = input->info()->dimension(1);
+ const int num_units = input_to_output_weights->info()->dimension(1);
+ const int output_size = output_state_out->info()->dimension(_out_state_output_size_dimension_idx);
const UniformQuantizationInfo qinput = input->info()->quantization_info().uniform();
const UniformQuantizationInfo qcell_state_in = cell_state_in->info()->quantization_info().uniform();
@@ -169,10 +208,9 @@ void CLQLSTMLayer::configure(const CLCompileContext &compile_context, const ICLT
_recurrent_to_cell_reduction.configure(compile_context, recurrent_to_cell_weights, &_recurrent_to_cell_eff_bias, GEMMLowpReductionKernelInfo(num_units, false, -qoutput_state_in.offset, true));
_input_to_output_reduction.configure(compile_context, input_to_output_weights, &_input_to_output_eff_bias, GEMMLowpReductionKernelInfo(num_units, false, -qinput.offset, true));
_recurrent_to_output_reduction.configure(compile_context, recurrent_to_output_weights, &_recurrent_to_output_eff_bias, GEMMLowpReductionKernelInfo(num_units, false, -qoutput_state_in.offset, true));
- if(_projection_bias != nullptr)
+ if(_has_projection)
{
- _projection_reduction.configure(compile_context, _projection_weights, &_projection_reduction_res, GEMMLowpReductionKernelInfo(num_units, false, lstm_params.hidden_state_zero(), true));
- _projection_bias_add.configure(compile_context, ArithmeticOperation::ADD, _projection_bias, &_projection_reduction_res, &_projection_eff_bias, ConvertPolicy::SATURATE);
+ _projection_reduction.configure(compile_context, _projection_weights, &_projection_eff_bias, GEMMLowpReductionKernelInfo(output_size, false, lstm_params.hidden_state_zero(), true));
}
// Pre-transpose weights to be used in GEMM.
@@ -219,6 +257,7 @@ void CLQLSTMLayer::configure(const CLCompileContext &compile_context, const ICLT
if(_has_peephole)
{
+ _mul_cell_to_forget_res.allocator()->init(TensorInfo(cell_state_in->info()->tensor_shape(), 1, DataType::S32));
_memory_group.manage(&_mul_cell_to_forget_res);
_pixelwise_mul_cell_to_forget.configure(compile_context, cell_state_in, lstm_params.cell_to_forget_weights(), &_mul_cell_to_forget_res, 1.f, ConvertPolicy::SATURATE, RoundingPolicy::TO_ZERO);
_cell_to_forget_outstage_res.allocator()->init(TensorInfo(_mul_cell_to_forget_res.info()->tensor_shape(), 1, DataType::QSYMM16, QuantizationInfo(lstm_params.forget_intermediate_scale(), 0)));
@@ -304,7 +343,7 @@ void CLQLSTMLayer::configure(const CLCompileContext &compile_context, const ICLT
const float recurrent_to_input_scale = _recurrent_to_input_weights->info()->quantization_info().uniform().scale * qoutput_state_in.scale / lstm_params.input_intermediate_scale();
configure_mm(compile_context, _mm_recurrent_to_input, _recurrent_to_input_outstage, gemmlowp_info,
- input, &_recurrent_to_input_weights_transposed, &_recurrent_to_input_eff_bias,
+ output_state_in, &_recurrent_to_input_weights_transposed, &_recurrent_to_input_eff_bias,
&_mm_recurrent_to_input_res, &_recurrent_to_input_outstage_res, recurrent_to_input_scale,
mm_out_info, input_outstage_info);
_accumulate_input_recurrent_input.configure(compile_context, ArithmeticOperation::ADD, &_input_to_input_outstage_res, &_recurrent_to_input_outstage_res, &_recurrent_to_input_outstage_res,
@@ -313,6 +352,7 @@ void CLQLSTMLayer::configure(const CLCompileContext &compile_context, const ICLT
if(_has_peephole)
{
+ _mul_cell_to_input_res.allocator()->init(TensorInfo(cell_state_in->info()->tensor_shape(), 1, DataType::S32));
_memory_group.manage(&_mul_cell_to_input_res);
_pixelwise_mul_cell_to_input.configure(compile_context, cell_state_in, lstm_params.cell_to_input_weights(), &_mul_cell_to_input_res, 1.f, ConvertPolicy::SATURATE, RoundingPolicy::TO_ZERO);
const float cell_to_input_scale = std::pow(2, cell_shift) * lstm_params.cell_to_input_weights()->info()->quantization_info().uniform().scale / lstm_params.input_intermediate_scale();
@@ -334,7 +374,7 @@ void CLQLSTMLayer::configure(const CLCompileContext &compile_context, const ICLT
input_activation_input = &get_layer_norm_output(LayerNormGate::Input);
}
- _input_gate_tanh.configure(compile_context, input_activation_input, &_input_gate, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::TANH, 1.f, 1.f));
+ _input_gate_sigmoid.configure(compile_context, input_activation_input, &_input_gate, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LOGISTIC));
input_activation_input->allocator()->allocate();
}
// Cell.
@@ -376,13 +416,20 @@ void CLQLSTMLayer::configure(const CLCompileContext &compile_context, const ICLT
{
// TODO(COMPMID-3396): Perform multiplication in the quantized domain in CLPixelWiseMultiplicationKernel
// Here we are not using the output stage because all operations are done in float
- // const float cell_to_output_scale = std::pow(2, cell_shift) * lstm_params.cell_to_output_weights()->info()->quantization_info().uniform().scale / lstm_params.output_intermediate_scale();
- // quantization::calculate_quantized_multiplier(cell_to_output_scale, &gemmlowp_info.gemmlowp_multiplier, &gemmlowp_info.gemmlowp_shift);
+ _mul_cell_to_output_res.allocator()->init(TensorInfo(cell_state_out->info()->tensor_shape(), 1, DataType::S32));
_memory_group.manage(&_mul_cell_to_output_res);
_pixelwise_mul_cell_to_output.configure(compile_context, cell_state_out, lstm_params.cell_to_output_weights(), &_mul_cell_to_output_res, 1.f, ConvertPolicy::SATURATE, RoundingPolicy::TO_ZERO);
- _accumulate_cell_to_output.configure(compile_context, ArithmeticOperation::ADD, &_recurrent_to_output_outstage_res, &_mul_cell_to_output_res, &_recurrent_to_output_outstage_res,
- ConvertPolicy::SATURATE);
+
+ const float cell_to_output_scale = std::pow(2, cell_shift) * lstm_params.cell_to_output_weights()->info()->quantization_info().uniform().scale / lstm_params.output_intermediate_scale();
+ quantization::calculate_quantized_multiplier(cell_to_output_scale, &gemmlowp_info.gemmlowp_multiplier, &gemmlowp_info.gemmlowp_shift);
+ _cell_to_output_outstage_res.allocator()->init(TensorInfo(_mul_cell_to_output_res.info()->tensor_shape(), 1, DataType::QSYMM16, QuantizationInfo(lstm_params.output_intermediate_scale(), 0)));
+ _memory_group.manage(&_cell_to_output_outstage_res);
+ _cell_to_output_outstage.configure(compile_context, &_mul_cell_to_output_res, nullptr, &_cell_to_output_outstage_res, gemmlowp_info);
_mul_cell_to_output_res.allocator()->allocate();
+
+ _accumulate_cell_to_output.configure(compile_context, ArithmeticOperation::ADD, &_recurrent_to_output_outstage_res, &_cell_to_output_outstage_res, &_recurrent_to_output_outstage_res,
+ ConvertPolicy::SATURATE);
+ _cell_to_output_outstage_res.allocator()->allocate();
}
CLTensor *output_activation_input = &_recurrent_to_output_outstage_res;
@@ -413,7 +460,20 @@ void CLQLSTMLayer::configure(const CLCompileContext &compile_context, const ICLT
quantization::calculate_quantized_multiplier(hidden_state_scale, &gemmlowp_info.gemmlowp_multiplier, &gemmlowp_info.gemmlowp_shift, /* ignore_epsilon */ true);
gemmlowp_info.gemmlowp_offset = lstm_params.hidden_state_zero();
gemmlowp_info.output_data_type = output_state_in->info()->data_type();
- _hidden_outstage.configure(compile_context, &_hidden_mul_res, nullptr, output_state_out, gemmlowp_info);
+
+ _projection_tensor_copy_required = (num_units != output_size);
+ ICLTensor *hidden_gate_result = output_state_out;
+
+ _memory_group.manage(&_hidden_gate);
+
+ if(_projection_tensor_copy_required)
+ {
+ _hidden_gate.allocator()->init(*output_state_out->info());
+ _hidden_gate.info()->set_tensor_shape(_hidden_mul_res.info()->tensor_shape());
+ hidden_gate_result = &_hidden_gate;
+ }
+
+ _hidden_outstage.configure(compile_context, &_hidden_mul_res, nullptr, hidden_gate_result, gemmlowp_info);
_hidden_mul_res.allocator()->allocate();
// Projection.
@@ -427,14 +487,34 @@ void CLQLSTMLayer::configure(const CLCompileContext &compile_context, const ICLT
gemmlowp_info.gemmlowp_max_bound = std::numeric_limits<int8_t>::max();
gemmlowp_info.output_data_type = DataType::QASYMM8_SIGNED;
+ TensorInfo projection_mm_out_info{ mm_out_info };
+ projection_mm_out_info.set_tensor_shape(TensorShape(output_size, batch_size));
+
configure_mm(compile_context, _mm_projection, _projection_outstage, gemmlowp_info,
- output_state_out, &_projection_weights_transposed, &_projection_eff_bias,
+ hidden_gate_result, &_projection_weights_transposed, &_projection_eff_bias,
&_mm_projection_res, &_projection_outstage_res, projection_scale,
- mm_out_info, projection_outstage_info);
+ projection_mm_out_info, projection_outstage_info);
- _accumulate_projection.configure(compile_context, ArithmeticOperation::ADD, &_projection_outstage_res, output_state_out, output_state_out, ConvertPolicy::SATURATE);
+ ICLTensor *accumulate_destination = output_state_out;
+
+ if(_projection_tensor_copy_required)
+ {
+ _hidden_gate.allocator()->allocate();
+ _projection_accumulate_res.allocator()->init(*output_state_out->info());
+ _projection_accumulate_res.info()->set_tensor_shape(_projection_outstage_res.info()->tensor_shape());
+ _projection_output_to_accumulate_copy.configure(*output_state_out, _projection_accumulate_res);
+ accumulate_destination = &_projection_accumulate_res;
+ }
+
+ _accumulate_projection.configure(compile_context, ArithmeticOperation::ADD, &_projection_outstage_res, accumulate_destination, accumulate_destination, ConvertPolicy::SATURATE);
_projection_outstage_res.allocator()->allocate();
+ if(_projection_tensor_copy_required)
+ {
+ _projection_accumulate_to_output_copy.configure(_projection_accumulate_res, *output_state_out);
+ _projection_accumulate_res.allocator()->allocate();
+ }
+
int8_t quantized_projection_clip{ 0 };
if(lstm_params.projection_clip() > 0.0f)
{
@@ -448,6 +528,14 @@ void CLQLSTMLayer::configure(const CLCompileContext &compile_context, const ICLT
_has_projection_clipping = true;
}
}
+ else
+ {
+ if(_projection_tensor_copy_required)
+ {
+ _hidden_to_output_copy.configure(_hidden_gate, *output_state_out);
+ _hidden_gate.allocator()->allocate();
+ }
+ }
// Copy output_state_out to output
_copy_output.configure(compile_context, output_state_out, output);
@@ -471,7 +559,7 @@ Status CLQLSTMLayer::validate(const ITensorInfo *input,
const unsigned int input_size = input->dimension(0);
const unsigned int batch_size = input->dimension(1);
const unsigned int num_units = input_to_output_weights->dimension(1);
- const unsigned int output_size = recurrent_to_output_weights->dimension(0);
+ const unsigned int output_size = output_state_out->dimension(_out_state_output_size_dimension_idx);
ARM_COMPUTE_RETURN_ERROR_ON(input_to_output_weights->num_dimensions() != 2);
ARM_COMPUTE_RETURN_ERROR_ON(input_to_output_weights->dimension(0) != input_size);
@@ -534,6 +622,7 @@ Status CLQLSTMLayer::validate(const ITensorInfo *input,
// Precompute effective bias for optimizing the matmul computations.
const TensorInfo eff_bias_info(TensorShape(num_units), 1, DataType::S32);
+ const TensorInfo projection_eff_bias_info(TensorShape(output_size), 1, DataType::S32);
if(!lstm_params.has_cifg_opt())
{
ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMLowpMatrixAReductionKernel::validate(lstm_params.input_to_input_weights(), &eff_bias_info, GEMMLowpReductionKernelInfo(num_units, false, -qinput.offset, true)));
@@ -546,11 +635,11 @@ Status CLQLSTMLayer::validate(const ITensorInfo *input,
ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMLowpMatrixAReductionKernel::validate(recurrent_to_cell_weights, &eff_bias_info, GEMMLowpReductionKernelInfo(num_units, false, -qoutput_state_in.offset, true)));
ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMLowpMatrixAReductionKernel::validate(input_to_output_weights, &eff_bias_info, GEMMLowpReductionKernelInfo(num_units, false, -qinput.offset, true)));
ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMLowpMatrixAReductionKernel::validate(recurrent_to_output_weights, &eff_bias_info, GEMMLowpReductionKernelInfo(num_units, false, -qoutput_state_in.offset, true)));
- if(lstm_params.projection_bias() != nullptr)
+ if(lstm_params.has_projection())
{
- ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMLowpMatrixAReductionKernel::validate(lstm_params.projection_weights(), &eff_bias_info, GEMMLowpReductionKernelInfo(num_units, false, lstm_params.hidden_state_zero(),
+ ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMLowpMatrixAReductionKernel::validate(lstm_params.projection_weights(), &projection_eff_bias_info, GEMMLowpReductionKernelInfo(output_size, false,
+ lstm_params.hidden_state_zero(),
true)));
- ARM_COMPUTE_RETURN_ON_ERROR(CLSaturatedArithmeticOperationKernel::validate(ArithmeticOperation::ADD, lstm_params.projection_bias(), &eff_bias_info, &eff_bias_info, ConvertPolicy::SATURATE));
}
const TensorInfo input_weights_transposed(TensorShape(num_units, input_size), 1, input_to_forget_weights->data_type(), input_to_forget_weights->quantization_info());
@@ -570,7 +659,8 @@ Status CLQLSTMLayer::validate(const ITensorInfo *input,
}
if(lstm_params.has_projection())
{
- ARM_COMPUTE_RETURN_ON_ERROR(CLTranspose::validate(lstm_params.projection_weights(), &recurrent_weights_transposed));
+ const TensorInfo projection_weights_transposed(TensorShape(output_size, num_units), 1, lstm_params.projection_weights()->data_type(), lstm_params.projection_weights()->quantization_info());
+ ARM_COMPUTE_RETURN_ON_ERROR(CLTranspose::validate(lstm_params.projection_weights(), &projection_weights_transposed));
}
GEMMLowpOutputStageInfo gemmlowp_info;
@@ -585,10 +675,10 @@ Status CLQLSTMLayer::validate(const ITensorInfo *input,
const TensorInfo forget_outstage_info(TensorShape(num_units, batch_size), 1, DataType::QSYMM16, QuantizationInfo(lstm_params.forget_intermediate_scale(), 0));
const TensorInfo mm_out_info(TensorShape(num_units, batch_size), 1, DataType::S32);
const float input_to_forget_scale = input_to_forget_weights->quantization_info().uniform().scale * qinput.scale / lstm_params.forget_intermediate_scale();
- validate_mm(gemmlowp_info, input, &input_weights_transposed, &eff_bias_info, input_to_forget_scale, &mm_out_info, &forget_outstage_info);
+ ARM_COMPUTE_RETURN_ON_ERROR(validate_mm(gemmlowp_info, input, &input_weights_transposed, &eff_bias_info, input_to_forget_scale, &mm_out_info, &forget_outstage_info));
const float recurrent_to_forget_scale = recurrent_to_forget_weights->quantization_info().uniform().scale * qoutput_state_in.scale / lstm_params.forget_intermediate_scale();
- validate_mm(gemmlowp_info, input, &recurrent_weights_transposed, &eff_bias_info, recurrent_to_forget_scale, &mm_out_info, &forget_outstage_info);
+ ARM_COMPUTE_RETURN_ON_ERROR(validate_mm(gemmlowp_info, output_state_in, &recurrent_weights_transposed, &eff_bias_info, recurrent_to_forget_scale, &mm_out_info, &forget_outstage_info));
ARM_COMPUTE_RETURN_ON_ERROR(CLSaturatedArithmeticOperationKernel::validate(ArithmeticOperation::ADD, &forget_outstage_info, &forget_outstage_info, &forget_outstage_info, ConvertPolicy::SATURATE));
@@ -619,10 +709,10 @@ Status CLQLSTMLayer::validate(const ITensorInfo *input,
// Modulation gate.
const TensorInfo cell_outstage_info(TensorShape(num_units, batch_size), 1, DataType::QSYMM16, QuantizationInfo(lstm_params.cell_intermediate_scale(), 0));
const float input_to_cell_scale = input_to_cell_weights->quantization_info().uniform().scale * qinput.scale / lstm_params.cell_intermediate_scale();
- validate_mm(gemmlowp_info, input, &input_weights_transposed, &eff_bias_info, input_to_cell_scale, &mm_out_info, &cell_outstage_info);
+ ARM_COMPUTE_RETURN_ON_ERROR(validate_mm(gemmlowp_info, input, &input_weights_transposed, &eff_bias_info, input_to_cell_scale, &mm_out_info, &cell_outstage_info));
const float recurrent_to_cell_scale = recurrent_to_cell_weights->quantization_info().uniform().scale * qoutput_state_in.scale / lstm_params.cell_intermediate_scale();
- validate_mm(gemmlowp_info, input, &input_weights_transposed, &eff_bias_info, recurrent_to_cell_scale, &mm_out_info, &cell_outstage_info);
+ ARM_COMPUTE_RETURN_ON_ERROR(validate_mm(gemmlowp_info, output_state_in, &input_weights_transposed, &eff_bias_info, recurrent_to_cell_scale, &mm_out_info, &cell_outstage_info));
ARM_COMPUTE_RETURN_ON_ERROR(CLSaturatedArithmeticOperationKernel::validate(ArithmeticOperation::ADD, &cell_outstage_info, &cell_outstage_info, &cell_outstage_info, ConvertPolicy::SATURATE));
@@ -652,23 +742,22 @@ Status CLQLSTMLayer::validate(const ITensorInfo *input,
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(forget_gate_bias, lstm_params.input_gate_bias());
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(forget_gate_bias, lstm_params.input_gate_bias());
- ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMLowpMatrixMultiplyCore::validate(input, lstm_params.input_to_input_weights(), nullptr, &mm_out_info));
const TensorInfo input_outstage_info(TensorShape(num_units, batch_size), 1, DataType::QSYMM16, QuantizationInfo(lstm_params.input_intermediate_scale(), 0));
const float input_to_input_scale = lstm_params.input_to_input_weights()->quantization_info().uniform().scale * qinput.scale / lstm_params.input_intermediate_scale();
- validate_mm(gemmlowp_info, input, lstm_params.input_to_input_weights(), &eff_bias_info, input_to_input_scale, &mm_out_info, &input_outstage_info);
+ ARM_COMPUTE_RETURN_ON_ERROR(validate_mm(gemmlowp_info, input, &input_weights_transposed, &eff_bias_info, input_to_input_scale, &mm_out_info, &input_outstage_info));
const float recurrent_to_input_scale = lstm_params.recurrent_to_input_weights()->quantization_info().uniform().scale * qoutput_state_in.scale / lstm_params.input_intermediate_scale();
- validate_mm(gemmlowp_info, input, lstm_params.recurrent_to_input_weights(), &eff_bias_info, recurrent_to_input_scale, &mm_out_info, &input_outstage_info);
+ ARM_COMPUTE_RETURN_ON_ERROR(validate_mm(gemmlowp_info, output_state_in, &recurrent_weights_transposed, &eff_bias_info, recurrent_to_input_scale, &mm_out_info, &input_outstage_info));
ARM_COMPUTE_RETURN_ON_ERROR(CLSaturatedArithmeticOperationKernel::validate(ArithmeticOperation::ADD, &input_outstage_info, &input_outstage_info, &input_outstage_info, ConvertPolicy::SATURATE));
if(lstm_params.has_peephole_opt())
{
- ARM_COMPUTE_RETURN_ON_ERROR(CLPixelWiseMultiplicationKernel::validate(cell_state_in, lstm_params.cell_to_input_weights(), &input_outstage_info, 1.f, ConvertPolicy::SATURATE,
+ ARM_COMPUTE_RETURN_ON_ERROR(CLPixelWiseMultiplicationKernel::validate(cell_state_in, lstm_params.cell_to_input_weights(), &mm_out_info, 1.f, ConvertPolicy::SATURATE,
RoundingPolicy::TO_ZERO));
const float cell_to_input_scale = std::pow(2, cell_shift) * lstm_params.cell_to_input_weights()->quantization_info().uniform().scale / lstm_params.input_intermediate_scale();
ARM_COMPUTE_RETURN_ON_ERROR(quantization::calculate_quantized_multiplier(cell_to_input_scale, &gemmlowp_info.gemmlowp_multiplier, &gemmlowp_info.gemmlowp_shift));
- ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMLowpOutputStage::validate(&input_outstage_info, &eff_bias_info, &input_outstage_info, gemmlowp_info));
+ ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMLowpOutputStage::validate(&mm_out_info, &eff_bias_info, &input_outstage_info, gemmlowp_info));
ARM_COMPUTE_RETURN_ON_ERROR(CLSaturatedArithmeticOperationKernel::validate(ArithmeticOperation::ADD, &input_outstage_info, &input_outstage_info, &input_outstage_info, ConvertPolicy::SATURATE));
}
@@ -679,7 +768,7 @@ Status CLQLSTMLayer::validate(const ITensorInfo *input,
ARM_COMPUTE_RETURN_ON_ERROR(validate_layer_norm(cell_outstage_info, *w_info, *b_info));
}
- ARM_COMPUTE_RETURN_ON_ERROR(CLActivationLayer::validate(&input_outstage_info, nullptr, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::TANH, 1.f, 1.f)));
+ ARM_COMPUTE_RETURN_ON_ERROR(CLActivationLayer::validate(&input_outstage_info, &input_gate_info, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LOGISTIC, 1.f, 1.f)));
}
// Cell.
ARM_COMPUTE_RETURN_ON_ERROR(CLPixelWiseMultiplicationKernel::validate(&forget_gate_info, cell_state_in, &forget_gate_info, 1.f, ConvertPolicy::SATURATE, RoundingPolicy::TO_ZERO));
@@ -693,10 +782,10 @@ Status CLQLSTMLayer::validate(const ITensorInfo *input,
// Output gate.
const TensorInfo output_outstage_info(TensorShape(num_units, batch_size), 1, DataType::QSYMM16, QuantizationInfo(lstm_params.output_intermediate_scale(), 0));
const float input_to_output_scale = input_to_output_weights->quantization_info().uniform().scale * qinput.scale / lstm_params.output_intermediate_scale();
- validate_mm(gemmlowp_info, input, &input_weights_transposed, &eff_bias_info, input_to_output_scale, &mm_out_info, &output_outstage_info);
+ ARM_COMPUTE_RETURN_ON_ERROR(validate_mm(gemmlowp_info, input, &input_weights_transposed, &eff_bias_info, input_to_output_scale, &mm_out_info, &output_outstage_info));
const float recurrent_to_output_scale = recurrent_to_output_weights->quantization_info().uniform().scale * qoutput_state_in.scale / lstm_params.output_intermediate_scale();
- validate_mm(gemmlowp_info, output_state_in, &recurrent_weights_transposed, &eff_bias_info, recurrent_to_output_scale, &mm_out_info, &output_outstage_info);
+ ARM_COMPUTE_RETURN_ON_ERROR(validate_mm(gemmlowp_info, output_state_in, &recurrent_weights_transposed, &eff_bias_info, recurrent_to_output_scale, &mm_out_info, &output_outstage_info));
ARM_COMPUTE_RETURN_ON_ERROR(CLSaturatedArithmeticOperationKernel::validate(ArithmeticOperation::ADD, &output_outstage_info, &output_outstage_info, &output_outstage_info, ConvertPolicy::SATURATE));
if(lstm_params.has_peephole_opt())
@@ -724,11 +813,15 @@ Status CLQLSTMLayer::validate(const ITensorInfo *input,
// Hidden.
ARM_COMPUTE_RETURN_ON_ERROR(CLActivationLayer::validate(cell_state_out, &input_gate_info, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::TANH, 1.f, 1.f)));
const TensorInfo hidden_mul_res(TensorShape(num_units, batch_size), 1, DataType::S32);
+ const TensorInfo hidden_out_info(TensorShape(num_units, batch_size), 1, DataType::QASYMM8_SIGNED);
+
ARM_COMPUTE_RETURN_ON_ERROR(CLPixelWiseMultiplicationKernel::validate(&output_gate_info, &input_gate_info, &hidden_mul_res, 1.f, ConvertPolicy::SATURATE, RoundingPolicy::TO_ZERO));
const float hidden_state_scale = std::pow(2, -15) / lstm_params.hidden_state_scale() * std::pow(2, -15);
ARM_COMPUTE_RETURN_ON_ERROR(quantization::calculate_quantized_multiplier(hidden_state_scale, &gemmlowp_info.gemmlowp_multiplier, &gemmlowp_info.gemmlowp_shift, /* ignore_epsilon */ true));
gemmlowp_info.gemmlowp_offset = lstm_params.hidden_state_zero();
- ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMLowpOutputStage::validate(&hidden_mul_res, nullptr, output_state_out, gemmlowp_info));
+ ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMLowpOutputStage::validate(&hidden_mul_res, nullptr, &hidden_out_info, gemmlowp_info));
+
+ const bool projection_tensor_copy_required = num_units != output_size;
// Projection.
if(lstm_params.has_projection())
@@ -745,10 +838,26 @@ Status CLQLSTMLayer::validate(const ITensorInfo *input,
gemmlowp_info.output_data_type = DataType::QASYMM8_SIGNED;
const TensorInfo projection_outstage_info(*output_state_out);
- validate_mm(gemmlowp_info, output_state_out, &recurrent_weights_transposed, &eff_bias_info, input_to_output_scale, &mm_out_info, &projection_outstage_info);
+ const TensorInfo projection_weights_transposed(TensorShape(output_size, num_units), 1, lstm_params.projection_weights()->data_type(), lstm_params.projection_weights()->quantization_info());
+
+ TensorInfo projection_mm_out_info{ mm_out_info };
+ projection_mm_out_info.set_tensor_shape(TensorShape(output_size, batch_size));
+
+ ARM_COMPUTE_RETURN_ON_ERROR(validate_mm(gemmlowp_info, &hidden_out_info, &projection_weights_transposed, &projection_eff_bias_info, projection_scale, &projection_mm_out_info,
+ &projection_outstage_info));
+
+ if(projection_tensor_copy_required)
+ {
+ ARM_COMPUTE_RETURN_ON_ERROR(CLQLSTMLayer::TensorCopyKernel::validate(*output_state_out, projection_outstage_info));
+ }
ARM_COMPUTE_RETURN_ON_ERROR(CLSaturatedArithmeticOperationKernel::validate(ArithmeticOperation::ADD, output_state_out, output_state_out, output_state_out, ConvertPolicy::SATURATE));
+ if(projection_tensor_copy_required)
+ {
+ ARM_COMPUTE_RETURN_ON_ERROR(CLQLSTMLayer::TensorCopyKernel::validate(projection_outstage_info, *output_state_out));
+ }
+
int8_t quantized_projection_clip{ 0 };
if(lstm_params.projection_clip() > 0.0f)
{
@@ -761,6 +870,13 @@ Status CLQLSTMLayer::validate(const ITensorInfo *input,
quantized_projection_clip)));
}
}
+ else
+ {
+ if(projection_tensor_copy_required)
+ {
+ ARM_COMPUTE_RETURN_ON_ERROR(CLQLSTMLayer::TensorCopyKernel::validate(hidden_out_info, *output_state_out));
+ }
+ }
if(cell_state_out->total_size() > 0)
{
@@ -847,7 +963,7 @@ void CLQLSTMLayer::run()
CLScheduler::get().enqueue(get_layer_norm(LayerNormGate::Input));
}
- _input_gate_tanh.run();
+ _input_gate_sigmoid.run();
}
// Cell.
@@ -868,6 +984,7 @@ void CLQLSTMLayer::run()
if(_has_peephole)
{
CLScheduler::get().enqueue(_pixelwise_mul_cell_to_output);
+ _cell_to_output_outstage.run();
CLScheduler::get().enqueue(_accumulate_cell_to_output);
}
@@ -888,12 +1005,31 @@ void CLQLSTMLayer::run()
{
_mm_projection.run();
_projection_outstage.run();
+
+ if(_projection_tensor_copy_required)
+ {
+ _projection_output_to_accumulate_copy.run();
+ }
+
CLScheduler::get().enqueue(_accumulate_projection);
+
+ if(_projection_tensor_copy_required)
+ {
+ _projection_accumulate_to_output_copy.run();
+ }
+
if(_has_projection_clipping)
{
_projection_clip.run();
}
}
+ else
+ {
+ if(_projection_tensor_copy_required)
+ {
+ _hidden_to_output_copy.run();
+ }
+ }
// Copy output_state_out to output
CLScheduler::get().enqueue(_copy_output);
@@ -963,6 +1099,12 @@ void CLQLSTMLayer::prepare()
_projection_weights_transposed.allocator()->allocate();
_transpose_projection_weights.run();
_projection_weights->mark_as_unused();
+
+ if(!_projection_tensor_copy_required)
+ {
+ _hidden_gate.mark_as_unused();
+ _projection_accumulate_res.mark_as_unused();
+ }
}
// Mark weights as unused
diff --git a/src/runtime/NEON/functions/NEQLSTMLayer.cpp b/src/runtime/NEON/functions/NEQLSTMLayer.cpp
index 466c41307b..beb180fda5 100644
--- a/src/runtime/NEON/functions/NEQLSTMLayer.cpp
+++ b/src/runtime/NEON/functions/NEQLSTMLayer.cpp
@@ -437,6 +437,8 @@ void NEQLSTMLayer::configure(const ITensor *input,
_projection_tensor_copy_required = (num_units != output_size);
ITensor *hidden_gate_result = output_state_out;
+ _memory_group.manage(&_hidden_gate);
+
if(_projection_tensor_copy_required)
{
_hidden_gate.allocator()->init(*output_state_out->info());
@@ -450,7 +452,7 @@ void NEQLSTMLayer::configure(const ITensor *input,
// Projection.
if(_has_projection)
{
- const TensorInfo projection_outstage_info(*hidden_gate_result->info());
+ const TensorInfo projection_outstage_info(*output_state_out->info());
const UniformQuantizationInfo qprojection = _projection_weights->info()->quantization_info().uniform();
const float projection_scale = qprojection.scale * lstm_params.hidden_state_scale() / qoutput_state_in.scale;
gemmlowp_info.gemmlowp_offset = qoutput_state_in.offset;
@@ -458,23 +460,13 @@ void NEQLSTMLayer::configure(const ITensor *input,
gemmlowp_info.gemmlowp_max_bound = std::numeric_limits<int8_t>::max();
gemmlowp_info.output_data_type = DataType::QASYMM8_SIGNED;
- _memory_group.manage(&_projection_eff_bias_adjusted);
- ITensor *bias_to_use = &_projection_eff_bias;
-
- if(_projection_tensor_copy_required)
- {
- _projection_eff_bias_adjusted.allocator()->init(*_projection_eff_bias.info());
- _projection_eff_bias_adjusted.info()->set_tensor_shape(TensorShape(num_units));
- _projection_bias_copy.configure(_projection_eff_bias, _projection_eff_bias_adjusted);
- bias_to_use = &_projection_eff_bias_adjusted;
- }
+ TensorInfo projection_mm_out_info{ mm_out_info };
+ projection_mm_out_info.set_tensor_shape(TensorShape(output_size, batch_size));
configure_mm(_mm_projection, _projection_outstage, gemmlowp_info,
- hidden_gate_result, &_projection_weights_transposed, bias_to_use,
+ hidden_gate_result, &_projection_weights_transposed, &_projection_eff_bias,
&_mm_projection_res, &_projection_outstage_res, projection_scale,
- mm_out_info, projection_outstage_info);
-
- _projection_eff_bias_adjusted.allocator()->allocate();
+ projection_mm_out_info, projection_outstage_info);
ITensor *accumulate_destination = output_state_out;
@@ -655,10 +647,10 @@ Status NEQLSTMLayer::validate(const ITensorInfo *input,
const TensorInfo forget_outstage_info(TensorShape(num_units, batch_size), 1, DataType::QSYMM16, QuantizationInfo(lstm_params.forget_intermediate_scale(), 0));
const TensorInfo mm_out_info(TensorShape(num_units, batch_size), 1, DataType::S32);
const float input_to_forget_scale = input_to_forget_weights->quantization_info().uniform().scale * qinput.scale / lstm_params.forget_intermediate_scale();
- validate_mm(gemmlowp_info, input, &input_weights_transposed, &eff_bias_info, input_to_forget_scale, &mm_out_info, &forget_outstage_info);
+ ARM_COMPUTE_RETURN_ON_ERROR(validate_mm(gemmlowp_info, input, &input_weights_transposed, &eff_bias_info, input_to_forget_scale, &mm_out_info, &forget_outstage_info));
const float recurrent_to_forget_scale = recurrent_to_forget_weights->quantization_info().uniform().scale * qoutput_state_in.scale / lstm_params.forget_intermediate_scale();
- validate_mm(gemmlowp_info, input, &recurrent_weights_transposed, &eff_bias_info, recurrent_to_forget_scale, &mm_out_info, &forget_outstage_info);
+ ARM_COMPUTE_RETURN_ON_ERROR(validate_mm(gemmlowp_info, output_state_in, &recurrent_weights_transposed, &eff_bias_info, recurrent_to_forget_scale, &mm_out_info, &forget_outstage_info));
ARM_COMPUTE_RETURN_ON_ERROR(NEArithmeticAdditionKernel::validate(&forget_outstage_info, &forget_outstage_info, &forget_outstage_info, ConvertPolicy::SATURATE));
@@ -689,10 +681,10 @@ Status NEQLSTMLayer::validate(const ITensorInfo *input,
// Modulation gate.
const TensorInfo cell_outstage_info(TensorShape(num_units, batch_size), 1, DataType::QSYMM16, QuantizationInfo(lstm_params.cell_intermediate_scale(), 0));
const float input_to_cell_scale = input_to_cell_weights->quantization_info().uniform().scale * qinput.scale / lstm_params.cell_intermediate_scale();
- validate_mm(gemmlowp_info, input, &input_weights_transposed, &eff_bias_info, input_to_cell_scale, &mm_out_info, &cell_outstage_info);
+ ARM_COMPUTE_RETURN_ON_ERROR(validate_mm(gemmlowp_info, input, &input_weights_transposed, &eff_bias_info, input_to_cell_scale, &mm_out_info, &cell_outstage_info));
const float recurrent_to_cell_scale = recurrent_to_cell_weights->quantization_info().uniform().scale * qoutput_state_in.scale / lstm_params.cell_intermediate_scale();
- validate_mm(gemmlowp_info, input, &input_weights_transposed, &eff_bias_info, recurrent_to_cell_scale, &mm_out_info, &cell_outstage_info);
+ ARM_COMPUTE_RETURN_ON_ERROR(validate_mm(gemmlowp_info, output_state_in, &recurrent_weights_transposed, &eff_bias_info, recurrent_to_cell_scale, &mm_out_info, &cell_outstage_info));
ARM_COMPUTE_RETURN_ON_ERROR(NEArithmeticAdditionKernel::validate(&cell_outstage_info, &cell_outstage_info, &cell_outstage_info, ConvertPolicy::SATURATE));
@@ -724,10 +716,10 @@ Status NEQLSTMLayer::validate(const ITensorInfo *input,
const TensorInfo input_outstage_info(TensorShape(num_units, batch_size), 1, DataType::QSYMM16, QuantizationInfo(lstm_params.input_intermediate_scale(), 0));
const float input_to_input_scale = lstm_params.input_to_input_weights()->quantization_info().uniform().scale * qinput.scale / lstm_params.input_intermediate_scale();
- validate_mm(gemmlowp_info, input, &input_weights_transposed, &eff_bias_info, input_to_input_scale, &mm_out_info, &input_outstage_info);
+ ARM_COMPUTE_RETURN_ON_ERROR(validate_mm(gemmlowp_info, input, &input_weights_transposed, &eff_bias_info, input_to_input_scale, &mm_out_info, &input_outstage_info));
const float recurrent_to_input_scale = lstm_params.recurrent_to_input_weights()->quantization_info().uniform().scale * qoutput_state_in.scale / lstm_params.input_intermediate_scale();
- validate_mm(gemmlowp_info, input, &recurrent_weights_transposed, &eff_bias_info, recurrent_to_input_scale, &mm_out_info, &input_outstage_info);
+ ARM_COMPUTE_RETURN_ON_ERROR(validate_mm(gemmlowp_info, output_state_in, &recurrent_weights_transposed, &eff_bias_info, recurrent_to_input_scale, &mm_out_info, &input_outstage_info));
ARM_COMPUTE_RETURN_ON_ERROR(NEArithmeticAdditionKernel::validate(&input_outstage_info, &input_outstage_info, &input_outstage_info, ConvertPolicy::SATURATE));
@@ -762,10 +754,10 @@ Status NEQLSTMLayer::validate(const ITensorInfo *input,
// Output gate.
const TensorInfo output_outstage_info(TensorShape(num_units, batch_size), 1, DataType::QSYMM16, QuantizationInfo(lstm_params.output_intermediate_scale(), 0));
const float input_to_output_scale = input_to_output_weights->quantization_info().uniform().scale * qinput.scale / lstm_params.output_intermediate_scale();
- validate_mm(gemmlowp_info, input, &input_weights_transposed, &eff_bias_info, input_to_output_scale, &mm_out_info, &output_outstage_info);
+ ARM_COMPUTE_RETURN_ON_ERROR(validate_mm(gemmlowp_info, input, &input_weights_transposed, &eff_bias_info, input_to_output_scale, &mm_out_info, &output_outstage_info));
const float recurrent_to_output_scale = recurrent_to_output_weights->quantization_info().uniform().scale * qoutput_state_in.scale / lstm_params.output_intermediate_scale();
- validate_mm(gemmlowp_info, output_state_in, &recurrent_weights_transposed, &eff_bias_info, recurrent_to_output_scale, &mm_out_info, &output_outstage_info);
+ ARM_COMPUTE_RETURN_ON_ERROR(validate_mm(gemmlowp_info, output_state_in, &recurrent_weights_transposed, &eff_bias_info, recurrent_to_output_scale, &mm_out_info, &output_outstage_info));
ARM_COMPUTE_RETURN_ON_ERROR(NEArithmeticAdditionKernel::validate(&output_outstage_info, &output_outstage_info, &output_outstage_info, ConvertPolicy::SATURATE));
if(lstm_params.has_peephole_opt())
@@ -816,17 +808,14 @@ Status NEQLSTMLayer::validate(const ITensorInfo *input,
gemmlowp_info.gemmlowp_max_bound = std::numeric_limits<int8_t>::max();
gemmlowp_info.output_data_type = DataType::QASYMM8_SIGNED;
- TensorInfo projection_outstage_info(hidden_out_info);
-
- if(projection_tensor_copy_required)
- {
- TensorInfo projection_eff_bias_adjusted_info{ projection_eff_bias_info };
- projection_eff_bias_adjusted_info.set_tensor_shape(TensorShape(num_units));
+ const TensorInfo projection_outstage_info(*output_state_out);
+ const TensorInfo projection_weights_transposed(TensorShape(output_size, num_units), 1, lstm_params.projection_weights()->data_type(), lstm_params.projection_weights()->quantization_info());
- ARM_COMPUTE_RETURN_ON_ERROR(NEQLSTMLayer::TensorCopyKernel::validate(projection_eff_bias_info, projection_eff_bias_adjusted_info));
- }
+ TensorInfo projection_mm_out_info{ mm_out_info };
+ projection_mm_out_info.set_tensor_shape(TensorShape(output_size, batch_size));
- validate_mm(gemmlowp_info, output_state_out, &recurrent_weights_transposed, &projection_eff_bias_info, input_to_output_scale, &mm_out_info, &projection_outstage_info);
+ ARM_COMPUTE_RETURN_ON_ERROR(validate_mm(gemmlowp_info, &hidden_out_info, &projection_weights_transposed, &projection_eff_bias_info, projection_scale, &projection_mm_out_info,
+ &projection_outstage_info));
if(projection_tensor_copy_required)
{
@@ -985,11 +974,6 @@ void NEQLSTMLayer::run()
// Projection.
if(_has_projection)
{
- if(_projection_tensor_copy_required)
- {
- _projection_bias_copy.run();
- }
-
_mm_projection.run();
_projection_outstage.run();
@@ -1088,7 +1072,6 @@ void NEQLSTMLayer::prepare()
if(!_projection_tensor_copy_required)
{
_hidden_gate.mark_as_unused();
- _projection_eff_bias_adjusted.mark_as_unused();
_projection_accumulate_res.mark_as_unused();
}
}