diff options
-rw-r--r-- | reference_model/src/graph_node.h | 9 | ||||
-rw-r--r-- | reference_model/src/ops/ewise_binary.cc | 86 | ||||
-rw-r--r-- | reference_model/src/ops/ewise_binary.h | 7 | ||||
-rw-r--r-- | reference_model/src/ops/op_factory.cc | 3 |
4 files changed, 65 insertions, 40 deletions
diff --git a/reference_model/src/graph_node.h b/reference_model/src/graph_node.h index eee5464..bf80859 100644 --- a/reference_model/src/graph_node.h +++ b/reference_model/src/graph_node.h @@ -33,15 +33,6 @@ #define DEF_INSTANTIATE_TWO_RANK_TWO_TYPE(OP, RANK1, RANK2, DTYPE1, DTYPE2) \ template class TosaReference::OP<RANK1, RANK2, DType_##DTYPE1, DType_##DTYPE2>; -#define DEF_INSTANTIATE_ONE_RANK_0_6(OP) \ - template class TosaReference::OP<0>; \ - template class TosaReference::OP<1>; \ - template class TosaReference::OP<2>; \ - template class TosaReference::OP<3>; \ - template class TosaReference::OP<4>; \ - template class TosaReference::OP<5>; \ - template class TosaReference::OP<6>; - #define DEF_INSTANTIATE_ONE_TYPE(OP, DTYPE) template class TosaReference::OP<DType_##DTYPE>; #define DEF_INSTANTIATE_TWO_TYPE(OP, DTYPE1, DTYPE2) template class TosaReference::OP<DType_##DTYPE1, DType_##DTYPE2>; diff --git a/reference_model/src/ops/ewise_binary.cc b/reference_model/src/ops/ewise_binary.cc index 76cebeb..3379ffe 100644 --- a/reference_model/src/ops/ewise_binary.cc +++ b/reference_model/src/ops/ewise_binary.cc @@ -519,20 +519,20 @@ int OpSub<Rank, Dtype>::register_fcn() return 0; } -template <int Rank> -OpTable<Rank>::OpTable(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) +template <int Rank, DType InDtype> +OpTable<Rank, InDtype>::OpTable(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) : GraphNode(Op_TABLE, id_) { setRequiredOperands(2, 1); setRequiredRank(0, 6); } -template <int Rank> -OpTable<Rank>::~OpTable() +template <int Rank, DType InDtype> +OpTable<Rank, InDtype>::~OpTable() {} -template <int Rank> -int OpTable<Rank>::checkTensorAttributes() +template <int Rank, DType InDtype> +int OpTable<Rank, InDtype>::checkTensorAttributes() { if (validateRequiredOperands()) return 1; @@ -542,12 +542,29 @@ int OpTable<Rank>::checkTensorAttributes() return 1; } - if (inputs[1]->getRank() != 1 || inputs[1]->getElementCount() != 513 || inputs[1]->getDtype() != DType_INT16) + if (inputs[1]->getRank() != 1) { - FATAL_ERROR_NODE("OpTable: must have INT16 table with 513 entries"); + printNodeValidationError("OpTable: Table must be rank 1 tensor"); return 1; } + if (inputs[0]->getDtype() == DType_INT8) + { + if (inputs[1]->getElementCount() != 256 || inputs[1]->getDtype() != DType_INT8) + { + printNodeValidationError("OpTable: Table must be INT8[256] if input is INT8"); + return 1; + } + } + else if (inputs[0]->getDtype() == DType_INT16) + { + if (inputs[1]->getElementCount() != 513 || inputs[1]->getDtype() != DType_INT16) + { + printNodeValidationError("OpTable: Table must be INT16[513] if input is INT16"); + return 1; + } + } + in = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]); table = dynamic_cast<TosaReference::TensorTemplate<TTable>*>(inputs[1]); out = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]); @@ -557,25 +574,41 @@ int OpTable<Rank>::checkTensorAttributes() return 0; } -template <int Rank> -int OpTable<Rank>::eval() +template <int Rank, DType InDtype> +int OpTable<Rank, InDtype>::eval() { - this->out->getTensor() = this->in->getTensor().unaryExpr([this](InEigenType in) -> OutEigenType { - // 1. make sure input is int16 range - int32_t input_truncated = std::min<int32_t>(std::max<int32_t>(in, QInMin), QInMax); - - // 2. calculate index and interpolation fraction - int32_t index = (input_truncated >> 7) + (1 << (IntegerBits - 1)); - index = std::min<int32_t>(std::max<int32_t>(index, 0), NumTableEntries - 1); // 9-bit index - int32_t frac = (input_truncated)&0x7F; // 7-bit fraction - - // 3. interpolate, generate 16.7 (23-bit) output - int32_t base = this->table->getTensor()(index); - int32_t next = this->table->getTensor()(index + 1); - int32_t value = (base << 7) + (next - base) * frac; + switch (InDtype) + { + case DType_INT8: + this->out->getTensor() = this->in->getTensor().unaryExpr([this](InEigenType in) -> OutEigenType { + int32_t input_truncated = std::min<int32_t>(std::max<int32_t>(in, QInMin), QInMax); + int32_t index = input_truncated - QInMin; + int32_t value = this->table->getTensor()(index); - return value; - }); + return value; + }); + break; + case DType_INT16: + this->out->getTensor() = this->in->getTensor().unaryExpr([this](InEigenType in) -> OutEigenType { + // 1. make sure input is int16 range + int32_t input_truncated = std::min<int32_t>(std::max<int32_t>(in, QInMin), QInMax); + + // 2. calculate index and interpolation fraction + int32_t index = (input_truncated >> FractionBits) + (1 << (IntegerBits - 1)); + index = std::min<int32_t>(std::max<int32_t>(index, 0), NumTableEntries - 1); // 9-bit index + int32_t frac = (input_truncated)&0x7F; // 7-bit fraction + + // 3. interpolate, generate 16.7 (23-bit) output + int32_t base = this->table->getTensor()(index); + int32_t next = this->table->getTensor()(index + 1); + int32_t value = (base << 7) + (next - base) * frac; + + return value; + }); + break; + default: + FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[InDtype]); + } return GraphNode::eval(); } @@ -632,7 +665,8 @@ DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpPow, FLOAT); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSub, FLOAT); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSub, INT32); -DEF_INSTANTIATE_ONE_RANK_0_6(OpTable); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTable, INT8); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTable, INT16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(BinaryNode, FLOAT, BOOL); DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(BinaryNode, INT32, BOOL); diff --git a/reference_model/src/ops/ewise_binary.h b/reference_model/src/ops/ewise_binary.h index 6b9c98d..a5b1059 100644 --- a/reference_model/src/ops/ewise_binary.h +++ b/reference_model/src/ops/ewise_binary.h @@ -176,7 +176,7 @@ protected: TosaMulAttribute* attribute; }; -template <int Rank> +template <int Rank, DType InDtype> class OpTable : public GraphNode { public: @@ -186,9 +186,8 @@ public: virtual int checkTensorAttributes(); virtual int eval(); - static constexpr DType InDtype = DType_INT16; - static constexpr DType TableDtype = DType_INT16; - static constexpr DType OutDtype = DType_INT32; + static constexpr DType TableDtype = (InDtype == DType_INT8) ? DType_INT8 : DType_INT16; + static constexpr DType OutDtype = (InDtype == DType_INT8) ? DType_INT8 : DType_INT32; using InEigenType = typename GetEigenType<InDtype>::type; using TableEigenType = typename GetEigenType<TableDtype>::type; using OutEigenType = typename GetEigenType<OutDtype>::type; diff --git a/reference_model/src/ops/op_factory.cc b/reference_model/src/ops/op_factory.cc index 440d624..726ab7c 100644 --- a/reference_model/src/ops/op_factory.cc +++ b/reference_model/src/ops/op_factory.cc @@ -178,7 +178,8 @@ GraphNode* OpFactory::newOp(TosaSerializationHandler* tsh, DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSub, INT32); break; case Op_TABLE: - DEF_FACTORY_ONE_RANK_0_6(OpTable); + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTable, INT8); + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTable, INT16); break; // ewise_unary |