aboutsummaryrefslogtreecommitdiff
path: root/src/cpu/operators/CpuGemmConv2d.h
diff options
context:
space:
mode:
authorFrancesco.Petrogalli@arm.com <francesco.petrogalli@arm.com>2022-04-13 09:28:25 +0000
committerFrancesco Petrogalli <francesco.petrogalli@arm.com>2022-04-22 09:28:41 +0000
commitfa6877f94b12ec80235e55bcfe5a9b6fdc009cf0 (patch)
tree50ee27b12700d095bbd3147ea28ad1c89265170d /src/cpu/operators/CpuGemmConv2d.h
parent50e48aaa15bdf39a2a7ad39daee93b7217d26d32 (diff)
downloadComputeLibrary-fa6877f94b12ec80235e55bcfe5a9b6fdc009cf0.tar.gz
[CpuGemmConv2d] Extract skip_im2col and skip_col2im computation.
This is just refactoring some duplicate code. No functional changes intented. Change-Id: Iff96798b03d25b490341598e676d0e4f2ebd132b Signed-off-by: Francesco.Petrogalli@arm.com <francesco.petrogalli@arm.com> Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/7408 Tested-by: Arm Jenkins <bsgcomp@arm.com> Reviewed-by: Gian Marco Iodice <gianmarco.iodice@arm.com> Reviewed-by: Gunes Bayir <gunes.bayir@arm.com> Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
Diffstat (limited to 'src/cpu/operators/CpuGemmConv2d.h')
-rw-r--r--src/cpu/operators/CpuGemmConv2d.h25
1 files changed, 22 insertions, 3 deletions
diff --git a/src/cpu/operators/CpuGemmConv2d.h b/src/cpu/operators/CpuGemmConv2d.h
index e63e7169b0..aec4a2ffa5 100644
--- a/src/cpu/operators/CpuGemmConv2d.h
+++ b/src/cpu/operators/CpuGemmConv2d.h
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2021 Arm Limited.
+ * Copyright (c) 2021-2022 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -118,8 +118,8 @@ public:
bool enable_fast_math = false, unsigned int num_groups = 1);
// Inherited methods overridden:
- void run(ITensorPack &tensors) override;
- void prepare(ITensorPack &tensors) override;
+ void run(ITensorPack &tensors) override;
+ void prepare(ITensorPack &tensors) override;
experimental::MemoryRequirements workspace() const override;
private:
@@ -168,6 +168,25 @@ private:
*/
static Status validate_gemm3d(const ITensorInfo *src, const ITensorInfo *weights, const ActivationLayerInfo &act_info, int gemm_3d_depth, bool skip_im2col);
+ struct SkipInfo
+ {
+ bool skip_im2col;
+ bool skip_col2im;
+ };
+
+ /** Static function to provide skip_im2col and skip_col2im information.
+ *
+ * @param[in] src Input tensor info.
+ * @param[in] weights Weights tensor info.
+ * @param[in] conv_info Contains padding and stride information described in @ref PadStrideInfo.
+ * @param[in] dilation Dilation, in elements, across x and y.
+ * @param[in] act_info Activation layer information in case of a fused activation.
+ *
+ * @return a SkipInfo instance.
+ */
+ static SkipInfo skip_im_col_info(const ITensorInfo *src, const ITensorInfo *weights, const PadStrideInfo &conv_info,
+ const Size2D &dilation, const ActivationLayerInfo &act_info);
+
enum AuxTensorIdx
{
// CpuGemmLowpMatrixMultiplyCore has up to 8 internal tensors