MLIR  21.0.0git
InferTypeOpInterface.cpp
Go to the documentation of this file.
1 //===- InferTypeOpInterface.cpp - Infer Type Interfaces ---------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file contains the definitions of the infer op interfaces defined in
10 // `InferTypeOpInterface.td`.
11 //
12 //===----------------------------------------------------------------------===//
13 
15 #include "mlir/IR/BuiltinTypes.h"
16 #include "mlir/IR/Matchers.h"
17 #include "llvm/Support/FormatVariadic.h"
18 #include "llvm/Support/InterleavedRange.h"
19 
20 using namespace mlir;
21 
22 namespace mlir {
23 #include "mlir/Interfaces/InferTypeOpInterface.cpp.inc"
24 } // namespace mlir
25 
26 LogicalResult
28  ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
29  auto reifiableOp = dyn_cast<ReifyRankedShapedTypeOpInterface>(op);
30  if (!reifiableOp)
31  return failure();
32  LogicalResult status = reifiableOp.reifyResultShapes(b, reifiedReturnShapes);
33 #ifndef NDEBUG
34  if (failed(status))
35  return failure();
36  // Assert that ReifyRankedShapedTypeOpInterface::reifyResultShapes produced
37  // a correct result.
38  int64_t resultIdx = 0;
39  for (OpResult result : op->getResults()) {
40  auto shapedType = dyn_cast<ShapedType>(result.getType());
41  if (!shapedType)
42  continue;
43  if (!shapedType.hasRank()) {
44  // Nothing to check for unranked shaped values.
45  ++resultIdx;
46  continue;
47  }
48  // Assert one OpFoldResult per dimension.
49  assert(shapedType.getRank() ==
50  static_cast<int64_t>(reifiedReturnShapes[resultIdx].size()) &&
51  "incorrect implementation of ReifyRankedShapedTypeOpInterface");
52  ++resultIdx;
53  }
54  // Assert that every shaped value result was reified.
55  assert(resultIdx == static_cast<int64_t>(reifiedReturnShapes.size()) &&
56  "incorrect implementation of ReifyRankedShapedTypeOpInterface");
57 #endif // NDEBUG
58  return status;
59 }
60 
61 bool ShapeAdaptor::hasRank() const {
62  if (val.isNull())
63  return false;
64  if (auto t = llvm::dyn_cast_if_present<Type>(val))
65  return cast<ShapedType>(t).hasRank();
66  if (isa<Attribute>(val))
67  return true;
68  return cast<ShapedTypeComponents *>(val)->hasRank();
69 }
70 
72  if (val.isNull())
73  return nullptr;
74  if (auto t = llvm::dyn_cast_if_present<Type>(val))
75  return cast<ShapedType>(t).getElementType();
76  if (isa<Attribute>(val))
77  return nullptr;
78  return cast<ShapedTypeComponents *>(val)->getElementType();
79 }
80 
82  assert(hasRank());
83  if (auto t = llvm::dyn_cast_if_present<Type>(val)) {
84  ArrayRef<int64_t> vals = cast<ShapedType>(t).getShape();
85  res.assign(vals.begin(), vals.end());
86  } else if (auto attr = llvm::dyn_cast_if_present<Attribute>(val)) {
87  auto dattr = cast<DenseIntElementsAttr>(attr);
88  res.clear();
89  res.reserve(dattr.size());
90  for (auto it : dattr.getValues<APInt>())
91  res.push_back(it.getSExtValue());
92  } else {
93  auto vals = cast<ShapedTypeComponents *>(val)->getDims();
94  res.assign(vals.begin(), vals.end());
95  }
96 }
97 
99  assert(hasRank());
100  res.ranked = true;
101  getDims(res.dims);
102 }
103 
104 int64_t ShapeAdaptor::getDimSize(int index) const {
105  assert(hasRank());
106  if (auto t = llvm::dyn_cast_if_present<Type>(val))
107  return cast<ShapedType>(t).getDimSize(index);
108  if (auto attr = llvm::dyn_cast_if_present<Attribute>(val))
109  return cast<DenseIntElementsAttr>(attr)
110  .getValues<APInt>()[index]
111  .getSExtValue();
112  auto *stc = cast<ShapedTypeComponents *>(val);
113  return stc->getDims()[index];
114 }
115 
116 int64_t ShapeAdaptor::getRank() const {
117  assert(hasRank());
118  if (auto t = llvm::dyn_cast_if_present<Type>(val))
119  return cast<ShapedType>(t).getRank();
120  if (auto attr = llvm::dyn_cast_if_present<Attribute>(val))
121  return cast<DenseIntElementsAttr>(attr).size();
122  return cast<ShapedTypeComponents *>(val)->getDims().size();
123 }
124 
126  if (!hasRank())
127  return false;
128 
129  if (auto t = llvm::dyn_cast_if_present<Type>(val))
130  return cast<ShapedType>(t).hasStaticShape();
131  if (auto attr = llvm::dyn_cast_if_present<Attribute>(val)) {
132  auto dattr = cast<DenseIntElementsAttr>(attr);
133  for (auto index : dattr.getValues<APInt>())
134  if (ShapedType::isDynamic(index.getSExtValue()))
135  return false;
136  return true;
137  }
138  auto *stc = cast<ShapedTypeComponents *>(val);
139  return llvm::none_of(stc->getDims(), ShapedType::isDynamic);
140 }
141 
143  assert(hasStaticShape() && "cannot get element count of dynamic shaped type");
144 
145  if (auto t = llvm::dyn_cast_if_present<Type>(val))
146  return cast<ShapedType>(t).getNumElements();
147 
148  if (auto attr = llvm::dyn_cast_if_present<Attribute>(val)) {
149  auto dattr = cast<DenseIntElementsAttr>(attr);
150  int64_t num = 1;
151  for (auto index : dattr.getValues<APInt>()) {
152  num *= index.getZExtValue();
153  assert(num >= 0 && "integer overflow in element count computation");
154  }
155  return num;
156  }
157 
158  auto *stc = cast<ShapedTypeComponents *>(val);
159  int64_t num = 1;
160  for (int64_t dim : stc->getDims()) {
161  num *= dim;
162  assert(num >= 0 && "integer overflow in element count computation");
163  }
164  return num;
165 }
166 
167 void ShapeAdaptor::dump() const {
168  if (!hasRank()) {
169  llvm::errs() << "<<unranked>>\n";
170  return;
171  }
172 
174  getDims(dims);
175  auto mapped = llvm::map_range(dims, [](int64_t dim) -> std::string {
176  if (ShapedType::isDynamic(dim))
177  return "?";
178  return llvm::formatv("{0}", dim).str();
179  });
180  llvm::errs() << "rank = " << getRank()
181  << " dims = " << llvm::interleaved_array(mapped, "x") << "\n";
182 }
183 
185  Value val = operator[](index);
186  if (valueToShape)
187  if (ShapeAdaptor ret = valueToShape(val))
188  return ret;
189 
191  if (!matchPattern(val, m_Constant(&attr)))
192  return nullptr;
193  if (attr.getType().getRank() != 1)
194  return nullptr;
195  return attr;
196 }
197 
199  if (operandShape)
200  if (ShapeAdaptor ret = operandShape(val))
201  return ret;
202  return val.getType();
203 }
204 
206  if (index < 0 || static_cast<size_t>(index) >= size())
207  return nullptr;
208  return getShape(operator[](index));
209 }
210 
212  ArrayRef<ShapedTypeComponents> retComponents,
213  SmallVectorImpl<Type> &inferredReturnTypes) {
214  for (const auto &shapeAndType : retComponents) {
215  Type elementTy = shapeAndType.getElementType();
216  assert(elementTy && "element type required to construct tensor");
217 
218  Attribute attr = shapeAndType.getAttribute();
219  if (shapeAndType.hasRank()) {
220  inferredReturnTypes.push_back(
221  RankedTensorType::get(shapeAndType.getDims(), elementTy, attr));
222  } else {
223  assert(attr == nullptr && "attribute not supported");
224  inferredReturnTypes.push_back(UnrankedTensorType::get(elementTy));
225  }
226  }
227  return success();
228 }
229 
231  SmallVector<Type, 4> inferredReturnTypes(op->getResultTypes());
232  auto retTypeFn = cast<InferTypeOpInterface>(op);
233  auto result = retTypeFn.refineReturnTypes(
234  op->getContext(), op->getLoc(), op->getOperands(),
236  inferredReturnTypes);
237  if (failed(result))
238  op->emitOpError() << "failed to infer returned types";
239 
240  return result;
241 }
242 
244  std::string buffer;
245  llvm::raw_string_ostream os(buffer);
246  os << "Failed to infer result type(s):\n"
247  << "\"" << state.name << "\"(...) "
248  << state.attributes.getDictionary(state.location.getContext()) << " : ("
249  << llvm::interleaved(llvm::map_range(
250  state.operands, [](Value val) { return val.getType(); }))
251  << ") -> ( ??? )";
252  emitRemark(state.location, "location of op");
253  llvm::report_fatal_error(llvm::StringRef(buffer));
254 }
Attributes are known-constant values of operations.
Definition: Attributes.h:25
An attribute that represents a reference to a dense integer vector or tensor object.
This class helps build Operations.
Definition: Builders.h:205
This is a value defined by a result of an operation.
Definition: Value.h:433
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
MLIRContext * getContext()
Return the context this operation is associated with.
Definition: Operation.h:216
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
DictionaryAttr getRawDictionaryAttrs()
Return all attributes that are not stored as properties.
Definition: Operation.h:509
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
Definition: Operation.h:677
result_type_range getResultTypes()
Definition: Operation.h:428
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:378
result_range getResults()
Definition: Operation.h:415
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:671
OpaqueProperties getPropertiesStorage()
Returns the properties storage.
Definition: Operation.h:901
Adaptor class to abstract the differences between whether value is from a ShapedType or ShapedTypeCom...
int64_t getDimSize(int index) const
Returns the size of the index'th dimension.
void dump() const
Dumps textual repesentation to stderr.
Type getElementType() const
Returns the element type.
int64_t getRank() const
Returns the rank of the shape.
bool hasStaticShape() const
Returns whether the shape is fully static.
int64_t getNumElements() const
Returns the number of elements in the shape.
void getDims(SmallVectorImpl< int64_t > &res) const
Populates the dimensions from shape referenced.
bool hasRank() const
Returns whether the shape has a rank.
ShapedTypeComponents that represents the components of a ShapedType.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
ShapeAdaptor getValueAsShape(int index)
Returns an argument as shape.
ShapeAdaptor getShape(int index) const
Returns the shape of index'th operand.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
LogicalResult inferReturnTensorTypes(ArrayRef< ShapedTypeComponents > retComponents, SmallVectorImpl< Type > &inferredReturnTypes)
void reportFatalInferReturnTypesError(OperationState &state)
Report a fatal error indicating that the result types could not be inferred.
LogicalResult verifyInferredResultTypes(Operation *op)
Verifies that the inferred result types match the actual result types for the op.
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:490
LogicalResult reifyResultShapes(OpBuilder &b, Operation *op, ReifiedRankedShapedTypeDims &reifiedReturnShapes)
Reify the shape of the result of an operation (typically in terms of the shape of its operands).
InFlightDiagnostic emitRemark(Location loc)
Utility method to emit a remark message using this location.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition: Matchers.h:369
This represents an operation in an abstracted form, suitable for use with the builder APIs.