diff options
Diffstat (limited to 'reference_model/src/ops/image.cc')
-rw-r--r-- | reference_model/src/ops/image.cc | 34 |
1 files changed, 20 insertions, 14 deletions
diff --git a/reference_model/src/ops/image.cc b/reference_model/src/ops/image.cc index 829a6e0..f4decae 100644 --- a/reference_model/src/ops/image.cc +++ b/reference_model/src/ops/image.cc @@ -22,8 +22,11 @@ using namespace Eigen; using namespace tosa; template <DType InDtype, DType OutDtype> -OpResize<InDtype, OutDtype>::OpResize(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) - : GraphNode(Op_RESIZE, id_) +OpResize<InDtype, OutDtype>::OpResize(SubgraphTraverser* sgt_, + TosaAttributeBase* attribute_, + TosaQuantInfoBase* qinfo_, + uint64_t id_) + : GraphNode(sgt_, Op_RESIZE, id_) { setRequiredOperands(1, 1); setRequiredRank(4, 4); @@ -102,10 +105,13 @@ int OpResize<InDtype, OutDtype>::eval() int out_width = out->getShape()[2]; int out_channels = out->getShape()[3]; - ASSERT_MSG_NODE(shift > 0 && shift <= 11, "OpResize: attribute shift should be within [1, 11]"); - ASSERT_MSG_NODE(stride[0] > 0 && stride[1] > 0, "OpResize: invalid attribute stride"); - ASSERT_MSG_NODE(in_batch == out_batch, "OpResize: output tensor batch mismatch"); - ASSERT_MSG_NODE(in_channels == out_channels, "OpResize: output tensor channel mismatch"); + ERROR_IF(shift < 1 || shift > 11, "OpResize: attribute shift should be within [1, 11]"); + ERROR_IF(stride[0] <= 0 || stride[0] >= (16 << shift), "OpResize: invalid attribute stride_x"); + ERROR_IF(stride[1] <= 0 || stride[1] >= (16 << shift), "OpResize: invalid attribute stride_y"); + ERROR_IF(offset[0] <= (-16 << shift) || offset[0] >= (16 << shift), "OpResize: invalid attribute offset_x"); + ERROR_IF(offset[1] <= (-16 << shift) || offset[1] >= (16 << shift), "OpResize: invalid attribute offset_y"); + ERROR_IF(in_batch != out_batch, "OpResize: output tensor batch mismatch"); + ERROR_IF(in_channels != out_channels, "OpResize: output tensor channel mismatch"); for (int b = 0; b < out_batch; b++) for (int c = 0; c < out_channels; c++) @@ -125,8 +131,8 @@ int OpResize<InDtype, OutDtype>::eval() int32_t ix0 = MAX(ix, 0); int32_t ix1 = MIN(ix + 1, in_width - 1); - ASSERT_MSG(iy0 <= iy1 && ix0 <= ix1, "OpResize: invalid index (iy0, iy1, ix0, ix1)=(%d,%d,%d,%d)", - iy0, iy1, ix0, ix1); + REQUIRE(iy0 <= iy1 && ix0 <= ix1, "OpResize: invalid index (iy0, iy1, ix0, ix1)=(%d,%d,%d,%d)", iy0, + iy1, ix0, ix1); OutEigenType acc; if (mode == ResizeMode_BILINEAR) @@ -167,10 +173,10 @@ int OpResize<DType_FLOAT, DType_FLOAT>::eval() int out_width = out->getShape()[2]; int out_channels = out->getShape()[3]; - ASSERT_MSG_NODE(shift == 0, "OpResize: float mode must have 0 shift"); - ASSERT_MSG_NODE(stride_fp[0] > 0.0f && stride_fp[1] > 0.0f, "OpResize: invalid attribute stride"); - ASSERT_MSG_NODE(in_batch == out_batch, "OpResize: output tensor batch mismatch"); - ASSERT_MSG_NODE(in_channels == out_channels, "OpResize: output tensor channel mismatch"); + ERROR_IF(shift != 0, "OpResize: float mode must have 0 shift"); + ERROR_IF(stride_fp[0] <= 0.0f || stride_fp[1] <= 0.0f, "OpResize: invalid attribute stride"); + ERROR_IF(in_batch != out_batch, "OpResize: output tensor batch mismatch"); + ERROR_IF(in_channels != out_channels, "OpResize: output tensor channel mismatch"); for (int b = 0; b < out_batch; b++) for (int c = 0; c < out_channels; c++) @@ -190,8 +196,8 @@ int OpResize<DType_FLOAT, DType_FLOAT>::eval() int32_t ix0 = MAX(ix, 0); int32_t ix1 = MIN(ix + 1, in_width - 1); - ASSERT_MSG(iy0 <= iy1 && ix0 <= ix1, "OpResize: invalid index (iy0, iy1, ix0, ix1)=(%d,%d,%d,%d)", - iy0, iy1, ix0, ix1); + REQUIRE(iy0 <= iy1 && ix0 <= ix1, "OpResize: invalid index (iy0, iy1, ix0, ix1)=(%d,%d,%d,%d)", iy0, + iy1, ix0, ix1); OutEigenType acc; if (mode == ResizeMode_BILINEAR) |