aboutsummaryrefslogtreecommitdiff
path: root/reference_model/src/ops/data_layout.cc
diff options
context:
space:
mode:
Diffstat (limited to 'reference_model/src/ops/data_layout.cc')
-rw-r--r--reference_model/src/ops/data_layout.cc55
1 files changed, 16 insertions, 39 deletions
diff --git a/reference_model/src/ops/data_layout.cc b/reference_model/src/ops/data_layout.cc
index fa99d21..a4b4e0a 100644
--- a/reference_model/src/ops/data_layout.cc
+++ b/reference_model/src/ops/data_layout.cc
@@ -1,5 +1,5 @@
-// Copyright (c) 2020-2023, ARM Limited.
+// Copyright (c) 2020-2024, ARM Limited.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@@ -258,7 +258,7 @@ int OpDim<Rank, Dtype>::eval()
int32_t axis = attribute->axis();
int64_t out_val = in->getShape()[axis];
- this->out->getTensor().setConstant(out_val);
+ this->out->getTensor().setValues({ out_val });
return GraphNode::eval();
}
@@ -267,17 +267,12 @@ template <int InRank, int OutRank, TOSA_REF_TYPE Dtype>
OpReshape<InRank, OutRank, Dtype>::OpReshape(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_)
: GraphNode(sgt_, Op_RESHAPE, id_)
{
- setRequiredOperands(1, 1);
-
- INIT_ATTRIBUTE(Reshape);
+ setRequiredOperands(2, 1);
}
template <int InRank, int OutRank, TOSA_REF_TYPE Dtype>
OpReshape<InRank, OutRank, Dtype>::~OpReshape()
-{
- if (attribute)
- delete attribute;
-}
+{}
template <int InRank, int OutRank, TOSA_REF_TYPE Dtype>
int OpReshape<InRank, OutRank, Dtype>::checkTensorAttributes()
@@ -297,25 +292,17 @@ int OpReshape<InRank, OutRank, Dtype>::checkTensorAttributes()
return 1;
}
- // Check for unsupported -1 shape inferencing
- for (int32_t d = 0; d < OutRank; d++)
- {
- auto curr_new_dim = attribute->new_shape()[d];
- ERROR_IF(curr_new_dim == -1, "OpReshape: inferred dimensions in output shape are unsupported")
- }
-
ERROR_IF(inputs[0]->getElementCount() != outputs[0]->getElementCount(),
"Input tensor size does not match output tensor size");
- for (uint32_t d = 0; d < OutRank; d++)
- {
- auto curr_new_dim = attribute->new_shape()[d];
- ERROR_IF(curr_new_dim != outputs[0]->getShape()[d], "OpReshape: new_shape doesn't match output shape");
- }
-
in = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
out = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
+ // note: do not assert mem on shape input, because it may be {} for reshape to scalar
+ // and also, because the shape input is not actually used in eval()
+
+ ASSERT_MEM(in && out)
+
return 0;
}
@@ -506,18 +493,13 @@ template <int Rank, TOSA_REF_TYPE Dtype>
OpTileBase<Rank, Dtype>::OpTileBase(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_)
: GraphNode(sgt_, Op_TILE, id_)
{
- setRequiredOperands(1, 1);
+ setRequiredOperands(2, 1);
setRequiredRank(1);
-
- INIT_ATTRIBUTE(Tile);
}
template <int Rank, TOSA_REF_TYPE Dtype>
OpTileBase<Rank, Dtype>::~OpTileBase()
-{
- if (attribute)
- delete attribute;
-}
+{}
template <int Rank, TOSA_REF_TYPE Dtype>
int OpTileBase<Rank, Dtype>::checkTensorAttributes()
@@ -541,23 +523,18 @@ int OpTileBase<Rank, Dtype>::checkTensorAttributes()
return 1;
}
- in = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
- out = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
+ in = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
+ multiples = dynamic_cast<TosaReference::TensorTemplate<TInMultiples>*>(inputs[1]);
+ out = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
- ASSERT_MEM(in && out);
+ ASSERT_MEM(in && multiples && out);
- if (attribute->multiples().size() != Rank)
+ if (multiples->getElementCount() != Rank)
{
printNodeValidationError("1D list 'multiples' must have size equal to input rank");
return 1;
}
- for (int32_t d = 0; d < Rank; d++)
- {
- ERROR_IF(in->getShape()[d] * attribute->multiples()[d] != out->getShape()[d],
- "Output shape not equal to input * multiples;")
- }
-
return 0;
}