aboutsummaryrefslogtreecommitdiff
path: root/compute_kernel_writer/prototype/include
diff options
context:
space:
mode:
Diffstat (limited to 'compute_kernel_writer/prototype/include')
-rw-r--r--compute_kernel_writer/prototype/include/ckw/Kernel.h22
-rw-r--r--compute_kernel_writer/prototype/include/ckw/KernelArgument.h106
-rw-r--r--compute_kernel_writer/prototype/include/ckw/KernelWriter.h21
-rw-r--r--compute_kernel_writer/prototype/include/ckw/TensorInfo.h4
-rw-r--r--compute_kernel_writer/prototype/include/ckw/TensorOperand.h53
5 files changed, 169 insertions, 37 deletions
diff --git a/compute_kernel_writer/prototype/include/ckw/Kernel.h b/compute_kernel_writer/prototype/include/ckw/Kernel.h
index 527206feec..3deb2ace0d 100644
--- a/compute_kernel_writer/prototype/include/ckw/Kernel.h
+++ b/compute_kernel_writer/prototype/include/ckw/Kernel.h
@@ -25,16 +25,20 @@
#ifndef CKW_PROTOTYPE_INCLUDE_CKW_KERNEL_H
#define CKW_PROTOTYPE_INCLUDE_CKW_KERNEL_H
+#include "ckw/KernelArgument.h"
#include "ckw/OperandBase.h"
#include "ckw/types/GpuTargetLanguage.h"
#include <map>
#include <memory>
#include <string>
+#include <vector>
namespace ckw
{
+class TileOperand;
+
namespace prototype
{
class GpuKernelWriterDataHolder;
@@ -57,11 +61,20 @@ public:
/** Get the name of the kernel function. */
const std::string &name() const;
- /** (Internal use only) Get the map from operand name to the operand declared in this kernel. */
- const ::std::map<::std::string, ::std::unique_ptr<OperandBase>> &operands() const;
+ /** Get the list of kernel arguments. */
+ ::std::vector<KernelArgument> arguments() const;
+
+ /** (Internal use only) Register the tile operand.
+ *
+ * @param operand The tile operand to be registered.
+ */
+ TileOperand &register_operand(::std::unique_ptr<TileOperand> operand);
- /** (Internal use only) Get the map from operand name to the operand declared in this kernel. */
- ::std::map<::std::string, ::std::unique_ptr<OperandBase>> &operands();
+ /** (Internal use only) Register the tensor operand.
+ *
+ * @param operand The tensor operand to be registered.
+ */
+ TensorOperand &register_operand(::std::unique_ptr<TensorOperand> operand);
/** (Internal use only) Get the implementation data. */
prototype::GpuKernelWriterDataHolder *impl();
@@ -70,6 +83,7 @@ private:
::std::string _name;
::std::unique_ptr<prototype::GpuKernelWriterDataHolder> _kernel;
::std::map<::std::string, ::std::unique_ptr<OperandBase>> _operands;
+ ::std::map<int32_t, TensorOperand *> _tensor_id_operands;
};
} // namespace ckw
diff --git a/compute_kernel_writer/prototype/include/ckw/KernelArgument.h b/compute_kernel_writer/prototype/include/ckw/KernelArgument.h
new file mode 100644
index 0000000000..af8bcde634
--- /dev/null
+++ b/compute_kernel_writer/prototype/include/ckw/KernelArgument.h
@@ -0,0 +1,106 @@
+/*
+ * Copyright (c) 2023 Arm Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#ifndef CKW_PROTOTYPE_INCLUDE_CKW_KERNELARGUMENT_H
+#define CKW_PROTOTYPE_INCLUDE_CKW_KERNELARGUMENT_H
+
+#include "ckw/TensorInfo.h"
+#include <cstdint>
+
+namespace ckw
+{
+
+class TensorOperand;
+class TensorComponentOperand;
+
+/** A kernel argument which can be either a tensor storage or a tensor component. */
+class KernelArgument
+{
+public:
+ /** The type of kernel argument. */
+ enum class Type : int32_t
+ {
+ /** The argument that provides the read and/or write access to the tensor data.
+ *
+ * See @ref ckw::TensorStorage to see the list of supported storage type.
+ */
+ TensorStorage,
+
+ /** The argument that provides extra information about the tensor.
+ *
+ * See @ref ckw::TensorComponent to see the list of supported component.
+ */
+ TensorComponent,
+ };
+
+ /** Initialize a new instance of kernel argument class for a tensor storage argument.
+ *
+ * @param[in] tensor The tensor whose storage is exposed to kernel arguments.
+ */
+ KernelArgument(TensorOperand &tensor);
+
+ /** Initialize a new instance of kernel argument class for a tensor component argument.
+ *
+ * @param[in] tensor_component The tensor component to be exposed to kernel arguments.
+ */
+ KernelArgument(TensorComponentOperand &tensor_component);
+
+ /** Get the type of kernel argument. */
+ Type type() const;
+
+ /** Get the argument ID.
+ *
+ * This method can be used to get the tensor info ID of both tensor storage and tensor component arguments.
+ */
+ int32_t id() const;
+
+ /** Get the type of tensor storage.
+ *
+ * This method can only be used for tensor storage argument.
+ */
+ TensorStorageType tensor_storage_type() const;
+
+ /** Get the tensor component type.
+ *
+ * This method can only be used for tensor component argument.
+ */
+ TensorComponentType tensor_component_type() const;
+
+private:
+ Type _type;
+ int32_t _id;
+
+ union SubId
+ {
+ int32_t unknown;
+ TensorStorageType tensor_storage_type;
+ TensorComponentType tensor_component_type;
+ };
+
+ SubId _sub_id{ 0 };
+};
+
+} // namespace ckw
+
+#endif // CKW_PROTOTYPE_INCLUDE_CKW_KERNELARGUMENT_H
diff --git a/compute_kernel_writer/prototype/include/ckw/KernelWriter.h b/compute_kernel_writer/prototype/include/ckw/KernelWriter.h
index 2bf443cd53..146fdac53e 100644
--- a/compute_kernel_writer/prototype/include/ckw/KernelWriter.h
+++ b/compute_kernel_writer/prototype/include/ckw/KernelWriter.h
@@ -88,12 +88,13 @@ public:
/** Declare a tensor argument.
*
- * @param[in] name The name of the tensor.
- * @param[in] info The tensor info.
+ * @param[in] name The name of the tensor.
+ * @param[in] info The tensor info.
+ * @param[in] storage_type The tensor storage type.
*
* @return The @ref TensorOperand object.
*/
- TensorOperand &declare_tensor_argument(const std::string &name, const TensorInfo &info);
+ TensorOperand &declare_tensor_argument(const std::string &name, const TensorInfo &info, TensorStorageType storage_type = TensorStorageType::BufferUint8Ptr);
/** Declare a compile-time constant scalar argument.
*
@@ -117,10 +118,9 @@ public:
TileOperand &declare_tile(const std::string &name, TArgs &&...args)
{
const auto var_name = generate_variable_name(name);
- auto operand = new TileOperand(var_name, ::std::forward<TArgs>(args)...);
- register_operand(operand, true);
+ auto operand = std::make_unique<TileOperand>(var_name, ::std::forward<TArgs>(args)...);
- return *operand;
+ return declare_tile_operand(std::move(operand));
}
// =============================================================================================
@@ -272,14 +272,11 @@ private:
*/
::std::string generate_variable_name(const std::string &name) const;
- /** Register the operand to the kernel.
+ /** Declare the tile operand.
*
- * The operand is uniquely owned by the kernel afterward.
- *
- * @param[in] operand The operand to be registered.
- * @param[in] declaring Whether the tile declaration is generated.
+ * @param[in] operand The tile operand to be declared.
*/
- void register_operand(OperandBase *operand, bool declaring);
+ TileOperand &declare_tile_operand(std::unique_ptr<TileOperand> operand);
private:
Kernel *_kernel;
diff --git a/compute_kernel_writer/prototype/include/ckw/TensorInfo.h b/compute_kernel_writer/prototype/include/ckw/TensorInfo.h
index 8eaa6ae314..55f8101a53 100644
--- a/compute_kernel_writer/prototype/include/ckw/TensorInfo.h
+++ b/compute_kernel_writer/prototype/include/ckw/TensorInfo.h
@@ -67,7 +67,7 @@ enum class TensorComponentBitmask : uint32_t
* The data type is represented as an integer. The value of the integer value
* is assigned to retrieve the information through the @ref TensorComponentBitmask.
*/
-enum class TensorComponent : uint32_t
+enum class TensorComponentType : uint32_t
{
Unknown = 0x00000000,
OffsetFirstElement = 0x01000000,
@@ -88,7 +88,7 @@ enum class TensorComponent : uint32_t
/** Compute Kernel Writer tensor storage. The tensor storage represents the type of tensor memory object.
*/
-enum class TensorStorage : uint32_t
+enum class TensorStorageType : uint32_t
{
Unknown = 0x00000000,
BufferUint8Ptr = 0x01000000,
diff --git a/compute_kernel_writer/prototype/include/ckw/TensorOperand.h b/compute_kernel_writer/prototype/include/ckw/TensorOperand.h
index 3a2509e7c8..6d88932c66 100644
--- a/compute_kernel_writer/prototype/include/ckw/TensorOperand.h
+++ b/compute_kernel_writer/prototype/include/ckw/TensorOperand.h
@@ -48,10 +48,11 @@ class TensorOperand : public OperandBase
public:
/** Initialize a new instance of @ref TensorOperand class.
*
- * @param[in] name The name of the tensor.
- * @param[in] info The tensor info.
+ * @param[in] name The name of the tensor.
+ * @param[in] info The tensor info.
+ * @param[in] storage_type The tensor storage type.
*/
- TensorOperand(const ::std::string &name, const TensorInfo &info);
+ TensorOperand(const ::std::string &name, const TensorInfo &info, TensorStorageType storage_type);
/** No copy constructor. */
TensorOperand(const TensorOperand &other) = delete;
@@ -71,6 +72,9 @@ public:
/** Get the tensor info. */
TensorInfo &info();
+ /** Get the tensor storage type. */
+ TensorStorageType storage_type() const;
+
/** Get the data type. */
virtual DataType data_type() const override;
@@ -96,43 +100,44 @@ public:
TensorOperand &tile_sampler(const TensorTileSampler &value);
/** Get the operand that contains the stride in y dimension of the tensor. */
- TileOperand &stride1();
+ TensorComponentOperand &stride1();
/** Get the operand that contains the stride in z dimension of the tensor. */
- TileOperand &stride2();
+ TensorComponentOperand &stride2();
/** Get the operand that contains the stride in w dimension of the tensor. */
- TileOperand &stride3();
+ TensorComponentOperand &stride3();
/** Get the operand that contains the stride in w dimension of the tensor. */
- TileOperand &stride4();
+ TensorComponentOperand &stride4();
/** Get the operand that contains the size of dimension 0 of the tensor. */
- TileOperand &dim0();
+ TensorComponentOperand &dim0();
/** Get the operand that contains the size of dimension 1 of the tensor. */
- TileOperand &dim1();
+ TensorComponentOperand &dim1();
/** Get the operand that contains the size of dimension 2 of the tensor. */
- TileOperand &dim2();
+ TensorComponentOperand &dim2();
/** Get the operand that contains the size of dimension 3 of the tensor. */
- TileOperand &dim3();
+ TensorComponentOperand &dim3();
/** Get the operand that contains the size of dimension 4 of the tensor. */
- TileOperand &dim4();
+ TensorComponentOperand &dim4();
/** Get the operand that contains the size of dimensions 1 and 2 collapsed. */
- TileOperand &dim1_dim2();
+ TensorComponentOperand &dim1_dim2();
/** Get the operand that contains the size of dimensions 1, 2 and 3 collapsed. */
- TileOperand &dim1_dim2_dim3();
+ TensorComponentOperand &dim1_dim2_dim3();
/** Get the operand that contains the offset in bytes to the first element. */
- TileOperand &offset_first_element_in_bytes();
+ TensorComponentOperand &offset_first_element_in_bytes();
private:
- TensorInfo _info;
+ TensorInfo _info;
+ TensorStorageType _storage_type;
TileOperand *_tile{ nullptr };
TensorTileSampler _tile_sampler{};
@@ -161,10 +166,19 @@ class TensorComponentOperand : public TileOperand
public:
/** Initialize a new instance of @ref TensorComponentOperand class.
*
- * @param[in] name The name of the operand.
+ * @param[in] tensor The tensor operand.
* @param[in] component The tensor info component.
*/
- TensorComponentOperand(const ::std::string &name, TensorComponent component);
+ TensorComponentOperand(TensorOperand &tensor, TensorComponentType component);
+
+ /** Get the tensor operand. */
+ TensorOperand &tensor();
+
+ /** Get the tensor operand. */
+ const TensorOperand &tensor() const;
+
+ /** Get the tensor component. */
+ TensorComponentType component_type() const;
/** (Internal use only) Create the implementation operand.
*
@@ -173,7 +187,8 @@ public:
virtual prototype::Operand create_impl_operand(prototype::IGpuKernelWriter *writer) const override;
private:
- TensorComponent _component;
+ TensorOperand &_tensor;
+ TensorComponentType _component;
};
} // namespace ckw