aboutsummaryrefslogtreecommitdiff
path: root/arm_compute/runtime/NEON/functions/NERNNLayer.h
diff options
context:
space:
mode:
Diffstat (limited to 'arm_compute/runtime/NEON/functions/NERNNLayer.h')
-rw-r--r--arm_compute/runtime/NEON/functions/NERNNLayer.h64
1 files changed, 43 insertions, 21 deletions
diff --git a/arm_compute/runtime/NEON/functions/NERNNLayer.h b/arm_compute/runtime/NEON/functions/NERNNLayer.h
index 0bfb905e19..af7f464ac9 100644
--- a/arm_compute/runtime/NEON/functions/NERNNLayer.h
+++ b/arm_compute/runtime/NEON/functions/NERNNLayer.h
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2018-2019 ARM Limited.
+ * Copyright (c) 2018-2021 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -24,11 +24,10 @@
#ifndef ARM_COMPUTE_NERNNLAYER_H
#define ARM_COMPUTE_NERNNLAYER_H
-#include "arm_compute/core/NEON/kernels/NEActivationLayerKernel.h"
-#include "arm_compute/core/NEON/kernels/NEArithmeticAdditionKernel.h"
-#include "arm_compute/core/NEON/kernels/NECopyKernel.h"
-
#include "arm_compute/core/Types.h"
+#include "arm_compute/runtime/NEON/functions/NEActivationLayer.h"
+#include "arm_compute/runtime/NEON/functions/NEArithmeticAddition.h"
+#include "arm_compute/runtime/NEON/functions/NECopy.h"
#include "arm_compute/runtime/NEON/functions/NEFullyConnectedLayer.h"
#include "arm_compute/runtime/NEON/functions/NEGEMM.h"
@@ -45,14 +44,26 @@ public:
NERNNLayer(std::shared_ptr<IMemoryManager> memory_manager = nullptr);
/** Prevent instances of this class from being copied (As this class contains pointers) */
NERNNLayer(const NERNNLayer &) = delete;
- /** Default move constructor */
- NERNNLayer(NERNNLayer &&) = default;
+ /** Prevent instances of this class from being moved (As this class contains pointers) */
+ NERNNLayer(NERNNLayer &&) = delete;
/** Prevent instances of this class from being copied (As this class contains pointers) */
NERNNLayer &operator=(const NERNNLayer &) = delete;
- /** Default move assignment operator */
- NERNNLayer &operator=(NERNNLayer &&) = default;
+ /** Prevent instances of this class from being moved (As this class contains pointers) */
+ NERNNLayer &operator=(NERNNLayer &&) = delete;
+ /** Default destructor */
+ ~NERNNLayer();
/** Initialize the function
*
+ * Valid data layouts:
+ * - NHWC
+ * - NCHW
+ *
+ * Valid data type configurations:
+ * |src0 |src1 |src2 |src3 |dst0 |dst1 |
+ * |:------|:------|:------|:------|:------|:------|
+ * |F16 |F16 |F16 |F16 |F16 |F16 |
+ * |F32 |F32 |F32 |F32 |F32 |F32 |
+ *
* @param[in] input Input is a 2-D tensor of shape [input_size, batch_size]. Data types supported: F16/F32
* @param[in] weights Weights tensor of shape [input_size, num_units] that multiplies the input. Data types supported: Same as @p input
* @param[in] recurrent_weights Weights tensor of shape [num_units, num_units] that multiplies the current 'state'. Data types supported: Same as @p input
@@ -61,7 +72,13 @@ public:
* @param[in,out] hidden_state Output tensor of shape [num_units, batch_size]. Data types supported: Same as @p input
* @param[in] info Activation layer parameter.
*/
- void configure(const ITensor *input, const ITensor *weights, const ITensor *recurrent_weights, const ITensor *bias, ITensor *hidden_state, ITensor *output, ActivationLayerInfo &info);
+ void configure(const ITensor *input,
+ const ITensor *weights,
+ const ITensor *recurrent_weights,
+ const ITensor *bias,
+ ITensor *hidden_state,
+ ITensor *output,
+ ActivationLayerInfo &info);
/** Initialize the function
*
* @param[in] input Input is a 2-D tensor of shape [input_size, batch_size]. Data types supported: F16/F32
@@ -74,7 +91,12 @@ public:
*
* @return a status
*/
- static Status validate(const ITensorInfo *input, const ITensorInfo *weights, const ITensorInfo *recurrent_weights, const ITensorInfo *bias, const ITensorInfo *hidden_state, const ITensorInfo *output,
+ static Status validate(const ITensorInfo *input,
+ const ITensorInfo *weights,
+ const ITensorInfo *recurrent_weights,
+ const ITensorInfo *bias,
+ const ITensorInfo *hidden_state,
+ const ITensorInfo *output,
const ActivationLayerInfo &info);
// Inherited methods overridden:
@@ -82,16 +104,16 @@ public:
void prepare() override;
private:
- MemoryGroup _memory_group;
- NEGEMM _gemm_state_f;
- NEArithmeticAdditionKernel _add_kernel;
- NEActivationLayerKernel _activation_kernel;
- NEFullyConnectedLayer _fully_connected;
- NECopyKernel _copy_kernel;
- Tensor _fully_connected_out;
- Tensor _gemm_output;
- Tensor _add_output;
- bool _is_prepared;
+ MemoryGroup _memory_group;
+ NEGEMM _gemm_state_f;
+ NEArithmeticAddition _add_f;
+ NEActivationLayer _activation;
+ NEFullyConnectedLayer _fully_connected;
+ NECopy _copy_f;
+ Tensor _fully_connected_out;
+ Tensor _gemm_output;
+ Tensor _add_output;
+ bool _is_prepared;
};
} // namespace arm_compute
#endif /* ARM_COMPUTE_NERNNLAYER_H */