aboutsummaryrefslogtreecommitdiff
path: root/src/core/CL/kernels/CLWeightsReshapeKernel.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/core/CL/kernels/CLWeightsReshapeKernel.cpp')
-rw-r--r--src/core/CL/kernels/CLWeightsReshapeKernel.cpp8
1 files changed, 5 insertions, 3 deletions
diff --git a/src/core/CL/kernels/CLWeightsReshapeKernel.cpp b/src/core/CL/kernels/CLWeightsReshapeKernel.cpp
index 9330b3b8a1..e325feac1f 100644
--- a/src/core/CL/kernels/CLWeightsReshapeKernel.cpp
+++ b/src/core/CL/kernels/CLWeightsReshapeKernel.cpp
@@ -33,7 +33,8 @@
#include "arm_compute/core/Types.h"
#include "arm_compute/core/utils/misc/ShapeCalculator.h"
-using namespace arm_compute;
+namespace arm_compute
+{
using namespace arm_compute::misc::shape_calculator;
namespace
@@ -42,7 +43,7 @@ Status validate_arguments(const ITensorInfo *input, const ITensorInfo *biases, c
{
ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input, output);
ARM_COMPUTE_RETURN_ERROR_ON_F16_UNSUPPORTED(input);
- ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QASYMM8, DataType::F16, DataType::F32);
+ ARM_COMPUTE_RETURN_ERROR_ON(input->data_type() == DataType::UNKNOWN);
ARM_COMPUTE_RETURN_ERROR_ON(num_groups == 0);
ARM_COMPUTE_RETURN_ERROR_ON(input->data_layout() == DataLayout::NHWC && num_groups > 1);
ARM_COMPUTE_RETURN_ERROR_ON(input->num_dimensions() > 4 && num_groups > 1);
@@ -50,7 +51,7 @@ Status validate_arguments(const ITensorInfo *input, const ITensorInfo *biases, c
if(biases != nullptr)
{
- ARM_COMPUTE_RETURN_ERROR_ON(is_data_type_quantized_asymmetric(input->data_type()));
+ ARM_COMPUTE_RETURN_ERROR_ON(!is_data_type_float(input->data_type()));
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, biases);
ARM_COMPUTE_RETURN_ERROR_ON((input->num_dimensions() == 4) && (biases->num_dimensions() != 1));
ARM_COMPUTE_RETURN_ERROR_ON((input->num_dimensions() == 5) && (biases->num_dimensions() != 2));
@@ -160,3 +161,4 @@ void CLWeightsReshapeKernel::run(const Window &window, cl::CommandQueue &queue)
}
while(window.slide_window_slice_4D(in_slice) && out_window.slide_window_slice_2D(out_slice));
}
+} // namespace arm_compute