aboutsummaryrefslogtreecommitdiff
path: root/src/cpu/operators/CpuGemmConv2d.h
diff options
context:
space:
mode:
Diffstat (limited to 'src/cpu/operators/CpuGemmConv2d.h')
-rw-r--r--src/cpu/operators/CpuGemmConv2d.h24
1 files changed, 22 insertions, 2 deletions
diff --git a/src/cpu/operators/CpuGemmConv2d.h b/src/cpu/operators/CpuGemmConv2d.h
index aec4a2ffa5..f8f0bce048 100644
--- a/src/cpu/operators/CpuGemmConv2d.h
+++ b/src/cpu/operators/CpuGemmConv2d.h
@@ -117,6 +117,17 @@ public:
const WeightsInfo &weights_info = WeightsInfo(), const Size2D &dilation = Size2D(1U, 1U), const ActivationLayerInfo &act_info = ActivationLayerInfo(),
bool enable_fast_math = false, unsigned int num_groups = 1);
+ /** Indicates whether or not there is an optimal assembly implementation that can be used to process the given parameters.
+ *
+ * The paramter list is the same as @ref NEGEMMConvolutionLayer::has_opt_impl
+ *
+ * @return a status.
+ */
+ static Status has_opt_impl(arm_gemm::WeightFormat &expected_weight_format, const ITensorInfo *src, const ITensorInfo *weights, const ITensorInfo *biases, const ITensorInfo *output,
+ const PadStrideInfo &conv_info,
+ const WeightsInfo &weights_info = WeightsInfo(), const Size2D &dilation = Size2D(1U, 1U), const ActivationLayerInfo &act_info = ActivationLayerInfo(),
+ const bool enable_fast_math = false);
+
// Inherited methods overridden:
void run(ITensorPack &tensors) override;
void prepare(ITensorPack &tensors) override;
@@ -135,9 +146,11 @@ private:
* @param[in] enable_fast_math (Optional) Enable fast math computation. In case this flag were set, the function could dispatch the fastest implementation
* available which may introduce a drop of accuracy as well. Default is false
* @param[in] gemm_3d_depth (Optional) Depth of GEMM 3D (Defaults to 1)
+ * @param[in] fixed_format (Optional) Select GEMM execution with variable weights.
+ * @param[in] weight_format (Optional) The layout to be used for the weights tensor when running GEMM with variable weights.
*/
void configure_mm(const ITensorInfo *src, const ITensorInfo *weights, const ITensorInfo *biases, ITensorInfo *output, const ActivationLayerInfo &act_info = ActivationLayerInfo(),
- bool enable_fast_math = false, int gemm_3d_depth = 1);
+ bool enable_fast_math = false, int gemm_3d_depth = 1, bool fixed_format = false, arm_gemm::WeightFormat weight_format = arm_gemm::WeightFormat::UNSPECIFIED);
/** Static function to check if given info will lead to a valid configuration of @ref NEGEMMConvolutionLayer matrix multiply routines
*
* @param[in] src Input tensor info. Data types supported: QASYMM8/QASYMM8_SIGNED/BFLOAT16/F16/F32.
@@ -151,11 +164,13 @@ private:
* available which may introduce a drop of accuracy as well. Default is false
* @param[in] gemm_3d_depth (Optional) Depth of GEMM 3D (Defaults to 1)
* @param[in] skip_im2col (Optional) Flag which specifies if im2col has to be skipped. i.e. 1x1 convolution with NHWC data layout. (Default to false)
+ * @param[in] fixed_format (Optional) Select GEMM execution with variable weights.
+ * @param[in] weight_format (Optional) The layout to be used for the weights tensor when running GEMM with variable weights.
*
* @return a status
*/
static Status validate_mm(const ITensorInfo *src, const ITensorInfo *weights, const ITensorInfo *biases, const ITensorInfo *dst, const ActivationLayerInfo &act_info = ActivationLayerInfo(),
- bool enable_fast_math = false, int gemm_3d_depth = 1, bool skip_im2col = false);
+ bool enable_fast_math = false, int gemm_3d_depth = 1, bool skip_im2col = false, bool fixed_format = false, arm_gemm::WeightFormat weight_format = arm_gemm::WeightFormat::UNSPECIFIED);
/** Static function to check if GEMM3D is supported in @ref NEGEMM or in @ref CpuGemmMLowpMatrixMultiplyCore
*
* @param[in] src Input tensor info. Data types supported: QASYMM8/QASYMM8_SIGNED/BFLOAT16/F16/F32.
@@ -187,6 +202,11 @@ private:
static SkipInfo skip_im_col_info(const ITensorInfo *src, const ITensorInfo *weights, const PadStrideInfo &conv_info,
const Size2D &dilation, const ActivationLayerInfo &act_info);
+ /** Indicates if the convolution executes in variable weights mode.
+ *
+ * Similar to @ref CpuGemm::isVarWeightsKernel
+ */
+ bool isVarWeightsKernel() const;
enum AuxTensorIdx
{
// CpuGemmLowpMatrixMultiplyCore has up to 8 internal tensors