aboutsummaryrefslogtreecommitdiff
path: root/src/cpu/operators/CpuGemmConv2d.cpp
diff options
context:
space:
mode:
authorSiCong Li <sicong.li@arm.com>2023-10-17 17:38:57 +0100
committerSiCong Li <sicong.li@arm.com>2023-11-08 09:49:56 +0000
commitc5ab4df0c11dc66db47f2070edc719923af3367e (patch)
treec04bdac32528e628b2a9b9a1c1653e300328fc1b /src/cpu/operators/CpuGemmConv2d.cpp
parent4a9dbedfbfa66c2612c7461e60cd867b8aea825b (diff)
downloadComputeLibrary-c5ab4df0c11dc66db47f2070edc719923af3367e.tar.gz
Optimize CpuGemmConv2d start-up time
When weight has no holes, we can replace CpuWeightsReshapeKernel with: - Collapse by reinterpreting weight's 3 spatial dimensions - Perform CpuTranspose For more details see the documentation in src/cpu/operators/CpuGemmConv2d.cpp This is one optimization since the CpuTranspose is better performing than CpuWeightsReshapeKernel A second optimization is to fuse this transpose with other weight transformations (e.g. pretranspose_B_array in CpuGemmAssemblyDispatch) However this second optimization depends on how the underlying gemm methods (the fall back path: CpuGemmMatrixMultiplyKernel or the assembly path: CpuGemmAssemblyDispatch) chooses to fuse the transpose. Therefore, this patch moves the transpose down from CpuGemmConv2d, to the individual gemm operators where the fusion decision needs to be made, by passing an extra "transpose_b" flag to CpuGemm New transpose_b flag in different scopes (they are all the same, but with different names because pretranspose_b has a different meaning in GemmAssemblyDispatch): GEMMInfo::pretranspose_B -> AsmGemmInfo::transpose_b New auxilliary tensors holding the transposed b result: - CpuGemm optimized path: CpuGemmAssemblyDispatch::PrePretransposedB - CpuGemm fallback path: CpuGemm::PreTransposedRHS Note that this patch does not yet have the second optimization (COMPMID-6595), but it prepares for it. Relates to COMPMID-6595 Resolves COMPMID-6499 Change-Id: I999a2da9da4b2b15369a3cc06d7872c86e0190ea Signed-off-by: SiCong Li <sicong.li@arm.com> Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/10526 Tested-by: Arm Jenkins <bsgcomp@arm.com> Reviewed-by: Anitha Raj <Anitha.Raj@arm.com> Reviewed-by: Gunes Bayir <gunes.bayir@arm.com> Comments-Addressed: Arm Jenkins <bsgcomp@arm.com> Benchmark: Arm Jenkins <bsgcomp@arm.com>
Diffstat (limited to 'src/cpu/operators/CpuGemmConv2d.cpp')
-rw-r--r--src/cpu/operators/CpuGemmConv2d.cpp372
1 files changed, 304 insertions, 68 deletions
diff --git a/src/cpu/operators/CpuGemmConv2d.cpp b/src/cpu/operators/CpuGemmConv2d.cpp
index 7c59d88c61..117527ccc1 100644
--- a/src/cpu/operators/CpuGemmConv2d.cpp
+++ b/src/cpu/operators/CpuGemmConv2d.cpp
@@ -32,7 +32,9 @@
#include "arm_compute/runtime/NEON/NEScheduler.h"
#include "src/common/utils/Log.h"
+#include "src/core/helpers/AutoConfiguration.h"
#include "src/core/helpers/MemoryHelpers.h"
+#include "src/core/helpers/Utils.h"
#include "src/cpu/kernels/CpuCol2ImKernel.h"
#include "src/cpu/kernels/CpuIm2ColKernel.h"
#include "src/cpu/kernels/CpuWeightsReshapeKernel.h"
@@ -52,6 +54,117 @@ namespace arm_compute
{
namespace cpu
{
+
+/** @section note_CpuGemmConv2d_weight_transformation Weight Transformations in CpuGemmConv2d
+ *
+ * A. Terminology
+ * Throughout CpuGemmConv2d, we use the following terms in ways that may differ from other operators / kernels:
+ * - "Transform" or "Reshape" of the weights: they both mean all the operations that we perform on the weight
+ * tensor up until they are consumed by gemm (CpuGemm or CpuGemmLowpMatrixMultiplyCore)
+ * Note that the specific gemm operator may perform further transformations on the weights, but the
+ * transformations here only mean those performed in CpuGemmConv2d
+ * - "Transpose" of weights: The @ref CpuTranspose operation. I.e. transpose of the weights' lowest two
+ * dimensions
+ *
+ * B. Gemm-based conv2d
+ * We want to convert the 2d convolution op (ignoring bias):
+ * dst = conv2d(src, weight)
+ * into a matrix multiplication op:
+ * gemm_dst = gemm(lhs, rhs)
+ *
+ * E.g.: For data layout NHWC
+ * 3 (hi) <----------> (lo) 0
+ * src.shape = [batch, in_h , in_w, in_c]
+ * weight.shape = [out_c, k_h , k_w, in_c]
+ * dst.shape = [batch, out_h, out_w, out_c]
+ *
+ * This requires three transformations:
+ * * src -> lhs, transform conv input to gemm lhs; gemm_lhs is a 2d matrix where each row (or column,
+ * depending on the convention) is a linearized "patch" of the conv_input that corresponds to
+ * the receptive field of the corresponding output element.
+ * The convention is to use "column", but to disambiguate from the column vector of a matrix,
+ * in this documentation we shall use "patch".
+ * This transform is called im2col (for details see @ref CpuIm2ColKernel)
+ * * weight -> rhs, transform conv weight to gemm rhs, known as weight transform/reshape (wt)
+ * * gemm_dst -> dst, transform gemm output back to conv output, known as col2im (for details see
+ * @ref CpuCol2ImKernel)
+ *
+ * This section focuses on the weight transformation and assumes the im2col is already performed
+ *
+ * C. Weight Transformation
+ * After im2col, assume: lhs.shape = [num_patch, patch_size],
+ * where patch_size is the number of elements in a "patch": patch_size = k_h * k_w * in_c
+ * num_patch is the number of patches; we can ignore it here (for details see @ref CpuIm2ColKernel)
+ *
+ * After wt, rhs should have the shape: rhs = [patch_size, out_c]
+ *
+ * Therefore, the weight transformation consists of two steps:
+ * 1. Collapsing all 3 spatial dimensions: [out_c, k_h, k_w, in_c] -> [out_c, patch_size]
+ * 2. Transpose the collapsed shape: [out_c, patch_size] -> [patch_size, out_c]
+ *
+ * D. Implementation
+ * There are 4 paths for weight transformation
+ *
+ * 1. Path 1: Fixed weight format - no transformation
+ * The underlying gemm kernel may adopt fixed weight format (isVarWeightsKernel() == true), which requires
+ * that no weight transformation shall be performed
+ * Note that this no-transform requirement applies both to this op (CpuGemmConv2d) and the constituent ops, up
+ * until the fixed format kernels themselves
+ *
+ * 2. Path 2: Reinterpret then transpose later
+ * If the weight tensor has no "holes" (see @ref has_holes), there are two optimizations we can apply:
+ * - We can ignore the first step (collapsing of spatial dimensions) by simply re-interpreting the shape
+ * in TensorInfo
+ * - Instead of performing transpose here, we can pass the transpose flag to the underlying gemm. The gemm
+ * may then decide to fuse the transpose with any further transformations
+ *
+ * 3. Path 3: Reshape then transpose later
+ * If the weight tensor has holes, then we use a dedicated @ref CpuReshape, followed by transpose later
+ *
+ * 4. Path 4: Fused reshape and transpose
+ * This is only for quantized types for now (TODO: Remove (COMPMID-6596)). We fall back to a legacy
+ * non-optimized kernel @ref CpuWeightsReshapeKernel to perform a fused reshape + transpose
+ *
+ * Path 1 is the long term solution that we shall migrate to once (if) we adopt fixed weight format for all gemm
+ * kernels.
+ * In the short term, Path 2 is the favored, more performant path.
+ */
+
+namespace
+{
+/** Initialize reshaped / transformed weight info
+ *
+ * @param[in] weights Input weights
+ * @param[out] reshaped_weights Transformed weights
+ */
+void initialize_reshaped_weight_info(const ITensorInfo &weights, ITensorInfo &reshaped_weights)
+{
+ auto_init_if_empty(reshaped_weights, weights);
+ if (is_data_type_quantized(weights.data_type()))
+ {
+ // WT method: FusedReshapeAndTranspose
+ reshaped_weights.set_tensor_shape(compute_weights_reshaped_shape(weights, /* has_bias */ false));
+ }
+ else
+ {
+ TensorShape collapsed_weights = weights.tensor_shape();
+ collapsed_weights.collapse(3);
+ reshaped_weights.set_tensor_shape(collapsed_weights);
+ }
+}
+} // namespace
+
+CpuGemmConv2d::WeightTransformMethod CpuGemmConv2d::get_wt_method(const ITensorInfo &weights)
+{
+ // TODO: Extend ReinterpretThenTranspose support for quantized data types COMPMID-6596
+ if (is_data_type_quantized(weights.data_type()))
+ {
+ return WeightTransformMethod::FusedReshapeAndTranspose;
+ }
+ return has_holes(weights) ? WeightTransformMethod::ReshapeThenTranspose
+ : WeightTransformMethod::ReinterpretThenTranspose;
+}
+
CpuGemmConv2d::SkipInfo CpuGemmConv2d::skip_im_col_info(const ITensorInfo *src,
const ITensorInfo *weights,
const PadStrideInfo &conv_info,
@@ -96,7 +209,8 @@ CpuGemmConv2d::SkipInfo CpuGemmConv2d::skip_im_col_info(const ITensorInfo
}
CpuGemmConv2d::CpuGemmConv2d()
- : _weights_reshape_kernel(nullptr),
+ : _weights_reshape(nullptr),
+ _weights_reshape_and_transpose_kernel(nullptr),
_im2col_kernel(),
_mm_gemm(),
_mm_gemmlowp(),
@@ -111,6 +225,8 @@ CpuGemmConv2d::CpuGemmConv2d()
_skip_col2im(false),
_is_quantized(false),
_is_prepared(false),
+ _wt_method(WeightTransformMethod::ReshapeThenTranspose),
+ _run_wt(true),
_aux_mem(AuxTensorIdx::Count)
{
}
@@ -130,12 +246,6 @@ void CpuGemmConv2d::configure_mm(const ITensorInfo *src,
ARM_COMPUTE_ERROR_THROW_ON(validate_mm(src, weights, biases, dst, act_info, enable_fast_math, gemm_3d_depth,
_skip_im2col, fixed_format, weight_format));
- // Create GEMMInfo structure
- const GEMMInfo &gemm_info =
- GEMMInfo(false, false, true /* Reshape weights only for the first run */, gemm_3d_depth,
- _skip_im2col /* Reinterpret the input as 3D if im2col is skipped */, false, GEMMLowpOutputStageInfo(),
- false, enable_fast_math, false, act_info, fixed_format, weight_format);
-
// Supported activations in GEMM
const std::set<ActivationLayerInfo::ActivationFunction> supported_acts = {
ActivationLayerInfo::ActivationFunction::RELU, ActivationLayerInfo::ActivationFunction::BOUNDED_RELU,
@@ -184,7 +294,8 @@ void CpuGemmConv2d::configure_mm(const ITensorInfo *src,
_mm_gemmlowp = std::make_unique<CpuGemmLowpMatrixMultiplyCore>();
_mm_gemmlowp->configure(&tmp_src, &tmp_weights, biases, dst,
GEMMInfo(false, false, true, gemm_3d_depth, _skip_im2col, false, output_info, false,
- enable_fast_math, false, act_info, fixed_format, weight_format));
+ enable_fast_math, false, act_info, fixed_format, weight_format,
+ false /* pretranspose_B. TODO: COMPMID-6596 */));
auto mm_mem_req = _mm_gemmlowp->workspace();
for (unsigned int cont = 0; cont < mm_mem_req.size(); ++cont)
@@ -194,6 +305,13 @@ void CpuGemmConv2d::configure_mm(const ITensorInfo *src,
}
else
{
+ // Create GEMMInfo structure
+ const GEMMInfo &gemm_info =
+ GEMMInfo(false, false, true /* Reshape weights only for the first run */, gemm_3d_depth,
+ _skip_im2col /* Reinterpret the input as 3D if im2col is skipped */, false,
+ GEMMLowpOutputStageInfo(), false, enable_fast_math, false, act_info, fixed_format, weight_format,
+ true /*pretranspose_B. For fp gemm (wt path 1 - 3), We always pretranspose B (for wt path 1 this
+ flag is ignored)*/);
// Configure matrix multiply function
_mm_gemm = std::make_unique<CpuGemm>();
_mm_gemm->configure(src, weights, biases, dst, 1.0f, 1.0f, gemm_info);
@@ -220,12 +338,6 @@ Status CpuGemmConv2d::validate_mm(const ITensorInfo *src,
const bool is_quantized = is_data_type_quantized_asymmetric(data_type);
const bool is_activation_enabled = act_info.enabled();
- // Create GEMMInfo structure
- const GEMMInfo gemm_info =
- GEMMInfo(false, false, true /* Reshape weights only for the first run */, gemm_3d_depth,
- skip_im2col /* Reinterpret the input as 3D if im2col is skipped */, false, GEMMLowpOutputStageInfo(),
- false, enable_fast_math, false, act_info, fixed_format, weight_format);
-
if (is_quantized)
{
// Since we need negative offsets for computing convolution, we need to change QuantizationInfo()
@@ -266,10 +378,19 @@ Status CpuGemmConv2d::validate_mm(const ITensorInfo *src,
return CpuGemmLowpMatrixMultiplyCore::validate(input_qa.get(), weights_qa.get(), biases, dst,
GEMMInfo(false, false, true, gemm_3d_depth, skip_im2col, false,
- output_info, false, enable_fast_math, false, act_info));
+ output_info, false, enable_fast_math, false, act_info,
+ false /* pretranspose_B. TODO: COMPMID-6596 */));
}
else
{
+ // Create GEMMInfo structure
+ const GEMMInfo gemm_info =
+ GEMMInfo(false, false, true /* Reshape weights only for the first run */, gemm_3d_depth,
+ skip_im2col /* Reinterpret the input as 3D if im2col is skipped */, false,
+ GEMMLowpOutputStageInfo(), false, enable_fast_math, false, act_info, fixed_format, weight_format,
+ true /*pretranspose_B. For fp gemm (wt path 1 - 3), We always pretranspose B (for wt path 1 this
+ flag is ignored)*/);
+
// Perform validation step on Matrix multiply function
return CpuGemm::validate(src, weights, biases, dst, 1.0f, 1.0f, gemm_info);
}
@@ -353,13 +474,8 @@ void CpuGemmConv2d::configure(const ITensorInfo *src,
unsigned int stride_y = 0;
std::tie(stride_x, stride_y) = conv_info.stride();
- unsigned int mat_weights_cols = weights->dimension(idx_kernels);
-
- // _weights_reshaped will be auto configured in the kernel.
- // Just append biases and do not transpose 1xW as it will be reshaped in CpuGemm
- _weights_reshape_kernel = std::make_unique<kernels::CpuWeightsReshapeKernel>();
- _weights_reshape_kernel->configure(weights, nullptr, &_weights_reshaped);
- _weights_reshaped.set_quantization_info(weights->quantization_info());
+ // Initialize reshaped weights
+ initialize_reshaped_weight_info(*weights, _weights_reshaped);
// Create tensor to store im2col reshaped inputs
if (!_skip_im2col)
@@ -380,6 +496,8 @@ void CpuGemmConv2d::configure(const ITensorInfo *src,
gemm_input_to_use = &_im2col_output;
}
+ const unsigned int mat_weights_cols = weights->dimension(idx_kernels);
+
// Create temporary GEMM output tensor in case we cannot skip col2im
const DataType output_data_type = data_type == DataType::BFLOAT16 ? DataType::F32 : data_type;
if (!_skip_col2im)
@@ -412,9 +530,38 @@ void CpuGemmConv2d::configure(const ITensorInfo *src,
// In case we need to skip col2im, GEMM3D (gemm_3d_depth != 0) must be called in order to avoid reshaping the output matrix
const unsigned int gemm_3d_depth = _skip_col2im ? conv_h : 0;
const bool fixed_format = weights_info.weight_format() != arm_compute::WeightFormat::UNSPECIFIED;
+ /** @section note_CpuGemmConv2d_weight_use_in_configure Which weights tensor should we use to configure gemm
+ *
+ * A. The problem:
+ * In principle, we should use the weights tensor corresponding to the weights transformation path. I.e.:
+ * - If no weight transformation (_run_wt == false): Use original weights
+ * - else: Use transformed weights
+ * However in practice we have a dilemma:
+ * - We need to know _run_wt before we can configure gemm with the corresponding weights, but
+ * - _run_wt depends on isVarWeightsKernel(), which is only known after gemm is configured
+ *
+ * B. The decision:
+ * To simplify the matter, we decide to always use the transformed weights, regardless of _run_wt
+ *
+ * This decision requires the following conditions:
+ * 1. The underlying gemm where isVarWeightsKernel() == true, must guarantee that:
+ * A. Ignore the flag to transpose weights (GEMMInfo::pretranspose_B)
+ * B. Use weights/B tensor passed to it at prepare() or run() instead of that passed at configure()
+ * 2. CpuGemmConv2d where isVarWeightsKernel() == true, must guarantee that:
+ * A. Pass original weights instead of reshaped or reinterpreted weights
+ *
+ * C. Future actions:
+ * Condition 2 is a given, based on our implementation.
+ * If condition 1 cannot hold, we must make changes to the underlying gemm to:
+ * 1. Either expose isVarWeightsKernel() before gemm is configured somehow, or
+ * 2. Take in an additional "original_weights" tensor info at configure
+ */
configure_mm(gemm_input_to_use, &_weights_reshaped, biases, gemm_output_to_use, act_info, enable_fast_math,
gemm_3d_depth, fixed_format, weights_info.weight_format());
+ // Can only decide isVarWeightsKernel after gemm is configured
+ _run_wt = !isVarWeightsKernel();
+
if (!_skip_col2im && _data_layout == DataLayout::NCHW)
{
// Configure col2im
@@ -428,18 +575,27 @@ void CpuGemmConv2d::configure(const ITensorInfo *src,
_reshape->configure(gemm_output_to_use, dst);
}
- // Check if GEMM transforms weights
- // Modernise through COMPMID-4535
- bool gemm_trans_wei = _aux_mem[1].size > 0; // Asm Pretranspose
- gemm_trans_wei = _mm_gemm != nullptr ? _aux_mem[3].size > 0 : gemm_trans_wei; // Tranpose RHS
- gemm_trans_wei = _mm_gemmlowp != nullptr ? _aux_mem[5].size > 0 : gemm_trans_wei; // Transpose RHS
-
// Check lifetime
_aux_mem[Im2ColOutput] =
MemoryInfo(offset_int_vec(Im2ColOutput), MemoryLifetime::Temporary, _im2col_output.total_size());
- _aux_mem[WeightsReshaped] = MemoryInfo(offset_int_vec(WeightsReshaped),
- gemm_trans_wei ? MemoryLifetime::Prepare : MemoryLifetime::Persistent,
- _weights_reshaped.total_size());
+ // Add WeightsReshaped memory requirement to workspace
+ // Note that in case of WeightTransformMethod::ReinterpretThenTranspose, we do not need to allocate this memory
+ // However since we cannot determine weight transformation method until prepare (see prepare()), we will have to
+ // settle with allocating more
+ if (_run_wt)
+ {
+ // Check if GEMM transforms weights
+ // If weight is further transformed by underlying gemm after ReshapeThenTranspose then we can free
+ // WeightsReshaped in prepare
+ // Otherwise WeightsReshaped is the final transformation of weights and needs to persist
+ bool gemm_trans_wei = _aux_mem[GemmAsmPretransposedRHS].size > 0;
+ gemm_trans_wei = _mm_gemm != nullptr ? _aux_mem[GemmTransposed1xWRHS].size > 0 : gemm_trans_wei;
+ gemm_trans_wei = _mm_gemmlowp != nullptr ? _aux_mem[GemmLowpTransposed1xWRHS].size > 0 : gemm_trans_wei;
+
+ _aux_mem[WeightsReshaped] = MemoryInfo(offset_int_vec(WeightsReshaped),
+ gemm_trans_wei ? MemoryLifetime::Prepare : MemoryLifetime::Persistent,
+ _weights_reshaped.total_size());
+ }
_aux_mem[GemmOutput] = MemoryInfo(offset_int_vec(GemmOutput), MemoryLifetime::Temporary, _gemm_output.total_size());
}
@@ -471,10 +627,18 @@ Status CpuGemmConv2d::has_opt_impl(arm_compute::WeightFormat &expected_weight_fo
const bool skip_col2im = skip_info.skip_col2im;
const unsigned int gemm_3d_depth = skip_col2im ? conv_h : 0;
const bool fixed_format = weights_info.weight_format() != arm_compute::WeightFormat::UNSPECIFIED;
- const GEMMInfo gemm_info =
- GEMMInfo(false, false, true /* Reshape weights only for the first run */, gemm_3d_depth,
- skip_im2col /* Reinterpret the input as 3D if im2col is skipped */, false, GEMMLowpOutputStageInfo(),
- false, enable_fast_math, false, act_info, fixed_format, weights_info.weight_format());
+
+ /** @section note_CpuGemmConv2d_weight_use_in_has_opt_impl Which weights tensor should we use for has_opt_impl
+ *
+ * For the pretranspose_B flag, this shares a similar problem and thus the same decision as that of
+ * @ref note_CpuGemmConv2d_weight_use_in_configure
+ *
+ * But for the weights, we shall always use the original instead of reshaped weights here
+ */
+ const GEMMInfo gemm_info = GEMMInfo(false, false, true /* Reshape weights only for the first run */, gemm_3d_depth,
+ skip_im2col /* Reinterpret the input as 3D if im2col is skipped */, false,
+ GEMMLowpOutputStageInfo(), false, enable_fast_math, false, act_info,
+ fixed_format, weights_info.weight_format(), true /* pretranspose_B */);
return CpuGemm::has_opt_impl(expected_weight_format, src, weights, biases, dst, gemm_info);
}
@@ -565,8 +729,10 @@ Status CpuGemmConv2d::validate(const ITensorInfo *src,
unsigned int mat_weights_rows =
weights->dimension(idx_width) * weights->dimension(idx_height) * weights->dimension(idx_channel);
- weights_reshaped_info = TensorInfo(compute_weights_reshaped_shape(*weights, append_bias), 1, weights->data_type());
- weights_reshaped_info.set_quantization_info(weights->quantization_info());
+ // Initialize reshaped weights
+ initialize_reshaped_weight_info(*weights, weights_reshaped_info);
+ // No need to call CpuReshape::validate() or CpuTranspose::validate() as the dst info is auto-configured from the
+ // src
weights_to_use = &weights_reshaped_info;
if (!skip_im2col)
@@ -613,6 +779,7 @@ Status CpuGemmConv2d::validate(const ITensorInfo *src,
gemm_output_to_use = &info_gemm;
const bool fixed_format = weights_info.weight_format() != arm_compute::WeightFormat::UNSPECIFIED;
+ // See note_CpuGemmConv2d_weight_use_in_configure regarding the choice of the weights
ARM_COMPUTE_RETURN_ON_ERROR(validate_mm(gemm_input_to_use, weights_to_use, biases, gemm_output_to_use, act_info,
enable_fast_math, skip_col2im ? conv_h : 0, skip_im2col, fixed_format,
weights_info.weight_format()));
@@ -637,7 +804,6 @@ void CpuGemmConv2d::run(ITensorPack &tensors)
CpuAuxTensorHandler im2col_output(offset_int_vec(Im2ColOutput), _im2col_output, tensors, false);
CpuAuxTensorHandler gemm_output(offset_int_vec(GemmOutput), _gemm_output, tensors, false);
- CpuAuxTensorHandler reshaped_wei(offset_int_vec(WeightsReshaped), _weights_reshaped, tensors, false);
bool out_has_padding = _skip_col2im && (dst->info()->padding().bottom != 0 || dst->info()->padding().top != 0);
if (!_skip_im2col)
@@ -666,25 +832,32 @@ void CpuGemmConv2d::run(ITensorPack &tensors)
gemm_output_to_use = dst;
}
- // Runs CpuGemm or CpuGemmLowpMatrixMultiplyCore functions
- ITensorPack pack_mm = tensors;
- pack_mm.add_const_tensor(TensorType::ACL_SRC_0, gemm_input_to_use);
- if (!this->isVarWeightsKernel())
- {
- pack_mm.add_const_tensor(TensorType::ACL_SRC_1, reshaped_wei.get());
- }
- pack_mm.add_tensor(TensorType::ACL_DST, gemm_output_to_use);
- if (_is_quantized)
+ ITensorPack gemm_pack = tensors;
+ gemm_pack.add_const_tensor(TensorType::ACL_SRC_0, gemm_input_to_use);
+ gemm_pack.add_tensor(TensorType::ACL_DST, gemm_output_to_use);
+ // Allocate reshaped weights if required
+ auto weights = gemm_pack.get_const_tensor(TensorType::ACL_SRC_1);
+ CpuAuxTensorHandler reinterpreted_wei(
+ _weights_reshaped,
+ *weights); // Re-interpreted weights. Only tensor shape is changed. No allocation
+ CpuAuxTensorHandler reshaped_wei(offset_int_vec(WeightsReshaped), _weights_reshaped, tensors);
+ // Update the weights to use if it has been reshaped
+ if (_run_wt)
{
- // Run gemmlowp
- _mm_gemmlowp->run(pack_mm);
- }
- else
- {
- // Run gemm
- _mm_gemm->run(pack_mm);
+ if (_wt_method == WeightTransformMethod::ReinterpretThenTranspose)
+ {
+ gemm_pack.add_const_tensor(TensorType::ACL_SRC_1, reinterpreted_wei.get());
+ }
+ else if (_wt_method == WeightTransformMethod::ReshapeThenTranspose ||
+ _wt_method == WeightTransformMethod::FusedReshapeAndTranspose)
+ {
+ gemm_pack.add_const_tensor(TensorType::ACL_SRC_1, reshaped_wei.get());
+ }
}
+ // Runs CpuGemm or CpuGemmLowpMatrixMultiplyCore functions
+ _is_quantized ? _mm_gemmlowp->run(gemm_pack) : _mm_gemm->run(gemm_pack);
+
// Reshape output matrix
if (!_skip_col2im)
{
@@ -710,24 +883,87 @@ void CpuGemmConv2d::prepare(ITensorPack &tensors)
{
if (!_is_prepared)
{
- // Variable weights executions that use fixed-format kernels
- // need no reshaping of the weights.
- if (this->isVarWeightsKernel())
+ auto weights = tensors.get_const_tensor(TensorType::ACL_SRC_1);
+ // Determine which weights reshape path to take
+ // Note that this decision can only occur at prepare instead of configure because it relies on the presence of
+ // any holes in the weight tensor, which may change after configure (e.g. from extending padding)
+ if (_run_wt)
{
- _is_quantized ? _mm_gemmlowp->prepare(tensors) : _mm_gemm->prepare(tensors);
- _is_prepared = true;
- return;
+ _wt_method = get_wt_method(*(weights->info()));
+ switch (_wt_method)
+ {
+ case (WeightTransformMethod::FusedReshapeAndTranspose):
+ {
+ ARM_COMPUTE_LOG_INFO_WITH_FUNCNAME_ACL("Perform weight transformation: FusedReshapeAndTranspose");
+ _weights_reshape_and_transpose_kernel = std::make_unique<kernels::CpuWeightsReshapeKernel>();
+ _weights_reshape_and_transpose_kernel->configure(weights->info(), nullptr, &_weights_reshaped);
+ break;
+ }
+ case (WeightTransformMethod::ReshapeThenTranspose):
+ {
+ ARM_COMPUTE_LOG_INFO_WITH_FUNCNAME_ACL("Perform weight transformation: ReshapeThenTranspose");
+ _weights_reshape = std::make_unique<CpuReshape>();
+ _weights_reshape->configure(weights->info(), &_weights_reshaped);
+ break;
+ }
+ case (WeightTransformMethod::ReinterpretThenTranspose):
+ {
+ ARM_COMPUTE_LOG_INFO_WITH_FUNCNAME_ACL("Perform weight transformation: ReinterpretThenTranspose");
+ // Nothing to configure
+ break;
+ }
+ default:
+ {
+ ARM_COMPUTE_ERROR("Unsupported weight transform method");
+ }
+ }
+ }
+ else
+ {
+ ARM_COMPUTE_LOG_INFO_WITH_FUNCNAME_ACL("No weight transformation is performed");
}
-
- // Run weights reshaping and mark original weights tensor as unused
- CpuAuxTensorHandler weights_reshaped(offset_int_vec(WeightsReshaped), _weights_reshaped, tensors);
- auto weights = tensors.get_const_tensor(TensorType::ACL_SRC_1);
- ITensorPack pack = {{TensorType::ACL_SRC, weights}, {TensorType::ACL_DST, weights_reshaped.get()}};
- NEScheduler::get().schedule_op(_weights_reshape_kernel.get(), 3, _weights_reshape_kernel->window(), pack);
- weights->mark_as_unused();
ITensorPack gemm_pack = tensors;
- gemm_pack.add_const_tensor(TensorType::ACL_SRC_1, weights_reshaped.get());
+ // Allocate reshaped weights if required
+ CpuAuxTensorHandler reinterpreted_wei(
+ _weights_reshaped,
+ *weights); // Re-interpreted weights. Only tensor shape is changed. No allocation
+ CpuAuxTensorHandler reshaped_wei(offset_int_vec(WeightsReshaped), _weights_reshaped, tensors);
+ // Run weights reshape if required
+ if (_run_wt)
+ {
+ switch (_wt_method)
+ {
+ case (WeightTransformMethod::FusedReshapeAndTranspose):
+ {
+ ITensorPack pack = {{TensorType::ACL_SRC, weights}, {TensorType::ACL_DST, reshaped_wei.get()}};
+ NEScheduler::get().schedule_op(_weights_reshape_and_transpose_kernel.get(), Window::DimW,
+ _weights_reshape_and_transpose_kernel->window(), pack);
+ weights->mark_as_unused();
+ gemm_pack.add_const_tensor(TensorType::ACL_SRC_1, reshaped_wei.get());
+ break;
+ }
+ case (WeightTransformMethod::ReshapeThenTranspose):
+ {
+ ITensorPack pack = {{TensorType::ACL_SRC, weights}, {TensorType::ACL_DST, reshaped_wei.get()}};
+ _weights_reshape->run(pack);
+ weights->mark_as_unused();
+ gemm_pack.add_const_tensor(TensorType::ACL_SRC_1, reshaped_wei.get());
+ break;
+ }
+ case (WeightTransformMethod::ReinterpretThenTranspose):
+ {
+ gemm_pack.add_const_tensor(TensorType::ACL_SRC_1, reinterpreted_wei.get());
+ // Nothing to run
+ break;
+ }
+ default:
+ {
+ ARM_COMPUTE_ERROR("Unsupported weight transform method");
+ }
+ }
+ }
_is_quantized ? _mm_gemmlowp->prepare(gemm_pack) : _mm_gemm->prepare(gemm_pack);
+
_is_prepared = true;
}
}