aboutsummaryrefslogtreecommitdiff
path: root/compute_kernel_writer/prototype/src
diff options
context:
space:
mode:
authorViet-Hoa Do <viet-hoa.do@arm.com>2023-06-27 14:09:46 +0100
committerViet-Hoa Do <viet-hoa.do@arm.com>2023-07-12 15:46:50 +0000
commitc8e1617807ef1985a39d8f8f5f69c113b758494d (patch)
tree675113cc27d4e95cf61aa719fc29cc98a1ce4a50 /compute_kernel_writer/prototype/src
parent3c776066a0195f2e99d3503f8b058e468d53b884 (diff)
downloadComputeLibrary-c8e1617807ef1985a39d8f8f5f69c113b758494d.tar.gz
Add compute kernel writer arguments export
* The information is extracted from the prototype argument registry. Partially resolves: COMPMID-6283 Signed-off-by: Viet-Hoa Do <viet-hoa.do@arm.com> Change-Id: Ia6d69b7c2a2e411597e76a7e03b7c92199a16990 Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/9848 Reviewed-by: SiCong Li <sicong.li@arm.com> Reviewed-by: Jakub Sujak <jakub.sujak@arm.com> Tested-by: Arm Jenkins <bsgcomp@arm.com> Comments-Addressed: Arm Jenkins <bsgcomp@arm.com> Benchmark: Arm Jenkins <bsgcomp@arm.com>
Diffstat (limited to 'compute_kernel_writer/prototype/src')
-rw-r--r--compute_kernel_writer/prototype/src/Kernel.cpp99
-rw-r--r--compute_kernel_writer/prototype/src/KernelArgument.cpp66
-rw-r--r--compute_kernel_writer/prototype/src/KernelWriter.cpp37
-rw-r--r--compute_kernel_writer/prototype/src/Prototype.h201
-rw-r--r--compute_kernel_writer/prototype/src/TensorOperand.cpp105
5 files changed, 348 insertions, 160 deletions
diff --git a/compute_kernel_writer/prototype/src/Kernel.cpp b/compute_kernel_writer/prototype/src/Kernel.cpp
index 692d504887..884b69afc6 100644
--- a/compute_kernel_writer/prototype/src/Kernel.cpp
+++ b/compute_kernel_writer/prototype/src/Kernel.cpp
@@ -23,6 +23,7 @@
*/
#include "ckw/Kernel.h"
+#include "ckw/TensorOperand.h"
#include "ckw/types/GpuTargetLanguage.h"
#include "src/Prototype.h"
@@ -30,7 +31,7 @@ namespace ckw
{
Kernel::Kernel(const char *name, GpuTargetLanguage language)
- : _name(name), _kernel(std::make_unique<prototype::GpuKernelWriterDataHolder>(language)), _operands{}
+ : _name(name), _kernel(std::make_unique<prototype::GpuKernelWriterDataHolder>(language)), _operands{}, _tensor_id_operands{}
{
}
@@ -43,14 +44,102 @@ const std::string &Kernel::name() const
return _name;
}
-const std::map<std::string, std::unique_ptr<OperandBase>> &Kernel::operands() const
+std::vector<KernelArgument> Kernel::arguments() const
{
- return _operands;
+ std::vector<KernelArgument> arguments;
+
+ const auto impl_args = _kernel->arguments.tensor_argument_declarations();
+
+ for(auto tensor_arg : impl_args)
+ {
+ auto tensor = _tensor_id_operands.at(tensor_arg->format().id);
+ arguments.push_back(*tensor);
+
+ for(auto component_arg : tensor_arg->component_declarations())
+ {
+ switch(component_arg)
+ {
+ case TensorComponentType::OffsetFirstElement:
+ arguments.push_back(tensor->offset_first_element_in_bytes());
+ break;
+
+ case TensorComponentType::Stride1:
+ arguments.push_back(tensor->stride1());
+ break;
+
+ case TensorComponentType::Stride2:
+ arguments.push_back(tensor->stride2());
+ break;
+
+ case TensorComponentType::Stride3:
+ arguments.push_back(tensor->stride3());
+ break;
+
+ case TensorComponentType::Stride4:
+ arguments.push_back(tensor->stride4());
+ break;
+
+ case TensorComponentType::Dim0:
+ arguments.push_back(tensor->dim0());
+ break;
+
+ case TensorComponentType::Dim1:
+ arguments.push_back(tensor->dim1());
+ break;
+
+ case TensorComponentType::Dim2:
+ arguments.push_back(tensor->dim2());
+ break;
+
+ case TensorComponentType::Dim3:
+ arguments.push_back(tensor->dim3());
+ break;
+
+ case TensorComponentType::Dim4:
+ arguments.push_back(tensor->dim4());
+ break;
+
+ case TensorComponentType::Dim1xDim2:
+ arguments.push_back(tensor->dim1_dim2());
+ break;
+
+ case TensorComponentType::Dim1xDim2xDim3:
+ arguments.push_back(tensor->dim1_dim2_dim3());
+ break;
+
+ default:
+ CKW_ASSERT(false);
+ }
+ }
+ }
+
+ return arguments;
+}
+
+TileOperand &Kernel::register_operand(std::unique_ptr<TileOperand> operand)
+{
+ const auto &name = operand->name();
+ auto ptr = operand.get();
+
+ CKW_ASSERT(_operands.find(name) == _operands.end());
+ _operands[name] = std::move(operand);
+
+ return *ptr;
}
-std::map<std::string, std::unique_ptr<OperandBase>> &Kernel::operands()
+TensorOperand &Kernel::register_operand(std::unique_ptr<TensorOperand> operand)
{
- return _operands;
+ const auto id = operand->info().id();
+ const auto &name = operand->name();
+ auto ptr = operand.get();
+
+ CKW_ASSERT(_tensor_id_operands.find(id) == _tensor_id_operands.end());
+ CKW_ASSERT(_operands.find(name) == _operands.end());
+
+ _tensor_id_operands[id] = operand.get();
+ _operands[name] = std::move(operand);
+
+ return *ptr;
}
prototype::GpuKernelWriterDataHolder *Kernel::impl()
diff --git a/compute_kernel_writer/prototype/src/KernelArgument.cpp b/compute_kernel_writer/prototype/src/KernelArgument.cpp
new file mode 100644
index 0000000000..2b4d7c8cee
--- /dev/null
+++ b/compute_kernel_writer/prototype/src/KernelArgument.cpp
@@ -0,0 +1,66 @@
+/*
+ * Copyright (c) 2023 Arm Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#include "ckw/KernelArgument.h"
+#include "ckw/Error.h"
+#include "ckw/TensorOperand.h"
+
+namespace ckw
+{
+
+KernelArgument::KernelArgument(TensorOperand &tensor)
+ : _type(Type::TensorStorage), _id(tensor.info().id())
+{
+ _sub_id.tensor_storage_type = tensor.storage_type();
+}
+
+KernelArgument::KernelArgument(TensorComponentOperand &tensor_component)
+ : _type(Type::TensorComponent), _id(tensor_component.tensor().info().id())
+{
+ _sub_id.tensor_component_type = tensor_component.component_type();
+}
+
+KernelArgument::Type KernelArgument::type() const
+{
+ return _type;
+}
+
+int32_t KernelArgument::id() const
+{
+ return _id;
+}
+
+TensorStorageType KernelArgument::tensor_storage_type() const
+{
+ CKW_ASSERT(_type == Type::TensorStorage);
+ return _sub_id.tensor_storage_type;
+}
+
+TensorComponentType KernelArgument::tensor_component_type() const
+{
+ CKW_ASSERT(_type == Type::TensorComponent);
+ return _sub_id.tensor_component_type;
+}
+
+} // namespace ckw
diff --git a/compute_kernel_writer/prototype/src/KernelWriter.cpp b/compute_kernel_writer/prototype/src/KernelWriter.cpp
index 73458efa1d..1ac9ede5b5 100644
--- a/compute_kernel_writer/prototype/src/KernelWriter.cpp
+++ b/compute_kernel_writer/prototype/src/KernelWriter.cpp
@@ -24,6 +24,7 @@
#include "ckw/KernelWriter.h"
#include "ckw/Error.h"
+#include "ckw/TensorInfo.h"
#include "ckw/TensorOperand.h"
#include "src/Prototype.h"
@@ -85,26 +86,24 @@ int32_t KernelWriter::next_id_space()
// Tensor and tile declaration
// =================================================================================================
-TensorOperand &KernelWriter::declare_tensor_argument(const std::string &name, const TensorInfo &info)
+TensorOperand &KernelWriter::declare_tensor_argument(const std::string &name, const TensorInfo &info, TensorStorageType storage_type)
{
const auto var_name = generate_variable_name(name);
_impl->declare_argument(var_name, create_impl_tensor_info(info));
- auto operand = new TensorOperand(var_name, info);
- register_operand(operand, false);
+ auto &operand = _kernel->register_operand(std::make_unique<TensorOperand>(var_name, info, storage_type));
- return *operand;
+ return operand;
}
TileOperand &KernelWriter::declare_tile_argument(const std::string &name, int32_t value)
{
const auto var_name = generate_variable_name(name);
- auto operand = new TileOperand(var_name, value);
- register_operand(operand, false);
+ auto &operand = _kernel->register_operand(std::make_unique<TileOperand>(var_name, value));
- return *operand;
+ return operand;
}
std::string KernelWriter::generate_variable_name(const std::string &name) const
@@ -116,21 +115,21 @@ std::string KernelWriter::generate_variable_name(const std::string &name) const
return var_name.str();
}
-void KernelWriter::register_operand(OperandBase *operand, bool declaring)
+TileOperand &KernelWriter::declare_tile_operand(std::unique_ptr<TileOperand> operand_ptr)
{
- const auto &name = operand->name();
- auto &operands = _kernel->operands();
+ auto &operand = _kernel->register_operand(std::move(operand_ptr));
+ const auto &name = operand.name();
- CKW_ASSERT(operands.find(name) == operands.end());
- operands[name] = std::unique_ptr<OperandBase>(operand);
-
- if(declaring && !operand->is_constant())
+ if(!operand.is_constant())
{
- const auto tile = reinterpret_cast<TileOperand *>(operand);
+ const auto &info = operand.tile_info();
- const auto &info = tile->tile_info();
- _impl->declare_tile(tile->name(), prototype::TileInfo(info.data_type(), info.width(), info.height()));
+ _impl->declare_tile(
+ name,
+ prototype::TileInfo(info.data_type(), info.width(), info.height()));
}
+
+ return operand;
}
// =================================================================================================
@@ -143,7 +142,7 @@ void KernelWriter::op_load(TileOperand &tile, TensorOperand &tensor, const Tenso
tensor.name(),
prototype::GpuSampler{
sampler.format(),
- prototype::GpuSamplerTensorStorage::BufferUint8Ptr,
+ prototype::to_gpu_tensor_storage(tensor.storage_type()),
sampler.address_mode_x(),
sampler.address_mode_y(),
sampler.address_mode_z() });
@@ -164,7 +163,7 @@ void KernelWriter::op_store(TensorOperand &tensor, const TileOperand &tile, cons
tensor.name(),
prototype::GpuSampler{
sampler.format(),
- prototype::GpuSamplerTensorStorage::BufferUint8Ptr,
+ prototype::to_gpu_tensor_storage(tensor.storage_type()),
sampler.address_mode_x(),
sampler.address_mode_y(),
sampler.address_mode_z() });
diff --git a/compute_kernel_writer/prototype/src/Prototype.h b/compute_kernel_writer/prototype/src/Prototype.h
index 18f284b2b1..b9f1efa542 100644
--- a/compute_kernel_writer/prototype/src/Prototype.h
+++ b/compute_kernel_writer/prototype/src/Prototype.h
@@ -561,7 +561,7 @@ enum class TensorComponentIndex : int32_t
IndexMask = 0x0000000f,
};
-enum class TensorComponentType : int32_t
+enum class TensorComponentGroup : int32_t
{
OffsetFirstElement = 0x00000100,
Stride = 0x00001000,
@@ -570,62 +570,39 @@ enum class TensorComponentType : int32_t
Constant = 0x01000000
};
-enum class TensorComponent : int32_t
-{
- Unknown = 0x00000000,
- OffsetFirstElement = 0x00000100,
- Stride1 = 0x00001001,
- Stride2 = 0x00001002,
- Stride3 = 0x00001003,
- Stride4 = 0x00001004,
- Dim0 = 0x00010000,
- Dim1 = 0x00010001,
- Dim2 = 0x00010002,
- Dim3 = 0x00010003,
- Dim4 = 0x00010004,
- C = 0x00010000, // Dim0
- W = 0x00010001, // Dim1
- H = 0x00010002, // Dim2
- D = 0x00010003,
- N = 0x00010004,
- Dim1xDim2 = 0x00100021,
- Dim1xDim2xDim3 = 0x00100321,
- WxH = 0x00100021,
- WxHxD = 0x00100321
-};
-
-inline std::string to_string(TensorComponent x)
+inline std::string to_string(TensorComponentType x)
{
switch(x)
{
- case TensorComponent::Unknown:
+ case TensorComponentType::Unknown:
return "Unknown";
- case TensorComponent::OffsetFirstElement:
+ case TensorComponentType::OffsetFirstElement:
return "OffsetFirstElement";
- case TensorComponent::Stride1:
+ case TensorComponentType::Stride1:
return "Stride1";
- case TensorComponent::Stride2:
+ case TensorComponentType::Stride2:
return "Stride2";
- case TensorComponent::Stride3:
+ case TensorComponentType::Stride3:
return "Stride3";
- case TensorComponent::Stride4:
+ case TensorComponentType::Stride4:
return "Stride4";
- case TensorComponent::Dim0:
+ case TensorComponentType::Dim0:
return "Dim0";
- case TensorComponent::Dim1:
+ case TensorComponentType::Dim1:
return "Dim1";
- case TensorComponent::Dim2:
+ case TensorComponentType::Dim2:
return "Dim2";
- case TensorComponent::Dim3:
+ case TensorComponentType::Dim3:
return "Dim3";
- case TensorComponent::Dim4:
+ case TensorComponentType::Dim4:
return "Dim4";
- case TensorComponent::Dim1xDim2:
+ case TensorComponentType::Dim1xDim2:
return "Dim1xDim2";
- case TensorComponent::Dim1xDim2xDim3:
+ case TensorComponentType::Dim1xDim2xDim3:
return "Dim1xDim2xDim3";
default:
assert(false);
+ return "";
}
}
@@ -640,7 +617,7 @@ public:
*
* @return the tensor component as a string
*/
- virtual std::string component(TensorComponent x) = 0;
+ virtual std::string component(TensorComponentType x) = 0;
/** Method to get the tensor component type declaration as a string
*
@@ -658,7 +635,7 @@ public:
*
* @return a vector containing the tensor component declarations
*/
- virtual std::vector<TensorComponent> component_declarations() const = 0;
+ virtual std::vector<TensorComponentType> component_declarations() const = 0;
/** Method to get the name of the tensor argument.
*
@@ -693,6 +670,50 @@ enum class GpuTensorStorage : int32_t
Image3dWriteOnly = 0x0031
};
+inline GpuTensorStorage to_gpu_tensor_storage(TensorStorageType s)
+{
+ switch(s)
+ {
+ case TensorStorageType::Unknown:
+ return GpuTensorStorage::Unknown;
+
+ case TensorStorageType::BufferUint8Ptr:
+ return GpuTensorStorage::BufferUint8Ptr;
+
+ case TensorStorageType::Texture2dReadOnly:
+ return GpuTensorStorage::Image2dReadOnly;
+
+ case TensorStorageType::Texture2dWriteOnly:
+ return GpuTensorStorage::Image2dWriteOnly;
+
+ default:
+ assert(false);
+ return GpuTensorStorage::Unknown;
+ }
+}
+
+inline TensorStorageType to_tensor_storage(GpuTensorStorage s)
+{
+ switch(s)
+ {
+ case GpuTensorStorage::Unknown:
+ return TensorStorageType::Unknown;
+
+ case GpuTensorStorage::BufferUint8Ptr:
+ return TensorStorageType::BufferUint8Ptr;
+
+ case GpuTensorStorage::Image2dReadOnly:
+ return TensorStorageType::Texture2dReadOnly;
+
+ case GpuTensorStorage::Image2dWriteOnly:
+ return TensorStorageType::Texture2dWriteOnly;
+
+ default:
+ assert(false);
+ return TensorStorageType::Unknown;
+ }
+}
+
class IGpuTensorArgument : public ITensorArgument
{
public:
@@ -732,9 +753,9 @@ public:
}
// Methods to override
- std::string component(TensorComponent x) override
+ std::string component(TensorComponentType x) override
{
- if((static_cast<int32_t>(x) & static_cast<int32_t>(TensorComponentType::Constant)))
+ if((static_cast<int32_t>(x) & static_cast<int32_t>(TensorComponentGroup::Constant)))
{
int32_t idx = static_cast<int32_t>(x) & static_cast<int32_t>(TensorComponentIndex::IndexMask);
return std::to_string(idx - 1);
@@ -742,19 +763,19 @@ public:
if(_return_by_value_when_possible)
{
- if((static_cast<int32_t>(x) & static_cast<int32_t>(TensorComponentType::Dimension)))
+ if((static_cast<int32_t>(x) & static_cast<int32_t>(TensorComponentGroup::Dimension)))
{
int32_t idx = static_cast<int32_t>(x) & static_cast<int32_t>(TensorComponentIndex::IndexMask);
return std::to_string(_format.shape[idx]);
}
- if((static_cast<int32_t>(x) & static_cast<int32_t>(TensorComponentType::FoldedDimension)))
+ if((static_cast<int32_t>(x) & static_cast<int32_t>(TensorComponentGroup::FoldedDimension)))
{
switch(x)
{
- case TensorComponent::Dim1xDim2:
+ case TensorComponentType::Dim1xDim2:
return std::to_string(_format.shape[1] * _format.shape[2]);
- case TensorComponent::Dim1xDim2xDim3:
+ case TensorComponentType::Dim1xDim2xDim3:
return std::to_string(_format.shape[1] * _format.shape[2] * _format.shape[2]);
default:
std::cout << "Unsupported folded dimension" << std::endl;
@@ -817,7 +838,7 @@ public:
return _storage_required;
}
- std::vector<TensorComponent> component_declarations() const override
+ std::vector<TensorComponentType> component_declarations() const override
{
return _components_required;
}
@@ -845,31 +866,31 @@ private:
return var_name;
}
- std::string build_component_name(TensorComponent x) const
+ std::string build_component_name(TensorComponentType x) const
{
std::string var_name = _basename;
switch(x)
{
- case TensorComponent::OffsetFirstElement:
+ case TensorComponentType::OffsetFirstElement:
return var_name + "_offset_first_element";
- case TensorComponent::Stride1:
+ case TensorComponentType::Stride1:
return var_name + "_stride1";
- case TensorComponent::Stride2:
+ case TensorComponentType::Stride2:
return var_name + "_stride2";
- case TensorComponent::Stride3:
+ case TensorComponentType::Stride3:
return var_name + "_stride3";
- case TensorComponent::Dim0:
+ case TensorComponentType::Dim0:
return var_name + "_dim0";
- case TensorComponent::Dim1:
+ case TensorComponentType::Dim1:
return var_name + "_dim1";
- case TensorComponent::Dim2:
+ case TensorComponentType::Dim2:
return var_name + "_dim2";
- case TensorComponent::Dim3:
+ case TensorComponentType::Dim3:
return var_name + "_dim3";
- case TensorComponent::Dim1xDim2:
+ case TensorComponentType::Dim1xDim2:
return var_name + "_dim1xdim2";
- case TensorComponent::Dim1xDim2xDim3:
+ case TensorComponentType::Dim1xDim2xDim3:
return var_name + "_dim1xdim2xdim3";
default:
std::cout << "Unsupported component" << std::endl;
@@ -881,7 +902,7 @@ private:
bool _return_by_value_when_possible{ false };
std::vector<GpuTensorStorage> _storage_required{};
- std::vector<TensorComponent> _components_required{};
+ std::vector<TensorComponentType> _components_required{};
};
/**
@@ -1745,15 +1766,7 @@ private:
ScalarTileCoord _coord{};
};
-enum class GpuSamplerTensorStorage : int32_t
-{
- Unknown = static_cast<int32_t>(GpuTensorStorage::Unknown),
- BufferUint8Ptr = static_cast<int32_t>(GpuTensorStorage::BufferUint8Ptr),
- Image2dReadOnly = static_cast<int32_t>(GpuTensorStorage::Image2dReadOnly),
- Image2dWriteOnly = static_cast<int32_t>(GpuTensorStorage::Image2dWriteOnly),
- Image3dReadOnly = static_cast<int32_t>(GpuTensorStorage::Image3dReadOnly),
- Image3dWriteOnly = static_cast<int32_t>(GpuTensorStorage::Image2dWriteOnly),
-};
+using GpuSamplerTensorStorage = GpuTensorStorage;
struct GpuSampler
{
@@ -2098,37 +2111,37 @@ private:
return static_cast<DataType>(static_cast<int32_t>(x) & 0x00ff);
}
- TensorComponent to_tensor_component(OperandType x)
+ TensorComponentType to_tensor_component(OperandType x)
{
switch(x)
{
case OperandType::TensorDim0:
- return TensorComponent::Dim0;
+ return TensorComponentType::Dim0;
case OperandType::TensorDim1:
- return TensorComponent::Dim1;
+ return TensorComponentType::Dim1;
case OperandType::TensorDim2:
- return TensorComponent::Dim2;
+ return TensorComponentType::Dim2;
case OperandType::TensorDim3:
- return TensorComponent::Dim3;
+ return TensorComponentType::Dim3;
case OperandType::TensorDim4:
- return TensorComponent::Dim4;
+ return TensorComponentType::Dim4;
case OperandType::TensorStride1:
- return TensorComponent::Stride1;
+ return TensorComponentType::Stride1;
case OperandType::TensorStride2:
- return TensorComponent::Stride2;
+ return TensorComponentType::Stride2;
case OperandType::TensorStride3:
- return TensorComponent::Stride3;
+ return TensorComponentType::Stride3;
case OperandType::TensorStride4:
- return TensorComponent::Stride4;
+ return TensorComponentType::Stride4;
case OperandType::TensorDim1xDim2:
- return TensorComponent::Dim1xDim2;
+ return TensorComponentType::Dim1xDim2;
case OperandType::TensorDim1xDim2xDim3:
- return TensorComponent::Dim1xDim2xDim3;
+ return TensorComponentType::Dim1xDim2xDim3;
case OperandType::TensorDataOffset:
- return TensorComponent::OffsetFirstElement;
+ return TensorComponentType::OffsetFirstElement;
default:
assert(false);
- return TensorComponent::Unknown;
+ return TensorComponentType::Unknown;
}
}
@@ -2174,7 +2187,7 @@ struct GpuKernel
// Dispatch stage
GpuOutputSampler output_sampler{}; // GpuOutputSampler, required for the dispatch stage
std::vector<std::pair<int32_t, GpuTensorStorage>> list_tensor_storages; // List of tensor storages, required for the dispatch stage
- std::vector<std::pair<int32_t, TensorComponent>> list_tensor_components; // List of tensor components (width, stride,..), required for the dispatch stage)
+ std::vector<std::pair<int32_t, TensorComponentType>> list_tensor_components; // List of tensor components (width, stride,..), required for the dispatch stage)
};
// This function should produce an object with the source
@@ -2251,7 +2264,7 @@ public:
{
case TensorSamplerFormat::C_WH_1:
case TensorSamplerFormat::C_W_H:
- return _tensor->component(TensorComponent::C);
+ return _tensor->component(TensorComponentType::Dim0);
default:
std::cout << "Unsupported tensor format" << std::endl;
assert(false);
@@ -2265,9 +2278,9 @@ public:
switch(format)
{
case TensorSamplerFormat::C_WH_1:
- return _tensor->component(TensorComponent::WxH);
+ return _tensor->component(TensorComponentType::Dim1xDim2);
case TensorSamplerFormat::C_W_H:
- return _tensor->component(TensorComponent::W);
+ return _tensor->component(TensorComponentType::Dim1);
default:
std::cout << "Unsupported tensor format" << std::endl;
assert(false);
@@ -2283,7 +2296,7 @@ public:
case TensorSamplerFormat::C_WH_1:
return "1";
case TensorSamplerFormat::C_W_H:
- return _tensor->component(TensorComponent::H);
+ return _tensor->component(TensorComponentType::Dim2);
default:
std::cout << "Unsupported tensor format" << std::endl;
assert(false);
@@ -2298,7 +2311,7 @@ public:
{
case TensorSamplerFormat::C_WH_1:
case TensorSamplerFormat::C_W_H:
- return _tensor->component(TensorComponent::Stride1);
+ return _tensor->component(TensorComponentType::Stride1);
default:
std::cout << "Unsupported tensor format" << std::endl;
assert(false);
@@ -2314,7 +2327,7 @@ public:
case TensorSamplerFormat::C_WH_1:
return "0";
case TensorSamplerFormat::C_W_H:
- return _tensor->component(TensorComponent::Stride2);
+ return _tensor->component(TensorComponentType::Stride2);
default:
std::cout << "Unsupported tensor format" << std::endl;
assert(false);
@@ -2329,7 +2342,7 @@ public:
{
case TensorSamplerFormat::C_WH_1:
case TensorSamplerFormat::C_W_H:
- return _tensor->component(TensorComponent::Stride3);
+ return _tensor->component(TensorComponentType::Stride3);
default:
std::cout << "Unsupported tensor format" << std::endl;
assert(false);
@@ -3941,9 +3954,9 @@ public:
assert(x_off->format().dt == DataType::Int32);
assert(y_off->format().dt == DataType::Int32);
- const std::string width = tensor->component(TensorComponent::W);
- const std::string height = tensor->component(TensorComponent::H);
- const std::string wxh = tensor->component(TensorComponent::WxH);
+ const std::string width = tensor->component(TensorComponentType::Dim1);
+ const std::string height = tensor->component(TensorComponentType::Dim2);
+ const std::string wxh = tensor->component(TensorComponentType::Dim1xDim2);
/*
int x_s;
int y_s;
diff --git a/compute_kernel_writer/prototype/src/TensorOperand.cpp b/compute_kernel_writer/prototype/src/TensorOperand.cpp
index 00ecc3824e..c6725d3b26 100644
--- a/compute_kernel_writer/prototype/src/TensorOperand.cpp
+++ b/compute_kernel_writer/prototype/src/TensorOperand.cpp
@@ -25,6 +25,7 @@
#include "ckw/TensorOperand.h"
#include "ckw/Error.h"
#include "ckw/Kernel.h"
+#include "ckw/TensorInfo.h"
#include "ckw/TileOperand.h"
#include "src/Prototype.h"
@@ -34,11 +35,11 @@ namespace ckw
namespace
{
-inline TensorComponentOperand &get_or_create_component(std::unique_ptr<TensorComponentOperand> &ptr, const ::std::string &name, TensorComponent component)
+TensorComponentOperand &get_or_create_component(TensorOperand &tensor, std::unique_ptr<TensorComponentOperand> &ptr, TensorComponentType component)
{
if(ptr == nullptr)
{
- ptr = std::make_unique<TensorComponentOperand>(name, component);
+ ptr = std::make_unique<TensorComponentOperand>(tensor, component);
}
return *ptr;
@@ -50,8 +51,8 @@ inline TensorComponentOperand &get_or_create_component(std::unique_ptr<TensorCom
// TensorOperand
// =================================================================================================
-TensorOperand::TensorOperand(const std::string &name, const TensorInfo &info)
- : OperandBase(name), _info(info)
+TensorOperand::TensorOperand(const std::string &name, const TensorInfo &info, TensorStorageType storage_type)
+ : OperandBase(name), _info(info), _storage_type(storage_type)
{
}
@@ -71,6 +72,11 @@ TensorInfo &TensorOperand::info()
return _info;
}
+TensorStorageType TensorOperand::storage_type() const
+{
+ return _storage_type;
+}
+
DataType TensorOperand::data_type() const
{
return _info.data_type();
@@ -113,73 +119,88 @@ TensorOperand &TensorOperand::tile_sampler(const TensorTileSampler &value)
return *this;
}
-TileOperand &TensorOperand::stride1()
+TensorComponentOperand &TensorOperand::stride1()
{
- return get_or_create_component(_stride1, name(), TensorComponent::Stride1);
+ return get_or_create_component(*this, _stride1, TensorComponentType::Stride1);
}
-TileOperand &TensorOperand::stride2()
+TensorComponentOperand &TensorOperand::stride2()
{
- return get_or_create_component(_stride2, name(), TensorComponent::Stride2);
+ return get_or_create_component(*this, _stride2, TensorComponentType::Stride2);
}
-TileOperand &TensorOperand::stride3()
+TensorComponentOperand &TensorOperand::stride3()
{
- return get_or_create_component(_stride3, name(), TensorComponent::Stride3);
+ return get_or_create_component(*this, _stride3, TensorComponentType::Stride3);
}
-TileOperand &TensorOperand::stride4()
+TensorComponentOperand &TensorOperand::stride4()
{
- return get_or_create_component(_stride4, name(), TensorComponent::Stride4);
+ return get_or_create_component(*this, _stride4, TensorComponentType::Stride4);
}
-TileOperand &TensorOperand::dim0()
+TensorComponentOperand &TensorOperand::dim0()
{
- return get_or_create_component(_dim0, name(), TensorComponent::Dim0);
+ return get_or_create_component(*this, _dim0, TensorComponentType::Dim0);
}
-TileOperand &TensorOperand::dim1()
+TensorComponentOperand &TensorOperand::dim1()
{
- return get_or_create_component(_dim1, name(), TensorComponent::Dim1);
+ return get_or_create_component(*this, _dim1, TensorComponentType::Dim1);
}
-TileOperand &TensorOperand::dim2()
+TensorComponentOperand &TensorOperand::dim2()
{
- return get_or_create_component(_dim2, name(), TensorComponent::Dim2);
+ return get_or_create_component(*this, _dim2, TensorComponentType::Dim2);
}
-TileOperand &TensorOperand::dim3()
+TensorComponentOperand &TensorOperand::dim3()
{
- return get_or_create_component(_dim3, name(), TensorComponent::Dim3);
+ return get_or_create_component(*this, _dim3, TensorComponentType::Dim3);
}
-TileOperand &TensorOperand::dim4()
+TensorComponentOperand &TensorOperand::dim4()
{
- return get_or_create_component(_dim4, name(), TensorComponent::Dim4);
+ return get_or_create_component(*this, _dim4, TensorComponentType::Dim4);
}
-TileOperand &TensorOperand::dim1_dim2()
+TensorComponentOperand &TensorOperand::dim1_dim2()
{
- return get_or_create_component(_dim1_dim2, name(), TensorComponent::Dim1xDim2);
+ return get_or_create_component(*this, _dim1_dim2, TensorComponentType::Dim1xDim2);
}
-TileOperand &TensorOperand::dim1_dim2_dim3()
+TensorComponentOperand &TensorOperand::dim1_dim2_dim3()
{
- return get_or_create_component(_dim1_dim2_dim3, name(), TensorComponent::Dim1xDim2xDim3);
+ return get_or_create_component(*this, _dim1_dim2_dim3, TensorComponentType::Dim1xDim2xDim3);
}
-TileOperand &TensorOperand::offset_first_element_in_bytes()
+TensorComponentOperand &TensorOperand::offset_first_element_in_bytes()
{
- return get_or_create_component(_offset_first_element_in_bytes, name(), TensorComponent::OffsetFirstElement);
+ return get_or_create_component(*this, _offset_first_element_in_bytes, TensorComponentType::OffsetFirstElement);
}
// =================================================================================================
// TensorComponentOperand
// =================================================================================================
-TensorComponentOperand::TensorComponentOperand(const ::std::string &name, TensorComponent component)
- : TileOperand(name, DataType::Int32), _component(component)
+TensorComponentOperand::TensorComponentOperand(TensorOperand &tensor, TensorComponentType component)
+ : TileOperand(tensor.name(), DataType::Int32), _tensor(tensor), _component(component)
+{
+}
+
+TensorOperand &TensorComponentOperand::tensor()
+{
+ return _tensor;
+}
+
+const TensorOperand &TensorComponentOperand::tensor() const
+{
+ return _tensor;
+}
+
+TensorComponentType TensorComponentOperand::component_type() const
{
+ return _component;
}
prototype::Operand TensorComponentOperand::create_impl_operand(prototype::IGpuKernelWriter *writer) const
@@ -189,51 +210,51 @@ prototype::Operand TensorComponentOperand::create_impl_operand(prototype::IGpuKe
switch(_component)
{
- case TensorComponent::OffsetFirstElement:
+ case TensorComponentType::OffsetFirstElement:
type = prototype::OperandType::TensorDataOffset;
break;
- case TensorComponent::Stride1:
+ case TensorComponentType::Stride1:
type = prototype::OperandType::TensorStride1;
break;
- case TensorComponent::Stride2:
+ case TensorComponentType::Stride2:
type = prototype::OperandType::TensorStride2;
break;
- case TensorComponent::Stride3:
+ case TensorComponentType::Stride3:
type = prototype::OperandType::TensorStride3;
break;
- case TensorComponent::Stride4:
+ case TensorComponentType::Stride4:
type = prototype::OperandType::TensorStride4;
break;
- case TensorComponent::Dim0:
+ case TensorComponentType::Dim0:
type = prototype::OperandType::TensorDim0;
break;
- case TensorComponent::Dim1:
+ case TensorComponentType::Dim1:
type = prototype::OperandType::TensorDim1;
break;
- case TensorComponent::Dim2:
+ case TensorComponentType::Dim2:
type = prototype::OperandType::TensorDim2;
break;
- case TensorComponent::Dim3:
+ case TensorComponentType::Dim3:
type = prototype::OperandType::TensorDim3;
break;
- case TensorComponent::Dim4:
+ case TensorComponentType::Dim4:
type = prototype::OperandType::TensorDim4;
break;
- case TensorComponent::Dim1xDim2:
+ case TensorComponentType::Dim1xDim2:
type = prototype::OperandType::TensorDim1xDim2;
break;
- case TensorComponent::Dim1xDim2xDim3:
+ case TensorComponentType::Dim1xDim2xDim3:
type = prototype::OperandType::TensorDim1xDim2xDim3;
break;