aboutsummaryrefslogtreecommitdiff
path: root/compute_kernel_writer/src/cl/CLTensorArgument.h
diff options
context:
space:
mode:
Diffstat (limited to 'compute_kernel_writer/src/cl/CLTensorArgument.h')
-rw-r--r--compute_kernel_writer/src/cl/CLTensorArgument.h40
1 files changed, 29 insertions, 11 deletions
diff --git a/compute_kernel_writer/src/cl/CLTensorArgument.h b/compute_kernel_writer/src/cl/CLTensorArgument.h
index cd924846c5..4cbbee21ee 100644
--- a/compute_kernel_writer/src/cl/CLTensorArgument.h
+++ b/compute_kernel_writer/src/cl/CLTensorArgument.h
@@ -24,8 +24,10 @@
#ifndef CKW_SRC_CL_CLTENSORARGUMENT_H
#define CKW_SRC_CL_CLTENSORARGUMENT_H
-#include "src/ITensorArgument.h"
-
+#include "ckw/types/TensorComponentType.h"
+#include "ckw/types/TensorStorageType.h"
+#include "src/ITensor.h"
+#include <memory>
#include <string>
#include <vector>
@@ -34,12 +36,16 @@ namespace ckw
// Forward declarations
class TensorInfo;
+class ITensorComponent;
+class CLTensorComponent;
+class CLTensorStorage;
+
/** OpenCL specific tensor argument
* Internally, the object keeps track of the components and storages used to minimize the number
* of kernel arguments required. Therefore, if we create this object but we do not access any components
* or storages, the storages() and components() method will return an empty list.
*/
-class CLTensorArgument : public ITensorArgument, ITensorStorageArgument, ITensorComponentArgument
+class CLTensorArgument : public ITensor
{
public:
/** Constructor
@@ -51,20 +57,32 @@ public:
*/
CLTensorArgument(const std::string &name, const TensorInfo &info, bool return_dims_by_value);
+ /** Destructor. */
+ ~CLTensorArgument();
+
+ /** Get a tensor component of the given type.
+ *
+ * This function is for internal use as it returns a reference to @ref CLTensorComponent object.
+ * It provides rich functionalities and doesn't require unnecessary casting
+ * unlike @ref CLTensorComponent::component which is for the public API and only returns
+ * a reference to a generic @ref ITile object.
+ */
+ CLTensorComponent& cl_component(TensorComponentType component_type);
+
// Inherited method overridden
- TensorStorageVariable storage(TensorStorageType x);
- TileVariable component(TensorComponentType x);
- std::vector<TensorStorageVariable> storages() const;
- std::vector<TileVariable> components() const;
+ TensorStorageVariable &storage(TensorStorageType x) override;
+ ITile &component(TensorComponentType x) override;
+ std::vector<TensorStorageVariable> storages() const override;
+ std::vector<const ITensorComponent *> components() const override;
private:
std::string create_storage_name(TensorStorageType x) const;
- std::string create_component_name(TensorComponentType x) const;
- bool _return_dims_by_value{ false };
- std::vector<TensorStorageType> _storages_used{};
- std::vector<TensorComponentType> _components_used{};
+ bool _return_dims_by_value{ false };
+ std::vector<TensorStorageVariable> _storages_used{};
+ std::vector<std::unique_ptr<CLTensorComponent>> _components_used{};
};
+
} // namespace ckw
#endif // CKW_SRC_CL_CLTENSORARGUMENT_H