aboutsummaryrefslogtreecommitdiff
path: root/compute_kernel_writer/prototype/include/ckw/TensorOperand.h
diff options
context:
space:
mode:
Diffstat (limited to 'compute_kernel_writer/prototype/include/ckw/TensorOperand.h')
-rw-r--r--compute_kernel_writer/prototype/include/ckw/TensorOperand.h53
1 files changed, 34 insertions, 19 deletions
diff --git a/compute_kernel_writer/prototype/include/ckw/TensorOperand.h b/compute_kernel_writer/prototype/include/ckw/TensorOperand.h
index 3a2509e7c8..6d88932c66 100644
--- a/compute_kernel_writer/prototype/include/ckw/TensorOperand.h
+++ b/compute_kernel_writer/prototype/include/ckw/TensorOperand.h
@@ -48,10 +48,11 @@ class TensorOperand : public OperandBase
public:
/** Initialize a new instance of @ref TensorOperand class.
*
- * @param[in] name The name of the tensor.
- * @param[in] info The tensor info.
+ * @param[in] name The name of the tensor.
+ * @param[in] info The tensor info.
+ * @param[in] storage_type The tensor storage type.
*/
- TensorOperand(const ::std::string &name, const TensorInfo &info);
+ TensorOperand(const ::std::string &name, const TensorInfo &info, TensorStorageType storage_type);
/** No copy constructor. */
TensorOperand(const TensorOperand &other) = delete;
@@ -71,6 +72,9 @@ public:
/** Get the tensor info. */
TensorInfo &info();
+ /** Get the tensor storage type. */
+ TensorStorageType storage_type() const;
+
/** Get the data type. */
virtual DataType data_type() const override;
@@ -96,43 +100,44 @@ public:
TensorOperand &tile_sampler(const TensorTileSampler &value);
/** Get the operand that contains the stride in y dimension of the tensor. */
- TileOperand &stride1();
+ TensorComponentOperand &stride1();
/** Get the operand that contains the stride in z dimension of the tensor. */
- TileOperand &stride2();
+ TensorComponentOperand &stride2();
/** Get the operand that contains the stride in w dimension of the tensor. */
- TileOperand &stride3();
+ TensorComponentOperand &stride3();
/** Get the operand that contains the stride in w dimension of the tensor. */
- TileOperand &stride4();
+ TensorComponentOperand &stride4();
/** Get the operand that contains the size of dimension 0 of the tensor. */
- TileOperand &dim0();
+ TensorComponentOperand &dim0();
/** Get the operand that contains the size of dimension 1 of the tensor. */
- TileOperand &dim1();
+ TensorComponentOperand &dim1();
/** Get the operand that contains the size of dimension 2 of the tensor. */
- TileOperand &dim2();
+ TensorComponentOperand &dim2();
/** Get the operand that contains the size of dimension 3 of the tensor. */
- TileOperand &dim3();
+ TensorComponentOperand &dim3();
/** Get the operand that contains the size of dimension 4 of the tensor. */
- TileOperand &dim4();
+ TensorComponentOperand &dim4();
/** Get the operand that contains the size of dimensions 1 and 2 collapsed. */
- TileOperand &dim1_dim2();
+ TensorComponentOperand &dim1_dim2();
/** Get the operand that contains the size of dimensions 1, 2 and 3 collapsed. */
- TileOperand &dim1_dim2_dim3();
+ TensorComponentOperand &dim1_dim2_dim3();
/** Get the operand that contains the offset in bytes to the first element. */
- TileOperand &offset_first_element_in_bytes();
+ TensorComponentOperand &offset_first_element_in_bytes();
private:
- TensorInfo _info;
+ TensorInfo _info;
+ TensorStorageType _storage_type;
TileOperand *_tile{ nullptr };
TensorTileSampler _tile_sampler{};
@@ -161,10 +166,19 @@ class TensorComponentOperand : public TileOperand
public:
/** Initialize a new instance of @ref TensorComponentOperand class.
*
- * @param[in] name The name of the operand.
+ * @param[in] tensor The tensor operand.
* @param[in] component The tensor info component.
*/
- TensorComponentOperand(const ::std::string &name, TensorComponent component);
+ TensorComponentOperand(TensorOperand &tensor, TensorComponentType component);
+
+ /** Get the tensor operand. */
+ TensorOperand &tensor();
+
+ /** Get the tensor operand. */
+ const TensorOperand &tensor() const;
+
+ /** Get the tensor component. */
+ TensorComponentType component_type() const;
/** (Internal use only) Create the implementation operand.
*
@@ -173,7 +187,8 @@ public:
virtual prototype::Operand create_impl_operand(prototype::IGpuKernelWriter *writer) const override;
private:
- TensorComponent _component;
+ TensorOperand &_tensor;
+ TensorComponentType _component;
};
} // namespace ckw