aboutsummaryrefslogtreecommitdiff
path: root/arm_compute/runtime/NEON/functions/NEIm2Col.h
diff options
context:
space:
mode:
authorGiorgio Arena <giorgio.arena@arm.com>2018-03-09 15:30:43 +0000
committerAnthony Barbier <anthony.barbier@arm.com>2018-11-02 16:49:16 +0000
commit156fcf3f36f6168e47d65db167bba3af5037e3d9 (patch)
tree89240783068a72b918791cf18a613eb43b93035d /arm_compute/runtime/NEON/functions/NEIm2Col.h
parent8de92619e223225aabdca873c02f231d8e941fd1 (diff)
downloadComputeLibrary-156fcf3f36f6168e47d65db167bba3af5037e3d9.tar.gz
COMPMID-802 Add NHWC data format support for NEON im2col.
Change-Id: I86e678179106a2b83d1c6a7cfe562df91b0f9eb2 Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/124000 Tested-by: Jenkins <bsgcomp@arm.com> Reviewed-by: Pablo Tello <pablo.tello@arm.com>
Diffstat (limited to 'arm_compute/runtime/NEON/functions/NEIm2Col.h')
-rw-r--r--arm_compute/runtime/NEON/functions/NEIm2Col.h22
1 files changed, 17 insertions, 5 deletions
diff --git a/arm_compute/runtime/NEON/functions/NEIm2Col.h b/arm_compute/runtime/NEON/functions/NEIm2Col.h
index cf4999b5af..caa8a011f6 100644
--- a/arm_compute/runtime/NEON/functions/NEIm2Col.h
+++ b/arm_compute/runtime/NEON/functions/NEIm2Col.h
@@ -26,6 +26,7 @@
#include "arm_compute/runtime/NEON/INESimpleFunction.h"
+#include "arm_compute/core/NEON/kernels/NEIm2ColKernel.h"
#include "arm_compute/core/Size2D.h"
#include "arm_compute/core/Types.h"
@@ -34,9 +35,11 @@ namespace arm_compute
class ITensor;
/** Basic function to run @ref NEIm2ColKernel */
-class NEIm2Col : public INESimpleFunction
+class NEIm2Col : public IFunction
{
public:
+ /** Default constructor */
+ NEIm2Col();
/** Configure the im2col NEON kernel
*
* @param[in] input The input tensor to convert. 3 lower dimensions represent a single input [width, height, IFM],
@@ -46,9 +49,10 @@ public:
* @param[in] kernel_dims The kernel dimensions (width and height).
* @param[in] conv_info Contains padding and stride information described in @ref PadStrideInfo.
* @param[in] has_bias In case biases are provided expands the matrix with 1.
- * @param[in] is_fully_connected Determines whether this kernel will be called by @ref NEFullyConnectedLayer in order to validate the arguments
+ * @param[in] is_fully_connected (Optional) Determines whether this function will be called by @ref NEFullyConnectedLayer in order to validate the arguments
+ * @param[in] is_flatten (Optional) Determines whether this function will be called by @ref NEFlattenLayer in order to validate the arguments
*/
- void configure(const ITensor *input, ITensor *output, const Size2D &kernel_dims, const PadStrideInfo &conv_info, bool has_bias, bool is_fully_connected = false);
+ void configure(const ITensor *input, ITensor *output, const Size2D &kernel_dims, const PadStrideInfo &conv_info, bool has_bias, bool is_fully_connected = false, bool is_flatten = false);
/** Static function to check if given info will lead to a valid configuration of @ref NEIm2Col
*
* @param[in] input The input tensor to convert. 3 lower dimensions represent a single input [width, height, IFM],
@@ -58,11 +62,19 @@ public:
* @param[in] kernel_dims The kernel dimensions (width and height).
* @param[in] conv_info Contains padding and stride information described in @ref PadStrideInfo.
* @param[in] has_bias In case biases are provided expands the matrix with 1.
- * @param[in] is_fully_connected Determines whether this kernel will be called by @ref NEFullyConnectedLayer in order to validate the arguments
+ * @param[in] is_fully_connected Determines whether this function will be called by @ref NEFullyConnectedLayer in order to validate the arguments
+ * @param[in] is_flatten Determines whether this function will be called by @ref NEFlattenLayer in order to validate the arguments
*
* @return a status
*/
- static Status validate(const ITensorInfo *input, const ITensorInfo *output, const Size2D &kernel_dims, const PadStrideInfo &conv_info, bool has_bias, bool is_fully_connected);
+ static Status validate(const ITensorInfo *input, const ITensorInfo *output, const Size2D &kernel_dims, const PadStrideInfo &conv_info, bool has_bias, bool is_fully_connected, bool is_flatten);
+
+ // Inherited methods overridden:
+ void run() override;
+
+private:
+ NEIm2ColKernel _kernel;
+ unsigned int _y_dim;
};
}
#endif /* __ARM_COMPUTE_NEIM2COL_H__ */