aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSiCong Li <sicong.li@arm.com>2023-10-17 17:38:57 +0100
committerSiCong Li <sicong.li@arm.com>2023-11-08 09:49:56 +0000
commitc5ab4df0c11dc66db47f2070edc719923af3367e (patch)
treec04bdac32528e628b2a9b9a1c1653e300328fc1b
parent4a9dbedfbfa66c2612c7461e60cd867b8aea825b (diff)
downloadComputeLibrary-c5ab4df0c11dc66db47f2070edc719923af3367e.tar.gz
Optimize CpuGemmConv2d start-up time
When weight has no holes, we can replace CpuWeightsReshapeKernel with: - Collapse by reinterpreting weight's 3 spatial dimensions - Perform CpuTranspose For more details see the documentation in src/cpu/operators/CpuGemmConv2d.cpp This is one optimization since the CpuTranspose is better performing than CpuWeightsReshapeKernel A second optimization is to fuse this transpose with other weight transformations (e.g. pretranspose_B_array in CpuGemmAssemblyDispatch) However this second optimization depends on how the underlying gemm methods (the fall back path: CpuGemmMatrixMultiplyKernel or the assembly path: CpuGemmAssemblyDispatch) chooses to fuse the transpose. Therefore, this patch moves the transpose down from CpuGemmConv2d, to the individual gemm operators where the fusion decision needs to be made, by passing an extra "transpose_b" flag to CpuGemm New transpose_b flag in different scopes (they are all the same, but with different names because pretranspose_b has a different meaning in GemmAssemblyDispatch): GEMMInfo::pretranspose_B -> AsmGemmInfo::transpose_b New auxilliary tensors holding the transposed b result: - CpuGemm optimized path: CpuGemmAssemblyDispatch::PrePretransposedB - CpuGemm fallback path: CpuGemm::PreTransposedRHS Note that this patch does not yet have the second optimization (COMPMID-6595), but it prepares for it. Relates to COMPMID-6595 Resolves COMPMID-6499 Change-Id: I999a2da9da4b2b15369a3cc06d7872c86e0190ea Signed-off-by: SiCong Li <sicong.li@arm.com> Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/10526 Tested-by: Arm Jenkins <bsgcomp@arm.com> Reviewed-by: Anitha Raj <Anitha.Raj@arm.com> Reviewed-by: Gunes Bayir <gunes.bayir@arm.com> Comments-Addressed: Arm Jenkins <bsgcomp@arm.com> Benchmark: Arm Jenkins <bsgcomp@arm.com>
-rw-r--r--arm_compute/function_info/GEMMInfo.h8
-rw-r--r--arm_compute/runtime/NEON/functions/NEConvolutionLayer.h8
-rw-r--r--docs/user_guide/release_version_and_change_log.dox1
-rw-r--r--src/core/helpers/Utils.cpp5
-rw-r--r--src/core/helpers/Utils.h12
-rw-r--r--src/cpu/operators/CpuFullyConnected.h22
-rw-r--r--src/cpu/operators/CpuGemm.cpp195
-rw-r--r--src/cpu/operators/CpuGemm.h21
-rw-r--r--src/cpu/operators/CpuGemmConv2d.cpp372
-rw-r--r--src/cpu/operators/CpuGemmConv2d.h55
-rw-r--r--src/cpu/operators/CpuGemmDirectConv2d.cpp8
-rw-r--r--src/cpu/operators/CpuGemmDirectConv2d.h12
-rw-r--r--src/cpu/operators/CpuGemmLowpMatrixMultiplyCore.cpp8
-rw-r--r--src/cpu/operators/CpuGemmLowpMatrixMultiplyCore.h11
-rw-r--r--src/cpu/operators/CpuMatMul.h11
-rw-r--r--src/cpu/operators/CpuWinogradConv2d.cpp13
-rw-r--r--src/cpu/operators/CpuWinogradConv2d.h32
-rw-r--r--src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp104
-rw-r--r--src/cpu/operators/internal/CpuGemmAssemblyDispatch.h13
-rw-r--r--src/cpu/utils/CpuAuxTensorHandler.h16
-rw-r--r--tests/validation/NEON/ConvolutionLayer.cpp18
-rw-r--r--tests/validation/fixtures/ConvolutionLayerFixture.h29
22 files changed, 735 insertions, 239 deletions
diff --git a/arm_compute/function_info/GEMMInfo.h b/arm_compute/function_info/GEMMInfo.h
index c24762c0aa..a827c79fda 100644
--- a/arm_compute/function_info/GEMMInfo.h
+++ b/arm_compute/function_info/GEMMInfo.h
@@ -105,6 +105,7 @@ public:
* @param[in] activation_info (Optional) Activation to apply after the matrix multiplication
* @param[in] fixed_format (Optional) Specify the selection of fixed format kernels for variable weights support in GEMM. These kernels expect the weights tensor to be in amemory format that is fixed by the kernel itself. For more information, see arm_compute::WeightFormat.
* @param[in] weight_format (Optional) arm_gemm:WeightFormat enumeration requested by the user. Default is arm_compute::WeightFormat::UNSPECIFIED.
+ * @param[in] pretranspose_B (Optional) Pretranspose matrix B (transposition of its lowest 2 dimensions), in addition to and before, any further transformations of B
*/
GEMMInfo(bool is_a_reshaped,
bool is_b_reshaped,
@@ -118,7 +119,8 @@ public:
bool broadcast_bias = false,
const ActivationLayerInfo &activation_info = ActivationLayerInfo(),
bool fixed_format = false,
- arm_compute::WeightFormat weight_format = arm_compute::WeightFormat::UNSPECIFIED) noexcept
+ arm_compute::WeightFormat weight_format = arm_compute::WeightFormat::UNSPECIFIED,
+ bool pretranspose_B = false) noexcept
: _is_a_reshaped(is_a_reshaped),
_is_b_reshaped(is_b_reshaped),
_reshape_b_only_on_first_run(reshape_b_only_on_first_run),
@@ -130,7 +132,7 @@ public:
_fp_mixed_precision(fp_mixed_precision),
_broadcast_bias(broadcast_bias),
_pretranspose_A(false),
- _pretranspose_B(false),
+ _pretranspose_B(pretranspose_B),
_activation_info(activation_info),
_fixed_format(fixed_format),
_weight_format(weight_format)
@@ -251,6 +253,8 @@ public:
_pretranspose_A = flag;
}
/** Flag which specifies whether b should be pre-transposed if supported.
+ * More concretely, the "pre-transpose" is the transposition of the b tensor's lowest 2 dimensions
+ * If specified true, this pre-transpose will occur in addition to and before, any further transformations of the b matrix
*
* @return True if b should be pre-transposed else false.
*/
diff --git a/arm_compute/runtime/NEON/functions/NEConvolutionLayer.h b/arm_compute/runtime/NEON/functions/NEConvolutionLayer.h
index cdf0f652e1..2d07980ade 100644
--- a/arm_compute/runtime/NEON/functions/NEConvolutionLayer.h
+++ b/arm_compute/runtime/NEON/functions/NEConvolutionLayer.h
@@ -21,8 +21,8 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*/
-#ifndef ARM_COMPUTE_NECONVOLUTIONLAYER_H
-#define ARM_COMPUTE_NECONVOLUTIONLAYER_H
+#ifndef ACL_ARM_COMPUTE_RUNTIME_NEON_FUNCTIONS_NECONVOLUTIONLAYER_H
+#define ACL_ARM_COMPUTE_RUNTIME_NEON_FUNCTIONS_NECONVOLUTIONLAYER_H
#include "arm_compute/core/ITensorInfo.h"
#include "arm_compute/core/Types.h"
@@ -38,7 +38,7 @@ namespace arm_compute
class ITensor;
/** Basic function to simulate a convolution layer. This function calls one of the following functions:
- * -# @ref cpu::CpuGemm (executed only in case GEMM is required for the operation)
+ * -# @ref cpu::CpuGemmConv2d (executed only in case GEMM is required for the operation)
* -# @ref cpu::CpuWinogradConv2d (executed only in case Winograd is required for the operation)
* -# @ref cpu::CpuDirectConv2d (executed only in case Direct Convolution is required for the operation)
* -# @ref NEFFTConvolutionLayer (executed only in case FFT is required for the operation)
@@ -196,4 +196,4 @@ private:
std::unique_ptr<Impl> _impl;
};
} // namespace arm_compute
-#endif /* ARM_COMPUTE_NECONVOLUTIONLAYER_H */
+#endif // ACL_ARM_COMPUTE_RUNTIME_NEON_FUNCTIONS_NECONVOLUTIONLAYER_H
diff --git a/docs/user_guide/release_version_and_change_log.dox b/docs/user_guide/release_version_and_change_log.dox
index c07cf88d80..b6627a9701 100644
--- a/docs/user_guide/release_version_and_change_log.dox
+++ b/docs/user_guide/release_version_and_change_log.dox
@@ -59,6 +59,7 @@ v23.11 Public major release
- Optimize @ref NEStackLayer
- Optimize @ref CLReductionOperation.
- Optimize @ref CLSoftmaxLayer.
+ - Optimize start-up time of @ref NEConvolutionLayer for some input configurations where GeMM is selected as the convolution algorithm
- Add new OpenCL™ kernels:
- @ref opencl::kernels::ClMatMulLowpNativeMMULKernel support for QASYMM8 and QASYMM8_SIGNED, with batch support
- Deprecate support for Bfloat16 in @ref cpu::CpuCast.
diff --git a/src/core/helpers/Utils.cpp b/src/core/helpers/Utils.cpp
index 6ca29d180d..f8895d8a3c 100644
--- a/src/core/helpers/Utils.cpp
+++ b/src/core/helpers/Utils.cpp
@@ -25,6 +25,11 @@
namespace arm_compute
{
+bool has_holes(const ITensorInfo &info)
+{
+ return has_holes(info, info.num_dimensions() - 1);
+}
+
bool has_holes(const ITensorInfo &info, size_t dimension)
{
const auto &shape = info.tensor_shape();
diff --git a/src/core/helpers/Utils.h b/src/core/helpers/Utils.h
index 2e7224c55b..a17a78f7ee 100644
--- a/src/core/helpers/Utils.h
+++ b/src/core/helpers/Utils.h
@@ -95,6 +95,18 @@ inline unsigned int get_next_power_two(unsigned int x)
/** Check if the tensor has any holes.
*
+ * A hole is defined as any gap in the tensor between two consecutive values. This can be a result of extending
+ * the paddings or manipulating the strides of the tensor
+ *
+ * @param[in] info Tensor info object defining the shape of the input tensor.
+ *
+ * @note This function checks for holes in all dimensions.
+ *
+ */
+bool has_holes(const ITensorInfo &info);
+
+/** Check if the tensor has any holes.
+ *
* @param[in] info Tensor info object defining the shape of the input tensor.
* @param[in] dimension Highest dimension to check.
*
diff --git a/src/cpu/operators/CpuFullyConnected.h b/src/cpu/operators/CpuFullyConnected.h
index 7073fb9f7c..b72f77e5c4 100644
--- a/src/cpu/operators/CpuFullyConnected.h
+++ b/src/cpu/operators/CpuFullyConnected.h
@@ -21,8 +21,8 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*/
-#ifndef ARM_COMPUTE_CPU_FULLY_CONNECTED_H
-#define ARM_COMPUTE_CPU_FULLY_CONNECTED_H
+#ifndef ACL_SRC_CPU_OPERATORS_CPUFULLYCONNECTED_H
+#define ACL_SRC_CPU_OPERATORS_CPUFULLYCONNECTED_H
#include "arm_compute/core/TensorInfo.h"
#include "arm_compute/function_info/FullyConnectedLayerInfo.h"
@@ -145,13 +145,15 @@ private:
{
AsmGemmWorkspace = 0,
Pretranspose,
- GemmTemp1, // Both CpuGemm and CpuGemmLowpMatrixMultiplyCore
- GemmTemp2, // Both CpuGemm and CpuGemmLowpMatrixMultiplyCore
- GemmTemp3, // Both CpuGemm and CpuGemmLowpMatrixMultiplyCore
- GemmTemp4, // CpuGemmLowpMatrixMultiplyCore only
- GemmTemp5, // CpuGemmLowpMatrixMultiplyCore only
- GemmTemp6, // CpuGemmLowpMatrixMultiplyCore only
- GemmTemp7, // CpuGemmLowpMatrixMultiplyCore only
+ GemmTemp1,
+ GemmTemp2,
+ GemmTemp3,
+ GemmTemp4,
+ GemmTemp5,
+ GemmTemp6,
+ GemmTemp7,
+ GemmTemp8,
+ // Slots above (0-9) reserved for either CpuGemm or CpuGemmLowpMatrixMultiplyCore
TransposedWeights,
ConvertedWeights,
FlattenedSrc,
@@ -189,4 +191,4 @@ private:
};
} // namespace cpu
} // namespace arm_compute
-#endif /* ARM_COMPUTE_CPU_FULLY_CONNECTED_H */
+#endif // ACL_SRC_CPU_OPERATORS_CPUFULLYCONNECTED_H
diff --git a/src/cpu/operators/CpuGemm.cpp b/src/cpu/operators/CpuGemm.cpp
index 8da166dbef..e035de0131 100644
--- a/src/cpu/operators/CpuGemm.cpp
+++ b/src/cpu/operators/CpuGemm.cpp
@@ -53,6 +53,8 @@ cpu::AsmGemmInfo init_assembly_metadata(const GEMMInfo &info)
asm_info.fast_mode = info.fast_math();
asm_info.fixed_format = info.fixed_format();
asm_info.weight_format = info.weight_format();
+ asm_info.transpose_b =
+ info.pretranspose_B(); // The "pretranspose_B" flag here is not the same as the pretranspose_B_array method. The flag here signals to pretranspose_B_array method if we want to perform additional transpose on B before the pretranspose_B_array method
return asm_info;
}
@@ -72,7 +74,7 @@ void CpuGemm::configure(const ITensorInfo *a,
const cpu::AsmGemmInfo asm_info = init_assembly_metadata(gemm_info);
const bool is_c_bias = beta == 1 && c != nullptr;
- bool run_optimised =
+ const bool run_optimised =
bool(cpu::CpuGemmAssemblyDispatch::validate(a, b, (is_c_bias) ? c : nullptr, d, asm_info)) &&
(c == nullptr || beta == 0.f || beta == 1.f) && // Optimized GeMM doesn't support beta coefficient.
!(!b->are_values_constant() &&
@@ -92,14 +94,17 @@ void CpuGemm::configure(const ITensorInfo *a,
if (run_optimised)
{
+ _run_interleave_transpose = false;
const ITensorInfo *c_to_use = is_c_bias ? c : nullptr;
_asm_glue = std::make_unique<cpu::CpuGemmAssemblyDispatch>();
_asm_glue->configure(a, b, c_to_use, d, asm_info);
ARM_COMPUTE_ERROR_ON(!_asm_glue->is_configured());
- auto asm_mem_req = _asm_glue->workspace();
- _aux_mem[AsmGemmWorkspace] = asm_mem_req[AsmGemmWorkspace];
- _aux_mem[Pretraspose] = asm_mem_req[Pretraspose];
+ const auto asm_mem_req = _asm_glue->workspace();
+ for (unsigned int slot = 0; slot < asm_mem_req.size(); ++slot)
+ {
+ _aux_mem[slot] = asm_mem_req[slot];
+ }
// Scale product by alpha
if (_run_alpha_scale)
@@ -111,37 +116,74 @@ void CpuGemm::configure(const ITensorInfo *a,
}
else
{
+ _run_interleave_transpose = !_run_vector_matrix_multiplication;
// Pick output tensor in case bias addition should be performed
ITensorInfo *gemm_output_to_use = (_run_bias_addition) ? &_tmp_d : d;
+ // Pick b tensor in case pretranspose should be performed
+ const ITensorInfo *b_to_use = b;
_mm_kernel = std::make_unique<cpu::kernels::CpuGemmMatrixMultiplyKernel>();
+ // Configure rhs pretranspose
+ if (gemm_info.pretranspose_B())
+ {
+ _pretranspose_b_func = std::make_unique<CpuTranspose>();
+ _pretranspose_b_func->configure(b_to_use, &_pretransposed_b);
+ MemoryLifetime lifetime;
+ if (_reshape_b_only_on_first_run)
+ {
+ if (_run_interleave_transpose)
+ {
+ // PreTransposedRHS tensor is only used in prepare(), but is then succeeded by Transposed1xWRHS
+ // So PreTransposedRHS can be freed inside prepare()
+ lifetime = MemoryLifetime::Prepare;
+ }
+ else
+ {
+ // PreTransposedRHS tensor is only used in prepare(), but is the final transformation of rhs
+ // So PreTransposedRHS needs to persist beyond prepare()
+ lifetime = MemoryLifetime::Persistent;
+ }
+ }
+ else
+ {
+ // PreTransposedRHS tensor is always used in run() and doesn't need to persist
+ lifetime = MemoryLifetime::Temporary;
+ }
+ _aux_mem[PreTransposedRHS] =
+ MemoryInfo(offset_int_vec(PreTransposedRHS), lifetime, _pretransposed_b.total_size());
+ b_to_use = &_pretransposed_b;
+ }
+
// Select between GEMV and GEMM
if (_run_vector_matrix_multiplication)
{
// Configure the matrix multiply kernel
- _mm_kernel->configure(a, b, gemm_output_to_use, alpha, false);
+ _mm_kernel->configure(a, b_to_use, gemm_output_to_use, alpha, false);
}
else
{
- const int m = a->dimension(1);
- const int n = b->dimension(0);
- const int k = a->dimension(0);
-
+ ARM_COMPUTE_ERROR_ON(!_run_interleave_transpose);
// Configure interleave kernel
_interleave_kernel = std::make_unique<cpu::kernels::CpuGemmInterleave4x4Kernel>();
_interleave_kernel->configure(a, &_tmp_a);
_aux_mem[InterleavedLHS] =
MemoryInfo(offset_int_vec(InterleavedLHS), MemoryLifetime::Temporary, _tmp_a.total_size());
- // Configure transpose kernel
- _transpose_kernel = std::make_unique<cpu::kernels::CpuGemmTranspose1xWKernel>();
- _transpose_kernel->configure(b, &_tmp_b);
- _aux_mem[TransposedRHS] =
- MemoryInfo(offset_int_vec(TransposedRHS), MemoryLifetime::Persistent, _tmp_b.total_size());
+ // Configure rhs transpose1xw kernel
+ _transpose1xW_b_kernel = std::make_unique<cpu::kernels::CpuGemmTranspose1xWKernel>();
+ _transpose1xW_b_kernel->configure(b_to_use, &_tmp_b);
+ _aux_mem[Transposed1xWRHS] =
+ MemoryInfo(offset_int_vec(Transposed1xWRHS), MemoryLifetime::Persistent, _tmp_b.total_size());
+
+ // Use a and b here instead of _tmp_a and _tmp_b because CpuGemmMatrixMultiplyKernel requires the original m,n,k in case of interleaved a and transposed1xw b
+ const int m = a->dimension(1);
+ const int n = b_to_use->dimension(0);
+ const int k = a->dimension(0);
// Configure matrix multiplication kernel
- _mm_kernel->configure(&_tmp_a, &_tmp_b, gemm_output_to_use, alpha, true, GEMMReshapeInfo(m, n, k));
+ _mm_kernel->configure(&_tmp_a, &_tmp_b, gemm_output_to_use, alpha, _run_interleave_transpose,
+ GEMMReshapeInfo(m, n, k));
}
if (_run_bias_addition)
@@ -179,6 +221,16 @@ Status CpuGemm::validate(const ITensorInfo *a,
ARM_COMPUTE_UNUSED(alpha);
const bool is_c_bias = beta == 1 && c != nullptr;
const bool run_addition = c != nullptr && beta != 0 && beta != 1;
+ // Check if we should use the pretransposed_b or original b
+ // TODO: COMPMID-6597
+ // Note that this check should only apply to the non-optimized path. The reason we brought this at the beginning
+ // instead of only for the fallback path is because of the checks performed below, between here and the run_optimised decision
+ // We should simplify this by
+ // 1. Moving the checks between "fix-start" and "fix-end" into their corresponding ops / kernels (e.g. the weights format checks can and should be moved into CpuGemmAssemblyDispatch)
+ // 2. Moving this b_to_use check back into the non-optimized path
+ TensorInfo pretransposed_b = b->clone()->set_tensor_shape(misc::shape_calculator::compute_transposed_shape(*b));
+ const ITensorInfo *b_to_use = gemm_info.pretranspose_B() ? &pretransposed_b : b;
+ // TODO: COMPMID-6597 fix-start
ARM_COMPUTE_RETURN_ERROR_ON_CPU_F16_UNSUPPORTED(a);
ARM_COMPUTE_RETURN_ERROR_ON_CPU_BF16_UNSUPPORTED(a);
@@ -187,16 +239,16 @@ Status CpuGemm::validate(const ITensorInfo *a,
if (is_fixed_format_fast_math(gemm_info.weight_format()))
{
ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_NOT_IN(a, DataType::F32);
- ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_NOT_IN(b, DataType::BFLOAT16);
+ ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_NOT_IN(b_to_use, DataType::BFLOAT16);
}
else
{
- ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(a, b);
+ ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(a, b_to_use);
}
const int block_by = arm_compute::block_by(gemm_info.weight_format());
// test if im2col has changed the dimensions that are needed for padding
- if (a->dimension(0) != b->dimension(1) && block_by > 1)
+ if (a->dimension(0) != b_to_use->dimension(1) && block_by > 1)
{
// have to verify bias
const size_t dim0_sz = a->dimension(0);
@@ -204,18 +256,18 @@ Status CpuGemm::validate(const ITensorInfo *a,
(dim0_sz % block_by) != 0,
("The matrix A number of columns must be a multiple of block_by=" + std::to_string(block_by)).c_str());
// a->dimension(0) = kernel_area * input_channel + kernel_area * input_pad_right
- // b->dimension(1) = kernel_area * input_channel
- // a->dimension(0) = b->dimension(1) + kernel_area * input_pad_right
- const size_t input_pad_right = (dim0_sz - b->dimension(1)) % block_by;
- const size_t kernel_area = (dim0_sz - b->dimension(1)) / input_pad_right;
+ // b_to_use->dimension(1) = kernel_area * input_channel
+ // a->dimension(0) = b_to_use->dimension(1) + kernel_area * input_pad_right
+ const size_t input_pad_right = (dim0_sz - b_to_use->dimension(1)) % block_by;
+ const size_t kernel_area = (dim0_sz - b_to_use->dimension(1)) / input_pad_right;
ARM_COMPUTE_RETURN_ERROR_ON_MSG(
- (dim0_sz - kernel_area * input_pad_right) != b->dimension(1),
+ (dim0_sz - kernel_area * input_pad_right) != b_to_use->dimension(1),
"The product AB is defined only if A number of columns and B number of rows are related");
}
else
{
ARM_COMPUTE_RETURN_ERROR_ON_MSG(
- a->dimension(0) != b->dimension(1),
+ a->dimension(0) != b_to_use->dimension(1),
"The product AB is defined only if the number of columns in A is equal to the number of rows in B");
}
@@ -233,14 +285,14 @@ Status CpuGemm::validate(const ITensorInfo *a,
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(c, d);
ARM_COMPUTE_RETURN_ERROR_ON_MSG(a->dimension(1) != c->dimension(1),
"The C matrix must have the same number of rows as the matrix A");
- ARM_COMPUTE_RETURN_ERROR_ON_MSG(b->dimension(0) != c->dimension(0),
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(b_to_use->dimension(0) != c->dimension(0),
"The C matrix must have the same number of columns as the matrix B");
}
if (d->total_size() != 0)
{
// For fixed format we are expecting some kind of blocked format for B/RHS so the dimension won't necessarily match the result matrix any more.
- ARM_COMPUTE_RETURN_ERROR_ON(!gemm_info.fixed_format() && b->dimension(0) != d->dimension(0));
+ ARM_COMPUTE_RETURN_ERROR_ON(!gemm_info.fixed_format() && b_to_use->dimension(0) != d->dimension(0));
if (gemm_info.depth_output_gemm3d() != 0)
{
if (gemm_info.reinterpret_input_as_3d())
@@ -258,10 +310,14 @@ Status CpuGemm::validate(const ITensorInfo *a,
ARM_COMPUTE_RETURN_ERROR_ON(a->dimension(1) != d->dimension(1));
}
}
+ // TODO: COMPMID-6597 fix-end
// Check if we need to run the optimized assembly kernel
cpu::AsmGemmInfo asm_info = init_assembly_metadata(gemm_info);
- const bool run_optimised =
+
+ // Note we use b instead of b_to_use here because asm_info also captures the pretranspose_b() flag
+ // so we pass the original b to CpuGemmAssemblyDispatch
+ const bool run_optimised =
bool(cpu::CpuGemmAssemblyDispatch::validate(a, b, is_c_bias ? c : nullptr, d, asm_info)) &&
(c == nullptr || beta == 0.f || beta == 1.f) && // Optimized GeMM doesn't support beta coefficient.
!(!b->are_values_constant() &&
@@ -277,13 +333,13 @@ Status CpuGemm::validate(const ITensorInfo *a,
// Check if the first input tensor is a vector.
const bool run_vector_matrix_multiplication = a->dimension(1) < 2;
// Check if we need to reshape the matrix A and matrix B
- const bool run_interleave_transpose = !run_vector_matrix_multiplication && !b->are_values_constant();
+ const bool run_interleave_transpose = !run_vector_matrix_multiplication;
// Arguments used by GEMMReshapeInfo
// If we pass the matrix A and matrix B reshaped to CpuGemmMatrixMultiplyKernel, we need to pass m, n, k, mult_transpose1xW_width and mult_interleave4x4_height to GEMMReshapeInfo
// in order to know how the matrices have been reshaped
const int m = a->dimension(1);
- const int n = b->dimension(0);
+ const int n = b_to_use->dimension(0);
const int k = a->dimension(0);
int mult_transpose1xW_width = 1;
int mult_interleave4x4_height = 1;
@@ -292,7 +348,7 @@ Status CpuGemm::validate(const ITensorInfo *a,
m, n, k, mult_transpose1xW_width, mult_interleave4x4_height, gemm_info.depth_output_gemm3d());
const ITensorInfo *matrix_a_info = a;
- const ITensorInfo *matrix_b_info = b;
+ const ITensorInfo *matrix_b_info = b_to_use;
TensorInfo tmp_a_info{};
TensorInfo tmp_b_info{};
@@ -309,9 +365,10 @@ Status CpuGemm::validate(const ITensorInfo *a,
ARM_COMPUTE_RETURN_ON_ERROR(cpu::kernels::CpuGemmInterleave4x4Kernel::validate(a, &tmp_a_info));
// Validate transpose kernel
- auto_init_if_empty(tmp_b_info, b->clone()->set_tensor_shape(compute_transpose1xW_with_element_size_shape(
- *b, mult_transpose1xW_width)));
- ARM_COMPUTE_RETURN_ON_ERROR(cpu::kernels::CpuGemmTranspose1xWKernel::validate(b, &tmp_b_info));
+ auto_init_if_empty(tmp_b_info,
+ b_to_use->clone()->set_tensor_shape(
+ compute_transpose1xW_with_element_size_shape(*b_to_use, mult_transpose1xW_width)));
+ ARM_COMPUTE_RETURN_ON_ERROR(cpu::kernels::CpuGemmTranspose1xWKernel::validate(b_to_use, &tmp_b_info));
}
// Validate matrix multiply
@@ -367,29 +424,46 @@ void CpuGemm::run(ITensorPack &tensors)
else
{
CpuAuxTensorHandler interleaved_a(offset_int_vec(InterleavedLHS), _tmp_a, tensors, true);
- CpuAuxTensorHandler transposed_b(offset_int_vec(TransposedRHS), _tmp_b, tensors, true);
+ CpuAuxTensorHandler pretransposed_b(offset_int_vec(PreTransposedRHS), _pretransposed_b, tensors);
+ CpuAuxTensorHandler transposed1xw_b(offset_int_vec(Transposed1xWRHS), _tmp_b, tensors, true);
CpuAuxTensorHandler temp_d(offset_int_vec(TempResult), _tmp_d, tensors, true);
ITensorPack mm_pack{{ACL_SRC_0, a}, {ACL_SRC_1, b}, {ACL_DST, (_run_bias_addition) ? temp_d.get() : d}};
- if (!_run_vector_matrix_multiplication)
+
+ if (_run_interleave_transpose)
{
// Run interleave kernel
ITensorPack interleave_pack{{ACL_SRC, a}, {ACL_DST, interleaved_a.get()}};
NEScheduler::get().schedule_op(_interleave_kernel.get(), Window::DimY, _interleave_kernel->window(),
interleave_pack);
+ // Use reshaped matrices
+ mm_pack.add_const_tensor(ACL_SRC_0, interleaved_a.get());
+ }
+ const ITensor *b_to_use = b;
+ if (_pretranspose_b_func)
+ {
if (!_reshape_b_only_on_first_run)
{
- // Run transpose kernel
- ITensorPack transpose_pack{{ACL_SRC, b}, {ACL_DST, transposed_b.get()}};
- NEScheduler::get().schedule_op(_transpose_kernel.get(), Window::DimY, _transpose_kernel->window(),
- transpose_pack);
+ // Run pretranspose kernel
+ ITensorPack pretranspose_pack{{ACL_SRC, b_to_use}, {ACL_DST, pretransposed_b.get()}};
+ _pretranspose_b_func->run(pretranspose_pack);
}
-
- // Use reshaped matrices
- mm_pack.add_const_tensor(ACL_SRC_0, interleaved_a.get());
- mm_pack.add_const_tensor(ACL_SRC_1, transposed_b.get());
+ b_to_use = pretransposed_b.get();
+ }
+ if (_run_interleave_transpose)
+ {
+ if (!_reshape_b_only_on_first_run)
+ {
+ // Run transpose1xw kernel
+ ITensorPack transpose_pack{{ACL_SRC, b_to_use}, {ACL_DST, transposed1xw_b.get()}};
+ NEScheduler::get().schedule_op(_transpose1xW_b_kernel.get(), Window::DimY,
+ _transpose1xW_b_kernel->window(), transpose_pack);
+ }
+ b_to_use = transposed1xw_b.get();
}
+ // Use reshaped matrices
+ mm_pack.add_const_tensor(ACL_SRC_1, b_to_use);
NEScheduler::get().schedule_op(_mm_kernel.get(),
_run_vector_matrix_multiplication ? Window::DimX : Window::DimY,
@@ -426,17 +500,32 @@ void CpuGemm::prepare(ITensorPack &tensors)
{
_asm_glue->prepare(tensors);
}
- else if (_reshape_b_only_on_first_run && !_run_vector_matrix_multiplication)
+ else if (_reshape_b_only_on_first_run)
{
- const ITensor *b = tensors.get_const_tensor(ACL_SRC_1);
- ITensor *b_aux =
- utils::cast::polymorphic_cast<ITensor *>(tensors.get_tensor(offset_int_vec(TransposedRHS)));
- ARM_COMPUTE_ERROR_ON_NULLPTR(b, b_aux);
-
- CpuAuxTensorHandler transposed_b(_tmp_b, *b_aux);
- ITensorPack transpose_pack{{ACL_SRC, b}, {ACL_DST, transposed_b.get()}};
- NEScheduler::get().schedule_op(_transpose_kernel.get(), Window::DimY, _transpose_kernel->window(),
- transpose_pack);
+ const ITensor *b = tensors.get_const_tensor(ACL_SRC_1);
+ const ITensor *b_to_use = b;
+ CpuAuxTensorHandler pretransposed_b(
+ offset_int_vec(PreTransposedRHS), _pretransposed_b, tensors,
+ false /*pack_inject: no need to inject into tensors*/,
+ _pretranspose_b_func ==
+ nullptr /*bypass_alloc: no need to allocate if _pretranspose_b_func is not run*/);
+ CpuAuxTensorHandler transposed1xw_b(offset_int_vec(Transposed1xWRHS), _tmp_b, tensors,
+ false /*pack_inject*/, !_run_interleave_transpose /*bypass_alloc*/);
+
+ if (_pretranspose_b_func)
+ {
+ // Run pretranspose kernel
+ ITensorPack pretranspose_pack{{ACL_SRC, b_to_use}, {ACL_DST, pretransposed_b.get()}};
+ _pretranspose_b_func->run(pretranspose_pack);
+ b_to_use = pretransposed_b.get();
+ }
+ if (_run_interleave_transpose)
+ {
+ // Run transpose kernel
+ ITensorPack transpose_pack{{ACL_SRC, b_to_use}, {ACL_DST, transposed1xw_b.get()}};
+ NEScheduler::get().schedule_op(_transpose1xW_b_kernel.get(), Window::DimY,
+ _transpose1xW_b_kernel->window(), transpose_pack);
+ }
}
_is_prepared = true;
}
diff --git a/src/cpu/operators/CpuGemm.h b/src/cpu/operators/CpuGemm.h
index 6b30d134fa..a05258d206 100644
--- a/src/cpu/operators/CpuGemm.h
+++ b/src/cpu/operators/CpuGemm.h
@@ -21,8 +21,8 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*/
-#ifndef ARM_COMPUTE_CPU_GEMM_H
-#define ARM_COMPUTE_CPU_GEMM_H
+#ifndef ACL_SRC_CPU_OPERATORS_CPUGEMM_H
+#define ACL_SRC_CPU_OPERATORS_CPUGEMM_H
#include "arm_compute/core/ITensorPack.h"
#include "arm_compute/core/TensorInfo.h"
@@ -36,6 +36,7 @@
#include "src/cpu/kernels/CpuGemmTranspose1xWKernel.h"
#include "src/cpu/operators/CpuActivation.h"
#include "src/cpu/operators/CpuAdd.h"
+#include "src/cpu/operators/CpuTranspose.h"
#include "src/cpu/operators/internal/CpuGemmAssemblyDispatch.h"
#include <memory>
@@ -144,16 +145,17 @@ public:
private:
enum AuxTensorIdx
{
- AsmGemmWorkspace = 0,
- Pretraspose,
- InterleavedLHS,
- TransposedRHS,
+ /* Slots 0 - 2 reserved for CpuGemmAssemblyDispatch */
+ InterleavedLHS = 3,
+ PreTransposedRHS,
+ Transposed1xWRHS,
TempResult,
Count
};
std::unique_ptr<kernels::CpuGemmInterleave4x4Kernel> _interleave_kernel{nullptr};
- std::unique_ptr<kernels::CpuGemmTranspose1xWKernel> _transpose_kernel{nullptr};
+ std::unique_ptr<CpuTranspose> _pretranspose_b_func{nullptr};
+ std::unique_ptr<kernels::CpuGemmTranspose1xWKernel> _transpose1xW_b_kernel{nullptr};
std::unique_ptr<kernels::CpuGemmMatrixMultiplyKernel> _mm_kernel{nullptr};
std::unique_ptr<CpuGemmAssemblyDispatch> _asm_glue{nullptr};
std::unique_ptr<kernels::CpuGemmMatrixAdditionKernel> _ma_kernel{nullptr};
@@ -162,10 +164,13 @@ private:
std::unique_ptr<CpuActivation> _activation_func{nullptr};
TensorInfo _tmp_a{};
+ TensorInfo _pretransposed_b{};
TensorInfo _tmp_b{};
TensorInfo _tmp_d{};
bool _run_vector_matrix_multiplication{false};
+ bool _run_interleave_transpose{
+ true}; /**< If we run CpuGemmInterleave4x4Kernel on lhs and CpuGemmTranspose1xWKernel on rhs */
bool _run_alpha_scale{false};
bool _run_addition{false};
bool _run_bias_addition{false};
@@ -177,4 +182,4 @@ private:
};
} // namespace cpu
} // namespace arm_compute
-#endif /*ARM_COMPUTE_CPU_GEMM_H */
+#endif // ACL_SRC_CPU_OPERATORS_CPUGEMM_H
diff --git a/src/cpu/operators/CpuGemmConv2d.cpp b/src/cpu/operators/CpuGemmConv2d.cpp
index 7c59d88c61..117527ccc1 100644
--- a/src/cpu/operators/CpuGemmConv2d.cpp
+++ b/src/cpu/operators/CpuGemmConv2d.cpp
@@ -32,7 +32,9 @@
#include "arm_compute/runtime/NEON/NEScheduler.h"
#include "src/common/utils/Log.h"
+#include "src/core/helpers/AutoConfiguration.h"
#include "src/core/helpers/MemoryHelpers.h"
+#include "src/core/helpers/Utils.h"
#include "src/cpu/kernels/CpuCol2ImKernel.h"
#include "src/cpu/kernels/CpuIm2ColKernel.h"
#include "src/cpu/kernels/CpuWeightsReshapeKernel.h"
@@ -52,6 +54,117 @@ namespace arm_compute
{
namespace cpu
{
+
+/** @section note_CpuGemmConv2d_weight_transformation Weight Transformations in CpuGemmConv2d
+ *
+ * A. Terminology
+ * Throughout CpuGemmConv2d, we use the following terms in ways that may differ from other operators / kernels:
+ * - "Transform" or "Reshape" of the weights: they both mean all the operations that we perform on the weight
+ * tensor up until they are consumed by gemm (CpuGemm or CpuGemmLowpMatrixMultiplyCore)
+ * Note that the specific gemm operator may perform further transformations on the weights, but the
+ * transformations here only mean those performed in CpuGemmConv2d
+ * - "Transpose" of weights: The @ref CpuTranspose operation. I.e. transpose of the weights' lowest two
+ * dimensions
+ *
+ * B. Gemm-based conv2d
+ * We want to convert the 2d convolution op (ignoring bias):
+ * dst = conv2d(src, weight)
+ * into a matrix multiplication op:
+ * gemm_dst = gemm(lhs, rhs)
+ *
+ * E.g.: For data layout NHWC
+ * 3 (hi) <----------> (lo) 0
+ * src.shape = [batch, in_h , in_w, in_c]
+ * weight.shape = [out_c, k_h , k_w, in_c]
+ * dst.shape = [batch, out_h, out_w, out_c]
+ *
+ * This requires three transformations:
+ * * src -> lhs, transform conv input to gemm lhs; gemm_lhs is a 2d matrix where each row (or column,
+ * depending on the convention) is a linearized "patch" of the conv_input that corresponds to
+ * the receptive field of the corresponding output element.
+ * The convention is to use "column", but to disambiguate from the column vector of a matrix,
+ * in this documentation we shall use "patch".
+ * This transform is called im2col (for details see @ref CpuIm2ColKernel)
+ * * weight -> rhs, transform conv weight to gemm rhs, known as weight transform/reshape (wt)
+ * * gemm_dst -> dst, transform gemm output back to conv output, known as col2im (for details see
+ * @ref CpuCol2ImKernel)
+ *
+ * This section focuses on the weight transformation and assumes the im2col is already performed
+ *
+ * C. Weight Transformation
+ * After im2col, assume: lhs.shape = [num_patch, patch_size],
+ * where patch_size is the number of elements in a "patch": patch_size = k_h * k_w * in_c
+ * num_patch is the number of patches; we can ignore it here (for details see @ref CpuIm2ColKernel)
+ *
+ * After wt, rhs should have the shape: rhs = [patch_size, out_c]
+ *
+ * Therefore, the weight transformation consists of two steps:
+ * 1. Collapsing all 3 spatial dimensions: [out_c, k_h, k_w, in_c] -> [out_c, patch_size]
+ * 2. Transpose the collapsed shape: [out_c, patch_size] -> [patch_size, out_c]
+ *
+ * D. Implementation
+ * There are 4 paths for weight transformation
+ *
+ * 1. Path 1: Fixed weight format - no transformation
+ * The underlying gemm kernel may adopt fixed weight format (isVarWeightsKernel() == true), which requires
+ * that no weight transformation shall be performed
+ * Note that this no-transform requirement applies both to this op (CpuGemmConv2d) and the constituent ops, up
+ * until the fixed format kernels themselves
+ *
+ * 2. Path 2: Reinterpret then transpose later
+ * If the weight tensor has no "holes" (see @ref has_holes), there are two optimizations we can apply:
+ * - We can ignore the first step (collapsing of spatial dimensions) by simply re-interpreting the shape
+ * in TensorInfo
+ * - Instead of performing transpose here, we can pass the transpose flag to the underlying gemm. The gemm
+ * may then decide to fuse the transpose with any further transformations
+ *
+ * 3. Path 3: Reshape then transpose later
+ * If the weight tensor has holes, then we use a dedicated @ref CpuReshape, followed by transpose later
+ *
+ * 4. Path 4: Fused reshape and transpose
+ * This is only for quantized types for now (TODO: Remove (COMPMID-6596)). We fall back to a legacy
+ * non-optimized kernel @ref CpuWeightsReshapeKernel to perform a fused reshape + transpose
+ *
+ * Path 1 is the long term solution that we shall migrate to once (if) we adopt fixed weight format for all gemm
+ * kernels.
+ * In the short term, Path 2 is the favored, more performant path.
+ */
+
+namespace
+{
+/** Initialize reshaped / transformed weight info
+ *
+ * @param[in] weights Input weights
+ * @param[out] reshaped_weights Transformed weights
+ */
+void initialize_reshaped_weight_info(const ITensorInfo &weights, ITensorInfo &reshaped_weights)
+{
+ auto_init_if_empty(reshaped_weights, weights);
+ if (is_data_type_quantized(weights.data_type()))
+ {
+ // WT method: FusedReshapeAndTranspose
+ reshaped_weights.set_tensor_shape(compute_weights_reshaped_shape(weights, /* has_bias */ false));
+ }
+ else
+ {
+ TensorShape collapsed_weights = weights.tensor_shape();
+ collapsed_weights.collapse(3);
+ reshaped_weights.set_tensor_shape(collapsed_weights);
+ }
+}
+} // namespace
+
+CpuGemmConv2d::WeightTransformMethod CpuGemmConv2d::get_wt_method(const ITensorInfo &weights)
+{
+ // TODO: Extend ReinterpretThenTranspose support for quantized data types COMPMID-6596
+ if (is_data_type_quantized(weights.data_type()))
+ {
+ return WeightTransformMethod::FusedReshapeAndTranspose;
+ }
+ return has_holes(weights) ? WeightTransformMethod::ReshapeThenTranspose
+ : WeightTransformMethod::ReinterpretThenTranspose;
+}
+
CpuGemmConv2d::SkipInfo CpuGemmConv2d::skip_im_col_info(const ITensorInfo *src,
const ITensorInfo *weights,
const PadStrideInfo &conv_info,
@@ -96,7 +209,8 @@ CpuGemmConv2d::SkipInfo CpuGemmConv2d::skip_im_col_info(const ITensorInfo
}
CpuGemmConv2d::CpuGemmConv2d()
- : _weights_reshape_kernel(nullptr),
+ : _weights_reshape(nullptr),
+ _weights_reshape_and_transpose_kernel(nullptr),
_im2col_kernel(),
_mm_gemm(),
_mm_gemmlowp(),
@@ -111,6 +225,8 @@ CpuGemmConv2d::CpuGemmConv2d()
_skip_col2im(false),
_is_quantized(false),
_is_prepared(false),
+ _wt_method(WeightTransformMethod::ReshapeThenTranspose),
+ _run_wt(true),
_aux_mem(AuxTensorIdx::Count)
{
}
@@ -130,12 +246,6 @@ void CpuGemmConv2d::configure_mm(const ITensorInfo *src,
ARM_COMPUTE_ERROR_THROW_ON(validate_mm(src, weights, biases, dst, act_info, enable_fast_math, gemm_3d_depth,
_skip_im2col, fixed_format, weight_format));
- // Create GEMMInfo structure
- const GEMMInfo &gemm_info =
- GEMMInfo(false, false, true /* Reshape weights only for the first run */, gemm_3d_depth,
- _skip_im2col /* Reinterpret the input as 3D if im2col is skipped */, false, GEMMLowpOutputStageInfo(),
- false, enable_fast_math, false, act_info, fixed_format, weight_format);
-
// Supported activations in GEMM
const std::set<ActivationLayerInfo::ActivationFunction> supported_acts = {
ActivationLayerInfo::ActivationFunction::RELU, ActivationLayerInfo::ActivationFunction::BOUNDED_RELU,
@@ -184,7 +294,8 @@ void CpuGemmConv2d::configure_mm(const ITensorInfo *src,
_mm_gemmlowp = std::make_unique<CpuGemmLowpMatrixMultiplyCore>();
_mm_gemmlowp->configure(&tmp_src, &tmp_weights, biases, dst,
GEMMInfo(false, false, true, gemm_3d_depth, _skip_im2col, false, output_info, false,
- enable_fast_math, false, act_info, fixed_format, weight_format));
+ enable_fast_math, false, act_info, fixed_format, weight_format,
+ false /* pretranspose_B. TODO: COMPMID-6596 */));
auto mm_mem_req = _mm_gemmlowp->workspace();
for (unsigned int cont = 0; cont < mm_mem_req.size(); ++cont)
@@ -194,6 +305,13 @@ void CpuGemmConv2d::configure_mm(const ITensorInfo *src,
}
else
{
+ // Create GEMMInfo structure
+ const GEMMInfo &gemm_info =
+ GEMMInfo(false, false, true /* Reshape weights only for the first run */, gemm_3d_depth,
+ _skip_im2col /* Reinterpret the input as 3D if im2col is skipped */, false,
+ GEMMLowpOutputStageInfo(), false, enable_fast_math, false, act_info, fixed_format, weight_format,
+ true /*pretranspose_B. For fp gemm (wt path 1 - 3), We always pretranspose B (for wt path 1 this
+ flag is ignored)*/);
// Configure matrix multiply function
_mm_gemm = std::make_unique<CpuGemm>();
_mm_gemm->configure(src, weights, biases, dst, 1.0f, 1.0f, gemm_info);
@@ -220,12 +338,6 @@ Status CpuGemmConv2d::validate_mm(const ITensorInfo *src,
const bool is_quantized = is_data_type_quantized_asymmetric(data_type);
const bool is_activation_enabled = act_info.enabled();
- // Create GEMMInfo structure
- const GEMMInfo gemm_info =
- GEMMInfo(false, false, true /* Reshape weights only for the first run */, gemm_3d_depth,
- skip_im2col /* Reinterpret the input as 3D if im2col is skipped */, false, GEMMLowpOutputStageInfo(),
- false, enable_fast_math, false, act_info, fixed_format, weight_format);
-
if (is_quantized)
{
// Since we need negative offsets for computing convolution, we need to change QuantizationInfo()
@@ -266,10 +378,19 @@ Status CpuGemmConv2d::validate_mm(const ITensorInfo *src,
return CpuGemmLowpMatrixMultiplyCore::validate(input_qa.get(), weights_qa.get(), biases, dst,
GEMMInfo(false, false, true, gemm_3d_depth, skip_im2col, false,
- output_info, false, enable_fast_math, false, act_info));
+ output_info, false, enable_fast_math, false, act_info,
+ false /* pretranspose_B. TODO: COMPMID-6596 */));
}
else
{
+ // Create GEMMInfo structure
+ const GEMMInfo gemm_info =
+ GEMMInfo(false, false, true /* Reshape weights only for the first run */, gemm_3d_depth,
+ skip_im2col /* Reinterpret the input as 3D if im2col is skipped */, false,
+ GEMMLowpOutputStageInfo(), false, enable_fast_math, false, act_info, fixed_format, weight_format,
+ true /*pretranspose_B. For fp gemm (wt path 1 - 3), We always pretranspose B (for wt path 1 this
+ flag is ignored)*/);
+
// Perform validation step on Matrix multiply function
return CpuGemm::validate(src, weights, biases, dst, 1.0f, 1.0f, gemm_info);
}
@@ -353,13 +474,8 @@ void CpuGemmConv2d::configure(const ITensorInfo *src,
unsigned int stride_y = 0;
std::tie(stride_x, stride_y) = conv_info.stride();
- unsigned int mat_weights_cols = weights->dimension(idx_kernels);
-
- // _weights_reshaped will be auto configured in the kernel.
- // Just append biases and do not transpose 1xW as it will be reshaped in CpuGemm
- _weights_reshape_kernel = std::make_unique<kernels::CpuWeightsReshapeKernel>();
- _weights_reshape_kernel->configure(weights, nullptr, &_weights_reshaped);
- _weights_reshaped.set_quantization_info(weights->quantization_info());
+ // Initialize reshaped weights
+ initialize_reshaped_weight_info(*weights, _weights_reshaped);
// Create tensor to store im2col reshaped inputs
if (!_skip_im2col)
@@ -380,6 +496,8 @@ void CpuGemmConv2d::configure(const ITensorInfo *src,
gemm_input_to_use = &_im2col_output;
}
+ const unsigned int mat_weights_cols = weights->dimension(idx_kernels);
+
// Create temporary GEMM output tensor in case we cannot skip col2im
const DataType output_data_type = data_type == DataType::BFLOAT16 ? DataType::F32 : data_type;
if (!_skip_col2im)
@@ -412,9 +530,38 @@ void CpuGemmConv2d::configure(const ITensorInfo *src,
// In case we need to skip col2im, GEMM3D (gemm_3d_depth != 0) must be called in order to avoid reshaping the output matrix
const unsigned int gemm_3d_depth = _skip_col2im ? conv_h : 0;
const bool fixed_format = weights_info.weight_format() != arm_compute::WeightFormat::UNSPECIFIED;
+ /** @section note_CpuGemmConv2d_weight_use_in_configure Which weights tensor should we use to configure gemm
+ *
+ * A. The problem:
+ * In principle, we should use the weights tensor corresponding to the weights transformation path. I.e.:
+ * - If no weight transformation (_run_wt == false): Use original weights
+ * - else: Use transformed weights
+ * However in practice we have a dilemma:
+ * - We need to know _run_wt before we can configure gemm with the corresponding weights, but
+ * - _run_wt depends on isVarWeightsKernel(), which is only known after gemm is configured
+ *
+ * B. The decision:
+ * To simplify the matter, we decide to always use the transformed weights, regardless of _run_wt
+ *
+ * This decision requires the following conditions:
+ * 1. The underlying gemm where isVarWeightsKernel() == true, must guarantee that:
+ * A. Ignore the flag to transpose weights (GEMMInfo::pretranspose_B)
+ * B. Use weights/B tensor passed to it at prepare() or run() instead of that passed at configure()
+ * 2. CpuGemmConv2d where isVarWeightsKernel() == true, must guarantee that:
+ * A. Pass original weights instead of reshaped or reinterpreted weights
+ *
+ * C. Future actions:
+ * Condition 2 is a given, based on our implementation.
+ * If condition 1 cannot hold, we must make changes to the underlying gemm to:
+ * 1. Either expose isVarWeightsKernel() before gemm is configured somehow, or
+ * 2. Take in an additional "original_weights" tensor info at configure
+ */
configure_mm(gemm_input_to_use, &_weights_reshaped, biases, gemm_output_to_use, act_info, enable_fast_math,
gemm_3d_depth, fixed_format, weights_info.weight_format());
+ // Can only decide isVarWeightsKernel after gemm is configured
+ _run_wt = !isVarWeightsKernel();
+
if (!_skip_col2im && _data_layout == DataLayout::NCHW)
{
// Configure col2im
@@ -428,18 +575,27 @@ void CpuGemmConv2d::configure(const ITensorInfo *src,
_reshape->configure(gemm_output_to_use, dst);
}
- // Check if GEMM transforms weights
- // Modernise through COMPMID-4535
- bool gemm_trans_wei = _aux_mem[1].size > 0; // Asm Pretranspose
- gemm_trans_wei = _mm_gemm != nullptr ? _aux_mem[3].size > 0 : gemm_trans_wei; // Tranpose RHS
- gemm_trans_wei = _mm_gemmlowp != nullptr ? _aux_mem[5].size > 0 : gemm_trans_wei; // Transpose RHS
-
// Check lifetime
_aux_mem[Im2ColOutput] =
MemoryInfo(offset_int_vec(Im2ColOutput), MemoryLifetime::Temporary, _im2col_output.total_size());
- _aux_mem[WeightsReshaped] = MemoryInfo(offset_int_vec(WeightsReshaped),
- gemm_trans_wei ? MemoryLifetime::Prepare : MemoryLifetime::Persistent,
- _weights_reshaped.total_size());
+ // Add WeightsReshaped memory requirement to workspace
+ // Note that in case of WeightTransformMethod::ReinterpretThenTranspose, we do not need to allocate this memory
+ // However since we cannot determine weight transformation method until prepare (see prepare()), we will have to
+ // settle with allocating more
+ if (_run_wt)
+ {
+ // Check if GEMM transforms weights
+ // If weight is further transformed by underlying gemm after ReshapeThenTranspose then we can free
+ // WeightsReshaped in prepare
+ // Otherwise WeightsReshaped is the final transformation of weights and needs to persist
+ bool gemm_trans_wei = _aux_mem[GemmAsmPretransposedRHS].size > 0;
+ gemm_trans_wei = _mm_gemm != nullptr ? _aux_mem[GemmTransposed1xWRHS].size > 0 : gemm_trans_wei;
+ gemm_trans_wei = _mm_gemmlowp != nullptr ? _aux_mem[GemmLowpTransposed1xWRHS].size > 0 : gemm_trans_wei;
+
+ _aux_mem[WeightsReshaped] = MemoryInfo(offset_int_vec(WeightsReshaped),
+ gemm_trans_wei ? MemoryLifetime::Prepare : MemoryLifetime::Persistent,
+ _weights_reshaped.total_size());
+ }
_aux_mem[GemmOutput] = MemoryInfo(offset_int_vec(GemmOutput), MemoryLifetime::Temporary, _gemm_output.total_size());
}
@@ -471,10 +627,18 @@ Status CpuGemmConv2d::has_opt_impl(arm_compute::WeightFormat &expected_weight_fo
const bool skip_col2im = skip_info.skip_col2im;
const unsigned int gemm_3d_depth = skip_col2im ? conv_h : 0;
const bool fixed_format = weights_info.weight_format() != arm_compute::WeightFormat::UNSPECIFIED;
- const GEMMInfo gemm_info =
- GEMMInfo(false, false, true /* Reshape weights only for the first run */, gemm_3d_depth,
- skip_im2col /* Reinterpret the input as 3D if im2col is skipped */, false, GEMMLowpOutputStageInfo(),
- false, enable_fast_math, false, act_info, fixed_format, weights_info.weight_format());
+
+ /** @section note_CpuGemmConv2d_weight_use_in_has_opt_impl Which weights tensor should we use for has_opt_impl
+ *
+ * For the pretranspose_B flag, this shares a similar problem and thus the same decision as that of
+ * @ref note_CpuGemmConv2d_weight_use_in_configure
+ *
+ * But for the weights, we shall always use the original instead of reshaped weights here
+ */
+ const GEMMInfo gemm_info = GEMMInfo(false, false, true /* Reshape weights only for the first run */, gemm_3d_depth,
+ skip_im2col /* Reinterpret the input as 3D if im2col is skipped */, false,
+ GEMMLowpOutputStageInfo(), false, enable_fast_math, false, act_info,
+ fixed_format, weights_info.weight_format(), true /* pretranspose_B */);
return CpuGemm::has_opt_impl(expected_weight_format, src, weights, biases, dst, gemm_info);
}
@@ -565,8 +729,10 @@ Status CpuGemmConv2d::validate(const ITensorInfo *src,
unsigned int mat_weights_rows =
weights->dimension(idx_width) * weights->dimension(idx_height) * weights->dimension(idx_channel);
- weights_reshaped_info = TensorInfo(compute_weights_reshaped_shape(*weights, append_bias), 1, weights->data_type());
- weights_reshaped_info.set_quantization_info(weights->quantization_info());
+ // Initialize reshaped weights
+ initialize_reshaped_weight_info(*weights, weights_reshaped_info);
+ // No need to call CpuReshape::validate() or CpuTranspose::validate() as the dst info is auto-configured from the
+ // src
weights_to_use = &weights_reshaped_info;
if (!skip_im2col)
@@ -613,6 +779,7 @@ Status CpuGemmConv2d::validate(const ITensorInfo *src,
gemm_output_to_use = &info_gemm;
const bool fixed_format = weights_info.weight_format() != arm_compute::WeightFormat::UNSPECIFIED;
+ // See note_CpuGemmConv2d_weight_use_in_configure regarding the choice of the weights
ARM_COMPUTE_RETURN_ON_ERROR(validate_mm(gemm_input_to_use, weights_to_use, biases, gemm_output_to_use, act_info,
enable_fast_math, skip_col2im ? conv_h : 0, skip_im2col, fixed_format,
weights_info.weight_format()));
@@ -637,7 +804,6 @@ void CpuGemmConv2d::run(ITensorPack &tensors)
CpuAuxTensorHandler im2col_output(offset_int_vec(Im2ColOutput), _im2col_output, tensors, false);
CpuAuxTensorHandler gemm_output(offset_int_vec(GemmOutput), _gemm_output, tensors, false);
- CpuAuxTensorHandler reshaped_wei(offset_int_vec(WeightsReshaped), _weights_reshaped, tensors, false);
bool out_has_padding = _skip_col2im && (dst->info()->padding().bottom != 0 || dst->info()->padding().top != 0);
if (!_skip_im2col)
@@ -666,25 +832,32 @@ void CpuGemmConv2d::run(ITensorPack &tensors)
gemm_output_to_use = dst;
}
- // Runs CpuGemm or CpuGemmLowpMatrixMultiplyCore functions
- ITensorPack pack_mm = tensors;
- pack_mm.add_const_tensor(TensorType::ACL_SRC_0, gemm_input_to_use);
- if (!this->isVarWeightsKernel())
- {
- pack_mm.add_const_tensor(TensorType::ACL_SRC_1, reshaped_wei.get());
- }
- pack_mm.add_tensor(TensorType::ACL_DST, gemm_output_to_use);
- if (_is_quantized)
+ ITensorPack gemm_pack = tensors;
+ gemm_pack.add_const_tensor(TensorType::ACL_SRC_0, gemm_input_to_use);
+ gemm_pack.add_tensor(TensorType::ACL_DST, gemm_output_to_use);
+ // Allocate reshaped weights if required
+ auto weights = gemm_pack.get_const_tensor(TensorType::ACL_SRC_1);
+ CpuAuxTensorHandler reinterpreted_wei(
+ _weights_reshaped,
+ *weights); // Re-interpreted weights. Only tensor shape is changed. No allocation
+ CpuAuxTensorHandler reshaped_wei(offset_int_vec(WeightsReshaped), _weights_reshaped, tensors);
+ // Update the weights to use if it has been reshaped
+ if (_run_wt)
{
- // Run gemmlowp
- _mm_gemmlowp->run(pack_mm);
- }
- else
- {
- // Run gemm
- _mm_gemm->run(pack_mm);
+ if (_wt_method == WeightTransformMethod::ReinterpretThenTranspose)
+ {
+ gemm_pack.add_const_tensor(TensorType::ACL_SRC_1, reinterpreted_wei.get());
+ }
+ else if (_wt_method == WeightTransformMethod::ReshapeThenTranspose ||
+ _wt_method == WeightTransformMethod::FusedReshapeAndTranspose)
+ {
+ gemm_pack.add_const_tensor(TensorType::ACL_SRC_1, reshaped_wei.get());
+ }
}
+ // Runs CpuGemm or CpuGemmLowpMatrixMultiplyCore functions
+ _is_quantized ? _mm_gemmlowp->run(gemm_pack) : _mm_gemm->run(gemm_pack);
+
// Reshape output matrix
if (!_skip_col2im)
{
@@ -710,24 +883,87 @@ void CpuGemmConv2d::prepare(ITensorPack &tensors)
{
if (!_is_prepared)
{
- // Variable weights executions that use fixed-format kernels
- // need no reshaping of the weights.
- if (this->isVarWeightsKernel())
+ auto weights = tensors.get_const_tensor(TensorType::ACL_SRC_1);
+ // Determine which weights reshape path to take
+ // Note that this decision can only occur at prepare instead of configure because it relies on the presence of
+ // any holes in the weight tensor, which may change after configure (e.g. from extending padding)
+ if (_run_wt)
{
- _is_quantized ? _mm_gemmlowp->prepare(tensors) : _mm_gemm->prepare(tensors);
- _is_prepared = true;
- return;
+ _wt_method = get_wt_method(*(weights->info()));
+ switch (_wt_method)
+ {
+ case (WeightTransformMethod::FusedReshapeAndTranspose):
+ {
+ ARM_COMPUTE_LOG_INFO_WITH_FUNCNAME_ACL("Perform weight transformation: FusedReshapeAndTranspose");
+ _weights_reshape_and_transpose_kernel = std::make_unique<kernels::CpuWeightsReshapeKernel>();
+ _weights_reshape_and_transpose_kernel->configure(weights->info(), nullptr, &_weights_reshaped);
+ break;
+ }
+ case (WeightTransformMethod::ReshapeThenTranspose):
+ {
+ ARM_COMPUTE_LOG_INFO_WITH_FUNCNAME_ACL("Perform weight transformation: ReshapeThenTranspose");
+ _weights_reshape = std::make_unique<CpuReshape>();
+ _weights_reshape->configure(weights->info(), &_weights_reshaped);
+ break;
+ }
+ case (WeightTransformMethod::ReinterpretThenTranspose):
+ {
+ ARM_COMPUTE_LOG_INFO_WITH_FUNCNAME_ACL("Perform weight transformation: ReinterpretThenTranspose");
+ // Nothing to configure
+ break;
+ }
+ default:
+ {
+ ARM_COMPUTE_ERROR("Unsupported weight transform method");
+ }
+ }
+ }
+ else
+ {
+ ARM_COMPUTE_LOG_INFO_WITH_FUNCNAME_ACL("No weight transformation is performed");
}
-
- // Run weights reshaping and mark original weights tensor as unused
- CpuAuxTensorHandler weights_reshaped(offset_int_vec(WeightsReshaped), _weights_reshaped, tensors);
- auto weights = tensors.get_const_tensor(TensorType::ACL_SRC_1);
- ITensorPack pack = {{TensorType::ACL_SRC, weights}, {TensorType::ACL_DST, weights_reshaped.get()}};
- NEScheduler::get().schedule_op(_weights_reshape_kernel.get(), 3, _weights_reshape_kernel->window(), pack);
- weights->mark_as_unused();
ITensorPack gemm_pack = tensors;
- gemm_pack.add_const_tensor(TensorType::ACL_SRC_1, weights_reshaped.get());
+ // Allocate reshaped weights if required
+ CpuAuxTensorHandler reinterpreted_wei(
+ _weights_reshaped,
+ *weights); // Re-interpreted weights. Only tensor shape is changed. No allocation
+ CpuAuxTensorHandler reshaped_wei(offset_int_vec(WeightsReshaped), _weights_reshaped, tensors);
+ // Run weights reshape if required
+ if (_run_wt)
+ {
+ switch (_wt_method)
+ {
+ case (WeightTransformMethod::FusedReshapeAndTranspose):
+ {
+ ITensorPack pack = {{TensorType::ACL_SRC, weights}, {TensorType::ACL_DST, reshaped_wei.get()}};
+ NEScheduler::get().schedule_op(_weights_reshape_and_transpose_kernel.get(), Window::DimW,
+ _weights_reshape_and_transpose_kernel->window(), pack);
+ weights->mark_as_unused();
+ gemm_pack.add_const_tensor(TensorType::ACL_SRC_1, reshaped_wei.get());
+ break;
+ }
+ case (WeightTransformMethod::ReshapeThenTranspose):
+ {
+ ITensorPack pack = {{TensorType::ACL_SRC, weights}, {TensorType::ACL_DST, reshaped_wei.get()}};
+ _weights_reshape->run(pack);
+ weights->mark_as_unused();
+ gemm_pack.add_const_tensor(TensorType::ACL_SRC_1, reshaped_wei.get());
+ break;
+ }
+ case (WeightTransformMethod::ReinterpretThenTranspose):
+ {
+ gemm_pack.add_const_tensor(TensorType::ACL_SRC_1, reinterpreted_wei.get());
+ // Nothing to run
+ break;
+ }
+ default:
+ {
+ ARM_COMPUTE_ERROR("Unsupported weight transform method");
+ }
+ }
+ }
_is_quantized ? _mm_gemmlowp->prepare(gemm_pack) : _mm_gemm->prepare(gemm_pack);
+
_is_prepared = true;
}
}
diff --git a/src/cpu/operators/CpuGemmConv2d.h b/src/cpu/operators/CpuGemmConv2d.h
index 118d366517..48a0d11107 100644
--- a/src/cpu/operators/CpuGemmConv2d.h
+++ b/src/cpu/operators/CpuGemmConv2d.h
@@ -42,21 +42,12 @@ class CpuGemmLowpOutputStage;
class CpuReshape;
namespace kernels
{
-class CpuWeightsReshapeKernel;
class CpuIm2ColKernel;
class CpuCol2ImKernel;
+class CpuWeightsReshapeKernel;
} // namespace kernels
-/** Basic function to compute the convolution layer. This function calls the following kernels/functions:
- *
- * -# @ref cpu::kernels::CpuIm2ColKernel
- * -# @ref CpuGemm (if the data type is BFLOAT16/FP16/FP32)
- * -# @ref CpuGemmLowpMatrixMultiplyCore (if the data type is QASYMM8/QASYMM8_SIGNED)
- * -# @ref CpuGemmLowpOutputStage (if the data type is QASYMM8/QASYMM8_SIGNED)
- * -# @ref cpu::kernels::CpuCol2ImKernel (if NCHW data layout)
- * -# @ref kernels::CpuWeightsReshapeKernel
- *
- */
+/** Basic function to compute the convolution layer. @ref note_CpuGemmConv2d_weight_transformation */
class CpuGemmConv2d : public ICpuOperator
{
public:
@@ -99,7 +90,7 @@ public:
* @param[out] dst Destination tensor info. 3 lower dimensions represent a single output [width, height, OFM], while the rest represent batch of outputs.
* Data types supported: Same as @p input.
* @param[in] conv_info Contains padding and stride information described in @ref PadStrideInfo.
- * @param[in] weights_info Specifies if the weights tensor has been reshaped with NEWeightsReshapeKernel. If this is not part of the fully connected layer the weights
+ * @param[in] weights_info Specifies if the weights tensor has been reshaped with CpuWeightsReshapeKernel. If this is not part of the fully connected layer the weights
* tensor has also been transposed with cpu::kernels::CpuGemmTranspose1xWKernel. Data type supported: Same as @p input.
* @param[in] dilation (Optional) Dilation, in elements, across x and y. Defaults to (1, 1).
* @param[in] act_info (Optional) Activation layer information in case of a fused activation. Only RELU, BOUNDED_RELU and LU_BOUNDED_RELU supported.
@@ -136,7 +127,7 @@ public:
/** Indicates whether or not there is an optimal assembly implementation that can be used to process the given parameters.
*
- * The paramter list is the same as @ref NEGEMMConvolutionLayer::has_opt_impl
+ * The parameter list is the same as @ref NEGEMMConvolutionLayer::has_opt_impl
*
* @return a status.
*/
@@ -254,15 +245,35 @@ private:
bool isVarWeightsKernel() const;
enum AuxTensorIdx
{
- // CpuGemmLowpMatrixMultiplyCore has up to 8 internal tensors
- Im2ColOutput = 9,
+ GemmAsmPretransposedRHS = 2, // CpuGemmAssemblyDispatch::Pretranspose
+ GemmTransposed1xWRHS = 5, // CpuGemm::Transposed1xWRHS
+ GemmLowpTransposed1xWRHS = 6, // CpuGemmLowpMatrixMultiplyCore::TmpB
+ /* Slots 0 - 9 reserved and shared by CpuGemmLowpMatrixMultiplyCore and CpuGemm */
+ Im2ColOutput = 10,
WeightsReshaped,
GemmOutput,
Count
};
- std::unique_ptr<kernels::CpuWeightsReshapeKernel> _weights_reshape_kernel;
- std::unique_ptr<cpu::kernels::CpuIm2ColKernel> _im2col_kernel;
+ /** Weight transformation method. See @ref note_CpuGemmConv2d_weight_transformation */
+ enum class WeightTransformMethod
+ {
+ ReinterpretThenTranspose,
+ ReshapeThenTranspose,
+ FusedReshapeAndTranspose,
+ };
+
+ /** Select weight transformation method
+ *
+ * @param[in] weights Input weights
+ *
+ * @return WeightTransformMethod
+ */
+ static WeightTransformMethod get_wt_method(const ITensorInfo &weights);
+
+ std::unique_ptr<CpuReshape> _weights_reshape;
+ std::unique_ptr<kernels::CpuWeightsReshapeKernel> _weights_reshape_and_transpose_kernel;
+ std::unique_ptr<kernels::CpuIm2ColKernel> _im2col_kernel;
std::unique_ptr<CpuGemm> _mm_gemm;
std::unique_ptr<CpuGemmLowpMatrixMultiplyCore> _mm_gemmlowp;
std::unique_ptr<kernels::CpuCol2ImKernel> _col2im_kernel;
@@ -275,10 +286,12 @@ private:
DataLayout _data_layout;
- bool _skip_im2col;
- bool _skip_col2im;
- bool _is_quantized;
- bool _is_prepared;
+ bool _skip_im2col;
+ bool _skip_col2im;
+ bool _is_quantized;
+ bool _is_prepared;
+ WeightTransformMethod _wt_method;
+ bool _run_wt;
experimental::MemoryRequirements _aux_mem{Count};
};
diff --git a/src/cpu/operators/CpuGemmDirectConv2d.cpp b/src/cpu/operators/CpuGemmDirectConv2d.cpp
index 8fa81b1907..9187927541 100644
--- a/src/cpu/operators/CpuGemmDirectConv2d.cpp
+++ b/src/cpu/operators/CpuGemmDirectConv2d.cpp
@@ -140,9 +140,11 @@ void CpuGemmDirectConv2d::configure(const ITensorInfo *src,
}
// Add auxiliary memory requirements of the assembly dispatch
- auto asm_mem_req = _gemm_asm_func->workspace();
- _aux_mem[AsmGemmWorkspace] = asm_mem_req[AsmGemmWorkspace];
- _aux_mem[Pretranspose] = asm_mem_req[Pretranspose];
+ const auto asm_mem_req = _gemm_asm_func->workspace();
+ for (unsigned int slot = 0; slot < asm_mem_req.size(); ++slot)
+ {
+ _aux_mem[slot] = asm_mem_req[slot];
+ }
if (_aux_mem[Pretranspose].size > 0)
{
diff --git a/src/cpu/operators/CpuGemmDirectConv2d.h b/src/cpu/operators/CpuGemmDirectConv2d.h
index 1cc3caadae..a7365615b9 100644
--- a/src/cpu/operators/CpuGemmDirectConv2d.h
+++ b/src/cpu/operators/CpuGemmDirectConv2d.h
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2021 Arm Limited.
+ * Copyright (c) 2021, 2023 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -21,8 +21,8 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*/
-#ifndef ARM_COMPUTE_CPU_GEMM_DIRECT_CONV_2D_H
-#define ARM_COMPUTE_CPU_GEMM_DIRECT_CONV_2D_H
+#ifndef ACL_SRC_CPU_OPERATORS_CPUGEMMDIRECTCONV2D_H
+#define ACL_SRC_CPU_OPERATORS_CPUGEMMDIRECTCONV2D_H
#include "arm_compute/core/TensorInfo.h"
@@ -95,8 +95,10 @@ public:
private:
enum AuxTensorIdx
{
- AsmGemmWorkspace = 0,
+ GemmTemp0 = 0,
+ GemmTemp1,
Pretranspose,
+ /* Slots above (0-2) are reserved for CpuGemmAssemblyDispatch */
PermutedWeights,
Count
};
@@ -112,4 +114,4 @@ private:
} // namespace cpu
} // namespace arm_compute
-#endif /* ARM_COMPUTE_CPU_GEMM_DIRECT_CONV_2D_H */
+#endif // ACL_SRC_CPU_OPERATORS_CPUGEMMDIRECTCONV2D_H
diff --git a/src/cpu/operators/CpuGemmLowpMatrixMultiplyCore.cpp b/src/cpu/operators/CpuGemmLowpMatrixMultiplyCore.cpp
index 2ee879b67b..b25505a85d 100644
--- a/src/cpu/operators/CpuGemmLowpMatrixMultiplyCore.cpp
+++ b/src/cpu/operators/CpuGemmLowpMatrixMultiplyCore.cpp
@@ -297,9 +297,11 @@ void CpuGemmLowpMatrixMultiplyCore::configure(
if (_assembly_path)
{
- auto asm_mem_req = _asm_glue->workspace();
- _aux_mem[AsmGemmWorkspace] = asm_mem_req[AsmGemmWorkspace];
- _aux_mem[Pretranspose] = asm_mem_req[Pretranspose];
+ const auto asm_mem_req = _asm_glue->workspace();
+ for (unsigned int slot = 0; slot < asm_mem_req.size(); ++slot)
+ {
+ _aux_mem[slot] = asm_mem_req[slot];
+ }
}
// Request memory for LHS and RHS reshape matrix
diff --git a/src/cpu/operators/CpuGemmLowpMatrixMultiplyCore.h b/src/cpu/operators/CpuGemmLowpMatrixMultiplyCore.h
index a7798938e7..78065a8953 100644
--- a/src/cpu/operators/CpuGemmLowpMatrixMultiplyCore.h
+++ b/src/cpu/operators/CpuGemmLowpMatrixMultiplyCore.h
@@ -21,8 +21,8 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*/
-#ifndef ARM_COMPUTE_CPU_GEMMLOWP_MATRIXMULTIPLY_CORE_H
-#define ARM_COMPUTE_CPU_GEMMLOWP_MATRIXMULTIPLY_CORE_H
+#ifndef ACL_SRC_CPU_OPERATORS_CPUGEMMLOWPMATRIXMULTIPLYCORE_H
+#define ACL_SRC_CPU_OPERATORS_CPUGEMMLOWPMATRIXMULTIPLYCORE_H
#include "arm_compute/core/TensorInfo.h"
#include "arm_compute/function_info/GEMMInfo.h"
@@ -134,9 +134,8 @@ public:
private:
enum AuxTensorIdx
{
- AsmGemmWorkspace = 0,
- Pretranspose,
- VectorSumCol,
+ /* Slots 0 - 2 reserved for CpuGemmAssemblyDispatch */
+ VectorSumCol = 3,
VectorSumRow,
TmpA,
TmpB,
@@ -181,4 +180,4 @@ private:
};
} // namespace cpu
} // namespace arm_compute
-#endif /*ARM_COMPUTE_CPU_GEMMLOWP_MATRIXMULTIPLY_CORE_H */
+#endif // ACL_SRC_CPU_OPERATORS_CPUGEMMLOWPMATRIXMULTIPLYCORE_H
diff --git a/src/cpu/operators/CpuMatMul.h b/src/cpu/operators/CpuMatMul.h
index 24db3da346..2b1b4cf0ff 100644
--- a/src/cpu/operators/CpuMatMul.h
+++ b/src/cpu/operators/CpuMatMul.h
@@ -21,8 +21,8 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*/
-#ifndef ACL_SRC_CPU_OPERATORS_CPUMATMUL
-#define ACL_SRC_CPU_OPERATORS_CPUMATMUL
+#ifndef ACL_SRC_CPU_OPERATORS_CPUMATMUL_H
+#define ACL_SRC_CPU_OPERATORS_CPUMATMUL_H
#include "arm_compute/core/TensorInfo.h"
@@ -93,9 +93,8 @@ public:
private:
enum InternalTensorIdx
{
- AsmGemmWorkspace = 0, // Pre-allocate workspace tensors for CpuGemmAssemblyDispatch
- PretransposeRHS, // Pre-allocate workspace tensors for CpuGemmAssemblyDispatch
- TransposeLHS,
+ /* Slots 0 - 2 reserved for CpuGemmAssemblyDispatch */
+ TransposeLHS = 3,
TransposeRHS,
Count
};
@@ -124,4 +123,4 @@ private:
} // namespace cpu
} // namespace arm_compute
-#endif /* ACL_SRC_CPU_OPERATORS_CPUMATMUL */
+#endif // ACL_SRC_CPU_OPERATORS_CPUMATMUL_H
diff --git a/src/cpu/operators/CpuWinogradConv2d.cpp b/src/cpu/operators/CpuWinogradConv2d.cpp
index 9d07736c13..e4bcdc0b64 100644
--- a/src/cpu/operators/CpuWinogradConv2d.cpp
+++ b/src/cpu/operators/CpuWinogradConv2d.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2021-2022 Arm Limited.
+ * Copyright (c) 2021-2023 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -294,12 +294,11 @@ void CpuWinogradConv2d::configure(const ITensorInfo *src,
_activation_func->configure(dst, nullptr, act_info);
}
- auto asm_mem_req = _gemm_function->workspace();
- _aux_mem[GemmWorkspace] = asm_mem_req[GemmWorkspace];
- _aux_mem[Pretranspose] = asm_mem_req[Pretranspose];
- _aux_mem[InterleavedLHS] = asm_mem_req[InterleavedLHS];
- _aux_mem[TransposedRHS] = asm_mem_req[TransposedRHS];
- _aux_mem[TempResult] = asm_mem_req[TempResult];
+ const auto mm_mem_req = _gemm_function->workspace();
+ for (unsigned int slot = 0; slot < mm_mem_req.size(); ++slot)
+ {
+ _aux_mem[slot] = mm_mem_req[slot];
+ }
// Request temporary memory. Overlap memory needed for Input/Output transformations as they run on different non-overlapping time-steps.
_aux_mem[TransformedInput] = MemoryInfo(offset_int_vec(TransformedInput), MemoryLifetime::Temporary,
diff --git a/src/cpu/operators/CpuWinogradConv2d.h b/src/cpu/operators/CpuWinogradConv2d.h
index ba9b879431..03bfc51a46 100644
--- a/src/cpu/operators/CpuWinogradConv2d.h
+++ b/src/cpu/operators/CpuWinogradConv2d.h
@@ -29,8 +29,8 @@
#include "src/core/common/Macros.h"
#include "src/cpu/ICpuOperator.h"
-#include "src/cpu/kernels/CpuWinogradConv2dKernel.h"
#include "src/cpu/kernels/assembly/gemm_common.hpp"
+#include "src/cpu/kernels/CpuWinogradConv2dKernel.h"
#include "src/cpu/operators/CpuActivation.h"
#include "src/cpu/operators/CpuGemm.h"
#include "src/cpu/operators/CpuPermute.h"
@@ -96,26 +96,22 @@ public:
bool enable_fast_math = false);
// Inherited methods overridden:
- void run(ITensorPack &tensors) override;
- void prepare(ITensorPack &constants) override;
+ void run(ITensorPack &tensors) override;
+ void prepare(ITensorPack &constants) override;
experimental::MemoryRequirements workspace() const override;
private:
enum AuxTensorIdx
{
- GemmWorkspace = 0,
- Pretranspose = 1,
- InterleavedLHS = 2,
- TransposedRHS = 3,
- TempResult = 4,
- TransformedInput = 5,
- TransformedOutput = 6,
- WorkspaceIO = 7,
- TransformedWeights = 8,
- PermutedWeights = 9,
- PermutedInput = TransformedOutput,
- PermutedOutput = TransformedInput,
- Count = 10
+ /** Slot 0 - 6 reserved for CpuGemm */
+ TransformedInput = 7,
+ TransformedOutput,
+ WorkspaceIO,
+ TransformedWeights,
+ PermutedWeights,
+ Count,
+ PermutedInput = TransformedOutput,
+ PermutedOutput = TransformedInput
};
std::unique_ptr<CpuGemm> _gemm_function;
std::unique_ptr<CpuActivation> _activation_func;
@@ -124,9 +120,9 @@ private:
std::unique_ptr<CpuPermute> _permute_input;
std::unique_ptr<CpuPermute> _permute_output;
std::unique_ptr<CpuPermute> _permute_weights;
- experimental::MemoryRequirements _aux_mem{ Count };
+ experimental::MemoryRequirements _aux_mem{Count};
std::unique_ptr<arm_conv::ConvolutionArgs>
- _conv_args; // Make it unique ptr because this type does not have a default constructor
+ _conv_args; // Make it unique ptr because this type does not have a default constructor
arm_conv::winograd::WinogradImpl _winograd_impl;
DataLayout _data_layout;
TensorInfo _winograd_transformed_input;
diff --git a/src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp b/src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp
index 343ef21c0b..82bd465c99 100644
--- a/src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp
+++ b/src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp
@@ -31,6 +31,7 @@
#include "src/core/utils/AssemblyUtils.h"
#include "src/cpu/kernels/assembly/arm_gemm.hpp"
#include "src/cpu/kernels/assembly/CpuGemmAssemblyWrapperKernel.h"
+#include "src/cpu/operators/CpuTranspose.h"
#include "src/cpu/utils/CpuAuxTensorHandler.h"
#include <arm_neon.h>
@@ -229,6 +230,7 @@ private:
enum AuxTensorIdx
{
AsmGemmWorkspace = 0,
+ PrePretransposedB, /* Transposed B (rhs) before being passed to gemm or pretranspose_B_array */
Pretranspose,
Count
};
@@ -244,12 +246,16 @@ private:
/** Prepare the indirect buffer */
void prepare_indirect_buffer(ITensorPack &tensors);
+ /** Operator to transpose B before gemm or pretranspose_B_array*/
+ std::unique_ptr<CpuTranspose> _pre_pretranspose_b{nullptr};
/** Assembly Gemm kernel */
std::shared_ptr<arm_gemm::GemmCommon<TypeInput, TypeOutput>> _gemm_kernel_asm{nullptr};
/** Optimised Arm® Neon™ kernel */
std::unique_ptr<INEKernel> _optimised_kernel{nullptr};
/** Assembly GEMM workspace tensor info */
TensorInfo _workspace_info{};
+ /** Pre-pre-transposed B tensor info */
+ TensorInfo _pre_pretransposed_b_info{};
/** Pre-transpose tensor info */
TensorInfo _pretranspose_info{};
/** Prepared flag */
@@ -473,9 +479,45 @@ void Fallback<TypeInput, TypeOutput, OutputStage>::configure(const ITensorInfo *
_optimised_kernel = std::move(acl_gemm_wrapper);
_gemm_info = gemm_info;
+ // Check if we need to pre-pretranspose B. Fixed format kernels need no pre-pretranspose.
+ const bool run_pre_pretranspose_b = _gemm_info.transpose_b && !isVarWeightsKernel();
+ if (run_pre_pretranspose_b)
+ {
+ _pre_pretranspose_b = std::make_unique<CpuTranspose>();
+ _pre_pretranspose_b->configure(b, &_pre_pretransposed_b_info);
+ MemoryLifetime lifetime;
+ if (_is_b_constant)
+ {
+ if (_gemm_kernel_asm->B_pretranspose_required())
+ {
+ // PrePretransposedB tensor is only used in prepare(), but is then succeeded by Pretranspose
+ // So PrePretransposedB can be freed inside prepare()
+ lifetime = MemoryLifetime::Prepare;
+ }
+ else
+ {
+ // PrePretransposedB tensor is only used in prepare(), but is the final transformation of B
+ // So PrePretransposedB needs to persist beyond prepare()
+ lifetime = MemoryLifetime::Persistent;
+ }
+ }
+ else
+ {
+ // PrePretransposedB tensor is always used in run() and doesn't need to persist
+ lifetime = MemoryLifetime::Temporary;
+ }
+ // Forcing 128-byte alignment (required by 32-bit kernels)
+ const unsigned int alignment = 128;
+ _aux_mem[PrePretransposedB] =
+ MemoryInfo(offset_int_vec(PrePretransposedB), lifetime, _pre_pretransposed_b_info.total_size(), alignment);
+ }
+
// Check for pre-transposed support
if (_gemm_kernel_asm->B_pretranspose_required())
{
+ // Fixed format kernels need no pretranspose.
+ ARM_COMPUTE_ERROR_ON(arm_compute::is_fixed_format(
+ assembly_utils::map_to_arm_compute_weight_format(_gemm_kernel_asm->get_config().weight_format)));
// Forcing 128-byte alignment (required by 32-bit kernels)
const unsigned int alignment = 128;
const size_t B_pretranspose_size = _gemm_kernel_asm->get_B_pretransposed_array_size();
@@ -506,6 +548,22 @@ void Fallback<TypeInput, TypeOutput, OutputStage>::prepare(ITensorPack &tensors)
_gemm_kernel_asm->set_quantized_bias(
reinterpret_cast<const int32_t *>(c->buffer() + c->info()->offset_first_element_in_bytes()), 0);
}
+ const ITensor *b_to_use = b;
+ // Pre-pretranspose B if required
+ const bool run_pre_pretranspose_b = _gemm_info.transpose_b && !isVarWeightsKernel();
+ CpuAuxTensorHandler pre_pretransposed_b(
+ offset_int_vec(PrePretransposedB), _pre_pretransposed_b_info, tensors,
+ /*pack_inject: no need to inject into tensors*/
+ false,
+ /*bypass_alloc: no need to allocate if pre-pretranspose B is not required as this handle will not be used*/
+ !run_pre_pretranspose_b);
+ if (run_pre_pretranspose_b)
+ {
+ ARM_COMPUTE_ERROR_ON(_pre_pretranspose_b == nullptr);
+ ITensorPack pre_pretranspose_pack{{ACL_SRC, b_to_use}, {ACL_DST, pre_pretransposed_b.get()}};
+ _pre_pretranspose_b->run(pre_pretranspose_pack);
+ b_to_use = pre_pretransposed_b.get();
+ }
// Pretranspose B if required
if (_gemm_kernel_asm->B_pretranspose_required())
@@ -513,10 +571,10 @@ void Fallback<TypeInput, TypeOutput, OutputStage>::prepare(ITensorPack &tensors)
// Fixed format kernels need no pretranspose.
ARM_COMPUTE_ERROR_ON(arm_compute::is_fixed_format(
assembly_utils::map_to_arm_compute_weight_format(_gemm_kernel_asm->get_config().weight_format)));
- const int ldb = b->info()->strides_in_bytes().y() / b->info()->element_size();
- const auto in1_ptr =
- reinterpret_cast<const TypeInput *>(b->buffer() + b->info()->offset_first_element_in_bytes());
- const int multi_stride_b = b->info()->strides_in_bytes().z() / b->info()->element_size();
+ const int ldb = b_to_use->info()->strides_in_bytes().y() / b_to_use->info()->element_size();
+ const auto in1_ptr = reinterpret_cast<const TypeInput *>(b_to_use->buffer() +
+ b_to_use->info()->offset_first_element_in_bytes());
+ const int multi_stride_b = b_to_use->info()->strides_in_bytes().z() / b_to_use->info()->element_size();
CpuAuxTensorHandler pretranspose(offset_int_vec(Pretranspose), _pretranspose_info, tensors, false);
ARM_COMPUTE_ERROR_ON(pretranspose.get()->buffer() == nullptr);
@@ -525,6 +583,7 @@ void Fallback<TypeInput, TypeOutput, OutputStage>::prepare(ITensorPack &tensors)
NEScheduler::get().num_threads());
b->mark_as_unused();
+ // Note that we don't need to mark b_to_use as unused, as if it's been assigned to pre_pretransposed_b, its memory will be auto-managed by the handler
}
if (_gemm_info.method == AsmConvMethod::Indirect)
@@ -576,16 +635,33 @@ void Fallback<TypeInput, TypeOutput, OutputStage>::run(ITensorPack &tensors)
const TypeInput *in1_ptr = nullptr;
auto out_ptr = reinterpret_cast<TypeOutput *>(d->buffer() + d->info()->offset_first_element_in_bytes());
+ const ITensor *b_to_use = b;
+
+ // Pre-pretranspose B if required
+ const bool run_pre_pretranspose_b = _gemm_info.transpose_b && !isVarWeightsKernel();
+ CpuAuxTensorHandler pre_pretransposed_b(
+ offset_int_vec(PrePretransposedB), _pre_pretransposed_b_info, tensors,
+ false /*pack_inject: no need to inject into tensors*/,
+ !run_pre_pretranspose_b /*bypass_alloc: no need to allocate if pre-pretranspose B is not required as this handle will not be used*/);
+ if (b_to_use && !_is_b_constant && run_pre_pretranspose_b)
+ {
+ ARM_COMPUTE_ERROR_ON(_pre_pretranspose_b == nullptr);
+ ITensorPack pre_pretranspose_pack{{ACL_SRC, b_to_use}, {ACL_DST, pre_pretransposed_b.get()}};
+ _pre_pretranspose_b->run(pre_pretranspose_pack);
+ b_to_use = pre_pretransposed_b.get();
+ }
+
// Check if B is pre-tranposed and de-reference if not
if (!_gemm_kernel_asm->B_is_pretransposed())
{
- ldb = b->info()->strides_in_bytes().y() / b->info()->element_size();
- multi_stride_b = b->info()->strides_in_bytes().z() / b->info()->element_size();
- in1_ptr = reinterpret_cast<const TypeInput *>(b->buffer() + b->info()->offset_first_element_in_bytes());
+ ldb = b_to_use->info()->strides_in_bytes().y() / b_to_use->info()->element_size();
+ multi_stride_b = b_to_use->info()->strides_in_bytes().z() / b_to_use->info()->element_size();
+ in1_ptr =
+ reinterpret_cast<const TypeInput *>(b_to_use->buffer() + b_to_use->info()->offset_first_element_in_bytes());
}
// If necessary, run pretranspose every time if either weights or biases are non-constant
- if ((b && !_is_b_constant) || (c && !_is_c_constant && c->info()->data_type() == DataType::S32))
+ if ((b_to_use && !_is_b_constant) || (c && !_is_c_constant && c->info()->data_type() == DataType::S32))
{
if (c && c->info()->data_type() == DataType::S32)
{
@@ -596,10 +672,13 @@ void Fallback<TypeInput, TypeOutput, OutputStage>::run(ITensorPack &tensors)
// Pretranspose B if required
if (_B_pretranspose_required)
{
- const int ldb = b->info()->strides_in_bytes().y() / b->info()->element_size();
- const auto b_ptr =
- reinterpret_cast<const TypeInput *>(b->buffer() + b->info()->offset_first_element_in_bytes());
- const int multi_stride_b = b->info()->strides_in_bytes().z() / b->info()->element_size();
+ // Fixed format kernels need no pretranspose.
+ ARM_COMPUTE_ERROR_ON(arm_compute::is_fixed_format(
+ assembly_utils::map_to_arm_compute_weight_format(_gemm_kernel_asm->get_config().weight_format)));
+ const int ldb = b_to_use->info()->strides_in_bytes().y() / b_to_use->info()->element_size();
+ const auto b_ptr = reinterpret_cast<const TypeInput *>(b_to_use->buffer() +
+ b_to_use->info()->offset_first_element_in_bytes());
+ const int multi_stride_b = b_to_use->info()->strides_in_bytes().z() / b_to_use->info()->element_size();
CpuAuxTensorHandler pretranspose(offset_int_vec(Pretranspose), _pretranspose_info, tensors, true);
ARM_COMPUTE_ERROR_ON(pretranspose.get()->buffer() == nullptr);
@@ -762,6 +841,7 @@ Status CpuGemmAssemblyDispatch::has_opt_impl(arm_compute::WeightFormat &expected
arm_gemm::WeightFormat arm_gemm_expected_wf = assembly_utils::map_to_arm_gemm_weight_format(expected_weight_format);
arm_gemm::GemmArgs args(&ci, p.M, p.N, p.K, p.sections, p.batches, p.multis, p.indirect, act, num_threads,
info.fixed_format, info.fast_mode, &cfg);
+ // TODO: Incorporate info.transpose_b COMPMID-6595
switch (a->data_type())
{
case DataType::F32:
diff --git a/src/cpu/operators/internal/CpuGemmAssemblyDispatch.h b/src/cpu/operators/internal/CpuGemmAssemblyDispatch.h
index 5be39a54c0..671a222fed 100644
--- a/src/cpu/operators/internal/CpuGemmAssemblyDispatch.h
+++ b/src/cpu/operators/internal/CpuGemmAssemblyDispatch.h
@@ -21,8 +21,8 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*/
-#ifndef ARM_COMPUTE_CPU_INTERNAL_CPU_GEMM_ASSEMBLY_DISPATCH_H
-#define ARM_COMPUTE_CPU_INTERNAL_CPU_GEMM_ASSEMBLY_DISPATCH_H
+#ifndef ACL_SRC_CPU_OPERATORS_INTERNAL_CPUGEMMASSEMBLYDISPATCH_H
+#define ACL_SRC_CPU_OPERATORS_INTERNAL_CPUGEMMASSEMBLYDISPATCH_H
#include "arm_compute/function_info/ActivationLayerInfo.h"
@@ -57,6 +57,13 @@ struct AsmGemmInfo
bool fixed_format{false};
arm_compute::WeightFormat weight_format{arm_compute::WeightFormat::UNSPECIFIED};
bool reshape_b_only_on_first_run{true};
+ /** Whether we want to perform an additional transpose of b before passing it to gemm or pretranspose_B_array
+ * @note This transpose b operation is also considered a form of "reshape" or "transform", so should be counted for
+ * by the reshape_b_only_on_first_run flag
+ * @note This flag will be silently ignored (assumed to be false) when the weight_format is a fixed format. Because
+ * fixed format kernels do not accept weights (B) with any prior transformations
+ */
+ bool transpose_b{false};
};
/** Assembly kernel glue */
@@ -187,4 +194,4 @@ private:
};
} // namespace cpu
} // namespace arm_compute
-#endif /* ARM_COMPUTE_CPU_INTERNAL_CPU_GEMM_ASSEMBLY_DISPATCH_H */
+#endif // ACL_SRC_CPU_OPERATORS_INTERNAL_CPUGEMMASSEMBLYDISPATCH_H
diff --git a/src/cpu/utils/CpuAuxTensorHandler.h b/src/cpu/utils/CpuAuxTensorHandler.h
index e23b88a777..627216837b 100644
--- a/src/cpu/utils/CpuAuxTensorHandler.h
+++ b/src/cpu/utils/CpuAuxTensorHandler.h
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2021 Arm Limited.
+ * Copyright (c) 2021, 2023 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -21,8 +21,8 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*/
-#ifndef ARM_COMPUTE_CPU_UTILS_CPU_AUX_TENSOR_HANDLER_H
-#define ARM_COMPUTE_CPU_UTILS_CPU_AUX_TENSOR_HANDLER_H
+#ifndef ACL_SRC_CPU_UTILS_CPUAUXTENSORHANDLER_H
+#define ACL_SRC_CPU_UTILS_CPUAUXTENSORHANDLER_H
#include "arm_compute/core/ITensorPack.h"
#include "arm_compute/core/TensorInfo.h"
@@ -71,7 +71,13 @@ public:
}
}
- CpuAuxTensorHandler(TensorInfo &info, ITensor &tensor) : _tensor()
+ /** Create a temporary handle to the original tensor with a new @ref TensorInfo
+ * This is useful if we want to change a tensor's tensor info at run time without modifying the original tensor
+ *
+ * @param[in] info New tensor info to "assign" to @p tensor
+ * @param[in] tensor Tensor to be assigned a new @ref TensorInfo
+ */
+ CpuAuxTensorHandler(TensorInfo &info, const ITensor &tensor) : _tensor()
{
_tensor.allocator()->soft_init(info);
if (info.total_size() <= tensor.info()->total_size())
@@ -108,4 +114,4 @@ private:
};
} // namespace cpu
} // namespace arm_compute
-#endif /* ARM_COMPUTE_CPU_UTILS_CPU_AUX_TENSOR_HANDLER_H */
+#endif // ACL_SRC_CPU_UTILS_CPUAUXTENSORHANDLER_H
diff --git a/tests/validation/NEON/ConvolutionLayer.cpp b/tests/validation/NEON/ConvolutionLayer.cpp
index 7a274906a6..98a5be5484 100644
--- a/tests/validation/NEON/ConvolutionLayer.cpp
+++ b/tests/validation/NEON/ConvolutionLayer.cpp
@@ -1032,6 +1032,8 @@ TEST_SUITE(GEMMConvolutionLayer)
template <typename T>
using NEGEMMConvolutionLayerFixture = ConvolutionValidationFixture<Tensor, Accessor, NEConvolutionLayer, T>;
template <typename T>
+using NEGEMMConvolutionLayerPaddedWeightsFixture = ConvolutionValidationPaddedWeightsFixture<Tensor, Accessor, NEConvolutionLayer, T>;
+template <typename T>
using NEGEMMConvolutionLayerMixedDataLayoutFixture = ConvolutionValidationFixture<Tensor, Accessor, NEConvolutionLayer, T, true>;
/** Test case for memory injection in @ref cpu::CpuGemmConv2d.
@@ -1184,9 +1186,25 @@ FIXTURE_DATA_TEST_CASE(RunMixedDataLayout, NEGEMMConvolutionLayerMixedDataLayout
// Validate output
validate(Accessor(_target), _reference, rel_tolerance_f32, 0.f, float(abs_tolerance_f32));
}
+/** Padded weights
+ * CpuGemmConv2d uses two different paths for reshaping the weights based on if the weight tensor has holes (a common
+ * way to have "holes" in tensor is via extended paddings)
+ *
+ * We only need to test the padded weight path here on a single floating data type and a single layout, because the fallback path is agnostic of them
+ */
+FIXTURE_DATA_TEST_CASE(RunPaddedWeights, NEGEMMConvolutionLayerPaddedWeightsFixture<float>, framework::DatasetMode::ALL, combine(datasets::SmallConvolutionLayerDataset(),
+ framework::dataset::make("ReshapeWeights", { true }),
+ framework::dataset::make("DataType", DataType::F32),
+ framework::dataset::make("DataLayout", { DataLayout::NHWC })
+ ))
+{
+ // Validate output
+ validate(Accessor(_target), _reference, rel_tolerance_f32, 0.f, float(abs_tolerance_f32));
+}
TEST_SUITE_END() // FP32
TEST_SUITE_END() // Float
+// TODO: COMPMID-6596 Extend quantized tests with at least one suite where the weight is padded (the legacy case, see floating point's RunPaddedWeights)
template <typename T>
using NEGEMMConvolutionLayerQuantizedFixture = ConvolutionValidationQuantizedFixture<Tensor, Accessor, NEConvolutionLayer, T>;
template <typename T>
diff --git a/tests/validation/fixtures/ConvolutionLayerFixture.h b/tests/validation/fixtures/ConvolutionLayerFixture.h
index 2051add225..0622e5e6f0 100644
--- a/tests/validation/fixtures/ConvolutionLayerFixture.h
+++ b/tests/validation/fixtures/ConvolutionLayerFixture.h
@@ -123,7 +123,7 @@ public:
public:
void setup(TensorShape input_shape, TensorShape weights_shape, TensorShape bias_shape, TensorShape output_shape, PadStrideInfo info, Size2D dilation, bool reshape_weights,
DataType data_type, DataType weights_data_type, DataLayout data_layout, QuantizationInfo quantization_info, QuantizationInfo weight_quantization_info, ActivationLayerInfo act_info,
- bool mixed_layout = false, PaddingList pre_pad_layer = PaddingList({}))
+ bool mixed_layout = false, PaddingList pre_pad_layer = PaddingList({}), bool padded_weights = false)
{
// This hash is used by random generators. There may be hash collisions but
// this is intentional as it's a very easy way to make the the current
@@ -151,7 +151,7 @@ public:
_use_dynamic_output_quant = true;
}
- _target = compute_target(input_shape, weights_shape, bias_shape, output_shape, info, reshape_weights, dilation, act_info, pre_pad_layer);
+ _target = compute_target(input_shape, weights_shape, bias_shape, output_shape, info, reshape_weights, dilation, act_info, pre_pad_layer, padded_weights);
_reference = compute_reference(input_shape, weights_shape, bias_shape, output_shape, info, dilation, act_info, pre_pad_layer);
}
@@ -267,7 +267,7 @@ protected:
// given input is IN nchw format
TensorType compute_target(TensorShape input_shape, TensorShape weights_shape, const TensorShape &bias_shape, TensorShape output_shape, const PadStrideInfo &info,
- bool reshape_weights, const Size2D &dilation, const ActivationLayerInfo act_info, PaddingList pre_pad_layer = PaddingList({}))
+ bool reshape_weights, const Size2D &dilation, const ActivationLayerInfo act_info, PaddingList pre_pad_layer = PaddingList({}), bool padded_weights = false)
{
ARM_COMPUTE_ERROR_ON((input_shape[2] % weights_shape[2]) != 0);
@@ -335,8 +335,13 @@ protected:
ARM_COMPUTE_ASSERT(weights.info()->is_resizable());
ARM_COMPUTE_ASSERT(bias.info()->is_resizable());
ARM_COMPUTE_ASSERT(dst.info()->is_resizable());
-
- add_padding_x({ &src, &weights, &bias, &dst }, _data_layout);
+ // Test "add padding after configure" behavior. This behavior should not affect the correctness
+ add_padding_x({ &src, &bias, &dst }, _data_layout);
+ // Padding weights may affect code path in some backends
+ if (padded_weights)
+ {
+ add_padding_x({ &weights }, _data_layout);
+ }
// Allocate tensors
src.allocator()->allocate();
@@ -437,6 +442,19 @@ public:
};
template <typename TensorType, typename AccessorType, typename FunctionType, typename T, bool mixed_layout = false>
+class ConvolutionValidationPaddedWeightsFixture : public ConvolutionValidationGenericFixture<TensorType, AccessorType, FunctionType, T, T>
+{
+public:
+ void setup(TensorShape input_shape, TensorShape weights_shape, TensorShape bias_shape, TensorShape output_shape, PadStrideInfo info, Size2D dilation, bool reshape_weights, DataType data_type,
+ DataLayout data_layout)
+ {
+ ConvolutionValidationGenericFixture<TensorType, AccessorType, FunctionType, T, T>::setup(input_shape, weights_shape, bias_shape, output_shape, info, dilation, reshape_weights,
+ data_type, data_type, data_layout,
+ QuantizationInfo(), QuantizationInfo(), ActivationLayerInfo(), mixed_layout, PaddingList({}), true);
+ }
+};
+
+template <typename TensorType, typename AccessorType, typename FunctionType, typename T, bool mixed_layout = false>
class ConvolutionValidationWithPaddingFixture : public ConvolutionValidationGenericFixture<TensorType, AccessorType, FunctionType, T, T>
{
public:
@@ -481,6 +499,7 @@ public:
}
};
+
#ifdef ARM_COMPUTE_ENABLE_FIXED_FORMAT_KERNELS
inline TensorInfo prepare_weights(const TensorInfo tensor_info, const arm_compute::WeightFormat weight_format)
{