diff options
Diffstat (limited to 'reference_model/src/ops/reduction.cc')
-rw-r--r-- | reference_model/src/ops/reduction.cc | 50 |
1 files changed, 46 insertions, 4 deletions
diff --git a/reference_model/src/ops/reduction.cc b/reference_model/src/ops/reduction.cc index eccba09..cd9d55f 100644 --- a/reference_model/src/ops/reduction.cc +++ b/reference_model/src/ops/reduction.cc @@ -80,10 +80,30 @@ int ReduceNode<Rank, Dtype>::checkTensorAttributes() return 0; } +// These 2 reducers are to overcome a bug introduced in Eigen between 3.3.7 and 3.4.0 +// The in-built .any and .all operations now fail on an assert in TensorMorphing.h:150 +// which seems to be due to incorrect data being passed internally as m_impl +struct AllReducer { + static const bool PacketAccess = false; + void reduce(const bool val, bool* accum) { + *accum = *accum && val; + } + bool initialize() const { return true; } + bool finalize(const bool accum) const { return accum; } +}; +struct AnyReducer { + static const bool PacketAccess = false; + void reduce(const bool val, bool* accum) { + *accum = *accum || val; + } + bool initialize() const { return false; } + bool finalize(const bool accum) const { return accum; } +}; + template <int Rank, DType Dtype> int OpReduceAll<Rank, Dtype>::eval() { - this->out->getTensor() = this->in->getTensor().all(this->dims).reshape(this->out->getTensor().dimensions()); + this->out->getTensor() = this->in->getTensor().reduce(this->dims, AllReducer()).reshape(this->out->getTensor().dimensions()); return GraphNode::eval(); } @@ -91,7 +111,7 @@ int OpReduceAll<Rank, Dtype>::eval() template <int Rank, DType Dtype> int OpReduceAny<Rank, Dtype>::eval() { - this->out->getTensor() = this->in->getTensor().any(this->dims).reshape(this->out->getTensor().dimensions()); + this->out->getTensor() = this->in->getTensor().reduce(this->dims, AnyReducer()).reshape(this->out->getTensor().dimensions()); return GraphNode::eval(); } @@ -115,7 +135,16 @@ int OpReduceMin<Rank, Dtype>::eval() template <int Rank, DType Dtype> int OpReduceProduct<Rank, Dtype>::eval() { - this->out->getTensor() = this->in->getTensor().prod(this->dims).reshape(this->out->getTensor().dimensions()); + switch(Dtype) + { + case DType_FP16: + case DType_BF16: + this->out->getTensor() = this->in->getTensor().prod(this->dims).reshape(this->out->getTensor().dimensions()).unaryExpr([](float f){return fpTrunc<Dtype>(f);}); + break; + default: + this->out->getTensor() = this->in->getTensor().prod(this->dims).reshape(this->out->getTensor().dimensions()); + break; + } return GraphNode::eval(); } @@ -123,7 +152,16 @@ int OpReduceProduct<Rank, Dtype>::eval() template <int Rank, DType Dtype> int OpReduceSum<Rank, Dtype>::eval() { - this->out->getTensor() = this->in->getTensor().sum(this->dims).reshape(this->out->getTensor().dimensions()); + switch(Dtype) + { + case DType_FP16: + case DType_BF16: + this->out->getTensor() = this->in->getTensor().sum(this->dims).reshape(this->out->getTensor().dimensions()).unaryExpr([](float f){return fpTrunc<Dtype>(f);}); + break; + default: + this->out->getTensor() = this->in->getTensor().sum(this->dims).reshape(this->out->getTensor().dimensions()); + break; + } return GraphNode::eval(); } @@ -159,20 +197,24 @@ DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceAll, BOOL); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceAny, BOOL); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMax, FP16); +DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMax, BF16); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMax, FP32); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMax, INT8); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMax, INT16); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMax, INT32); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMin, FP16); +DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMin, BF16); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMin, FP32); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMin, INT8); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMin, INT16); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMin, INT32); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceProduct, FP16); +DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceProduct, BF16); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceProduct, FP32); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceSum, FP16); +DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceSum, BF16); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceSum, FP32); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceSumInt, INT32); |