aboutsummaryrefslogtreecommitdiff
path: root/src/gpu/cl/operators/ClWinogradConv2d.h
diff options
context:
space:
mode:
Diffstat (limited to 'src/gpu/cl/operators/ClWinogradConv2d.h')
-rw-r--r--src/gpu/cl/operators/ClWinogradConv2d.h26
1 files changed, 19 insertions, 7 deletions
diff --git a/src/gpu/cl/operators/ClWinogradConv2d.h b/src/gpu/cl/operators/ClWinogradConv2d.h
index eb2f7a72b2..54ec1a1737 100644
--- a/src/gpu/cl/operators/ClWinogradConv2d.h
+++ b/src/gpu/cl/operators/ClWinogradConv2d.h
@@ -25,6 +25,7 @@
#define ARM_COMPUTE_CL_WINOGRADCONV2D_H
#include "arm_compute/runtime/CL/CLTensor.h"
+
#include "src/core/CL/kernels/CLFillBorderKernel.h"
#include "src/gpu/cl/ClCompileContext.h"
#include "src/gpu/cl/IClOperator.h"
@@ -41,7 +42,7 @@ namespace kernels
class ClWinogradInputTransformKernel;
class ClWinogradFilterTransformKernel;
class ClWinogradOutputTransformKernel;
-} // kernels
+} // namespace kernels
/** Basic function to execute Winograd-based convolution on OpenCL. This function calls the following OpenCL functions/kernels:
*
* -# @ref kernels::ClWinogradInputTransformKernel
@@ -93,20 +94,31 @@ public:
* @param[in] enable_fast_math (Optional) Enable fast math computation. In case this flag were set, the function could dispatch the fastest implementation
* available which may introduce a drop of accuracy as well. Default is false
*/
- void configure(const ClCompileContext &compile_context, ITensorInfo *src, ITensorInfo *weights, ITensorInfo *biases, ITensorInfo *dst, const PadStrideInfo &conv_info,
- const ActivationLayerInfo &act_info = ActivationLayerInfo(), bool enable_fast_math = false);
+ void configure(const ClCompileContext &compile_context,
+ ITensorInfo *src,
+ ITensorInfo *weights,
+ ITensorInfo *biases,
+ ITensorInfo *dst,
+ const PadStrideInfo &conv_info,
+ const ActivationLayerInfo &act_info = ActivationLayerInfo(),
+ bool enable_fast_math = false);
/** Static function to check if given info will lead to a valid configuration
*
* Similar to ClWinogradConv2d::configure()
*
* @return a status
*/
- static Status validate(const ITensorInfo *src, const ITensorInfo *weights, const ITensorInfo *biases, const ITensorInfo *dst, const PadStrideInfo &conv_info,
- const ActivationLayerInfo &act_info = ActivationLayerInfo(), bool enable_fast_math = false);
+ static Status validate(const ITensorInfo *src,
+ const ITensorInfo *weights,
+ const ITensorInfo *biases,
+ const ITensorInfo *dst,
+ const PadStrideInfo &conv_info,
+ const ActivationLayerInfo &act_info = ActivationLayerInfo(),
+ bool enable_fast_math = false);
// Inherited method overridden
- void run(ITensorPack &tensors) override;
- void prepare(ITensorPack &tensors) override;
+ void run(ITensorPack &tensors) override;
+ void prepare(ITensorPack &tensors) override;
experimental::MemoryRequirements workspace() const override;
private: