diff options
Diffstat (limited to 'reference_model/src/ops/image.cc')
-rw-r--r-- | reference_model/src/ops/image.cc | 116 |
1 files changed, 87 insertions, 29 deletions
diff --git a/reference_model/src/ops/image.cc b/reference_model/src/ops/image.cc index d3352ce..829a6e0 100644 --- a/reference_model/src/ops/image.cc +++ b/reference_model/src/ops/image.cc @@ -51,6 +51,8 @@ int OpResize<InDtype, OutDtype>::checkTensorAttributes() stride = this->attribute->stride(); offset = this->attribute->offset(); shift = this->attribute->shift(); + stride_fp = this->attribute->stride_fp(); + offset_fp = this->attribute->offset_fp(); mode = this->attribute->mode(); int output_height = outputs[0]->getShape()[1]; @@ -58,7 +60,7 @@ int OpResize<InDtype, OutDtype>::checkTensorAttributes() if (this->mode == ResizeMode_BILINEAR) { - if (OutDtype != DType_INT32 && OutDtype != DType_INT48) + if (OutDtype != DType_INT32 && OutDtype != DType_INT48 && OutDtype != DType_FLOAT) { printNodeValidationError("OpResize: invalid data type for BILINEAR"); return 1; @@ -66,7 +68,7 @@ int OpResize<InDtype, OutDtype>::checkTensorAttributes() } else { - if (OutDtype != DType_INT8 && OutDtype != DType_INT16) + if (OutDtype != DType_INT8 && OutDtype != DType_INT16 && OutDtype != DType_FLOAT) { printNodeValidationError("OpResize: invalid data type for NEAREST"); return 1; @@ -79,18 +81,6 @@ int OpResize<InDtype, OutDtype>::checkTensorAttributes() return 1; } - if (shift < 1 || shift > 11) - { - printNodeValidationError("OpResize: attribute shift should be within [1, 11]"); - return 1; - } - - if (stride[0] <= 0 || stride[1] <= 0) - { - printNodeValidationError("OpResize: invalid attribute stride"); - return 1; - } - in = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]); out = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]); @@ -112,6 +102,8 @@ int OpResize<InDtype, OutDtype>::eval() int out_width = out->getShape()[2]; int out_channels = out->getShape()[3]; + ASSERT_MSG_NODE(shift > 0 && shift <= 11, "OpResize: attribute shift should be within [1, 11]"); + ASSERT_MSG_NODE(stride[0] > 0 && stride[1] > 0, "OpResize: invalid attribute stride"); ASSERT_MSG_NODE(in_batch == out_batch, "OpResize: output tensor batch mismatch"); ASSERT_MSG_NODE(in_channels == out_channels, "OpResize: output tensor channel mismatch"); @@ -120,30 +112,30 @@ int OpResize<InDtype, OutDtype>::eval() for (int oy = 0; oy < out_height; oy++) for (int ox = 0; ox < out_width; ox++) { - int y = oy * stride[0] + offset[0]; - int x = ox * stride[1] + offset[1]; + int32_t y = oy * stride[0] + offset[0]; + int32_t x = ox * stride[1] + offset[1]; - int iy = y >> shift; - int dy = y - (iy << shift); - int ix = x >> shift; - int dx = x - (ix << shift); + int32_t iy = y >> shift; + int32_t dy = y - (iy << shift); + int32_t ix = x >> shift; + int32_t dx = x - (ix << shift); - int iy0 = MAX(iy, 0); - int iy1 = MIN(iy + 1, in_height - 1); - int ix0 = MAX(ix, 0); - int ix1 = MIN(ix + 1, in_width - 1); + int32_t iy0 = MAX(iy, 0); + int32_t iy1 = MIN(iy + 1, in_height - 1); + int32_t ix0 = MAX(ix, 0); + int32_t ix1 = MIN(ix + 1, in_width - 1); ASSERT_MSG(iy0 <= iy1 && ix0 <= ix1, "OpResize: invalid index (iy0, iy1, ix0, ix1)=(%d,%d,%d,%d)", iy0, iy1, ix0, ix1); - InEigenType v00 = in->getTensor()(b, iy0, ix0, c); - InEigenType v01 = in->getTensor()(b, iy0, ix1, c); - InEigenType v10 = in->getTensor()(b, iy1, ix0, c); - InEigenType v11 = in->getTensor()(b, iy1, ix1, c); - OutEigenType acc; if (mode == ResizeMode_BILINEAR) { + InEigenType v00 = in->getTensor()(b, iy0, ix0, c); + InEigenType v01 = in->getTensor()(b, iy0, ix1, c); + InEigenType v10 = in->getTensor()(b, iy1, ix0, c); + InEigenType v11 = in->getTensor()(b, iy1, ix1, c); + acc = (OutEigenType)v00 * ((1 << shift) - dy) * ((1 << shift) - dx); acc = acc + (OutEigenType)v01 * ((1 << shift) - dy) * dx; acc = acc + (OutEigenType)v10 * dy * ((1 << shift) - dx); @@ -162,8 +154,74 @@ int OpResize<InDtype, OutDtype>::eval() return GraphNode::eval(); } +template <> +int OpResize<DType_FLOAT, DType_FLOAT>::eval() +{ + int in_batch = in->getShape()[0]; + int in_height = in->getShape()[1]; + int in_width = in->getShape()[2]; + int in_channels = in->getShape()[3]; + + int out_batch = out->getShape()[0]; + int out_height = out->getShape()[1]; + int out_width = out->getShape()[2]; + int out_channels = out->getShape()[3]; + + ASSERT_MSG_NODE(shift == 0, "OpResize: float mode must have 0 shift"); + ASSERT_MSG_NODE(stride_fp[0] > 0.0f && stride_fp[1] > 0.0f, "OpResize: invalid attribute stride"); + ASSERT_MSG_NODE(in_batch == out_batch, "OpResize: output tensor batch mismatch"); + ASSERT_MSG_NODE(in_channels == out_channels, "OpResize: output tensor channel mismatch"); + + for (int b = 0; b < out_batch; b++) + for (int c = 0; c < out_channels; c++) + for (int oy = 0; oy < out_height; oy++) + for (int ox = 0; ox < out_width; ox++) + { + float y = oy * stride_fp[0] + offset_fp[0]; + float x = ox * stride_fp[1] + offset_fp[1]; + + int32_t iy = static_cast<int32_t>(std::floor(y)); + float dy = y - static_cast<float>(iy); + int32_t ix = static_cast<int32_t>(std::floor(x)); + float dx = x - static_cast<float>(ix); + + int32_t iy0 = MAX(iy, 0); + int32_t iy1 = MIN(iy + 1, in_height - 1); + int32_t ix0 = MAX(ix, 0); + int32_t ix1 = MIN(ix + 1, in_width - 1); + + ASSERT_MSG(iy0 <= iy1 && ix0 <= ix1, "OpResize: invalid index (iy0, iy1, ix0, ix1)=(%d,%d,%d,%d)", + iy0, iy1, ix0, ix1); + + OutEigenType acc; + if (mode == ResizeMode_BILINEAR) + { + InEigenType v00 = in->getTensor()(b, iy0, ix0, c); + InEigenType v01 = in->getTensor()(b, iy0, ix1, c); + InEigenType v10 = in->getTensor()(b, iy1, ix0, c); + InEigenType v11 = in->getTensor()(b, iy1, ix1, c); + + acc = (OutEigenType)v00 * (1.0 - dy) * (1.0 - dx); + acc = acc + (OutEigenType)v01 * (1.0 - dy) * dx; + acc = acc + (OutEigenType)v10 * dy * (1.0 - dx); + acc = acc + (OutEigenType)v11 * dy * dx; + } + else + { + iy = (dy >= 0.5) ? iy1 : iy0; + ix = (dx >= 0.5) ? ix1 : ix0; + acc = in->getTensor()(b, iy, ix, c); + } + + out->getTensor()(b, oy, ox, c) = acc; + } + + return GraphNode::eval(); +} + // template explicit instantiation DEF_INSTANTIATE_TWO_TYPE(OpResize, INT8, INT32); DEF_INSTANTIATE_TWO_TYPE(OpResize, INT8, INT8); DEF_INSTANTIATE_TWO_TYPE(OpResize, INT16, INT48); DEF_INSTANTIATE_TWO_TYPE(OpResize, INT16, INT16); +DEF_INSTANTIATE_TWO_TYPE(OpResize, FLOAT, FLOAT); |