MLIR 23.0.0git
ControlFlowInterfaces.cpp
Go to the documentation of this file.
1//===- ControlFlowInterfaces.cpp - ControlFlow Interfaces -----------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include <utility>
10
12#include "mlir/IR/Operation.h"
15#include "llvm/ADT/EquivalenceClasses.h"
16#include "llvm/Support/DebugLog.h"
17
18using namespace mlir;
19
20//===----------------------------------------------------------------------===//
21// ControlFlowInterfaces
22//===----------------------------------------------------------------------===//
23
24#include "mlir/Interfaces/ControlFlowInterfaces.cpp.inc"
25
27 : producedOperandCount(0), forwardedOperands(std::move(forwardedOperands)) {
28}
29
30SuccessorOperands::SuccessorOperands(unsigned int producedOperandCount,
31 MutableOperandRange forwardedOperands)
32 : producedOperandCount(producedOperandCount),
33 forwardedOperands(std::move(forwardedOperands)) {}
34
35//===----------------------------------------------------------------------===//
36// BranchOpInterface
37//===----------------------------------------------------------------------===//
38
39/// Returns the `BlockArgument` corresponding to operand `operandIndex` in some
40/// successor if 'operandIndex' is within the range of 'operands', or
41/// std::nullopt if `operandIndex` isn't a successor operand index.
42std::optional<BlockArgument>
44 unsigned operandIndex, Block *successor) {
45 LDBG() << "Getting branch successor argument for operand index "
46 << operandIndex << " in successor block";
47
48 OperandRange forwardedOperands = operands.getForwardedOperands();
49 // Check that the operands are valid.
50 if (forwardedOperands.empty()) {
51 LDBG() << "No forwarded operands, returning nullopt";
52 return std::nullopt;
53 }
54
55 // Check to ensure that this operand is within the range.
56 unsigned operandsStart = forwardedOperands.getBeginOperandIndex();
57 if (operandIndex < operandsStart ||
58 operandIndex >= (operandsStart + forwardedOperands.size())) {
59 LDBG() << "Operand index " << operandIndex << " out of range ["
60 << operandsStart << ", "
61 << (operandsStart + forwardedOperands.size())
62 << "), returning nullopt";
63 return std::nullopt;
64 }
65
66 // Index the successor.
67 unsigned argIndex =
68 operands.getProducedOperandCount() + operandIndex - operandsStart;
69 LDBG() << "Computed argument index " << argIndex << " for successor block";
70 return successor->getArgument(argIndex);
71}
72
73/// Verify that the given operands match those of the given successor block.
74LogicalResult
76 const SuccessorOperands &operands) {
77 LDBG() << "Verifying branch successor operands for successor #" << succNo
78 << " in operation " << op->getName();
79
80 // Check the count.
81 unsigned operandCount = operands.size();
82 Block *destBB = op->getSuccessor(succNo);
83 LDBG() << "Branch has " << operandCount << " operands, target block has "
84 << destBB->getNumArguments() << " arguments";
85
86 if (operandCount != destBB->getNumArguments())
87 return op->emitError() << "branch has " << operandCount
88 << " operands for successor #" << succNo
89 << ", but target block has "
90 << destBB->getNumArguments();
91
92 // Check the types.
93 LDBG() << "Checking type compatibility for "
94 << (operandCount - operands.getProducedOperandCount())
95 << " forwarded operands";
96 for (unsigned i = operands.getProducedOperandCount(); i != operandCount;
97 ++i) {
98 Type operandType = operands[i].getType();
99 Type argType = destBB->getArgument(i).getType();
100 LDBG() << "Checking type compatibility: operand type " << operandType
101 << " vs argument type " << argType;
102
103 if (!cast<BranchOpInterface>(op).areTypesCompatible(operandType, argType))
104 return op->emitError() << "type mismatch for bb argument #" << i
105 << " of successor #" << succNo;
106 }
107
108 LDBG() << "Branch successor operand verification successful";
109 return success();
110}
111
112//===----------------------------------------------------------------------===//
113// WeightedBranchOpInterface
114//===----------------------------------------------------------------------===//
115
116static LogicalResult verifyWeights(Operation *op,
118 std::size_t expectedWeightsNum,
119 llvm::StringRef weightAnchorName,
120 llvm::StringRef weightRefName) {
121 if (weights.empty())
122 return success();
123
124 if (weights.size() != expectedWeightsNum)
125 return op->emitError() << "expects number of " << weightAnchorName
126 << " weights to match number of " << weightRefName
127 << ": " << weights.size() << " vs "
128 << expectedWeightsNum;
129
130 if (llvm::all_of(weights, [](int32_t value) { return value == 0; }))
131 return op->emitError() << "branch weights cannot all be zero";
132
133 return success();
134}
135
138 cast<WeightedBranchOpInterface>(op).getWeights();
139 return verifyWeights(op, weights, op->getNumSuccessors(), "branch",
140 "successors");
141}
142
143//===----------------------------------------------------------------------===//
144// WeightedRegionBranchOpInterface
145//===----------------------------------------------------------------------===//
146
149 cast<WeightedRegionBranchOpInterface>(op).getWeights();
150 return verifyWeights(op, weights, op->getNumRegions(), "region", "regions");
151}
152
153//===----------------------------------------------------------------------===//
154// RegionBranchOpInterface
155//===----------------------------------------------------------------------===//
156
157/// Verify that types match along control flow edges described the given op.
159 auto regionInterface = cast<RegionBranchOpInterface>(op);
160
161 // Verify all control flow edges from region branch points to region
162 // successors.
163 SmallVector<RegionBranchPoint> regionBranchPoints =
164 regionInterface.getAllRegionBranchPoints();
165 for (const RegionBranchPoint &branchPoint : regionBranchPoints) {
167 regionInterface.getSuccessorRegions(branchPoint, successors);
168 for (const RegionSuccessor &successor : successors) {
169 // Helper function that print the region branch point and the region
170 // successor.
171 auto emitRegionEdgeError = [&]() {
173 regionInterface->emitOpError("along control flow edge from ");
174 if (branchPoint.isParent()) {
175 diag << "parent";
176 diag.attachNote(op->getLoc()) << "region branch point";
177 } else {
178 diag << "Operation "
179 << branchPoint.getTerminatorPredecessorOrNull()->getName();
180 diag.attachNote(
181 branchPoint.getTerminatorPredecessorOrNull()->getLoc())
182 << "region branch point";
183 }
184 diag << " to ";
185 if (Region *region = successor.getSuccessor()) {
186 diag << "Region #" << region->getRegionNumber();
187 } else {
188 diag << "parent";
189 }
190 return diag;
191 };
192
193 // Verify number of successor operands and successor inputs.
194 OperandRange succOperands =
195 regionInterface.getSuccessorOperands(branchPoint, successor);
196 ValueRange succInputs = regionInterface.getSuccessorInputs(successor);
197 if (succOperands.size() != succInputs.size()) {
198 return emitRegionEdgeError()
199 << ": region branch point has " << succOperands.size()
200 << " operands, but region successor needs " << succInputs.size()
201 << " inputs";
202 }
203
204 // Verify that the types are compatible.
205 TypeRange succInputTypes = succInputs.getTypes();
206 TypeRange succOperandTypes = succOperands.getTypes();
207 for (const auto &typesIdx :
208 llvm::enumerate(llvm::zip(succOperandTypes, succInputTypes))) {
209 Type succOperandType = std::get<0>(typesIdx.value());
210 Type succInputType = std::get<1>(typesIdx.value());
211 if (!regionInterface.areTypesCompatible(succOperandType, succInputType))
212 return emitRegionEdgeError()
213 << ": successor operand type #" << typesIdx.index() << " "
214 << succOperandType << " should match successor input type #"
215 << typesIdx.index() << " " << succInputType;
216 }
217 }
218 }
219 return success();
220}
221
222/// Stop condition for `traverseRegionGraph`. The traversal is interrupted if
223/// this function returns "true" for a successor region. The first parameter is
224/// the successor region. The second parameter indicates all already visited
225/// regions.
227
228/// Traverse the region graph starting at `begin`. The traversal is interrupted
229/// if `stopCondition` evaluates to "true" for a successor region. In that case,
230/// this function returns "true". Otherwise, if the traversal was not
231/// interrupted, this function returns "false".
232static bool traverseRegionGraph(Region *begin,
233 StopConditionFn stopConditionFn) {
234 auto op = cast<RegionBranchOpInterface>(begin->getParentOp());
235 LDBG() << "Starting region graph traversal from region #"
236 << begin->getRegionNumber() << " in operation " << op->getName();
237
238 SmallVector<bool> visited(op->getNumRegions(), false);
239 visited[begin->getRegionNumber()] = true;
240 LDBG() << "Initialized visited array with " << op->getNumRegions()
241 << " regions";
242
243 // Retrieve all successors of the region and enqueue them in the worklist.
244 SmallVector<Region *> worklist;
245 auto enqueueAllSuccessors = [&](Region *region) {
246 LDBG() << "Enqueuing successors for region #" << region->getRegionNumber();
247 SmallVector<Attribute> operandAttributes(op->getNumOperands());
248 for (Block &block : *region) {
249 if (block.empty())
250 continue;
251 auto terminator =
252 dyn_cast<RegionBranchTerminatorOpInterface>(block.back());
253 if (!terminator)
254 continue;
256 operandAttributes.resize(terminator->getNumOperands());
257 terminator.getSuccessorRegions(operandAttributes, successors);
258 LDBG() << "Found " << successors.size()
259 << " successors from terminator in block";
260 for (RegionSuccessor successor : successors) {
261 if (!successor.isParent()) {
262 worklist.push_back(successor.getSuccessor());
263 LDBG() << "Added region #"
264 << successor.getSuccessor()->getRegionNumber()
265 << " to worklist";
266 } else {
267 LDBG() << "Skipping parent successor";
268 }
269 }
270 }
271 };
272 enqueueAllSuccessors(begin);
273 LDBG() << "Initial worklist size: " << worklist.size();
274
275 // Process all regions in the worklist via DFS.
276 while (!worklist.empty()) {
277 Region *nextRegion = worklist.pop_back_val();
278 LDBG() << "Processing region #" << nextRegion->getRegionNumber()
279 << " from worklist (remaining: " << worklist.size() << ")";
280
281 if (stopConditionFn(nextRegion, visited)) {
282 LDBG() << "Stop condition met for region #"
283 << nextRegion->getRegionNumber() << ", returning true";
284 return true;
285 }
286 if (!nextRegion->getParentOp()) {
287 llvm::errs() << "Region " << *nextRegion << " has no parent op\n";
288 return false;
289 }
290 if (visited[nextRegion->getRegionNumber()]) {
291 LDBG() << "Region #" << nextRegion->getRegionNumber()
292 << " already visited, skipping";
293 continue;
294 }
295 visited[nextRegion->getRegionNumber()] = true;
296 LDBG() << "Marking region #" << nextRegion->getRegionNumber()
297 << " as visited";
298 enqueueAllSuccessors(nextRegion);
299 }
300
301 LDBG() << "Traversal completed, returning false";
302 return false;
303}
304
305/// Return `true` if region `r` is reachable from region `begin` according to
306/// the RegionBranchOpInterface (by taking a branch).
307static bool isRegionReachable(Region *begin, Region *r) {
308 assert(begin->getParentOp() == r->getParentOp() &&
309 "expected that both regions belong to the same op");
310 return traverseRegionGraph(begin,
311 [&](Region *nextRegion, ArrayRef<bool> visited) {
312 // Interrupt traversal if `r` was reached.
313 return nextRegion == r;
314 });
315}
316
317/// Return `true` if `a` and `b` are in mutually exclusive regions.
318///
319/// 1. Find the first common of `a` and `b` (ancestor) that implements
320/// RegionBranchOpInterface.
321/// 2. Determine the regions `regionA` and `regionB` in which `a` and `b` are
322/// contained.
323/// 3. Check if `regionA` and `regionB` are mutually exclusive. They are
324/// mutually exclusive if they are not reachable from each other as per
325/// RegionBranchOpInterface::getSuccessorRegions.
327 LDBG() << "Checking if operations are in mutually exclusive regions: "
328 << a->getName() << " and " << b->getName();
329
330 assert(a && "expected non-empty operation");
331 assert(b && "expected non-empty operation");
332
333 auto branchOp = a->getParentOfType<RegionBranchOpInterface>();
334 while (branchOp) {
335 LDBG() << "Checking branch operation " << branchOp->getName();
336
337 // Check if b is inside branchOp. (We already know that a is.)
338 if (!branchOp->isProperAncestor(b)) {
339 LDBG() << "Operation b is not inside branchOp, checking next ancestor";
340 // Check next enclosing RegionBranchOpInterface.
341 branchOp = branchOp->getParentOfType<RegionBranchOpInterface>();
342 continue;
343 }
344
345 LDBG() << "Both operations are inside branchOp, finding their regions";
346
347 // b is contained in branchOp. Retrieve the regions in which `a` and `b`
348 // are contained.
349 Region *regionA = nullptr, *regionB = nullptr;
350 for (Region &r : branchOp->getRegions()) {
351 if (r.findAncestorOpInRegion(*a)) {
352 assert(!regionA && "already found a region for a");
353 regionA = &r;
354 LDBG() << "Found region #" << r.getRegionNumber() << " for operation a";
355 }
356 if (r.findAncestorOpInRegion(*b)) {
357 assert(!regionB && "already found a region for b");
358 regionB = &r;
359 LDBG() << "Found region #" << r.getRegionNumber() << " for operation b";
360 }
361 }
362 assert(regionA && regionB && "could not find region of op");
363
364 LDBG() << "Region A: #" << regionA->getRegionNumber() << ", Region B: #"
365 << regionB->getRegionNumber();
366
367 // `a` and `b` are in mutually exclusive regions if both regions are
368 // distinct and neither region is reachable from the other region.
369 bool regionsAreDistinct = (regionA != regionB);
370 bool aNotReachableFromB = !isRegionReachable(regionA, regionB);
371 bool bNotReachableFromA = !isRegionReachable(regionB, regionA);
372
373 LDBG() << "Regions distinct: " << regionsAreDistinct
374 << ", A not reachable from B: " << aNotReachableFromB
375 << ", B not reachable from A: " << bNotReachableFromA;
376
377 bool mutuallyExclusive =
378 regionsAreDistinct && aNotReachableFromB && bNotReachableFromA;
379 LDBG() << "Operations are mutually exclusive: " << mutuallyExclusive;
380
381 return mutuallyExclusive;
382 }
383
384 // Could not find a common RegionBranchOpInterface among a's and b's
385 // ancestors.
386 LDBG() << "No common RegionBranchOpInterface found, operations are not "
387 "mutually exclusive";
388 return false;
389}
390
391bool RegionBranchOpInterface::isRepetitiveRegion(unsigned index) {
392 LDBG() << "Checking if region #" << index << " is repetitive in operation "
393 << getOperation()->getName();
394
395 Region *region = &getOperation()->getRegion(index);
396 bool isRepetitive = isRegionReachable(region, region);
397
398 LDBG() << "Region #" << index << " is repetitive: " << isRepetitive;
399 return isRepetitive;
400}
401
402bool RegionBranchOpInterface::hasLoop() {
403 LDBG() << "Checking if operation " << getOperation()->getName()
404 << " has loops";
405
406 SmallVector<RegionSuccessor> entryRegions;
407 getSuccessorRegions(RegionBranchPoint::parent(), entryRegions);
408 LDBG() << "Found " << entryRegions.size() << " entry regions";
409
410 for (RegionSuccessor successor : entryRegions) {
411 if (!successor.isParent()) {
412 LDBG() << "Checking entry region #"
413 << successor.getSuccessor()->getRegionNumber() << " for loops";
414
415 bool hasLoop =
416 traverseRegionGraph(successor.getSuccessor(),
417 [](Region *nextRegion, ArrayRef<bool> visited) {
418 // Interrupt traversal if the region was already
419 // visited.
420 return visited[nextRegion->getRegionNumber()];
421 });
422
423 if (hasLoop) {
424 LDBG() << "Found loop in entry region #"
425 << successor.getSuccessor()->getRegionNumber();
426 return true;
427 }
428 } else {
429 LDBG() << "Skipping parent successor";
430 }
431 }
432
433 LDBG() << "No loops found in operation";
434 return false;
435}
436
438RegionBranchOpInterface::getSuccessorOperands(RegionBranchPoint src,
439 RegionSuccessor dest) {
440 if (src.isParent())
441 return getEntrySuccessorOperands(dest);
442 auto terminator = cast<RegionBranchTerminatorOpInterface>(
444 return terminator.getSuccessorOperands(dest);
445}
446
448 return MutableArrayRef<OpOperand>(operands.getBase(), operands.size());
449}
450
451static void
452getSuccessorOperandInputMapping(RegionBranchOpInterface branchOp,
454 RegionBranchPoint src) {
456 branchOp.getSuccessorRegions(src, successors);
457 for (RegionSuccessor dst : successors) {
458 OperandRange operands = branchOp.getSuccessorOperands(src, dst);
459 assert(operands.size() == branchOp.getSuccessorInputs(dst).size() &&
460 "expected the same number of operands and inputs");
461 for (const auto &[operand, input] : llvm::zip_equal(
462 operandsToOpOperands(operands), branchOp.getSuccessorInputs(dst)))
463 mapping[&operand].push_back(input);
464 }
465}
466void RegionBranchOpInterface::getSuccessorOperandInputMapping(
468 std::optional<RegionBranchPoint> src) {
469 if (src.has_value()) {
470 ::getSuccessorOperandInputMapping(*this, mapping, src.value());
471 } else {
472 // No region branch point specified: populate the mapping for all possible
473 // region branch points.
474 for (RegionBranchPoint branchPoint : getAllRegionBranchPoints())
475 ::getSuccessorOperandInputMapping(*this, mapping, branchPoint);
476 }
477}
478
480 const RegionBranchSuccessorMapping &operandToInputs) {
482 for (const auto &[operand, inputs] : operandToInputs) {
483 for (Value input : inputs)
484 inputToOperands[input].push_back(operand);
485 }
486 return inputToOperands;
487}
488
489void RegionBranchOpInterface::getSuccessorInputOperandMapping(
491 RegionBranchSuccessorMapping operandToInputs;
492 getSuccessorOperandInputMapping(operandToInputs);
493 mapping = invertRegionBranchSuccessorMapping(operandToInputs);
494}
495
497RegionBranchOpInterface::getAllRegionBranchPoints() {
499 branchPoints.push_back(RegionBranchPoint::parent());
500 for (Region &region : getOperation()->getRegions()) {
501 for (Block &block : region) {
502 if (block.empty())
503 continue;
504 if (auto terminator =
505 dyn_cast<RegionBranchTerminatorOpInterface>(block.back()))
506 branchPoints.push_back(RegionBranchPoint(terminator));
507 }
508 }
509 return branchPoints;
510}
511
513 LDBG() << "Finding enclosing repetitive region for operation "
514 << op->getName();
515
516 while (Region *region = op->getParentRegion()) {
517 LDBG() << "Checking region #" << region->getRegionNumber()
518 << " in operation " << region->getParentOp()->getName();
519
520 op = region->getParentOp();
521 if (auto branchOp = dyn_cast<RegionBranchOpInterface>(op)) {
522 LDBG()
523 << "Found RegionBranchOpInterface, checking if region is repetitive";
524 if (branchOp.isRepetitiveRegion(region->getRegionNumber())) {
525 LDBG() << "Found repetitive region #" << region->getRegionNumber();
526 return region;
527 }
528 } else {
529 LDBG() << "Parent operation does not implement RegionBranchOpInterface";
530 }
531 }
532
533 LDBG() << "No enclosing repetitive region found";
534 return nullptr;
535}
536
538 LDBG() << "Finding enclosing repetitive region for value";
539
540 Region *region = value.getParentRegion();
541 while (region) {
542 LDBG() << "Checking region #" << region->getRegionNumber()
543 << " in operation " << region->getParentOp()->getName();
544
545 Operation *op = region->getParentOp();
546 if (auto branchOp = dyn_cast<RegionBranchOpInterface>(op)) {
547 LDBG()
548 << "Found RegionBranchOpInterface, checking if region is repetitive";
549 if (branchOp.isRepetitiveRegion(region->getRegionNumber())) {
550 LDBG() << "Found repetitive region #" << region->getRegionNumber();
551 return region;
552 }
553 } else {
554 LDBG() << "Parent operation does not implement RegionBranchOpInterface";
555 }
556 region = op->getParentRegion();
557 }
558
559 LDBG() << "No enclosing repetitive region found for value";
560 return nullptr;
561}
562
563/// Return "true" if `a` can be used in lieu of `b`, where `b` is a region
564/// successor input and `a` is a "reachable value" of `b`. Reachable values
565/// are successor operand values that are (maybe transitively) forwarded to
566/// `b`.
567static bool isDefinedBefore(Operation *regionBranchOp, Value a, Value b) {
568 assert((b.getDefiningOp() == regionBranchOp ||
569 b.getParentRegion()->getParentOp() == regionBranchOp) &&
570 "b must be a region successor input");
571
572 // Case 1: `a` is defined inside of the region branch op. `a` must be
573 // directly nested in the region branch op. Otherwise, it could not have
574 // been among the reachable values for a region successor input.
575 if (a.getParentRegion()->getParentOp() == regionBranchOp) {
576 // Case 1.1: If `b` is a result of the region branch op, `a` is not in
577 // scope for `b`.
578 // Example:
579 // %b = region_op({
580 // ^bb0(%a1: ...):
581 // %a2 = ...
582 // })
583 if (isa<OpResult>(b))
584 return false;
585
586 // Case 1.2: `b` is an entry block argument of a region. `a` is in scope
587 // for `b` only if it is also an entry block argument of the same region.
588 // Example:
589 // region_op({
590 // ^bb0(%b: ..., %a: ...):
591 // ...
592 // })
593 assert(isa<BlockArgument>(b) && "b must be a block argument");
594 return isa<BlockArgument>(a) && cast<BlockArgument>(a).getOwner() ==
595 cast<BlockArgument>(b).getOwner();
596 }
597
598 // Case 2: `a` is defined outside of the region branch op. In that case, we
599 // can safely assume that `a` was defined before `b`. Otherwise, it could not
600 // be among the reachable values for a region successor input.
601 // Example:
602 // { <- %a1 parent region begins here.
603 // ^bb0(%a1: ...):
604 // %a2 = ...
605 // %b1 = reigon_op({
606 // ^bb1(%b2: ...):
607 // ...
608 // })
609 // }
610 return true;
611}
612
613/// Compute all non-successor-input values that a successor input could have
614/// based on the given successor input to successor operand mapping.
615///
616/// Example 1:
617/// %r = scf.if ... {
618/// scf.yield %a : ...
619/// } else {
620/// scf.yield %b : ...
621/// }
622/// reachableValues(%r) = {%a, %b}
623///
624/// Example 2:
625/// %r = scf.for ... iter_args(%arg0 = %0) -> ... {
626/// scf.yield %arg0 : ...
627/// }
628/// reachableValues(%arg0) = {%0}
629/// reachableValues(%r) = {%0}
630///
631/// Example 3:
632/// %r = scf.for ... iter_args(%arg0 = %0) -> ... {
633/// ...
634/// scf.yield %1 : ...
635/// }
636/// reachableValues(%arg0) = {%0, %1}
637/// reachableValues(%r) = {%0, %1}
638static llvm::SmallDenseSet<Value> computeReachableValuesFromSuccessorInput(
639 Value value, const RegionBranchInverseSuccessorMapping &inputToOperands) {
640 assert(inputToOperands.contains(value) && "value must be a successor input");
641 // Starting with the given value, trace back all predecessor values (i.e.,
642 // preceding successor operands) and add them to the set of reachable values.
643 // If the successor operand is again a successor input, do not add it to
644 // result set, but instead continue the traversal.
645 llvm::SmallDenseSet<Value> reachableValues;
646 llvm::SmallDenseSet<Value> visited;
647 SmallVector<Value> worklist;
648 worklist.push_back(value);
649 while (!worklist.empty()) {
650 Value next = worklist.pop_back_val();
651 auto it = inputToOperands.find(next);
652 if (it == inputToOperands.end()) {
653 reachableValues.insert(next);
654 continue;
655 }
656 for (OpOperand *operand : it->second)
657 if (visited.insert(operand->get()).second)
658 worklist.push_back(operand->get());
659 }
660 // Note: The result does not contain any successor inputs. (Therefore,
661 // `value` is also guaranteed to be excluded.)
662 return reachableValues;
663}
664
665namespace {
666/// Try to make successor inputs dead by replacing their uses with values that
667/// are not successor inputs. This pattern enables additional canonicalization
668/// opportunities for RemoveDeadRegionBranchOpSuccessorInputs.
669///
670/// Example:
671///
672/// %r0, %r1 = scf.for ... iter_args(%arg0 = %0, %arg1 = %1) -> ... {
673/// scf.yield %arg1, %arg1 : ...
674/// }
675/// use(%r0, %r1)
676///
677/// reachableValues(%r0) = {%0, %1}
678/// reachableValues(%r1) = {%1} ==> replace uses of %r1 with %1.
679/// reachableValues(%arg0) = {%0, %1}
680/// reachableValues(%arg1) = {%1} ==> replace uses of %arg1 with %1.
681///
682/// IR after pattern application:
683///
684/// %r0, %r1 = scf.for ... iter_args(%arg0 = %0, %arg1 = %1) -> ... {
685/// scf.yield %1, %1 : ...
686/// }
687/// use(%r0, %1)
688///
689/// Note that %r1 and %arg1 are dead now. The IR can now be further
690/// canonicalized by RemoveDeadRegionBranchOpSuccessorInputs.
691struct MakeRegionBranchOpSuccessorInputsDead : public RewritePattern {
692 MakeRegionBranchOpSuccessorInputsDead(MLIRContext *context, StringRef name,
693 PatternBenefit benefit = 1)
694 : RewritePattern(name, benefit, context) {}
695
696 LogicalResult matchAndRewrite(Operation *op,
697 PatternRewriter &rewriter) const override {
698 assert(!op->hasTrait<OpTrait::IsIsolatedFromAbove>() &&
699 "isolated-from-above ops are not supported");
700
701 // Compute the mapping of successor inputs to successor operands.
702 auto regionBranchOp = cast<RegionBranchOpInterface>(op);
704 regionBranchOp.getSuccessorInputOperandMapping(inputToOperands);
705
706 // Try to replace the uses of each successor input one-by-one.
707 bool changed = false;
708 for (Value value : inputToOperands.keys()) {
709 // Nothing to do for successor inputs that are already dead.
710 if (value.use_empty())
711 continue;
712 // Nothing to do for successor inputs that may have multiple reachable
713 // values.
714 llvm::SmallDenseSet<Value> reachableValues =
715 computeReachableValuesFromSuccessorInput(value, inputToOperands);
716 if (reachableValues.size() != 1)
717 continue;
718 assert(*reachableValues.begin() != value &&
719 "successor inputs are supposed to be excluded");
720 // Do not replace `value` with the found reachable value if doing so
721 // would violate dominance. Example:
722 // %r = scf.execute_region ... {
723 // %a = ...
724 // scf.yield %a : ...
725 // }
726 // use(%r)
727 // In the above example, reachableValues(%r) = {%a}, but %a cannot be
728 // used as a replacement for %r due to dominance / scope.
729 if (!isDefinedBefore(regionBranchOp, *reachableValues.begin(), value))
730 continue;
731 rewriter.replaceAllUsesWith(value, *reachableValues.begin());
732 changed = true;
733 }
734 return success(changed);
735 }
736};
737
738/// Lookup a bit vector in the given mapping (DenseMap). If the key was not
739/// found, create a new bit vector with the given size and initialize it with
740/// false.
741template <typename MappingTy, typename KeyTy>
742static BitVector &lookupOrCreateBitVector(MappingTy &mapping, KeyTy key,
743 unsigned size) {
744 return mapping.try_emplace(key, size, false).first->second;
745}
746
747/// Compute tied successor inputs. Tied successor inputs are successor inputs
748/// that come as a set. If you erase one value from a set, you must erase all
749/// values from the set. Otherwise, the op would become structurally invalid.
750/// Each successor input appears in exactly one set.
751///
752/// Example:
753/// %r0, %r1 = scf.for ... iter_args(%arg0 = %0, %arg1 = %1) -> ... {
754/// ...
755/// }
756/// There are two sets: {{%r0, %arg0}, {%r1, %arg1}}.
757static llvm::EquivalenceClasses<Value> computeTiedSuccessorInputs(
758 const RegionBranchSuccessorMapping &operandToInputs) {
759 llvm::EquivalenceClasses<Value> tiedSuccessorInputs;
760 for (const auto &[operand, inputs] : operandToInputs) {
761 assert(!inputs.empty() && "expected non-empty inputs");
762 Value firstInput = inputs.front();
763 tiedSuccessorInputs.insert(firstInput);
764 for (Value nextInput : llvm::drop_begin(inputs)) {
765 // As we explore more successor operand to successor input mappings,
766 // existing sets may get merged.
767 tiedSuccessorInputs.unionSets(firstInput, nextInput);
768 }
769 }
770 return tiedSuccessorInputs;
771}
772
773/// Remove dead successor inputs from region branch ops. A successor input is
774/// dead if it has no uses. Successor inputs come in sets of tied values: if
775/// you remove one value from a set, you must remove all values from the set.
776/// Furthermore, successor operands must also be removed. (Op operands are not
777/// part of the set, but the set is built based on the successor operand to
778/// successor input mapping.)
779///
780/// Example 1:
781/// %r0, %r1 = scf.for ... iter_args(%arg0 = %0, %arg1 = %1) -> ... {
782/// scf.yield %0, %arg1 : ...
783/// }
784/// use(%0, %1)
785///
786/// There are two sets: {{%r0, %arg0}, {%r1, %arg1}}. All values in the first
787/// set are dead, so %arg0 and %r0 can be removed, but not %r1 and %arg1. The
788/// resulting IR is as follows:
789///
790/// %r1 = scf.for ... iter_args(%arg1 = %1) -> ... {
791/// scf.yield %arg1 : ...
792/// }
793/// use(%0, %1)
794///
795/// Example 2:
796/// %r0, %r1 = scf.while (%arg0 = %0) {
797/// scf.condition(...) %arg0, %arg0 : ...
798/// } do {
799/// ^bb0(%arg1: ..., %arg2: ...):
800/// scf.yield %arg1 : ...
801/// }
802/// There are three sets: {{%r0, %arg1}, {%r1, %arg2}, {%r0}}.
803///
804/// Example 3:
805/// %r1, %r2 = scf.if ... {
806/// scf.yield %0, %1 : ...
807/// } else {
808/// scf.yield %2, %3 : ...
809/// }
810/// There are two sets: {{%r1}, {%r2}}. Each set has one value, so there each
811/// value can be removed independently of the other values.
812struct RemoveDeadRegionBranchOpSuccessorInputs : public RewritePattern {
813 RemoveDeadRegionBranchOpSuccessorInputs(MLIRContext *context, StringRef name,
814 PatternBenefit benefit = 1)
815 : RewritePattern(name, benefit, context) {}
816
817 LogicalResult matchAndRewrite(Operation *op,
818 PatternRewriter &rewriter) const override {
819 assert(!op->hasTrait<OpTrait::IsIsolatedFromAbove>() &&
820 "isolated-from-above ops are not supported");
821
822 // Compute tied values: values that must come as a set. If you remove one,
823 // you must remove all. If a successor op operand is forwarded to two
824 // successor inputs %a and %b, both %a and %b are in the same set.
825 auto regionBranchOp = cast<RegionBranchOpInterface>(op);
826 RegionBranchSuccessorMapping operandToInputs;
827 regionBranchOp.getSuccessorOperandInputMapping(operandToInputs);
828 llvm::EquivalenceClasses<Value> tiedSuccessorInputs =
829 computeTiedSuccessorInputs(operandToInputs);
830
831 // Determine which values to remove and group them by block and operation.
832 SmallVector<Value> valuesToRemove;
833 DenseMap<Block *, BitVector> blockArgsToRemove;
834 BitVector resultsToRemove(regionBranchOp->getNumResults(), false);
835 // Iterate over all sets of tied successor inputs.
836 for (auto it = tiedSuccessorInputs.begin(), e = tiedSuccessorInputs.end();
837 it != e; ++it) {
838 if (!(*it)->isLeader())
839 continue;
840
841 // Value can be removed if it is dead and all other tied values are also
842 // dead.
843 bool allDead = true;
844 for (auto memberIt = tiedSuccessorInputs.member_begin(**it);
845 memberIt != tiedSuccessorInputs.member_end(); ++memberIt) {
846 // Iterate over all values in the set and check their liveness.
847 if (!memberIt->use_empty()) {
848 allDead = false;
849 break;
850 }
851 }
852 if (!allDead)
853 continue;
854
855 // The entire set is dead. Group values by block and operation to
856 // simplify removal.
857 for (auto memberIt = tiedSuccessorInputs.member_begin(**it);
858 memberIt != tiedSuccessorInputs.member_end(); ++memberIt) {
859 if (auto arg = dyn_cast<BlockArgument>(*memberIt)) {
860 // Set blockArgsToRemove[block][arg_number] = true.
861 BitVector &vector =
862 lookupOrCreateBitVector(blockArgsToRemove, arg.getOwner(),
863 arg.getOwner()->getNumArguments());
864 vector.set(arg.getArgNumber());
865 } else {
866 // Set resultsToRemove[result_number] = true.
867 OpResult result = cast<OpResult>(*memberIt);
868 assert(result.getDefiningOp() == regionBranchOp &&
869 "result must be a region branch op result");
870 resultsToRemove.set(result.getResultNumber());
871 }
872 valuesToRemove.push_back(*memberIt);
873 }
874 }
875
876 if (valuesToRemove.empty())
877 return rewriter.notifyMatchFailure(op, "no values to remove");
878
879 // Find operands that must be removed together with the values.
880 RegionBranchInverseSuccessorMapping inputsToOperands =
881 invertRegionBranchSuccessorMapping(operandToInputs);
883 for (Value value : valuesToRemove) {
884 for (OpOperand *operand : inputsToOperands[value]) {
885 // Set operandsToRemove[op][operand_number] = true.
886 BitVector &vector =
887 lookupOrCreateBitVector(operandsToRemove, operand->getOwner(),
888 operand->getOwner()->getNumOperands());
889 vector.set(operand->getOperandNumber());
890 }
891 }
892
893 // Erase operands.
894 for (auto &pair : operandsToRemove) {
895 Operation *op = pair.first;
896 BitVector &operands = pair.second;
897 rewriter.modifyOpInPlace(op, [&]() { op->eraseOperands(operands); });
898 }
899
900 // Erase block arguments.
901 for (auto &pair : blockArgsToRemove) {
902 Block *block = pair.first;
903 BitVector &blockArg = pair.second;
904 rewriter.modifyOpInPlace(block->getParentOp(),
905 [&]() { block->eraseArguments(blockArg); });
906 }
907
908 // Erase op results.
909 if (resultsToRemove.any())
910 rewriter.eraseOpResults(regionBranchOp, resultsToRemove);
911
912 return success();
913 }
914};
915
916/// Return "true" if the two values are owned by the same operation or block.
917static bool haveSameOwner(Value a, Value b) {
918 void *aOwner, *bOwner;
919 if (auto arg = dyn_cast<BlockArgument>(a))
920 aOwner = arg.getOwner();
921 else
922 aOwner = a.getDefiningOp();
923 if (auto arg = dyn_cast<BlockArgument>(b))
924 bOwner = arg.getOwner();
925 else
926 bOwner = b.getDefiningOp();
927 return aOwner == bOwner;
928}
929
930/// Get the block argument or op result number of the given value.
931static unsigned getArgOrResultNumber(Value value) {
932 if (auto opResult = llvm::dyn_cast<OpResult>(value))
933 return opResult.getResultNumber();
934 return llvm::cast<BlockArgument>(value).getArgNumber();
935}
936
937/// Find duplicate successor inputs and make all dead except for one. Two
938/// successor inputs are "duplicate" if their corresponding successor operands
939/// have the same values. This pattern enables additional canonicalization
940/// opportunities for RemoveDeadRegionBranchOpSuccessorInputs.
941///
942/// Example:
943/// %r0, %r1 = scf.for ... iter_args(%arg0 = %0, %arg1 = %0) -> ... {
944/// use(%arg0, %arg1)
945/// ...
946/// scf.yield %x, %x : ...
947/// }
948/// use(%r0, %r1)
949///
950/// Operands of successor input %r0: [%0, %x]
951/// Operands of successor input %r1: [%0, %x] ==> DUPLICATE!
952/// Replace %r1 with %r0.
953///
954/// Operands of successor input %arg0: [%0, %x]
955/// Operands of successor input %arg1: [%0, %x] ==> DUPLICATE!
956/// Replace %arg1 with %arg0. (We have to make sure that we make same decision
957/// as for the other tied successor inputs above. Otherwise, a set of tied
958/// successor inputs may not become entirely dead.)
959///
960/// The resulting IR is as follows:
961/// %r0, %r1 = scf.for ... iter_args(%arg0 = %0, %arg1 = %0) -> ... {
962/// use(%arg0, %arg0)
963/// ...
964/// scf.yield %x, %x : ...
965/// }
966/// use(%r0, %r0) // Note: We don't want use(%r1, %r1), which is also correct,
967/// // but does not help with further canonicalizations.
968struct RemoveDuplicateSuccessorInputUses : public RewritePattern {
969 RemoveDuplicateSuccessorInputUses(MLIRContext *context, StringRef name,
970 PatternBenefit benefit = 1)
971 : RewritePattern(name, benefit, context) {}
972
973 LogicalResult matchAndRewrite(Operation *op,
974 PatternRewriter &rewriter) const override {
975 assert(!op->hasTrait<OpTrait::IsIsolatedFromAbove>() &&
976 "isolated-from-above ops are not supported");
977
978 // Collect all successor inputs and sort them. When dropping the uses of a
979 // successor input, we'd like to also drop the uses of the same tied
980 // successor inputs. Otherwise, a set of tied successor inputs may not
981 // become entirely dead, which is required for
982 // RemoveDeadRegionBranchOpSuccessorInputs to be able to erase them.
983 // (Sorting is not required for correctness.)
984 auto regionBranchOp = cast<RegionBranchOpInterface>(op);
986 regionBranchOp.getSuccessorInputOperandMapping(inputsToOperands);
987 SmallVector<Value> inputs = llvm::to_vector(inputsToOperands.keys());
988 llvm::sort(inputs, [](Value a, Value b) {
989 return getArgOrResultNumber(a) < getArgOrResultNumber(b);
990 });
991
992 // Check every distinct pair of successor inputs for duplicates. Replace
993 // `input2` with `input1` if they are duplicates.
994 bool changed = false;
995 unsigned numInputs = inputs.size();
996 for (auto i : llvm::seq<unsigned>(0, numInputs)) {
997 Value input1 = inputs[i];
998 for (auto j : llvm::seq<unsigned>(i + 1, numInputs)) {
999 Value input2 = inputs[j];
1000 // Nothing to do if input2 is already dead.
1001 if (input2.use_empty())
1002 continue;
1003 // Replace only values that belong to the same block / operation.
1004 // This implies that the two values are either both block arguments or
1005 // both op results.
1006 if (!haveSameOwner(input1, input2))
1007 continue;
1008
1009 // Gather the predecessor value for each predecessor (region branch
1010 // point). The two inputs are duplicates if each predecessor forwards
1011 // the same value.
1012 llvm::SmallDenseMap<Operation *, Value> operands1, operands2;
1013 for (OpOperand *operand : inputsToOperands[input1]) {
1014 assert(!operands1.contains(operand->getOwner()));
1015 operands1[operand->getOwner()] = operand->get();
1016 }
1017 for (OpOperand *operand : inputsToOperands[input2]) {
1018 assert(!operands2.contains(operand->getOwner()));
1019 operands2[operand->getOwner()] = operand->get();
1020 }
1021 if (operands1 == operands2) {
1022 rewriter.replaceAllUsesWith(input2, input1);
1023 changed = true;
1024 }
1025 }
1026 }
1027 return success(changed);
1028 }
1029};
1030} // namespace
1031
1033 RewritePatternSet &patterns, StringRef opName, PatternBenefit benefit) {
1034 patterns.add<MakeRegionBranchOpSuccessorInputsDead,
1035 RemoveDuplicateSuccessorInputUses,
1036 RemoveDeadRegionBranchOpSuccessorInputs>(patterns.getContext(),
1037 opName, benefit);
1038}
return success()
static LogicalResult verifyWeights(Operation *op, llvm::ArrayRef< int32_t > weights, std::size_t expectedWeightsNum, llvm::StringRef weightAnchorName, llvm::StringRef weightRefName)
static bool isDefinedBefore(Operation *regionBranchOp, Value a, Value b)
Return "true" if a can be used in lieu of b, where b is a region successor input and a is a "reachabl...
static void getSuccessorOperandInputMapping(RegionBranchOpInterface branchOp, RegionBranchSuccessorMapping &mapping, RegionBranchPoint src)
static bool traverseRegionGraph(Region *begin, StopConditionFn stopConditionFn)
Traverse the region graph starting at begin.
static llvm::SmallDenseSet< Value > computeReachableValuesFromSuccessorInput(Value value, const RegionBranchInverseSuccessorMapping &inputToOperands)
Compute all non-successor-input values that a successor input could have based on the given successor...
static RegionBranchInverseSuccessorMapping invertRegionBranchSuccessorMapping(const RegionBranchSuccessorMapping &operandToInputs)
function_ref< bool(Region *, ArrayRef< bool > visited)> StopConditionFn
Stop condition for traverseRegionGraph.
static bool isRegionReachable(Region *begin, Region *r)
Return true if region r is reachable from region begin according to the RegionBranchOpInterface (by t...
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
static std::string diag(const llvm::Value &value)
static MutableArrayRef< OpOperand > operandsToOpOperands(OperandRange &operands)
Block represents an ordered list of Operations.
Definition Block.h:33
BlockArgument getArgument(unsigned i)
Definition Block.h:139
unsigned getNumArguments()
Definition Block.h:138
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
Definition Block.cpp:31
IRValueT get() const
Return the current value being used by this operand.
This class represents a diagnostic that is inflight and set to be reported.
This class provides a mutable adaptor for a range of operands.
Definition ValueRange.h:118
This class represents an operand of an operation.
Definition Value.h:257
This class implements the operand iterators for the Operation class.
Definition ValueRange.h:43
unsigned getBeginOperandIndex() const
Return the operand index of the first element of this range.
type_range getTypes() const
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition Operation.h:749
unsigned getNumSuccessors()
Definition Operation.h:706
void eraseOperands(unsigned idx, unsigned length=1)
Erase the operands starting at position idx and ending at position 'idx'+'length'.
Definition Operation.h:360
unsigned getNumRegions()
Returns the number of regions held by this operation.
Definition Operation.h:674
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:223
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition Operation.h:234
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
Definition Operation.h:238
OperationName getName()
The name of an operation is the key identifier for it.
Definition Operation.h:119
Block * getSuccessor(unsigned index)
Definition Operation.h:708
Region * getParentRegion()
Returns the region to which the instruction belongs.
Definition Operation.h:230
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
This class represents a point being branched from in the methods of the RegionBranchOpInterface.
bool isParent() const
Returns true if branching from the parent op.
Operation * getTerminatorPredecessorOrNull() const
Returns the terminator if branching from a region.
static constexpr RegionBranchPoint parent()
Returns an instance of RegionBranchPoint representing the parent operation.
This class represents a successor of a region.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition Region.h:26
unsigned getRegionNumber()
Return the number of this region in the parent operation.
Definition Region.cpp:62
Operation * getParentOp()
Return the parent operation this region is attached to.
Definition Region.h:200
RewritePattern is the common base class for all DAG to DAG replacements.
Operation * eraseOpResults(Operation *op, const BitVector &eraseIndices)
Erase the specified results of the given operation.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
virtual void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
This class models how operands are forwarded to block arguments in control flow.
SuccessorOperands(MutableOperandRange forwardedOperands)
Constructs a SuccessorOperands with no produced operands that simply forwards operands to the success...
unsigned getProducedOperandCount() const
Returns the amount of operands that are produced internally by the operation.
unsigned size() const
Returns the amount of operands passed to the successor.
OperandRange getForwardedOperands() const
Get the range of operands that are simply forwarded to the successor.
This class provides an abstraction over the various different ranges of value types.
Definition TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:387
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
bool use_empty() const
Returns true if this value has no uses.
Definition Value.h:208
Type getType() const
Return the type of this value.
Definition Value.h:105
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
Region * getParentRegion()
Return the Region in which this Value is defined.
Definition Value.cpp:39
std::optional< BlockArgument > getBranchSuccessorArgument(const SuccessorOperands &operands, unsigned operandIndex, Block *successor)
Return the BlockArgument corresponding to operand operandIndex in some successor if operandIndex is w...
LogicalResult verifyRegionBranchWeights(Operation *op)
Verify that the region weights attached to an operation implementing WeightedRegiobBranchOpInterface ...
LogicalResult verifyBranchSuccessorOperands(Operation *op, unsigned succNo, const SuccessorOperands &operands)
Verify that the given operands match those of the given successor block.
LogicalResult verifyRegionBranchOpInterface(Operation *op)
Verify that types match along control flow edges described the given op.
LogicalResult verifyBranchWeights(Operation *op)
Verify that the branch weights attached to an operation implementing WeightedBranchOpInterface are co...
Include the generated interface declarations.
DenseMap< OpOperand *, SmallVector< Value > > RegionBranchSuccessorMapping
A mapping from successor operands to successor inputs.
const FrozenRewritePatternSet GreedyRewriteConfig bool * changed
bool insideMutuallyExclusiveRegions(Operation *a, Operation *b)
Return true if a and b are in mutually exclusive regions as per RegionBranchOpInterface.
Region * getEnclosingRepetitiveRegion(Operation *op)
Return the first enclosing region of the given op that may be executed repetitively as per RegionBran...
const FrozenRewritePatternSet & patterns
DenseMap< Value, SmallVector< OpOperand * > > RegionBranchInverseSuccessorMapping
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
Definition LLVM.h:126
void populateRegionBranchOpInterfaceCanonicalizationPatterns(RewritePatternSet &patterns, StringRef opName, PatternBenefit benefit=1)
Populate canonicalization patterns that simplify successor operands/inputs of region branch operation...
llvm::function_ref< Fn > function_ref
Definition LLVM.h:152