diff options
Diffstat (limited to 'src/cpu/operators/CpuGemmConv2d.cpp')
-rw-r--r-- | src/cpu/operators/CpuGemmConv2d.cpp | 372 |
1 files changed, 304 insertions, 68 deletions
diff --git a/src/cpu/operators/CpuGemmConv2d.cpp b/src/cpu/operators/CpuGemmConv2d.cpp index 7c59d88c61..117527ccc1 100644 --- a/src/cpu/operators/CpuGemmConv2d.cpp +++ b/src/cpu/operators/CpuGemmConv2d.cpp @@ -32,7 +32,9 @@ #include "arm_compute/runtime/NEON/NEScheduler.h" #include "src/common/utils/Log.h" +#include "src/core/helpers/AutoConfiguration.h" #include "src/core/helpers/MemoryHelpers.h" +#include "src/core/helpers/Utils.h" #include "src/cpu/kernels/CpuCol2ImKernel.h" #include "src/cpu/kernels/CpuIm2ColKernel.h" #include "src/cpu/kernels/CpuWeightsReshapeKernel.h" @@ -52,6 +54,117 @@ namespace arm_compute { namespace cpu { + +/** @section note_CpuGemmConv2d_weight_transformation Weight Transformations in CpuGemmConv2d + * + * A. Terminology + * Throughout CpuGemmConv2d, we use the following terms in ways that may differ from other operators / kernels: + * - "Transform" or "Reshape" of the weights: they both mean all the operations that we perform on the weight + * tensor up until they are consumed by gemm (CpuGemm or CpuGemmLowpMatrixMultiplyCore) + * Note that the specific gemm operator may perform further transformations on the weights, but the + * transformations here only mean those performed in CpuGemmConv2d + * - "Transpose" of weights: The @ref CpuTranspose operation. I.e. transpose of the weights' lowest two + * dimensions + * + * B. Gemm-based conv2d + * We want to convert the 2d convolution op (ignoring bias): + * dst = conv2d(src, weight) + * into a matrix multiplication op: + * gemm_dst = gemm(lhs, rhs) + * + * E.g.: For data layout NHWC + * 3 (hi) <----------> (lo) 0 + * src.shape = [batch, in_h , in_w, in_c] + * weight.shape = [out_c, k_h , k_w, in_c] + * dst.shape = [batch, out_h, out_w, out_c] + * + * This requires three transformations: + * * src -> lhs, transform conv input to gemm lhs; gemm_lhs is a 2d matrix where each row (or column, + * depending on the convention) is a linearized "patch" of the conv_input that corresponds to + * the receptive field of the corresponding output element. + * The convention is to use "column", but to disambiguate from the column vector of a matrix, + * in this documentation we shall use "patch". + * This transform is called im2col (for details see @ref CpuIm2ColKernel) + * * weight -> rhs, transform conv weight to gemm rhs, known as weight transform/reshape (wt) + * * gemm_dst -> dst, transform gemm output back to conv output, known as col2im (for details see + * @ref CpuCol2ImKernel) + * + * This section focuses on the weight transformation and assumes the im2col is already performed + * + * C. Weight Transformation + * After im2col, assume: lhs.shape = [num_patch, patch_size], + * where patch_size is the number of elements in a "patch": patch_size = k_h * k_w * in_c + * num_patch is the number of patches; we can ignore it here (for details see @ref CpuIm2ColKernel) + * + * After wt, rhs should have the shape: rhs = [patch_size, out_c] + * + * Therefore, the weight transformation consists of two steps: + * 1. Collapsing all 3 spatial dimensions: [out_c, k_h, k_w, in_c] -> [out_c, patch_size] + * 2. Transpose the collapsed shape: [out_c, patch_size] -> [patch_size, out_c] + * + * D. Implementation + * There are 4 paths for weight transformation + * + * 1. Path 1: Fixed weight format - no transformation + * The underlying gemm kernel may adopt fixed weight format (isVarWeightsKernel() == true), which requires + * that no weight transformation shall be performed + * Note that this no-transform requirement applies both to this op (CpuGemmConv2d) and the constituent ops, up + * until the fixed format kernels themselves + * + * 2. Path 2: Reinterpret then transpose later + * If the weight tensor has no "holes" (see @ref has_holes), there are two optimizations we can apply: + * - We can ignore the first step (collapsing of spatial dimensions) by simply re-interpreting the shape + * in TensorInfo + * - Instead of performing transpose here, we can pass the transpose flag to the underlying gemm. The gemm + * may then decide to fuse the transpose with any further transformations + * + * 3. Path 3: Reshape then transpose later + * If the weight tensor has holes, then we use a dedicated @ref CpuReshape, followed by transpose later + * + * 4. Path 4: Fused reshape and transpose + * This is only for quantized types for now (TODO: Remove (COMPMID-6596)). We fall back to a legacy + * non-optimized kernel @ref CpuWeightsReshapeKernel to perform a fused reshape + transpose + * + * Path 1 is the long term solution that we shall migrate to once (if) we adopt fixed weight format for all gemm + * kernels. + * In the short term, Path 2 is the favored, more performant path. + */ + +namespace +{ +/** Initialize reshaped / transformed weight info + * + * @param[in] weights Input weights + * @param[out] reshaped_weights Transformed weights + */ +void initialize_reshaped_weight_info(const ITensorInfo &weights, ITensorInfo &reshaped_weights) +{ + auto_init_if_empty(reshaped_weights, weights); + if (is_data_type_quantized(weights.data_type())) + { + // WT method: FusedReshapeAndTranspose + reshaped_weights.set_tensor_shape(compute_weights_reshaped_shape(weights, /* has_bias */ false)); + } + else + { + TensorShape collapsed_weights = weights.tensor_shape(); + collapsed_weights.collapse(3); + reshaped_weights.set_tensor_shape(collapsed_weights); + } +} +} // namespace + +CpuGemmConv2d::WeightTransformMethod CpuGemmConv2d::get_wt_method(const ITensorInfo &weights) +{ + // TODO: Extend ReinterpretThenTranspose support for quantized data types COMPMID-6596 + if (is_data_type_quantized(weights.data_type())) + { + return WeightTransformMethod::FusedReshapeAndTranspose; + } + return has_holes(weights) ? WeightTransformMethod::ReshapeThenTranspose + : WeightTransformMethod::ReinterpretThenTranspose; +} + CpuGemmConv2d::SkipInfo CpuGemmConv2d::skip_im_col_info(const ITensorInfo *src, const ITensorInfo *weights, const PadStrideInfo &conv_info, @@ -96,7 +209,8 @@ CpuGemmConv2d::SkipInfo CpuGemmConv2d::skip_im_col_info(const ITensorInfo } CpuGemmConv2d::CpuGemmConv2d() - : _weights_reshape_kernel(nullptr), + : _weights_reshape(nullptr), + _weights_reshape_and_transpose_kernel(nullptr), _im2col_kernel(), _mm_gemm(), _mm_gemmlowp(), @@ -111,6 +225,8 @@ CpuGemmConv2d::CpuGemmConv2d() _skip_col2im(false), _is_quantized(false), _is_prepared(false), + _wt_method(WeightTransformMethod::ReshapeThenTranspose), + _run_wt(true), _aux_mem(AuxTensorIdx::Count) { } @@ -130,12 +246,6 @@ void CpuGemmConv2d::configure_mm(const ITensorInfo *src, ARM_COMPUTE_ERROR_THROW_ON(validate_mm(src, weights, biases, dst, act_info, enable_fast_math, gemm_3d_depth, _skip_im2col, fixed_format, weight_format)); - // Create GEMMInfo structure - const GEMMInfo &gemm_info = - GEMMInfo(false, false, true /* Reshape weights only for the first run */, gemm_3d_depth, - _skip_im2col /* Reinterpret the input as 3D if im2col is skipped */, false, GEMMLowpOutputStageInfo(), - false, enable_fast_math, false, act_info, fixed_format, weight_format); - // Supported activations in GEMM const std::set<ActivationLayerInfo::ActivationFunction> supported_acts = { ActivationLayerInfo::ActivationFunction::RELU, ActivationLayerInfo::ActivationFunction::BOUNDED_RELU, @@ -184,7 +294,8 @@ void CpuGemmConv2d::configure_mm(const ITensorInfo *src, _mm_gemmlowp = std::make_unique<CpuGemmLowpMatrixMultiplyCore>(); _mm_gemmlowp->configure(&tmp_src, &tmp_weights, biases, dst, GEMMInfo(false, false, true, gemm_3d_depth, _skip_im2col, false, output_info, false, - enable_fast_math, false, act_info, fixed_format, weight_format)); + enable_fast_math, false, act_info, fixed_format, weight_format, + false /* pretranspose_B. TODO: COMPMID-6596 */)); auto mm_mem_req = _mm_gemmlowp->workspace(); for (unsigned int cont = 0; cont < mm_mem_req.size(); ++cont) @@ -194,6 +305,13 @@ void CpuGemmConv2d::configure_mm(const ITensorInfo *src, } else { + // Create GEMMInfo structure + const GEMMInfo &gemm_info = + GEMMInfo(false, false, true /* Reshape weights only for the first run */, gemm_3d_depth, + _skip_im2col /* Reinterpret the input as 3D if im2col is skipped */, false, + GEMMLowpOutputStageInfo(), false, enable_fast_math, false, act_info, fixed_format, weight_format, + true /*pretranspose_B. For fp gemm (wt path 1 - 3), We always pretranspose B (for wt path 1 this + flag is ignored)*/); // Configure matrix multiply function _mm_gemm = std::make_unique<CpuGemm>(); _mm_gemm->configure(src, weights, biases, dst, 1.0f, 1.0f, gemm_info); @@ -220,12 +338,6 @@ Status CpuGemmConv2d::validate_mm(const ITensorInfo *src, const bool is_quantized = is_data_type_quantized_asymmetric(data_type); const bool is_activation_enabled = act_info.enabled(); - // Create GEMMInfo structure - const GEMMInfo gemm_info = - GEMMInfo(false, false, true /* Reshape weights only for the first run */, gemm_3d_depth, - skip_im2col /* Reinterpret the input as 3D if im2col is skipped */, false, GEMMLowpOutputStageInfo(), - false, enable_fast_math, false, act_info, fixed_format, weight_format); - if (is_quantized) { // Since we need negative offsets for computing convolution, we need to change QuantizationInfo() @@ -266,10 +378,19 @@ Status CpuGemmConv2d::validate_mm(const ITensorInfo *src, return CpuGemmLowpMatrixMultiplyCore::validate(input_qa.get(), weights_qa.get(), biases, dst, GEMMInfo(false, false, true, gemm_3d_depth, skip_im2col, false, - output_info, false, enable_fast_math, false, act_info)); + output_info, false, enable_fast_math, false, act_info, + false /* pretranspose_B. TODO: COMPMID-6596 */)); } else { + // Create GEMMInfo structure + const GEMMInfo gemm_info = + GEMMInfo(false, false, true /* Reshape weights only for the first run */, gemm_3d_depth, + skip_im2col /* Reinterpret the input as 3D if im2col is skipped */, false, + GEMMLowpOutputStageInfo(), false, enable_fast_math, false, act_info, fixed_format, weight_format, + true /*pretranspose_B. For fp gemm (wt path 1 - 3), We always pretranspose B (for wt path 1 this + flag is ignored)*/); + // Perform validation step on Matrix multiply function return CpuGemm::validate(src, weights, biases, dst, 1.0f, 1.0f, gemm_info); } @@ -353,13 +474,8 @@ void CpuGemmConv2d::configure(const ITensorInfo *src, unsigned int stride_y = 0; std::tie(stride_x, stride_y) = conv_info.stride(); - unsigned int mat_weights_cols = weights->dimension(idx_kernels); - - // _weights_reshaped will be auto configured in the kernel. - // Just append biases and do not transpose 1xW as it will be reshaped in CpuGemm - _weights_reshape_kernel = std::make_unique<kernels::CpuWeightsReshapeKernel>(); - _weights_reshape_kernel->configure(weights, nullptr, &_weights_reshaped); - _weights_reshaped.set_quantization_info(weights->quantization_info()); + // Initialize reshaped weights + initialize_reshaped_weight_info(*weights, _weights_reshaped); // Create tensor to store im2col reshaped inputs if (!_skip_im2col) @@ -380,6 +496,8 @@ void CpuGemmConv2d::configure(const ITensorInfo *src, gemm_input_to_use = &_im2col_output; } + const unsigned int mat_weights_cols = weights->dimension(idx_kernels); + // Create temporary GEMM output tensor in case we cannot skip col2im const DataType output_data_type = data_type == DataType::BFLOAT16 ? DataType::F32 : data_type; if (!_skip_col2im) @@ -412,9 +530,38 @@ void CpuGemmConv2d::configure(const ITensorInfo *src, // In case we need to skip col2im, GEMM3D (gemm_3d_depth != 0) must be called in order to avoid reshaping the output matrix const unsigned int gemm_3d_depth = _skip_col2im ? conv_h : 0; const bool fixed_format = weights_info.weight_format() != arm_compute::WeightFormat::UNSPECIFIED; + /** @section note_CpuGemmConv2d_weight_use_in_configure Which weights tensor should we use to configure gemm + * + * A. The problem: + * In principle, we should use the weights tensor corresponding to the weights transformation path. I.e.: + * - If no weight transformation (_run_wt == false): Use original weights + * - else: Use transformed weights + * However in practice we have a dilemma: + * - We need to know _run_wt before we can configure gemm with the corresponding weights, but + * - _run_wt depends on isVarWeightsKernel(), which is only known after gemm is configured + * + * B. The decision: + * To simplify the matter, we decide to always use the transformed weights, regardless of _run_wt + * + * This decision requires the following conditions: + * 1. The underlying gemm where isVarWeightsKernel() == true, must guarantee that: + * A. Ignore the flag to transpose weights (GEMMInfo::pretranspose_B) + * B. Use weights/B tensor passed to it at prepare() or run() instead of that passed at configure() + * 2. CpuGemmConv2d where isVarWeightsKernel() == true, must guarantee that: + * A. Pass original weights instead of reshaped or reinterpreted weights + * + * C. Future actions: + * Condition 2 is a given, based on our implementation. + * If condition 1 cannot hold, we must make changes to the underlying gemm to: + * 1. Either expose isVarWeightsKernel() before gemm is configured somehow, or + * 2. Take in an additional "original_weights" tensor info at configure + */ configure_mm(gemm_input_to_use, &_weights_reshaped, biases, gemm_output_to_use, act_info, enable_fast_math, gemm_3d_depth, fixed_format, weights_info.weight_format()); + // Can only decide isVarWeightsKernel after gemm is configured + _run_wt = !isVarWeightsKernel(); + if (!_skip_col2im && _data_layout == DataLayout::NCHW) { // Configure col2im @@ -428,18 +575,27 @@ void CpuGemmConv2d::configure(const ITensorInfo *src, _reshape->configure(gemm_output_to_use, dst); } - // Check if GEMM transforms weights - // Modernise through COMPMID-4535 - bool gemm_trans_wei = _aux_mem[1].size > 0; // Asm Pretranspose - gemm_trans_wei = _mm_gemm != nullptr ? _aux_mem[3].size > 0 : gemm_trans_wei; // Tranpose RHS - gemm_trans_wei = _mm_gemmlowp != nullptr ? _aux_mem[5].size > 0 : gemm_trans_wei; // Transpose RHS - // Check lifetime _aux_mem[Im2ColOutput] = MemoryInfo(offset_int_vec(Im2ColOutput), MemoryLifetime::Temporary, _im2col_output.total_size()); - _aux_mem[WeightsReshaped] = MemoryInfo(offset_int_vec(WeightsReshaped), - gemm_trans_wei ? MemoryLifetime::Prepare : MemoryLifetime::Persistent, - _weights_reshaped.total_size()); + // Add WeightsReshaped memory requirement to workspace + // Note that in case of WeightTransformMethod::ReinterpretThenTranspose, we do not need to allocate this memory + // However since we cannot determine weight transformation method until prepare (see prepare()), we will have to + // settle with allocating more + if (_run_wt) + { + // Check if GEMM transforms weights + // If weight is further transformed by underlying gemm after ReshapeThenTranspose then we can free + // WeightsReshaped in prepare + // Otherwise WeightsReshaped is the final transformation of weights and needs to persist + bool gemm_trans_wei = _aux_mem[GemmAsmPretransposedRHS].size > 0; + gemm_trans_wei = _mm_gemm != nullptr ? _aux_mem[GemmTransposed1xWRHS].size > 0 : gemm_trans_wei; + gemm_trans_wei = _mm_gemmlowp != nullptr ? _aux_mem[GemmLowpTransposed1xWRHS].size > 0 : gemm_trans_wei; + + _aux_mem[WeightsReshaped] = MemoryInfo(offset_int_vec(WeightsReshaped), + gemm_trans_wei ? MemoryLifetime::Prepare : MemoryLifetime::Persistent, + _weights_reshaped.total_size()); + } _aux_mem[GemmOutput] = MemoryInfo(offset_int_vec(GemmOutput), MemoryLifetime::Temporary, _gemm_output.total_size()); } @@ -471,10 +627,18 @@ Status CpuGemmConv2d::has_opt_impl(arm_compute::WeightFormat &expected_weight_fo const bool skip_col2im = skip_info.skip_col2im; const unsigned int gemm_3d_depth = skip_col2im ? conv_h : 0; const bool fixed_format = weights_info.weight_format() != arm_compute::WeightFormat::UNSPECIFIED; - const GEMMInfo gemm_info = - GEMMInfo(false, false, true /* Reshape weights only for the first run */, gemm_3d_depth, - skip_im2col /* Reinterpret the input as 3D if im2col is skipped */, false, GEMMLowpOutputStageInfo(), - false, enable_fast_math, false, act_info, fixed_format, weights_info.weight_format()); + + /** @section note_CpuGemmConv2d_weight_use_in_has_opt_impl Which weights tensor should we use for has_opt_impl + * + * For the pretranspose_B flag, this shares a similar problem and thus the same decision as that of + * @ref note_CpuGemmConv2d_weight_use_in_configure + * + * But for the weights, we shall always use the original instead of reshaped weights here + */ + const GEMMInfo gemm_info = GEMMInfo(false, false, true /* Reshape weights only for the first run */, gemm_3d_depth, + skip_im2col /* Reinterpret the input as 3D if im2col is skipped */, false, + GEMMLowpOutputStageInfo(), false, enable_fast_math, false, act_info, + fixed_format, weights_info.weight_format(), true /* pretranspose_B */); return CpuGemm::has_opt_impl(expected_weight_format, src, weights, biases, dst, gemm_info); } @@ -565,8 +729,10 @@ Status CpuGemmConv2d::validate(const ITensorInfo *src, unsigned int mat_weights_rows = weights->dimension(idx_width) * weights->dimension(idx_height) * weights->dimension(idx_channel); - weights_reshaped_info = TensorInfo(compute_weights_reshaped_shape(*weights, append_bias), 1, weights->data_type()); - weights_reshaped_info.set_quantization_info(weights->quantization_info()); + // Initialize reshaped weights + initialize_reshaped_weight_info(*weights, weights_reshaped_info); + // No need to call CpuReshape::validate() or CpuTranspose::validate() as the dst info is auto-configured from the + // src weights_to_use = &weights_reshaped_info; if (!skip_im2col) @@ -613,6 +779,7 @@ Status CpuGemmConv2d::validate(const ITensorInfo *src, gemm_output_to_use = &info_gemm; const bool fixed_format = weights_info.weight_format() != arm_compute::WeightFormat::UNSPECIFIED; + // See note_CpuGemmConv2d_weight_use_in_configure regarding the choice of the weights ARM_COMPUTE_RETURN_ON_ERROR(validate_mm(gemm_input_to_use, weights_to_use, biases, gemm_output_to_use, act_info, enable_fast_math, skip_col2im ? conv_h : 0, skip_im2col, fixed_format, weights_info.weight_format())); @@ -637,7 +804,6 @@ void CpuGemmConv2d::run(ITensorPack &tensors) CpuAuxTensorHandler im2col_output(offset_int_vec(Im2ColOutput), _im2col_output, tensors, false); CpuAuxTensorHandler gemm_output(offset_int_vec(GemmOutput), _gemm_output, tensors, false); - CpuAuxTensorHandler reshaped_wei(offset_int_vec(WeightsReshaped), _weights_reshaped, tensors, false); bool out_has_padding = _skip_col2im && (dst->info()->padding().bottom != 0 || dst->info()->padding().top != 0); if (!_skip_im2col) @@ -666,25 +832,32 @@ void CpuGemmConv2d::run(ITensorPack &tensors) gemm_output_to_use = dst; } - // Runs CpuGemm or CpuGemmLowpMatrixMultiplyCore functions - ITensorPack pack_mm = tensors; - pack_mm.add_const_tensor(TensorType::ACL_SRC_0, gemm_input_to_use); - if (!this->isVarWeightsKernel()) - { - pack_mm.add_const_tensor(TensorType::ACL_SRC_1, reshaped_wei.get()); - } - pack_mm.add_tensor(TensorType::ACL_DST, gemm_output_to_use); - if (_is_quantized) + ITensorPack gemm_pack = tensors; + gemm_pack.add_const_tensor(TensorType::ACL_SRC_0, gemm_input_to_use); + gemm_pack.add_tensor(TensorType::ACL_DST, gemm_output_to_use); + // Allocate reshaped weights if required + auto weights = gemm_pack.get_const_tensor(TensorType::ACL_SRC_1); + CpuAuxTensorHandler reinterpreted_wei( + _weights_reshaped, + *weights); // Re-interpreted weights. Only tensor shape is changed. No allocation + CpuAuxTensorHandler reshaped_wei(offset_int_vec(WeightsReshaped), _weights_reshaped, tensors); + // Update the weights to use if it has been reshaped + if (_run_wt) { - // Run gemmlowp - _mm_gemmlowp->run(pack_mm); - } - else - { - // Run gemm - _mm_gemm->run(pack_mm); + if (_wt_method == WeightTransformMethod::ReinterpretThenTranspose) + { + gemm_pack.add_const_tensor(TensorType::ACL_SRC_1, reinterpreted_wei.get()); + } + else if (_wt_method == WeightTransformMethod::ReshapeThenTranspose || + _wt_method == WeightTransformMethod::FusedReshapeAndTranspose) + { + gemm_pack.add_const_tensor(TensorType::ACL_SRC_1, reshaped_wei.get()); + } } + // Runs CpuGemm or CpuGemmLowpMatrixMultiplyCore functions + _is_quantized ? _mm_gemmlowp->run(gemm_pack) : _mm_gemm->run(gemm_pack); + // Reshape output matrix if (!_skip_col2im) { @@ -710,24 +883,87 @@ void CpuGemmConv2d::prepare(ITensorPack &tensors) { if (!_is_prepared) { - // Variable weights executions that use fixed-format kernels - // need no reshaping of the weights. - if (this->isVarWeightsKernel()) + auto weights = tensors.get_const_tensor(TensorType::ACL_SRC_1); + // Determine which weights reshape path to take + // Note that this decision can only occur at prepare instead of configure because it relies on the presence of + // any holes in the weight tensor, which may change after configure (e.g. from extending padding) + if (_run_wt) { - _is_quantized ? _mm_gemmlowp->prepare(tensors) : _mm_gemm->prepare(tensors); - _is_prepared = true; - return; + _wt_method = get_wt_method(*(weights->info())); + switch (_wt_method) + { + case (WeightTransformMethod::FusedReshapeAndTranspose): + { + ARM_COMPUTE_LOG_INFO_WITH_FUNCNAME_ACL("Perform weight transformation: FusedReshapeAndTranspose"); + _weights_reshape_and_transpose_kernel = std::make_unique<kernels::CpuWeightsReshapeKernel>(); + _weights_reshape_and_transpose_kernel->configure(weights->info(), nullptr, &_weights_reshaped); + break; + } + case (WeightTransformMethod::ReshapeThenTranspose): + { + ARM_COMPUTE_LOG_INFO_WITH_FUNCNAME_ACL("Perform weight transformation: ReshapeThenTranspose"); + _weights_reshape = std::make_unique<CpuReshape>(); + _weights_reshape->configure(weights->info(), &_weights_reshaped); + break; + } + case (WeightTransformMethod::ReinterpretThenTranspose): + { + ARM_COMPUTE_LOG_INFO_WITH_FUNCNAME_ACL("Perform weight transformation: ReinterpretThenTranspose"); + // Nothing to configure + break; + } + default: + { + ARM_COMPUTE_ERROR("Unsupported weight transform method"); + } + } + } + else + { + ARM_COMPUTE_LOG_INFO_WITH_FUNCNAME_ACL("No weight transformation is performed"); } - - // Run weights reshaping and mark original weights tensor as unused - CpuAuxTensorHandler weights_reshaped(offset_int_vec(WeightsReshaped), _weights_reshaped, tensors); - auto weights = tensors.get_const_tensor(TensorType::ACL_SRC_1); - ITensorPack pack = {{TensorType::ACL_SRC, weights}, {TensorType::ACL_DST, weights_reshaped.get()}}; - NEScheduler::get().schedule_op(_weights_reshape_kernel.get(), 3, _weights_reshape_kernel->window(), pack); - weights->mark_as_unused(); ITensorPack gemm_pack = tensors; - gemm_pack.add_const_tensor(TensorType::ACL_SRC_1, weights_reshaped.get()); + // Allocate reshaped weights if required + CpuAuxTensorHandler reinterpreted_wei( + _weights_reshaped, + *weights); // Re-interpreted weights. Only tensor shape is changed. No allocation + CpuAuxTensorHandler reshaped_wei(offset_int_vec(WeightsReshaped), _weights_reshaped, tensors); + // Run weights reshape if required + if (_run_wt) + { + switch (_wt_method) + { + case (WeightTransformMethod::FusedReshapeAndTranspose): + { + ITensorPack pack = {{TensorType::ACL_SRC, weights}, {TensorType::ACL_DST, reshaped_wei.get()}}; + NEScheduler::get().schedule_op(_weights_reshape_and_transpose_kernel.get(), Window::DimW, + _weights_reshape_and_transpose_kernel->window(), pack); + weights->mark_as_unused(); + gemm_pack.add_const_tensor(TensorType::ACL_SRC_1, reshaped_wei.get()); + break; + } + case (WeightTransformMethod::ReshapeThenTranspose): + { + ITensorPack pack = {{TensorType::ACL_SRC, weights}, {TensorType::ACL_DST, reshaped_wei.get()}}; + _weights_reshape->run(pack); + weights->mark_as_unused(); + gemm_pack.add_const_tensor(TensorType::ACL_SRC_1, reshaped_wei.get()); + break; + } + case (WeightTransformMethod::ReinterpretThenTranspose): + { + gemm_pack.add_const_tensor(TensorType::ACL_SRC_1, reinterpreted_wei.get()); + // Nothing to run + break; + } + default: + { + ARM_COMPUTE_ERROR("Unsupported weight transform method"); + } + } + } _is_quantized ? _mm_gemmlowp->prepare(gemm_pack) : _mm_gemm->prepare(gemm_pack); + _is_prepared = true; } } |