aboutsummaryrefslogtreecommitdiff
path: root/arm_compute/runtime/NEON/functions/NEGEMMConvolutionLayer.h
diff options
context:
space:
mode:
authorGian Marco Iodice <gianmarco.iodice@arm.com>2018-08-08 12:29:38 +0100
committerAnthony Barbier <anthony.barbier@arm.com>2018-11-02 16:54:54 +0000
commitdb9d46da3a8645d0c2cc71d035448999a36770ec (patch)
tree79c61d6ad845e60d86fa3059056b8ac3fffc236b /arm_compute/runtime/NEON/functions/NEGEMMConvolutionLayer.h
parent980002bd5848f065b02a31bb105e47a5deb7bc98 (diff)
downloadComputeLibrary-db9d46da3a8645d0c2cc71d035448999a36770ec.tar.gz
COMPMID-1485 - Add support for NHWC when running NEGEMMConvolutionLayer with FP16/QASYMM8
When the GEMM3D check fails, now we fallback to the classic implementation with im2col and col2im. In this manner the function can work with QASYMM8 and FP16 Change-Id: I359e9da3a63956f33b5acbc9bca4383b14af10e2 Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/143372 Tested-by: Jenkins <bsgcomp@arm.com> Reviewed-by: Anthony Barbier <anthony.barbier@arm.com>
Diffstat (limited to 'arm_compute/runtime/NEON/functions/NEGEMMConvolutionLayer.h')
-rw-r--r--arm_compute/runtime/NEON/functions/NEGEMMConvolutionLayer.h13
1 files changed, 12 insertions, 1 deletions
diff --git a/arm_compute/runtime/NEON/functions/NEGEMMConvolutionLayer.h b/arm_compute/runtime/NEON/functions/NEGEMMConvolutionLayer.h
index a362a29a82..e587cb4e6f 100644
--- a/arm_compute/runtime/NEON/functions/NEGEMMConvolutionLayer.h
+++ b/arm_compute/runtime/NEON/functions/NEGEMMConvolutionLayer.h
@@ -37,6 +37,7 @@
#include "arm_compute/runtime/NEON/functions/NEGEMM.h"
#include "arm_compute/runtime/NEON/functions/NEGEMMLowpMatrixMultiplyCore.h"
#include "arm_compute/runtime/NEON/functions/NEGEMMLowpOutputStage.h"
+#include "arm_compute/runtime/NEON/functions/NEReshapeLayer.h"
#include "arm_compute/runtime/Tensor.h"
#include <memory>
@@ -84,7 +85,7 @@ private:
* -# @ref NEGEMMLowpMatrixMultiplyCore (if the data type is QASYMM8)
* -# @ref NEGEMMLowpQuantizeDownInt32ToUint8ScaleByFixedPoint (if the data type is QASYMM8)
* -# @ref NEArithmeticAdditionKernel (if biases != nullptr and we have a 1x1 convolution with the NHWC data layout)
- * -# @ref NECol2ImKernel
+ * -# @ref NECol2ImKernel or @ref NEReshapeLayer (if NHWC and GEMM3D is not supported)
*
*/
class NEGEMMConvolutionLayer : public IFunction
@@ -165,6 +166,15 @@ private:
* @return a status
*/
static Status validate_mm(const ITensorInfo *input, const ITensorInfo *weights, const ITensorInfo *output, int gemm_3d_depth = 1, bool skip_im2col = false);
+ /** Static function to check if GEMM3D is supported in @ref NEGEMM or in @ref NEGEMMLowpMatrixMultiplyCore
+ *
+ * @param[in] data_type Input data type
+ * @param[in] gemm_3d_depth Depth of GEMM 3D
+ * @param[in] skip_im2col Flag which specifies if im2col has to be skipped. i.e. 1x1 convolution with NHWC data layout
+ *
+ * @return a status
+ */
+ static Status validate_gemm3d(DataType data_type, int gemm_3d_depth, bool skip_im2col);
private:
MemoryGroup _memory_group;
@@ -176,6 +186,7 @@ private:
NECol2ImKernel _col2im_kernel;
NEActivationLayer _activationlayer_function;
NEArithmeticAdditionKernel _add_bias_kernel;
+ NEReshapeLayer _reshape_layer;
const ITensor *_original_weights;