diff options
Diffstat (limited to 'src/core/experimental/dynamic_fusion/ClKernelBuildingAPI.cpp')
-rw-r--r-- | src/core/experimental/dynamic_fusion/ClKernelBuildingAPI.cpp | 8 |
1 files changed, 6 insertions, 2 deletions
diff --git a/src/core/experimental/dynamic_fusion/ClKernelBuildingAPI.cpp b/src/core/experimental/dynamic_fusion/ClKernelBuildingAPI.cpp index e40f9c6da9..6db1ca4cf5 100644 --- a/src/core/experimental/dynamic_fusion/ClKernelBuildingAPI.cpp +++ b/src/core/experimental/dynamic_fusion/ClKernelBuildingAPI.cpp @@ -61,13 +61,15 @@ Status add_tensor_intermed(ClKernelBlueprint &kernel_blueprint, ArgumentID &id) return Status{}; } -Status add_kcomp_gemm_native(ClKernelBlueprint &kernel_blueprint, const ClKernelComponentDescriptor &, const GemmNativeDescriptor &, +Status add_kcomp_gemm_native(ClKernelBlueprint &kernel_blueprint, const ClKernelComponentDescriptor &, + const GemmNativeDescriptor &gemm_native_desc, ArgumentID lhs_id, ArgumentID rhs_id, ArgumentID bias_id, ArgumentID &dst_id) { kernel_blueprint.impl().validate_arg_ids({ lhs_id, rhs_id, bias_id, dst_id }); - kernel_blueprint.impl().add_component( std::make_unique<ClGemmNativeKernelComponent>( + &kernel_blueprint, + gemm_native_desc, SharedVarLink{ lhs_id, SharedVarIO::Input, kernel_blueprint.impl().group(lhs_id) }, SharedVarLink{ rhs_id, SharedVarIO::Input, kernel_blueprint.impl().group(rhs_id) }, SharedVarLink{ dst_id, SharedVarIO::Output, kernel_blueprint.impl().group(dst_id) }, @@ -81,6 +83,7 @@ Status add_kcomp_eltwise_add(ClKernelBlueprint &kernel_blueprint, const ClKernel { kernel_blueprint.impl().add_component( std::make_unique<ClElementwiseAddKernelComponent>( + &kernel_blueprint, SharedVarLink{ src0_id, SharedVarIO::Input, kernel_blueprint.impl().group(src0_id) }, SharedVarLink{ src1_id, SharedVarIO::Input, kernel_blueprint.impl().group(src1_id) }, SharedVarLink{ dst_id, SharedVarIO::Output, kernel_blueprint.impl().group(dst_id) })); @@ -98,6 +101,7 @@ Status add_kcomp_store(ClKernelBlueprint &kernel_blueprint, const ClKernelCompon case StoreType::StoreBlockBoundaryAware: kernel_blueprint.impl().add_component( std::make_unique<ClStoreBlockBoundaryAwareKernelComponent>( + &kernel_blueprint, SharedVarLink{ src_tile, SharedVarIO::Input, kernel_blueprint.impl().group(src_tile) }, SharedVarLink{ dst_tile, SharedVarIO::Output, kernel_blueprint.impl().group(dst_tile) })); break; |