aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTai Ly <tai.ly@arm.com>2023-12-18 20:40:24 +0000
committerTai Ly <tai.ly@arm.com>2024-01-18 23:50:04 +0000
commit8690a0873fac28acccbb0acb15c16e8337e14df1 (patch)
treea13d5e195d8b7becffc23da98fde7449e91c96e4
parent9f5febe05901bfbd3919ef17f2caea8087cd9ccf (diff)
downloadreference_model-8690a0873fac28acccbb0acb15c16e8337e14df1.tar.gz
[reference model] Add shape operators
- fixed up reshape conformance tests to use shape input instead of attribute - fixed up tile conformance tests to use shape input instead of attribute - fixed output and output rank of dim op - allow rank 0 and rank 1 tensors for tosa.shape values (for shape = {}) - added initialization of rank 0 const_shape tensors (for shape = {}) - Update conformance tests to use new rescale attributes Signed-off-by: Tai Ly <tai.ly@arm.com> Signed-off-by: Won Jeon <won.jeon@arm.com> Change-Id: I6cce0d2a9ab066fe20a2abf9d2cfde3eb3d8c18b
-rw-r--r--reference_model/CMakeLists.txt1
-rw-r--r--reference_model/src/ops/data_layout.cc55
-rw-r--r--reference_model/src/ops/data_layout.h16
-rw-r--r--reference_model/src/ops/op_factory.cc16
-rw-r--r--reference_model/src/ops/shape.cc198
-rw-r--r--reference_model/src/ops/shape.h120
-rw-r--r--reference_model/src/subgraph_traverser.cc19
-rw-r--r--reference_model/src/tensor.h14
m---------thirdparty/serialization_lib0
-rw-r--r--verif/generator/tosa_test_gen.py60
10 files changed, 427 insertions, 72 deletions
diff --git a/reference_model/CMakeLists.txt b/reference_model/CMakeLists.txt
index 41e4d57..178594b 100644
--- a/reference_model/CMakeLists.txt
+++ b/reference_model/CMakeLists.txt
@@ -97,6 +97,7 @@ set(CXX_SOURCE
src/ops/data_nodes.cc
src/ops/custom.cc
src/ops/control_flow.cc
+ src/ops/shape.cc
)
set(PUBLIC_INCLUDE_DIRS
diff --git a/reference_model/src/ops/data_layout.cc b/reference_model/src/ops/data_layout.cc
index fa99d21..a4b4e0a 100644
--- a/reference_model/src/ops/data_layout.cc
+++ b/reference_model/src/ops/data_layout.cc
@@ -1,5 +1,5 @@
-// Copyright (c) 2020-2023, ARM Limited.
+// Copyright (c) 2020-2024, ARM Limited.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@@ -258,7 +258,7 @@ int OpDim<Rank, Dtype>::eval()
int32_t axis = attribute->axis();
int64_t out_val = in->getShape()[axis];
- this->out->getTensor().setConstant(out_val);
+ this->out->getTensor().setValues({ out_val });
return GraphNode::eval();
}
@@ -267,17 +267,12 @@ template <int InRank, int OutRank, TOSA_REF_TYPE Dtype>
OpReshape<InRank, OutRank, Dtype>::OpReshape(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_)
: GraphNode(sgt_, Op_RESHAPE, id_)
{
- setRequiredOperands(1, 1);
-
- INIT_ATTRIBUTE(Reshape);
+ setRequiredOperands(2, 1);
}
template <int InRank, int OutRank, TOSA_REF_TYPE Dtype>
OpReshape<InRank, OutRank, Dtype>::~OpReshape()
-{
- if (attribute)
- delete attribute;
-}
+{}
template <int InRank, int OutRank, TOSA_REF_TYPE Dtype>
int OpReshape<InRank, OutRank, Dtype>::checkTensorAttributes()
@@ -297,25 +292,17 @@ int OpReshape<InRank, OutRank, Dtype>::checkTensorAttributes()
return 1;
}
- // Check for unsupported -1 shape inferencing
- for (int32_t d = 0; d < OutRank; d++)
- {
- auto curr_new_dim = attribute->new_shape()[d];
- ERROR_IF(curr_new_dim == -1, "OpReshape: inferred dimensions in output shape are unsupported")
- }
-
ERROR_IF(inputs[0]->getElementCount() != outputs[0]->getElementCount(),
"Input tensor size does not match output tensor size");
- for (uint32_t d = 0; d < OutRank; d++)
- {
- auto curr_new_dim = attribute->new_shape()[d];
- ERROR_IF(curr_new_dim != outputs[0]->getShape()[d], "OpReshape: new_shape doesn't match output shape");
- }
-
in = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
out = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
+ // note: do not assert mem on shape input, because it may be {} for reshape to scalar
+ // and also, because the shape input is not actually used in eval()
+
+ ASSERT_MEM(in && out)
+
return 0;
}
@@ -506,18 +493,13 @@ template <int Rank, TOSA_REF_TYPE Dtype>
OpTileBase<Rank, Dtype>::OpTileBase(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_)
: GraphNode(sgt_, Op_TILE, id_)
{
- setRequiredOperands(1, 1);
+ setRequiredOperands(2, 1);
setRequiredRank(1);
-
- INIT_ATTRIBUTE(Tile);
}
template <int Rank, TOSA_REF_TYPE Dtype>
OpTileBase<Rank, Dtype>::~OpTileBase()
-{
- if (attribute)
- delete attribute;
-}
+{}
template <int Rank, TOSA_REF_TYPE Dtype>
int OpTileBase<Rank, Dtype>::checkTensorAttributes()
@@ -541,23 +523,18 @@ int OpTileBase<Rank, Dtype>::checkTensorAttributes()
return 1;
}
- in = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
- out = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
+ in = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
+ multiples = dynamic_cast<TosaReference::TensorTemplate<TInMultiples>*>(inputs[1]);
+ out = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
- ASSERT_MEM(in && out);
+ ASSERT_MEM(in && multiples && out);
- if (attribute->multiples().size() != Rank)
+ if (multiples->getElementCount() != Rank)
{
printNodeValidationError("1D list 'multiples' must have size equal to input rank");
return 1;
}
- for (int32_t d = 0; d < Rank; d++)
- {
- ERROR_IF(in->getShape()[d] * attribute->multiples()[d] != out->getShape()[d],
- "Output shape not equal to input * multiples;")
- }
-
return 0;
}
diff --git a/reference_model/src/ops/data_layout.h b/reference_model/src/ops/data_layout.h
index 024f9a2..9341709 100644
--- a/reference_model/src/ops/data_layout.h
+++ b/reference_model/src/ops/data_layout.h
@@ -1,5 +1,5 @@
-// Copyright (c) 2020-2023, ARM Limited.
+// Copyright (c) 2020-2024, ARM Limited.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@@ -79,7 +79,7 @@ public:
using InEigenType = typename GetEigenType<Dtype>::type;
using OutEigenType = typename GetEigenType<TOSA_REF_TYPE_SHAPE>::type;
using TIn = Eigen::Tensor<InEigenType, Rank>;
- using TOut = Eigen::Tensor<OutEigenType, 0>;
+ using TOut = Eigen::Tensor<OutEigenType, 1>;
protected:
TosaReference::TensorTemplate<TIn>* in;
@@ -107,7 +107,6 @@ protected:
Eigen::array<Eigen::Index, InRank> in_reverser;
Eigen::array<Eigen::Index, OutRank> out_reverser;
TosaReference::TensorTemplate<TIn>* in;
- TosaReshapeAttribute* attribute;
TosaReference::TensorTemplate<TOut>* out;
};
@@ -165,14 +164,17 @@ public:
virtual int checkTensorAttributes();
- using InEigenType = typename GetEigenType<Dtype>::type;
- using OutEigenType = typename GetEigenType<Dtype>::type;
- using TIn = Eigen::Tensor<InEigenType, Rank>;
- using TOut = Eigen::Tensor<OutEigenType, Rank>;
+ using InEigenType = typename GetEigenType<Dtype>::type;
+ using InEigenShapeType = typename GetEigenType<TOSA_REF_TYPE_SHAPE>::type;
+ using OutEigenType = typename GetEigenType<Dtype>::type;
+ using TIn = Eigen::Tensor<InEigenType, Rank>;
+ using TInMultiples = Eigen::Tensor<InEigenShapeType, 1>;
+ using TOut = Eigen::Tensor<OutEigenType, Rank>;
protected:
TosaTileAttribute* attribute;
TosaReference::TensorTemplate<TIn>* in;
+ TosaReference::TensorTemplate<TInMultiples>* multiples;
TosaReference::TensorTemplate<TOut>* out;
};
diff --git a/reference_model/src/ops/op_factory.cc b/reference_model/src/ops/op_factory.cc
index 34db903..af8332e 100644
--- a/reference_model/src/ops/op_factory.cc
+++ b/reference_model/src/ops/op_factory.cc
@@ -1,5 +1,5 @@
-// Copyright (c) 2020-2023, ARM Limited.
+// Copyright (c) 2020-2024, ARM Limited.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@@ -26,6 +26,7 @@
#include "image.h"
#include "reduction.h"
#include "scatter_gather.h"
+#include "shape.h"
#include "tensor_ops.h"
#include "type_conversion.h"
@@ -600,6 +601,19 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt,
case Op_WHILE_LOOP:
return new OpWhileLoop(sgt, tsh, attribute, id);
+ case Op_CONST_SHAPE:
+ return new OpConstShape(sgt, id);
+ case Op_CONCAT_SHAPE:
+ return new OpConcatShape(sgt, id);
+ case Op_ADD_SHAPE:
+ return new OpAddShape(sgt, id);
+ case Op_SUB_SHAPE:
+ return new OpSubShape(sgt, id);
+ case Op_MUL_SHAPE:
+ return new OpMulShape(sgt, id);
+ case Op_DIV_SHAPE:
+ return new OpDivShape(sgt, id);
+
// Ops not recognized
default:
goto done;
diff --git a/reference_model/src/ops/shape.cc b/reference_model/src/ops/shape.cc
new file mode 100644
index 0000000..b087dd8
--- /dev/null
+++ b/reference_model/src/ops/shape.cc
@@ -0,0 +1,198 @@
+// Copyright (c) 2023-2024, ARM Limited.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "shape.h"
+
+using namespace TosaReference;
+using namespace Eigen;
+using namespace tosa;
+
+OpConstShape::OpConstShape(SubgraphTraverser* sgt_, uint64_t id_)
+ : GraphNode(sgt_, Op_CONST, id_)
+{
+ setRequiredOperands(0, 1);
+}
+
+OpConstShape::~OpConstShape()
+{}
+
+int OpConstShape::checkTensorAttributes()
+{
+ if (validateRequiredOperands())
+ return 1;
+
+ return 0;
+}
+
+int OpConstShape::eval()
+{
+ for (auto ct : getOutputs())
+ {
+ if (!ct->getIsValid())
+ {
+ std::string err = "Constant Shape tensor " + ct->getName() + " not correctly initialized";
+ printNodeValidationError(err.c_str());
+ return 1;
+ }
+ }
+
+ // Evaluation is trivial for constants
+ return GraphNode::eval();
+}
+
+OpConcatShape::OpConcatShape(SubgraphTraverser* sgt_, uint64_t id_)
+ : GraphNode(sgt_, Op_CONCAT_SHAPE, id_)
+{
+ setRequiredOperands(-1, 1);
+ setRequiredRank(1, 1);
+}
+
+OpConcatShape::~OpConcatShape()
+{}
+
+int OpConcatShape::checkTensorAttributes()
+{
+ if (validateRequiredOperands())
+ return 1;
+
+ if (inputs.empty())
+ {
+ printNodeValidationError("ConcatShape operator must have at least one input tensor");
+ return 1;
+ }
+
+ int32_t num_inputs = inputs.size();
+ int32_t elements_count = 0;
+ for (int32_t i = 0; i < num_inputs; i++)
+ {
+ if (validateRequiredRank(inputs[i]))
+ return 1;
+ ins.push_back(dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[i]));
+ elements_count += inputs[i]->getShape()[0];
+ }
+
+ ERROR_IF(elements_count != outputs[0]->getShape()[0],
+ "OpConcatShape: sum of input elements not equal to output number of elements");
+
+ num_dims = outputs[0]->getShape()[0];
+ out = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
+
+ return 0;
+}
+
+int OpConcatShape::eval()
+{
+ ETensor1<EigenType> out_tensor(num_dims);
+ int32_t out_idx = 0;
+ for (size_t i = 0; i < ins.size(); i++)
+ {
+ // all tosa.shape values are 1-d tensors
+ // interate in_idx in range of [0, rank of 1-d tensor]
+ for (int32_t in_idx = 0; in_idx < inputs[i]->getShape()[0]; in_idx++)
+ {
+ out_tensor(out_idx) = ins[i]->getTensor()(in_idx);
+ out_idx++;
+ }
+ }
+ out->getTensor() = out_tensor;
+ return GraphNode::eval();
+}
+
+ShapeBinaryNodeBase::ShapeBinaryNodeBase(SubgraphTraverser* sgt_, const Op& op_, uint64_t id_)
+ : GraphNode(sgt_, op_, id_)
+{
+ setRequiredOperands(2, 1);
+ setRequiredRank(1, 1);
+
+ fcn = [](EigenType a, EigenType b) -> EigenType { return EigenType(); };
+}
+
+ShapeBinaryNodeBase::~ShapeBinaryNodeBase()
+{}
+
+int ShapeBinaryNodeBase::checkTensorAttributes()
+{
+ if (validateRequiredOperands())
+ return 1;
+ if (validateRequiredRank(inputs[0]) || validateRequiredRank(inputs[1]) || validateRequiredRank(outputs[0]))
+ return 1;
+
+ num_dims = outputs[0]->getShape()[0];
+
+ if (inputs[0]->getShape()[0] != num_dims)
+ {
+ std::string err = "Binary shape operators " + std::string(EnumNamesOp()[nodeType]) +
+ " lhs input and output rank/shape must match";
+ printNodeValidationError(err.c_str());
+ return 1;
+ }
+
+ if (inputs[1]->getShape()[0] != num_dims)
+ {
+ std::string err = "Binary shape operators " + std::string(EnumNamesOp()[nodeType]) +
+ " rhs input and output rank/shape must match";
+ printNodeValidationError(err.c_str());
+ return 1;
+ }
+
+ a = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
+ b = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[1]);
+ result = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
+
+ ASSERT_MEM(a && b && result);
+
+ return 0;
+}
+
+int ShapeBinaryNodeBase::eval()
+{
+ auto ia = a->getTensor();
+ auto ib = b->getTensor();
+ ETensor1<EigenType> out_tens(num_dims);
+ for (int32_t i = 0; i < num_dims; i++)
+ {
+ EigenType lhs = ia(i);
+ EigenType rhs = ib(i);
+ out_tens(i) = (lhs < 0 || rhs < 0) ? static_cast<EigenType>(-1) : fcn(lhs, rhs);
+ }
+
+ result->getTensor() = out_tens;
+ return GraphNode::eval();
+}
+
+int OpAddShape::register_fcn()
+{
+ fcn = [](EigenType a, EigenType b) -> EigenType { return a + b; };
+ return 0;
+}
+
+int OpSubShape::register_fcn()
+{
+ fcn = [](EigenType a, EigenType b) -> EigenType { return a - b; };
+ return 0;
+}
+
+int OpMulShape::register_fcn()
+{
+ fcn = [](EigenType a, EigenType b) -> EigenType { return a * b; };
+ return 0;
+}
+
+int OpDivShape::register_fcn()
+{
+ fcn = [](EigenType a, EigenType b) -> EigenType {
+ return (b == static_cast<EigenType>(0)) ? static_cast<EigenType>(-1) : (a / b);
+ };
+ return 0;
+} \ No newline at end of file
diff --git a/reference_model/src/ops/shape.h b/reference_model/src/ops/shape.h
new file mode 100644
index 0000000..38ecda8
--- /dev/null
+++ b/reference_model/src/ops/shape.h
@@ -0,0 +1,120 @@
+// Copyright (c) 2023-2024, ARM Limited.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef OPS_SHAPES_H
+#define OPS_SHAPES_H
+
+#include "graph_node.h"
+
+namespace TosaReference
+{
+
+class OpConstShape : public GraphNode
+{
+public:
+ OpConstShape(SubgraphTraverser* sgt_, uint64_t id_);
+ virtual ~OpConstShape();
+
+ virtual int checkTensorAttributes();
+ virtual int eval();
+};
+
+class OpConcatShape : public GraphNode
+{
+public:
+ OpConcatShape(SubgraphTraverser* sgt_, uint64_t id_);
+ virtual ~OpConcatShape();
+
+ virtual int checkTensorAttributes();
+ virtual int eval();
+
+ using EigenType = typename GetEigenType<TOSA_REF_TYPE_SHAPE>::type;
+ using TIn = Eigen::Tensor<EigenType, 1>;
+ using TOut = Eigen::Tensor<EigenType, 1>;
+
+protected:
+ int32_t num_dims; // number of dimensions in concat_shape output
+ std::vector<TosaReference::TensorTemplate<TIn>*> ins;
+ TosaReference::TensorTemplate<TOut>* out;
+};
+
+class ShapeBinaryNodeBase : public GraphNode
+{
+public:
+ ShapeBinaryNodeBase(SubgraphTraverser* sgt_, const Op& op_, uint64_t id_);
+ virtual ~ShapeBinaryNodeBase();
+
+ virtual int checkTensorAttributes() final;
+ virtual int eval();
+ virtual int register_fcn() = 0;
+
+ using EigenType = typename GetEigenType<TOSA_REF_TYPE_SHAPE>::type;
+ using TIn = Eigen::Tensor<EigenType, 1>;
+ using TOut = Eigen::Tensor<EigenType, 1>;
+
+protected:
+ int32_t num_dims; // number of dimensions in shape op's result
+ std::function<EigenType(EigenType, EigenType)> fcn;
+ TosaReference::TensorTemplate<TIn>* a;
+ TosaReference::TensorTemplate<TIn>* b;
+ TosaReference::TensorTemplate<TOut>* result;
+};
+
+class OpAddShape : public ShapeBinaryNodeBase
+{
+public:
+ OpAddShape(SubgraphTraverser* sgt_, uint64_t id_)
+ : ShapeBinaryNodeBase(sgt_, Op_ADD_SHAPE, id_)
+ {
+ register_fcn();
+ }
+ virtual int register_fcn();
+};
+
+class OpSubShape : public ShapeBinaryNodeBase
+{
+public:
+ OpSubShape(SubgraphTraverser* sgt_, uint64_t id_)
+ : ShapeBinaryNodeBase(sgt_, Op_SUB_SHAPE, id_)
+ {
+ register_fcn();
+ }
+ virtual int register_fcn();
+};
+
+class OpMulShape : public ShapeBinaryNodeBase
+{
+public:
+ OpMulShape(SubgraphTraverser* sgt_, uint64_t id_)
+ : ShapeBinaryNodeBase(sgt_, Op_MUL_SHAPE, id_)
+ {
+ register_fcn();
+ }
+ virtual int register_fcn();
+};
+
+class OpDivShape : public ShapeBinaryNodeBase
+{
+public:
+ OpDivShape(SubgraphTraverser* sgt_, uint64_t id_)
+ : ShapeBinaryNodeBase(sgt_, Op_DIV_SHAPE, id_)
+ {
+ register_fcn();
+ }
+ virtual int register_fcn();
+};
+
+}; // namespace TosaReference
+
+#endif
diff --git a/reference_model/src/subgraph_traverser.cc b/reference_model/src/subgraph_traverser.cc
index 745213e..fae0b30 100644
--- a/reference_model/src/subgraph_traverser.cc
+++ b/reference_model/src/subgraph_traverser.cc
@@ -510,6 +510,12 @@ int SubgraphTraverser::allocateTensor(std::string name)
FATAL_ERROR("Failed to allocate tensor %s", tensor->getName().c_str());
}
+ // set valid for constant tensors:
+ if ((ts->GetShape().empty() && ts->GetDtype() == DType_SHAPE))
+ {
+ // corner case: const_shape {} has no data
+ tensor->setIsValid();
+ }
if (!ts->GetData().empty())
{
if (ts->GetVariable() && g_func_config.initialize_variable_tensor_from_numpy)
@@ -545,13 +551,18 @@ int SubgraphTraverser::allocateTensor(std::string name)
tensor->setTensorValueInt32(i32_data.size(), i32_data.data());
}
break;
- case DType_INT48:
- case DType_SHAPE: {
+ case DType_INT48: {
std::vector<int64_t> i64_data;
TosaSerializationHandler::ConvertU8toI48(ts->GetData(), tensor->getElementCount(), i64_data);
tensor->setTensorValueInt64(i64_data.size(), i64_data.data());
}
break;
+ case DType_SHAPE: {
+ std::vector<int64_t> i64_data;
+ TosaSerializationHandler::ConvertU8toI64(ts->GetData(), tensor->getElementCount(), i64_data);
+ tensor->setTensorValueInt64(i64_data.size(), i64_data.data());
+ }
+ break;
case DType_FP16: {
// Interpret f16 data as float
std::vector<float> f16_data;
@@ -617,6 +628,10 @@ int SubgraphTraverser::allocateTensor(std::string name)
EnumNameDType(ts->GetDtype()));
}
tensor->setIsValid();
+ }
+
+ if (tensor->getIsValid())
+ {
// Push ready consumers to the next node list
for (auto gn : tensor->getConsumers())
{
diff --git a/reference_model/src/tensor.h b/reference_model/src/tensor.h
index 5bcd1b2..cd71f9f 100644
--- a/reference_model/src/tensor.h
+++ b/reference_model/src/tensor.h
@@ -1,5 +1,5 @@
-// Copyright (c) 2020-2023, ARM Limited.
+// Copyright (c) 2020-2024, ARM Limited.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@@ -855,8 +855,16 @@ public:
}
break;
case TOSA_REF_TYPE_SHAPE:
- assert(rank == 0);
- return new Tensor0<int64_t>(tensorName_, dtype_, shape_);
+ switch (rank)
+ {
+ case 0:
+ return new Tensor0<int64_t>(tensorName_, dtype_, shape_);
+ case 1:
+ return new Tensor1<int64_t>(tensorName_, dtype_, shape_);
+ default:
+ assert(0); // shape tensors must have rank of 0 or 1
+ }
+ break;
case TOSA_REF_TYPE_BOOL:
switch (rank)
{
diff --git a/thirdparty/serialization_lib b/thirdparty/serialization_lib
-Subproject f5dfad14f0cdc9556785b610674350c2e5a3355
+Subproject 5d580faec02bcef56164587accb5fd88a3e80d8
diff --git a/verif/generator/tosa_test_gen.py b/verif/generator/tosa_test_gen.py
index b1c53f5..67ac367 100644
--- a/verif/generator/tosa_test_gen.py
+++ b/verif/generator/tosa_test_gen.py
@@ -263,7 +263,7 @@ class TosaTestGen:
return gtu.vect_f32_to_bf16(rand_f32)
elif dtype == DType.BOOL:
return self.rng.choice([False, True])
- elif dtype == DType.INT48:
+ elif dtype == DType.INT48 or dtype == DType.SHAPE:
# Special size
return np.int64(self.rng.integers(low, high, size=1))[0]
@@ -1556,15 +1556,24 @@ class TosaTestGen:
def build_reshape(
self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
):
- assert len(inputs) == 1
+ assert len(inputs) == 2
a = inputs[0]
- new_shape = args_dict["new_shape"]
+ # second input is not properly generated yet
+ # new_shape = inputs[1]
+
+ # modify inputs[1] by a shape tensor from new_shape arg value
+ new_shape_attr = args_dict["new_shape"]
+ shape_array = np.array(new_shape_attr)
+ shape = shape_array.shape
+ new_shape = self.ser.addPlaceholder(shape, DType.SHAPE, shape_array)
+ inputs[1] = new_shape
+
result_tensor = OutputShaper.reshapeOp(
- self.ser, self.rng, a, new_shape, error_name
+ self.ser, self.rng, a, new_shape_attr, error_name
)
# Invalidate Input/Output list for error if checks.
- input_list = [a.name]
+ input_list = [a.name, new_shape.name]
output_list = [result_tensor.name]
pCount, cCount = op["operands"]
num_operands = pCount + cCount
@@ -1588,10 +1597,7 @@ class TosaTestGen:
):
return None
- attr = ts.TosaSerializerAttribute()
- attr.ReshapeAttribute(new_shape)
-
- self.ser.addOperator(op["op"], input_list, output_list, attr)
+ self.ser.addOperator(op["op"], input_list, output_list)
compliance = self.tensorComplianceMetaData(
op, a.dtype, args_dict, result_tensor, error_name
@@ -1717,16 +1723,24 @@ class TosaTestGen:
def build_tile(
self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
):
- assert len(inputs) == 1
+ assert len(inputs) == 2
a = inputs[0]
- multiples = args_dict["multiples"]
+ # second input is not properly generated yet
+ # multiples = inputs[1]
+
+ # modify inputs[1] by a shape tensor from multiples arg value
+ multiples_attr = args_dict["multiples"]
+ shape_array = np.int64(np.array(multiples_attr))
+ shape = shape_array.shape
+ multiples = self.ser.addPlaceholder(shape, DType.SHAPE, shape_array)
+ inputs[1] = multiples
result_tensor = OutputShaper.tileOp(
- self.ser, self.rng, a, multiples, error_name
+ self.ser, self.rng, a, multiples_attr, error_name
)
# Invalidate Input/Output list for error if checks.
- input_list = [a.name]
+ input_list = [a.name, multiples.name]
output_list = [result_tensor.name]
pCount, cCount = op["operands"]
num_operands = pCount + cCount
@@ -1751,10 +1765,7 @@ class TosaTestGen:
):
return None
- attr = ts.TosaSerializerAttribute()
- attr.TileAttribute(multiples)
-
- self.ser.addOperator(op["op"], input_list, output_list, attr)
+ self.ser.addOperator(op["op"], input_list, output_list)
compliance = self.tensorComplianceMetaData(
op, a.dtype, args_dict, result_tensor, error_name
@@ -1989,12 +2000,16 @@ class TosaTestGen:
in_type_width = self.typeWidth(val.dtype)
out_type_width = self.typeWidth(out_dtype)
+ input_unsigned = False
+ output_unsigned = False
+
if val.dtype == DType.INT8:
input_zp = self.randInt(-128, 128)
in_type_width += 1
elif val.dtype == DType.UINT8:
input_zp = self.randInt(0, 256)
in_type_width += 1
+ input_unsigned = True
elif error_name in [
ErrorIf.InputZeroPointNotZero,
ErrorIf.U16InputZeroPointNotValid,
@@ -2007,6 +2022,7 @@ class TosaTestGen:
# Must come after ErrorIf.U16InputZeroPointNotValid check
input_zp = self.rng.choice([0, 32768])
in_type_width += 1
+ input_unsigned = True
else:
input_zp = 0
@@ -2016,6 +2032,7 @@ class TosaTestGen:
elif out_dtype == DType.UINT8:
output_zp = self.randInt(0, 256)
out_type_width += 1
+ output_unsigned = True
elif error_name in [
ErrorIf.OutputZeroPointNotZero,
ErrorIf.U16OutputZeroPointNotValid,
@@ -2028,6 +2045,7 @@ class TosaTestGen:
# Must come after ErrorIf.U16OutputZeroPointNotValid check
output_zp = self.rng.choice([0, 32768])
out_type_width += 1
+ output_unsigned = True
else:
output_zp = 0
@@ -2116,6 +2134,8 @@ class TosaTestGen:
scale32,
double_round,
per_channel,
+ input_unsigned,
+ output_unsigned,
)
self.ser.addOperator(op["op"], input_list, output_list, attr)
@@ -4212,7 +4232,7 @@ class TosaTestGen:
},
"reshape": {
"op": Op.RESHAPE,
- "operands": (1, 0),
+ "operands": (2, 0),
"build_fcn": (
build_reshape,
TosaTensorGen.tgBasic,
@@ -4277,7 +4297,7 @@ class TosaTestGen:
},
"tile": {
"op": Op.TILE,
- "operands": (1, 0),
+ "operands": (2, 0),
"rank": (1, 6),
"build_fcn": (
build_tile,
@@ -5141,7 +5161,7 @@ class OutputShaper:
@staticmethod
def dimOp(ser, rng, a, axis, error_name=None):
- output_shape = []
+ output_shape = [1]
if error_name == ErrorIf.WrongOutputType:
all_dtypes = [