aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorViet-Hoa Do <viet-hoa.do@arm.com>2023-07-25 14:00:46 +0100
committerViet-Hoa Do <viet-hoa.do@arm.com>2023-07-27 14:34:04 +0000
commit0b23e0e6402cb18ddf621d36454cadbb73959518 (patch)
tree244c32e5a44a8c2a644cb6a1e965c114175d2515
parent9662ac062bafe454afb77a563648e5577c5a8360 (diff)
downloadComputeLibrary-0b23e0e6402cb18ddf621d36454cadbb73959518.tar.gz
Add TensorOperand and declare tensor argument
Partially resolves: COMPMID-6391 Signed-off-by: Viet-Hoa Do <viet-hoa.do@arm.com> Change-Id: I849d486401f99a93919015f2e173559dca5bffa2 Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/9972 Tested-by: Arm Jenkins <bsgcomp@arm.com> Reviewed-by: Gunes Bayir <gunes.bayir@arm.com> Reviewed-by: Gian Marco Iodice <gianmarco.iodice@arm.com> Comments-Addressed: Arm Jenkins <bsgcomp@arm.com> Benchmark: Arm Jenkins <bsgcomp@arm.com>
-rw-r--r--compute_kernel_writer/CMakeLists.txt2
-rw-r--r--compute_kernel_writer/include/ckw/KernelWriter.h33
-rw-r--r--compute_kernel_writer/include/ckw/TensorOperand.h100
-rw-r--r--compute_kernel_writer/include/ckw/TileOperand.h4
-rw-r--r--compute_kernel_writer/src/ITensor.h46
-rw-r--r--compute_kernel_writer/src/ITensorArgument.h39
-rw-r--r--compute_kernel_writer/src/ITensorComponent.h53
-rw-r--r--compute_kernel_writer/src/KernelWriter.cpp12
-rw-r--r--compute_kernel_writer/src/TensorOperand.cpp111
-rw-r--r--compute_kernel_writer/src/cl/CLKernelWriter.cpp14
-rw-r--r--compute_kernel_writer/src/cl/CLKernelWriter.h13
-rw-r--r--compute_kernel_writer/src/cl/CLTensorArgument.cpp150
-rw-r--r--compute_kernel_writer/src/cl/CLTensorArgument.h40
-rw-r--r--compute_kernel_writer/src/cl/CLTensorComponent.cpp123
-rw-r--r--compute_kernel_writer/src/cl/CLTensorComponent.h80
-rw-r--r--compute_kernel_writer/validation/tests/CLTensorArgumentTest.h59
16 files changed, 723 insertions, 156 deletions
diff --git a/compute_kernel_writer/CMakeLists.txt b/compute_kernel_writer/CMakeLists.txt
index 659c1058b1..1e82f9c6b3 100644
--- a/compute_kernel_writer/CMakeLists.txt
+++ b/compute_kernel_writer/CMakeLists.txt
@@ -123,6 +123,7 @@ target_sources(ckw PRIVATE
src/Kernel.cpp
src/KernelWriter.cpp
src/TensorInfo.cpp
+ src/TensorOperand.cpp
src/TensorSampler.cpp
src/TensorUtils.cpp
src/TileInfo.cpp
@@ -132,6 +133,7 @@ target_sources(ckw PRIVATE
if(CKW_ENABLE_OPENCL)
target_sources(ckw PRIVATE
src/cl/CLTensorArgument.cpp
+ src/cl/CLTensorComponent.cpp
src/cl/CLHelpers.cpp
src/cl/CLTile.cpp
src/cl/CLKernelWriter.cpp
diff --git a/compute_kernel_writer/include/ckw/KernelWriter.h b/compute_kernel_writer/include/ckw/KernelWriter.h
index cfd24d35a3..894f7b6758 100644
--- a/compute_kernel_writer/include/ckw/KernelWriter.h
+++ b/compute_kernel_writer/include/ckw/KernelWriter.h
@@ -25,6 +25,7 @@
#ifndef CKW_INCLUDE_CKW_KERNELWRITER_H
#define CKW_INCLUDE_CKW_KERNELWRITER_H
+#include "ckw/TensorOperand.h"
#include "ckw/TileOperand.h"
#include <memory>
@@ -36,6 +37,7 @@ namespace ckw
class Kernel;
/** Forward Declerations */
+class TensorInfo;
class TileInfo;
enum class TargetArchitecture;
enum class TargetLanguage;
@@ -95,6 +97,19 @@ public:
*/
virtual std::unique_ptr<Kernel> emit_kernel(const std::string &name) = 0;
+ // =============================================================================================
+ // Tensor and tile declaration
+ // =============================================================================================
+
+ /** Declare a tensor argument.
+ *
+ * @param[in] name The name of the tensor.
+ * @param[in] info The tensor info.
+ *
+ * @return The @ref TensorOperand object.
+ */
+ virtual TensorOperand declare_tensor_argument(const std::string &name, const TensorInfo &info) = 0;
+
/** Declare a tile given its name and tile info
*
* @param[in] name Name of the tile
@@ -110,20 +125,18 @@ protected:
/** Generate full variable name by prefixing it with id space */
std::string generate_full_name(const std::string &name) const;
- /** Create a new tile operand referring to the specified tile object.
- *
- * This class has friendship relationship with @ref TileOperand which allows it
- * to access the private constructor.
- */
+ /** Create a new tile operand referring to the specified tile object. */
static TileOperand create_tile_operand(ITile &tile);
- /** Get the reference to tile object from the tile operand.
- *
- * This class has friendship relationship with @ref TileOperand which allows it
- * to access the private reference to the private tile field.
- */
+ /** Get the reference to tile object from the tile operand. */
static ITile &get_tile(const TileOperand &operand);
+ /** Create a new tensor operand from a tensor object. */
+ static TensorOperand create_tensor_operand(ITensor &tensor);
+
+ /** Get the reference to tensor object from the tensor operand. */
+ static ITensor &get_tensor(const TensorOperand &operand);
+
private:
int32_t _id_space{ 0 };
};
diff --git a/compute_kernel_writer/include/ckw/TensorOperand.h b/compute_kernel_writer/include/ckw/TensorOperand.h
new file mode 100644
index 0000000000..2672cd5334
--- /dev/null
+++ b/compute_kernel_writer/include/ckw/TensorOperand.h
@@ -0,0 +1,100 @@
+/*
+ * Copyright (c) 2023 Arm Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#ifndef CKW_INCLUDE_CKW_TENSOROPERAND_H
+#define CKW_INCLUDE_CKW_TENSOROPERAND_H
+
+#include "ckw/TileOperand.h"
+
+namespace ckw
+{
+
+class ITensor;
+class TensorInfo;
+
+/** A tensor operand provides access to the tensor info, tensor storages for load/store operations
+ * and tensor components (e.g. shape, strides, etc.) in the form of @ref TileOperand objects.
+ */
+class TensorOperand
+{
+public:
+ // _tensor field is completely hidden from the public API to avoid any misuse.
+ // Only kernel writer class interacts with tensor operand hence we allow it to access this field.
+ friend class KernelWriter;
+
+ /** Get the tensor info. */
+ const TensorInfo &info() const;
+
+ /** Get the operand that contains the stride in dimension 0 of the tensor. */
+ TileOperand stride0();
+
+ /** Get the operand that contains the stride in dimension 1 of the tensor. */
+ TileOperand stride1();
+
+ /** Get the operand that contains the stride in dimension 2 of the tensor. */
+ TileOperand stride2();
+
+ /** Get the operand that contains the stride in dimension 3 of the tensor. */
+ TileOperand stride3();
+
+ /** Get the operand that contains the stride in dimension 4 of the tensor. */
+ TileOperand stride4();
+
+ /** Get the operand that contains the size of dimension 0 of the tensor. */
+ TileOperand dim0();
+
+ /** Get the operand that contains the size of dimension 1 of the tensor. */
+ TileOperand dim1();
+
+ /** Get the operand that contains the size of dimension 2 of the tensor. */
+ TileOperand dim2();
+
+ /** Get the operand that contains the size of dimension 3 of the tensor. */
+ TileOperand dim3();
+
+ /** Get the operand that contains the size of dimension 4 of the tensor. */
+ TileOperand dim4();
+
+ /** Get the operand that contains the size of dimensions 1 and 2 collapsed. */
+ TileOperand dim1_dim2();
+
+ /** Get the operand that contains the size of dimensions 1, 2 and 3 collapsed. */
+ TileOperand dim1_dim2_dim3();
+
+ /** Get the operand that contains the size of dimensions 2 and 3 collapsed. */
+ TileOperand dim2_dim3();
+
+ /** Get the operand that contains the offset in bytes to the first element. */
+ TileOperand offset_first_element_in_bytes();
+
+private:
+ /** Initialize a new instance of @ref TensorOperand class for a tensor. */
+ TensorOperand(ITensor &tensor);
+
+ ITensor &_tensor;
+};
+
+} // namespace ckw
+
+#endif // CKW_INCLUDE_CKW_TENSOROPERAND_H
diff --git a/compute_kernel_writer/include/ckw/TileOperand.h b/compute_kernel_writer/include/ckw/TileOperand.h
index fe44b73e82..873a9825f3 100644
--- a/compute_kernel_writer/include/ckw/TileOperand.h
+++ b/compute_kernel_writer/include/ckw/TileOperand.h
@@ -29,6 +29,7 @@ namespace ckw
{
class KernelWriter;
+class TensorOperand;
class ITile;
/** A tile operand refers to a tile object that can be used for kernel writing. */
@@ -36,8 +37,9 @@ class TileOperand
{
public:
// The constructor and _tile field is completely hidden from the public API to avoid any misuse.
- // Only kernel writer class interacts with tile operand hence we allow it to access this field.
+ // Only kernel writer and tensor operand classes create and interact with tile operand hence we allow them to access this field.
friend class KernelWriter;
+ friend class TensorOperand;
private:
// These are hidden from the public API to avoid any misuse.
diff --git a/compute_kernel_writer/src/ITensor.h b/compute_kernel_writer/src/ITensor.h
new file mode 100644
index 0000000000..4c1c56fd35
--- /dev/null
+++ b/compute_kernel_writer/src/ITensor.h
@@ -0,0 +1,46 @@
+/*
+ * Copyright (c) 2023 Arm Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#ifndef CKW_SRC_ITENSOR_H
+#define CKW_SRC_ITENSOR_H
+
+#include "src/ITensorArgument.h"
+
+namespace ckw
+{
+
+/** The generic class for all tensor objects in CKW.
+ *
+ * Tensors in CKW are always kernel arguments consisting of:
+ * - Essential information such as name, tensor info, etc.
+ * - Tensor storage access: allowing load/store operation to perform.
+ * - Tensor component access: allowing interaction with tensor information such as shape, strides, etc. in the form of tile objects.
+ */
+class ITensor : public ITensorArgument, public ITensorStorageAccess, public ITensorComponentAccess
+{
+};
+
+} // namespace ckw
+
+#endif // CKW_SRC_ITENSOR_H \ No newline at end of file
diff --git a/compute_kernel_writer/src/ITensorArgument.h b/compute_kernel_writer/src/ITensorArgument.h
index 40ad69fdc0..838bd40f85 100644
--- a/compute_kernel_writer/src/ITensorArgument.h
+++ b/compute_kernel_writer/src/ITensorArgument.h
@@ -35,11 +35,14 @@
namespace ckw
{
+
+class ITensorComponent;
+
/** Tensor storage variable */
struct TensorStorageVariable
{
- std::string val{ "" }; /** Tensor storage as a string */
- std::string type{ "" }; /** Tensor storage type as a string */
+ std::string val{ "" }; /** Tensor storage as a string */
+ TensorStorageType type{ TensorStorageType::Unknown }; /** Tensor storage type */
};
/** Tensor argument base class.
@@ -60,11 +63,21 @@ public:
{
return _basename;
}
+
+ /** Method to get the tensor info
+ *
+ * @return the @ref TensorInfo
+ */
+ TensorInfo &info()
+ {
+ return _info;
+ }
+
/** Method to get the tensor info
*
* @return the @ref TensorInfo
*/
- TensorInfo info() const
+ const TensorInfo &info() const
{
return _info;
}
@@ -75,38 +88,38 @@ protected:
};
/** Tensor component argument base class */
-class ITensorComponentArgument
+class ITensorComponentAccess
{
public:
- virtual ~ITensorComponentArgument() = default;
- /** Method to get the tensor component variable as a string
+ virtual ~ITensorComponentAccess() = default;
+ /** Method to get the tensor component variable as a tile.
*
* @param[in] x The tensor component to query
*
- * @return the tensor component variable as a @ref TileVariable
+ * @return the tensor component variable as a @ref ITile.
*/
- virtual TileVariable component(TensorComponentType x) = 0;
+ virtual ITile &component(TensorComponentType x) = 0;
/** Method to get all tensor components needed to access the data in the tensor
*
* The tensor components returned by this method must be all passed as kernel argument
*
- * @return a vector containing all the tensor components as @ref TileVariable objects
+ * @return a vector containing all the tensor components as pointers to @ref ITensorComponent objects.
*/
- virtual std::vector<TileVariable> components() const = 0;
+ virtual std::vector<const ITensorComponent *> components() const = 0;
};
/** Tensor storage argument base class */
-class ITensorStorageArgument
+class ITensorStorageAccess
{
public:
- virtual ~ITensorStorageArgument() = default;
+ virtual ~ITensorStorageAccess() = default;
/** Method to get the tensor storage as a string
*
* @param[in] x The tensor storage to query
*
* @return the tensor storage as a @ref TensorStorageVariable
*/
- virtual TensorStorageVariable storage(TensorStorageType x) = 0;
+ virtual TensorStorageVariable &storage(TensorStorageType x) = 0;
/** Method to get all tensor storages needed to access the data in the tensor
*
* The tensor storages returned by this method must be all passed as kernel argument
diff --git a/compute_kernel_writer/src/ITensorComponent.h b/compute_kernel_writer/src/ITensorComponent.h
new file mode 100644
index 0000000000..e2775b62b0
--- /dev/null
+++ b/compute_kernel_writer/src/ITensorComponent.h
@@ -0,0 +1,53 @@
+/*
+ * Copyright (c) 2023 Arm Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#ifndef CKW_SRC_ITENSORCOMPONENT_H
+#define CKW_SRC_ITENSORCOMPONENT_H
+
+#include "ckw/types/TensorComponentType.h"
+#include "src/ITile.h"
+
+namespace ckw
+{
+
+/** A tensor component provides access to tensor information such as shape, strides, etc. in the form of @ref ITile objects. */
+class ITensorComponent
+{
+public:
+ /** Destructor. */
+ virtual ~ITensorComponent() = default;
+
+ /** Get the tile variable for the component. */
+ virtual ITile &tile() = 0;
+
+ /** Get the const tile variable for the component. */
+ virtual const ITile &tile() const = 0;
+
+ /** Get the component type. */
+ virtual TensorComponentType component_type() const = 0;
+};
+
+} // namespace ckw
+
+#endif // CKW_SRC_ITENSORCOMPONENT_H
diff --git a/compute_kernel_writer/src/KernelWriter.cpp b/compute_kernel_writer/src/KernelWriter.cpp
index 7b83eade6f..ce34a1c2d6 100644
--- a/compute_kernel_writer/src/KernelWriter.cpp
+++ b/compute_kernel_writer/src/KernelWriter.cpp
@@ -28,6 +28,8 @@
#include "ckw/types/TargetArchitecture.h"
#include "ckw/types/TargetLanguage.h"
#include "src/cl/CLKernelWriter.h"
+#include "src/cl/CLTensorArgument.h"
+#include "src/cl/CLTile.h"
namespace ckw
{
@@ -69,4 +71,14 @@ ITile &KernelWriter::get_tile(const TileOperand &operand)
return operand._tile;
}
+TensorOperand KernelWriter::create_tensor_operand(ITensor &tensor)
+{
+ return TensorOperand(tensor);
+}
+
+ITensor &KernelWriter::get_tensor(const TensorOperand &operand)
+{
+ return operand._tensor;
+}
+
} // namespace ckw
diff --git a/compute_kernel_writer/src/TensorOperand.cpp b/compute_kernel_writer/src/TensorOperand.cpp
new file mode 100644
index 0000000000..5ad24c6276
--- /dev/null
+++ b/compute_kernel_writer/src/TensorOperand.cpp
@@ -0,0 +1,111 @@
+/*
+ * Copyright (c) 2023 Arm Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#include "ckw/TensorOperand.h"
+#include "src/ITensor.h"
+
+namespace ckw
+{
+
+TensorOperand::TensorOperand(ITensor &tensor)
+ : _tensor(tensor)
+{
+}
+
+const TensorInfo &TensorOperand::info() const
+{
+ return _tensor.info();
+}
+
+TileOperand TensorOperand::stride0()
+{
+ return TileOperand(_tensor.component(TensorComponentType::Stride0));
+}
+
+TileOperand TensorOperand::stride1()
+{
+ return TileOperand(_tensor.component(TensorComponentType::Stride1));
+}
+
+TileOperand TensorOperand::stride2()
+{
+ return TileOperand(_tensor.component(TensorComponentType::Stride2));
+}
+
+TileOperand TensorOperand::stride3()
+{
+ return TileOperand(_tensor.component(TensorComponentType::Stride3));
+}
+
+TileOperand TensorOperand::stride4()
+{
+ return TileOperand(_tensor.component(TensorComponentType::Stride4));
+}
+
+TileOperand TensorOperand::dim0()
+{
+ return TileOperand(_tensor.component(TensorComponentType::Dim0));
+}
+
+TileOperand TensorOperand::dim1()
+{
+ return TileOperand(_tensor.component(TensorComponentType::Dim1));
+}
+
+TileOperand TensorOperand::dim2()
+{
+ return TileOperand(_tensor.component(TensorComponentType::Dim2));
+}
+
+TileOperand TensorOperand::dim3()
+{
+ return TileOperand(_tensor.component(TensorComponentType::Dim3));
+}
+
+TileOperand TensorOperand::dim4()
+{
+ return TileOperand(_tensor.component(TensorComponentType::Dim4));
+}
+
+TileOperand TensorOperand::dim1_dim2()
+{
+ return TileOperand(_tensor.component(TensorComponentType::Dim1xDim2));
+}
+
+TileOperand TensorOperand::dim1_dim2_dim3()
+{
+ return TileOperand(_tensor.component(TensorComponentType::Dim1xDim2xDim3));
+}
+
+TileOperand TensorOperand::dim2_dim3()
+{
+ return TileOperand(_tensor.component(TensorComponentType::Dim2xDim3));
+}
+
+TileOperand TensorOperand::offset_first_element_in_bytes()
+{
+ return TileOperand(_tensor.component(TensorComponentType::OffsetFirstElement));
+}
+
+} // namespace ckw \ No newline at end of file
diff --git a/compute_kernel_writer/src/cl/CLKernelWriter.cpp b/compute_kernel_writer/src/cl/CLKernelWriter.cpp
index bc056c67a2..69b5244aa2 100644
--- a/compute_kernel_writer/src/cl/CLKernelWriter.cpp
+++ b/compute_kernel_writer/src/cl/CLKernelWriter.cpp
@@ -24,8 +24,10 @@
#include "src/cl/CLKernelWriter.h"
#include "ckw/Error.h"
+#include "ckw/Kernel.h"
#include "ckw/TileOperand.h"
#include "src/cl/CLHelpers.h"
+#include "src/cl/CLTensorArgument.h"
#include "src/cl/CLTile.h"
#include <cstdint>
@@ -62,6 +64,18 @@ const std::string &CLKernelWriter::body_source_code() const
return _body_source_code;
}
+TensorOperand CLKernelWriter::declare_tensor_argument(const std::string &name, const TensorInfo &info)
+{
+ const auto fullname = generate_full_name(name);
+
+ auto tensor = std::make_unique<CLTensorArgument>(fullname, info, false /* return_dims_by_value */);
+ const auto operand = create_tensor_operand(*tensor);
+
+ _tensors.insert(std::move(tensor));
+
+ return operand;
+}
+
TileOperand CLKernelWriter::declare_tile(const std::string &name, const TileInfo &tile_info)
{
const std::string fullname = generate_full_name(name);
diff --git a/compute_kernel_writer/src/cl/CLKernelWriter.h b/compute_kernel_writer/src/cl/CLKernelWriter.h
index c69a0bc07e..42d2b07ded 100644
--- a/compute_kernel_writer/src/cl/CLKernelWriter.h
+++ b/compute_kernel_writer/src/cl/CLKernelWriter.h
@@ -26,7 +26,6 @@
#define CKW_SRC_CL_CLKERNELWRITER_H
#include "ckw/KernelWriter.h"
-#include "src/cl/CLTile.h"
#include <memory>
#include <set>
@@ -35,6 +34,9 @@
namespace ckw
{
+class CLTile;
+class CLTensorArgument;
+
/** OpenCL kernel writer. */
class CLKernelWriter : public KernelWriter
{
@@ -61,6 +63,12 @@ public:
std::unique_ptr<Kernel> emit_kernel(const std::string &name) override;
+ // =============================================================================================
+ // Tensor and tile declaration
+ // =============================================================================================
+
+ TensorOperand declare_tensor_argument(const std::string &name, const TensorInfo &info) override;
+
/** Declare a tile given name and tile information
*
* Similar to @ref KernelWriter::declare_tile()
@@ -95,7 +103,8 @@ private:
*/
std::string _body_source_code{};
- std::set<std::unique_ptr<CLTile>> _tiles{};
+ std::set<std::unique_ptr<CLTensorArgument>> _tensors{};
+ std::set<std::unique_ptr<CLTile>> _tiles{};
};
} // namespace ckw
diff --git a/compute_kernel_writer/src/cl/CLTensorArgument.cpp b/compute_kernel_writer/src/cl/CLTensorArgument.cpp
index ed1c5bd687..7d4dc958df 100644
--- a/compute_kernel_writer/src/cl/CLTensorArgument.cpp
+++ b/compute_kernel_writer/src/cl/CLTensorArgument.cpp
@@ -24,7 +24,10 @@
#include "src/cl/CLTensorArgument.h"
#include "ckw/Error.h"
+#include "src/ITensorArgument.h"
+#include "src/ITensorComponent.h"
#include "src/cl/CLHelpers.h"
+#include "src/cl/CLTensorComponent.h"
#include "src/types/TensorComponentType.h"
#include <algorithm>
@@ -39,8 +42,25 @@ CLTensorArgument::CLTensorArgument(const std::string &name, const TensorInfo &in
_info = info;
}
-TileVariable CLTensorArgument::component(TensorComponentType x)
+CLTensorArgument::~CLTensorArgument() = default;
+
+CLTensorComponent &CLTensorArgument::cl_component(TensorComponentType x)
{
+ // Return the component if it has already been created.
+ {
+ const auto it = std::find_if(
+ _components_used.begin(), _components_used.end(),
+ [=](const std::unique_ptr<CLTensorComponent> &item)
+ {
+ return item->component_type() == x;
+ });
+
+ if(it != _components_used.end())
+ {
+ return **it;
+ }
+ }
+
if(_return_dims_by_value)
{
uint32_t component_type = static_cast<uint32_t>(x);
@@ -100,42 +120,47 @@ TileVariable CLTensorArgument::component(TensorComponentType x)
if(idx != kDynamicTensorDimensionValue)
{
- TileVariable t;
- t.str = std::to_string(idx);
- t.desc.dt = DataType::Uint32;
- t.desc.len = 1;
- return t;
+ _components_used.emplace_back(std::make_unique<CLTensorComponent>(*this, x, idx));
+
+ return *_components_used.back();
}
}
}
- auto it = std::find(_components_used.begin(), _components_used.end(), x);
+ _components_used.emplace_back(std::make_unique<CLTensorComponent>(*this, x));
- // Add to the list of used components if not present yet
- if(it == _components_used.end())
- {
- _components_used.push_back(x);
- }
+ return *_components_used.back();
+}
- TileVariable t;
- t.str = create_component_name(x);
- t.desc.dt = DataType::Int32;
- t.desc.len = 1;
- return t;
+ITile &CLTensorArgument::component(TensorComponentType x)
+{
+ return cl_component(x);
}
-TensorStorageVariable CLTensorArgument::storage(TensorStorageType x)
+TensorStorageVariable &CLTensorArgument::storage(TensorStorageType x)
{
- if(std::find(_storages_used.begin(), _storages_used.end(), x) == _storages_used.end())
+ // Return the storage if it has already been created.
{
- _storages_used.push_back(x);
+ const auto it = std::find_if(
+ _storages_used.begin(), _storages_used.end(),
+ [=](const TensorStorageVariable &item)
+ {
+ return item.type == x;
+ });
+
+ if(it != _storages_used.end())
+ {
+ return *it;
+ }
}
TensorStorageVariable t;
t.val = create_storage_name(x);
- t.type = cl_get_variable_storagetype_as_string(x);
+ t.type = x;
+
+ _storages_used.emplace_back(t);
- return t;
+ return _storages_used.back();
}
std::string CLTensorArgument::create_storage_name(TensorStorageType x) const
@@ -159,87 +184,26 @@ std::string CLTensorArgument::create_storage_name(TensorStorageType x) const
return var_name;
}
-std::string CLTensorArgument::create_component_name(TensorComponentType x) const
-{
- std::string var_name = _basename;
-
- switch(x)
- {
- case TensorComponentType::OffsetFirstElement:
- var_name += "_offset_first_element";
- break;
- case TensorComponentType::Stride0:
- var_name += "_stride0";
- break;
- case TensorComponentType::Stride1:
- var_name += "_stride1";
- break;
- case TensorComponentType::Stride2:
- var_name += "_stride2";
- break;
- case TensorComponentType::Stride3:
- var_name += "_stride3";
- break;
- case TensorComponentType::Stride4:
- var_name += "_stride4";
- break;
- case TensorComponentType::Dim0:
- var_name += "_dim0";
- break;
- case TensorComponentType::Dim1:
- var_name += "_dim1";
- break;
- case TensorComponentType::Dim2:
- var_name += "_dim2";
- break;
- case TensorComponentType::Dim3:
- var_name += "_dim3";
- break;
- case TensorComponentType::Dim4:
- var_name += "_dim4";
- break;
- case TensorComponentType::Dim1xDim2:
- var_name += "_dim1xdim2";
- break;
- case TensorComponentType::Dim2xDim3:
- var_name += "_dim2xdim3";
- break;
- case TensorComponentType::Dim1xDim2xDim3:
- var_name += "_dim1xdim2xdim3";
- break;
- default:
- COMPUTE_KERNEL_WRITER_ERROR_ON_MSG("Unsupported tensor component");
- return "";
- }
-
- return var_name;
-}
-
std::vector<TensorStorageVariable> CLTensorArgument::storages() const
{
std::vector<TensorStorageVariable> storages;
- for(auto &val : _storages_used)
- {
- TensorStorageVariable t;
- t.val = create_storage_name(val);
- t.type = cl_get_variable_storagetype_as_string(val);
- storages.push_back(t);
- }
+ storages.reserve(_storages_used.size());
+
+ std::copy(_storages_used.begin(), _storages_used.end(), std::back_inserter(storages));
return storages;
}
-std::vector<TileVariable> CLTensorArgument::components() const
+std::vector<const ITensorComponent *> CLTensorArgument::components() const
{
- std::vector<TileVariable> components;
+ std::vector<const ITensorComponent *> components;
- for(auto &val : _components_used)
+ for(const auto &component : _components_used)
{
- TileVariable t;
- t.str = create_component_name(val);
- t.desc.dt = DataType::Int32;
- t.desc.len = 1;
- components.push_back(t);
+ if(component->is_assignable())
+ {
+ components.push_back(component.get());
+ }
}
return components;
diff --git a/compute_kernel_writer/src/cl/CLTensorArgument.h b/compute_kernel_writer/src/cl/CLTensorArgument.h
index cd924846c5..4cbbee21ee 100644
--- a/compute_kernel_writer/src/cl/CLTensorArgument.h
+++ b/compute_kernel_writer/src/cl/CLTensorArgument.h
@@ -24,8 +24,10 @@
#ifndef CKW_SRC_CL_CLTENSORARGUMENT_H
#define CKW_SRC_CL_CLTENSORARGUMENT_H
-#include "src/ITensorArgument.h"
-
+#include "ckw/types/TensorComponentType.h"
+#include "ckw/types/TensorStorageType.h"
+#include "src/ITensor.h"
+#include <memory>
#include <string>
#include <vector>
@@ -34,12 +36,16 @@ namespace ckw
// Forward declarations
class TensorInfo;
+class ITensorComponent;
+class CLTensorComponent;
+class CLTensorStorage;
+
/** OpenCL specific tensor argument
* Internally, the object keeps track of the components and storages used to minimize the number
* of kernel arguments required. Therefore, if we create this object but we do not access any components
* or storages, the storages() and components() method will return an empty list.
*/
-class CLTensorArgument : public ITensorArgument, ITensorStorageArgument, ITensorComponentArgument
+class CLTensorArgument : public ITensor
{
public:
/** Constructor
@@ -51,20 +57,32 @@ public:
*/
CLTensorArgument(const std::string &name, const TensorInfo &info, bool return_dims_by_value);
+ /** Destructor. */
+ ~CLTensorArgument();
+
+ /** Get a tensor component of the given type.
+ *
+ * This function is for internal use as it returns a reference to @ref CLTensorComponent object.
+ * It provides rich functionalities and doesn't require unnecessary casting
+ * unlike @ref CLTensorComponent::component which is for the public API and only returns
+ * a reference to a generic @ref ITile object.
+ */
+ CLTensorComponent& cl_component(TensorComponentType component_type);
+
// Inherited method overridden
- TensorStorageVariable storage(TensorStorageType x);
- TileVariable component(TensorComponentType x);
- std::vector<TensorStorageVariable> storages() const;
- std::vector<TileVariable> components() const;
+ TensorStorageVariable &storage(TensorStorageType x) override;
+ ITile &component(TensorComponentType x) override;
+ std::vector<TensorStorageVariable> storages() const override;
+ std::vector<const ITensorComponent *> components() const override;
private:
std::string create_storage_name(TensorStorageType x) const;
- std::string create_component_name(TensorComponentType x) const;
- bool _return_dims_by_value{ false };
- std::vector<TensorStorageType> _storages_used{};
- std::vector<TensorComponentType> _components_used{};
+ bool _return_dims_by_value{ false };
+ std::vector<TensorStorageVariable> _storages_used{};
+ std::vector<std::unique_ptr<CLTensorComponent>> _components_used{};
};
+
} // namespace ckw
#endif // CKW_SRC_CL_CLTENSORARGUMENT_H
diff --git a/compute_kernel_writer/src/cl/CLTensorComponent.cpp b/compute_kernel_writer/src/cl/CLTensorComponent.cpp
new file mode 100644
index 0000000000..c29b307748
--- /dev/null
+++ b/compute_kernel_writer/src/cl/CLTensorComponent.cpp
@@ -0,0 +1,123 @@
+/*
+ * Copyright (c) 2023 Arm Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#include "src/cl/CLTensorComponent.h"
+#include "ckw/Error.h"
+#include "ckw/types/TensorComponentType.h"
+#include "src/cl/CLTensorArgument.h"
+#include "src/cl/CLTile.h"
+
+namespace ckw
+{
+
+namespace
+{
+
+std::string create_component_name(const std::string &name, TensorComponentType x)
+{
+ std::string var_name(name);
+
+ switch(x)
+ {
+ case TensorComponentType::OffsetFirstElement:
+ var_name += "_offset_first_element";
+ break;
+ case TensorComponentType::Stride0:
+ var_name += "_stride0";
+ break;
+ case TensorComponentType::Stride1:
+ var_name += "_stride1";
+ break;
+ case TensorComponentType::Stride2:
+ var_name += "_stride2";
+ break;
+ case TensorComponentType::Stride3:
+ var_name += "_stride3";
+ break;
+ case TensorComponentType::Stride4:
+ var_name += "_stride4";
+ break;
+ case TensorComponentType::Dim0:
+ var_name += "_dim0";
+ break;
+ case TensorComponentType::Dim1:
+ var_name += "_dim1";
+ break;
+ case TensorComponentType::Dim2:
+ var_name += "_dim2";
+ break;
+ case TensorComponentType::Dim3:
+ var_name += "_dim3";
+ break;
+ case TensorComponentType::Dim4:
+ var_name += "_dim4";
+ break;
+ case TensorComponentType::Dim1xDim2:
+ var_name += "_dim1xdim2";
+ break;
+ case TensorComponentType::Dim2xDim3:
+ var_name += "_dim2xdim3";
+ break;
+ case TensorComponentType::Dim1xDim2xDim3:
+ var_name += "_dim1xdim2xdim3";
+ break;
+ default:
+ CKW_THROW_MSG("Unsupported tensor component");
+ return "";
+ }
+
+ return var_name;
+}
+
+} // namespace
+
+CLTensorComponent::CLTensorComponent(const CLTensorArgument &tensor, TensorComponentType component_type)
+ : CLTile(create_component_name(tensor.name(), component_type), TileInfo(DataType::Int32)), _component_type(component_type)
+{
+}
+
+CLTensorComponent::CLTensorComponent(const CLTensorArgument &tensor, TensorComponentType component_type, int32_t value)
+ : CLTile({ { std::to_string(value) } }, DataType::Int32), _component_type(component_type)
+{
+ CKW_UNUSED(tensor);
+}
+
+CLTensorComponent::~CLTensorComponent() = default;
+
+ITile &CLTensorComponent::tile()
+{
+ return *this;
+}
+
+const ITile &CLTensorComponent::tile() const
+{
+ return *this;
+}
+
+TensorComponentType CLTensorComponent::component_type() const
+{
+ return _component_type;
+}
+
+} // namespace ckw
diff --git a/compute_kernel_writer/src/cl/CLTensorComponent.h b/compute_kernel_writer/src/cl/CLTensorComponent.h
new file mode 100644
index 0000000000..42a42666dc
--- /dev/null
+++ b/compute_kernel_writer/src/cl/CLTensorComponent.h
@@ -0,0 +1,80 @@
+/*
+ * Copyright (c) 2023 Arm Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#ifndef CKW_SRC_CL_CLTENSORCOMPONENT_H
+#define CKW_SRC_CL_CLTENSORCOMPONENT_H
+
+#include "ckw/types/TensorComponentType.h"
+#include "src/ITensorComponent.h"
+#include "src/cl/CLTile.h"
+
+namespace ckw
+{
+
+class CLTensorArgument;
+
+/** A tensor component object that can be used as a tile.
+ *
+ * The tensor component is created by @ref CLTensorArgument object when it is used
+ * either by the user or internally by a kernel writer operation.
+ * It allows the user to perform operation on tensor component just like any other tile.
+ *
+ * Because of the nature of tensor component, it's always a scalar tile of 32-bit integer.
+ *
+ * To find the list of all tensor components, see @ref TensorComponentType.
+ */
+class CLTensorComponent : public CLTile, public ITensorComponent
+{
+public:
+ /** Initialize a new instance of @ref CLTensorComponent class for dynamic component.
+ *
+ * @param[in] tensor The tensor to which this component belongs.
+ * @param[in] component_type The tensor component type.
+ */
+ CLTensorComponent(const CLTensorArgument &tensor, TensorComponentType component_type);
+
+ /** Initialize a new instance of @ref CLTensorComponent class for compile-time constant component.
+ *
+ * @param[in] tensor The tensor to which this component belongs.
+ * @param[in] component_type The tensor component type.
+ * @param[in] value The value of the component.
+ */
+ CLTensorComponent(const CLTensorArgument &tensor, TensorComponentType component_type, int32_t value);
+
+ /** Destructor. */
+ virtual ~CLTensorComponent();
+
+ ITile &tile() override;
+
+ const ITile &tile() const override;
+
+ TensorComponentType component_type() const override;
+
+private:
+ TensorComponentType _component_type{ TensorComponentType::Unknown };
+};
+
+} // namespace ckw
+
+#endif // CKW_SRC_CL_CLTENSORCOMPONENT_H
diff --git a/compute_kernel_writer/validation/tests/CLTensorArgumentTest.h b/compute_kernel_writer/validation/tests/CLTensorArgumentTest.h
index 6db1384247..d3e455cb83 100644
--- a/compute_kernel_writer/validation/tests/CLTensorArgumentTest.h
+++ b/compute_kernel_writer/validation/tests/CLTensorArgumentTest.h
@@ -22,12 +22,13 @@
* SOFTWARE.
*/
-#ifndef CKW_TESTS_CLTENSORARGUMENTTEST_H
-#define CKW_TESTS_CLTENSORARGUMENTTEST_H
+#ifndef CKW_VALIDATION_TESTS_CLTENSORARGUMENTTEST_H
+#define CKW_VALIDATION_TESTS_CLTENSORARGUMENTTEST_H
#include "common/Common.h"
#include "src/cl/CLHelpers.h"
#include "src/cl/CLTensorArgument.h"
+#include "src/cl/CLTensorComponent.h"
#include <string>
#include <vector>
@@ -89,7 +90,7 @@ public:
CLTensorArgument arg(tensor_name, info, false /* return_dims_by_value */);
const std::string expected_var_name = _expected_vars[i];
- const std::string actual_var_name = arg.component(_components[i]).str;
+ const std::string actual_var_name = arg.component(_components[i]).name();
VALIDATE_TEST(actual_var_name.compare(expected_var_name) == 0, all_tests_passed, test_idx++);
}
@@ -200,8 +201,8 @@ public:
{
CLTensorArgument arg(tensor_name, info, true /* return_dims_by_value */);
- const std::string expected_var_val = _expected_vals[i];
- const std::string actual_var_val = arg.component(_components[i]).str;
+ const std::string expected_var_val = std::string("((int)(") + _expected_vals[i] + "))";
+ const std::string actual_var_val = arg.cl_component(_components[i]).scalar(0, 0).str;
VALIDATE_TEST(actual_var_val.compare(expected_var_val) == 0, all_tests_passed, test_idx++);
}
@@ -276,18 +277,20 @@ public:
{
// Validate variable name
const std::string expected_var_name = _expected_vars[i];
- const std::string actual_var_name = actual_vars[i].str;
+ const std::string actual_var_name = actual_vars[i]->tile().name();
VALIDATE_TEST(actual_var_name.compare(expected_var_name) == 0, all_tests_passed, test_idx++);
// Validate data type
const DataType expected_var_type = DataType::Int32;
- const DataType actual_var_type = actual_vars[i].desc.dt;
+ const DataType actual_var_type = actual_vars[i]->tile().info().data_type();
VALIDATE_TEST(actual_var_type == expected_var_type, all_tests_passed, test_idx++);
- // Validate data type length
- const int32_t expected_var_len = 1;
- const int32_t actual_var_len = actual_vars[i].desc.len;
- VALIDATE_TEST(actual_var_len == expected_var_len, all_tests_passed, test_idx++);
+ // Validate tile shape
+ const int32_t actual_var_width = actual_vars[i]->tile().info().width();
+ const int32_t actual_var_height = actual_vars[i]->tile().info().height();
+
+ VALIDATE_TEST(actual_var_height == 1, all_tests_passed, test_idx++);
+ VALIDATE_TEST(actual_var_width == 1, all_tests_passed, test_idx++);
}
return all_tests_passed;
}
@@ -356,18 +359,20 @@ public:
{
// Validate variable name
const std::string expected_var_name = _expected_vars[i];
- const std::string actual_var_name = actual_vars[i].str;
+ const std::string actual_var_name = actual_vars[i]->tile().name();
VALIDATE_TEST(actual_var_name.compare(expected_var_name) == 0, all_tests_passed, test_idx++);
// Validate data type
const DataType expected_var_type = DataType::Int32;
- const DataType actual_var_type = actual_vars[i].desc.dt;
+ const DataType actual_var_type = actual_vars[i]->tile().info().data_type();
VALIDATE_TEST(actual_var_type == expected_var_type, all_tests_passed, test_idx++);
- // Validate data type length
- const int32_t expected_var_len = 1;
- const int32_t actual_var_len = actual_vars[i].desc.len;
- VALIDATE_TEST(actual_var_len == expected_var_len, all_tests_passed, test_idx++);
+ // Validate tile shape
+ const int32_t actual_var_width = actual_vars[i]->tile().info().width();
+ const int32_t actual_var_height = actual_vars[i]->tile().info().height();
+
+ VALIDATE_TEST(actual_var_height == 1, all_tests_passed, test_idx++);
+ VALIDATE_TEST(actual_var_width == 1, all_tests_passed, test_idx++);
}
return all_tests_passed;
}
@@ -430,8 +435,8 @@ public:
VALIDATE_TEST(actual_var_name.compare(expected_var_name) == 0, all_tests_passed, test_idx++);
// Validate storage type
- const std::string expected_var_type = cl_get_variable_storagetype_as_string(_storages[i]);
- const std::string actual_var_type = actual_vars[i].type;
+ const auto expected_var_type = _storages[i];
+ const auto actual_var_type = actual_vars[i].type;
VALIDATE_TEST(actual_var_type == expected_var_type, all_tests_passed, test_idx++);
}
return all_tests_passed;
@@ -503,18 +508,20 @@ public:
{
// Validate variable name
const std::string expected_var_name = _expected_vars[i];
- const std::string actual_var_name = actual_vars[i].str;
+ const std::string actual_var_name = actual_vars[i]->tile().name();
VALIDATE_TEST(actual_var_name.compare(expected_var_name) == 0, all_tests_passed, test_idx++);
// Validate data type
const DataType expected_var_type = DataType::Int32;
- const DataType actual_var_type = actual_vars[i].desc.dt;
+ const DataType actual_var_type = actual_vars[i]->tile().info().data_type();
VALIDATE_TEST(actual_var_type == expected_var_type, all_tests_passed, test_idx++);
- // Validate data type length
- const int32_t expected_var_len = 1;
- const int32_t actual_var_len = actual_vars[i].desc.len;
- VALIDATE_TEST(actual_var_len == expected_var_len, all_tests_passed, test_idx++);
+ // Validate tile shape
+ const int32_t actual_var_width = actual_vars[i]->tile().info().width();
+ const int32_t actual_var_height = actual_vars[i]->tile().info().height();
+
+ VALIDATE_TEST(actual_var_height == 1, all_tests_passed, test_idx++);
+ VALIDATE_TEST(actual_var_width == 1, all_tests_passed, test_idx++);
}
return all_tests_passed;
}
@@ -530,4 +537,4 @@ private:
};
} // namespace ckw
-#endif // CKW_TESTS_CLTENSORARGUMENTTEST_H
+#endif // CKW_VALIDATION_TESTS_CLTENSORARGUMENTTEST_H