aboutsummaryrefslogtreecommitdiff
path: root/src/cpu/operators/CpuGemmConv2d.h
diff options
context:
space:
mode:
Diffstat (limited to 'src/cpu/operators/CpuGemmConv2d.h')
-rw-r--r--src/cpu/operators/CpuGemmConv2d.h55
1 files changed, 34 insertions, 21 deletions
diff --git a/src/cpu/operators/CpuGemmConv2d.h b/src/cpu/operators/CpuGemmConv2d.h
index 118d366517..48a0d11107 100644
--- a/src/cpu/operators/CpuGemmConv2d.h
+++ b/src/cpu/operators/CpuGemmConv2d.h
@@ -42,21 +42,12 @@ class CpuGemmLowpOutputStage;
class CpuReshape;
namespace kernels
{
-class CpuWeightsReshapeKernel;
class CpuIm2ColKernel;
class CpuCol2ImKernel;
+class CpuWeightsReshapeKernel;
} // namespace kernels
-/** Basic function to compute the convolution layer. This function calls the following kernels/functions:
- *
- * -# @ref cpu::kernels::CpuIm2ColKernel
- * -# @ref CpuGemm (if the data type is BFLOAT16/FP16/FP32)
- * -# @ref CpuGemmLowpMatrixMultiplyCore (if the data type is QASYMM8/QASYMM8_SIGNED)
- * -# @ref CpuGemmLowpOutputStage (if the data type is QASYMM8/QASYMM8_SIGNED)
- * -# @ref cpu::kernels::CpuCol2ImKernel (if NCHW data layout)
- * -# @ref kernels::CpuWeightsReshapeKernel
- *
- */
+/** Basic function to compute the convolution layer. @ref note_CpuGemmConv2d_weight_transformation */
class CpuGemmConv2d : public ICpuOperator
{
public:
@@ -99,7 +90,7 @@ public:
* @param[out] dst Destination tensor info. 3 lower dimensions represent a single output [width, height, OFM], while the rest represent batch of outputs.
* Data types supported: Same as @p input.
* @param[in] conv_info Contains padding and stride information described in @ref PadStrideInfo.
- * @param[in] weights_info Specifies if the weights tensor has been reshaped with NEWeightsReshapeKernel. If this is not part of the fully connected layer the weights
+ * @param[in] weights_info Specifies if the weights tensor has been reshaped with CpuWeightsReshapeKernel. If this is not part of the fully connected layer the weights
* tensor has also been transposed with cpu::kernels::CpuGemmTranspose1xWKernel. Data type supported: Same as @p input.
* @param[in] dilation (Optional) Dilation, in elements, across x and y. Defaults to (1, 1).
* @param[in] act_info (Optional) Activation layer information in case of a fused activation. Only RELU, BOUNDED_RELU and LU_BOUNDED_RELU supported.
@@ -136,7 +127,7 @@ public:
/** Indicates whether or not there is an optimal assembly implementation that can be used to process the given parameters.
*
- * The paramter list is the same as @ref NEGEMMConvolutionLayer::has_opt_impl
+ * The parameter list is the same as @ref NEGEMMConvolutionLayer::has_opt_impl
*
* @return a status.
*/
@@ -254,15 +245,35 @@ private:
bool isVarWeightsKernel() const;
enum AuxTensorIdx
{
- // CpuGemmLowpMatrixMultiplyCore has up to 8 internal tensors
- Im2ColOutput = 9,
+ GemmAsmPretransposedRHS = 2, // CpuGemmAssemblyDispatch::Pretranspose
+ GemmTransposed1xWRHS = 5, // CpuGemm::Transposed1xWRHS
+ GemmLowpTransposed1xWRHS = 6, // CpuGemmLowpMatrixMultiplyCore::TmpB
+ /* Slots 0 - 9 reserved and shared by CpuGemmLowpMatrixMultiplyCore and CpuGemm */
+ Im2ColOutput = 10,
WeightsReshaped,
GemmOutput,
Count
};
- std::unique_ptr<kernels::CpuWeightsReshapeKernel> _weights_reshape_kernel;
- std::unique_ptr<cpu::kernels::CpuIm2ColKernel> _im2col_kernel;
+ /** Weight transformation method. See @ref note_CpuGemmConv2d_weight_transformation */
+ enum class WeightTransformMethod
+ {
+ ReinterpretThenTranspose,
+ ReshapeThenTranspose,
+ FusedReshapeAndTranspose,
+ };
+
+ /** Select weight transformation method
+ *
+ * @param[in] weights Input weights
+ *
+ * @return WeightTransformMethod
+ */
+ static WeightTransformMethod get_wt_method(const ITensorInfo &weights);
+
+ std::unique_ptr<CpuReshape> _weights_reshape;
+ std::unique_ptr<kernels::CpuWeightsReshapeKernel> _weights_reshape_and_transpose_kernel;
+ std::unique_ptr<kernels::CpuIm2ColKernel> _im2col_kernel;
std::unique_ptr<CpuGemm> _mm_gemm;
std::unique_ptr<CpuGemmLowpMatrixMultiplyCore> _mm_gemmlowp;
std::unique_ptr<kernels::CpuCol2ImKernel> _col2im_kernel;
@@ -275,10 +286,12 @@ private:
DataLayout _data_layout;
- bool _skip_im2col;
- bool _skip_col2im;
- bool _is_quantized;
- bool _is_prepared;
+ bool _skip_im2col;
+ bool _skip_col2im;
+ bool _is_quantized;
+ bool _is_prepared;
+ WeightTransformMethod _wt_method;
+ bool _run_wt;
experimental::MemoryRequirements _aux_mem{Count};
};