diff options
Diffstat (limited to 'compute_kernel_writer/prototype/include/ckw/Kernel.h')
-rw-r--r-- | compute_kernel_writer/prototype/include/ckw/Kernel.h | 22 |
1 files changed, 18 insertions, 4 deletions
diff --git a/compute_kernel_writer/prototype/include/ckw/Kernel.h b/compute_kernel_writer/prototype/include/ckw/Kernel.h index 527206feec..3deb2ace0d 100644 --- a/compute_kernel_writer/prototype/include/ckw/Kernel.h +++ b/compute_kernel_writer/prototype/include/ckw/Kernel.h @@ -25,16 +25,20 @@ #ifndef CKW_PROTOTYPE_INCLUDE_CKW_KERNEL_H #define CKW_PROTOTYPE_INCLUDE_CKW_KERNEL_H +#include "ckw/KernelArgument.h" #include "ckw/OperandBase.h" #include "ckw/types/GpuTargetLanguage.h" #include <map> #include <memory> #include <string> +#include <vector> namespace ckw { +class TileOperand; + namespace prototype { class GpuKernelWriterDataHolder; @@ -57,11 +61,20 @@ public: /** Get the name of the kernel function. */ const std::string &name() const; - /** (Internal use only) Get the map from operand name to the operand declared in this kernel. */ - const ::std::map<::std::string, ::std::unique_ptr<OperandBase>> &operands() const; + /** Get the list of kernel arguments. */ + ::std::vector<KernelArgument> arguments() const; + + /** (Internal use only) Register the tile operand. + * + * @param operand The tile operand to be registered. + */ + TileOperand ®ister_operand(::std::unique_ptr<TileOperand> operand); - /** (Internal use only) Get the map from operand name to the operand declared in this kernel. */ - ::std::map<::std::string, ::std::unique_ptr<OperandBase>> &operands(); + /** (Internal use only) Register the tensor operand. + * + * @param operand The tensor operand to be registered. + */ + TensorOperand ®ister_operand(::std::unique_ptr<TensorOperand> operand); /** (Internal use only) Get the implementation data. */ prototype::GpuKernelWriterDataHolder *impl(); @@ -70,6 +83,7 @@ private: ::std::string _name; ::std::unique_ptr<prototype::GpuKernelWriterDataHolder> _kernel; ::std::map<::std::string, ::std::unique_ptr<OperandBase>> _operands; + ::std::map<int32_t, TensorOperand *> _tensor_id_operands; }; } // namespace ckw |