aboutsummaryrefslogtreecommitdiff
path: root/src/dynamic_fusion/sketch/gpu/operators/GpuConv2d.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/dynamic_fusion/sketch/gpu/operators/GpuConv2d.cpp')
-rw-r--r--src/dynamic_fusion/sketch/gpu/operators/GpuConv2d.cpp35
1 files changed, 29 insertions, 6 deletions
diff --git a/src/dynamic_fusion/sketch/gpu/operators/GpuConv2d.cpp b/src/dynamic_fusion/sketch/gpu/operators/GpuConv2d.cpp
index 9cb4ee7815..048ee01f35 100644
--- a/src/dynamic_fusion/sketch/gpu/operators/GpuConv2d.cpp
+++ b/src/dynamic_fusion/sketch/gpu/operators/GpuConv2d.cpp
@@ -23,16 +23,19 @@
*/
#include "arm_compute/dynamic_fusion/sketch/gpu/operators/GpuConv2d.h"
+#include "arm_compute/core/KernelDescriptors.h"
#include "arm_compute/core/Validate.h"
#include "arm_compute/core/utils/misc/ShapeCalculator.h"
+#include "arm_compute/runtime/CL/CLScheduler.h"
+#include "src/common/utils/Log.h"
#include "src/core/helpers/AutoConfiguration.h"
#include "src/dynamic_fusion/sketch/ArgumentPack.h"
#include "src/dynamic_fusion/sketch/gpu/GpuWorkloadSketchImpl.h"
#include "src/dynamic_fusion/sketch/gpu/components/cl/ClComponentDirectConv2d.h"
#include "src/gpu/cl/kernels/gemm/ClGemmHelpers.h"
-
-#include "src/common/utils/Log.h"
+#include "src/runtime/heuristics/direct_conv/ClDirectConvKernelConfig.h"
+#include "src/runtime/heuristics/direct_conv/IClDirectConvKernelConfig.h"
namespace arm_compute
{
@@ -85,6 +88,16 @@ bool export_to_cl_image_support(const ITensorInfo *tensor, GPUTarget gpu_target,
return true;
}
+DirectConvComputeKernelInfo config_direct_convolution_nhwc(const ITensorInfo *src, const ITensorInfo *weights, const PadStrideInfo &conv_info)
+{
+ // Get GPU target
+ GPUTarget gpu_target = CLScheduler::get().target();
+
+ std::unique_ptr<arm_compute::cl_direct_conv::IClDirectConvKernelConfig> t = arm_compute::cl_direct_conv::ClDirectConvKernelConfigurationFactory::create(gpu_target);
+
+ return t->configure(src, weights, conv_info);
+}
+
constexpr GpuOperatorType operator_type = GpuOperatorType::Complex;
} // namespace
@@ -112,6 +125,11 @@ Status GpuConv2d::validate_op(const GpuWorkloadSketch &sketch,
attributes.pad().right,
attributes.pad().top, attributes.pad().bottom, DimensionRoundingType::FLOOR)); // use the default DimensionRoundingType
+ // Checks performed when dst is configured
+ if(dst->total_size() != 0)
+ {
+ ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DIMENSIONS(dst->tensor_shape(), shape);
+ }
auto_init_if_empty(dst_info_to_validate, src->clone()->set_tensor_shape(shape));
}
@@ -175,6 +193,12 @@ void GpuConv2d::create_op(GpuWorkloadSketch &sketch,
const Conv2dAttributes &attributes)
{
ARM_COMPUTE_LOG_PARAMS(src, wei, bia, dst, attributes);
+ PadStrideInfo conv_info(attributes.stride().x(), attributes.stride().y(), attributes.pad().left,
+ attributes.pad().right,
+ attributes.pad().top, attributes.pad().bottom, DimensionRoundingType::FLOOR);
+ // Initialize the direct convolution descriptor
+ const DirectConvComputeKernelInfo desc = config_direct_convolution_nhwc(src, wei, conv_info);
+
// Assert validation
ARM_COMPUTE_ERROR_THROW_ON(GpuConv2d::validate_op(sketch, src, wei, bia, dst, attributes));
ARM_COMPUTE_ERROR_ON_NULLPTR(src, wei, dst);
@@ -182,10 +206,7 @@ void GpuConv2d::create_op(GpuWorkloadSketch &sketch,
// Auto initialize dst tensor
{
- auto shape = misc::shape_calculator::compute_deep_convolution_shape(src->tensor_shape(), data_layout, wei->tensor_shape(),
- PadStrideInfo(attributes.stride().x(), attributes.stride().y(), attributes.pad().left,
- attributes.pad().right,
- attributes.pad().top, attributes.pad().bottom, DimensionRoundingType::FLOOR)); // use the default DimensionRoundingType
+ auto shape = misc::shape_calculator::compute_deep_convolution_shape(src->tensor_shape(), data_layout, wei->tensor_shape(), conv_info); // use the default DimensionRoundingType
auto_init_if_empty(*dst, src->clone()->set_tensor_shape(shape));
}
@@ -221,6 +242,8 @@ void GpuConv2d::create_op(GpuWorkloadSketch &sketch,
arm_compute::opencl::kernels::gemm::update_padding_for_cl_image(wei);
}
+ settings.direct_conv_descriptor(desc);
+
ArgumentPack<ITensorInfo> arguments;
arguments.add_const_tensor(ACL_SRC_0, src);
arguments.add_const_tensor(ACL_SRC_1, wei);