aboutsummaryrefslogtreecommitdiff
path: root/src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp')
-rw-r--r--src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp73
1 files changed, 18 insertions, 55 deletions
diff --git a/src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp b/src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp
index 8ff81afe54..bf3ec5a1ac 100644
--- a/src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp
+++ b/src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2018-2022 Arm Limited.
+ * Copyright (c) 2018-2023 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -430,9 +430,9 @@ void Fallback<TypeInput, TypeOutput, OutputStage>::prepare(ITensorPack &tensors)
{
// Fixed format kernels need no pretranspose.
ARM_COMPUTE_ERROR_ON(arm_compute::is_fixed_format(assembly_utils::map_to_arm_compute_weight_format(_gemm_kernel_asm->get_config().weight_format)));
- const int ldb = b->info()->strides_in_bytes().y() / sizeof(TypeInput);
+ const int ldb = b->info()->strides_in_bytes().y() / b->info()->element_size();
const auto in1_ptr = reinterpret_cast<const TypeInput *>(b->buffer() + b->info()->offset_first_element_in_bytes());
- const int multi_stride_b = b->info()->strides_in_bytes().z() / sizeof(TypeInput);
+ const int multi_stride_b = b->info()->strides_in_bytes().z() / b->info()->element_size();
CpuAuxTensorHandler pretranspose(offset_int_vec(Pretranspose), _pretranspose_info, tensors, false);
ARM_COMPUTE_ERROR_ON(pretranspose.get()->buffer() == nullptr);
@@ -470,21 +470,21 @@ void Fallback<TypeInput, TypeOutput, OutputStage>::run(ITensorPack &tensors)
auto c = tensors.get_const_tensor(TensorType::ACL_SRC_2);
auto d = tensors.get_tensor(TensorType::ACL_DST);
- int lda = a->info()->strides_in_bytes().y() / sizeof(TypeInput);
+ int lda = a->info()->strides_in_bytes().y() / a->info()->element_size();
int ldb = 0;
- const int ldd = d->info()->strides_in_bytes().y() / sizeof(TypeOutput);
+ const int ldd = d->info()->strides_in_bytes().y() / d->info()->element_size();
const size_t a_batch_idx = _gemm_info.reinterpret_input_as_3d != 0 ? 3 : 2;
const size_t a_multi_idx = a_batch_idx + 1;
const size_t d_batch_idx = _gemm_info.depth_output_gemm3d != 0 ? 3 : 2;
const size_t d_multi_idx = d_batch_idx + 1;
- int batch_stride_a = a->info()->strides_in_bytes()[a_batch_idx] / sizeof(TypeInput);
- const int batch_stride_d = d->info()->strides_in_bytes()[d_batch_idx] / sizeof(TypeOutput);
+ int batch_stride_a = a->info()->strides_in_bytes()[a_batch_idx] / a->info()->element_size();
+ const int batch_stride_d = d->info()->strides_in_bytes()[d_batch_idx] / d->info()->element_size();
- int multi_stride_a = a->info()->strides_in_bytes()[a_multi_idx] / sizeof(TypeInput);
+ int multi_stride_a = a->info()->strides_in_bytes()[a_multi_idx] / a->info()->element_size();
int multi_stride_b = 0;
- const int multi_stride_d = d->info()->strides_in_bytes()[d_multi_idx] / sizeof(TypeOutput);
+ const int multi_stride_d = d->info()->strides_in_bytes()[d_multi_idx] / d->info()->element_size();
auto in0_ptr = reinterpret_cast<const TypeInput *>(a->buffer() + a->info()->offset_first_element_in_bytes());
const TypeInput *in1_ptr = nullptr;
@@ -493,50 +493,8 @@ void Fallback<TypeInput, TypeOutput, OutputStage>::run(ITensorPack &tensors)
// Check if B is pre-tranposed and de-reference if not
if(!_gemm_kernel_asm->B_is_pretransposed())
{
- ldb = b->info()->strides_in_bytes().y() / sizeof(TypeInput);
- multi_stride_b = b->info()->strides_in_bytes().z() / sizeof(TypeInput);
- const arm_compute::WeightFormat wf = assembly_utils::map_to_arm_compute_weight_format(_gemm_kernel_asm->get_config().weight_format);
- if(is_fixed_format(wf))
- {
- // The 4D tensor of dimension O'HWI' created for the
- // OHWIo<interleave_by>i<block_by> format is in reality seen
- // as a 2D tensor at arm_gemm level, where the rows are
- // O'/<interleave_by> and the columns are <interleave_by> *
- // H * W * I'.
- ITensorInfo *tensor_info = b->info();
- const DataLayout data_layout = tensor_info->data_layout();
- const TensorShape tensor_shape = tensor_info->tensor_shape();
- const int tensor_height = tensor_shape[get_data_layout_dimension_index(data_layout, DataLayoutDimension::HEIGHT)];
- const int tensor_width = tensor_shape[get_data_layout_dimension_index(data_layout, DataLayoutDimension::WIDTH)];
- int tensor_channels = tensor_shape[get_data_layout_dimension_index(data_layout, DataLayoutDimension::CHANNEL)];
- const int interleave_by = arm_compute::interleave_by(wf);
- const int blocked_by = arm_compute::block_by(wf);
- // We need to find a new stride that is distance from the data for one
- // set of output channels to the next
- if(ldb == tensor_channels && multi_stride_b == tensor_channels * tensor_width)
- {
- // In this case dimensions that are packed are height, width and channel
- // so we need to stride it by interleave_by
- if(tensor_channels % blocked_by != 0)
- {
- // We need to pad
- tensor_channels = arm_gemm::iceildiv(tensor_channels, blocked_by) * blocked_by;
- }
- ldb = interleave_by * tensor_height * tensor_width * tensor_channels;
- }
- else if(multi_stride_b == 0 || (ldb == tensor_width && multi_stride_b == tensor_height * tensor_width))
- {
- // In this case dimension that is packed is only height
- // so we need to stride only height by interleave_by
- ldb = interleave_by * tensor_height;
- }
- else
- {
- // If dimensions are not packed as above error is thrown
- // as at the moment other forms of packing are not supported
- ARM_COMPUTE_ERROR("Unsupported packing for fixed format kernel");
- }
- }
+ ldb = b->info()->strides_in_bytes().y() / b->info()->element_size();
+ multi_stride_b = b->info()->strides_in_bytes().z() / b->info()->element_size();
in1_ptr = reinterpret_cast<const TypeInput *>(b->buffer() + b->info()->offset_first_element_in_bytes());
}
@@ -551,9 +509,9 @@ void Fallback<TypeInput, TypeOutput, OutputStage>::run(ITensorPack &tensors)
// Pretranspose B if required
if(_B_pretranspose_required)
{
- const int ldb = b->info()->strides_in_bytes().y() / sizeof(TypeInput);
+ const int ldb = b->info()->strides_in_bytes().y() / b->info()->element_size();
const auto b_ptr = reinterpret_cast<const TypeInput *>(b->buffer() + b->info()->offset_first_element_in_bytes());
- const int multi_stride_b = b->info()->strides_in_bytes().z() / sizeof(TypeInput);
+ const int multi_stride_b = b->info()->strides_in_bytes().z() / b->info()->element_size();
CpuAuxTensorHandler pretranspose(offset_int_vec(Pretranspose), _pretranspose_info, tensors, true);
ARM_COMPUTE_ERROR_ON(pretranspose.get()->buffer() == nullptr);
@@ -780,6 +738,11 @@ Status CpuGemmAssemblyDispatch::validate(const ITensorInfo *a, const ITensorInfo
{
ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(a, 1, DataType::QASYMM8_SIGNED, DataType::S8);
}
+ else if(is_fixed_format_fast_math(info.weight_format))
+ {
+ ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_NOT_IN(a, DataType::F32);
+ ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_NOT_IN(b, DataType::BFLOAT16);
+ }
else
{
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(a, b);