diff options
Diffstat (limited to 'arm_compute/runtime')
-rw-r--r-- | arm_compute/runtime/CL/functions/CLGEMMConvolutionLayer.h | 27 |
1 files changed, 16 insertions, 11 deletions
diff --git a/arm_compute/runtime/CL/functions/CLGEMMConvolutionLayer.h b/arm_compute/runtime/CL/functions/CLGEMMConvolutionLayer.h index 3dde52989b..2c1f7a9d5e 100644 --- a/arm_compute/runtime/CL/functions/CLGEMMConvolutionLayer.h +++ b/arm_compute/runtime/CL/functions/CLGEMMConvolutionLayer.h @@ -158,22 +158,24 @@ public: private: /** Configures the appropriate matrix multiply routine * - * @param input Input tensor. Data types supported: QS8/QASYMM8/QS16/F16/F32. - * @param weights Weights tensor. Data type supported: Same as @p input. - * @param output Output tensor. Data types supported: Same as @p input, - * except for input of QASYMM8 type where output should be of S32 type. + * @param[in] input Input tensor. Data types supported: QS8/QASYMM8/QS16/F16/F32. + * @param[in] weights Weights tensor. Data type supported: Same as @p input. + * @param[in, out] output Output tensor. Data types supported: Same as @p input, + * except for input of QASYMM8 type where output should be of S32 type. + * @param[in] gemm_3d_depth (Optional) Depth of GEMM 3D (Defaults to 1) */ - void configure_mm(const ICLTensor *input, const ICLTensor *weights, ICLTensor *output); + void configure_mm(const ICLTensor *input, const ICLTensor *weights, ICLTensor *output, int gemm_3d_depth = 1); /** Static function to check if given info will lead to a valid configuration of @ref CLGEMMConvolutionLayer matrix multiply routines * - * @param[in] input Input tensor. Data types supported: QS8/QASYMM8/QS16/F16/F32. - * @param[in] weights Weights tensor. Data type supported: Same as @p input. - * @param[in] output Output tensor. Data types supported: Same as @p input, - * except for input of QASYMM8 type where output should be of S32 type. + * @param[in] input Input tensor. Data types supported: QS8/QASYMM8/QS16/F16/F32. + * @param[in] weights Weights tensor. Data type supported: Same as @p input. + * @param[in] output Output tensor. Data types supported: Same as @p input, + * except for input of QASYMM8 type where output should be of S32 type. + * @param[in] gemm_3d_depth (Optional) Depth of GEMM 3D (Defaults to 1) * * @return a status */ - static Status validate_mm(const ITensorInfo *input, const ITensorInfo *weights, const ITensorInfo *output); + static Status validate_mm(const ITensorInfo *input, const ITensorInfo *weights, const ITensorInfo *output, int gemm_3d_depth = 1); private: CLMemoryGroup _memory_group; @@ -192,9 +194,12 @@ private: CLTensor _gemm_output; CLTensor _tmp_output; + DataLayout _data_layout; + + bool _skip_im2col; bool _is_quantized; bool _is_activationlayer_enabled; bool _is_prepared; }; -} +} // namespace arm_compute #endif /* __ARM_COMPUTE_CLGEMMCONVOLUTIONLAYER_H__ */ |