MLIR 23.0.0git
ControlFlowInterfaces.cpp
Go to the documentation of this file.
1//===- ControlFlowInterfaces.cpp - ControlFlow Interfaces -----------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include <map>
10#include <utility>
11
13#include "mlir/IR/Matchers.h"
14#include "mlir/IR/Operation.h"
17#include "llvm/ADT/EquivalenceClasses.h"
18#include "llvm/Support/DebugLog.h"
19
20using namespace mlir;
21
22//===----------------------------------------------------------------------===//
23// ControlFlowInterfaces
24//===----------------------------------------------------------------------===//
25
26#include "mlir/Interfaces/ControlFlowInterfaces.cpp.inc"
27
29 : producedOperandCount(0), forwardedOperands(std::move(forwardedOperands)) {
30}
31
32SuccessorOperands::SuccessorOperands(unsigned int producedOperandCount,
33 MutableOperandRange forwardedOperands)
34 : producedOperandCount(producedOperandCount),
35 forwardedOperands(std::move(forwardedOperands)) {}
36
37//===----------------------------------------------------------------------===//
38// BranchOpInterface
39//===----------------------------------------------------------------------===//
40
41/// Returns the `BlockArgument` corresponding to operand `operandIndex` in some
42/// successor if 'operandIndex' is within the range of 'operands', or
43/// std::nullopt if `operandIndex` isn't a successor operand index.
44std::optional<BlockArgument>
46 unsigned operandIndex, Block *successor) {
47 LDBG() << "Getting branch successor argument for operand index "
48 << operandIndex << " in successor block";
49
50 OperandRange forwardedOperands = operands.getForwardedOperands();
51 // Check that the operands are valid.
52 if (forwardedOperands.empty()) {
53 LDBG() << "No forwarded operands, returning nullopt";
54 return std::nullopt;
55 }
56
57 // Check to ensure that this operand is within the range.
58 unsigned operandsStart = forwardedOperands.getBeginOperandIndex();
59 if (operandIndex < operandsStart ||
60 operandIndex >= (operandsStart + forwardedOperands.size())) {
61 LDBG() << "Operand index " << operandIndex << " out of range ["
62 << operandsStart << ", "
63 << (operandsStart + forwardedOperands.size())
64 << "), returning nullopt";
65 return std::nullopt;
66 }
67
68 // Index the successor.
69 unsigned argIndex =
70 operands.getProducedOperandCount() + operandIndex - operandsStart;
71 LDBG() << "Computed argument index " << argIndex << " for successor block";
72 return successor->getArgument(argIndex);
73}
74
75/// Verify that the given operands match those of the given successor block.
76LogicalResult
78 const SuccessorOperands &operands) {
79 LDBG() << "Verifying branch successor operands for successor #" << succNo
80 << " in operation " << op->getName();
81
82 // Check the count.
83 unsigned operandCount = operands.size();
84 Block *destBB = op->getSuccessor(succNo);
85 LDBG() << "Branch has " << operandCount << " operands, target block has "
86 << destBB->getNumArguments() << " arguments";
87
88 if (operandCount != destBB->getNumArguments())
89 return op->emitError() << "branch has " << operandCount
90 << " operands for successor #" << succNo
91 << ", but target block has "
92 << destBB->getNumArguments();
93
94 // Check the types.
95 LDBG() << "Checking type compatibility for "
96 << (operandCount - operands.getProducedOperandCount())
97 << " forwarded operands";
98 for (unsigned i = operands.getProducedOperandCount(); i != operandCount;
99 ++i) {
100 Type operandType = operands[i].getType();
101 Type argType = destBB->getArgument(i).getType();
102 LDBG() << "Checking type compatibility: operand type " << operandType
103 << " vs argument type " << argType;
104
105 if (!cast<BranchOpInterface>(op).areTypesCompatible(operandType, argType))
106 return op->emitError() << "type mismatch for bb argument #" << i
107 << " of successor #" << succNo;
108 }
109
110 LDBG() << "Branch successor operand verification successful";
111 return success();
112}
113
114//===----------------------------------------------------------------------===//
115// WeightedBranchOpInterface
116//===----------------------------------------------------------------------===//
117
118static LogicalResult verifyWeights(Operation *op,
120 std::size_t expectedWeightsNum,
121 llvm::StringRef weightAnchorName,
122 llvm::StringRef weightRefName) {
123 if (weights.empty())
124 return success();
125
126 if (weights.size() != expectedWeightsNum)
127 return op->emitError() << "expects number of " << weightAnchorName
128 << " weights to match number of " << weightRefName
129 << ": " << weights.size() << " vs "
130 << expectedWeightsNum;
131
132 if (llvm::all_of(weights, [](int32_t value) { return value == 0; }))
133 return op->emitError() << "branch weights cannot all be zero";
134
135 return success();
136}
137
140 cast<WeightedBranchOpInterface>(op).getWeights();
141 return verifyWeights(op, weights, op->getNumSuccessors(), "branch",
142 "successors");
143}
144
145//===----------------------------------------------------------------------===//
146// WeightedRegionBranchOpInterface
147//===----------------------------------------------------------------------===//
148
151 cast<WeightedRegionBranchOpInterface>(op).getWeights();
152 return verifyWeights(op, weights, op->getNumRegions(), "region", "regions");
153}
154
155//===----------------------------------------------------------------------===//
156// RegionBranchOpInterface
157//===----------------------------------------------------------------------===//
158
159/// Verify that types match along control flow edges described the given op.
161 auto regionInterface = cast<RegionBranchOpInterface>(op);
162
163 // Verify all control flow edges from region branch points to region
164 // successors.
165 SmallVector<RegionBranchPoint> regionBranchPoints =
166 regionInterface.getAllRegionBranchPoints();
167 for (const RegionBranchPoint &branchPoint : regionBranchPoints) {
169 regionInterface.getSuccessorRegions(branchPoint, successors);
170 for (const RegionSuccessor &successor : successors) {
171 // Helper function that print the region branch point and the region
172 // successor.
173 auto emitRegionEdgeError = [&]() {
175 regionInterface->emitOpError("along control flow edge from ");
176 if (branchPoint.isParent()) {
177 diag << "parent";
178 diag.attachNote(op->getLoc()) << "region branch point";
179 } else {
180 diag << "Operation "
181 << branchPoint.getTerminatorPredecessorOrNull()->getName();
182 diag.attachNote(
183 branchPoint.getTerminatorPredecessorOrNull()->getLoc())
184 << "region branch point";
185 }
186 diag << " to ";
187 if (Region *region = successor.getSuccessor()) {
188 diag << "Region #" << region->getRegionNumber();
189 } else {
190 diag << "parent";
191 }
192 return diag;
193 };
194
195 // Verify number of successor operands and successor inputs.
196 OperandRange succOperands =
197 regionInterface.getSuccessorOperands(branchPoint, successor);
198 ValueRange succInputs = regionInterface.getSuccessorInputs(successor);
199 if (succOperands.size() != succInputs.size()) {
200 return emitRegionEdgeError()
201 << ": region branch point has " << succOperands.size()
202 << " operands, but region successor needs " << succInputs.size()
203 << " inputs";
204 }
205
206 // Verify that the types are compatible.
207 TypeRange succInputTypes = succInputs.getTypes();
208 TypeRange succOperandTypes = succOperands.getTypes();
209 for (const auto &typesIdx :
210 llvm::enumerate(llvm::zip(succOperandTypes, succInputTypes))) {
211 Type succOperandType = std::get<0>(typesIdx.value());
212 Type succInputType = std::get<1>(typesIdx.value());
213 if (!regionInterface.areTypesCompatible(succOperandType, succInputType))
214 return emitRegionEdgeError()
215 << ": successor operand type #" << typesIdx.index() << " "
216 << succOperandType << " should match successor input type #"
217 << typesIdx.index() << " " << succInputType;
218 }
219 }
220 }
221 return success();
222}
223
224/// Stop condition for `traverseRegionGraph`. The traversal is interrupted if
225/// this function returns "true" for a successor region. The first parameter is
226/// the successor region. The second parameter indicates all already visited
227/// regions.
229
230/// Traverse the region graph starting at `begin`. The traversal is interrupted
231/// if `stopCondition` evaluates to "true" for a successor region. In that case,
232/// this function returns "true". Otherwise, if the traversal was not
233/// interrupted, this function returns "false".
234static bool traverseRegionGraph(Region *begin,
235 StopConditionFn stopConditionFn) {
236 auto op = cast<RegionBranchOpInterface>(begin->getParentOp());
237 LDBG() << "Starting region graph traversal from region #"
238 << begin->getRegionNumber() << " in operation " << op->getName();
239
240 SmallVector<bool> visited(op->getNumRegions(), false);
241 visited[begin->getRegionNumber()] = true;
242 LDBG() << "Initialized visited array with " << op->getNumRegions()
243 << " regions";
244
245 // Retrieve all successors of the region and enqueue them in the worklist.
246 SmallVector<Region *> worklist;
247 auto enqueueAllSuccessors = [&](Region *region) {
248 LDBG() << "Enqueuing successors for region #" << region->getRegionNumber();
249 SmallVector<Attribute> operandAttributes(op->getNumOperands());
250 for (Block &block : *region) {
251 if (block.empty())
252 continue;
253 auto terminator =
254 dyn_cast<RegionBranchTerminatorOpInterface>(block.back());
255 if (!terminator)
256 continue;
258 operandAttributes.resize(terminator->getNumOperands());
259 terminator.getSuccessorRegions(operandAttributes, successors);
260 LDBG() << "Found " << successors.size()
261 << " successors from terminator in block";
262 for (RegionSuccessor successor : successors) {
263 if (!successor.isParent()) {
264 worklist.push_back(successor.getSuccessor());
265 LDBG() << "Added region #"
266 << successor.getSuccessor()->getRegionNumber()
267 << " to worklist";
268 } else {
269 LDBG() << "Skipping parent successor";
270 }
271 }
272 }
273 };
274 enqueueAllSuccessors(begin);
275 LDBG() << "Initial worklist size: " << worklist.size();
276
277 // Process all regions in the worklist via DFS.
278 while (!worklist.empty()) {
279 Region *nextRegion = worklist.pop_back_val();
280 LDBG() << "Processing region #" << nextRegion->getRegionNumber()
281 << " from worklist (remaining: " << worklist.size() << ")";
282
283 if (stopConditionFn(nextRegion, visited)) {
284 LDBG() << "Stop condition met for region #"
285 << nextRegion->getRegionNumber() << ", returning true";
286 return true;
287 }
288 if (!nextRegion->getParentOp()) {
289 llvm::errs() << "Region " << *nextRegion << " has no parent op\n";
290 return false;
291 }
292 if (visited[nextRegion->getRegionNumber()]) {
293 LDBG() << "Region #" << nextRegion->getRegionNumber()
294 << " already visited, skipping";
295 continue;
296 }
297 visited[nextRegion->getRegionNumber()] = true;
298 LDBG() << "Marking region #" << nextRegion->getRegionNumber()
299 << " as visited";
300 enqueueAllSuccessors(nextRegion);
301 }
302
303 LDBG() << "Traversal completed, returning false";
304 return false;
305}
306
307/// Return `true` if region `r` is reachable from region `begin` according to
308/// the RegionBranchOpInterface (by taking a branch).
309static bool isRegionReachable(Region *begin, Region *r) {
310 assert(begin->getParentOp() == r->getParentOp() &&
311 "expected that both regions belong to the same op");
312 return traverseRegionGraph(begin,
313 [&](Region *nextRegion, ArrayRef<bool> visited) {
314 // Interrupt traversal if `r` was reached.
315 return nextRegion == r;
316 });
317}
318
319/// Return `true` if `a` and `b` are in mutually exclusive regions.
320///
321/// 1. Find the first common of `a` and `b` (ancestor) that implements
322/// RegionBranchOpInterface.
323/// 2. Determine the regions `regionA` and `regionB` in which `a` and `b` are
324/// contained.
325/// 3. Check if `regionA` and `regionB` are mutually exclusive. They are
326/// mutually exclusive if they are not reachable from each other as per
327/// RegionBranchOpInterface::getSuccessorRegions.
329 LDBG() << "Checking if operations are in mutually exclusive regions: "
330 << a->getName() << " and " << b->getName();
331
332 assert(a && "expected non-empty operation");
333 assert(b && "expected non-empty operation");
334
335 auto branchOp = a->getParentOfType<RegionBranchOpInterface>();
336 while (branchOp) {
337 LDBG() << "Checking branch operation " << branchOp->getName();
338
339 // Check if b is inside branchOp. (We already know that a is.)
340 if (!branchOp->isProperAncestor(b)) {
341 LDBG() << "Operation b is not inside branchOp, checking next ancestor";
342 // Check next enclosing RegionBranchOpInterface.
343 branchOp = branchOp->getParentOfType<RegionBranchOpInterface>();
344 continue;
345 }
346
347 LDBG() << "Both operations are inside branchOp, finding their regions";
348
349 // b is contained in branchOp. Retrieve the regions in which `a` and `b`
350 // are contained.
351 Region *regionA = nullptr, *regionB = nullptr;
352 for (Region &r : branchOp->getRegions()) {
353 if (r.findAncestorOpInRegion(*a)) {
354 assert(!regionA && "already found a region for a");
355 regionA = &r;
356 LDBG() << "Found region #" << r.getRegionNumber() << " for operation a";
357 }
358 if (r.findAncestorOpInRegion(*b)) {
359 assert(!regionB && "already found a region for b");
360 regionB = &r;
361 LDBG() << "Found region #" << r.getRegionNumber() << " for operation b";
362 }
363 }
364 assert(regionA && regionB && "could not find region of op");
365
366 LDBG() << "Region A: #" << regionA->getRegionNumber() << ", Region B: #"
367 << regionB->getRegionNumber();
368
369 // `a` and `b` are in mutually exclusive regions if both regions are
370 // distinct and neither region is reachable from the other region.
371 bool regionsAreDistinct = (regionA != regionB);
372 bool aNotReachableFromB = !isRegionReachable(regionA, regionB);
373 bool bNotReachableFromA = !isRegionReachable(regionB, regionA);
374
375 LDBG() << "Regions distinct: " << regionsAreDistinct
376 << ", A not reachable from B: " << aNotReachableFromB
377 << ", B not reachable from A: " << bNotReachableFromA;
378
379 bool mutuallyExclusive =
380 regionsAreDistinct && aNotReachableFromB && bNotReachableFromA;
381 LDBG() << "Operations are mutually exclusive: " << mutuallyExclusive;
382
383 return mutuallyExclusive;
384 }
385
386 // Could not find a common RegionBranchOpInterface among a's and b's
387 // ancestors.
388 LDBG() << "No common RegionBranchOpInterface found, operations are not "
389 "mutually exclusive";
390 return false;
391}
392
393bool RegionBranchOpInterface::isRepetitiveRegion(unsigned index) {
394 LDBG() << "Checking if region #" << index << " is repetitive in operation "
395 << getOperation()->getName();
396
397 Region *region = &getOperation()->getRegion(index);
398 bool isRepetitive = isRegionReachable(region, region);
399
400 LDBG() << "Region #" << index << " is repetitive: " << isRepetitive;
401 return isRepetitive;
402}
403
404bool RegionBranchOpInterface::hasLoop() {
405 LDBG() << "Checking if operation " << getOperation()->getName()
406 << " has loops";
407
408 SmallVector<RegionSuccessor> entryRegions;
409 getSuccessorRegions(RegionBranchPoint::parent(), entryRegions);
410 LDBG() << "Found " << entryRegions.size() << " entry regions";
411
412 for (RegionSuccessor successor : entryRegions) {
413 if (!successor.isParent()) {
414 LDBG() << "Checking entry region #"
415 << successor.getSuccessor()->getRegionNumber() << " for loops";
416
417 bool hasLoop =
418 traverseRegionGraph(successor.getSuccessor(),
419 [](Region *nextRegion, ArrayRef<bool> visited) {
420 // Interrupt traversal if the region was already
421 // visited.
422 return visited[nextRegion->getRegionNumber()];
423 });
424
425 if (hasLoop) {
426 LDBG() << "Found loop in entry region #"
427 << successor.getSuccessor()->getRegionNumber();
428 return true;
429 }
430 } else {
431 LDBG() << "Skipping parent successor";
432 }
433 }
434
435 LDBG() << "No loops found in operation";
436 return false;
437}
438
440RegionBranchOpInterface::getSuccessorOperands(RegionBranchPoint src,
441 RegionSuccessor dest) {
442 if (src.isParent())
443 return getEntrySuccessorOperands(dest);
444 return src.getTerminatorPredecessorOrNull().getSuccessorOperands(dest);
445}
446
448RegionBranchOpInterface::getNonSuccessorInputs(RegionSuccessor successor) {
449 SmallVector<Value> results = llvm::to_vector(
450 successor.isParent()
451 ? ValueRange(getOperation()->getResults())
452 : ValueRange(successor.getSuccessor()->getArguments()));
453 ValueRange successorInputs = getSuccessorInputs(successor);
454 if (!successorInputs.empty()) {
455 unsigned inputBegin =
456 successor.isParent()
457 ? cast<OpResult>(successorInputs.front()).getResultNumber()
458 : cast<BlockArgument>(successorInputs.front()).getArgNumber();
459 results.erase(results.begin() + inputBegin,
460 results.begin() + inputBegin + successorInputs.size());
461 }
462 return results;
463}
464
466 return MutableArrayRef<OpOperand>(operands.getBase(), operands.size());
467}
468
469static void
470getSuccessorOperandInputMapping(RegionBranchOpInterface branchOp,
472 RegionBranchPoint src) {
474 branchOp.getSuccessorRegions(src, successors);
475 for (RegionSuccessor dst : successors) {
476 OperandRange operands = branchOp.getSuccessorOperands(src, dst);
477 assert(operands.size() == branchOp.getSuccessorInputs(dst).size() &&
478 "expected the same number of operands and inputs");
479 for (const auto &[operand, input] : llvm::zip_equal(
480 operandsToOpOperands(operands), branchOp.getSuccessorInputs(dst)))
481 mapping[&operand].push_back(input);
482 }
483}
484void RegionBranchOpInterface::getSuccessorOperandInputMapping(
486 std::optional<RegionBranchPoint> src) {
487 if (src.has_value()) {
488 ::getSuccessorOperandInputMapping(*this, mapping, src.value());
489 } else {
490 // No region branch point specified: populate the mapping for all possible
491 // region branch points.
492 for (RegionBranchPoint branchPoint : getAllRegionBranchPoints())
493 ::getSuccessorOperandInputMapping(*this, mapping, branchPoint);
494 }
495}
496
498 const RegionBranchSuccessorMapping &operandToInputs) {
500 for (const auto &[operand, inputs] : operandToInputs) {
501 for (Value input : inputs)
502 inputToOperands[input].push_back(operand);
503 }
504 return inputToOperands;
505}
506
507void RegionBranchOpInterface::getSuccessorInputOperandMapping(
509 RegionBranchSuccessorMapping operandToInputs;
510 getSuccessorOperandInputMapping(operandToInputs);
511 mapping = invertRegionBranchSuccessorMapping(operandToInputs);
512}
513
515RegionBranchOpInterface::getAllRegionBranchPoints() {
517 branchPoints.push_back(RegionBranchPoint::parent());
518 for (Region &region : getOperation()->getRegions()) {
519 for (Block &block : region) {
520 if (block.empty())
521 continue;
522 if (auto terminator =
523 dyn_cast<RegionBranchTerminatorOpInterface>(block.back()))
524 branchPoints.push_back(RegionBranchPoint(terminator));
525 }
526 }
527 return branchPoints;
528}
529
531 LDBG() << "Finding enclosing repetitive region for operation "
532 << op->getName();
533
534 while (Region *region = op->getParentRegion()) {
535 LDBG() << "Checking region #" << region->getRegionNumber()
536 << " in operation " << region->getParentOp()->getName();
537
538 op = region->getParentOp();
539 if (auto branchOp = dyn_cast<RegionBranchOpInterface>(op)) {
540 LDBG()
541 << "Found RegionBranchOpInterface, checking if region is repetitive";
542 if (branchOp.isRepetitiveRegion(region->getRegionNumber())) {
543 LDBG() << "Found repetitive region #" << region->getRegionNumber();
544 return region;
545 }
546 } else {
547 LDBG() << "Parent operation does not implement RegionBranchOpInterface";
548 }
549 }
550
551 LDBG() << "No enclosing repetitive region found";
552 return nullptr;
553}
554
556 LDBG() << "Finding enclosing repetitive region for value";
557
558 Region *region = value.getParentRegion();
559 while (region) {
560 LDBG() << "Checking region #" << region->getRegionNumber()
561 << " in operation " << region->getParentOp()->getName();
562
563 Operation *op = region->getParentOp();
564 if (auto branchOp = dyn_cast<RegionBranchOpInterface>(op)) {
565 LDBG()
566 << "Found RegionBranchOpInterface, checking if region is repetitive";
567 if (branchOp.isRepetitiveRegion(region->getRegionNumber())) {
568 LDBG() << "Found repetitive region #" << region->getRegionNumber();
569 return region;
570 }
571 } else {
572 LDBG() << "Parent operation does not implement RegionBranchOpInterface";
573 }
574 region = op->getParentRegion();
575 }
576
577 LDBG() << "No enclosing repetitive region found for value";
578 return nullptr;
579}
580
581/// Return "true" if `a` can be used in lieu of `b`, where `b` is a region
582/// successor input and `a` is a "reachable value" of `b`. Reachable values
583/// are successor operand values that are (maybe transitively) forwarded to
584/// `b`.
585static bool isDefinedBefore(Operation *regionBranchOp, Value a, Value b) {
586 assert((b.getDefiningOp() == regionBranchOp ||
587 b.getParentRegion()->getParentOp() == regionBranchOp) &&
588 "b must be a region successor input");
589
590 // Case 1: `a` is defined inside of the region branch op. `a` must be
591 // directly nested in the region branch op. Otherwise, it could not have
592 // been among the reachable values for a region successor input.
593 if (a.getParentRegion()->getParentOp() == regionBranchOp) {
594 // Case 1.1: If `b` is a result of the region branch op, `a` is not in
595 // scope for `b`.
596 // Example:
597 // %b = region_op({
598 // ^bb0(%a1: ...):
599 // %a2 = ...
600 // })
601 if (isa<OpResult>(b))
602 return false;
603
604 // Case 1.2: `b` is an entry block argument of a region. `a` is in scope
605 // for `b` only if it is also an entry block argument of the same region.
606 // Example:
607 // region_op({
608 // ^bb0(%b: ..., %a: ...):
609 // ...
610 // })
611 assert(isa<BlockArgument>(b) && "b must be a block argument");
612 return isa<BlockArgument>(a) && cast<BlockArgument>(a).getOwner() ==
613 cast<BlockArgument>(b).getOwner();
614 }
615
616 // Case 2: `a` is defined outside of the region branch op. In that case, we
617 // can safely assume that `a` was defined before `b`. Otherwise, it could not
618 // be among the reachable values for a region successor input.
619 // Example:
620 // { <- %a1 parent region begins here.
621 // ^bb0(%a1: ...):
622 // %a2 = ...
623 // %b1 = reigon_op({
624 // ^bb1(%b2: ...):
625 // ...
626 // })
627 // }
628 return true;
629}
630
631/// Compute all non-successor-input values that a successor input could have
632/// based on the given successor input to successor operand mapping.
633///
634/// Starting with the given value, trace back all predecessor values (i.e.,
635/// preceding successor operands) and add them to the set of reachable values.
636/// If the successor operand is again a successor input, do not add it to the
637/// result set, but instead continue the traversal.
638///
639/// If `maxReachableValues` is set, the traversal is aborted early and
640/// `failure` is returned as soon as the number of reachable values exceeds
641/// the limit. Otherwise, `success` is returned and the result set contains
642/// all reachable values.
643///
644/// Example 1:
645/// %r = scf.if ... {
646/// scf.yield %a : ...
647/// } else {
648/// scf.yield %b : ...
649/// }
650/// reachableValues(%r) = {%a, %b}
651///
652/// Example 2:
653/// %r = scf.for ... iter_args(%arg0 = %0) -> ... {
654/// scf.yield %arg0 : ...
655/// }
656/// reachableValues(%arg0) = {%0}
657/// reachableValues(%r) = {%0}
658///
659/// Example 3:
660/// %r = scf.for ... iter_args(%arg0 = %0) -> ... {
661/// ...
662/// scf.yield %1 : ...
663/// }
664/// reachableValues(%arg0) = {%0, %1}
665/// reachableValues(%r) = {%0, %1}
667 llvm::SmallDenseSet<Value> &result, Value value,
668 const RegionBranchInverseSuccessorMapping &inputToOperands,
669 std::optional<unsigned> maxReachableValues = std::nullopt) {
670 assert(inputToOperands.contains(value) && "value must be a successor input");
671 llvm::SmallDenseSet<Value> visited;
672 SmallVector<Value> worklist;
673 worklist.push_back(value);
674 while (!worklist.empty()) {
675 Value next = worklist.pop_back_val();
676 auto it = inputToOperands.find(next);
677 if (it == inputToOperands.end()) {
678 result.insert(next);
679 if (maxReachableValues && result.size() > *maxReachableValues)
680 return failure();
681 continue;
682 }
683 for (OpOperand *operand : it->second)
684 if (visited.insert(operand->get()).second)
685 worklist.push_back(operand->get());
686 }
687 // Note: The result does not contain any successor inputs. (Therefore,
688 // `value` is also guaranteed to be excluded.)
689 return success();
690}
691
692namespace {
693/// Try to make successor inputs dead by replacing their uses with values that
694/// are not successor inputs. This pattern enables additional canonicalization
695/// opportunities for RemoveDeadRegionBranchOpSuccessorInputs.
696///
697/// Example:
698///
699/// %r0, %r1 = scf.for ... iter_args(%arg0 = %0, %arg1 = %1) -> ... {
700/// scf.yield %arg1, %arg1 : ...
701/// }
702/// use(%r0, %r1)
703///
704/// reachableValues(%r0) = {%0, %1}
705/// reachableValues(%r1) = {%1} ==> replace uses of %r1 with %1.
706/// reachableValues(%arg0) = {%0, %1}
707/// reachableValues(%arg1) = {%1} ==> replace uses of %arg1 with %1.
708///
709/// IR after pattern application:
710///
711/// %r0, %r1 = scf.for ... iter_args(%arg0 = %0, %arg1 = %1) -> ... {
712/// scf.yield %1, %1 : ...
713/// }
714/// use(%r0, %1)
715///
716/// Note that %r1 and %arg1 are dead now. The IR can now be further
717/// canonicalized by RemoveDeadRegionBranchOpSuccessorInputs.
718struct MakeRegionBranchOpSuccessorInputsDead : public RewritePattern {
719 MakeRegionBranchOpSuccessorInputsDead(MLIRContext *context, StringRef name,
720 PatternBenefit benefit = 1)
721 : RewritePattern(name, benefit, context) {}
722
723 LogicalResult matchAndRewrite(Operation *op,
724 PatternRewriter &rewriter) const override {
725 assert(!op->hasTrait<OpTrait::IsIsolatedFromAbove>() &&
726 "isolated-from-above ops are not supported");
727
728 // Compute the mapping of successor inputs to successor operands.
729 auto regionBranchOp = cast<RegionBranchOpInterface>(op);
731 regionBranchOp.getSuccessorInputOperandMapping(inputToOperands);
732
733 // Try to replace the uses of each successor input one-by-one.
734 bool changed = false;
735 for (Value value : inputToOperands.keys()) {
736 // Nothing to do for successor inputs that are already dead.
737 if (value.use_empty())
738 continue;
739 // Nothing to do for successor inputs that may have multiple reachable
740 // values.
741 llvm::SmallDenseSet<Value> reachableValues;
743 reachableValues, value, inputToOperands,
744 /*maxReachableValues=*/1)) ||
745 reachableValues.empty())
746 continue;
747 assert(*reachableValues.begin() != value &&
748 "successor inputs are supposed to be excluded");
749 // Do not replace `value` with the found reachable value if doing so
750 // would violate dominance. Example:
751 // %r = scf.execute_region ... {
752 // %a = ...
753 // scf.yield %a : ...
754 // }
755 // use(%r)
756 // In the above example, reachableValues(%r) = {%a}, but %a cannot be
757 // used as a replacement for %r due to dominance / scope.
758 if (!isDefinedBefore(regionBranchOp, *reachableValues.begin(), value))
759 continue;
760 rewriter.replaceAllUsesWith(value, *reachableValues.begin());
761 changed = true;
762 }
763 return success(changed);
764 }
765};
766
767/// Lookup a bit vector in the given mapping (DenseMap). If the key was not
768/// found, create a new bit vector with the given size and initialize it with
769/// false.
770template <typename MappingTy, typename KeyTy>
771static BitVector &lookupOrCreateBitVector(MappingTy &mapping, KeyTy key,
772 unsigned size) {
773 return mapping.try_emplace(key, size, false).first->second;
774}
775
776/// Compute tied successor inputs. Tied successor inputs are successor inputs
777/// that come as a set. If you erase one value from a set, you must erase all
778/// values from the set. Otherwise, the op would become structurally invalid.
779/// Each successor input appears in exactly one set.
780///
781/// Example:
782/// %r0, %r1 = scf.for ... iter_args(%arg0 = %0, %arg1 = %1) -> ... {
783/// ...
784/// }
785/// There are two sets: {{%r0, %arg0}, {%r1, %arg1}}.
786static llvm::EquivalenceClasses<Value> computeTiedSuccessorInputs(
787 const RegionBranchSuccessorMapping &operandToInputs) {
788 llvm::EquivalenceClasses<Value> tiedSuccessorInputs;
789 for (const auto &[operand, inputs] : operandToInputs) {
790 assert(!inputs.empty() && "expected non-empty inputs");
791 Value firstInput = inputs.front();
792 tiedSuccessorInputs.insert(firstInput);
793 for (Value nextInput : llvm::drop_begin(inputs)) {
794 // As we explore more successor operand to successor input mappings,
795 // existing sets may get merged.
796 tiedSuccessorInputs.unionSets(firstInput, nextInput);
797 }
798 }
799 return tiedSuccessorInputs;
800}
801
802/// Remove dead successor inputs from region branch ops. A successor input is
803/// dead if it has no uses. Successor inputs come in sets of tied values: if
804/// you remove one value from a set, you must remove all values from the set.
805/// Furthermore, successor operands must also be removed. (Op operands are not
806/// part of the set, but the set is built based on the successor operand to
807/// successor input mapping.)
808///
809/// Example 1:
810/// %r0, %r1 = scf.for ... iter_args(%arg0 = %0, %arg1 = %1) -> ... {
811/// scf.yield %0, %arg1 : ...
812/// }
813/// use(%0, %1)
814///
815/// There are two sets: {{%r0, %arg0}, {%r1, %arg1}}. All values in the first
816/// set are dead, so %arg0 and %r0 can be removed, but not %r1 and %arg1. The
817/// resulting IR is as follows:
818///
819/// %r1 = scf.for ... iter_args(%arg1 = %1) -> ... {
820/// scf.yield %arg1 : ...
821/// }
822/// use(%0, %1)
823///
824/// Example 2:
825/// %r0, %r1 = scf.while (%arg0 = %0) {
826/// scf.condition(...) %arg0, %arg0 : ...
827/// } do {
828/// ^bb0(%arg1: ..., %arg2: ...):
829/// scf.yield %arg1 : ...
830/// }
831/// There are three sets: {{%r0, %arg1}, {%r1, %arg2}, {%r0}}.
832///
833/// Example 3:
834/// %r1, %r2 = scf.if ... {
835/// scf.yield %0, %1 : ...
836/// } else {
837/// scf.yield %2, %3 : ...
838/// }
839/// There are two sets: {{%r1}, {%r2}}. Each set has one value, so there each
840/// value can be removed independently of the other values.
841struct RemoveDeadRegionBranchOpSuccessorInputs : public RewritePattern {
842 RemoveDeadRegionBranchOpSuccessorInputs(MLIRContext *context, StringRef name,
843 PatternBenefit benefit = 1)
844 : RewritePattern(name, benefit, context) {}
845
846 LogicalResult matchAndRewrite(Operation *op,
847 PatternRewriter &rewriter) const override {
848 assert(!op->hasTrait<OpTrait::IsIsolatedFromAbove>() &&
849 "isolated-from-above ops are not supported");
850
851 // Compute tied values: values that must come as a set. If you remove one,
852 // you must remove all. If a successor op operand is forwarded to two
853 // successor inputs %a and %b, both %a and %b are in the same set.
854 auto regionBranchOp = cast<RegionBranchOpInterface>(op);
855 RegionBranchSuccessorMapping operandToInputs;
856 regionBranchOp.getSuccessorOperandInputMapping(operandToInputs);
857 llvm::EquivalenceClasses<Value> tiedSuccessorInputs =
858 computeTiedSuccessorInputs(operandToInputs);
859
860 // Determine which values to remove and group them by block and operation.
861 SmallVector<Value> valuesToRemove;
862 DenseMap<Block *, BitVector> blockArgsToRemove;
863 BitVector resultsToRemove(regionBranchOp->getNumResults(), false);
864 // Iterate over all sets of tied successor inputs.
865 for (auto it = tiedSuccessorInputs.begin(), e = tiedSuccessorInputs.end();
866 it != e; ++it) {
867 if (!(*it)->isLeader())
868 continue;
869
870 // Value can be removed if it is dead and all other tied values are also
871 // dead.
872 bool allDead = true;
873 for (auto memberIt = tiedSuccessorInputs.member_begin(**it);
874 memberIt != tiedSuccessorInputs.member_end(); ++memberIt) {
875 // Iterate over all values in the set and check their liveness.
876 if (!memberIt->use_empty()) {
877 allDead = false;
878 break;
879 }
880 }
881 if (!allDead)
882 continue;
883
884 // The entire set is dead. Group values by block and operation to
885 // simplify removal.
886 for (auto memberIt = tiedSuccessorInputs.member_begin(**it);
887 memberIt != tiedSuccessorInputs.member_end(); ++memberIt) {
888 if (auto arg = dyn_cast<BlockArgument>(*memberIt)) {
889 // Set blockArgsToRemove[block][arg_number] = true.
890 BitVector &vector =
891 lookupOrCreateBitVector(blockArgsToRemove, arg.getOwner(),
892 arg.getOwner()->getNumArguments());
893 vector.set(arg.getArgNumber());
894 } else {
895 // Set resultsToRemove[result_number] = true.
896 OpResult result = cast<OpResult>(*memberIt);
897 assert(result.getDefiningOp() == regionBranchOp &&
898 "result must be a region branch op result");
899 resultsToRemove.set(result.getResultNumber());
900 }
901 valuesToRemove.push_back(*memberIt);
902 }
903 }
904
905 if (valuesToRemove.empty())
906 return rewriter.notifyMatchFailure(op, "no values to remove");
907
908 // Find operands that must be removed together with the values.
909 RegionBranchInverseSuccessorMapping inputsToOperands =
910 invertRegionBranchSuccessorMapping(operandToInputs);
912 for (Value value : valuesToRemove) {
913 for (OpOperand *operand : inputsToOperands[value]) {
914 // Set operandsToRemove[op][operand_number] = true.
915 BitVector &vector =
916 lookupOrCreateBitVector(operandsToRemove, operand->getOwner(),
917 operand->getOwner()->getNumOperands());
918 vector.set(operand->getOperandNumber());
919 }
920 }
921
922 // Erase operands.
923 for (auto &pair : operandsToRemove) {
924 Operation *op = pair.first;
925 BitVector &operands = pair.second;
926 rewriter.modifyOpInPlace(op, [&]() { op->eraseOperands(operands); });
927 }
928
929 // Erase block arguments.
930 for (auto &pair : blockArgsToRemove) {
931 Block *block = pair.first;
932 BitVector &blockArg = pair.second;
933 rewriter.modifyOpInPlace(block->getParentOp(),
934 [&]() { block->eraseArguments(blockArg); });
935 }
936
937 // Erase op results.
938 if (resultsToRemove.any())
939 rewriter.eraseOpResults(regionBranchOp, resultsToRemove);
940
941 return success();
942 }
943};
944
945/// Return the "owner" of a value: the parent block for block arguments, the
946/// defining op for op results.
947static void *getOwnerOfValue(Value value) {
948 if (auto arg = dyn_cast<BlockArgument>(value))
949 return arg.getOwner();
950 return value.getDefiningOp();
951}
952
953/// Get the block argument or op result number of the given value.
954static unsigned getArgOrResultNumber(Value value) {
955 if (auto opResult = llvm::dyn_cast<OpResult>(value))
956 return opResult.getResultNumber();
957 return llvm::cast<BlockArgument>(value).getArgNumber();
958}
959
960/// Find duplicate successor inputs and make all dead except for one. Two
961/// successor inputs are "duplicate" if their corresponding successor operands
962/// have the same values. This pattern enables additional canonicalization
963/// opportunities for RemoveDeadRegionBranchOpSuccessorInputs.
964///
965/// Example:
966/// %r0, %r1 = scf.for ... iter_args(%arg0 = %0, %arg1 = %0) -> ... {
967/// use(%arg0, %arg1)
968/// ...
969/// scf.yield %x, %x : ...
970/// }
971/// use(%r0, %r1)
972///
973/// Operands of successor input %r0: [%0, %x]
974/// Operands of successor input %r1: [%0, %x] ==> DUPLICATE!
975/// Replace %r1 with %r0.
976///
977/// Operands of successor input %arg0: [%0, %x]
978/// Operands of successor input %arg1: [%0, %x] ==> DUPLICATE!
979/// Replace %arg1 with %arg0. (We have to make sure that we make same decision
980/// as for the other tied successor inputs above. Otherwise, a set of tied
981/// successor inputs may not become entirely dead.)
982///
983/// The resulting IR is as follows:
984/// %r0, %r1 = scf.for ... iter_args(%arg0 = %0, %arg1 = %0) -> ... {
985/// use(%arg0, %arg0)
986/// ...
987/// scf.yield %x, %x : ...
988/// }
989/// use(%r0, %r0) // Note: We don't want use(%r1, %r1), which is also correct,
990/// // but does not help with further canonicalizations.
991struct RemoveDuplicateSuccessorInputUses : public RewritePattern {
992 RemoveDuplicateSuccessorInputUses(MLIRContext *context, StringRef name,
993 PatternBenefit benefit = 1)
994 : RewritePattern(name, benefit, context) {}
995
996 LogicalResult matchAndRewrite(Operation *op,
997 PatternRewriter &rewriter) const override {
998 assert(!op->hasTrait<OpTrait::IsIsolatedFromAbove>() &&
999 "isolated-from-above ops are not supported");
1000
1001 // Collect all successor inputs and sort them. When dropping the uses of a
1002 // successor input, we'd like to also drop the uses of the same tied
1003 // successor inputs. Otherwise, a set of tied successor inputs may not
1004 // become entirely dead, which is required for
1005 // RemoveDeadRegionBranchOpSuccessorInputs to be able to erase them.
1006 // (Sorting is not required for correctness.)
1007 auto regionBranchOp = cast<RegionBranchOpInterface>(op);
1008 RegionBranchInverseSuccessorMapping inputsToOperands;
1009 regionBranchOp.getSuccessorInputOperandMapping(inputsToOperands);
1010 SmallVector<Value> inputs = llvm::to_vector(inputsToOperands.keys());
1011 llvm::sort(inputs, [](Value a, Value b) {
1012 return getArgOrResultNumber(a) < getArgOrResultNumber(b);
1013 });
1014
1015 // Group inputs by their operand "signature" to find duplicates. Two
1016 // successor inputs are duplicates if each predecessor (region branch point)
1017 // forwards the same value for both. Let n = number of successor inputs and
1018 // k = number of predecessors per input. Instead of comparing every pair of
1019 // inputs (O(n² * k)), we build a signature for each input and group them
1020 // via a std::map.
1021 //
1022 // A signature is a sorted list of (predecessor, forwarded value) pairs.
1023 // Within each group, all but the first (canonical) input are replaced with
1024 // the canonical one.
1025 using SigEntry = std::pair<Operation *, Value>;
1026 using Signature = SmallVector<SigEntry>;
1027 auto sigEntryLess = [](const SigEntry &a, const SigEntry &b) {
1028 if (a.first != b.first)
1029 return a.first < b.first;
1030 return a.second.getAsOpaquePointer() < b.second.getAsOpaquePointer();
1031 };
1032 // The map key is (signature, owner). Two inputs are duplicates only if they
1033 // have the same signature AND the same owner (block or defining op). This
1034 // ensures we track one canonical per owner group.
1035 using MapKey = std::pair<Signature, void *>;
1036 auto mapKeyLess = [&](const MapKey &a, const MapKey &b) {
1037 if (a.second != b.second)
1038 return a.second < b.second;
1039 return std::lexicographical_compare(a.first.begin(), a.first.end(),
1040 b.first.begin(), b.first.end(),
1041 sigEntryLess);
1042 };
1043 std::map<MapKey, Value, decltype(mapKeyLess)> signatureToCanonical(
1044 mapKeyLess);
1045 bool changed = false;
1046 // Total complexity: O(n * k * max(log k, log n)). For each input, sorting
1047 // the signature costs O(k log k) and the std::map lookup costs O(k log n).
1048 for (Value input : inputs) {
1049 // Gather the predecessor value for each predecessor (region branch
1050 // point) and sort them to form this input's signature.
1051 Signature sig;
1052 for (OpOperand *operand : inputsToOperands[input])
1053 sig.emplace_back(operand->getOwner(), operand->get());
1054 llvm::sort(sig, sigEntryLess);
1055
1056 void *owner = getOwnerOfValue(input);
1057
1058 auto [it, inserted] = signatureToCanonical.try_emplace(
1059 MapKey{std::move(sig), owner}, input);
1060 if (!inserted) {
1061 Value canonical = it->second;
1062 // Nothing to do if input is already dead.
1063 if (input.use_empty())
1064 continue;
1065 rewriter.replaceAllUsesWith(input, canonical);
1066 changed = true;
1067 }
1068 }
1069 return success(changed);
1070 }
1071};
1072
1073/// Given a range of values, return a vector of attributes of the same size,
1074/// where the i-th attribute is the constant value of the i-th value. If a
1075/// value is not constant, the corresponding attribute is null.
1076static SmallVector<Attribute> extractConstants(ValueRange values) {
1077 return llvm::map_to_vector(values, [](Value value) {
1078 Attribute attr;
1079 matchPattern(value, m_Constant(&attr));
1080 return attr;
1081 });
1082}
1083
1084/// Return all successor regions when branching from the given region branch
1085/// point. This helper functions extracts all constant operand values and
1086/// passes them to the `RegionBranchOpInterface`.
1088getSuccessorRegionsWithAttrs(RegionBranchOpInterface op,
1089 RegionBranchPoint point) {
1091 if (point.isParent()) {
1092 op.getEntrySuccessorRegions(extractConstants(op->getOperands()),
1093 successors);
1094 return successors;
1095 }
1096 RegionBranchTerminatorOpInterface terminator =
1098 terminator.getSuccessorRegions(extractConstants(terminator->getOperands()),
1099 successors);
1100 return successors;
1101}
1102
1103/// Find the single acyclic path through the given region branch op. Return an
1104/// empty vector if no such path or multiple such paths exist.
1105///
1106/// Example: "scf.if %true" has a single path: parent => then_region => parent
1107///
1108/// Example: "scf.if ???" has multiple paths:
1109/// (1) parent => then_region => parent
1110/// (2) parent => else_region => parent
1111///
1112/// Example: "scf.while with scf.condition(%false)" has a single path:
1113/// parent => before_region => parent
1114///
1115/// Example: "scf.for with 0 iterations" has a single path: parent => parent
1116///
1117/// Note: Each path starts and ends with "parent". The "parent" at the beginning
1118/// of the path is omitted from the result.
1119///
1120/// Note: This function also returns an "empty" path when a region with multiple
1121/// blocks was found.
1123computeSingleAcyclicRegionBranchPath(RegionBranchOpInterface op) {
1124 llvm::SmallDenseSet<Region *> visited;
1126
1127 // Path starts with "parent".
1129 do {
1130 SmallVector<RegionSuccessor> successors =
1131 getSuccessorRegionsWithAttrs(op, next);
1132 if (successors.size() != 1) {
1133 // There are multiple region successors. I.e., there are multiple paths
1134 // through the region branch op.
1135 return {};
1136 }
1137 path.push_back(successors.front());
1138 if (successors.front().isParent()) {
1139 // Found path that ends with "parent".
1140 return path;
1141 }
1142 Region *region = successors.front().getSuccessor();
1143 if (!region->hasOneBlock()) {
1144 // Entering a region with multiple blocks. Such regions are not supported
1145 // at the moment.
1146 return {};
1147 }
1148 if (!visited.insert(region).second) {
1149 // We have already visited this region. I.e., we have found a cycle.
1150 return {};
1151 }
1152 auto terminator =
1153 dyn_cast<RegionBranchTerminatorOpInterface>(&region->front().back());
1154 if (!terminator) {
1155 // Region has no RegionBranchTerminatorOpInterface terminator. E.g., the
1156 // terminator could be a "ub.unreachable" op. Such IR is not supported.
1157 return {};
1158 }
1159 next = RegionBranchPoint(terminator);
1160 } while (true);
1161 llvm_unreachable("expected to return from loop");
1162}
1163
1164/// Inline the body of the matched region branch op into the enclosing block if
1165/// there is exactly one acyclic path through the region branch op, starting
1166/// from "parent", and if that path ends with "parent".
1167///
1168/// Example: This pattern can inline "scf.for" operations that are guaranteed to
1169/// have a single iteration, as indicated by the region branch path "parent =>
1170/// region => parent". "scf.for" operations have a non-successor-input: the loop
1171/// induction variable. Non-successor-input values have op-specific semantics
1172/// and cannot be reasoned about through the `RegionBranchOpInterface`. A
1173/// replacement value for non-successor-inputs is injected by the user-specified
1174/// lambda: in the case of the loop induction variable of an "scf.for", the
1175/// lower bound of the loop is used as a replacement value.
1176///
1177/// Before pattern application:
1178/// %r = scf.for %iv = %c5 to %c6 step %c1 iter_args(%arg0 = %0) {
1179/// %1 = "producer"(%arg0, %iv)
1180/// scf.yield %1
1181/// }
1182/// "user"(%r)
1183///
1184/// After pattern application:
1185/// %1 = "producer"(%0, %c5)
1186/// "user"(%1)
1187///
1188/// This pattern is limited to the following cases:
1189/// - Only regions with a single block are supported. This could be generalized.
1190/// - Region branch ops with side effects are not supported. (Recursive side
1191/// effects are fine.)
1192///
1193/// Note: This pattern queries the region dataflow from the
1194/// `RegionBranchOpInterface`. Replacement values are for block arguments / op
1195/// results are determined based on region dataflow. In case of
1196/// non-successor-inputs (whose values are not modeled by the
1197/// `RegionBranchOpInterface`), a user-specified lambda is queried.
1198struct InlineRegionBranchOp : public RewritePattern {
1199 InlineRegionBranchOp(MLIRContext *context, StringRef name,
1201 PatternMatcherFn matcherFn, PatternBenefit benefit = 1)
1202 : RewritePattern(name, benefit, context), replBuilderFn(replBuilderFn),
1203 matcherFn(matcherFn) {}
1204
1205 LogicalResult matchAndRewrite(Operation *op,
1206 PatternRewriter &rewriter) const override {
1207 // Check if the pattern is applicable to the given operation.
1208 if (failed(matcherFn(op)))
1209 return rewriter.notifyMatchFailure(op, "pattern not applicable");
1210
1211 // Patterns without recursive memory effects could have side effects, so
1212 // it is not safe to fold such ops away.
1213 if (!op->hasTrait<OpTrait::HasRecursiveMemoryEffects>())
1214 return rewriter.notifyMatchFailure(
1215 op, "pattern not applicable to ops without recursive memory effects");
1216
1217 // Find the single acyclic path through the region branch op.
1218 auto regionBranchOp = cast<RegionBranchOpInterface>(op);
1219 SmallVector<RegionSuccessor> path =
1220 computeSingleAcyclicRegionBranchPath(regionBranchOp);
1221 if (path.empty())
1222 return rewriter.notifyMatchFailure(
1223 op, "failed to find acyclic region branch path");
1224
1225 // Inline all regions on the path into the enclosing block.
1226 rewriter.setInsertionPoint(op);
1227 ArrayRef remainingPath = path;
1228 SmallVector<Value> successorOperands = llvm::to_vector(
1229 regionBranchOp.getEntrySuccessorOperands(remainingPath.front()));
1230 while (!remainingPath.empty()) {
1231 RegionSuccessor nextSuccessor = remainingPath.consume_front();
1232 ValueRange successorInputs =
1233 regionBranchOp.getSuccessorInputs(nextSuccessor);
1234 assert(successorInputs.size() == successorOperands.size() &&
1235 "size mismatch");
1236 // Find the index of the first block argument / op result that is a
1237 // succesor input.
1238 unsigned firstSuccessorInputIdx = 0;
1239 if (!successorInputs.empty())
1240 firstSuccessorInputIdx =
1241 nextSuccessor.isParent()
1242 ? cast<OpResult>(successorInputs.front()).getResultNumber()
1243 : cast<BlockArgument>(successorInputs.front()).getArgNumber();
1244 // Query the total number of block arguments / op results.
1245 unsigned numValues =
1246 nextSuccessor.isParent()
1247 ? op->getNumResults()
1248 : nextSuccessor.getSuccessor()->getNumArguments();
1249 // Compute replacement values for all block arguments / op results.
1250 SmallVector<Value> replacements;
1251 // Helper function to get the i-th block argument / op result.
1252 auto getValue = [&](unsigned idx) {
1253 return nextSuccessor.isParent()
1254 ? Value(op->getResult(idx))
1255 : Value(nextSuccessor.getSuccessor()->getArgument(idx));
1256 };
1257 // Compute replacement values for all non-successor-input values that
1258 // precede the first successor input.
1259 for (unsigned i = 0; i < firstSuccessorInputIdx; ++i)
1260 replacements.push_back(
1261 replBuilderFn(rewriter, op->getLoc(), getValue(i)));
1262 // Use the successor operands of the predecessor as replacement values for
1263 // the successor inputs.
1264 llvm::append_range(replacements, successorOperands);
1265 // Compute replacement values for all block arguments / op results that
1266 // succeed the first successor input.
1267 for (unsigned i = replacements.size(); i < numValues; ++i)
1268 replacements.push_back(
1269 replBuilderFn(rewriter, op->getLoc(), getValue(i)));
1270 if (nextSuccessor.isParent()) {
1271 // The path ends with "parent". Replace the region branch op with the
1272 // computed replacement values.
1273 assert(remainingPath.empty() && "expected that the path ended");
1274 rewriter.replaceOp(op, replacements);
1275 return success();
1276 }
1277 // We are inside of a region: query the successor operands from the
1278 // terminator, inline the region into the enclosing block, and erase the
1279 // terminator.
1280 auto terminator = cast<RegionBranchTerminatorOpInterface>(
1281 &nextSuccessor.getSuccessor()->front().back());
1282 rewriter.inlineBlockBefore(&nextSuccessor.getSuccessor()->front(),
1283 op->getBlock(), op->getIterator(),
1284 replacements);
1285 successorOperands = llvm::to_vector(
1286 terminator.getSuccessorOperands(remainingPath.front()));
1287 rewriter.eraseOp(terminator);
1288 }
1289
1290 llvm_unreachable("expected that paths ends with parent");
1291 }
1292
1294 PatternMatcherFn matcherFn;
1295};
1296} // namespace
1297
1299 RewritePatternSet &patterns, StringRef opName, PatternBenefit benefit) {
1300 patterns.add<MakeRegionBranchOpSuccessorInputsDead,
1301 RemoveDuplicateSuccessorInputUses,
1302 RemoveDeadRegionBranchOpSuccessorInputs>(patterns.getContext(),
1303 opName, benefit);
1304}
1305
1307 RewritePatternSet &patterns, StringRef opName,
1309 PatternMatcherFn matcherFn, PatternBenefit benefit) {
1310 patterns.add<InlineRegionBranchOp>(patterns.getContext(), opName,
1311 replBuilderFn, matcherFn, benefit);
1312}
return success()
static LogicalResult verifyWeights(Operation *op, llvm::ArrayRef< int32_t > weights, std::size_t expectedWeightsNum, llvm::StringRef weightAnchorName, llvm::StringRef weightRefName)
static bool isDefinedBefore(Operation *regionBranchOp, Value a, Value b)
Return "true" if a can be used in lieu of b, where b is a region successor input and a is a "reachabl...
static void getSuccessorOperandInputMapping(RegionBranchOpInterface branchOp, RegionBranchSuccessorMapping &mapping, RegionBranchPoint src)
static bool traverseRegionGraph(Region *begin, StopConditionFn stopConditionFn)
Traverse the region graph starting at begin.
static LogicalResult computeReachableValuesFromSuccessorInput(llvm::SmallDenseSet< Value > &result, Value value, const RegionBranchInverseSuccessorMapping &inputToOperands, std::optional< unsigned > maxReachableValues=std::nullopt)
Compute all non-successor-input values that a successor input could have based on the given successor...
static RegionBranchInverseSuccessorMapping invertRegionBranchSuccessorMapping(const RegionBranchSuccessorMapping &operandToInputs)
function_ref< bool(Region *, ArrayRef< bool > visited)> StopConditionFn
Stop condition for traverseRegionGraph.
static bool isRegionReachable(Region *begin, Region *r)
Return true if region r is reachable from region begin according to the RegionBranchOpInterface (by t...
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
*if copies could not be generated due to yet unimplemented cases *copyInPlacementStart and copyOutPlacementStart in copyPlacementBlock *specify the insertion points where the incoming copies and outgoing should be inserted(the insertion happens right before the *insertion point). Since `begin` can itself be invalidated due to the memref *rewriting done from this method
static std::string diag(const llvm::Value &value)
static MutableArrayRef< OpOperand > operandsToOpOperands(OperandRange &operands)
static Operation * getOwnerOfValue(Value value)
Attributes are known-constant values of operations.
Definition Attributes.h:25
Block represents an ordered list of Operations.
Definition Block.h:33
BlockArgument getArgument(unsigned i)
Definition Block.h:139
unsigned getNumArguments()
Definition Block.h:138
Operation & back()
Definition Block.h:162
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
Definition Block.cpp:31
IRValueT get() const
Return the current value being used by this operand.
This class represents a diagnostic that is inflight and set to be reported.
This class provides a mutable adaptor for a range of operands.
Definition ValueRange.h:119
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition Builders.h:400
This class represents an operand of an operation.
Definition Value.h:254
This class implements the operand iterators for the Operation class.
Definition ValueRange.h:44
unsigned getBeginOperandIndex() const
Return the operand index of the first element of this range.
type_range getTypes() const
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition Operation.h:775
unsigned getNumSuccessors()
Definition Operation.h:732
void eraseOperands(unsigned idx, unsigned length=1)
Erase the operands starting at position idx and ending at position 'idx'+'length'.
Definition Operation.h:386
Block * getBlock()
Returns the operation block that contains this operation.
Definition Operation.h:231
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition Operation.h:433
unsigned getNumRegions()
Returns the number of regions held by this operation.
Definition Operation.h:700
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:241
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition Operation.h:252
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
Definition Operation.h:256
OperationName getName()
The name of an operation is the key identifier for it.
Definition Operation.h:116
Block * getSuccessor(unsigned index)
Definition Operation.h:734
Region * getParentRegion()
Returns the region to which the instruction belongs.
Definition Operation.h:248
unsigned getNumResults()
Return the number of results held by this operation.
Definition Operation.h:430
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
This class represents a point being branched from in the methods of the RegionBranchOpInterface.
bool isParent() const
Returns true if branching from the parent op.
static constexpr RegionBranchPoint parent()
Returns an instance of RegionBranchPoint representing the parent operation.
RegionBranchTerminatorOpInterface getTerminatorPredecessorOrNull() const
Returns the terminator if branching from a region.
This class represents a successor of a region.
bool isParent() const
Return true if the successor is the parent operation.
Region * getSuccessor() const
Return the given region successor.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition Region.h:26
Block & front()
Definition Region.h:65
BlockArgListType getArguments()
Definition Region.h:81
unsigned getRegionNumber()
Return the number of this region in the parent operation.
Definition Region.cpp:62
unsigned getNumArguments()
Definition Region.h:123
Operation * getParentOp()
Return the parent operation this region is attached to.
Definition Region.h:200
bool hasOneBlock()
Return true if this region has exactly one block.
Definition Region.h:68
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
RewritePattern is the common base class for all DAG to DAG replacements.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
Operation * eraseOpResults(Operation *op, const BitVector &eraseIndices)
Erase the specified results of the given operation.
virtual void inlineBlockBefore(Block *source, Block *dest, Block::iterator before, ValueRange argValues={})
Inline the operations of block 'source' into block 'dest' before the given position.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
virtual void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
This class models how operands are forwarded to block arguments in control flow.
SuccessorOperands(MutableOperandRange forwardedOperands)
Constructs a SuccessorOperands with no produced operands that simply forwards operands to the success...
unsigned getProducedOperandCount() const
Returns the amount of operands that are produced internally by the operation.
unsigned size() const
Returns the amount of operands passed to the successor.
OperandRange getForwardedOperands() const
Get the range of operands that are simply forwarded to the successor.
This class provides an abstraction over the various different ranges of value types.
Definition TypeRange.h:40
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:389
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
void * getAsOpaquePointer() const
Methods for supporting PointerLikeTypeTraits.
Definition Value.h:233
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
Region * getParentRegion()
Return the Region in which this Value is defined.
Definition Value.cpp:39
std::optional< BlockArgument > getBranchSuccessorArgument(const SuccessorOperands &operands, unsigned operandIndex, Block *successor)
Return the BlockArgument corresponding to operand operandIndex in some successor if operandIndex is w...
LogicalResult verifyRegionBranchWeights(Operation *op)
Verify that the region weights attached to an operation implementing WeightedRegiobBranchOpInterface ...
LogicalResult verifyBranchSuccessorOperands(Operation *op, unsigned succNo, const SuccessorOperands &operands)
Verify that the given operands match those of the given successor block.
LogicalResult verifyRegionBranchOpInterface(Operation *op)
Verify that types match along control flow edges described the given op.
LogicalResult verifyBranchWeights(Operation *op)
Verify that the branch weights attached to an operation implementing WeightedBranchOpInterface are co...
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition Matchers.h:490
DenseMap< OpOperand *, SmallVector< Value > > RegionBranchSuccessorMapping
A mapping from successor operands to successor inputs.
std::function< LogicalResult(Operation *)> PatternMatcherFn
Helper function for the region branch op inlining pattern that checks if the pattern is applicable to...
bool insideMutuallyExclusiveRegions(Operation *a, Operation *b)
Return true if a and b are in mutually exclusive regions as per RegionBranchOpInterface.
std::function< Value(OpBuilder &, Location, Value)> NonSuccessorInputReplacementBuilderFn
Helper function for the region branch op inlining pattern that builds replacement values for non-succ...
Region * getEnclosingRepetitiveRegion(Operation *op)
Return the first enclosing region of the given op that may be executed repetitively as per RegionBran...
DenseMap< Value, SmallVector< OpOperand * > > RegionBranchInverseSuccessorMapping
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
Definition LLVM.h:120
void populateRegionBranchOpInterfaceInliningPattern(RewritePatternSet &patterns, StringRef opName, NonSuccessorInputReplacementBuilderFn replBuilderFn=detail::defaultReplBuilderFn, PatternMatcherFn matcherFn=detail::defaultMatcherFn, PatternBenefit benefit=1)
Populate a pattern that inlines the body of region branch ops when there is a single acyclic path thr...
void populateRegionBranchOpInterfaceCanonicalizationPatterns(RewritePatternSet &patterns, StringRef opName, PatternBenefit benefit=1)
Populate canonicalization patterns that simplify successor operands/inputs of region branch operation...
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition Matchers.h:369
llvm::function_ref< Fn > function_ref
Definition LLVM.h:147