aboutsummaryrefslogtreecommitdiff
path: root/src/core/experimental/dynamic_fusion/ClKernelBuildingImpl/components/ClGemmNativeKernelComponent.h
diff options
context:
space:
mode:
Diffstat (limited to 'src/core/experimental/dynamic_fusion/ClKernelBuildingImpl/components/ClGemmNativeKernelComponent.h')
-rw-r--r--src/core/experimental/dynamic_fusion/ClKernelBuildingImpl/components/ClGemmNativeKernelComponent.h19
1 files changed, 13 insertions, 6 deletions
diff --git a/src/core/experimental/dynamic_fusion/ClKernelBuildingImpl/components/ClGemmNativeKernelComponent.h b/src/core/experimental/dynamic_fusion/ClKernelBuildingImpl/components/ClGemmNativeKernelComponent.h
index 38f007c07c..09933a8932 100644
--- a/src/core/experimental/dynamic_fusion/ClKernelBuildingImpl/components/ClGemmNativeKernelComponent.h
+++ b/src/core/experimental/dynamic_fusion/ClKernelBuildingImpl/components/ClGemmNativeKernelComponent.h
@@ -26,7 +26,10 @@
#ifndef ARM_COMPUTE_EXPERIMENTAL_DYNAMICFUSION_IMPL_COMPONENTS_CLGEMMNATIVEKERNELCOMPONENT_H
#define ARM_COMPUTE_EXPERIMENTAL_DYNAMICFUSION_IMPL_COMPONENTS_CLGEMMNATIVEKERNELCOMPONENT_H
+#include "arm_compute/core/Steps.h"
+#include "arm_compute/core/utils/misc/ShapeCalculator.h"
#include "src/core/experimental/dynamic_fusion/ClKernelBuildingImpl/Common.h"
+#include "src/core/helpers/AutoConfiguration.h"
namespace arm_compute
{
@@ -37,14 +40,17 @@ namespace dynamic_fusion
class ClGemmNativeKernelComponent : public IClKernelComponent
{
public:
- ClGemmNativeKernelComponent(const Link &lhs, const Link &rhs, const Link &dst, const Link &bias = Link{})
- : _lhs{ lhs }, _rhs{ rhs }, _bias{ bias }, _dst{ dst }
+ ClGemmNativeKernelComponent(const ClKernelBlueprint *blueprint, const GemmNativeDescriptor &desc,
+ const Link &lhs, const Link &rhs, const Link &dst, const Link &bias = Link{})
+ : IClKernelComponent(blueprint), _desc{ desc }, _lhs{ lhs }, _rhs{ rhs }, _bias{ bias }, _dst{ dst }
{
}
+
ComponentType get_component_type() const override;
std::set<std::string> get_headers_list() const override;
std::string get_additional_macros() const override;
std::string get_component_code() const override;
+ Window get_window() const override;
ClKernelArgList get_args();
virtual std::vector<Link> get_links() const override
@@ -60,10 +66,11 @@ public:
}
private:
- Link _lhs{};
- Link _rhs{};
- Link _bias{};
- Link _dst{};
+ GemmNativeDescriptor _desc{};
+ Link _lhs{};
+ Link _rhs{};
+ Link _bias{};
+ Link _dst{};
};
} // namespace dynamic_fusion