diff options
Diffstat (limited to 'reference_model/src/ops/data_layout.cc')
-rw-r--r-- | reference_model/src/ops/data_layout.cc | 55 |
1 files changed, 16 insertions, 39 deletions
diff --git a/reference_model/src/ops/data_layout.cc b/reference_model/src/ops/data_layout.cc index fa99d21..a4b4e0a 100644 --- a/reference_model/src/ops/data_layout.cc +++ b/reference_model/src/ops/data_layout.cc @@ -1,5 +1,5 @@ -// Copyright (c) 2020-2023, ARM Limited. +// Copyright (c) 2020-2024, ARM Limited. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -258,7 +258,7 @@ int OpDim<Rank, Dtype>::eval() int32_t axis = attribute->axis(); int64_t out_val = in->getShape()[axis]; - this->out->getTensor().setConstant(out_val); + this->out->getTensor().setValues({ out_val }); return GraphNode::eval(); } @@ -267,17 +267,12 @@ template <int InRank, int OutRank, TOSA_REF_TYPE Dtype> OpReshape<InRank, OutRank, Dtype>::OpReshape(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_) : GraphNode(sgt_, Op_RESHAPE, id_) { - setRequiredOperands(1, 1); - - INIT_ATTRIBUTE(Reshape); + setRequiredOperands(2, 1); } template <int InRank, int OutRank, TOSA_REF_TYPE Dtype> OpReshape<InRank, OutRank, Dtype>::~OpReshape() -{ - if (attribute) - delete attribute; -} +{} template <int InRank, int OutRank, TOSA_REF_TYPE Dtype> int OpReshape<InRank, OutRank, Dtype>::checkTensorAttributes() @@ -297,25 +292,17 @@ int OpReshape<InRank, OutRank, Dtype>::checkTensorAttributes() return 1; } - // Check for unsupported -1 shape inferencing - for (int32_t d = 0; d < OutRank; d++) - { - auto curr_new_dim = attribute->new_shape()[d]; - ERROR_IF(curr_new_dim == -1, "OpReshape: inferred dimensions in output shape are unsupported") - } - ERROR_IF(inputs[0]->getElementCount() != outputs[0]->getElementCount(), "Input tensor size does not match output tensor size"); - for (uint32_t d = 0; d < OutRank; d++) - { - auto curr_new_dim = attribute->new_shape()[d]; - ERROR_IF(curr_new_dim != outputs[0]->getShape()[d], "OpReshape: new_shape doesn't match output shape"); - } - in = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]); out = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]); + // note: do not assert mem on shape input, because it may be {} for reshape to scalar + // and also, because the shape input is not actually used in eval() + + ASSERT_MEM(in && out) + return 0; } @@ -506,18 +493,13 @@ template <int Rank, TOSA_REF_TYPE Dtype> OpTileBase<Rank, Dtype>::OpTileBase(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_) : GraphNode(sgt_, Op_TILE, id_) { - setRequiredOperands(1, 1); + setRequiredOperands(2, 1); setRequiredRank(1); - - INIT_ATTRIBUTE(Tile); } template <int Rank, TOSA_REF_TYPE Dtype> OpTileBase<Rank, Dtype>::~OpTileBase() -{ - if (attribute) - delete attribute; -} +{} template <int Rank, TOSA_REF_TYPE Dtype> int OpTileBase<Rank, Dtype>::checkTensorAttributes() @@ -541,23 +523,18 @@ int OpTileBase<Rank, Dtype>::checkTensorAttributes() return 1; } - in = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]); - out = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]); + in = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]); + multiples = dynamic_cast<TosaReference::TensorTemplate<TInMultiples>*>(inputs[1]); + out = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]); - ASSERT_MEM(in && out); + ASSERT_MEM(in && multiples && out); - if (attribute->multiples().size() != Rank) + if (multiples->getElementCount() != Rank) { printNodeValidationError("1D list 'multiples' must have size equal to input rank"); return 1; } - for (int32_t d = 0; d < Rank; d++) - { - ERROR_IF(in->getShape()[d] * attribute->multiples()[d] != out->getShape()[d], - "Output shape not equal to input * multiples;") - } - return 0; } |