MLIR  14.0.0git
InferTypeOpInterface.cpp
Go to the documentation of this file.
1 //===- InferTypeOpInterface.cpp - Infer Type Interfaces ---------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file contains the definitions of the infer op interfaces defined in
10 // `InferTypeOpInterface.td`.
11 //
12 //===----------------------------------------------------------------------===//
13 
15 #include "mlir/IR/BuiltinTypes.h"
16 #include "mlir/IR/Matchers.h"
17 #include "llvm/Support/FormatVariadic.h"
18 
19 using namespace mlir;
20 
21 namespace mlir {
22 #include "mlir/Interfaces/InferTypeOpInterface.cpp.inc"
23 } // namespace mlir
24 
25 bool ShapeAdaptor::hasRank() const {
26  if (val.isNull())
27  return false;
28  if (auto t = val.dyn_cast<Type>())
29  return t.cast<ShapedType>().hasRank();
30  if (val.is<Attribute>())
31  return true;
32  return val.get<ShapedTypeComponents *>()->hasRank();
33 }
34 
36  if (val.isNull())
37  return nullptr;
38  if (auto t = val.dyn_cast<Type>())
39  return t.cast<ShapedType>().getElementType();
40  if (val.is<Attribute>())
41  return nullptr;
42  return val.get<ShapedTypeComponents *>()->getElementType();
43 }
44 
46  assert(hasRank());
47  if (auto t = val.dyn_cast<Type>()) {
48  ArrayRef<int64_t> vals = t.cast<ShapedType>().getShape();
49  res.assign(vals.begin(), vals.end());
50  } else if (auto attr = val.dyn_cast<Attribute>()) {
51  auto dattr = attr.cast<DenseIntElementsAttr>();
52  res.clear();
53  res.reserve(dattr.size());
54  for (auto it : dattr.getValues<APInt>())
55  res.push_back(it.getSExtValue());
56  } else {
57  auto vals = val.get<ShapedTypeComponents *>()->getDims();
58  res.assign(vals.begin(), vals.end());
59  }
60 }
61 
63  assert(hasRank());
64  res.ranked = true;
65  getDims(res.dims);
66 }
67 
68 int64_t ShapeAdaptor::getDimSize(int index) const {
69  assert(hasRank());
70  if (auto t = val.dyn_cast<Type>())
71  return t.cast<ShapedType>().getDimSize(index);
72  if (auto attr = val.dyn_cast<Attribute>())
73  return attr.cast<DenseIntElementsAttr>()
74  .getValues<APInt>()[index]
75  .getSExtValue();
76  auto *stc = val.get<ShapedTypeComponents *>();
77  return stc->getDims()[index];
78 }
79 
80 int64_t ShapeAdaptor::getRank() const {
81  assert(hasRank());
82  if (auto t = val.dyn_cast<Type>())
83  return t.cast<ShapedType>().getRank();
84  if (auto attr = val.dyn_cast<Attribute>())
85  return attr.cast<DenseIntElementsAttr>().size();
86  return val.get<ShapedTypeComponents *>()->getDims().size();
87 }
88 
90  if (!hasRank())
91  return false;
92 
93  if (auto t = val.dyn_cast<Type>())
94  return t.cast<ShapedType>().hasStaticShape();
95  if (auto attr = val.dyn_cast<Attribute>()) {
96  auto dattr = attr.cast<DenseIntElementsAttr>();
97  for (auto index : dattr.getValues<APInt>())
98  if (ShapedType::isDynamic(index.getSExtValue()))
99  return false;
100  return true;
101  }
102  auto *stc = val.get<ShapedTypeComponents *>();
103  for (int64_t dim : stc->getDims())
104  if (ShapedType::isDynamic(dim))
105  return false;
106  return true;
107 }
108 
110  assert(hasStaticShape() && "cannot get element count of dynamic shaped type");
111 
112  if (auto t = val.dyn_cast<Type>())
113  return t.cast<ShapedType>().getNumElements();
114 
115  if (auto attr = val.dyn_cast<Attribute>()) {
116  auto dattr = attr.cast<DenseIntElementsAttr>();
117  int64_t num = 1;
118  for (auto index : dattr.getValues<APInt>()) {
119  num *= index.getZExtValue();
120  assert(num >= 0 && "integer overflow in element count computation");
121  }
122  return num;
123  }
124 
125  auto *stc = val.get<ShapedTypeComponents *>();
126  int64_t num = 1;
127  for (int64_t dim : stc->getDims()) {
128  num *= dim;
129  assert(num >= 0 && "integer overflow in element count computation");
130  }
131  return num;
132 }
133 
134 void ShapeAdaptor::dump() const {
135  if (!hasRank()) {
136  llvm::errs() << "<<unranked>>\n";
137  return;
138  }
139 
141  getDims(dims);
142  auto mapped = llvm::map_range(dims, [](int64_t dim) -> std::string {
143  if (ShapedType::isDynamic(dim))
144  return "?";
145  return llvm::formatv("{0}", dim).str();
146  });
147  llvm::errs() << "rank = " << getRank() << " dims = [";
148  llvm::interleave(mapped, llvm::errs(), "x");
149  llvm::errs() << "]\n";
150 }
151 
153  Value val = operator[](index);
154  if (valueToShape)
155  if (ShapeAdaptor ret = valueToShape(val))
156  return ret;
157 
159  if (!matchPattern(val, m_Constant(&attr)))
160  return nullptr;
161  if (attr.getType().getRank() != 1)
162  return nullptr;
163  return attr;
164 }
165 
167  if (operandShape)
168  if (ShapeAdaptor ret = operandShape(val))
169  return ret;
170  return val.getType();
171 }
172 
174  if (index < 0 || static_cast<size_t>(index) >= size())
175  return nullptr;
176  return getShape(operator[](index));
177 }
178 
181  MLIRContext *, Optional<Location> location, ValueShapeRange operands,
182  DictionaryAttr attributes, RegionRange regions,
184  componentTypeFn,
185  MLIRContext *context, Optional<Location> location, ValueRange operands,
186  DictionaryAttr attributes, RegionRange regions,
187  SmallVectorImpl<Type> &inferredReturnTypes) {
189  if (failed(componentTypeFn(context, location, operands, attributes, regions,
190  retComponents)))
191  return failure();
192  for (const auto &shapeAndType : retComponents) {
193  assert(shapeAndType.getAttribute() == nullptr && "attribute not supported");
194  if (shapeAndType.hasRank())
195  inferredReturnTypes.push_back(RankedTensorType::get(
196  shapeAndType.getDims(), shapeAndType.getElementType()));
197  else
198  inferredReturnTypes.push_back(
199  UnrankedTensorType::get(shapeAndType.getElementType()));
200  }
201  return success();
202 }
203 
205  SmallVector<Type, 4> inferredReturnTypes;
206  auto retTypeFn = cast<InferTypeOpInterface>(op);
207  if (failed(retTypeFn.inferReturnTypes(
208  op->getContext(), op->getLoc(), op->getOperands(),
209  op->getAttrDictionary(), op->getRegions(), inferredReturnTypes)))
210  return failure();
211  if (!retTypeFn.isCompatibleReturnTypes(inferredReturnTypes,
212  op->getResultTypes()))
213  return op->emitOpError("inferred type(s) ")
214  << inferredReturnTypes
215  << " are incompatible with return type(s) of operation "
216  << op->getResultTypes();
217  return success();
218 }
Include the generated interface declarations.
LogicalResult inferReturnTensorTypes(function_ref< LogicalResult(MLIRContext *, Optional< Location > location, ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions, SmallVectorImpl< ShapedTypeComponents > &retComponents)> componentTypeFn, MLIRContext *context, Optional< Location > location, ValueRange operands, DictionaryAttr attributes, RegionRange regions, SmallVectorImpl< Type > &inferredReturnTypes)
Operation is a basic unit of execution within MLIR.
Definition: Operation.h:28
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
Definition: Operation.h:423
operand_range getOperands()
Returns an iterator on the underlying Value&#39;s.
Definition: Operation.h:247
ShapedTypeComponents that represents the components of a ShapedType.
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value...
Definition: LogicalResult.h:72
int64_t getNumElements() const
Returns the number of elements in the shape.
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Definition: Traits.cpp:117
void getDims(SmallVectorImpl< int64_t > &res) const
Populates the dimensions from shape referenced.
Range of values and shapes (corresponding effectively to Shapes dialect&#39;s ValueShape type concept)...
MLIRContext * getContext()
Return the context this operation is associated with.
Definition: Operation.h:99
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
Adaptor class to abstract the differences between whether value is from a ShapedType or ShapedTypeCom...
ShapeAdaptor getValueAsShape(int index)
Returns an argument as shape.
Attributes are known-constant values of operations.
Definition: Attributes.h:24
bool hasRank() const
Returns whether the shape has a rank.
Type getElementType() const
Returns the element type.
ShapeAdaptor getShape(int index) const
Returns the shape of index&#39;th operand.
ArrayRef< int64_t > getDims() const
Return the dimensions of the shape.
bool hasStaticShape() const
Returns whether the shape is fully static.
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:106
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition: Matchers.h:231
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:72
DictionaryAttr getAttrDictionary()
Return all of the attributes on this operation as a DictionaryAttr.
Definition: Operation.h:311
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:84
void dump() const
Dumps textual repesentation to stderr.
Type getType() const
Return the type of this value.
Definition: Value.h:117
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:266
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:55
This class provides an abstraction over the different types of ranges over Regions.
Definition: Region.h:322
int64_t getDimSize(int index) const
Returns the size of the index&#39;th dimension.
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "&#39;dim&#39; op " which is convenient for verifiers...
Definition: Operation.cpp:518
This class provides an abstraction over the different types of ranges over Values.
LogicalResult verifyInferredResultTypes(Operation *op)
Verifies that the inferred result types match the actual result types for the op. ...
result_type_range getResultTypes()
Definition: Operation.h:297
An attribute that represents a reference to a dense integer vector or tensor object.
U cast() const
Definition: Types.h:250
int64_t getRank() const
Returns the rank of the shape.