aboutsummaryrefslogtreecommitdiff
path: root/src/cpu/operators/CpuGemmConv2d.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/cpu/operators/CpuGemmConv2d.cpp')
-rw-r--r--src/cpu/operators/CpuGemmConv2d.cpp11
1 files changed, 8 insertions, 3 deletions
diff --git a/src/cpu/operators/CpuGemmConv2d.cpp b/src/cpu/operators/CpuGemmConv2d.cpp
index f3a16f104f..9bf6ed1e85 100644
--- a/src/cpu/operators/CpuGemmConv2d.cpp
+++ b/src/cpu/operators/CpuGemmConv2d.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2021-2022 Arm Limited.
+ * Copyright (c) 2021-2023 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -427,7 +427,12 @@ Status CpuGemmConv2d::validate(const ITensorInfo *src, const ITensorInfo *weight
ARM_COMPUTE_RETURN_ERROR_ON_MSG(weights_info.are_reshaped(), "Weights already reshaped are not supported!");
ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(src, 1, DataType::QASYMM8, DataType::QASYMM8_SIGNED, DataType::BFLOAT16, DataType::F16, DataType::F32);
ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(weights, 1, DataType::QASYMM8, DataType::QASYMM8_SIGNED, DataType::QSYMM8_PER_CHANNEL, DataType::BFLOAT16, DataType::F16, DataType::F32);
- ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_LAYOUT(src, weights);
+
+ if (!is_fixed_format(weights_info.weight_format()))
+ {
+ ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_LAYOUT(src, weights);
+ }
+
ARM_COMPUTE_RETURN_ERROR_ON_MSG(num_groups > 1, "Grouping (num_groups != 1) is not supported");
const DataLayout data_layout = src->data_layout();
@@ -493,7 +498,7 @@ Status CpuGemmConv2d::validate(const ITensorInfo *src, const ITensorInfo *weight
unsigned int mat_weights_cols = weights->dimension(idx_kernels);
unsigned int mat_weights_rows = weights->dimension(idx_width) * weights->dimension(idx_height) * weights->dimension(idx_channel);
- weights_reshaped_info = TensorInfo(compute_weights_reshaped_shape(*weights, append_bias), 1, data_type);
+ weights_reshaped_info = TensorInfo(compute_weights_reshaped_shape(*weights, append_bias), 1, weights->data_type());
weights_reshaped_info.set_quantization_info(weights->quantization_info());
weights_to_use = &weights_reshaped_info;