aboutsummaryrefslogtreecommitdiff
path: root/src/core/experimental/dynamic_fusion/ClKernelBuildingAPI.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/core/experimental/dynamic_fusion/ClKernelBuildingAPI.cpp')
-rw-r--r--src/core/experimental/dynamic_fusion/ClKernelBuildingAPI.cpp8
1 files changed, 6 insertions, 2 deletions
diff --git a/src/core/experimental/dynamic_fusion/ClKernelBuildingAPI.cpp b/src/core/experimental/dynamic_fusion/ClKernelBuildingAPI.cpp
index e40f9c6da9..6db1ca4cf5 100644
--- a/src/core/experimental/dynamic_fusion/ClKernelBuildingAPI.cpp
+++ b/src/core/experimental/dynamic_fusion/ClKernelBuildingAPI.cpp
@@ -61,13 +61,15 @@ Status add_tensor_intermed(ClKernelBlueprint &kernel_blueprint, ArgumentID &id)
return Status{};
}
-Status add_kcomp_gemm_native(ClKernelBlueprint &kernel_blueprint, const ClKernelComponentDescriptor &, const GemmNativeDescriptor &,
+Status add_kcomp_gemm_native(ClKernelBlueprint &kernel_blueprint, const ClKernelComponentDescriptor &,
+ const GemmNativeDescriptor &gemm_native_desc,
ArgumentID lhs_id, ArgumentID rhs_id, ArgumentID bias_id, ArgumentID &dst_id)
{
kernel_blueprint.impl().validate_arg_ids({ lhs_id, rhs_id, bias_id, dst_id });
-
kernel_blueprint.impl().add_component(
std::make_unique<ClGemmNativeKernelComponent>(
+ &kernel_blueprint,
+ gemm_native_desc,
SharedVarLink{ lhs_id, SharedVarIO::Input, kernel_blueprint.impl().group(lhs_id) },
SharedVarLink{ rhs_id, SharedVarIO::Input, kernel_blueprint.impl().group(rhs_id) },
SharedVarLink{ dst_id, SharedVarIO::Output, kernel_blueprint.impl().group(dst_id) },
@@ -81,6 +83,7 @@ Status add_kcomp_eltwise_add(ClKernelBlueprint &kernel_blueprint, const ClKernel
{
kernel_blueprint.impl().add_component(
std::make_unique<ClElementwiseAddKernelComponent>(
+ &kernel_blueprint,
SharedVarLink{ src0_id, SharedVarIO::Input, kernel_blueprint.impl().group(src0_id) },
SharedVarLink{ src1_id, SharedVarIO::Input, kernel_blueprint.impl().group(src1_id) },
SharedVarLink{ dst_id, SharedVarIO::Output, kernel_blueprint.impl().group(dst_id) }));
@@ -98,6 +101,7 @@ Status add_kcomp_store(ClKernelBlueprint &kernel_blueprint, const ClKernelCompon
case StoreType::StoreBlockBoundaryAware:
kernel_blueprint.impl().add_component(
std::make_unique<ClStoreBlockBoundaryAwareKernelComponent>(
+ &kernel_blueprint,
SharedVarLink{ src_tile, SharedVarIO::Input, kernel_blueprint.impl().group(src_tile) },
SharedVarLink{ dst_tile, SharedVarIO::Output, kernel_blueprint.impl().group(dst_tile) }));
break;