MLIR  20.0.0git
TosaInferShapes.cpp
Go to the documentation of this file.
1 //===- TosaInferShapes.cpp ------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Propagate shapes forward along TOSA operations to resolve dynamic shape
10 // operations.
11 //
12 //===----------------------------------------------------------------------===//
13 
15 
20 #include "mlir/IR/Builders.h"
23 #include "mlir/Pass/Pass.h"
25 
26 namespace mlir {
27 namespace tosa {
28 #define GEN_PASS_DEF_TOSAINFERSHAPES
29 #include "mlir/Dialect/Tosa/Transforms/Passes.h.inc"
30 } // namespace tosa
31 } // namespace mlir
32 
33 using namespace mlir;
34 using namespace mlir::tosa;
35 
36 namespace {
37 
38 // Check whether this use case is replaceable. We define an op as
39 // being replaceable if it is used by a TosaOp, or an op with a
40 // type-inference related interface.
41 // When a non-replaceable use is encountered, the value is wrapped in a
42 // cast back to the original type after inference.
43 bool canBeRefined(Operation *user) {
44  if (!user->getDialect())
45  return false;
46  return user->getDialect()->getTypeID() == TypeID::get<TosaDialect>() ||
47  isa<InferTypeOpInterface, InferShapedTypeOpInterface>(user);
48 }
49 
50 // During type propagation, the types of values in the operator graph are
51 // updated. For the tosa.while_loop operation, types are speculatively updated
52 // within the body region to determine the output type of the while_loop. This
53 // process is performed until a fixed point is reached, then the types are
54 // rolled back.
55 //
56 // This class encapsulates the state information needed to perform the roll back
57 // process or to commit to the final changes.
58 class TypeModificationState {
59 public:
60  TypeModificationState() = default;
61 
62  ~TypeModificationState() {
63  // Ensure the recorded modifications are either committed or rolled back.
64  assert(oldTypes.empty() && "unhandled type modifications");
65  }
66 
67  // Update the state of the value and record the old type.
68  void setType(Value value, Type type) {
69  if (value.getType() != type) {
70  oldTypes.emplace_back(value, value.getType());
71  value.setType(type);
72  }
73  }
74 
75  // Roll back changes made to the types in the IR by setting all the affected
76  // values to their old types.
77  void rollBack() {
78  for (auto [value, type] : oldTypes)
79  value.setType(type);
80 
81  oldTypes.clear();
82  }
83 
84  // Commit the changes to the types in the IR.
85  // This requires inserting tensor.cast operations to mediate the newly
86  // inferred result types with users that do not support type inference.
87  void commit() {
88  // For each use whose type changed, cast the value with the new type back to
89  // the old type.
90  for (auto [value, oldType] : oldTypes) {
91  // The call to 'use->set()' in the body of the loop below invalidates the
92  // iterator used to traverse op uses, so it is important to make a copy of
93  // these first.
94  llvm::SmallVector<OpOperand *> uses = llvm::map_to_vector(
95  value.getUses(),
96  [](OpOperand &use) -> OpOperand * {
97  return &use;
98  });
99 
100  // A 'tensor.cast' op is emitted only if needed. Once emitted, it is
101  // cached and reused by all consumers.
102  tensor::CastOp castValue;
103 
104  // Traverse all uses
105  for (OpOperand *use : uses) {
106  if (canBeRefined(use->getOwner()))
107  continue;
108 
109  if (!castValue) {
110  // Set the insertion point as far back as possible, since new
111  // consumers of the 'tensor.cast' op generated in future iterations
112  // are likely to be further up in the code due to the order in which
113  // they appear in the use list.
114  OpBuilder builder{value.getContext()};
115  builder.setInsertionPointAfter(value.getDefiningOp());
116  castValue =
117  builder.create<tensor::CastOp>(value.getLoc(), oldType, value);
118  }
119 
120  use->set(castValue);
121  }
122  }
123 
124  oldTypes.clear();
125  }
126 
127 private:
128  // A record of each value whose type was updated along with that value's
129  // previous type.
131 };
132 
133 void propagateShapesInRegion(Region &region, TypeModificationState &state);
134 
135 void propagateShapesToTosaIf(Operation &op, TypeModificationState &state) {
136  IfOp ifOp = dyn_cast<IfOp>(op);
137  if (!ifOp)
138  return;
139 
140  for (auto &region : op.getRegions()) {
141  Block &frontBlock = region.front();
142  if (frontBlock.getNumArguments() + 1 != ifOp.getNumOperands())
143  return;
144 
145  for (unsigned int i = 1, s = op.getNumOperands(); i < s; i++) {
146  auto inferredTy = cast<ShapedType>(op.getOperand(i).getType());
147  auto blockArg = frontBlock.getArgument(i - 1);
148  auto oldType = cast<ShapedType>(blockArg.getType());
149 
150  if (inferredTy.hasRank()) {
151  Type newType = oldType.clone(inferredTy.getShape());
152  state.setType(blockArg, newType);
153  }
154  }
155 
156  for (int i = 0, e = frontBlock.getNumArguments(); i < e; i++) {
158  ifOp.getOperand(i + 1).getType());
160  frontBlock.getArgument(i).getType());
161  ValueKnowledge joinedKnowledge =
162  ValueKnowledge::join(operandKnowledge, blockKnowledge);
163  if (!joinedKnowledge)
164  continue;
165  state.setType(frontBlock.getArgument(i), joinedKnowledge.getType());
166  }
167 
168  propagateShapesInRegion(region, state);
169  }
170 }
171 
172 void propagateShapesToTosaWhile(Operation &op, TypeModificationState &state) {
173  WhileOp whileOp = dyn_cast<WhileOp>(op);
174  if (!whileOp)
175  return;
176 
177  // Determine what the expected argument types are to the cond/body blocks.
178  // The expected arguments should be compatible with ever iteration of the
179  // loop body / condition for tosa.while.
180  SmallVector<Type> argTypes = llvm::to_vector(op.getOperandTypes());
181 
182  bool hasNewTypes = true;
183  while (hasNewTypes) {
184  TypeModificationState localState;
185 
186  // Set types on the block args.
187  Region &bodyRegion = op.getRegion(1);
188  Block &block = bodyRegion.front();
189  for (int i = 0, s = argTypes.size(); i < s; i++) {
190  localState.setType(block.getArgument(i), argTypes[i]);
191  }
192 
193  // Propagate to the end.
194  propagateShapesInRegion(bodyRegion, localState);
195 
196  // Find all the tosa yield types and verify there is a single one.
198  for (auto &block : bodyRegion)
199  if (auto yieldOp = dyn_cast<YieldOp>(block.getTerminator()))
200  yieldOps.push_back(yieldOp);
201 
202  assert(yieldOps.size() == 1 && "missing or non-unique yield op");
203  // Using the new tosa.yield operand types, infer the new subtypes.
204  llvm::SmallVector<ValueKnowledge> yieldTypeInfo;
205  for (auto ty : argTypes) {
206  yieldTypeInfo.push_back(ValueKnowledge::getKnowledgeFromType(ty));
207  }
208 
209  for (auto yieldOp : yieldOps) {
210  for (const auto &it : llvm::enumerate(yieldOp.getOperands())) {
211  auto newKnowledge =
212  ValueKnowledge::getKnowledgeFromType(it.value().getType());
213  yieldTypeInfo[it.index()] =
214  ValueKnowledge::meet(yieldTypeInfo[it.index()], newKnowledge);
215  }
216  }
217 
218  // This should never happen.
219  if (yieldTypeInfo.size() != argTypes.size()) {
220  op.emitWarning("has a tosa.yield with the incorrect number of operands");
221  return;
222  }
223 
224  // Determine the new block args and see if any changed.
225  hasNewTypes = false;
226  for (int i = 0, s = yieldTypeInfo.size(); i < s; i++) {
227  Type newType = yieldTypeInfo[i].getType();
228  hasNewTypes |= (newType != argTypes[i]);
229  argTypes[i] = newType;
230  }
231 
232  // Roll back all changes made during the speculative part of the algorithm.
233  localState.rollBack();
234  }
235 
236  // We now set the block arguments according to the most recent shape
237  // inference results. This gives us the block arg types for the next
238  // iteration.
239  for (auto &region : op.getRegions()) {
240  for (unsigned int i = 0, s = argTypes.size(); i < s; i++) {
241  state.setType(region.front().getArgument(i), argTypes[i]);
242  }
243 
244  propagateShapesInRegion(region, state);
245  }
246 }
247 
248 void propagateShapesInRegion(Region &region, TypeModificationState &state) {
249  Dialect *tosaDialect = region.getContext()->getLoadedDialect<TosaDialect>();
250 
251  for (auto &block : region) {
252  for (Operation &op : block) {
253  if (op.getDialect() != tosaDialect)
254  continue;
255 
256  propagateShapesToTosaIf(op, state);
257  propagateShapesToTosaWhile(op, state);
258 
259  InferShapedTypeOpInterface shapeInterface =
260  dyn_cast<InferShapedTypeOpInterface>(op);
261  if (!shapeInterface)
262  continue;
263 
264  SmallVector<ShapedTypeComponents> returnedShapes;
265 
266  if (shapeInterface
267  .inferReturnTypeComponents(
268  op.getContext(), op.getLoc(), op.getOperands(),
270  op.getRegions(), returnedShapes)
271  .succeeded()) {
272  for (auto it : llvm::zip(op.getResults(), returnedShapes)) {
273  Value result = std::get<0>(it);
274  ShapedTypeComponents predictedShape = std::get<1>(it);
275 
276  // Determine the knowledge based on the output type.
277  // TODO: should also query WIP type probably
278  Type resultTy = result.getType();
279  auto currentKnowledge =
281 
282  // Compute the knowledge based on the inferred type.
283  auto inferredKnowledge = ValueKnowledge::getPessimisticValueState();
284  inferredKnowledge.dtype = cast<ShapedType>(resultTy).getElementType();
285  inferredKnowledge.hasRank = predictedShape.hasRank();
286  if (predictedShape.hasRank()) {
287  for (auto dim : predictedShape.getDims()) {
288  inferredKnowledge.sizes.push_back(dim);
289  }
290  }
291 
292  // Compute the new type based on the joined version.
293  auto newKnowledge =
294  ValueKnowledge::join(currentKnowledge, inferredKnowledge);
295  if (!newKnowledge)
296  continue;
297 
298  // Set new type
299  state.setType(result, newKnowledge.getType());
300  }
301  }
302  }
303  }
304 }
305 
306 /// Pass that performs shape propagation across TOSA operations. This includes
307 /// migrating to within the regions of if/while operations.
308 struct TosaInferShapes
309  : public tosa::impl::TosaInferShapesBase<TosaInferShapes> {
310 public:
311  void runOnOperation() override {
312  func::FuncOp func = getOperation();
313  TypeModificationState state;
314  propagateShapesInRegion(func.getBody(), state);
315  state.commit();
316  }
317 };
318 } // namespace
319 
320 std::unique_ptr<Pass> mlir::tosa::createTosaInferShapesPass() {
321  return std::make_unique<TosaInferShapes>();
322 }
Block represents an ordered list of Operations.
Definition: Block.h:33
BlockArgument getArgument(unsigned i)
Definition: Block.h:129
unsigned getNumArguments()
Definition: Block.h:128
Operation * getTerminator()
Get the terminator operation of this block.
Definition: Block.cpp:246
Operation & front()
Definition: Block.h:153
Dialects are groups of MLIR operations, types and attributes, as well as behavior associated with the...
Definition: Dialect.h:38
TypeID getTypeID() const
Returns the unique identifier that corresponds to this dialect.
Definition: Dialect.h:57
Dialect * getLoadedDialect(StringRef name)
Get a registered IR dialect with the given namespace.
This class helps build Operations.
Definition: Builders.h:215
This class represents an operand of an operation.
Definition: Value.h:267
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Value getOperand(unsigned idx)
Definition: Operation.h:345
Dialect * getDialect()
Return the dialect this operation is associated with, or nullptr if the associated dialect is not loa...
Definition: Operation.h:220
InFlightDiagnostic emitWarning(const Twine &message={})
Emit a warning about this operation, reporting up to any diagnostic handlers that may be listening.
Definition: Operation.cpp:280
MLIRContext * getContext()
Return the context this operation is associated with.
Definition: Operation.h:216
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
unsigned getNumOperands()
Definition: Operation.h:341
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Definition: Operation.h:682
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
Definition: Operation.h:672
DictionaryAttr getDiscardableAttrDictionary()
Return all of the discardable attributes on this operation as a DictionaryAttr.
Definition: Operation.h:496
operand_type_range getOperandTypes()
Definition: Operation.h:392
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:373
result_range getResults()
Definition: Operation.h:410
OpaqueProperties getPropertiesStorage()
Returns the properties storage.
Definition: Operation.h:896
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
MLIRContext * getContext()
Return the context this region is inserted in.
Definition: Region.cpp:24
Block & front()
Definition: Region.h:65
ShapedTypeComponents that represents the components of a ShapedType.
bool hasRank() const
Return whether the shape has a rank.
ArrayRef< int64_t > getDims() const
Return the dimensions of the shape.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
void setType(Type newType)
Mutate the type of this Value to be of the specified type.
Definition: Value.h:140
MLIRContext * getContext() const
Utility to get the associated MLIRContext that this value is defined in.
Definition: Value.h:132
Type getType() const
Return the type of this value.
Definition: Value.h:129
use_range getUses() const
Returns a range of all uses, which is useful for iterating over all uses.
Definition: Value.h:212
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:26
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
std::unique_ptr< Pass > createTosaInferShapesPass()
Include the generated interface declarations.
Statically known information for a particular Value.
Definition: ShapeUtils.h:33
static ValueKnowledge join(const ValueKnowledge &lhs, const ValueKnowledge &rhs)
Definition: ShapeUtils.h:81
static ValueKnowledge getPessimisticValueState()
Definition: ShapeUtils.h:61
static ValueKnowledge meet(const ValueKnowledge &lhs, const ValueKnowledge &rhs)
Definition: ShapeUtils.h:136
static ValueKnowledge getKnowledgeFromType(Type type)
Definition: ShapeUtils.h:45