diff options
Diffstat (limited to 'reference_model/src/ops/ewise_binary.cc')
-rw-r--r-- | reference_model/src/ops/ewise_binary.cc | 86 |
1 files changed, 60 insertions, 26 deletions
diff --git a/reference_model/src/ops/ewise_binary.cc b/reference_model/src/ops/ewise_binary.cc index 76cebeb..3379ffe 100644 --- a/reference_model/src/ops/ewise_binary.cc +++ b/reference_model/src/ops/ewise_binary.cc @@ -519,20 +519,20 @@ int OpSub<Rank, Dtype>::register_fcn() return 0; } -template <int Rank> -OpTable<Rank>::OpTable(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) +template <int Rank, DType InDtype> +OpTable<Rank, InDtype>::OpTable(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) : GraphNode(Op_TABLE, id_) { setRequiredOperands(2, 1); setRequiredRank(0, 6); } -template <int Rank> -OpTable<Rank>::~OpTable() +template <int Rank, DType InDtype> +OpTable<Rank, InDtype>::~OpTable() {} -template <int Rank> -int OpTable<Rank>::checkTensorAttributes() +template <int Rank, DType InDtype> +int OpTable<Rank, InDtype>::checkTensorAttributes() { if (validateRequiredOperands()) return 1; @@ -542,12 +542,29 @@ int OpTable<Rank>::checkTensorAttributes() return 1; } - if (inputs[1]->getRank() != 1 || inputs[1]->getElementCount() != 513 || inputs[1]->getDtype() != DType_INT16) + if (inputs[1]->getRank() != 1) { - FATAL_ERROR_NODE("OpTable: must have INT16 table with 513 entries"); + printNodeValidationError("OpTable: Table must be rank 1 tensor"); return 1; } + if (inputs[0]->getDtype() == DType_INT8) + { + if (inputs[1]->getElementCount() != 256 || inputs[1]->getDtype() != DType_INT8) + { + printNodeValidationError("OpTable: Table must be INT8[256] if input is INT8"); + return 1; + } + } + else if (inputs[0]->getDtype() == DType_INT16) + { + if (inputs[1]->getElementCount() != 513 || inputs[1]->getDtype() != DType_INT16) + { + printNodeValidationError("OpTable: Table must be INT16[513] if input is INT16"); + return 1; + } + } + in = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]); table = dynamic_cast<TosaReference::TensorTemplate<TTable>*>(inputs[1]); out = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]); @@ -557,25 +574,41 @@ int OpTable<Rank>::checkTensorAttributes() return 0; } -template <int Rank> -int OpTable<Rank>::eval() +template <int Rank, DType InDtype> +int OpTable<Rank, InDtype>::eval() { - this->out->getTensor() = this->in->getTensor().unaryExpr([this](InEigenType in) -> OutEigenType { - // 1. make sure input is int16 range - int32_t input_truncated = std::min<int32_t>(std::max<int32_t>(in, QInMin), QInMax); - - // 2. calculate index and interpolation fraction - int32_t index = (input_truncated >> 7) + (1 << (IntegerBits - 1)); - index = std::min<int32_t>(std::max<int32_t>(index, 0), NumTableEntries - 1); // 9-bit index - int32_t frac = (input_truncated)&0x7F; // 7-bit fraction - - // 3. interpolate, generate 16.7 (23-bit) output - int32_t base = this->table->getTensor()(index); - int32_t next = this->table->getTensor()(index + 1); - int32_t value = (base << 7) + (next - base) * frac; + switch (InDtype) + { + case DType_INT8: + this->out->getTensor() = this->in->getTensor().unaryExpr([this](InEigenType in) -> OutEigenType { + int32_t input_truncated = std::min<int32_t>(std::max<int32_t>(in, QInMin), QInMax); + int32_t index = input_truncated - QInMin; + int32_t value = this->table->getTensor()(index); - return value; - }); + return value; + }); + break; + case DType_INT16: + this->out->getTensor() = this->in->getTensor().unaryExpr([this](InEigenType in) -> OutEigenType { + // 1. make sure input is int16 range + int32_t input_truncated = std::min<int32_t>(std::max<int32_t>(in, QInMin), QInMax); + + // 2. calculate index and interpolation fraction + int32_t index = (input_truncated >> FractionBits) + (1 << (IntegerBits - 1)); + index = std::min<int32_t>(std::max<int32_t>(index, 0), NumTableEntries - 1); // 9-bit index + int32_t frac = (input_truncated)&0x7F; // 7-bit fraction + + // 3. interpolate, generate 16.7 (23-bit) output + int32_t base = this->table->getTensor()(index); + int32_t next = this->table->getTensor()(index + 1); + int32_t value = (base << 7) + (next - base) * frac; + + return value; + }); + break; + default: + FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[InDtype]); + } return GraphNode::eval(); } @@ -632,7 +665,8 @@ DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpPow, FLOAT); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSub, FLOAT); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSub, INT32); -DEF_INSTANTIATE_ONE_RANK_0_6(OpTable); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTable, INT8); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTable, INT16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(BinaryNode, FLOAT, BOOL); DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(BinaryNode, INT32, BOOL); |