diff options
Diffstat (limited to 'reference_model/src')
-rw-r--r-- | reference_model/src/generate/generate_dot_product_states.cc | 12 | ||||
-rw-r--r-- | reference_model/src/generate/generate_fp_special.cc | 19 | ||||
-rw-r--r-- | reference_model/src/ops/ewise_binary.cc | 60 | ||||
-rw-r--r-- | reference_model/src/ops/ewise_binary.h | 3 | ||||
-rw-r--r-- | reference_model/src/ops/type_conversion.cc | 12 |
5 files changed, 74 insertions, 32 deletions
diff --git a/reference_model/src/generate/generate_dot_product_states.cc b/reference_model/src/generate/generate_dot_product_states.cc index f372bde..fb4d1dd 100644 --- a/reference_model/src/generate/generate_dot_product_states.cc +++ b/reference_model/src/generate/generate_dot_product_states.cc @@ -86,8 +86,8 @@ class GeneratorS0 : public TosaReference::IDotProductGenerator public: GeneratorS0(uint32_t p) : _p(p) - , _set_data0(2 * 0) - , _set_data1(2 * 0 + 1) + , _set_data0(3 * 0) + , _set_data1(3 * 0 + 1) {} float operator()(uint32_t k) override { @@ -157,7 +157,7 @@ public: GeneratorS2(uint32_t p, uint32_t KS) : _p(p) , _KS(KS) - , _set_data(2 * 2 + p) + , _set_data(3 * 2 + p) {} float operator()(uint32_t k) override { @@ -188,7 +188,7 @@ class GeneratorS3 : public TosaReference::IDotProductGenerator public: GeneratorS3(uint32_t p) : _p(p) - , _set_data(2 * 3 + p) + , _set_data(3 * 3 + p) {} float operator()(uint32_t k) override { @@ -225,8 +225,8 @@ public: : _p(p) , _KS(KS) , _B(B) - , _set_data0(2 * 4 + 0) - , _set_data1(2 * 4 + 1) + , _set_data0(3 * 4 + 0) + , _set_data1(3 * 4 + 1) {} float operator()(uint32_t k) override { diff --git a/reference_model/src/generate/generate_fp_special.cc b/reference_model/src/generate/generate_fp_special.cc index 3602f51..ff0c4a4 100644 --- a/reference_model/src/generate/generate_fp_special.cc +++ b/reference_model/src/generate/generate_fp_special.cc @@ -28,9 +28,10 @@ public: Zero, Inf, NaN, - Min, - Max, + Min, // Smallest positive normal floating point value + Max, // Largest positive floating point value One, + MinDenorm, // Smallest positive denormal floating point value }; SpecialValue() = default; @@ -78,6 +79,9 @@ public: return negative ? -std::numeric_limits<DataType>::max() : std::numeric_limits<DataType>::max(); case One: return static_cast<DataType>(negative ? -1.0 : 1.0); + case MinDenorm: + return negative ? -std::numeric_limits<DataType>::denorm_min() + : std::numeric_limits<DataType>::denorm_min(); default: WARNING("[Generator][FS] Uninitialised special value."); return static_cast<DataType>(0.0); @@ -110,10 +114,13 @@ TestValues equalOpsTestVals{ { SpecialValue(SpecialValue::Zero), -SpecialValue(S TestValues addTestVals{ { SpecialValue(SpecialValue::Max), SpecialValue(SpecialValue::One) }, { SpecialValue(SpecialValue::Inf), -SpecialValue(SpecialValue::Inf) } }; -TestValues defaultTestVals{ { SpecialValue(SpecialValue::Zero) }, { -SpecialValue(SpecialValue::Zero) }, - { SpecialValue(SpecialValue::Inf) }, { -SpecialValue(SpecialValue::Inf) }, - { SpecialValue(SpecialValue::NaN) }, { SpecialValue(SpecialValue::Min) }, - { SpecialValue(SpecialValue::Max) } }; +TestValues defaultTestVals{ { SpecialValue(SpecialValue::Zero) }, { -SpecialValue(SpecialValue::Zero) }, + { SpecialValue(SpecialValue::Inf) }, { -SpecialValue(SpecialValue::Inf) }, + { SpecialValue(SpecialValue::Min) }, { -SpecialValue(SpecialValue::Min) }, + { SpecialValue(SpecialValue::Max) }, { -SpecialValue(SpecialValue::Max) }, + { SpecialValue(SpecialValue::MinDenorm) }, { -SpecialValue(SpecialValue::MinDenorm) }, + { SpecialValue(SpecialValue::One) }, { -SpecialValue(SpecialValue::One) }, + { SpecialValue(SpecialValue::NaN) } }; std::map<Op, TestValues> testValues = { { Op::Op_EQUAL, equalOpsTestVals }, { Op::Op_GREATER, equalOpsTestVals }, diff --git a/reference_model/src/ops/ewise_binary.cc b/reference_model/src/ops/ewise_binary.cc index 8cc1319..bc63535 100644 --- a/reference_model/src/ops/ewise_binary.cc +++ b/reference_model/src/ops/ewise_binary.cc @@ -411,6 +411,22 @@ int OpMaximum<Rank, Dtype>::register_fcn() case TOSA_REF_TYPE_BF16: case TOSA_REF_TYPE_FP32: case TOSA_REF_TYPE_FP64: + this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { + if (isnan(a)) + { + return a; + } + else if (isnan(b)) + { + return b; + } + else + { + return a > b ? a : b; + } + }; + break; + case TOSA_REF_TYPE_INT32: this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a > b ? a : b; }; break; @@ -430,6 +446,21 @@ int OpMinimum<Rank, Dtype>::register_fcn() case TOSA_REF_TYPE_BF16: case TOSA_REF_TYPE_FP32: case TOSA_REF_TYPE_FP64: + this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { + if (isnan(a)) + { + return a; + } + else if (isnan(b)) + { + return b; + } + else + { + return a < b ? a : b; + } + }; + break; case TOSA_REF_TYPE_INT32: this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a < b ? a : b; }; break; @@ -635,18 +666,13 @@ template <int Rank, TOSA_REF_TYPE InDtype> OpTable<Rank, InDtype>::OpTable(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_) : GraphNode(sgt_, Op_TABLE, id_) { - setRequiredOperands(1, 1); + setRequiredOperands(2, 1); setRequiredRank(0, 6); - - INIT_ATTRIBUTE(Table); } template <int Rank, TOSA_REF_TYPE InDtype> OpTable<Rank, InDtype>::~OpTable() -{ - if (attribute) - delete attribute; -} +{} template <int Rank, TOSA_REF_TYPE InDtype> int OpTable<Rank, InDtype>::checkTensorAttributes() @@ -664,16 +690,12 @@ int OpTable<Rank, InDtype>::checkTensorAttributes() } ERROR_IF(inputs[0]->getDtype() != InDtype, "OpTable: Unexpected input type"); + ERROR_IF(inputs[1]->getDtype() != TableDtype, "OpTable: Unexpected table type"); ERROR_IF(outputs[0]->getDtype() != OutDtype, "OpTable: Unexpected output type"); - ERROR_IF(attribute->table().size() != TableNumEntries, "OpTable: table attribute size must be %u", TableNumEntries); - - for (uint32_t i = 0; i < TableNumEntries; i++) - { - table[i] = (TableEigenType)attribute->table()[i]; - } - in = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]); - out = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]); + in = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]); + table = dynamic_cast<TosaReference::TensorTemplate<TTable>*>(inputs[1]); + out = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]); ASSERT_MEM(in && out); @@ -683,13 +705,15 @@ int OpTable<Rank, InDtype>::checkTensorAttributes() template <int Rank, TOSA_REF_TYPE InDtype> int OpTable<Rank, InDtype>::eval() { + ERROR_IF(this->table->getTensor().size() != TableNumEntries, "OpTable: table tensor size must be %u", + TableNumEntries); switch (InDtype) { case TOSA_REF_TYPE_INT8: this->out->getTensor() = this->in->getTensor().unaryExpr([this](InEigenType in) -> OutEigenType { int32_t input_truncated = std::min<int32_t>(std::max<int32_t>(in, QInMin), QInMax); int32_t index = input_truncated - QInMin; - int32_t value = table[index]; + int32_t value = this->table->getTensor()(index); return value; }); @@ -705,8 +729,8 @@ int OpTable<Rank, InDtype>::eval() int32_t frac = (input_truncated)&0x7F; // 7-bit fraction // 3. Add REQUIRE CHECK for extreme large/small slopes - int32_t base = table[index]; - int32_t next = table[index + 1]; + int32_t base = this->table->getTensor()(index); + int32_t next = this->table->getTensor()(index + 1); int32_t slope = next - base; REQUIRE(slope <= std::numeric_limits<int16_t>::max() && slope >= std::numeric_limits<int16_t>::min(), "OpTable: slope out of int16_t range"); diff --git a/reference_model/src/ops/ewise_binary.h b/reference_model/src/ops/ewise_binary.h index 7ebd852..54c05e3 100644 --- a/reference_model/src/ops/ewise_binary.h +++ b/reference_model/src/ops/ewise_binary.h @@ -215,8 +215,7 @@ public: protected: TosaReference::TensorTemplate<TIn>* in; TosaReference::TensorTemplate<TOut>* out; - TosaTableAttribute* attribute; - std::array<TableEigenType, TableNumEntries> table; + TosaReference::TensorTemplate<TTable>* table; }; }; // namespace TosaReference diff --git a/reference_model/src/ops/type_conversion.cc b/reference_model/src/ops/type_conversion.cc index 9719f07..7c77c22 100644 --- a/reference_model/src/ops/type_conversion.cc +++ b/reference_model/src/ops/type_conversion.cc @@ -843,7 +843,19 @@ DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT48, INT16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT48, INT32); DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, UINT8, INT8); DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, UINT8, INT16); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, UINT8, INT32); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, UINT16, INT8); DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, UINT16, INT16); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, UINT16, INT32); DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT8, UINT8); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT8, UINT16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT16, UINT8); DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT16, UINT16); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT32, UINT8); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT32, UINT16); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT48, UINT8); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT48, UINT16); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, UINT8, UINT8); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, UINT8, UINT16); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, UINT16, UINT8); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, UINT16, UINT16); |