aboutsummaryrefslogtreecommitdiff
path: root/reference_model
diff options
context:
space:
mode:
authorTatWai Chong <tatwai.chong@arm.com>2024-05-12 02:35:04 -0700
committerTatWai Chong <tatwai.chong@arm.com>2024-05-28 09:45:36 -0700
commit51d880e7b11e6c10e0f332afc8830015e7c57bd8 (patch)
tree482b82923b222dc83cd83b1441d6634d160f182c /reference_model
parent359fac9c00aab8d29f7da6a060d36bcdaa491584 (diff)
downloadreference_model-main.tar.gz
Change the table parameter from attribute to tensor typeHEADmain
also add testing support for table parameter as input. Signed-off-by: TatWai Chong <tatwai.chong@arm.com> Change-Id: Ie4f6d3cf0b68803fa3353cfa0e9f7f38a83b1539
Diffstat (limited to 'reference_model')
-rw-r--r--reference_model/src/ops/ewise_binary.cc29
-rw-r--r--reference_model/src/ops/ewise_binary.h3
2 files changed, 12 insertions, 20 deletions
diff --git a/reference_model/src/ops/ewise_binary.cc b/reference_model/src/ops/ewise_binary.cc
index d4a9f2f..bc63535 100644
--- a/reference_model/src/ops/ewise_binary.cc
+++ b/reference_model/src/ops/ewise_binary.cc
@@ -666,18 +666,13 @@ template <int Rank, TOSA_REF_TYPE InDtype>
OpTable<Rank, InDtype>::OpTable(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_)
: GraphNode(sgt_, Op_TABLE, id_)
{
- setRequiredOperands(1, 1);
+ setRequiredOperands(2, 1);
setRequiredRank(0, 6);
-
- INIT_ATTRIBUTE(Table);
}
template <int Rank, TOSA_REF_TYPE InDtype>
OpTable<Rank, InDtype>::~OpTable()
-{
- if (attribute)
- delete attribute;
-}
+{}
template <int Rank, TOSA_REF_TYPE InDtype>
int OpTable<Rank, InDtype>::checkTensorAttributes()
@@ -695,16 +690,12 @@ int OpTable<Rank, InDtype>::checkTensorAttributes()
}
ERROR_IF(inputs[0]->getDtype() != InDtype, "OpTable: Unexpected input type");
+ ERROR_IF(inputs[1]->getDtype() != TableDtype, "OpTable: Unexpected table type");
ERROR_IF(outputs[0]->getDtype() != OutDtype, "OpTable: Unexpected output type");
- ERROR_IF(attribute->table().size() != TableNumEntries, "OpTable: table attribute size must be %u", TableNumEntries);
-
- for (uint32_t i = 0; i < TableNumEntries; i++)
- {
- table[i] = (TableEigenType)attribute->table()[i];
- }
- in = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
- out = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
+ in = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
+ table = dynamic_cast<TosaReference::TensorTemplate<TTable>*>(inputs[1]);
+ out = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
ASSERT_MEM(in && out);
@@ -714,13 +705,15 @@ int OpTable<Rank, InDtype>::checkTensorAttributes()
template <int Rank, TOSA_REF_TYPE InDtype>
int OpTable<Rank, InDtype>::eval()
{
+ ERROR_IF(this->table->getTensor().size() != TableNumEntries, "OpTable: table tensor size must be %u",
+ TableNumEntries);
switch (InDtype)
{
case TOSA_REF_TYPE_INT8:
this->out->getTensor() = this->in->getTensor().unaryExpr([this](InEigenType in) -> OutEigenType {
int32_t input_truncated = std::min<int32_t>(std::max<int32_t>(in, QInMin), QInMax);
int32_t index = input_truncated - QInMin;
- int32_t value = table[index];
+ int32_t value = this->table->getTensor()(index);
return value;
});
@@ -736,8 +729,8 @@ int OpTable<Rank, InDtype>::eval()
int32_t frac = (input_truncated)&0x7F; // 7-bit fraction
// 3. Add REQUIRE CHECK for extreme large/small slopes
- int32_t base = table[index];
- int32_t next = table[index + 1];
+ int32_t base = this->table->getTensor()(index);
+ int32_t next = this->table->getTensor()(index + 1);
int32_t slope = next - base;
REQUIRE(slope <= std::numeric_limits<int16_t>::max() && slope >= std::numeric_limits<int16_t>::min(),
"OpTable: slope out of int16_t range");
diff --git a/reference_model/src/ops/ewise_binary.h b/reference_model/src/ops/ewise_binary.h
index 7ebd852..54c05e3 100644
--- a/reference_model/src/ops/ewise_binary.h
+++ b/reference_model/src/ops/ewise_binary.h
@@ -215,8 +215,7 @@ public:
protected:
TosaReference::TensorTemplate<TIn>* in;
TosaReference::TensorTemplate<TOut>* out;
- TosaTableAttribute* attribute;
- std::array<TableEigenType, TableNumEntries> table;
+ TosaReference::TensorTemplate<TTable>* table;
};
}; // namespace TosaReference