diff options
Diffstat (limited to 'src/dynamic_fusion/sketch/gpu/components')
-rw-r--r-- | src/dynamic_fusion/sketch/gpu/components/cl/ClComponentDirectConv2d.cpp | 18 | ||||
-rw-r--r-- | src/dynamic_fusion/sketch/gpu/components/cl/ClComponentDirectConv2d.h | 11 |
2 files changed, 26 insertions, 3 deletions
diff --git a/src/dynamic_fusion/sketch/gpu/components/cl/ClComponentDirectConv2d.cpp b/src/dynamic_fusion/sketch/gpu/components/cl/ClComponentDirectConv2d.cpp index dc05825500..1fbcb41028 100644 --- a/src/dynamic_fusion/sketch/gpu/components/cl/ClComponentDirectConv2d.cpp +++ b/src/dynamic_fusion/sketch/gpu/components/cl/ClComponentDirectConv2d.cpp @@ -57,13 +57,24 @@ bool ClComponentDirectConv2dSettings::fast_relaxed_math() const return _fast_relaxed_math; } +ClComponentDirectConv2dSettings &ClComponentDirectConv2dSettings::direct_conv_descriptor(const DirectConvComputeKernelInfo &desc) +{ + _desc = desc; + return *this; +} + +DirectConvComputeKernelInfo ClComponentDirectConv2dSettings::direct_conv_descriptor() const +{ + return _desc; +} + Status ClComponentDirectConv2d::validate( const Properties &properties, const ArgumentPack<ITensorInfo> &tensors, const Attributes &attributes, const Settings &settings) { - ARM_COMPUTE_UNUSED(properties, settings); + ARM_COMPUTE_UNUSED(properties); const auto src = tensors.get_const_tensor(TensorType::ACL_SRC_0); const auto wei = tensors.get_const_tensor(TensorType::ACL_SRC_1); const auto bia = tensors.get_const_tensor(TensorType::ACL_SRC_2); @@ -125,6 +136,11 @@ Status ClComponentDirectConv2d::validate( // Data layout ARM_COMPUTE_RETURN_ERROR_ON_DATA_LAYOUT_NOT_IN(src, DataLayout::NHWC); + const auto desc = settings.direct_conv_descriptor(); + ARM_COMPUTE_RETURN_ERROR_ON_MSG(desc.n0 != 1 && desc.n0 != 2 && desc.n0 != 3 && desc.n0 != 4 && desc.n0 != 8 && desc.n0 != 16, + "N0 can only be: 1, 2, 3, 4, 8, and 16"); + ARM_COMPUTE_RETURN_ERROR_ON_MSG(desc.k0 != 1 && desc.k0 != 2 && desc.k0 != 3 && desc.k0 != 4 && desc.k0 != 8 && desc.k0 != 16, + "K0 can only be: 1, 2, 3, 4, 8, and 16"); return Status{}; } diff --git a/src/dynamic_fusion/sketch/gpu/components/cl/ClComponentDirectConv2d.h b/src/dynamic_fusion/sketch/gpu/components/cl/ClComponentDirectConv2d.h index fec22b84a5..c3a70ef3ae 100644 --- a/src/dynamic_fusion/sketch/gpu/components/cl/ClComponentDirectConv2d.h +++ b/src/dynamic_fusion/sketch/gpu/components/cl/ClComponentDirectConv2d.h @@ -25,6 +25,7 @@ #define SRC_DYNAMIC_FUSION_SKETCH_GPU_COMPONENTS_CL_CLCOMPONENTDIRECTCONV2D #include "arm_compute/core/Error.h" +#include "arm_compute/core/KernelDescriptors.h" #include "src/dynamic_fusion/sketch/gpu/components/IGpuKernelComponent.h" #include <memory> @@ -56,9 +57,15 @@ public: /** Get fast_relaxed_math flag */ bool fast_relaxed_math() const; + /** Set direct convolution descriptor */ + ClComponentDirectConv2dSettings &direct_conv_descriptor(const DirectConvComputeKernelInfo &desc); + /** Get direct convolution descriptor */ + DirectConvComputeKernelInfo direct_conv_descriptor() const; + private: - bool _export_to_cl_image{ false }; - bool _fast_relaxed_math{ true }; + bool _export_to_cl_image{ false }; + bool _fast_relaxed_math{ true }; + DirectConvComputeKernelInfo _desc{}; // Direct convolution descriptor }; /** Forward declaration */ |