aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJerry Ge <jerry.ge@arm.com>2023-04-21 22:49:57 +0000
committerJerry Ge <jerry.ge@arm.com>2023-05-17 22:46:57 +0000
commit264f7faa59709ffa8117541f5d55c99c5dba967d (patch)
treeae767b3e4375ab87d4323f18b63239a84ac857db
parent7e5968166a5105da30bc11c9241f271cb3dc1da9 (diff)
downloadreference_model-264f7faa59709ffa8117541f5d55c99c5dba967d.tar.gz
Add support for one dimension of size -1 in ReshapeOp
Signed-off-by: Jerry Ge <jerry.ge@arm.com> Signed-off-by: Jeremy Johnson <jeremy.johnson@arm.com> Change-Id: I0ef7607f4266296a1204c5cccdb5be36f345b5ba
-rw-r--r--reference_model/src/ops/data_layout.cc42
-rw-r--r--reference_model/src/subgraph_traverser.h4
-rw-r--r--reference_model/src/tensor.h8
m---------thirdparty/serialization_lib0
-rw-r--r--verif/conformance/test_select.py2
-rw-r--r--verif/generator/tosa_arg_gen.py46
-rw-r--r--verif/generator/tosa_error_if.py58
-rw-r--r--verif/generator/tosa_test_gen.py2
8 files changed, 156 insertions, 6 deletions
diff --git a/reference_model/src/ops/data_layout.cc b/reference_model/src/ops/data_layout.cc
index fd19f96..86cd752 100644
--- a/reference_model/src/ops/data_layout.cc
+++ b/reference_model/src/ops/data_layout.cc
@@ -250,12 +250,50 @@ int OpReshape<InRank, OutRank, Dtype>::checkTensorAttributes()
return 1;
}
+ // -1 shape inferencing
+ auto inferred_size = -1;
+ auto inferred_dim = -1;
+ auto total_size = getInputs()[0]->getElementCount();
+ uint32_t accum_size = 1;
+
+ for (int32_t d = 0; d < OutRank; d++)
+ {
+ auto curr_new_shape = attribute->new_shape()[d];
+ if (curr_new_shape != -1) {
+ accum_size *= curr_new_shape;
+ } else {
+ ERROR_IF(inferred_dim != -1, "OpReshape: only 1 inferred dimension in output shape is supported");
+ inferred_dim = d;
+ }
+ }
+
+ ERROR_IF((total_size % accum_size) != 0, "OpReshape: shape inference failed, missing dimension would be non-integer");
+ inferred_size = total_size / accum_size;
+
+ if (inferred_dim != -1) {
+ getOutputs()[0]->setDimSize(inferred_dim, inferred_size);
+
+ // Need to also edit the serializedTensor's shape at inferred_dim
+ TosaSerializationTensor* serializedTensor;
+ for (auto region : parent_sgt->getTsh()->GetRegions()) {
+ for (auto block : region->GetBlocks()) {
+ if (block->GetTensorByName(getOutputs()[0]->getName())) {
+ serializedTensor = block->GetTensorByName(getOutputs()[0]->getName());
+ serializedTensor->SetDimSize(inferred_dim, inferred_size);
+ break;
+ }
+ }
+ }
+
+ }
+
ERROR_IF(inputs[0]->getElementCount() != outputs[0]->getElementCount(),
"Input tensor size does not match output tensor size");
for (uint32_t d = 0; d < OutRank; d++)
{
- ERROR_IF(attribute->new_shape()[d] != outputs[0]->getShape()[d],
+ auto curr_new_shape = attribute->new_shape()[d];
+ ERROR_IF(curr_new_shape != -1 && curr_new_shape != outputs[0]->getShape()[d],
"OpReshape: new_shape doesn't match output shape");
}
@@ -270,7 +308,7 @@ int OpReshape<InRank, OutRank, Dtype>::eval()
{
for (int32_t d = 0; d < OutRank; d++)
{
- array_shape[d] = attribute->new_shape()[OutRank - 1 - d];
+ array_shape[d] = getOutputs()[0]->getShape()[OutRank - 1 - d];
out_reverser[d] = OutRank - 1 - d;
}
diff --git a/reference_model/src/subgraph_traverser.h b/reference_model/src/subgraph_traverser.h
index 543b008..1cf582e 100644
--- a/reference_model/src/subgraph_traverser.h
+++ b/reference_model/src/subgraph_traverser.h
@@ -63,6 +63,10 @@ public:
{
return block->GetRegionName();
}
+ TosaSerializationHandler* getTsh() const
+ {
+ return tsh;
+ }
int getNumInputTensors() const;
Tensor* getInputTensor(const unsigned int idx) const;
Tensor* getInputTensorByName(const std::string name) const;
diff --git a/reference_model/src/tensor.h b/reference_model/src/tensor.h
index b68a9b6..21cf148 100644
--- a/reference_model/src/tensor.h
+++ b/reference_model/src/tensor.h
@@ -96,6 +96,12 @@ public:
return shape;
}
+ void setDimSize(size_t dim, uint32_t new_size)
+ {
+ this->shape[dim] = new_size;
+ return;
+ }
+
std::string getShapeAsString() const
{
std::string shape_str("[");
@@ -269,7 +275,7 @@ public:
protected:
const std::string tensorName;
const DType serializationDtype;
- const std::vector<int> shape;
+ std::vector<int> shape;
const TOSA_REF_TYPE tensorDtype;
int isValid;
int isSubgraphInput;
diff --git a/thirdparty/serialization_lib b/thirdparty/serialization_lib
-Subproject cfcb20d08c4c409bbcd2d2dde6ca5ecdac29945
+Subproject ab8d234bdc64896297ceceb7b97ce74a783ac7a
diff --git a/verif/conformance/test_select.py b/verif/conformance/test_select.py
index 66b2e56..9868a7f 100644
--- a/verif/conformance/test_select.py
+++ b/verif/conformance/test_select.py
@@ -733,7 +733,7 @@ class ReshapeOperator(Operator):
"""Test selector for the RESHAPE operator."""
name = "reshape"
- param_names = ["shape", "type", "perm", "rank"]
+ param_names = ["shape", "type", "perm", "rank", "out"]
class ResizeOperator(Operator):
diff --git a/verif/generator/tosa_arg_gen.py b/verif/generator/tosa_arg_gen.py
index 2bbc349..9386ec2 100644
--- a/verif/generator/tosa_arg_gen.py
+++ b/verif/generator/tosa_arg_gen.py
@@ -1878,17 +1878,27 @@ class TosaArgGen:
escape_counter = 0
while found:
newShape = []
+ new_shape_inferred = []
# Generate newShape ensuring it isn't a duplicate
remainingElements = totalElements
shuffledFactors = testGen.rng.permutation(factors)
+ inferred_dim = testGen.rng.integers(1, newRank + 1)
for i in range(1, newRank):
# pick rank-1 factors
newShape.append(shuffledFactors[0])
remainingElements = remainingElements // shuffledFactors[0]
+ if i == inferred_dim:
+ new_shape_inferred.append(-1)
+ else:
+ new_shape_inferred.append(shuffledFactors[0])
shuffledFactors = testGen.rng.permutation(
TosaArgGen.getFactors(remainingElements)
)
newShape.append(remainingElements)
+ if inferred_dim == newRank:
+ new_shape_inferred.append(-1)
+ else:
+ new_shape_inferred.append(remainingElements)
# Check for duplicates
found = False
@@ -1902,7 +1912,41 @@ class TosaArgGen:
break
if not found:
- arg_list.append(("perm{}_rank{}".format(p, newRank), [newShape]))
+ if error_name in [
+ ErrorIf.ReshapeOutputSizeNonInteger,
+ ErrorIf.ReshapeOutputSizeMultiInference,
+ ]:
+ if newRank < 2:
+ # Need at least two dimensions
+ continue
+ # NOTE: Change inferred_dim starting offset from 1 to 0
+ inferred_dim -= 1
+ extra_dim = inferred_dim + testGen.rng.integers(1, newRank)
+ extra_dim = extra_dim % newRank
+ assert extra_dim != inferred_dim
+ if error_name == ErrorIf.ReshapeOutputSizeNonInteger:
+ elements = 1
+ for i, dim_value in enumerate(new_shape_inferred):
+ if i != inferred_dim and i != extra_dim:
+ elements *= dim_value
+ dim_value = new_shape_inferred[extra_dim]
+ while totalElements % (elements * dim_value) == 0:
+ dim_value += 1
+ new_shape_inferred[extra_dim] = dim_value
+ else:
+ assert error_name == ErrorIf.ReshapeOutputSizeMultiInference
+ new_shape_inferred[extra_dim] = -1
+ else:
+ arg_list.append(
+ ("perm{}_rank{}_outdefined".format(p, newRank), [newShape])
+ )
+ if error_name != ErrorIf.TensorSizeInputOutputMismatch:
+ arg_list.append(
+ (
+ "perm{}_rank{}_outinferred".format(p, newRank),
+ [new_shape_inferred],
+ )
+ )
return arg_list
diff --git a/verif/generator/tosa_error_if.py b/verif/generator/tosa_error_if.py
index 8c40371..a0a9203 100644
--- a/verif/generator/tosa_error_if.py
+++ b/verif/generator/tosa_error_if.py
@@ -81,6 +81,8 @@ class ErrorIf(object):
KernelNotPowerOfTwo = "KernelNotPowerOfTwo"
FFTInputShapeMismatch = "FFTInputShapeMismatch"
FFTOutputShapeMismatch = "FFTOutputShapeMismatch"
+ ReshapeOutputSizeMultiInference = "ReshapeOutputSizeMultiInference"
+ ReshapeOutputSizeNonInteger = "ReshapeOutputSizeNonInteger"
class TosaErrorIfArgGen:
@@ -1822,13 +1824,17 @@ class TosaErrorValidator:
param_reqs = {"rank": None, "dtype": None, "shape": None}
error_result = False
error_reason = "Input tensor size does not match output tensor size"
+ op = kwargs["op"]
if check:
input_shape = kwargs["input_shape"]
output_shape = kwargs["output_shape"]
+ shape_inferencing = False
+ if -1 in output_shape and op["op"] == Op.RESHAPE:
+ shape_inferencing = True
input_size = np.prod(input_shape)
output_size = np.prod(output_shape)
- if input_size != output_size:
+ if input_size != output_size and not shape_inferencing:
error_result = True
info_dict = {
@@ -2510,6 +2516,56 @@ class TosaErrorValidator:
}
return info_dict
+ @staticmethod
+ def evReshapeOutputSizeMultiInference(check=False, **kwargs):
+ error_name = ErrorIf.ReshapeOutputSizeMultiInference
+ param_reqs = {"rank": None, "dtype": None, "shape": None}
+ error_result = False
+ error_reason = "Reshape output tensor contains more than one inferred dimension"
+
+ if check:
+ output_shape = kwargs["output_shape"]
+ inferences = 0
+ for dim in output_shape:
+ if dim == -1:
+ inferences += 1
+ if inferences > 1:
+ error_result = True
+
+ info_dict = {
+ "error_name": error_name,
+ "error_result": error_result,
+ "error_reason": error_reason,
+ "param_reqs": param_reqs,
+ }
+ return info_dict
+
+ @staticmethod
+ def evReshapeOutputSizeNonInteger(check=False, **kwargs):
+ error_name = ErrorIf.ReshapeOutputSizeNonInteger
+ param_reqs = {"rank": None, "dtype": None, "shape": None}
+ error_result = False
+ error_reason = "Reshape inferred output tensor dimension is non-integer"
+
+ if check:
+ input_shape = kwargs["input_shape"]
+ output_shape = kwargs["output_shape"]
+ input_size = np.prod(input_shape)
+ output_size = 1
+ for dim in output_shape:
+ if dim != -1:
+ output_size *= dim
+ if -1 in output_shape and input_size % output_size != 0:
+ error_result = True
+
+ info_dict = {
+ "error_name": error_name,
+ "error_result": error_result,
+ "error_reason": error_reason,
+ "param_reqs": param_reqs,
+ }
+ return info_dict
+
class TosaInvalidValidator:
@staticmethod
diff --git a/verif/generator/tosa_test_gen.py b/verif/generator/tosa_test_gen.py
index c8c22c2..7691fdd 100644
--- a/verif/generator/tosa_test_gen.py
+++ b/verif/generator/tosa_test_gen.py
@@ -3693,6 +3693,8 @@ class TosaTestGen:
TosaErrorValidator.evWrongOutputType,
TosaErrorValidator.evWrongInputList,
TosaErrorValidator.evWrongOutputList,
+ TosaErrorValidator.evReshapeOutputSizeMultiInference,
+ TosaErrorValidator.evReshapeOutputSizeNonInteger,
),
},
"reverse": {