aboutsummaryrefslogtreecommitdiff
path: root/reference_model/src/ops/reduction.cc
diff options
context:
space:
mode:
Diffstat (limited to 'reference_model/src/ops/reduction.cc')
-rw-r--r--reference_model/src/ops/reduction.cc27
1 files changed, 26 insertions, 1 deletions
diff --git a/reference_model/src/ops/reduction.cc b/reference_model/src/ops/reduction.cc
index 8c1c4d0..18fac44 100644
--- a/reference_model/src/ops/reduction.cc
+++ b/reference_model/src/ops/reduction.cc
@@ -128,6 +128,31 @@ int OpReduceSum<Rank, Dtype>::eval()
return GraphNode::eval();
}
+struct SumRequiresReducer {
+ static const bool PacketAccess = false;
+ SumRequiresReducer(SubgraphTraverser* parent_sgt) : parent_sgt(parent_sgt) {}
+ void reduce(const int32_t val, int32_t* accum) {
+ int64_t res_in_64 = static_cast<int64_t>(*accum) + val;
+ int64_t i32_max_in_64 = static_cast<int64_t>(std::numeric_limits<int32_t>::max());
+ int64_t i32_min_in_64 = static_cast<int64_t>(std::numeric_limits<int32_t>::min());
+ REQUIRE(res_in_64 <= i32_max_in_64 && res_in_64 >= i32_min_in_64, "OpReduceSum: result not in i32 range");
+ *accum = static_cast<int32_t>(res_in_64);
+ }
+ int32_t initialize() const { return 0; }
+ int32_t finalize(const int32_t accum) const { return accum; }
+
+ private:
+ SubgraphTraverser* parent_sgt;
+};
+
+template <int Rank, DType Dtype>
+int OpReduceSumInt<Rank, Dtype>::eval()
+{
+ this->out->getTensor() = this->in->getTensor().reduce(this->dims, SumRequiresReducer(this->parent_sgt)).reshape(this->out->getTensor().dimensions());
+
+ return GraphNode::eval();
+}
+
// template explicit instantiation
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceAll, BOOL);
@@ -146,4 +171,4 @@ DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMin, INT32);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceProduct, FLOAT);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceSum, FLOAT);
-DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceSum, INT32);
+DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceSumInt, INT32);