diff options
author | Kevin Cheng <kevin.cheng@arm.com> | 2021-10-18 21:51:55 +0000 |
---|---|---|
committer | Kevin Cheng <kevin.cheng@arm.com> | 2021-10-28 16:51:33 +0000 |
commit | fe392ce8e714e616b5ab5b8a519d3eb84623273d (patch) | |
tree | 15b909b07cbe2b0fb435a2e0c3b513e13a8727b7 /reference_model/src/ops/ewise_binary.cc | |
parent | 1009674513d09af1a699a8bf0f646c7130d7a0ac (diff) | |
download | reference_model-fe392ce8e714e616b5ab5b8a519d3eb84623273d.tar.gz |
Changes for 0.23.0 release
- update serialization_lib hash
- PAD:
1. make padding as an attribute instead of tensor.
2. add pad_const_int (for non-float type) / pad_const_fp (for float type)
- TRANSPOSE: make perm as an attribute instead of tensor
- TABLE: make table as attribute instead of tensor
- update examples/ tests
Signed-off-by: Kevin Cheng <kevin.cheng@arm.com>
Change-Id: Iddc446db4b356ba2f36ea4a79b7220b9cfc2aa4e
Diffstat (limited to 'reference_model/src/ops/ewise_binary.cc')
-rw-r--r-- | reference_model/src/ops/ewise_binary.cc | 42 |
1 files changed, 13 insertions, 29 deletions
diff --git a/reference_model/src/ops/ewise_binary.cc b/reference_model/src/ops/ewise_binary.cc index 6808604..415cd1c 100644 --- a/reference_model/src/ops/ewise_binary.cc +++ b/reference_model/src/ops/ewise_binary.cc @@ -490,8 +490,10 @@ OpTable<Rank, InDtype>::OpTable(SubgraphTraverser* sgt_, uint64_t id_) : GraphNode(sgt_, Op_TABLE, id_) { - setRequiredOperands(2, 1); + setRequiredOperands(1, 1); setRequiredRank(0, 6); + + INIT_ATTRIBUTE(Table); } template <int Rank, DType InDtype> @@ -509,36 +511,18 @@ int OpTable<Rank, InDtype>::checkTensorAttributes() return 1; } - if (inputs[1]->getRank() != 1) - { - printNodeValidationError("OpTable: Table must be rank 1 tensor"); - return 1; - } + ERROR_IF(inputs[0]->getDtype() != InDtype, "OpTable: Unexpected input type"); + ERROR_IF(attribute->table().size() != TableNumEntries, "OpTable: table attribute size must be %u", TableNumEntries); - if (inputs[0]->getDtype() == DType_INT8) + for (uint32_t i = 0; i < TableNumEntries; i++) { - if (inputs[1]->getElementCount() != 256 || inputs[1]->getDtype() != DType_INT8) - { - printNodeValidationError("OpTable: Table must be INT8[256] if input is INT8"); - return 1; - } - ERROR_IF(outputs[0]->getDtype() != DType_INT8, "OpTable: output tensor must be INT8"); - } - else if (inputs[0]->getDtype() == DType_INT16) - { - if (inputs[1]->getElementCount() != 513 || inputs[1]->getDtype() != DType_INT16) - { - printNodeValidationError("OpTable: Table must be INT16[513] if input is INT16"); - return 1; - } - ERROR_IF(outputs[0]->getDtype() != DType_INT32, "OpTable: output tensor must be INT32"); + table[i] = (TableEigenType)attribute->table()[i]; } - in = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]); - table = dynamic_cast<TosaReference::TensorTemplate<TTable>*>(inputs[1]); - out = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]); + in = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]); + out = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]); - ASSERT_MEM(in && table && out); + ASSERT_MEM(in && out); return 0; } @@ -552,7 +536,7 @@ int OpTable<Rank, InDtype>::eval() this->out->getTensor() = this->in->getTensor().unaryExpr([this](InEigenType in) -> OutEigenType { int32_t input_truncated = std::min<int32_t>(std::max<int32_t>(in, QInMin), QInMax); int32_t index = input_truncated - QInMin; - int32_t value = this->table->getTensor()(index); + int32_t value = table[index]; return value; }); @@ -568,8 +552,8 @@ int OpTable<Rank, InDtype>::eval() int32_t frac = (input_truncated)&0x7F; // 7-bit fraction // 3. interpolate, generate 16.7 (23-bit) output - int32_t base = this->table->getTensor()(index); - int32_t next = this->table->getTensor()(index + 1); + int32_t base = table[index]; + int32_t next = table[index + 1]; int32_t value = (base << 7) + (next - base) * frac; return value; |