aboutsummaryrefslogtreecommitdiff
path: root/reference_model/src/ops/image.cc
diff options
context:
space:
mode:
Diffstat (limited to 'reference_model/src/ops/image.cc')
-rw-r--r--reference_model/src/ops/image.cc169
1 files changed, 169 insertions, 0 deletions
diff --git a/reference_model/src/ops/image.cc b/reference_model/src/ops/image.cc
new file mode 100644
index 0000000..d3352ce
--- /dev/null
+++ b/reference_model/src/ops/image.cc
@@ -0,0 +1,169 @@
+
+// Copyright (c) 2020, ARM Limited.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "image.h"
+#include "arith_util.h"
+#include "quant_util.h"
+
+using namespace TosaReference;
+using namespace Eigen;
+using namespace tosa;
+
+template <DType InDtype, DType OutDtype>
+OpResize<InDtype, OutDtype>::OpResize(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_)
+ : GraphNode(Op_RESIZE, id_)
+{
+ setRequiredOperands(1, 1);
+ setRequiredRank(4, 4);
+
+ INIT_ATTRIBUTE(Resize);
+}
+
+template <DType InDtype, DType OutDtype>
+OpResize<InDtype, OutDtype>::~OpResize()
+{
+ if (attribute)
+ delete attribute;
+}
+
+template <DType InDtype, DType OutDtype>
+int OpResize<InDtype, OutDtype>::checkTensorAttributes()
+{
+ if (validateRequiredOperands())
+ return 1;
+
+ if (validateRequiredRank(inputs[0]) || validateRequiredRank(outputs[0]))
+ return 1;
+
+ output_size = this->attribute->output_size();
+ stride = this->attribute->stride();
+ offset = this->attribute->offset();
+ shift = this->attribute->shift();
+ mode = this->attribute->mode();
+
+ int output_height = outputs[0]->getShape()[1];
+ int output_width = outputs[0]->getShape()[2];
+
+ if (this->mode == ResizeMode_BILINEAR)
+ {
+ if (OutDtype != DType_INT32 && OutDtype != DType_INT48)
+ {
+ printNodeValidationError("OpResize: invalid data type for BILINEAR");
+ return 1;
+ }
+ }
+ else
+ {
+ if (OutDtype != DType_INT8 && OutDtype != DType_INT16)
+ {
+ printNodeValidationError("OpResize: invalid data type for NEAREST");
+ return 1;
+ }
+ }
+
+ if (output_size[0] != output_height || output_size[1] != output_width)
+ {
+ printNodeValidationError("OpResize: attribute output_size doesn't match output [height, width]");
+ return 1;
+ }
+
+ if (shift < 1 || shift > 11)
+ {
+ printNodeValidationError("OpResize: attribute shift should be within [1, 11]");
+ return 1;
+ }
+
+ if (stride[0] <= 0 || stride[1] <= 0)
+ {
+ printNodeValidationError("OpResize: invalid attribute stride");
+ return 1;
+ }
+
+ in = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
+ out = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
+
+ ASSERT_MEM(in && out);
+
+ return 0;
+}
+
+template <DType InDtype, DType OutDtype>
+int OpResize<InDtype, OutDtype>::eval()
+{
+ int in_batch = in->getShape()[0];
+ int in_height = in->getShape()[1];
+ int in_width = in->getShape()[2];
+ int in_channels = in->getShape()[3];
+
+ int out_batch = out->getShape()[0];
+ int out_height = out->getShape()[1];
+ int out_width = out->getShape()[2];
+ int out_channels = out->getShape()[3];
+
+ ASSERT_MSG_NODE(in_batch == out_batch, "OpResize: output tensor batch mismatch");
+ ASSERT_MSG_NODE(in_channels == out_channels, "OpResize: output tensor channel mismatch");
+
+ for (int b = 0; b < out_batch; b++)
+ for (int c = 0; c < out_channels; c++)
+ for (int oy = 0; oy < out_height; oy++)
+ for (int ox = 0; ox < out_width; ox++)
+ {
+ int y = oy * stride[0] + offset[0];
+ int x = ox * stride[1] + offset[1];
+
+ int iy = y >> shift;
+ int dy = y - (iy << shift);
+ int ix = x >> shift;
+ int dx = x - (ix << shift);
+
+ int iy0 = MAX(iy, 0);
+ int iy1 = MIN(iy + 1, in_height - 1);
+ int ix0 = MAX(ix, 0);
+ int ix1 = MIN(ix + 1, in_width - 1);
+
+ ASSERT_MSG(iy0 <= iy1 && ix0 <= ix1, "OpResize: invalid index (iy0, iy1, ix0, ix1)=(%d,%d,%d,%d)",
+ iy0, iy1, ix0, ix1);
+
+ InEigenType v00 = in->getTensor()(b, iy0, ix0, c);
+ InEigenType v01 = in->getTensor()(b, iy0, ix1, c);
+ InEigenType v10 = in->getTensor()(b, iy1, ix0, c);
+ InEigenType v11 = in->getTensor()(b, iy1, ix1, c);
+
+ OutEigenType acc;
+ if (mode == ResizeMode_BILINEAR)
+ {
+ acc = (OutEigenType)v00 * ((1 << shift) - dy) * ((1 << shift) - dx);
+ acc = acc + (OutEigenType)v01 * ((1 << shift) - dy) * dx;
+ acc = acc + (OutEigenType)v10 * dy * ((1 << shift) - dx);
+ acc = acc + (OutEigenType)v11 * dy * dx;
+ }
+ else
+ {
+ iy = (dy >> (shift - 1)) != 0 ? iy1 : iy0;
+ ix = (dx >> (shift - 1)) != 0 ? ix1 : ix0;
+ acc = in->getTensor()(b, iy, ix, c);
+ }
+
+ out->getTensor()(b, oy, ox, c) = acc;
+ }
+
+ return GraphNode::eval();
+}
+
+// template explicit instantiation
+DEF_INSTANTIATE_TWO_TYPE(OpResize, INT8, INT32);
+DEF_INSTANTIATE_TWO_TYPE(OpResize, INT8, INT8);
+DEF_INSTANTIATE_TWO_TYPE(OpResize, INT16, INT48);
+DEF_INSTANTIATE_TWO_TYPE(OpResize, INT16, INT16);