aboutsummaryrefslogtreecommitdiff
path: root/reference_model/src/ops/ewise_binary.cc
diff options
context:
space:
mode:
Diffstat (limited to 'reference_model/src/ops/ewise_binary.cc')
-rw-r--r--reference_model/src/ops/ewise_binary.cc42
1 files changed, 13 insertions, 29 deletions
diff --git a/reference_model/src/ops/ewise_binary.cc b/reference_model/src/ops/ewise_binary.cc
index 6808604..415cd1c 100644
--- a/reference_model/src/ops/ewise_binary.cc
+++ b/reference_model/src/ops/ewise_binary.cc
@@ -490,8 +490,10 @@ OpTable<Rank, InDtype>::OpTable(SubgraphTraverser* sgt_,
uint64_t id_)
: GraphNode(sgt_, Op_TABLE, id_)
{
- setRequiredOperands(2, 1);
+ setRequiredOperands(1, 1);
setRequiredRank(0, 6);
+
+ INIT_ATTRIBUTE(Table);
}
template <int Rank, DType InDtype>
@@ -509,36 +511,18 @@ int OpTable<Rank, InDtype>::checkTensorAttributes()
return 1;
}
- if (inputs[1]->getRank() != 1)
- {
- printNodeValidationError("OpTable: Table must be rank 1 tensor");
- return 1;
- }
+ ERROR_IF(inputs[0]->getDtype() != InDtype, "OpTable: Unexpected input type");
+ ERROR_IF(attribute->table().size() != TableNumEntries, "OpTable: table attribute size must be %u", TableNumEntries);
- if (inputs[0]->getDtype() == DType_INT8)
+ for (uint32_t i = 0; i < TableNumEntries; i++)
{
- if (inputs[1]->getElementCount() != 256 || inputs[1]->getDtype() != DType_INT8)
- {
- printNodeValidationError("OpTable: Table must be INT8[256] if input is INT8");
- return 1;
- }
- ERROR_IF(outputs[0]->getDtype() != DType_INT8, "OpTable: output tensor must be INT8");
- }
- else if (inputs[0]->getDtype() == DType_INT16)
- {
- if (inputs[1]->getElementCount() != 513 || inputs[1]->getDtype() != DType_INT16)
- {
- printNodeValidationError("OpTable: Table must be INT16[513] if input is INT16");
- return 1;
- }
- ERROR_IF(outputs[0]->getDtype() != DType_INT32, "OpTable: output tensor must be INT32");
+ table[i] = (TableEigenType)attribute->table()[i];
}
- in = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
- table = dynamic_cast<TosaReference::TensorTemplate<TTable>*>(inputs[1]);
- out = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
+ in = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
+ out = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
- ASSERT_MEM(in && table && out);
+ ASSERT_MEM(in && out);
return 0;
}
@@ -552,7 +536,7 @@ int OpTable<Rank, InDtype>::eval()
this->out->getTensor() = this->in->getTensor().unaryExpr([this](InEigenType in) -> OutEigenType {
int32_t input_truncated = std::min<int32_t>(std::max<int32_t>(in, QInMin), QInMax);
int32_t index = input_truncated - QInMin;
- int32_t value = this->table->getTensor()(index);
+ int32_t value = table[index];
return value;
});
@@ -568,8 +552,8 @@ int OpTable<Rank, InDtype>::eval()
int32_t frac = (input_truncated)&0x7F; // 7-bit fraction
// 3. interpolate, generate 16.7 (23-bit) output
- int32_t base = this->table->getTensor()(index);
- int32_t next = this->table->getTensor()(index + 1);
+ int32_t base = table[index];
+ int32_t next = table[index + 1];
int32_t value = (base << 7) + (next - base) * frac;
return value;