diff options
Diffstat (limited to 'compute_kernel_writer/src/cl/CLTensorArgument.h')
-rw-r--r-- | compute_kernel_writer/src/cl/CLTensorArgument.h | 40 |
1 files changed, 29 insertions, 11 deletions
diff --git a/compute_kernel_writer/src/cl/CLTensorArgument.h b/compute_kernel_writer/src/cl/CLTensorArgument.h index cd924846c5..4cbbee21ee 100644 --- a/compute_kernel_writer/src/cl/CLTensorArgument.h +++ b/compute_kernel_writer/src/cl/CLTensorArgument.h @@ -24,8 +24,10 @@ #ifndef CKW_SRC_CL_CLTENSORARGUMENT_H #define CKW_SRC_CL_CLTENSORARGUMENT_H -#include "src/ITensorArgument.h" - +#include "ckw/types/TensorComponentType.h" +#include "ckw/types/TensorStorageType.h" +#include "src/ITensor.h" +#include <memory> #include <string> #include <vector> @@ -34,12 +36,16 @@ namespace ckw // Forward declarations class TensorInfo; +class ITensorComponent; +class CLTensorComponent; +class CLTensorStorage; + /** OpenCL specific tensor argument * Internally, the object keeps track of the components and storages used to minimize the number * of kernel arguments required. Therefore, if we create this object but we do not access any components * or storages, the storages() and components() method will return an empty list. */ -class CLTensorArgument : public ITensorArgument, ITensorStorageArgument, ITensorComponentArgument +class CLTensorArgument : public ITensor { public: /** Constructor @@ -51,20 +57,32 @@ public: */ CLTensorArgument(const std::string &name, const TensorInfo &info, bool return_dims_by_value); + /** Destructor. */ + ~CLTensorArgument(); + + /** Get a tensor component of the given type. + * + * This function is for internal use as it returns a reference to @ref CLTensorComponent object. + * It provides rich functionalities and doesn't require unnecessary casting + * unlike @ref CLTensorComponent::component which is for the public API and only returns + * a reference to a generic @ref ITile object. + */ + CLTensorComponent& cl_component(TensorComponentType component_type); + // Inherited method overridden - TensorStorageVariable storage(TensorStorageType x); - TileVariable component(TensorComponentType x); - std::vector<TensorStorageVariable> storages() const; - std::vector<TileVariable> components() const; + TensorStorageVariable &storage(TensorStorageType x) override; + ITile &component(TensorComponentType x) override; + std::vector<TensorStorageVariable> storages() const override; + std::vector<const ITensorComponent *> components() const override; private: std::string create_storage_name(TensorStorageType x) const; - std::string create_component_name(TensorComponentType x) const; - bool _return_dims_by_value{ false }; - std::vector<TensorStorageType> _storages_used{}; - std::vector<TensorComponentType> _components_used{}; + bool _return_dims_by_value{ false }; + std::vector<TensorStorageVariable> _storages_used{}; + std::vector<std::unique_ptr<CLTensorComponent>> _components_used{}; }; + } // namespace ckw #endif // CKW_SRC_CL_CLTENSORARGUMENT_H |