aboutsummaryrefslogtreecommitdiff
path: root/compute_kernel_writer/prototype/include/ckw/Kernel.h
diff options
context:
space:
mode:
Diffstat (limited to 'compute_kernel_writer/prototype/include/ckw/Kernel.h')
-rw-r--r--compute_kernel_writer/prototype/include/ckw/Kernel.h22
1 files changed, 18 insertions, 4 deletions
diff --git a/compute_kernel_writer/prototype/include/ckw/Kernel.h b/compute_kernel_writer/prototype/include/ckw/Kernel.h
index 527206feec..3deb2ace0d 100644
--- a/compute_kernel_writer/prototype/include/ckw/Kernel.h
+++ b/compute_kernel_writer/prototype/include/ckw/Kernel.h
@@ -25,16 +25,20 @@
#ifndef CKW_PROTOTYPE_INCLUDE_CKW_KERNEL_H
#define CKW_PROTOTYPE_INCLUDE_CKW_KERNEL_H
+#include "ckw/KernelArgument.h"
#include "ckw/OperandBase.h"
#include "ckw/types/GpuTargetLanguage.h"
#include <map>
#include <memory>
#include <string>
+#include <vector>
namespace ckw
{
+class TileOperand;
+
namespace prototype
{
class GpuKernelWriterDataHolder;
@@ -57,11 +61,20 @@ public:
/** Get the name of the kernel function. */
const std::string &name() const;
- /** (Internal use only) Get the map from operand name to the operand declared in this kernel. */
- const ::std::map<::std::string, ::std::unique_ptr<OperandBase>> &operands() const;
+ /** Get the list of kernel arguments. */
+ ::std::vector<KernelArgument> arguments() const;
+
+ /** (Internal use only) Register the tile operand.
+ *
+ * @param operand The tile operand to be registered.
+ */
+ TileOperand &register_operand(::std::unique_ptr<TileOperand> operand);
- /** (Internal use only) Get the map from operand name to the operand declared in this kernel. */
- ::std::map<::std::string, ::std::unique_ptr<OperandBase>> &operands();
+ /** (Internal use only) Register the tensor operand.
+ *
+ * @param operand The tensor operand to be registered.
+ */
+ TensorOperand &register_operand(::std::unique_ptr<TensorOperand> operand);
/** (Internal use only) Get the implementation data. */
prototype::GpuKernelWriterDataHolder *impl();
@@ -70,6 +83,7 @@ private:
::std::string _name;
::std::unique_ptr<prototype::GpuKernelWriterDataHolder> _kernel;
::std::map<::std::string, ::std::unique_ptr<OperandBase>> _operands;
+ ::std::map<int32_t, TensorOperand *> _tensor_id_operands;
};
} // namespace ckw