aboutsummaryrefslogtreecommitdiff
path: root/reference_model/src/ops/reduction.cc
diff options
context:
space:
mode:
Diffstat (limited to 'reference_model/src/ops/reduction.cc')
-rw-r--r--reference_model/src/ops/reduction.cc50
1 files changed, 46 insertions, 4 deletions
diff --git a/reference_model/src/ops/reduction.cc b/reference_model/src/ops/reduction.cc
index eccba09..cd9d55f 100644
--- a/reference_model/src/ops/reduction.cc
+++ b/reference_model/src/ops/reduction.cc
@@ -80,10 +80,30 @@ int ReduceNode<Rank, Dtype>::checkTensorAttributes()
return 0;
}
+// These 2 reducers are to overcome a bug introduced in Eigen between 3.3.7 and 3.4.0
+// The in-built .any and .all operations now fail on an assert in TensorMorphing.h:150
+// which seems to be due to incorrect data being passed internally as m_impl
+struct AllReducer {
+ static const bool PacketAccess = false;
+ void reduce(const bool val, bool* accum) {
+ *accum = *accum && val;
+ }
+ bool initialize() const { return true; }
+ bool finalize(const bool accum) const { return accum; }
+};
+struct AnyReducer {
+ static const bool PacketAccess = false;
+ void reduce(const bool val, bool* accum) {
+ *accum = *accum || val;
+ }
+ bool initialize() const { return false; }
+ bool finalize(const bool accum) const { return accum; }
+};
+
template <int Rank, DType Dtype>
int OpReduceAll<Rank, Dtype>::eval()
{
- this->out->getTensor() = this->in->getTensor().all(this->dims).reshape(this->out->getTensor().dimensions());
+ this->out->getTensor() = this->in->getTensor().reduce(this->dims, AllReducer()).reshape(this->out->getTensor().dimensions());
return GraphNode::eval();
}
@@ -91,7 +111,7 @@ int OpReduceAll<Rank, Dtype>::eval()
template <int Rank, DType Dtype>
int OpReduceAny<Rank, Dtype>::eval()
{
- this->out->getTensor() = this->in->getTensor().any(this->dims).reshape(this->out->getTensor().dimensions());
+ this->out->getTensor() = this->in->getTensor().reduce(this->dims, AnyReducer()).reshape(this->out->getTensor().dimensions());
return GraphNode::eval();
}
@@ -115,7 +135,16 @@ int OpReduceMin<Rank, Dtype>::eval()
template <int Rank, DType Dtype>
int OpReduceProduct<Rank, Dtype>::eval()
{
- this->out->getTensor() = this->in->getTensor().prod(this->dims).reshape(this->out->getTensor().dimensions());
+ switch(Dtype)
+ {
+ case DType_FP16:
+ case DType_BF16:
+ this->out->getTensor() = this->in->getTensor().prod(this->dims).reshape(this->out->getTensor().dimensions()).unaryExpr([](float f){return fpTrunc<Dtype>(f);});
+ break;
+ default:
+ this->out->getTensor() = this->in->getTensor().prod(this->dims).reshape(this->out->getTensor().dimensions());
+ break;
+ }
return GraphNode::eval();
}
@@ -123,7 +152,16 @@ int OpReduceProduct<Rank, Dtype>::eval()
template <int Rank, DType Dtype>
int OpReduceSum<Rank, Dtype>::eval()
{
- this->out->getTensor() = this->in->getTensor().sum(this->dims).reshape(this->out->getTensor().dimensions());
+ switch(Dtype)
+ {
+ case DType_FP16:
+ case DType_BF16:
+ this->out->getTensor() = this->in->getTensor().sum(this->dims).reshape(this->out->getTensor().dimensions()).unaryExpr([](float f){return fpTrunc<Dtype>(f);});
+ break;
+ default:
+ this->out->getTensor() = this->in->getTensor().sum(this->dims).reshape(this->out->getTensor().dimensions());
+ break;
+ }
return GraphNode::eval();
}
@@ -159,20 +197,24 @@ DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceAll, BOOL);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceAny, BOOL);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMax, FP16);
+DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMax, BF16);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMax, FP32);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMax, INT8);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMax, INT16);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMax, INT32);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMin, FP16);
+DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMin, BF16);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMin, FP32);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMin, INT8);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMin, INT16);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMin, INT32);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceProduct, FP16);
+DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceProduct, BF16);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceProduct, FP32);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceSum, FP16);
+DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceSum, BF16);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceSum, FP32);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceSumInt, INT32);