diff options
author | James Ward <james.ward@arm.com> | 2022-10-19 12:20:31 +0100 |
---|---|---|
committer | Jeremy Johnson <jeremy.johnson@arm.com> | 2022-11-09 12:19:51 +0000 |
commit | 24dbc420aae556649f50e645bd94489dab2cc75a (patch) | |
tree | 490345da43e9c5bae0f450ba05ffe85874077e0a /reference_model/src/ops/reduction.cc | |
parent | 3b0544c1e7463295c49a48a162ebb9a546326829 (diff) | |
download | reference_model-24dbc420aae556649f50e645bd94489dab2cc75a.tar.gz |
Add BF16 support to reference model
* Upgrade Eigen to 3.4.0 (for bfloat16 support) and add work-
arounds for reduce.any() and reduce.all() bugs (introduced
between 3.3.7 and 3.4.0)
* Truncation to bfloat16 now performed in eval() methods
Signed-off-by: James Ward <james.ward@arm.com>
Signed-off-by: Jeremy Johnson <jeremy.johnson@arm.com>
Change-Id: If5f5c988d76d3d30790acf3b97081726b89205fe
Diffstat (limited to 'reference_model/src/ops/reduction.cc')
-rw-r--r-- | reference_model/src/ops/reduction.cc | 50 |
1 files changed, 46 insertions, 4 deletions
diff --git a/reference_model/src/ops/reduction.cc b/reference_model/src/ops/reduction.cc index eccba09..cd9d55f 100644 --- a/reference_model/src/ops/reduction.cc +++ b/reference_model/src/ops/reduction.cc @@ -80,10 +80,30 @@ int ReduceNode<Rank, Dtype>::checkTensorAttributes() return 0; } +// These 2 reducers are to overcome a bug introduced in Eigen between 3.3.7 and 3.4.0 +// The in-built .any and .all operations now fail on an assert in TensorMorphing.h:150 +// which seems to be due to incorrect data being passed internally as m_impl +struct AllReducer { + static const bool PacketAccess = false; + void reduce(const bool val, bool* accum) { + *accum = *accum && val; + } + bool initialize() const { return true; } + bool finalize(const bool accum) const { return accum; } +}; +struct AnyReducer { + static const bool PacketAccess = false; + void reduce(const bool val, bool* accum) { + *accum = *accum || val; + } + bool initialize() const { return false; } + bool finalize(const bool accum) const { return accum; } +}; + template <int Rank, DType Dtype> int OpReduceAll<Rank, Dtype>::eval() { - this->out->getTensor() = this->in->getTensor().all(this->dims).reshape(this->out->getTensor().dimensions()); + this->out->getTensor() = this->in->getTensor().reduce(this->dims, AllReducer()).reshape(this->out->getTensor().dimensions()); return GraphNode::eval(); } @@ -91,7 +111,7 @@ int OpReduceAll<Rank, Dtype>::eval() template <int Rank, DType Dtype> int OpReduceAny<Rank, Dtype>::eval() { - this->out->getTensor() = this->in->getTensor().any(this->dims).reshape(this->out->getTensor().dimensions()); + this->out->getTensor() = this->in->getTensor().reduce(this->dims, AnyReducer()).reshape(this->out->getTensor().dimensions()); return GraphNode::eval(); } @@ -115,7 +135,16 @@ int OpReduceMin<Rank, Dtype>::eval() template <int Rank, DType Dtype> int OpReduceProduct<Rank, Dtype>::eval() { - this->out->getTensor() = this->in->getTensor().prod(this->dims).reshape(this->out->getTensor().dimensions()); + switch(Dtype) + { + case DType_FP16: + case DType_BF16: + this->out->getTensor() = this->in->getTensor().prod(this->dims).reshape(this->out->getTensor().dimensions()).unaryExpr([](float f){return fpTrunc<Dtype>(f);}); + break; + default: + this->out->getTensor() = this->in->getTensor().prod(this->dims).reshape(this->out->getTensor().dimensions()); + break; + } return GraphNode::eval(); } @@ -123,7 +152,16 @@ int OpReduceProduct<Rank, Dtype>::eval() template <int Rank, DType Dtype> int OpReduceSum<Rank, Dtype>::eval() { - this->out->getTensor() = this->in->getTensor().sum(this->dims).reshape(this->out->getTensor().dimensions()); + switch(Dtype) + { + case DType_FP16: + case DType_BF16: + this->out->getTensor() = this->in->getTensor().sum(this->dims).reshape(this->out->getTensor().dimensions()).unaryExpr([](float f){return fpTrunc<Dtype>(f);}); + break; + default: + this->out->getTensor() = this->in->getTensor().sum(this->dims).reshape(this->out->getTensor().dimensions()); + break; + } return GraphNode::eval(); } @@ -159,20 +197,24 @@ DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceAll, BOOL); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceAny, BOOL); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMax, FP16); +DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMax, BF16); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMax, FP32); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMax, INT8); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMax, INT16); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMax, INT32); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMin, FP16); +DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMin, BF16); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMin, FP32); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMin, INT8); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMin, INT16); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMin, INT32); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceProduct, FP16); +DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceProduct, BF16); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceProduct, FP32); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceSum, FP16); +DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceSum, BF16); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceSum, FP32); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceSumInt, INT32); |