diff options
Diffstat (limited to 'compute_kernel_writer/prototype/src/TileOperand.cpp')
-rw-r--r-- | compute_kernel_writer/prototype/src/TileOperand.cpp | 57 |
1 files changed, 44 insertions, 13 deletions
diff --git a/compute_kernel_writer/prototype/src/TileOperand.cpp b/compute_kernel_writer/prototype/src/TileOperand.cpp index fcb3cb6415..bf6a15b9df 100644 --- a/compute_kernel_writer/prototype/src/TileOperand.cpp +++ b/compute_kernel_writer/prototype/src/TileOperand.cpp @@ -30,22 +30,42 @@ namespace ckw { TileOperand::TileOperand(const std::string &name, const TileInfo &info) - : OperandBase(name), _info(info), _value{ 0 }, _constant(false) + : OperandBase(name), + _info(info), + _value{ std::vector<std::string>{ "0" } }, + _constant(false) { } TileOperand::TileOperand(const std::string &name, DataType data_type) - : OperandBase(name), _info(TileInfo{ data_type }), _value(0), _constant(false) + : OperandBase(name), + _info(TileInfo{ data_type }), + _value{ std::vector<std::string>{ "0" } }, + _constant(false) { } TileOperand::TileOperand(const std::string &name, int32_t value) - : OperandBase(name), _info(TileInfo{ DataType::Int32 }), _value(value), _constant(true) + : OperandBase(name), + _info(TileInfo{ DataType::Int32 }), + _value{ std::vector<std::string>{ std::to_string(value) } }, + _constant(true) { } TileOperand::TileOperand(const std::string &name, float value) - : OperandBase(name), _info(TileInfo{ DataType::Fp32 }), _value(value), _constant(true) + : OperandBase(name), + _info(TileInfo{ DataType::Fp32 }), + _value{ std::vector<std::string>{ std::to_string(value) } }, + _constant(true) +{ +} + +TileOperand::TileOperand(const std::string &name, const TileContainer &vals, DataType dt) + : OperandBase(name), + _info(TileInfo{ dt, static_cast<int32_t>(vals.size()), static_cast<int32_t>(vals[0].size()) }), + _value(vals), + _constant(true) { } @@ -55,17 +75,23 @@ prototype::Operand TileOperand::create_impl_operand(prototype::IGpuKernelWriter if(_constant) { - switch(_info.data_type()) + if(is_scalar()) { - case DataType::Int32: - return prototype::Operand(std::to_string(_value.get<int32_t>()), - prototype::OperandType::ScalarInt32); + switch(_info.data_type()) + { + case DataType::Int32: + return prototype::Operand(_value[0][0], prototype::OperandType::ScalarInt32); - case DataType::Fp32: - return prototype::Operand(std::to_string(_value.get<float>()), prototype::OperandType::ScalarFp32); + case DataType::Fp32: + return prototype::Operand(_value[0][0], prototype::OperandType::ScalarFp32); - default: - CKW_ASSERT(false); + default: + CKW_ASSERT(false); + } + } + else + { + return prototype::Operand(name()); } } else @@ -94,11 +120,16 @@ bool TileOperand::is_scalar() const return _info.width() == 1 && _info.height() == 1; } -ScalarValue TileOperand::scalar_value() const +std::string TileOperand::scalar_value() const { CKW_ASSERT(is_scalar()); CKW_ASSERT(is_constant()); + return _value[0][0]; +} + +const TileContainer &TileOperand::value() const +{ return _value; } |