diff options
Diffstat (limited to 'src/dynamic_fusion/sketch/gpu/operators')
-rw-r--r-- | src/dynamic_fusion/sketch/gpu/operators/GpuConv2d.cpp | 35 |
1 files changed, 29 insertions, 6 deletions
diff --git a/src/dynamic_fusion/sketch/gpu/operators/GpuConv2d.cpp b/src/dynamic_fusion/sketch/gpu/operators/GpuConv2d.cpp index 9cb4ee7815..048ee01f35 100644 --- a/src/dynamic_fusion/sketch/gpu/operators/GpuConv2d.cpp +++ b/src/dynamic_fusion/sketch/gpu/operators/GpuConv2d.cpp @@ -23,16 +23,19 @@ */ #include "arm_compute/dynamic_fusion/sketch/gpu/operators/GpuConv2d.h" +#include "arm_compute/core/KernelDescriptors.h" #include "arm_compute/core/Validate.h" #include "arm_compute/core/utils/misc/ShapeCalculator.h" +#include "arm_compute/runtime/CL/CLScheduler.h" +#include "src/common/utils/Log.h" #include "src/core/helpers/AutoConfiguration.h" #include "src/dynamic_fusion/sketch/ArgumentPack.h" #include "src/dynamic_fusion/sketch/gpu/GpuWorkloadSketchImpl.h" #include "src/dynamic_fusion/sketch/gpu/components/cl/ClComponentDirectConv2d.h" #include "src/gpu/cl/kernels/gemm/ClGemmHelpers.h" - -#include "src/common/utils/Log.h" +#include "src/runtime/heuristics/direct_conv/ClDirectConvKernelConfig.h" +#include "src/runtime/heuristics/direct_conv/IClDirectConvKernelConfig.h" namespace arm_compute { @@ -85,6 +88,16 @@ bool export_to_cl_image_support(const ITensorInfo *tensor, GPUTarget gpu_target, return true; } +DirectConvComputeKernelInfo config_direct_convolution_nhwc(const ITensorInfo *src, const ITensorInfo *weights, const PadStrideInfo &conv_info) +{ + // Get GPU target + GPUTarget gpu_target = CLScheduler::get().target(); + + std::unique_ptr<arm_compute::cl_direct_conv::IClDirectConvKernelConfig> t = arm_compute::cl_direct_conv::ClDirectConvKernelConfigurationFactory::create(gpu_target); + + return t->configure(src, weights, conv_info); +} + constexpr GpuOperatorType operator_type = GpuOperatorType::Complex; } // namespace @@ -112,6 +125,11 @@ Status GpuConv2d::validate_op(const GpuWorkloadSketch &sketch, attributes.pad().right, attributes.pad().top, attributes.pad().bottom, DimensionRoundingType::FLOOR)); // use the default DimensionRoundingType + // Checks performed when dst is configured + if(dst->total_size() != 0) + { + ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DIMENSIONS(dst->tensor_shape(), shape); + } auto_init_if_empty(dst_info_to_validate, src->clone()->set_tensor_shape(shape)); } @@ -175,6 +193,12 @@ void GpuConv2d::create_op(GpuWorkloadSketch &sketch, const Conv2dAttributes &attributes) { ARM_COMPUTE_LOG_PARAMS(src, wei, bia, dst, attributes); + PadStrideInfo conv_info(attributes.stride().x(), attributes.stride().y(), attributes.pad().left, + attributes.pad().right, + attributes.pad().top, attributes.pad().bottom, DimensionRoundingType::FLOOR); + // Initialize the direct convolution descriptor + const DirectConvComputeKernelInfo desc = config_direct_convolution_nhwc(src, wei, conv_info); + // Assert validation ARM_COMPUTE_ERROR_THROW_ON(GpuConv2d::validate_op(sketch, src, wei, bia, dst, attributes)); ARM_COMPUTE_ERROR_ON_NULLPTR(src, wei, dst); @@ -182,10 +206,7 @@ void GpuConv2d::create_op(GpuWorkloadSketch &sketch, // Auto initialize dst tensor { - auto shape = misc::shape_calculator::compute_deep_convolution_shape(src->tensor_shape(), data_layout, wei->tensor_shape(), - PadStrideInfo(attributes.stride().x(), attributes.stride().y(), attributes.pad().left, - attributes.pad().right, - attributes.pad().top, attributes.pad().bottom, DimensionRoundingType::FLOOR)); // use the default DimensionRoundingType + auto shape = misc::shape_calculator::compute_deep_convolution_shape(src->tensor_shape(), data_layout, wei->tensor_shape(), conv_info); // use the default DimensionRoundingType auto_init_if_empty(*dst, src->clone()->set_tensor_shape(shape)); } @@ -221,6 +242,8 @@ void GpuConv2d::create_op(GpuWorkloadSketch &sketch, arm_compute::opencl::kernels::gemm::update_padding_for_cl_image(wei); } + settings.direct_conv_descriptor(desc); + ArgumentPack<ITensorInfo> arguments; arguments.add_const_tensor(ACL_SRC_0, src); arguments.add_const_tensor(ACL_SRC_1, wei); |