aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorFrancesco Petrogalli <francesco.petrogalli@arm.com>2022-06-30 10:22:01 +0000
committerFrancesco Petrogalli <francesco.petrogalli@arm.com>2022-07-19 09:26:27 +0000
commit553f6953fe3bdfad53c11c25f305a16d79d83b24 (patch)
tree73642b948b79662096f593458c6138d2f7f48ec6
parent99c46475daf277aa53e6747f9e41209f418fed33 (diff)
downloadComputeLibrary-553f6953fe3bdfad53c11c25f305a16d79d83b24.tar.gz
[ONCPUML-951] Variable weight support for Convolution.
API changes for NEGEMMConvolutionLayer and CpuGemmConv2d Built with: scons neon=1 opencl=0 os=linux arch=armv8.2-a multi_isa=1 \ build=native -j32 Werror=false validation_tests=1 build_dir=opt \ standalone=1 asserts=1 experimental_fixed_format_kernels=1 . Tested with: ./build/opt/tests/arm_compute_validation Hardware where the test executable was run: Neoverse N1 Test coverage: * NEGEMMConvolutionLayer, CpuGemmConv2d * NHWC (the only one supported by the fixed-format kernels) * F16, F32 * Shapes: RunSmall Change-Id: I4fd3e495a7cbf61210ea02d37440ba9652934e99 Signed-off-by: Francesco Petrogalli <francesco.petrogalli@arm.com> Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/7632 Tested-by: Arm Jenkins <bsgcomp@arm.com> Reviewed-by: Gunes Bayir <gunes.bayir@arm.com> Comments-Addressed: Arm Jenkins <bsgcomp@arm.com> Benchmark: Arm Jenkins <bsgcomp@arm.com>
-rw-r--r--SConscript2
-rw-r--r--arm_compute/core/Types.h99
-rw-r--r--arm_compute/runtime/NEON/functions/NEGEMMConvolutionLayer.h61
-rw-r--r--src/core/NEON/kernels/arm_gemm/gemm_bf16.cpp18
-rw-r--r--src/core/NEON/kernels/arm_gemm/gemm_fp16.cpp18
-rw-r--r--src/core/NEON/kernels/arm_gemm/gemm_fp32.cpp18
-rw-r--r--src/core/NEON/kernels/arm_gemm/gemm_hybrid_indirect.hpp2
-rw-r--r--src/core/NEON/kernels/arm_gemm/gemm_implementation.hpp7
-rw-r--r--src/core/NEON/kernels/arm_gemm/gemm_int16.cpp2
-rw-r--r--src/core/NEON/kernels/arm_gemm/gemm_int8.cpp2
-rw-r--r--src/core/NEON/kernels/arm_gemm/gemm_qint8.cpp2
-rw-r--r--src/core/NEON/kernels/arm_gemm/gemm_quint8.cpp2
-rw-r--r--src/core/NEON/kernels/arm_gemm/gemm_uint16.cpp2
-rw-r--r--src/core/NEON/kernels/arm_gemm/gemm_uint8.cpp2
-rw-r--r--src/core/NEON/kernels/arm_gemm/utils.hpp2
-rw-r--r--src/cpu/kernels/assembly/arm_gemm.hpp116
-rw-r--r--src/cpu/operators/CpuGemm.cpp21
-rw-r--r--src/cpu/operators/CpuGemm.h24
-rw-r--r--src/cpu/operators/CpuGemmConv2d.cpp82
-rw-r--r--src/cpu/operators/CpuGemmConv2d.h24
-rw-r--r--src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp69
-rw-r--r--src/cpu/operators/internal/CpuGemmAssemblyDispatch.h21
-rw-r--r--src/runtime/NEON/functions/NEGEMMConvolutionLayer.cpp10
-rw-r--r--tests/SConscript4
-rw-r--r--tests/framework/Asserts.h7
-rw-r--r--tests/validation/NEON/ConvolutionLayer.cpp262
-rw-r--r--tests/validation/fixtures/ConvolutionLayerFixture.h295
-rw-r--r--utils/TypePrinter.h24
28 files changed, 1040 insertions, 158 deletions
diff --git a/SConscript b/SConscript
index 6f6b078b63..522db94c7c 100644
--- a/SConscript
+++ b/SConscript
@@ -503,7 +503,7 @@ if env['experimental_dynamic_fusion']:
# Fixed format GEMM kernels.
if env['experimental_fixed_format_kernels']:
- arm_compute_env.Append(CPPDEFINES = ['ENABLE_FIXED_FORMAT_KERNELS'])
+ arm_compute_env.Append(CPPDEFINES = ['ARM_COMPUTE_ENABLE_FIXED_FORMAT_KERNELS'])
# Logging files
diff --git a/arm_compute/core/Types.h b/arm_compute/core/Types.h
index 94fe1a07f4..989cdfb8cc 100644
--- a/arm_compute/core/Types.h
+++ b/arm_compute/core/Types.h
@@ -32,6 +32,7 @@
#include "arm_compute/core/TensorShape.h"
#include "arm_compute/core/experimental/IPostOp.h"
#include "arm_compute/core/utils/misc/Macros.h"
+#include "src/cpu/kernels/assembly/arm_gemm.hpp"
#include "support/Bfloat16.h"
#include "support/Half.h"
@@ -774,10 +775,10 @@ public:
private:
std::pair<unsigned int, unsigned int> _stride;
- unsigned int _pad_left;
- unsigned int _pad_top;
- unsigned int _pad_right;
- unsigned int _pad_bottom;
+ unsigned int _pad_left;
+ unsigned int _pad_top;
+ unsigned int _pad_right;
+ unsigned int _pad_bottom;
DimensionRoundingType _round_type;
};
@@ -919,14 +920,14 @@ public:
}
private:
- std::vector<float> _min_sizes;
- std::vector<float> _variances;
- float _offset;
- bool _flip;
- bool _clip;
- std::vector<float> _max_sizes;
- std::vector<float> _aspect_ratios;
- Coordinates2D _img_size;
+ std::vector<float> _min_sizes;
+ std::vector<float> _variances;
+ float _offset;
+ bool _flip;
+ bool _clip;
+ std::vector<float> _max_sizes;
+ std::vector<float> _aspect_ratios;
+ Coordinates2D _img_size;
std::array<float, 2> _steps;
};
@@ -1171,15 +1172,15 @@ public:
}
private:
- unsigned int _max_detections;
- unsigned int _max_classes_per_detection;
- float _nms_score_threshold;
- float _iou_threshold;
- unsigned int _num_classes;
+ unsigned int _max_detections;
+ unsigned int _max_classes_per_detection;
+ float _nms_score_threshold;
+ float _iou_threshold;
+ unsigned int _num_classes;
std::array<float, 4> _scales_values;
- bool _use_regular_nms;
- unsigned int _detection_per_class;
- bool _dequantize_scores;
+ bool _use_regular_nms;
+ unsigned int _detection_per_class;
+ bool _dequantize_scores;
};
/** Pooling Layer Information struct*/
@@ -1612,13 +1613,13 @@ public:
}
private:
- float _img_width;
- float _img_height;
- float _scale;
- bool _apply_scale;
- bool _correct_transform_coords;
+ float _img_width;
+ float _img_height;
+ float _scale;
+ bool _apply_scale;
+ bool _correct_transform_coords;
std::array<float, 4> _weights;
- float _bbox_xform_clip;
+ float _bbox_xform_clip;
};
/** Activation Layer Information class */
@@ -1900,7 +1901,7 @@ class WeightsInfo
public:
/** Default constructor */
WeightsInfo()
- : _are_reshaped(false), _kernel_width(0), _kernel_height(0), _num_kernels(0), _retain_internal_weights(false)
+ : _are_reshaped(false), _kernel_width(0), _kernel_height(0), _num_kernels(0), _retain_internal_weights(false), _weight_format(arm_gemm::WeightFormat::UNSPECIFIED)
{
}
/** Constructor
@@ -1910,9 +1911,11 @@ public:
* @param[in] kernel_height Kernel height.
* @param[in] num_kernels Number of convolution kernels.
* @param[in] retain_internal_weights (Optional) True if internal reshaped weights must be retained. Used for reconfiguration purposes. Default is false.
+ * @param[in] weight_format (Optional) arm_gemm:WeightFormat enumeration requested by the user. Default is arm_gemm::WeightFormat::UNSPECIFIED.
*/
- WeightsInfo(bool are_reshaped, unsigned int kernel_width, unsigned int kernel_height, unsigned int num_kernels, bool retain_internal_weights = false)
- : _are_reshaped(are_reshaped), _kernel_width(kernel_width), _kernel_height(kernel_height), _num_kernels(num_kernels), _retain_internal_weights(retain_internal_weights)
+ WeightsInfo(bool are_reshaped, unsigned int kernel_width, unsigned int kernel_height, unsigned int num_kernels, bool retain_internal_weights = false,
+ arm_gemm::WeightFormat weight_format = arm_gemm::WeightFormat::UNSPECIFIED)
+ : _are_reshaped(are_reshaped), _kernel_width(kernel_width), _kernel_height(kernel_height), _num_kernels(num_kernels), _retain_internal_weights(retain_internal_weights), _weight_format(weight_format)
{
}
/** Flag which specifies if the weights tensor has been reshaped.
@@ -1943,13 +1946,26 @@ public:
{
return _retain_internal_weights;
}
+ arm_gemm::WeightFormat weight_format() const
+ {
+ return _weight_format;
+ }
+ unsigned int kernel_width() const
+ {
+ return _kernel_width;
+ }
+ unsigned int kernel_height() const
+ {
+ return _kernel_height;
+ }
private:
- bool _are_reshaped;
- unsigned int _kernel_width;
- unsigned int _kernel_height;
- unsigned int _num_kernels;
- bool _retain_internal_weights;
+ bool _are_reshaped;
+ unsigned int _kernel_width;
+ unsigned int _kernel_height;
+ unsigned int _num_kernels;
+ bool _retain_internal_weights;
+ arm_gemm::WeightFormat _weight_format;
};
/** GEMM reshape information class. This class stores the necessary information about matrix A and matrix B reshape.
@@ -2160,7 +2176,8 @@ public:
_pretranspose_B(false),
_activation_info(),
_post_ops(),
- _fixed_format(false)
+ _fixed_format(false),
+ _weight_format(arm_gemm::WeightFormat::UNSPECIFIED)
{
}
/** Constructor
@@ -2180,11 +2197,12 @@ public:
* @param[in] activation_info (Optional) Activation to apply after the matrix multiplication
* @param[in] post_ops (Optional) A sequence of post operations that are performed after the main operation.
* @param[in] fixed_format (Optional) Specify the selection of fixed format kernels for variable weights support in GEMM. These kernels expect the weights tensor to be in amemory format that is fixed by the kernel itself. For more information, see arm_gemm::WeightFormat.
+ * @param[in] weight_format (Optional) arm_gemm:WeightFormat enumeration requested by the user. Default is arm_gemm::WeightFormat::UNSPECIFIED.
*/
GEMMInfo(bool is_a_reshaped, bool is_b_reshaped, bool reshape_b_only_on_first_run, int depth_output_gemm3d = 0, bool reinterpret_input_as_3d = false, bool retain_internal_weights = false,
GEMMLowpOutputStageInfo gemmlowp_output_stage = GEMMLowpOutputStageInfo(), bool fp_mixed_precision = false, bool fast_math = false, bool broadcast_bias = false,
const ActivationLayerInfo &activation_info = ActivationLayerInfo(), const experimental::PostOpList<ITensorInfo *> &post_ops = experimental::PostOpList<ITensorInfo *>(),
- bool fixed_format = false) noexcept
+ bool fixed_format = false, arm_gemm::WeightFormat weight_format = arm_gemm::WeightFormat::UNSPECIFIED) noexcept
: _is_a_reshaped(is_a_reshaped),
_is_b_reshaped(is_b_reshaped),
_reshape_b_only_on_first_run(reshape_b_only_on_first_run),
@@ -2199,7 +2217,8 @@ public:
_pretranspose_B(false),
_activation_info(activation_info),
_post_ops(post_ops),
- _fixed_format(fixed_format)
+ _fixed_format(fixed_format),
+ _weight_format(weight_format)
{
}
/** Flag which specifies if the matrix A has been reshaped
@@ -2373,6 +2392,11 @@ public:
return _fixed_format;
}
+ arm_gemm::WeightFormat weight_format() const
+ {
+ return _weight_format;
+ }
+
private:
bool _is_a_reshaped;
bool _is_b_reshaped;
@@ -2389,6 +2413,7 @@ private:
ActivationLayerInfo _activation_info;
experimental::PostOpList<ITensorInfo *> _post_ops;
bool _fixed_format;
+ arm_gemm::WeightFormat _weight_format;
};
/** Winograd information */
diff --git a/arm_compute/runtime/NEON/functions/NEGEMMConvolutionLayer.h b/arm_compute/runtime/NEON/functions/NEGEMMConvolutionLayer.h
index cf5fb82398..2af11ad656 100644
--- a/arm_compute/runtime/NEON/functions/NEGEMMConvolutionLayer.h
+++ b/arm_compute/runtime/NEON/functions/NEGEMMConvolutionLayer.h
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017-2021 Arm Limited.
+ * Copyright (c) 2017-2022 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -122,6 +122,65 @@ public:
const WeightsInfo &weights_info = WeightsInfo(), const Size2D &dilation = Size2D(1U, 1U), const ActivationLayerInfo &act_info = ActivationLayerInfo(),
bool enable_fast_math = false, unsigned int num_groups = 1);
+ /** Static function to check if there is an optimized version of
+ * GEMM available for the input parameters.
+ *
+ * The method is intended to be used to find out the optimal
+ * memory layout to be used for the weights tensor when running
+ * variable weights execution.
+ *
+ * The user can query the database of optimised kernels in
+ * arm_gemm by specifying one of the enumerations of
+ * arm_gemm::WeightFormat in the weight_format field of the input
+ * parameter weights_info. In case of success, the method
+ * writes the expected format in the output parameter
+ * expected_weight_format. The expected_weight_format can than be
+ * used in the configure method of the class for retrieving the
+ * best optimal kernel.
+ *
+ * Use case one - query for a specific format:
+ *
+ * WeightInfo weights_info(..., arm_gemm::WeightFormat::OHWIo4, ...); // Set the value of the input query.
+ * if (NEGEMMConvolutionlayer::has_opt_impl(WeightFormat(), ...., weights_info, ...))
+ * {
+ * auto conv = std::unique_ptr<NEGEMMConvolutionlayer>();
+ * conv->configure(..., weights_info, ...); // uses the same WeightFormat the user wanted originally, OHWYo4.
+ * conv->run(...);
+ * }
+ *
+ * Use case two - query for any format that would be optimal for the GEMM to execute:
+ *
+ * WeightInfo weights_info(..., arm_gemm::WeightFormat::ANY, ...); // Set the value of the input query.
+ * arm_gemm::WeightFormat expected_wf;
+ * if (NEGEMMConvolutionlayer::has_opt_impl(expected_wf, ...., weights_info, ...))
+ * {
+ * auto conv = std::unique_ptr<NEGEMMConvolutionlayer>();
+ * // ... code to convert the layout of the weights tensor to the layout returned by has_opt_impl
+ * WeightInfo new_weights_info(..., expected_wf, ...); // Set the value of the WeightFormat returned by has_opt_impl.
+ * conv->configure(..., new_weights_info, ...);
+ * conv->run(...);
+ * }
+ *
+ * Notice that a GEMM configured with a WeightFormat other than
+ * UNSPECIFIED will run GEMM with variable weights mode.
+ *
+ * @param[out] expected_weight_format The arm_compute::WeightFormat expected by the kernel.
+ * @param[in] src Source tensor info.
+ * @param[in] weights Weights tensor info.
+ * @param[in] biases Biases tensor info. Shared biases supported.
+ * @param[in] dst Destination tensor info.
+ * @param[in] conv_info Contains padding and stride information described in @ref PadStrideInfo.
+ * @param[in] weights_info (optional) Specifies additional configuration parameters for the weights of the GEMM computation.
+ * @param[in] dilation (Optional) Dilation, in elements, across x and y. Defaults to (1, 1).
+ * @param[in] act_info (Optional) Activation layer information in case of a fused activation. Only RELU, BOUNDED_RELU and LU_BOUNDED_RELU supported. And no activation (i.e. Linear) which is the default value.
+ * @param[in] enable_fast_math (Optional) Enable fast math computation. In case this flag were set, the function could dispatch the fastest implementation
+ *
+ * @return a Status
+ */
+ static Status has_opt_impl(arm_gemm::WeightFormat &expected_weight_format, const ITensorInfo *src, const ITensorInfo *weights, const ITensorInfo *biases, const ITensorInfo *dst,
+ const PadStrideInfo &conv_info,
+ const WeightsInfo &weights_info = WeightsInfo(), const Size2D &dilation = Size2D(1U, 1U), const ActivationLayerInfo &act_info = ActivationLayerInfo(),
+ bool enable_fast_math = false);
// Inherited methods overridden:
void run() override;
void prepare() override;
diff --git a/src/core/NEON/kernels/arm_gemm/gemm_bf16.cpp b/src/core/NEON/kernels/arm_gemm/gemm_bf16.cpp
index 50fc5bdb8a..58e4861bc0 100644
--- a/src/core/NEON/kernels/arm_gemm/gemm_bf16.cpp
+++ b/src/core/NEON/kernels/arm_gemm/gemm_bf16.cpp
@@ -33,21 +33,21 @@
#include "kernels/a32_sgemm_8x6.hpp"
-#ifdef ENABLE_FIXED_FORMAT_KERNELS
+#ifdef ARM_COMPUTE_ENABLE_FIXED_FORMAT_KERNELS
#include "kernels/a64_ffhybrid_bf16fp32_mmla_6x16.hpp"
#include "kernels/a64_ffinterleaved_bf16fp32_dot_8x12.hpp"
#include "kernels/a64_ffinterleaved_bf16fp32_mmla_8x12.hpp"
-#endif // ENABLE_FIXED_FORMAT_KERNELS
+#endif // ARM_COMPUTE_ENABLE_FIXED_FORMAT_KERNELS
#include "kernels/a64_hybrid_bf16fp32_dot_6x16.hpp"
#include "kernels/a64_hybrid_bf16fp32_mmla_6x16.hpp"
#include "kernels/a64_interleaved_bf16fp32_dot_8x12.hpp"
#include "kernels/a64_interleaved_bf16fp32_mmla_8x12.hpp"
#include "kernels/a64_sgemm_8x12.hpp"
-#ifdef ENABLE_FIXED_FORMAT_KERNELS
+#ifdef ARM_COMPUTE_ENABLE_FIXED_FORMAT_KERNELS
#include "kernels/sve_ffhybrid_bf16fp32_mmla_6x4VL.hpp"
#include "kernels/sve_ffinterleaved_bf16fp32_mmla_8x3VL.hpp"
-#endif // ENABLE_FIXED_FORMAT_KERNELS
+#endif // ARM_COMPUTE_ENABLE_FIXED_FORMAT_KERNELS
#include "kernels/sve_hybrid_bf16fp32_dot_6x4VL.hpp"
#include "kernels/sve_hybrid_bf16fp32_mmla_6x4VL.hpp"
#include "kernels/sve_interleaved_bf16fp32_dot_8x3VL.hpp"
@@ -89,7 +89,7 @@ GemmImplementation<bfloat16, float>::with_estimate(
[](const GemmArgs &args) { return GemmInterleaved<cls_sve_interleaved_bf16fp32_dot_8x3VL, bfloat16, float>::estimate_cycles<bfloat16>(args); },
[](const GemmArgs &args) { return new GemmInterleaved<cls_sve_interleaved_bf16fp32_dot_8x3VL, bfloat16, float>(args); }
),
-#ifdef ENABLE_FIXED_FORMAT_KERNELS
+#ifdef ARM_COMPUTE_ENABLE_FIXED_FORMAT_KERNELS
GemmImplementation<bfloat16, float>::with_estimate(
GemmMethod::GEMM_INTERLEAVED,
"sve_ffinterleaved_bf16fp32_mmla_8x3VL",
@@ -106,7 +106,7 @@ GemmImplementation<bfloat16, float>::with_estimate(
[](const GemmArgs &args) { return GemmHybridIndirectFixedFormat<cls_sve_ffhybrid_bf16fp32_mmla_6x4VL, bfloat16, float>::estimate_cycles<bfloat16>(args); },
[](const GemmArgs &args) { return new GemmHybridIndirectFixedFormat<cls_sve_ffhybrid_bf16fp32_mmla_6x4VL, bfloat16, float>(args); }
),
-#endif // ENABLE_FIXED_FORMAT_KERNELS
+#endif // ARM_COMPUTE_ENABLE_FIXED_FORMAT_KERNELS
#endif // ARM_COMPUTE_ENABLE_SVE
GemmImplementation<bfloat16, float>::with_estimate(
GemmMethod::GEMM_HYBRID,
@@ -136,7 +136,7 @@ GemmImplementation<bfloat16, float>::with_estimate(
[](const GemmArgs &args) { return GemmInterleaved<cls_a64_interleaved_bf16fp32_dot_8x12, bfloat16, float>::estimate_cycles<bfloat16>(args); },
[](const GemmArgs &args) { return new GemmInterleaved<cls_a64_interleaved_bf16fp32_dot_8x12, bfloat16, float>(args); }
),
-#ifdef ENABLE_FIXED_FORMAT_KERNELS
+#ifdef ARM_COMPUTE_ENABLE_FIXED_FORMAT_KERNELS
GemmImplementation<bfloat16, float>::with_estimate(
GemmMethod::GEMM_INTERLEAVED,
"a64_ffinterleaved_bf16fp32_mmla_8x12",
@@ -161,7 +161,7 @@ GemmImplementation<bfloat16, float>::with_estimate(
[](const GemmArgs &args) { return GemmInterleavedFixedFormat<cls_a64_ffinterleaved_bf16fp32_dot_8x12, bfloat16, float>::estimate_cycles<bfloat16>(args); },
[](const GemmArgs &args) { return new GemmInterleavedFixedFormat<cls_a64_ffinterleaved_bf16fp32_dot_8x12, bfloat16, float>(args); }
),
-#endif // ENABLE_FIXED_FORMAT_KERNELS
+#endif // ARM_COMPUTE_ENABLE_FIXED_FORMAT_KERNELS
GemmImplementation<bfloat16, float>::with_estimate(
GemmMethod::GEMM_INTERLEAVED,
"a64_sgemm_8x12",
@@ -197,7 +197,7 @@ const GemmImplementation<bfloat16, float> *gemm_implementation_list<bfloat16, fl
/* Explicitly instantiate the external functions for these types. */
template UniqueGemmCommon<bfloat16, float> gemm<bfloat16, float, Nothing>(const GemmArgs &args, const Nothing &);
-template bool has_opt_gemm<bfloat16, float, Nothing>(const GemmArgs &args, const Nothing &);
+template bool has_opt_gemm<bfloat16, float, Nothing>(WeightFormat &weight_format, const GemmArgs &args, const Nothing &);
template KernelDescription get_gemm_method<bfloat16, float, Nothing>(const GemmArgs &args, const Nothing &);
template std::vector<KernelDescription> get_compatible_kernels<bfloat16, float, Nothing>(const GemmArgs &args, const Nothing &);
diff --git a/src/core/NEON/kernels/arm_gemm/gemm_fp16.cpp b/src/core/NEON/kernels/arm_gemm/gemm_fp16.cpp
index 2796b0d204..d749dce98d 100644
--- a/src/core/NEON/kernels/arm_gemm/gemm_fp16.cpp
+++ b/src/core/NEON/kernels/arm_gemm/gemm_fp16.cpp
@@ -34,17 +34,17 @@
#include "gemm_interleaved.hpp"
#include "kernels/a32_sgemm_8x6.hpp"
-#ifdef ENABLE_FIXED_FORMAT_KERNELS
+#ifdef ARM_COMPUTE_ENABLE_FIXED_FORMAT_KERNELS
#include "kernels/a64_ffhybrid_fp16_mla_6x32.hpp"
#include "kernels/a64_ffinterleaved_fp16_mla_8x24.hpp"
-#endif // ENABLE_FIXED_FORMAT_KERNELS
+#endif // ARM_COMPUTE_ENABLE_FIXED_FORMAT_KERNELS
#include "kernels/a64_hgemm_8x24.hpp"
#include "kernels/a64_hybrid_fp16_mla_6x32.hpp"
#include "kernels/a64_sgemm_8x12.hpp"
-#ifdef ENABLE_FIXED_FORMAT_KERNELS
+#ifdef ARM_COMPUTE_ENABLE_FIXED_FORMAT_KERNELS
#include "kernels/sve_ffhybrid_fp16_mla_6x4VL.hpp"
#include "kernels/sve_ffinterleaved_fp16_mla_8x3VL.hpp"
-#endif // ENABLE_FIXED_FORMAT_KERNELS
+#endif // ARM_COMPUTE_ENABLE_FIXED_FORMAT_KERNELS
#include "kernels/sve_hybrid_fp16_mla_6x4VL.hpp"
#include "kernels/sve_interleaved_fp16_mla_8x3VL.hpp"
@@ -66,7 +66,7 @@ GemmImplementation<__fp16, __fp16>::with_estimate(
[](const GemmArgs &args) { return GemmInterleaved<cls_sve_interleaved_fp16_mla_8x3VL, __fp16, __fp16>::estimate_cycles<__fp16>(args); },
[](const GemmArgs &args) { return new GemmInterleaved<cls_sve_interleaved_fp16_mla_8x3VL, __fp16, __fp16>(args); }
),
-#ifdef ENABLE_FIXED_FORMAT_KERNELS
+#ifdef ARM_COMPUTE_ENABLE_FIXED_FORMAT_KERNELS
GemmImplementation<__fp16, __fp16>::with_estimate(
GemmMethod::GEMM_INTERLEAVED,
"sve_ffinterleaved_fp16_mla_8x3VL",
@@ -83,7 +83,7 @@ GemmImplementation<__fp16, __fp16>::with_estimate(
[](const GemmArgs &args) { return GemmHybridIndirectFixedFormat<cls_sve_ffhybrid_fp16_mla_6x4VL, __fp16, __fp16>::estimate_cycles<__fp16>(args); },
[](const GemmArgs &args) { return new GemmHybridIndirectFixedFormat<cls_sve_ffhybrid_fp16_mla_6x4VL, __fp16, __fp16>(args); }
),
-#endif // ENABLE_FIXED_FORMAT_KERNELS
+#endif // ARM_COMPUTE_ENABLE_FIXED_FORMAT_KERNELS
#endif // ARM_COMPUTE_ENABLE_SVE
#if defined(__aarch64__)
GemmImplementation<__fp16, __fp16>::with_estimate(
@@ -100,7 +100,7 @@ GemmImplementation<__fp16, __fp16>::with_estimate(
[](const GemmArgs &args) { return GemmInterleaved<cls_a64_hgemm_8x24, __fp16, __fp16>::estimate_cycles<__fp16>(args); },
[](const GemmArgs &args) { return new GemmInterleaved<cls_a64_hgemm_8x24, __fp16, __fp16>(args); }
),
-#ifdef ENABLE_FIXED_FORMAT_KERNELS
+#ifdef ARM_COMPUTE_ENABLE_FIXED_FORMAT_KERNELS
GemmImplementation<__fp16, __fp16>::with_estimate(
GemmMethod::GEMM_INTERLEAVED,
"a64_ffinterleaved_fp16_mla_8x24",
@@ -117,7 +117,7 @@ GemmImplementation<__fp16, __fp16>::with_estimate(
[](const GemmArgs &args) { return GemmHybridIndirectFixedFormat<cls_a64_ffhybrid_fp16_mla_6x32, __fp16, __fp16>::estimate_cycles<__fp16>(args); },
[](const GemmArgs &args) { return new GemmHybridIndirectFixedFormat<cls_a64_ffhybrid_fp16_mla_6x32, __fp16, __fp16>(args); }
),
-#endif // ENABLE_FIXED_FORMAT_KERNELS
+#endif // ARM_COMPUTE_ENABLE_FIXED_FORMAT_KERNELS
{
GemmMethod::GEMM_INTERLEAVED,
"a64_sgemm_8x12",
@@ -152,7 +152,7 @@ const GemmImplementation<__fp16, __fp16> *gemm_implementation_list<__fp16, __fp1
/* Explicitly instantiate the external functions for these types. */
template UniqueGemmCommon<__fp16, __fp16> gemm<__fp16, __fp16, Nothing>(const GemmArgs &args, const Nothing &);
-template bool has_opt_gemm<__fp16, __fp16, Nothing>(const GemmArgs &args, const Nothing &);
+template bool has_opt_gemm<__fp16, __fp16, Nothing>(WeightFormat &weight_format, const GemmArgs &args, const Nothing &);
template KernelDescription get_gemm_method<__fp16, __fp16, Nothing>(const GemmArgs &args, const Nothing &);
template std::vector<KernelDescription> get_compatible_kernels<__fp16, __fp16, Nothing>(const GemmArgs &args, const Nothing &);
diff --git a/src/core/NEON/kernels/arm_gemm/gemm_fp32.cpp b/src/core/NEON/kernels/arm_gemm/gemm_fp32.cpp
index 4f7e191fb3..0fc9e8b912 100644
--- a/src/core/NEON/kernels/arm_gemm/gemm_fp32.cpp
+++ b/src/core/NEON/kernels/arm_gemm/gemm_fp32.cpp
@@ -31,12 +31,12 @@
#include "gemv_pretransposed.hpp"
#include "kernels/a32_sgemm_8x6.hpp"
-#ifdef ENABLE_FIXED_FORMAT_KERNELS
+#ifdef ARM_COMPUTE_ENABLE_FIXED_FORMAT_KERNELS
#include "kernels/a64_ffhybrid_fp32_mla_6x16.hpp"
#include "kernels/a64_ffhybrid_fp32bf16fp32_mmla_4x24.hpp"
#include "kernels/a64_ffinterleaved_bf16fp32_mmla_8x12.hpp"
#include "kernels/a64_ffinterleaved_fp32_mla_8x12.hpp"
-#endif // ENABLE_FIXED_FORMAT_KERNELS
+#endif // ARM_COMPUTE_ENABLE_FIXED_FORMAT_KERNELS
#include "kernels/a64_hybrid_fp32bf16fp32_mmla_4x24.hpp"
#include "kernels/a64_hybrid_fp32bf16fp32_mmla_6x16.hpp"
#include "kernels/a64_hybrid_fp32_mla_4x24.hpp"
@@ -48,12 +48,12 @@
#include "kernels/a64_smallK_hybrid_fp32_mla_6x4.hpp"
#include "kernels/a64_smallK_hybrid_fp32_mla_8x4.hpp"
-#ifdef ENABLE_FIXED_FORMAT_KERNELS
+#ifdef ARM_COMPUTE_ENABLE_FIXED_FORMAT_KERNELS
#include "kernels/sve_ffhybrid_fp32_mla_6x4VL.hpp"
#include "kernels/sve_ffhybrid_fp32bf16fp32_mmla_4x6VL.hpp"
#include "kernels/sve_ffinterleaved_fp32_mla_8x3VL.hpp"
#include "kernels/sve_ffinterleaved_bf16fp32_mmla_8x3VL.hpp"
-#endif // ENABLE_FIXED_FORMAT_KERNELS
+#endif // ARM_COMPUTE_ENABLE_FIXED_FORMAT_KERNELS
#include "kernels/sve_hybrid_fp32bf16fp32_mmla_4x6VL.hpp"
#include "kernels/sve_hybrid_fp32bf16fp32_mmla_6x4VL.hpp"
#include "kernels/sve_hybrid_fp32_mla_6x4VL.hpp"
@@ -165,7 +165,7 @@ GemmImplementation<float, float>::with_estimate(
[](const GemmArgs &args) { return GemmInterleaved<cls_sve_interleaved_fp32_mla_8x3VL, float, float>::estimate_cycles<float>(args); },
[](const GemmArgs &args) { return new GemmInterleaved<cls_sve_interleaved_fp32_mla_8x3VL, float, float>(args); }
),
- #ifdef ENABLE_FIXED_FORMAT_KERNELS
+ #ifdef ARM_COMPUTE_ENABLE_FIXED_FORMAT_KERNELS
#ifdef ARM_COMPUTE_ENABLE_BF16
GemmImplementation<float, float>::with_estimate(
GemmMethod::GEMM_INTERLEAVED,
@@ -200,7 +200,7 @@ GemmImplementation<float, float>::with_estimate(
[](const GemmArgs &args) { return GemmHybridIndirectFixedFormat<cls_sve_ffhybrid_fp32_mla_6x4VL, float, float>::estimate_cycles<float>(args); },
[](const GemmArgs &args) { return new GemmHybridIndirectFixedFormat<cls_sve_ffhybrid_fp32_mla_6x4VL, float, float>(args); }
),
-#endif // ENABLE_FIXED_FORMAT_KERNELS
+#endif // ARM_COMPUTE_ENABLE_FIXED_FORMAT_KERNELS
#endif // ARM_COMPUTE_ENABLE_SVE
// Cortex-A35 specific kernel - use for any problem on A35, and never in any other cases.
{
@@ -253,7 +253,7 @@ GemmImplementation<float, float>::with_estimate(
[](const GemmArgs &args) { return GemmInterleaved<cls_a64_sgemm_8x12, float, float>::estimate_cycles<float>(args); },
[](const GemmArgs &args) { return new GemmInterleaved<cls_a64_sgemm_8x12, float, float>(args); }
),
-#ifdef ENABLE_FIXED_FORMAT_KERNELS
+#ifdef ARM_COMPUTE_ENABLE_FIXED_FORMAT_KERNELS
#ifdef ARM_COMPUTE_ENABLE_BF16
// "fast mode" (BF16) kernels
GemmImplementation<float, float>::with_estimate(
@@ -289,7 +289,7 @@ GemmImplementation<float, float>::with_estimate(
[](const GemmArgs &args) { return GemmHybridIndirectFixedFormat<cls_a64_ffhybrid_fp32_mla_6x16, float, float>::estimate_cycles<float>(args); },
[](const GemmArgs &args) { return new GemmHybridIndirectFixedFormat<cls_a64_ffhybrid_fp32_mla_6x16, float, float>(args); }
),
-#endif // ENABLE_FIXED_FORMAT_KERNELS
+#endif // ARM_COMPUTE_ENABLE_FIXED_FORMAT_KERNELS
#endif // __aarch64__
#ifdef __arm__
@@ -318,7 +318,7 @@ const GemmImplementation<float, float> *gemm_implementation_list<float, float>()
/* Explicitly instantiate the external functions for these types. */
template UniqueGemmCommon<float, float> gemm<float, float, Nothing>(const GemmArgs &args, const Nothing &);
-template bool has_opt_gemm<float, float, Nothing>(const GemmArgs &args, const Nothing &);
+template bool has_opt_gemm<float, float, Nothing>(WeightFormat &weight_format, const GemmArgs &args, const Nothing &);
template KernelDescription get_gemm_method<float, float, Nothing>(const GemmArgs &args, const Nothing &);
template std::vector<KernelDescription> get_compatible_kernels<float, float, Nothing> (const GemmArgs &args, const Nothing &);
diff --git a/src/core/NEON/kernels/arm_gemm/gemm_hybrid_indirect.hpp b/src/core/NEON/kernels/arm_gemm/gemm_hybrid_indirect.hpp
index c41b0a5b3e..90e2f07607 100644
--- a/src/core/NEON/kernels/arm_gemm/gemm_hybrid_indirect.hpp
+++ b/src/core/NEON/kernels/arm_gemm/gemm_hybrid_indirect.hpp
@@ -450,7 +450,7 @@ public:
}
/* Make sure we've been set up correctly. */
- assert(_B_transposed);
+ assert(FixedFormat || _B_transposed);
static_assert(std::is_same<To, Tloi>::value, "gemm_native: Operand types must be the same.");
// static_assert(std::is_same<Tr, Tri>::value, "gemm_native: Result types must be the same.");
diff --git a/src/core/NEON/kernels/arm_gemm/gemm_implementation.hpp b/src/core/NEON/kernels/arm_gemm/gemm_implementation.hpp
index 75fb1cb306..19c8fcadd3 100644
--- a/src/core/NEON/kernels/arm_gemm/gemm_implementation.hpp
+++ b/src/core/NEON/kernels/arm_gemm/gemm_implementation.hpp
@@ -306,9 +306,12 @@ std::vector<KernelDescription> get_compatible_kernels(const GemmArgs &args, cons
}
template<typename Top, typename Tret, class OutputStage>
-bool has_opt_gemm(const GemmArgs &args, const OutputStage &os) {
+bool has_opt_gemm(WeightFormat &wf, const GemmArgs &args, const OutputStage &os) {
const GemmImplementation<Top, Tret, OutputStage> *impl;
- return find_implementation<Top, Tret, OutputStage>(args, os, impl);
+ const bool success = find_implementation<Top, Tret, OutputStage>(args, os, impl);
+ if (success)
+ wf = UniqueGemmCommon<Top, Tret>(impl->do_instantiate(args, os))->get_config().weight_format;
+ return success;
}
template<typename Top, typename Tret, class OutputStage>
diff --git a/src/core/NEON/kernels/arm_gemm/gemm_int16.cpp b/src/core/NEON/kernels/arm_gemm/gemm_int16.cpp
index 3915861112..18d8fc9312 100644
--- a/src/core/NEON/kernels/arm_gemm/gemm_int16.cpp
+++ b/src/core/NEON/kernels/arm_gemm/gemm_int16.cpp
@@ -56,7 +56,7 @@ const GemmImplementation<int16_t, int32_t> *gemm_implementation_list<int16_t, in
/* Explicitly instantiate the external functions for these types. */
template UniqueGemmCommon<int16_t, int32_t> gemm<int16_t, int32_t, Nothing>(const GemmArgs &args, const Nothing &);
-template bool has_opt_gemm<int16_t, int32_t, Nothing>(const GemmArgs &args, const Nothing &);
+template bool has_opt_gemm<int16_t, int32_t, Nothing>(WeightFormat &weight_format, const GemmArgs &args, const Nothing &);
template std::vector<KernelDescription> get_compatible_kernels<int16_t, int32_t, Nothing> (const GemmArgs &args, const Nothing &);
} // namespace arm_gemm
diff --git a/src/core/NEON/kernels/arm_gemm/gemm_int8.cpp b/src/core/NEON/kernels/arm_gemm/gemm_int8.cpp
index 0c68e4dd99..24507486ac 100644
--- a/src/core/NEON/kernels/arm_gemm/gemm_int8.cpp
+++ b/src/core/NEON/kernels/arm_gemm/gemm_int8.cpp
@@ -159,7 +159,7 @@ const GemmImplementation<int8_t, int32_t> *gemm_implementation_list<int8_t, int3
/* Explicitly instantiate the external functions for these types. */
template UniqueGemmCommon<int8_t, int32_t> gemm<int8_t, int32_t, Nothing>(const GemmArgs &args, const Nothing &);
-template bool has_opt_gemm<int8_t, int32_t, Nothing>(const GemmArgs &args, const Nothing &);
+template bool has_opt_gemm<int8_t, int32_t, Nothing>(WeightFormat &weight_format, const GemmArgs &args, const Nothing &);
template std::vector<KernelDescription> get_compatible_kernels<int8_t, int32_t, Nothing> (const GemmArgs &args, const Nothing &);
} // namespace arm_gemm
diff --git a/src/core/NEON/kernels/arm_gemm/gemm_qint8.cpp b/src/core/NEON/kernels/arm_gemm/gemm_qint8.cpp
index 6b813c7974..1d7b9c5b73 100644
--- a/src/core/NEON/kernels/arm_gemm/gemm_qint8.cpp
+++ b/src/core/NEON/kernels/arm_gemm/gemm_qint8.cpp
@@ -230,7 +230,7 @@ const GemmImplementation<int8_t, int8_t, Requantize32> *gemm_implementation_list
}
template UniqueGemmCommon<int8_t, int8_t> gemm<int8_t, int8_t, Requantize32>(const GemmArgs &args, const Requantize32 &os);
-template bool has_opt_gemm<int8_t, int8_t, Requantize32>(const GemmArgs &args, const Requantize32 &os);
+template bool has_opt_gemm<int8_t, int8_t, Requantize32>(WeightFormat &weight_format, const GemmArgs &args, const Requantize32 &os);
template std::vector<KernelDescription> get_compatible_kernels<int8_t, int8_t, Requantize32>(const GemmArgs &args, const Requantize32 &os);
} // namespace arm_gemm
diff --git a/src/core/NEON/kernels/arm_gemm/gemm_quint8.cpp b/src/core/NEON/kernels/arm_gemm/gemm_quint8.cpp
index 95139c2bf6..be7a4ee570 100644
--- a/src/core/NEON/kernels/arm_gemm/gemm_quint8.cpp
+++ b/src/core/NEON/kernels/arm_gemm/gemm_quint8.cpp
@@ -197,7 +197,7 @@ const GemmImplementation<uint8_t, uint8_t, Requantize32> *gemm_implementation_li
}
template UniqueGemmCommon<uint8_t, uint8_t> gemm<uint8_t, uint8_t, Requantize32>(const GemmArgs &args, const Requantize32 &os);
-template bool has_opt_gemm<uint8_t, uint8_t, Requantize32>(const GemmArgs &args, const Requantize32 &os);
+template bool has_opt_gemm<uint8_t, uint8_t, Requantize32>(WeightFormat &weight_format, const GemmArgs &args, const Requantize32 &os);
template std::vector<KernelDescription> get_compatible_kernels<uint8_t, uint8_t, Requantize32>(const GemmArgs &args, const Requantize32 &os);
} // namespace arm_gemm
diff --git a/src/core/NEON/kernels/arm_gemm/gemm_uint16.cpp b/src/core/NEON/kernels/arm_gemm/gemm_uint16.cpp
index 20cee556f0..fc836f9790 100644
--- a/src/core/NEON/kernels/arm_gemm/gemm_uint16.cpp
+++ b/src/core/NEON/kernels/arm_gemm/gemm_uint16.cpp
@@ -56,7 +56,7 @@ const GemmImplementation<uint16_t, uint32_t> *gemm_implementation_list<uint16_t,
/* Explicitly instantiate the external functions for these types. */
template UniqueGemmCommon<uint16_t, uint32_t> gemm<uint16_t, uint32_t, Nothing>(const GemmArgs &args, const Nothing &);
-template bool has_opt_gemm<uint16_t, uint32_t, Nothing>(const GemmArgs &args, const Nothing &);
+template bool has_opt_gemm<uint16_t, uint32_t, Nothing>(WeightFormat &weight_format, const GemmArgs &args, const Nothing &);
template std::vector<KernelDescription> get_compatible_kernels<uint16_t, uint32_t, Nothing>(const GemmArgs &args, const Nothing &);
} // namespace arm_gemm
diff --git a/src/core/NEON/kernels/arm_gemm/gemm_uint8.cpp b/src/core/NEON/kernels/arm_gemm/gemm_uint8.cpp
index a2d2cc86f0..03e9cd6c1f 100644
--- a/src/core/NEON/kernels/arm_gemm/gemm_uint8.cpp
+++ b/src/core/NEON/kernels/arm_gemm/gemm_uint8.cpp
@@ -157,7 +157,7 @@ const GemmImplementation<uint8_t, uint32_t> *gemm_implementation_list<uint8_t, u
/* Explicitly instantiate the external functions for these types. */
template UniqueGemmCommon<uint8_t, uint32_t> gemm<uint8_t, uint32_t, Nothing>(const GemmArgs &args, const Nothing &);
-template bool has_opt_gemm<uint8_t, uint32_t, Nothing>(const GemmArgs &args, const Nothing &);
+template bool has_opt_gemm<uint8_t, uint32_t, Nothing>(WeightFormat &weight_format, const GemmArgs &args, const Nothing &);
template std::vector<KernelDescription> get_compatible_kernels<uint8_t, uint32_t, Nothing> (const GemmArgs &args, const Nothing &);
} // namespace arm_gemm
diff --git a/src/core/NEON/kernels/arm_gemm/utils.hpp b/src/core/NEON/kernels/arm_gemm/utils.hpp
index 18e124b83e..d7b5398488 100644
--- a/src/core/NEON/kernels/arm_gemm/utils.hpp
+++ b/src/core/NEON/kernels/arm_gemm/utils.hpp
@@ -24,7 +24,7 @@
#pragma once
-#include "arm_gemm.hpp"
+#include "src/cpu/kernels/assembly/arm_gemm.hpp"
#include <cstddef>
#include <limits>
diff --git a/src/cpu/kernels/assembly/arm_gemm.hpp b/src/cpu/kernels/assembly/arm_gemm.hpp
index 247cb1d470..48fd7c6b43 100644
--- a/src/cpu/kernels/assembly/arm_gemm.hpp
+++ b/src/cpu/kernels/assembly/arm_gemm.hpp
@@ -47,6 +47,57 @@ enum class GemmMethod
GEMM_HYBRID_QUANTIZED
};
+/** Memory layouts for the weights tensor.
+ *
+ * * UNSPECIFIED is used to select kernels that do not run in
+ * variable weights mode.
+ *
+ * * ANY is used to query the kernel database to retrieve any of the
+ * kernels that runs in variable weights mode. Once a kernel is
+ * found, the specific format expected by the kernel can be
+ * retrieved by the user for reordering the weights tensor
+ * accordingly.
+ *
+ * The other values OHWIo{interleave_by}i{block_by} describe the
+ * memory layout of a 4D tensor with layout OHWI that has been
+ * transformed into a 4D tensor with dimensions O'HWI' where:
+ *
+ * O' = first multiple of {interleave_by} s.t. O<=O'
+ * I' = first multiple of {block_by} s.t. I<=I'
+ *
+ * The total size of the dst tensor is O' x H x W x I'
+ *
+ * The access function of the tensor with layout
+ * OHWIo{interleave_by}i{block_by} and size O'HWI' is a 6-parameter
+ * access function, where the 6 parameters are computed as follows:
+ *
+ * x5 = floor(o/{interleave_by}) RANGE [0, O'/{interleave_by} -1] SIZE: O'/{interleave_by}
+ *
+ * x4 = h RANGE [0, H-1] SIZE: H
+ * x3 = w RANGE [0, W-1] SIZE: W
+ * x2 = floor(i/{block_by}) RANGE [0, I'/{block_by} -1] SIZE: I'/{block_by}
+ * x1 = o%{interleave_by} RANGE [0, {interleave_by} -1] SIZE: {interleave_by}
+ * x0 = i%{block_by} RANGE [0, {block_by} -1] SIZE: {block_by}
+ * TOTAL SIZE: O' * H * W * I'
+ *
+ * 4D 6D
+ * ----------------- -----------------------------------
+ * value(o, h, w, i) = x5 * H * W * I' * {interleave_by}
+ * + x4 * W * I' * {interleave_by}
+ * + x3 * I' * {interleave_by}
+ * + x2 * {interleave_by} * {block_by}
+ * + x1 * {block_by}
+ * + x0
+ *
+ * Notice that in arm_gemm the 4D tensor of dimension O'HWI' created
+ * for the OHWIo{interleave_by}i{block_by} format is in reality seen
+ * as a 2D tensor, where the number of rows is O'/{interleave_by}
+ * and the number of columns is {interleave_by} * H * W * I'.
+ *
+ * The postfix *_bf16 is for the memory layout needed for the
+ * fast-mode kernels, in which the weights are passed in bfloat16
+ * format.
+ */
enum class WeightFormat
{
UNSPECIFIED = 0x1,
@@ -87,6 +138,69 @@ enum class WeightFormat
OHWIo64i8 = 0x804000
};
+// OHWIo<interleave_by>i<block_by>
+inline int interleave_by(const WeightFormat wf)
+{
+ return ((int)wf >> 8) & 0xFFF;
+}
+inline int block_by(const WeightFormat wf)
+{
+ return ((int)wf >> 20) & 0xF;
+}
+inline bool is_fixed_format(const WeightFormat wf)
+{
+ return wf != WeightFormat::UNSPECIFIED && wf != WeightFormat::ANY;
+}
+
+inline std::string to_string(WeightFormat wf)
+{
+#define __CASE_WEIGHT_FORMAT(wf) \
+case WeightFormat::wf: \
+ return #wf;
+ switch(wf)
+ {
+ __CASE_WEIGHT_FORMAT(UNSPECIFIED)
+ __CASE_WEIGHT_FORMAT(ANY)
+ __CASE_WEIGHT_FORMAT(OHWI)
+ __CASE_WEIGHT_FORMAT(OHWIo2)
+ __CASE_WEIGHT_FORMAT(OHWIo4)
+ __CASE_WEIGHT_FORMAT(OHWIo8)
+ __CASE_WEIGHT_FORMAT(OHWIo16)
+ __CASE_WEIGHT_FORMAT(OHWIo32)
+ __CASE_WEIGHT_FORMAT(OHWIo64)
+ __CASE_WEIGHT_FORMAT(OHWIo128)
+ __CASE_WEIGHT_FORMAT(OHWIo4i2)
+ __CASE_WEIGHT_FORMAT(OHWIo4i2_bf16)
+ __CASE_WEIGHT_FORMAT(OHWIo8i2)
+ __CASE_WEIGHT_FORMAT(OHWIo8i2_bf16)
+ __CASE_WEIGHT_FORMAT(OHWIo16i2)
+ __CASE_WEIGHT_FORMAT(OHWIo16i2_bf16)
+ __CASE_WEIGHT_FORMAT(OHWIo32i2)
+ __CASE_WEIGHT_FORMAT(OHWIo32i2_bf16)
+ __CASE_WEIGHT_FORMAT(OHWIo64i2)
+ __CASE_WEIGHT_FORMAT(OHWIo64i2_bf16)
+ __CASE_WEIGHT_FORMAT(OHWIo4i4)
+ __CASE_WEIGHT_FORMAT(OHWIo4i4_bf16)
+ __CASE_WEIGHT_FORMAT(OHWIo8i4)
+ __CASE_WEIGHT_FORMAT(OHWIo8i4_bf16)
+ __CASE_WEIGHT_FORMAT(OHWIo16i4)
+ __CASE_WEIGHT_FORMAT(OHWIo16i4_bf16)
+ __CASE_WEIGHT_FORMAT(OHWIo32i4)
+ __CASE_WEIGHT_FORMAT(OHWIo32i4_bf16)
+ __CASE_WEIGHT_FORMAT(OHWIo64i4)
+ __CASE_WEIGHT_FORMAT(OHWIo64i4_bf16)
+ __CASE_WEIGHT_FORMAT(OHWIo2i8)
+ __CASE_WEIGHT_FORMAT(OHWIo4i8)
+ __CASE_WEIGHT_FORMAT(OHWIo8i8)
+ __CASE_WEIGHT_FORMAT(OHWIo16i8)
+ __CASE_WEIGHT_FORMAT(OHWIo32i8)
+ __CASE_WEIGHT_FORMAT(OHWIo64i8)
+ default:
+ return "invalid value";
+ }
+#undef __CASE_WEIGHT_FORMAT
+}
+
struct KernelDescription
{
GemmMethod method = GemmMethod::DEFAULT;
@@ -230,6 +344,6 @@ template <typename Top, typename Tret, class OutputStage = Nothing>
std::vector<KernelDescription> get_compatible_kernels(const GemmArgs &args, const OutputStage & = {});
template <typename Top, typename Tret, class OutputStage = Nothing>
-bool has_opt_gemm(const GemmArgs &args, const OutputStage & = {});
+bool has_opt_gemm(WeightFormat &weight_format, const GemmArgs &args, const OutputStage & = {});
} // namespace arm_gemm
diff --git a/src/cpu/operators/CpuGemm.cpp b/src/cpu/operators/CpuGemm.cpp
index 61cd11ece0..f3fff608dc 100644
--- a/src/cpu/operators/CpuGemm.cpp
+++ b/src/cpu/operators/CpuGemm.cpp
@@ -51,6 +51,7 @@ cpu::AsmGemmInfo init_assembly_metadata(const GEMMInfo &info)
asm_info.activation_info = info.activation_info();
asm_info.fast_mode = info.fast_math();
asm_info.fixed_format = info.fixed_format();
+ asm_info.weight_format = info.weight_format();
return asm_info;
}
@@ -177,7 +178,8 @@ Status CpuGemm::validate(const ITensorInfo *a, const ITensorInfo *b, const ITens
if(d->total_size() != 0)
{
- ARM_COMPUTE_RETURN_ERROR_ON(b->dimension(0) != d->dimension(0));
+ // For fixed format we are expecting some kind of blocked format for B/RHS so the dimension won't necessarily match the result matrix any more.
+ ARM_COMPUTE_RETURN_ERROR_ON(!gemm_info.fixed_format() && b->dimension(0) != d->dimension(0));
if(gemm_info.depth_output_gemm3d() != 0)
{
if(gemm_info.reinterpret_input_as_3d())
@@ -277,7 +279,7 @@ void CpuGemm::run(ITensorPack &tensors)
auto c = tensors.get_const_tensor(ACL_SRC_2);
auto d = tensors.get_tensor(ACL_DST);
- if(_asm_glue->is_configured())
+ if(_asm_glue && _asm_glue->is_configured())
{
// Pass c to asm dispatch only if it's the bias tensor
ITensorPack asm_pack = tensors;
@@ -343,7 +345,7 @@ void CpuGemm::prepare(ITensorPack &tensors)
{
if(!_is_prepared)
{
- if(_asm_glue->is_configured())
+ if(_asm_glue && _asm_glue->is_configured())
{
_asm_glue->prepare(tensors);
}
@@ -365,5 +367,18 @@ experimental::MemoryRequirements CpuGemm::workspace() const
{
return _aux_mem;
}
+
+Status CpuGemm::has_opt_impl(arm_gemm::WeightFormat &expected_weight_format, const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *d,
+ const GEMMInfo &gemm_info)
+{
+ const cpu::AsmGemmInfo asm_info = init_assembly_metadata(gemm_info);
+
+ return CpuGemmAssemblyDispatch::has_opt_impl(expected_weight_format, a, b, c, d, asm_info);
+}
+
+bool CpuGemm::isVarWeightsKernel() const
+{
+ return _asm_glue && _asm_glue->isVarWeightsKernel();
+}
} // namespace cpu
} // namespace arm_compute
diff --git a/src/cpu/operators/CpuGemm.h b/src/cpu/operators/CpuGemm.h
index 334ab6c647..b37ab73485 100644
--- a/src/cpu/operators/CpuGemm.h
+++ b/src/cpu/operators/CpuGemm.h
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2021 Arm Limited.
+ * Copyright (c) 2021-2022 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -101,11 +101,29 @@ public:
static Status validate(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *d,
float alpha, float beta, const GEMMInfo &gemm_info = GEMMInfo());
+ /** Indicates whether or not there is an optimal assembly implementation that can be used to process the given parameters.
+ *
+ * This method has the same use of @ref
+ * NEGEMMConvolutionLayer::has_opt_impl, with the only caveat that
+ * the value of arm_gemm::WeightFormat need to be passed via the
+ * parameter gemm_info.
+ */
+ static Status has_opt_impl(arm_gemm::WeightFormat &weight_format, const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *d,
+ const GEMMInfo &gemm_info = GEMMInfo());
+
// Inherited methods overridden:
- void run(ITensorPack &tensors) override;
- void prepare(ITensorPack &constants) override;
+ void run(ITensorPack &tensors) override;
+ void prepare(ITensorPack &constants) override;
experimental::MemoryRequirements workspace() const override;
+ /** Indicates if the convolution executes in variable weights mode.
+ *
+ * When ACL executes convolution in variable weights mode, it does
+ * not perform any processing of the weights tensor. Instead, it
+ * utilizes the data as it is given by the user.
+ */
+ bool isVarWeightsKernel() const;
+
private:
enum AuxTensorIdx
{
diff --git a/src/cpu/operators/CpuGemmConv2d.cpp b/src/cpu/operators/CpuGemmConv2d.cpp
index c021d31059..0174d0eed3 100644
--- a/src/cpu/operators/CpuGemmConv2d.cpp
+++ b/src/cpu/operators/CpuGemmConv2d.cpp
@@ -99,15 +99,15 @@ CpuGemmConv2d::CpuGemmConv2d()
CpuGemmConv2d::~CpuGemmConv2d() = default;
void CpuGemmConv2d::configure_mm(const ITensorInfo *src, const ITensorInfo *weights, const ITensorInfo *biases, ITensorInfo *dst, const ActivationLayerInfo &act_info,
- bool enable_fast_math, int gemm_3d_depth)
+ bool enable_fast_math, int gemm_3d_depth, bool fixed_format, arm_gemm::WeightFormat weight_format)
{
ARM_COMPUTE_ERROR_ON_NULLPTR(src, weights);
- ARM_COMPUTE_ERROR_THROW_ON(validate_mm(src, weights, biases, dst, act_info, enable_fast_math, gemm_3d_depth, _skip_im2col));
+ ARM_COMPUTE_ERROR_THROW_ON(validate_mm(src, weights, biases, dst, act_info, enable_fast_math, gemm_3d_depth, _skip_im2col, fixed_format, weight_format));
// Create GEMMInfo structure
const GEMMInfo &gemm_info = GEMMInfo(false, false, true /* Reshape weights only for the first run */,
gemm_3d_depth, _skip_im2col /* Reinterpret the input as 3D if im2col is skipped */,
- false, GEMMLowpOutputStageInfo(), false, enable_fast_math, false, act_info);
+ false, GEMMLowpOutputStageInfo(), false, enable_fast_math, false, act_info, experimental::PostOpList<ITensorInfo *>(), fixed_format, weight_format);
// Supported activations in GEMM
const std::set<ActivationLayerInfo::ActivationFunction> supported_acts = { ActivationLayerInfo::ActivationFunction::RELU,
@@ -156,7 +156,8 @@ void CpuGemmConv2d::configure_mm(const ITensorInfo *src, const ITensorInfo *weig
quantization::calculate_quantized_multipliers(iqinfo, wqinfo, oqinfo, output_info);
_mm_gemmlowp = std::make_unique<CpuGemmLowpMatrixMultiplyCore>();
- _mm_gemmlowp->configure(&tmp_src, &tmp_weights, biases, dst, GEMMInfo(false, false, true, gemm_3d_depth, _skip_im2col, false, output_info, false, enable_fast_math, false, act_info));
+ _mm_gemmlowp->configure(&tmp_src, &tmp_weights, biases, dst, GEMMInfo(false, false, true, gemm_3d_depth, _skip_im2col, false, output_info, false, enable_fast_math, false, act_info,
+ experimental::PostOpList<ITensorInfo *>(), fixed_format, weight_format));
auto mm_mem_req = _mm_gemmlowp->workspace();
for(unsigned int cont = 0; cont < mm_mem_req.size(); ++cont)
@@ -178,7 +179,7 @@ void CpuGemmConv2d::configure_mm(const ITensorInfo *src, const ITensorInfo *weig
}
Status CpuGemmConv2d::validate_mm(const ITensorInfo *src, const ITensorInfo *weights, const ITensorInfo *biases, const ITensorInfo *dst,
- const ActivationLayerInfo &act_info, bool enable_fast_math, int gemm_3d_depth, bool skip_im2col)
+ const ActivationLayerInfo &act_info, bool enable_fast_math, int gemm_3d_depth, bool skip_im2col, bool fixed_format, arm_gemm::WeightFormat weight_format)
{
const DataType data_type = src->data_type();
const bool is_quantized = is_data_type_quantized_asymmetric(data_type);
@@ -187,7 +188,7 @@ Status CpuGemmConv2d::validate_mm(const ITensorInfo *src, const ITensorInfo *wei
// Create GEMMInfo structure
const GEMMInfo gemm_info = GEMMInfo(false, false, true /* Reshape weights only for the first run */,
gemm_3d_depth, skip_im2col /* Reinterpret the input as 3D if im2col is skipped */,
- false, GEMMLowpOutputStageInfo(), false, enable_fast_math, false, act_info);
+ false, GEMMLowpOutputStageInfo(), false, enable_fast_math, false, act_info, experimental::PostOpList<ITensorInfo *>(), fixed_format, weight_format);
if(is_quantized)
{
@@ -227,6 +228,7 @@ Status CpuGemmConv2d::validate_mm(const ITensorInfo *src, const ITensorInfo *wei
std::unique_ptr<ITensorInfo> weights_qa = weights->clone();
input_qa->set_quantization_info(QuantizationInfo(iqinfo.uniform().scale, -iqinfo.uniform().offset));
weights_qa->set_quantization_info(QuantizationInfo(wqinfo.uniform().scale, -wqinfo.uniform().offset));
+
return CpuGemmLowpMatrixMultiplyCore::validate(input_qa.get(), weights_qa.get(), biases, dst, GEMMInfo(false, false, true, gemm_3d_depth, skip_im2col, false, output_info, false, enable_fast_math,
false, act_info));
}
@@ -294,6 +296,7 @@ void CpuGemmConv2d::configure(const ITensorInfo *src, const ITensorInfo *weights
kernel_height,
conv_info,
dilation);
+
ARM_COMPUTE_ERROR_ON_MSG((dst->dimension(idx_width) != conv_w) || (dst->dimension(idx_height) != conv_h),
"Output shape does not match the expected one");
@@ -357,7 +360,8 @@ void CpuGemmConv2d::configure(const ITensorInfo *src, const ITensorInfo *weights
// Configure GEMM
// In case we need to skip col2im, GEMM3D (gemm_3d_depth != 0) must be called in order to avoid reshaping the output matrix
const unsigned int gemm_3d_depth = _skip_col2im ? conv_h : 0;
- configure_mm(gemm_input_to_use, &_weights_reshaped, biases, gemm_output_to_use, act_info, enable_fast_math, gemm_3d_depth);
+ const bool fixed_format = weights_info.weight_format() != arm_gemm::WeightFormat::UNSPECIFIED;
+ configure_mm(gemm_input_to_use, &_weights_reshaped, biases, gemm_output_to_use, act_info, enable_fast_math, gemm_3d_depth, fixed_format, weights_info.weight_format());
if(!_skip_col2im && _data_layout == DataLayout::NCHW)
{
@@ -384,6 +388,38 @@ void CpuGemmConv2d::configure(const ITensorInfo *src, const ITensorInfo *weights
_aux_mem[GemmOutput] = MemoryInfo(offset_int_vec(GemmOutput), MemoryLifetime::Temporary, _gemm_output.total_size());
}
+Status CpuGemmConv2d::has_opt_impl(arm_gemm::WeightFormat &expected_weight_format, const ITensorInfo *src, const ITensorInfo *weights, const ITensorInfo *biases, const ITensorInfo *dst,
+ const PadStrideInfo &conv_info,
+ const WeightsInfo &weights_info, const Size2D &dilation, const ActivationLayerInfo &act_info, const bool enable_fast_math)
+{
+ const DataLayout data_layout = src->data_layout();
+ const int idx_width = get_data_layout_dimension_index(data_layout, DataLayoutDimension::WIDTH);
+ const int idx_height = get_data_layout_dimension_index(data_layout, DataLayoutDimension::HEIGHT);
+ const unsigned int kernel_width = weights->dimension(idx_width);
+ const unsigned int kernel_height = weights->dimension(idx_height);
+ unsigned int conv_w = 0;
+ unsigned int conv_h = 0;
+ std::tie(conv_w, conv_h) = scaled_dimensions(src->dimension(idx_width),
+ src->dimension(idx_height),
+ kernel_width,
+ kernel_height,
+ conv_info,
+ dilation);
+
+ const CpuGemmConv2d::SkipInfo skip_info = CpuGemmConv2d::skip_im_col_info(src, weights, conv_info,
+ dilation, act_info);
+
+ const bool skip_im2col = skip_info.skip_im2col;
+ const bool skip_col2im = skip_info.skip_col2im;
+ const unsigned int gemm_3d_depth = skip_col2im ? conv_h : 0;
+ const bool fixed_format = weights_info.weight_format() != arm_gemm::WeightFormat::UNSPECIFIED;
+ const GEMMInfo gemm_info = GEMMInfo(false, false, true /* Reshape weights only for the first run */,
+ gemm_3d_depth, skip_im2col /* Reinterpret the input as 3D if im2col is skipped */,
+ false, GEMMLowpOutputStageInfo(), false, enable_fast_math, false, act_info, experimental::PostOpList<ITensorInfo *>(), fixed_format, weights_info.weight_format());
+
+ return CpuGemm::has_opt_impl(expected_weight_format, src, weights, biases, dst, gemm_info);
+}
+
Status CpuGemmConv2d::validate(const ITensorInfo *src, const ITensorInfo *weights, const ITensorInfo *biases, const ITensorInfo *dst, const PadStrideInfo &conv_info,
const WeightsInfo &weights_info, const Size2D &dilation, const ActivationLayerInfo &act_info, bool enable_fast_math, unsigned int num_groups)
{
@@ -450,7 +486,7 @@ Status CpuGemmConv2d::validate(const ITensorInfo *src, const ITensorInfo *weight
{
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(src, biases);
}
- ARM_COMPUTE_RETURN_ERROR_ON(biases->dimension(0) != weights->dimension(idx_kernels));
+ ARM_COMPUTE_RETURN_ERROR_ON(biases->dimension(0) != dst->dimension(idx_channel));
ARM_COMPUTE_RETURN_ERROR_ON(biases->num_dimensions() > 1);
}
@@ -472,7 +508,7 @@ Status CpuGemmConv2d::validate(const ITensorInfo *src, const ITensorInfo *weight
im2col_reshaped_info = TensorInfo(shape_im2col, 1, data_type);
im2col_reshaped_info.set_quantization_info(src->quantization_info());
- ARM_COMPUTE_RETURN_ON_ERROR(kernels::CpuIm2ColKernel::validate(src, &im2col_reshaped_info, Size2D(kernel_width, kernel_height), conv_info, append_bias, dilation));
+ ARM_COMPUTE_RETURN_ON_ERROR(kernels::CpuIm2ColKernel::validate(src, &im2col_reshaped_info, Size2D(kernel_width, kernel_height), conv_info, append_bias, dilation, 1));
gemm_input_to_use = &im2col_reshaped_info;
}
@@ -490,8 +526,11 @@ Status CpuGemmConv2d::validate(const ITensorInfo *src, const ITensorInfo *weight
info_gemm = TensorInfo(dst->tensor_shape(), 1, output_data_type);
}
info_gemm.set_quantization_info(dst->quantization_info()).set_data_layout(src->data_layout());
- gemm_output_to_use = &info_gemm;
- ARM_COMPUTE_RETURN_ON_ERROR(validate_mm(gemm_input_to_use, weights_to_use, biases, gemm_output_to_use, act_info, enable_fast_math, skip_col2im ? conv_h : 0, skip_im2col));
+ gemm_output_to_use = &info_gemm;
+ const bool fixed_format = weights_info.weight_format() != arm_gemm::WeightFormat::UNSPECIFIED;
+
+ ARM_COMPUTE_RETURN_ON_ERROR(validate_mm(gemm_input_to_use, weights_to_use, biases, gemm_output_to_use, act_info, enable_fast_math, skip_col2im ? conv_h : 0, skip_im2col, fixed_format,
+ weights_info.weight_format()));
// Validate Col2Im/ReshapeLayer
if(!skip_col2im && (data_layout == DataLayout::NCHW))
@@ -548,7 +587,10 @@ void CpuGemmConv2d::run(ITensorPack &tensors)
// Runs CpuGemm or CpuGemmLowpMatrixMultiplyCore functions
ITensorPack pack_mm = tensors;
pack_mm.add_const_tensor(TensorType::ACL_SRC_0, gemm_input_to_use);
- pack_mm.add_const_tensor(TensorType::ACL_SRC_1, reshaped_wei.get());
+ if(!this->isVarWeightsKernel())
+ {
+ pack_mm.add_const_tensor(TensorType::ACL_SRC_1, reshaped_wei.get());
+ }
pack_mm.add_tensor(TensorType::ACL_DST, gemm_output_to_use);
if(_is_quantized)
{
@@ -598,6 +640,15 @@ void CpuGemmConv2d::prepare(ITensorPack &tensors)
{
if(!_is_prepared)
{
+ // Variable weights executions that use fixed-format kernels
+ // need no reshaping of the weights.
+ if(this->isVarWeightsKernel())
+ {
+ _is_quantized ? _mm_gemmlowp->prepare(tensors) : _mm_gemm->prepare(tensors);
+ _is_prepared = true;
+ return;
+ }
+
// Run weights reshaping and mark original weights tensor as unused
CpuAuxTensorHandler weights_reshaped(offset_int_vec(WeightsReshaped), _weights_reshaped, tensors);
auto weights = tensors.get_const_tensor(TensorType::ACL_SRC_1);
@@ -608,12 +659,9 @@ void CpuGemmConv2d::prepare(ITensorPack &tensors)
};
NEScheduler::get().schedule_op(_weights_reshape_kernel.get(), 3, _weights_reshape_kernel->window(), pack);
weights->mark_as_unused();
-
- // Prepare GEMM
ITensorPack gemm_pack = tensors;
gemm_pack.add_const_tensor(TensorType::ACL_SRC_1, weights_reshaped.get());
_is_quantized ? _mm_gemmlowp->prepare(gemm_pack) : _mm_gemm->prepare(gemm_pack);
-
_is_prepared = true;
}
}
@@ -621,5 +669,9 @@ experimental::MemoryRequirements CpuGemmConv2d::workspace() const
{
return _aux_mem;
}
+bool CpuGemmConv2d::isVarWeightsKernel() const
+{
+ return _mm_gemm && _mm_gemm->isVarWeightsKernel();
+}
} // namespace cpu
} // namespace arm_compute
diff --git a/src/cpu/operators/CpuGemmConv2d.h b/src/cpu/operators/CpuGemmConv2d.h
index aec4a2ffa5..f8f0bce048 100644
--- a/src/cpu/operators/CpuGemmConv2d.h
+++ b/src/cpu/operators/CpuGemmConv2d.h
@@ -117,6 +117,17 @@ public:
const WeightsInfo &weights_info = WeightsInfo(), const Size2D &dilation = Size2D(1U, 1U), const ActivationLayerInfo &act_info = ActivationLayerInfo(),
bool enable_fast_math = false, unsigned int num_groups = 1);
+ /** Indicates whether or not there is an optimal assembly implementation that can be used to process the given parameters.
+ *
+ * The paramter list is the same as @ref NEGEMMConvolutionLayer::has_opt_impl
+ *
+ * @return a status.
+ */
+ static Status has_opt_impl(arm_gemm::WeightFormat &expected_weight_format, const ITensorInfo *src, const ITensorInfo *weights, const ITensorInfo *biases, const ITensorInfo *output,
+ const PadStrideInfo &conv_info,
+ const WeightsInfo &weights_info = WeightsInfo(), const Size2D &dilation = Size2D(1U, 1U), const ActivationLayerInfo &act_info = ActivationLayerInfo(),
+ const bool enable_fast_math = false);
+
// Inherited methods overridden:
void run(ITensorPack &tensors) override;
void prepare(ITensorPack &tensors) override;
@@ -135,9 +146,11 @@ private:
* @param[in] enable_fast_math (Optional) Enable fast math computation. In case this flag were set, the function could dispatch the fastest implementation
* available which may introduce a drop of accuracy as well. Default is false
* @param[in] gemm_3d_depth (Optional) Depth of GEMM 3D (Defaults to 1)
+ * @param[in] fixed_format (Optional) Select GEMM execution with variable weights.
+ * @param[in] weight_format (Optional) The layout to be used for the weights tensor when running GEMM with variable weights.
*/
void configure_mm(const ITensorInfo *src, const ITensorInfo *weights, const ITensorInfo *biases, ITensorInfo *output, const ActivationLayerInfo &act_info = ActivationLayerInfo(),
- bool enable_fast_math = false, int gemm_3d_depth = 1);
+ bool enable_fast_math = false, int gemm_3d_depth = 1, bool fixed_format = false, arm_gemm::WeightFormat weight_format = arm_gemm::WeightFormat::UNSPECIFIED);
/** Static function to check if given info will lead to a valid configuration of @ref NEGEMMConvolutionLayer matrix multiply routines
*
* @param[in] src Input tensor info. Data types supported: QASYMM8/QASYMM8_SIGNED/BFLOAT16/F16/F32.
@@ -151,11 +164,13 @@ private:
* available which may introduce a drop of accuracy as well. Default is false
* @param[in] gemm_3d_depth (Optional) Depth of GEMM 3D (Defaults to 1)
* @param[in] skip_im2col (Optional) Flag which specifies if im2col has to be skipped. i.e. 1x1 convolution with NHWC data layout. (Default to false)
+ * @param[in] fixed_format (Optional) Select GEMM execution with variable weights.
+ * @param[in] weight_format (Optional) The layout to be used for the weights tensor when running GEMM with variable weights.
*
* @return a status
*/
static Status validate_mm(const ITensorInfo *src, const ITensorInfo *weights, const ITensorInfo *biases, const ITensorInfo *dst, const ActivationLayerInfo &act_info = ActivationLayerInfo(),
- bool enable_fast_math = false, int gemm_3d_depth = 1, bool skip_im2col = false);
+ bool enable_fast_math = false, int gemm_3d_depth = 1, bool skip_im2col = false, bool fixed_format = false, arm_gemm::WeightFormat weight_format = arm_gemm::WeightFormat::UNSPECIFIED);
/** Static function to check if GEMM3D is supported in @ref NEGEMM or in @ref CpuGemmMLowpMatrixMultiplyCore
*
* @param[in] src Input tensor info. Data types supported: QASYMM8/QASYMM8_SIGNED/BFLOAT16/F16/F32.
@@ -187,6 +202,11 @@ private:
static SkipInfo skip_im_col_info(const ITensorInfo *src, const ITensorInfo *weights, const PadStrideInfo &conv_info,
const Size2D &dilation, const ActivationLayerInfo &act_info);
+ /** Indicates if the convolution executes in variable weights mode.
+ *
+ * Similar to @ref CpuGemm::isVarWeightsKernel
+ */
+ bool isVarWeightsKernel() const;
enum AuxTensorIdx
{
// CpuGemmLowpMatrixMultiplyCore has up to 8 internal tensors
diff --git a/src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp b/src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp
index 787ea95372..5694a3d9ee 100644
--- a/src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp
+++ b/src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp
@@ -160,6 +160,13 @@ public:
void prepare(ITensorPack &tensors) override;
bool is_configured() const override;
experimental::MemoryRequirements workspace() const override;
+ bool isVarWeightsKernel() const override
+ {
+ if(!_gemm_kernel_asm)
+ return false;
+ const arm_gemm::WeightFormat wf = _gemm_kernel_asm->get_config().weight_format;
+ return wf != arm_gemm::WeightFormat::UNSPECIFIED && wf != arm_gemm::WeightFormat::ANY;
+ }
private:
enum AuxTensorIdx
@@ -420,6 +427,8 @@ void Fallback<TypeInput, TypeOutput, OutputStage>::prepare(ITensorPack &tensors)
// Pretranspose B if required
if(_gemm_kernel_asm->B_pretranspose_required())
{
+ // Fixed format kernels need no pretranspose.
+ ARM_COMPUTE_ERROR_ON(arm_gemm::is_fixed_format(_gemm_kernel_asm->get_config().weight_format));
const int ldb = b->info()->strides_in_bytes().y() / sizeof(TypeInput);
const auto in1_ptr = reinterpret_cast<const TypeInput *>(b->buffer() + b->info()->offset_first_element_in_bytes());
const int multi_stride_b = b->info()->strides_in_bytes().z() / sizeof(TypeInput);
@@ -483,7 +492,24 @@ void Fallback<TypeInput, TypeOutput, OutputStage>::run(ITensorPack &tensors)
// Check if B is pre-tranposed and de-reference if not
if(!_gemm_kernel_asm->B_is_pretransposed())
{
- ldb = b->info()->strides_in_bytes().y() / sizeof(TypeInput);
+ ldb = b->info()->strides_in_bytes().y() / sizeof(TypeInput);
+ const arm_gemm::WeightFormat wf = _gemm_kernel_asm->get_config().weight_format;
+ if(is_fixed_format(wf))
+ {
+ // The 4D tensor of dimension O'HWI' created for the
+ // OHWIo<interleave_by>i<block_by> format is in reality seen
+ // as a 2D tensor at arm_gemm level, where the rows are
+ // O'/<interleave_by> and the columns are <interleave_by> *
+ // H * W * I'.
+ ITensorInfo *tensor_info = b->info();
+ const DataLayout data_layout = tensor_info->data_layout();
+ const TensorShape tensor_shape = tensor_info->tensor_shape();
+ const int H = tensor_shape[get_data_layout_dimension_index(data_layout, DataLayoutDimension::HEIGHT)];
+ const int W = tensor_shape[get_data_layout_dimension_index(data_layout, DataLayoutDimension::WIDTH)];
+ const int Ip = tensor_shape[get_data_layout_dimension_index(data_layout, DataLayoutDimension::CHANNEL)];
+ const int interleave_by = arm_gemm::interleave_by(wf);
+ ldb = (interleave_by * H * W * Ip);
+ }
multi_stride_b = b->info()->strides_in_bytes().z() / sizeof(TypeInput);
in1_ptr = reinterpret_cast<const TypeInput *>(b->buffer() + b->info()->offset_first_element_in_bytes());
}
@@ -576,7 +602,9 @@ void create_arm_gemm(std::unique_ptr<CpuGemmAssemblyDispatch::IFallback> &arm_ge
const CPUInfo &ci = NEScheduler::get().cpu_info();
unsigned int num_threads = NEScheduler::get().num_threads();
- arm_gemm::GemmArgs args(&ci, p.M, p.N, p.K, p.sections, p.batches, p.multis, p.indirect, activation, num_threads, info.fixed_format, info.fast_mode);
+ arm_gemm::GemmConfig cfg;
+ cfg.weight_format = info.weight_format;
+ arm_gemm::GemmArgs args(&ci, p.M, p.N, p.K, p.sections, p.batches, p.multis, p.indirect, activation, num_threads, info.fixed_format, info.fast_mode, &cfg);
// Create arm_gemm fallback
auto fallback = std::make_unique<Fallback<TypeInput, TypeOutput>>();
@@ -594,7 +622,9 @@ void create_arm_gemm_quant(std::unique_ptr<CpuGemmAssemblyDispatch::IFallback> &
const CPUInfo &ci = NEScheduler::get().cpu_info();
const unsigned int num_threads = NEScheduler::get().num_threads();
- arm_gemm::GemmArgs args(&ci, p.M, p.N, p.K, p.sections, p.batches, p.multis, p.indirect, activation, num_threads, info.fixed_format, info.fast_mode);
+ arm_gemm::GemmConfig cfg;
+ cfg.weight_format = info.weight_format;
+ arm_gemm::GemmArgs args(&ci, p.M, p.N, p.K, p.sections, p.batches, p.multis, p.indirect, activation, num_threads, info.fixed_format, info.fast_mode, &cfg);
// Create arm_gemm fallback
auto fallback = std::make_unique<Fallback<TypeInput, TypeOutput, arm_gemm::Requantize32>>();
@@ -635,7 +665,8 @@ CpuGemmAssemblyDispatch::CpuGemmAssemblyDispatch()
{
}
-Status CpuGemmAssemblyDispatch::has_opt_impl(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *d, const AsmGemmInfo &info)
+Status CpuGemmAssemblyDispatch::has_opt_impl(arm_gemm::WeightFormat &expected_weight_format, const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *d,
+ const AsmGemmInfo &info)
{
ARM_COMPUTE_ERROR_ON_NULLPTR(a, b, d);
ARM_COMPUTE_UNUSED(c);
@@ -643,12 +674,14 @@ Status CpuGemmAssemblyDispatch::has_opt_impl(const ITensorInfo *a, const ITensor
Params p = extract_parameters(a, b, d, info);
const CPUInfo &ci = NEScheduler::get().cpu_info();
unsigned int num_threads = NEScheduler::get().num_threads();
+ arm_gemm::GemmConfig cfg;
+ cfg.weight_format = info.weight_format;
- arm_gemm::GemmArgs args(&ci, p.M, p.N, p.K, p.sections, p.batches, p.multis, p.indirect, act, num_threads, info.fixed_format, info.fast_mode);
+ arm_gemm::GemmArgs args(&ci, p.M, p.N, p.K, p.sections, p.batches, p.multis, p.indirect, act, num_threads, info.fixed_format, info.fast_mode, &cfg);
switch(a->data_type())
{
case DataType::F32:
- ARM_COMPUTE_RETURN_ERROR_ON_MSG(!(arm_gemm::has_opt_gemm<float, float, arm_gemm::Nothing>(args, {})),
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(!(arm_gemm::has_opt_gemm<float, float, arm_gemm::Nothing>(expected_weight_format, args, {})),
"We could not find an optimized kernel for F32 input");
break;
#ifdef __aarch64__
@@ -656,12 +689,12 @@ Status CpuGemmAssemblyDispatch::has_opt_impl(const ITensorInfo *a, const ITensor
case DataType::QASYMM8:
if(d->data_type() == DataType::S32)
{
- ARM_COMPUTE_RETURN_ERROR_ON_MSG(!(arm_gemm::has_opt_gemm<uint8_t, uint32_t, arm_gemm::Nothing>(args, {})),
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(!(arm_gemm::has_opt_gemm<uint8_t, uint32_t, arm_gemm::Nothing>(expected_weight_format, args, {})),
"We could not find an optimized kernel for U8/QASYMM8 input and S32 output");
}
else
{
- ARM_COMPUTE_RETURN_ERROR_ON_MSG(!(arm_gemm::has_opt_gemm<uint8_t, uint8_t, arm_gemm::Requantize32>(args, {})),
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(!(arm_gemm::has_opt_gemm<uint8_t, uint8_t, arm_gemm::Requantize32>(expected_weight_format, args, {})),
"We could not find an optimized kernel for U8 input and U8 output");
}
break;
@@ -669,12 +702,12 @@ Status CpuGemmAssemblyDispatch::has_opt_impl(const ITensorInfo *a, const ITensor
case DataType::QASYMM8_SIGNED:
if(d->data_type() == DataType::S32)
{
- ARM_COMPUTE_RETURN_ERROR_ON_MSG(!(arm_gemm::has_opt_gemm<int8_t, int32_t, arm_gemm::Nothing>(args, {})),
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(!(arm_gemm::has_opt_gemm<int8_t, int32_t, arm_gemm::Nothing>(expected_weight_format, args, {})),
"We could not find an optimized kernel for S8/QASYMM8_SIGNED input and S32 output");
}
else
{
- ARM_COMPUTE_RETURN_ERROR_ON_MSG(!(arm_gemm::has_opt_gemm<int8_t, int8_t, arm_gemm::Requantize32>(args, {})),
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(!(arm_gemm::has_opt_gemm<int8_t, int8_t, arm_gemm::Requantize32>(expected_weight_format, args, {})),
"We could not find an optimized kernel for S8 input and S32 output");
}
break;
@@ -689,7 +722,7 @@ Status CpuGemmAssemblyDispatch::has_opt_impl(const ITensorInfo *a, const ITensor
#endif /* defined(__ARM_FEATURE_BF16_VECTOR_ARITHMETIC) || defined(ARM_COMPUTE_FORCE_BF16) */
#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
case DataType::F16:
- ARM_COMPUTE_RETURN_ERROR_ON_MSG(!(arm_gemm::has_opt_gemm<float16_t, float16_t, arm_gemm::Nothing>(args, {})),
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(!(arm_gemm::has_opt_gemm<float16_t, float16_t, arm_gemm::Nothing>(expected_weight_format, args, {})),
"We could not find an optimized kernel for BFLOAT16 input and F32 output");
break;
#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
@@ -729,7 +762,17 @@ Status CpuGemmAssemblyDispatch::validate(const ITensorInfo *a, const ITensorInfo
ARM_COMPUTE_RETURN_ERROR_ON_MSG(a->data_type() == DataType::U8 && d->data_type() != DataType::U32, "Only U32 output supported for U8 input");
ARM_COMPUTE_RETURN_ERROR_ON_MSG(a->data_type() == DataType::S8 && d->data_type() != DataType::S32, "Only S32 output supported for S8 input");
ARM_COMPUTE_RETURN_ERROR_ON_MSG(a->data_type() == DataType::QASYMM8 && d->data_type() != DataType::QASYMM8, "Only QASYMM8 output supported for QASYMM8 input");
- return CpuGemmAssemblyDispatch::has_opt_impl(a, b, c, d, info);
+ arm_gemm::WeightFormat expected_weight_format;
+ const Status ret = CpuGemmAssemblyDispatch::has_opt_impl(expected_weight_format, a, b, c, d, info);
+ if((bool)ret && expected_weight_format != arm_gemm::WeightFormat::ANY)
+ {
+ // Correctness check: if the format expected by the kernel is
+ // not "any", make sure that the one found matches the format
+ // intended by the caller.
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG((expected_weight_format != info.weight_format),
+ "The format expected by the kernel does not correspond with the one requested by the user.");
+ }
+ return ret;
}
bool CpuGemmAssemblyDispatch::is_activation_supported(const ActivationLayerInfo &activation)
@@ -801,7 +844,7 @@ void CpuGemmAssemblyDispatch::prepare(ITensorPack &tensors)
bool CpuGemmAssemblyDispatch::is_configured() const
{
- return _arm_gemm != nullptr && _arm_gemm->is_configured();
+ return _arm_gemm && _arm_gemm->is_configured();
}
void CpuGemmAssemblyDispatch::run(ITensorPack &tensors)
diff --git a/src/cpu/operators/internal/CpuGemmAssemblyDispatch.h b/src/cpu/operators/internal/CpuGemmAssemblyDispatch.h
index 3c25866f25..4ef108d430 100644
--- a/src/cpu/operators/internal/CpuGemmAssemblyDispatch.h
+++ b/src/cpu/operators/internal/CpuGemmAssemblyDispatch.h
@@ -53,6 +53,7 @@ struct AsmGemmInfo
float padding_value{ 0.f };
bool fast_mode{ false };
bool fixed_format{ false };
+ arm_gemm::WeightFormat weight_format{ arm_gemm::WeightFormat::UNSPECIFIED };
};
/** Assembly kernel glue */
@@ -73,6 +74,7 @@ public:
virtual void prepare(ITensorPack &tensors) = 0;
virtual experimental::MemoryRequirements workspace() const = 0;
virtual bool is_configured() const = 0;
+ virtual bool isVarWeightsKernel() const = 0;
virtual ~IFallback() = default;
};
@@ -101,15 +103,14 @@ public:
/** Indicates whether or not there is an optimal assembly implementation that can be used to process the given parameters.
*
- * @param[in] a Input tensor info (Matrix A)
- * @param[in] b Input tensor info (Matrix B)
- * @param[in] c Input tensor info (Matrix C) used to pass the bias for quantized calculations
- * @param[in] d Output tensor to store the result of matrix multiplication. Data type supported: same as @p input0.
- * @param[in] info GEMM meta-data
+ * This method has the same use of @ref
+ * NEGEMMConvolutionLayer::has_opt_impl, with the only caveat that
+ * the value of arm_gemm::WeightFormat need to be passed via the
+ * parameter info.
*
* @return a status.
*/
- static Status has_opt_impl(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *d, const AsmGemmInfo &info);
+ static Status has_opt_impl(arm_gemm::WeightFormat &weight_format, const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *d, const AsmGemmInfo &info);
/** Checks if activation is supported by the gemm assembly dispatcher
*
* @param[in] activation Activation to check
@@ -122,6 +123,14 @@ public:
* @return True if the function is configured and ready to run
*/
bool is_configured() const;
+ /** Indicates if the convolution executes in variable weights mode.
+ *
+ * Similar to @ref CpuGemm::isVarWeightsKernel
+ */
+ bool isVarWeightsKernel() const
+ {
+ return _arm_gemm && _arm_gemm->isVarWeightsKernel();
+ }
// Inherited methods overridden:
void prepare(ITensorPack &tensors) override;
diff --git a/src/runtime/NEON/functions/NEGEMMConvolutionLayer.cpp b/src/runtime/NEON/functions/NEGEMMConvolutionLayer.cpp
index c780d63763..13635c6e34 100644
--- a/src/runtime/NEON/functions/NEGEMMConvolutionLayer.cpp
+++ b/src/runtime/NEON/functions/NEGEMMConvolutionLayer.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017-2021 Arm Limited.
+ * Copyright (c) 2017-2022 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -58,6 +58,7 @@ void NEGEMMConvolutionLayer::configure(const ITensor *input, const ITensor *weig
const Size2D &dilation, const ActivationLayerInfo &act_info, bool enable_fast_math, unsigned int num_groups)
{
ARM_COMPUTE_ERROR_ON_NULLPTR(input, weights, output);
+
_impl->weights = weights;
_impl->op = std::make_unique<cpu::CpuGemmConv2d>();
_impl->op->configure(input->info(), weights->info(), (biases != nullptr ? biases->info() : nullptr), output->info(), conv_info, weights_info, dilation, act_info, enable_fast_math, num_groups);
@@ -79,6 +80,13 @@ Status NEGEMMConvolutionLayer::validate(const ITensorInfo *input, const ITensorI
return cpu::CpuGemmConv2d::validate(input, weights, biases, output, conv_info, weights_info, dilation, act_info, enable_fast_math, num_groups);
}
+Status NEGEMMConvolutionLayer::has_opt_impl(arm_gemm::WeightFormat &expected_weight_format, const ITensorInfo *src, const ITensorInfo *weights, const ITensorInfo *biases, const ITensorInfo *dst,
+ const PadStrideInfo &conv_info,
+ const WeightsInfo &weights_info, const Size2D &dilation, const ActivationLayerInfo &act_info, const bool enable_fast_math)
+{
+ return cpu::CpuGemmConv2d::has_opt_impl(expected_weight_format, src, weights, biases, dst, conv_info, weights_info, dilation, act_info, enable_fast_math);
+}
+
void NEGEMMConvolutionLayer::run()
{
prepare();
diff --git a/tests/SConscript b/tests/SConscript
index 8cd13ab914..b848f27043 100644
--- a/tests/SConscript
+++ b/tests/SConscript
@@ -184,6 +184,10 @@ if test_env['reference_openmp'] and env['os'] not in ['bare_metal', 'macos','win
if 'ndk_above_r21' in env:
test_env['LINKFLAGS'].append('-static-openmp')
+# Testing for fixed format GEMM kernels.
+if env['experimental_fixed_format_kernels'] and test_env['validation_tests']:
+ test_env.Append(CPPDEFINES = ['ARM_COMPUTE_ENABLE_FIXED_FORMAT_KERNELS'])
+
if test_env['validation_tests']:
arm_compute_validation_framework = env.StaticLibrary('arm_compute_validation_framework', Glob('validation/reference/*.cpp') + Glob('validation/*.cpp'), LINKFLAGS=test_env['LINKFLAGS'], CXXFLAGS=test_env['CXXFLAGS'], LIBS= [ arm_compute_test_framework, arm_compute_core_a])
Depends(arm_compute_validation_framework , arm_compute_test_framework)
diff --git a/tests/framework/Asserts.h b/tests/framework/Asserts.h
index 28d3da9a85..5f462773d0 100644
--- a/tests/framework/Asserts.h
+++ b/tests/framework/Asserts.h
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017-2021 Arm Limited.
+ * Copyright (c) 2017-2022 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -42,6 +42,11 @@ inline int make_printable(int8_t value)
return value;
}
+inline std::string make_printable(arm_gemm::WeightFormat wf)
+{
+ return arm_gemm::to_string(wf);
+}
+
inline unsigned int make_printable(uint8_t value)
{
return value;
diff --git a/tests/validation/NEON/ConvolutionLayer.cpp b/tests/validation/NEON/ConvolutionLayer.cpp
index 578921bddd..3b385d4724 100644
--- a/tests/validation/NEON/ConvolutionLayer.cpp
+++ b/tests/validation/NEON/ConvolutionLayer.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017-2021 Arm Limited.
+ * Copyright (c) 2017-2022 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -504,6 +504,220 @@ TEST_SUITE_END() // FP16
#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
TEST_SUITE_END() // WinogradLayer
+#ifdef ARM_COMPUTE_ENABLE_FIXED_FORMAT_KERNELS
+TEST_SUITE(VariableWeightUtils)
+
+// UC2_1_* tests: the user requests a specific fixed format, but there is no kernel that supports it.
+
+FIXTURE_DATA_TEST_CASE(UC2_1_CpuGemmConv2d, HasOptImplFixture<cpu::CpuGemmConv2d>, framework::DatasetMode::ALL,
+ combine(framework::dataset::make("DataType", { DataType::F32 }),
+ framework::dataset::make("QueryWeightFormat", { arm_gemm::WeightFormat::OHWIo2 })))
+{
+ ARM_COMPUTE_EXPECT(!_kernel_found, framework::LogLevel::ERRORS);
+}
+FIXTURE_DATA_TEST_CASE(UC2_1_NEGEMMConvolutionLayer, HasOptImplFixture<NEGEMMConvolutionLayer>, framework::DatasetMode::ALL,
+ combine(framework::dataset::make("DataType", { DataType::F32 }),
+ framework::dataset::make("QueryWeightFormat", { arm_gemm::WeightFormat::OHWIo2 })))
+{
+ ARM_COMPUTE_EXPECT(!_kernel_found, framework::LogLevel::ERRORS);
+}
+
+// UC2_1_* tests: the user requests a specific fixed format, and a
+// kernel that support that fixed format is found.
+
+FIXTURE_DATA_TEST_CASE(UC2_2_CpuGemmConv2d, HasOptImplFixture<cpu::CpuGemmConv2d>, framework::DatasetMode::ALL,
+ combine(framework::dataset::make("DataType", { DataType::F32 }),
+ framework::dataset::make("QueryWeightFormat", { arm_gemm::WeightFormat::OHWIo4 })))
+{
+ ARM_COMPUTE_EXPECT(_kernel_found, framework::LogLevel::ERRORS);
+ ARM_COMPUTE_EXPECT(_computed_weight_format == arm_gemm::WeightFormat::OHWIo4, framework::LogLevel::ERRORS);
+}
+
+FIXTURE_DATA_TEST_CASE(UC2_2_NEGEMMConvolutionLayer, HasOptImplFixture<NEGEMMConvolutionLayer>, framework::DatasetMode::ALL,
+ combine(framework::dataset::make("DataType", { DataType::F32 }),
+ framework::dataset::make("QueryWeightFormat", { arm_gemm::WeightFormat::OHWIo4 })))
+{
+ ARM_COMPUTE_EXPECT(_kernel_found, framework::LogLevel::ERRORS);
+ ARM_COMPUTE_EXPECT(_computed_weight_format == arm_gemm::WeightFormat::OHWIo4, framework::LogLevel::ERRORS);
+}
+
+// UC3_1_* tests: the user queries for ANY fixed format, but there is
+// no kernel that support the use case specified by the user (for
+// example, there is no fixed format kernel for the datatype of the
+// problem).
+
+FIXTURE_DATA_TEST_CASE(UC3_1_CpuGemmConv2d, HasOptImplFixture<cpu::CpuGemmConv2d>, framework::DatasetMode::ALL,
+ combine(framework::dataset::make("DataType", { DataType::S32 }),
+ framework::dataset::make("QueryWeightFormat", { arm_gemm::WeightFormat::ANY })))
+{
+ ARM_COMPUTE_EXPECT(!_kernel_found, framework::LogLevel::ERRORS);
+}
+
+FIXTURE_DATA_TEST_CASE(UC3_1_NEGEMMConvolutionLayer, HasOptImplFixture<NEGEMMConvolutionLayer>, framework::DatasetMode::ALL,
+ combine(framework::dataset::make("DataType", { DataType::S32 }),
+ framework::dataset::make("QueryWeightFormat", { arm_gemm::WeightFormat::ANY })))
+{
+ ARM_COMPUTE_EXPECT(!_kernel_found, framework::LogLevel::ERRORS);
+}
+
+// UC3_2_* tests: the user queries for ANY fixed format. The search
+// succeeded and the fixed format found is prompted back for
+// consumption by the user. Note that we just test the
+// _computed_weight_format to be anything but not the formats that are
+// not fixed formats (ANY and UNSPECIFIED). This is because the weight
+// format that the runtime produces depends on the size of the vector
+// units of the hardware where the tests is executed. For example, a
+// format like OHWIo4 for FP32 data returned for 128-bit NEON hardware
+// is replaced by OHWIo8 when running on 256-bit SVE.
+
+FIXTURE_DATA_TEST_CASE(UC3_2_CpuGemmConv2d, HasOptImplFixture<cpu::CpuGemmConv2d>, framework::DatasetMode::ALL,
+ combine(framework::dataset::make("DataType", { DataType::F32 }),
+ framework::dataset::make("QueryWeightFormat", { arm_gemm::WeightFormat::ANY })))
+{
+ ARM_COMPUTE_EXPECT(_kernel_found, framework::LogLevel::ERRORS);
+ ARM_COMPUTE_EXPECT(_computed_weight_format != arm_gemm::WeightFormat::ANY, framework::LogLevel::ERRORS);
+ ARM_COMPUTE_EXPECT(_computed_weight_format != arm_gemm::WeightFormat::UNSPECIFIED, framework::LogLevel::ERRORS);
+}
+
+FIXTURE_DATA_TEST_CASE(UC3_2_NEGEMMConvolutionLayer, HasOptImplFixture<NEGEMMConvolutionLayer>, framework::DatasetMode::ALL,
+ combine(framework::dataset::make("DataType", { DataType::F32 }),
+ framework::dataset::make("QueryWeightFormat", { arm_gemm::WeightFormat::ANY })))
+{
+ ARM_COMPUTE_EXPECT(_computed_weight_format != arm_gemm::WeightFormat::ANY, framework::LogLevel::ERRORS);
+ ARM_COMPUTE_EXPECT(_computed_weight_format != arm_gemm::WeightFormat::UNSPECIFIED, framework::LogLevel::ERRORS);
+}
+
+namespace
+{
+using TestCaseType = std::tuple<TensorShape, TensorShape, arm_gemm::WeightFormat>;
+auto prepare_weights_shapes = framework::dataset::make("TensorShape",
+{
+ // OHWIo<interleave_by>i<block_by>
+ //
+ // OHWI --> O'HWI', where:
+ //
+ // O'= smallest multiple of <interleave_by> such that O<=O'
+ // I'= smallest multiple of <block_by> such that I<=I'
+ //
+
+ // Change N for OHWIo4
+ TestCaseType({ { 1U, 1U, 1U, 1U }, { 1U, 1U, 1U, 4U }, arm_gemm::WeightFormat::OHWIo4 }),
+ TestCaseType({ { 1U, 1U, 1U, 2U }, { 1U, 1U, 1U, 4U }, arm_gemm::WeightFormat::OHWIo4 }),
+ TestCaseType({ { 1U, 1U, 1U, 3U }, { 1U, 1U, 1U, 4U }, arm_gemm::WeightFormat::OHWIo4 }),
+ TestCaseType({ { 1U, 1U, 1U, 4U }, { 1U, 1U, 1U, 4U }, arm_gemm::WeightFormat::OHWIo4 }),
+ TestCaseType({ { 1U, 1U, 1U, 5U }, { 1U, 1U, 1U, 8U }, arm_gemm::WeightFormat::OHWIo4 }),
+ TestCaseType({ { 1U, 1U, 1U, 6U }, { 1U, 1U, 1U, 8U }, arm_gemm::WeightFormat::OHWIo4 }),
+ TestCaseType({ { 1U, 1U, 1U, 7U }, { 1U, 1U, 1U, 8U }, arm_gemm::WeightFormat::OHWIo4 }),
+ TestCaseType({ { 1U, 1U, 1U, 8U }, { 1U, 1U, 1U, 8U }, arm_gemm::WeightFormat::OHWIo4 }),
+ TestCaseType({ { 1U, 1U, 1U, 9U }, { 1U, 1U, 1U, 12U }, arm_gemm::WeightFormat::OHWIo4 }),
+ // // Change N for OHWIo8
+ TestCaseType({ { 1U, 1U, 1U, 1U }, { 1U, 1U, 1U, 8U }, arm_gemm::WeightFormat::OHWIo8 }),
+ TestCaseType({ { 1U, 1U, 1U, 2U }, { 1U, 1U, 1U, 8U }, arm_gemm::WeightFormat::OHWIo8 }),
+ TestCaseType({ { 1U, 1U, 1U, 3U }, { 1U, 1U, 1U, 8U }, arm_gemm::WeightFormat::OHWIo8 }),
+ TestCaseType({ { 1U, 1U, 1U, 4U }, { 1U, 1U, 1U, 8U }, arm_gemm::WeightFormat::OHWIo8 }),
+ TestCaseType({ { 1U, 1U, 1U, 5U }, { 1U, 1U, 1U, 8U }, arm_gemm::WeightFormat::OHWIo8 }),
+ TestCaseType({ { 1U, 1U, 1U, 6U }, { 1U, 1U, 1U, 8U }, arm_gemm::WeightFormat::OHWIo8 }),
+ TestCaseType({ { 1U, 1U, 1U, 7U }, { 1U, 1U, 1U, 8U }, arm_gemm::WeightFormat::OHWIo8 }),
+ TestCaseType({ { 1U, 1U, 1U, 8U }, { 1U, 1U, 1U, 8U }, arm_gemm::WeightFormat::OHWIo8 }),
+ TestCaseType({ { 1U, 1U, 1U, 9U }, { 1U, 1U, 1U, 16U }, arm_gemm::WeightFormat::OHWIo8 }),
+ // // Change N for OHWIo4 when H, W and C are not 1
+ TestCaseType({ { 3U, 4U, 2U, 1U }, { 3, 4, 2, 4 }, arm_gemm::WeightFormat::OHWIo4 }),
+ TestCaseType({ { 3U, 4U, 2U, 2U }, { 3, 4, 2, 4 }, arm_gemm::WeightFormat::OHWIo4 }),
+ TestCaseType({ { 3U, 4U, 2U, 3U }, { 3, 4, 2, 4 }, arm_gemm::WeightFormat::OHWIo4 }),
+ TestCaseType({ { 3U, 4U, 2U, 4U }, { 3, 4, 2, 4 }, arm_gemm::WeightFormat::OHWIo4 }),
+ TestCaseType({ { 3U, 4U, 2U, 5U }, { 3, 4, 2, 8 }, arm_gemm::WeightFormat::OHWIo4 }),
+ TestCaseType({ { 3U, 4U, 2U, 6U }, { 3, 4, 2, 8 }, arm_gemm::WeightFormat::OHWIo4 }),
+ TestCaseType({ { 3U, 4U, 2U, 7U }, { 3, 4, 2, 8 }, arm_gemm::WeightFormat::OHWIo4 }),
+ TestCaseType({ { 3U, 4U, 2U, 8U }, { 3, 4, 2, 8 }, arm_gemm::WeightFormat::OHWIo4 }),
+ TestCaseType({ { 3U, 4U, 2U, 9U }, { 3, 4, 2, 12 }, arm_gemm::WeightFormat::OHWIo4 }),
+
+ // // Fix N and move HWI around, with different data layouts and formats
+ TestCaseType({ { 2U, 4U, 3U, 5U }, { 2, 4, 3, 8 }, arm_gemm::WeightFormat::OHWIo4 }),
+ TestCaseType({ { 3U, 4U, 2U, 5U }, { 3, 4, 2, 8 }, arm_gemm::WeightFormat::OHWIo4 }),
+ TestCaseType({ { 2U, 4U, 3U, 9U }, { 2, 4, 3, 16 }, arm_gemm::WeightFormat::OHWIo8 }),
+ TestCaseType({ { 3U, 4U, 2U, 9U }, { 3, 4, 2, 16 }, arm_gemm::WeightFormat::OHWIo8 }),
+ TestCaseType({ { 1024U, 1U, 1U, 1001U }, { 1024, 1, 1, 1008 }, arm_gemm::WeightFormat::OHWIo8 }),
+
+ // // Adding <block_by> on I (=C)
+ TestCaseType({ { 1U, 4U, 3U, 5U }, { 2, 4, 3, 8 }, arm_gemm::WeightFormat::OHWIo4i2 }),
+ TestCaseType({ { 2U, 4U, 3U, 5U }, { 2, 4, 3, 8 }, arm_gemm::WeightFormat::OHWIo4i2 }),
+ TestCaseType({ { 3U, 4U, 3U, 5U }, { 4, 4, 3, 8 }, arm_gemm::WeightFormat::OHWIo4i2 }),
+
+ // ---------
+ TestCaseType({ { 2, 2, 1, 5 }, { 2, 2, 1, 8 }, arm_gemm::WeightFormat::OHWIo4 }),
+ TestCaseType({ { 1, 2, 2, 5 }, { 1, 2, 2, 8 }, arm_gemm::WeightFormat::OHWIo4 }),
+
+});
+} // unnamed namespace
+
+DATA_TEST_CASE(PrepareWeightShape, framework::DatasetMode::ALL,
+ prepare_weights_shapes, shapes)
+{
+ const TensorShape input_shape = std::get<0>(shapes);
+ const TensorShape expected_shape = std::get<1>(shapes);
+ const arm_gemm::WeightFormat wf = std::get<2>(shapes);
+ const DataType DT = DataType::F32;
+ const DataLayout DL = DataLayout::NHWC;
+ const auto TI = TensorInfo(input_shape, 1 /*num_channels, deprecated*/, DT, DL);
+ const TensorInfo computed = ::arm_compute::test::validation::prepare_weights(TI, wf);
+ const TensorInfo expected = TensorInfo(expected_shape, 1 /*num_channels, deprecated*/, DT, DL);
+ ARM_COMPUTE_EXPECT_EQUAL(computed, expected, framework::LogLevel::ERRORS);
+}
+
+TEST_SUITE_END() // VariableWeightUtils
+
+TEST_SUITE(ExperimentalCpuAPIVariableWeightWithFixtures)
+
+template <typename ScalarType>
+using VarWidth = VariableWeightsFixture<cpu::CpuGemmConv2d, Tensor, Accessor, ScalarType>;
+
+FIXTURE_DATA_TEST_CASE(RunSmallFloat, VarWidth<float>, framework::DatasetMode::ALL,
+ combine(combine(datasets::SmallConvolutionLayerDataset(),
+ framework::dataset::make("DataLayout", { DataLayout::NHWC })),
+ framework::dataset::make("ACL Scalar type", { DataType::F32 })))
+{
+ // Validate output
+ validate(Accessor(_target), _reference, rel_tolerance_f32, 0.f, float(abs_tolerance_f32));
+}
+
+FIXTURE_DATA_TEST_CASE(RunSmallHalf, VarWidth<half>, framework::DatasetMode::ALL,
+ combine(combine(datasets::SmallConvolutionLayerDataset(),
+ framework::dataset::make("DataLayout", { DataLayout::NHWC })),
+ framework::dataset::make("ACL Scalar type", { DataType::F16 })))
+{
+ // Validate output
+ validate(Accessor(_target), _reference, rel_tolerance_f16, 0.f, half(abs_tolerance_f16));
+}
+
+TEST_SUITE_END() // ExperimentalCpuAPIVariableWeightWithFixtures
+
+TEST_SUITE(ExperimentalNEAPIVariableWeightWithFixtures)
+
+template <typename ScalarType>
+using NEGEMMVarWidth = VariableWeightsFixtureNEInterface<NEGEMMConvolutionLayer, Tensor, Accessor, ScalarType>;
+
+FIXTURE_DATA_TEST_CASE(NEGEMMRunSmallFloat, NEGEMMVarWidth<float>, framework::DatasetMode::ALL,
+ combine(combine(datasets::SmallConvolutionLayerDataset(),
+ framework::dataset::make("DataLayout", { DataLayout::NHWC })),
+ framework::dataset::make("ACL Scalar type", { DataType::F32 })))
+{
+ // Validate output
+ validate(Accessor(_target), _reference, rel_tolerance_f32, 0.f, float(abs_tolerance_f32));
+}
+
+FIXTURE_DATA_TEST_CASE(NEGEMMRunSmallHalf, NEGEMMVarWidth<half>, framework::DatasetMode::ALL,
+ combine(combine(datasets::SmallConvolutionLayerDataset(),
+ framework::dataset::make("DataLayout", { DataLayout::NHWC })),
+ framework::dataset::make("ACL Scalar type", { DataType::F16 })))
+{
+ // Validate output
+ validate(Accessor(_target), _reference, rel_tolerance_f16, 0.f, half(abs_tolerance_f16));
+}
+
+TEST_SUITE_END() // ExperimentalNEAPIVariableWeightWithFixtures
+
+#endif // ARM_COMPUTE_ENABLE_FIXED_FORMAT_KERNELS
+
TEST_SUITE(GEMMConvolutionLayer)
template <typename T>
using NEGEMMConvolutionLayerFixture = ConvolutionValidationFixture<Tensor, Accessor, NEConvolutionLayer, T>;
@@ -609,9 +823,7 @@ TEST_SUITE(Float)
#if defined(__ARM_FEATURE_BF16_VECTOR_ARITHMETIC) || defined(ARM_COMPUTE_FORCE_BF16)
TEST_SUITE(BFLOAT16)
FIXTURE_DATA_TEST_CASE(RunSmall, NEGEMMConvolutionLayerFixture<float>, framework::DatasetMode::ALL, combine(combine(combine(combine(datasets::SmallConvolutionLayerDataset(),
- framework::dataset::make("ReshapeWeights", { true })),
- framework::dataset::make("DataType", DataType::BFLOAT16)),
- framework::dataset::make("DataLayout", { DataLayout::NHWC })),
+ framework::dataset::make("ReshapeWeights", { true })), framework::dataset::make("DataType", DataType::BFLOAT16)), framework::dataset::make("DataLayout", { DataLayout::NHWC })),
ActivationFunctionsDataset))
{
// Validate output
@@ -623,10 +835,7 @@ TEST_SUITE_END() // BFLOAT16
#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
TEST_SUITE(FP16)
FIXTURE_DATA_TEST_CASE(RunSmall, NEGEMMConvolutionLayerFixture<half>, framework::DatasetMode::ALL, combine(combine(combine(combine(datasets::SmallConvolutionLayerDataset(),
- framework::dataset::make("ReshapeWeights", { true })),
- framework::dataset::make("DataType", DataType::F16)),
- framework::dataset::make("DataLayout", { DataLayout::NCHW })),
- ActivationFunctionsDataset))
+ framework::dataset::make("ReshapeWeights", { true })), framework::dataset::make("DataType", DataType::F16)), framework::dataset::make("DataLayout", { DataLayout::NCHW })), ActivationFunctionsDataset))
{
// Validate output
validate(Accessor(_target), _reference, rel_tolerance_f16, tolerance_num, abs_tolerance_f16);
@@ -636,9 +845,7 @@ TEST_SUITE_END() // FP16
TEST_SUITE(FP32)
FIXTURE_DATA_TEST_CASE(RunSmall, NEGEMMConvolutionLayerFixture<float>, framework::DatasetMode::ALL, combine(combine(combine(combine(datasets::SmallConvolutionLayerDataset(),
- framework::dataset::make("ReshapeWeights", { true })),
- framework::dataset::make("DataType", DataType::F32)),
- framework::dataset::make("DataLayout", { DataLayout::NCHW, DataLayout::NHWC })),
+ framework::dataset::make("ReshapeWeights", { true })), framework::dataset::make("DataType", DataType::F32)), framework::dataset::make("DataLayout", { DataLayout::NCHW, DataLayout::NHWC })),
ActivationFunctionsDataset))
{
// Validate output
@@ -680,11 +887,8 @@ const auto QuantizedActivationFunctionsDataset = framework::dataset::make("Activ
TEST_SUITE(Quantized)
TEST_SUITE(QASYMM8)
FIXTURE_DATA_TEST_CASE(RunSmall, NEGEMMConvolutionLayerQuantizedFixture<uint8_t>, framework::DatasetMode::ALL, combine(combine(combine(combine(combine(datasets::SmallConvolutionLayerDataset(),
- framework::dataset::make("ReshapeWeights", { true })),
- framework::dataset::make("DataType", DataType::QASYMM8)),
- framework::dataset::make("DataLayout", { DataLayout::NCHW, DataLayout::NHWC })),
- framework::dataset::make("QuantizationInfo", { QuantizationInfo(2.f / 255.f, 10) })),
- QuantizedActivationFunctionsDataset))
+ framework::dataset::make("ReshapeWeights", { true })), framework::dataset::make("DataType", DataType::QASYMM8)), framework::dataset::make("DataLayout", { DataLayout::NCHW, DataLayout::NHWC })),
+ framework::dataset::make("QuantizationInfo", { QuantizationInfo(2.f / 255.f, 10) })), QuantizedActivationFunctionsDataset))
{
// Validate output
validate(Accessor(_target), _reference, tolerance_qasymm8);
@@ -710,11 +914,8 @@ TEST_SUITE_END() // QASYMM8
TEST_SUITE(QASYMM8_SIGNED)
FIXTURE_DATA_TEST_CASE(RunSmall, NEGEMMConvolutionLayerQuantizedFixture<int8_t>, framework::DatasetMode::ALL, combine(combine(combine(combine(combine(datasets::SmallConvolutionLayerDataset(),
- framework::dataset::make("ReshapeWeights", { true })),
- framework::dataset::make("DataType", DataType::QASYMM8_SIGNED)),
- framework::dataset::make("DataLayout", { DataLayout::NCHW, DataLayout::NHWC })),
- framework::dataset::make("QuantizationInfo", { QuantizationInfo(0.01f, -10) })),
- QuantizedActivationFunctionsDataset))
+ framework::dataset::make("ReshapeWeights", { true })), framework::dataset::make("DataType", DataType::QASYMM8_SIGNED)), framework::dataset::make("DataLayout", { DataLayout::NCHW, DataLayout::NHWC })),
+ framework::dataset::make("QuantizationInfo", { QuantizationInfo(0.01f, -10) })), QuantizedActivationFunctionsDataset))
{
// Validate output
validate(Accessor(_target), _reference, tolerance_qasymm8);
@@ -868,10 +1069,7 @@ TEST_CASE(MultipleExecutionWithConfigure, framework::DatasetMode::ALL)
TEST_SUITE(Float)
TEST_SUITE(FP32)
FIXTURE_DATA_TEST_CASE(RunSmall, NEDirectGEMMConv2dLayerFixture<float>, framework::DatasetMode::ALL, combine(combine(combine(combine(datasets::SmallConvolutionLayerDataset(),
- framework::dataset::make("ReshapeWeights", { true })),
- framework::dataset::make("DataType", DataType::F32)),
- framework::dataset::make("DataLayout", { DataLayout::NHWC })),
- ActivationFunctionsDataset))
+ framework::dataset::make("ReshapeWeights", { true })), framework::dataset::make("DataType", DataType::F32)), framework::dataset::make("DataLayout", { DataLayout::NHWC })), ActivationFunctionsDataset))
{
// Validate output
validate(Accessor(_target), _reference, rel_tolerance_f32, 0.f, float(abs_tolerance_f32));
@@ -895,11 +1093,8 @@ const auto QuantizedActivationFunctionsDataset = framework::dataset::make("Activ
TEST_SUITE(Quantized)
TEST_SUITE(QASYMM8)
FIXTURE_DATA_TEST_CASE(RunSmall, NEDirectGEMMConv2dLayerQuantizedFixture<uint8_t>, framework::DatasetMode::ALL, combine(combine(combine(combine(combine(datasets::SmallConvolutionLayerDataset(),
- framework::dataset::make("ReshapeWeights", { true })),
- framework::dataset::make("DataType", DataType::QASYMM8)),
- framework::dataset::make("DataLayout", { DataLayout::NHWC })),
- framework::dataset::make("QuantizationInfo", { QuantizationInfo(2.f / 255.f, 10) })),
- QuantizedActivationFunctionsDataset))
+ framework::dataset::make("ReshapeWeights", { true })), framework::dataset::make("DataType", DataType::QASYMM8)), framework::dataset::make("DataLayout", { DataLayout::NHWC })),
+ framework::dataset::make("QuantizationInfo", { QuantizationInfo(2.f / 255.f, 10) })), QuantizedActivationFunctionsDataset))
{
// Validate output
validate(Accessor(_target), _reference, tolerance_qasymm8);
@@ -908,11 +1103,8 @@ TEST_SUITE_END() // QASYMM8
TEST_SUITE(QASYMM8_SIGNED)
FIXTURE_DATA_TEST_CASE(RunSmall, NEDirectGEMMConv2dLayerQuantizedFixture<int8_t>, framework::DatasetMode::ALL, combine(combine(combine(combine(combine(datasets::SmallConvolutionLayerDataset(),
- framework::dataset::make("ReshapeWeights", { true })),
- framework::dataset::make("DataType", DataType::QASYMM8_SIGNED)),
- framework::dataset::make("DataLayout", { DataLayout::NHWC })),
- framework::dataset::make("QuantizationInfo", { QuantizationInfo(0.01f, -10) })),
- QuantizedActivationFunctionsDataset))
+ framework::dataset::make("ReshapeWeights", { true })), framework::dataset::make("DataType", DataType::QASYMM8_SIGNED)), framework::dataset::make("DataLayout", { DataLayout::NHWC })),
+ framework::dataset::make("QuantizationInfo", { QuantizationInfo(0.01f, -10) })), QuantizedActivationFunctionsDataset))
{
// Validate output
validate(Accessor(_target), _reference, tolerance_qasymm8);
diff --git a/tests/validation/fixtures/ConvolutionLayerFixture.h b/tests/validation/fixtures/ConvolutionLayerFixture.h
index bffdc59758..d3804ee371 100644
--- a/tests/validation/fixtures/ConvolutionLayerFixture.h
+++ b/tests/validation/fixtures/ConvolutionLayerFixture.h
@@ -28,6 +28,7 @@
#include "arm_compute/core/Types.h"
#include "arm_compute/graph/Utils.h"
#include "arm_compute/runtime/NEON/NEScheduler.h"
+#include "src/core/NEON/kernels/arm_gemm/utils.hpp"
#include "src/graph/mutators/MutatorUtils.h"
#include "tests/AssetsLibrary.h"
#include "tests/Globals.h"
@@ -121,14 +122,14 @@ protected:
{
case DataType::QASYMM8:
{
- std::pair<int, int> bounds = get_quantized_bounds(tensor.quantization_info(), -1.0f, 1.0f);
+ std::pair<int, int> bounds = get_quantized_bounds(tensor.quantization_info(), -1.0f, 1.0f);
std::uniform_int_distribution<uint32_t> distribution(bounds.first, bounds.second);
library->fill(tensor, distribution, i);
break;
}
case DataType::QASYMM8_SIGNED:
{
- std::pair<int, int> bounds = get_quantized_qasymm8_signed_bounds(tensor.quantization_info(), -1.0f, 1.0f);
+ std::pair<int, int> bounds = get_quantized_qasymm8_signed_bounds(tensor.quantization_info(), -1.0f, 1.0f);
std::uniform_int_distribution<int32_t> distribution(bounds.first, bounds.second);
library->fill(tensor, distribution, i);
break;
@@ -397,6 +398,296 @@ public:
quantization_info, QuantizationInfo(weights_scales), act_info);
}
};
+
+#ifdef ARM_COMPUTE_ENABLE_FIXED_FORMAT_KERNELS
+inline TensorInfo prepare_weights(const TensorInfo tensor_info, const arm_gemm::WeightFormat weight_format)
+{
+ const DataLayout data_layout = tensor_info.data_layout();
+ ARM_COMPUTE_EXPECT(data_layout == DataLayout::NHWC, framework::LogLevel::ERRORS);
+ const DataType data_type = tensor_info.data_type();
+ const TensorShape tensor_shape = tensor_info.tensor_shape();
+ const int N = tensor_shape[get_data_layout_dimension_index(data_layout, DataLayoutDimension::BATCHES)]; // N=O
+ const int H = tensor_shape[get_data_layout_dimension_index(data_layout, DataLayoutDimension::HEIGHT)];
+ const int W = tensor_shape[get_data_layout_dimension_index(data_layout, DataLayoutDimension::WIDTH)];
+ const int C = tensor_shape[get_data_layout_dimension_index(data_layout, DataLayoutDimension::CHANNEL)]; // C=I
+
+ const int interleave_by = arm_gemm::interleave_by(weight_format);
+ const int block_by = arm_gemm::block_by(weight_format);
+ const int Ip = arm_gemm::roundup<unsigned int>(C, block_by); // C'=I'
+ const int Op = arm_gemm::roundup<unsigned int>(N, interleave_by); // O'=N'
+
+ const TensorShape TS(Ip, W, H, Op);
+ return TensorInfo(TS, 1 /*num_channels*/, data_type, data_layout);
+}
+
+template <typename ScalarType, typename AccessorType>
+inline void rearrange_data(const AccessorType src, AccessorType dst, const arm_gemm::WeightFormat weight_format)
+{
+ ARM_COMPUTE_EXPECT(arm_gemm::is_fixed_format(weight_format), framework::LogLevel::ERRORS);
+ // Data Layout: OHWIo<interleave_by>i<block_by>
+ const int interleave_by = arm_gemm::interleave_by(weight_format);
+ const int block_by = arm_gemm::block_by(weight_format);
+ const TensorShape src_tensor_shape = src.shape();
+ const DataLayout data_layout = src.data_layout();
+ ARM_COMPUTE_EXPECT(data_layout == DataLayout::NHWC, framework::LogLevel::ERRORS);
+ const unsigned int O = src_tensor_shape[get_data_layout_dimension_index(data_layout, DataLayoutDimension::BATCHES)]; // N=O
+ const unsigned int H = src_tensor_shape[get_data_layout_dimension_index(data_layout, DataLayoutDimension::HEIGHT)];
+ const unsigned int W = src_tensor_shape[get_data_layout_dimension_index(data_layout, DataLayoutDimension::WIDTH)];
+ const unsigned int I = src_tensor_shape[get_data_layout_dimension_index(data_layout, DataLayoutDimension::CHANNEL)]; // C=I
+ const unsigned int Ip = arm_gemm::roundup<unsigned int>(I, block_by); // C'=I'
+ const unsigned int Op = arm_gemm::roundup<unsigned int>(O, interleave_by); // N'=O'
+
+ ARM_COMPUTE_EXPECT_EQUAL(Op * H * W * Ip, (unsigned)dst.num_elements(), framework::LogLevel::ERRORS);
+ ARM_COMPUTE_EXPECT(src.num_elements() <= dst.num_elements(), framework::LogLevel::ERRORS);
+
+ const ScalarType *src_ptr = reinterpret_cast<const ScalarType *>(src.data());
+ ScalarType *dst_ptr = reinterpret_cast<ScalarType *>(dst.data());
+ for(unsigned i = 0; i < I; ++i)
+ for(unsigned w = 0; w < W; ++w)
+ for(unsigned h = 0; h < H; ++h)
+ for(unsigned o = 0; o < O; ++o)
+ {
+ ScalarType src_element;
+ switch(data_layout)
+ {
+ case DataLayout::NHWC:
+ {
+ src_element = src_ptr[o * H * W * I + h * W * I + w * I + i];
+ }
+ break;
+ default:
+ {
+ ARM_COMPUTE_ERROR("Unsupported memory layout.");
+ }
+ }
+ const int x5 = std::floor(((float)o) / interleave_by);
+ const int x4 = h;
+ const int x3 = w;
+ const int x2 = std::floor((float)i / block_by);
+ const int x1 = o % interleave_by;
+ const int x0 = i % block_by;
+ unsigned dst_idx = x5 * H * W * Ip * interleave_by
+ + x4 * W * Ip * interleave_by
+ + x3 * Ip * interleave_by
+ + x2 * interleave_by * block_by
+ + x1 * block_by
+ + x0;
+ dst_ptr[dst_idx] = src_element;
+ }
+}
+
+template <typename ConvolutionFunction, typename TensorClass, typename AccessorType, typename ScalarType>
+class VariableWeightsFixtureBaseClass : public framework::Fixture
+{
+public:
+ template <typename...>
+ void setup(TensorShape input_shape, TensorShape weights_shape, TensorShape bias_shape, TensorShape output_shape, PadStrideInfo info, Size2D dilation, DataLayout data_layout,
+ const DataType data_type)
+ {
+ conv = std::make_unique<ConvolutionFunction>();
+ // prepare data
+ _data_layout = data_layout;
+ // Fixed format kernels for variable weights can work only with NHWC format.
+ ARM_COMPUTE_EXPECT_EQUAL(_data_layout, DataLayout::NHWC, framework::LogLevel::ERRORS);
+ _data_type = data_type;
+ // run the code
+ compute_target(input_shape, weights_shape, bias_shape, output_shape, info, dilation);
+ compute_reference(input_shape, weights_shape, bias_shape, output_shape, info, dilation);
+ }
+ void teardown()
+ {
+ _target.allocator()->free();
+ }
+
+protected:
+ template <typename U>
+ void fill(U &&tensor, int i)
+ {
+ switch(tensor.data_type())
+ {
+ case DataType::F16:
+ {
+ arm_compute::utils::uniform_real_distribution_16bit<half> distribution{ -1.0f, 1.0f };
+ library->fill(tensor, distribution, i);
+ break;
+ }
+ case DataType::F32:
+ {
+ std::uniform_real_distribution<float> distribution(-1.0f, 1.0f);
+ library->fill(tensor, distribution, i);
+ break;
+ }
+ default:
+ library->fill_tensor_uniform(tensor, i);
+ }
+ }
+
+private:
+ virtual void configure_and_execute_kernel(TensorInfo src_tensor_info, TensorInfo weight_tensor_info, TensorInfo bias_tensor_info, TensorInfo dst_tensor_info, const WeightsInfo weights_info,
+ const PadStrideInfo &conv_info,
+ const Size2D &dilation) = 0;
+
+ void compute_target(TensorShape input_shape, TensorShape weights_shape, const TensorShape &bias_shape, TensorShape output_shape, const PadStrideInfo &conv_info,
+ const Size2D &dilation)
+ {
+ // The dataset is always in NCHW format - we need to make C the
+ // innermost dimension because the fixed-format kernel work only
+ // with NHWC layout.
+ permute(input_shape, PermutationVector(2U, 0U, 1U));
+ permute(weights_shape, PermutationVector(2U, 0U, 1U));
+ permute(output_shape, PermutationVector(2U, 0U, 1U));
+ const auto src_tensor_info = TensorInfo(input_shape, 1, _data_type, _data_layout);
+ const auto weight_tensor_info = TensorInfo(weights_shape, 1, _data_type, _data_layout);
+ const auto bias_tensor_info = TensorInfo(bias_shape, 1, _data_type, _data_layout);
+ auto dst_tensor_info = TensorInfo(output_shape, 1, _data_type, _data_layout);
+
+ const int kernel_height = weights_shape[get_data_layout_dimension_index(_data_layout, DataLayoutDimension::HEIGHT)];
+ const int kernel_width = weights_shape[get_data_layout_dimension_index(_data_layout, DataLayoutDimension::WIDTH)];
+ const int num_kernels = weights_shape[get_data_layout_dimension_index(_data_layout, DataLayoutDimension::BATCHES)];
+
+ const WeightsInfo query_weights_info(/*reshape_weights*/ false, kernel_width, kernel_height, num_kernels, false, arm_gemm::WeightFormat::ANY);
+ const bool kernel_found = bool(ConvolutionFunction::has_opt_impl(_computed_weight_format, &src_tensor_info, &weight_tensor_info,
+ &bias_tensor_info, &dst_tensor_info, conv_info, query_weights_info));
+ // Make surethat the setup founds a fixed-format kernel as requested by the test case.
+ ARM_COMPUTE_EXPECT(kernel_found, framework::LogLevel::ERRORS);
+ ARM_COMPUTE_EXPECT(arm_gemm::is_fixed_format(_computed_weight_format), framework::LogLevel::ERRORS);
+
+ const WeightsInfo weights_info(/*reshape_weights*/ false, kernel_width, kernel_height, num_kernels, false, _computed_weight_format);
+ configure_and_execute_kernel(src_tensor_info, weight_tensor_info, bias_tensor_info, dst_tensor_info, weights_info, conv_info,
+ dilation);
+ }
+ void compute_reference(const TensorShape &input_shape, const TensorShape &weights_shape, const TensorShape &bias_shape, const TensorShape &output_shape, const PadStrideInfo &info,
+ const Size2D &dilation)
+ {
+ ARM_COMPUTE_UNUSED(input_shape, weights_shape, bias_shape, output_shape, info,
+ dilation);
+
+ // Create reference
+ SimpleTensor<ScalarType> src{ input_shape, _data_type };
+ SimpleTensor<ScalarType> weights{ weights_shape, _data_type };
+ SimpleTensor<ScalarType> bias{ bias_shape, _data_type };
+ fill(src, 0);
+ fill(bias, 1);
+ fill(weights, 3);
+ _reference = reference::convolution_layer<ScalarType>(src, weights, bias, output_shape, info, dilation, 1 /*num_groups*/);
+ }
+ DataLayout _data_layout{};
+ DataType _data_type{};
+
+protected:
+ std::unique_ptr<ConvolutionFunction> conv{};
+ arm_gemm::WeightFormat _computed_weight_format{ arm_gemm::WeightFormat::UNSPECIFIED };
+ TensorClass _target{};
+ SimpleTensor<ScalarType> _reference{};
+};
+
+template <typename ConvolutionFunction, typename TensorClass, typename AccessorType, typename ScalarType>
+class VariableWeightsFixture : public VariableWeightsFixtureBaseClass<ConvolutionFunction, TensorClass, AccessorType, ScalarType>
+{
+ void configure_and_execute_kernel(TensorInfo src_tensor_info, TensorInfo weight_tensor_info, TensorInfo bias_tensor_info, TensorInfo dst_tensor_info, const WeightsInfo weights_info,
+ const PadStrideInfo &conv_info,
+ const Size2D &dilation)
+ {
+ this->conv->configure(&src_tensor_info, &weight_tensor_info, &bias_tensor_info, &dst_tensor_info, conv_info, weights_info, dilation);
+
+ // Allocate input tensors
+ auto src = create_tensor<TensorClass>(src_tensor_info);
+ auto weights_original = create_tensor<TensorClass>(weight_tensor_info);
+ const TensorInfo new_tensor_info = prepare_weights(weight_tensor_info, this->_computed_weight_format);
+ auto weights_transformed = create_tensor<TensorClass>(new_tensor_info);
+ auto bias = create_tensor<TensorClass>(bias_tensor_info);
+ src.allocator()->allocate();
+ weights_original.allocator()->allocate();
+ weights_transformed.allocator()->allocate();
+ bias.allocator()->allocate();
+ // Allocate destination tensor
+ this->_target = create_tensor<TensorClass>(dst_tensor_info);
+ this->_target.allocator()->allocate();
+
+ // Prepare source and biases that are left unchanged.
+ this->fill(AccessorType(src), 0);
+ this->fill(AccessorType(bias), 1);
+
+ // First run
+ this->fill(AccessorType(weights_original), 2);
+ rearrange_data<ScalarType, AccessorType>(AccessorType(weights_original), AccessorType(weights_transformed), this->_computed_weight_format);
+ ITensorPack run_pack{ { TensorType::ACL_SRC_0, &src }, { TensorType::ACL_SRC_1, &weights_transformed }, { TensorType::ACL_SRC_2, &bias }, { TensorType::ACL_DST, &(this->_target) } };
+ this->conv->run(run_pack);
+ // Second run, with new weights
+ this->fill(AccessorType(weights_original), 3);
+ rearrange_data<ScalarType, AccessorType>(AccessorType(weights_original), AccessorType(weights_transformed), this->_computed_weight_format);
+ this->conv->run(run_pack);
+ src.allocator()->free();
+ weights_original.allocator()->free();
+ weights_transformed.allocator()->free();
+ bias.allocator()->free();
+ }
+};
+
+template <typename ConvolutionFunction, typename TensorClass, typename AccessorType, typename ScalarType>
+class VariableWeightsFixtureNEInterface : public VariableWeightsFixtureBaseClass<ConvolutionFunction, TensorClass, AccessorType, ScalarType>
+{
+ void configure_and_execute_kernel(TensorInfo src_tensor_info, TensorInfo weight_tensor_info, TensorInfo bias_tensor_info, TensorInfo dst_tensor_info, const WeightsInfo weights_info,
+ const PadStrideInfo &conv_info,
+ const Size2D &dilation)
+ {
+ // Allocate input tensors
+ auto src = create_tensor<TensorClass>(src_tensor_info);
+ auto weights_original = create_tensor<TensorClass>(weight_tensor_info);
+ const TensorInfo new_tensor_info = prepare_weights(weight_tensor_info, this->_computed_weight_format);
+ auto weights_transformed = create_tensor<TensorClass>(new_tensor_info);
+ auto bias = create_tensor<TensorClass>(bias_tensor_info);
+ src.allocator()->allocate();
+ weights_original.allocator()->allocate();
+ weights_transformed.allocator()->allocate();
+ bias.allocator()->allocate();
+ // Allocate destination tensor
+ this->_target = create_tensor<TensorClass>(dst_tensor_info);
+ this->_target.allocator()->allocate();
+ this->conv->configure(&src, &weights_transformed, &bias, &(this->_target), conv_info, weights_info, dilation);
+ // Prepare source and biases that are left unchanged.
+ this->fill(AccessorType(src), 0);
+ this->fill(AccessorType(bias), 1);
+
+ // First run
+ this->fill(AccessorType(weights_original), 2);
+ rearrange_data<ScalarType, AccessorType>(AccessorType(weights_original), AccessorType(weights_transformed), this->_computed_weight_format);
+ this->conv->run();
+ // Second run, with new weights
+ this->fill(AccessorType(weights_original), 3);
+ rearrange_data<ScalarType, AccessorType>(AccessorType(weights_original), AccessorType(weights_transformed), this->_computed_weight_format);
+ this->conv->run();
+ src.allocator()->free();
+ weights_original.allocator()->free();
+ weights_transformed.allocator()->free();
+ bias.allocator()->free();
+ }
+};
+
+template <typename ConvolutionClass>
+class HasOptImplFixture : public framework::Fixture
+{
+public:
+ template <typename...>
+ void setup(DataType data_type, arm_gemm::WeightFormat query_weight_format)
+ {
+ auto conv = std::make_unique<ConvolutionClass>();
+ const auto src_info = TensorInfo(TensorShape(1U, 5U, 2U), 1, data_type, DataLayout::NHWC);
+ const auto weight_info = TensorInfo(TensorShape(1U, 3U, 2U, 3U), 1, data_type, DataLayout::NHWC);
+ const auto bias_info = TensorInfo(TensorShape(3U), 1, data_type, DataLayout::NHWC);
+ auto dst_info = TensorInfo(TensorShape(1U, 7U, 3U), 1, data_type, DataLayout::NHWC);
+ const auto conv_info = PadStrideInfo(1, 1, 0, 0, 2, 2, DimensionRoundingType::FLOOR);
+ const WeightsInfo weights_info(false, 3U, 3U, 1U, false, query_weight_format);
+ _kernel_found = bool(ConvolutionClass::has_opt_impl(_computed_weight_format, &src_info, &weight_info,
+ &bias_info, &dst_info, conv_info, weights_info));
+ }
+
+protected:
+ bool _kernel_found{ false };
+ arm_gemm::WeightFormat _computed_weight_format{ arm_gemm::WeightFormat::UNSPECIFIED };
+};
+#endif // ARM_COMPUTE_ENABLE_FIXED_FORMAT_KERNELS
+
} // namespace validation
} // namespace test
} // namespace arm_compute
diff --git a/utils/TypePrinter.h b/utils/TypePrinter.h
index 31eff57e6b..a41b3cc9ae 100644
--- a/utils/TypePrinter.h
+++ b/utils/TypePrinter.h
@@ -3242,6 +3242,30 @@ inline std::string to_string(const Conv3dInfo &conv3d_info)
return str.str();
}
+inline ::std::ostream &operator<<(::std::ostream &os, const arm_gemm::WeightFormat &wf)
+{
+ os << arm_gemm::to_string(wf);
+ return os;
+}
+inline std::string to_string(const arm_gemm::WeightFormat wf)
+{
+ std::stringstream str;
+ str << wf;
+ return str.str();
+}
+
+inline std::string to_string(const std::tuple<TensorShape, TensorShape, arm_gemm::WeightFormat> values)
+{
+ std::stringstream str;
+ str << "[Input shape = " << std::get<0>(values);
+ str << ", ";
+ str << "Expected output shape = " << std::get<1>(values);
+
+ str << ", ";
+ str << "WeightFormat = " << std::get<2>(values) << "]";
+ return str.str();
+}
+
} // namespace arm_compute
#endif /* __ARM_COMPUTE_TYPE_PRINTER_H__ */