aboutsummaryrefslogtreecommitdiff
path: root/arm_compute/runtime/CL/functions/CLGEMMConvolutionLayer.h
diff options
context:
space:
mode:
Diffstat (limited to 'arm_compute/runtime/CL/functions/CLGEMMConvolutionLayer.h')
-rw-r--r--arm_compute/runtime/CL/functions/CLGEMMConvolutionLayer.h27
1 files changed, 16 insertions, 11 deletions
diff --git a/arm_compute/runtime/CL/functions/CLGEMMConvolutionLayer.h b/arm_compute/runtime/CL/functions/CLGEMMConvolutionLayer.h
index 3dde52989b..2c1f7a9d5e 100644
--- a/arm_compute/runtime/CL/functions/CLGEMMConvolutionLayer.h
+++ b/arm_compute/runtime/CL/functions/CLGEMMConvolutionLayer.h
@@ -158,22 +158,24 @@ public:
private:
/** Configures the appropriate matrix multiply routine
*
- * @param input Input tensor. Data types supported: QS8/QASYMM8/QS16/F16/F32.
- * @param weights Weights tensor. Data type supported: Same as @p input.
- * @param output Output tensor. Data types supported: Same as @p input,
- * except for input of QASYMM8 type where output should be of S32 type.
+ * @param[in] input Input tensor. Data types supported: QS8/QASYMM8/QS16/F16/F32.
+ * @param[in] weights Weights tensor. Data type supported: Same as @p input.
+ * @param[in, out] output Output tensor. Data types supported: Same as @p input,
+ * except for input of QASYMM8 type where output should be of S32 type.
+ * @param[in] gemm_3d_depth (Optional) Depth of GEMM 3D (Defaults to 1)
*/
- void configure_mm(const ICLTensor *input, const ICLTensor *weights, ICLTensor *output);
+ void configure_mm(const ICLTensor *input, const ICLTensor *weights, ICLTensor *output, int gemm_3d_depth = 1);
/** Static function to check if given info will lead to a valid configuration of @ref CLGEMMConvolutionLayer matrix multiply routines
*
- * @param[in] input Input tensor. Data types supported: QS8/QASYMM8/QS16/F16/F32.
- * @param[in] weights Weights tensor. Data type supported: Same as @p input.
- * @param[in] output Output tensor. Data types supported: Same as @p input,
- * except for input of QASYMM8 type where output should be of S32 type.
+ * @param[in] input Input tensor. Data types supported: QS8/QASYMM8/QS16/F16/F32.
+ * @param[in] weights Weights tensor. Data type supported: Same as @p input.
+ * @param[in] output Output tensor. Data types supported: Same as @p input,
+ * except for input of QASYMM8 type where output should be of S32 type.
+ * @param[in] gemm_3d_depth (Optional) Depth of GEMM 3D (Defaults to 1)
*
* @return a status
*/
- static Status validate_mm(const ITensorInfo *input, const ITensorInfo *weights, const ITensorInfo *output);
+ static Status validate_mm(const ITensorInfo *input, const ITensorInfo *weights, const ITensorInfo *output, int gemm_3d_depth = 1);
private:
CLMemoryGroup _memory_group;
@@ -192,9 +194,12 @@ private:
CLTensor _gemm_output;
CLTensor _tmp_output;
+ DataLayout _data_layout;
+
+ bool _skip_im2col;
bool _is_quantized;
bool _is_activationlayer_enabled;
bool _is_prepared;
};
-}
+} // namespace arm_compute
#endif /* __ARM_COMPUTE_CLGEMMCONVOLUTIONLAYER_H__ */