aboutsummaryrefslogtreecommitdiff
path: root/src/gpu/cl/operators/ClConv2d.h
diff options
context:
space:
mode:
Diffstat (limited to 'src/gpu/cl/operators/ClConv2d.h')
-rw-r--r--src/gpu/cl/operators/ClConv2d.h28
1 files changed, 21 insertions, 7 deletions
diff --git a/src/gpu/cl/operators/ClConv2d.h b/src/gpu/cl/operators/ClConv2d.h
index c6c366a762..0cf3cbc1ce 100644
--- a/src/gpu/cl/operators/ClConv2d.h
+++ b/src/gpu/cl/operators/ClConv2d.h
@@ -26,6 +26,7 @@
#include "arm_compute/core/Types.h"
#include "arm_compute/runtime/FunctionDescriptors.h"
+
#include "src/gpu/cl/ClCompileContext.h"
#include "src/gpu/cl/IClKernel.h"
#include "src/gpu/cl/IClOperator.h"
@@ -112,15 +113,24 @@ public:
* @param[in] conv2d_info Contains convolution 2d info described in @ref Conv2dInfo.
* @param[in] weights_info Specifies if the weights tensor has been reshaped with CLWeightsReshapeKernel. Data type supported: Same as @p src.
*/
- void configure(const CLCompileContext &compile_context, ITensorInfo *src, ITensorInfo *weights, ITensorInfo *biases, ITensorInfo *dst, const Conv2dInfo &conv2d_info,
- const WeightsInfo &weights_info = WeightsInfo());
+ void configure(const CLCompileContext &compile_context,
+ ITensorInfo *src,
+ ITensorInfo *weights,
+ ITensorInfo *biases,
+ ITensorInfo *dst,
+ const Conv2dInfo &conv2d_info,
+ const WeightsInfo &weights_info = WeightsInfo());
/** Static function to check if given info will lead to a valid configuration of @ref ClConv2d
*
* Similar to ClConv2d::configure()
*
* @return a status
*/
- static Status validate(const ITensorInfo *src, const ITensorInfo *weights, const ITensorInfo *biases, const ITensorInfo *dst, const Conv2dInfo &conv2d_info,
+ static Status validate(const ITensorInfo *src,
+ const ITensorInfo *weights,
+ const ITensorInfo *biases,
+ const ITensorInfo *dst,
+ const Conv2dInfo &conv2d_info,
const WeightsInfo &weights_info = WeightsInfo());
/** Static function to check if given info will return the convolution called by @ref ClConv2d
*
@@ -137,11 +147,15 @@ public:
*
* @return the Convolution Method Hint
*/
- static ConvolutionMethod get_convolution_method(const ITensorInfo *src, const ITensorInfo *weights, const ITensorInfo *dst, const Conv2dInfo &conv2d_info,
- const WeightsInfo &weights_info, const GPUTarget gpu_target);
+ static ConvolutionMethod get_convolution_method(const ITensorInfo *src,
+ const ITensorInfo *weights,
+ const ITensorInfo *dst,
+ const Conv2dInfo &conv2d_info,
+ const WeightsInfo &weights_info,
+ const GPUTarget gpu_target);
// Inherited methods overridden:
- void run(ITensorPack &tensors) override;
- void prepare(ITensorPack &tensors) override;
+ void run(ITensorPack &tensors) override;
+ void prepare(ITensorPack &tensors) override;
experimental::MemoryRequirements workspace() const override;
private: