MLIR 23.0.0git
TosaInferShapes.cpp
Go to the documentation of this file.
1//===- TosaInferShapes.cpp ------------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// Propagate shapes forward along TOSA operations to resolve dynamic shape
10// operations.
11//
12//===----------------------------------------------------------------------===//
13
15
20#include "mlir/IR/Builders.h"
21#include "mlir/IR/Iterators.h"
25
26namespace mlir {
27namespace tosa {
28#define GEN_PASS_DEF_TOSAINFERSHAPESPASS
29#include "mlir/Dialect/Tosa/Transforms/Passes.h.inc"
30} // namespace tosa
31} // namespace mlir
32
33using namespace mlir;
34using namespace mlir::tosa;
35
36namespace {
37
38// Check whether this use case is replaceable. We define an op as
39// being replaceable if it is used by a TosaOp, or an op with a
40// type-inference related interface.
41// When a non-replaceable use is encountered, the value is wrapped in a
42// cast back to the original type after inference.
43bool canBeRefined(Operation *user) {
44 if (!user->getDialect())
45 return false;
46 return user->getDialect()->getTypeID() == TypeID::get<TosaDialect>() ||
47 isa<InferTypeOpInterface, InferShapedTypeOpInterface>(user);
48}
49
50// During type propagation, the types of values in the operator graph are
51// updated. For the tosa.while_loop operation, types are speculatively updated
52// within the body region to determine the output type of the while_loop. This
53// process is performed until a fixed point is reached, then the types are
54// rolled back.
55//
56// This class encapsulates the state information needed to perform the roll back
57// process or to commit to the final changes.
58class TypeModificationState {
59public:
60 TypeModificationState() = default;
61
62 ~TypeModificationState() {
63 // Ensure the recorded modifications are either committed or rolled back.
64 assert(oldTypes.empty() && "unhandled type modifications");
65 }
66
67 // Update the state of the value and record the old type.
68 void setType(Value value, Type type) {
69 if (value.getType() != type) {
70 oldTypes.emplace_back(value, value.getType());
71 value.setType(type);
72 }
73 }
74
75 // Roll back changes made to the types in the IR by setting all the affected
76 // values to their old types.
77 void rollBack() {
78 for (auto [value, type] : oldTypes)
79 value.setType(type);
80
81 oldTypes.clear();
82 }
83
84 // Commit the changes to the types in the IR.
85 // This requires inserting tensor.cast operations to mediate the newly
86 // inferred result types with users that do not support type inference.
87 void commit() {
88 // For each use whose type changed, cast the value with the new type back to
89 // the old type.
90 for (auto [value, oldType] : oldTypes) {
91 // The call to 'use->set()' in the body of the loop below invalidates the
92 // iterator used to traverse op uses, so it is important to make a copy of
93 // these first.
94 llvm::SmallVector<OpOperand *> uses = llvm::map_to_vector(
95 value.getUses(),
96 [](OpOperand &use) -> OpOperand * {
97 return &use;
98 });
99
100 // A 'tensor.cast' op is emitted only if needed. Once emitted, it is
101 // cached and reused by all consumers.
102 tensor::CastOp castValue;
103
104 // Traverse all uses
105 for (OpOperand *use : uses) {
106 if (canBeRefined(use->getOwner()))
107 continue;
108
109 if (!castValue) {
110 // Set the insertion point as far back as possible, since new
111 // consumers of the 'tensor.cast' op generated in future iterations
112 // are likely to be further up in the code due to the order in which
113 // they appear in the use list.
114 OpBuilder builder{value.getContext()};
115 if (Operation *defOp = value.getDefiningOp()) {
116 builder.setInsertionPointAfter(defOp);
117 } else {
118 // For block arguments there is no defining op; insert at the start
119 // of the block that owns the argument.
121 }
122 castValue =
123 tensor::CastOp::create(builder, value.getLoc(), oldType, value);
124 }
125
126 use->set(castValue);
127 }
128 }
129
130 oldTypes.clear();
131 }
132
133private:
134 // A record of each value whose type was updated along with that value's
135 // previous type.
136 llvm::SmallVector<std::pair<Value, Type>> oldTypes;
137};
138
139/// Recursively validate tosa ops with SameOperandsAndResultRank trait in region
140/// and all nested regions
141void validateSameOperandsAndResultRankTrait(Region &region) {
142 int errs = 0;
143 for (auto &block : region) {
144 for (auto &op : block) {
145 if (!op.getDialect() ||
146 op.getDialect()->getNamespace() != TosaDialect::getDialectNamespace())
147 continue;
148 if (op.hasTrait<OpTrait::SameOperandsAndResultRank>()) {
150 errs++;
151 (void)errs;
152 }
153 }
154 WhileOp whileOp = dyn_cast<WhileOp>(op);
155 IfOp ifOp = dyn_cast<IfOp>(op);
156 if (whileOp || ifOp) {
157 // recurse into whileOp's regions
158 for (auto &next : op.getRegions()) {
159 validateSameOperandsAndResultRankTrait(next);
160 }
161 }
162 }
163 }
164}
165
166/// Pass that performs shape propagation across TOSA operations. This includes
167/// migrating to within the regions of if/while operations.
168struct TosaInferShapes
169 : public tosa::impl::TosaInferShapesPassBase<TosaInferShapes> {
170public:
171 explicit TosaInferShapes() = default;
172 explicit TosaInferShapes(const TosaInferShapesPassOptions &options)
173 : TosaInferShapes() {
174 this->foldShapeExpressions = options.foldShapeExpressions;
175 this->convertFunctionBoundaries = options.convertFunctionBoundaries;
176 }
177
178 void runOnOperation() override {
179 func::FuncOp func = getOperation();
180 TypeModificationState state;
181 propagateShapesInRegion(func.getBody(), state);
182 state.commit();
183
184 if (foldShapeExpressions) {
185 // Folding shape expressions may leave dead tosa.const_shape operations
186 func.walk<WalkOrder::PostOrder, ReverseIterator>(
187 [](tosa::ConstShapeOp op) {
188 if (isOpTriviallyDead(op))
189 op->erase();
190 });
191 }
192
193 validateSameOperandsAndResultRankTrait(func.getBody());
194
195 if (convertFunctionBoundaries)
196 convertFunctionReturnTypes(func);
197 }
198
199private:
200 void propagateShapesToTosaIf(Operation &op, TypeModificationState &state) {
201 IfOp ifOp = dyn_cast<IfOp>(op);
202 if (!ifOp)
203 return;
204
205 for (auto &region : op.getRegions()) {
206 Block &frontBlock = region.front();
207 if (frontBlock.getNumArguments() + 1 != ifOp.getNumOperands())
208 return;
209
210 for (unsigned int i = 1, s = op.getNumOperands(); i < s; i++) {
211 auto inferredTy = cast<ShapedType>(op.getOperand(i).getType());
212 auto blockArg = frontBlock.getArgument(i - 1);
213 auto oldType = cast<ShapedType>(blockArg.getType());
214
215 if (inferredTy.hasRank()) {
216 Type newType = oldType.clone(inferredTy.getShape());
217 state.setType(blockArg, newType);
218 }
219 }
220
221 for (int i = 0, e = frontBlock.getNumArguments(); i < e; i++) {
222 ValueKnowledge operandKnowledge = ValueKnowledge::getKnowledgeFromType(
223 ifOp.getOperand(i + 1).getType());
224 ValueKnowledge blockKnowledge = ValueKnowledge::getKnowledgeFromType(
225 frontBlock.getArgument(i).getType());
226 ValueKnowledge joinedKnowledge =
227 ValueKnowledge::join(operandKnowledge, blockKnowledge);
228 if (!joinedKnowledge)
229 continue;
230 state.setType(frontBlock.getArgument(i), joinedKnowledge.getType());
231 }
232
233 propagateShapesInRegion(region, state);
234 }
235 }
236
237 void propagateShapesToTosaWhile(Operation &op, TypeModificationState &state) {
238 WhileOp whileOp = dyn_cast<WhileOp>(op);
239 if (!whileOp)
240 return;
241
242 // Determine what the expected argument types are to the cond/body blocks.
243 // The expected arguments should be compatible with ever iteration of the
244 // loop body / condition for tosa.while.
245 SmallVector<Type> argTypes = llvm::to_vector(op.getOperandTypes());
246
247 bool hasNewTypes = true;
248 while (hasNewTypes) {
249 TypeModificationState localState;
250
251 // Set types on the block args.
252 Region &bodyRegion = op.getRegion(1);
253 Block &block = bodyRegion.front();
254 for (int i = 0, s = argTypes.size(); i < s; i++) {
255 localState.setType(block.getArgument(i), argTypes[i]);
256 }
257
258 // Propagate to the end.
259 propagateShapesInRegion(bodyRegion, localState);
260
261 // Find all the tosa yield types and verify there is a single one.
262 llvm::SmallVector<YieldOp> yieldOps;
263 for (auto &block : bodyRegion)
264 if (auto yieldOp = dyn_cast<YieldOp>(block.getTerminator()))
265 yieldOps.push_back(yieldOp);
266
267 assert(yieldOps.size() == 1 && "missing or non-unique yield op");
268 // Using the new tosa.yield operand types, infer the new subtypes.
269 llvm::SmallVector<ValueKnowledge> yieldTypeInfo;
270 for (auto ty : argTypes) {
271 yieldTypeInfo.push_back(ValueKnowledge::getKnowledgeFromType(ty));
272 }
273
274 for (auto yieldOp : yieldOps) {
275 for (const auto &it : llvm::enumerate(yieldOp.getOperands())) {
276 auto newKnowledge =
277 ValueKnowledge::getKnowledgeFromType(it.value().getType());
278 yieldTypeInfo[it.index()] =
279 ValueKnowledge::meet(yieldTypeInfo[it.index()], newKnowledge);
280 }
281 }
282
283 // This should never happen.
284 if (yieldTypeInfo.size() != argTypes.size()) {
285 op.emitWarning(
286 "has a tosa.yield with the incorrect number of operands");
287 return;
288 }
289
290 // Determine the new block args and see if any changed.
291 hasNewTypes = false;
292 for (int i = 0, s = yieldTypeInfo.size(); i < s; i++) {
293 Type newType = yieldTypeInfo[i].getType();
294 hasNewTypes |= (newType != argTypes[i]);
295 argTypes[i] = newType;
296 }
297
298 // Roll back all changes made during the speculative part of the
299 // algorithm.
300 localState.rollBack();
301 }
302
303 // We now set the block arguments according to the most recent shape
304 // inference results. This gives us the block arg types for the next
305 // iteration.
306 for (auto &region : op.getRegions()) {
307 for (unsigned int i = 0, s = argTypes.size(); i < s; i++) {
308 state.setType(region.front().getArgument(i), argTypes[i]);
309 }
310
311 propagateShapesInRegion(region, state);
312 }
313 }
314
315 void propagateShapesInRegion(Region &region, TypeModificationState &state) {
316 MLIRContext *ctx = region.getContext();
317 Dialect *tosaDialect = ctx->getLoadedDialect<TosaDialect>();
318 OperationFolder folder(ctx);
319
320 for (auto &block : region) {
321 // The loop body may erase operations, so we need to be careful
322 // when iterating. Fetch the next operation before the current
323 // operation is modified.
324 for (auto it = block.begin(); it != block.end();) {
325 Operation &op = *it++;
326 if (op.getDialect() != tosaDialect)
327 continue;
328
329 propagateShapesToTosaIf(op, state);
330 propagateShapesToTosaWhile(op, state);
331
332 if (foldShapeExpressions &&
333 op.hasTrait<OpTrait::tosa::TosaShapeOperator>()) {
334 (void)folder.tryToFold(&op);
335 continue;
336 }
337
338 InferShapedTypeOpInterface shapeInterface =
339 dyn_cast<InferShapedTypeOpInterface>(op);
340 if (!shapeInterface)
341 continue;
342
343 SmallVector<ShapedTypeComponents> returnedShapes;
344
345 if (shapeInterface
346 .inferReturnTypeComponents(
347 op.getContext(), op.getLoc(), op.getOperands(),
349 op.getPropertiesStorage(), op.getRegions(), returnedShapes)
350 .succeeded()) {
351 for (auto it : llvm::zip(op.getResults(), returnedShapes)) {
352 Value result = std::get<0>(it);
353 ShapedTypeComponents predictedShape = std::get<1>(it);
354
355 // Determine the knowledge based on the output type.
356 // TODO: should also query WIP type probably
357 Type resultTy = result.getType();
358 auto currentKnowledge =
360
361 // Compute the knowledge based on the inferred type.
362 auto inferredKnowledge = ValueKnowledge::getPessimisticValueState();
363 inferredKnowledge.dtype =
364 cast<ShapedType>(resultTy).getElementType();
365 inferredKnowledge.hasRank = predictedShape.hasRank();
366 if (predictedShape.hasRank()) {
367 for (auto dim : predictedShape.getDims()) {
368 inferredKnowledge.sizes.push_back(dim);
369 }
370 }
371
372 // Compute the new type based on the joined version.
373 auto newKnowledge =
374 ValueKnowledge::join(currentKnowledge, inferredKnowledge);
375 if (!newKnowledge)
376 continue;
377
378 // Set new type
379 state.setType(result, newKnowledge.getType());
380 }
381 }
382 }
383 }
384 }
385
386 void convertFunctionReturnTypes(func::FuncOp func) {
387 IRRewriter rewriter(func.getContext());
388 SmallVector<Type> newReturnTypes;
389
390 // Rewrite func.return ops, removing dead tensor.cast ops if possible
391 func.walk([&rewriter, &newReturnTypes](func::ReturnOp ret) {
392 SmallVector<Value> newReturnValues;
393 SmallVector<Value> maybeDeadCasts;
394 OperandRange returnOperands = ret.getOperands();
395 newReturnValues.reserve(returnOperands.size());
396 maybeDeadCasts.reserve(returnOperands.size());
397 newReturnTypes.reserve(newReturnTypes.size() + returnOperands.size());
398
399 for (const Value &v : returnOperands) {
400 Value newReturnValue = v;
401 if (auto castOp = v.getDefiningOp<tensor::CastOp>()) {
402 newReturnValue = castOp.getSource();
403 maybeDeadCasts.push_back(castOp);
404 }
405 newReturnValues.push_back(newReturnValue);
406 newReturnTypes.push_back(newReturnValue.getType());
407 }
408
409 rewriter.setInsertionPoint(ret);
410 rewriter.replaceOpWithNewOp<func::ReturnOp>(ret, newReturnValues);
411
412 if (!maybeDeadCasts.empty()) {
413 llvm::for_each(maybeDeadCasts, [&](Value castVal) {
414 if (castVal.use_empty()) {
415 rewriter.eraseOp(castVal.getDefiningOp());
416 }
417 });
418 }
419 });
420
421 // Update function return types with newly inferred types
422 const FunctionType oldType = func.getFunctionType();
423 const FunctionType newType = FunctionType::get(
424 func.getContext(), oldType.getInputs(), newReturnTypes);
425 func.setType(newType);
426 }
427};
428} // namespace
static llvm::ManagedStatic< PassManagerOptions > options
BlockArgument getArgument(unsigned i)
Definition Block.h:139
unsigned getNumArguments()
Definition Block.h:138
Operation * getTerminator()
Get the terminator operation of this block.
Definition Block.cpp:249
iterator end()
Definition Block.h:154
iterator begin()
Definition Block.h:153
TypeID getTypeID() const
Returns the unique identifier that corresponds to this dialect.
Definition Dialect.h:57
Dialect * getLoadedDialect(StringRef name)
Get a registered IR dialect with the given namespace.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition Builders.h:433
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition Builders.h:414
This class verifies that op has same ranks for all operands and results types, if known.
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
Dialect * getDialect()
Return the dialect this operation is associated with, or nullptr if the associated dialect is not loa...
Definition Operation.h:241
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Definition Operation.h:715
Value getOperand(unsigned idx)
Definition Operation.h:379
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition Operation.h:778
InFlightDiagnostic emitWarning(const Twine &message={})
Emit a warning about this operation, reporting up to any diagnostic handlers that may be listening.
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:244
unsigned getNumOperands()
Definition Operation.h:375
DictionaryAttr getDiscardableAttrDictionary()
Return all of the discardable attributes on this operation as a DictionaryAttr.
Definition Operation.h:530
operand_type_range getOperandTypes()
Definition Operation.h:426
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
Definition Operation.h:706
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition Operation.h:407
result_range getResults()
Definition Operation.h:444
MLIRContext * getContext()
Return the context this operation is associated with.
Definition Operation.h:237
OpaqueProperties getPropertiesStorage()
Returns the properties storage.
Definition Operation.h:929
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition Region.h:26
Block & front()
Definition Region.h:65
MLIRContext * getContext()
Return the context this region is inserted in.
Definition Region.cpp:24
bool hasRank() const
Return whether the shape has a rank.
ArrayRef< int64_t > getDims() const
Return the dimensions of the shape.
static TypeID get()
Construct a type info object for the given type T.
Definition TypeID.h:245
bool use_empty() const
Returns true if this value has no uses.
Definition Value.h:208
void setType(Type newType)
Mutate the type of this Value to be of the specified type.
Definition Value.h:116
MLIRContext * getContext() const
Utility to get the associated MLIRContext that this value is defined in.
Definition Value.h:108
Type getType() const
Return the type of this value.
Definition Value.h:105
use_range getUses() const
Returns a range of all uses, which is useful for iterating over all uses.
Definition Value.h:188
Block * getParentBlock()
Return the Block in which this Value is defined.
Definition Value.cpp:46
Location getLoc() const
Return the location of this value.
Definition Value.cpp:24
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
LogicalResult verifySameOperandsAndResultRank(Operation *op)
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
Include the generated interface declarations.
bool isOpTriviallyDead(Operation *op)
Return true if the given operation is unused, and has no side effects on memory that prevent erasing.
static ValueKnowledge join(const ValueKnowledge &lhs, const ValueKnowledge &rhs)
Definition ShapeUtils.h:81
static ValueKnowledge getPessimisticValueState()
Definition ShapeUtils.h:61
static ValueKnowledge meet(const ValueKnowledge &lhs, const ValueKnowledge &rhs)
Definition ShapeUtils.h:136
static ValueKnowledge getKnowledgeFromType(Type type)
Definition ShapeUtils.h:45