MLIR 23.0.0git
ControlFlowInterfaces.cpp
Go to the documentation of this file.
1//===- ControlFlowInterfaces.cpp - ControlFlow Interfaces -----------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include <utility>
10
12#include "mlir/IR/Matchers.h"
13#include "mlir/IR/Operation.h"
16#include "llvm/ADT/EquivalenceClasses.h"
17#include "llvm/Support/DebugLog.h"
18
19using namespace mlir;
20
21//===----------------------------------------------------------------------===//
22// ControlFlowInterfaces
23//===----------------------------------------------------------------------===//
24
25#include "mlir/Interfaces/ControlFlowInterfaces.cpp.inc"
26
28 : producedOperandCount(0), forwardedOperands(std::move(forwardedOperands)) {
29}
30
31SuccessorOperands::SuccessorOperands(unsigned int producedOperandCount,
32 MutableOperandRange forwardedOperands)
33 : producedOperandCount(producedOperandCount),
34 forwardedOperands(std::move(forwardedOperands)) {}
35
36//===----------------------------------------------------------------------===//
37// BranchOpInterface
38//===----------------------------------------------------------------------===//
39
40/// Returns the `BlockArgument` corresponding to operand `operandIndex` in some
41/// successor if 'operandIndex' is within the range of 'operands', or
42/// std::nullopt if `operandIndex` isn't a successor operand index.
43std::optional<BlockArgument>
45 unsigned operandIndex, Block *successor) {
46 LDBG() << "Getting branch successor argument for operand index "
47 << operandIndex << " in successor block";
48
49 OperandRange forwardedOperands = operands.getForwardedOperands();
50 // Check that the operands are valid.
51 if (forwardedOperands.empty()) {
52 LDBG() << "No forwarded operands, returning nullopt";
53 return std::nullopt;
54 }
55
56 // Check to ensure that this operand is within the range.
57 unsigned operandsStart = forwardedOperands.getBeginOperandIndex();
58 if (operandIndex < operandsStart ||
59 operandIndex >= (operandsStart + forwardedOperands.size())) {
60 LDBG() << "Operand index " << operandIndex << " out of range ["
61 << operandsStart << ", "
62 << (operandsStart + forwardedOperands.size())
63 << "), returning nullopt";
64 return std::nullopt;
65 }
66
67 // Index the successor.
68 unsigned argIndex =
69 operands.getProducedOperandCount() + operandIndex - operandsStart;
70 LDBG() << "Computed argument index " << argIndex << " for successor block";
71 return successor->getArgument(argIndex);
72}
73
74/// Verify that the given operands match those of the given successor block.
75LogicalResult
77 const SuccessorOperands &operands) {
78 LDBG() << "Verifying branch successor operands for successor #" << succNo
79 << " in operation " << op->getName();
80
81 // Check the count.
82 unsigned operandCount = operands.size();
83 Block *destBB = op->getSuccessor(succNo);
84 LDBG() << "Branch has " << operandCount << " operands, target block has "
85 << destBB->getNumArguments() << " arguments";
86
87 if (operandCount != destBB->getNumArguments())
88 return op->emitError() << "branch has " << operandCount
89 << " operands for successor #" << succNo
90 << ", but target block has "
91 << destBB->getNumArguments();
92
93 // Check the types.
94 LDBG() << "Checking type compatibility for "
95 << (operandCount - operands.getProducedOperandCount())
96 << " forwarded operands";
97 for (unsigned i = operands.getProducedOperandCount(); i != operandCount;
98 ++i) {
99 Type operandType = operands[i].getType();
100 Type argType = destBB->getArgument(i).getType();
101 LDBG() << "Checking type compatibility: operand type " << operandType
102 << " vs argument type " << argType;
103
104 if (!cast<BranchOpInterface>(op).areTypesCompatible(operandType, argType))
105 return op->emitError() << "type mismatch for bb argument #" << i
106 << " of successor #" << succNo;
107 }
108
109 LDBG() << "Branch successor operand verification successful";
110 return success();
111}
112
113//===----------------------------------------------------------------------===//
114// WeightedBranchOpInterface
115//===----------------------------------------------------------------------===//
116
117static LogicalResult verifyWeights(Operation *op,
119 std::size_t expectedWeightsNum,
120 llvm::StringRef weightAnchorName,
121 llvm::StringRef weightRefName) {
122 if (weights.empty())
123 return success();
124
125 if (weights.size() != expectedWeightsNum)
126 return op->emitError() << "expects number of " << weightAnchorName
127 << " weights to match number of " << weightRefName
128 << ": " << weights.size() << " vs "
129 << expectedWeightsNum;
130
131 if (llvm::all_of(weights, [](int32_t value) { return value == 0; }))
132 return op->emitError() << "branch weights cannot all be zero";
133
134 return success();
135}
136
139 cast<WeightedBranchOpInterface>(op).getWeights();
140 return verifyWeights(op, weights, op->getNumSuccessors(), "branch",
141 "successors");
142}
143
144//===----------------------------------------------------------------------===//
145// WeightedRegionBranchOpInterface
146//===----------------------------------------------------------------------===//
147
150 cast<WeightedRegionBranchOpInterface>(op).getWeights();
151 return verifyWeights(op, weights, op->getNumRegions(), "region", "regions");
152}
153
154//===----------------------------------------------------------------------===//
155// RegionBranchOpInterface
156//===----------------------------------------------------------------------===//
157
158/// Verify that types match along control flow edges described the given op.
160 auto regionInterface = cast<RegionBranchOpInterface>(op);
161
162 // Verify all control flow edges from region branch points to region
163 // successors.
164 SmallVector<RegionBranchPoint> regionBranchPoints =
165 regionInterface.getAllRegionBranchPoints();
166 for (const RegionBranchPoint &branchPoint : regionBranchPoints) {
168 regionInterface.getSuccessorRegions(branchPoint, successors);
169 for (const RegionSuccessor &successor : successors) {
170 // Helper function that print the region branch point and the region
171 // successor.
172 auto emitRegionEdgeError = [&]() {
174 regionInterface->emitOpError("along control flow edge from ");
175 if (branchPoint.isParent()) {
176 diag << "parent";
177 diag.attachNote(op->getLoc()) << "region branch point";
178 } else {
179 diag << "Operation "
180 << branchPoint.getTerminatorPredecessorOrNull()->getName();
181 diag.attachNote(
182 branchPoint.getTerminatorPredecessorOrNull()->getLoc())
183 << "region branch point";
184 }
185 diag << " to ";
186 if (Region *region = successor.getSuccessor()) {
187 diag << "Region #" << region->getRegionNumber();
188 } else {
189 diag << "parent";
190 }
191 return diag;
192 };
193
194 // Verify number of successor operands and successor inputs.
195 OperandRange succOperands =
196 regionInterface.getSuccessorOperands(branchPoint, successor);
197 ValueRange succInputs = regionInterface.getSuccessorInputs(successor);
198 if (succOperands.size() != succInputs.size()) {
199 return emitRegionEdgeError()
200 << ": region branch point has " << succOperands.size()
201 << " operands, but region successor needs " << succInputs.size()
202 << " inputs";
203 }
204
205 // Verify that the types are compatible.
206 TypeRange succInputTypes = succInputs.getTypes();
207 TypeRange succOperandTypes = succOperands.getTypes();
208 for (const auto &typesIdx :
209 llvm::enumerate(llvm::zip(succOperandTypes, succInputTypes))) {
210 Type succOperandType = std::get<0>(typesIdx.value());
211 Type succInputType = std::get<1>(typesIdx.value());
212 if (!regionInterface.areTypesCompatible(succOperandType, succInputType))
213 return emitRegionEdgeError()
214 << ": successor operand type #" << typesIdx.index() << " "
215 << succOperandType << " should match successor input type #"
216 << typesIdx.index() << " " << succInputType;
217 }
218 }
219 }
220 return success();
221}
222
223/// Stop condition for `traverseRegionGraph`. The traversal is interrupted if
224/// this function returns "true" for a successor region. The first parameter is
225/// the successor region. The second parameter indicates all already visited
226/// regions.
228
229/// Traverse the region graph starting at `begin`. The traversal is interrupted
230/// if `stopCondition` evaluates to "true" for a successor region. In that case,
231/// this function returns "true". Otherwise, if the traversal was not
232/// interrupted, this function returns "false".
233static bool traverseRegionGraph(Region *begin,
234 StopConditionFn stopConditionFn) {
235 auto op = cast<RegionBranchOpInterface>(begin->getParentOp());
236 LDBG() << "Starting region graph traversal from region #"
237 << begin->getRegionNumber() << " in operation " << op->getName();
238
239 SmallVector<bool> visited(op->getNumRegions(), false);
240 visited[begin->getRegionNumber()] = true;
241 LDBG() << "Initialized visited array with " << op->getNumRegions()
242 << " regions";
243
244 // Retrieve all successors of the region and enqueue them in the worklist.
245 SmallVector<Region *> worklist;
246 auto enqueueAllSuccessors = [&](Region *region) {
247 LDBG() << "Enqueuing successors for region #" << region->getRegionNumber();
248 SmallVector<Attribute> operandAttributes(op->getNumOperands());
249 for (Block &block : *region) {
250 if (block.empty())
251 continue;
252 auto terminator =
253 dyn_cast<RegionBranchTerminatorOpInterface>(block.back());
254 if (!terminator)
255 continue;
257 operandAttributes.resize(terminator->getNumOperands());
258 terminator.getSuccessorRegions(operandAttributes, successors);
259 LDBG() << "Found " << successors.size()
260 << " successors from terminator in block";
261 for (RegionSuccessor successor : successors) {
262 if (!successor.isParent()) {
263 worklist.push_back(successor.getSuccessor());
264 LDBG() << "Added region #"
265 << successor.getSuccessor()->getRegionNumber()
266 << " to worklist";
267 } else {
268 LDBG() << "Skipping parent successor";
269 }
270 }
271 }
272 };
273 enqueueAllSuccessors(begin);
274 LDBG() << "Initial worklist size: " << worklist.size();
275
276 // Process all regions in the worklist via DFS.
277 while (!worklist.empty()) {
278 Region *nextRegion = worklist.pop_back_val();
279 LDBG() << "Processing region #" << nextRegion->getRegionNumber()
280 << " from worklist (remaining: " << worklist.size() << ")";
281
282 if (stopConditionFn(nextRegion, visited)) {
283 LDBG() << "Stop condition met for region #"
284 << nextRegion->getRegionNumber() << ", returning true";
285 return true;
286 }
287 if (!nextRegion->getParentOp()) {
288 llvm::errs() << "Region " << *nextRegion << " has no parent op\n";
289 return false;
290 }
291 if (visited[nextRegion->getRegionNumber()]) {
292 LDBG() << "Region #" << nextRegion->getRegionNumber()
293 << " already visited, skipping";
294 continue;
295 }
296 visited[nextRegion->getRegionNumber()] = true;
297 LDBG() << "Marking region #" << nextRegion->getRegionNumber()
298 << " as visited";
299 enqueueAllSuccessors(nextRegion);
300 }
301
302 LDBG() << "Traversal completed, returning false";
303 return false;
304}
305
306/// Return `true` if region `r` is reachable from region `begin` according to
307/// the RegionBranchOpInterface (by taking a branch).
308static bool isRegionReachable(Region *begin, Region *r) {
309 assert(begin->getParentOp() == r->getParentOp() &&
310 "expected that both regions belong to the same op");
311 return traverseRegionGraph(begin,
312 [&](Region *nextRegion, ArrayRef<bool> visited) {
313 // Interrupt traversal if `r` was reached.
314 return nextRegion == r;
315 });
316}
317
318/// Return `true` if `a` and `b` are in mutually exclusive regions.
319///
320/// 1. Find the first common of `a` and `b` (ancestor) that implements
321/// RegionBranchOpInterface.
322/// 2. Determine the regions `regionA` and `regionB` in which `a` and `b` are
323/// contained.
324/// 3. Check if `regionA` and `regionB` are mutually exclusive. They are
325/// mutually exclusive if they are not reachable from each other as per
326/// RegionBranchOpInterface::getSuccessorRegions.
328 LDBG() << "Checking if operations are in mutually exclusive regions: "
329 << a->getName() << " and " << b->getName();
330
331 assert(a && "expected non-empty operation");
332 assert(b && "expected non-empty operation");
333
334 auto branchOp = a->getParentOfType<RegionBranchOpInterface>();
335 while (branchOp) {
336 LDBG() << "Checking branch operation " << branchOp->getName();
337
338 // Check if b is inside branchOp. (We already know that a is.)
339 if (!branchOp->isProperAncestor(b)) {
340 LDBG() << "Operation b is not inside branchOp, checking next ancestor";
341 // Check next enclosing RegionBranchOpInterface.
342 branchOp = branchOp->getParentOfType<RegionBranchOpInterface>();
343 continue;
344 }
345
346 LDBG() << "Both operations are inside branchOp, finding their regions";
347
348 // b is contained in branchOp. Retrieve the regions in which `a` and `b`
349 // are contained.
350 Region *regionA = nullptr, *regionB = nullptr;
351 for (Region &r : branchOp->getRegions()) {
352 if (r.findAncestorOpInRegion(*a)) {
353 assert(!regionA && "already found a region for a");
354 regionA = &r;
355 LDBG() << "Found region #" << r.getRegionNumber() << " for operation a";
356 }
357 if (r.findAncestorOpInRegion(*b)) {
358 assert(!regionB && "already found a region for b");
359 regionB = &r;
360 LDBG() << "Found region #" << r.getRegionNumber() << " for operation b";
361 }
362 }
363 assert(regionA && regionB && "could not find region of op");
364
365 LDBG() << "Region A: #" << regionA->getRegionNumber() << ", Region B: #"
366 << regionB->getRegionNumber();
367
368 // `a` and `b` are in mutually exclusive regions if both regions are
369 // distinct and neither region is reachable from the other region.
370 bool regionsAreDistinct = (regionA != regionB);
371 bool aNotReachableFromB = !isRegionReachable(regionA, regionB);
372 bool bNotReachableFromA = !isRegionReachable(regionB, regionA);
373
374 LDBG() << "Regions distinct: " << regionsAreDistinct
375 << ", A not reachable from B: " << aNotReachableFromB
376 << ", B not reachable from A: " << bNotReachableFromA;
377
378 bool mutuallyExclusive =
379 regionsAreDistinct && aNotReachableFromB && bNotReachableFromA;
380 LDBG() << "Operations are mutually exclusive: " << mutuallyExclusive;
381
382 return mutuallyExclusive;
383 }
384
385 // Could not find a common RegionBranchOpInterface among a's and b's
386 // ancestors.
387 LDBG() << "No common RegionBranchOpInterface found, operations are not "
388 "mutually exclusive";
389 return false;
390}
391
392bool RegionBranchOpInterface::isRepetitiveRegion(unsigned index) {
393 LDBG() << "Checking if region #" << index << " is repetitive in operation "
394 << getOperation()->getName();
395
396 Region *region = &getOperation()->getRegion(index);
397 bool isRepetitive = isRegionReachable(region, region);
398
399 LDBG() << "Region #" << index << " is repetitive: " << isRepetitive;
400 return isRepetitive;
401}
402
403bool RegionBranchOpInterface::hasLoop() {
404 LDBG() << "Checking if operation " << getOperation()->getName()
405 << " has loops";
406
407 SmallVector<RegionSuccessor> entryRegions;
408 getSuccessorRegions(RegionBranchPoint::parent(), entryRegions);
409 LDBG() << "Found " << entryRegions.size() << " entry regions";
410
411 for (RegionSuccessor successor : entryRegions) {
412 if (!successor.isParent()) {
413 LDBG() << "Checking entry region #"
414 << successor.getSuccessor()->getRegionNumber() << " for loops";
415
416 bool hasLoop =
417 traverseRegionGraph(successor.getSuccessor(),
418 [](Region *nextRegion, ArrayRef<bool> visited) {
419 // Interrupt traversal if the region was already
420 // visited.
421 return visited[nextRegion->getRegionNumber()];
422 });
423
424 if (hasLoop) {
425 LDBG() << "Found loop in entry region #"
426 << successor.getSuccessor()->getRegionNumber();
427 return true;
428 }
429 } else {
430 LDBG() << "Skipping parent successor";
431 }
432 }
433
434 LDBG() << "No loops found in operation";
435 return false;
436}
437
439RegionBranchOpInterface::getSuccessorOperands(RegionBranchPoint src,
440 RegionSuccessor dest) {
441 if (src.isParent())
442 return getEntrySuccessorOperands(dest);
443 return src.getTerminatorPredecessorOrNull().getSuccessorOperands(dest);
444}
445
447RegionBranchOpInterface::getNonSuccessorInputs(RegionSuccessor successor) {
448 SmallVector<Value> results = llvm::to_vector(
449 successor.isParent()
450 ? ValueRange(getOperation()->getResults())
451 : ValueRange(successor.getSuccessor()->getArguments()));
452 ValueRange successorInputs = getSuccessorInputs(successor);
453 if (!successorInputs.empty()) {
454 unsigned inputBegin =
455 successor.isParent()
456 ? cast<OpResult>(successorInputs.front()).getResultNumber()
457 : cast<BlockArgument>(successorInputs.front()).getArgNumber();
458 results.erase(results.begin() + inputBegin,
459 results.begin() + inputBegin + successorInputs.size());
460 }
461 return results;
462}
463
465 return MutableArrayRef<OpOperand>(operands.getBase(), operands.size());
466}
467
468static void
469getSuccessorOperandInputMapping(RegionBranchOpInterface branchOp,
471 RegionBranchPoint src) {
473 branchOp.getSuccessorRegions(src, successors);
474 for (RegionSuccessor dst : successors) {
475 OperandRange operands = branchOp.getSuccessorOperands(src, dst);
476 assert(operands.size() == branchOp.getSuccessorInputs(dst).size() &&
477 "expected the same number of operands and inputs");
478 for (const auto &[operand, input] : llvm::zip_equal(
479 operandsToOpOperands(operands), branchOp.getSuccessorInputs(dst)))
480 mapping[&operand].push_back(input);
481 }
482}
483void RegionBranchOpInterface::getSuccessorOperandInputMapping(
485 std::optional<RegionBranchPoint> src) {
486 if (src.has_value()) {
487 ::getSuccessorOperandInputMapping(*this, mapping, src.value());
488 } else {
489 // No region branch point specified: populate the mapping for all possible
490 // region branch points.
491 for (RegionBranchPoint branchPoint : getAllRegionBranchPoints())
492 ::getSuccessorOperandInputMapping(*this, mapping, branchPoint);
493 }
494}
495
497 const RegionBranchSuccessorMapping &operandToInputs) {
499 for (const auto &[operand, inputs] : operandToInputs) {
500 for (Value input : inputs)
501 inputToOperands[input].push_back(operand);
502 }
503 return inputToOperands;
504}
505
506void RegionBranchOpInterface::getSuccessorInputOperandMapping(
508 RegionBranchSuccessorMapping operandToInputs;
509 getSuccessorOperandInputMapping(operandToInputs);
510 mapping = invertRegionBranchSuccessorMapping(operandToInputs);
511}
512
514RegionBranchOpInterface::getAllRegionBranchPoints() {
516 branchPoints.push_back(RegionBranchPoint::parent());
517 for (Region &region : getOperation()->getRegions()) {
518 for (Block &block : region) {
519 if (block.empty())
520 continue;
521 if (auto terminator =
522 dyn_cast<RegionBranchTerminatorOpInterface>(block.back()))
523 branchPoints.push_back(RegionBranchPoint(terminator));
524 }
525 }
526 return branchPoints;
527}
528
530 LDBG() << "Finding enclosing repetitive region for operation "
531 << op->getName();
532
533 while (Region *region = op->getParentRegion()) {
534 LDBG() << "Checking region #" << region->getRegionNumber()
535 << " in operation " << region->getParentOp()->getName();
536
537 op = region->getParentOp();
538 if (auto branchOp = dyn_cast<RegionBranchOpInterface>(op)) {
539 LDBG()
540 << "Found RegionBranchOpInterface, checking if region is repetitive";
541 if (branchOp.isRepetitiveRegion(region->getRegionNumber())) {
542 LDBG() << "Found repetitive region #" << region->getRegionNumber();
543 return region;
544 }
545 } else {
546 LDBG() << "Parent operation does not implement RegionBranchOpInterface";
547 }
548 }
549
550 LDBG() << "No enclosing repetitive region found";
551 return nullptr;
552}
553
555 LDBG() << "Finding enclosing repetitive region for value";
556
557 Region *region = value.getParentRegion();
558 while (region) {
559 LDBG() << "Checking region #" << region->getRegionNumber()
560 << " in operation " << region->getParentOp()->getName();
561
562 Operation *op = region->getParentOp();
563 if (auto branchOp = dyn_cast<RegionBranchOpInterface>(op)) {
564 LDBG()
565 << "Found RegionBranchOpInterface, checking if region is repetitive";
566 if (branchOp.isRepetitiveRegion(region->getRegionNumber())) {
567 LDBG() << "Found repetitive region #" << region->getRegionNumber();
568 return region;
569 }
570 } else {
571 LDBG() << "Parent operation does not implement RegionBranchOpInterface";
572 }
573 region = op->getParentRegion();
574 }
575
576 LDBG() << "No enclosing repetitive region found for value";
577 return nullptr;
578}
579
580/// Return "true" if `a` can be used in lieu of `b`, where `b` is a region
581/// successor input and `a` is a "reachable value" of `b`. Reachable values
582/// are successor operand values that are (maybe transitively) forwarded to
583/// `b`.
584static bool isDefinedBefore(Operation *regionBranchOp, Value a, Value b) {
585 assert((b.getDefiningOp() == regionBranchOp ||
586 b.getParentRegion()->getParentOp() == regionBranchOp) &&
587 "b must be a region successor input");
588
589 // Case 1: `a` is defined inside of the region branch op. `a` must be
590 // directly nested in the region branch op. Otherwise, it could not have
591 // been among the reachable values for a region successor input.
592 if (a.getParentRegion()->getParentOp() == regionBranchOp) {
593 // Case 1.1: If `b` is a result of the region branch op, `a` is not in
594 // scope for `b`.
595 // Example:
596 // %b = region_op({
597 // ^bb0(%a1: ...):
598 // %a2 = ...
599 // })
600 if (isa<OpResult>(b))
601 return false;
602
603 // Case 1.2: `b` is an entry block argument of a region. `a` is in scope
604 // for `b` only if it is also an entry block argument of the same region.
605 // Example:
606 // region_op({
607 // ^bb0(%b: ..., %a: ...):
608 // ...
609 // })
610 assert(isa<BlockArgument>(b) && "b must be a block argument");
611 return isa<BlockArgument>(a) && cast<BlockArgument>(a).getOwner() ==
612 cast<BlockArgument>(b).getOwner();
613 }
614
615 // Case 2: `a` is defined outside of the region branch op. In that case, we
616 // can safely assume that `a` was defined before `b`. Otherwise, it could not
617 // be among the reachable values for a region successor input.
618 // Example:
619 // { <- %a1 parent region begins here.
620 // ^bb0(%a1: ...):
621 // %a2 = ...
622 // %b1 = reigon_op({
623 // ^bb1(%b2: ...):
624 // ...
625 // })
626 // }
627 return true;
628}
629
630/// Compute all non-successor-input values that a successor input could have
631/// based on the given successor input to successor operand mapping.
632///
633/// Example 1:
634/// %r = scf.if ... {
635/// scf.yield %a : ...
636/// } else {
637/// scf.yield %b : ...
638/// }
639/// reachableValues(%r) = {%a, %b}
640///
641/// Example 2:
642/// %r = scf.for ... iter_args(%arg0 = %0) -> ... {
643/// scf.yield %arg0 : ...
644/// }
645/// reachableValues(%arg0) = {%0}
646/// reachableValues(%r) = {%0}
647///
648/// Example 3:
649/// %r = scf.for ... iter_args(%arg0 = %0) -> ... {
650/// ...
651/// scf.yield %1 : ...
652/// }
653/// reachableValues(%arg0) = {%0, %1}
654/// reachableValues(%r) = {%0, %1}
655static llvm::SmallDenseSet<Value> computeReachableValuesFromSuccessorInput(
656 Value value, const RegionBranchInverseSuccessorMapping &inputToOperands) {
657 assert(inputToOperands.contains(value) && "value must be a successor input");
658 // Starting with the given value, trace back all predecessor values (i.e.,
659 // preceding successor operands) and add them to the set of reachable values.
660 // If the successor operand is again a successor input, do not add it to
661 // result set, but instead continue the traversal.
662 llvm::SmallDenseSet<Value> reachableValues;
663 llvm::SmallDenseSet<Value> visited;
664 SmallVector<Value> worklist;
665 worklist.push_back(value);
666 while (!worklist.empty()) {
667 Value next = worklist.pop_back_val();
668 auto it = inputToOperands.find(next);
669 if (it == inputToOperands.end()) {
670 reachableValues.insert(next);
671 continue;
672 }
673 for (OpOperand *operand : it->second)
674 if (visited.insert(operand->get()).second)
675 worklist.push_back(operand->get());
676 }
677 // Note: The result does not contain any successor inputs. (Therefore,
678 // `value` is also guaranteed to be excluded.)
679 return reachableValues;
680}
681
682namespace {
683/// Try to make successor inputs dead by replacing their uses with values that
684/// are not successor inputs. This pattern enables additional canonicalization
685/// opportunities for RemoveDeadRegionBranchOpSuccessorInputs.
686///
687/// Example:
688///
689/// %r0, %r1 = scf.for ... iter_args(%arg0 = %0, %arg1 = %1) -> ... {
690/// scf.yield %arg1, %arg1 : ...
691/// }
692/// use(%r0, %r1)
693///
694/// reachableValues(%r0) = {%0, %1}
695/// reachableValues(%r1) = {%1} ==> replace uses of %r1 with %1.
696/// reachableValues(%arg0) = {%0, %1}
697/// reachableValues(%arg1) = {%1} ==> replace uses of %arg1 with %1.
698///
699/// IR after pattern application:
700///
701/// %r0, %r1 = scf.for ... iter_args(%arg0 = %0, %arg1 = %1) -> ... {
702/// scf.yield %1, %1 : ...
703/// }
704/// use(%r0, %1)
705///
706/// Note that %r1 and %arg1 are dead now. The IR can now be further
707/// canonicalized by RemoveDeadRegionBranchOpSuccessorInputs.
708struct MakeRegionBranchOpSuccessorInputsDead : public RewritePattern {
709 MakeRegionBranchOpSuccessorInputsDead(MLIRContext *context, StringRef name,
710 PatternBenefit benefit = 1)
711 : RewritePattern(name, benefit, context) {}
712
713 LogicalResult matchAndRewrite(Operation *op,
714 PatternRewriter &rewriter) const override {
715 assert(!op->hasTrait<OpTrait::IsIsolatedFromAbove>() &&
716 "isolated-from-above ops are not supported");
717
718 // Compute the mapping of successor inputs to successor operands.
719 auto regionBranchOp = cast<RegionBranchOpInterface>(op);
721 regionBranchOp.getSuccessorInputOperandMapping(inputToOperands);
722
723 // Try to replace the uses of each successor input one-by-one.
724 bool changed = false;
725 for (Value value : inputToOperands.keys()) {
726 // Nothing to do for successor inputs that are already dead.
727 if (value.use_empty())
728 continue;
729 // Nothing to do for successor inputs that may have multiple reachable
730 // values.
731 llvm::SmallDenseSet<Value> reachableValues =
732 computeReachableValuesFromSuccessorInput(value, inputToOperands);
733 if (reachableValues.size() != 1)
734 continue;
735 assert(*reachableValues.begin() != value &&
736 "successor inputs are supposed to be excluded");
737 // Do not replace `value` with the found reachable value if doing so
738 // would violate dominance. Example:
739 // %r = scf.execute_region ... {
740 // %a = ...
741 // scf.yield %a : ...
742 // }
743 // use(%r)
744 // In the above example, reachableValues(%r) = {%a}, but %a cannot be
745 // used as a replacement for %r due to dominance / scope.
746 if (!isDefinedBefore(regionBranchOp, *reachableValues.begin(), value))
747 continue;
748 rewriter.replaceAllUsesWith(value, *reachableValues.begin());
749 changed = true;
750 }
751 return success(changed);
752 }
753};
754
755/// Lookup a bit vector in the given mapping (DenseMap). If the key was not
756/// found, create a new bit vector with the given size and initialize it with
757/// false.
758template <typename MappingTy, typename KeyTy>
759static BitVector &lookupOrCreateBitVector(MappingTy &mapping, KeyTy key,
760 unsigned size) {
761 return mapping.try_emplace(key, size, false).first->second;
762}
763
764/// Compute tied successor inputs. Tied successor inputs are successor inputs
765/// that come as a set. If you erase one value from a set, you must erase all
766/// values from the set. Otherwise, the op would become structurally invalid.
767/// Each successor input appears in exactly one set.
768///
769/// Example:
770/// %r0, %r1 = scf.for ... iter_args(%arg0 = %0, %arg1 = %1) -> ... {
771/// ...
772/// }
773/// There are two sets: {{%r0, %arg0}, {%r1, %arg1}}.
774static llvm::EquivalenceClasses<Value> computeTiedSuccessorInputs(
775 const RegionBranchSuccessorMapping &operandToInputs) {
776 llvm::EquivalenceClasses<Value> tiedSuccessorInputs;
777 for (const auto &[operand, inputs] : operandToInputs) {
778 assert(!inputs.empty() && "expected non-empty inputs");
779 Value firstInput = inputs.front();
780 tiedSuccessorInputs.insert(firstInput);
781 for (Value nextInput : llvm::drop_begin(inputs)) {
782 // As we explore more successor operand to successor input mappings,
783 // existing sets may get merged.
784 tiedSuccessorInputs.unionSets(firstInput, nextInput);
785 }
786 }
787 return tiedSuccessorInputs;
788}
789
790/// Remove dead successor inputs from region branch ops. A successor input is
791/// dead if it has no uses. Successor inputs come in sets of tied values: if
792/// you remove one value from a set, you must remove all values from the set.
793/// Furthermore, successor operands must also be removed. (Op operands are not
794/// part of the set, but the set is built based on the successor operand to
795/// successor input mapping.)
796///
797/// Example 1:
798/// %r0, %r1 = scf.for ... iter_args(%arg0 = %0, %arg1 = %1) -> ... {
799/// scf.yield %0, %arg1 : ...
800/// }
801/// use(%0, %1)
802///
803/// There are two sets: {{%r0, %arg0}, {%r1, %arg1}}. All values in the first
804/// set are dead, so %arg0 and %r0 can be removed, but not %r1 and %arg1. The
805/// resulting IR is as follows:
806///
807/// %r1 = scf.for ... iter_args(%arg1 = %1) -> ... {
808/// scf.yield %arg1 : ...
809/// }
810/// use(%0, %1)
811///
812/// Example 2:
813/// %r0, %r1 = scf.while (%arg0 = %0) {
814/// scf.condition(...) %arg0, %arg0 : ...
815/// } do {
816/// ^bb0(%arg1: ..., %arg2: ...):
817/// scf.yield %arg1 : ...
818/// }
819/// There are three sets: {{%r0, %arg1}, {%r1, %arg2}, {%r0}}.
820///
821/// Example 3:
822/// %r1, %r2 = scf.if ... {
823/// scf.yield %0, %1 : ...
824/// } else {
825/// scf.yield %2, %3 : ...
826/// }
827/// There are two sets: {{%r1}, {%r2}}. Each set has one value, so there each
828/// value can be removed independently of the other values.
829struct RemoveDeadRegionBranchOpSuccessorInputs : public RewritePattern {
830 RemoveDeadRegionBranchOpSuccessorInputs(MLIRContext *context, StringRef name,
831 PatternBenefit benefit = 1)
832 : RewritePattern(name, benefit, context) {}
833
834 LogicalResult matchAndRewrite(Operation *op,
835 PatternRewriter &rewriter) const override {
836 assert(!op->hasTrait<OpTrait::IsIsolatedFromAbove>() &&
837 "isolated-from-above ops are not supported");
838
839 // Compute tied values: values that must come as a set. If you remove one,
840 // you must remove all. If a successor op operand is forwarded to two
841 // successor inputs %a and %b, both %a and %b are in the same set.
842 auto regionBranchOp = cast<RegionBranchOpInterface>(op);
843 RegionBranchSuccessorMapping operandToInputs;
844 regionBranchOp.getSuccessorOperandInputMapping(operandToInputs);
845 llvm::EquivalenceClasses<Value> tiedSuccessorInputs =
846 computeTiedSuccessorInputs(operandToInputs);
847
848 // Determine which values to remove and group them by block and operation.
849 SmallVector<Value> valuesToRemove;
850 DenseMap<Block *, BitVector> blockArgsToRemove;
851 BitVector resultsToRemove(regionBranchOp->getNumResults(), false);
852 // Iterate over all sets of tied successor inputs.
853 for (auto it = tiedSuccessorInputs.begin(), e = tiedSuccessorInputs.end();
854 it != e; ++it) {
855 if (!(*it)->isLeader())
856 continue;
857
858 // Value can be removed if it is dead and all other tied values are also
859 // dead.
860 bool allDead = true;
861 for (auto memberIt = tiedSuccessorInputs.member_begin(**it);
862 memberIt != tiedSuccessorInputs.member_end(); ++memberIt) {
863 // Iterate over all values in the set and check their liveness.
864 if (!memberIt->use_empty()) {
865 allDead = false;
866 break;
867 }
868 }
869 if (!allDead)
870 continue;
871
872 // The entire set is dead. Group values by block and operation to
873 // simplify removal.
874 for (auto memberIt = tiedSuccessorInputs.member_begin(**it);
875 memberIt != tiedSuccessorInputs.member_end(); ++memberIt) {
876 if (auto arg = dyn_cast<BlockArgument>(*memberIt)) {
877 // Set blockArgsToRemove[block][arg_number] = true.
878 BitVector &vector =
879 lookupOrCreateBitVector(blockArgsToRemove, arg.getOwner(),
880 arg.getOwner()->getNumArguments());
881 vector.set(arg.getArgNumber());
882 } else {
883 // Set resultsToRemove[result_number] = true.
884 OpResult result = cast<OpResult>(*memberIt);
885 assert(result.getDefiningOp() == regionBranchOp &&
886 "result must be a region branch op result");
887 resultsToRemove.set(result.getResultNumber());
888 }
889 valuesToRemove.push_back(*memberIt);
890 }
891 }
892
893 if (valuesToRemove.empty())
894 return rewriter.notifyMatchFailure(op, "no values to remove");
895
896 // Find operands that must be removed together with the values.
897 RegionBranchInverseSuccessorMapping inputsToOperands =
898 invertRegionBranchSuccessorMapping(operandToInputs);
900 for (Value value : valuesToRemove) {
901 for (OpOperand *operand : inputsToOperands[value]) {
902 // Set operandsToRemove[op][operand_number] = true.
903 BitVector &vector =
904 lookupOrCreateBitVector(operandsToRemove, operand->getOwner(),
905 operand->getOwner()->getNumOperands());
906 vector.set(operand->getOperandNumber());
907 }
908 }
909
910 // Erase operands.
911 for (auto &pair : operandsToRemove) {
912 Operation *op = pair.first;
913 BitVector &operands = pair.second;
914 rewriter.modifyOpInPlace(op, [&]() { op->eraseOperands(operands); });
915 }
916
917 // Erase block arguments.
918 for (auto &pair : blockArgsToRemove) {
919 Block *block = pair.first;
920 BitVector &blockArg = pair.second;
921 rewriter.modifyOpInPlace(block->getParentOp(),
922 [&]() { block->eraseArguments(blockArg); });
923 }
924
925 // Erase op results.
926 if (resultsToRemove.any())
927 rewriter.eraseOpResults(regionBranchOp, resultsToRemove);
928
929 return success();
930 }
931};
932
933/// Return "true" if the two values are owned by the same operation or block.
934static bool haveSameOwner(Value a, Value b) {
935 void *aOwner, *bOwner;
936 if (auto arg = dyn_cast<BlockArgument>(a))
937 aOwner = arg.getOwner();
938 else
939 aOwner = a.getDefiningOp();
940 if (auto arg = dyn_cast<BlockArgument>(b))
941 bOwner = arg.getOwner();
942 else
943 bOwner = b.getDefiningOp();
944 return aOwner == bOwner;
945}
946
947/// Get the block argument or op result number of the given value.
948static unsigned getArgOrResultNumber(Value value) {
949 if (auto opResult = llvm::dyn_cast<OpResult>(value))
950 return opResult.getResultNumber();
951 return llvm::cast<BlockArgument>(value).getArgNumber();
952}
953
954/// Find duplicate successor inputs and make all dead except for one. Two
955/// successor inputs are "duplicate" if their corresponding successor operands
956/// have the same values. This pattern enables additional canonicalization
957/// opportunities for RemoveDeadRegionBranchOpSuccessorInputs.
958///
959/// Example:
960/// %r0, %r1 = scf.for ... iter_args(%arg0 = %0, %arg1 = %0) -> ... {
961/// use(%arg0, %arg1)
962/// ...
963/// scf.yield %x, %x : ...
964/// }
965/// use(%r0, %r1)
966///
967/// Operands of successor input %r0: [%0, %x]
968/// Operands of successor input %r1: [%0, %x] ==> DUPLICATE!
969/// Replace %r1 with %r0.
970///
971/// Operands of successor input %arg0: [%0, %x]
972/// Operands of successor input %arg1: [%0, %x] ==> DUPLICATE!
973/// Replace %arg1 with %arg0. (We have to make sure that we make same decision
974/// as for the other tied successor inputs above. Otherwise, a set of tied
975/// successor inputs may not become entirely dead.)
976///
977/// The resulting IR is as follows:
978/// %r0, %r1 = scf.for ... iter_args(%arg0 = %0, %arg1 = %0) -> ... {
979/// use(%arg0, %arg0)
980/// ...
981/// scf.yield %x, %x : ...
982/// }
983/// use(%r0, %r0) // Note: We don't want use(%r1, %r1), which is also correct,
984/// // but does not help with further canonicalizations.
985struct RemoveDuplicateSuccessorInputUses : public RewritePattern {
986 RemoveDuplicateSuccessorInputUses(MLIRContext *context, StringRef name,
987 PatternBenefit benefit = 1)
988 : RewritePattern(name, benefit, context) {}
989
990 LogicalResult matchAndRewrite(Operation *op,
991 PatternRewriter &rewriter) const override {
992 assert(!op->hasTrait<OpTrait::IsIsolatedFromAbove>() &&
993 "isolated-from-above ops are not supported");
994
995 // Collect all successor inputs and sort them. When dropping the uses of a
996 // successor input, we'd like to also drop the uses of the same tied
997 // successor inputs. Otherwise, a set of tied successor inputs may not
998 // become entirely dead, which is required for
999 // RemoveDeadRegionBranchOpSuccessorInputs to be able to erase them.
1000 // (Sorting is not required for correctness.)
1001 auto regionBranchOp = cast<RegionBranchOpInterface>(op);
1002 RegionBranchInverseSuccessorMapping inputsToOperands;
1003 regionBranchOp.getSuccessorInputOperandMapping(inputsToOperands);
1004 SmallVector<Value> inputs = llvm::to_vector(inputsToOperands.keys());
1005 llvm::sort(inputs, [](Value a, Value b) {
1006 return getArgOrResultNumber(a) < getArgOrResultNumber(b);
1007 });
1008
1009 // Check every distinct pair of successor inputs for duplicates. Replace
1010 // `input2` with `input1` if they are duplicates.
1011 bool changed = false;
1012 unsigned numInputs = inputs.size();
1013 for (auto i : llvm::seq<unsigned>(0, numInputs)) {
1014 Value input1 = inputs[i];
1015 for (auto j : llvm::seq<unsigned>(i + 1, numInputs)) {
1016 Value input2 = inputs[j];
1017 // Nothing to do if input2 is already dead.
1018 if (input2.use_empty())
1019 continue;
1020 // Replace only values that belong to the same block / operation.
1021 // This implies that the two values are either both block arguments or
1022 // both op results.
1023 if (!haveSameOwner(input1, input2))
1024 continue;
1025
1026 // Gather the predecessor value for each predecessor (region branch
1027 // point). The two inputs are duplicates if each predecessor forwards
1028 // the same value.
1029 llvm::SmallDenseMap<Operation *, Value> operands1, operands2;
1030 for (OpOperand *operand : inputsToOperands[input1]) {
1031 assert(!operands1.contains(operand->getOwner()));
1032 operands1[operand->getOwner()] = operand->get();
1033 }
1034 for (OpOperand *operand : inputsToOperands[input2]) {
1035 assert(!operands2.contains(operand->getOwner()));
1036 operands2[operand->getOwner()] = operand->get();
1037 }
1038 if (operands1 == operands2) {
1039 rewriter.replaceAllUsesWith(input2, input1);
1040 changed = true;
1041 }
1042 }
1043 }
1044 return success(changed);
1045 }
1046};
1047
1048/// Given a range of values, return a vector of attributes of the same size,
1049/// where the i-th attribute is the constant value of the i-th value. If a
1050/// value is not constant, the corresponding attribute is null.
1051static SmallVector<Attribute> extractConstants(ValueRange values) {
1052 return llvm::map_to_vector(values, [](Value value) {
1053 Attribute attr;
1054 matchPattern(value, m_Constant(&attr));
1055 return attr;
1056 });
1057}
1058
1059/// Return all successor regions when branching from the given region branch
1060/// point. This helper functions extracts all constant operand values and
1061/// passes them to the `RegionBranchOpInterface`.
1063getSuccessorRegionsWithAttrs(RegionBranchOpInterface op,
1064 RegionBranchPoint point) {
1066 if (point.isParent()) {
1067 op.getEntrySuccessorRegions(extractConstants(op->getOperands()),
1068 successors);
1069 return successors;
1070 }
1071 RegionBranchTerminatorOpInterface terminator =
1073 terminator.getSuccessorRegions(extractConstants(terminator->getOperands()),
1074 successors);
1075 return successors;
1076}
1077
1078/// Find the single acyclic path through the given region branch op. Return an
1079/// empty vector if no such path or multiple such paths exist.
1080///
1081/// Example: "scf.if %true" has a single path: parent => then_region => parent
1082///
1083/// Example: "scf.if ???" has multiple paths:
1084/// (1) parent => then_region => parent
1085/// (2) parent => else_region => parent
1086///
1087/// Example: "scf.while with scf.condition(%false)" has a single path:
1088/// parent => before_region => parent
1089///
1090/// Example: "scf.for with 0 iterations" has a single path: parent => parent
1091///
1092/// Note: Each path starts and ends with "parent". The "parent" at the beginning
1093/// of the path is omitted from the result.
1094///
1095/// Note: This function also returns an "empty" path when a region with multiple
1096/// blocks was found.
1098computeSingleAcyclicRegionBranchPath(RegionBranchOpInterface op) {
1099 llvm::SmallDenseSet<Region *> visited;
1101
1102 // Path starts with "parent".
1104 do {
1105 SmallVector<RegionSuccessor> successors =
1106 getSuccessorRegionsWithAttrs(op, next);
1107 if (successors.size() != 1) {
1108 // There are multiple region successors. I.e., there are multiple paths
1109 // through the region branch op.
1110 return {};
1111 }
1112 path.push_back(successors.front());
1113 if (successors.front().isParent()) {
1114 // Found path that ends with "parent".
1115 return path;
1116 }
1117 Region *region = successors.front().getSuccessor();
1118 if (!region->hasOneBlock()) {
1119 // Entering a region with multiple blocks. Such regions are not supported
1120 // at the moment.
1121 return {};
1122 }
1123 if (!visited.insert(region).second) {
1124 // We have already visited this region. I.e., we have found a cycle.
1125 return {};
1126 }
1127 auto terminator =
1128 dyn_cast<RegionBranchTerminatorOpInterface>(&region->front().back());
1129 if (!terminator) {
1130 // Region has no RegionBranchTerminatorOpInterface terminator. E.g., the
1131 // terminator could be a "ub.unreachable" op. Such IR is not supported.
1132 return {};
1133 }
1134 next = RegionBranchPoint(terminator);
1135 } while (true);
1136 llvm_unreachable("expected to return from loop");
1137}
1138
1139/// Inline the body of the matched region branch op into the enclosing block if
1140/// there is exactly one acyclic path through the region branch op, starting
1141/// from "parent", and if that path ends with "parent".
1142///
1143/// Example: This pattern can inline "scf.for" operations that are guaranteed to
1144/// have a single iteration, as indicated by the region branch path "parent =>
1145/// region => parent". "scf.for" operations have a non-successor-input: the loop
1146/// induction variable. Non-successor-input values have op-specific semantics
1147/// and cannot be reasoned about through the `RegionBranchOpInterface`. A
1148/// replacement value for non-successor-inputs is injected by the user-specified
1149/// lambda: in the case of the loop induction variable of an "scf.for", the
1150/// lower bound of the loop is used as a replacement value.
1151///
1152/// Before pattern application:
1153/// %r = scf.for %iv = %c5 to %c6 step %c1 iter_args(%arg0 = %0) {
1154/// %1 = "producer"(%arg0, %iv)
1155/// scf.yield %1
1156/// }
1157/// "user"(%r)
1158///
1159/// After pattern application:
1160/// %1 = "producer"(%0, %c5)
1161/// "user"(%1)
1162///
1163/// This pattern is limited to the following cases:
1164/// - Only regions with a single block are supported. This could be generalized.
1165/// - Region branch ops with side effects are not supported. (Recursive side
1166/// effects are fine.)
1167///
1168/// Note: This pattern queries the region dataflow from the
1169/// `RegionBranchOpInterface`. Replacement values are for block arguments / op
1170/// results are determined based on region dataflow. In case of
1171/// non-successor-inputs (whose values are not modeled by the
1172/// `RegionBranchOpInterface`), a user-specified lambda is queried.
1173struct InlineRegionBranchOp : public RewritePattern {
1174 InlineRegionBranchOp(MLIRContext *context, StringRef name,
1176 PatternMatcherFn matcherFn, PatternBenefit benefit = 1)
1177 : RewritePattern(name, benefit, context), replBuilderFn(replBuilderFn),
1178 matcherFn(matcherFn) {}
1179
1180 LogicalResult matchAndRewrite(Operation *op,
1181 PatternRewriter &rewriter) const override {
1182 // Check if the pattern is applicable to the given operation.
1183 if (failed(matcherFn(op)))
1184 return rewriter.notifyMatchFailure(op, "pattern not applicable");
1185
1186 // Patterns without recursive memory effects could have side effects, so
1187 // it is not safe to fold such ops away.
1188 if (!op->hasTrait<OpTrait::HasRecursiveMemoryEffects>())
1189 return rewriter.notifyMatchFailure(
1190 op, "pattern not applicable to ops without recursive memory effects");
1191
1192 // Find the single acyclic path through the region branch op.
1193 auto regionBranchOp = cast<RegionBranchOpInterface>(op);
1194 SmallVector<RegionSuccessor> path =
1195 computeSingleAcyclicRegionBranchPath(regionBranchOp);
1196 if (path.empty())
1197 return rewriter.notifyMatchFailure(
1198 op, "failed to find acyclic region branch path");
1199
1200 // Inline all regions on the path into the enclosing block.
1201 rewriter.setInsertionPoint(op);
1202 ArrayRef remainingPath = path;
1203 SmallVector<Value> successorOperands = llvm::to_vector(
1204 regionBranchOp.getEntrySuccessorOperands(remainingPath.front()));
1205 while (!remainingPath.empty()) {
1206 RegionSuccessor nextSuccessor = remainingPath.consume_front();
1207 ValueRange successorInputs =
1208 regionBranchOp.getSuccessorInputs(nextSuccessor);
1209 assert(successorInputs.size() == successorOperands.size() &&
1210 "size mismatch");
1211 // Find the index of the first block argument / op result that is a
1212 // succesor input.
1213 unsigned firstSuccessorInputIdx = 0;
1214 if (!successorInputs.empty())
1215 firstSuccessorInputIdx =
1216 nextSuccessor.isParent()
1217 ? cast<OpResult>(successorInputs.front()).getResultNumber()
1218 : cast<BlockArgument>(successorInputs.front()).getArgNumber();
1219 // Query the total number of block arguments / op results.
1220 unsigned numValues =
1221 nextSuccessor.isParent()
1222 ? op->getNumResults()
1223 : nextSuccessor.getSuccessor()->getNumArguments();
1224 // Compute replacement values for all block arguments / op results.
1225 SmallVector<Value> replacements;
1226 // Helper function to get the i-th block argument / op result.
1227 auto getValue = [&](unsigned idx) {
1228 return nextSuccessor.isParent()
1229 ? Value(op->getResult(idx))
1230 : Value(nextSuccessor.getSuccessor()->getArgument(idx));
1231 };
1232 // Compute replacement values for all non-successor-input values that
1233 // precede the first successor input.
1234 for (unsigned i = 0; i < firstSuccessorInputIdx; ++i)
1235 replacements.push_back(
1236 replBuilderFn(rewriter, op->getLoc(), getValue(i)));
1237 // Use the successor operands of the predecessor as replacement values for
1238 // the successor inputs.
1239 llvm::append_range(replacements, successorOperands);
1240 // Compute replacement values for all block arguments / op results that
1241 // succeed the first successor input.
1242 for (unsigned i = replacements.size(); i < numValues; ++i)
1243 replacements.push_back(
1244 replBuilderFn(rewriter, op->getLoc(), getValue(i)));
1245 if (nextSuccessor.isParent()) {
1246 // The path ends with "parent". Replace the region branch op with the
1247 // computed replacement values.
1248 assert(remainingPath.empty() && "expected that the path ended");
1249 rewriter.replaceOp(op, replacements);
1250 return success();
1251 }
1252 // We are inside of a region: query the successor operands from the
1253 // terminator, inline the region into the enclosing block, and erase the
1254 // terminator.
1255 auto terminator = cast<RegionBranchTerminatorOpInterface>(
1256 &nextSuccessor.getSuccessor()->front().back());
1257 rewriter.inlineBlockBefore(&nextSuccessor.getSuccessor()->front(),
1258 op->getBlock(), op->getIterator(),
1259 replacements);
1260 successorOperands = llvm::to_vector(
1261 terminator.getSuccessorOperands(remainingPath.front()));
1262 rewriter.eraseOp(terminator);
1263 }
1264
1265 llvm_unreachable("expected that paths ends with parent");
1266 }
1267
1269 PatternMatcherFn matcherFn;
1270};
1271} // namespace
1272
1274 RewritePatternSet &patterns, StringRef opName, PatternBenefit benefit) {
1275 patterns.add<MakeRegionBranchOpSuccessorInputsDead,
1276 RemoveDuplicateSuccessorInputUses,
1277 RemoveDeadRegionBranchOpSuccessorInputs>(patterns.getContext(),
1278 opName, benefit);
1279}
1280
1282 RewritePatternSet &patterns, StringRef opName,
1284 PatternMatcherFn matcherFn, PatternBenefit benefit) {
1285 patterns.add<InlineRegionBranchOp>(patterns.getContext(), opName,
1286 replBuilderFn, matcherFn, benefit);
1287}
return success()
static LogicalResult verifyWeights(Operation *op, llvm::ArrayRef< int32_t > weights, std::size_t expectedWeightsNum, llvm::StringRef weightAnchorName, llvm::StringRef weightRefName)
static bool isDefinedBefore(Operation *regionBranchOp, Value a, Value b)
Return "true" if a can be used in lieu of b, where b is a region successor input and a is a "reachabl...
static void getSuccessorOperandInputMapping(RegionBranchOpInterface branchOp, RegionBranchSuccessorMapping &mapping, RegionBranchPoint src)
static bool traverseRegionGraph(Region *begin, StopConditionFn stopConditionFn)
Traverse the region graph starting at begin.
static llvm::SmallDenseSet< Value > computeReachableValuesFromSuccessorInput(Value value, const RegionBranchInverseSuccessorMapping &inputToOperands)
Compute all non-successor-input values that a successor input could have based on the given successor...
static RegionBranchInverseSuccessorMapping invertRegionBranchSuccessorMapping(const RegionBranchSuccessorMapping &operandToInputs)
function_ref< bool(Region *, ArrayRef< bool > visited)> StopConditionFn
Stop condition for traverseRegionGraph.
static bool isRegionReachable(Region *begin, Region *r)
Return true if region r is reachable from region begin according to the RegionBranchOpInterface (by t...
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
static std::string diag(const llvm::Value &value)
static MutableArrayRef< OpOperand > operandsToOpOperands(OperandRange &operands)
Attributes are known-constant values of operations.
Definition Attributes.h:25
Block represents an ordered list of Operations.
Definition Block.h:33
BlockArgument getArgument(unsigned i)
Definition Block.h:139
unsigned getNumArguments()
Definition Block.h:138
Operation & back()
Definition Block.h:162
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
Definition Block.cpp:31
IRValueT get() const
Return the current value being used by this operand.
This class represents a diagnostic that is inflight and set to be reported.
This class provides a mutable adaptor for a range of operands.
Definition ValueRange.h:118
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition Builders.h:398
This class represents an operand of an operation.
Definition Value.h:257
This class implements the operand iterators for the Operation class.
Definition ValueRange.h:43
unsigned getBeginOperandIndex() const
Return the operand index of the first element of this range.
type_range getTypes() const
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition Operation.h:749
unsigned getNumSuccessors()
Definition Operation.h:706
void eraseOperands(unsigned idx, unsigned length=1)
Erase the operands starting at position idx and ending at position 'idx'+'length'.
Definition Operation.h:360
Block * getBlock()
Returns the operation block that contains this operation.
Definition Operation.h:213
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition Operation.h:407
unsigned getNumRegions()
Returns the number of regions held by this operation.
Definition Operation.h:674
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:223
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition Operation.h:234
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
Definition Operation.h:238
OperationName getName()
The name of an operation is the key identifier for it.
Definition Operation.h:119
Block * getSuccessor(unsigned index)
Definition Operation.h:708
Region * getParentRegion()
Returns the region to which the instruction belongs.
Definition Operation.h:230
unsigned getNumResults()
Return the number of results held by this operation.
Definition Operation.h:404
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
This class represents a point being branched from in the methods of the RegionBranchOpInterface.
bool isParent() const
Returns true if branching from the parent op.
static constexpr RegionBranchPoint parent()
Returns an instance of RegionBranchPoint representing the parent operation.
RegionBranchTerminatorOpInterface getTerminatorPredecessorOrNull() const
Returns the terminator if branching from a region.
This class represents a successor of a region.
bool isParent() const
Return true if the successor is the parent operation.
Region * getSuccessor() const
Return the given region successor.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition Region.h:26
Block & front()
Definition Region.h:65
BlockArgListType getArguments()
Definition Region.h:81
unsigned getRegionNumber()
Return the number of this region in the parent operation.
Definition Region.cpp:62
unsigned getNumArguments()
Definition Region.h:123
Operation * getParentOp()
Return the parent operation this region is attached to.
Definition Region.h:200
bool hasOneBlock()
Return true if this region has exactly one block.
Definition Region.h:68
RewritePattern is the common base class for all DAG to DAG replacements.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
Operation * eraseOpResults(Operation *op, const BitVector &eraseIndices)
Erase the specified results of the given operation.
virtual void inlineBlockBefore(Block *source, Block *dest, Block::iterator before, ValueRange argValues={})
Inline the operations of block 'source' into block 'dest' before the given position.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
virtual void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
This class models how operands are forwarded to block arguments in control flow.
SuccessorOperands(MutableOperandRange forwardedOperands)
Constructs a SuccessorOperands with no produced operands that simply forwards operands to the success...
unsigned getProducedOperandCount() const
Returns the amount of operands that are produced internally by the operation.
unsigned size() const
Returns the amount of operands passed to the successor.
OperandRange getForwardedOperands() const
Get the range of operands that are simply forwarded to the successor.
This class provides an abstraction over the various different ranges of value types.
Definition TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:387
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
bool use_empty() const
Returns true if this value has no uses.
Definition Value.h:208
Type getType() const
Return the type of this value.
Definition Value.h:105
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
Region * getParentRegion()
Return the Region in which this Value is defined.
Definition Value.cpp:39
std::optional< BlockArgument > getBranchSuccessorArgument(const SuccessorOperands &operands, unsigned operandIndex, Block *successor)
Return the BlockArgument corresponding to operand operandIndex in some successor if operandIndex is w...
LogicalResult verifyRegionBranchWeights(Operation *op)
Verify that the region weights attached to an operation implementing WeightedRegiobBranchOpInterface ...
LogicalResult verifyBranchSuccessorOperands(Operation *op, unsigned succNo, const SuccessorOperands &operands)
Verify that the given operands match those of the given successor block.
LogicalResult verifyRegionBranchOpInterface(Operation *op)
Verify that types match along control flow edges described the given op.
LogicalResult verifyBranchWeights(Operation *op)
Verify that the branch weights attached to an operation implementing WeightedBranchOpInterface are co...
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:573
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition Matchers.h:490
DenseMap< OpOperand *, SmallVector< Value > > RegionBranchSuccessorMapping
A mapping from successor operands to successor inputs.
const FrozenRewritePatternSet GreedyRewriteConfig bool * changed
std::function< LogicalResult(Operation *)> PatternMatcherFn
Helper function for the region branch op inlining pattern that checks if the pattern is applicable to...
bool insideMutuallyExclusiveRegions(Operation *a, Operation *b)
Return true if a and b are in mutually exclusive regions as per RegionBranchOpInterface.
std::function< Value(OpBuilder &, Location, Value)> NonSuccessorInputReplacementBuilderFn
Helper function for the region branch op inlining pattern that builds replacement values for non-succ...
Region * getEnclosingRepetitiveRegion(Operation *op)
Return the first enclosing region of the given op that may be executed repetitively as per RegionBran...
const FrozenRewritePatternSet & patterns
DenseMap< Value, SmallVector< OpOperand * > > RegionBranchInverseSuccessorMapping
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
Definition LLVM.h:118
void populateRegionBranchOpInterfaceInliningPattern(RewritePatternSet &patterns, StringRef opName, NonSuccessorInputReplacementBuilderFn replBuilderFn=detail::defaultReplBuilderFn, PatternMatcherFn matcherFn=detail::defaultMatcherFn, PatternBenefit benefit=1)
Populate a pattern that inlines the body of region branch ops when there is a single acyclic path thr...
void populateRegionBranchOpInterfaceCanonicalizationPatterns(RewritePatternSet &patterns, StringRef opName, PatternBenefit benefit=1)
Populate canonicalization patterns that simplify successor operands/inputs of region branch operation...
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition Matchers.h:369
llvm::function_ref< Fn > function_ref
Definition LLVM.h:144