aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJonathan Deakin <jonathan.deakin@arm.com>2023-03-02 15:15:15 +0000
committerJakub Sujak <jakub.sujak@arm.com>2023-03-13 17:05:35 +0000
commit0527cea105c8e73c1aebd845616864dff1bee935 (patch)
treec070359a4277c6d84ba49ce3cb7a1cf23b1f425f
parent470cc5dea721716a93fded4d98642f6e9150c69b (diff)
downloadComputeLibrary-0527cea105c8e73c1aebd845616864dff1bee935.tar.gz
[ONCPUML-1174] Allow src/weights mismatch for fixed format
Without this, we have to pass in weights to be NHWC, even if they are in fact blocked/interleaved for consumption by a fixed format kernel. Signed-off-by: Jonathan Deakin <jonathan.deakin@arm.com> Change-Id: I9ee8720a21a16b17816dbecf6308e1668ddda59c Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/9285 Benchmark: Arm Jenkins <bsgcomp@arm.com> Tested-by: Arm Jenkins <bsgcomp@arm.com> Reviewed-by: Jakub Sujak <jakub.sujak@arm.com> Reviewed-by: Gian Marco Iodice <gianmarco.iodice@arm.com> Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
-rw-r--r--src/cpu/operators/CpuGemmDirectConv2d.cpp7
1 files changed, 5 insertions, 2 deletions
diff --git a/src/cpu/operators/CpuGemmDirectConv2d.cpp b/src/cpu/operators/CpuGemmDirectConv2d.cpp
index ee47a17d64..5ce285cb6f 100644
--- a/src/cpu/operators/CpuGemmDirectConv2d.cpp
+++ b/src/cpu/operators/CpuGemmDirectConv2d.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2021-2022 Arm Limited.
+ * Copyright (c) 2021-2023 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -158,7 +158,10 @@ Status CpuGemmDirectConv2d::validate(const ITensorInfo *src, const ITensorInfo *
ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(src, weights, dst);
ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(src, 1, DataType::QASYMM8, DataType::QASYMM8_SIGNED, DataType::BFLOAT16, DataType::F16, DataType::F32);
ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(weights, 1, DataType::QASYMM8, DataType::QASYMM8_SIGNED, DataType::QSYMM8_PER_CHANNEL, DataType::BFLOAT16, DataType::F16, DataType::F32);
- ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_LAYOUT(src, weights);
+ if(!is_fixed_format(info.weights_info.weight_format()))
+ {
+ ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_LAYOUT(src, weights);
+ }
ARM_COMPUTE_RETURN_ERROR_ON_MSG(info.num_groups > 1, "Grouping (num_groups != 1) is not supported on Neon");
ARM_COMPUTE_RETURN_ERROR_ON_MSG(src->data_layout() != DataLayout::NHWC, "Data layout supported is NHWC");
const DataType data_type = src->data_type();