diff options
Diffstat (limited to 'src/cpu/operators/CpuGemmConv2d.h')
-rw-r--r-- | src/cpu/operators/CpuGemmConv2d.h | 25 |
1 files changed, 22 insertions, 3 deletions
diff --git a/src/cpu/operators/CpuGemmConv2d.h b/src/cpu/operators/CpuGemmConv2d.h index e63e7169b0..aec4a2ffa5 100644 --- a/src/cpu/operators/CpuGemmConv2d.h +++ b/src/cpu/operators/CpuGemmConv2d.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021 Arm Limited. + * Copyright (c) 2021-2022 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -118,8 +118,8 @@ public: bool enable_fast_math = false, unsigned int num_groups = 1); // Inherited methods overridden: - void run(ITensorPack &tensors) override; - void prepare(ITensorPack &tensors) override; + void run(ITensorPack &tensors) override; + void prepare(ITensorPack &tensors) override; experimental::MemoryRequirements workspace() const override; private: @@ -168,6 +168,25 @@ private: */ static Status validate_gemm3d(const ITensorInfo *src, const ITensorInfo *weights, const ActivationLayerInfo &act_info, int gemm_3d_depth, bool skip_im2col); + struct SkipInfo + { + bool skip_im2col; + bool skip_col2im; + }; + + /** Static function to provide skip_im2col and skip_col2im information. + * + * @param[in] src Input tensor info. + * @param[in] weights Weights tensor info. + * @param[in] conv_info Contains padding and stride information described in @ref PadStrideInfo. + * @param[in] dilation Dilation, in elements, across x and y. + * @param[in] act_info Activation layer information in case of a fused activation. + * + * @return a SkipInfo instance. + */ + static SkipInfo skip_im_col_info(const ITensorInfo *src, const ITensorInfo *weights, const PadStrideInfo &conv_info, + const Size2D &dilation, const ActivationLayerInfo &act_info); + enum AuxTensorIdx { // CpuGemmLowpMatrixMultiplyCore has up to 8 internal tensors |