aboutsummaryrefslogtreecommitdiff
path: root/src/cpu/operators
diff options
context:
space:
mode:
Diffstat (limited to 'src/cpu/operators')
-rw-r--r--src/cpu/operators/CpuGemm.cpp2
-rw-r--r--src/cpu/operators/CpuGemm.h8
-rw-r--r--src/cpu/operators/CpuGemmConv2d.cpp64
-rw-r--r--src/cpu/operators/CpuGemmConv2d.h10
-rw-r--r--src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp43
-rw-r--r--src/cpu/operators/internal/CpuGemmAssemblyDispatch.h46
6 files changed, 87 insertions, 86 deletions
diff --git a/src/cpu/operators/CpuGemm.cpp b/src/cpu/operators/CpuGemm.cpp
index f3fff608dc..f6582c73f8 100644
--- a/src/cpu/operators/CpuGemm.cpp
+++ b/src/cpu/operators/CpuGemm.cpp
@@ -368,7 +368,7 @@ experimental::MemoryRequirements CpuGemm::workspace() const
return _aux_mem;
}
-Status CpuGemm::has_opt_impl(arm_gemm::WeightFormat &expected_weight_format, const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *d,
+Status CpuGemm::has_opt_impl(arm_compute::WeightFormat &expected_weight_format, const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *d,
const GEMMInfo &gemm_info)
{
const cpu::AsmGemmInfo asm_info = init_assembly_metadata(gemm_info);
diff --git a/src/cpu/operators/CpuGemm.h b/src/cpu/operators/CpuGemm.h
index b37ab73485..8d34b22437 100644
--- a/src/cpu/operators/CpuGemm.h
+++ b/src/cpu/operators/CpuGemm.h
@@ -105,15 +105,15 @@ public:
*
* This method has the same use of @ref
* NEGEMMConvolutionLayer::has_opt_impl, with the only caveat that
- * the value of arm_gemm::WeightFormat need to be passed via the
+ * the value of arm_compute::WeightFormat need to be passed via the
* parameter gemm_info.
*/
- static Status has_opt_impl(arm_gemm::WeightFormat &weight_format, const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *d,
+ static Status has_opt_impl(arm_compute::WeightFormat &weight_format, const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *d,
const GEMMInfo &gemm_info = GEMMInfo());
// Inherited methods overridden:
- void run(ITensorPack &tensors) override;
- void prepare(ITensorPack &constants) override;
+ void run(ITensorPack &tensors) override;
+ void prepare(ITensorPack &constants) override;
experimental::MemoryRequirements workspace() const override;
/** Indicates if the convolution executes in variable weights mode.
diff --git a/src/cpu/operators/CpuGemmConv2d.cpp b/src/cpu/operators/CpuGemmConv2d.cpp
index 0174d0eed3..f3a16f104f 100644
--- a/src/cpu/operators/CpuGemmConv2d.cpp
+++ b/src/cpu/operators/CpuGemmConv2d.cpp
@@ -62,13 +62,13 @@ CpuGemmConv2d::SkipInfo CpuGemmConv2d::skip_im_col_info(const ITensorInfo *src,
const unsigned int kernel_height = weights->dimension(idx_height);
unsigned int conv_w = 0;
unsigned int conv_h = 0;
- std::tie(conv_w, conv_h) = scaled_dimensions(src->dimension(idx_width),
- src->dimension(idx_height),
- kernel_width,
- kernel_height,
- conv_info,
- dilation);
- const bool skip_im2col = (data_layout == DataLayout::NHWC && kernel_width == 1 && kernel_height == 1 && conv_info.stride().first == 1 && conv_info.stride().second == 1);
+ std::tie(conv_w, conv_h) = scaled_dimensions(src->dimension(idx_width),
+ src->dimension(idx_height),
+ kernel_width,
+ kernel_height,
+ conv_info,
+ dilation);
+ const bool skip_im2col = (data_layout == DataLayout::NHWC && kernel_width == 1 && kernel_height == 1 && conv_info.stride().first == 1 && conv_info.stride().second == 1);
if(skip_im2col)
{
@@ -99,7 +99,7 @@ CpuGemmConv2d::CpuGemmConv2d()
CpuGemmConv2d::~CpuGemmConv2d() = default;
void CpuGemmConv2d::configure_mm(const ITensorInfo *src, const ITensorInfo *weights, const ITensorInfo *biases, ITensorInfo *dst, const ActivationLayerInfo &act_info,
- bool enable_fast_math, int gemm_3d_depth, bool fixed_format, arm_gemm::WeightFormat weight_format)
+ bool enable_fast_math, int gemm_3d_depth, bool fixed_format, arm_compute::WeightFormat weight_format)
{
ARM_COMPUTE_ERROR_ON_NULLPTR(src, weights);
ARM_COMPUTE_ERROR_THROW_ON(validate_mm(src, weights, biases, dst, act_info, enable_fast_math, gemm_3d_depth, _skip_im2col, fixed_format, weight_format));
@@ -139,8 +139,8 @@ void CpuGemmConv2d::configure_mm(const ITensorInfo *src, const ITensorInfo *weig
PixelValue type_min{};
PixelValue type_max{};
std::tie(type_min, type_max) = get_min_max(data_type);
- int32_t min_activation = type_min.get<int32_t>();
- int32_t max_activation = type_max.get<int32_t>();
+ int32_t min_activation = type_min.get<int32_t>();
+ int32_t max_activation = type_max.get<int32_t>();
if(supported_acts.count(act_info.activation()) != 0)
{
@@ -179,7 +179,7 @@ void CpuGemmConv2d::configure_mm(const ITensorInfo *src, const ITensorInfo *weig
}
Status CpuGemmConv2d::validate_mm(const ITensorInfo *src, const ITensorInfo *weights, const ITensorInfo *biases, const ITensorInfo *dst,
- const ActivationLayerInfo &act_info, bool enable_fast_math, int gemm_3d_depth, bool skip_im2col, bool fixed_format, arm_gemm::WeightFormat weight_format)
+ const ActivationLayerInfo &act_info, bool enable_fast_math, int gemm_3d_depth, bool skip_im2col, bool fixed_format, arm_compute::WeightFormat weight_format)
{
const DataType data_type = src->data_type();
const bool is_quantized = is_data_type_quantized_asymmetric(data_type);
@@ -203,8 +203,8 @@ Status CpuGemmConv2d::validate_mm(const ITensorInfo *src, const ITensorInfo *wei
PixelValue type_min{};
PixelValue type_max{};
std::tie(type_min, type_max) = get_min_max(data_type);
- int32_t min_activation = type_min.get<int32_t>();
- int32_t max_activation = type_max.get<int32_t>();
+ int32_t min_activation = type_min.get<int32_t>();
+ int32_t max_activation = type_max.get<int32_t>();
const std::set<ActivationLayerInfo::ActivationFunction> supported_acts = { ActivationLayerInfo::ActivationFunction::RELU,
ActivationLayerInfo::ActivationFunction::BOUNDED_RELU,
@@ -288,8 +288,8 @@ void CpuGemmConv2d::configure(const ITensorInfo *src, const ITensorInfo *weights
ITensorInfo *gemm_output_to_use = dst;
// Get convolved dimensions
- unsigned int conv_w = 0;
- unsigned int conv_h = 0;
+ unsigned int conv_w = 0;
+ unsigned int conv_h = 0;
std::tie(conv_w, conv_h) = scaled_dimensions(src->dimension(idx_width),
src->dimension(idx_height),
kernel_width,
@@ -306,8 +306,8 @@ void CpuGemmConv2d::configure(const ITensorInfo *src, const ITensorInfo *weights
_skip_col2im = skip_info.skip_col2im;
// Get parameters from conv_info
- unsigned int stride_x = 0;
- unsigned int stride_y = 0;
+ unsigned int stride_x = 0;
+ unsigned int stride_y = 0;
std::tie(stride_x, stride_y) = conv_info.stride();
unsigned int mat_weights_cols = weights->dimension(idx_kernels);
@@ -360,7 +360,7 @@ void CpuGemmConv2d::configure(const ITensorInfo *src, const ITensorInfo *weights
// Configure GEMM
// In case we need to skip col2im, GEMM3D (gemm_3d_depth != 0) must be called in order to avoid reshaping the output matrix
const unsigned int gemm_3d_depth = _skip_col2im ? conv_h : 0;
- const bool fixed_format = weights_info.weight_format() != arm_gemm::WeightFormat::UNSPECIFIED;
+ const bool fixed_format = weights_info.weight_format() != arm_compute::WeightFormat::UNSPECIFIED;
configure_mm(gemm_input_to_use, &_weights_reshaped, biases, gemm_output_to_use, act_info, enable_fast_math, gemm_3d_depth, fixed_format, weights_info.weight_format());
if(!_skip_col2im && _data_layout == DataLayout::NCHW)
@@ -388,7 +388,7 @@ void CpuGemmConv2d::configure(const ITensorInfo *src, const ITensorInfo *weights
_aux_mem[GemmOutput] = MemoryInfo(offset_int_vec(GemmOutput), MemoryLifetime::Temporary, _gemm_output.total_size());
}
-Status CpuGemmConv2d::has_opt_impl(arm_gemm::WeightFormat &expected_weight_format, const ITensorInfo *src, const ITensorInfo *weights, const ITensorInfo *biases, const ITensorInfo *dst,
+Status CpuGemmConv2d::has_opt_impl(arm_compute::WeightFormat &expected_weight_format, const ITensorInfo *src, const ITensorInfo *weights, const ITensorInfo *biases, const ITensorInfo *dst,
const PadStrideInfo &conv_info,
const WeightsInfo &weights_info, const Size2D &dilation, const ActivationLayerInfo &act_info, const bool enable_fast_math)
{
@@ -399,12 +399,12 @@ Status CpuGemmConv2d::has_opt_impl(arm_gemm::WeightFormat &expected_weight_forma
const unsigned int kernel_height = weights->dimension(idx_height);
unsigned int conv_w = 0;
unsigned int conv_h = 0;
- std::tie(conv_w, conv_h) = scaled_dimensions(src->dimension(idx_width),
- src->dimension(idx_height),
- kernel_width,
- kernel_height,
- conv_info,
- dilation);
+ std::tie(conv_w, conv_h) = scaled_dimensions(src->dimension(idx_width),
+ src->dimension(idx_height),
+ kernel_width,
+ kernel_height,
+ conv_info,
+ dilation);
const CpuGemmConv2d::SkipInfo skip_info = CpuGemmConv2d::skip_im_col_info(src, weights, conv_info,
dilation, act_info);
@@ -412,7 +412,7 @@ Status CpuGemmConv2d::has_opt_impl(arm_gemm::WeightFormat &expected_weight_forma
const bool skip_im2col = skip_info.skip_im2col;
const bool skip_col2im = skip_info.skip_col2im;
const unsigned int gemm_3d_depth = skip_col2im ? conv_h : 0;
- const bool fixed_format = weights_info.weight_format() != arm_gemm::WeightFormat::UNSPECIFIED;
+ const bool fixed_format = weights_info.weight_format() != arm_compute::WeightFormat::UNSPECIFIED;
const GEMMInfo gemm_info = GEMMInfo(false, false, true /* Reshape weights only for the first run */,
gemm_3d_depth, skip_im2col /* Reinterpret the input as 3D if im2col is skipped */,
false, GEMMLowpOutputStageInfo(), false, enable_fast_math, false, act_info, experimental::PostOpList<ITensorInfo *>(), fixed_format, weights_info.weight_format());
@@ -464,9 +464,9 @@ Status CpuGemmConv2d::validate(const ITensorInfo *src, const ITensorInfo *weight
dilation);
// Check if GEMM3D is supported
- const CpuGemmConv2d::SkipInfo skip_info = CpuGemmConv2d::skip_im_col_info(src, weights, conv_info,
- dilation, act_info);
- const bool skip_im2col = skip_info.skip_im2col, skip_col2im = skip_info.skip_col2im;
+ const CpuGemmConv2d::SkipInfo skip_info = CpuGemmConv2d::skip_im_col_info(src, weights, conv_info,
+ dilation, act_info);
+ const bool skip_im2col = skip_info.skip_im2col, skip_col2im = skip_info.skip_col2im;
ARM_COMPUTE_RETURN_ERROR_ON(weights->dimension(idx_channel) != src->dimension(idx_channel));
ARM_COMPUTE_RETURN_ERROR_ON(weights->num_dimensions() > 4);
@@ -527,7 +527,7 @@ Status CpuGemmConv2d::validate(const ITensorInfo *src, const ITensorInfo *weight
}
info_gemm.set_quantization_info(dst->quantization_info()).set_data_layout(src->data_layout());
gemm_output_to_use = &info_gemm;
- const bool fixed_format = weights_info.weight_format() != arm_gemm::WeightFormat::UNSPECIFIED;
+ const bool fixed_format = weights_info.weight_format() != arm_compute::WeightFormat::UNSPECIFIED;
ARM_COMPUTE_RETURN_ON_ERROR(validate_mm(gemm_input_to_use, weights_to_use, biases, gemm_output_to_use, act_info, enable_fast_math, skip_col2im ? conv_h : 0, skip_im2col, fixed_format,
weights_info.weight_format()));
@@ -558,7 +558,7 @@ void CpuGemmConv2d::run(ITensorPack &tensors)
{
// Run input reshaping
unsigned int y_dim = get_data_layout_dimension_index(_data_layout, DataLayoutDimension::HEIGHT);
- ITensorPack pack =
+ ITensorPack pack =
{
{ TensorType::ACL_SRC, src },
{ TensorType::ACL_DST, im2col_output.get() }
@@ -652,7 +652,7 @@ void CpuGemmConv2d::prepare(ITensorPack &tensors)
// Run weights reshaping and mark original weights tensor as unused
CpuAuxTensorHandler weights_reshaped(offset_int_vec(WeightsReshaped), _weights_reshaped, tensors);
auto weights = tensors.get_const_tensor(TensorType::ACL_SRC_1);
- ITensorPack pack =
+ ITensorPack pack =
{
{ TensorType::ACL_SRC, weights },
{ TensorType::ACL_DST, weights_reshaped.get() }
diff --git a/src/cpu/operators/CpuGemmConv2d.h b/src/cpu/operators/CpuGemmConv2d.h
index f8f0bce048..08b76a6c46 100644
--- a/src/cpu/operators/CpuGemmConv2d.h
+++ b/src/cpu/operators/CpuGemmConv2d.h
@@ -123,14 +123,14 @@ public:
*
* @return a status.
*/
- static Status has_opt_impl(arm_gemm::WeightFormat &expected_weight_format, const ITensorInfo *src, const ITensorInfo *weights, const ITensorInfo *biases, const ITensorInfo *output,
+ static Status has_opt_impl(arm_compute::WeightFormat &expected_weight_format, const ITensorInfo *src, const ITensorInfo *weights, const ITensorInfo *biases, const ITensorInfo *output,
const PadStrideInfo &conv_info,
const WeightsInfo &weights_info = WeightsInfo(), const Size2D &dilation = Size2D(1U, 1U), const ActivationLayerInfo &act_info = ActivationLayerInfo(),
const bool enable_fast_math = false);
// Inherited methods overridden:
- void run(ITensorPack &tensors) override;
- void prepare(ITensorPack &tensors) override;
+ void run(ITensorPack &tensors) override;
+ void prepare(ITensorPack &tensors) override;
experimental::MemoryRequirements workspace() const override;
private:
@@ -150,7 +150,7 @@ private:
* @param[in] weight_format (Optional) The layout to be used for the weights tensor when running GEMM with variable weights.
*/
void configure_mm(const ITensorInfo *src, const ITensorInfo *weights, const ITensorInfo *biases, ITensorInfo *output, const ActivationLayerInfo &act_info = ActivationLayerInfo(),
- bool enable_fast_math = false, int gemm_3d_depth = 1, bool fixed_format = false, arm_gemm::WeightFormat weight_format = arm_gemm::WeightFormat::UNSPECIFIED);
+ bool enable_fast_math = false, int gemm_3d_depth = 1, bool fixed_format = false, arm_compute::WeightFormat weight_format = arm_compute::WeightFormat::UNSPECIFIED);
/** Static function to check if given info will lead to a valid configuration of @ref NEGEMMConvolutionLayer matrix multiply routines
*
* @param[in] src Input tensor info. Data types supported: QASYMM8/QASYMM8_SIGNED/BFLOAT16/F16/F32.
@@ -170,7 +170,7 @@ private:
* @return a status
*/
static Status validate_mm(const ITensorInfo *src, const ITensorInfo *weights, const ITensorInfo *biases, const ITensorInfo *dst, const ActivationLayerInfo &act_info = ActivationLayerInfo(),
- bool enable_fast_math = false, int gemm_3d_depth = 1, bool skip_im2col = false, bool fixed_format = false, arm_gemm::WeightFormat weight_format = arm_gemm::WeightFormat::UNSPECIFIED);
+ bool enable_fast_math = false, int gemm_3d_depth = 1, bool skip_im2col = false, bool fixed_format = false, arm_compute::WeightFormat weight_format = arm_compute::WeightFormat::UNSPECIFIED);
/** Static function to check if GEMM3D is supported in @ref NEGEMM or in @ref CpuGemmMLowpMatrixMultiplyCore
*
* @param[in] src Input tensor info. Data types supported: QASYMM8/QASYMM8_SIGNED/BFLOAT16/F16/F32.
diff --git a/src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp b/src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp
index 558ff41a5c..c969c9f4f6 100644
--- a/src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp
+++ b/src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp
@@ -164,8 +164,8 @@ public:
{
if(!_gemm_kernel_asm)
return false;
- const arm_gemm::WeightFormat wf = _gemm_kernel_asm->get_config().weight_format;
- return wf != arm_gemm::WeightFormat::UNSPECIFIED && wf != arm_gemm::WeightFormat::ANY;
+ const arm_compute::WeightFormat wf = assembly_utils::map_to_arm_compute_weight_format(_gemm_kernel_asm->get_config().weight_format);
+ return wf != arm_compute::WeightFormat::UNSPECIFIED && wf != arm_compute::WeightFormat::ANY;
}
private:
@@ -428,7 +428,7 @@ void Fallback<TypeInput, TypeOutput, OutputStage>::prepare(ITensorPack &tensors)
if(_gemm_kernel_asm->B_pretranspose_required())
{
// Fixed format kernels need no pretranspose.
- ARM_COMPUTE_ERROR_ON(arm_gemm::is_fixed_format(_gemm_kernel_asm->get_config().weight_format));
+ ARM_COMPUTE_ERROR_ON(arm_compute::is_fixed_format(assembly_utils::map_to_arm_compute_weight_format(_gemm_kernel_asm->get_config().weight_format)));
const int ldb = b->info()->strides_in_bytes().y() / sizeof(TypeInput);
const auto in1_ptr = reinterpret_cast<const TypeInput *>(b->buffer() + b->info()->offset_first_element_in_bytes());
const int multi_stride_b = b->info()->strides_in_bytes().z() / sizeof(TypeInput);
@@ -492,8 +492,8 @@ void Fallback<TypeInput, TypeOutput, OutputStage>::run(ITensorPack &tensors)
// Check if B is pre-tranposed and de-reference if not
if(!_gemm_kernel_asm->B_is_pretransposed())
{
- ldb = b->info()->strides_in_bytes().y() / sizeof(TypeInput);
- const arm_gemm::WeightFormat wf = _gemm_kernel_asm->get_config().weight_format;
+ ldb = b->info()->strides_in_bytes().y() / sizeof(TypeInput);
+ const arm_compute::WeightFormat wf = assembly_utils::map_to_arm_compute_weight_format(_gemm_kernel_asm->get_config().weight_format);
if(is_fixed_format(wf))
{
// The 4D tensor of dimension O'HWI' created for the
@@ -507,7 +507,7 @@ void Fallback<TypeInput, TypeOutput, OutputStage>::run(ITensorPack &tensors)
const int H = tensor_shape[get_data_layout_dimension_index(data_layout, DataLayoutDimension::HEIGHT)];
const int W = tensor_shape[get_data_layout_dimension_index(data_layout, DataLayoutDimension::WIDTH)];
const int Ip = tensor_shape[get_data_layout_dimension_index(data_layout, DataLayoutDimension::CHANNEL)];
- const int interleave_by = arm_gemm::interleave_by(wf);
+ const int interleave_by = arm_compute::interleave_by(wf);
ldb = (interleave_by * H * W * Ip);
}
multi_stride_b = b->info()->strides_in_bytes().z() / sizeof(TypeInput);
@@ -603,7 +603,7 @@ void create_arm_gemm(std::unique_ptr<CpuGemmAssemblyDispatch::IFallback> &arm_ge
unsigned int num_threads = NEScheduler::get().num_threads();
arm_gemm::GemmConfig cfg;
- cfg.weight_format = info.weight_format;
+ cfg.weight_format = assembly_utils::map_to_arm_gemm_weight_format(info.weight_format);
arm_gemm::GemmArgs args(&ci, p.M, p.N, p.K, p.sections, p.batches, p.multis, p.indirect, activation, num_threads, info.fixed_format, info.fast_mode, &cfg);
// Create arm_gemm fallback
@@ -623,7 +623,7 @@ void create_arm_gemm_quant(std::unique_ptr<CpuGemmAssemblyDispatch::IFallback> &
const unsigned int num_threads = NEScheduler::get().num_threads();
arm_gemm::GemmConfig cfg;
- cfg.weight_format = info.weight_format;
+ cfg.weight_format = assembly_utils::map_to_arm_gemm_weight_format(info.weight_format);
arm_gemm::GemmArgs args(&ci, p.M, p.N, p.K, p.sections, p.batches, p.multis, p.indirect, activation, num_threads, info.fixed_format, info.fast_mode, &cfg);
// Create arm_gemm fallback
@@ -665,7 +665,7 @@ CpuGemmAssemblyDispatch::CpuGemmAssemblyDispatch()
{
}
-Status CpuGemmAssemblyDispatch::has_opt_impl(arm_gemm::WeightFormat &expected_weight_format, const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *d,
+Status CpuGemmAssemblyDispatch::has_opt_impl(arm_compute::WeightFormat &expected_weight_format, const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *d,
const AsmGemmInfo &info)
{
ARM_COMPUTE_ERROR_ON_NULLPTR(a, b, d);
@@ -675,13 +675,13 @@ Status CpuGemmAssemblyDispatch::has_opt_impl(arm_gemm::WeightFormat &expected_we
const CPUInfo &ci = NEScheduler::get().cpu_info();
unsigned int num_threads = NEScheduler::get().num_threads();
arm_gemm::GemmConfig cfg;
- cfg.weight_format = info.weight_format;
-
- arm_gemm::GemmArgs args(&ci, p.M, p.N, p.K, p.sections, p.batches, p.multis, p.indirect, act, num_threads, info.fixed_format, info.fast_mode, &cfg);
+ cfg.weight_format = assembly_utils::map_to_arm_gemm_weight_format(info.weight_format);
+ arm_gemm::WeightFormat arm_gemm_expected_wf = assembly_utils::map_to_arm_gemm_weight_format(expected_weight_format);
+ arm_gemm::GemmArgs args(&ci, p.M, p.N, p.K, p.sections, p.batches, p.multis, p.indirect, act, num_threads, info.fixed_format, info.fast_mode, &cfg);
switch(a->data_type())
{
case DataType::F32:
- ARM_COMPUTE_RETURN_ERROR_ON_MSG(!(arm_gemm::has_opt_gemm<float, float, arm_gemm::Nothing>(expected_weight_format, args, {})),
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(!(arm_gemm::has_opt_gemm<float, float, arm_gemm::Nothing>(arm_gemm_expected_wf, args, {})),
"We could not find an optimized kernel for F32 input");
break;
#ifdef __aarch64__
@@ -689,12 +689,12 @@ Status CpuGemmAssemblyDispatch::has_opt_impl(arm_gemm::WeightFormat &expected_we
case DataType::QASYMM8:
if(d->data_type() == DataType::S32)
{
- ARM_COMPUTE_RETURN_ERROR_ON_MSG(!(arm_gemm::has_opt_gemm<uint8_t, uint32_t, arm_gemm::Nothing>(expected_weight_format, args, {})),
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(!(arm_gemm::has_opt_gemm<uint8_t, uint32_t, arm_gemm::Nothing>(arm_gemm_expected_wf, args, {})),
"We could not find an optimized kernel for U8/QASYMM8 input and S32 output");
}
else
{
- ARM_COMPUTE_RETURN_ERROR_ON_MSG(!(arm_gemm::has_opt_gemm<uint8_t, uint8_t, arm_gemm::Requantize32>(expected_weight_format, args, {})),
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(!(arm_gemm::has_opt_gemm<uint8_t, uint8_t, arm_gemm::Requantize32>(arm_gemm_expected_wf, args, {})),
"We could not find an optimized kernel for U8 input and U8 output");
}
break;
@@ -702,12 +702,12 @@ Status CpuGemmAssemblyDispatch::has_opt_impl(arm_gemm::WeightFormat &expected_we
case DataType::QASYMM8_SIGNED:
if(d->data_type() == DataType::S32)
{
- ARM_COMPUTE_RETURN_ERROR_ON_MSG(!(arm_gemm::has_opt_gemm<int8_t, int32_t, arm_gemm::Nothing>(expected_weight_format, args, {})),
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(!(arm_gemm::has_opt_gemm<int8_t, int32_t, arm_gemm::Nothing>(arm_gemm_expected_wf, args, {})),
"We could not find an optimized kernel for S8/QASYMM8_SIGNED input and S32 output");
}
else
{
- ARM_COMPUTE_RETURN_ERROR_ON_MSG(!(arm_gemm::has_opt_gemm<int8_t, int8_t, arm_gemm::Requantize32>(expected_weight_format, args, {})),
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(!(arm_gemm::has_opt_gemm<int8_t, int8_t, arm_gemm::Requantize32>(arm_gemm_expected_wf, args, {})),
"We could not find an optimized kernel for S8 input and S32 output");
}
break;
@@ -722,7 +722,7 @@ Status CpuGemmAssemblyDispatch::has_opt_impl(arm_gemm::WeightFormat &expected_we
#endif /* defined(ARM_COMPUTE_ENABLE_BF16) */
#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
case DataType::F16:
- ARM_COMPUTE_RETURN_ERROR_ON_MSG(!(arm_gemm::has_opt_gemm<float16_t, float16_t, arm_gemm::Nothing>(expected_weight_format, args, {})),
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(!(arm_gemm::has_opt_gemm<float16_t, float16_t, arm_gemm::Nothing>(arm_gemm_expected_wf, args, {})),
"We could not find an optimized kernel for BFLOAT16 input and F32 output");
break;
#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
@@ -730,6 +730,7 @@ Status CpuGemmAssemblyDispatch::has_opt_impl(arm_gemm::WeightFormat &expected_we
ARM_COMPUTE_RETURN_ERROR_ON_MSG(true, "Usupported type. Could not find a kernel");
break;
}
+ expected_weight_format = assembly_utils::map_to_arm_compute_weight_format(arm_gemm_expected_wf);
return Status{};
}
@@ -762,9 +763,9 @@ Status CpuGemmAssemblyDispatch::validate(const ITensorInfo *a, const ITensorInfo
ARM_COMPUTE_RETURN_ERROR_ON_MSG(a->data_type() == DataType::U8 && d->data_type() != DataType::U32, "Only U32 output supported for U8 input");
ARM_COMPUTE_RETURN_ERROR_ON_MSG(a->data_type() == DataType::S8 && d->data_type() != DataType::S32, "Only S32 output supported for S8 input");
ARM_COMPUTE_RETURN_ERROR_ON_MSG(a->data_type() == DataType::QASYMM8 && d->data_type() != DataType::QASYMM8, "Only QASYMM8 output supported for QASYMM8 input");
- arm_gemm::WeightFormat expected_weight_format;
- const Status ret = CpuGemmAssemblyDispatch::has_opt_impl(expected_weight_format, a, b, c, d, info);
- if((bool)ret && expected_weight_format != arm_gemm::WeightFormat::ANY)
+ arm_compute::WeightFormat expected_weight_format;
+ const Status ret = CpuGemmAssemblyDispatch::has_opt_impl(expected_weight_format, a, b, c, d, info);
+ if((bool)ret && expected_weight_format != arm_compute::WeightFormat::ANY)
{
// Correctness check: if the format expected by the kernel is
// not "any", make sure that the one found matches the format
diff --git a/src/cpu/operators/internal/CpuGemmAssemblyDispatch.h b/src/cpu/operators/internal/CpuGemmAssemblyDispatch.h
index 4ef108d430..691eeff8d2 100644
--- a/src/cpu/operators/internal/CpuGemmAssemblyDispatch.h
+++ b/src/cpu/operators/internal/CpuGemmAssemblyDispatch.h
@@ -41,19 +41,19 @@ enum class AsmConvMethod
struct AsmGemmInfo
{
- AsmConvMethod method{ AsmConvMethod::Im2Col };
- PadStrideInfo ps_info{};
- ActivationLayerInfo activation_info{};
- GEMMLowpOutputStageInfo output_stage{};
- bool negated_offsets{ true };
- bool reinterpret_input_as_3d{ false };
- bool depth_output_gemm3d{ false };
- int64_t padding_top{ 0 };
- int64_t padding_left{ 0 };
- float padding_value{ 0.f };
- bool fast_mode{ false };
- bool fixed_format{ false };
- arm_gemm::WeightFormat weight_format{ arm_gemm::WeightFormat::UNSPECIFIED };
+ AsmConvMethod method{ AsmConvMethod::Im2Col };
+ PadStrideInfo ps_info{};
+ ActivationLayerInfo activation_info{};
+ GEMMLowpOutputStageInfo output_stage{};
+ bool negated_offsets{ true };
+ bool reinterpret_input_as_3d{ false };
+ bool depth_output_gemm3d{ false };
+ int64_t padding_top{ 0 };
+ int64_t padding_left{ 0 };
+ float padding_value{ 0.f };
+ bool fast_mode{ false };
+ bool fixed_format{ false };
+ arm_compute::WeightFormat weight_format{ arm_compute::WeightFormat::UNSPECIFIED };
};
/** Assembly kernel glue */
@@ -70,12 +70,12 @@ public:
class IFallback
{
public:
- virtual void run(ITensorPack &tensors) = 0;
- virtual void prepare(ITensorPack &tensors) = 0;
- virtual experimental::MemoryRequirements workspace() const = 0;
- virtual bool is_configured() const = 0;
- virtual bool isVarWeightsKernel() const = 0;
- virtual ~IFallback() = default;
+ virtual void run(ITensorPack &tensors) = 0;
+ virtual void prepare(ITensorPack &tensors) = 0;
+ virtual experimental::MemoryRequirements workspace() const = 0;
+ virtual bool is_configured() const = 0;
+ virtual bool isVarWeightsKernel() const = 0;
+ virtual ~IFallback() = default;
};
public:
@@ -105,12 +105,12 @@ public:
*
* This method has the same use of @ref
* NEGEMMConvolutionLayer::has_opt_impl, with the only caveat that
- * the value of arm_gemm::WeightFormat need to be passed via the
+ * the value of arm_compute::WeightFormat need to be passed via the
* parameter info.
*
* @return a status.
*/
- static Status has_opt_impl(arm_gemm::WeightFormat &weight_format, const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *d, const AsmGemmInfo &info);
+ static Status has_opt_impl(arm_compute::WeightFormat &weight_format, const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *d, const AsmGemmInfo &info);
/** Checks if activation is supported by the gemm assembly dispatcher
*
* @param[in] activation Activation to check
@@ -133,8 +133,8 @@ public:
}
// Inherited methods overridden:
- void prepare(ITensorPack &tensors) override;
- void run(ITensorPack &tensors) override;
+ void prepare(ITensorPack &tensors) override;
+ void run(ITensorPack &tensors) override;
experimental::MemoryRequirements workspace() const override;
private: