From 571f7182a10a974f1ce993d83b01070153f142cc Mon Sep 17 00:00:00 2001 From: Kevin Cheng Date: Mon, 24 May 2021 17:20:01 -0700 Subject: Support 8-bit TABLE op. Signed-off-by: Kevin Cheng Change-Id: If577035d71c5f9970df5b6a78640a3028c3f83c0 --- reference_model/src/ops/ewise_binary.cc | 86 +++++++++++++++++++++++---------- reference_model/src/ops/ewise_binary.h | 7 ++- reference_model/src/ops/op_factory.cc | 3 +- 3 files changed, 65 insertions(+), 31 deletions(-) (limited to 'reference_model/src/ops') diff --git a/reference_model/src/ops/ewise_binary.cc b/reference_model/src/ops/ewise_binary.cc index 76cebeb..3379ffe 100644 --- a/reference_model/src/ops/ewise_binary.cc +++ b/reference_model/src/ops/ewise_binary.cc @@ -519,20 +519,20 @@ int OpSub::register_fcn() return 0; } -template -OpTable::OpTable(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) +template +OpTable::OpTable(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) : GraphNode(Op_TABLE, id_) { setRequiredOperands(2, 1); setRequiredRank(0, 6); } -template -OpTable::~OpTable() +template +OpTable::~OpTable() {} -template -int OpTable::checkTensorAttributes() +template +int OpTable::checkTensorAttributes() { if (validateRequiredOperands()) return 1; @@ -542,12 +542,29 @@ int OpTable::checkTensorAttributes() return 1; } - if (inputs[1]->getRank() != 1 || inputs[1]->getElementCount() != 513 || inputs[1]->getDtype() != DType_INT16) + if (inputs[1]->getRank() != 1) { - FATAL_ERROR_NODE("OpTable: must have INT16 table with 513 entries"); + printNodeValidationError("OpTable: Table must be rank 1 tensor"); return 1; } + if (inputs[0]->getDtype() == DType_INT8) + { + if (inputs[1]->getElementCount() != 256 || inputs[1]->getDtype() != DType_INT8) + { + printNodeValidationError("OpTable: Table must be INT8[256] if input is INT8"); + return 1; + } + } + else if (inputs[0]->getDtype() == DType_INT16) + { + if (inputs[1]->getElementCount() != 513 || inputs[1]->getDtype() != DType_INT16) + { + printNodeValidationError("OpTable: Table must be INT16[513] if input is INT16"); + return 1; + } + } + in = dynamic_cast*>(inputs[0]); table = dynamic_cast*>(inputs[1]); out = dynamic_cast*>(outputs[0]); @@ -557,25 +574,41 @@ int OpTable::checkTensorAttributes() return 0; } -template -int OpTable::eval() +template +int OpTable::eval() { - this->out->getTensor() = this->in->getTensor().unaryExpr([this](InEigenType in) -> OutEigenType { - // 1. make sure input is int16 range - int32_t input_truncated = std::min(std::max(in, QInMin), QInMax); - - // 2. calculate index and interpolation fraction - int32_t index = (input_truncated >> 7) + (1 << (IntegerBits - 1)); - index = std::min(std::max(index, 0), NumTableEntries - 1); // 9-bit index - int32_t frac = (input_truncated)&0x7F; // 7-bit fraction - - // 3. interpolate, generate 16.7 (23-bit) output - int32_t base = this->table->getTensor()(index); - int32_t next = this->table->getTensor()(index + 1); - int32_t value = (base << 7) + (next - base) * frac; + switch (InDtype) + { + case DType_INT8: + this->out->getTensor() = this->in->getTensor().unaryExpr([this](InEigenType in) -> OutEigenType { + int32_t input_truncated = std::min(std::max(in, QInMin), QInMax); + int32_t index = input_truncated - QInMin; + int32_t value = this->table->getTensor()(index); - return value; - }); + return value; + }); + break; + case DType_INT16: + this->out->getTensor() = this->in->getTensor().unaryExpr([this](InEigenType in) -> OutEigenType { + // 1. make sure input is int16 range + int32_t input_truncated = std::min(std::max(in, QInMin), QInMax); + + // 2. calculate index and interpolation fraction + int32_t index = (input_truncated >> FractionBits) + (1 << (IntegerBits - 1)); + index = std::min(std::max(index, 0), NumTableEntries - 1); // 9-bit index + int32_t frac = (input_truncated)&0x7F; // 7-bit fraction + + // 3. interpolate, generate 16.7 (23-bit) output + int32_t base = this->table->getTensor()(index); + int32_t next = this->table->getTensor()(index + 1); + int32_t value = (base << 7) + (next - base) * frac; + + return value; + }); + break; + default: + FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[InDtype]); + } return GraphNode::eval(); } @@ -632,7 +665,8 @@ DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpPow, FLOAT); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSub, FLOAT); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSub, INT32); -DEF_INSTANTIATE_ONE_RANK_0_6(OpTable); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTable, INT8); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTable, INT16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(BinaryNode, FLOAT, BOOL); DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(BinaryNode, INT32, BOOL); diff --git a/reference_model/src/ops/ewise_binary.h b/reference_model/src/ops/ewise_binary.h index 6b9c98d..a5b1059 100644 --- a/reference_model/src/ops/ewise_binary.h +++ b/reference_model/src/ops/ewise_binary.h @@ -176,7 +176,7 @@ protected: TosaMulAttribute* attribute; }; -template +template class OpTable : public GraphNode { public: @@ -186,9 +186,8 @@ public: virtual int checkTensorAttributes(); virtual int eval(); - static constexpr DType InDtype = DType_INT16; - static constexpr DType TableDtype = DType_INT16; - static constexpr DType OutDtype = DType_INT32; + static constexpr DType TableDtype = (InDtype == DType_INT8) ? DType_INT8 : DType_INT16; + static constexpr DType OutDtype = (InDtype == DType_INT8) ? DType_INT8 : DType_INT32; using InEigenType = typename GetEigenType::type; using TableEigenType = typename GetEigenType::type; using OutEigenType = typename GetEigenType::type; diff --git a/reference_model/src/ops/op_factory.cc b/reference_model/src/ops/op_factory.cc index 440d624..726ab7c 100644 --- a/reference_model/src/ops/op_factory.cc +++ b/reference_model/src/ops/op_factory.cc @@ -178,7 +178,8 @@ GraphNode* OpFactory::newOp(TosaSerializationHandler* tsh, DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSub, INT32); break; case Op_TABLE: - DEF_FACTORY_ONE_RANK_0_6(OpTable); + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTable, INT8); + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTable, INT16); break; // ewise_unary -- cgit v1.2.1