diff options
Diffstat (limited to 'compute_kernel_writer/prototype/src/TensorOperand.cpp')
-rw-r--r-- | compute_kernel_writer/prototype/src/TensorOperand.cpp | 105 |
1 files changed, 63 insertions, 42 deletions
diff --git a/compute_kernel_writer/prototype/src/TensorOperand.cpp b/compute_kernel_writer/prototype/src/TensorOperand.cpp index 00ecc3824e..c6725d3b26 100644 --- a/compute_kernel_writer/prototype/src/TensorOperand.cpp +++ b/compute_kernel_writer/prototype/src/TensorOperand.cpp @@ -25,6 +25,7 @@ #include "ckw/TensorOperand.h" #include "ckw/Error.h" #include "ckw/Kernel.h" +#include "ckw/TensorInfo.h" #include "ckw/TileOperand.h" #include "src/Prototype.h" @@ -34,11 +35,11 @@ namespace ckw namespace { -inline TensorComponentOperand &get_or_create_component(std::unique_ptr<TensorComponentOperand> &ptr, const ::std::string &name, TensorComponent component) +TensorComponentOperand &get_or_create_component(TensorOperand &tensor, std::unique_ptr<TensorComponentOperand> &ptr, TensorComponentType component) { if(ptr == nullptr) { - ptr = std::make_unique<TensorComponentOperand>(name, component); + ptr = std::make_unique<TensorComponentOperand>(tensor, component); } return *ptr; @@ -50,8 +51,8 @@ inline TensorComponentOperand &get_or_create_component(std::unique_ptr<TensorCom // TensorOperand // ================================================================================================= -TensorOperand::TensorOperand(const std::string &name, const TensorInfo &info) - : OperandBase(name), _info(info) +TensorOperand::TensorOperand(const std::string &name, const TensorInfo &info, TensorStorageType storage_type) + : OperandBase(name), _info(info), _storage_type(storage_type) { } @@ -71,6 +72,11 @@ TensorInfo &TensorOperand::info() return _info; } +TensorStorageType TensorOperand::storage_type() const +{ + return _storage_type; +} + DataType TensorOperand::data_type() const { return _info.data_type(); @@ -113,73 +119,88 @@ TensorOperand &TensorOperand::tile_sampler(const TensorTileSampler &value) return *this; } -TileOperand &TensorOperand::stride1() +TensorComponentOperand &TensorOperand::stride1() { - return get_or_create_component(_stride1, name(), TensorComponent::Stride1); + return get_or_create_component(*this, _stride1, TensorComponentType::Stride1); } -TileOperand &TensorOperand::stride2() +TensorComponentOperand &TensorOperand::stride2() { - return get_or_create_component(_stride2, name(), TensorComponent::Stride2); + return get_or_create_component(*this, _stride2, TensorComponentType::Stride2); } -TileOperand &TensorOperand::stride3() +TensorComponentOperand &TensorOperand::stride3() { - return get_or_create_component(_stride3, name(), TensorComponent::Stride3); + return get_or_create_component(*this, _stride3, TensorComponentType::Stride3); } -TileOperand &TensorOperand::stride4() +TensorComponentOperand &TensorOperand::stride4() { - return get_or_create_component(_stride4, name(), TensorComponent::Stride4); + return get_or_create_component(*this, _stride4, TensorComponentType::Stride4); } -TileOperand &TensorOperand::dim0() +TensorComponentOperand &TensorOperand::dim0() { - return get_or_create_component(_dim0, name(), TensorComponent::Dim0); + return get_or_create_component(*this, _dim0, TensorComponentType::Dim0); } -TileOperand &TensorOperand::dim1() +TensorComponentOperand &TensorOperand::dim1() { - return get_or_create_component(_dim1, name(), TensorComponent::Dim1); + return get_or_create_component(*this, _dim1, TensorComponentType::Dim1); } -TileOperand &TensorOperand::dim2() +TensorComponentOperand &TensorOperand::dim2() { - return get_or_create_component(_dim2, name(), TensorComponent::Dim2); + return get_or_create_component(*this, _dim2, TensorComponentType::Dim2); } -TileOperand &TensorOperand::dim3() +TensorComponentOperand &TensorOperand::dim3() { - return get_or_create_component(_dim3, name(), TensorComponent::Dim3); + return get_or_create_component(*this, _dim3, TensorComponentType::Dim3); } -TileOperand &TensorOperand::dim4() +TensorComponentOperand &TensorOperand::dim4() { - return get_or_create_component(_dim4, name(), TensorComponent::Dim4); + return get_or_create_component(*this, _dim4, TensorComponentType::Dim4); } -TileOperand &TensorOperand::dim1_dim2() +TensorComponentOperand &TensorOperand::dim1_dim2() { - return get_or_create_component(_dim1_dim2, name(), TensorComponent::Dim1xDim2); + return get_or_create_component(*this, _dim1_dim2, TensorComponentType::Dim1xDim2); } -TileOperand &TensorOperand::dim1_dim2_dim3() +TensorComponentOperand &TensorOperand::dim1_dim2_dim3() { - return get_or_create_component(_dim1_dim2_dim3, name(), TensorComponent::Dim1xDim2xDim3); + return get_or_create_component(*this, _dim1_dim2_dim3, TensorComponentType::Dim1xDim2xDim3); } -TileOperand &TensorOperand::offset_first_element_in_bytes() +TensorComponentOperand &TensorOperand::offset_first_element_in_bytes() { - return get_or_create_component(_offset_first_element_in_bytes, name(), TensorComponent::OffsetFirstElement); + return get_or_create_component(*this, _offset_first_element_in_bytes, TensorComponentType::OffsetFirstElement); } // ================================================================================================= // TensorComponentOperand // ================================================================================================= -TensorComponentOperand::TensorComponentOperand(const ::std::string &name, TensorComponent component) - : TileOperand(name, DataType::Int32), _component(component) +TensorComponentOperand::TensorComponentOperand(TensorOperand &tensor, TensorComponentType component) + : TileOperand(tensor.name(), DataType::Int32), _tensor(tensor), _component(component) +{ +} + +TensorOperand &TensorComponentOperand::tensor() +{ + return _tensor; +} + +const TensorOperand &TensorComponentOperand::tensor() const +{ + return _tensor; +} + +TensorComponentType TensorComponentOperand::component_type() const { + return _component; } prototype::Operand TensorComponentOperand::create_impl_operand(prototype::IGpuKernelWriter *writer) const @@ -189,51 +210,51 @@ prototype::Operand TensorComponentOperand::create_impl_operand(prototype::IGpuKe switch(_component) { - case TensorComponent::OffsetFirstElement: + case TensorComponentType::OffsetFirstElement: type = prototype::OperandType::TensorDataOffset; break; - case TensorComponent::Stride1: + case TensorComponentType::Stride1: type = prototype::OperandType::TensorStride1; break; - case TensorComponent::Stride2: + case TensorComponentType::Stride2: type = prototype::OperandType::TensorStride2; break; - case TensorComponent::Stride3: + case TensorComponentType::Stride3: type = prototype::OperandType::TensorStride3; break; - case TensorComponent::Stride4: + case TensorComponentType::Stride4: type = prototype::OperandType::TensorStride4; break; - case TensorComponent::Dim0: + case TensorComponentType::Dim0: type = prototype::OperandType::TensorDim0; break; - case TensorComponent::Dim1: + case TensorComponentType::Dim1: type = prototype::OperandType::TensorDim1; break; - case TensorComponent::Dim2: + case TensorComponentType::Dim2: type = prototype::OperandType::TensorDim2; break; - case TensorComponent::Dim3: + case TensorComponentType::Dim3: type = prototype::OperandType::TensorDim3; break; - case TensorComponent::Dim4: + case TensorComponentType::Dim4: type = prototype::OperandType::TensorDim4; break; - case TensorComponent::Dim1xDim2: + case TensorComponentType::Dim1xDim2: type = prototype::OperandType::TensorDim1xDim2; break; - case TensorComponent::Dim1xDim2xDim3: + case TensorComponentType::Dim1xDim2xDim3: type = prototype::OperandType::TensorDim1xDim2xDim3; break; |