diff options
author | Tai Ly <tai.ly@arm.com> | 2024-02-22 23:26:28 +0000 |
---|---|---|
committer | Eric Kunze <eric.kunze@arm.com> | 2024-02-26 15:36:59 +0000 |
commit | c5c2a7e4be64fef198c150063db9f241f21299d2 (patch) | |
tree | d46ee581b7992fb42c5484464e9c32a4a3adc646 /reference_model | |
parent | 97f1c0e602049ee43537ad6c8a08f476eb5c722b (diff) | |
download | reference_model-c5c2a7e4be64fef198c150063db9f241f21299d2.tar.gz |
[ref_model] Change resize attrs to inputs
This patch implements changes needed for resize op's
scale/offset/border changing from attributes to inputs
Signed-off-by: Tai Ly <tai.ly@arm.com>
Change-Id: I20db0687fad40711f3ded644af51096292dd05b3
Diffstat (limited to 'reference_model')
-rw-r--r-- | reference_model/src/ops/image.cc | 44 | ||||
-rw-r--r-- | reference_model/src/ops/image.h | 16 |
2 files changed, 32 insertions, 28 deletions
diff --git a/reference_model/src/ops/image.cc b/reference_model/src/ops/image.cc index b1b762f..7c480ad 100644 --- a/reference_model/src/ops/image.cc +++ b/reference_model/src/ops/image.cc @@ -27,7 +27,7 @@ template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE OutDtype, typename resize_t> OpResize<InDtype, OutDtype, resize_t>::OpResize(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_) : GraphNode(sgt_, Op_RESIZE, id_) { - setRequiredOperands(1, 1); + setRequiredOperands(4, 1); setRequiredRank(4, 4); INIT_ATTRIBUTE(Resize); @@ -49,16 +49,7 @@ int OpResize<InDtype, OutDtype, resize_t>::checkTensorAttributes() if (validateRequiredRank(inputs[0]) || validateRequiredRank(outputs[0])) return 1; - if (this->attribute->scale().size() != 4) - { - printNodeValidationError("OpResize: illegal size for attribute scale"); - return 1; - } - - scale = this->attribute->scale(); - offset = this->attribute->offset(); - border = this->attribute->border(); - mode = this->attribute->mode(); + mode = this->attribute->mode(); if (this->mode == ResizeMode_BILINEAR) { @@ -79,8 +70,11 @@ int OpResize<InDtype, OutDtype, resize_t>::checkTensorAttributes() } } - in = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]); - out = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]); + in = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]); + scale = dynamic_cast<TosaReference::TensorTemplate<TInShape>*>(inputs[1]); + offset = dynamic_cast<TosaReference::TensorTemplate<TInShape>*>(inputs[2]); + border = dynamic_cast<TosaReference::TensorTemplate<TInShape>*>(inputs[3]); + out = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]); ASSERT_MEM(in && out); @@ -90,6 +84,14 @@ int OpResize<InDtype, OutDtype, resize_t>::checkTensorAttributes() template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE OutDtype, typename resize_t> int OpResize<InDtype, OutDtype, resize_t>::eval() { + // validate scale/offset/border number of elements + TInShape scale_val = this->scale->getTensor(); + TInShape offset_val = this->offset->getTensor(); + TInShape border_val = this->border->getTensor(); + ERROR_IF(scale_val.size() != 4, "OpResize: illegal size for input scale"); + ERROR_IF(offset_val.size() != 2, "OpResize: illegal size for input offset"); + ERROR_IF(border_val.size() != 2, "OpResize: illegal size for input border"); + int in_batch = in->getShape()[0]; int in_height = in->getShape()[1]; int in_width = in->getShape()[2]; @@ -100,16 +102,16 @@ int OpResize<InDtype, OutDtype, resize_t>::eval() int out_width = out->getShape()[2]; int out_channels = out->getShape()[3]; - int16_t scale_y_n = scale[0]; - int16_t scale_y_d = scale[1]; - int16_t scale_x_n = scale[2]; - int16_t scale_x_d = scale[3]; + int16_t scale_y_n = scale_val(0); + int16_t scale_y_d = scale_val(1); + int16_t scale_x_n = scale_val(2); + int16_t scale_x_d = scale_val(3); - int16_t offset_y = offset[0]; - int16_t offset_x = offset[1]; + int16_t offset_y = offset_val(0); + int16_t offset_x = offset_val(1); - int16_t border_y = border[0]; - int16_t border_x = border[1]; + int16_t border_y = border_val(0); + int16_t border_x = border_val(1); ERROR_IF(std::max<int>({ in_height, in_width, out_height, out_width }) >= 16384, "OpResize: exceeds maximum dimension"); diff --git a/reference_model/src/ops/image.h b/reference_model/src/ops/image.h index 6d5a418..9c08def 100644 --- a/reference_model/src/ops/image.h +++ b/reference_model/src/ops/image.h @@ -32,18 +32,20 @@ public: virtual int checkTensorAttributes() final; virtual int eval(); - using InEigenType = typename GetEigenType<InDtype>::type; - using OutEigenType = typename GetEigenType<OutDtype>::type; - using TIn = Eigen::Tensor<InEigenType, 4>; - using TOut = Eigen::Tensor<OutEigenType, 4>; + using InEigenType = typename GetEigenType<InDtype>::type; + using InEigenShapeType = typename GetEigenType<TOSA_REF_TYPE_SHAPE>::type; + using OutEigenType = typename GetEigenType<OutDtype>::type; + using TIn = Eigen::Tensor<InEigenType, 4>; + using TInShape = Eigen::Tensor<InEigenShapeType, 1>; + using TOut = Eigen::Tensor<OutEigenType, 4>; protected: TosaResizeAttribute* attribute; - std::vector<int16_t> scale; - std::vector<int16_t> offset; - std::vector<int16_t> border; ResizeMode mode; TosaReference::TensorTemplate<TIn>* in; + TosaReference::TensorTemplate<TInShape>* scale; + TosaReference::TensorTemplate<TInShape>* offset; + TosaReference::TensorTemplate<TInShape>* border; TosaReference::TensorTemplate<TOut>* out; }; |