aboutsummaryrefslogtreecommitdiff
path: root/src/cpu/kernels/internal/CpuPool2dAssemblyWrapperKernel.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/cpu/kernels/internal/CpuPool2dAssemblyWrapperKernel.cpp')
-rw-r--r--src/cpu/kernels/internal/CpuPool2dAssemblyWrapperKernel.cpp6
1 files changed, 4 insertions, 2 deletions
diff --git a/src/cpu/kernels/internal/CpuPool2dAssemblyWrapperKernel.cpp b/src/cpu/kernels/internal/CpuPool2dAssemblyWrapperKernel.cpp
index 9ba2451482..2c1cb15786 100644
--- a/src/cpu/kernels/internal/CpuPool2dAssemblyWrapperKernel.cpp
+++ b/src/cpu/kernels/internal/CpuPool2dAssemblyWrapperKernel.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2021-2023 Arm Limited.
+ * Copyright (c) 2021-2024 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -23,6 +23,7 @@
*/
#include "src/cpu/kernels/internal/CpuPool2dAssemblyWrapperKernel.h"
+#include "arm_compute/core/TensorInfo.h"
#include "arm_compute/core/Utils.h"
#include "arm_compute/core/utils/misc/ShapeCalculator.h"
#include "arm_compute/core/utils/quantization/AsymmHelpers.h"
@@ -100,7 +101,6 @@ Status
CpuPool2dAssemblyWrapperKernel::validate(const ITensorInfo *src, const ITensorInfo *dst, const PoolingLayerInfo &info)
{
ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(src, dst);
-
#ifndef __aarch64__
ARM_COMPUTE_RETURN_ERROR_MSG("32-bit is not supported by assembly kernels");
#endif /* __aarch64__ */
@@ -120,6 +120,8 @@ CpuPool2dAssemblyWrapperKernel::validate(const ITensorInfo *src, const ITensorIn
{
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(src, dst);
+ const TensorInfo out_info(compute_pool_shape(*src, info), 1, dst->data_type());
+ ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(dst, &out_info);
const auto src_qinfo = src->quantization_info().uniform();
const auto dst_qinfo = dst->quantization_info().uniform();