MLIR  19.0.0git
TosaInferShapes.cpp
Go to the documentation of this file.
1 //===- TosaInferShapes.cpp ------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Propogate shapes forward along TOSA operations to resolve dynamic shape
10 // operations.
11 //
12 //===----------------------------------------------------------------------===//
13 
15 
20 #include "mlir/IR/Builders.h"
23 #include "mlir/Pass/Pass.h"
25 
26 namespace mlir {
27 namespace tosa {
28 #define GEN_PASS_DEF_TOSAINFERSHAPES
29 #include "mlir/Dialect/Tosa/Transforms/Passes.h.inc"
30 } // namespace tosa
31 } // namespace mlir
32 
33 using namespace mlir;
34 using namespace mlir::tosa;
35 
36 namespace {
37 
38 // Check whether this use case is replaceable. We define an op as
39 // being replaceable if it is used by a TosaOp, or an op with a
40 // type-inference related interface.
41 // When a non-replaceable use is encountered, the value is wrapped in a
42 // cast back to the original type after inference.
43 bool canBeRefined(Operation *user) {
44  if (!user->getDialect())
45  return false;
46  return user->getDialect()->getTypeID() == TypeID::get<TosaDialect>() ||
47  isa<InferTypeOpInterface, InferShapedTypeOpInterface>(user);
48 }
49 
50 // During type propagation, the types of values in the operator graph are
51 // updated. For the tosa.while_loop operation, types are speculatively updated
52 // within the body region to determine the output type of the while_loop. This
53 // process is performed until a fixed point is reached, then the types are
54 // rolled back.
55 //
56 // This class encapsulates the state information needed to perform the roll back
57 // process or to commit to the final changes.
58 class TypeModificationState {
59 public:
60  TypeModificationState() = default;
61 
62  ~TypeModificationState() {
63  // Ensure the recorded modifications are either committed or rolled back.
64  assert(oldTypes.empty() && "unhandled type modifications");
65  }
66 
67  // Update the state of the value and record the old type.
68  void setType(Value value, Type type) {
69  if (value.getType() != type) {
70  oldTypes.emplace_back(value, value.getType());
71  value.setType(type);
72  }
73  }
74 
75  // Roll back changes made to the types in the IR by setting all the affected
76  // values to their old types.
77  void rollBack() {
78  for (auto [value, type] : oldTypes)
79  value.setType(type);
80 
81  oldTypes.clear();
82  }
83 
84  // Commit the changes to the types in the IR.
85  // This requires inserting tensor.cast operations to mediate the newly
86  // inferred result types with users that do not support type inference.
87  void commit() {
88  // For each use whose type changed, cast the value with the new type back to
89  // the old type.
90  for (auto [value, oldType] : oldTypes) {
91  tensor::CastOp castedValue;
92  for (auto &use : value.getUses()) {
93  if (canBeRefined(use.getOwner()))
94  continue;
95 
96  // Cache the cast to avoid generating duplicates
97  if (!castedValue) {
98  ImplicitLocOpBuilder builder{value.getLoc(), use.getOwner()};
99  castedValue = builder.create<tensor::CastOp>(oldType, value);
100  }
101 
102  use.set(castedValue);
103  }
104  }
105 
106  oldTypes.clear();
107  }
108 
109 private:
110  // A record of each value whose type was updated along with that value's
111  // previous type.
113 };
114 
115 void propagateShapesInRegion(Region &region, TypeModificationState &state);
116 
117 void propagateShapesToTosaIf(Operation &op, TypeModificationState &state) {
118  IfOp ifOp = dyn_cast<IfOp>(op);
119  if (!ifOp)
120  return;
121 
122  for (auto &region : op.getRegions()) {
123  Block &frontBlock = region.front();
124  if (frontBlock.getNumArguments() + 1 != ifOp.getNumOperands())
125  return;
126 
127  for (unsigned int i = 1, s = op.getNumOperands(); i < s; i++) {
128  auto inferredTy = cast<ShapedType>(op.getOperand(i).getType());
129  auto blockArg = frontBlock.getArgument(i - 1);
130  auto oldType = cast<ShapedType>(blockArg.getType());
131 
132  if (inferredTy.hasRank()) {
133  Type newType = oldType.clone(inferredTy.getShape());
134  state.setType(blockArg, newType);
135  }
136  }
137 
138  for (int i = 0, e = frontBlock.getNumArguments(); i < e; i++) {
140  ifOp.getOperand(i + 1).getType());
142  frontBlock.getArgument(i).getType());
143  ValueKnowledge joinedKnowledge =
144  ValueKnowledge::join(operandKnowledge, blockKnowledge);
145  if (!joinedKnowledge)
146  continue;
147  state.setType(frontBlock.getArgument(i), joinedKnowledge.getType());
148  }
149 
150  propagateShapesInRegion(region, state);
151  }
152 }
153 
154 void propagateShapesToTosaWhile(Operation &op, TypeModificationState &state) {
155  WhileOp whileOp = dyn_cast<WhileOp>(op);
156  if (!whileOp)
157  return;
158 
159  // Determine what the expected argument types are to the cond/body blocks.
160  // The expected arguments should be compatible with ever iteration of the
161  // loop body / condition for tosa.while.
162  SmallVector<Type> argTypes = llvm::to_vector(op.getOperandTypes());
163 
164  bool hasNewTypes = true;
165  while (hasNewTypes) {
166  TypeModificationState localState;
167 
168  // Set types on the block args.
169  Region &bodyRegion = op.getRegion(1);
170  Block &block = bodyRegion.front();
171  for (int i = 0, s = argTypes.size(); i < s; i++) {
172  localState.setType(block.getArgument(i), argTypes[i]);
173  }
174 
175  // Propagate to the end.
176  propagateShapesInRegion(bodyRegion, localState);
177 
178  // Find all the tosa yield types and verify there is a single one.
180  for (auto &block : bodyRegion)
181  if (auto yieldOp = dyn_cast<YieldOp>(block.getTerminator()))
182  yieldOps.push_back(yieldOp);
183 
184  assert(yieldOps.size() == 1 && "missing or non-unique yield op");
185  // Using the new tosa.yield operand types, infer the new subtypes.
186  llvm::SmallVector<ValueKnowledge> yieldTypeInfo;
187  for (auto ty : argTypes) {
188  yieldTypeInfo.push_back(ValueKnowledge::getKnowledgeFromType(ty));
189  }
190 
191  for (auto yieldOp : yieldOps) {
192  for (const auto &it : llvm::enumerate(yieldOp.getOperands())) {
193  auto newKnowledge =
194  ValueKnowledge::getKnowledgeFromType(it.value().getType());
195  yieldTypeInfo[it.index()] =
196  ValueKnowledge::meet(yieldTypeInfo[it.index()], newKnowledge);
197  }
198  }
199 
200  // This should never happen.
201  if (yieldTypeInfo.size() != argTypes.size()) {
202  op.emitWarning("has a tosa.yield with the incorrect number of operands");
203  return;
204  }
205 
206  // Determine the new block args and see if any changed.
207  hasNewTypes = false;
208  for (int i = 0, s = yieldTypeInfo.size(); i < s; i++) {
209  Type newType = yieldTypeInfo[i].getType();
210  hasNewTypes |= (newType != argTypes[i]);
211  argTypes[i] = newType;
212  }
213 
214  // Roll back all changes made during the speculative part of the algorithm.
215  localState.rollBack();
216  }
217 
218  // We now set the block arguments according to the most recent shape
219  // inference results. This gives us the block arg types for the next
220  // iteration.
221  for (auto &region : op.getRegions()) {
222  for (unsigned int i = 0, s = argTypes.size(); i < s; i++) {
223  state.setType(region.front().getArgument(i), argTypes[i]);
224  }
225 
226  propagateShapesInRegion(region, state);
227  }
228 }
229 
230 void propagateShapesInRegion(Region &region, TypeModificationState &state) {
231  Dialect *tosaDialect = region.getContext()->getLoadedDialect<TosaDialect>();
232 
233  for (auto &block : region) {
234  for (Operation &op : block) {
235  if (op.getDialect() != tosaDialect)
236  continue;
237 
238  propagateShapesToTosaIf(op, state);
239  propagateShapesToTosaWhile(op, state);
240 
241  InferShapedTypeOpInterface shapeInterface =
242  dyn_cast<InferShapedTypeOpInterface>(op);
243  if (!shapeInterface)
244  continue;
245 
246  SmallVector<ShapedTypeComponents> returnedShapes;
247 
248  if (shapeInterface
249  .inferReturnTypeComponents(
250  op.getContext(), op.getLoc(), op.getOperands(),
252  op.getRegions(), returnedShapes)
253  .succeeded()) {
254  for (auto it : llvm::zip(op.getResults(), returnedShapes)) {
255  Value result = std::get<0>(it);
256  ShapedTypeComponents predictedShape = std::get<1>(it);
257 
258  // Determine the knowledge based on the output type.
259  // TODO: should also query WIP type probably
260  Type resultTy = result.getType();
261  auto currentKnowledge =
263 
264  // Compute the knowledge based on the inferred type.
265  auto inferredKnowledge = ValueKnowledge::getPessimisticValueState();
266  inferredKnowledge.dtype = cast<ShapedType>(resultTy).getElementType();
267  inferredKnowledge.hasRank = predictedShape.hasRank();
268  if (predictedShape.hasRank()) {
269  for (auto dim : predictedShape.getDims()) {
270  inferredKnowledge.sizes.push_back(dim);
271  }
272  }
273 
274  // Compute the new type based on the joined version.
275  auto newKnowledge =
276  ValueKnowledge::join(currentKnowledge, inferredKnowledge);
277  if (!newKnowledge)
278  continue;
279 
280  // Set new type
281  state.setType(result, newKnowledge.getType());
282  }
283  }
284  }
285  }
286 }
287 
288 /// Pass that performs shape propagation across TOSA operations. This includes
289 /// migrating to within the regions of if/while operations.
290 struct TosaInferShapes
291  : public tosa::impl::TosaInferShapesBase<TosaInferShapes> {
292 public:
293  void runOnOperation() override {
294  func::FuncOp func = getOperation();
295  TypeModificationState state;
296  propagateShapesInRegion(func.getBody(), state);
297  state.commit();
298  }
299 };
300 } // namespace
301 
302 std::unique_ptr<Pass> mlir::tosa::createTosaInferShapesPass() {
303  return std::make_unique<TosaInferShapes>();
304 }
Block represents an ordered list of Operations.
Definition: Block.h:31
BlockArgument getArgument(unsigned i)
Definition: Block.h:127
unsigned getNumArguments()
Definition: Block.h:126
Operation * getTerminator()
Get the terminator operation of this block.
Definition: Block.cpp:243
Operation & front()
Definition: Block.h:151
Dialects are groups of MLIR operations, types and attributes, as well as behavior associated with the...
Definition: Dialect.h:41
TypeID getTypeID() const
Returns the unique identifier that corresponds to this dialect.
Definition: Dialect.h:60
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
Dialect * getLoadedDialect(StringRef name)
Get a registered IR dialect with the given namespace.
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Value getOperand(unsigned idx)
Definition: Operation.h:345
Dialect * getDialect()
Return the dialect this operation is associated with, or nullptr if the associated dialect is not loa...
Definition: Operation.h:220
InFlightDiagnostic emitWarning(const Twine &message={})
Emit a warning about this operation, reporting up to any diagnostic handlers that may be listening.
Definition: Operation.cpp:280
MLIRContext * getContext()
Return the context this operation is associated with.
Definition: Operation.h:216
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
unsigned getNumOperands()
Definition: Operation.h:341
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Definition: Operation.h:682
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
Definition: Operation.h:672
DictionaryAttr getDiscardableAttrDictionary()
Return all of the discardable attributes on this operation as a DictionaryAttr.
Definition: Operation.h:496
operand_type_range getOperandTypes()
Definition: Operation.h:392
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:373
result_range getResults()
Definition: Operation.h:410
OpaqueProperties getPropertiesStorage()
Returns the properties storage.
Definition: Operation.h:896
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
MLIRContext * getContext()
Return the context this region is inserted in.
Definition: Region.cpp:24
Block & front()
Definition: Region.h:65
ShapedTypeComponents that represents the components of a ShapedType.
bool hasRank() const
Return whether the shape has a rank.
ArrayRef< int64_t > getDims() const
Return the dimensions of the shape.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
void setType(Type newType)
Mutate the type of this Value to be of the specified type.
Definition: Value.h:140
Type getType() const
Return the type of this value.
Definition: Value.h:129
use_range getUses() const
Returns a range of all uses, which is useful for iterating over all uses.
Definition: Value.h:212
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:26
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:285
std::unique_ptr< Pass > createTosaInferShapesPass()
Include the generated interface declarations.
Statically known information for a particular Value.
Definition: ShapeUtils.h:33
static ValueKnowledge join(const ValueKnowledge &lhs, const ValueKnowledge &rhs)
Definition: ShapeUtils.h:81
static ValueKnowledge getPessimisticValueState()
Definition: ShapeUtils.h:61
static ValueKnowledge meet(const ValueKnowledge &lhs, const ValueKnowledge &rhs)
Definition: ShapeUtils.h:136
static ValueKnowledge getKnowledgeFromType(Type type)
Definition: ShapeUtils.h:45