diff options
Diffstat (limited to 'src/core/CL/kernels/CLPixelWiseMultiplicationKernel.cpp')
-rw-r--r-- | src/core/CL/kernels/CLPixelWiseMultiplicationKernel.cpp | 31 |
1 files changed, 3 insertions, 28 deletions
diff --git a/src/core/CL/kernels/CLPixelWiseMultiplicationKernel.cpp b/src/core/CL/kernels/CLPixelWiseMultiplicationKernel.cpp index a9df36dfcc..4ea093fe04 100644 --- a/src/core/CL/kernels/CLPixelWiseMultiplicationKernel.cpp +++ b/src/core/CL/kernels/CLPixelWiseMultiplicationKernel.cpp @@ -51,36 +51,23 @@ Status validate_arguments(const ITensorInfo *input1, const ITensorInfo *input2, ARM_COMPUTE_UNUSED(rounding_policy); ARM_COMPUTE_RETURN_ERROR_ON_F16_UNSUPPORTED(input1); - ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input1, 1, DataType::U8, DataType::QS8, DataType::QS16, DataType::S16, DataType::F16, DataType::F32); + ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input1, 1, DataType::U8, DataType::S16, DataType::F16, DataType::F32); ARM_COMPUTE_RETURN_ERROR_ON_F16_UNSUPPORTED(input2); - ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input2, 1, DataType::U8, DataType::QS8, DataType::QS16, DataType::S16, DataType::F16, DataType::F32); + ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input2, 1, DataType::U8, DataType::S16, DataType::F16, DataType::F32); ARM_COMPUTE_RETURN_ERROR_ON_MSG(scale < 0, "Scale cannot be negative."); const TensorShape &out_shape = TensorShape::broadcast_shape(input1->tensor_shape(), input2->tensor_shape()); ARM_COMPUTE_RETURN_ERROR_ON_MSG(out_shape.total_size() == 0, "Inputs are not broadcast compatible"); - ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_FIXED_POINT(input1, input2); - - if(is_data_type_fixed_point(input1->data_type())) - { - // All data types must be all QS8 or all QS16 - ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input1, input2); - ARM_COMPUTE_RETURN_ERROR_ON_MSG(scale != 1, "Unsupported scaling factor for QS8/QS16. Scale must be 1."); - } // Validate in case of configured output if(output->total_size() > 0) { ARM_COMPUTE_RETURN_ERROR_ON_F16_UNSUPPORTED(output); - ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, DataType::U8, DataType::QS8, DataType::QS16, DataType::S16, DataType::F16, DataType::F32); + ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, DataType::U8, DataType::S16, DataType::F16, DataType::F32); ARM_COMPUTE_RETURN_ERROR_ON_MSG(output->data_type() == DataType::U8 && (input1->data_type() != DataType::U8 || input2->data_type() != DataType::U8), "Output can only be U8 if both inputs are U8"); ARM_COMPUTE_RETURN_ERROR_ON_MSG(detail::have_different_dimensions(out_shape, output->tensor_shape(), 0), "Wrong shape for output"); - ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_FIXED_POINT(input1, output); - if(is_data_type_fixed_point(input1->data_type())) - { - ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input1, output); - } } return Status{}; @@ -174,14 +161,6 @@ void CLPixelWiseMultiplicationKernel::configure(const ICLTensor *input1, const I { compute_type = "int"; } - else if(input1->info()->data_type() == DataType::QS8) - { - compute_type = "qs8"; - } - else if(input1->info()->data_type() == DataType::QS16) - { - compute_type = "qs16"; - } else { compute_type = "ushort"; @@ -197,10 +176,6 @@ void CLPixelWiseMultiplicationKernel::configure(const ICLTensor *input1, const I std::set<std::string> build_opts; build_opts.emplace((overflow_policy == ConvertPolicy::WRAP || is_data_type_float(output->info()->data_type())) ? "-DWRAP" : "-DSATURATE"); build_opts.emplace((rounding_policy == RoundingPolicy::TO_ZERO) ? "-DROUND=_rtz" : "-DROUND=_rte"); - if(is_data_type_fixed_point(input1->info()->data_type())) - { - build_opts.emplace("-DFIXED_POINT_POSITION=" + support::cpp11::to_string(input1->info()->fixed_point_position())); - } build_opts.emplace("-DDATA_TYPE_IN1=" + get_cl_type_from_data_type(input1->info()->data_type())); build_opts.emplace("-DDATA_TYPE_IN2=" + get_cl_type_from_data_type(input2->info()->data_type())); build_opts.emplace("-DDATA_TYPE_OUT=" + get_cl_type_from_data_type(output->info()->data_type())); |