diff options
Diffstat (limited to 'arm_compute/runtime/NEON/functions/NERNNLayer.h')
-rw-r--r-- | arm_compute/runtime/NEON/functions/NERNNLayer.h | 49 |
1 files changed, 35 insertions, 14 deletions
diff --git a/arm_compute/runtime/NEON/functions/NERNNLayer.h b/arm_compute/runtime/NEON/functions/NERNNLayer.h index c42b303a89..af7f464ac9 100644 --- a/arm_compute/runtime/NEON/functions/NERNNLayer.h +++ b/arm_compute/runtime/NEON/functions/NERNNLayer.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018-2020 Arm Limited. + * Copyright (c) 2018-2021 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -27,6 +27,7 @@ #include "arm_compute/core/Types.h" #include "arm_compute/runtime/NEON/functions/NEActivationLayer.h" #include "arm_compute/runtime/NEON/functions/NEArithmeticAddition.h" +#include "arm_compute/runtime/NEON/functions/NECopy.h" #include "arm_compute/runtime/NEON/functions/NEFullyConnectedLayer.h" #include "arm_compute/runtime/NEON/functions/NEGEMM.h" @@ -34,7 +35,6 @@ namespace arm_compute { // Forward declarations class ITensor; -class NECopyKernel; /** Basic function to run @ref NERNNLayer */ class NERNNLayer : public IFunction @@ -54,6 +54,16 @@ public: ~NERNNLayer(); /** Initialize the function * + * Valid data layouts: + * - NHWC + * - NCHW + * + * Valid data type configurations: + * |src0 |src1 |src2 |src3 |dst0 |dst1 | + * |:------|:------|:------|:------|:------|:------| + * |F16 |F16 |F16 |F16 |F16 |F16 | + * |F32 |F32 |F32 |F32 |F32 |F32 | + * * @param[in] input Input is a 2-D tensor of shape [input_size, batch_size]. Data types supported: F16/F32 * @param[in] weights Weights tensor of shape [input_size, num_units] that multiplies the input. Data types supported: Same as @p input * @param[in] recurrent_weights Weights tensor of shape [num_units, num_units] that multiplies the current 'state'. Data types supported: Same as @p input @@ -62,7 +72,13 @@ public: * @param[in,out] hidden_state Output tensor of shape [num_units, batch_size]. Data types supported: Same as @p input * @param[in] info Activation layer parameter. */ - void configure(const ITensor *input, const ITensor *weights, const ITensor *recurrent_weights, const ITensor *bias, ITensor *hidden_state, ITensor *output, ActivationLayerInfo &info); + void configure(const ITensor *input, + const ITensor *weights, + const ITensor *recurrent_weights, + const ITensor *bias, + ITensor *hidden_state, + ITensor *output, + ActivationLayerInfo &info); /** Initialize the function * * @param[in] input Input is a 2-D tensor of shape [input_size, batch_size]. Data types supported: F16/F32 @@ -75,7 +91,12 @@ public: * * @return a status */ - static Status validate(const ITensorInfo *input, const ITensorInfo *weights, const ITensorInfo *recurrent_weights, const ITensorInfo *bias, const ITensorInfo *hidden_state, const ITensorInfo *output, + static Status validate(const ITensorInfo *input, + const ITensorInfo *weights, + const ITensorInfo *recurrent_weights, + const ITensorInfo *bias, + const ITensorInfo *hidden_state, + const ITensorInfo *output, const ActivationLayerInfo &info); // Inherited methods overridden: @@ -83,16 +104,16 @@ public: void prepare() override; private: - MemoryGroup _memory_group; - NEGEMM _gemm_state_f; - NEArithmeticAddition _add_f; - NEActivationLayer _activation; - NEFullyConnectedLayer _fully_connected; - std::unique_ptr<NECopyKernel> _copy_kernel; - Tensor _fully_connected_out; - Tensor _gemm_output; - Tensor _add_output; - bool _is_prepared; + MemoryGroup _memory_group; + NEGEMM _gemm_state_f; + NEArithmeticAddition _add_f; + NEActivationLayer _activation; + NEFullyConnectedLayer _fully_connected; + NECopy _copy_f; + Tensor _fully_connected_out; + Tensor _gemm_output; + Tensor _add_output; + bool _is_prepared; }; } // namespace arm_compute #endif /* ARM_COMPUTE_NERNNLAYER_H */ |