aboutsummaryrefslogtreecommitdiff
path: root/src/cpu/operators/CpuGemm.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/cpu/operators/CpuGemm.cpp')
-rw-r--r--src/cpu/operators/CpuGemm.cpp23
1 files changed, 20 insertions, 3 deletions
diff --git a/src/cpu/operators/CpuGemm.cpp b/src/cpu/operators/CpuGemm.cpp
index b9d18c4cb6..7411d763b4 100644
--- a/src/cpu/operators/CpuGemm.cpp
+++ b/src/cpu/operators/CpuGemm.cpp
@@ -155,14 +155,14 @@ void CpuGemm::configure(const ITensorInfo *a, const ITensorInfo *b, const ITenso
Status CpuGemm::validate(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *d, float alpha, float beta, const GEMMInfo &gemm_info)
{
ARM_COMPUTE_UNUSED(alpha);
- const bool is_c_bias = beta == 1 && c != nullptr;
+ const bool is_c_bias = beta == 1 && c != nullptr;
const bool run_addition = c != nullptr && beta != 0 && beta != 1;
ARM_COMPUTE_RETURN_ERROR_ON_CPU_F16_UNSUPPORTED(a);
ARM_COMPUTE_RETURN_ERROR_ON_CPU_BF16_UNSUPPORTED(a);
ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(a, 1, DataType::BFLOAT16, DataType::F16, DataType::F32);
- if (is_fixed_format_fast_math(gemm_info.weight_format()))
+ if(is_fixed_format_fast_math(gemm_info.weight_format()))
{
ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_NOT_IN(a, DataType::F32);
ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_NOT_IN(b, DataType::BFLOAT16);
@@ -172,7 +172,24 @@ Status CpuGemm::validate(const ITensorInfo *a, const ITensorInfo *b, const ITens
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(a, b);
}
- ARM_COMPUTE_RETURN_ERROR_ON_MSG(a->dimension(0) != b->dimension(1), "The product AB is defined only if the number of columns in A is equal to the number of rows in B");
+ const int block_by = arm_compute::block_by(gemm_info.weight_format());
+ if(block_by > 1)
+ {
+ // have to verify bias
+ const size_t dim0_sz = a->dimension(0);
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG((dim0_sz % block_by) != 0, ("The matrix A number of columns must be a multiple of block_by=" + std::to_string(block_by)).c_str());
+ // a->dimension(0) = kernel_area * input_channel + kernel_area * input_pad_right
+ // b->dimension(1) = kernel_area * input_channel
+ // a->dimension(0) = b->dimension(1) + kernel_area * input_pad_right
+ const size_t input_pad_right = (dim0_sz - b->dimension(1)) % block_by;
+ const size_t kernel_area = (dim0_sz - b->dimension(1)) / input_pad_right;
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG((dim0_sz - kernel_area * input_pad_right) != b->dimension(1), "The product AB is defined only if A number of columns and B number of rows are related");
+ }
+ else
+ {
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(a->dimension(0) != b->dimension(1), "The product AB is defined only if the number of columns in A is equal to the number of rows in B");
+ }
+
ARM_COMPUTE_RETURN_ERROR_ON_MSG(gemm_info.is_a_reshaped(), "Matrix A already reshaped is not supported");
ARM_COMPUTE_RETURN_ERROR_ON_MSG(gemm_info.is_b_reshaped(), "Matrix B already reshaped is not supported");
if(a->data_type() != DataType::BFLOAT16)