diff options
author | evacha01 <evan.chandler@arm.com> | 2024-03-08 16:39:24 +0000 |
---|---|---|
committer | Eric Kunze <eric.kunze@arm.com> | 2024-04-16 16:02:16 +0000 |
commit | 4a2051146f498cb9ec35d7213720540c5c3e81e2 (patch) | |
tree | 543000b3ef22bd587c3c7702100742e4b94eb5fb /reference_model/src/generate/generate_fp_special.cc | |
parent | 5d0e9c7f3748e80d6f14a3eeaef858eeb912e1fd (diff) | |
download | reference_model-4a2051146f498cb9ec35d7213720540c5c3e81e2.tar.gz |
SPECIAL data gen mode for FP16 and FP32
Signed-off-by: evacha01 <evan.chandler@arm.com>
Change-Id: I5a9a1c63345bd83ca04bc6c2a99b0ef3612971ee
Diffstat (limited to 'reference_model/src/generate/generate_fp_special.cc')
-rw-r--r-- | reference_model/src/generate/generate_fp_special.cc | 180 |
1 files changed, 180 insertions, 0 deletions
diff --git a/reference_model/src/generate/generate_fp_special.cc b/reference_model/src/generate/generate_fp_special.cc new file mode 100644 index 0000000..3602f51 --- /dev/null +++ b/reference_model/src/generate/generate_fp_special.cc @@ -0,0 +1,180 @@ +// Copyright (c) 2024, ARM Limited. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "generate_fp_special.h" +#include "half.hpp" + +#include <map> + +namespace +{ + +class SpecialValue +{ +public: + enum SpecialValsEnum + { + Zero, + Inf, + NaN, + Min, + Max, + One, + }; + + SpecialValue() = default; + SpecialValue(SpecialValsEnum v) + : value(v) + {} + operator SpecialValsEnum() const + { + return value; + } + SpecialValue& operator=(SpecialValsEnum v) + { + value = v; + return *this; + } + bool operator==(const SpecialValsEnum v) const + { + return value == v; + } + bool operator!=(const SpecialValsEnum v) const + { + return value != v; + } + SpecialValue operator-() + { + negative = !negative; + return *this; + } + + template <typename DataType> + DataType evaluate() + { + switch (value) + { + case Zero: + return static_cast<DataType>(negative ? -0.0 : 0.0); + case Inf: + return negative ? -std::numeric_limits<DataType>::infinity() + : std::numeric_limits<DataType>::infinity(); + case NaN: + return std::numeric_limits<DataType>::quiet_NaN(); + case Min: + return negative ? -std::numeric_limits<DataType>::min() : std::numeric_limits<DataType>::min(); + case Max: + return negative ? -std::numeric_limits<DataType>::max() : std::numeric_limits<DataType>::max(); + case One: + return static_cast<DataType>(negative ? -1.0 : 1.0); + default: + WARNING("[Generator][FS] Uninitialised special value."); + return static_cast<DataType>(0.0); + } + } + +private: + SpecialValsEnum value; + bool negative = false; +}; + +/* +Test vals format + +I: Number of inputs to an op - referenced by cfg.inputPos +T: Number of test cases defined for the op + +vector of test inputs: { + vector of values for test 0: { valueForinputPos0, valueForinputPos1, ..., valueForinputPosI-1 }, + vector of values for test 1: { valueForinputPos0, valueForinputPos1, ..., valueForinputPosI-1 }, + ... + vector of values for test T-1: { valueForinputPos0, valueForinputPos1, ..., valueForinputPosI-1 }, +} +*/ +using TestValues = std::vector<std::vector<SpecialValue>>; + +TestValues equalOpsTestVals{ { SpecialValue(SpecialValue::Zero), -SpecialValue(SpecialValue::Zero) }, + { SpecialValue(SpecialValue::Inf), -SpecialValue(SpecialValue::Inf) } }; + +TestValues addTestVals{ { SpecialValue(SpecialValue::Max), SpecialValue(SpecialValue::One) }, + { SpecialValue(SpecialValue::Inf), -SpecialValue(SpecialValue::Inf) } }; + +TestValues defaultTestVals{ { SpecialValue(SpecialValue::Zero) }, { -SpecialValue(SpecialValue::Zero) }, + { SpecialValue(SpecialValue::Inf) }, { -SpecialValue(SpecialValue::Inf) }, + { SpecialValue(SpecialValue::NaN) }, { SpecialValue(SpecialValue::Min) }, + { SpecialValue(SpecialValue::Max) } }; + +std::map<Op, TestValues> testValues = { { Op::Op_EQUAL, equalOpsTestVals }, + { Op::Op_GREATER, equalOpsTestVals }, + { Op::Op_GREATER_EQUAL, equalOpsTestVals }, + { Op::Op_ADD, addTestVals } }; + +template <typename DataType> +bool generate(const TosaReference::GenerateConfig& cfg, DataType* data, size_t size) +{ + const TosaReference::FpSpecialInfo& fsinfo = cfg.fpSpecialInfo; + uint8_t startIndex = fsinfo.startIndex; + + std::vector<DataType> values; + auto testValuesResult = testValues.find(cfg.opType); + TestValues opTestVals = defaultTestVals; + size_t inputIndex = 0; + if (testValuesResult != testValues.end()) + { + // When an op has an entry in testValues we use its op specific special test values, otherwise default values are used + opTestVals = testValuesResult->second; + inputIndex = cfg.inputPos; + } + + for (std::vector<SpecialValue> inputs : opTestVals) + { + values.push_back(inputs[inputIndex].evaluate<DataType>()); + } + + const auto T = TosaReference::numElementsFromShape(cfg.shape); + for (auto t = 0; t < T; ++t) + { + data[t] = values[(t + startIndex) % values.size()]; + } + return true; +} +} // namespace + +namespace TosaReference +{ +bool generateFpSpecial(const GenerateConfig& cfg, void* data, size_t size) +{ + // Check we support the operator + if (cfg.opType == Op::Op_UNKNOWN) + { + WARNING("[Generator][FS] Unknown operator."); + return false; + } + + switch (cfg.dataType) + { + case DType::DType_FP16: { + half_float::half* outData = reinterpret_cast<half_float::half*>(data); + return generate(cfg, outData, size); + } + case DType::DType_FP32: { + float* outData = reinterpret_cast<float*>(data); + return generate(cfg, outData, size); + } + default: + WARNING("[Generator][FS] Unsupported type."); + return false; + } +} +} // namespace TosaReference
\ No newline at end of file |