diff options
Diffstat (limited to 'src/dynamic_fusion/sketch/gpu/components/cl/ClComponentDirectConv2d.cpp')
-rw-r--r-- | src/dynamic_fusion/sketch/gpu/components/cl/ClComponentDirectConv2d.cpp | 18 |
1 files changed, 17 insertions, 1 deletions
diff --git a/src/dynamic_fusion/sketch/gpu/components/cl/ClComponentDirectConv2d.cpp b/src/dynamic_fusion/sketch/gpu/components/cl/ClComponentDirectConv2d.cpp index dc05825500..1fbcb41028 100644 --- a/src/dynamic_fusion/sketch/gpu/components/cl/ClComponentDirectConv2d.cpp +++ b/src/dynamic_fusion/sketch/gpu/components/cl/ClComponentDirectConv2d.cpp @@ -57,13 +57,24 @@ bool ClComponentDirectConv2dSettings::fast_relaxed_math() const return _fast_relaxed_math; } +ClComponentDirectConv2dSettings &ClComponentDirectConv2dSettings::direct_conv_descriptor(const DirectConvComputeKernelInfo &desc) +{ + _desc = desc; + return *this; +} + +DirectConvComputeKernelInfo ClComponentDirectConv2dSettings::direct_conv_descriptor() const +{ + return _desc; +} + Status ClComponentDirectConv2d::validate( const Properties &properties, const ArgumentPack<ITensorInfo> &tensors, const Attributes &attributes, const Settings &settings) { - ARM_COMPUTE_UNUSED(properties, settings); + ARM_COMPUTE_UNUSED(properties); const auto src = tensors.get_const_tensor(TensorType::ACL_SRC_0); const auto wei = tensors.get_const_tensor(TensorType::ACL_SRC_1); const auto bia = tensors.get_const_tensor(TensorType::ACL_SRC_2); @@ -125,6 +136,11 @@ Status ClComponentDirectConv2d::validate( // Data layout ARM_COMPUTE_RETURN_ERROR_ON_DATA_LAYOUT_NOT_IN(src, DataLayout::NHWC); + const auto desc = settings.direct_conv_descriptor(); + ARM_COMPUTE_RETURN_ERROR_ON_MSG(desc.n0 != 1 && desc.n0 != 2 && desc.n0 != 3 && desc.n0 != 4 && desc.n0 != 8 && desc.n0 != 16, + "N0 can only be: 1, 2, 3, 4, 8, and 16"); + ARM_COMPUTE_RETURN_ERROR_ON_MSG(desc.k0 != 1 && desc.k0 != 2 && desc.k0 != 3 && desc.k0 != 4 && desc.k0 != 8 && desc.k0 != 16, + "K0 can only be: 1, 2, 3, 4, 8, and 16"); return Status{}; } |