diff options
Diffstat (limited to 'compute_kernel_writer/include/ckw')
5 files changed, 41 insertions, 11 deletions
diff --git a/compute_kernel_writer/include/ckw/KernelWriter.h b/compute_kernel_writer/include/ckw/KernelWriter.h index 0d739e859a..da41b940d7 100644 --- a/compute_kernel_writer/include/ckw/KernelWriter.h +++ b/compute_kernel_writer/include/ckw/KernelWriter.h @@ -25,11 +25,22 @@ #ifndef CKW_INCLUDE_CKW_KERNELWRITER_H #define CKW_INCLUDE_CKW_KERNELWRITER_H +#include "ckw/Kernel.h" +#include "ckw/TensorInfo.h" #include "ckw/TensorOperand.h" +#include "ckw/TensorSampler.h" +#include "ckw/TileInfo.h" #include "ckw/TileOperand.h" #include "ckw/types/ConstantData.h" #include "ckw/types/ConvertPolicy.h" +#include "ckw/types/DataType.h" #include "ckw/types/Operators.h" +#include "ckw/types/TargetArchitecture.h" +#include "ckw/types/TargetLanguage.h" +#include "ckw/types/TensorComponentType.h" +#include "ckw/types/TensorDataLayout.h" +#include "ckw/types/TensorSamplerTypes.h" +#include "ckw/types/TensorStorageType.h" #include <functional> #include <memory> @@ -39,16 +50,8 @@ namespace ckw { -/** Forward Declerations */ -class Kernel; -class TensorInfo; -class TensorSampler; +/** Forward Declarations */ class TileArea; -class TileInfo; - -enum class DataType; -enum class TargetArchitecture; -enum class TargetLanguage; /** A kernel writer. * @@ -350,7 +353,6 @@ public: const TileOperand &z, const TileOperand &batch_op) = 0; -protected: // ============================================================================================= // ID space management // ============================================================================================= @@ -367,6 +369,7 @@ protected: /** Get the current ID space. */ int32_t id_space() const; +protected: /** Set the current ID space. * * @param[in] value The ID space to be used. diff --git a/compute_kernel_writer/include/ckw/TensorOperand.h b/compute_kernel_writer/include/ckw/TensorOperand.h index 2672cd5334..a3e53d1314 100644 --- a/compute_kernel_writer/include/ckw/TensorOperand.h +++ b/compute_kernel_writer/include/ckw/TensorOperand.h @@ -43,6 +43,15 @@ public: // Only kernel writer class interacts with tensor operand hence we allow it to access this field. friend class KernelWriter; + /** Create an empty tensor operand. + * + * The new tensor operand doesn't refer to any tensor therefore it is not useable. + */ + TensorOperand(); + + /** Check if the tensor operand contains a tensor and therefore useable. */ + bool is_valid() const; + /** Get the tensor info. */ const TensorInfo &info() const; @@ -92,7 +101,7 @@ private: /** Initialize a new instance of @ref TensorOperand class for a tensor. */ TensorOperand(ITensor &tensor); - ITensor &_tensor; + ITensor *_tensor; }; } // namespace ckw diff --git a/compute_kernel_writer/include/ckw/TileOperand.h b/compute_kernel_writer/include/ckw/TileOperand.h index 56dc5e7b2b..556d589bc0 100644 --- a/compute_kernel_writer/include/ckw/TileOperand.h +++ b/compute_kernel_writer/include/ckw/TileOperand.h @@ -33,6 +33,7 @@ namespace ckw class KernelWriter; class TensorOperand; class ITile; +class TileInfo; /** A tile operand refers to a tile object that can be used for kernel writing. */ class TileOperand @@ -43,6 +44,18 @@ public: friend class KernelWriter; friend class TensorOperand; + /** Create an empty tile operand. + * + * The new tile operand doesn't refer to any tile therefore it is not useable. + */ + TileOperand(); + + /** Check if the tile operand contains a tile and therefore useable. */ + bool is_valid() const; + + /** Get the tile info. */ + const TileInfo &tile_info() const; + /** Get a row vector of the current tile operand. * * @param[in] row The index of the row to be accessed in the current tile operand. diff --git a/compute_kernel_writer/include/ckw/types/ConstantData.h b/compute_kernel_writer/include/ckw/types/ConstantData.h index 7708818ca8..ea95049c9e 100644 --- a/compute_kernel_writer/include/ckw/types/ConstantData.h +++ b/compute_kernel_writer/include/ckw/types/ConstantData.h @@ -53,6 +53,10 @@ public: template <typename T> ConstantData(std::initializer_list<std::initializer_list<T>> values, DataType data_type); + /** Templated constructor */ + template <typename T> + ConstantData(const std::vector<std::vector<T>> &values, DataType data_type); + private: /** Validate the given data type and the template type * diff --git a/compute_kernel_writer/include/ckw/types/Operators.h b/compute_kernel_writer/include/ckw/types/Operators.h index 1e5f9bd542..77b0519422 100644 --- a/compute_kernel_writer/include/ckw/types/Operators.h +++ b/compute_kernel_writer/include/ckw/types/Operators.h @@ -43,6 +43,7 @@ enum class UnaryOp : int32_t Fabs = 0x0014, Log = 0x0015, Round = 0x0016, + Floor = 0x0017, }; /** Assignment operators. */ |