aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorNikolaj Jensen <nikolaj.jensen@arm.com>2023-06-27 14:13:24 +0100
committerNikolaj Jensen <nikolaj.jensen@arm.com>2023-07-10 16:04:14 +0000
commit5ff480265a110ea1f2ce24491e082f52348b0f92 (patch)
tree438268e9c4465213d57477104620a260d59ae33a
parent4c0a38a33046416a8f8fd779a467502b98311bcd (diff)
downloadComputeLibrary-5ff480265a110ea1f2ce24491e082f52348b0f92.tar.gz
Port operations to CKW prototype
Resolves: COMPMID-6334 Signed-off-by: Nikolaj Jensen <nikolaj.jensen@arm.com> Change-Id: I500d30f09daec4087eb3e7aecd1de77dc8fd53b4 Signed-off-by: Nikolaj Jensen <nikolaj.jensen@arm.com> Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/9828 Tested-by: Arm Jenkins <bsgcomp@arm.com> Reviewed-by: Gunes Bayir <gunes.bayir@arm.com> Reviewed-by: Jakub Sujak <jakub.sujak@arm.com> Comments-Addressed: Arm Jenkins <bsgcomp@arm.com> Benchmark: Arm Jenkins <bsgcomp@arm.com>
-rw-r--r--compute_kernel_writer/include/ckw/TensorInfo.h2
-rw-r--r--compute_kernel_writer/include/ckw/TileInfo.h2
-rw-r--r--compute_kernel_writer/include/ckw/types/DataType.h50
-rw-r--r--compute_kernel_writer/prototype/examples/add_exp_store.cpp11
-rw-r--r--compute_kernel_writer/prototype/include/ckw/Kernel.h2
-rw-r--r--compute_kernel_writer/prototype/include/ckw/KernelWriter.h123
-rw-r--r--compute_kernel_writer/prototype/include/ckw/OperandBase.h2
-rw-r--r--compute_kernel_writer/prototype/include/ckw/TensorInfo.h2
-rw-r--r--compute_kernel_writer/prototype/include/ckw/TensorOperand.h2
-rw-r--r--compute_kernel_writer/prototype/include/ckw/TensorTileSampler.h2
-rw-r--r--compute_kernel_writer/prototype/include/ckw/TileInfo.h2
-rw-r--r--compute_kernel_writer/prototype/include/ckw/types/ConvertPolicy.h41
-rw-r--r--compute_kernel_writer/prototype/include/ckw/types/DataType.h50
-rw-r--r--compute_kernel_writer/prototype/include/ckw/types/Functions.h61
-rw-r--r--compute_kernel_writer/prototype/include/ckw/types/GpuTargetLanguage.h (renamed from compute_kernel_writer/include/ckw/Types.h)24
-rw-r--r--compute_kernel_writer/prototype/include/ckw/types/Operators.h74
-rw-r--r--compute_kernel_writer/prototype/include/ckw/types/TensorSamplerTypes.h (renamed from compute_kernel_writer/prototype/include/ckw/Types.h)68
-rw-r--r--compute_kernel_writer/prototype/src/Kernel.cpp2
-rw-r--r--compute_kernel_writer/prototype/src/KernelWriter.cpp96
-rw-r--r--compute_kernel_writer/prototype/src/Prototype.h389
-rw-r--r--compute_kernel_writer/prototype/src/TensorTileSampler.cpp2
-rw-r--r--compute_kernel_writer/src/TensorUtils.cpp1
-rw-r--r--compute_kernel_writer/src/cl/CLConstantTile.h8
-rw-r--r--compute_kernel_writer/src/cl/CLHelpers.cpp2
-rw-r--r--compute_kernel_writer/src/cl/CLTile.h8
-rw-r--r--src/dynamic_fusion/sketch/gpu/ckw_driver/GpuCkwVariableTable.cpp2
-rw-r--r--src/dynamic_fusion/sketch/gpu/ckw_driver/components/GpuCkwElementwiseBinary.cpp4
27 files changed, 786 insertions, 246 deletions
diff --git a/compute_kernel_writer/include/ckw/TensorInfo.h b/compute_kernel_writer/include/ckw/TensorInfo.h
index 41abe60f35..63d9f412b6 100644
--- a/compute_kernel_writer/include/ckw/TensorInfo.h
+++ b/compute_kernel_writer/include/ckw/TensorInfo.h
@@ -25,7 +25,7 @@
#ifndef COMPUTE_KERNEL_WRITER_INCLUDE_CKW_TENSORINFO_H
#define COMPUTE_KERNEL_WRITER_INCLUDE_CKW_TENSORINFO_H
-#include "ckw/Types.h"
+#include "ckw/types/DataType.h"
#include <array>
#include <cstdint>
diff --git a/compute_kernel_writer/include/ckw/TileInfo.h b/compute_kernel_writer/include/ckw/TileInfo.h
index 293a90fb94..b8094f79bf 100644
--- a/compute_kernel_writer/include/ckw/TileInfo.h
+++ b/compute_kernel_writer/include/ckw/TileInfo.h
@@ -25,7 +25,7 @@
#ifndef COMPUTE_KERNEL_WRITER_INCLUDE_CKW_TILEINFO
#define COMPUTE_KERNEL_WRITER_INCLUDE_CKW_TILEINFO
-#include "ckw/Types.h"
+#include "ckw/types/DataType.h"
#include <array>
#include <cstdint>
diff --git a/compute_kernel_writer/include/ckw/types/DataType.h b/compute_kernel_writer/include/ckw/types/DataType.h
new file mode 100644
index 0000000000..3447dd61d6
--- /dev/null
+++ b/compute_kernel_writer/include/ckw/types/DataType.h
@@ -0,0 +1,50 @@
+/*
+* Copyright (c) 2023 Arm Limited.
+*
+* SPDX-License-Identifier: MIT
+*
+* Permission is hereby granted, free of charge, to any person obtaining a copy
+* of this software and associated documentation files (the "Software"), to
+* deal in the Software without restriction, including without limitation the
+* rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+* sell copies of the Software, and to permit persons to whom the Software is
+* furnished to do so, subject to the following conditions:
+*
+* The above copyright notice and this permission notice shall be included in all
+* copies or substantial portions of the Software.
+*
+* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+* SOFTWARE.
+*/
+
+#ifndef CKW_INCLUDE_CKW_DATATYPE_H
+#define CKW_INCLUDE_CKW_DATATYPE_H
+
+#include <cstdint>
+
+namespace ckw
+{
+
+/** Compute Kernel Writer data types. This data type is used by the code variables and tensor arguments. */
+enum class DataType : int32_t
+{
+ Unknown = 0x00,
+ Fp32 = 0x11,
+ Fp16 = 0x12,
+ Int32 = 0x21,
+ Int16 = 0x22,
+ Int8 = 0x24,
+ Uint32 = 0x31,
+ Uint16 = 0x32,
+ Uint8 = 0x34,
+ Bool = 0x41
+};
+
+} // namespace ckw
+
+#endif //CKW_INCLUDE_CKW_DATATYPE_H
diff --git a/compute_kernel_writer/prototype/examples/add_exp_store.cpp b/compute_kernel_writer/prototype/examples/add_exp_store.cpp
index a9be0495ec..9529268c9a 100644
--- a/compute_kernel_writer/prototype/examples/add_exp_store.cpp
+++ b/compute_kernel_writer/prototype/examples/add_exp_store.cpp
@@ -27,7 +27,6 @@
#include "ckw/TensorOperand.h"
#include "ckw/TensorTileSampler.h"
#include "ckw/TileOperand.h"
-#include "ckw/Types.h"
#include "common/ExampleComponentArgument.h"
#include "common/ExampleKernelWriter.h"
@@ -110,7 +109,7 @@ void op_binary_elementwise(ExampleScopedKernelWriter writer, std::vector<Example
auto &dst_tile = dst->tile();
// Perform the operation.
- writer->op_binary_expression(dst_tile, lhs_tile, rhs_tile, BinaryOp::Add);
+ writer->op_binary_expression(dst_tile, lhs_tile, BinaryOp::Add, rhs_tile);
}
void op_exp(ExampleScopedKernelWriter writer, std::vector<ExampleComponentArgument *> operands)
@@ -138,7 +137,7 @@ void op_exp(ExampleScopedKernelWriter writer, std::vector<ExampleComponentArgume
auto &dst_tile = dst->tile();
// Perform the operation.
- writer->op_scalar_function(dst_tile, src_tile, ScalarUnaryFunction::Exp);
+ writer->op_unary_elementwise_function(dst_tile, UnaryFunction::Exp, src_tile);
}
void op_store(ExampleScopedKernelWriter writer, std::vector<ExampleComponentArgument *> operands)
@@ -164,9 +163,9 @@ int main()
const TensorInfo src1_info(DataType::Fp32, TensorShape({ 3, 10, 20, 1, 1 }), TensorDataLayout::Nhwc, 1);
const TensorInfo dst_info(DataType::Fp32, TensorShape({ 3, 10, 20, 1, 1 }), TensorDataLayout::Nhwc, 2);
- ExampleComponentArgument src0(writer->create_tensor_argument("src0", src0_info));
- ExampleComponentArgument src1(writer->create_tensor_argument("src1", src1_info));
- ExampleComponentArgument dst(writer->create_tensor_argument("dst", dst_info));
+ ExampleComponentArgument src0(writer->declare_tensor_argument("src0", src0_info));
+ ExampleComponentArgument src1(writer->declare_tensor_argument("src1", src1_info));
+ ExampleComponentArgument dst(writer->declare_tensor_argument("dst", dst_info));
ExampleComponentArgument ans;
diff --git a/compute_kernel_writer/prototype/include/ckw/Kernel.h b/compute_kernel_writer/prototype/include/ckw/Kernel.h
index 57a8a40341..527206feec 100644
--- a/compute_kernel_writer/prototype/include/ckw/Kernel.h
+++ b/compute_kernel_writer/prototype/include/ckw/Kernel.h
@@ -26,7 +26,7 @@
#define CKW_PROTOTYPE_INCLUDE_CKW_KERNEL_H
#include "ckw/OperandBase.h"
-#include "ckw/Types.h"
+#include "ckw/types/GpuTargetLanguage.h"
#include <map>
#include <memory>
diff --git a/compute_kernel_writer/prototype/include/ckw/KernelWriter.h b/compute_kernel_writer/prototype/include/ckw/KernelWriter.h
index 3b1539116a..2bf443cd53 100644
--- a/compute_kernel_writer/prototype/include/ckw/KernelWriter.h
+++ b/compute_kernel_writer/prototype/include/ckw/KernelWriter.h
@@ -30,6 +30,9 @@
#include "ckw/TensorOperand.h"
#include "ckw/TileInfo.h"
#include "ckw/TileOperand.h"
+#include "ckw/types/ConvertPolicy.h"
+#include "ckw/types/Functions.h"
+#include "ckw/types/Operators.h"
#include <memory>
@@ -83,23 +86,23 @@ public:
// Tensor and tile declaration
// =============================================================================================
- /** Define a tensor argument.
+ /** Declare a tensor argument.
*
* @param[in] name The name of the tensor.
* @param[in] info The tensor info.
*
* @return The @ref TensorOperand object.
*/
- TensorOperand &create_tensor_argument(const char *name, const TensorInfo &info);
+ TensorOperand &declare_tensor_argument(const std::string &name, const TensorInfo &info);
- /** Define a compile-time constant scalar argument.
+ /** Declare a compile-time constant scalar argument.
*
* @param[in] name The name of the tile.
* @param[in] value The value of the tile.
*
* @return The @ref TileOperand object.
*/
- TileOperand &create_tile_argument(const char *name, int32_t value);
+ TileOperand &declare_tile_argument(const std::string &name, int32_t value);
/** Declare a new tile.
*
@@ -111,7 +114,7 @@ public:
* @return The @ref TileOperand object.
*/
template <typename... TArgs>
- TileOperand &declare_tile(const char *name, TArgs &&...args)
+ TileOperand &declare_tile(const std::string &name, TArgs &&...args)
{
const auto var_name = generate_variable_name(name);
auto operand = new TileOperand(var_name, ::std::forward<TArgs>(args)...);
@@ -144,29 +147,103 @@ public:
// Data processing
// =============================================================================================
- /** Write assignment: `<dst> = <src>`.
+ /** Write assignment: `<dst> = <src>;`.
*
- * @param[in] dst The destination tile.
- * @param[in] src The source tile.
+ * @param[out] dst The destination tile.
+ * @param[in] src The source tile.
*/
- void op_assign(TileOperand &dst, const TileOperand &src);
+ void op_assign(const TileOperand &dst, const TileOperand &src);
- /** Write binary expression: `<dst> = <lhs> <op> <rhs>`.
+ /** Write the cast: `<dst> = convert_<dst.type><_sat>(<src>);`.
*
- * @param[in] dst The destination tile.
- * @param[in] lhs The LHS operand.
- * @param[in] rhs The RHS operand.
- * @param[in] op The binary operator.
+ * @param[out] dst The destination tile.
+ * @param[in] src The source tile.
+ * @param[in] policy The policy governing the behavior of the cast.
*/
- void op_binary_expression(TileOperand &dst, const TileOperand &lhs, const TileOperand &rhs, BinaryOp op);
+ void op_cast_expression(const TileOperand &dst, const TileOperand &src, ConvertPolicy policy);
- /** Write function applied to scalar value: `<dst> = <func>(<src>)`.
+ /** Write the unary expression: `<dst> = <op> <src>`.
*
- * @param[in] dst The destination tile.
- * @param[in] src The source tile.
- * @param[in] func The function to be applied to the source tile.
+ * @param[out] dst The destination tile.
+ * @param[in] op The unary operator.
+ * @param[in] src The source tile.
*/
- void op_scalar_function(TileOperand &dst, const TileOperand &src, ScalarUnaryFunction func);
+ void op_unary_expression(const TileOperand &dst, UnaryOp op, const TileOperand &src);
+
+ /** Write binary expression: `<dst> = <lhs> <op> <rhs>;`.
+ *
+ * @param[out] dst The destination tile.
+ * @param[in] lhs The LHS tile.
+ * @param[in] op The binary operator.
+ * @param[in] rhs The RHS tile.
+ */
+ void op_binary_expression(const TileOperand &dst, const TileOperand &lhs, BinaryOp op, const TileOperand &rhs);
+
+ /** Write function applied to scalar value: `<dst> = <func>(<src>);`.
+ *
+ * @param[out] dst The destination tile.
+ * @param[in] func The function to be applied to the source tile.
+ * @param[in] src The source tile.
+ */
+ void op_unary_elementwise_function(const TileOperand &dst, UnaryFunction func, const TileOperand &src);
+
+ /** Write function applied to scalar value: `<dst> = <func>(<first>, <second>);`.
+ *
+ * @param[out] dst The destination tile.
+ * @param[in] func The function to be applied to the source tiles.
+ * @param[in] first The first argument tile.
+ * @param[in] second The second argument tile.
+ */
+ void op_binary_elementwise_function(const TileOperand &dst, BinaryFunction func, const TileOperand &first, const TileOperand &second);
+
+ /** Write function applied to scalar value: `<dst> = <func>(<first>, <second>, <third>);`.
+ *
+ * @param[out] dst The destination tile.
+ * @param[in] func The function to be applied to the source tiles.
+ * @param[in] first The first argument tile.
+ * @param[in] second The second argument tile.
+ * @param[in] third The third argument tile.
+ */
+ void op_ternary_elementwise_function(const TileOperand &dst, TernaryFunction func, const TileOperand &first, const TileOperand &second, const TileOperand &third);
+
+ /** Write if-statement: `if(<lhs> <op> <rhs>) { <body> }`.
+ *
+ * @param[in] lhs The LHS tile of the condition.
+ * @param[in] op The relational binary operator.
+ * @param[in] rhs The RHS tile of the condition.
+ * @param[in] body The body of the if-statement.
+ */
+ void op_if(const TileOperand &lhs, BinaryOp op, const TileOperand &rhs, const std::function<void()> &body);
+
+ /** Write else-if-statement: `else if(<lhs> <op> <rhs>) { <body> }`.
+ *
+ * @param[in] lhs The LHS tile of the condition.
+ * @param[in] op The relational binary operator.
+ * @param[in] rhs The RHS tile of the condition.
+ * @param[in] body The body of the else-if-statement.
+ */
+ void op_else_if(const TileOperand &lhs, BinaryOp op, const TileOperand &rhs, const std::function<void()> &body);
+
+ /** Write an else-statement: `else { <body> }`.
+ *
+ * @param[in] body The body of the else-statement.
+ */
+ void op_else(const std::function<void()> &body);
+
+ /** Write for-loops: `for(; <var> <cond_op> <cond_value>; <update_op> <update_value>) { body }`.
+ *
+ * @param[in] var_name The name of the variable used in condition.
+ * @param[in] cond_op The relational binary operator used in condition.
+ * @param[in] cond_value_name The value which the variable is compared against.
+ * @param[in] update_op The assignment operator used for updating the update value.
+ * @param[in, out] update_value The value which is updated at every iteration.
+ * @param[in] body The body of the for-loop.
+ */
+ void op_for_loop(const TileOperand &var_name, BinaryOp cond_op, const TileOperand &cond_value_name, AssignmentOp update_op, const TileOperand &update_value_name, const std::function<void()> &body);
+
+ /** Write the return statement: `return;`
+ */
+ void op_return();
// =============================================================================================
// Misc
@@ -174,8 +251,8 @@ public:
/** Set `dst` the global ID of dimension `dim`.
*
- * @param[in] dst The tile to be written to.
- * @param[in] dim The global ID dimension.
+ * @param[out] dst The tile to be written to.
+ * @param[in] dim The global ID dimension.
*/
void op_get_global_id(TileOperand &dst, int32_t dim);
@@ -193,7 +270,7 @@ private:
*
* @return The full variable name.
*/
- ::std::string generate_variable_name(const char *name) const;
+ ::std::string generate_variable_name(const std::string &name) const;
/** Register the operand to the kernel.
*
diff --git a/compute_kernel_writer/prototype/include/ckw/OperandBase.h b/compute_kernel_writer/prototype/include/ckw/OperandBase.h
index a9e313fc0a..06d9f82756 100644
--- a/compute_kernel_writer/prototype/include/ckw/OperandBase.h
+++ b/compute_kernel_writer/prototype/include/ckw/OperandBase.h
@@ -25,7 +25,7 @@
#ifndef CKW_PROTOTYPE_INCLUDE_CKW_OPERANDBASE_H
#define CKW_PROTOTYPE_INCLUDE_CKW_OPERANDBASE_H
-#include "ckw/Types.h"
+#include "ckw/types/DataType.h"
#include <string>
namespace ckw
diff --git a/compute_kernel_writer/prototype/include/ckw/TensorInfo.h b/compute_kernel_writer/prototype/include/ckw/TensorInfo.h
index 807158896b..8eaa6ae314 100644
--- a/compute_kernel_writer/prototype/include/ckw/TensorInfo.h
+++ b/compute_kernel_writer/prototype/include/ckw/TensorInfo.h
@@ -25,7 +25,7 @@
#ifndef CKW_PROTOTYPE_INCLUDE_CKW_TENSORINFO_H
#define CKW_PROTOTYPE_INCLUDE_CKW_TENSORINFO_H
-#include "ckw/Types.h"
+#include "ckw/types/DataType.h"
#include <array>
#include <cstdint>
diff --git a/compute_kernel_writer/prototype/include/ckw/TensorOperand.h b/compute_kernel_writer/prototype/include/ckw/TensorOperand.h
index 7a663f095b..3a2509e7c8 100644
--- a/compute_kernel_writer/prototype/include/ckw/TensorOperand.h
+++ b/compute_kernel_writer/prototype/include/ckw/TensorOperand.h
@@ -29,7 +29,7 @@
#include "ckw/TensorInfo.h"
#include "ckw/TensorTileSampler.h"
#include "ckw/TileOperand.h"
-#include "ckw/Types.h"
+#include "ckw/types/DataType.h"
#include <memory>
diff --git a/compute_kernel_writer/prototype/include/ckw/TensorTileSampler.h b/compute_kernel_writer/prototype/include/ckw/TensorTileSampler.h
index 2ea65bce9e..e1bf0c52b8 100644
--- a/compute_kernel_writer/prototype/include/ckw/TensorTileSampler.h
+++ b/compute_kernel_writer/prototype/include/ckw/TensorTileSampler.h
@@ -25,7 +25,7 @@
#ifndef CKW_PROTOTYPE_INCLUDE_CKW_TENSORTILESAMPLER_H
#define CKW_PROTOTYPE_INCLUDE_CKW_TENSORTILESAMPLER_H
-#include "ckw/Types.h"
+#include "ckw/types/TensorSamplerTypes.h"
#include <functional>
namespace ckw
diff --git a/compute_kernel_writer/prototype/include/ckw/TileInfo.h b/compute_kernel_writer/prototype/include/ckw/TileInfo.h
index c60880dcd1..de9e47af2b 100644
--- a/compute_kernel_writer/prototype/include/ckw/TileInfo.h
+++ b/compute_kernel_writer/prototype/include/ckw/TileInfo.h
@@ -25,7 +25,7 @@
#ifndef CKW_PROTOTYPE_INCLUDE_CKW_TILEINFO_H
#define CKW_PROTOTYPE_INCLUDE_CKW_TILEINFO_H
-#include "ckw/Types.h"
+#include "ckw/types/DataType.h"
#include <array>
#include <cstdint>
diff --git a/compute_kernel_writer/prototype/include/ckw/types/ConvertPolicy.h b/compute_kernel_writer/prototype/include/ckw/types/ConvertPolicy.h
new file mode 100644
index 0000000000..2a198507eb
--- /dev/null
+++ b/compute_kernel_writer/prototype/include/ckw/types/ConvertPolicy.h
@@ -0,0 +1,41 @@
+/*
+ * Copyright (c) 2023 Arm Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#ifndef CKW_INCLUDE_CKW_CONVERTPOLICY_H
+#define CKW_INCLUDE_CKW_CONVERTPOLICY_H
+
+#include <cstdint>
+
+namespace ckw
+{
+
+enum class ConvertPolicy : int32_t
+{
+ None = 0, // No policy specified.
+ Saturate = 1, // Saturated.
+};
+
+} // namespace ckw
+
+#endif //CKW_INCLUDE_CKW_CONVERTPOLICY_H
diff --git a/compute_kernel_writer/prototype/include/ckw/types/DataType.h b/compute_kernel_writer/prototype/include/ckw/types/DataType.h
new file mode 100644
index 0000000000..3447dd61d6
--- /dev/null
+++ b/compute_kernel_writer/prototype/include/ckw/types/DataType.h
@@ -0,0 +1,50 @@
+/*
+* Copyright (c) 2023 Arm Limited.
+*
+* SPDX-License-Identifier: MIT
+*
+* Permission is hereby granted, free of charge, to any person obtaining a copy
+* of this software and associated documentation files (the "Software"), to
+* deal in the Software without restriction, including without limitation the
+* rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+* sell copies of the Software, and to permit persons to whom the Software is
+* furnished to do so, subject to the following conditions:
+*
+* The above copyright notice and this permission notice shall be included in all
+* copies or substantial portions of the Software.
+*
+* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+* SOFTWARE.
+*/
+
+#ifndef CKW_INCLUDE_CKW_DATATYPE_H
+#define CKW_INCLUDE_CKW_DATATYPE_H
+
+#include <cstdint>
+
+namespace ckw
+{
+
+/** Compute Kernel Writer data types. This data type is used by the code variables and tensor arguments. */
+enum class DataType : int32_t
+{
+ Unknown = 0x00,
+ Fp32 = 0x11,
+ Fp16 = 0x12,
+ Int32 = 0x21,
+ Int16 = 0x22,
+ Int8 = 0x24,
+ Uint32 = 0x31,
+ Uint16 = 0x32,
+ Uint8 = 0x34,
+ Bool = 0x41
+};
+
+} // namespace ckw
+
+#endif //CKW_INCLUDE_CKW_DATATYPE_H
diff --git a/compute_kernel_writer/prototype/include/ckw/types/Functions.h b/compute_kernel_writer/prototype/include/ckw/types/Functions.h
new file mode 100644
index 0000000000..68146cb1c8
--- /dev/null
+++ b/compute_kernel_writer/prototype/include/ckw/types/Functions.h
@@ -0,0 +1,61 @@
+/*
+* Copyright (c) 2023 Arm Limited.
+*
+* SPDX-License-Identifier: MIT
+*
+* Permission is hereby granted, free of charge, to any person obtaining a copy
+* of this software and associated documentation files (the "Software"), to
+* deal in the Software without restriction, including without limitation the
+* rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+* sell copies of the Software, and to permit persons to whom the Software is
+* furnished to do so, subject to the following conditions:
+*
+* The above copyright notice and this permission notice shall be included in all
+* copies or substantial portions of the Software.
+*
+* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+* SOFTWARE.
+*/
+
+#ifndef CKW_INCLUDE_CKW_FUNCTIONS_H
+#define CKW_INCLUDE_CKW_FUNCTIONS_H
+
+#include <cstdint>
+
+namespace ckw
+{
+
+enum class UnaryFunction : int32_t
+{
+ Exp = 0x0000,
+ Tanh = 0x0001,
+ Sqrt = 0x0002,
+ Erf = 0x0003,
+ Fabs = 0x0004,
+ IsGreaterEqual = 0x0005,
+ Log = 0x0006,
+ Round = 0x0007,
+
+ // Misc
+ SizeOf = 0x0008,
+};
+
+enum class BinaryFunction : int32_t
+{
+ Min = 0x0000,
+ Max = 0x0001,
+};
+
+enum class TernaryFunction : int32_t
+{
+ Select = 0x0000,
+};
+
+} // namespace ckw
+
+#endif //CKW_INCLUDE_CKW_FUNCTIONS_H
diff --git a/compute_kernel_writer/include/ckw/Types.h b/compute_kernel_writer/prototype/include/ckw/types/GpuTargetLanguage.h
index c9f80b65e0..6c08617949 100644
--- a/compute_kernel_writer/include/ckw/Types.h
+++ b/compute_kernel_writer/prototype/include/ckw/types/GpuTargetLanguage.h
@@ -21,25 +21,21 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*/
-#ifndef COMPUTE_KERNEL_WRITER_INCLUDE_CKW_TYPES_H
-#define COMPUTE_KERNEL_WRITER_INCLUDE_CKW_TYPES_H
+
+#ifndef CKW_INCLUDE_CKW_GPUTARGETLANGUAGE_H
+#define CKW_INCLUDE_CKW_GPUTARGETLANGUAGE_H
+
+#include <cstdint>
namespace ckw
{
-/** Compute Kernel Writer data types. This data type is used by the code variables and tensor arguments. */
-enum class DataType
+
+enum class GpuTargetLanguage : int32_t
{
Unknown,
- Fp32,
- Fp16,
- Int32,
- Int16,
- Int8,
- Uint32,
- Uint16,
- Uint8,
- Bool
+ OpenCL
};
+
} // namespace ckw
-#endif /* COMPUTE_KERNEL_WRITER_INCLUDE_CKW_TYPES_H */
+#endif //CKW_INCLUDE_CKW_GPUTARGETLANGUAGE_H
diff --git a/compute_kernel_writer/prototype/include/ckw/types/Operators.h b/compute_kernel_writer/prototype/include/ckw/types/Operators.h
new file mode 100644
index 0000000000..78027f1ed5
--- /dev/null
+++ b/compute_kernel_writer/prototype/include/ckw/types/Operators.h
@@ -0,0 +1,74 @@
+/*
+* Copyright (c) 2023 Arm Limited.
+*
+* SPDX-License-Identifier: MIT
+*
+* Permission is hereby granted, free of charge, to any person obtaining a copy
+* of this software and associated documentation files (the "Software"), to
+* deal in the Software without restriction, including without limitation the
+* rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+* sell copies of the Software, and to permit persons to whom the Software is
+* furnished to do so, subject to the following conditions:
+*
+* The above copyright notice and this permission notice shall be included in all
+* copies or substantial portions of the Software.
+*
+* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+* SOFTWARE.
+*/
+
+#ifndef CKW_INCLUDE_CKW_OPERATORS_H
+#define CKW_INCLUDE_CKW_OPERATORS_H
+
+#include <cstdint>
+
+namespace ckw
+{
+
+enum class UnaryOp : int32_t
+{
+ LogicalNot = 0x0000, // !
+};
+
+/* Binary operations
+*/
+enum class BinaryOp : int32_t
+{
+ // Elementwise
+ Add = 0x0000, // +
+ Sub = 0x0001, // -
+ Mul = 0x0002, // *
+ Div = 0x0003, // /
+ Mod = 0x0004, // %
+ // Relational
+ Equal = 0x1000, // ==
+ Less = 0x1001, // <
+ LessEqual = 0x1002, // <=
+ Greater = 0x1003, // >
+ GreaterEqual = 0x1004, // >=
+ // Algebra
+ MatMul_Nt_Nt = 0x2000, // X
+ MatMul_Nt_T = 0x2001, // X
+ MatMul_T_Nt = 0x2002, // X
+ MatMul_T_T = 0x2003, // X
+ Dot = 0x2004, // .
+ // Logical
+ LogicalAnd = 0x3000, // &&
+ LogicalOr = 0x3001, // ||
+};
+
+enum class AssignmentOp : int32_t
+{
+ // Unary
+ Increment = 0x0000, // +=
+ Decrement = 0x0001, // -=
+};
+
+} // namespace ckw
+
+#endif //CKW_INCLUDE_CKW_OPERATORS_H
diff --git a/compute_kernel_writer/prototype/include/ckw/Types.h b/compute_kernel_writer/prototype/include/ckw/types/TensorSamplerTypes.h
index bb5d7ce077..836bd13c95 100644
--- a/compute_kernel_writer/prototype/include/ckw/Types.h
+++ b/compute_kernel_writer/prototype/include/ckw/types/TensorSamplerTypes.h
@@ -22,76 +22,14 @@
* SOFTWARE.
*/
-#ifndef CKW_PROTOTYPE_INCLUDE_CKW_TYPES_H
-#define CKW_PROTOTYPE_INCLUDE_CKW_TYPES_H
+#ifndef CKW_INCLUDE_CKW_TENSORSAMPLERTYPES_H
+#define CKW_INCLUDE_CKW_TENSORSAMPLERTYPES_H
-#include <array>
#include <cstdint>
namespace ckw
{
-/** Compute Kernel Writer data types. This data type is used by the code variables and tensor arguments. */
-enum class DataType
-{
- Unknown = 0x00,
- Fp32 = 0x11,
- Fp16 = 0x12,
- Int32 = 0x21,
- Int16 = 0x22,
- Int8 = 0x24,
- Uint32 = 0x31,
- Uint16 = 0x32,
- Uint8 = 0x34,
- Bool = 0x41
-};
-
-enum class GpuTargetLanguage
-{
- Unknown,
- OpenCL
-};
-
-/* Binary operations
-*/
-enum class BinaryOp : int32_t
-{
- // Elementwise
- Add = 0x0000, // +
- Sub = 0x0001, // -
- Mul = 0x0002, // *
- Div = 0x0003, // /
- Mod = 0x0004, // %
- // Relational
- Equal = 0x1000, // ==
- Less = 0x1001, // <
- LessEqual = 0x1002, // <=
- Greater = 0x1003, // >
- GreaterEqual = 0x1004, // >=
- // Algebra
- MatMul_Nt_Nt = 0x2000, // X
- MatMul_Nt_T = 0x2001, // X
- MatMul_T_Nt = 0x2002, // X
- MatMul_T_T = 0x2003, // X
- Dot = 0x2004, // .
- // Logical
- LogicalAnd = 0x3000, // &&
- LogicalOr = 0x3001, // ||
- LogicalNot = 0x3002 // !
-};
-
-enum class AssignmentOp : int32_t
-{
- // Unary
- Increment = 0x0000, // +=
- Decrement = 0x0001, // -=
-};
-
-enum class ScalarUnaryFunction : int32_t
-{
- Exp,
-};
-
enum class TensorSamplerFormat : int32_t
{
Unknown = 0,
@@ -137,4 +75,4 @@ enum class TensorSamplerAddressModeZ : int32_t
} // namespace ckw
-#endif // CKW_PROTOTYPE_INCLUDE_CKW_TYPES_H
+#endif //CKW_INCLUDE_CKW_TENSORSAMPLERTYPES_H
diff --git a/compute_kernel_writer/prototype/src/Kernel.cpp b/compute_kernel_writer/prototype/src/Kernel.cpp
index bbf5c440a7..692d504887 100644
--- a/compute_kernel_writer/prototype/src/Kernel.cpp
+++ b/compute_kernel_writer/prototype/src/Kernel.cpp
@@ -23,7 +23,7 @@
*/
#include "ckw/Kernel.h"
-#include "ckw/Types.h"
+#include "ckw/types/GpuTargetLanguage.h"
#include "src/Prototype.h"
namespace ckw
diff --git a/compute_kernel_writer/prototype/src/KernelWriter.cpp b/compute_kernel_writer/prototype/src/KernelWriter.cpp
index 5d79985e87..73458efa1d 100644
--- a/compute_kernel_writer/prototype/src/KernelWriter.cpp
+++ b/compute_kernel_writer/prototype/src/KernelWriter.cpp
@@ -85,7 +85,7 @@ int32_t KernelWriter::next_id_space()
// Tensor and tile declaration
// =================================================================================================
-TensorOperand &KernelWriter::create_tensor_argument(const char *name, const TensorInfo &info)
+TensorOperand &KernelWriter::declare_tensor_argument(const std::string &name, const TensorInfo &info)
{
const auto var_name = generate_variable_name(name);
@@ -97,7 +97,7 @@ TensorOperand &KernelWriter::create_tensor_argument(const char *name, const Tens
return *operand;
}
-TileOperand &KernelWriter::create_tile_argument(const char *name, int32_t value)
+TileOperand &KernelWriter::declare_tile_argument(const std::string &name, int32_t value)
{
const auto var_name = generate_variable_name(name);
@@ -107,7 +107,7 @@ TileOperand &KernelWriter::create_tile_argument(const char *name, int32_t value)
return *operand;
}
-std::string KernelWriter::generate_variable_name(const char *name) const
+std::string KernelWriter::generate_variable_name(const std::string &name) const
{
std::stringstream var_name;
@@ -181,7 +181,7 @@ void KernelWriter::op_store(TensorOperand &tensor, const TileOperand &tile, cons
// Data processing
// =================================================================================================
-void KernelWriter::op_assign(TileOperand &dst, const TileOperand &src)
+void KernelWriter::op_assign(const TileOperand &dst, const TileOperand &src)
{
auto impl_dst = dst.create_impl_operand(_impl.get());
auto impl_src = src.create_impl_operand(_impl.get());
@@ -189,7 +189,15 @@ void KernelWriter::op_assign(TileOperand &dst, const TileOperand &src)
_impl->op_assign(impl_dst, impl_src);
}
-void KernelWriter::op_binary_expression(TileOperand &dst, const TileOperand &lhs, const TileOperand &rhs, BinaryOp op)
+void KernelWriter::op_cast_expression(const TileOperand &dst, const TileOperand &src, const ConvertPolicy policy)
+{
+ auto impl_dst = dst.create_impl_operand(_impl.get());
+ auto impl_src = src.create_impl_operand(_impl.get());
+
+ _impl->op_cast_expression(impl_dst, impl_src, policy);
+}
+
+void KernelWriter::op_binary_expression(const TileOperand &dst, const TileOperand &lhs, BinaryOp op, const TileOperand &rhs)
{
auto impl_lhs = lhs.create_impl_operand(_impl.get());
auto impl_rhs = rhs.create_impl_operand(_impl.get());
@@ -198,12 +206,81 @@ void KernelWriter::op_binary_expression(TileOperand &dst, const TileOperand &lhs
_impl->op_binary_expression(impl_dst, impl_lhs, op, impl_rhs);
}
-void KernelWriter::op_scalar_function(TileOperand &dst, const TileOperand &src, ScalarUnaryFunction opcode)
+void KernelWriter::op_unary_expression(const TileOperand &dst, UnaryOp op, const TileOperand &src)
{
auto impl_dst = dst.create_impl_operand(_impl.get());
auto impl_src = src.create_impl_operand(_impl.get());
- _impl->op_scalar_function(impl_dst, impl_src, opcode);
+ _impl->op_unary_expression(impl_dst, op, impl_src);
+}
+
+void KernelWriter::op_unary_elementwise_function(const TileOperand &dst, UnaryFunction opcode, const TileOperand &src)
+{
+ auto impl_dst = dst.create_impl_operand(_impl.get());
+ auto impl_src = src.create_impl_operand(_impl.get());
+
+ _impl->op_unary_elementwise_function(impl_dst, opcode, impl_src);
+}
+
+void KernelWriter::op_binary_elementwise_function(const TileOperand &dst, BinaryFunction opcode, const TileOperand &first, const TileOperand &second)
+{
+ auto impl_dst = dst.create_impl_operand(_impl.get());
+ auto impl_first = first.create_impl_operand(_impl.get());
+ auto impl_second = second.create_impl_operand(_impl.get());
+
+ _impl->op_binary_elementwise_function(impl_dst, opcode, impl_first, impl_second);
+}
+
+void KernelWriter::op_ternary_elementwise_function(const TileOperand &dst, TernaryFunction opcode, const TileOperand &first, const TileOperand &second, const TileOperand &third)
+{
+ auto impl_dst = dst.create_impl_operand(_impl.get());
+ auto impl_first = first.create_impl_operand(_impl.get());
+ auto impl_second = second.create_impl_operand(_impl.get());
+ auto impl_third = third.create_impl_operand(_impl.get());
+
+ _impl->op_ternary_elementwise_function(impl_dst, opcode, impl_first, impl_second, impl_third);
+}
+
+void KernelWriter::op_if(const TileOperand &lhs, BinaryOp op, const TileOperand &rhs, const std::function<void()> &body)
+{
+ auto impl_lhs = lhs.create_impl_operand(_impl.get());
+ auto impl_rhs = rhs.create_impl_operand(_impl.get());
+
+ _impl->op_if_header(impl_lhs, op, impl_rhs);
+ _impl->compound_statement_begin();
+ body();
+ _impl->compound_statement_end();
+}
+
+void KernelWriter::op_else_if(const TileOperand &lhs, BinaryOp op, const TileOperand &rhs, const std::function<void()> &body)
+{
+ auto impl_lhs = lhs.create_impl_operand(_impl.get());
+ auto impl_rhs = rhs.create_impl_operand(_impl.get());
+
+ _impl->op_else_if_header(impl_lhs, op, impl_rhs);
+ _impl->compound_statement_begin();
+ body();
+ _impl->compound_statement_end();
+}
+
+void KernelWriter::op_else(const std::function<void()> &body)
+{
+ _impl->op_else_header();
+ _impl->compound_statement_begin();
+ body();
+ _impl->compound_statement_end();
+}
+
+void KernelWriter::op_for_loop(const TileOperand &var_name, BinaryOp cond_op, const TileOperand &cond_value_name, AssignmentOp update_op, const TileOperand &update_value_name, const std::function<void()> &body)
+{
+ auto impl_var_name = var_name.create_impl_operand(_impl.get());
+ auto impl_cond_value_name = cond_value_name.create_impl_operand(_impl.get());
+ auto impl_update_value_name = update_value_name.create_impl_operand(_impl.get());
+
+ _impl->op_for_loop_header(impl_var_name, cond_op, impl_cond_value_name, update_op, impl_update_value_name);
+ _impl->compound_statement_begin();
+ body();
+ _impl->compound_statement_end();
}
// =================================================================================================
@@ -215,6 +292,11 @@ void KernelWriter::op_get_global_id(TileOperand &dst, int32_t dim)
_impl->op_get_global_id(prototype::Operand(dst.name()), dim);
}
+void KernelWriter::op_return()
+{
+ _impl->op_return();
+}
+
// =================================================================================================
// Code generation
// =================================================================================================
diff --git a/compute_kernel_writer/prototype/src/Prototype.h b/compute_kernel_writer/prototype/src/Prototype.h
index fdb4ab1bab..18f284b2b1 100644
--- a/compute_kernel_writer/prototype/src/Prototype.h
+++ b/compute_kernel_writer/prototype/src/Prototype.h
@@ -31,6 +31,7 @@
#include <chrono>
#include <cmath>
#include <cstdint> // int32_t
+#include <functional>
#include <iostream> // cout (to be removed)
#include <map>
#include <memory>
@@ -41,7 +42,12 @@
#include "ckw/Error.h"
#include "ckw/TensorInfo.h"
-#include "ckw/Types.h"
+#include "ckw/types/ConvertPolicy.h"
+#include "ckw/types/DataType.h"
+#include "ckw/types/Functions.h"
+#include "ckw/types/GpuTargetLanguage.h"
+#include "ckw/types/Operators.h"
+#include "ckw/types/TensorSamplerTypes.h"
namespace ckw
{
@@ -1548,6 +1554,18 @@ inline std::string to_string(AssignmentOp op)
}
}
+inline std::string to_string(UnaryOp op)
+{
+ switch(op)
+ {
+ case UnaryOp::LogicalNot:
+ return "!";
+ default:
+ assert(false);
+ return "";
+ }
+}
+
inline std::string to_string(BinaryOp op)
{
switch(op)
@@ -1576,8 +1594,6 @@ inline std::string to_string(BinaryOp op)
return "&&";
case BinaryOp::LogicalOr:
return "||";
- case BinaryOp::LogicalNot:
- return "!";
default:
assert(false);
return "";
@@ -2407,12 +2423,6 @@ struct GpuKernelWriterAttribute
bool return_tensor_component_by_value{ false };
};
-enum class ConvertPolicy
-{
- Wrap, /**< Wrap around */
- Saturate /**< Saturate */
-};
-
enum class RoundingMode
{
None,
@@ -2445,36 +2455,44 @@ public:
virtual void compound_statement_end() = 0;
// Operations
- virtual void op_get_global_id(const Operand &dst_var, int32_t dim) = 0;
+ virtual void op_get_global_id(const Operand &dst_var, int32_t dim) = 0;
+
+ virtual void op_get_global_coord(const Operand &dst, const Operand &step, const TensorOperand &tensor, int32_t dim) = 0;
- virtual void op_get_global_coord(const Operand &dst, const Operand &step, const TensorOperand &tensor, int32_t dim) = 0;
+ virtual void op_get_global_batch(const Operand &dst, const TensorOperand &tensor) = 0;
- virtual void op_get_global_batch(const Operand &dst, const TensorOperand &tensor) = 0;
+ virtual void op_get_global_size(const Operand &dst_var, int32_t dim) = 0;
- virtual void op_get_global_size(const Operand &dst_var, int32_t dim) = 0;
+ virtual void op_unary_expression(const Operand &dst, UnaryOp op, const Operand &src) = 0;
- virtual void op_binary_expression(const Operand &dst, const Operand &lhs, BinaryOp op, const Operand &rhs) = 0;
+ virtual void op_binary_expression(const Operand &dst, const Operand &lhs, BinaryOp op, const Operand &rhs) = 0;
- virtual void op_assign(const Operand &dst_name, const Operand &src_name) = 0;
+ virtual void op_assign(const Operand &dst_name, const Operand &src_name) = 0;
- virtual void op_scalar_function(const Operand &dst_name, const Operand &src_name, ScalarUnaryFunction func) = 0;
+ virtual void op_unary_elementwise_function(const Operand &dst_name, UnaryFunction func, const Operand &src_name) = 0;
- virtual void op_if(const Operand &lhs, BinaryOp op, const Operand &rhs) = 0;
+ virtual void op_binary_elementwise_function(const Operand &dst_name, BinaryFunction func, const Operand &first_name, const Operand &second_name) = 0;
- virtual void op_for_loop(const Operand &var_name, BinaryOp cond_op, const Operand &cond_value, AssignmentOp update_op, const Operand &update_value) = 0;
+ virtual void op_ternary_elementwise_function(const Operand &dst_name, TernaryFunction func, const Operand &first_name, const Operand &second_name, const Operand &third_name) = 0;
- virtual void op_load_indirect(const TensorOperand &tensor, const Operand &dst, const Operand &x, const Operand &y_indirect, const Operand &z, const Operand &b = Operand("0", OperandType::ScalarInt32)) = 0;
+ virtual void op_if_header(const Operand &lhs, BinaryOp op, const Operand &rhs) = 0;
+
+ virtual void op_else_if_header(const Operand &lhs, BinaryOp op, const Operand &rhs) = 0;
+
+ virtual void op_else_header() = 0;
+
+ virtual void op_for_loop_header(const Operand &var_name, BinaryOp cond_op, const Operand &cond_value, AssignmentOp update_op, const Operand &update_value) = 0;
+
+ virtual void op_load_indirect(const TensorOperand &tensor, const Operand &dst, const Operand &x, const Operand &y_indirect, const Operand &z, const Operand &b = Operand("0", OperandType::ScalarInt32)) = 0;
virtual void op_load_immediate(const TensorOperand &tensor, const Operand &dst, const Operand &x, const Operand &y, const Operand &z, const Operand &b = Operand("0", OperandType::ScalarInt32), const Operand &dilation_y = Operand("1", OperandType::ScalarInt32)) = 0;
- virtual void op_store_immediate(const TensorOperand &tensor, const Operand &src, const Operand &x, const Operand &y, const Operand &z, const Operand &b = Operand("0", OperandType::ScalarInt32)) = 0;
+ virtual void op_store_immediate(const TensorOperand &tensor, const Operand &src, const Operand &x, const Operand &y, const Operand &z, const Operand &b = Operand("0", OperandType::ScalarInt32)) = 0;
- virtual void op_cast_expression(const Operand &dst, const Operand &src, ConvertPolicy policy) = 0;
+ virtual void op_cast_expression(const Operand &dst, const Operand &src, ConvertPolicy policy) = 0;
- virtual void op_return() = 0;
+ virtual void op_return() = 0;
- // virtual void op_else() = 0;
- // virtual void op_elseif() = 0;
// Utils
// It is the process of converting
virtual void util_get_indirect_buffer(const Operand &dst, const TensorOperand &tensor, const Operand &x,
@@ -2929,10 +2947,10 @@ private:
std::string to_ls_buffer_address(const std::string &x, const std::string &y, const std::string &z,
const std::string &b) const
{
- auto tensor_storage = static_cast<GpuTensorStorage>(_mapper.gpu_sampler().storage);
+ auto tensor_storage = static_cast<GpuTensorStorage>(_mapper.gpu_sampler().storage);
assert(tensor_storage == GpuTensorStorage::BufferUint8Ptr);
- const std::string ptr_buf = _mapper.tensor_argument()->storage(tensor_storage);
- const std::string dst_type = get_cl_data_type(_dst->format().dt, 1);
+ const std::string ptr_buf = _mapper.tensor_argument()->storage(tensor_storage);
+ const std::string dst_type = get_cl_data_type(_dst->format().dt, 1);
std::string address;
address += "(__global ";
@@ -3135,7 +3153,6 @@ private:
auto tensor_storage = static_cast<GpuTensorStorage>(_mapper.gpu_sampler().storage);
const std::string image2d_obj = _mapper.tensor_argument()->storage(tensor_storage);
- // const DataType dt = _dst->format().dt;
const std::string post_fix = _dst->format().dt == DataType::Fp32 ? "f" : "h";
switch(type)
@@ -3242,7 +3259,7 @@ public:
};
// This utility method needs to go in utils.h
-inline bool is_tile_scalar(IVectorTile *x)
+inline bool is_tile_scalar(const IVectorTile *x)
{
return x->format().w == 1 && x->format().h == 1;
}
@@ -3415,11 +3432,11 @@ public:
void op_get_global_batch(const Operand &o_dst, const TensorOperand &o_tensor) override
{
- OperandUnpacker operands(_data->tiles, _data->arguments);
- auto dst = operands.unpack(o_dst);
+ OperandUnpacker operands(_data->tiles, _data->arguments);
+ const IVectorTile *dst = operands.unpack(o_dst);
TensorOperandUnpacker tensor_operands(_data->arguments);
- auto tensor = tensor_operands.unpack(o_tensor);
+ IGpuTensorArgument *tensor = tensor_operands.unpack(o_tensor);
auto gpu_sampler = o_tensor.sampler();
GpuTensor3dMapper mapper(tensor, gpu_sampler);
@@ -3450,13 +3467,39 @@ public:
_data->code += ");\n";
}
+ void op_unary_expression(const Operand &dst_name, UnaryOp op, const Operand &src_name) override
+ {
+ OperandUnpacker operands(_data->tiles, _data->arguments);
+ const IVectorTile *src = operands.unpack(src_name);
+ const IVectorTile *dst = operands.unpack(dst_name);
+
+ const int32_t dst_w = dst->format().w;
+ const int32_t dst_h = dst->format().h;
+ const int32_t src_w = src->format().w;
+ const std::string dt = dst->underlying_source_variables()[0].type.str;
+
+ const bool broadcast_src_x = dst_w != 1 && src_w == 1;
+
+ const std::string src_prefix = broadcast_src_x ? "(" + dt + ")" : "";
+
+ // Broadcasting on Y is automatic
+ for(int32_t y = 0; y < dst_h; ++y)
+ {
+ _data->code += dst->vector(y).str;
+ _data->code += " = ";
+ _data->code += to_string(op);
+ _data->code += src_prefix + src->vector(y).str;
+ _data->code += ";\n";
+ }
+ }
+
void op_binary_expression(const Operand &dst_name, const Operand &lhs_name, BinaryOp op,
const Operand &rhs_name) override
{
- OperandUnpacker operands(_data->tiles, _data->arguments);
- auto lhs = operands.unpack(lhs_name);
- auto rhs = operands.unpack(rhs_name);
- auto dst = operands.unpack(dst_name);
+ OperandUnpacker operands(_data->tiles, _data->arguments);
+ const IVectorTile *lhs = operands.unpack(lhs_name);
+ const IVectorTile *rhs = operands.unpack(rhs_name);
+ const IVectorTile *dst = operands.unpack(dst_name);
const int32_t dst_w = dst->format().w;
const int32_t dst_h = dst->format().h;
@@ -3488,12 +3531,12 @@ public:
return;
}
- bool broadcast_lhs_x = dst_w != 1 && lhs_w == 1;
- bool broadcast_rhs_x = dst_w != 1 && rhs_w == 1;
+ const bool broadcast_lhs_x = dst_w != 1 && lhs_w == 1;
+ const bool broadcast_rhs_x = dst_w != 1 && rhs_w == 1;
- std::string lhs_prefix = broadcast_lhs_x ? "(" + dst->underlying_source_variables()[0].type.str + ")" : "";
- std::string rhs_prefix = broadcast_rhs_x ? "(" + dst->underlying_source_variables()[0].type.str + ")" : "";
- std::string op_str = to_string(op);
+ const std::string lhs_prefix = broadcast_lhs_x ? "(" + dst->underlying_source_variables()[0].type.str + ")" : "";
+ const std::string rhs_prefix = broadcast_rhs_x ? "(" + dst->underlying_source_variables()[0].type.str + ")" : "";
+ const std::string op_str = to_string(op);
// Broadcasting on Y is automatic
for(int32_t y = 0; y < dst_h; ++y)
@@ -3511,21 +3554,20 @@ public:
void op_cast_expression(const Operand &o_dst, const Operand &o_src, ConvertPolicy policy) override
{
- CKW_UNUSED(policy);
-
- OperandUnpacker operands(_data->tiles, _data->arguments);
- auto src = operands.unpack(o_src);
- auto dst = operands.unpack(o_dst);
+ OperandUnpacker operands(_data->tiles, _data->arguments);
+ const IVectorTile *src = operands.unpack(o_src);
+ const IVectorTile *dst = operands.unpack(o_dst);
// const int32_t dst_w = dst->format().w;
const int32_t dst_h = dst->format().h;
- const std::string dt = dst->scalar(0, 0).type.str;
+ const std::string dt = dst->underlying_source_variables()[0].type.str;
+ const std::string sat = (policy == ConvertPolicy::Saturate ? "_sat" : "");
// Broadcasting on Y is automatic
for(int32_t y = 0; y < dst_h; ++y)
{
_data->code += dst->vector(y).str;
- _data->code += " = convert_" + dt + "(";
+ _data->code += " = convert_" + dt + sat + "(";
_data->code += src->vector(y).str;
_data->code += ");\n";
}
@@ -3533,19 +3575,18 @@ public:
void op_assign(const Operand &dst_name, const Operand &src_name) override
{
- OperandUnpacker operands(_data->tiles, _data->arguments);
- auto src = operands.unpack(src_name);
- auto dst = operands.unpack(dst_name);
+ OperandUnpacker operands(_data->tiles, _data->arguments);
+ const IVectorTile *src = operands.unpack(src_name);
+ const IVectorTile *dst = operands.unpack(dst_name);
- const int32_t dst_w = dst->format().w;
- const int32_t dst_h = dst->format().h;
- const int32_t src_w = src->format().w;
- // const int32_t src_h = src->format().h;
- const std::string dt = dst->scalar(0, 0).type.str;
+ const int32_t dst_w = dst->format().w;
+ const int32_t dst_h = dst->format().h;
+ const int32_t src_w = src->format().w;
+ const std::string dt = dst->underlying_source_variables()[0].type.str;
- bool broadcast_src_x = dst_w != 1 && src_w == 1;
+ const bool broadcast_src_x = dst_w != 1 && src_w == 1;
- std::string src_prefix = broadcast_src_x ? "(" + dt + ")" : "";
+ const std::string src_prefix = broadcast_src_x ? "(" + dt + ")" : "";
// Broadcasting on Y is automatic
for(int32_t y = 0; y < dst_h; ++y)
@@ -3558,21 +3599,20 @@ public:
}
void
- op_scalar_function(const Operand &dst_name, const Operand &src_name, ScalarUnaryFunction func) override
+ op_unary_elementwise_function(const Operand &dst_name, UnaryFunction func, const Operand &src_name) override
{
- OperandUnpacker operands(_data->tiles, _data->arguments);
- auto src = operands.unpack(src_name);
- auto dst = operands.unpack(dst_name);
+ OperandUnpacker operands(_data->tiles, _data->arguments);
+ const IVectorTile *src = operands.unpack(src_name);
+ const IVectorTile *dst = operands.unpack(dst_name);
- const int32_t dst_w = dst->format().w;
- const int32_t dst_h = dst->format().h;
- const int32_t src_w = src->format().w;
- // const int32_t src_h = src->format().h;
- const std::string dt = dst->scalar(0, 0).type.str;
+ const int32_t dst_w = dst->format().w;
+ const int32_t dst_h = dst->format().h;
+ const int32_t src_w = src->format().w;
+ const std::string dt = dst->underlying_source_variables()[0].type.str;
- bool broadcast_src_x = dst_w != 1 && src_w == 1;
+ const bool broadcast_src_x = dst_w != 1 && src_w == 1;
- std::string src_prefix = broadcast_src_x ? "(" + dt + ")" : "";
+ const std::string src_prefix = broadcast_src_x ? "(" + dt + ")" : "";
// Broadcasting on Y is automatic
for(int32_t y = 0; y < dst_h; ++y)
@@ -3582,12 +3622,35 @@ public:
switch(func)
{
- case ScalarUnaryFunction::Exp:
+ case UnaryFunction::Exp:
_data->code += "exp(";
break;
-
+ case UnaryFunction::Tanh:
+ _data->code += "tanh(";
+ break;
+ case UnaryFunction::Sqrt:
+ _data->code += "sqrt(";
+ break;
+ case UnaryFunction::Erf:
+ _data->code += "erf(";
+ break;
+ case UnaryFunction::Fabs:
+ _data->code += "fabs(";
+ break;
+ case UnaryFunction::IsGreaterEqual:
+ _data->code += "isgreaterequal(";
+ break;
+ case UnaryFunction::Log:
+ _data->code += "log(";
+ break;
+ case UnaryFunction::SizeOf:
+ _data->code += "sizeof(";
+ break;
+ case UnaryFunction::Round:
+ _data->code += "round(";
+ break;
default:
- CKW_ASSERT(false);
+ CKW_ASSERT_MSG(false, "Unexpected UnaryFunction used.");
}
_data->code += src_prefix + src->vector(y).str;
@@ -3595,11 +3658,105 @@ public:
}
}
- void op_if(const Operand &o_lhs, BinaryOp op, const Operand &o_rhs) override
+ void op_binary_elementwise_function(const Operand &dst_name, BinaryFunction func, const Operand &first_name, const Operand &second_name) override
{
- OperandUnpacker operands(_data->tiles, _data->arguments);
- auto lhs = operands.unpack(o_lhs);
- auto rhs = operands.unpack(o_rhs);
+ OperandUnpacker operands(_data->tiles, _data->arguments);
+ const IVectorTile *first = operands.unpack(first_name);
+ const IVectorTile *second = operands.unpack(second_name);
+ const IVectorTile *dst = operands.unpack(dst_name);
+
+ const int32_t dst_w = dst->format().w;
+ const int32_t dst_h = dst->format().h;
+ const int32_t first_w = first->format().w;
+ const int32_t second_w = second->format().w;
+ const auto datatype = dst->underlying_source_variables()[0].type;
+ const std::string datatype_str = datatype.str;
+
+ const bool broadcast_first_x = dst_w != 1 && first_w == 1;
+ const bool broadcast_second_x = dst_w != 1 && second_w == 1;
+
+ const std::string first_prefix = broadcast_first_x ? "(" + datatype_str + ")" : "";
+ const std::string second_prefix = broadcast_second_x ? "(" + datatype_str + ")" : "";
+
+ const bool is_float = (datatype.dt == DataType::Fp32 || datatype.dt == DataType::Fp16);
+
+ // Broadcasting on Y is automatic
+ for(int32_t y = 0; y < dst_h; ++y)
+ {
+ _data->code += dst->vector(y).str;
+ _data->code += " = ";
+
+ switch(func)
+ {
+ case BinaryFunction::Min:
+ _data->code += is_float ? "fmin(" : "min(";
+ break;
+ case BinaryFunction::Max:
+ _data->code += is_float ? "fmax(" : "max(";
+ break;
+ default:
+ CKW_ASSERT_MSG(false, "Unexpected BinaryFunction used.");
+ }
+
+ _data->code += first_prefix + first->vector(y).str;
+ _data->code += ", ";
+ _data->code += second_prefix + second->vector(y).str;
+ _data->code += ");\n";
+ }
+ }
+
+ void op_ternary_elementwise_function(const Operand &dst_name, TernaryFunction func, const Operand &first_name, const Operand &second_name, const Operand &third_name) override
+ {
+ OperandUnpacker operands(_data->tiles, _data->arguments);
+ const IVectorTile *first = operands.unpack(first_name);
+ const IVectorTile *second = operands.unpack(second_name);
+ const IVectorTile *third = operands.unpack(third_name);
+ const IVectorTile *dst = operands.unpack(dst_name);
+
+ const int32_t dst_w = dst->format().w;
+ const int32_t dst_h = dst->format().h;
+ const int32_t first_w = first->format().w;
+ const int32_t second_w = second->format().w;
+ const int32_t third_w = third->format().w;
+ const std::string dt = dst->underlying_source_variables()[0].type.str;
+
+ const bool broadcast_first_x = dst_w != 1 && first_w == 1;
+ const bool broadcast_second_x = dst_w != 1 && second_w == 1;
+ const bool broadcast_third_x = dst_w != 1 && third_w == 1;
+
+ const std::string first_prefix = broadcast_first_x ? "(" + dt + ")" : "";
+ const std::string second_prefix = broadcast_second_x ? "(" + dt + ")" : "";
+ const std::string third_prefix = broadcast_third_x ? "(" + dt + ")" : "";
+
+ // Broadcasting on Y is automatic
+ for(int32_t y = 0; y < dst_h; ++y)
+ {
+ _data->code += dst->vector(y).str;
+ _data->code += " = ";
+
+ switch(func)
+ {
+ case TernaryFunction::Select:
+ _data->code += "select(";
+ break;
+ default:
+ CKW_ASSERT_MSG(false, "Unexpected TernaryFunction used.");
+ }
+
+ _data->code += first_prefix + first->vector(y).str;
+ _data->code += ", ";
+ _data->code += second_prefix + second->vector(y).str;
+ _data->code += ", ";
+ _data->code += third_prefix + third->vector(y).str;
+ _data->code += ");\n";
+ }
+ }
+
+ void op_if_header(const Operand &o_lhs, BinaryOp op, const Operand &o_rhs) override
+ {
+ OperandUnpacker operands(_data->tiles, _data->arguments);
+ const IVectorTile *lhs = operands.unpack(o_lhs);
+ const IVectorTile *rhs = operands.unpack(o_rhs);
assert(is_tile_scalar(lhs));
assert(is_tile_scalar(rhs));
@@ -3613,13 +3770,23 @@ public:
_data->code += ")\n";
}
- void op_for_loop(const Operand &var_name, BinaryOp cond_op, const Operand &cond_value_name,
- AssignmentOp update_op, const Operand &update_value_name) override
+ void op_else_if_header(const Operand &o_lhs, BinaryOp op, const Operand &o_rhs) override
{
- OperandUnpacker operands(_data->tiles, _data->arguments);
- auto var = operands.unpack(var_name);
- auto cond_value = operands.unpack(cond_value_name);
- auto update_value = operands.unpack(update_value_name);
+ _data->code += "else ";
+ op_if_header(o_lhs, op, o_rhs);
+ }
+
+ void op_else_header() override
+ {
+ _data->code += "else\n";
+ }
+
+ void op_for_loop_header(const Operand& var_name, BinaryOp cond_op, const Operand& cond_value_name, AssignmentOp update_op, const Operand& update_value_name) override
+ {
+ OperandUnpacker operands(_data->tiles, _data->arguments);
+ const IVectorTile *var = operands.unpack(var_name);
+ const IVectorTile *cond_value = operands.unpack(cond_value_name);
+ const IVectorTile *update_value = operands.unpack(update_value_name);
const int32_t dst_w = var->format().w;
const int32_t dst_h = var->format().h;
@@ -3646,15 +3813,17 @@ public:
const Operand &dilation_y) override
{
OperandUnpacker operands(_data->tiles, _data->arguments);
- auto dst = operands.unpack(o_dst);
- auto x = operands.unpack(o_x);
- auto y = operands.unpack(o_y);
- auto z = operands.unpack(o_z);
- auto dil_y = operands.unpack(dilation_y);
- auto b = operands.unpack(o_batch_idx);
+
+ // Not const as it requires changes to 'load_writer'.
+ IVectorTile *dst = operands.unpack(o_dst);
+ IVectorTile *x = operands.unpack(o_x);
+ IVectorTile *y = operands.unpack(o_y);
+ IVectorTile *z = operands.unpack(o_z);
+ IVectorTile *dil_y = operands.unpack(dilation_y);
+ IVectorTile *b = operands.unpack(o_batch_idx);
TensorOperandUnpacker tensor_operands(_data->arguments);
- auto tensor = tensor_operands.unpack(o_tensor);
+ IGpuTensorArgument *tensor = tensor_operands.unpack(o_tensor);
auto gpu_sampler = o_tensor.sampler();
GpuTensor3dMapper mapper(tensor, gpu_sampler);
@@ -3682,14 +3851,16 @@ public:
const Operand &o_batch_idx) override
{
OperandUnpacker operands(_data->tiles, _data->arguments);
- auto dst = operands.unpack(o_dst);
- auto x = operands.unpack(o_x);
- auto y_ind = operands.unpack(o_indirect_h);
- auto z = operands.unpack(o_z);
- auto b = operands.unpack(o_batch_idx);
+
+ // Not const as it requires changes to 'load_writer'.
+ IVectorTile *dst = operands.unpack(o_dst);
+ IVectorTile *x = operands.unpack(o_x);
+ IVectorTile *y_ind = operands.unpack(o_indirect_h);
+ IVectorTile *z = operands.unpack(o_z);
+ IVectorTile *b = operands.unpack(o_batch_idx);
TensorOperandUnpacker tensor_operands(_data->arguments);
- auto tensor = tensor_operands.unpack(o_tensor);
+ IGpuTensorArgument *tensor = tensor_operands.unpack(o_tensor);
auto gpu_sampler = o_tensor.sampler();
GpuTensor3dMapper mapper(tensor, gpu_sampler);
@@ -3712,14 +3883,16 @@ public:
const Operand &batch_index_name) override
{
OperandUnpacker operands(_data->tiles, _data->arguments);
- auto src = operands.unpack(src_name);
- auto x = operands.unpack(x_name);
- auto y = operands.unpack(y_name);
- auto z = operands.unpack(z_name);
- auto b = operands.unpack(batch_index_name);
+
+ // Not const as it requires changes to 'load_writer'.
+ IVectorTile *src = operands.unpack(src_name);
+ IVectorTile *x = operands.unpack(x_name);
+ IVectorTile *y = operands.unpack(y_name);
+ IVectorTile *z = operands.unpack(z_name);
+ IVectorTile *b = operands.unpack(batch_index_name);
TensorOperandUnpacker tensor_operands(_data->arguments);
- auto tensor = tensor_operands.unpack(tensor_name);
+ IGpuTensorArgument *tensor = tensor_operands.unpack(tensor_name);
auto gpu_sampler = tensor_name.sampler();
GpuTensor3dMapper mapper(tensor, gpu_sampler);
@@ -3747,15 +3920,15 @@ public:
void util_get_indirect_buffer(const Operand &o_dst, const TensorOperand &o_tensor, const Operand &o_x,
const Operand &o_y, const Operand &o_x_off, const Operand &o_y_off) override
{
- OperandUnpacker operands(_data->tiles, _data->arguments);
- auto dst = operands.unpack(o_dst);
- auto x = operands.unpack(o_x);
- auto y = operands.unpack(o_y);
- auto x_off = operands.unpack(o_x_off);
- auto y_off = operands.unpack(o_y_off);
+ OperandUnpacker operands(_data->tiles, _data->arguments);
+ const IVectorTile *dst = operands.unpack(o_dst);
+ const IVectorTile *x = operands.unpack(o_x);
+ const IVectorTile *y = operands.unpack(o_y);
+ const IVectorTile *x_off = operands.unpack(o_x_off);
+ const IVectorTile *y_off = operands.unpack(o_y_off);
TensorOperandUnpacker tensor_operands(_data->arguments);
- auto tensor = tensor_operands.unpack(o_tensor);
+ IGpuTensorArgument *tensor = tensor_operands.unpack(o_tensor);
assert(dst->format().w == 1);
assert(x->format().w == 1);
diff --git a/compute_kernel_writer/prototype/src/TensorTileSampler.cpp b/compute_kernel_writer/prototype/src/TensorTileSampler.cpp
index 143d550dec..28e54df3a5 100644
--- a/compute_kernel_writer/prototype/src/TensorTileSampler.cpp
+++ b/compute_kernel_writer/prototype/src/TensorTileSampler.cpp
@@ -24,7 +24,7 @@
#include "ckw/TensorTileSampler.h"
#include "ckw/TileOperand.h"
-#include "ckw/Types.h"
+#include "ckw/types/TensorSamplerTypes.h"
namespace ckw
{
diff --git a/compute_kernel_writer/src/TensorUtils.cpp b/compute_kernel_writer/src/TensorUtils.cpp
index cc179b4b51..4970de75a6 100644
--- a/compute_kernel_writer/src/TensorUtils.cpp
+++ b/compute_kernel_writer/src/TensorUtils.cpp
@@ -24,7 +24,6 @@
#include "ckw/Error.h"
#include "ckw/TensorInfo.h"
-#include "ckw/Types.h"
#include "src/TensorUtils.h"
diff --git a/compute_kernel_writer/src/cl/CLConstantTile.h b/compute_kernel_writer/src/cl/CLConstantTile.h
index c8318487e6..658fb63f7f 100644
--- a/compute_kernel_writer/src/cl/CLConstantTile.h
+++ b/compute_kernel_writer/src/cl/CLConstantTile.h
@@ -47,15 +47,15 @@ public:
CLConstantTile(const TileContainer &vals, DataType dt);
// Inherited method overridden
- TileVariable scalar(int32_t row, int32_t col) const override;
+ TileVariable scalar(int32_t row, int32_t col) const override;
- TileVariable vector(int32_t row) const override;
+ TileVariable vector(int32_t row) const override;
- TileVariable vector(int32_t row, int32_t col_start, int32_t width) const override;
+ TileVariable vector(int32_t row, int32_t col_start, int32_t width) const override;
std::vector<TileVariable> all() const override;
- bool is_assignable() const override;
+ bool is_assignable() const override;
private:
TileContainer _vals{};
diff --git a/compute_kernel_writer/src/cl/CLHelpers.cpp b/compute_kernel_writer/src/cl/CLHelpers.cpp
index 68d7db252b..d940a5a529 100644
--- a/compute_kernel_writer/src/cl/CLHelpers.cpp
+++ b/compute_kernel_writer/src/cl/CLHelpers.cpp
@@ -22,7 +22,7 @@
* SOFTWARE.
*/
#include "ckw/Error.h"
-#include "ckw/Types.h"
+#include "ckw/types/DataType.h"
#include "src/cl/CLHelpers.h"
diff --git a/compute_kernel_writer/src/cl/CLTile.h b/compute_kernel_writer/src/cl/CLTile.h
index 039bd5613f..7e69de847b 100644
--- a/compute_kernel_writer/src/cl/CLTile.h
+++ b/compute_kernel_writer/src/cl/CLTile.h
@@ -47,15 +47,15 @@ public:
CLTile(const std::string &name, const TileInfo &info);
// Inherited method overridden
- TileVariable scalar(int32_t row, int32_t col) const override;
+ TileVariable scalar(int32_t row, int32_t col) const override;
- TileVariable vector(int32_t row) const override;
+ TileVariable vector(int32_t row) const override;
- TileVariable vector(int32_t row, int32_t col_start, int32_t width) const override;
+ TileVariable vector(int32_t row, int32_t col_start, int32_t width) const override;
std::vector<TileVariable> all() const override;
- bool is_assignable() const override;
+ bool is_assignable() const override;
private:
std::string create_var_name(int32_t row) const;
diff --git a/src/dynamic_fusion/sketch/gpu/ckw_driver/GpuCkwVariableTable.cpp b/src/dynamic_fusion/sketch/gpu/ckw_driver/GpuCkwVariableTable.cpp
index 4475586db8..154968775c 100644
--- a/src/dynamic_fusion/sketch/gpu/ckw_driver/GpuCkwVariableTable.cpp
+++ b/src/dynamic_fusion/sketch/gpu/ckw_driver/GpuCkwVariableTable.cpp
@@ -59,7 +59,7 @@ GpuCkwComponentArgument *GpuCkwVariableTable::declare_variable(const GpuKernelCo
std::stringstream ss;
ss << alias << "_t" << abs(tensor->id());
const auto uniq_name = ss.str();
- GpuCkwComponentArgument var{ writer->create_tensor_argument(uniq_name.c_str(), to_ckw(*tensor)) };
+ GpuCkwComponentArgument var{ writer->declare_tensor_argument(uniq_name.c_str(), to_ckw(*tensor)) };
auto &&inserted = _vars.emplace(tensor->id(), var);
return &(inserted.first->second);
}
diff --git a/src/dynamic_fusion/sketch/gpu/ckw_driver/components/GpuCkwElementwiseBinary.cpp b/src/dynamic_fusion/sketch/gpu/ckw_driver/components/GpuCkwElementwiseBinary.cpp
index cba1cfbe40..685bf391dc 100644
--- a/src/dynamic_fusion/sketch/gpu/ckw_driver/components/GpuCkwElementwiseBinary.cpp
+++ b/src/dynamic_fusion/sketch/gpu/ckw_driver/components/GpuCkwElementwiseBinary.cpp
@@ -28,7 +28,7 @@
#include "arm_compute/core/Error.h"
#include "arm_compute/core/Validate.h"
#include "ckw/TensorTileSampler.h"
-#include "ckw/Types.h"
+#include "ckw/types/TensorSamplerTypes.h"
#include "src/core/helpers/WindowHelpers.h"
#include "src/dynamic_fusion/sketch/gpu/GpuKernelComponentGroup.h"
#include "src/dynamic_fusion/sketch/gpu/ckw_driver/GpuCkwVariableTable.h"
@@ -120,7 +120,7 @@ void GpuCkwElementwiseBinary::write_component_code(const ComponentGroup &comp_gr
auto &dst_tile = dst->tile();
// Perform the operation.
- writer->op_binary_expression(dst_tile, lhs_tile, rhs_tile, BinaryOp::Add);
+ writer->op_binary_expression(dst_tile, lhs_tile, BinaryOp::Add, rhs_tile);
}
Window GpuCkwElementwiseBinary::get_window() const