MLIR  19.0.0git
TosaInferShapes.cpp
Go to the documentation of this file.
1 //===- TosaInferShapes.cpp ------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Propogate shapes forward along TOSA operations to resolve dynamic shape
10 // operations.
11 //
12 //===----------------------------------------------------------------------===//
13 
15 
20 #include "mlir/IR/Builders.h"
22 #include "mlir/Pass/Pass.h"
24 
25 namespace mlir {
26 namespace tosa {
27 #define GEN_PASS_DEF_TOSAINFERSHAPES
28 #include "mlir/Dialect/Tosa/Transforms/Passes.h.inc"
29 } // namespace tosa
30 } // namespace mlir
31 
32 using namespace mlir;
33 using namespace mlir::tosa;
34 
35 namespace {
36 
37 // Check whether this use case is replaceable. We define an op as
38 // being replaceable if it is used by a TosaOp, or an op with a
39 // type-inference related interface.
40 // When a non-replaceable use is encountered, the value is wrapped in a
41 // cast back to the original type after inference.
42 bool isReplaceableUser(Operation *user) {
43  // Handle unregistered dialects.
44  if (!user->getDialect())
45  return false;
46 
47  return user->getDialect()->getNamespace() ==
48  TosaDialect::getDialectNamespace() ||
49  isa<InferTypeOpInterface, InferShapedTypeOpInterface>(user);
50 }
51 
52 // During type propagation, the types of values in the operator graph are
53 // updated. For the tosa.while_loop operation, types are speculatively updated
54 // within the body region to determine the output type of the while_loop. This
55 // process is performed until a fixed point is reached, then the types are
56 // reverted.
57 //
58 // This class encapsulates the state information needed to perform the reversion
59 // process or to commit to the final changes.
60 class TypeModificationState {
61 public:
62  TypeModificationState() = default;
63 
64  ~TypeModificationState() {
65  // Ensure the recorded modifications are either committed or reverted.
66  assert(oldTypes.empty() && "unhandled type modifications");
67  }
68 
69  // Update the state of the value and record the old type.
70  void setType(Value value, Type type) {
71  if (value.getType() != type) {
72  oldTypes.emplace_back(value, value.getType());
73  value.setType(type);
74  }
75  }
76 
77  // Revert changes made to the types in the IR by setting all the affected
78  // values to their old types.
79  void revert() {
80  // Otherwise revert the changes.
81  for (auto [value, type] : oldTypes)
82  value.setType(type);
83 
84  oldTypes.clear();
85  }
86 
87  // Commit the changes to the types in the IR.
88  // This requires inserting tensor.cast operations to mediate the newly
89  // inferred result types with users that do not support type inference.
90  void commit() {
91  // For each use whose type changed, cast the value with the new type back to
92  // the old type.
93  for (auto [value, oldType] : oldTypes) {
94  for (auto &use : value.getUses()) {
95  if (isReplaceableUser(use.getOwner()))
96  continue;
97 
98  OpBuilder builder(value.getContext());
99  builder.setInsertionPoint(use.getOwner());
100 
101  Location loc = value.getLoc();
102  use.set(builder.create<tensor::CastOp>(loc, oldType, value));
103  }
104  }
105 
106  oldTypes.clear();
107  }
108 
109 private:
110  // A record of each value whose type was updated along with that value's
111  // previous type.
113 };
114 
115 void propagateShapesInRegion(Region &region, TypeModificationState &state);
116 
117 void propagateShapesToTosaIf(Operation &op, TypeModificationState &state) {
118  IfOp ifOp = dyn_cast<IfOp>(op);
119  if (!ifOp)
120  return;
121 
122  for (auto &region : op.getRegions()) {
123  Block &frontBlock = region.front();
124  if (frontBlock.getNumArguments() + 1 != ifOp.getNumOperands())
125  return;
126 
127  for (unsigned int i = 1, s = op.getNumOperands(); i < s; i++) {
128  auto inferredTy = cast<ShapedType>(op.getOperand(i).getType());
129  auto blockArg = frontBlock.getArgument(i - 1);
130  auto oldType = cast<ShapedType>(blockArg.getType());
131 
132  if (inferredTy.hasRank()) {
133  Type newType = oldType.clone(inferredTy.getShape());
134  state.setType(blockArg, newType);
135  }
136  }
137 
138  for (int i = 0, e = frontBlock.getNumArguments(); i < e; i++) {
140  ifOp.getOperand(i + 1).getType());
142  frontBlock.getArgument(i).getType());
143  ValueKnowledge joinedKnowledge =
144  ValueKnowledge::join(operandKnowledge, blockKnowledge);
145  if (!joinedKnowledge)
146  continue;
147  state.setType(frontBlock.getArgument(i), joinedKnowledge.getType());
148  }
149 
150  propagateShapesInRegion(region, state);
151  }
152 }
153 
154 void propagateShapesToTosaWhile(Operation &op, TypeModificationState &state) {
155  WhileOp whileOp = dyn_cast<WhileOp>(op);
156  if (!whileOp)
157  return;
158 
159  // Determine what the expected argument types are to the cond/body blocks.
160  // The expected arguments should be compatible with ever iteration of the
161  // loop body / condition for tosa.while.
162  SmallVector<Type> argTypes = llvm::to_vector(op.getOperandTypes());
163 
164  bool hasNewTypes = true;
165  while (hasNewTypes) {
166  TypeModificationState localState;
167 
168  // Set types on the block args.
169  Region &bodyRegion = op.getRegion(1);
170  Block &block = bodyRegion.front();
171  for (int i = 0, s = argTypes.size(); i < s; i++) {
172  localState.setType(block.getArgument(i), argTypes[i]);
173  }
174 
175  // Propagate to the end.
176  propagateShapesInRegion(bodyRegion, localState);
177 
178  // Find all the tosa yield types and verify there is a single one.
180  for (auto &block : bodyRegion)
181  if (auto yieldOp = dyn_cast<YieldOp>(block.getTerminator()))
182  yieldOps.push_back(yieldOp);
183 
184  assert(yieldOps.size() == 1 && "missing or non-unique yield op");
185  // Using the new tosa.yield operand types, infer the new subtypes.
186  llvm::SmallVector<ValueKnowledge> yieldTypeInfo;
187  for (auto ty : argTypes) {
188  yieldTypeInfo.push_back(ValueKnowledge::getKnowledgeFromType(ty));
189  }
190 
191  for (auto yieldOp : yieldOps) {
192  for (const auto &it : llvm::enumerate(yieldOp.getOperands())) {
193  auto newKnowledge =
194  ValueKnowledge::getKnowledgeFromType(it.value().getType());
195  yieldTypeInfo[it.index()] =
196  ValueKnowledge::meet(yieldTypeInfo[it.index()], newKnowledge);
197  }
198  }
199 
200  // This should never happen.
201  if (yieldTypeInfo.size() != argTypes.size()) {
202  op.emitWarning("has a tosa.yield with the incorrect number of operands");
203  return;
204  }
205 
206  // Determine the new block args and see if any changed.
207  hasNewTypes = false;
208  for (int i = 0, s = yieldTypeInfo.size(); i < s; i++) {
209  Type newType = yieldTypeInfo[i].getType();
210  hasNewTypes |= (newType != argTypes[i]);
211  argTypes[i] = newType;
212  }
213 
214  // Revert all changes made during the speculative part of the algorithm.
215  localState.revert();
216  }
217 
218  // We now set the block arguments according to the most recent shape
219  // inference results. This gives us the block arg types for the next
220  // iteration.
221  for (auto &region : op.getRegions()) {
222  for (unsigned int i = 0, s = argTypes.size(); i < s; i++) {
223  state.setType(region.front().getArgument(i), argTypes[i]);
224  }
225 
226  propagateShapesInRegion(region, state);
227  }
228 }
229 
230 void propagateShapesInRegion(Region &region, TypeModificationState &state) {
231  for (auto &block : region) {
232  for (Operation &op : block) {
233  if (!op.getDialect() ||
234  op.getDialect()->getNamespace() != TosaDialect::getDialectNamespace())
235  continue;
236 
237  propagateShapesToTosaIf(op, state);
238  propagateShapesToTosaWhile(op, state);
239 
240  InferShapedTypeOpInterface shapeInterface =
241  dyn_cast<InferShapedTypeOpInterface>(op);
242  if (!shapeInterface)
243  continue;
244 
245  SmallVector<ShapedTypeComponents> returnedShapes;
246 
247  if (shapeInterface
248  .inferReturnTypeComponents(
249  op.getContext(), op.getLoc(), op.getOperands(),
251  op.getRegions(), returnedShapes)
252  .succeeded()) {
253  for (auto it : llvm::zip(op.getResults(), returnedShapes)) {
254  Value result = std::get<0>(it);
255  ShapedTypeComponents predictedShape = std::get<1>(it);
256 
257  // Determine the knowledge based on the output type.
258  // TODO: should also query WIP type probably
259  Type resultTy = result.getType();
260  auto currentKnowledge =
262 
263  // Compute the knowledge based on the inferred type.
264  auto inferredKnowledge = ValueKnowledge::getPessimisticValueState();
265  inferredKnowledge.dtype = cast<ShapedType>(resultTy).getElementType();
266  inferredKnowledge.hasRank = predictedShape.hasRank();
267  if (predictedShape.hasRank()) {
268  for (auto dim : predictedShape.getDims()) {
269  inferredKnowledge.sizes.push_back(dim);
270  }
271  }
272 
273  // Compute the new type based on the joined version.
274  auto newKnowledge =
275  ValueKnowledge::join(currentKnowledge, inferredKnowledge);
276  if (!newKnowledge)
277  continue;
278 
279  // Set new type
280  state.setType(result, newKnowledge.getType());
281  }
282  }
283  }
284  }
285 }
286 
287 /// Pass that performs shape propagation across TOSA operations. This includes
288 /// migrating to within the regions of if/while operations.
289 struct TosaInferShapes
290  : public tosa::impl::TosaInferShapesBase<TosaInferShapes> {
291 public:
292  void runOnOperation() override {
293  func::FuncOp func = getOperation();
294  TypeModificationState state;
295  propagateShapesInRegion(func.getBody(), state);
296  state.commit();
297  }
298 };
299 } // namespace
300 
301 std::unique_ptr<Pass> mlir::tosa::createTosaInferShapesPass() {
302  return std::make_unique<TosaInferShapes>();
303 }
Block represents an ordered list of Operations.
Definition: Block.h:30
BlockArgument getArgument(unsigned i)
Definition: Block.h:126
unsigned getNumArguments()
Definition: Block.h:125
Operation * getTerminator()
Get the terminator operation of this block.
Definition: Block.cpp:243
Operation & front()
Definition: Block.h:150
StringRef getNamespace() const
Definition: Dialect.h:57
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
This class helps build Operations.
Definition: Builders.h:209
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Value getOperand(unsigned idx)
Definition: Operation.h:345
Dialect * getDialect()
Return the dialect this operation is associated with, or nullptr if the associated dialect is not loa...
Definition: Operation.h:220
InFlightDiagnostic emitWarning(const Twine &message={})
Emit a warning about this operation, reporting up to any diagnostic handlers that may be listening.
Definition: Operation.cpp:280
MLIRContext * getContext()
Return the context this operation is associated with.
Definition: Operation.h:216
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
unsigned getNumOperands()
Definition: Operation.h:341
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Definition: Operation.h:682
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
Definition: Operation.h:672
DictionaryAttr getDiscardableAttrDictionary()
Return all of the discardable attributes on this operation as a DictionaryAttr.
Definition: Operation.h:496
operand_type_range getOperandTypes()
Definition: Operation.h:392
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:373
result_range getResults()
Definition: Operation.h:410
OpaqueProperties getPropertiesStorage()
Returns the properties storage.
Definition: Operation.h:896
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
Block & front()
Definition: Region.h:65
ShapedTypeComponents that represents the components of a ShapedType.
bool hasRank() const
Return whether the shape has a rank.
ArrayRef< int64_t > getDims() const
Return the dimensions of the shape.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
void setType(Type newType)
Mutate the type of this Value to be of the specified type.
Definition: Value.h:140
MLIRContext * getContext() const
Utility to get the associated MLIRContext that this value is defined in.
Definition: Value.h:132
Type getType() const
Return the type of this value.
Definition: Value.h:129
use_range getUses() const
Returns a range of all uses, which is useful for iterating over all uses.
Definition: Value.h:212
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:26
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:285
std::unique_ptr< Pass > createTosaInferShapesPass()
Include the generated interface declarations.
Statically known information for a particular Value.
Definition: ShapeUtils.h:33
static ValueKnowledge join(const ValueKnowledge &lhs, const ValueKnowledge &rhs)
Definition: ShapeUtils.h:81
static ValueKnowledge getPessimisticValueState()
Definition: ShapeUtils.h:61
static ValueKnowledge meet(const ValueKnowledge &lhs, const ValueKnowledge &rhs)
Definition: ShapeUtils.h:136
static ValueKnowledge getKnowledgeFromType(Type type)
Definition: ShapeUtils.h:45