MLIR  19.0.0git
InferTypeOpInterface.cpp
Go to the documentation of this file.
1 //===- InferTypeOpInterface.cpp - Infer Type Interfaces ---------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file contains the definitions of the infer op interfaces defined in
10 // `InferTypeOpInterface.td`.
11 //
12 //===----------------------------------------------------------------------===//
13 
15 #include "mlir/IR/BuiltinTypes.h"
16 #include "mlir/IR/Matchers.h"
17 #include "llvm/Support/FormatVariadic.h"
18 
19 using namespace mlir;
20 
21 namespace mlir {
22 #include "mlir/Interfaces/InferTypeOpInterface.cpp.inc"
23 } // namespace mlir
24 
27  ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
28  auto reifiableOp = dyn_cast<ReifyRankedShapedTypeOpInterface>(op);
29  if (!reifiableOp)
30  return failure();
31  LogicalResult status = reifiableOp.reifyResultShapes(b, reifiedReturnShapes);
32 #ifndef NDEBUG
33  if (failed(status))
34  return failure();
35  // Assert that ReifyRankedShapedTypeOpInterface::reifyResultShapes produced
36  // a correct result.
37  int64_t resultIdx = 0;
38  for (OpResult result : op->getResults()) {
39  auto shapedType = dyn_cast<ShapedType>(result.getType());
40  if (!shapedType)
41  continue;
42  if (!shapedType.hasRank()) {
43  // Nothing to check for unranked shaped values.
44  ++resultIdx;
45  continue;
46  }
47  // Assert one OpFoldResult per dimension.
48  assert(shapedType.getRank() ==
49  static_cast<int64_t>(reifiedReturnShapes[resultIdx].size()) &&
50  "incorrect implementation of ReifyRankedShapedTypeOpInterface");
51  for (int64_t dim = 0; dim < shapedType.getRank(); ++dim) {
52  // reifyResultShapes must return:
53  // * Attribute for static dimensions
54  // * Value for dynamic dimensions
55  assert(shapedType.isDynamicDim(dim) ==
56  reifiedReturnShapes[resultIdx][dim].is<Value>() &&
57  "incorrect implementation of ReifyRankedShapedTypeOpInterface");
58  }
59  ++resultIdx;
60  }
61  // Assert that every shaped value result was reified.
62  assert(resultIdx == static_cast<int64_t>(reifiedReturnShapes.size()) &&
63  "incorrect implementation of ReifyRankedShapedTypeOpInterface");
64 #endif // NDEBUG
65  return status;
66 }
67 
68 bool ShapeAdaptor::hasRank() const {
69  if (val.isNull())
70  return false;
71  if (auto t = llvm::dyn_cast_if_present<Type>(val))
72  return cast<ShapedType>(t).hasRank();
73  if (val.is<Attribute>())
74  return true;
75  return val.get<ShapedTypeComponents *>()->hasRank();
76 }
77 
79  if (val.isNull())
80  return nullptr;
81  if (auto t = llvm::dyn_cast_if_present<Type>(val))
82  return cast<ShapedType>(t).getElementType();
83  if (val.is<Attribute>())
84  return nullptr;
85  return val.get<ShapedTypeComponents *>()->getElementType();
86 }
87 
89  assert(hasRank());
90  if (auto t = llvm::dyn_cast_if_present<Type>(val)) {
91  ArrayRef<int64_t> vals = cast<ShapedType>(t).getShape();
92  res.assign(vals.begin(), vals.end());
93  } else if (auto attr = llvm::dyn_cast_if_present<Attribute>(val)) {
94  auto dattr = cast<DenseIntElementsAttr>(attr);
95  res.clear();
96  res.reserve(dattr.size());
97  for (auto it : dattr.getValues<APInt>())
98  res.push_back(it.getSExtValue());
99  } else {
100  auto vals = val.get<ShapedTypeComponents *>()->getDims();
101  res.assign(vals.begin(), vals.end());
102  }
103 }
104 
106  assert(hasRank());
107  res.ranked = true;
108  getDims(res.dims);
109 }
110 
111 int64_t ShapeAdaptor::getDimSize(int index) const {
112  assert(hasRank());
113  if (auto t = llvm::dyn_cast_if_present<Type>(val))
114  return cast<ShapedType>(t).getDimSize(index);
115  if (auto attr = llvm::dyn_cast_if_present<Attribute>(val))
116  return cast<DenseIntElementsAttr>(attr)
117  .getValues<APInt>()[index]
118  .getSExtValue();
119  auto *stc = val.get<ShapedTypeComponents *>();
120  return stc->getDims()[index];
121 }
122 
123 int64_t ShapeAdaptor::getRank() const {
124  assert(hasRank());
125  if (auto t = llvm::dyn_cast_if_present<Type>(val))
126  return cast<ShapedType>(t).getRank();
127  if (auto attr = llvm::dyn_cast_if_present<Attribute>(val))
128  return cast<DenseIntElementsAttr>(attr).size();
129  return val.get<ShapedTypeComponents *>()->getDims().size();
130 }
131 
133  if (!hasRank())
134  return false;
135 
136  if (auto t = llvm::dyn_cast_if_present<Type>(val))
137  return cast<ShapedType>(t).hasStaticShape();
138  if (auto attr = llvm::dyn_cast_if_present<Attribute>(val)) {
139  auto dattr = cast<DenseIntElementsAttr>(attr);
140  for (auto index : dattr.getValues<APInt>())
141  if (ShapedType::isDynamic(index.getSExtValue()))
142  return false;
143  return true;
144  }
145  auto *stc = val.get<ShapedTypeComponents *>();
146  return llvm::none_of(stc->getDims(), ShapedType::isDynamic);
147 }
148 
150  assert(hasStaticShape() && "cannot get element count of dynamic shaped type");
151 
152  if (auto t = llvm::dyn_cast_if_present<Type>(val))
153  return cast<ShapedType>(t).getNumElements();
154 
155  if (auto attr = llvm::dyn_cast_if_present<Attribute>(val)) {
156  auto dattr = cast<DenseIntElementsAttr>(attr);
157  int64_t num = 1;
158  for (auto index : dattr.getValues<APInt>()) {
159  num *= index.getZExtValue();
160  assert(num >= 0 && "integer overflow in element count computation");
161  }
162  return num;
163  }
164 
165  auto *stc = val.get<ShapedTypeComponents *>();
166  int64_t num = 1;
167  for (int64_t dim : stc->getDims()) {
168  num *= dim;
169  assert(num >= 0 && "integer overflow in element count computation");
170  }
171  return num;
172 }
173 
174 void ShapeAdaptor::dump() const {
175  if (!hasRank()) {
176  llvm::errs() << "<<unranked>>\n";
177  return;
178  }
179 
181  getDims(dims);
182  auto mapped = llvm::map_range(dims, [](int64_t dim) -> std::string {
183  if (ShapedType::isDynamic(dim))
184  return "?";
185  return llvm::formatv("{0}", dim).str();
186  });
187  llvm::errs() << "rank = " << getRank() << " dims = [";
188  llvm::interleave(mapped, llvm::errs(), "x");
189  llvm::errs() << "]\n";
190 }
191 
193  Value val = operator[](index);
194  if (valueToShape)
195  if (ShapeAdaptor ret = valueToShape(val))
196  return ret;
197 
199  if (!matchPattern(val, m_Constant(&attr)))
200  return nullptr;
201  if (attr.getType().getRank() != 1)
202  return nullptr;
203  return attr;
204 }
205 
207  if (operandShape)
208  if (ShapeAdaptor ret = operandShape(val))
209  return ret;
210  return val.getType();
211 }
212 
214  if (index < 0 || static_cast<size_t>(index) >= size())
215  return nullptr;
216  return getShape(operator[](index));
217 }
218 
220  ArrayRef<ShapedTypeComponents> retComponents,
221  SmallVectorImpl<Type> &inferredReturnTypes) {
222  for (const auto &shapeAndType : retComponents) {
223  Type elementTy = shapeAndType.getElementType();
224  assert(elementTy && "element type required to construct tensor");
225 
226  Attribute attr = shapeAndType.getAttribute();
227  if (shapeAndType.hasRank()) {
228  inferredReturnTypes.push_back(
229  RankedTensorType::get(shapeAndType.getDims(), elementTy, attr));
230  } else {
231  assert(attr == nullptr && "attribute not supported");
232  inferredReturnTypes.push_back(UnrankedTensorType::get(elementTy));
233  }
234  }
235  return success();
236 }
237 
239  SmallVector<Type, 4> inferredReturnTypes(op->getResultTypes());
240  auto retTypeFn = cast<InferTypeOpInterface>(op);
241  auto result = retTypeFn.refineReturnTypes(
242  op->getContext(), op->getLoc(), op->getOperands(),
244  inferredReturnTypes);
245  if (failed(result))
246  op->emitOpError() << "failed to infer returned types";
247 
248  return result;
249 }
Attributes are known-constant values of operations.
Definition: Attributes.h:25
An attribute that represents a reference to a dense integer vector or tensor object.
This class helps build Operations.
Definition: Builders.h:209
This is a value defined by a result of an operation.
Definition: Value.h:453
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
MLIRContext * getContext()
Return the context this operation is associated with.
Definition: Operation.h:216
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
DictionaryAttr getRawDictionaryAttrs()
Return all attributes that are not stored as properties.
Definition: Operation.h:504
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
Definition: Operation.h:672
result_type_range getResultTypes()
Definition: Operation.h:423
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:373
result_range getResults()
Definition: Operation.h:410
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:671
OpaqueProperties getPropertiesStorage()
Returns the properties storage.
Definition: Operation.h:896
Adaptor class to abstract the differences between whether value is from a ShapedType or ShapedTypeCom...
int64_t getDimSize(int index) const
Returns the size of the index'th dimension.
void dump() const
Dumps textual repesentation to stderr.
Type getElementType() const
Returns the element type.
int64_t getRank() const
Returns the rank of the shape.
bool hasStaticShape() const
Returns whether the shape is fully static.
int64_t getNumElements() const
Returns the number of elements in the shape.
void getDims(SmallVectorImpl< int64_t > &res) const
Populates the dimensions from shape referenced.
bool hasRank() const
Returns whether the shape has a rank.
ShapedTypeComponents that represents the components of a ShapedType.
ArrayRef< int64_t > getDims() const
Return the dimensions of the shape.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
ShapeAdaptor getValueAsShape(int index)
Returns an argument as shape.
ShapeAdaptor getShape(int index) const
Returns the shape of index'th operand.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:125
LogicalResult inferReturnTensorTypes(ArrayRef< ShapedTypeComponents > retComponents, SmallVectorImpl< Type > &inferredReturnTypes)
LogicalResult verifyInferredResultTypes(Operation *op)
Verifies that the inferred result types match the actual result types for the op.
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:401
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
LogicalResult reifyResultShapes(OpBuilder &b, Operation *op, ReifiedRankedShapedTypeDims &reifiedReturnShapes)
Reify the shape of the result of an operation (typically in terms of the shape of its operands).
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition: Matchers.h:310
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26