diff options
Diffstat (limited to 'src/core/experimental/dynamic_fusion/ClKernelBuildingImpl/components/ClGemmNativeKernelComponent.h')
-rw-r--r-- | src/core/experimental/dynamic_fusion/ClKernelBuildingImpl/components/ClGemmNativeKernelComponent.h | 19 |
1 files changed, 13 insertions, 6 deletions
diff --git a/src/core/experimental/dynamic_fusion/ClKernelBuildingImpl/components/ClGemmNativeKernelComponent.h b/src/core/experimental/dynamic_fusion/ClKernelBuildingImpl/components/ClGemmNativeKernelComponent.h index 38f007c07c..09933a8932 100644 --- a/src/core/experimental/dynamic_fusion/ClKernelBuildingImpl/components/ClGemmNativeKernelComponent.h +++ b/src/core/experimental/dynamic_fusion/ClKernelBuildingImpl/components/ClGemmNativeKernelComponent.h @@ -26,7 +26,10 @@ #ifndef ARM_COMPUTE_EXPERIMENTAL_DYNAMICFUSION_IMPL_COMPONENTS_CLGEMMNATIVEKERNELCOMPONENT_H #define ARM_COMPUTE_EXPERIMENTAL_DYNAMICFUSION_IMPL_COMPONENTS_CLGEMMNATIVEKERNELCOMPONENT_H +#include "arm_compute/core/Steps.h" +#include "arm_compute/core/utils/misc/ShapeCalculator.h" #include "src/core/experimental/dynamic_fusion/ClKernelBuildingImpl/Common.h" +#include "src/core/helpers/AutoConfiguration.h" namespace arm_compute { @@ -37,14 +40,17 @@ namespace dynamic_fusion class ClGemmNativeKernelComponent : public IClKernelComponent { public: - ClGemmNativeKernelComponent(const Link &lhs, const Link &rhs, const Link &dst, const Link &bias = Link{}) - : _lhs{ lhs }, _rhs{ rhs }, _bias{ bias }, _dst{ dst } + ClGemmNativeKernelComponent(const ClKernelBlueprint *blueprint, const GemmNativeDescriptor &desc, + const Link &lhs, const Link &rhs, const Link &dst, const Link &bias = Link{}) + : IClKernelComponent(blueprint), _desc{ desc }, _lhs{ lhs }, _rhs{ rhs }, _bias{ bias }, _dst{ dst } { } + ComponentType get_component_type() const override; std::set<std::string> get_headers_list() const override; std::string get_additional_macros() const override; std::string get_component_code() const override; + Window get_window() const override; ClKernelArgList get_args(); virtual std::vector<Link> get_links() const override @@ -60,10 +66,11 @@ public: } private: - Link _lhs{}; - Link _rhs{}; - Link _bias{}; - Link _dst{}; + GemmNativeDescriptor _desc{}; + Link _lhs{}; + Link _rhs{}; + Link _bias{}; + Link _dst{}; }; } // namespace dynamic_fusion |