diff options
Diffstat (limited to 'src/core/gpu/cl/kernels/ClDirectConv2dKernel.cpp')
-rw-r--r-- | src/core/gpu/cl/kernels/ClDirectConv2dKernel.cpp | 22 |
1 files changed, 12 insertions, 10 deletions
diff --git a/src/core/gpu/cl/kernels/ClDirectConv2dKernel.cpp b/src/core/gpu/cl/kernels/ClDirectConv2dKernel.cpp index 2c9a4f301b..94c4044bff 100644 --- a/src/core/gpu/cl/kernels/ClDirectConv2dKernel.cpp +++ b/src/core/gpu/cl/kernels/ClDirectConv2dKernel.cpp @@ -48,7 +48,8 @@ namespace kernels { namespace { -Status validate_arguments(const ITensorInfo *src, const ITensorInfo *weights, const ITensorInfo *biases, const ITensorInfo *dst, const PadStrideInfo &conv_info) +Status validate_arguments(const ITensorInfo *src, const ITensorInfo *weights, const ITensorInfo *biases, const ITensorInfo *dst, + const PadStrideInfo &conv_info, const ActivationLayerInfo &act_info) { ARM_COMPUTE_RETURN_ERROR_ON_F16_UNSUPPORTED(src); ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(src, 1, DataType::QASYMM8_SIGNED, DataType::QASYMM8, DataType::F16, DataType::F32); @@ -67,6 +68,8 @@ Status validate_arguments(const ITensorInfo *src, const ITensorInfo *weights, co ARM_COMPUTE_RETURN_ERROR_ON_MSG((weights->dimension(width_idx) == 3 || weights->dimension(width_idx) == 5 || weights->dimension(width_idx) == 9) && std::get<0>(conv_info.stride()) > 2, "Strides larger than 2 not supported for 3x3, 5x5, 9x9 convolution."); + ARM_COMPUTE_RETURN_ERROR_ON_MSG(data_layout != DataLayout::NHWC && !is_data_type_float(src->data_type()) && act_info.enabled(), + "Activation supported only for floating point and NHWC."); if(data_layout == DataLayout::NCHW) { @@ -375,16 +378,12 @@ BorderSize ClDirectConv2dKernel::border_size() const } void ClDirectConv2dKernel::configure(const CLCompileContext &compile_context, ITensorInfo *src, ITensorInfo *weights, ITensorInfo *biases, ITensorInfo *dst, - const PadStrideInfo &conv_info) + const PadStrideInfo &conv_info, const ActivationLayerInfo &act_info) { ARM_COMPUTE_ERROR_ON_NULLPTR(src, weights, dst); // Perform validation - ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(src, - weights, - (biases != nullptr) ? biases : nullptr, - dst, - conv_info)); + ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(src, weights, biases, dst, conv_info, act_info)); const int conv_stride_x = std::get<0>(conv_info.stride()); const int conv_stride_y = std::get<1>(conv_info.stride()); @@ -457,6 +456,7 @@ void ClDirectConv2dKernel::configure(const CLCompileContext &compile_context, IT build_options.add_option("-DM0=" + support::cpp11::to_string(m0)); build_options.add_option("-DK0=" + support::cpp11::to_string(k0)); build_options.add_option("-DPARTIAL_N0=" + support::cpp11::to_string(partial_store_n0)); + build_options.add_option("-DACTIVATION_TYPE=" + lower_string(string_from_activation_func(act_info.activation()))); if(is_data_type_quantized(data_type)) { @@ -488,6 +488,8 @@ void ClDirectConv2dKernel::configure(const CLCompileContext &compile_context, IT build_options.add_option("-DSRC_OFFSET=" + support::cpp11::to_string(0)); build_options.add_option("-DWEI_OFFSET=" + support::cpp11::to_string(0)); build_options.add_option("-DDST_OFFSET=" + support::cpp11::to_string(0)); + build_options.add_option_if(act_info.enabled(), "-DA_VAL=" + float_to_string_with_full_precision(act_info.a())); + build_options.add_option_if(act_info.enabled(), "-DB_VAL=" + float_to_string_with_full_precision(act_info.b())); } } else @@ -564,10 +566,10 @@ void ClDirectConv2dKernel::configure(const CLCompileContext &compile_context, IT _config_id += lower_string(string_from_data_layout(_data_layout)); } -Status ClDirectConv2dKernel::validate(const ITensorInfo *src, const ITensorInfo *weights, const ITensorInfo *biases, const ITensorInfo *dst, const PadStrideInfo &conv_info, - const GPUTarget target) +Status ClDirectConv2dKernel::validate(const ITensorInfo *src, const ITensorInfo *weights, const ITensorInfo *biases, const ITensorInfo *dst, + const PadStrideInfo &conv_info, const ActivationLayerInfo &act_info, const GPUTarget target) { - ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(src, weights, biases, dst, conv_info)); + ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(src, weights, biases, dst, conv_info, act_info)); ARM_COMPUTE_RETURN_ON_ERROR(validate_and_configure_window(src->clone().get(), weights->clone().get(), dst->clone().get(), conv_info, target).first); return Status{}; |