diff options
Diffstat (limited to 'src/cpu/operators/CpuGemmConv2d.h')
-rw-r--r-- | src/cpu/operators/CpuGemmConv2d.h | 55 |
1 files changed, 34 insertions, 21 deletions
diff --git a/src/cpu/operators/CpuGemmConv2d.h b/src/cpu/operators/CpuGemmConv2d.h index 118d366517..48a0d11107 100644 --- a/src/cpu/operators/CpuGemmConv2d.h +++ b/src/cpu/operators/CpuGemmConv2d.h @@ -42,21 +42,12 @@ class CpuGemmLowpOutputStage; class CpuReshape; namespace kernels { -class CpuWeightsReshapeKernel; class CpuIm2ColKernel; class CpuCol2ImKernel; +class CpuWeightsReshapeKernel; } // namespace kernels -/** Basic function to compute the convolution layer. This function calls the following kernels/functions: - * - * -# @ref cpu::kernels::CpuIm2ColKernel - * -# @ref CpuGemm (if the data type is BFLOAT16/FP16/FP32) - * -# @ref CpuGemmLowpMatrixMultiplyCore (if the data type is QASYMM8/QASYMM8_SIGNED) - * -# @ref CpuGemmLowpOutputStage (if the data type is QASYMM8/QASYMM8_SIGNED) - * -# @ref cpu::kernels::CpuCol2ImKernel (if NCHW data layout) - * -# @ref kernels::CpuWeightsReshapeKernel - * - */ +/** Basic function to compute the convolution layer. @ref note_CpuGemmConv2d_weight_transformation */ class CpuGemmConv2d : public ICpuOperator { public: @@ -99,7 +90,7 @@ public: * @param[out] dst Destination tensor info. 3 lower dimensions represent a single output [width, height, OFM], while the rest represent batch of outputs. * Data types supported: Same as @p input. * @param[in] conv_info Contains padding and stride information described in @ref PadStrideInfo. - * @param[in] weights_info Specifies if the weights tensor has been reshaped with NEWeightsReshapeKernel. If this is not part of the fully connected layer the weights + * @param[in] weights_info Specifies if the weights tensor has been reshaped with CpuWeightsReshapeKernel. If this is not part of the fully connected layer the weights * tensor has also been transposed with cpu::kernels::CpuGemmTranspose1xWKernel. Data type supported: Same as @p input. * @param[in] dilation (Optional) Dilation, in elements, across x and y. Defaults to (1, 1). * @param[in] act_info (Optional) Activation layer information in case of a fused activation. Only RELU, BOUNDED_RELU and LU_BOUNDED_RELU supported. @@ -136,7 +127,7 @@ public: /** Indicates whether or not there is an optimal assembly implementation that can be used to process the given parameters. * - * The paramter list is the same as @ref NEGEMMConvolutionLayer::has_opt_impl + * The parameter list is the same as @ref NEGEMMConvolutionLayer::has_opt_impl * * @return a status. */ @@ -254,15 +245,35 @@ private: bool isVarWeightsKernel() const; enum AuxTensorIdx { - // CpuGemmLowpMatrixMultiplyCore has up to 8 internal tensors - Im2ColOutput = 9, + GemmAsmPretransposedRHS = 2, // CpuGemmAssemblyDispatch::Pretranspose + GemmTransposed1xWRHS = 5, // CpuGemm::Transposed1xWRHS + GemmLowpTransposed1xWRHS = 6, // CpuGemmLowpMatrixMultiplyCore::TmpB + /* Slots 0 - 9 reserved and shared by CpuGemmLowpMatrixMultiplyCore and CpuGemm */ + Im2ColOutput = 10, WeightsReshaped, GemmOutput, Count }; - std::unique_ptr<kernels::CpuWeightsReshapeKernel> _weights_reshape_kernel; - std::unique_ptr<cpu::kernels::CpuIm2ColKernel> _im2col_kernel; + /** Weight transformation method. See @ref note_CpuGemmConv2d_weight_transformation */ + enum class WeightTransformMethod + { + ReinterpretThenTranspose, + ReshapeThenTranspose, + FusedReshapeAndTranspose, + }; + + /** Select weight transformation method + * + * @param[in] weights Input weights + * + * @return WeightTransformMethod + */ + static WeightTransformMethod get_wt_method(const ITensorInfo &weights); + + std::unique_ptr<CpuReshape> _weights_reshape; + std::unique_ptr<kernels::CpuWeightsReshapeKernel> _weights_reshape_and_transpose_kernel; + std::unique_ptr<kernels::CpuIm2ColKernel> _im2col_kernel; std::unique_ptr<CpuGemm> _mm_gemm; std::unique_ptr<CpuGemmLowpMatrixMultiplyCore> _mm_gemmlowp; std::unique_ptr<kernels::CpuCol2ImKernel> _col2im_kernel; @@ -275,10 +286,12 @@ private: DataLayout _data_layout; - bool _skip_im2col; - bool _skip_col2im; - bool _is_quantized; - bool _is_prepared; + bool _skip_im2col; + bool _skip_col2im; + bool _is_quantized; + bool _is_prepared; + WeightTransformMethod _wt_method; + bool _run_wt; experimental::MemoryRequirements _aux_mem{Count}; }; |