aboutsummaryrefslogtreecommitdiff
path: root/src/runtime/NEON/functions/NELSTMLayer.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/runtime/NEON/functions/NELSTMLayer.cpp')
-rw-r--r--src/runtime/NEON/functions/NELSTMLayer.cpp27
1 files changed, 15 insertions, 12 deletions
diff --git a/src/runtime/NEON/functions/NELSTMLayer.cpp b/src/runtime/NEON/functions/NELSTMLayer.cpp
index 3d3c6a12fa..42b805794b 100644
--- a/src/runtime/NEON/functions/NELSTMLayer.cpp
+++ b/src/runtime/NEON/functions/NELSTMLayer.cpp
@@ -107,14 +107,14 @@ void NELSTMLayer::configure(const ITensor *input,
inputs_vector.emplace_back(output_state_in);
_memory_group.manage(&_forget_gate_out2);
- _concat_inputs_forget_gate.configure(inputs_vector, &_forget_gate_out2);
+ _concat_inputs_forget_gate.configure(inputs_vector, &_forget_gate_out2, Window::DimX);
std::vector<const ITensor *> weights_vector;
weights_vector.emplace_back(input_to_forget_weights);
weights_vector.emplace_back(recurrent_to_forget_weights);
- _concat_weights_forget_gate.configure(weights_vector, &_forget_gate_out6);
+ _concat_weights_forget_gate.configure(weights_vector, &_forget_gate_out6, Window::DimX);
_memory_group.manage(&_forget_gate_out5);
_fully_connected_forget_gate.configure(&_forget_gate_out2, &_forget_gate_out6, forget_gate_bias, &_forget_gate_out5);
@@ -165,7 +165,7 @@ void NELSTMLayer::configure(const ITensor *input,
lstm_weights.emplace_back(lstm_params.input_to_input_weights());
lstm_weights.emplace_back(lstm_params.recurrent_to_input_weights());
- _concat_weights_input_gate.configure(lstm_weights, &_input_gate_out2);
+ _concat_weights_input_gate.configure(lstm_weights, &_input_gate_out2, Window::DimX);
_memory_group.manage(&_input_gate_out1);
_memory_group.manage(&_input_gate_out4);
@@ -234,7 +234,7 @@ void NELSTMLayer::configure(const ITensor *input,
in_out_weights.emplace_back(input_to_output_weights);
in_out_weights.emplace_back(recurrent_to_output_weights);
- _concat_weights_output.configure(in_out_weights, &_output2);
+ _concat_weights_output.configure(in_out_weights, &_output2, Window::DimX);
_memory_group.manage(&_output1);
_memory_group.manage(&_output4);
@@ -308,7 +308,7 @@ void NELSTMLayer::configure(const ITensor *input,
scratch_inputs.emplace_back(&_cell_state_out1);
scratch_inputs.emplace_back(forget_gate_out);
scratch_inputs.emplace_back(output_gate_out);
- _concat_scratch_buffer.configure(scratch_inputs, scratch_buffer);
+ _concat_scratch_buffer.configure(scratch_inputs, scratch_buffer, Window::DimX);
input_gate_out->allocator()->allocate();
_cell_state_out1.allocator()->allocate();
forget_gate_out->allocator()->allocate();
@@ -383,8 +383,9 @@ Status NELSTMLayer::validate(const ITensorInfo *input,
std::vector<const ITensorInfo *> inputs_vector;
inputs_vector.emplace_back(input);
inputs_vector.emplace_back(output_state_in);
- TensorInfo forget_gate_concat;
- ARM_COMPUTE_RETURN_ON_ERROR(NEWidthConcatenateLayer::validate(inputs_vector, &forget_gate_concat));
+ const TensorShape concat_shape = arm_compute::misc::shape_calculator::calculate_concatenate_shape(inputs_vector, 0);
+ TensorInfo forget_gate_concat = TensorInfo(concat_shape, 1, input->data_type());
+ ARM_COMPUTE_RETURN_ON_ERROR(NEConcatenateLayer::validate(inputs_vector, &forget_gate_concat, Window::DimX));
// Validate forget gate
ARM_COMPUTE_RETURN_ON_ERROR(NEFullyConnectedLayer::validate(input, input_to_forget_weights, forget_gate_bias, &forget_gate));
@@ -409,8 +410,9 @@ Status NELSTMLayer::validate(const ITensorInfo *input,
std::vector<const ITensorInfo *> lstm_weights;
lstm_weights.emplace_back(lstm_params.input_to_input_weights());
lstm_weights.emplace_back(lstm_params.recurrent_to_input_weights());
- TensorInfo lstm_gate_concat;
- ARM_COMPUTE_RETURN_ON_ERROR(NEWidthConcatenateLayer::validate(lstm_weights, &lstm_gate_concat));
+ TensorShape lstm_weights_concat_shape = arm_compute::misc::shape_calculator::calculate_concatenate_shape(lstm_weights, 0);
+ TensorInfo lstm_gate_concat = TensorInfo(lstm_weights_concat_shape, 1, input->data_type());
+ ARM_COMPUTE_RETURN_ON_ERROR(NEConcatenateLayer::validate(lstm_weights, &lstm_gate_concat, Window::DimX));
ARM_COMPUTE_RETURN_ON_ERROR(NEFullyConnectedLayer::validate(input, lstm_params.input_to_input_weights(), lstm_params.input_gate_bias(), &input_gate));
if(lstm_params.has_peephole_opt())
@@ -445,8 +447,9 @@ Status NELSTMLayer::validate(const ITensorInfo *input,
std::vector<const ITensorInfo *> in_out_weights;
in_out_weights.emplace_back(input_to_output_weights);
in_out_weights.emplace_back(recurrent_to_output_weights);
- TensorInfo in_out_gate_concat;
- ARM_COMPUTE_RETURN_ON_ERROR(NEWidthConcatenateLayer::validate(in_out_weights, &in_out_gate_concat));
+ TensorShape in_out_weights_concat_shape = arm_compute::misc::shape_calculator::calculate_concatenate_shape(in_out_weights, 0);
+ TensorInfo in_out_gate_concat = TensorInfo(in_out_weights_concat_shape, 1, input->data_type());
+ ARM_COMPUTE_RETURN_ON_ERROR(NEConcatenateLayer::validate(in_out_weights, &in_out_gate_concat, Window::DimX));
ARM_COMPUTE_RETURN_ON_ERROR(NEFullyConnectedLayer::validate(input, input_to_output_weights, output_gate_bias, &output_gate_tmp));
@@ -485,7 +488,7 @@ Status NELSTMLayer::validate(const ITensorInfo *input,
inputs_vector_info_raw.push_back(&forget_gate);
inputs_vector_info_raw.push_back(&output_gate_tmp);
- ARM_COMPUTE_RETURN_ON_ERROR(NEWidthConcatenateLayer::validate(inputs_vector_info_raw, scratch_buffer));
+ ARM_COMPUTE_RETURN_ON_ERROR(NEConcatenateLayer::validate(inputs_vector_info_raw, scratch_buffer, Window::DimX));
return Status{};
}