diff options
Diffstat (limited to 'compute_kernel_writer/src/ITensorArgument.h')
-rw-r--r-- | compute_kernel_writer/src/ITensorArgument.h | 39 |
1 files changed, 26 insertions, 13 deletions
diff --git a/compute_kernel_writer/src/ITensorArgument.h b/compute_kernel_writer/src/ITensorArgument.h index 40ad69fdc0..838bd40f85 100644 --- a/compute_kernel_writer/src/ITensorArgument.h +++ b/compute_kernel_writer/src/ITensorArgument.h @@ -35,11 +35,14 @@ namespace ckw { + +class ITensorComponent; + /** Tensor storage variable */ struct TensorStorageVariable { - std::string val{ "" }; /** Tensor storage as a string */ - std::string type{ "" }; /** Tensor storage type as a string */ + std::string val{ "" }; /** Tensor storage as a string */ + TensorStorageType type{ TensorStorageType::Unknown }; /** Tensor storage type */ }; /** Tensor argument base class. @@ -60,11 +63,21 @@ public: { return _basename; } + + /** Method to get the tensor info + * + * @return the @ref TensorInfo + */ + TensorInfo &info() + { + return _info; + } + /** Method to get the tensor info * * @return the @ref TensorInfo */ - TensorInfo info() const + const TensorInfo &info() const { return _info; } @@ -75,38 +88,38 @@ protected: }; /** Tensor component argument base class */ -class ITensorComponentArgument +class ITensorComponentAccess { public: - virtual ~ITensorComponentArgument() = default; - /** Method to get the tensor component variable as a string + virtual ~ITensorComponentAccess() = default; + /** Method to get the tensor component variable as a tile. * * @param[in] x The tensor component to query * - * @return the tensor component variable as a @ref TileVariable + * @return the tensor component variable as a @ref ITile. */ - virtual TileVariable component(TensorComponentType x) = 0; + virtual ITile &component(TensorComponentType x) = 0; /** Method to get all tensor components needed to access the data in the tensor * * The tensor components returned by this method must be all passed as kernel argument * - * @return a vector containing all the tensor components as @ref TileVariable objects + * @return a vector containing all the tensor components as pointers to @ref ITensorComponent objects. */ - virtual std::vector<TileVariable> components() const = 0; + virtual std::vector<const ITensorComponent *> components() const = 0; }; /** Tensor storage argument base class */ -class ITensorStorageArgument +class ITensorStorageAccess { public: - virtual ~ITensorStorageArgument() = default; + virtual ~ITensorStorageAccess() = default; /** Method to get the tensor storage as a string * * @param[in] x The tensor storage to query * * @return the tensor storage as a @ref TensorStorageVariable */ - virtual TensorStorageVariable storage(TensorStorageType x) = 0; + virtual TensorStorageVariable &storage(TensorStorageType x) = 0; /** Method to get all tensor storages needed to access the data in the tensor * * The tensor storages returned by this method must be all passed as kernel argument |