aboutsummaryrefslogtreecommitdiff
path: root/src/dynamic_fusion/sketch/gpu/operators/GpuConv2d.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/dynamic_fusion/sketch/gpu/operators/GpuConv2d.cpp')
-rw-r--r--src/dynamic_fusion/sketch/gpu/operators/GpuConv2d.cpp29
1 files changed, 14 insertions, 15 deletions
diff --git a/src/dynamic_fusion/sketch/gpu/operators/GpuConv2d.cpp b/src/dynamic_fusion/sketch/gpu/operators/GpuConv2d.cpp
index 12aa4d1b9f..9cb4ee7815 100644
--- a/src/dynamic_fusion/sketch/gpu/operators/GpuConv2d.cpp
+++ b/src/dynamic_fusion/sketch/gpu/operators/GpuConv2d.cpp
@@ -23,18 +23,17 @@
*/
#include "arm_compute/dynamic_fusion/sketch/gpu/operators/GpuConv2d.h"
-#include "arm_compute/core/CL/CLCompileContext.h"
#include "arm_compute/core/Validate.h"
-#include "arm_compute/core/experimental/Types.h"
#include "arm_compute/core/utils/misc/ShapeCalculator.h"
#include "src/core/helpers/AutoConfiguration.h"
#include "src/dynamic_fusion/sketch/ArgumentPack.h"
#include "src/dynamic_fusion/sketch/gpu/GpuWorkloadSketchImpl.h"
-#include "src/dynamic_fusion/sketch/gpu/GpuWorkloadSourceCode.h"
#include "src/dynamic_fusion/sketch/gpu/components/cl/ClComponentDirectConv2d.h"
#include "src/gpu/cl/kernels/gemm/ClGemmHelpers.h"
+#include "src/common/utils/Log.h"
+
namespace arm_compute
{
namespace experimental
@@ -103,18 +102,6 @@ Status GpuConv2d::validate_op(const GpuWorkloadSketch &sketch,
{
ARM_COMPUTE_RETURN_ERROR_ON(!bia->has_valid_id());
}
-
- // Perform fusion test
- // Pack tensor infos
- ArgumentPack<ITensorInfo> tensors;
- tensors.add_const_tensor(ACL_SRC_0, src);
- tensors.add_const_tensor(ACL_SRC_1, wei);
- tensors.add_const_tensor(ACL_SRC_2, bia);
- tensors.add_const_tensor(ACL_DST_0, dst);
- const auto op = sketch.implementation().operator_group().new_operator(operator_type, tensors);
- ARM_COMPUTE_RETURN_ERROR_ON_MSG(!sketch.implementation().operator_group().try_add_operator(op),
- "Operator fusion test failed. This operator cannot be fused into the workload");
-
// Auto initialize dst tensor info
TensorInfo dst_info_to_validate = *dst;
const auto data_layout = src->data_layout();
@@ -128,6 +115,17 @@ Status GpuConv2d::validate_op(const GpuWorkloadSketch &sketch,
auto_init_if_empty(dst_info_to_validate, src->clone()->set_tensor_shape(shape));
}
+ // Perform fusion test
+ // Pack tensor infos
+ ArgumentPack<ITensorInfo> tensors;
+ tensors.add_const_tensor(ACL_SRC_0, src);
+ tensors.add_const_tensor(ACL_SRC_1, wei);
+ tensors.add_const_tensor(ACL_SRC_2, bia);
+ tensors.add_const_tensor(ACL_DST_0, &dst_info_to_validate);
+ const auto op = sketch.implementation().operator_group().new_operator(operator_type, tensors);
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(!sketch.implementation().operator_group().try_add_operator(op),
+ "Operator fusion test failed. This operator cannot be fused into the workload");
+
// Check support level
// Data type
ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(src, 1, DataType::F16, DataType::F32);
@@ -176,6 +174,7 @@ void GpuConv2d::create_op(GpuWorkloadSketch &sketch,
ITensorInfo *dst,
const Conv2dAttributes &attributes)
{
+ ARM_COMPUTE_LOG_PARAMS(src, wei, bia, dst, attributes);
// Assert validation
ARM_COMPUTE_ERROR_THROW_ON(GpuConv2d::validate_op(sketch, src, wei, bia, dst, attributes));
ARM_COMPUTE_ERROR_ON_NULLPTR(src, wei, dst);