aboutsummaryrefslogtreecommitdiff
path: root/reference_model
diff options
context:
space:
mode:
authorTai Ly <tai.ly@arm.com>2024-02-22 23:26:28 +0000
committerEric Kunze <eric.kunze@arm.com>2024-02-26 15:36:59 +0000
commitc5c2a7e4be64fef198c150063db9f241f21299d2 (patch)
treed46ee581b7992fb42c5484464e9c32a4a3adc646 /reference_model
parent97f1c0e602049ee43537ad6c8a08f476eb5c722b (diff)
downloadreference_model-c5c2a7e4be64fef198c150063db9f241f21299d2.tar.gz
[ref_model] Change resize attrs to inputs
This patch implements changes needed for resize op's scale/offset/border changing from attributes to inputs Signed-off-by: Tai Ly <tai.ly@arm.com> Change-Id: I20db0687fad40711f3ded644af51096292dd05b3
Diffstat (limited to 'reference_model')
-rw-r--r--reference_model/src/ops/image.cc44
-rw-r--r--reference_model/src/ops/image.h16
2 files changed, 32 insertions, 28 deletions
diff --git a/reference_model/src/ops/image.cc b/reference_model/src/ops/image.cc
index b1b762f..7c480ad 100644
--- a/reference_model/src/ops/image.cc
+++ b/reference_model/src/ops/image.cc
@@ -27,7 +27,7 @@ template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE OutDtype, typename resize_t>
OpResize<InDtype, OutDtype, resize_t>::OpResize(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_)
: GraphNode(sgt_, Op_RESIZE, id_)
{
- setRequiredOperands(1, 1);
+ setRequiredOperands(4, 1);
setRequiredRank(4, 4);
INIT_ATTRIBUTE(Resize);
@@ -49,16 +49,7 @@ int OpResize<InDtype, OutDtype, resize_t>::checkTensorAttributes()
if (validateRequiredRank(inputs[0]) || validateRequiredRank(outputs[0]))
return 1;
- if (this->attribute->scale().size() != 4)
- {
- printNodeValidationError("OpResize: illegal size for attribute scale");
- return 1;
- }
-
- scale = this->attribute->scale();
- offset = this->attribute->offset();
- border = this->attribute->border();
- mode = this->attribute->mode();
+ mode = this->attribute->mode();
if (this->mode == ResizeMode_BILINEAR)
{
@@ -79,8 +70,11 @@ int OpResize<InDtype, OutDtype, resize_t>::checkTensorAttributes()
}
}
- in = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
- out = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
+ in = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
+ scale = dynamic_cast<TosaReference::TensorTemplate<TInShape>*>(inputs[1]);
+ offset = dynamic_cast<TosaReference::TensorTemplate<TInShape>*>(inputs[2]);
+ border = dynamic_cast<TosaReference::TensorTemplate<TInShape>*>(inputs[3]);
+ out = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
ASSERT_MEM(in && out);
@@ -90,6 +84,14 @@ int OpResize<InDtype, OutDtype, resize_t>::checkTensorAttributes()
template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE OutDtype, typename resize_t>
int OpResize<InDtype, OutDtype, resize_t>::eval()
{
+ // validate scale/offset/border number of elements
+ TInShape scale_val = this->scale->getTensor();
+ TInShape offset_val = this->offset->getTensor();
+ TInShape border_val = this->border->getTensor();
+ ERROR_IF(scale_val.size() != 4, "OpResize: illegal size for input scale");
+ ERROR_IF(offset_val.size() != 2, "OpResize: illegal size for input offset");
+ ERROR_IF(border_val.size() != 2, "OpResize: illegal size for input border");
+
int in_batch = in->getShape()[0];
int in_height = in->getShape()[1];
int in_width = in->getShape()[2];
@@ -100,16 +102,16 @@ int OpResize<InDtype, OutDtype, resize_t>::eval()
int out_width = out->getShape()[2];
int out_channels = out->getShape()[3];
- int16_t scale_y_n = scale[0];
- int16_t scale_y_d = scale[1];
- int16_t scale_x_n = scale[2];
- int16_t scale_x_d = scale[3];
+ int16_t scale_y_n = scale_val(0);
+ int16_t scale_y_d = scale_val(1);
+ int16_t scale_x_n = scale_val(2);
+ int16_t scale_x_d = scale_val(3);
- int16_t offset_y = offset[0];
- int16_t offset_x = offset[1];
+ int16_t offset_y = offset_val(0);
+ int16_t offset_x = offset_val(1);
- int16_t border_y = border[0];
- int16_t border_x = border[1];
+ int16_t border_y = border_val(0);
+ int16_t border_x = border_val(1);
ERROR_IF(std::max<int>({ in_height, in_width, out_height, out_width }) >= 16384,
"OpResize: exceeds maximum dimension");
diff --git a/reference_model/src/ops/image.h b/reference_model/src/ops/image.h
index 6d5a418..9c08def 100644
--- a/reference_model/src/ops/image.h
+++ b/reference_model/src/ops/image.h
@@ -32,18 +32,20 @@ public:
virtual int checkTensorAttributes() final;
virtual int eval();
- using InEigenType = typename GetEigenType<InDtype>::type;
- using OutEigenType = typename GetEigenType<OutDtype>::type;
- using TIn = Eigen::Tensor<InEigenType, 4>;
- using TOut = Eigen::Tensor<OutEigenType, 4>;
+ using InEigenType = typename GetEigenType<InDtype>::type;
+ using InEigenShapeType = typename GetEigenType<TOSA_REF_TYPE_SHAPE>::type;
+ using OutEigenType = typename GetEigenType<OutDtype>::type;
+ using TIn = Eigen::Tensor<InEigenType, 4>;
+ using TInShape = Eigen::Tensor<InEigenShapeType, 1>;
+ using TOut = Eigen::Tensor<OutEigenType, 4>;
protected:
TosaResizeAttribute* attribute;
- std::vector<int16_t> scale;
- std::vector<int16_t> offset;
- std::vector<int16_t> border;
ResizeMode mode;
TosaReference::TensorTemplate<TIn>* in;
+ TosaReference::TensorTemplate<TInShape>* scale;
+ TosaReference::TensorTemplate<TInShape>* offset;
+ TosaReference::TensorTemplate<TInShape>* border;
TosaReference::TensorTemplate<TOut>* out;
};