aboutsummaryrefslogtreecommitdiff
path: root/compute_kernel_writer/src/ITensorArgument.h
diff options
context:
space:
mode:
Diffstat (limited to 'compute_kernel_writer/src/ITensorArgument.h')
-rw-r--r--compute_kernel_writer/src/ITensorArgument.h39
1 files changed, 26 insertions, 13 deletions
diff --git a/compute_kernel_writer/src/ITensorArgument.h b/compute_kernel_writer/src/ITensorArgument.h
index 40ad69fdc0..838bd40f85 100644
--- a/compute_kernel_writer/src/ITensorArgument.h
+++ b/compute_kernel_writer/src/ITensorArgument.h
@@ -35,11 +35,14 @@
namespace ckw
{
+
+class ITensorComponent;
+
/** Tensor storage variable */
struct TensorStorageVariable
{
- std::string val{ "" }; /** Tensor storage as a string */
- std::string type{ "" }; /** Tensor storage type as a string */
+ std::string val{ "" }; /** Tensor storage as a string */
+ TensorStorageType type{ TensorStorageType::Unknown }; /** Tensor storage type */
};
/** Tensor argument base class.
@@ -60,11 +63,21 @@ public:
{
return _basename;
}
+
+ /** Method to get the tensor info
+ *
+ * @return the @ref TensorInfo
+ */
+ TensorInfo &info()
+ {
+ return _info;
+ }
+
/** Method to get the tensor info
*
* @return the @ref TensorInfo
*/
- TensorInfo info() const
+ const TensorInfo &info() const
{
return _info;
}
@@ -75,38 +88,38 @@ protected:
};
/** Tensor component argument base class */
-class ITensorComponentArgument
+class ITensorComponentAccess
{
public:
- virtual ~ITensorComponentArgument() = default;
- /** Method to get the tensor component variable as a string
+ virtual ~ITensorComponentAccess() = default;
+ /** Method to get the tensor component variable as a tile.
*
* @param[in] x The tensor component to query
*
- * @return the tensor component variable as a @ref TileVariable
+ * @return the tensor component variable as a @ref ITile.
*/
- virtual TileVariable component(TensorComponentType x) = 0;
+ virtual ITile &component(TensorComponentType x) = 0;
/** Method to get all tensor components needed to access the data in the tensor
*
* The tensor components returned by this method must be all passed as kernel argument
*
- * @return a vector containing all the tensor components as @ref TileVariable objects
+ * @return a vector containing all the tensor components as pointers to @ref ITensorComponent objects.
*/
- virtual std::vector<TileVariable> components() const = 0;
+ virtual std::vector<const ITensorComponent *> components() const = 0;
};
/** Tensor storage argument base class */
-class ITensorStorageArgument
+class ITensorStorageAccess
{
public:
- virtual ~ITensorStorageArgument() = default;
+ virtual ~ITensorStorageAccess() = default;
/** Method to get the tensor storage as a string
*
* @param[in] x The tensor storage to query
*
* @return the tensor storage as a @ref TensorStorageVariable
*/
- virtual TensorStorageVariable storage(TensorStorageType x) = 0;
+ virtual TensorStorageVariable &storage(TensorStorageType x) = 0;
/** Method to get all tensor storages needed to access the data in the tensor
*
* The tensor storages returned by this method must be all passed as kernel argument