From 7de9b456620c0b9df20c1bed466779149c4112fd Mon Sep 17 00:00:00 2001 From: Jeremy Johnson Date: Tue, 5 Apr 2022 14:31:37 +0100 Subject: Add missing REQUIREs check to REDUCE_SUM in refmodel And limit REDUCE_SUM test values to within int32 Signed-off-by: Jeremy Johnson Change-Id: I4d902b245d17eb343cfb2bbc23d9db28c1d1f4c3 --- reference_model/src/ops/reduction.cc | 27 ++++++++++++++++++++++++++- 1 file changed, 26 insertions(+), 1 deletion(-) (limited to 'reference_model/src/ops/reduction.cc') diff --git a/reference_model/src/ops/reduction.cc b/reference_model/src/ops/reduction.cc index 8c1c4d0..18fac44 100644 --- a/reference_model/src/ops/reduction.cc +++ b/reference_model/src/ops/reduction.cc @@ -128,6 +128,31 @@ int OpReduceSum::eval() return GraphNode::eval(); } +struct SumRequiresReducer { + static const bool PacketAccess = false; + SumRequiresReducer(SubgraphTraverser* parent_sgt) : parent_sgt(parent_sgt) {} + void reduce(const int32_t val, int32_t* accum) { + int64_t res_in_64 = static_cast(*accum) + val; + int64_t i32_max_in_64 = static_cast(std::numeric_limits::max()); + int64_t i32_min_in_64 = static_cast(std::numeric_limits::min()); + REQUIRE(res_in_64 <= i32_max_in_64 && res_in_64 >= i32_min_in_64, "OpReduceSum: result not in i32 range"); + *accum = static_cast(res_in_64); + } + int32_t initialize() const { return 0; } + int32_t finalize(const int32_t accum) const { return accum; } + + private: + SubgraphTraverser* parent_sgt; +}; + +template +int OpReduceSumInt::eval() +{ + this->out->getTensor() = this->in->getTensor().reduce(this->dims, SumRequiresReducer(this->parent_sgt)).reshape(this->out->getTensor().dimensions()); + + return GraphNode::eval(); +} + // template explicit instantiation DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceAll, BOOL); @@ -146,4 +171,4 @@ DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMin, INT32); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceProduct, FLOAT); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceSum, FLOAT); -DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceSum, INT32); +DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceSumInt, INT32); -- cgit v1.2.1