aboutsummaryrefslogtreecommitdiff
path: root/reference_model/src/ops/ewise_binary.cc
diff options
context:
space:
mode:
authorKevin Cheng <kevin.cheng@arm.com>2021-05-24 17:20:01 -0700
committerKevin Cheng <kevin.cheng@arm.com>2021-05-24 17:20:21 -0700
commit571f7182a10a974f1ce993d83b01070153f142cc (patch)
tree1cdc94294b24ec5d4719219e46d39ddc2e674652 /reference_model/src/ops/ewise_binary.cc
parent47315e1af6947dd93729c6dbd034c7db1af7f312 (diff)
downloadreference_model-571f7182a10a974f1ce993d83b01070153f142cc.tar.gz
Support 8-bit TABLE op.
Signed-off-by: Kevin Cheng <kevin.cheng@arm.com> Change-Id: If577035d71c5f9970df5b6a78640a3028c3f83c0
Diffstat (limited to 'reference_model/src/ops/ewise_binary.cc')
-rw-r--r--reference_model/src/ops/ewise_binary.cc86
1 files changed, 60 insertions, 26 deletions
diff --git a/reference_model/src/ops/ewise_binary.cc b/reference_model/src/ops/ewise_binary.cc
index 76cebeb..3379ffe 100644
--- a/reference_model/src/ops/ewise_binary.cc
+++ b/reference_model/src/ops/ewise_binary.cc
@@ -519,20 +519,20 @@ int OpSub<Rank, Dtype>::register_fcn()
return 0;
}
-template <int Rank>
-OpTable<Rank>::OpTable(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_)
+template <int Rank, DType InDtype>
+OpTable<Rank, InDtype>::OpTable(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_)
: GraphNode(Op_TABLE, id_)
{
setRequiredOperands(2, 1);
setRequiredRank(0, 6);
}
-template <int Rank>
-OpTable<Rank>::~OpTable()
+template <int Rank, DType InDtype>
+OpTable<Rank, InDtype>::~OpTable()
{}
-template <int Rank>
-int OpTable<Rank>::checkTensorAttributes()
+template <int Rank, DType InDtype>
+int OpTable<Rank, InDtype>::checkTensorAttributes()
{
if (validateRequiredOperands())
return 1;
@@ -542,12 +542,29 @@ int OpTable<Rank>::checkTensorAttributes()
return 1;
}
- if (inputs[1]->getRank() != 1 || inputs[1]->getElementCount() != 513 || inputs[1]->getDtype() != DType_INT16)
+ if (inputs[1]->getRank() != 1)
{
- FATAL_ERROR_NODE("OpTable: must have INT16 table with 513 entries");
+ printNodeValidationError("OpTable: Table must be rank 1 tensor");
return 1;
}
+ if (inputs[0]->getDtype() == DType_INT8)
+ {
+ if (inputs[1]->getElementCount() != 256 || inputs[1]->getDtype() != DType_INT8)
+ {
+ printNodeValidationError("OpTable: Table must be INT8[256] if input is INT8");
+ return 1;
+ }
+ }
+ else if (inputs[0]->getDtype() == DType_INT16)
+ {
+ if (inputs[1]->getElementCount() != 513 || inputs[1]->getDtype() != DType_INT16)
+ {
+ printNodeValidationError("OpTable: Table must be INT16[513] if input is INT16");
+ return 1;
+ }
+ }
+
in = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
table = dynamic_cast<TosaReference::TensorTemplate<TTable>*>(inputs[1]);
out = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
@@ -557,25 +574,41 @@ int OpTable<Rank>::checkTensorAttributes()
return 0;
}
-template <int Rank>
-int OpTable<Rank>::eval()
+template <int Rank, DType InDtype>
+int OpTable<Rank, InDtype>::eval()
{
- this->out->getTensor() = this->in->getTensor().unaryExpr([this](InEigenType in) -> OutEigenType {
- // 1. make sure input is int16 range
- int32_t input_truncated = std::min<int32_t>(std::max<int32_t>(in, QInMin), QInMax);
-
- // 2. calculate index and interpolation fraction
- int32_t index = (input_truncated >> 7) + (1 << (IntegerBits - 1));
- index = std::min<int32_t>(std::max<int32_t>(index, 0), NumTableEntries - 1); // 9-bit index
- int32_t frac = (input_truncated)&0x7F; // 7-bit fraction
-
- // 3. interpolate, generate 16.7 (23-bit) output
- int32_t base = this->table->getTensor()(index);
- int32_t next = this->table->getTensor()(index + 1);
- int32_t value = (base << 7) + (next - base) * frac;
+ switch (InDtype)
+ {
+ case DType_INT8:
+ this->out->getTensor() = this->in->getTensor().unaryExpr([this](InEigenType in) -> OutEigenType {
+ int32_t input_truncated = std::min<int32_t>(std::max<int32_t>(in, QInMin), QInMax);
+ int32_t index = input_truncated - QInMin;
+ int32_t value = this->table->getTensor()(index);
- return value;
- });
+ return value;
+ });
+ break;
+ case DType_INT16:
+ this->out->getTensor() = this->in->getTensor().unaryExpr([this](InEigenType in) -> OutEigenType {
+ // 1. make sure input is int16 range
+ int32_t input_truncated = std::min<int32_t>(std::max<int32_t>(in, QInMin), QInMax);
+
+ // 2. calculate index and interpolation fraction
+ int32_t index = (input_truncated >> FractionBits) + (1 << (IntegerBits - 1));
+ index = std::min<int32_t>(std::max<int32_t>(index, 0), NumTableEntries - 1); // 9-bit index
+ int32_t frac = (input_truncated)&0x7F; // 7-bit fraction
+
+ // 3. interpolate, generate 16.7 (23-bit) output
+ int32_t base = this->table->getTensor()(index);
+ int32_t next = this->table->getTensor()(index + 1);
+ int32_t value = (base << 7) + (next - base) * frac;
+
+ return value;
+ });
+ break;
+ default:
+ FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[InDtype]);
+ }
return GraphNode::eval();
}
@@ -632,7 +665,8 @@ DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpPow, FLOAT);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSub, FLOAT);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSub, INT32);
-DEF_INSTANTIATE_ONE_RANK_0_6(OpTable);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTable, INT8);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTable, INT16);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(BinaryNode, FLOAT, BOOL);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(BinaryNode, INT32, BOOL);