aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMichalis Spyrou <michalis.spyrou@arm.com>2018-06-05 11:45:48 +0100
committerAnthony Barbier <anthony.barbier@arm.com>2018-11-02 16:53:09 +0000
commit542e92d95536f2ab7fc6f1cc1aa1bd4f1d471212 (patch)
treea4c03807d9731c1305b6f446282d3f4b97cfb595
parent72219330fd85b1271e714d4ba894d6d8e26340c9 (diff)
downloadComputeLibrary-542e92d95536f2ab7fc6f1cc1aa1bd4f1d471212.tar.gz
COMPMID-1067 NEON RNN FP32 / FP16
Change-Id: I440df2b2af512fd874651baf28428caa6f8e0b41 Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/134433 Tested-by: Jenkins <bsgcomp@arm.com> Reviewed-by: Anthony Barbier <anthony.barbier@arm.com>
-rw-r--r--arm_compute/runtime/NEON/NEFunctions.h1
-rw-r--r--arm_compute/runtime/NEON/functions/NEFullyConnectedLayer.h2
-rw-r--r--arm_compute/runtime/NEON/functions/NERNNLayer.h96
-rw-r--r--src/runtime/NEON/functions/NERNNLayer.cpp125
-rw-r--r--tests/datasets/RNNLayerDataset.h2
-rw-r--r--tests/validation/NEON/RNNLayer.cpp147
6 files changed, 371 insertions, 2 deletions
diff --git a/arm_compute/runtime/NEON/NEFunctions.h b/arm_compute/runtime/NEON/NEFunctions.h
index bd4097c6d3..b83675b779 100644
--- a/arm_compute/runtime/NEON/NEFunctions.h
+++ b/arm_compute/runtime/NEON/NEFunctions.h
@@ -95,6 +95,7 @@
#include "arm_compute/runtime/NEON/functions/NEPixelWiseMultiplication.h"
#include "arm_compute/runtime/NEON/functions/NEPoolingLayer.h"
#include "arm_compute/runtime/NEON/functions/NEQuantizationLayer.h"
+#include "arm_compute/runtime/NEON/functions/NERNNLayer.h"
#include "arm_compute/runtime/NEON/functions/NEROIPoolingLayer.h"
#include "arm_compute/runtime/NEON/functions/NEReductionOperation.h"
#include "arm_compute/runtime/NEON/functions/NERemap.h"
diff --git a/arm_compute/runtime/NEON/functions/NEFullyConnectedLayer.h b/arm_compute/runtime/NEON/functions/NEFullyConnectedLayer.h
index 2739f5ebef..42c9e2d3e9 100644
--- a/arm_compute/runtime/NEON/functions/NEFullyConnectedLayer.h
+++ b/arm_compute/runtime/NEON/functions/NEFullyConnectedLayer.h
@@ -104,7 +104,7 @@ public:
NEFullyConnectedLayer &operator=(NEFullyConnectedLayer &&) = default;
/** Set the input and output tensors.
*
- * @param[in] input Source tensor. Data type supported: QS8/QS16/F32.
+ * @param[in] input Source tensor. Data type supported: QS8/QS16/F16/F32.
* @param[in] weights Weights tensor. The weights must be 2 dimensional. Data type supported: Same as @p input.
* @param[in] biases Bias tensor. Can be nullptr. Data type supported:Same as @p input.
* @param[out] output Destination tensor. Data type supported: Same as @p input.
diff --git a/arm_compute/runtime/NEON/functions/NERNNLayer.h b/arm_compute/runtime/NEON/functions/NERNNLayer.h
new file mode 100644
index 0000000000..f1398eb3cc
--- /dev/null
+++ b/arm_compute/runtime/NEON/functions/NERNNLayer.h
@@ -0,0 +1,96 @@
+/*
+ * Copyright (c) 2018 ARM Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+#ifndef __ARM_COMPUTE_NERNNLAYER_H__
+#define __ARM_COMPUTE_NERNNLAYER_H__
+
+#include "arm_compute/core/NEON/kernels/NEActivationLayerKernel.h"
+#include "arm_compute/core/NEON/kernels/NEArithmeticAdditionKernel.h"
+#include "arm_compute/runtime/NEON/INESimpleFunction.h"
+
+#include "arm_compute/core/Types.h"
+#include "arm_compute/runtime/NEON/functions/NEFullyConnectedLayer.h"
+#include "arm_compute/runtime/NEON/functions/NEGEMM.h"
+
+namespace arm_compute
+{
+// Forward declarations
+class ITensor;
+
+/** Basic function to run @ref NERNNLayer */
+class NERNNLayer : public IFunction
+{
+public:
+ /** Default constructor */
+ NERNNLayer(std::shared_ptr<IMemoryManager> memory_manager = nullptr);
+ /** Prevent instances of this class from being copied (As this class contains pointers) */
+ NERNNLayer(const NERNNLayer &) = delete;
+ /** Default move constructor */
+ NERNNLayer(NERNNLayer &&) = default;
+ /** Prevent instances of this class from being copied (As this class contains pointers) */
+ NERNNLayer &operator=(const NERNNLayer &) = delete;
+ /** Default move assignment operator */
+ NERNNLayer &operator=(NERNNLayer &&) = default;
+ /** Initialize the function
+ *
+ * @param[in] input Input is a 2-D tensor of shape [input_size, batch_size]. Data types supported: F16/F32
+ * @param[in] weights Weights tensor of shape [input_size, num_units] that multiplies the input. Data types supported: Same as @p input
+ * @param[in] recurrent_weights Weights tensor of shape [num_units, num_units] that multiplies the current 'state'. Data types supported: Same as @p input
+ * @param[in] bias Bias vector of shape [num_units]. Data types supported: Same as @p input
+ * @param[out] output Output tensor of shape [num_units, batch_size]. Data types supported: Same as @p input
+ * @param[in,out] hidden_state Output tensor of shape [num_units, batch_size]. Data types supported: Same as @p input
+ * @param[in] info Activation layer parameter.
+ */
+ void configure(const ITensor *input, const ITensor *weights, const ITensor *recurrent_weights, const ITensor *bias, ITensor *hidden_state, ITensor *output, ActivationLayerInfo &info);
+ /** Initialize the function
+ *
+ * @param[in] input Input is a 2-D tensor of shape [input_size, batch_size]. Data types supported: F16/F32
+ * @param[in] weights Weights tensor of shape [input_size, num_units] that multiplies the input. Data types supported: Same as @p input
+ * @param[in] recurrent_weights Weights tensor of shape [num_units, num_units] that multiplies the current 'state'. Data types supported: Same as @p input
+ * @param[in] bias Bias vector of shape [num_units]. Data types supported: Same as @p input
+ * @param[in] output Output tensor of shape [num_units, batch_size]. Data types supported: Same as @p input
+ * @param[in] hidden_state Output tensor of shape [num_units, batch_size]. Data types supported: Same as @p input
+ * @param[in] info Activation layer parameter.
+ *
+ * @return a status
+ */
+ static Status validate(const ITensorInfo *input, const ITensorInfo *weights, const ITensorInfo *recurrent_weights, const ITensorInfo *bias, const ITensorInfo *hidden_state, const ITensorInfo *output,
+ const ActivationLayerInfo &info);
+
+ // Inherited methods overridden:
+ void run() override;
+
+private:
+ MemoryGroup _memory_group;
+ NEGEMM _gemm_state_f;
+ NEArithmeticAdditionKernel _add_kernel;
+ NEActivationLayerKernel _activation_kernel;
+ NEFullyConnectedLayer _fully_connected_kernel;
+ Tensor _fully_connected_out;
+ Tensor _gemm_output;
+ Tensor _add_output;
+ ITensor *_hidden_state;
+ ITensor *_output;
+};
+} // namespace arm_compute
+#endif /* __ARM_COMPUTE_NERNNLAYER_H__ */
diff --git a/src/runtime/NEON/functions/NERNNLayer.cpp b/src/runtime/NEON/functions/NERNNLayer.cpp
new file mode 100644
index 0000000000..08017e20c3
--- /dev/null
+++ b/src/runtime/NEON/functions/NERNNLayer.cpp
@@ -0,0 +1,125 @@
+/*
+ * Copyright (c) 2018 ARM Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#include "arm_compute/runtime/NEON/functions/NERNNLayer.h"
+
+#include "arm_compute/core/Error.h"
+#include "arm_compute/core/TensorInfo.h"
+#include "arm_compute/core/Types.h"
+#include "arm_compute/core/Validate.h"
+#include "arm_compute/core/utils/misc/ShapeCalculator.h"
+#include "arm_compute/runtime/NEON/NEScheduler.h"
+
+namespace arm_compute
+{
+NERNNLayer::NERNNLayer(std::shared_ptr<IMemoryManager> memory_manager)
+ : _memory_group(std::move(memory_manager)), _gemm_state_f(), _add_kernel(), _activation_kernel(), _fully_connected_kernel(), _fully_connected_out(), _gemm_output(), _add_output(), _hidden_state(),
+ _output()
+{
+}
+
+Status NERNNLayer::validate(const ITensorInfo *input, const ITensorInfo *weights, const ITensorInfo *recurrent_weights, const ITensorInfo *bias, const ITensorInfo *hidden_state,
+ const ITensorInfo *output, const ActivationLayerInfo &info)
+{
+ ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input, weights, recurrent_weights, bias, hidden_state, output);
+
+ const int idx_width = get_data_layout_dimension_index(input->data_layout(), DataLayoutDimension::WIDTH);
+ const int idx_height = get_data_layout_dimension_index(input->data_layout(), DataLayoutDimension::HEIGHT);
+ ARM_COMPUTE_RETURN_ERROR_ON(input->dimension(idx_width) != weights->dimension(idx_width));
+ ARM_COMPUTE_RETURN_ERROR_ON(weights->dimension(idx_height) != recurrent_weights->dimension(idx_width));
+ ARM_COMPUTE_RETURN_ERROR_ON(recurrent_weights->dimension(idx_width) != recurrent_weights->dimension(idx_height));
+ ARM_COMPUTE_RETURN_ERROR_ON(bias->num_dimensions() != 1);
+ ARM_COMPUTE_RETURN_ERROR_ON(bias->dimension(idx_width) != weights->dimension(idx_height));
+ ARM_COMPUTE_RETURN_ERROR_ON(hidden_state->dimension(idx_width) != weights->dimension(idx_height));
+ ARM_COMPUTE_RETURN_ERROR_ON(hidden_state->dimension(idx_height) != input->dimension(idx_height));
+ ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DIMENSIONS(output->tensor_shape(), hidden_state->tensor_shape());
+
+ auto shape_info = TensorInfo(misc::shape_calculator::compute_rnn_shape(recurrent_weights, hidden_state->dimension(idx_height)), 1, input->data_type());
+
+ ARM_COMPUTE_RETURN_ON_ERROR(NEFullyConnectedLayer::validate(input, weights, bias, &shape_info, true, false));
+ ARM_COMPUTE_RETURN_ON_ERROR(NEArithmeticAdditionKernel::validate(&shape_info, &shape_info, &shape_info, ConvertPolicy::SATURATE));
+ ARM_COMPUTE_RETURN_ON_ERROR(NEActivationLayerKernel::validate(&shape_info, &shape_info, info));
+
+ return Status{};
+}
+
+void NERNNLayer::configure(const ITensor *input, const ITensor *weights, const ITensor *recurrent_weights, const ITensor *bias, ITensor *hidden_state, ITensor *output,
+ ActivationLayerInfo &info)
+{
+ ARM_COMPUTE_ERROR_ON_NULLPTR(input, weights, recurrent_weights, bias, hidden_state, output);
+ ARM_COMPUTE_ERROR_THROW_ON(NERNNLayer::validate(input->info(), weights->info(), recurrent_weights->info(), bias->info(), hidden_state->info(), output->info(), info));
+
+ _hidden_state = hidden_state;
+ _output = output;
+
+ const int idx_height = get_data_layout_dimension_index(input->info()->data_layout(), DataLayoutDimension::HEIGHT);
+ TensorShape shape = misc::shape_calculator::compute_rnn_shape(recurrent_weights->info(), hidden_state->info()->dimension(idx_height));
+
+ // Manage intermediate buffers and configure
+ _fully_connected_out.allocator()->init(TensorInfo(shape, 1, input->info()->data_type()));
+ _memory_group.manage(&_fully_connected_out);
+ _fully_connected_kernel.configure(input, weights, bias, &_fully_connected_out, true, false);
+
+ _gemm_output.allocator()->init(TensorInfo(shape, 1, input->info()->data_type()));
+ _memory_group.manage(&_gemm_output);
+ _gemm_state_f.configure(hidden_state, recurrent_weights, nullptr, &_gemm_output, 1.f, 0.f);
+
+ _add_output.allocator()->init(TensorInfo(shape, 1, input->info()->data_type()));
+ _memory_group.manage(&_add_output);
+ _add_kernel.configure(&_fully_connected_out, &_gemm_output, &_add_output, ConvertPolicy::SATURATE);
+
+ _fully_connected_out.allocator()->allocate();
+ _gemm_output.allocator()->allocate();
+
+ _activation_kernel.configure(&_add_output, hidden_state, info);
+ _add_output.allocator()->allocate();
+}
+
+void NERNNLayer::run()
+{
+ _memory_group.acquire();
+
+ _fully_connected_kernel.run();
+ _gemm_state_f.run();
+ NEScheduler::get().schedule(&_add_kernel, Window::DimY);
+ NEScheduler::get().schedule(&_activation_kernel, Window::DimY);
+
+ // copy hidden out to output
+ Window hidden_state_window;
+ Window output_window;
+ hidden_state_window.use_tensor_dimensions(_hidden_state->info()->tensor_shape(), Window::DimY);
+ output_window.use_tensor_dimensions(_output->info()->tensor_shape(), Window::DimY);
+
+ Iterator hidden_state_it(_hidden_state, output_window);
+ Iterator output_it(_output, output_window);
+
+ execute_window_loop(output_window, [&](const Coordinates & id)
+ {
+ memcpy(output_it.ptr(), hidden_state_it.ptr(), _output->info()->dimension(0) * _output->info()->element_size());
+ },
+ hidden_state_it, output_it);
+
+ _memory_group.release();
+}
+} // namespace arm_compute
diff --git a/tests/datasets/RNNLayerDataset.h b/tests/datasets/RNNLayerDataset.h
index 616a69e213..40d1b934f3 100644
--- a/tests/datasets/RNNLayerDataset.h
+++ b/tests/datasets/RNNLayerDataset.h
@@ -131,7 +131,7 @@ class SmallRNNLayerDataset final : public RNNLayerDataset
public:
SmallRNNLayerDataset()
{
- add_config(TensorShape(8U, 2U), TensorShape(8U, 16U), TensorShape(16U, 16U), TensorShape(16U), TensorShape(16U, 2U), ActivationLayerInfo());
+ add_config(TensorShape(128U, 16U), TensorShape(128U, 32U), TensorShape(32U, 32U), TensorShape(32U), TensorShape(32U, 16U), ActivationLayerInfo());
}
};
diff --git a/tests/validation/NEON/RNNLayer.cpp b/tests/validation/NEON/RNNLayer.cpp
new file mode 100644
index 0000000000..7aa3befd03
--- /dev/null
+++ b/tests/validation/NEON/RNNLayer.cpp
@@ -0,0 +1,147 @@
+/*
+ * Copyright (c) 2018 ARM Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+#include "arm_compute/runtime/NEON/functions/NERNNLayer.h"
+#include "tests/NEON/Accessor.h"
+#include "tests/PaddingCalculator.h"
+#include "tests/datasets/RNNLayerDataset.h"
+#include "tests/framework/Asserts.h"
+#include "tests/framework/Macros.h"
+#include "tests/framework/datasets/Datasets.h"
+#include "tests/validation/Validation.h"
+#include "tests/validation/fixtures/RNNLayerFixture.h"
+
+namespace arm_compute
+{
+namespace test
+{
+namespace validation
+{
+namespace
+{
+RelativeTolerance<float> tolerance_f32(0.001f);
+RelativeTolerance<half> tolerance_f16(half(0.1));
+} // namespace
+
+TEST_SUITE(NEON)
+TEST_SUITE(RNNLayer)
+
+// *INDENT-OFF*
+// clang-format off
+DATA_TEST_CASE(Validate, framework::DatasetMode::ALL, zip(zip(zip(zip(zip(zip(zip(
+ framework::dataset::make("InputInfo", { TensorInfo(TensorShape(27U, 13U), 1, DataType::U8, 0), // Wrong data type
+ TensorInfo(TensorShape(27U, 13U, 2U), 1, DataType::F32, 0), // Wrong input size
+ TensorInfo(TensorShape(27U, 13U), 1, DataType::F32, 0), // Wrong weights size
+ TensorInfo(TensorShape(27U, 13U), 1, DataType::F32, 0), // Wrong recurrent weights size
+ TensorInfo(TensorShape(27U, 13U), 1, DataType::F32, 0), // Wrong bias size
+ TensorInfo(TensorShape(27U, 13U), 1, DataType::F32, 0), // Wrong output size
+ TensorInfo(TensorShape(27U, 13U), 1, DataType::F32, 0), // Wrong hidden output size
+ TensorInfo(TensorShape(32U, 32U), 1, DataType::F32, 0),
+ }),
+ framework::dataset::make("WeightsInfo", { TensorInfo(TensorShape(27U, 11U), 1, DataType::F32, 0),
+ TensorInfo(TensorShape(27U, 11U), 1, DataType::F32, 0),
+ TensorInfo(TensorShape(27U, 11U, 2U), 1, DataType::F32, 0),
+ TensorInfo(TensorShape(27U, 11U), 1, DataType::F32, 0),
+ TensorInfo(TensorShape(27U, 11U), 1, DataType::F32, 0),
+ TensorInfo(TensorShape(27U, 11U), 1, DataType::F32, 0),
+ TensorInfo(TensorShape(27U, 11U), 1, DataType::F32, 0),
+ TensorInfo(TensorShape(32U, 32U), 1, DataType::F32, 0),
+ })),
+ framework::dataset::make("RecurrentWeightsInfo", { TensorInfo(TensorShape(11U, 11U), 1, DataType::F32, 0),
+ TensorInfo(TensorShape(11U, 11U), 1, DataType::F32, 0),
+ TensorInfo(TensorShape(11U, 11U), 1, DataType::F32, 0),
+ TensorInfo(TensorShape(25U, 11U, 2U), 1, DataType::F32, 0),
+ TensorInfo(TensorShape(11U, 11U), 1, DataType::F32, 0),
+ TensorInfo(TensorShape(11U, 11U), 1, DataType::F32, 0),
+ TensorInfo(TensorShape(11U, 11U), 1, DataType::F32, 0),
+ TensorInfo(TensorShape(32U, 32U), 1, DataType::F32, 0),
+ })),
+ framework::dataset::make("BiasInfo", { TensorInfo(TensorShape(11U), 1, DataType::F32, 0),
+ TensorInfo(TensorShape(11U), 1, DataType::F32, 0),
+ TensorInfo(TensorShape(11U), 1, DataType::F32, 0),
+ TensorInfo(TensorShape(11U), 1, DataType::F32, 0),
+ TensorInfo(TensorShape(30U), 1, DataType::F32, 0),
+ TensorInfo(TensorShape(11U), 1, DataType::F32, 0),
+ TensorInfo(TensorShape(11U), 1, DataType::F32, 0),
+ TensorInfo(TensorShape(32U), 1, DataType::F32, 0),
+ })),
+ framework::dataset::make("OutputInfo", { TensorInfo(TensorShape(11U, 13U), 1, DataType::F32, 0),
+ TensorInfo(TensorShape(11U, 13U), 1, DataType::F32, 0),
+ TensorInfo(TensorShape(11U, 13U), 1, DataType::F32, 0),
+ TensorInfo(TensorShape(11U, 13U), 1, DataType::F32, 0),
+ TensorInfo(TensorShape(11U, 13U), 1, DataType::F32, 0),
+ TensorInfo(TensorShape(11U), 1, DataType::F32, 0),
+ TensorInfo(TensorShape(11U, 13U), 1, DataType::F32, 0),
+ TensorInfo(TensorShape(32U, 32U), 1, DataType::F32, 0),
+ })),
+ framework::dataset::make("HiddenStateInfo", { TensorInfo(TensorShape(11U, 13U), 1, DataType::F32, 0),
+ TensorInfo(TensorShape(11U, 13U), 1, DataType::F32, 0),
+ TensorInfo(TensorShape(11U, 13U), 1, DataType::F32, 0),
+ TensorInfo(TensorShape(11U, 13U), 1, DataType::F32, 0),
+ TensorInfo(TensorShape(11U, 13U), 1, DataType::F32, 0),
+ TensorInfo(TensorShape(11U, 13U), 1, DataType::F32, 0),
+ TensorInfo(TensorShape(11U, 13U, 2U), 1, DataType::F32, 0),
+ TensorInfo(TensorShape(32U, 32U), 1, DataType::F32, 0),
+ })),
+ framework::dataset::make("ActivationInfo", { ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::RELU),
+ ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::RELU),
+ ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::RELU),
+ ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::RELU),
+ ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::RELU),
+ ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::RELU),
+ ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::RELU),
+ ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::RELU),
+ })),
+ framework::dataset::make("Expected", { false, false, false, false, false, false, false, true })),
+ input_info, weights_info, recurrent_weights_info, bias_info, output_info, hidden_output_info, info, expected)
+{
+ ARM_COMPUTE_EXPECT(bool(NERNNLayer::validate(&input_info.clone()->set_is_resizable(false), &weights_info.clone()->set_is_resizable(false), &recurrent_weights_info.clone()->set_is_resizable(false), &bias_info.clone()->set_is_resizable(false), &output_info.clone()->set_is_resizable(false), &hidden_output_info.clone()->set_is_resizable(false), info)) == expected, framework::LogLevel::ERRORS);
+}
+// clang-format on
+// *INDENT-ON*
+
+template <typename T>
+using NERNNLayerFixture = RNNLayerValidationFixture<Tensor, Accessor, NERNNLayer, T>;
+
+TEST_SUITE(FP32)
+FIXTURE_DATA_TEST_CASE(RunSmall, NERNNLayerFixture<float>, framework::DatasetMode::ALL, combine(datasets::SmallRNNLayerDataset(), framework::dataset::make("DataType", DataType::F32)))
+{
+ // Validate output
+ validate(Accessor(_target), _reference, tolerance_f32);
+}
+TEST_SUITE_END() // FP32
+
+#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
+TEST_SUITE(FP16)
+FIXTURE_DATA_TEST_CASE(RunSmall, NERNNLayerFixture<half>, framework::DatasetMode::ALL, combine(datasets::SmallRNNLayerDataset(), framework::dataset::make("DataType", DataType::F16)))
+{
+ // Validate output
+ validate(Accessor(_target), _reference, tolerance_f16);
+}
+TEST_SUITE_END() // FP16
+#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
+TEST_SUITE_END() // RNNLayer
+TEST_SUITE_END() // NEON
+} // namespace validation
+} // namespace test
+} // namespace arm_compute