aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorRamy Elgammal <ramy.elgammal@arm.com>2022-07-20 14:57:37 +0100
committerRamy Elgammal <ramy.elgammal@arm.com>2022-07-26 11:57:27 +0000
commit91780021e25575086c6c31d014d34b6513649a9d (patch)
tree06eeb9c1b6c92e766464fb43dccced2c4f8aa90f
parentbf5274d1cbc2ba592b3990c40065e3c843075252 (diff)
downloadComputeLibrary-91780021e25575086c6c31d014d34b6513649a9d.tar.gz
Fix for inclusion of "arm_gemm" from src into "Types.h" from core
- Added arm_compute::WeightFormat and converted to/from arm_gemm::WeightFormat when needed through two map function. - Moved to_string(WeightFormat) to TypePrinter.h Resolves: COMPMID-5415 Signed-off-by: Ramy Elgammal <ramy.elgammal@arm.com> Change-Id: I65f7942100bcd4dbf2c5cf6c07f26c8e1e3bf86e Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/c/VisualCompute/ComputeLibrary/+/438511 Tested-by: bsgcomp <bsgcomp@arm.com> Reviewed-by: Pablo Tello <pablo.tello@arm.com> Reviewed-by: Sicong Li <sicong.li@arm.com> Comments-Addressed: bsgcomp <bsgcomp@arm.com> Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/7985 Tested-by: Arm Jenkins <bsgcomp@arm.com> Comments-Addressed: Arm Jenkins <bsgcomp@arm.com> Reviewed-by: Michalis Spyrou <michalis.spyrou@arm.com> Benchmark: Arm Jenkins <bsgcomp@arm.com>
-rw-r--r--arm_compute/core/Types.h189
-rw-r--r--arm_compute/runtime/NEON/functions/NEGEMMConvolutionLayer.h10
-rw-r--r--src/core/utils/AssemblyUtils.cpp242
-rw-r--r--src/core/utils/AssemblyUtils.h18
-rw-r--r--src/cpu/kernels/assembly/arm_gemm.hpp114
-rw-r--r--src/cpu/operators/CpuGemm.cpp2
-rw-r--r--src/cpu/operators/CpuGemm.h8
-rw-r--r--src/cpu/operators/CpuGemmConv2d.cpp64
-rw-r--r--src/cpu/operators/CpuGemmConv2d.h10
-rw-r--r--src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp43
-rw-r--r--src/cpu/operators/internal/CpuGemmAssemblyDispatch.h46
-rw-r--r--src/runtime/NEON/functions/NEGEMMConvolutionLayer.cpp2
-rw-r--r--tests/framework/Asserts.h6
-rw-r--r--tests/validation/NEON/ConvolutionLayer.cpp120
-rw-r--r--tests/validation/fixtures/ConvolutionLayerFixture.h30
-rw-r--r--utils/TypePrinter.h82
16 files changed, 648 insertions, 338 deletions
diff --git a/arm_compute/core/Types.h b/arm_compute/core/Types.h
index 989cdfb8cc..c87c97cb06 100644
--- a/arm_compute/core/Types.h
+++ b/arm_compute/core/Types.h
@@ -32,7 +32,6 @@
#include "arm_compute/core/TensorShape.h"
#include "arm_compute/core/experimental/IPostOp.h"
#include "arm_compute/core/utils/misc/Macros.h"
-#include "src/cpu/kernels/assembly/arm_gemm.hpp"
#include "support/Bfloat16.h"
#include "support/Half.h"
@@ -775,10 +774,10 @@ public:
private:
std::pair<unsigned int, unsigned int> _stride;
- unsigned int _pad_left;
- unsigned int _pad_top;
- unsigned int _pad_right;
- unsigned int _pad_bottom;
+ unsigned int _pad_left;
+ unsigned int _pad_top;
+ unsigned int _pad_right;
+ unsigned int _pad_bottom;
DimensionRoundingType _round_type;
};
@@ -920,14 +919,14 @@ public:
}
private:
- std::vector<float> _min_sizes;
- std::vector<float> _variances;
- float _offset;
- bool _flip;
- bool _clip;
- std::vector<float> _max_sizes;
- std::vector<float> _aspect_ratios;
- Coordinates2D _img_size;
+ std::vector<float> _min_sizes;
+ std::vector<float> _variances;
+ float _offset;
+ bool _flip;
+ bool _clip;
+ std::vector<float> _max_sizes;
+ std::vector<float> _aspect_ratios;
+ Coordinates2D _img_size;
std::array<float, 2> _steps;
};
@@ -1172,15 +1171,15 @@ public:
}
private:
- unsigned int _max_detections;
- unsigned int _max_classes_per_detection;
- float _nms_score_threshold;
- float _iou_threshold;
- unsigned int _num_classes;
+ unsigned int _max_detections;
+ unsigned int _max_classes_per_detection;
+ float _nms_score_threshold;
+ float _iou_threshold;
+ unsigned int _num_classes;
std::array<float, 4> _scales_values;
- bool _use_regular_nms;
- unsigned int _detection_per_class;
- bool _dequantize_scores;
+ bool _use_regular_nms;
+ unsigned int _detection_per_class;
+ bool _dequantize_scores;
};
/** Pooling Layer Information struct*/
@@ -1613,13 +1612,13 @@ public:
}
private:
- float _img_width;
- float _img_height;
- float _scale;
- bool _apply_scale;
- bool _correct_transform_coords;
+ float _img_width;
+ float _img_height;
+ float _scale;
+ bool _apply_scale;
+ bool _correct_transform_coords;
std::array<float, 4> _weights;
- float _bbox_xform_clip;
+ float _bbox_xform_clip;
};
/** Activation Layer Information class */
@@ -1895,13 +1894,117 @@ private:
int32_t _shrink_axis_mask;
};
+/** Memory layouts for the weights tensor.
+ *
+ * * UNSPECIFIED is used to select kernels that do not run in
+ * variable weights mode.
+ *
+ * * ANY is used to query the kernel database to retrieve any of the
+ * kernels that runs in variable weights mode. Once a kernel is
+ * found, the specific format expected by the kernel can be
+ * retrieved by the user for reordering the weights tensor
+ * accordingly.
+ *
+ * The other values OHWIo{interleave_by}i{block_by} describe the
+ * memory layout of a 4D tensor with layout OHWI that has been
+ * transformed into a 4D tensor with dimensions O'HWI' where:
+ *
+ * O' = first multiple of {interleave_by} s.t. O<=O'
+ * I' = first multiple of {block_by} s.t. I<=I'
+ *
+ * The total size of the dst tensor is O' x H x W x I'
+ *
+ * The access function of the tensor with layout
+ * OHWIo{interleave_by}i{block_by} and size O'HWI' is a 6-parameter
+ * access function, where the 6 parameters are computed as follows:
+ *
+ * x5 = floor(o/{interleave_by}) RANGE [0, O'/{interleave_by} -1] SIZE: O'/{interleave_by}
+ *
+ * x4 = h RANGE [0, H-1] SIZE: H
+ * x3 = w RANGE [0, W-1] SIZE: W
+ * x2 = floor(i/{block_by}) RANGE [0, I'/{block_by} -1] SIZE: I'/{block_by}
+ * x1 = o%{interleave_by} RANGE [0, {interleave_by} -1] SIZE: {interleave_by}
+ * x0 = i%{block_by} RANGE [0, {block_by} -1] SIZE: {block_by}
+ * TOTAL SIZE: O' * H * W * I'
+ *
+ * 4D 6D
+ * ----------------- -----------------------------------
+ * value(o, h, w, i) = x5 * H * W * I' * {interleave_by}
+ * + x4 * W * I' * {interleave_by}
+ * + x3 * I' * {interleave_by}
+ * + x2 * {interleave_by} * {block_by}
+ * + x1 * {block_by}
+ * + x0
+ *
+ * Notice that in arm_gemm the 4D tensor of dimension O'HWI' created
+ * for the OHWIo{interleave_by}i{block_by} format is in reality seen
+ * as a 2D tensor, where the number of rows is O'/{interleave_by}
+ * and the number of columns is {interleave_by} * H * W * I'.
+ *
+ * The postfix *_bf16 is for the memory layout needed for the
+ * fast-mode kernels, in which the weights are passed in bfloat16
+ * format.
+ */
+enum class WeightFormat
+{
+ UNSPECIFIED = 0x1,
+ ANY = 0x2,
+ OHWI = 0x100100,
+ OHWIo2 = 0x100200,
+ OHWIo4 = 0x100400,
+ OHWIo8 = 0x100800,
+ OHWIo16 = 0x101000,
+ OHWIo32 = 0x102000,
+ OHWIo64 = 0x104000,
+ OHWIo128 = 0x108000,
+ OHWIo4i2 = 0x200400,
+ OHWIo4i2_bf16 = 0x200410,
+ OHWIo8i2 = 0x200800,
+ OHWIo8i2_bf16 = 0x200810,
+ OHWIo16i2 = 0x201000,
+ OHWIo16i2_bf16 = 0x201010,
+ OHWIo32i2 = 0x202000,
+ OHWIo32i2_bf16 = 0x202010,
+ OHWIo64i2 = 0x204000,
+ OHWIo64i2_bf16 = 0x204010,
+ OHWIo4i4 = 0x400400,
+ OHWIo4i4_bf16 = 0x400410,
+ OHWIo8i4 = 0x400800,
+ OHWIo8i4_bf16 = 0x400810,
+ OHWIo16i4 = 0x401000,
+ OHWIo16i4_bf16 = 0x401010,
+ OHWIo32i4 = 0x402000,
+ OHWIo32i4_bf16 = 0x402010,
+ OHWIo64i4 = 0x404000,
+ OHWIo64i4_bf16 = 0x404010,
+ OHWIo2i8 = 0x800200,
+ OHWIo4i8 = 0x800400,
+ OHWIo8i8 = 0x800800,
+ OHWIo16i8 = 0x801000,
+ OHWIo32i8 = 0x802000,
+ OHWIo64i8 = 0x804000
+};
+// OHWIo<interleave_by>i<block_by>
+inline int interleave_by(const WeightFormat wf)
+{
+ return ((int)wf >> 8) & 0xFFF;
+}
+inline int block_by(const WeightFormat wf)
+{
+ return ((int)wf >> 20) & 0xF;
+}
+inline bool is_fixed_format(const WeightFormat wf)
+{
+ return wf != WeightFormat::UNSPECIFIED && wf != WeightFormat::ANY;
+}
+
/** Convolution Layer Weights Information class. This class stores the necessary information to compute convolution layer when the weights are already reshaped */
class WeightsInfo
{
public:
/** Default constructor */
WeightsInfo()
- : _are_reshaped(false), _kernel_width(0), _kernel_height(0), _num_kernels(0), _retain_internal_weights(false), _weight_format(arm_gemm::WeightFormat::UNSPECIFIED)
+ : _are_reshaped(false), _kernel_width(0), _kernel_height(0), _num_kernels(0), _retain_internal_weights(false), _weight_format(arm_compute::WeightFormat::UNSPECIFIED)
{
}
/** Constructor
@@ -1911,10 +2014,10 @@ public:
* @param[in] kernel_height Kernel height.
* @param[in] num_kernels Number of convolution kernels.
* @param[in] retain_internal_weights (Optional) True if internal reshaped weights must be retained. Used for reconfiguration purposes. Default is false.
- * @param[in] weight_format (Optional) arm_gemm:WeightFormat enumeration requested by the user. Default is arm_gemm::WeightFormat::UNSPECIFIED.
+ * @param[in] weight_format (Optional) arm_gemm:WeightFormat enumeration requested by the user. Default is arm_compute::WeightFormat::UNSPECIFIED.
*/
WeightsInfo(bool are_reshaped, unsigned int kernel_width, unsigned int kernel_height, unsigned int num_kernels, bool retain_internal_weights = false,
- arm_gemm::WeightFormat weight_format = arm_gemm::WeightFormat::UNSPECIFIED)
+ arm_compute::WeightFormat weight_format = arm_compute::WeightFormat::UNSPECIFIED)
: _are_reshaped(are_reshaped), _kernel_width(kernel_width), _kernel_height(kernel_height), _num_kernels(num_kernels), _retain_internal_weights(retain_internal_weights), _weight_format(weight_format)
{
}
@@ -1946,7 +2049,7 @@ public:
{
return _retain_internal_weights;
}
- arm_gemm::WeightFormat weight_format() const
+ arm_compute::WeightFormat weight_format() const
{
return _weight_format;
}
@@ -1960,12 +2063,12 @@ public:
}
private:
- bool _are_reshaped;
- unsigned int _kernel_width;
- unsigned int _kernel_height;
- unsigned int _num_kernels;
- bool _retain_internal_weights;
- arm_gemm::WeightFormat _weight_format;
+ bool _are_reshaped;
+ unsigned int _kernel_width;
+ unsigned int _kernel_height;
+ unsigned int _num_kernels;
+ bool _retain_internal_weights;
+ arm_compute::WeightFormat _weight_format;
};
/** GEMM reshape information class. This class stores the necessary information about matrix A and matrix B reshape.
@@ -2177,7 +2280,7 @@ public:
_activation_info(),
_post_ops(),
_fixed_format(false),
- _weight_format(arm_gemm::WeightFormat::UNSPECIFIED)
+ _weight_format(arm_compute::WeightFormat::UNSPECIFIED)
{
}
/** Constructor
@@ -2196,13 +2299,13 @@ public:
* @param[in] broadcast_bias (Optional) Broadcast the shape of the bias tensor from a vector to a matrix.
* @param[in] activation_info (Optional) Activation to apply after the matrix multiplication
* @param[in] post_ops (Optional) A sequence of post operations that are performed after the main operation.
- * @param[in] fixed_format (Optional) Specify the selection of fixed format kernels for variable weights support in GEMM. These kernels expect the weights tensor to be in amemory format that is fixed by the kernel itself. For more information, see arm_gemm::WeightFormat.
- * @param[in] weight_format (Optional) arm_gemm:WeightFormat enumeration requested by the user. Default is arm_gemm::WeightFormat::UNSPECIFIED.
+ * @param[in] fixed_format (Optional) Specify the selection of fixed format kernels for variable weights support in GEMM. These kernels expect the weights tensor to be in amemory format that is fixed by the kernel itself. For more information, see arm_compute::WeightFormat.
+ * @param[in] weight_format (Optional) arm_gemm:WeightFormat enumeration requested by the user. Default is arm_compute::WeightFormat::UNSPECIFIED.
*/
GEMMInfo(bool is_a_reshaped, bool is_b_reshaped, bool reshape_b_only_on_first_run, int depth_output_gemm3d = 0, bool reinterpret_input_as_3d = false, bool retain_internal_weights = false,
GEMMLowpOutputStageInfo gemmlowp_output_stage = GEMMLowpOutputStageInfo(), bool fp_mixed_precision = false, bool fast_math = false, bool broadcast_bias = false,
const ActivationLayerInfo &activation_info = ActivationLayerInfo(), const experimental::PostOpList<ITensorInfo *> &post_ops = experimental::PostOpList<ITensorInfo *>(),
- bool fixed_format = false, arm_gemm::WeightFormat weight_format = arm_gemm::WeightFormat::UNSPECIFIED) noexcept
+ bool fixed_format = false, arm_compute::WeightFormat weight_format = arm_compute::WeightFormat::UNSPECIFIED) noexcept
: _is_a_reshaped(is_a_reshaped),
_is_b_reshaped(is_b_reshaped),
_reshape_b_only_on_first_run(reshape_b_only_on_first_run),
@@ -2392,7 +2495,7 @@ public:
return _fixed_format;
}
- arm_gemm::WeightFormat weight_format() const
+ arm_compute::WeightFormat weight_format() const
{
return _weight_format;
}
@@ -2413,7 +2516,7 @@ private:
ActivationLayerInfo _activation_info;
experimental::PostOpList<ITensorInfo *> _post_ops;
bool _fixed_format;
- arm_gemm::WeightFormat _weight_format;
+ arm_compute::WeightFormat _weight_format;
};
/** Winograd information */
diff --git a/arm_compute/runtime/NEON/functions/NEGEMMConvolutionLayer.h b/arm_compute/runtime/NEON/functions/NEGEMMConvolutionLayer.h
index 2af11ad656..a28266265d 100644
--- a/arm_compute/runtime/NEON/functions/NEGEMMConvolutionLayer.h
+++ b/arm_compute/runtime/NEON/functions/NEGEMMConvolutionLayer.h
@@ -131,7 +131,7 @@ public:
*
* The user can query the database of optimised kernels in
* arm_gemm by specifying one of the enumerations of
- * arm_gemm::WeightFormat in the weight_format field of the input
+ * arm_compute::WeightFormat in the weight_format field of the input
* parameter weights_info. In case of success, the method
* writes the expected format in the output parameter
* expected_weight_format. The expected_weight_format can than be
@@ -140,7 +140,7 @@ public:
*
* Use case one - query for a specific format:
*
- * WeightInfo weights_info(..., arm_gemm::WeightFormat::OHWIo4, ...); // Set the value of the input query.
+ * WeightInfo weights_info(..., arm_compute::WeightFormat::OHWIo4, ...); // Set the value of the input query.
* if (NEGEMMConvolutionlayer::has_opt_impl(WeightFormat(), ...., weights_info, ...))
* {
* auto conv = std::unique_ptr<NEGEMMConvolutionlayer>();
@@ -150,8 +150,8 @@ public:
*
* Use case two - query for any format that would be optimal for the GEMM to execute:
*
- * WeightInfo weights_info(..., arm_gemm::WeightFormat::ANY, ...); // Set the value of the input query.
- * arm_gemm::WeightFormat expected_wf;
+ * WeightInfo weights_info(..., arm_compute::WeightFormat::ANY, ...); // Set the value of the input query.
+ * arm_compute::WeightFormat expected_wf;
* if (NEGEMMConvolutionlayer::has_opt_impl(expected_wf, ...., weights_info, ...))
* {
* auto conv = std::unique_ptr<NEGEMMConvolutionlayer>();
@@ -177,7 +177,7 @@ public:
*
* @return a Status
*/
- static Status has_opt_impl(arm_gemm::WeightFormat &expected_weight_format, const ITensorInfo *src, const ITensorInfo *weights, const ITensorInfo *biases, const ITensorInfo *dst,
+ static Status has_opt_impl(arm_compute::WeightFormat &expected_weight_format, const ITensorInfo *src, const ITensorInfo *weights, const ITensorInfo *biases, const ITensorInfo *dst,
const PadStrideInfo &conv_info,
const WeightsInfo &weights_info = WeightsInfo(), const Size2D &dilation = Size2D(1U, 1U), const ActivationLayerInfo &act_info = ActivationLayerInfo(),
bool enable_fast_math = false);
diff --git a/src/core/utils/AssemblyUtils.cpp b/src/core/utils/AssemblyUtils.cpp
index 1e8a2a54c9..45e7ff78be 100644
--- a/src/core/utils/AssemblyUtils.cpp
+++ b/src/core/utils/AssemblyUtils.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2021 Arm Limited.
+ * Copyright (c) 2021-2022 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -66,5 +66,245 @@ arm_conv::PaddingValues map_to_arm_conv_padding(const PadStrideInfo &pad_stride_
pad_stride_info.pad_right(),
pad_stride_info.pad_bottom() };
}
+
+arm_gemm::WeightFormat map_to_arm_gemm_weight_format(const arm_compute::WeightFormat &weight_format)
+{
+ arm_gemm::WeightFormat gemm_weight_fromat;
+
+ switch(weight_format)
+ {
+ case arm_compute::WeightFormat::UNSPECIFIED:
+ gemm_weight_fromat = arm_gemm::WeightFormat::UNSPECIFIED;
+ break;
+ case arm_compute::WeightFormat::ANY:
+ gemm_weight_fromat = arm_gemm::WeightFormat::ANY;
+ break;
+ case arm_compute::WeightFormat::OHWI:
+ gemm_weight_fromat = arm_gemm::WeightFormat::OHWI;
+ break;
+ case arm_compute::WeightFormat::OHWIo2:
+ gemm_weight_fromat = arm_gemm::WeightFormat::OHWIo2;
+ break;
+ case arm_compute::WeightFormat::OHWIo4:
+ gemm_weight_fromat = arm_gemm::WeightFormat::OHWIo4;
+ break;
+ case arm_compute::WeightFormat::OHWIo8:
+ gemm_weight_fromat = arm_gemm::WeightFormat::OHWIo8;
+ break;
+ case arm_compute::WeightFormat::OHWIo16:
+ gemm_weight_fromat = arm_gemm::WeightFormat::OHWIo16;
+ break;
+ case arm_compute::WeightFormat::OHWIo32:
+ gemm_weight_fromat = arm_gemm::WeightFormat::OHWIo32;
+ break;
+ case arm_compute::WeightFormat::OHWIo64:
+ gemm_weight_fromat = arm_gemm::WeightFormat::OHWIo64;
+ break;
+ case arm_compute::WeightFormat::OHWIo128:
+ gemm_weight_fromat = arm_gemm::WeightFormat::OHWIo128;
+ break;
+ case arm_compute::WeightFormat::OHWIo4i2:
+ gemm_weight_fromat = arm_gemm::WeightFormat::OHWIo4i2;
+ break;
+ case arm_compute::WeightFormat::OHWIo4i2_bf16:
+ gemm_weight_fromat = arm_gemm::WeightFormat::OHWIo4i2_bf16;
+ break;
+ case arm_compute::WeightFormat::OHWIo8i2:
+ gemm_weight_fromat = arm_gemm::WeightFormat::OHWIo8i2;
+ break;
+ case arm_compute::WeightFormat::OHWIo8i2_bf16:
+ gemm_weight_fromat = arm_gemm::WeightFormat::OHWIo8i2_bf16;
+ break;
+ case arm_compute::WeightFormat::OHWIo16i2:
+ gemm_weight_fromat = arm_gemm::WeightFormat::OHWIo16i2;
+ break;
+ case arm_compute::WeightFormat::OHWIo16i2_bf16:
+ gemm_weight_fromat = arm_gemm::WeightFormat::OHWIo16i2_bf16;
+ break;
+ case arm_compute::WeightFormat::OHWIo32i2:
+ gemm_weight_fromat = arm_gemm::WeightFormat::OHWIo32i2;
+ break;
+ case arm_compute::WeightFormat::OHWIo32i2_bf16:
+ gemm_weight_fromat = arm_gemm::WeightFormat::OHWIo32i2_bf16;
+ break;
+ case arm_compute::WeightFormat::OHWIo64i2:
+ gemm_weight_fromat = arm_gemm::WeightFormat::OHWIo64i2;
+ break;
+ case arm_compute::WeightFormat::OHWIo64i2_bf16:
+ gemm_weight_fromat = arm_gemm::WeightFormat::OHWIo64i2_bf16;
+ break;
+ case arm_compute::WeightFormat::OHWIo4i4:
+ gemm_weight_fromat = arm_gemm::WeightFormat::OHWIo4i4;
+ break;
+ case arm_compute::WeightFormat::OHWIo4i4_bf16:
+ gemm_weight_fromat = arm_gemm::WeightFormat::OHWIo4i4_bf16;
+ break;
+ case arm_compute::WeightFormat::OHWIo8i4:
+ gemm_weight_fromat = arm_gemm::WeightFormat::OHWIo8i4;
+ break;
+ case arm_compute::WeightFormat::OHWIo8i4_bf16:
+ gemm_weight_fromat = arm_gemm::WeightFormat::OHWIo8i4_bf16;
+ break;
+ case arm_compute::WeightFormat::OHWIo16i4:
+ gemm_weight_fromat = arm_gemm::WeightFormat::OHWIo16i4;
+ break;
+ case arm_compute::WeightFormat::OHWIo16i4_bf16:
+ gemm_weight_fromat = arm_gemm::WeightFormat::OHWIo16i4_bf16;
+ break;
+ case arm_compute::WeightFormat::OHWIo32i4:
+ gemm_weight_fromat = arm_gemm::WeightFormat::OHWIo32i4;
+ break;
+ case arm_compute::WeightFormat::OHWIo32i4_bf16:
+ gemm_weight_fromat = arm_gemm::WeightFormat::OHWIo32i4_bf16;
+ break;
+ case arm_compute::WeightFormat::OHWIo64i4:
+ gemm_weight_fromat = arm_gemm::WeightFormat::OHWIo64i4;
+ break;
+ case arm_compute::WeightFormat::OHWIo64i4_bf16:
+ gemm_weight_fromat = arm_gemm::WeightFormat::OHWIo64i4_bf16;
+ break;
+ case arm_compute::WeightFormat::OHWIo2i8:
+ gemm_weight_fromat = arm_gemm::WeightFormat::OHWIo2i8;
+ break;
+ case arm_compute::WeightFormat::OHWIo4i8:
+ gemm_weight_fromat = arm_gemm::WeightFormat::OHWIo4i8;
+ break;
+ case arm_compute::WeightFormat::OHWIo8i8:
+ gemm_weight_fromat = arm_gemm::WeightFormat::OHWIo8i8;
+ break;
+ case arm_compute::WeightFormat::OHWIo16i8:
+ gemm_weight_fromat = arm_gemm::WeightFormat::OHWIo16i8;
+ break;
+ case arm_compute::WeightFormat::OHWIo32i8:
+ gemm_weight_fromat = arm_gemm::WeightFormat::OHWIo32i8;
+ break;
+ case arm_compute::WeightFormat::OHWIo64i8:
+ gemm_weight_fromat = arm_gemm::WeightFormat::OHWIo64i8;
+ break;
+ default:
+ gemm_weight_fromat = arm_gemm::WeightFormat::UNSPECIFIED;
+ }
+ return gemm_weight_fromat;
+}
+
+arm_compute::WeightFormat map_to_arm_compute_weight_format(const arm_gemm::WeightFormat &weight_format)
+{
+ arm_compute::WeightFormat acl_weight_fromat;
+
+ switch(weight_format)
+ {
+ case arm_gemm::WeightFormat::UNSPECIFIED:
+ acl_weight_fromat = arm_compute::WeightFormat::UNSPECIFIED;
+ break;
+ case arm_gemm::WeightFormat::ANY:
+ acl_weight_fromat = arm_compute::WeightFormat::ANY;
+ break;
+ case arm_gemm::WeightFormat::OHWI:
+ acl_weight_fromat = arm_compute::WeightFormat::OHWI;
+ break;
+ case arm_gemm::WeightFormat::OHWIo2:
+ acl_weight_fromat = arm_compute::WeightFormat::OHWIo2;
+ break;
+ case arm_gemm::WeightFormat::OHWIo4:
+ acl_weight_fromat = arm_compute::WeightFormat::OHWIo4;
+ break;
+ case arm_gemm::WeightFormat::OHWIo8:
+ acl_weight_fromat = arm_compute::WeightFormat::OHWIo8;
+ break;
+ case arm_gemm::WeightFormat::OHWIo16:
+ acl_weight_fromat = arm_compute::WeightFormat::OHWIo16;
+ break;
+ case arm_gemm::WeightFormat::OHWIo32:
+ acl_weight_fromat = arm_compute::WeightFormat::OHWIo32;
+ break;
+ case arm_gemm::WeightFormat::OHWIo64:
+ acl_weight_fromat = arm_compute::WeightFormat::OHWIo64;
+ break;
+ case arm_gemm::WeightFormat::OHWIo128:
+ acl_weight_fromat = arm_compute::WeightFormat::OHWIo128;
+ break;
+ case arm_gemm::WeightFormat::OHWIo4i2:
+ acl_weight_fromat = arm_compute::WeightFormat::OHWIo4i2;
+ break;
+ case arm_gemm::WeightFormat::OHWIo4i2_bf16:
+ acl_weight_fromat = arm_compute::WeightFormat::OHWIo4i2_bf16;
+ break;
+ case arm_gemm::WeightFormat::OHWIo8i2:
+ acl_weight_fromat = arm_compute::WeightFormat::OHWIo8i2;
+ break;
+ case arm_gemm::WeightFormat::OHWIo8i2_bf16:
+ acl_weight_fromat = arm_compute::WeightFormat::OHWIo8i2_bf16;
+ break;
+ case arm_gemm::WeightFormat::OHWIo16i2:
+ acl_weight_fromat = arm_compute::WeightFormat::OHWIo16i2;
+ break;
+ case arm_gemm::WeightFormat::OHWIo16i2_bf16:
+ acl_weight_fromat = arm_compute::WeightFormat::OHWIo16i2_bf16;
+ break;
+ case arm_gemm::WeightFormat::OHWIo32i2:
+ acl_weight_fromat = arm_compute::WeightFormat::OHWIo32i2;
+ break;
+ case arm_gemm::WeightFormat::OHWIo32i2_bf16:
+ acl_weight_fromat = arm_compute::WeightFormat::OHWIo32i2_bf16;
+ break;
+ case arm_gemm::WeightFormat::OHWIo64i2:
+ acl_weight_fromat = arm_compute::WeightFormat::OHWIo64i2;
+ break;
+ case arm_gemm::WeightFormat::OHWIo64i2_bf16:
+ acl_weight_fromat = arm_compute::WeightFormat::OHWIo64i2_bf16;
+ break;
+ case arm_gemm::WeightFormat::OHWIo4i4:
+ acl_weight_fromat = arm_compute::WeightFormat::OHWIo4i4;
+ break;
+ case arm_gemm::WeightFormat::OHWIo4i4_bf16:
+ acl_weight_fromat = arm_compute::WeightFormat::OHWIo4i4_bf16;
+ break;
+ case arm_gemm::WeightFormat::OHWIo8i4:
+ acl_weight_fromat = arm_compute::WeightFormat::OHWIo8i4;
+ break;
+ case arm_gemm::WeightFormat::OHWIo8i4_bf16:
+ acl_weight_fromat = arm_compute::WeightFormat::OHWIo8i4_bf16;
+ break;
+ case arm_gemm::WeightFormat::OHWIo16i4:
+ acl_weight_fromat = arm_compute::WeightFormat::OHWIo16i4;
+ break;
+ case arm_gemm::WeightFormat::OHWIo16i4_bf16:
+ acl_weight_fromat = arm_compute::WeightFormat::OHWIo16i4_bf16;
+ break;
+ case arm_gemm::WeightFormat::OHWIo32i4:
+ acl_weight_fromat = arm_compute::WeightFormat::OHWIo32i4;
+ break;
+ case arm_gemm::WeightFormat::OHWIo32i4_bf16:
+ acl_weight_fromat = arm_compute::WeightFormat::OHWIo32i4_bf16;
+ break;
+ case arm_gemm::WeightFormat::OHWIo64i4:
+ acl_weight_fromat = arm_compute::WeightFormat::OHWIo64i4;
+ break;
+ case arm_gemm::WeightFormat::OHWIo64i4_bf16:
+ acl_weight_fromat = arm_compute::WeightFormat::OHWIo64i4_bf16;
+ break;
+ case arm_gemm::WeightFormat::OHWIo2i8:
+ acl_weight_fromat = arm_compute::WeightFormat::OHWIo2i8;
+ break;
+ case arm_gemm::WeightFormat::OHWIo4i8:
+ acl_weight_fromat = arm_compute::WeightFormat::OHWIo4i8;
+ break;
+ case arm_gemm::WeightFormat::OHWIo8i8:
+ acl_weight_fromat = arm_compute::WeightFormat::OHWIo8i8;
+ break;
+ case arm_gemm::WeightFormat::OHWIo16i8:
+ acl_weight_fromat = arm_compute::WeightFormat::OHWIo16i8;
+ break;
+ case arm_gemm::WeightFormat::OHWIo32i8:
+ acl_weight_fromat = arm_compute::WeightFormat::OHWIo32i8;
+ break;
+ case arm_gemm::WeightFormat::OHWIo64i8:
+ acl_weight_fromat = arm_compute::WeightFormat::OHWIo64i8;
+ break;
+ default:
+ acl_weight_fromat = arm_compute::WeightFormat::UNSPECIFIED;
+ }
+ return acl_weight_fromat;
+}
} // namespace assembly_utils
} // namespace arm_compute
diff --git a/src/core/utils/AssemblyUtils.h b/src/core/utils/AssemblyUtils.h
index b1aee64d5d..7514175ed6 100644
--- a/src/core/utils/AssemblyUtils.h
+++ b/src/core/utils/AssemblyUtils.h
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2021 Arm Limited.
+ * Copyright (c) 2021-2022 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -47,6 +47,22 @@ arm_gemm::Activation map_to_arm_gemm_activation(const ActivationLayerInfo &act);
* @return Assembly padding values.
*/
arm_conv::PaddingValues map_to_arm_conv_padding(const PadStrideInfo &pad_stride_info);
+
+/** Performs a mapping from Compute Library WeightFormat to the assembly WeightFormat enum
+ *
+ * @param[in] weight_format Compute Library WeightFormat enum value
+ *
+ * @return Assembly WeightFormat
+ */
+arm_gemm::WeightFormat map_to_arm_gemm_weight_format(const arm_compute::WeightFormat &weight_format);
+
+/** Performs a mapping from Assembly WeightFormat to the Compute Library WeightFormat enum
+ *
+ * @param[in] weight_format Assembly WeightFormat enum value
+ *
+ * @return Compute Library WeightFormat
+ */
+arm_compute::WeightFormat map_to_arm_compute_weight_format(const arm_gemm::WeightFormat &weight_format);
} // namespace assembly
} // namespace arm_compute
#endif /* UTILS_CORE_ASSEMBLY_UTILS_H */
diff --git a/src/cpu/kernels/assembly/arm_gemm.hpp b/src/cpu/kernels/assembly/arm_gemm.hpp
index 48fd7c6b43..4c127b4ec3 100644
--- a/src/cpu/kernels/assembly/arm_gemm.hpp
+++ b/src/cpu/kernels/assembly/arm_gemm.hpp
@@ -47,57 +47,6 @@ enum class GemmMethod
GEMM_HYBRID_QUANTIZED
};
-/** Memory layouts for the weights tensor.
- *
- * * UNSPECIFIED is used to select kernels that do not run in
- * variable weights mode.
- *
- * * ANY is used to query the kernel database to retrieve any of the
- * kernels that runs in variable weights mode. Once a kernel is
- * found, the specific format expected by the kernel can be
- * retrieved by the user for reordering the weights tensor
- * accordingly.
- *
- * The other values OHWIo{interleave_by}i{block_by} describe the
- * memory layout of a 4D tensor with layout OHWI that has been
- * transformed into a 4D tensor with dimensions O'HWI' where:
- *
- * O' = first multiple of {interleave_by} s.t. O<=O'
- * I' = first multiple of {block_by} s.t. I<=I'
- *
- * The total size of the dst tensor is O' x H x W x I'
- *
- * The access function of the tensor with layout
- * OHWIo{interleave_by}i{block_by} and size O'HWI' is a 6-parameter
- * access function, where the 6 parameters are computed as follows:
- *
- * x5 = floor(o/{interleave_by}) RANGE [0, O'/{interleave_by} -1] SIZE: O'/{interleave_by}
- *
- * x4 = h RANGE [0, H-1] SIZE: H
- * x3 = w RANGE [0, W-1] SIZE: W
- * x2 = floor(i/{block_by}) RANGE [0, I'/{block_by} -1] SIZE: I'/{block_by}
- * x1 = o%{interleave_by} RANGE [0, {interleave_by} -1] SIZE: {interleave_by}
- * x0 = i%{block_by} RANGE [0, {block_by} -1] SIZE: {block_by}
- * TOTAL SIZE: O' * H * W * I'
- *
- * 4D 6D
- * ----------------- -----------------------------------
- * value(o, h, w, i) = x5 * H * W * I' * {interleave_by}
- * + x4 * W * I' * {interleave_by}
- * + x3 * I' * {interleave_by}
- * + x2 * {interleave_by} * {block_by}
- * + x1 * {block_by}
- * + x0
- *
- * Notice that in arm_gemm the 4D tensor of dimension O'HWI' created
- * for the OHWIo{interleave_by}i{block_by} format is in reality seen
- * as a 2D tensor, where the number of rows is O'/{interleave_by}
- * and the number of columns is {interleave_by} * H * W * I'.
- *
- * The postfix *_bf16 is for the memory layout needed for the
- * fast-mode kernels, in which the weights are passed in bfloat16
- * format.
- */
enum class WeightFormat
{
UNSPECIFIED = 0x1,
@@ -138,69 +87,6 @@ enum class WeightFormat
OHWIo64i8 = 0x804000
};
-// OHWIo<interleave_by>i<block_by>
-inline int interleave_by(const WeightFormat wf)
-{
- return ((int)wf >> 8) & 0xFFF;
-}
-inline int block_by(const WeightFormat wf)
-{
- return ((int)wf >> 20) & 0xF;
-}
-inline bool is_fixed_format(const WeightFormat wf)
-{
- return wf != WeightFormat::UNSPECIFIED && wf != WeightFormat::ANY;
-}
-
-inline std::string to_string(WeightFormat wf)
-{
-#define __CASE_WEIGHT_FORMAT(wf) \
-case WeightFormat::wf: \
- return #wf;
- switch(wf)
- {
- __CASE_WEIGHT_FORMAT(UNSPECIFIED)
- __CASE_WEIGHT_FORMAT(ANY)
- __CASE_WEIGHT_FORMAT(OHWI)
- __CASE_WEIGHT_FORMAT(OHWIo2)
- __CASE_WEIGHT_FORMAT(OHWIo4)
- __CASE_WEIGHT_FORMAT(OHWIo8)
- __CASE_WEIGHT_FORMAT(OHWIo16)
- __CASE_WEIGHT_FORMAT(OHWIo32)
- __CASE_WEIGHT_FORMAT(OHWIo64)
- __CASE_WEIGHT_FORMAT(OHWIo128)
- __CASE_WEIGHT_FORMAT(OHWIo4i2)
- __CASE_WEIGHT_FORMAT(OHWIo4i2_bf16)
- __CASE_WEIGHT_FORMAT(OHWIo8i2)
- __CASE_WEIGHT_FORMAT(OHWIo8i2_bf16)
- __CASE_WEIGHT_FORMAT(OHWIo16i2)
- __CASE_WEIGHT_FORMAT(OHWIo16i2_bf16)
- __CASE_WEIGHT_FORMAT(OHWIo32i2)
- __CASE_WEIGHT_FORMAT(OHWIo32i2_bf16)
- __CASE_WEIGHT_FORMAT(OHWIo64i2)
- __CASE_WEIGHT_FORMAT(OHWIo64i2_bf16)
- __CASE_WEIGHT_FORMAT(OHWIo4i4)
- __CASE_WEIGHT_FORMAT(OHWIo4i4_bf16)
- __CASE_WEIGHT_FORMAT(OHWIo8i4)
- __CASE_WEIGHT_FORMAT(OHWIo8i4_bf16)
- __CASE_WEIGHT_FORMAT(OHWIo16i4)
- __CASE_WEIGHT_FORMAT(OHWIo16i4_bf16)
- __CASE_WEIGHT_FORMAT(OHWIo32i4)
- __CASE_WEIGHT_FORMAT(OHWIo32i4_bf16)
- __CASE_WEIGHT_FORMAT(OHWIo64i4)
- __CASE_WEIGHT_FORMAT(OHWIo64i4_bf16)
- __CASE_WEIGHT_FORMAT(OHWIo2i8)
- __CASE_WEIGHT_FORMAT(OHWIo4i8)
- __CASE_WEIGHT_FORMAT(OHWIo8i8)
- __CASE_WEIGHT_FORMAT(OHWIo16i8)
- __CASE_WEIGHT_FORMAT(OHWIo32i8)
- __CASE_WEIGHT_FORMAT(OHWIo64i8)
- default:
- return "invalid value";
- }
-#undef __CASE_WEIGHT_FORMAT
-}
-
struct KernelDescription
{
GemmMethod method = GemmMethod::DEFAULT;
diff --git a/src/cpu/operators/CpuGemm.cpp b/src/cpu/operators/CpuGemm.cpp
index f3fff608dc..f6582c73f8 100644
--- a/src/cpu/operators/CpuGemm.cpp
+++ b/src/cpu/operators/CpuGemm.cpp
@@ -368,7 +368,7 @@ experimental::MemoryRequirements CpuGemm::workspace() const
return _aux_mem;
}
-Status CpuGemm::has_opt_impl(arm_gemm::WeightFormat &expected_weight_format, const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *d,
+Status CpuGemm::has_opt_impl(arm_compute::WeightFormat &expected_weight_format, const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *d,
const GEMMInfo &gemm_info)
{
const cpu::AsmGemmInfo asm_info = init_assembly_metadata(gemm_info);
diff --git a/src/cpu/operators/CpuGemm.h b/src/cpu/operators/CpuGemm.h
index b37ab73485..8d34b22437 100644
--- a/src/cpu/operators/CpuGemm.h
+++ b/src/cpu/operators/CpuGemm.h
@@ -105,15 +105,15 @@ public:
*
* This method has the same use of @ref
* NEGEMMConvolutionLayer::has_opt_impl, with the only caveat that
- * the value of arm_gemm::WeightFormat need to be passed via the
+ * the value of arm_compute::WeightFormat need to be passed via the
* parameter gemm_info.
*/
- static Status has_opt_impl(arm_gemm::WeightFormat &weight_format, const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *d,
+ static Status has_opt_impl(arm_compute::WeightFormat &weight_format, const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *d,
const GEMMInfo &gemm_info = GEMMInfo());
// Inherited methods overridden:
- void run(ITensorPack &tensors) override;
- void prepare(ITensorPack &constants) override;
+ void run(ITensorPack &tensors) override;
+ void prepare(ITensorPack &constants) override;
experimental::MemoryRequirements workspace() const override;
/** Indicates if the convolution executes in variable weights mode.
diff --git a/src/cpu/operators/CpuGemmConv2d.cpp b/src/cpu/operators/CpuGemmConv2d.cpp
index 0174d0eed3..f3a16f104f 100644
--- a/src/cpu/operators/CpuGemmConv2d.cpp
+++ b/src/cpu/operators/CpuGemmConv2d.cpp
@@ -62,13 +62,13 @@ CpuGemmConv2d::SkipInfo CpuGemmConv2d::skip_im_col_info(const ITensorInfo *src,
const unsigned int kernel_height = weights->dimension(idx_height);
unsigned int conv_w = 0;
unsigned int conv_h = 0;
- std::tie(conv_w, conv_h) = scaled_dimensions(src->dimension(idx_width),
- src->dimension(idx_height),
- kernel_width,
- kernel_height,
- conv_info,
- dilation);
- const bool skip_im2col = (data_layout == DataLayout::NHWC && kernel_width == 1 && kernel_height == 1 && conv_info.stride().first == 1 && conv_info.stride().second == 1);
+ std::tie(conv_w, conv_h) = scaled_dimensions(src->dimension(idx_width),
+ src->dimension(idx_height),
+ kernel_width,
+ kernel_height,
+ conv_info,
+ dilation);
+ const bool skip_im2col = (data_layout == DataLayout::NHWC && kernel_width == 1 && kernel_height == 1 && conv_info.stride().first == 1 && conv_info.stride().second == 1);
if(skip_im2col)
{
@@ -99,7 +99,7 @@ CpuGemmConv2d::CpuGemmConv2d()
CpuGemmConv2d::~CpuGemmConv2d() = default;
void CpuGemmConv2d::configure_mm(const ITensorInfo *src, const ITensorInfo *weights, const ITensorInfo *biases, ITensorInfo *dst, const ActivationLayerInfo &act_info,
- bool enable_fast_math, int gemm_3d_depth, bool fixed_format, arm_gemm::WeightFormat weight_format)
+ bool enable_fast_math, int gemm_3d_depth, bool fixed_format, arm_compute::WeightFormat weight_format)
{
ARM_COMPUTE_ERROR_ON_NULLPTR(src, weights);
ARM_COMPUTE_ERROR_THROW_ON(validate_mm(src, weights, biases, dst, act_info, enable_fast_math, gemm_3d_depth, _skip_im2col, fixed_format, weight_format));
@@ -139,8 +139,8 @@ void CpuGemmConv2d::configure_mm(const ITensorInfo *src, const ITensorInfo *weig
PixelValue type_min{};
PixelValue type_max{};
std::tie(type_min, type_max) = get_min_max(data_type);
- int32_t min_activation = type_min.get<int32_t>();
- int32_t max_activation = type_max.get<int32_t>();
+ int32_t min_activation = type_min.get<int32_t>();
+ int32_t max_activation = type_max.get<int32_t>();
if(supported_acts.count(act_info.activation()) != 0)
{
@@ -179,7 +179,7 @@ void CpuGemmConv2d::configure_mm(const ITensorInfo *src, const ITensorInfo *weig
}
Status CpuGemmConv2d::validate_mm(const ITensorInfo *src, const ITensorInfo *weights, const ITensorInfo *biases, const ITensorInfo *dst,
- const ActivationLayerInfo &act_info, bool enable_fast_math, int gemm_3d_depth, bool skip_im2col, bool fixed_format, arm_gemm::WeightFormat weight_format)
+ const ActivationLayerInfo &act_info, bool enable_fast_math, int gemm_3d_depth, bool skip_im2col, bool fixed_format, arm_compute::WeightFormat weight_format)
{
const DataType data_type = src->data_type();
const bool is_quantized = is_data_type_quantized_asymmetric(data_type);
@@ -203,8 +203,8 @@ Status CpuGemmConv2d::validate_mm(const ITensorInfo *src, const ITensorInfo *wei
PixelValue type_min{};
PixelValue type_max{};
std::tie(type_min, type_max) = get_min_max(data_type);
- int32_t min_activation = type_min.get<int32_t>();
- int32_t max_activation = type_max.get<int32_t>();
+ int32_t min_activation = type_min.get<int32_t>();
+ int32_t max_activation = type_max.get<int32_t>();
const std::set<ActivationLayerInfo::ActivationFunction> supported_acts = { ActivationLayerInfo::ActivationFunction::RELU,
ActivationLayerInfo::ActivationFunction::BOUNDED_RELU,
@@ -288,8 +288,8 @@ void CpuGemmConv2d::configure(const ITensorInfo *src, const ITensorInfo *weights
ITensorInfo *gemm_output_to_use = dst;
// Get convolved dimensions
- unsigned int conv_w = 0;
- unsigned int conv_h = 0;
+ unsigned int conv_w = 0;
+ unsigned int conv_h = 0;
std::tie(conv_w, conv_h) = scaled_dimensions(src->dimension(idx_width),
src->dimension(idx_height),
kernel_width,
@@ -306,8 +306,8 @@ void CpuGemmConv2d::configure(const ITensorInfo *src, const ITensorInfo *weights
_skip_col2im = skip_info.skip_col2im;
// Get parameters from conv_info
- unsigned int stride_x = 0;
- unsigned int stride_y = 0;
+ unsigned int stride_x = 0;
+ unsigned int stride_y = 0;
std::tie(stride_x, stride_y) = conv_info.stride();
unsigned int mat_weights_cols = weights->dimension(idx_kernels);
@@ -360,7 +360,7 @@ void CpuGemmConv2d::configure(const ITensorInfo *src, const ITensorInfo *weights
// Configure GEMM
// In case we need to skip col2im, GEMM3D (gemm_3d_depth != 0) must be called in order to avoid reshaping the output matrix
const unsigned int gemm_3d_depth = _skip_col2im ? conv_h : 0;
- const bool fixed_format = weights_info.weight_format() != arm_gemm::WeightFormat::UNSPECIFIED;
+ const bool fixed_format = weights_info.weight_format() != arm_compute::WeightFormat::UNSPECIFIED;
configure_mm(gemm_input_to_use, &_weights_reshaped, biases, gemm_output_to_use, act_info, enable_fast_math, gemm_3d_depth, fixed_format, weights_info.weight_format());
if(!_skip_col2im && _data_layout == DataLayout::NCHW)
@@ -388,7 +388,7 @@ void CpuGemmConv2d::configure(const ITensorInfo *src, const ITensorInfo *weights
_aux_mem[GemmOutput] = MemoryInfo(offset_int_vec(GemmOutput), MemoryLifetime::Temporary, _gemm_output.total_size());
}
-Status CpuGemmConv2d::has_opt_impl(arm_gemm::WeightFormat &expected_weight_format, const ITensorInfo *src, const ITensorInfo *weights, const ITensorInfo *biases, const ITensorInfo *dst,
+Status CpuGemmConv2d::has_opt_impl(arm_compute::WeightFormat &expected_weight_format, const ITensorInfo *src, const ITensorInfo *weights, const ITensorInfo *biases, const ITensorInfo *dst,
const PadStrideInfo &conv_info,
const WeightsInfo &weights_info, const Size2D &dilation, const ActivationLayerInfo &act_info, const bool enable_fast_math)
{
@@ -399,12 +399,12 @@ Status CpuGemmConv2d::has_opt_impl(arm_gemm::WeightFormat &expected_weight_forma
const unsigned int kernel_height = weights->dimension(idx_height);
unsigned int conv_w = 0;
unsigned int conv_h = 0;
- std::tie(conv_w, conv_h) = scaled_dimensions(src->dimension(idx_width),
- src->dimension(idx_height),
- kernel_width,
- kernel_height,
- conv_info,
- dilation);
+ std::tie(conv_w, conv_h) = scaled_dimensions(src->dimension(idx_width),
+ src->dimension(idx_height),
+ kernel_width,
+ kernel_height,
+ conv_info,
+ dilation);
const CpuGemmConv2d::SkipInfo skip_info = CpuGemmConv2d::skip_im_col_info(src, weights, conv_info,
dilation, act_info);
@@ -412,7 +412,7 @@ Status CpuGemmConv2d::has_opt_impl(arm_gemm::WeightFormat &expected_weight_forma
const bool skip_im2col = skip_info.skip_im2col;
const bool skip_col2im = skip_info.skip_col2im;
const unsigned int gemm_3d_depth = skip_col2im ? conv_h : 0;
- const bool fixed_format = weights_info.weight_format() != arm_gemm::WeightFormat::UNSPECIFIED;
+ const bool fixed_format = weights_info.weight_format() != arm_compute::WeightFormat::UNSPECIFIED;
const GEMMInfo gemm_info = GEMMInfo(false, false, true /* Reshape weights only for the first run */,
gemm_3d_depth, skip_im2col /* Reinterpret the input as 3D if im2col is skipped */,
false, GEMMLowpOutputStageInfo(), false, enable_fast_math, false, act_info, experimental::PostOpList<ITensorInfo *>(), fixed_format, weights_info.weight_format());
@@ -464,9 +464,9 @@ Status CpuGemmConv2d::validate(const ITensorInfo *src, const ITensorInfo *weight
dilation);
// Check if GEMM3D is supported
- const CpuGemmConv2d::SkipInfo skip_info = CpuGemmConv2d::skip_im_col_info(src, weights, conv_info,
- dilation, act_info);
- const bool skip_im2col = skip_info.skip_im2col, skip_col2im = skip_info.skip_col2im;
+ const CpuGemmConv2d::SkipInfo skip_info = CpuGemmConv2d::skip_im_col_info(src, weights, conv_info,
+ dilation, act_info);
+ const bool skip_im2col = skip_info.skip_im2col, skip_col2im = skip_info.skip_col2im;
ARM_COMPUTE_RETURN_ERROR_ON(weights->dimension(idx_channel) != src->dimension(idx_channel));
ARM_COMPUTE_RETURN_ERROR_ON(weights->num_dimensions() > 4);
@@ -527,7 +527,7 @@ Status CpuGemmConv2d::validate(const ITensorInfo *src, const ITensorInfo *weight
}
info_gemm.set_quantization_info(dst->quantization_info()).set_data_layout(src->data_layout());
gemm_output_to_use = &info_gemm;
- const bool fixed_format = weights_info.weight_format() != arm_gemm::WeightFormat::UNSPECIFIED;
+ const bool fixed_format = weights_info.weight_format() != arm_compute::WeightFormat::UNSPECIFIED;
ARM_COMPUTE_RETURN_ON_ERROR(validate_mm(gemm_input_to_use, weights_to_use, biases, gemm_output_to_use, act_info, enable_fast_math, skip_col2im ? conv_h : 0, skip_im2col, fixed_format,
weights_info.weight_format()));
@@ -558,7 +558,7 @@ void CpuGemmConv2d::run(ITensorPack &tensors)
{
// Run input reshaping
unsigned int y_dim = get_data_layout_dimension_index(_data_layout, DataLayoutDimension::HEIGHT);
- ITensorPack pack =
+ ITensorPack pack =
{
{ TensorType::ACL_SRC, src },
{ TensorType::ACL_DST, im2col_output.get() }
@@ -652,7 +652,7 @@ void CpuGemmConv2d::prepare(ITensorPack &tensors)
// Run weights reshaping and mark original weights tensor as unused
CpuAuxTensorHandler weights_reshaped(offset_int_vec(WeightsReshaped), _weights_reshaped, tensors);
auto weights = tensors.get_const_tensor(TensorType::ACL_SRC_1);
- ITensorPack pack =
+ ITensorPack pack =
{
{ TensorType::ACL_SRC, weights },
{ TensorType::ACL_DST, weights_reshaped.get() }
diff --git a/src/cpu/operators/CpuGemmConv2d.h b/src/cpu/operators/CpuGemmConv2d.h
index f8f0bce048..08b76a6c46 100644
--- a/src/cpu/operators/CpuGemmConv2d.h
+++ b/src/cpu/operators/CpuGemmConv2d.h
@@ -123,14 +123,14 @@ public:
*
* @return a status.
*/
- static Status has_opt_impl(arm_gemm::WeightFormat &expected_weight_format, const ITensorInfo *src, const ITensorInfo *weights, const ITensorInfo *biases, const ITensorInfo *output,
+ static Status has_opt_impl(arm_compute::WeightFormat &expected_weight_format, const ITensorInfo *src, const ITensorInfo *weights, const ITensorInfo *biases, const ITensorInfo *output,
const PadStrideInfo &conv_info,
const WeightsInfo &weights_info = WeightsInfo(), const Size2D &dilation = Size2D(1U, 1U), const ActivationLayerInfo &act_info = ActivationLayerInfo(),
const bool enable_fast_math = false);
// Inherited methods overridden:
- void run(ITensorPack &tensors) override;
- void prepare(ITensorPack &tensors) override;
+ void run(ITensorPack &tensors) override;
+ void prepare(ITensorPack &tensors) override;
experimental::MemoryRequirements workspace() const override;
private:
@@ -150,7 +150,7 @@ private:
* @param[in] weight_format (Optional) The layout to be used for the weights tensor when running GEMM with variable weights.
*/
void configure_mm(const ITensorInfo *src, const ITensorInfo *weights, const ITensorInfo *biases, ITensorInfo *output, const ActivationLayerInfo &act_info = ActivationLayerInfo(),
- bool enable_fast_math = false, int gemm_3d_depth = 1, bool fixed_format = false, arm_gemm::WeightFormat weight_format = arm_gemm::WeightFormat::UNSPECIFIED);
+ bool enable_fast_math = false, int gemm_3d_depth = 1, bool fixed_format = false, arm_compute::WeightFormat weight_format = arm_compute::WeightFormat::UNSPECIFIED);
/** Static function to check if given info will lead to a valid configuration of @ref NEGEMMConvolutionLayer matrix multiply routines
*
* @param[in] src Input tensor info. Data types supported: QASYMM8/QASYMM8_SIGNED/BFLOAT16/F16/F32.
@@ -170,7 +170,7 @@ private:
* @return a status
*/
static Status validate_mm(const ITensorInfo *src, const ITensorInfo *weights, const ITensorInfo *biases, const ITensorInfo *dst, const ActivationLayerInfo &act_info = ActivationLayerInfo(),
- bool enable_fast_math = false, int gemm_3d_depth = 1, bool skip_im2col = false, bool fixed_format = false, arm_gemm::WeightFormat weight_format = arm_gemm::WeightFormat::UNSPECIFIED);
+ bool enable_fast_math = false, int gemm_3d_depth = 1, bool skip_im2col = false, bool fixed_format = false, arm_compute::WeightFormat weight_format = arm_compute::WeightFormat::UNSPECIFIED);
/** Static function to check if GEMM3D is supported in @ref NEGEMM or in @ref CpuGemmMLowpMatrixMultiplyCore
*
* @param[in] src Input tensor info. Data types supported: QASYMM8/QASYMM8_SIGNED/BFLOAT16/F16/F32.
diff --git a/src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp b/src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp
index 558ff41a5c..c969c9f4f6 100644
--- a/src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp
+++ b/src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp
@@ -164,8 +164,8 @@ public:
{
if(!_gemm_kernel_asm)
return false;
- const arm_gemm::WeightFormat wf = _gemm_kernel_asm->get_config().weight_format;
- return wf != arm_gemm::WeightFormat::UNSPECIFIED && wf != arm_gemm::WeightFormat::ANY;
+ const arm_compute::WeightFormat wf = assembly_utils::map_to_arm_compute_weight_format(_gemm_kernel_asm->get_config().weight_format);
+ return wf != arm_compute::WeightFormat::UNSPECIFIED && wf != arm_compute::WeightFormat::ANY;
}
private:
@@ -428,7 +428,7 @@ void Fallback<TypeInput, TypeOutput, OutputStage>::prepare(ITensorPack &tensors)
if(_gemm_kernel_asm->B_pretranspose_required())
{
// Fixed format kernels need no pretranspose.
- ARM_COMPUTE_ERROR_ON(arm_gemm::is_fixed_format(_gemm_kernel_asm->get_config().weight_format));
+ ARM_COMPUTE_ERROR_ON(arm_compute::is_fixed_format(assembly_utils::map_to_arm_compute_weight_format(_gemm_kernel_asm->get_config().weight_format)));
const int ldb = b->info()->strides_in_bytes().y() / sizeof(TypeInput);
const auto in1_ptr = reinterpret_cast<const TypeInput *>(b->buffer() + b->info()->offset_first_element_in_bytes());
const int multi_stride_b = b->info()->strides_in_bytes().z() / sizeof(TypeInput);
@@ -492,8 +492,8 @@ void Fallback<TypeInput, TypeOutput, OutputStage>::run(ITensorPack &tensors)
// Check if B is pre-tranposed and de-reference if not
if(!_gemm_kernel_asm->B_is_pretransposed())
{
- ldb = b->info()->strides_in_bytes().y() / sizeof(TypeInput);
- const arm_gemm::WeightFormat wf = _gemm_kernel_asm->get_config().weight_format;
+ ldb = b->info()->strides_in_bytes().y() / sizeof(TypeInput);
+ const arm_compute::WeightFormat wf = assembly_utils::map_to_arm_compute_weight_format(_gemm_kernel_asm->get_config().weight_format);
if(is_fixed_format(wf))
{
// The 4D tensor of dimension O'HWI' created for the
@@ -507,7 +507,7 @@ void Fallback<TypeInput, TypeOutput, OutputStage>::run(ITensorPack &tensors)
const int H = tensor_shape[get_data_layout_dimension_index(data_layout, DataLayoutDimension::HEIGHT)];
const int W = tensor_shape[get_data_layout_dimension_index(data_layout, DataLayoutDimension::WIDTH)];
const int Ip = tensor_shape[get_data_layout_dimension_index(data_layout, DataLayoutDimension::CHANNEL)];
- const int interleave_by = arm_gemm::interleave_by(wf);
+ const int interleave_by = arm_compute::interleave_by(wf);
ldb = (interleave_by * H * W * Ip);
}
multi_stride_b = b->info()->strides_in_bytes().z() / sizeof(TypeInput);
@@ -603,7 +603,7 @@ void create_arm_gemm(std::unique_ptr<CpuGemmAssemblyDispatch::IFallback> &arm_ge
unsigned int num_threads = NEScheduler::get().num_threads();
arm_gemm::GemmConfig cfg;
- cfg.weight_format = info.weight_format;
+ cfg.weight_format = assembly_utils::map_to_arm_gemm_weight_format(info.weight_format);
arm_gemm::GemmArgs args(&ci, p.M, p.N, p.K, p.sections, p.batches, p.multis, p.indirect, activation, num_threads, info.fixed_format, info.fast_mode, &cfg);
// Create arm_gemm fallback
@@ -623,7 +623,7 @@ void create_arm_gemm_quant(std::unique_ptr<CpuGemmAssemblyDispatch::IFallback> &
const unsigned int num_threads = NEScheduler::get().num_threads();
arm_gemm::GemmConfig cfg;
- cfg.weight_format = info.weight_format;
+ cfg.weight_format = assembly_utils::map_to_arm_gemm_weight_format(info.weight_format);
arm_gemm::GemmArgs args(&ci, p.M, p.N, p.K, p.sections, p.batches, p.multis, p.indirect, activation, num_threads, info.fixed_format, info.fast_mode, &cfg);
// Create arm_gemm fallback
@@ -665,7 +665,7 @@ CpuGemmAssemblyDispatch::CpuGemmAssemblyDispatch()
{
}
-Status CpuGemmAssemblyDispatch::has_opt_impl(arm_gemm::WeightFormat &expected_weight_format, const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *d,
+Status CpuGemmAssemblyDispatch::has_opt_impl(arm_compute::WeightFormat &expected_weight_format, const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *d,
const AsmGemmInfo &info)
{
ARM_COMPUTE_ERROR_ON_NULLPTR(a, b, d);
@@ -675,13 +675,13 @@ Status CpuGemmAssemblyDispatch::has_opt_impl(arm_gemm::WeightFormat &expected_we
const CPUInfo &ci = NEScheduler::get().cpu_info();
unsigned int num_threads = NEScheduler::get().num_threads();
arm_gemm::GemmConfig cfg;
- cfg.weight_format = info.weight_format;
-
- arm_gemm::GemmArgs args(&ci, p.M, p.N, p.K, p.sections, p.batches, p.multis, p.indirect, act, num_threads, info.fixed_format, info.fast_mode, &cfg);
+ cfg.weight_format = assembly_utils::map_to_arm_gemm_weight_format(info.weight_format);
+ arm_gemm::WeightFormat arm_gemm_expected_wf = assembly_utils::map_to_arm_gemm_weight_format(expected_weight_format);
+ arm_gemm::GemmArgs args(&ci, p.M, p.N, p.K, p.sections, p.batches, p.multis, p.indirect, act, num_threads, info.fixed_format, info.fast_mode, &cfg);
switch(a->data_type())
{
case DataType::F32:
- ARM_COMPUTE_RETURN_ERROR_ON_MSG(!(arm_gemm::has_opt_gemm<float, float, arm_gemm::Nothing>(expected_weight_format, args, {})),
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(!(arm_gemm::has_opt_gemm<float, float, arm_gemm::Nothing>(arm_gemm_expected_wf, args, {})),
"We could not find an optimized kernel for F32 input");
break;
#ifdef __aarch64__
@@ -689,12 +689,12 @@ Status CpuGemmAssemblyDispatch::has_opt_impl(arm_gemm::WeightFormat &expected_we
case DataType::QASYMM8:
if(d->data_type() == DataType::S32)
{
- ARM_COMPUTE_RETURN_ERROR_ON_MSG(!(arm_gemm::has_opt_gemm<uint8_t, uint32_t, arm_gemm::Nothing>(expected_weight_format, args, {})),
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(!(arm_gemm::has_opt_gemm<uint8_t, uint32_t, arm_gemm::Nothing>(arm_gemm_expected_wf, args, {})),
"We could not find an optimized kernel for U8/QASYMM8 input and S32 output");
}
else
{
- ARM_COMPUTE_RETURN_ERROR_ON_MSG(!(arm_gemm::has_opt_gemm<uint8_t, uint8_t, arm_gemm::Requantize32>(expected_weight_format, args, {})),
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(!(arm_gemm::has_opt_gemm<uint8_t, uint8_t, arm_gemm::Requantize32>(arm_gemm_expected_wf, args, {})),
"We could not find an optimized kernel for U8 input and U8 output");
}
break;
@@ -702,12 +702,12 @@ Status CpuGemmAssemblyDispatch::has_opt_impl(arm_gemm::WeightFormat &expected_we
case DataType::QASYMM8_SIGNED:
if(d->data_type() == DataType::S32)
{
- ARM_COMPUTE_RETURN_ERROR_ON_MSG(!(arm_gemm::has_opt_gemm<int8_t, int32_t, arm_gemm::Nothing>(expected_weight_format, args, {})),
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(!(arm_gemm::has_opt_gemm<int8_t, int32_t, arm_gemm::Nothing>(arm_gemm_expected_wf, args, {})),
"We could not find an optimized kernel for S8/QASYMM8_SIGNED input and S32 output");
}
else
{
- ARM_COMPUTE_RETURN_ERROR_ON_MSG(!(arm_gemm::has_opt_gemm<int8_t, int8_t, arm_gemm::Requantize32>(expected_weight_format, args, {})),
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(!(arm_gemm::has_opt_gemm<int8_t, int8_t, arm_gemm::Requantize32>(arm_gemm_expected_wf, args, {})),
"We could not find an optimized kernel for S8 input and S32 output");
}
break;
@@ -722,7 +722,7 @@ Status CpuGemmAssemblyDispatch::has_opt_impl(arm_gemm::WeightFormat &expected_we
#endif /* defined(ARM_COMPUTE_ENABLE_BF16) */
#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
case DataType::F16:
- ARM_COMPUTE_RETURN_ERROR_ON_MSG(!(arm_gemm::has_opt_gemm<float16_t, float16_t, arm_gemm::Nothing>(expected_weight_format, args, {})),
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(!(arm_gemm::has_opt_gemm<float16_t, float16_t, arm_gemm::Nothing>(arm_gemm_expected_wf, args, {})),
"We could not find an optimized kernel for BFLOAT16 input and F32 output");
break;
#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
@@ -730,6 +730,7 @@ Status CpuGemmAssemblyDispatch::has_opt_impl(arm_gemm::WeightFormat &expected_we
ARM_COMPUTE_RETURN_ERROR_ON_MSG(true, "Usupported type. Could not find a kernel");
break;
}
+ expected_weight_format = assembly_utils::map_to_arm_compute_weight_format(arm_gemm_expected_wf);
return Status{};
}
@@ -762,9 +763,9 @@ Status CpuGemmAssemblyDispatch::validate(const ITensorInfo *a, const ITensorInfo
ARM_COMPUTE_RETURN_ERROR_ON_MSG(a->data_type() == DataType::U8 && d->data_type() != DataType::U32, "Only U32 output supported for U8 input");
ARM_COMPUTE_RETURN_ERROR_ON_MSG(a->data_type() == DataType::S8 && d->data_type() != DataType::S32, "Only S32 output supported for S8 input");
ARM_COMPUTE_RETURN_ERROR_ON_MSG(a->data_type() == DataType::QASYMM8 && d->data_type() != DataType::QASYMM8, "Only QASYMM8 output supported for QASYMM8 input");
- arm_gemm::WeightFormat expected_weight_format;
- const Status ret = CpuGemmAssemblyDispatch::has_opt_impl(expected_weight_format, a, b, c, d, info);
- if((bool)ret && expected_weight_format != arm_gemm::WeightFormat::ANY)
+ arm_compute::WeightFormat expected_weight_format;
+ const Status ret = CpuGemmAssemblyDispatch::has_opt_impl(expected_weight_format, a, b, c, d, info);
+ if((bool)ret && expected_weight_format != arm_compute::WeightFormat::ANY)
{
// Correctness check: if the format expected by the kernel is
// not "any", make sure that the one found matches the format
diff --git a/src/cpu/operators/internal/CpuGemmAssemblyDispatch.h b/src/cpu/operators/internal/CpuGemmAssemblyDispatch.h
index 4ef108d430..691eeff8d2 100644
--- a/src/cpu/operators/internal/CpuGemmAssemblyDispatch.h
+++ b/src/cpu/operators/internal/CpuGemmAssemblyDispatch.h
@@ -41,19 +41,19 @@ enum class AsmConvMethod
struct AsmGemmInfo
{
- AsmConvMethod method{ AsmConvMethod::Im2Col };
- PadStrideInfo ps_info{};
- ActivationLayerInfo activation_info{};
- GEMMLowpOutputStageInfo output_stage{};
- bool negated_offsets{ true };
- bool reinterpret_input_as_3d{ false };
- bool depth_output_gemm3d{ false };
- int64_t padding_top{ 0 };
- int64_t padding_left{ 0 };
- float padding_value{ 0.f };
- bool fast_mode{ false };
- bool fixed_format{ false };
- arm_gemm::WeightFormat weight_format{ arm_gemm::WeightFormat::UNSPECIFIED };
+ AsmConvMethod method{ AsmConvMethod::Im2Col };
+ PadStrideInfo ps_info{};
+ ActivationLayerInfo activation_info{};
+ GEMMLowpOutputStageInfo output_stage{};
+ bool negated_offsets{ true };
+ bool reinterpret_input_as_3d{ false };
+ bool depth_output_gemm3d{ false };
+ int64_t padding_top{ 0 };
+ int64_t padding_left{ 0 };
+ float padding_value{ 0.f };
+ bool fast_mode{ false };
+ bool fixed_format{ false };
+ arm_compute::WeightFormat weight_format{ arm_compute::WeightFormat::UNSPECIFIED };
};
/** Assembly kernel glue */
@@ -70,12 +70,12 @@ public:
class IFallback
{
public:
- virtual void run(ITensorPack &tensors) = 0;
- virtual void prepare(ITensorPack &tensors) = 0;
- virtual experimental::MemoryRequirements workspace() const = 0;
- virtual bool is_configured() const = 0;
- virtual bool isVarWeightsKernel() const = 0;
- virtual ~IFallback() = default;
+ virtual void run(ITensorPack &tensors) = 0;
+ virtual void prepare(ITensorPack &tensors) = 0;
+ virtual experimental::MemoryRequirements workspace() const = 0;
+ virtual bool is_configured() const = 0;
+ virtual bool isVarWeightsKernel() const = 0;
+ virtual ~IFallback() = default;
};
public:
@@ -105,12 +105,12 @@ public:
*
* This method has the same use of @ref
* NEGEMMConvolutionLayer::has_opt_impl, with the only caveat that
- * the value of arm_gemm::WeightFormat need to be passed via the
+ * the value of arm_compute::WeightFormat need to be passed via the
* parameter info.
*
* @return a status.
*/
- static Status has_opt_impl(arm_gemm::WeightFormat &weight_format, const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *d, const AsmGemmInfo &info);
+ static Status has_opt_impl(arm_compute::WeightFormat &weight_format, const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *d, const AsmGemmInfo &info);
/** Checks if activation is supported by the gemm assembly dispatcher
*
* @param[in] activation Activation to check
@@ -133,8 +133,8 @@ public:
}
// Inherited methods overridden:
- void prepare(ITensorPack &tensors) override;
- void run(ITensorPack &tensors) override;
+ void prepare(ITensorPack &tensors) override;
+ void run(ITensorPack &tensors) override;
experimental::MemoryRequirements workspace() const override;
private:
diff --git a/src/runtime/NEON/functions/NEGEMMConvolutionLayer.cpp b/src/runtime/NEON/functions/NEGEMMConvolutionLayer.cpp
index 13635c6e34..fe3ea6a767 100644
--- a/src/runtime/NEON/functions/NEGEMMConvolutionLayer.cpp
+++ b/src/runtime/NEON/functions/NEGEMMConvolutionLayer.cpp
@@ -80,7 +80,7 @@ Status NEGEMMConvolutionLayer::validate(const ITensorInfo *input, const ITensorI
return cpu::CpuGemmConv2d::validate(input, weights, biases, output, conv_info, weights_info, dilation, act_info, enable_fast_math, num_groups);
}
-Status NEGEMMConvolutionLayer::has_opt_impl(arm_gemm::WeightFormat &expected_weight_format, const ITensorInfo *src, const ITensorInfo *weights, const ITensorInfo *biases, const ITensorInfo *dst,
+Status NEGEMMConvolutionLayer::has_opt_impl(arm_compute::WeightFormat &expected_weight_format, const ITensorInfo *src, const ITensorInfo *weights, const ITensorInfo *biases, const ITensorInfo *dst,
const PadStrideInfo &conv_info,
const WeightsInfo &weights_info, const Size2D &dilation, const ActivationLayerInfo &act_info, const bool enable_fast_math)
{
diff --git a/tests/framework/Asserts.h b/tests/framework/Asserts.h
index 5f462773d0..7adfa8f2f3 100644
--- a/tests/framework/Asserts.h
+++ b/tests/framework/Asserts.h
@@ -30,6 +30,8 @@
#include <sstream>
#include <type_traits>
+#include "utils/TypePrinter.h"
+
namespace arm_compute
{
namespace test
@@ -42,9 +44,9 @@ inline int make_printable(int8_t value)
return value;
}
-inline std::string make_printable(arm_gemm::WeightFormat wf)
+inline std::string make_printable(const arm_compute::WeightFormat wf)
{
- return arm_gemm::to_string(wf);
+ return arm_compute::to_string(wf);
}
inline unsigned int make_printable(uint8_t value)
diff --git a/tests/validation/NEON/ConvolutionLayer.cpp b/tests/validation/NEON/ConvolutionLayer.cpp
index 940983f42b..0194220e1a 100644
--- a/tests/validation/NEON/ConvolutionLayer.cpp
+++ b/tests/validation/NEON/ConvolutionLayer.cpp
@@ -511,13 +511,13 @@ TEST_SUITE(VariableWeightUtils)
FIXTURE_DATA_TEST_CASE(UC2_1_CpuGemmConv2d, HasOptImplFixture<cpu::CpuGemmConv2d>, framework::DatasetMode::ALL,
combine(framework::dataset::make("DataType", { DataType::F32 }),
- framework::dataset::make("QueryWeightFormat", { arm_gemm::WeightFormat::OHWIo2 })))
+ framework::dataset::make("QueryWeightFormat", { arm_compute::WeightFormat::OHWIo2 })))
{
ARM_COMPUTE_EXPECT(!_kernel_found, framework::LogLevel::ERRORS);
}
FIXTURE_DATA_TEST_CASE(UC2_1_NEGEMMConvolutionLayer, HasOptImplFixture<NEGEMMConvolutionLayer>, framework::DatasetMode::ALL,
combine(framework::dataset::make("DataType", { DataType::F32 }),
- framework::dataset::make("QueryWeightFormat", { arm_gemm::WeightFormat::OHWIo2 })))
+ framework::dataset::make("QueryWeightFormat", { arm_compute::WeightFormat::OHWIo2 })))
{
ARM_COMPUTE_EXPECT(!_kernel_found, framework::LogLevel::ERRORS);
}
@@ -527,18 +527,18 @@ FIXTURE_DATA_TEST_CASE(UC2_1_NEGEMMConvolutionLayer, HasOptImplFixture<NEGEMMCon
FIXTURE_DATA_TEST_CASE(UC2_2_CpuGemmConv2d, HasOptImplFixture<cpu::CpuGemmConv2d>, framework::DatasetMode::ALL,
combine(framework::dataset::make("DataType", { DataType::F32 }),
- framework::dataset::make("QueryWeightFormat", { arm_gemm::WeightFormat::OHWIo4 })))
+ framework::dataset::make("QueryWeightFormat", { arm_compute::WeightFormat::OHWIo4 })))
{
ARM_COMPUTE_EXPECT(_kernel_found, framework::LogLevel::ERRORS);
- ARM_COMPUTE_EXPECT(_computed_weight_format == arm_gemm::WeightFormat::OHWIo4, framework::LogLevel::ERRORS);
+ ARM_COMPUTE_EXPECT(_computed_weight_format == arm_compute::WeightFormat::OHWIo4, framework::LogLevel::ERRORS);
}
FIXTURE_DATA_TEST_CASE(UC2_2_NEGEMMConvolutionLayer, HasOptImplFixture<NEGEMMConvolutionLayer>, framework::DatasetMode::ALL,
combine(framework::dataset::make("DataType", { DataType::F32 }),
- framework::dataset::make("QueryWeightFormat", { arm_gemm::WeightFormat::OHWIo4 })))
+ framework::dataset::make("QueryWeightFormat", { arm_compute::WeightFormat::OHWIo4 })))
{
ARM_COMPUTE_EXPECT(_kernel_found, framework::LogLevel::ERRORS);
- ARM_COMPUTE_EXPECT(_computed_weight_format == arm_gemm::WeightFormat::OHWIo4, framework::LogLevel::ERRORS);
+ ARM_COMPUTE_EXPECT(_computed_weight_format == arm_compute::WeightFormat::OHWIo4, framework::LogLevel::ERRORS);
}
// UC3_1_* tests: the user queries for ANY fixed format, but there is
@@ -548,14 +548,14 @@ FIXTURE_DATA_TEST_CASE(UC2_2_NEGEMMConvolutionLayer, HasOptImplFixture<NEGEMMCon
FIXTURE_DATA_TEST_CASE(UC3_1_CpuGemmConv2d, HasOptImplFixture<cpu::CpuGemmConv2d>, framework::DatasetMode::ALL,
combine(framework::dataset::make("DataType", { DataType::S32 }),
- framework::dataset::make("QueryWeightFormat", { arm_gemm::WeightFormat::ANY })))
+ framework::dataset::make("QueryWeightFormat", { arm_compute::WeightFormat::ANY })))
{
ARM_COMPUTE_EXPECT(!_kernel_found, framework::LogLevel::ERRORS);
}
FIXTURE_DATA_TEST_CASE(UC3_1_NEGEMMConvolutionLayer, HasOptImplFixture<NEGEMMConvolutionLayer>, framework::DatasetMode::ALL,
combine(framework::dataset::make("DataType", { DataType::S32 }),
- framework::dataset::make("QueryWeightFormat", { arm_gemm::WeightFormat::ANY })))
+ framework::dataset::make("QueryWeightFormat", { arm_compute::WeightFormat::ANY })))
{
ARM_COMPUTE_EXPECT(!_kernel_found, framework::LogLevel::ERRORS);
}
@@ -572,24 +572,24 @@ FIXTURE_DATA_TEST_CASE(UC3_1_NEGEMMConvolutionLayer, HasOptImplFixture<NEGEMMCon
FIXTURE_DATA_TEST_CASE(UC3_2_CpuGemmConv2d, HasOptImplFixture<cpu::CpuGemmConv2d>, framework::DatasetMode::ALL,
combine(framework::dataset::make("DataType", { DataType::F32 }),
- framework::dataset::make("QueryWeightFormat", { arm_gemm::WeightFormat::ANY })))
+ framework::dataset::make("QueryWeightFormat", { arm_compute::WeightFormat::ANY })))
{
ARM_COMPUTE_EXPECT(_kernel_found, framework::LogLevel::ERRORS);
- ARM_COMPUTE_EXPECT(_computed_weight_format != arm_gemm::WeightFormat::ANY, framework::LogLevel::ERRORS);
- ARM_COMPUTE_EXPECT(_computed_weight_format != arm_gemm::WeightFormat::UNSPECIFIED, framework::LogLevel::ERRORS);
+ ARM_COMPUTE_EXPECT(_computed_weight_format != arm_compute::WeightFormat::ANY, framework::LogLevel::ERRORS);
+ ARM_COMPUTE_EXPECT(_computed_weight_format != arm_compute::WeightFormat::UNSPECIFIED, framework::LogLevel::ERRORS);
}
FIXTURE_DATA_TEST_CASE(UC3_2_NEGEMMConvolutionLayer, HasOptImplFixture<NEGEMMConvolutionLayer>, framework::DatasetMode::ALL,
combine(framework::dataset::make("DataType", { DataType::F32 }),
- framework::dataset::make("QueryWeightFormat", { arm_gemm::WeightFormat::ANY })))
+ framework::dataset::make("QueryWeightFormat", { arm_compute::WeightFormat::ANY })))
{
- ARM_COMPUTE_EXPECT(_computed_weight_format != arm_gemm::WeightFormat::ANY, framework::LogLevel::ERRORS);
- ARM_COMPUTE_EXPECT(_computed_weight_format != arm_gemm::WeightFormat::UNSPECIFIED, framework::LogLevel::ERRORS);
+ ARM_COMPUTE_EXPECT(_computed_weight_format != arm_compute::WeightFormat::ANY, framework::LogLevel::ERRORS);
+ ARM_COMPUTE_EXPECT(_computed_weight_format != arm_compute::WeightFormat::UNSPECIFIED, framework::LogLevel::ERRORS);
}
namespace
{
-using TestCaseType = std::tuple<TensorShape, TensorShape, arm_gemm::WeightFormat>;
+using TestCaseType = std::tuple<TensorShape, TensorShape, arm_compute::WeightFormat>;
auto prepare_weights_shapes = framework::dataset::make("TensorShape",
{
// OHWIo<interleave_by>i<block_by>
@@ -601,51 +601,51 @@ auto prepare_weights_shapes = framework::dataset::make("TensorShape",
//
// Change N for OHWIo4
- TestCaseType({ { 1U, 1U, 1U, 1U }, { 1U, 1U, 1U, 4U }, arm_gemm::WeightFormat::OHWIo4 }),
- TestCaseType({ { 1U, 1U, 1U, 2U }, { 1U, 1U, 1U, 4U }, arm_gemm::WeightFormat::OHWIo4 }),
- TestCaseType({ { 1U, 1U, 1U, 3U }, { 1U, 1U, 1U, 4U }, arm_gemm::WeightFormat::OHWIo4 }),
- TestCaseType({ { 1U, 1U, 1U, 4U }, { 1U, 1U, 1U, 4U }, arm_gemm::WeightFormat::OHWIo4 }),
- TestCaseType({ { 1U, 1U, 1U, 5U }, { 1U, 1U, 1U, 8U }, arm_gemm::WeightFormat::OHWIo4 }),
- TestCaseType({ { 1U, 1U, 1U, 6U }, { 1U, 1U, 1U, 8U }, arm_gemm::WeightFormat::OHWIo4 }),
- TestCaseType({ { 1U, 1U, 1U, 7U }, { 1U, 1U, 1U, 8U }, arm_gemm::WeightFormat::OHWIo4 }),
- TestCaseType({ { 1U, 1U, 1U, 8U }, { 1U, 1U, 1U, 8U }, arm_gemm::WeightFormat::OHWIo4 }),
- TestCaseType({ { 1U, 1U, 1U, 9U }, { 1U, 1U, 1U, 12U }, arm_gemm::WeightFormat::OHWIo4 }),
+ TestCaseType({ { 1U, 1U, 1U, 1U }, { 1U, 1U, 1U, 4U }, arm_compute::WeightFormat::OHWIo4 }),
+ TestCaseType({ { 1U, 1U, 1U, 2U }, { 1U, 1U, 1U, 4U }, arm_compute::WeightFormat::OHWIo4 }),
+ TestCaseType({ { 1U, 1U, 1U, 3U }, { 1U, 1U, 1U, 4U }, arm_compute::WeightFormat::OHWIo4 }),
+ TestCaseType({ { 1U, 1U, 1U, 4U }, { 1U, 1U, 1U, 4U }, arm_compute::WeightFormat::OHWIo4 }),
+ TestCaseType({ { 1U, 1U, 1U, 5U }, { 1U, 1U, 1U, 8U }, arm_compute::WeightFormat::OHWIo4 }),
+ TestCaseType({ { 1U, 1U, 1U, 6U }, { 1U, 1U, 1U, 8U }, arm_compute::WeightFormat::OHWIo4 }),
+ TestCaseType({ { 1U, 1U, 1U, 7U }, { 1U, 1U, 1U, 8U }, arm_compute::WeightFormat::OHWIo4 }),
+ TestCaseType({ { 1U, 1U, 1U, 8U }, { 1U, 1U, 1U, 8U }, arm_compute::WeightFormat::OHWIo4 }),
+ TestCaseType({ { 1U, 1U, 1U, 9U }, { 1U, 1U, 1U, 12U }, arm_compute::WeightFormat::OHWIo4 }),
// // Change N for OHWIo8
- TestCaseType({ { 1U, 1U, 1U, 1U }, { 1U, 1U, 1U, 8U }, arm_gemm::WeightFormat::OHWIo8 }),
- TestCaseType({ { 1U, 1U, 1U, 2U }, { 1U, 1U, 1U, 8U }, arm_gemm::WeightFormat::OHWIo8 }),
- TestCaseType({ { 1U, 1U, 1U, 3U }, { 1U, 1U, 1U, 8U }, arm_gemm::WeightFormat::OHWIo8 }),
- TestCaseType({ { 1U, 1U, 1U, 4U }, { 1U, 1U, 1U, 8U }, arm_gemm::WeightFormat::OHWIo8 }),
- TestCaseType({ { 1U, 1U, 1U, 5U }, { 1U, 1U, 1U, 8U }, arm_gemm::WeightFormat::OHWIo8 }),
- TestCaseType({ { 1U, 1U, 1U, 6U }, { 1U, 1U, 1U, 8U }, arm_gemm::WeightFormat::OHWIo8 }),
- TestCaseType({ { 1U, 1U, 1U, 7U }, { 1U, 1U, 1U, 8U }, arm_gemm::WeightFormat::OHWIo8 }),
- TestCaseType({ { 1U, 1U, 1U, 8U }, { 1U, 1U, 1U, 8U }, arm_gemm::WeightFormat::OHWIo8 }),
- TestCaseType({ { 1U, 1U, 1U, 9U }, { 1U, 1U, 1U, 16U }, arm_gemm::WeightFormat::OHWIo8 }),
+ TestCaseType({ { 1U, 1U, 1U, 1U }, { 1U, 1U, 1U, 8U }, arm_compute::WeightFormat::OHWIo8 }),
+ TestCaseType({ { 1U, 1U, 1U, 2U }, { 1U, 1U, 1U, 8U }, arm_compute::WeightFormat::OHWIo8 }),
+ TestCaseType({ { 1U, 1U, 1U, 3U }, { 1U, 1U, 1U, 8U }, arm_compute::WeightFormat::OHWIo8 }),
+ TestCaseType({ { 1U, 1U, 1U, 4U }, { 1U, 1U, 1U, 8U }, arm_compute::WeightFormat::OHWIo8 }),
+ TestCaseType({ { 1U, 1U, 1U, 5U }, { 1U, 1U, 1U, 8U }, arm_compute::WeightFormat::OHWIo8 }),
+ TestCaseType({ { 1U, 1U, 1U, 6U }, { 1U, 1U, 1U, 8U }, arm_compute::WeightFormat::OHWIo8 }),
+ TestCaseType({ { 1U, 1U, 1U, 7U }, { 1U, 1U, 1U, 8U }, arm_compute::WeightFormat::OHWIo8 }),
+ TestCaseType({ { 1U, 1U, 1U, 8U }, { 1U, 1U, 1U, 8U }, arm_compute::WeightFormat::OHWIo8 }),
+ TestCaseType({ { 1U, 1U, 1U, 9U }, { 1U, 1U, 1U, 16U }, arm_compute::WeightFormat::OHWIo8 }),
// // Change N for OHWIo4 when H, W and C are not 1
- TestCaseType({ { 3U, 4U, 2U, 1U }, { 3, 4, 2, 4 }, arm_gemm::WeightFormat::OHWIo4 }),
- TestCaseType({ { 3U, 4U, 2U, 2U }, { 3, 4, 2, 4 }, arm_gemm::WeightFormat::OHWIo4 }),
- TestCaseType({ { 3U, 4U, 2U, 3U }, { 3, 4, 2, 4 }, arm_gemm::WeightFormat::OHWIo4 }),
- TestCaseType({ { 3U, 4U, 2U, 4U }, { 3, 4, 2, 4 }, arm_gemm::WeightFormat::OHWIo4 }),
- TestCaseType({ { 3U, 4U, 2U, 5U }, { 3, 4, 2, 8 }, arm_gemm::WeightFormat::OHWIo4 }),
- TestCaseType({ { 3U, 4U, 2U, 6U }, { 3, 4, 2, 8 }, arm_gemm::WeightFormat::OHWIo4 }),
- TestCaseType({ { 3U, 4U, 2U, 7U }, { 3, 4, 2, 8 }, arm_gemm::WeightFormat::OHWIo4 }),
- TestCaseType({ { 3U, 4U, 2U, 8U }, { 3, 4, 2, 8 }, arm_gemm::WeightFormat::OHWIo4 }),
- TestCaseType({ { 3U, 4U, 2U, 9U }, { 3, 4, 2, 12 }, arm_gemm::WeightFormat::OHWIo4 }),
+ TestCaseType({ { 3U, 4U, 2U, 1U }, { 3, 4, 2, 4 }, arm_compute::WeightFormat::OHWIo4 }),
+ TestCaseType({ { 3U, 4U, 2U, 2U }, { 3, 4, 2, 4 }, arm_compute::WeightFormat::OHWIo4 }),
+ TestCaseType({ { 3U, 4U, 2U, 3U }, { 3, 4, 2, 4 }, arm_compute::WeightFormat::OHWIo4 }),
+ TestCaseType({ { 3U, 4U, 2U, 4U }, { 3, 4, 2, 4 }, arm_compute::WeightFormat::OHWIo4 }),
+ TestCaseType({ { 3U, 4U, 2U, 5U }, { 3, 4, 2, 8 }, arm_compute::WeightFormat::OHWIo4 }),
+ TestCaseType({ { 3U, 4U, 2U, 6U }, { 3, 4, 2, 8 }, arm_compute::WeightFormat::OHWIo4 }),
+ TestCaseType({ { 3U, 4U, 2U, 7U }, { 3, 4, 2, 8 }, arm_compute::WeightFormat::OHWIo4 }),
+ TestCaseType({ { 3U, 4U, 2U, 8U }, { 3, 4, 2, 8 }, arm_compute::WeightFormat::OHWIo4 }),
+ TestCaseType({ { 3U, 4U, 2U, 9U }, { 3, 4, 2, 12 }, arm_compute::WeightFormat::OHWIo4 }),
// // Fix N and move HWI around, with different data layouts and formats
- TestCaseType({ { 2U, 4U, 3U, 5U }, { 2, 4, 3, 8 }, arm_gemm::WeightFormat::OHWIo4 }),
- TestCaseType({ { 3U, 4U, 2U, 5U }, { 3, 4, 2, 8 }, arm_gemm::WeightFormat::OHWIo4 }),
- TestCaseType({ { 2U, 4U, 3U, 9U }, { 2, 4, 3, 16 }, arm_gemm::WeightFormat::OHWIo8 }),
- TestCaseType({ { 3U, 4U, 2U, 9U }, { 3, 4, 2, 16 }, arm_gemm::WeightFormat::OHWIo8 }),
- TestCaseType({ { 1024U, 1U, 1U, 1001U }, { 1024, 1, 1, 1008 }, arm_gemm::WeightFormat::OHWIo8 }),
+ TestCaseType({ { 2U, 4U, 3U, 5U }, { 2, 4, 3, 8 }, arm_compute::WeightFormat::OHWIo4 }),
+ TestCaseType({ { 3U, 4U, 2U, 5U }, { 3, 4, 2, 8 }, arm_compute::WeightFormat::OHWIo4 }),
+ TestCaseType({ { 2U, 4U, 3U, 9U }, { 2, 4, 3, 16 }, arm_compute::WeightFormat::OHWIo8 }),
+ TestCaseType({ { 3U, 4U, 2U, 9U }, { 3, 4, 2, 16 }, arm_compute::WeightFormat::OHWIo8 }),
+ TestCaseType({ { 1024U, 1U, 1U, 1001U }, { 1024, 1, 1, 1008 }, arm_compute::WeightFormat::OHWIo8 }),
// // Adding <block_by> on I (=C)
- TestCaseType({ { 1U, 4U, 3U, 5U }, { 2, 4, 3, 8 }, arm_gemm::WeightFormat::OHWIo4i2 }),
- TestCaseType({ { 2U, 4U, 3U, 5U }, { 2, 4, 3, 8 }, arm_gemm::WeightFormat::OHWIo4i2 }),
- TestCaseType({ { 3U, 4U, 3U, 5U }, { 4, 4, 3, 8 }, arm_gemm::WeightFormat::OHWIo4i2 }),
+ TestCaseType({ { 1U, 4U, 3U, 5U }, { 2, 4, 3, 8 }, arm_compute::WeightFormat::OHWIo4i2 }),
+ TestCaseType({ { 2U, 4U, 3U, 5U }, { 2, 4, 3, 8 }, arm_compute::WeightFormat::OHWIo4i2 }),
+ TestCaseType({ { 3U, 4U, 3U, 5U }, { 4, 4, 3, 8 }, arm_compute::WeightFormat::OHWIo4i2 }),
// ---------
- TestCaseType({ { 2, 2, 1, 5 }, { 2, 2, 1, 8 }, arm_gemm::WeightFormat::OHWIo4 }),
- TestCaseType({ { 1, 2, 2, 5 }, { 1, 2, 2, 8 }, arm_gemm::WeightFormat::OHWIo4 }),
+ TestCaseType({ { 2, 2, 1, 5 }, { 2, 2, 1, 8 }, arm_compute::WeightFormat::OHWIo4 }),
+ TestCaseType({ { 1, 2, 2, 5 }, { 1, 2, 2, 8 }, arm_compute::WeightFormat::OHWIo4 }),
});
} // unnamed namespace
@@ -653,14 +653,14 @@ auto prepare_weights_shapes = framework::dataset::make("TensorShape",
DATA_TEST_CASE(PrepareWeightShape, framework::DatasetMode::ALL,
prepare_weights_shapes, shapes)
{
- const TensorShape input_shape = std::get<0>(shapes);
- const TensorShape expected_shape = std::get<1>(shapes);
- const arm_gemm::WeightFormat wf = std::get<2>(shapes);
- const DataType DT = DataType::F32;
- const DataLayout DL = DataLayout::NHWC;
- const auto TI = TensorInfo(input_shape, 1 /*num_channels, deprecated*/, DT, DL);
- const TensorInfo computed = ::arm_compute::test::validation::prepare_weights(TI, wf);
- const TensorInfo expected = TensorInfo(expected_shape, 1 /*num_channels, deprecated*/, DT, DL);
+ const TensorShape input_shape = std::get<0>(shapes);
+ const TensorShape expected_shape = std::get<1>(shapes);
+ const arm_compute::WeightFormat wf = std::get<2>(shapes);
+ const DataType DT = DataType::F32;
+ const DataLayout DL = DataLayout::NHWC;
+ const auto TI = TensorInfo(input_shape, 1 /*num_channels, deprecated*/, DT, DL);
+ const TensorInfo computed = ::arm_compute::test::validation::prepare_weights(TI, wf);
+ const TensorInfo expected = TensorInfo(expected_shape, 1 /*num_channels, deprecated*/, DT, DL);
ARM_COMPUTE_EXPECT_EQUAL(computed, expected, framework::LogLevel::ERRORS);
}
diff --git a/tests/validation/fixtures/ConvolutionLayerFixture.h b/tests/validation/fixtures/ConvolutionLayerFixture.h
index d3804ee371..c58a0a2c91 100644
--- a/tests/validation/fixtures/ConvolutionLayerFixture.h
+++ b/tests/validation/fixtures/ConvolutionLayerFixture.h
@@ -122,14 +122,14 @@ protected:
{
case DataType::QASYMM8:
{
- std::pair<int, int> bounds = get_quantized_bounds(tensor.quantization_info(), -1.0f, 1.0f);
+ std::pair<int, int> bounds = get_quantized_bounds(tensor.quantization_info(), -1.0f, 1.0f);
std::uniform_int_distribution<uint32_t> distribution(bounds.first, bounds.second);
library->fill(tensor, distribution, i);
break;
}
case DataType::QASYMM8_SIGNED:
{
- std::pair<int, int> bounds = get_quantized_qasymm8_signed_bounds(tensor.quantization_info(), -1.0f, 1.0f);
+ std::pair<int, int> bounds = get_quantized_qasymm8_signed_bounds(tensor.quantization_info(), -1.0f, 1.0f);
std::uniform_int_distribution<int32_t> distribution(bounds.first, bounds.second);
library->fill(tensor, distribution, i);
break;
@@ -400,7 +400,7 @@ public:
};
#ifdef ARM_COMPUTE_ENABLE_FIXED_FORMAT_KERNELS
-inline TensorInfo prepare_weights(const TensorInfo tensor_info, const arm_gemm::WeightFormat weight_format)
+inline TensorInfo prepare_weights(const TensorInfo tensor_info, const arm_compute::WeightFormat weight_format)
{
const DataLayout data_layout = tensor_info.data_layout();
ARM_COMPUTE_EXPECT(data_layout == DataLayout::NHWC, framework::LogLevel::ERRORS);
@@ -411,8 +411,8 @@ inline TensorInfo prepare_weights(const TensorInfo tensor_info, const arm_gemm::
const int W = tensor_shape[get_data_layout_dimension_index(data_layout, DataLayoutDimension::WIDTH)];
const int C = tensor_shape[get_data_layout_dimension_index(data_layout, DataLayoutDimension::CHANNEL)]; // C=I
- const int interleave_by = arm_gemm::interleave_by(weight_format);
- const int block_by = arm_gemm::block_by(weight_format);
+ const int interleave_by = arm_compute::interleave_by(weight_format);
+ const int block_by = arm_compute::block_by(weight_format);
const int Ip = arm_gemm::roundup<unsigned int>(C, block_by); // C'=I'
const int Op = arm_gemm::roundup<unsigned int>(N, interleave_by); // O'=N'
@@ -421,12 +421,12 @@ inline TensorInfo prepare_weights(const TensorInfo tensor_info, const arm_gemm::
}
template <typename ScalarType, typename AccessorType>
-inline void rearrange_data(const AccessorType src, AccessorType dst, const arm_gemm::WeightFormat weight_format)
+inline void rearrange_data(const AccessorType src, AccessorType dst, const arm_compute::WeightFormat weight_format)
{
- ARM_COMPUTE_EXPECT(arm_gemm::is_fixed_format(weight_format), framework::LogLevel::ERRORS);
+ ARM_COMPUTE_EXPECT(arm_compute::is_fixed_format(weight_format), framework::LogLevel::ERRORS);
// Data Layout: OHWIo<interleave_by>i<block_by>
- const int interleave_by = arm_gemm::interleave_by(weight_format);
- const int block_by = arm_gemm::block_by(weight_format);
+ const int interleave_by = arm_compute::interleave_by(weight_format);
+ const int block_by = arm_compute::block_by(weight_format);
const TensorShape src_tensor_shape = src.shape();
const DataLayout data_layout = src.data_layout();
ARM_COMPUTE_EXPECT(data_layout == DataLayout::NHWC, framework::LogLevel::ERRORS);
@@ -545,12 +545,12 @@ private:
const int kernel_width = weights_shape[get_data_layout_dimension_index(_data_layout, DataLayoutDimension::WIDTH)];
const int num_kernels = weights_shape[get_data_layout_dimension_index(_data_layout, DataLayoutDimension::BATCHES)];
- const WeightsInfo query_weights_info(/*reshape_weights*/ false, kernel_width, kernel_height, num_kernels, false, arm_gemm::WeightFormat::ANY);
+ const WeightsInfo query_weights_info(/*reshape_weights*/ false, kernel_width, kernel_height, num_kernels, false, arm_compute::WeightFormat::ANY);
const bool kernel_found = bool(ConvolutionFunction::has_opt_impl(_computed_weight_format, &src_tensor_info, &weight_tensor_info,
&bias_tensor_info, &dst_tensor_info, conv_info, query_weights_info));
// Make surethat the setup founds a fixed-format kernel as requested by the test case.
ARM_COMPUTE_EXPECT(kernel_found, framework::LogLevel::ERRORS);
- ARM_COMPUTE_EXPECT(arm_gemm::is_fixed_format(_computed_weight_format), framework::LogLevel::ERRORS);
+ ARM_COMPUTE_EXPECT(arm_compute::is_fixed_format(_computed_weight_format), framework::LogLevel::ERRORS);
const WeightsInfo weights_info(/*reshape_weights*/ false, kernel_width, kernel_height, num_kernels, false, _computed_weight_format);
configure_and_execute_kernel(src_tensor_info, weight_tensor_info, bias_tensor_info, dst_tensor_info, weights_info, conv_info,
@@ -576,7 +576,7 @@ private:
protected:
std::unique_ptr<ConvolutionFunction> conv{};
- arm_gemm::WeightFormat _computed_weight_format{ arm_gemm::WeightFormat::UNSPECIFIED };
+ arm_compute::WeightFormat _computed_weight_format{ arm_compute::WeightFormat::UNSPECIFIED };
TensorClass _target{};
SimpleTensor<ScalarType> _reference{};
};
@@ -669,7 +669,7 @@ class HasOptImplFixture : public framework::Fixture
{
public:
template <typename...>
- void setup(DataType data_type, arm_gemm::WeightFormat query_weight_format)
+ void setup(DataType data_type, arm_compute::WeightFormat query_weight_format)
{
auto conv = std::make_unique<ConvolutionClass>();
const auto src_info = TensorInfo(TensorShape(1U, 5U, 2U), 1, data_type, DataLayout::NHWC);
@@ -683,8 +683,8 @@ public:
}
protected:
- bool _kernel_found{ false };
- arm_gemm::WeightFormat _computed_weight_format{ arm_gemm::WeightFormat::UNSPECIFIED };
+ bool _kernel_found{ false };
+ arm_compute::WeightFormat _computed_weight_format{ arm_compute::WeightFormat::UNSPECIFIED };
};
#endif // ARM_COMPUTE_ENABLE_FIXED_FORMAT_KERNELS
diff --git a/utils/TypePrinter.h b/utils/TypePrinter.h
index 23e73f6a9e..f47943aa77 100644
--- a/utils/TypePrinter.h
+++ b/utils/TypePrinter.h
@@ -473,14 +473,14 @@ inline ::std::ostream &operator<<(::std::ostream &os, const BoundingBoxTransform
}
#if defined(ARM_COMPUTE_ENABLE_BF16)
-inline ::std::ostream &operator<<(::std::ostream &os, const bfloat16& v)
+inline ::std::ostream &operator<<(::std::ostream &os, const bfloat16 &v)
{
std::stringstream str;
str << v;
os << str.str();
return os;
}
-#endif /* defined(ARM_COMPUTE_ENABLE_BF16) */
+#endif /* defined(ARM_COMPUTE_ENABLE_BF16) */
/** Formatted output of the BoundingBoxTransformInfo type.
*
@@ -3252,19 +3252,81 @@ inline std::string to_string(const Conv3dInfo &conv3d_info)
return str.str();
}
-inline ::std::ostream &operator<<(::std::ostream &os, const arm_gemm::WeightFormat &wf)
+/** Formatted output of the arm_compute::WeightFormat type.
+ *
+ * @param[in] wf arm_compute::WeightFormat Type to output.
+ *
+ * @return Formatted string.
+ */
+inline std::string to_string(const WeightFormat wf)
{
- os << arm_gemm::to_string(wf);
- return os;
+#define __CASE_WEIGHT_FORMAT(wf) \
+case WeightFormat::wf: \
+ return #wf;
+ switch(wf)
+ {
+ __CASE_WEIGHT_FORMAT(UNSPECIFIED)
+ __CASE_WEIGHT_FORMAT(ANY)
+ __CASE_WEIGHT_FORMAT(OHWI)
+ __CASE_WEIGHT_FORMAT(OHWIo2)
+ __CASE_WEIGHT_FORMAT(OHWIo4)
+ __CASE_WEIGHT_FORMAT(OHWIo8)
+ __CASE_WEIGHT_FORMAT(OHWIo16)
+ __CASE_WEIGHT_FORMAT(OHWIo32)
+ __CASE_WEIGHT_FORMAT(OHWIo64)
+ __CASE_WEIGHT_FORMAT(OHWIo128)
+ __CASE_WEIGHT_FORMAT(OHWIo4i2)
+ __CASE_WEIGHT_FORMAT(OHWIo4i2_bf16)
+ __CASE_WEIGHT_FORMAT(OHWIo8i2)
+ __CASE_WEIGHT_FORMAT(OHWIo8i2_bf16)
+ __CASE_WEIGHT_FORMAT(OHWIo16i2)
+ __CASE_WEIGHT_FORMAT(OHWIo16i2_bf16)
+ __CASE_WEIGHT_FORMAT(OHWIo32i2)
+ __CASE_WEIGHT_FORMAT(OHWIo32i2_bf16)
+ __CASE_WEIGHT_FORMAT(OHWIo64i2)
+ __CASE_WEIGHT_FORMAT(OHWIo64i2_bf16)
+ __CASE_WEIGHT_FORMAT(OHWIo4i4)
+ __CASE_WEIGHT_FORMAT(OHWIo4i4_bf16)
+ __CASE_WEIGHT_FORMAT(OHWIo8i4)
+ __CASE_WEIGHT_FORMAT(OHWIo8i4_bf16)
+ __CASE_WEIGHT_FORMAT(OHWIo16i4)
+ __CASE_WEIGHT_FORMAT(OHWIo16i4_bf16)
+ __CASE_WEIGHT_FORMAT(OHWIo32i4)
+ __CASE_WEIGHT_FORMAT(OHWIo32i4_bf16)
+ __CASE_WEIGHT_FORMAT(OHWIo64i4)
+ __CASE_WEIGHT_FORMAT(OHWIo64i4_bf16)
+ __CASE_WEIGHT_FORMAT(OHWIo2i8)
+ __CASE_WEIGHT_FORMAT(OHWIo4i8)
+ __CASE_WEIGHT_FORMAT(OHWIo8i8)
+ __CASE_WEIGHT_FORMAT(OHWIo16i8)
+ __CASE_WEIGHT_FORMAT(OHWIo32i8)
+ __CASE_WEIGHT_FORMAT(OHWIo64i8)
+ default:
+ return "invalid value";
+ }
+#undef __CASE_WEIGHT_FORMAT
}
-inline std::string to_string(const arm_gemm::WeightFormat wf)
+
+/** Formatted output of the arm_compute::WeightFormat type.
+ *
+ * @param[out] os Output stream.
+ * @param[in] wf WeightFormat to output.
+ *
+ * @return Modified output stream.
+ */
+inline ::std::ostream &operator<<(::std::ostream &os, const arm_compute::WeightFormat &wf)
{
- std::stringstream str;
- str << wf;
- return str.str();
+ os << to_string(wf);
+ return os;
}
-inline std::string to_string(const std::tuple<TensorShape, TensorShape, arm_gemm::WeightFormat> values)
+/** Formatted output of the std::tuple<TensorShape, TensorShape, arm_compute::WeightFormat> tuple.
+ *
+ * @param[in] values tuple of input and output tensor shapes and WeightFormat used.
+ *
+ * @return Formatted string.
+ */
+inline std::string to_string(const std::tuple<TensorShape, TensorShape, arm_compute::WeightFormat> values)
{
std::stringstream str;
str << "[Input shape = " << std::get<0>(values);