aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorKevin Cheng <kevin.cheng@arm.com>2021-05-24 17:20:01 -0700
committerKevin Cheng <kevin.cheng@arm.com>2021-05-24 17:20:21 -0700
commit571f7182a10a974f1ce993d83b01070153f142cc (patch)
tree1cdc94294b24ec5d4719219e46d39ddc2e674652
parent47315e1af6947dd93729c6dbd034c7db1af7f312 (diff)
downloadreference_model-571f7182a10a974f1ce993d83b01070153f142cc.tar.gz
Support 8-bit TABLE op.
Signed-off-by: Kevin Cheng <kevin.cheng@arm.com> Change-Id: If577035d71c5f9970df5b6a78640a3028c3f83c0
-rw-r--r--reference_model/src/graph_node.h9
-rw-r--r--reference_model/src/ops/ewise_binary.cc86
-rw-r--r--reference_model/src/ops/ewise_binary.h7
-rw-r--r--reference_model/src/ops/op_factory.cc3
4 files changed, 65 insertions, 40 deletions
diff --git a/reference_model/src/graph_node.h b/reference_model/src/graph_node.h
index eee5464..bf80859 100644
--- a/reference_model/src/graph_node.h
+++ b/reference_model/src/graph_node.h
@@ -33,15 +33,6 @@
#define DEF_INSTANTIATE_TWO_RANK_TWO_TYPE(OP, RANK1, RANK2, DTYPE1, DTYPE2) \
template class TosaReference::OP<RANK1, RANK2, DType_##DTYPE1, DType_##DTYPE2>;
-#define DEF_INSTANTIATE_ONE_RANK_0_6(OP) \
- template class TosaReference::OP<0>; \
- template class TosaReference::OP<1>; \
- template class TosaReference::OP<2>; \
- template class TosaReference::OP<3>; \
- template class TosaReference::OP<4>; \
- template class TosaReference::OP<5>; \
- template class TosaReference::OP<6>;
-
#define DEF_INSTANTIATE_ONE_TYPE(OP, DTYPE) template class TosaReference::OP<DType_##DTYPE>;
#define DEF_INSTANTIATE_TWO_TYPE(OP, DTYPE1, DTYPE2) template class TosaReference::OP<DType_##DTYPE1, DType_##DTYPE2>;
diff --git a/reference_model/src/ops/ewise_binary.cc b/reference_model/src/ops/ewise_binary.cc
index 76cebeb..3379ffe 100644
--- a/reference_model/src/ops/ewise_binary.cc
+++ b/reference_model/src/ops/ewise_binary.cc
@@ -519,20 +519,20 @@ int OpSub<Rank, Dtype>::register_fcn()
return 0;
}
-template <int Rank>
-OpTable<Rank>::OpTable(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_)
+template <int Rank, DType InDtype>
+OpTable<Rank, InDtype>::OpTable(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_)
: GraphNode(Op_TABLE, id_)
{
setRequiredOperands(2, 1);
setRequiredRank(0, 6);
}
-template <int Rank>
-OpTable<Rank>::~OpTable()
+template <int Rank, DType InDtype>
+OpTable<Rank, InDtype>::~OpTable()
{}
-template <int Rank>
-int OpTable<Rank>::checkTensorAttributes()
+template <int Rank, DType InDtype>
+int OpTable<Rank, InDtype>::checkTensorAttributes()
{
if (validateRequiredOperands())
return 1;
@@ -542,12 +542,29 @@ int OpTable<Rank>::checkTensorAttributes()
return 1;
}
- if (inputs[1]->getRank() != 1 || inputs[1]->getElementCount() != 513 || inputs[1]->getDtype() != DType_INT16)
+ if (inputs[1]->getRank() != 1)
{
- FATAL_ERROR_NODE("OpTable: must have INT16 table with 513 entries");
+ printNodeValidationError("OpTable: Table must be rank 1 tensor");
return 1;
}
+ if (inputs[0]->getDtype() == DType_INT8)
+ {
+ if (inputs[1]->getElementCount() != 256 || inputs[1]->getDtype() != DType_INT8)
+ {
+ printNodeValidationError("OpTable: Table must be INT8[256] if input is INT8");
+ return 1;
+ }
+ }
+ else if (inputs[0]->getDtype() == DType_INT16)
+ {
+ if (inputs[1]->getElementCount() != 513 || inputs[1]->getDtype() != DType_INT16)
+ {
+ printNodeValidationError("OpTable: Table must be INT16[513] if input is INT16");
+ return 1;
+ }
+ }
+
in = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
table = dynamic_cast<TosaReference::TensorTemplate<TTable>*>(inputs[1]);
out = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
@@ -557,25 +574,41 @@ int OpTable<Rank>::checkTensorAttributes()
return 0;
}
-template <int Rank>
-int OpTable<Rank>::eval()
+template <int Rank, DType InDtype>
+int OpTable<Rank, InDtype>::eval()
{
- this->out->getTensor() = this->in->getTensor().unaryExpr([this](InEigenType in) -> OutEigenType {
- // 1. make sure input is int16 range
- int32_t input_truncated = std::min<int32_t>(std::max<int32_t>(in, QInMin), QInMax);
-
- // 2. calculate index and interpolation fraction
- int32_t index = (input_truncated >> 7) + (1 << (IntegerBits - 1));
- index = std::min<int32_t>(std::max<int32_t>(index, 0), NumTableEntries - 1); // 9-bit index
- int32_t frac = (input_truncated)&0x7F; // 7-bit fraction
-
- // 3. interpolate, generate 16.7 (23-bit) output
- int32_t base = this->table->getTensor()(index);
- int32_t next = this->table->getTensor()(index + 1);
- int32_t value = (base << 7) + (next - base) * frac;
+ switch (InDtype)
+ {
+ case DType_INT8:
+ this->out->getTensor() = this->in->getTensor().unaryExpr([this](InEigenType in) -> OutEigenType {
+ int32_t input_truncated = std::min<int32_t>(std::max<int32_t>(in, QInMin), QInMax);
+ int32_t index = input_truncated - QInMin;
+ int32_t value = this->table->getTensor()(index);
- return value;
- });
+ return value;
+ });
+ break;
+ case DType_INT16:
+ this->out->getTensor() = this->in->getTensor().unaryExpr([this](InEigenType in) -> OutEigenType {
+ // 1. make sure input is int16 range
+ int32_t input_truncated = std::min<int32_t>(std::max<int32_t>(in, QInMin), QInMax);
+
+ // 2. calculate index and interpolation fraction
+ int32_t index = (input_truncated >> FractionBits) + (1 << (IntegerBits - 1));
+ index = std::min<int32_t>(std::max<int32_t>(index, 0), NumTableEntries - 1); // 9-bit index
+ int32_t frac = (input_truncated)&0x7F; // 7-bit fraction
+
+ // 3. interpolate, generate 16.7 (23-bit) output
+ int32_t base = this->table->getTensor()(index);
+ int32_t next = this->table->getTensor()(index + 1);
+ int32_t value = (base << 7) + (next - base) * frac;
+
+ return value;
+ });
+ break;
+ default:
+ FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[InDtype]);
+ }
return GraphNode::eval();
}
@@ -632,7 +665,8 @@ DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpPow, FLOAT);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSub, FLOAT);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSub, INT32);
-DEF_INSTANTIATE_ONE_RANK_0_6(OpTable);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTable, INT8);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTable, INT16);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(BinaryNode, FLOAT, BOOL);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(BinaryNode, INT32, BOOL);
diff --git a/reference_model/src/ops/ewise_binary.h b/reference_model/src/ops/ewise_binary.h
index 6b9c98d..a5b1059 100644
--- a/reference_model/src/ops/ewise_binary.h
+++ b/reference_model/src/ops/ewise_binary.h
@@ -176,7 +176,7 @@ protected:
TosaMulAttribute* attribute;
};
-template <int Rank>
+template <int Rank, DType InDtype>
class OpTable : public GraphNode
{
public:
@@ -186,9 +186,8 @@ public:
virtual int checkTensorAttributes();
virtual int eval();
- static constexpr DType InDtype = DType_INT16;
- static constexpr DType TableDtype = DType_INT16;
- static constexpr DType OutDtype = DType_INT32;
+ static constexpr DType TableDtype = (InDtype == DType_INT8) ? DType_INT8 : DType_INT16;
+ static constexpr DType OutDtype = (InDtype == DType_INT8) ? DType_INT8 : DType_INT32;
using InEigenType = typename GetEigenType<InDtype>::type;
using TableEigenType = typename GetEigenType<TableDtype>::type;
using OutEigenType = typename GetEigenType<OutDtype>::type;
diff --git a/reference_model/src/ops/op_factory.cc b/reference_model/src/ops/op_factory.cc
index 440d624..726ab7c 100644
--- a/reference_model/src/ops/op_factory.cc
+++ b/reference_model/src/ops/op_factory.cc
@@ -178,7 +178,8 @@ GraphNode* OpFactory::newOp(TosaSerializationHandler* tsh,
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSub, INT32);
break;
case Op_TABLE:
- DEF_FACTORY_ONE_RANK_0_6(OpTable);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTable, INT8);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTable, INT16);
break;
// ewise_unary