MLIR 22.0.0git
TosaInferShapes.cpp
Go to the documentation of this file.
1//===- TosaInferShapes.cpp ------------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// Propagate shapes forward along TOSA operations to resolve dynamic shape
10// operations.
11//
12//===----------------------------------------------------------------------===//
13
15
20#include "mlir/IR/Builders.h"
23
24namespace mlir {
25namespace tosa {
26#define GEN_PASS_DEF_TOSAINFERSHAPESPASS
27#include "mlir/Dialect/Tosa/Transforms/Passes.h.inc"
28} // namespace tosa
29} // namespace mlir
30
31using namespace mlir;
32using namespace mlir::tosa;
33
34namespace {
35
36// Check whether this use case is replaceable. We define an op as
37// being replaceable if it is used by a TosaOp, or an op with a
38// type-inference related interface.
39// When a non-replaceable use is encountered, the value is wrapped in a
40// cast back to the original type after inference.
41bool canBeRefined(Operation *user) {
42 if (!user->getDialect())
43 return false;
44 return user->getDialect()->getTypeID() == TypeID::get<TosaDialect>() ||
45 isa<InferTypeOpInterface, InferShapedTypeOpInterface>(user);
46}
47
48// During type propagation, the types of values in the operator graph are
49// updated. For the tosa.while_loop operation, types are speculatively updated
50// within the body region to determine the output type of the while_loop. This
51// process is performed until a fixed point is reached, then the types are
52// rolled back.
53//
54// This class encapsulates the state information needed to perform the roll back
55// process or to commit to the final changes.
56class TypeModificationState {
57public:
58 TypeModificationState() = default;
59
60 ~TypeModificationState() {
61 // Ensure the recorded modifications are either committed or rolled back.
62 assert(oldTypes.empty() && "unhandled type modifications");
63 }
64
65 // Update the state of the value and record the old type.
66 void setType(Value value, Type type) {
67 if (value.getType() != type) {
68 oldTypes.emplace_back(value, value.getType());
69 value.setType(type);
70 }
71 }
72
73 // Roll back changes made to the types in the IR by setting all the affected
74 // values to their old types.
75 void rollBack() {
76 for (auto [value, type] : oldTypes)
77 value.setType(type);
78
79 oldTypes.clear();
80 }
81
82 // Commit the changes to the types in the IR.
83 // This requires inserting tensor.cast operations to mediate the newly
84 // inferred result types with users that do not support type inference.
85 void commit() {
86 // For each use whose type changed, cast the value with the new type back to
87 // the old type.
88 for (auto [value, oldType] : oldTypes) {
89 // The call to 'use->set()' in the body of the loop below invalidates the
90 // iterator used to traverse op uses, so it is important to make a copy of
91 // these first.
92 llvm::SmallVector<OpOperand *> uses = llvm::map_to_vector(
93 value.getUses(),
94 [](OpOperand &use) -> OpOperand * {
95 return &use;
96 });
97
98 // A 'tensor.cast' op is emitted only if needed. Once emitted, it is
99 // cached and reused by all consumers.
100 tensor::CastOp castValue;
101
102 // Traverse all uses
103 for (OpOperand *use : uses) {
104 if (canBeRefined(use->getOwner()))
105 continue;
106
107 if (!castValue) {
108 // Set the insertion point as far back as possible, since new
109 // consumers of the 'tensor.cast' op generated in future iterations
110 // are likely to be further up in the code due to the order in which
111 // they appear in the use list.
112 OpBuilder builder{value.getContext()};
113 builder.setInsertionPointAfter(value.getDefiningOp());
114 castValue =
115 tensor::CastOp::create(builder, value.getLoc(), oldType, value);
116 }
117
118 use->set(castValue);
119 }
120 }
121
122 oldTypes.clear();
123 }
124
125private:
126 // A record of each value whose type was updated along with that value's
127 // previous type.
128 llvm::SmallVector<std::pair<Value, Type>> oldTypes;
129};
130
131void propagateShapesInRegion(Region &region, TypeModificationState &state);
132
133void propagateShapesToTosaIf(Operation &op, TypeModificationState &state) {
134 IfOp ifOp = dyn_cast<IfOp>(op);
135 if (!ifOp)
136 return;
137
138 for (auto &region : op.getRegions()) {
139 Block &frontBlock = region.front();
140 if (frontBlock.getNumArguments() + 1 != ifOp.getNumOperands())
141 return;
142
143 for (unsigned int i = 1, s = op.getNumOperands(); i < s; i++) {
144 auto inferredTy = cast<ShapedType>(op.getOperand(i).getType());
145 auto blockArg = frontBlock.getArgument(i - 1);
146 auto oldType = cast<ShapedType>(blockArg.getType());
147
148 if (inferredTy.hasRank()) {
149 Type newType = oldType.clone(inferredTy.getShape());
150 state.setType(blockArg, newType);
151 }
152 }
153
154 for (int i = 0, e = frontBlock.getNumArguments(); i < e; i++) {
156 ifOp.getOperand(i + 1).getType());
158 frontBlock.getArgument(i).getType());
159 ValueKnowledge joinedKnowledge =
160 ValueKnowledge::join(operandKnowledge, blockKnowledge);
161 if (!joinedKnowledge)
162 continue;
163 state.setType(frontBlock.getArgument(i), joinedKnowledge.getType());
164 }
165
166 propagateShapesInRegion(region, state);
167 }
168}
169
170void propagateShapesToTosaWhile(Operation &op, TypeModificationState &state) {
171 WhileOp whileOp = dyn_cast<WhileOp>(op);
172 if (!whileOp)
173 return;
174
175 // Determine what the expected argument types are to the cond/body blocks.
176 // The expected arguments should be compatible with ever iteration of the
177 // loop body / condition for tosa.while.
178 SmallVector<Type> argTypes = llvm::to_vector(op.getOperandTypes());
179
180 bool hasNewTypes = true;
181 while (hasNewTypes) {
182 TypeModificationState localState;
183
184 // Set types on the block args.
185 Region &bodyRegion = op.getRegion(1);
186 Block &block = bodyRegion.front();
187 for (int i = 0, s = argTypes.size(); i < s; i++) {
188 localState.setType(block.getArgument(i), argTypes[i]);
189 }
190
191 // Propagate to the end.
192 propagateShapesInRegion(bodyRegion, localState);
193
194 // Find all the tosa yield types and verify there is a single one.
196 for (auto &block : bodyRegion)
197 if (auto yieldOp = dyn_cast<YieldOp>(block.getTerminator()))
198 yieldOps.push_back(yieldOp);
199
200 assert(yieldOps.size() == 1 && "missing or non-unique yield op");
201 // Using the new tosa.yield operand types, infer the new subtypes.
203 for (auto ty : argTypes) {
204 yieldTypeInfo.push_back(ValueKnowledge::getKnowledgeFromType(ty));
205 }
206
207 for (auto yieldOp : yieldOps) {
208 for (const auto &it : llvm::enumerate(yieldOp.getOperands())) {
209 auto newKnowledge =
210 ValueKnowledge::getKnowledgeFromType(it.value().getType());
211 yieldTypeInfo[it.index()] =
212 ValueKnowledge::meet(yieldTypeInfo[it.index()], newKnowledge);
213 }
214 }
215
216 // This should never happen.
217 if (yieldTypeInfo.size() != argTypes.size()) {
218 op.emitWarning("has a tosa.yield with the incorrect number of operands");
219 return;
220 }
221
222 // Determine the new block args and see if any changed.
223 hasNewTypes = false;
224 for (int i = 0, s = yieldTypeInfo.size(); i < s; i++) {
225 Type newType = yieldTypeInfo[i].getType();
226 hasNewTypes |= (newType != argTypes[i]);
227 argTypes[i] = newType;
230 // Roll back all changes made during the speculative part of the algorithm.
231 localState.rollBack();
232 }
233
234 // We now set the block arguments according to the most recent shape
235 // inference results. This gives us the block arg types for the next
236 // iteration.
237 for (auto &region : op.getRegions()) {
238 for (unsigned int i = 0, s = argTypes.size(); i < s; i++) {
239 state.setType(region.front().getArgument(i), argTypes[i]);
240 }
241
242 propagateShapesInRegion(region, state);
243 }
244}
246void propagateShapesInRegion(Region &region, TypeModificationState &state) {
247 Dialect *tosaDialect = region.getContext()->getLoadedDialect<TosaDialect>();
249 for (auto &block : region) {
250 for (Operation &op : block) {
251 if (op.getDialect() != tosaDialect)
252 continue;
254 propagateShapesToTosaIf(op, state);
255 propagateShapesToTosaWhile(op, state);
256
257 InferShapedTypeOpInterface shapeInterface =
258 dyn_cast<InferShapedTypeOpInterface>(op);
259 if (!shapeInterface)
260 continue;
261
263
264 if (shapeInterface
265 .inferReturnTypeComponents(
266 op.getContext(), op.getLoc(), op.getOperands(),
268 op.getRegions(), returnedShapes)
269 .succeeded()) {
270 for (auto it : llvm::zip(op.getResults(), returnedShapes)) {
271 Value result = std::get<0>(it);
272 ShapedTypeComponents predictedShape = std::get<1>(it);
273
274 // Determine the knowledge based on the output type.
275 // TODO: should also query WIP type probably
276 Type resultTy = result.getType();
277 auto currentKnowledge =
279
280 // Compute the knowledge based on the inferred type.
281 auto inferredKnowledge = ValueKnowledge::getPessimisticValueState();
282 inferredKnowledge.dtype = cast<ShapedType>(resultTy).getElementType();
283 inferredKnowledge.hasRank = predictedShape.hasRank();
284 if (predictedShape.hasRank()) {
285 for (auto dim : predictedShape.getDims()) {
286 inferredKnowledge.sizes.push_back(dim);
287 }
288 }
289
290 // Compute the new type based on the joined version.
291 auto newKnowledge =
292 ValueKnowledge::join(currentKnowledge, inferredKnowledge);
293 if (!newKnowledge)
294 continue;
295
296 // Set new type
297 state.setType(result, newKnowledge.getType());
298 }
299 }
300 }
301 }
302}
303
304/// Recursively validate tosa ops with SameOperandsAndResultRank trait in region
305/// and all nested regions
306void validateSameOperandsAndResultRankTrait(Region &region) {
307 int errs = 0;
308 for (auto &block : region) {
309 for (auto &op : block) {
310 if (!op.getDialect() ||
311 op.getDialect()->getNamespace() != TosaDialect::getDialectNamespace())
312 continue;
315 errs++;
316 (void)errs;
317 }
318 }
319 WhileOp whileOp = dyn_cast<WhileOp>(op);
320 IfOp ifOp = dyn_cast<IfOp>(op);
321 if (whileOp || ifOp) {
322 // recurse into whileOp's regions
323 for (auto &next : op.getRegions()) {
324 validateSameOperandsAndResultRankTrait(next);
325 }
326 }
327 }
328 }
329}
330
331/// Pass that performs shape propagation across TOSA operations. This includes
332/// migrating to within the regions of if/while operations.
333struct TosaInferShapes
334 : public tosa::impl::TosaInferShapesPassBase<TosaInferShapes> {
335public:
336 void runOnOperation() override {
337 func::FuncOp func = getOperation();
338 TypeModificationState state;
339 propagateShapesInRegion(func.getBody(), state);
340 state.commit();
341
342 validateSameOperandsAndResultRankTrait(func.getBody());
343 }
344};
345} // namespace
Block represents an ordered list of Operations.
Definition Block.h:33
BlockArgument getArgument(unsigned i)
Definition Block.h:129
unsigned getNumArguments()
Definition Block.h:128
Operation & front()
Definition Block.h:153
Operation * getTerminator()
Get the terminator operation of this block.
Definition Block.cpp:244
Dialects are groups of MLIR operations, types and attributes, as well as behavior associated with the...
Definition Dialect.h:38
StringRef getNamespace() const
Definition Dialect.h:54
TypeID getTypeID() const
Returns the unique identifier that corresponds to this dialect.
Definition Dialect.h:57
Dialect * getLoadedDialect(StringRef name)
Get a registered IR dialect with the given namespace.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition Builders.h:412
This class verifies that op has same ranks for all operands and results types, if known.
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
Dialect * getDialect()
Return the dialect this operation is associated with, or nullptr if the associated dialect is not loa...
Definition Operation.h:220
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Definition Operation.h:686
Value getOperand(unsigned idx)
Definition Operation.h:350
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition Operation.h:749
InFlightDiagnostic emitWarning(const Twine &message={})
Emit a warning about this operation, reporting up to any diagnostic handlers that may be listening.
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:223
unsigned getNumOperands()
Definition Operation.h:346
DictionaryAttr getDiscardableAttrDictionary()
Return all of the discardable attributes on this operation as a DictionaryAttr.
Definition Operation.h:501
operand_type_range getOperandTypes()
Definition Operation.h:397
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
Definition Operation.h:677
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition Operation.h:378
result_range getResults()
Definition Operation.h:415
MLIRContext * getContext()
Return the context this operation is associated with.
Definition Operation.h:216
OpaqueProperties getPropertiesStorage()
Returns the properties storage.
Definition Operation.h:900
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition Region.h:26
Block & front()
Definition Region.h:65
MLIRContext * getContext()
Return the context this region is inserted in.
Definition Region.cpp:24
ShapedTypeComponents that represents the components of a ShapedType.
bool hasRank() const
Return whether the shape has a rank.
ArrayRef< int64_t > getDims() const
Return the dimensions of the shape.
static TypeID get()
Construct a type info object for the given type T.
Definition TypeID.h:245
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
void setType(Type newType)
Mutate the type of this Value to be of the specified type.
Definition Value.h:116
MLIRContext * getContext() const
Utility to get the associated MLIRContext that this value is defined in.
Definition Value.h:108
Type getType() const
Return the type of this value.
Definition Value.h:105
use_range getUses() const
Returns a range of all uses, which is useful for iterating over all uses.
Definition Value.h:188
Location getLoc() const
Return the location of this value.
Definition Value.cpp:24
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
LogicalResult verifySameOperandsAndResultRank(Operation *op)
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:561
Include the generated interface declarations.
Statically known information for a particular Value.
Definition ShapeUtils.h:33
static ValueKnowledge join(const ValueKnowledge &lhs, const ValueKnowledge &rhs)
Definition ShapeUtils.h:81
static ValueKnowledge getPessimisticValueState()
Definition ShapeUtils.h:61
static ValueKnowledge meet(const ValueKnowledge &lhs, const ValueKnowledge &rhs)
Definition ShapeUtils.h:136
static ValueKnowledge getKnowledgeFromType(Type type)
Definition ShapeUtils.h:45