aboutsummaryrefslogtreecommitdiff
path: root/src/dynamic_fusion/sketch/gpu/components/cl/ClComponentDirectConv2d.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/dynamic_fusion/sketch/gpu/components/cl/ClComponentDirectConv2d.cpp')
-rw-r--r--src/dynamic_fusion/sketch/gpu/components/cl/ClComponentDirectConv2d.cpp18
1 files changed, 17 insertions, 1 deletions
diff --git a/src/dynamic_fusion/sketch/gpu/components/cl/ClComponentDirectConv2d.cpp b/src/dynamic_fusion/sketch/gpu/components/cl/ClComponentDirectConv2d.cpp
index dc05825500..1fbcb41028 100644
--- a/src/dynamic_fusion/sketch/gpu/components/cl/ClComponentDirectConv2d.cpp
+++ b/src/dynamic_fusion/sketch/gpu/components/cl/ClComponentDirectConv2d.cpp
@@ -57,13 +57,24 @@ bool ClComponentDirectConv2dSettings::fast_relaxed_math() const
return _fast_relaxed_math;
}
+ClComponentDirectConv2dSettings &ClComponentDirectConv2dSettings::direct_conv_descriptor(const DirectConvComputeKernelInfo &desc)
+{
+ _desc = desc;
+ return *this;
+}
+
+DirectConvComputeKernelInfo ClComponentDirectConv2dSettings::direct_conv_descriptor() const
+{
+ return _desc;
+}
+
Status ClComponentDirectConv2d::validate(
const Properties &properties,
const ArgumentPack<ITensorInfo> &tensors,
const Attributes &attributes,
const Settings &settings)
{
- ARM_COMPUTE_UNUSED(properties, settings);
+ ARM_COMPUTE_UNUSED(properties);
const auto src = tensors.get_const_tensor(TensorType::ACL_SRC_0);
const auto wei = tensors.get_const_tensor(TensorType::ACL_SRC_1);
const auto bia = tensors.get_const_tensor(TensorType::ACL_SRC_2);
@@ -125,6 +136,11 @@ Status ClComponentDirectConv2d::validate(
// Data layout
ARM_COMPUTE_RETURN_ERROR_ON_DATA_LAYOUT_NOT_IN(src, DataLayout::NHWC);
+ const auto desc = settings.direct_conv_descriptor();
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(desc.n0 != 1 && desc.n0 != 2 && desc.n0 != 3 && desc.n0 != 4 && desc.n0 != 8 && desc.n0 != 16,
+ "N0 can only be: 1, 2, 3, 4, 8, and 16");
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(desc.k0 != 1 && desc.k0 != 2 && desc.k0 != 3 && desc.k0 != 4 && desc.k0 != 8 && desc.k0 != 16,
+ "K0 can only be: 1, 2, 3, 4, 8, and 16");
return Status{};
}