MLIR  21.0.0git
InferTypeOpInterface.cpp
Go to the documentation of this file.
1 //===- InferTypeOpInterface.cpp - Infer Type Interfaces ---------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file contains the definitions of the infer op interfaces defined in
10 // `InferTypeOpInterface.td`.
11 //
12 //===----------------------------------------------------------------------===//
13 
15 #include "mlir/IR/BuiltinTypes.h"
16 #include "mlir/IR/Matchers.h"
17 #include "llvm/Support/FormatVariadic.h"
18 
19 using namespace mlir;
20 
21 namespace mlir {
22 #include "mlir/Interfaces/InferTypeOpInterface.cpp.inc"
23 } // namespace mlir
24 
25 LogicalResult
27  ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
28  auto reifiableOp = dyn_cast<ReifyRankedShapedTypeOpInterface>(op);
29  if (!reifiableOp)
30  return failure();
31  LogicalResult status = reifiableOp.reifyResultShapes(b, reifiedReturnShapes);
32 #ifndef NDEBUG
33  if (failed(status))
34  return failure();
35  // Assert that ReifyRankedShapedTypeOpInterface::reifyResultShapes produced
36  // a correct result.
37  int64_t resultIdx = 0;
38  for (OpResult result : op->getResults()) {
39  auto shapedType = dyn_cast<ShapedType>(result.getType());
40  if (!shapedType)
41  continue;
42  if (!shapedType.hasRank()) {
43  // Nothing to check for unranked shaped values.
44  ++resultIdx;
45  continue;
46  }
47  // Assert one OpFoldResult per dimension.
48  assert(shapedType.getRank() ==
49  static_cast<int64_t>(reifiedReturnShapes[resultIdx].size()) &&
50  "incorrect implementation of ReifyRankedShapedTypeOpInterface");
51  ++resultIdx;
52  }
53  // Assert that every shaped value result was reified.
54  assert(resultIdx == static_cast<int64_t>(reifiedReturnShapes.size()) &&
55  "incorrect implementation of ReifyRankedShapedTypeOpInterface");
56 #endif // NDEBUG
57  return status;
58 }
59 
60 bool ShapeAdaptor::hasRank() const {
61  if (val.isNull())
62  return false;
63  if (auto t = llvm::dyn_cast_if_present<Type>(val))
64  return cast<ShapedType>(t).hasRank();
65  if (isa<Attribute>(val))
66  return true;
67  return cast<ShapedTypeComponents *>(val)->hasRank();
68 }
69 
71  if (val.isNull())
72  return nullptr;
73  if (auto t = llvm::dyn_cast_if_present<Type>(val))
74  return cast<ShapedType>(t).getElementType();
75  if (isa<Attribute>(val))
76  return nullptr;
77  return cast<ShapedTypeComponents *>(val)->getElementType();
78 }
79 
81  assert(hasRank());
82  if (auto t = llvm::dyn_cast_if_present<Type>(val)) {
83  ArrayRef<int64_t> vals = cast<ShapedType>(t).getShape();
84  res.assign(vals.begin(), vals.end());
85  } else if (auto attr = llvm::dyn_cast_if_present<Attribute>(val)) {
86  auto dattr = cast<DenseIntElementsAttr>(attr);
87  res.clear();
88  res.reserve(dattr.size());
89  for (auto it : dattr.getValues<APInt>())
90  res.push_back(it.getSExtValue());
91  } else {
92  auto vals = cast<ShapedTypeComponents *>(val)->getDims();
93  res.assign(vals.begin(), vals.end());
94  }
95 }
96 
98  assert(hasRank());
99  res.ranked = true;
100  getDims(res.dims);
101 }
102 
103 int64_t ShapeAdaptor::getDimSize(int index) const {
104  assert(hasRank());
105  if (auto t = llvm::dyn_cast_if_present<Type>(val))
106  return cast<ShapedType>(t).getDimSize(index);
107  if (auto attr = llvm::dyn_cast_if_present<Attribute>(val))
108  return cast<DenseIntElementsAttr>(attr)
109  .getValues<APInt>()[index]
110  .getSExtValue();
111  auto *stc = cast<ShapedTypeComponents *>(val);
112  return stc->getDims()[index];
113 }
114 
115 int64_t ShapeAdaptor::getRank() const {
116  assert(hasRank());
117  if (auto t = llvm::dyn_cast_if_present<Type>(val))
118  return cast<ShapedType>(t).getRank();
119  if (auto attr = llvm::dyn_cast_if_present<Attribute>(val))
120  return cast<DenseIntElementsAttr>(attr).size();
121  return cast<ShapedTypeComponents *>(val)->getDims().size();
122 }
123 
125  if (!hasRank())
126  return false;
127 
128  if (auto t = llvm::dyn_cast_if_present<Type>(val))
129  return cast<ShapedType>(t).hasStaticShape();
130  if (auto attr = llvm::dyn_cast_if_present<Attribute>(val)) {
131  auto dattr = cast<DenseIntElementsAttr>(attr);
132  for (auto index : dattr.getValues<APInt>())
133  if (ShapedType::isDynamic(index.getSExtValue()))
134  return false;
135  return true;
136  }
137  auto *stc = cast<ShapedTypeComponents *>(val);
138  return llvm::none_of(stc->getDims(), ShapedType::isDynamic);
139 }
140 
142  assert(hasStaticShape() && "cannot get element count of dynamic shaped type");
143 
144  if (auto t = llvm::dyn_cast_if_present<Type>(val))
145  return cast<ShapedType>(t).getNumElements();
146 
147  if (auto attr = llvm::dyn_cast_if_present<Attribute>(val)) {
148  auto dattr = cast<DenseIntElementsAttr>(attr);
149  int64_t num = 1;
150  for (auto index : dattr.getValues<APInt>()) {
151  num *= index.getZExtValue();
152  assert(num >= 0 && "integer overflow in element count computation");
153  }
154  return num;
155  }
156 
157  auto *stc = cast<ShapedTypeComponents *>(val);
158  int64_t num = 1;
159  for (int64_t dim : stc->getDims()) {
160  num *= dim;
161  assert(num >= 0 && "integer overflow in element count computation");
162  }
163  return num;
164 }
165 
166 void ShapeAdaptor::dump() const {
167  if (!hasRank()) {
168  llvm::errs() << "<<unranked>>\n";
169  return;
170  }
171 
173  getDims(dims);
174  auto mapped = llvm::map_range(dims, [](int64_t dim) -> std::string {
175  if (ShapedType::isDynamic(dim))
176  return "?";
177  return llvm::formatv("{0}", dim).str();
178  });
179  llvm::errs() << "rank = " << getRank() << " dims = [";
180  llvm::interleave(mapped, llvm::errs(), "x");
181  llvm::errs() << "]\n";
182 }
183 
185  Value val = operator[](index);
186  if (valueToShape)
187  if (ShapeAdaptor ret = valueToShape(val))
188  return ret;
189 
191  if (!matchPattern(val, m_Constant(&attr)))
192  return nullptr;
193  if (attr.getType().getRank() != 1)
194  return nullptr;
195  return attr;
196 }
197 
199  if (operandShape)
200  if (ShapeAdaptor ret = operandShape(val))
201  return ret;
202  return val.getType();
203 }
204 
206  if (index < 0 || static_cast<size_t>(index) >= size())
207  return nullptr;
208  return getShape(operator[](index));
209 }
210 
212  ArrayRef<ShapedTypeComponents> retComponents,
213  SmallVectorImpl<Type> &inferredReturnTypes) {
214  for (const auto &shapeAndType : retComponents) {
215  Type elementTy = shapeAndType.getElementType();
216  assert(elementTy && "element type required to construct tensor");
217 
218  Attribute attr = shapeAndType.getAttribute();
219  if (shapeAndType.hasRank()) {
220  inferredReturnTypes.push_back(
221  RankedTensorType::get(shapeAndType.getDims(), elementTy, attr));
222  } else {
223  assert(attr == nullptr && "attribute not supported");
224  inferredReturnTypes.push_back(UnrankedTensorType::get(elementTy));
225  }
226  }
227  return success();
228 }
229 
231  SmallVector<Type, 4> inferredReturnTypes(op->getResultTypes());
232  auto retTypeFn = cast<InferTypeOpInterface>(op);
233  auto result = retTypeFn.refineReturnTypes(
234  op->getContext(), op->getLoc(), op->getOperands(),
236  inferredReturnTypes);
237  if (failed(result))
238  op->emitOpError() << "failed to infer returned types";
239 
240  return result;
241 }
242 
244  std::string buffer;
245  llvm::raw_string_ostream os(buffer);
246  os << "Failed to infer result type(s):\n";
247  os << "\"" << state.name << "\"(...) ";
248  os << state.attributes.getDictionary(state.location.getContext());
249  os << " : (";
250  llvm::interleaveComma(state.operands, os,
251  [&](Value val) { os << val.getType(); });
252  os << ") -> ( ??? )";
253  emitRemark(state.location, "location of op");
254  llvm::report_fatal_error(llvm::StringRef(buffer));
255 }
Attributes are known-constant values of operations.
Definition: Attributes.h:25
An attribute that represents a reference to a dense integer vector or tensor object.
This class helps build Operations.
Definition: Builders.h:205
This is a value defined by a result of an operation.
Definition: Value.h:457
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
MLIRContext * getContext()
Return the context this operation is associated with.
Definition: Operation.h:216
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
DictionaryAttr getRawDictionaryAttrs()
Return all attributes that are not stored as properties.
Definition: Operation.h:509
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
Definition: Operation.h:677
result_type_range getResultTypes()
Definition: Operation.h:428
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:378
result_range getResults()
Definition: Operation.h:415
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:671
OpaqueProperties getPropertiesStorage()
Returns the properties storage.
Definition: Operation.h:901
Adaptor class to abstract the differences between whether value is from a ShapedType or ShapedTypeCom...
int64_t getDimSize(int index) const
Returns the size of the index'th dimension.
void dump() const
Dumps textual repesentation to stderr.
Type getElementType() const
Returns the element type.
int64_t getRank() const
Returns the rank of the shape.
bool hasStaticShape() const
Returns whether the shape is fully static.
int64_t getNumElements() const
Returns the number of elements in the shape.
void getDims(SmallVectorImpl< int64_t > &res) const
Populates the dimensions from shape referenced.
bool hasRank() const
Returns whether the shape has a rank.
ShapedTypeComponents that represents the components of a ShapedType.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
ShapeAdaptor getValueAsShape(int index)
Returns an argument as shape.
ShapeAdaptor getShape(int index) const
Returns the shape of index'th operand.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
LogicalResult inferReturnTensorTypes(ArrayRef< ShapedTypeComponents > retComponents, SmallVectorImpl< Type > &inferredReturnTypes)
void reportFatalInferReturnTypesError(OperationState &state)
Report a fatal error indicating that the result types could not be inferred.
LogicalResult verifyInferredResultTypes(Operation *op)
Verifies that the inferred result types match the actual result types for the op.
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:490
LogicalResult reifyResultShapes(OpBuilder &b, Operation *op, ReifiedRankedShapedTypeDims &reifiedReturnShapes)
Reify the shape of the result of an operation (typically in terms of the shape of its operands).
InFlightDiagnostic emitRemark(Location loc)
Utility method to emit a remark message using this location.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition: Matchers.h:369
This represents an operation in an abstracted form, suitable for use with the builder APIs.