aboutsummaryrefslogtreecommitdiff
path: root/reference_model/src/ops/image.cc
diff options
context:
space:
mode:
Diffstat (limited to 'reference_model/src/ops/image.cc')
-rw-r--r--reference_model/src/ops/image.cc44
1 files changed, 23 insertions, 21 deletions
diff --git a/reference_model/src/ops/image.cc b/reference_model/src/ops/image.cc
index b1b762f..7c480ad 100644
--- a/reference_model/src/ops/image.cc
+++ b/reference_model/src/ops/image.cc
@@ -27,7 +27,7 @@ template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE OutDtype, typename resize_t>
OpResize<InDtype, OutDtype, resize_t>::OpResize(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_)
: GraphNode(sgt_, Op_RESIZE, id_)
{
- setRequiredOperands(1, 1);
+ setRequiredOperands(4, 1);
setRequiredRank(4, 4);
INIT_ATTRIBUTE(Resize);
@@ -49,16 +49,7 @@ int OpResize<InDtype, OutDtype, resize_t>::checkTensorAttributes()
if (validateRequiredRank(inputs[0]) || validateRequiredRank(outputs[0]))
return 1;
- if (this->attribute->scale().size() != 4)
- {
- printNodeValidationError("OpResize: illegal size for attribute scale");
- return 1;
- }
-
- scale = this->attribute->scale();
- offset = this->attribute->offset();
- border = this->attribute->border();
- mode = this->attribute->mode();
+ mode = this->attribute->mode();
if (this->mode == ResizeMode_BILINEAR)
{
@@ -79,8 +70,11 @@ int OpResize<InDtype, OutDtype, resize_t>::checkTensorAttributes()
}
}
- in = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
- out = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
+ in = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
+ scale = dynamic_cast<TosaReference::TensorTemplate<TInShape>*>(inputs[1]);
+ offset = dynamic_cast<TosaReference::TensorTemplate<TInShape>*>(inputs[2]);
+ border = dynamic_cast<TosaReference::TensorTemplate<TInShape>*>(inputs[3]);
+ out = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
ASSERT_MEM(in && out);
@@ -90,6 +84,14 @@ int OpResize<InDtype, OutDtype, resize_t>::checkTensorAttributes()
template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE OutDtype, typename resize_t>
int OpResize<InDtype, OutDtype, resize_t>::eval()
{
+ // validate scale/offset/border number of elements
+ TInShape scale_val = this->scale->getTensor();
+ TInShape offset_val = this->offset->getTensor();
+ TInShape border_val = this->border->getTensor();
+ ERROR_IF(scale_val.size() != 4, "OpResize: illegal size for input scale");
+ ERROR_IF(offset_val.size() != 2, "OpResize: illegal size for input offset");
+ ERROR_IF(border_val.size() != 2, "OpResize: illegal size for input border");
+
int in_batch = in->getShape()[0];
int in_height = in->getShape()[1];
int in_width = in->getShape()[2];
@@ -100,16 +102,16 @@ int OpResize<InDtype, OutDtype, resize_t>::eval()
int out_width = out->getShape()[2];
int out_channels = out->getShape()[3];
- int16_t scale_y_n = scale[0];
- int16_t scale_y_d = scale[1];
- int16_t scale_x_n = scale[2];
- int16_t scale_x_d = scale[3];
+ int16_t scale_y_n = scale_val(0);
+ int16_t scale_y_d = scale_val(1);
+ int16_t scale_x_n = scale_val(2);
+ int16_t scale_x_d = scale_val(3);
- int16_t offset_y = offset[0];
- int16_t offset_x = offset[1];
+ int16_t offset_y = offset_val(0);
+ int16_t offset_x = offset_val(1);
- int16_t border_y = border[0];
- int16_t border_x = border[1];
+ int16_t border_y = border_val(0);
+ int16_t border_x = border_val(1);
ERROR_IF(std::max<int>({ in_height, in_width, out_height, out_width }) >= 16384,
"OpResize: exceeds maximum dimension");