aboutsummaryrefslogtreecommitdiff
path: root/src/cpu
diff options
context:
space:
mode:
authorGunes Bayir <gunes.bayir@arm.com>2023-09-19 15:44:21 +0100
committerGunes Bayir <gunes.bayir@arm.com>2023-09-20 13:55:15 +0000
commite071b5e31004b29afefaa96907032bfd2b4e5a43 (patch)
tree0a53d7c32c6b3f055fdffcd5dcfc3830226e81cb /src/cpu
parent500e10b3222e726cfc5d484f924d5eb98016a754 (diff)
downloadComputeLibrary-e071b5e31004b29afefaa96907032bfd2b4e5a43.tar.gz
Fix the validation issue in AddMulAdd fused kernel
Resolves: COMPMID-6558 Change-Id: I015d504aaa9b8a1a232b01e49ab373d415ea1de9 Signed-off-by: Gunes Bayir <gunes.bayir@arm.com> Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/10340 Reviewed-by: Viet-Hoa Do <viet-hoa.do@arm.com> Reviewed-by: TeresaARM <teresa.charlinreyes@arm.com> Comments-Addressed: Arm Jenkins <bsgcomp@arm.com> Benchmark: Arm Jenkins <bsgcomp@arm.com> Tested-by: Arm Jenkins <bsgcomp@arm.com>
Diffstat (limited to 'src/cpu')
-rw-r--r--src/cpu/operators/CpuAddMulAdd.cpp6
1 files changed, 3 insertions, 3 deletions
diff --git a/src/cpu/operators/CpuAddMulAdd.cpp b/src/cpu/operators/CpuAddMulAdd.cpp
index 3fd690e3f9..590ee482ca 100644
--- a/src/cpu/operators/CpuAddMulAdd.cpp
+++ b/src/cpu/operators/CpuAddMulAdd.cpp
@@ -71,8 +71,8 @@ Status CpuAddMulAdd::validate(const ITensorInfo *input1, const ITensorInfo *inpu
const DataType data_type = input1->data_type();
if(is_data_type_quantized(data_type))
{
- TensorInfo dequantized_bn_mul;
- TensorInfo dequantized_bn_add;
+ TensorInfo dequantized_bn_mul = bn_mul->clone()->set_data_type(DataType::F32);
+ TensorInfo dequantized_bn_add = bn_add->clone()->set_data_type(DataType::F32);
ARM_COMPUTE_RETURN_ON_ERROR(CpuDequantize::validate(bn_mul, &dequantized_bn_mul));
ARM_COMPUTE_RETURN_ON_ERROR(CpuDequantize::validate(bn_add, &dequantized_bn_add));
@@ -87,7 +87,7 @@ Status CpuAddMulAdd::validate(const ITensorInfo *input1, const ITensorInfo *inpu
void CpuAddMulAdd::run(ITensorPack &tensors)
{
- const DataType data_type = tensors.get_const_tensor(TensorType::ACL_SRC_0)->info()->data_type();
+ const DataType data_type = tensors.get_const_tensor(TensorType::ACL_SRC_0)->info()->data_type();
if(is_data_type_quantized(data_type))
{