diff options
author | Georgios Pinitas <georgios.pinitas@arm.com> | 2018-06-19 13:09:53 +0100 |
---|---|---|
committer | Anthony Barbier <anthony.barbier@arm.com> | 2018-11-02 16:53:34 +0000 |
commit | 19ea419e7f14d02aeb208c2fbd5a4ac55f4cb101 (patch) | |
tree | fe04ed9d40ebb8b717f63490f672a28c5b27d01e /arm_compute/runtime/CL | |
parent | bb71fe50930f5669a7a325e0fa95fee559856793 (diff) | |
download | ComputeLibrary-19ea419e7f14d02aeb208c2fbd5a4ac55f4cb101.tar.gz |
COMPMID-809: Add NHWC data format on CLGEMMConvolutionLayer.
Change-Id: I50e4f5e7d47e21c300f754bee2c216863075b5cf
Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/136191
Tested-by: Jenkins <bsgcomp@arm.com>
Reviewed-by: Giorgio Arena <giorgio.arena@arm.com>
Reviewed-by: Gian Marco Iodice <gianmarco.iodice@arm.com>
Diffstat (limited to 'arm_compute/runtime/CL')
-rw-r--r-- | arm_compute/runtime/CL/functions/CLGEMMConvolutionLayer.h | 27 |
1 files changed, 16 insertions, 11 deletions
diff --git a/arm_compute/runtime/CL/functions/CLGEMMConvolutionLayer.h b/arm_compute/runtime/CL/functions/CLGEMMConvolutionLayer.h index 3dde52989b..2c1f7a9d5e 100644 --- a/arm_compute/runtime/CL/functions/CLGEMMConvolutionLayer.h +++ b/arm_compute/runtime/CL/functions/CLGEMMConvolutionLayer.h @@ -158,22 +158,24 @@ public: private: /** Configures the appropriate matrix multiply routine * - * @param input Input tensor. Data types supported: QS8/QASYMM8/QS16/F16/F32. - * @param weights Weights tensor. Data type supported: Same as @p input. - * @param output Output tensor. Data types supported: Same as @p input, - * except for input of QASYMM8 type where output should be of S32 type. + * @param[in] input Input tensor. Data types supported: QS8/QASYMM8/QS16/F16/F32. + * @param[in] weights Weights tensor. Data type supported: Same as @p input. + * @param[in, out] output Output tensor. Data types supported: Same as @p input, + * except for input of QASYMM8 type where output should be of S32 type. + * @param[in] gemm_3d_depth (Optional) Depth of GEMM 3D (Defaults to 1) */ - void configure_mm(const ICLTensor *input, const ICLTensor *weights, ICLTensor *output); + void configure_mm(const ICLTensor *input, const ICLTensor *weights, ICLTensor *output, int gemm_3d_depth = 1); /** Static function to check if given info will lead to a valid configuration of @ref CLGEMMConvolutionLayer matrix multiply routines * - * @param[in] input Input tensor. Data types supported: QS8/QASYMM8/QS16/F16/F32. - * @param[in] weights Weights tensor. Data type supported: Same as @p input. - * @param[in] output Output tensor. Data types supported: Same as @p input, - * except for input of QASYMM8 type where output should be of S32 type. + * @param[in] input Input tensor. Data types supported: QS8/QASYMM8/QS16/F16/F32. + * @param[in] weights Weights tensor. Data type supported: Same as @p input. + * @param[in] output Output tensor. Data types supported: Same as @p input, + * except for input of QASYMM8 type where output should be of S32 type. + * @param[in] gemm_3d_depth (Optional) Depth of GEMM 3D (Defaults to 1) * * @return a status */ - static Status validate_mm(const ITensorInfo *input, const ITensorInfo *weights, const ITensorInfo *output); + static Status validate_mm(const ITensorInfo *input, const ITensorInfo *weights, const ITensorInfo *output, int gemm_3d_depth = 1); private: CLMemoryGroup _memory_group; @@ -192,9 +194,12 @@ private: CLTensor _gemm_output; CLTensor _tmp_output; + DataLayout _data_layout; + + bool _skip_im2col; bool _is_quantized; bool _is_activationlayer_enabled; bool _is_prepared; }; -} +} // namespace arm_compute #endif /* __ARM_COMPUTE_CLGEMMCONVOLUTIONLAYER_H__ */ |