diff options
Diffstat (limited to 'src/gpu/cl/operators/ClConv2d.h')
-rw-r--r-- | src/gpu/cl/operators/ClConv2d.h | 28 |
1 files changed, 21 insertions, 7 deletions
diff --git a/src/gpu/cl/operators/ClConv2d.h b/src/gpu/cl/operators/ClConv2d.h index c6c366a762..0cf3cbc1ce 100644 --- a/src/gpu/cl/operators/ClConv2d.h +++ b/src/gpu/cl/operators/ClConv2d.h @@ -26,6 +26,7 @@ #include "arm_compute/core/Types.h" #include "arm_compute/runtime/FunctionDescriptors.h" + #include "src/gpu/cl/ClCompileContext.h" #include "src/gpu/cl/IClKernel.h" #include "src/gpu/cl/IClOperator.h" @@ -112,15 +113,24 @@ public: * @param[in] conv2d_info Contains convolution 2d info described in @ref Conv2dInfo. * @param[in] weights_info Specifies if the weights tensor has been reshaped with CLWeightsReshapeKernel. Data type supported: Same as @p src. */ - void configure(const CLCompileContext &compile_context, ITensorInfo *src, ITensorInfo *weights, ITensorInfo *biases, ITensorInfo *dst, const Conv2dInfo &conv2d_info, - const WeightsInfo &weights_info = WeightsInfo()); + void configure(const CLCompileContext &compile_context, + ITensorInfo *src, + ITensorInfo *weights, + ITensorInfo *biases, + ITensorInfo *dst, + const Conv2dInfo &conv2d_info, + const WeightsInfo &weights_info = WeightsInfo()); /** Static function to check if given info will lead to a valid configuration of @ref ClConv2d * * Similar to ClConv2d::configure() * * @return a status */ - static Status validate(const ITensorInfo *src, const ITensorInfo *weights, const ITensorInfo *biases, const ITensorInfo *dst, const Conv2dInfo &conv2d_info, + static Status validate(const ITensorInfo *src, + const ITensorInfo *weights, + const ITensorInfo *biases, + const ITensorInfo *dst, + const Conv2dInfo &conv2d_info, const WeightsInfo &weights_info = WeightsInfo()); /** Static function to check if given info will return the convolution called by @ref ClConv2d * @@ -137,11 +147,15 @@ public: * * @return the Convolution Method Hint */ - static ConvolutionMethod get_convolution_method(const ITensorInfo *src, const ITensorInfo *weights, const ITensorInfo *dst, const Conv2dInfo &conv2d_info, - const WeightsInfo &weights_info, const GPUTarget gpu_target); + static ConvolutionMethod get_convolution_method(const ITensorInfo *src, + const ITensorInfo *weights, + const ITensorInfo *dst, + const Conv2dInfo &conv2d_info, + const WeightsInfo &weights_info, + const GPUTarget gpu_target); // Inherited methods overridden: - void run(ITensorPack &tensors) override; - void prepare(ITensorPack &tensors) override; + void run(ITensorPack &tensors) override; + void prepare(ITensorPack &tensors) override; experimental::MemoryRequirements workspace() const override; private: |