aboutsummaryrefslogtreecommitdiff
path: root/src/gpu/cl/operators/ClGemmConv2d.h
diff options
context:
space:
mode:
Diffstat (limited to 'src/gpu/cl/operators/ClGemmConv2d.h')
-rw-r--r--src/gpu/cl/operators/ClGemmConv2d.h35
1 files changed, 28 insertions, 7 deletions
diff --git a/src/gpu/cl/operators/ClGemmConv2d.h b/src/gpu/cl/operators/ClGemmConv2d.h
index 8a46ee2dc3..e8f3147ac3 100644
--- a/src/gpu/cl/operators/ClGemmConv2d.h
+++ b/src/gpu/cl/operators/ClGemmConv2d.h
@@ -27,6 +27,7 @@
#include "arm_compute/core/TensorInfo.h"
#include "arm_compute/core/Types.h"
#include "arm_compute/runtime/FunctionDescriptors.h"
+
#include "src/gpu/cl/ClCompileContext.h"
#include "src/gpu/cl/IClOperator.h"
@@ -100,15 +101,24 @@ public:
* @param[in] weights_info Specifies if the weights tensor has been reshaped with CLWeightsReshapeKernel. If this is not part of the fully connected layer the weights
* tensor has also been transposed with CLGEMMReshapeRHSMatrixKernel. Data type supported: Same as @p input.
*/
- void configure(const ClCompileContext &compile_context, ITensorInfo *src, ITensorInfo *weights, ITensorInfo *biases, ITensorInfo *dst, const Conv2dInfo &conv2d_info,
- const WeightsInfo &weights_info = WeightsInfo());
+ void configure(const ClCompileContext &compile_context,
+ ITensorInfo *src,
+ ITensorInfo *weights,
+ ITensorInfo *biases,
+ ITensorInfo *dst,
+ const Conv2dInfo &conv2d_info,
+ const WeightsInfo &weights_info = WeightsInfo());
/** Static function to check if given info will lead to a valid configuration
*
* Similar to ClGemmConvolution::configure()
*
* @return a status
*/
- static Status validate(const ITensorInfo *input, const ITensorInfo *weights, const ITensorInfo *biases, const ITensorInfo *output, const Conv2dInfo &conv2d_info,
+ static Status validate(const ITensorInfo *input,
+ const ITensorInfo *weights,
+ const ITensorInfo *biases,
+ const ITensorInfo *output,
+ const Conv2dInfo &conv2d_info,
const WeightsInfo &weights_info = WeightsInfo());
// Inherited methods overridden:
@@ -130,9 +140,14 @@ private:
* @param[in] gemm_3d_depth Depth of GEMM 3D
* @param[in] act_info Activation to apply after the matrix multiplication
*/
- void configure_mm(const CLCompileContext &compile_context, const ITensorInfo *src, ITensorInfo *weights, ITensorInfo *biases, ITensorInfo *dst,
+ void configure_mm(const CLCompileContext &compile_context,
+ const ITensorInfo *src,
+ ITensorInfo *weights,
+ ITensorInfo *biases,
+ ITensorInfo *dst,
const GEMMLowpOutputStageInfo &gemmlowp_output_stage,
- int gemm_3d_depth, const ActivationLayerInfo &act_info);
+ int gemm_3d_depth,
+ const ActivationLayerInfo &act_info);
/** Static function to check if given info will lead to a valid configuration of @ref CLGEMMConvolutionLayer matrix multiply routines
*
* @param[in] src Input tensor info. Data types supported: QASYMM8/QASYMM8_SIGNED/F16/F32.
@@ -148,8 +163,14 @@ private:
*
* @return a status
*/
- static Status validate_mm(const ITensorInfo *src, const ITensorInfo *weights, const ITensorInfo *biases, const ITensorInfo *dst, const GEMMLowpOutputStageInfo &gemmlowp_output_stage,
- int gemm_3d_depth, bool skip_im2col, const ActivationLayerInfo &act_info);
+ static Status validate_mm(const ITensorInfo *src,
+ const ITensorInfo *weights,
+ const ITensorInfo *biases,
+ const ITensorInfo *dst,
+ const GEMMLowpOutputStageInfo &gemmlowp_output_stage,
+ int gemm_3d_depth,
+ bool skip_im2col,
+ const ActivationLayerInfo &act_info);
enum AuxTensorIdx
{