aboutsummaryrefslogtreecommitdiff
path: root/reference_model/src/ops/image.cc
diff options
context:
space:
mode:
Diffstat (limited to 'reference_model/src/ops/image.cc')
-rw-r--r--reference_model/src/ops/image.cc116
1 files changed, 87 insertions, 29 deletions
diff --git a/reference_model/src/ops/image.cc b/reference_model/src/ops/image.cc
index d3352ce..829a6e0 100644
--- a/reference_model/src/ops/image.cc
+++ b/reference_model/src/ops/image.cc
@@ -51,6 +51,8 @@ int OpResize<InDtype, OutDtype>::checkTensorAttributes()
stride = this->attribute->stride();
offset = this->attribute->offset();
shift = this->attribute->shift();
+ stride_fp = this->attribute->stride_fp();
+ offset_fp = this->attribute->offset_fp();
mode = this->attribute->mode();
int output_height = outputs[0]->getShape()[1];
@@ -58,7 +60,7 @@ int OpResize<InDtype, OutDtype>::checkTensorAttributes()
if (this->mode == ResizeMode_BILINEAR)
{
- if (OutDtype != DType_INT32 && OutDtype != DType_INT48)
+ if (OutDtype != DType_INT32 && OutDtype != DType_INT48 && OutDtype != DType_FLOAT)
{
printNodeValidationError("OpResize: invalid data type for BILINEAR");
return 1;
@@ -66,7 +68,7 @@ int OpResize<InDtype, OutDtype>::checkTensorAttributes()
}
else
{
- if (OutDtype != DType_INT8 && OutDtype != DType_INT16)
+ if (OutDtype != DType_INT8 && OutDtype != DType_INT16 && OutDtype != DType_FLOAT)
{
printNodeValidationError("OpResize: invalid data type for NEAREST");
return 1;
@@ -79,18 +81,6 @@ int OpResize<InDtype, OutDtype>::checkTensorAttributes()
return 1;
}
- if (shift < 1 || shift > 11)
- {
- printNodeValidationError("OpResize: attribute shift should be within [1, 11]");
- return 1;
- }
-
- if (stride[0] <= 0 || stride[1] <= 0)
- {
- printNodeValidationError("OpResize: invalid attribute stride");
- return 1;
- }
-
in = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
out = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
@@ -112,6 +102,8 @@ int OpResize<InDtype, OutDtype>::eval()
int out_width = out->getShape()[2];
int out_channels = out->getShape()[3];
+ ASSERT_MSG_NODE(shift > 0 && shift <= 11, "OpResize: attribute shift should be within [1, 11]");
+ ASSERT_MSG_NODE(stride[0] > 0 && stride[1] > 0, "OpResize: invalid attribute stride");
ASSERT_MSG_NODE(in_batch == out_batch, "OpResize: output tensor batch mismatch");
ASSERT_MSG_NODE(in_channels == out_channels, "OpResize: output tensor channel mismatch");
@@ -120,30 +112,30 @@ int OpResize<InDtype, OutDtype>::eval()
for (int oy = 0; oy < out_height; oy++)
for (int ox = 0; ox < out_width; ox++)
{
- int y = oy * stride[0] + offset[0];
- int x = ox * stride[1] + offset[1];
+ int32_t y = oy * stride[0] + offset[0];
+ int32_t x = ox * stride[1] + offset[1];
- int iy = y >> shift;
- int dy = y - (iy << shift);
- int ix = x >> shift;
- int dx = x - (ix << shift);
+ int32_t iy = y >> shift;
+ int32_t dy = y - (iy << shift);
+ int32_t ix = x >> shift;
+ int32_t dx = x - (ix << shift);
- int iy0 = MAX(iy, 0);
- int iy1 = MIN(iy + 1, in_height - 1);
- int ix0 = MAX(ix, 0);
- int ix1 = MIN(ix + 1, in_width - 1);
+ int32_t iy0 = MAX(iy, 0);
+ int32_t iy1 = MIN(iy + 1, in_height - 1);
+ int32_t ix0 = MAX(ix, 0);
+ int32_t ix1 = MIN(ix + 1, in_width - 1);
ASSERT_MSG(iy0 <= iy1 && ix0 <= ix1, "OpResize: invalid index (iy0, iy1, ix0, ix1)=(%d,%d,%d,%d)",
iy0, iy1, ix0, ix1);
- InEigenType v00 = in->getTensor()(b, iy0, ix0, c);
- InEigenType v01 = in->getTensor()(b, iy0, ix1, c);
- InEigenType v10 = in->getTensor()(b, iy1, ix0, c);
- InEigenType v11 = in->getTensor()(b, iy1, ix1, c);
-
OutEigenType acc;
if (mode == ResizeMode_BILINEAR)
{
+ InEigenType v00 = in->getTensor()(b, iy0, ix0, c);
+ InEigenType v01 = in->getTensor()(b, iy0, ix1, c);
+ InEigenType v10 = in->getTensor()(b, iy1, ix0, c);
+ InEigenType v11 = in->getTensor()(b, iy1, ix1, c);
+
acc = (OutEigenType)v00 * ((1 << shift) - dy) * ((1 << shift) - dx);
acc = acc + (OutEigenType)v01 * ((1 << shift) - dy) * dx;
acc = acc + (OutEigenType)v10 * dy * ((1 << shift) - dx);
@@ -162,8 +154,74 @@ int OpResize<InDtype, OutDtype>::eval()
return GraphNode::eval();
}
+template <>
+int OpResize<DType_FLOAT, DType_FLOAT>::eval()
+{
+ int in_batch = in->getShape()[0];
+ int in_height = in->getShape()[1];
+ int in_width = in->getShape()[2];
+ int in_channels = in->getShape()[3];
+
+ int out_batch = out->getShape()[0];
+ int out_height = out->getShape()[1];
+ int out_width = out->getShape()[2];
+ int out_channels = out->getShape()[3];
+
+ ASSERT_MSG_NODE(shift == 0, "OpResize: float mode must have 0 shift");
+ ASSERT_MSG_NODE(stride_fp[0] > 0.0f && stride_fp[1] > 0.0f, "OpResize: invalid attribute stride");
+ ASSERT_MSG_NODE(in_batch == out_batch, "OpResize: output tensor batch mismatch");
+ ASSERT_MSG_NODE(in_channels == out_channels, "OpResize: output tensor channel mismatch");
+
+ for (int b = 0; b < out_batch; b++)
+ for (int c = 0; c < out_channels; c++)
+ for (int oy = 0; oy < out_height; oy++)
+ for (int ox = 0; ox < out_width; ox++)
+ {
+ float y = oy * stride_fp[0] + offset_fp[0];
+ float x = ox * stride_fp[1] + offset_fp[1];
+
+ int32_t iy = static_cast<int32_t>(std::floor(y));
+ float dy = y - static_cast<float>(iy);
+ int32_t ix = static_cast<int32_t>(std::floor(x));
+ float dx = x - static_cast<float>(ix);
+
+ int32_t iy0 = MAX(iy, 0);
+ int32_t iy1 = MIN(iy + 1, in_height - 1);
+ int32_t ix0 = MAX(ix, 0);
+ int32_t ix1 = MIN(ix + 1, in_width - 1);
+
+ ASSERT_MSG(iy0 <= iy1 && ix0 <= ix1, "OpResize: invalid index (iy0, iy1, ix0, ix1)=(%d,%d,%d,%d)",
+ iy0, iy1, ix0, ix1);
+
+ OutEigenType acc;
+ if (mode == ResizeMode_BILINEAR)
+ {
+ InEigenType v00 = in->getTensor()(b, iy0, ix0, c);
+ InEigenType v01 = in->getTensor()(b, iy0, ix1, c);
+ InEigenType v10 = in->getTensor()(b, iy1, ix0, c);
+ InEigenType v11 = in->getTensor()(b, iy1, ix1, c);
+
+ acc = (OutEigenType)v00 * (1.0 - dy) * (1.0 - dx);
+ acc = acc + (OutEigenType)v01 * (1.0 - dy) * dx;
+ acc = acc + (OutEigenType)v10 * dy * (1.0 - dx);
+ acc = acc + (OutEigenType)v11 * dy * dx;
+ }
+ else
+ {
+ iy = (dy >= 0.5) ? iy1 : iy0;
+ ix = (dx >= 0.5) ? ix1 : ix0;
+ acc = in->getTensor()(b, iy, ix, c);
+ }
+
+ out->getTensor()(b, oy, ox, c) = acc;
+ }
+
+ return GraphNode::eval();
+}
+
// template explicit instantiation
DEF_INSTANTIATE_TWO_TYPE(OpResize, INT8, INT32);
DEF_INSTANTIATE_TWO_TYPE(OpResize, INT8, INT8);
DEF_INSTANTIATE_TWO_TYPE(OpResize, INT16, INT48);
DEF_INSTANTIATE_TWO_TYPE(OpResize, INT16, INT16);
+DEF_INSTANTIATE_TWO_TYPE(OpResize, FLOAT, FLOAT);