MLIR  22.0.0git
TosaInferShapes.cpp
Go to the documentation of this file.
1 //===- TosaInferShapes.cpp ------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Propagate shapes forward along TOSA operations to resolve dynamic shape
10 // operations.
11 //
12 //===----------------------------------------------------------------------===//
13 
15 
20 #include "mlir/IR/Builders.h"
23 
24 namespace mlir {
25 namespace tosa {
26 #define GEN_PASS_DEF_TOSAINFERSHAPESPASS
27 #include "mlir/Dialect/Tosa/Transforms/Passes.h.inc"
28 } // namespace tosa
29 } // namespace mlir
30 
31 using namespace mlir;
32 using namespace mlir::tosa;
33 
34 namespace {
35 
36 // Check whether this use case is replaceable. We define an op as
37 // being replaceable if it is used by a TosaOp, or an op with a
38 // type-inference related interface.
39 // When a non-replaceable use is encountered, the value is wrapped in a
40 // cast back to the original type after inference.
41 bool canBeRefined(Operation *user) {
42  if (!user->getDialect())
43  return false;
44  return user->getDialect()->getTypeID() == TypeID::get<TosaDialect>() ||
45  isa<InferTypeOpInterface, InferShapedTypeOpInterface>(user);
46 }
47 
48 // During type propagation, the types of values in the operator graph are
49 // updated. For the tosa.while_loop operation, types are speculatively updated
50 // within the body region to determine the output type of the while_loop. This
51 // process is performed until a fixed point is reached, then the types are
52 // rolled back.
53 //
54 // This class encapsulates the state information needed to perform the roll back
55 // process or to commit to the final changes.
56 class TypeModificationState {
57 public:
58  TypeModificationState() = default;
59 
60  ~TypeModificationState() {
61  // Ensure the recorded modifications are either committed or rolled back.
62  assert(oldTypes.empty() && "unhandled type modifications");
63  }
64 
65  // Update the state of the value and record the old type.
66  void setType(Value value, Type type) {
67  if (value.getType() != type) {
68  oldTypes.emplace_back(value, value.getType());
69  value.setType(type);
70  }
71  }
72 
73  // Roll back changes made to the types in the IR by setting all the affected
74  // values to their old types.
75  void rollBack() {
76  for (auto [value, type] : oldTypes)
77  value.setType(type);
78 
79  oldTypes.clear();
80  }
81 
82  // Commit the changes to the types in the IR.
83  // This requires inserting tensor.cast operations to mediate the newly
84  // inferred result types with users that do not support type inference.
85  void commit() {
86  // For each use whose type changed, cast the value with the new type back to
87  // the old type.
88  for (auto [value, oldType] : oldTypes) {
89  // The call to 'use->set()' in the body of the loop below invalidates the
90  // iterator used to traverse op uses, so it is important to make a copy of
91  // these first.
92  llvm::SmallVector<OpOperand *> uses = llvm::map_to_vector(
93  value.getUses(),
94  [](OpOperand &use) -> OpOperand * {
95  return &use;
96  });
97 
98  // A 'tensor.cast' op is emitted only if needed. Once emitted, it is
99  // cached and reused by all consumers.
100  tensor::CastOp castValue;
101 
102  // Traverse all uses
103  for (OpOperand *use : uses) {
104  if (canBeRefined(use->getOwner()))
105  continue;
106 
107  if (!castValue) {
108  // Set the insertion point as far back as possible, since new
109  // consumers of the 'tensor.cast' op generated in future iterations
110  // are likely to be further up in the code due to the order in which
111  // they appear in the use list.
112  OpBuilder builder{value.getContext()};
113  builder.setInsertionPointAfter(value.getDefiningOp());
114  castValue =
115  tensor::CastOp::create(builder, value.getLoc(), oldType, value);
116  }
117 
118  use->set(castValue);
119  }
120  }
121 
122  oldTypes.clear();
123  }
124 
125 private:
126  // A record of each value whose type was updated along with that value's
127  // previous type.
129 };
130 
131 void propagateShapesInRegion(Region &region, TypeModificationState &state);
132 
133 void propagateShapesToTosaIf(Operation &op, TypeModificationState &state) {
134  IfOp ifOp = dyn_cast<IfOp>(op);
135  if (!ifOp)
136  return;
137 
138  for (auto &region : op.getRegions()) {
139  Block &frontBlock = region.front();
140  if (frontBlock.getNumArguments() + 1 != ifOp.getNumOperands())
141  return;
142 
143  for (unsigned int i = 1, s = op.getNumOperands(); i < s; i++) {
144  auto inferredTy = cast<ShapedType>(op.getOperand(i).getType());
145  auto blockArg = frontBlock.getArgument(i - 1);
146  auto oldType = cast<ShapedType>(blockArg.getType());
147 
148  if (inferredTy.hasRank()) {
149  Type newType = oldType.clone(inferredTy.getShape());
150  state.setType(blockArg, newType);
151  }
152  }
153 
154  for (int i = 0, e = frontBlock.getNumArguments(); i < e; i++) {
156  ifOp.getOperand(i + 1).getType());
158  frontBlock.getArgument(i).getType());
159  ValueKnowledge joinedKnowledge =
160  ValueKnowledge::join(operandKnowledge, blockKnowledge);
161  if (!joinedKnowledge)
162  continue;
163  state.setType(frontBlock.getArgument(i), joinedKnowledge.getType());
164  }
165 
166  propagateShapesInRegion(region, state);
167  }
168 }
169 
170 void propagateShapesToTosaWhile(Operation &op, TypeModificationState &state) {
171  WhileOp whileOp = dyn_cast<WhileOp>(op);
172  if (!whileOp)
173  return;
174 
175  // Determine what the expected argument types are to the cond/body blocks.
176  // The expected arguments should be compatible with ever iteration of the
177  // loop body / condition for tosa.while.
178  SmallVector<Type> argTypes = llvm::to_vector(op.getOperandTypes());
179 
180  bool hasNewTypes = true;
181  while (hasNewTypes) {
182  TypeModificationState localState;
183 
184  // Set types on the block args.
185  Region &bodyRegion = op.getRegion(1);
186  Block &block = bodyRegion.front();
187  for (int i = 0, s = argTypes.size(); i < s; i++) {
188  localState.setType(block.getArgument(i), argTypes[i]);
189  }
190 
191  // Propagate to the end.
192  propagateShapesInRegion(bodyRegion, localState);
193 
194  // Find all the tosa yield types and verify there is a single one.
196  for (auto &block : bodyRegion)
197  if (auto yieldOp = dyn_cast<YieldOp>(block.getTerminator()))
198  yieldOps.push_back(yieldOp);
199 
200  assert(yieldOps.size() == 1 && "missing or non-unique yield op");
201  // Using the new tosa.yield operand types, infer the new subtypes.
202  llvm::SmallVector<ValueKnowledge> yieldTypeInfo;
203  for (auto ty : argTypes) {
204  yieldTypeInfo.push_back(ValueKnowledge::getKnowledgeFromType(ty));
205  }
206 
207  for (auto yieldOp : yieldOps) {
208  for (const auto &it : llvm::enumerate(yieldOp.getOperands())) {
209  auto newKnowledge =
210  ValueKnowledge::getKnowledgeFromType(it.value().getType());
211  yieldTypeInfo[it.index()] =
212  ValueKnowledge::meet(yieldTypeInfo[it.index()], newKnowledge);
213  }
214  }
215 
216  // This should never happen.
217  if (yieldTypeInfo.size() != argTypes.size()) {
218  op.emitWarning("has a tosa.yield with the incorrect number of operands");
219  return;
220  }
221 
222  // Determine the new block args and see if any changed.
223  hasNewTypes = false;
224  for (int i = 0, s = yieldTypeInfo.size(); i < s; i++) {
225  Type newType = yieldTypeInfo[i].getType();
226  hasNewTypes |= (newType != argTypes[i]);
227  argTypes[i] = newType;
228  }
229 
230  // Roll back all changes made during the speculative part of the algorithm.
231  localState.rollBack();
232  }
233 
234  // We now set the block arguments according to the most recent shape
235  // inference results. This gives us the block arg types for the next
236  // iteration.
237  for (auto &region : op.getRegions()) {
238  for (unsigned int i = 0, s = argTypes.size(); i < s; i++) {
239  state.setType(region.front().getArgument(i), argTypes[i]);
240  }
241 
242  propagateShapesInRegion(region, state);
243  }
244 }
245 
246 void propagateShapesInRegion(Region &region, TypeModificationState &state) {
247  Dialect *tosaDialect = region.getContext()->getLoadedDialect<TosaDialect>();
248 
249  for (auto &block : region) {
250  for (Operation &op : block) {
251  if (op.getDialect() != tosaDialect)
252  continue;
253 
254  propagateShapesToTosaIf(op, state);
255  propagateShapesToTosaWhile(op, state);
256 
257  InferShapedTypeOpInterface shapeInterface =
258  dyn_cast<InferShapedTypeOpInterface>(op);
259  if (!shapeInterface)
260  continue;
261 
262  SmallVector<ShapedTypeComponents> returnedShapes;
263 
264  if (shapeInterface
265  .inferReturnTypeComponents(
266  op.getContext(), op.getLoc(), op.getOperands(),
268  op.getRegions(), returnedShapes)
269  .succeeded()) {
270  for (auto it : llvm::zip(op.getResults(), returnedShapes)) {
271  Value result = std::get<0>(it);
272  ShapedTypeComponents predictedShape = std::get<1>(it);
273 
274  // Determine the knowledge based on the output type.
275  // TODO: should also query WIP type probably
276  Type resultTy = result.getType();
277  auto currentKnowledge =
279 
280  // Compute the knowledge based on the inferred type.
281  auto inferredKnowledge = ValueKnowledge::getPessimisticValueState();
282  inferredKnowledge.dtype = cast<ShapedType>(resultTy).getElementType();
283  inferredKnowledge.hasRank = predictedShape.hasRank();
284  if (predictedShape.hasRank()) {
285  for (auto dim : predictedShape.getDims()) {
286  inferredKnowledge.sizes.push_back(dim);
287  }
288  }
289 
290  // Compute the new type based on the joined version.
291  auto newKnowledge =
292  ValueKnowledge::join(currentKnowledge, inferredKnowledge);
293  if (!newKnowledge)
294  continue;
295 
296  // Set new type
297  state.setType(result, newKnowledge.getType());
298  }
299  }
300  }
301  }
302 }
303 
304 /// Recursively validate tosa ops with SameOperandsAndResultRank trait in region
305 /// and all nested regions
306 void validateSameOperandsAndResultRankTrait(Region &region) {
307  int errs = 0;
308  for (auto &block : region) {
309  for (auto &op : block) {
310  if (!op.getDialect() ||
311  op.getDialect()->getNamespace() != TosaDialect::getDialectNamespace())
312  continue;
315  errs++;
316  (void)errs;
317  }
318  }
319  WhileOp whileOp = dyn_cast<WhileOp>(op);
320  IfOp ifOp = dyn_cast<IfOp>(op);
321  if (whileOp || ifOp) {
322  // recurse into whileOp's regions
323  for (auto &next : op.getRegions()) {
324  validateSameOperandsAndResultRankTrait(next);
325  }
326  }
327  }
328  }
329 }
330 
331 /// Pass that performs shape propagation across TOSA operations. This includes
332 /// migrating to within the regions of if/while operations.
333 struct TosaInferShapes
334  : public tosa::impl::TosaInferShapesPassBase<TosaInferShapes> {
335 public:
336  void runOnOperation() override {
337  func::FuncOp func = getOperation();
338  TypeModificationState state;
339  propagateShapesInRegion(func.getBody(), state);
340  state.commit();
341 
342  validateSameOperandsAndResultRankTrait(func.getBody());
343  }
344 };
345 } // namespace
Block represents an ordered list of Operations.
Definition: Block.h:33
BlockArgument getArgument(unsigned i)
Definition: Block.h:129
unsigned getNumArguments()
Definition: Block.h:128
Operation * getTerminator()
Get the terminator operation of this block.
Definition: Block.cpp:244
Operation & front()
Definition: Block.h:153
Dialects are groups of MLIR operations, types and attributes, as well as behavior associated with the...
Definition: Dialect.h:38
StringRef getNamespace() const
Definition: Dialect.h:54
TypeID getTypeID() const
Returns the unique identifier that corresponds to this dialect.
Definition: Dialect.h:57
Dialect * getLoadedDialect(StringRef name)
Get a registered IR dialect with the given namespace.
This class helps build Operations.
Definition: Builders.h:207
This class represents an operand of an operation.
Definition: Value.h:257
This class verifies that op has same ranks for all operands and results types, if known.
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Value getOperand(unsigned idx)
Definition: Operation.h:350
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition: Operation.h:749
Dialect * getDialect()
Return the dialect this operation is associated with, or nullptr if the associated dialect is not loa...
Definition: Operation.h:220
InFlightDiagnostic emitWarning(const Twine &message={})
Emit a warning about this operation, reporting up to any diagnostic handlers that may be listening.
Definition: Operation.cpp:279
MLIRContext * getContext()
Return the context this operation is associated with.
Definition: Operation.h:216
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
unsigned getNumOperands()
Definition: Operation.h:346
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Definition: Operation.h:686
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
Definition: Operation.h:677
DictionaryAttr getDiscardableAttrDictionary()
Return all of the discardable attributes on this operation as a DictionaryAttr.
Definition: Operation.h:501
operand_type_range getOperandTypes()
Definition: Operation.h:397
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:378
result_range getResults()
Definition: Operation.h:415
OpaqueProperties getPropertiesStorage()
Returns the properties storage.
Definition: Operation.h:900
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
MLIRContext * getContext()
Return the context this region is inserted in.
Definition: Region.cpp:24
Block & front()
Definition: Region.h:65
ShapedTypeComponents that represents the components of a ShapedType.
bool hasRank() const
Return whether the shape has a rank.
ArrayRef< int64_t > getDims() const
Return the dimensions of the shape.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
void setType(Type newType)
Mutate the type of this Value to be of the specified type.
Definition: Value.h:116
MLIRContext * getContext() const
Utility to get the associated MLIRContext that this value is defined in.
Definition: Value.h:108
Type getType() const
Return the type of this value.
Definition: Value.h:105
use_range getUses() const
Returns a range of all uses, which is useful for iterating over all uses.
Definition: Value.h:188
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:24
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:18
LogicalResult verifySameOperandsAndResultRank(Operation *op)
Definition: Operation.cpp:1139
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition: Remarks.h:491
Include the generated interface declarations.
Statically known information for a particular Value.
Definition: ShapeUtils.h:33
static ValueKnowledge join(const ValueKnowledge &lhs, const ValueKnowledge &rhs)
Definition: ShapeUtils.h:81
static ValueKnowledge getPessimisticValueState()
Definition: ShapeUtils.h:61
static ValueKnowledge meet(const ValueKnowledge &lhs, const ValueKnowledge &rhs)
Definition: ShapeUtils.h:136
static ValueKnowledge getKnowledgeFromType(Type type)
Definition: ShapeUtils.h:45