MLIR 22.0.0git
SCF.cpp
Go to the documentation of this file.
1//===- SCF.cpp - Structured Control Flow Operations -----------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
20#include "mlir/IR/IRMapping.h"
21#include "mlir/IR/Matchers.h"
22#include "mlir/IR/Operation.h"
30#include "llvm/ADT/MapVector.h"
31#include "llvm/ADT/STLExtras.h"
32#include "llvm/ADT/SmallPtrSet.h"
33#include "llvm/Support/Casting.h"
34#include "llvm/Support/DebugLog.h"
35#include <optional>
36
37using namespace mlir;
38using namespace mlir::scf;
39
40#include "mlir/Dialect/SCF/IR/SCFOpsDialect.cpp.inc"
41
42//===----------------------------------------------------------------------===//
43// SCFDialect Dialect Interfaces
44//===----------------------------------------------------------------------===//
45
46namespace {
47struct SCFInlinerInterface : public DialectInlinerInterface {
48 using DialectInlinerInterface::DialectInlinerInterface;
49 // We don't have any special restrictions on what can be inlined into
50 // destination regions (e.g. while/conditional bodies). Always allow it.
51 bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned,
52 IRMapping &valueMapping) const final {
53 return true;
54 }
55 // Operations in scf dialect are always legal to inline since they are
56 // pure.
57 bool isLegalToInline(Operation *, Region *, bool, IRMapping &) const final {
58 return true;
59 }
60 // Handle the given inlined terminator by replacing it with a new operation
61 // as necessary. Required when the region has only one block.
62 void handleTerminator(Operation *op, ValueRange valuesToRepl) const final {
63 auto retValOp = dyn_cast<scf::YieldOp>(op);
64 if (!retValOp)
65 return;
66
67 for (auto retValue : llvm::zip(valuesToRepl, retValOp.getOperands())) {
68 std::get<0>(retValue).replaceAllUsesWith(std::get<1>(retValue));
69 }
70 }
71};
72} // namespace
73
74//===----------------------------------------------------------------------===//
75// SCFDialect
76//===----------------------------------------------------------------------===//
77
78void SCFDialect::initialize() {
79 addOperations<
80#define GET_OP_LIST
81#include "mlir/Dialect/SCF/IR/SCFOps.cpp.inc"
82 >();
83 addInterfaces<SCFInlinerInterface>();
84 declarePromisedInterface<ConvertToEmitCPatternInterface, SCFDialect>();
85 declarePromisedInterfaces<bufferization::BufferDeallocationOpInterface,
86 InParallelOp, ReduceReturnOp>();
87 declarePromisedInterfaces<bufferization::BufferizableOpInterface, ConditionOp,
88 ExecuteRegionOp, ForOp, IfOp, IndexSwitchOp,
89 ForallOp, InParallelOp, WhileOp, YieldOp>();
90 declarePromisedInterface<ValueBoundsOpInterface, ForOp>();
91}
92
93/// Default callback for IfOp builders. Inserts a yield without arguments.
95 scf::YieldOp::create(builder, loc);
96}
97
98/// Verifies that the first block of the given `region` is terminated by a
99/// TerminatorTy. Reports errors on the given operation if it is not the case.
100template <typename TerminatorTy>
101static TerminatorTy verifyAndGetTerminator(Operation *op, Region &region,
102 StringRef errorMessage) {
103 Operation *terminatorOperation = nullptr;
104 if (!region.empty() && !region.front().empty()) {
105 terminatorOperation = &region.front().back();
106 if (auto yield = dyn_cast_or_null<TerminatorTy>(terminatorOperation))
107 return yield;
108 }
109 auto diag = op->emitOpError(errorMessage);
110 if (terminatorOperation)
111 diag.attachNote(terminatorOperation->getLoc()) << "terminator here";
112 return nullptr;
113}
114
115std::optional<llvm::APSInt> mlir::scf::computeUbMinusLb(Value lb, Value ub,
116 bool isSigned) {
117 llvm::APSInt diff;
118 auto addOp = ub.getDefiningOp<arith::AddIOp>();
119 if (!addOp)
120 return std::nullopt;
121 if ((isSigned && !addOp.hasNoSignedWrap()) ||
122 (!isSigned && !addOp.hasNoUnsignedWrap()))
123 return std::nullopt;
124
125 if (addOp.getLhs() != lb ||
126 !matchPattern(addOp.getRhs(), m_ConstantInt(&diff)))
127 return std::nullopt;
128 return diff;
129}
130
131//===----------------------------------------------------------------------===//
132// ExecuteRegionOp
133//===----------------------------------------------------------------------===//
134
135/// Replaces the given op with the contents of the given single-block region,
136/// using the operands of the block terminator to replace operation results.
138 Region &region, ValueRange blockArgs = {}) {
139 assert(region.hasOneBlock() && "expected single-block region");
140 Block *block = &region.front();
141 Operation *terminator = block->getTerminator();
142 ValueRange results = terminator->getOperands();
143 rewriter.inlineBlockBefore(block, op, blockArgs);
144 rewriter.replaceOp(op, results);
145 rewriter.eraseOp(terminator);
146}
147
148///
149/// (ssa-id `=`)? `execute_region` `->` function-result-type `{`
150/// block+
151/// `}`
152///
153/// Example:
154/// scf.execute_region -> i32 {
155/// %idx = load %rI[%i] : memref<128xi32>
156/// return %idx : i32
157/// }
158///
159ParseResult ExecuteRegionOp::parse(OpAsmParser &parser,
161 if (parser.parseOptionalArrowTypeList(result.types))
162 return failure();
163
164 if (succeeded(parser.parseOptionalKeyword("no_inline")))
165 result.addAttribute("no_inline", parser.getBuilder().getUnitAttr());
166
167 // Introduce the body region and parse it.
168 Region *body = result.addRegion();
169 if (parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{}) ||
170 parser.parseOptionalAttrDict(result.attributes))
171 return failure();
172
173 return success();
174}
175
176void ExecuteRegionOp::print(OpAsmPrinter &p) {
177 p.printOptionalArrowTypeList(getResultTypes());
178 p << ' ';
179 if (getNoInline())
180 p << "no_inline ";
181 p.printRegion(getRegion(),
182 /*printEntryBlockArgs=*/false,
183 /*printBlockTerminators=*/true);
184 p.printOptionalAttrDict((*this)->getAttrs(), /*elidedAttrs=*/{"no_inline"});
185}
186
187LogicalResult ExecuteRegionOp::verify() {
188 if (getRegion().empty())
189 return emitOpError("region needs to have at least one block");
190 if (getRegion().front().getNumArguments() > 0)
191 return emitOpError("region cannot have any arguments");
192 return success();
193}
194
195// Inline an ExecuteRegionOp if it only contains one block.
196// "test.foo"() : () -> ()
197// %v = scf.execute_region -> i64 {
198// %x = "test.val"() : () -> i64
199// scf.yield %x : i64
200// }
201// "test.bar"(%v) : (i64) -> ()
202//
203// becomes
204//
205// "test.foo"() : () -> ()
206// %x = "test.val"() : () -> i64
207// "test.bar"(%x) : (i64) -> ()
208//
209struct SingleBlockExecuteInliner : public OpRewritePattern<ExecuteRegionOp> {
210 using OpRewritePattern<ExecuteRegionOp>::OpRewritePattern;
211
212 LogicalResult matchAndRewrite(ExecuteRegionOp op,
213 PatternRewriter &rewriter) const override {
214 if (!op.getRegion().hasOneBlock() || op.getNoInline())
215 return failure();
216 replaceOpWithRegion(rewriter, op, op.getRegion());
217 return success();
218 }
219};
220
221// Inline an ExecuteRegionOp if its parent can contain multiple blocks.
222// TODO generalize the conditions for operations which can be inlined into.
223// func @func_execute_region_elim() {
224// "test.foo"() : () -> ()
225// %v = scf.execute_region -> i64 {
226// %c = "test.cmp"() : () -> i1
227// cf.cond_br %c, ^bb2, ^bb3
228// ^bb2:
229// %x = "test.val1"() : () -> i64
230// cf.br ^bb4(%x : i64)
231// ^bb3:
232// %y = "test.val2"() : () -> i64
233// cf.br ^bb4(%y : i64)
234// ^bb4(%z : i64):
235// scf.yield %z : i64
236// }
237// "test.bar"(%v) : (i64) -> ()
238// return
239// }
240//
241// becomes
242//
243// func @func_execute_region_elim() {
244// "test.foo"() : () -> ()
245// %c = "test.cmp"() : () -> i1
246// cf.cond_br %c, ^bb1, ^bb2
247// ^bb1: // pred: ^bb0
248// %x = "test.val1"() : () -> i64
249// cf.br ^bb3(%x : i64)
250// ^bb2: // pred: ^bb0
251// %y = "test.val2"() : () -> i64
252// cf.br ^bb3(%y : i64)
253// ^bb3(%z: i64): // 2 preds: ^bb1, ^bb2
254// "test.bar"(%z) : (i64) -> ()
255// return
256// }
257//
258struct MultiBlockExecuteInliner : public OpRewritePattern<ExecuteRegionOp> {
259 using OpRewritePattern<ExecuteRegionOp>::OpRewritePattern;
260
261 LogicalResult matchAndRewrite(ExecuteRegionOp op,
262 PatternRewriter &rewriter) const override {
263 if (op.getNoInline())
264 return failure();
265 if (!isa<FunctionOpInterface, ExecuteRegionOp>(op->getParentOp()))
266 return failure();
267
268 Block *prevBlock = op->getBlock();
269 Block *postBlock = rewriter.splitBlock(prevBlock, op->getIterator());
270 rewriter.setInsertionPointToEnd(prevBlock);
271
272 cf::BranchOp::create(rewriter, op.getLoc(), &op.getRegion().front());
273
274 for (Block &blk : op.getRegion()) {
275 if (YieldOp yieldOp = dyn_cast<YieldOp>(blk.getTerminator())) {
276 rewriter.setInsertionPoint(yieldOp);
277 cf::BranchOp::create(rewriter, yieldOp.getLoc(), postBlock,
278 yieldOp.getResults());
279 rewriter.eraseOp(yieldOp);
280 }
281 }
282
283 rewriter.inlineRegionBefore(op.getRegion(), postBlock);
284 SmallVector<Value> blockArgs;
285
286 for (auto res : op.getResults())
287 blockArgs.push_back(postBlock->addArgument(res.getType(), res.getLoc()));
288
289 rewriter.replaceOp(op, blockArgs);
290 return success();
291 }
292};
293
294// Pattern to eliminate ExecuteRegionOp results which forward external
295// values from the region. In case there are multiple yield operations,
296// all of them must have the same operands in order for the pattern to be
297// applicable.
299 : public OpRewritePattern<ExecuteRegionOp> {
300 using OpRewritePattern<ExecuteRegionOp>::OpRewritePattern;
301
302 LogicalResult matchAndRewrite(ExecuteRegionOp op,
303 PatternRewriter &rewriter) const override {
304 if (op.getNumResults() == 0)
305 return failure();
306
308 for (Block &block : op.getRegion()) {
309 if (auto yield = dyn_cast<scf::YieldOp>(block.getTerminator()))
310 yieldOps.push_back(yield.getOperation());
311 }
312
313 if (yieldOps.empty())
314 return failure();
315
316 // Check if all yield operations have the same operands.
317 auto yieldOpsOperands = yieldOps[0]->getOperands();
318 for (auto *yieldOp : yieldOps) {
319 if (yieldOp->getOperands() != yieldOpsOperands)
320 return failure();
321 }
322
323 SmallVector<Value> externalValues;
324 SmallVector<Value> internalValues;
325 SmallVector<Value> opResultsToReplaceWithExternalValues;
326 SmallVector<Value> opResultsToKeep;
327 for (auto [index, yieldedValue] : llvm::enumerate(yieldOpsOperands)) {
328 if (isValueFromInsideRegion(yieldedValue, op)) {
329 internalValues.push_back(yieldedValue);
330 opResultsToKeep.push_back(op.getResult(index));
331 } else {
332 externalValues.push_back(yieldedValue);
333 opResultsToReplaceWithExternalValues.push_back(op.getResult(index));
334 }
335 }
336 // No yielded external values - nothing to do.
337 if (externalValues.empty())
338 return failure();
339
340 // There are yielded external values - create a new execute_region returning
341 // just the internal values.
342 SmallVector<Type> resultTypes;
343 for (Value value : internalValues)
344 resultTypes.push_back(value.getType());
345 auto newOp =
346 ExecuteRegionOp::create(rewriter, op.getLoc(), TypeRange(resultTypes));
347 newOp->setAttrs(op->getAttrs());
348
349 // Move old op's region to the new operation.
350 rewriter.inlineRegionBefore(op.getRegion(), newOp.getRegion(),
351 newOp.getRegion().end());
352
353 // Replace all yield operations with a new yield operation with updated
354 // results. scf.execute_region must have at least one yield operation.
355 for (auto *yieldOp : yieldOps) {
356 rewriter.setInsertionPoint(yieldOp);
357 rewriter.replaceOpWithNewOp<scf::YieldOp>(yieldOp,
358 ValueRange(internalValues));
359 }
360
361 // Replace the old operation with the external values directly.
362 rewriter.replaceAllUsesWith(opResultsToReplaceWithExternalValues,
363 externalValues);
364 // Replace the old operation's remaining results with the new operation's
365 // results.
366 rewriter.replaceAllUsesWith(opResultsToKeep, newOp.getResults());
367 rewriter.eraseOp(op);
368 return success();
369 }
370
371private:
372 bool isValueFromInsideRegion(Value value,
373 ExecuteRegionOp executeRegionOp) const {
374 // Check if the value is defined within the execute_region
375 if (Operation *defOp = value.getDefiningOp())
376 return &executeRegionOp.getRegion() == defOp->getParentRegion();
377
378 // If it's a block argument, check if it's from within the region
379 if (BlockArgument blockArg = dyn_cast<BlockArgument>(value))
380 return &executeRegionOp.getRegion() == blockArg.getParentRegion();
381
382 return false; // Value is from outside the region
383 }
384};
385
386void ExecuteRegionOp::getCanonicalizationPatterns(RewritePatternSet &results,
387 MLIRContext *context) {
390}
391
392void ExecuteRegionOp::getSuccessorRegions(
394 // If the predecessor is the ExecuteRegionOp, branch into the body.
395 if (point.isParent()) {
396 regions.push_back(RegionSuccessor(&getRegion()));
397 return;
398 }
399
400 // Otherwise, the region branches back to the parent operation.
401 regions.push_back(RegionSuccessor(getOperation(), getResults()));
402}
403
404//===----------------------------------------------------------------------===//
405// ConditionOp
406//===----------------------------------------------------------------------===//
407
409ConditionOp::getMutableSuccessorOperands(RegionSuccessor point) {
410 assert(
411 (point.isParent() || point.getSuccessor() == &getParentOp().getAfter()) &&
412 "condition op can only exit the loop or branch to the after"
413 "region");
414 // Pass all operands except the condition to the successor region.
415 return getArgsMutable();
416}
417
418void ConditionOp::getSuccessorRegions(
420 FoldAdaptor adaptor(operands, *this);
421
422 WhileOp whileOp = getParentOp();
423
424 // Condition can either lead to the after region or back to the parent op
425 // depending on whether the condition is true or not.
426 auto boolAttr = dyn_cast_or_null<BoolAttr>(adaptor.getCondition());
427 if (!boolAttr || boolAttr.getValue())
428 regions.emplace_back(&whileOp.getAfter(),
429 whileOp.getAfter().getArguments());
430 if (!boolAttr || !boolAttr.getValue())
431 regions.emplace_back(whileOp.getOperation(), whileOp.getResults());
432}
433
434//===----------------------------------------------------------------------===//
435// ForOp
436//===----------------------------------------------------------------------===//
437
438void ForOp::build(OpBuilder &builder, OperationState &result, Value lb,
439 Value ub, Value step, ValueRange initArgs,
440 BodyBuilderFn bodyBuilder, bool unsignedCmp) {
441 OpBuilder::InsertionGuard guard(builder);
442
443 if (unsignedCmp)
444 result.addAttribute(getUnsignedCmpAttrName(result.name),
445 builder.getUnitAttr());
446 result.addOperands({lb, ub, step});
447 result.addOperands(initArgs);
448 for (Value v : initArgs)
449 result.addTypes(v.getType());
450 Type t = lb.getType();
451 Region *bodyRegion = result.addRegion();
452 Block *bodyBlock = builder.createBlock(bodyRegion);
453 bodyBlock->addArgument(t, result.location);
454 for (Value v : initArgs)
455 bodyBlock->addArgument(v.getType(), v.getLoc());
456
457 // Create the default terminator if the builder is not provided and if the
458 // iteration arguments are not provided. Otherwise, leave this to the caller
459 // because we don't know which values to return from the loop.
460 if (initArgs.empty() && !bodyBuilder) {
461 ForOp::ensureTerminator(*bodyRegion, builder, result.location);
462 } else if (bodyBuilder) {
463 OpBuilder::InsertionGuard guard(builder);
464 builder.setInsertionPointToStart(bodyBlock);
465 bodyBuilder(builder, result.location, bodyBlock->getArgument(0),
466 bodyBlock->getArguments().drop_front());
467 }
468}
469
470LogicalResult ForOp::verify() {
471 // Check that the number of init args and op results is the same.
472 if (getInitArgs().size() != getNumResults())
473 return emitOpError(
474 "mismatch in number of loop-carried values and defined values");
475
476 return success();
477}
478
479LogicalResult ForOp::verifyRegions() {
480 // Check that the body defines as single block argument for the induction
481 // variable.
482 if (getInductionVar().getType() != getLowerBound().getType())
483 return emitOpError(
484 "expected induction variable to be same type as bounds and step");
485
486 if (getNumRegionIterArgs() != getNumResults())
487 return emitOpError(
488 "mismatch in number of basic block args and defined values");
489
490 auto initArgs = getInitArgs();
491 auto iterArgs = getRegionIterArgs();
492 auto opResults = getResults();
493 unsigned i = 0;
494 for (auto e : llvm::zip(initArgs, iterArgs, opResults)) {
495 if (std::get<0>(e).getType() != std::get<2>(e).getType())
496 return emitOpError() << "types mismatch between " << i
497 << "th iter operand and defined value";
498 if (std::get<1>(e).getType() != std::get<2>(e).getType())
499 return emitOpError() << "types mismatch between " << i
500 << "th iter region arg and defined value";
501
502 ++i;
503 }
504 return success();
505}
506
507std::optional<SmallVector<Value>> ForOp::getLoopInductionVars() {
508 return SmallVector<Value>{getInductionVar()};
509}
510
511std::optional<SmallVector<OpFoldResult>> ForOp::getLoopLowerBounds() {
513}
514
515std::optional<SmallVector<OpFoldResult>> ForOp::getLoopSteps() {
516 return SmallVector<OpFoldResult>{OpFoldResult(getStep())};
517}
518
519std::optional<SmallVector<OpFoldResult>> ForOp::getLoopUpperBounds() {
521}
522
523std::optional<ResultRange> ForOp::getLoopResults() { return getResults(); }
524
525/// Promotes the loop body of a forOp to its containing block if the forOp
526/// it can be determined that the loop has a single iteration.
527LogicalResult ForOp::promoteIfSingleIteration(RewriterBase &rewriter) {
528 std::optional<APInt> tripCount = getStaticTripCount();
529 LDBG() << "promoteIfSingleIteration tripCount is " << tripCount
530 << " for loop "
531 << OpWithFlags(getOperation(), OpPrintingFlags().skipRegions());
532 if (!tripCount.has_value() || tripCount->getSExtValue() > 1)
533 return failure();
534
535 if (*tripCount == 0) {
536 rewriter.replaceAllUsesWith(getResults(), getInitArgs());
537 rewriter.eraseOp(*this);
538 return success();
539 }
540
541 // Replace all results with the yielded values.
542 auto yieldOp = cast<scf::YieldOp>(getBody()->getTerminator());
543 rewriter.replaceAllUsesWith(getResults(), getYieldedValues());
544
545 // Replace block arguments with lower bound (replacement for IV) and
546 // iter_args.
547 SmallVector<Value> bbArgReplacements;
548 bbArgReplacements.push_back(getLowerBound());
549 llvm::append_range(bbArgReplacements, getInitArgs());
550
551 // Move the loop body operations to the loop's containing block.
552 rewriter.inlineBlockBefore(getBody(), getOperation()->getBlock(),
553 getOperation()->getIterator(), bbArgReplacements);
554
555 // Erase the old terminator and the loop.
556 rewriter.eraseOp(yieldOp);
557 rewriter.eraseOp(*this);
558
559 return success();
560}
561
562/// Prints the initialization list in the form of
563/// <prefix>(%inner = %outer, %inner2 = %outer2, <...>)
564/// where 'inner' values are assumed to be region arguments and 'outer' values
565/// are regular SSA values.
567 Block::BlockArgListType blocksArgs,
568 ValueRange initializers,
569 StringRef prefix = "") {
570 assert(blocksArgs.size() == initializers.size() &&
571 "expected same length of arguments and initializers");
572 if (initializers.empty())
573 return;
574
575 p << prefix << '(';
576 llvm::interleaveComma(llvm::zip(blocksArgs, initializers), p, [&](auto it) {
577 p << std::get<0>(it) << " = " << std::get<1>(it);
578 });
579 p << ")";
580}
581
582void ForOp::print(OpAsmPrinter &p) {
583 if (getUnsignedCmp())
584 p << " unsigned";
585
586 p << " " << getInductionVar() << " = " << getLowerBound() << " to "
587 << getUpperBound() << " step " << getStep();
588
589 printInitializationList(p, getRegionIterArgs(), getInitArgs(), " iter_args");
590 if (!getInitArgs().empty())
591 p << " -> (" << getInitArgs().getTypes() << ')';
592 p << ' ';
593 if (Type t = getInductionVar().getType(); !t.isIndex())
594 p << " : " << t << ' ';
595 p.printRegion(getRegion(),
596 /*printEntryBlockArgs=*/false,
597 /*printBlockTerminators=*/!getInitArgs().empty());
598 p.printOptionalAttrDict((*this)->getAttrs(),
599 /*elidedAttrs=*/getUnsignedCmpAttrName().strref());
600}
601
602ParseResult ForOp::parse(OpAsmParser &parser, OperationState &result) {
603 auto &builder = parser.getBuilder();
604 Type type;
605
606 OpAsmParser::Argument inductionVariable;
608
609 if (succeeded(parser.parseOptionalKeyword("unsigned")))
610 result.addAttribute(getUnsignedCmpAttrName(result.name),
611 builder.getUnitAttr());
612
613 // Parse the induction variable followed by '='.
614 if (parser.parseOperand(inductionVariable.ssaName) || parser.parseEqual() ||
615 // Parse loop bounds.
616 parser.parseOperand(lb) || parser.parseKeyword("to") ||
617 parser.parseOperand(ub) || parser.parseKeyword("step") ||
618 parser.parseOperand(step))
619 return failure();
620
621 // Parse the optional initial iteration arguments.
624 regionArgs.push_back(inductionVariable);
625
626 bool hasIterArgs = succeeded(parser.parseOptionalKeyword("iter_args"));
627 if (hasIterArgs) {
628 // Parse assignment list and results type list.
629 if (parser.parseAssignmentList(regionArgs, operands) ||
630 parser.parseArrowTypeList(result.types))
631 return failure();
632 }
633
634 if (regionArgs.size() != result.types.size() + 1)
635 return parser.emitError(
636 parser.getNameLoc(),
637 "mismatch in number of loop-carried values and defined values");
638
639 // Parse optional type, else assume Index.
640 if (parser.parseOptionalColon())
641 type = builder.getIndexType();
642 else if (parser.parseType(type))
643 return failure();
644
645 // Set block argument types, so that they are known when parsing the region.
646 regionArgs.front().type = type;
647 for (auto [iterArg, type] :
648 llvm::zip_equal(llvm::drop_begin(regionArgs), result.types))
649 iterArg.type = type;
650
651 // Parse the body region.
652 Region *body = result.addRegion();
653 if (parser.parseRegion(*body, regionArgs))
654 return failure();
655 ForOp::ensureTerminator(*body, builder, result.location);
656
657 // Resolve input operands. This should be done after parsing the region to
658 // catch invalid IR where operands were defined inside of the region.
659 if (parser.resolveOperand(lb, type, result.operands) ||
660 parser.resolveOperand(ub, type, result.operands) ||
661 parser.resolveOperand(step, type, result.operands))
662 return failure();
663 if (hasIterArgs) {
664 for (auto argOperandType : llvm::zip_equal(llvm::drop_begin(regionArgs),
665 operands, result.types)) {
666 Type type = std::get<2>(argOperandType);
667 std::get<0>(argOperandType).type = type;
668 if (parser.resolveOperand(std::get<1>(argOperandType), type,
669 result.operands))
670 return failure();
671 }
672 }
673
674 // Parse the optional attribute list.
675 if (parser.parseOptionalAttrDict(result.attributes))
676 return failure();
677
678 return success();
679}
680
681SmallVector<Region *> ForOp::getLoopRegions() { return {&getRegion()}; }
682
683Block::BlockArgListType ForOp::getRegionIterArgs() {
684 return getBody()->getArguments().drop_front(getNumInductionVars());
685}
686
687MutableArrayRef<OpOperand> ForOp::getInitsMutable() {
688 return getInitArgsMutable();
689}
690
691FailureOr<LoopLikeOpInterface>
692ForOp::replaceWithAdditionalYields(RewriterBase &rewriter,
693 ValueRange newInitOperands,
694 bool replaceInitOperandUsesInLoop,
695 const NewYieldValuesFn &newYieldValuesFn) {
696 // Create a new loop before the existing one, with the extra operands.
697 OpBuilder::InsertionGuard g(rewriter);
698 rewriter.setInsertionPoint(getOperation());
699 auto inits = llvm::to_vector(getInitArgs());
700 inits.append(newInitOperands.begin(), newInitOperands.end());
701 scf::ForOp newLoop = scf::ForOp::create(
702 rewriter, getLoc(), getLowerBound(), getUpperBound(), getStep(), inits,
703 [](OpBuilder &, Location, Value, ValueRange) {}, getUnsignedCmp());
704 newLoop->setAttrs(getPrunedAttributeList(getOperation(), {}));
705
706 // Generate the new yield values and append them to the scf.yield operation.
707 auto yieldOp = cast<scf::YieldOp>(getBody()->getTerminator());
708 ArrayRef<BlockArgument> newIterArgs =
709 newLoop.getBody()->getArguments().take_back(newInitOperands.size());
710 {
711 OpBuilder::InsertionGuard g(rewriter);
712 rewriter.setInsertionPoint(yieldOp);
713 SmallVector<Value> newYieldedValues =
714 newYieldValuesFn(rewriter, getLoc(), newIterArgs);
715 assert(newInitOperands.size() == newYieldedValues.size() &&
716 "expected as many new yield values as new iter operands");
717 rewriter.modifyOpInPlace(yieldOp, [&]() {
718 yieldOp.getResultsMutable().append(newYieldedValues);
719 });
720 }
721
722 // Move the loop body to the new op.
723 rewriter.mergeBlocks(getBody(), newLoop.getBody(),
724 newLoop.getBody()->getArguments().take_front(
725 getBody()->getNumArguments()));
726
727 if (replaceInitOperandUsesInLoop) {
728 // Replace all uses of `newInitOperands` with the corresponding basic block
729 // arguments.
730 for (auto it : llvm::zip(newInitOperands, newIterArgs)) {
731 rewriter.replaceUsesWithIf(std::get<0>(it), std::get<1>(it),
732 [&](OpOperand &use) {
733 Operation *user = use.getOwner();
734 return newLoop->isProperAncestor(user);
735 });
736 }
737 }
738
739 // Replace the old loop.
740 rewriter.replaceOp(getOperation(),
741 newLoop->getResults().take_front(getNumResults()));
742 return cast<LoopLikeOpInterface>(newLoop.getOperation());
743}
744
746 auto ivArg = llvm::dyn_cast<BlockArgument>(val);
747 if (!ivArg)
748 return ForOp();
749 assert(ivArg.getOwner() && "unlinked block argument");
750 auto *containingOp = ivArg.getOwner()->getParentOp();
751 return dyn_cast_or_null<ForOp>(containingOp);
752}
753
754OperandRange ForOp::getEntrySuccessorOperands(RegionSuccessor successor) {
755 return getInitArgs();
756}
757
758void ForOp::getSuccessorRegions(RegionBranchPoint point,
760 // Both the operation itself and the region may be branching into the body or
761 // back into the operation itself. It is possible for loop not to enter the
762 // body.
763 regions.push_back(RegionSuccessor(&getRegion(), getRegionIterArgs()));
764 regions.push_back(RegionSuccessor(getOperation(), getResults()));
765}
766
767SmallVector<Region *> ForallOp::getLoopRegions() { return {&getRegion()}; }
768
769/// Promotes the loop body of a forallOp to its containing block if it can be
770/// determined that the loop has a single iteration.
771LogicalResult scf::ForallOp::promoteIfSingleIteration(RewriterBase &rewriter) {
772 for (auto [lb, ub, step] :
773 llvm::zip(getMixedLowerBound(), getMixedUpperBound(), getMixedStep())) {
774 auto tripCount =
775 constantTripCount(lb, ub, step, /*isSigned=*/true, computeUbMinusLb);
776 if (!tripCount.has_value() || *tripCount != 1)
777 return failure();
778 }
779
780 promote(rewriter, *this);
781 return success();
782}
783
784Block::BlockArgListType ForallOp::getRegionIterArgs() {
785 return getBody()->getArguments().drop_front(getRank());
786}
787
788MutableArrayRef<OpOperand> ForallOp::getInitsMutable() {
789 return getOutputsMutable();
790}
791
792/// Promotes the loop body of a scf::ForallOp to its containing block.
793void mlir::scf::promote(RewriterBase &rewriter, scf::ForallOp forallOp) {
794 OpBuilder::InsertionGuard g(rewriter);
795 scf::InParallelOp terminator = forallOp.getTerminator();
796
797 // Replace block arguments with lower bounds (replacements for IVs) and
798 // outputs.
799 SmallVector<Value> bbArgReplacements = forallOp.getLowerBound(rewriter);
800 bbArgReplacements.append(forallOp.getOutputs().begin(),
801 forallOp.getOutputs().end());
802
803 // Move the loop body operations to the loop's containing block.
804 rewriter.inlineBlockBefore(forallOp.getBody(), forallOp->getBlock(),
805 forallOp->getIterator(), bbArgReplacements);
806
807 // Replace the terminator with tensor.insert_slice ops.
808 rewriter.setInsertionPointAfter(forallOp);
809 SmallVector<Value> results;
810 results.reserve(forallOp.getResults().size());
811 for (auto &yieldingOp : terminator.getYieldingOps()) {
812 auto parallelInsertSliceOp =
813 dyn_cast<tensor::ParallelInsertSliceOp>(yieldingOp);
814 if (!parallelInsertSliceOp)
815 continue;
816
817 Value dst = parallelInsertSliceOp.getDest();
818 Value src = parallelInsertSliceOp.getSource();
819 if (llvm::isa<TensorType>(src.getType())) {
820 results.push_back(tensor::InsertSliceOp::create(
821 rewriter, forallOp.getLoc(), dst.getType(), src, dst,
822 parallelInsertSliceOp.getOffsets(), parallelInsertSliceOp.getSizes(),
823 parallelInsertSliceOp.getStrides(),
824 parallelInsertSliceOp.getStaticOffsets(),
825 parallelInsertSliceOp.getStaticSizes(),
826 parallelInsertSliceOp.getStaticStrides()));
827 } else {
828 llvm_unreachable("unsupported terminator");
829 }
830 }
831 rewriter.replaceAllUsesWith(forallOp.getResults(), results);
832
833 // Erase the old terminator and the loop.
834 rewriter.eraseOp(terminator);
835 rewriter.eraseOp(forallOp);
836}
837
839 OpBuilder &builder, Location loc, ValueRange lbs, ValueRange ubs,
840 ValueRange steps, ValueRange iterArgs,
842 bodyBuilder) {
843 assert(lbs.size() == ubs.size() &&
844 "expected the same number of lower and upper bounds");
845 assert(lbs.size() == steps.size() &&
846 "expected the same number of lower bounds and steps");
847
848 // If there are no bounds, call the body-building function and return early.
849 if (lbs.empty()) {
850 ValueVector results =
851 bodyBuilder ? bodyBuilder(builder, loc, ValueRange(), iterArgs)
852 : ValueVector();
853 assert(results.size() == iterArgs.size() &&
854 "loop nest body must return as many values as loop has iteration "
855 "arguments");
856 return LoopNest{{}, std::move(results)};
857 }
858
859 // First, create the loop structure iteratively using the body-builder
860 // callback of `ForOp::build`. Do not create `YieldOp`s yet.
861 OpBuilder::InsertionGuard guard(builder);
864 loops.reserve(lbs.size());
865 ivs.reserve(lbs.size());
866 ValueRange currentIterArgs = iterArgs;
867 Location currentLoc = loc;
868 for (unsigned i = 0, e = lbs.size(); i < e; ++i) {
869 auto loop = scf::ForOp::create(
870 builder, currentLoc, lbs[i], ubs[i], steps[i], currentIterArgs,
871 [&](OpBuilder &nestedBuilder, Location nestedLoc, Value iv,
872 ValueRange args) {
873 ivs.push_back(iv);
874 // It is safe to store ValueRange args because it points to block
875 // arguments of a loop operation that we also own.
876 currentIterArgs = args;
877 currentLoc = nestedLoc;
878 });
879 // Set the builder to point to the body of the newly created loop. We don't
880 // do this in the callback because the builder is reset when the callback
881 // returns.
882 builder.setInsertionPointToStart(loop.getBody());
883 loops.push_back(loop);
884 }
885
886 // For all loops but the innermost, yield the results of the nested loop.
887 for (unsigned i = 0, e = loops.size() - 1; i < e; ++i) {
888 builder.setInsertionPointToEnd(loops[i].getBody());
889 scf::YieldOp::create(builder, loc, loops[i + 1].getResults());
890 }
891
892 // In the body of the innermost loop, call the body building function if any
893 // and yield its results.
894 builder.setInsertionPointToStart(loops.back().getBody());
895 ValueVector results = bodyBuilder
896 ? bodyBuilder(builder, currentLoc, ivs,
897 loops.back().getRegionIterArgs())
898 : ValueVector();
899 assert(results.size() == iterArgs.size() &&
900 "loop nest body must return as many values as loop has iteration "
901 "arguments");
902 builder.setInsertionPointToEnd(loops.back().getBody());
903 scf::YieldOp::create(builder, loc, results);
904
905 // Return the loops.
906 ValueVector nestResults;
907 llvm::append_range(nestResults, loops.front().getResults());
908 return LoopNest{std::move(loops), std::move(nestResults)};
909}
910
912 OpBuilder &builder, Location loc, ValueRange lbs, ValueRange ubs,
913 ValueRange steps,
914 function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilder) {
915 // Delegate to the main function by wrapping the body builder.
916 return buildLoopNest(builder, loc, lbs, ubs, steps, {},
917 [&bodyBuilder](OpBuilder &nestedBuilder,
918 Location nestedLoc, ValueRange ivs,
920 if (bodyBuilder)
921 bodyBuilder(nestedBuilder, nestedLoc, ivs);
922 return {};
923 });
924}
925
928 OpOperand &operand, Value replacement,
929 const ValueTypeCastFnTy &castFn) {
930 assert(operand.getOwner() == forOp);
931 Type oldType = operand.get().getType(), newType = replacement.getType();
932
933 // 1. Create new iter operands, exactly 1 is replaced.
934 assert(operand.getOperandNumber() >= forOp.getNumControlOperands() &&
935 "expected an iter OpOperand");
936 assert(operand.get().getType() != replacement.getType() &&
937 "Expected a different type");
938 SmallVector<Value> newIterOperands;
939 for (OpOperand &opOperand : forOp.getInitArgsMutable()) {
940 if (opOperand.getOperandNumber() == operand.getOperandNumber()) {
941 newIterOperands.push_back(replacement);
942 continue;
943 }
944 newIterOperands.push_back(opOperand.get());
945 }
946
947 // 2. Create the new forOp shell.
948 scf::ForOp newForOp = scf::ForOp::create(
949 rewriter, forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
950 forOp.getStep(), newIterOperands, /*bodyBuilder=*/nullptr,
951 forOp.getUnsignedCmp());
952 newForOp->setAttrs(forOp->getAttrs());
953 Block &newBlock = newForOp.getRegion().front();
954 SmallVector<Value, 4> newBlockTransferArgs(newBlock.getArguments().begin(),
955 newBlock.getArguments().end());
956
957 // 3. Inject an incoming cast op at the beginning of the block for the bbArg
958 // corresponding to the `replacement` value.
959 OpBuilder::InsertionGuard g(rewriter);
960 rewriter.setInsertionPointToStart(&newBlock);
961 BlockArgument newRegionIterArg = newForOp.getTiedLoopRegionIterArg(
962 &newForOp->getOpOperand(operand.getOperandNumber()));
963 Value castIn = castFn(rewriter, newForOp.getLoc(), oldType, newRegionIterArg);
964 newBlockTransferArgs[newRegionIterArg.getArgNumber()] = castIn;
965
966 // 4. Steal the old block ops, mapping to the newBlockTransferArgs.
967 Block &oldBlock = forOp.getRegion().front();
968 rewriter.mergeBlocks(&oldBlock, &newBlock, newBlockTransferArgs);
969
970 // 5. Inject an outgoing cast op at the end of the block and yield it instead.
971 auto clonedYieldOp = cast<scf::YieldOp>(newBlock.getTerminator());
972 rewriter.setInsertionPoint(clonedYieldOp);
973 unsigned yieldIdx =
974 newRegionIterArg.getArgNumber() - forOp.getNumInductionVars();
975 Value castOut = castFn(rewriter, newForOp.getLoc(), newType,
976 clonedYieldOp.getOperand(yieldIdx));
977 SmallVector<Value> newYieldOperands = clonedYieldOp.getOperands();
978 newYieldOperands[yieldIdx] = castOut;
979 scf::YieldOp::create(rewriter, newForOp.getLoc(), newYieldOperands);
980 rewriter.eraseOp(clonedYieldOp);
981
982 // 6. Inject an outgoing cast op after the forOp.
983 rewriter.setInsertionPointAfter(newForOp);
984 SmallVector<Value> newResults = newForOp.getResults();
985 newResults[yieldIdx] =
986 castFn(rewriter, newForOp.getLoc(), oldType, newResults[yieldIdx]);
987
988 return newResults;
989}
990
991namespace {
992// Fold away ForOp iter arguments when:
993// 1) The op yields the iter arguments.
994// 2) The argument's corresponding outer region iterators (inputs) are yielded.
995// 3) The iter arguments have no use and the corresponding (operation) results
996// have no use.
997//
998// These arguments must be defined outside of the ForOp region and can just be
999// forwarded after simplifying the op inits, yields and returns.
1000//
1001// The implementation uses `inlineBlockBefore` to steal the content of the
1002// original ForOp and avoid cloning.
1003struct ForOpIterArgsFolder : public OpRewritePattern<scf::ForOp> {
1004 using OpRewritePattern<scf::ForOp>::OpRewritePattern;
1005
1006 LogicalResult matchAndRewrite(scf::ForOp forOp,
1007 PatternRewriter &rewriter) const final {
1008 bool canonicalize = false;
1009
1010 // An internal flat vector of block transfer
1011 // arguments `newBlockTransferArgs` keeps the 1-1 mapping of original to
1012 // transformed block argument mappings. This plays the role of a
1013 // IRMapping for the particular use case of calling into
1014 // `inlineBlockBefore`.
1015 int64_t numResults = forOp.getNumResults();
1016 SmallVector<bool, 4> keepMask;
1017 keepMask.reserve(numResults);
1018 SmallVector<Value, 4> newBlockTransferArgs, newIterArgs, newYieldValues,
1019 newResultValues;
1020 newBlockTransferArgs.reserve(1 + numResults);
1021 newBlockTransferArgs.push_back(Value()); // iv placeholder with null value
1022 newIterArgs.reserve(forOp.getInitArgs().size());
1023 newYieldValues.reserve(numResults);
1024 newResultValues.reserve(numResults);
1025 DenseMap<std::pair<Value, Value>, std::pair<Value, Value>> initYieldToArg;
1026 for (auto [init, arg, result, yielded] :
1027 llvm::zip(forOp.getInitArgs(), // iter from outside
1028 forOp.getRegionIterArgs(), // iter inside region
1029 forOp.getResults(), // op results
1030 forOp.getYieldedValues() // iter yield
1031 )) {
1032 // Forwarded is `true` when:
1033 // 1) The region `iter` argument is yielded.
1034 // 2) The region `iter` argument the corresponding input is yielded.
1035 // 3) The region `iter` argument has no use, and the corresponding op
1036 // result has no use.
1037 bool forwarded = (arg == yielded) || (init == yielded) ||
1038 (arg.use_empty() && result.use_empty());
1039 if (forwarded) {
1040 canonicalize = true;
1041 keepMask.push_back(false);
1042 newBlockTransferArgs.push_back(init);
1043 newResultValues.push_back(init);
1044 continue;
1045 }
1046
1047 // Check if a previous kept argument always has the same values for init
1048 // and yielded values.
1049 if (auto it = initYieldToArg.find({init, yielded});
1050 it != initYieldToArg.end()) {
1051 canonicalize = true;
1052 keepMask.push_back(false);
1053 auto [sameArg, sameResult] = it->second;
1054 rewriter.replaceAllUsesWith(arg, sameArg);
1055 rewriter.replaceAllUsesWith(result, sameResult);
1056 // The replacement value doesn't matter because there are no uses.
1057 newBlockTransferArgs.push_back(init);
1058 newResultValues.push_back(init);
1059 continue;
1060 }
1061
1062 // This value is kept.
1063 initYieldToArg.insert({{init, yielded}, {arg, result}});
1064 keepMask.push_back(true);
1065 newIterArgs.push_back(init);
1066 newYieldValues.push_back(yielded);
1067 newBlockTransferArgs.push_back(Value()); // placeholder with null value
1068 newResultValues.push_back(Value()); // placeholder with null value
1069 }
1070
1071 if (!canonicalize)
1072 return failure();
1073
1074 scf::ForOp newForOp =
1075 scf::ForOp::create(rewriter, forOp.getLoc(), forOp.getLowerBound(),
1076 forOp.getUpperBound(), forOp.getStep(), newIterArgs,
1077 /*bodyBuilder=*/nullptr, forOp.getUnsignedCmp());
1078 newForOp->setAttrs(forOp->getAttrs());
1079 Block &newBlock = newForOp.getRegion().front();
1080
1081 // Replace the null placeholders with newly constructed values.
1082 newBlockTransferArgs[0] = newBlock.getArgument(0); // iv
1083 for (unsigned idx = 0, collapsedIdx = 0, e = newResultValues.size();
1084 idx != e; ++idx) {
1085 Value &blockTransferArg = newBlockTransferArgs[1 + idx];
1086 Value &newResultVal = newResultValues[idx];
1087 assert((blockTransferArg && newResultVal) ||
1088 (!blockTransferArg && !newResultVal));
1089 if (!blockTransferArg) {
1090 blockTransferArg = newForOp.getRegionIterArgs()[collapsedIdx];
1091 newResultVal = newForOp.getResult(collapsedIdx++);
1092 }
1093 }
1094
1095 Block &oldBlock = forOp.getRegion().front();
1096 assert(oldBlock.getNumArguments() == newBlockTransferArgs.size() &&
1097 "unexpected argument size mismatch");
1098
1099 // No results case: the scf::ForOp builder already created a zero
1100 // result terminator. Merge before this terminator and just get rid of the
1101 // original terminator that has been merged in.
1102 if (newIterArgs.empty()) {
1103 auto newYieldOp = cast<scf::YieldOp>(newBlock.getTerminator());
1104 rewriter.inlineBlockBefore(&oldBlock, newYieldOp, newBlockTransferArgs);
1105 rewriter.eraseOp(newBlock.getTerminator()->getPrevNode());
1106 rewriter.replaceOp(forOp, newResultValues);
1107 return success();
1108 }
1109
1110 // No terminator case: merge and rewrite the merged terminator.
1111 auto cloneFilteredTerminator = [&](scf::YieldOp mergedTerminator) {
1112 OpBuilder::InsertionGuard g(rewriter);
1113 rewriter.setInsertionPoint(mergedTerminator);
1114 SmallVector<Value, 4> filteredOperands;
1115 filteredOperands.reserve(newResultValues.size());
1116 for (unsigned idx = 0, e = keepMask.size(); idx < e; ++idx)
1117 if (keepMask[idx])
1118 filteredOperands.push_back(mergedTerminator.getOperand(idx));
1119 scf::YieldOp::create(rewriter, mergedTerminator.getLoc(),
1120 filteredOperands);
1121 };
1122
1123 rewriter.mergeBlocks(&oldBlock, &newBlock, newBlockTransferArgs);
1124 auto mergedYieldOp = cast<scf::YieldOp>(newBlock.getTerminator());
1125 cloneFilteredTerminator(mergedYieldOp);
1126 rewriter.eraseOp(mergedYieldOp);
1127 rewriter.replaceOp(forOp, newResultValues);
1128 return success();
1129 }
1130};
1131
1132/// Rewriting pattern that erases loops that are known not to iterate, replaces
1133/// single-iteration loops with their bodies, and removes empty loops that
1134/// iterate at least once and only return values defined outside of the loop.
1135struct SimplifyTrivialLoops : public OpRewritePattern<ForOp> {
1136 using OpRewritePattern<ForOp>::OpRewritePattern;
1137
1138 LogicalResult matchAndRewrite(ForOp op,
1139 PatternRewriter &rewriter) const override {
1140 std::optional<APInt> tripCount = op.getStaticTripCount();
1141 if (!tripCount.has_value())
1142 return rewriter.notifyMatchFailure(op,
1143 "can't compute constant trip count");
1144
1145 if (tripCount->isZero()) {
1146 LDBG() << "SimplifyTrivialLoops tripCount is 0 for loop "
1147 << OpWithFlags(op, OpPrintingFlags().skipRegions());
1148 rewriter.replaceOp(op, op.getInitArgs());
1149 return success();
1150 }
1151
1152 if (tripCount->getSExtValue() == 1) {
1153 LDBG() << "SimplifyTrivialLoops tripCount is 1 for loop "
1154 << OpWithFlags(op, OpPrintingFlags().skipRegions());
1155 SmallVector<Value, 4> blockArgs;
1156 blockArgs.reserve(op.getInitArgs().size() + 1);
1157 blockArgs.push_back(op.getLowerBound());
1158 llvm::append_range(blockArgs, op.getInitArgs());
1159 replaceOpWithRegion(rewriter, op, op.getRegion(), blockArgs);
1160 return success();
1161 }
1162
1163 // Now we are left with loops that have more than 1 iterations.
1164 Block &block = op.getRegion().front();
1165 if (!llvm::hasSingleElement(block))
1166 return failure();
1167 // The loop is empty and iterates at least once, if it only returns values
1168 // defined outside of the loop, remove it and replace it with yield values.
1169 if (llvm::any_of(op.getYieldedValues(),
1170 [&](Value v) { return !op.isDefinedOutsideOfLoop(v); }))
1171 return failure();
1172 LDBG() << "SimplifyTrivialLoops empty body loop allows replacement with "
1173 "yield operands for loop "
1174 << OpWithFlags(op, OpPrintingFlags().skipRegions());
1175 rewriter.replaceOp(op, op.getYieldedValues());
1176 return success();
1177 }
1178};
1179
1180/// Fold scf.for iter_arg/result pairs that go through incoming/ougoing
1181/// a tensor.cast op pair so as to pull the tensor.cast inside the scf.for:
1182///
1183/// ```
1184/// %0 = tensor.cast %t0 : tensor<32x1024xf32> to tensor<?x?xf32>
1185/// %1 = scf.for %i = %c0 to %c1024 step %c32 iter_args(%iter_t0 = %0)
1186/// -> (tensor<?x?xf32>) {
1187/// %2 = call @do(%iter_t0) : (tensor<?x?xf32>) -> tensor<?x?xf32>
1188/// scf.yield %2 : tensor<?x?xf32>
1189/// }
1190/// use_of(%1)
1191/// ```
1192///
1193/// folds into:
1194///
1195/// ```
1196/// %0 = scf.for %arg2 = %c0 to %c1024 step %c32 iter_args(%arg3 = %arg0)
1197/// -> (tensor<32x1024xf32>) {
1198/// %2 = tensor.cast %arg3 : tensor<32x1024xf32> to tensor<?x?xf32>
1199/// %3 = call @do(%2) : (tensor<?x?xf32>) -> tensor<?x?xf32>
1200/// %4 = tensor.cast %3 : tensor<?x?xf32> to tensor<32x1024xf32>
1201/// scf.yield %4 : tensor<32x1024xf32>
1202/// }
1203/// %1 = tensor.cast %0 : tensor<32x1024xf32> to tensor<?x?xf32>
1204/// use_of(%1)
1205/// ```
1206struct ForOpTensorCastFolder : public OpRewritePattern<ForOp> {
1207 using OpRewritePattern<ForOp>::OpRewritePattern;
1208
1209 LogicalResult matchAndRewrite(ForOp op,
1210 PatternRewriter &rewriter) const override {
1211 for (auto it : llvm::zip(op.getInitArgsMutable(), op.getResults())) {
1212 OpOperand &iterOpOperand = std::get<0>(it);
1213 auto incomingCast = iterOpOperand.get().getDefiningOp<tensor::CastOp>();
1214 if (!incomingCast ||
1215 incomingCast.getSource().getType() == incomingCast.getType())
1216 continue;
1217 // If the dest type of the cast does not preserve static information in
1218 // the source type.
1220 incomingCast.getDest().getType(),
1221 incomingCast.getSource().getType()))
1222 continue;
1223 if (!std::get<1>(it).hasOneUse())
1224 continue;
1225
1226 // Create a new ForOp with that iter operand replaced.
1227 rewriter.replaceOp(
1229 rewriter, op, iterOpOperand, incomingCast.getSource(),
1230 [](OpBuilder &b, Location loc, Type type, Value source) {
1231 return tensor::CastOp::create(b, loc, type, source);
1232 }));
1233 return success();
1234 }
1235 return failure();
1236 }
1237};
1238
1239} // namespace
1240
1241void ForOp::getCanonicalizationPatterns(RewritePatternSet &results,
1242 MLIRContext *context) {
1243 results.add<ForOpIterArgsFolder, SimplifyTrivialLoops, ForOpTensorCastFolder>(
1244 context);
1245}
1246
1247std::optional<APInt> ForOp::getConstantStep() {
1248 IntegerAttr step;
1249 if (matchPattern(getStep(), m_Constant(&step)))
1250 return step.getValue();
1251 return {};
1252}
1253
1254std::optional<MutableArrayRef<OpOperand>> ForOp::getYieldedValuesMutable() {
1255 return cast<scf::YieldOp>(getBody()->getTerminator()).getResultsMutable();
1256}
1257
1258Speculation::Speculatability ForOp::getSpeculatability() {
1259 // `scf.for (I = Start; I < End; I += 1)` terminates for all values of Start
1260 // and End.
1261 if (auto constantStep = getConstantStep())
1262 if (*constantStep == 1)
1264
1265 // For Step != 1, the loop may not terminate. We can add more smarts here if
1266 // needed.
1268}
1269
1270std::optional<APInt> ForOp::getStaticTripCount() {
1271 return constantTripCount(getLowerBound(), getUpperBound(), getStep(),
1272 /*isSigned=*/!getUnsignedCmp(), computeUbMinusLb);
1273}
1274
1275//===----------------------------------------------------------------------===//
1276// ForallOp
1277//===----------------------------------------------------------------------===//
1278
1279LogicalResult ForallOp::verify() {
1280 unsigned numLoops = getRank();
1281 // Check number of outputs.
1282 if (getNumResults() != getOutputs().size())
1283 return emitOpError("produces ")
1284 << getNumResults() << " results, but has only "
1285 << getOutputs().size() << " outputs";
1286
1287 // Check that the body defines block arguments for thread indices and outputs.
1288 auto *body = getBody();
1289 if (body->getNumArguments() != numLoops + getOutputs().size())
1290 return emitOpError("region expects ") << numLoops << " arguments";
1291 for (int64_t i = 0; i < numLoops; ++i)
1292 if (!body->getArgument(i).getType().isIndex())
1293 return emitOpError("expects ")
1294 << i << "-th block argument to be an index";
1295 for (unsigned i = 0; i < getOutputs().size(); ++i)
1296 if (body->getArgument(i + numLoops).getType() != getOutputs()[i].getType())
1297 return emitOpError("type mismatch between ")
1298 << i << "-th output and corresponding block argument";
1299 if (getMapping().has_value() && !getMapping()->empty()) {
1300 if (getDeviceMappingAttrs().size() != numLoops)
1301 return emitOpError() << "mapping attribute size must match op rank";
1302 if (failed(getDeviceMaskingAttr()))
1303 return emitOpError() << getMappingAttrName()
1304 << " supports at most one device masking attribute";
1305 }
1306
1307 // Verify mixed static/dynamic control variables.
1308 Operation *op = getOperation();
1309 if (failed(verifyListOfOperandsOrIntegers(op, "lower bound", numLoops,
1310 getStaticLowerBound(),
1311 getDynamicLowerBound())))
1312 return failure();
1313 if (failed(verifyListOfOperandsOrIntegers(op, "upper bound", numLoops,
1314 getStaticUpperBound(),
1315 getDynamicUpperBound())))
1316 return failure();
1317 if (failed(verifyListOfOperandsOrIntegers(op, "step", numLoops,
1318 getStaticStep(), getDynamicStep())))
1319 return failure();
1320
1321 return success();
1322}
1323
1324void ForallOp::print(OpAsmPrinter &p) {
1325 Operation *op = getOperation();
1326 p << " (" << getInductionVars();
1327 if (isNormalized()) {
1328 p << ") in ";
1329 printDynamicIndexList(p, op, getDynamicUpperBound(), getStaticUpperBound(),
1330 /*valueTypes=*/{}, /*scalables=*/{},
1332 } else {
1333 p << ") = ";
1334 printDynamicIndexList(p, op, getDynamicLowerBound(), getStaticLowerBound(),
1335 /*valueTypes=*/{}, /*scalables=*/{},
1337 p << " to ";
1338 printDynamicIndexList(p, op, getDynamicUpperBound(), getStaticUpperBound(),
1339 /*valueTypes=*/{}, /*scalables=*/{},
1341 p << " step ";
1342 printDynamicIndexList(p, op, getDynamicStep(), getStaticStep(),
1343 /*valueTypes=*/{}, /*scalables=*/{},
1345 }
1346 printInitializationList(p, getRegionOutArgs(), getOutputs(), " shared_outs");
1347 p << " ";
1348 if (!getRegionOutArgs().empty())
1349 p << "-> (" << getResultTypes() << ") ";
1350 p.printRegion(getRegion(),
1351 /*printEntryBlockArgs=*/false,
1352 /*printBlockTerminators=*/getNumResults() > 0);
1353 p.printOptionalAttrDict(op->getAttrs(), {getOperandSegmentSizesAttrName(),
1354 getStaticLowerBoundAttrName(),
1355 getStaticUpperBoundAttrName(),
1356 getStaticStepAttrName()});
1357}
1358
1359ParseResult ForallOp::parse(OpAsmParser &parser, OperationState &result) {
1360 OpBuilder b(parser.getContext());
1361 auto indexType = b.getIndexType();
1362
1363 // Parse an opening `(` followed by thread index variables followed by `)`
1364 // TODO: when we can refer to such "induction variable"-like handles from the
1365 // declarative assembly format, we can implement the parser as a custom hook.
1366 SmallVector<OpAsmParser::Argument, 4> ivs;
1368 return failure();
1369
1370 DenseI64ArrayAttr staticLbs, staticUbs, staticSteps;
1371 SmallVector<OpAsmParser::UnresolvedOperand> dynamicLbs, dynamicUbs,
1372 dynamicSteps;
1373 if (succeeded(parser.parseOptionalKeyword("in"))) {
1374 // Parse upper bounds.
1375 if (parseDynamicIndexList(parser, dynamicUbs, staticUbs,
1376 /*valueTypes=*/nullptr,
1378 parser.resolveOperands(dynamicUbs, indexType, result.operands))
1379 return failure();
1380
1381 unsigned numLoops = ivs.size();
1382 staticLbs = b.getDenseI64ArrayAttr(SmallVector<int64_t>(numLoops, 0));
1383 staticSteps = b.getDenseI64ArrayAttr(SmallVector<int64_t>(numLoops, 1));
1384 } else {
1385 // Parse lower bounds.
1386 if (parser.parseEqual() ||
1387 parseDynamicIndexList(parser, dynamicLbs, staticLbs,
1388 /*valueTypes=*/nullptr,
1390
1391 parser.resolveOperands(dynamicLbs, indexType, result.operands))
1392 return failure();
1393
1394 // Parse upper bounds.
1395 if (parser.parseKeyword("to") ||
1396 parseDynamicIndexList(parser, dynamicUbs, staticUbs,
1397 /*valueTypes=*/nullptr,
1399 parser.resolveOperands(dynamicUbs, indexType, result.operands))
1400 return failure();
1401
1402 // Parse step values.
1403 if (parser.parseKeyword("step") ||
1404 parseDynamicIndexList(parser, dynamicSteps, staticSteps,
1405 /*valueTypes=*/nullptr,
1407 parser.resolveOperands(dynamicSteps, indexType, result.operands))
1408 return failure();
1409 }
1410
1411 // Parse out operands and results.
1412 SmallVector<OpAsmParser::Argument, 4> regionOutArgs;
1413 SmallVector<OpAsmParser::UnresolvedOperand, 4> outOperands;
1414 SMLoc outOperandsLoc = parser.getCurrentLocation();
1415 if (succeeded(parser.parseOptionalKeyword("shared_outs"))) {
1416 if (outOperands.size() != result.types.size())
1417 return parser.emitError(outOperandsLoc,
1418 "mismatch between out operands and types");
1419 if (parser.parseAssignmentList(regionOutArgs, outOperands) ||
1420 parser.parseOptionalArrowTypeList(result.types) ||
1421 parser.resolveOperands(outOperands, result.types, outOperandsLoc,
1422 result.operands))
1423 return failure();
1424 }
1425
1426 // Parse region.
1427 SmallVector<OpAsmParser::Argument, 4> regionArgs;
1428 std::unique_ptr<Region> region = std::make_unique<Region>();
1429 for (auto &iv : ivs) {
1430 iv.type = b.getIndexType();
1431 regionArgs.push_back(iv);
1432 }
1433 for (const auto &it : llvm::enumerate(regionOutArgs)) {
1434 auto &out = it.value();
1435 out.type = result.types[it.index()];
1436 regionArgs.push_back(out);
1437 }
1438 if (parser.parseRegion(*region, regionArgs))
1439 return failure();
1440
1441 // Ensure terminator and move region.
1442 ForallOp::ensureTerminator(*region, b, result.location);
1443 result.addRegion(std::move(region));
1444
1445 // Parse the optional attribute list.
1446 if (parser.parseOptionalAttrDict(result.attributes))
1447 return failure();
1448
1449 result.addAttribute("staticLowerBound", staticLbs);
1450 result.addAttribute("staticUpperBound", staticUbs);
1451 result.addAttribute("staticStep", staticSteps);
1452 result.addAttribute("operandSegmentSizes",
1454 {static_cast<int32_t>(dynamicLbs.size()),
1455 static_cast<int32_t>(dynamicUbs.size()),
1456 static_cast<int32_t>(dynamicSteps.size()),
1457 static_cast<int32_t>(outOperands.size())}));
1458 return success();
1459}
1460
1461// Builder that takes loop bounds.
1462void ForallOp::build(
1463 mlir::OpBuilder &b, mlir::OperationState &result,
1464 ArrayRef<OpFoldResult> lbs, ArrayRef<OpFoldResult> ubs,
1465 ArrayRef<OpFoldResult> steps, ValueRange outputs,
1466 std::optional<ArrayAttr> mapping,
1467 function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilderFn) {
1468 SmallVector<int64_t> staticLbs, staticUbs, staticSteps;
1469 SmallVector<Value> dynamicLbs, dynamicUbs, dynamicSteps;
1470 dispatchIndexOpFoldResults(lbs, dynamicLbs, staticLbs);
1471 dispatchIndexOpFoldResults(ubs, dynamicUbs, staticUbs);
1472 dispatchIndexOpFoldResults(steps, dynamicSteps, staticSteps);
1473
1474 result.addOperands(dynamicLbs);
1475 result.addOperands(dynamicUbs);
1476 result.addOperands(dynamicSteps);
1477 result.addOperands(outputs);
1478 result.addTypes(TypeRange(outputs));
1479
1480 result.addAttribute(getStaticLowerBoundAttrName(result.name),
1481 b.getDenseI64ArrayAttr(staticLbs));
1482 result.addAttribute(getStaticUpperBoundAttrName(result.name),
1483 b.getDenseI64ArrayAttr(staticUbs));
1484 result.addAttribute(getStaticStepAttrName(result.name),
1485 b.getDenseI64ArrayAttr(staticSteps));
1486 result.addAttribute(
1487 "operandSegmentSizes",
1488 b.getDenseI32ArrayAttr({static_cast<int32_t>(dynamicLbs.size()),
1489 static_cast<int32_t>(dynamicUbs.size()),
1490 static_cast<int32_t>(dynamicSteps.size()),
1491 static_cast<int32_t>(outputs.size())}));
1492 if (mapping.has_value()) {
1493 result.addAttribute(ForallOp::getMappingAttrName(result.name),
1494 mapping.value());
1495 }
1496
1497 Region *bodyRegion = result.addRegion();
1498 OpBuilder::InsertionGuard g(b);
1499 b.createBlock(bodyRegion);
1500 Block &bodyBlock = bodyRegion->front();
1501
1502 // Add block arguments for indices and outputs.
1503 bodyBlock.addArguments(
1504 SmallVector<Type>(lbs.size(), b.getIndexType()),
1505 SmallVector<Location>(staticLbs.size(), result.location));
1506 bodyBlock.addArguments(
1507 TypeRange(outputs),
1508 SmallVector<Location>(outputs.size(), result.location));
1509
1510 b.setInsertionPointToStart(&bodyBlock);
1511 if (!bodyBuilderFn) {
1512 ForallOp::ensureTerminator(*bodyRegion, b, result.location);
1513 return;
1514 }
1515 bodyBuilderFn(b, result.location, bodyBlock.getArguments());
1516}
1517
1518// Builder that takes loop bounds.
1519void ForallOp::build(
1520 mlir::OpBuilder &b, mlir::OperationState &result,
1521 ArrayRef<OpFoldResult> ubs, ValueRange outputs,
1522 std::optional<ArrayAttr> mapping,
1523 function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilderFn) {
1524 unsigned numLoops = ubs.size();
1525 SmallVector<OpFoldResult> lbs(numLoops, b.getIndexAttr(0));
1526 SmallVector<OpFoldResult> steps(numLoops, b.getIndexAttr(1));
1527 build(b, result, lbs, ubs, steps, outputs, mapping, bodyBuilderFn);
1528}
1529
1530// Checks if the lbs are zeros and steps are ones.
1531bool ForallOp::isNormalized() {
1532 auto allEqual = [](ArrayRef<OpFoldResult> results, int64_t val) {
1533 return llvm::all_of(results, [&](OpFoldResult ofr) {
1534 auto intValue = getConstantIntValue(ofr);
1535 return intValue.has_value() && intValue == val;
1536 });
1537 };
1538 return allEqual(getMixedLowerBound(), 0) && allEqual(getMixedStep(), 1);
1539}
1540
1541InParallelOp ForallOp::getTerminator() {
1542 return cast<InParallelOp>(getBody()->getTerminator());
1543}
1544
1545SmallVector<Operation *> ForallOp::getCombiningOps(BlockArgument bbArg) {
1546 SmallVector<Operation *> storeOps;
1547 for (Operation *user : bbArg.getUsers()) {
1548 if (auto parallelOp = dyn_cast<ParallelCombiningOpInterface>(user)) {
1549 storeOps.push_back(parallelOp);
1550 }
1551 }
1552 return storeOps;
1553}
1554
1555SmallVector<DeviceMappingAttrInterface> ForallOp::getDeviceMappingAttrs() {
1556 SmallVector<DeviceMappingAttrInterface> res;
1557 if (!getMapping())
1558 return res;
1559 for (auto attr : getMapping()->getValue()) {
1560 auto m = dyn_cast<DeviceMappingAttrInterface>(attr);
1561 if (m)
1562 res.push_back(m);
1563 }
1564 return res;
1565}
1566
1567FailureOr<DeviceMaskingAttrInterface> ForallOp::getDeviceMaskingAttr() {
1568 DeviceMaskingAttrInterface res;
1569 if (!getMapping())
1570 return res;
1571 for (auto attr : getMapping()->getValue()) {
1572 auto m = dyn_cast<DeviceMaskingAttrInterface>(attr);
1573 if (m && res)
1574 return failure();
1575 if (m)
1576 res = m;
1577 }
1578 return res;
1579}
1580
1581bool ForallOp::usesLinearMapping() {
1582 SmallVector<DeviceMappingAttrInterface> ifaces = getDeviceMappingAttrs();
1583 if (ifaces.empty())
1584 return false;
1585 return ifaces.front().isLinearMapping();
1586}
1587
1588std::optional<SmallVector<Value>> ForallOp::getLoopInductionVars() {
1589 return SmallVector<Value>{getBody()->getArguments().take_front(getRank())};
1590}
1591
1592// Get lower bounds as OpFoldResult.
1593std::optional<SmallVector<OpFoldResult>> ForallOp::getLoopLowerBounds() {
1594 Builder b(getOperation()->getContext());
1595 return getMixedValues(getStaticLowerBound(), getDynamicLowerBound(), b);
1596}
1597
1598// Get upper bounds as OpFoldResult.
1599std::optional<SmallVector<OpFoldResult>> ForallOp::getLoopUpperBounds() {
1600 Builder b(getOperation()->getContext());
1601 return getMixedValues(getStaticUpperBound(), getDynamicUpperBound(), b);
1602}
1603
1604// Get steps as OpFoldResult.
1605std::optional<SmallVector<OpFoldResult>> ForallOp::getLoopSteps() {
1606 Builder b(getOperation()->getContext());
1607 return getMixedValues(getStaticStep(), getDynamicStep(), b);
1608}
1609
1611 auto tidxArg = llvm::dyn_cast<BlockArgument>(val);
1612 if (!tidxArg)
1613 return ForallOp();
1614 assert(tidxArg.getOwner() && "unlinked block argument");
1615 auto *containingOp = tidxArg.getOwner()->getParentOp();
1616 return dyn_cast<ForallOp>(containingOp);
1617}
1618
1619namespace {
1620/// Fold tensor.dim(forall shared_outs(... = %t)) to tensor.dim(%t).
1621struct DimOfForallOp : public OpRewritePattern<tensor::DimOp> {
1622 using OpRewritePattern<tensor::DimOp>::OpRewritePattern;
1623
1624 LogicalResult matchAndRewrite(tensor::DimOp dimOp,
1625 PatternRewriter &rewriter) const final {
1626 auto forallOp = dimOp.getSource().getDefiningOp<ForallOp>();
1627 if (!forallOp)
1628 return failure();
1629 Value sharedOut =
1630 forallOp.getTiedOpOperand(llvm::cast<OpResult>(dimOp.getSource()))
1631 ->get();
1632 rewriter.modifyOpInPlace(
1633 dimOp, [&]() { dimOp.getSourceMutable().assign(sharedOut); });
1634 return success();
1635 }
1636};
1637
1638class ForallOpControlOperandsFolder : public OpRewritePattern<ForallOp> {
1639public:
1640 using OpRewritePattern<ForallOp>::OpRewritePattern;
1641
1642 LogicalResult matchAndRewrite(ForallOp op,
1643 PatternRewriter &rewriter) const override {
1644 SmallVector<OpFoldResult> mixedLowerBound(op.getMixedLowerBound());
1645 SmallVector<OpFoldResult> mixedUpperBound(op.getMixedUpperBound());
1646 SmallVector<OpFoldResult> mixedStep(op.getMixedStep());
1647 if (failed(foldDynamicIndexList(mixedLowerBound)) &&
1648 failed(foldDynamicIndexList(mixedUpperBound)) &&
1649 failed(foldDynamicIndexList(mixedStep)))
1650 return failure();
1651
1652 rewriter.modifyOpInPlace(op, [&]() {
1653 SmallVector<Value> dynamicLowerBound, dynamicUpperBound, dynamicStep;
1654 SmallVector<int64_t> staticLowerBound, staticUpperBound, staticStep;
1655 dispatchIndexOpFoldResults(mixedLowerBound, dynamicLowerBound,
1656 staticLowerBound);
1657 op.getDynamicLowerBoundMutable().assign(dynamicLowerBound);
1658 op.setStaticLowerBound(staticLowerBound);
1659
1660 dispatchIndexOpFoldResults(mixedUpperBound, dynamicUpperBound,
1661 staticUpperBound);
1662 op.getDynamicUpperBoundMutable().assign(dynamicUpperBound);
1663 op.setStaticUpperBound(staticUpperBound);
1664
1665 dispatchIndexOpFoldResults(mixedStep, dynamicStep, staticStep);
1666 op.getDynamicStepMutable().assign(dynamicStep);
1667 op.setStaticStep(staticStep);
1668
1669 op->setAttr(ForallOp::getOperandSegmentSizeAttr(),
1670 rewriter.getDenseI32ArrayAttr(
1671 {static_cast<int32_t>(dynamicLowerBound.size()),
1672 static_cast<int32_t>(dynamicUpperBound.size()),
1673 static_cast<int32_t>(dynamicStep.size()),
1674 static_cast<int32_t>(op.getNumResults())}));
1675 });
1676 return success();
1677 }
1678};
1679
1680/// The following canonicalization pattern folds the iter arguments of
1681/// scf.forall op if :-
1682/// 1. The corresponding result has zero uses.
1683/// 2. The iter argument is NOT being modified within the loop body.
1684/// uses.
1685///
1686/// Example of first case :-
1687/// INPUT:
1688/// %res:3 = scf.forall ... shared_outs(%arg0 = %a, %arg1 = %b, %arg2 = %c)
1689/// {
1690/// ...
1691/// <SOME USE OF %arg0>
1692/// <SOME USE OF %arg1>
1693/// <SOME USE OF %arg2>
1694/// ...
1695/// scf.forall.in_parallel {
1696/// <STORE OP WITH DESTINATION %arg1>
1697/// <STORE OP WITH DESTINATION %arg0>
1698/// <STORE OP WITH DESTINATION %arg2>
1699/// }
1700/// }
1701/// return %res#1
1702///
1703/// OUTPUT:
1704/// %res:3 = scf.forall ... shared_outs(%new_arg0 = %b)
1705/// {
1706/// ...
1707/// <SOME USE OF %a>
1708/// <SOME USE OF %new_arg0>
1709/// <SOME USE OF %c>
1710/// ...
1711/// scf.forall.in_parallel {
1712/// <STORE OP WITH DESTINATION %new_arg0>
1713/// }
1714/// }
1715/// return %res
1716///
1717/// NOTE: 1. All uses of the folded shared_outs (iter argument) within the
1718/// scf.forall is replaced by their corresponding operands.
1719/// 2. Even if there are <STORE OP WITH DESTINATION *> ops within the body
1720/// of the scf.forall besides within scf.forall.in_parallel terminator,
1721/// this canonicalization remains valid. For more details, please refer
1722/// to :
1723/// https://github.com/llvm/llvm-project/pull/90189#discussion_r1589011124
1724/// 3. TODO(avarma): Generalize it for other store ops. Currently it
1725/// handles tensor.parallel_insert_slice ops only.
1726///
1727/// Example of second case :-
1728/// INPUT:
1729/// %res:2 = scf.forall ... shared_outs(%arg0 = %a, %arg1 = %b)
1730/// {
1731/// ...
1732/// <SOME USE OF %arg0>
1733/// <SOME USE OF %arg1>
1734/// ...
1735/// scf.forall.in_parallel {
1736/// <STORE OP WITH DESTINATION %arg1>
1737/// }
1738/// }
1739/// return %res#0, %res#1
1740///
1741/// OUTPUT:
1742/// %res = scf.forall ... shared_outs(%new_arg0 = %b)
1743/// {
1744/// ...
1745/// <SOME USE OF %a>
1746/// <SOME USE OF %new_arg0>
1747/// ...
1748/// scf.forall.in_parallel {
1749/// <STORE OP WITH DESTINATION %new_arg0>
1750/// }
1751/// }
1752/// return %a, %res
1753struct ForallOpIterArgsFolder : public OpRewritePattern<ForallOp> {
1754 using OpRewritePattern<ForallOp>::OpRewritePattern;
1755
1756 LogicalResult matchAndRewrite(ForallOp forallOp,
1757 PatternRewriter &rewriter) const final {
1758 // Step 1: For a given i-th result of scf.forall, check the following :-
1759 // a. If it has any use.
1760 // b. If the corresponding iter argument is being modified within
1761 // the loop, i.e. has at least one store op with the iter arg as
1762 // its destination operand. For this we use
1763 // ForallOp::getCombiningOps(iter_arg).
1764 //
1765 // Based on the check we maintain the following :-
1766 // a. op results, block arguments, outputs to delete
1767 // b. new outputs (i.e., outputs to retain)
1768 SmallVector<Value> resultsToDelete;
1769 SmallVector<Value> outsToDelete;
1770 SmallVector<BlockArgument> blockArgsToDelete;
1771 SmallVector<Value> newOuts;
1772 BitVector resultIndicesToDelete(forallOp.getNumResults(), false);
1773 BitVector blockIndicesToDelete(forallOp.getBody()->getNumArguments(),
1774 false);
1775 for (OpResult result : forallOp.getResults()) {
1776 OpOperand *opOperand = forallOp.getTiedOpOperand(result);
1777 BlockArgument blockArg = forallOp.getTiedBlockArgument(opOperand);
1778 if (result.use_empty() || forallOp.getCombiningOps(blockArg).empty()) {
1779 resultsToDelete.push_back(result);
1780 outsToDelete.push_back(opOperand->get());
1781 blockArgsToDelete.push_back(blockArg);
1782 resultIndicesToDelete[result.getResultNumber()] = true;
1783 blockIndicesToDelete[blockArg.getArgNumber()] = true;
1784 } else {
1785 newOuts.push_back(opOperand->get());
1786 }
1787 }
1788
1789 // Return early if all results of scf.forall have at least one use and being
1790 // modified within the loop.
1791 if (resultsToDelete.empty())
1792 return failure();
1793
1794 // Step 2: Erase combining ops and replace uses of deleted results and
1795 // block arguments with the corresponding outputs.
1796 for (auto blockArg : blockArgsToDelete) {
1797 SmallVector<Operation *> combiningOps =
1798 forallOp.getCombiningOps(blockArg);
1799 for (Operation *combiningOp : combiningOps)
1800 rewriter.eraseOp(combiningOp);
1801 }
1802 for (auto [blockArg, result, out] :
1803 llvm::zip_equal(blockArgsToDelete, resultsToDelete, outsToDelete)) {
1804 rewriter.replaceAllUsesWith(blockArg, out);
1805 rewriter.replaceAllUsesWith(result, out);
1806 }
1807 // TODO: There is no rewriter API for erasing block arguments.
1808 rewriter.modifyOpInPlace(forallOp, [&]() {
1809 forallOp.getBody()->eraseArguments(blockIndicesToDelete);
1810 });
1811
1812 // Step 3. Create a new scf.forall op with only the shared_outs/results
1813 // that should be retained.
1814 auto newForallOp = cast<scf::ForallOp>(
1815 rewriter.eraseOpResults(forallOp, resultIndicesToDelete));
1816 newForallOp.getOutputsMutable().assign(newOuts);
1817
1818 return success();
1819 }
1820};
1821
1822struct ForallOpSingleOrZeroIterationDimsFolder
1823 : public OpRewritePattern<ForallOp> {
1824 using OpRewritePattern<ForallOp>::OpRewritePattern;
1825
1826 LogicalResult matchAndRewrite(ForallOp op,
1827 PatternRewriter &rewriter) const override {
1828 // Do not fold dimensions if they are mapped to processing units.
1829 if (op.getMapping().has_value() && !op.getMapping()->empty())
1830 return failure();
1831 Location loc = op.getLoc();
1832
1833 // Compute new loop bounds that omit all single-iteration loop dimensions.
1834 SmallVector<OpFoldResult> newMixedLowerBounds, newMixedUpperBounds,
1835 newMixedSteps;
1836 IRMapping mapping;
1837 for (auto [lb, ub, step, iv] :
1838 llvm::zip(op.getMixedLowerBound(), op.getMixedUpperBound(),
1839 op.getMixedStep(), op.getInductionVars())) {
1840 auto numIterations =
1841 constantTripCount(lb, ub, step, /*isSigned=*/true, computeUbMinusLb);
1842 if (numIterations.has_value()) {
1843 // Remove the loop if it performs zero iterations.
1844 if (*numIterations == 0) {
1845 rewriter.replaceOp(op, op.getOutputs());
1846 return success();
1847 }
1848 // Replace the loop induction variable by the lower bound if the loop
1849 // performs a single iteration. Otherwise, copy the loop bounds.
1850 if (*numIterations == 1) {
1851 mapping.map(iv, getValueOrCreateConstantIndexOp(rewriter, loc, lb));
1852 continue;
1853 }
1854 }
1855 newMixedLowerBounds.push_back(lb);
1856 newMixedUpperBounds.push_back(ub);
1857 newMixedSteps.push_back(step);
1858 }
1859
1860 // All of the loop dimensions perform a single iteration. Inline loop body.
1861 if (newMixedLowerBounds.empty()) {
1862 promote(rewriter, op);
1863 return success();
1864 }
1865
1866 // Exit if none of the loop dimensions perform a single iteration.
1867 if (newMixedLowerBounds.size() == static_cast<unsigned>(op.getRank())) {
1868 return rewriter.notifyMatchFailure(
1869 op, "no dimensions have 0 or 1 iterations");
1870 }
1871
1872 // Replace the loop by a lower-dimensional loop.
1873 ForallOp newOp;
1874 newOp = ForallOp::create(rewriter, loc, newMixedLowerBounds,
1875 newMixedUpperBounds, newMixedSteps,
1876 op.getOutputs(), std::nullopt, nullptr);
1877 newOp.getBodyRegion().getBlocks().clear();
1878 // The new loop needs to keep all attributes from the old one, except for
1879 // "operandSegmentSizes" and static loop bound attributes which capture
1880 // the outdated information of the old iteration domain.
1881 SmallVector<StringAttr> elidedAttrs{newOp.getOperandSegmentSizesAttrName(),
1882 newOp.getStaticLowerBoundAttrName(),
1883 newOp.getStaticUpperBoundAttrName(),
1884 newOp.getStaticStepAttrName()};
1885 for (const auto &namedAttr : op->getAttrs()) {
1886 if (llvm::is_contained(elidedAttrs, namedAttr.getName()))
1887 continue;
1888 rewriter.modifyOpInPlace(newOp, [&]() {
1889 newOp->setAttr(namedAttr.getName(), namedAttr.getValue());
1890 });
1891 }
1892 rewriter.cloneRegionBefore(op.getRegion(), newOp.getRegion(),
1893 newOp.getRegion().begin(), mapping);
1894 rewriter.replaceOp(op, newOp.getResults());
1895 return success();
1896 }
1897};
1898
1899/// Replace all induction vars with a single trip count with their lower bound.
1900struct ForallOpReplaceConstantInductionVar : public OpRewritePattern<ForallOp> {
1901 using OpRewritePattern<ForallOp>::OpRewritePattern;
1902
1903 LogicalResult matchAndRewrite(ForallOp op,
1904 PatternRewriter &rewriter) const override {
1905 Location loc = op.getLoc();
1906 bool changed = false;
1907 for (auto [lb, ub, step, iv] :
1908 llvm::zip(op.getMixedLowerBound(), op.getMixedUpperBound(),
1909 op.getMixedStep(), op.getInductionVars())) {
1910 if (iv.hasNUses(0))
1911 continue;
1912 auto numIterations =
1913 constantTripCount(lb, ub, step, /*isSigned=*/true, computeUbMinusLb);
1914 if (!numIterations.has_value() || numIterations.value() != 1) {
1915 continue;
1916 }
1917 rewriter.replaceAllUsesWith(
1918 iv, getValueOrCreateConstantIndexOp(rewriter, loc, lb));
1919 changed = true;
1920 }
1921 return success(changed);
1922 }
1923};
1924
1925struct FoldTensorCastOfOutputIntoForallOp
1926 : public OpRewritePattern<scf::ForallOp> {
1927 using OpRewritePattern<scf::ForallOp>::OpRewritePattern;
1928
1929 struct TypeCast {
1930 Type srcType;
1931 Type dstType;
1932 };
1933
1934 LogicalResult matchAndRewrite(scf::ForallOp forallOp,
1935 PatternRewriter &rewriter) const final {
1936 llvm::SmallMapVector<unsigned, TypeCast, 2> tensorCastProducers;
1937 llvm::SmallVector<Value> newOutputTensors = forallOp.getOutputs();
1938 for (auto en : llvm::enumerate(newOutputTensors)) {
1939 auto castOp = en.value().getDefiningOp<tensor::CastOp>();
1940 if (!castOp)
1941 continue;
1942
1943 // Only casts that that preserve static information, i.e. will make the
1944 // loop result type "more" static than before, will be folded.
1945 if (!tensor::preservesStaticInformation(castOp.getDest().getType(),
1946 castOp.getSource().getType())) {
1947 continue;
1948 }
1949
1950 tensorCastProducers[en.index()] =
1951 TypeCast{castOp.getSource().getType(), castOp.getType()};
1952 newOutputTensors[en.index()] = castOp.getSource();
1953 }
1954
1955 if (tensorCastProducers.empty())
1956 return failure();
1957
1958 // Create new loop.
1959 Location loc = forallOp.getLoc();
1960 auto newForallOp = ForallOp::create(
1961 rewriter, loc, forallOp.getMixedLowerBound(),
1962 forallOp.getMixedUpperBound(), forallOp.getMixedStep(),
1963 newOutputTensors, forallOp.getMapping(),
1964 [&](OpBuilder nestedBuilder, Location nestedLoc, ValueRange bbArgs) {
1965 auto castBlockArgs =
1966 llvm::to_vector(bbArgs.take_back(forallOp->getNumResults()));
1967 for (auto [index, cast] : tensorCastProducers) {
1968 Value &oldTypeBBArg = castBlockArgs[index];
1969 oldTypeBBArg = tensor::CastOp::create(nestedBuilder, nestedLoc,
1970 cast.dstType, oldTypeBBArg);
1971 }
1972
1973 // Move old body into new parallel loop.
1974 SmallVector<Value> ivsBlockArgs =
1975 llvm::to_vector(bbArgs.take_front(forallOp.getRank()));
1976 ivsBlockArgs.append(castBlockArgs);
1977 rewriter.mergeBlocks(forallOp.getBody(),
1978 bbArgs.front().getParentBlock(), ivsBlockArgs);
1979 });
1980
1981 // After `mergeBlocks` happened, the destinations in the terminator were
1982 // mapped to the tensor.cast old-typed results of the output bbArgs. The
1983 // destination have to be updated to point to the output bbArgs directly.
1984 auto terminator = newForallOp.getTerminator();
1985 for (auto [yieldingOp, outputBlockArg] : llvm::zip(
1986 terminator.getYieldingOps(), newForallOp.getRegionIterArgs())) {
1987 if (auto parallelCombingingOp =
1988 dyn_cast<ParallelCombiningOpInterface>(yieldingOp)) {
1989 parallelCombingingOp.getUpdatedDestinations().assign(outputBlockArg);
1990 }
1991 }
1992
1993 // Cast results back to the original types.
1994 rewriter.setInsertionPointAfter(newForallOp);
1995 SmallVector<Value> castResults = newForallOp.getResults();
1996 for (auto &item : tensorCastProducers) {
1997 Value &oldTypeResult = castResults[item.first];
1998 oldTypeResult = tensor::CastOp::create(rewriter, loc, item.second.dstType,
1999 oldTypeResult);
2000 }
2001 rewriter.replaceOp(forallOp, castResults);
2002 return success();
2003 }
2004};
2005
2006} // namespace
2007
2008void ForallOp::getCanonicalizationPatterns(RewritePatternSet &results,
2009 MLIRContext *context) {
2010 results.add<DimOfForallOp, FoldTensorCastOfOutputIntoForallOp,
2011 ForallOpControlOperandsFolder, ForallOpIterArgsFolder,
2012 ForallOpSingleOrZeroIterationDimsFolder,
2013 ForallOpReplaceConstantInductionVar>(context);
2014}
2015
2016/// Given the region at `index`, or the parent operation if `index` is None,
2017/// return the successor regions. These are the regions that may be selected
2018/// during the flow of control. `operands` is a set of optional attributes that
2019/// correspond to a constant value for each operand, or null if that operand is
2020/// not a constant.
2021void ForallOp::getSuccessorRegions(RegionBranchPoint point,
2022 SmallVectorImpl<RegionSuccessor> &regions) {
2023 // In accordance with the semantics of forall, its body is executed in
2024 // parallel by multiple threads. We should not expect to branch back into
2025 // the forall body after the region's execution is complete.
2026 if (point.isParent())
2027 regions.push_back(RegionSuccessor(&getRegion(), getRegionIterArgs()));
2028 else
2029 regions.push_back(
2030 RegionSuccessor(getOperation(), getOperation()->getResults()));
2031}
2032
2033//===----------------------------------------------------------------------===//
2034// InParallelOp
2035//===----------------------------------------------------------------------===//
2036
2037// Build a InParallelOp with mixed static and dynamic entries.
2038void InParallelOp::build(OpBuilder &b, OperationState &result) {
2039 OpBuilder::InsertionGuard g(b);
2040 Region *bodyRegion = result.addRegion();
2041 b.createBlock(bodyRegion);
2042}
2043
2044LogicalResult InParallelOp::verify() {
2045 scf::ForallOp forallOp =
2046 dyn_cast<scf::ForallOp>(getOperation()->getParentOp());
2047 if (!forallOp)
2048 return this->emitOpError("expected forall op parent");
2049
2050 for (Operation &op : getRegion().front().getOperations()) {
2051 auto parallelCombiningOp = dyn_cast<ParallelCombiningOpInterface>(&op);
2052 if (!parallelCombiningOp) {
2053 return this->emitOpError("expected only ParallelCombiningOpInterface")
2054 << " ops";
2055 }
2056
2057 // Verify that inserts are into out block arguments.
2058 MutableOperandRange dests = parallelCombiningOp.getUpdatedDestinations();
2059 ArrayRef<BlockArgument> regionOutArgs = forallOp.getRegionOutArgs();
2060 for (OpOperand &dest : dests) {
2061 if (!llvm::is_contained(regionOutArgs, dest.get()))
2062 return op.emitOpError("may only insert into an output block argument");
2063 }
2064 }
2065
2066 return success();
2067}
2068
2069void InParallelOp::print(OpAsmPrinter &p) {
2070 p << " ";
2071 p.printRegion(getRegion(),
2072 /*printEntryBlockArgs=*/false,
2073 /*printBlockTerminators=*/false);
2074 p.printOptionalAttrDict(getOperation()->getAttrs());
2075}
2076
2077ParseResult InParallelOp::parse(OpAsmParser &parser, OperationState &result) {
2078 auto &builder = parser.getBuilder();
2079
2080 SmallVector<OpAsmParser::Argument, 8> regionOperands;
2081 std::unique_ptr<Region> region = std::make_unique<Region>();
2082 if (parser.parseRegion(*region, regionOperands))
2083 return failure();
2084
2085 if (region->empty())
2086 OpBuilder(builder.getContext()).createBlock(region.get());
2087 result.addRegion(std::move(region));
2088
2089 // Parse the optional attribute list.
2090 if (parser.parseOptionalAttrDict(result.attributes))
2091 return failure();
2092 return success();
2093}
2094
2095OpResult InParallelOp::getParentResult(int64_t idx) {
2096 return getOperation()->getParentOp()->getResult(idx);
2097}
2098
2099SmallVector<BlockArgument> InParallelOp::getDests() {
2100 SmallVector<BlockArgument> updatedDests;
2101 for (Operation &yieldingOp : getYieldingOps()) {
2102 auto parallelCombiningOp =
2103 dyn_cast<ParallelCombiningOpInterface>(&yieldingOp);
2104 if (!parallelCombiningOp)
2105 continue;
2106 for (OpOperand &updatedOperand :
2107 parallelCombiningOp.getUpdatedDestinations())
2108 updatedDests.push_back(cast<BlockArgument>(updatedOperand.get()));
2109 }
2110 return updatedDests;
2111}
2112
2113llvm::iterator_range<Block::iterator> InParallelOp::getYieldingOps() {
2114 return getRegion().front().getOperations();
2115}
2116
2117//===----------------------------------------------------------------------===//
2118// IfOp
2119//===----------------------------------------------------------------------===//
2120
2122 assert(a && "expected non-empty operation");
2123 assert(b && "expected non-empty operation");
2124
2125 IfOp ifOp = a->getParentOfType<IfOp>();
2126 while (ifOp) {
2127 // Check if b is inside ifOp. (We already know that a is.)
2128 if (ifOp->isProperAncestor(b))
2129 // b is contained in ifOp. a and b are in mutually exclusive branches if
2130 // they are in different blocks of ifOp.
2131 return static_cast<bool>(ifOp.thenBlock()->findAncestorOpInBlock(*a)) !=
2132 static_cast<bool>(ifOp.thenBlock()->findAncestorOpInBlock(*b));
2133 // Check next enclosing IfOp.
2134 ifOp = ifOp->getParentOfType<IfOp>();
2135 }
2136
2137 // Could not find a common IfOp among a's and b's ancestors.
2138 return false;
2139}
2140
2141LogicalResult
2142IfOp::inferReturnTypes(MLIRContext *ctx, std::optional<Location> loc,
2143 IfOp::Adaptor adaptor,
2144 SmallVectorImpl<Type> &inferredReturnTypes) {
2145 if (adaptor.getRegions().empty())
2146 return failure();
2147 Region *r = &adaptor.getThenRegion();
2148 if (r->empty())
2149 return failure();
2150 Block &b = r->front();
2151 if (b.empty())
2152 return failure();
2153 auto yieldOp = llvm::dyn_cast<YieldOp>(b.back());
2154 if (!yieldOp)
2155 return failure();
2156 TypeRange types = yieldOp.getOperandTypes();
2157 llvm::append_range(inferredReturnTypes, types);
2158 return success();
2159}
2160
2161void IfOp::build(OpBuilder &builder, OperationState &result,
2162 TypeRange resultTypes, Value cond) {
2163 return build(builder, result, resultTypes, cond, /*addThenBlock=*/false,
2164 /*addElseBlock=*/false);
2165}
2166
2167void IfOp::build(OpBuilder &builder, OperationState &result,
2168 TypeRange resultTypes, Value cond, bool addThenBlock,
2169 bool addElseBlock) {
2170 assert((!addElseBlock || addThenBlock) &&
2171 "must not create else block w/o then block");
2172 result.addTypes(resultTypes);
2173 result.addOperands(cond);
2174
2175 // Add regions and blocks.
2176 OpBuilder::InsertionGuard guard(builder);
2177 Region *thenRegion = result.addRegion();
2178 if (addThenBlock)
2179 builder.createBlock(thenRegion);
2180 Region *elseRegion = result.addRegion();
2181 if (addElseBlock)
2182 builder.createBlock(elseRegion);
2183}
2184
2185void IfOp::build(OpBuilder &builder, OperationState &result, Value cond,
2186 bool withElseRegion) {
2187 build(builder, result, TypeRange{}, cond, withElseRegion);
2188}
2189
2190void IfOp::build(OpBuilder &builder, OperationState &result,
2191 TypeRange resultTypes, Value cond, bool withElseRegion) {
2192 result.addTypes(resultTypes);
2193 result.addOperands(cond);
2194
2195 // Build then region.
2196 OpBuilder::InsertionGuard guard(builder);
2197 Region *thenRegion = result.addRegion();
2198 builder.createBlock(thenRegion);
2199 if (resultTypes.empty())
2200 IfOp::ensureTerminator(*thenRegion, builder, result.location);
2201
2202 // Build else region.
2203 Region *elseRegion = result.addRegion();
2204 if (withElseRegion) {
2205 builder.createBlock(elseRegion);
2206 if (resultTypes.empty())
2207 IfOp::ensureTerminator(*elseRegion, builder, result.location);
2208 }
2209}
2210
2211void IfOp::build(OpBuilder &builder, OperationState &result, Value cond,
2212 function_ref<void(OpBuilder &, Location)> thenBuilder,
2213 function_ref<void(OpBuilder &, Location)> elseBuilder) {
2214 assert(thenBuilder && "the builder callback for 'then' must be present");
2215 result.addOperands(cond);
2216
2217 // Build then region.
2218 OpBuilder::InsertionGuard guard(builder);
2219 Region *thenRegion = result.addRegion();
2220 builder.createBlock(thenRegion);
2221 thenBuilder(builder, result.location);
2222
2223 // Build else region.
2224 Region *elseRegion = result.addRegion();
2225 if (elseBuilder) {
2226 builder.createBlock(elseRegion);
2227 elseBuilder(builder, result.location);
2228 }
2229
2230 // Infer result types.
2231 SmallVector<Type> inferredReturnTypes;
2232 MLIRContext *ctx = builder.getContext();
2233 auto attrDict = DictionaryAttr::get(ctx, result.attributes);
2234 if (succeeded(inferReturnTypes(ctx, std::nullopt, result.operands, attrDict,
2235 /*properties=*/nullptr, result.regions,
2236 inferredReturnTypes))) {
2237 result.addTypes(inferredReturnTypes);
2238 }
2239}
2240
2241LogicalResult IfOp::verify() {
2242 if (getNumResults() != 0 && getElseRegion().empty())
2243 return emitOpError("must have an else block if defining values");
2244 return success();
2245}
2246
2247ParseResult IfOp::parse(OpAsmParser &parser, OperationState &result) {
2248 // Create the regions for 'then'.
2249 result.regions.reserve(2);
2250 Region *thenRegion = result.addRegion();
2251 Region *elseRegion = result.addRegion();
2252
2253 auto &builder = parser.getBuilder();
2254 OpAsmParser::UnresolvedOperand cond;
2255 Type i1Type = builder.getIntegerType(1);
2256 if (parser.parseOperand(cond) ||
2257 parser.resolveOperand(cond, i1Type, result.operands))
2258 return failure();
2259 // Parse optional results type list.
2260 if (parser.parseOptionalArrowTypeList(result.types))
2261 return failure();
2262 // Parse the 'then' region.
2263 if (parser.parseRegion(*thenRegion, /*arguments=*/{}, /*argTypes=*/{}))
2264 return failure();
2265 IfOp::ensureTerminator(*thenRegion, parser.getBuilder(), result.location);
2266
2267 // If we find an 'else' keyword then parse the 'else' region.
2268 if (!parser.parseOptionalKeyword("else")) {
2269 if (parser.parseRegion(*elseRegion, /*arguments=*/{}, /*argTypes=*/{}))
2270 return failure();
2271 IfOp::ensureTerminator(*elseRegion, parser.getBuilder(), result.location);
2272 }
2273
2274 // Parse the optional attribute list.
2275 if (parser.parseOptionalAttrDict(result.attributes))
2276 return failure();
2277 return success();
2278}
2279
2280void IfOp::print(OpAsmPrinter &p) {
2281 bool printBlockTerminators = false;
2282
2283 p << " " << getCondition();
2284 if (!getResults().empty()) {
2285 p << " -> (" << getResultTypes() << ")";
2286 // Print yield explicitly if the op defines values.
2287 printBlockTerminators = true;
2288 }
2289 p << ' ';
2290 p.printRegion(getThenRegion(),
2291 /*printEntryBlockArgs=*/false,
2292 /*printBlockTerminators=*/printBlockTerminators);
2293
2294 // Print the 'else' regions if it exists and has a block.
2295 auto &elseRegion = getElseRegion();
2296 if (!elseRegion.empty()) {
2297 p << " else ";
2298 p.printRegion(elseRegion,
2299 /*printEntryBlockArgs=*/false,
2300 /*printBlockTerminators=*/printBlockTerminators);
2301 }
2302
2303 p.printOptionalAttrDict((*this)->getAttrs());
2304}
2305
2306void IfOp::getSuccessorRegions(RegionBranchPoint point,
2307 SmallVectorImpl<RegionSuccessor> &regions) {
2308 // The `then` and the `else` region branch back to the parent operation or one
2309 // of the recursive parent operations (early exit case).
2310 if (!point.isParent()) {
2311 regions.push_back(RegionSuccessor(getOperation(), getResults()));
2312 return;
2313 }
2314
2315 regions.push_back(RegionSuccessor(&getThenRegion()));
2316
2317 // Don't consider the else region if it is empty.
2318 Region *elseRegion = &this->getElseRegion();
2319 if (elseRegion->empty())
2320 regions.push_back(
2321 RegionSuccessor(getOperation(), getOperation()->getResults()));
2322 else
2323 regions.push_back(RegionSuccessor(elseRegion));
2324}
2325
2326void IfOp::getEntrySuccessorRegions(ArrayRef<Attribute> operands,
2327 SmallVectorImpl<RegionSuccessor> &regions) {
2328 FoldAdaptor adaptor(operands, *this);
2329 auto boolAttr = dyn_cast_or_null<BoolAttr>(adaptor.getCondition());
2330 if (!boolAttr || boolAttr.getValue())
2331 regions.emplace_back(&getThenRegion());
2332
2333 // If the else region is empty, execution continues after the parent op.
2334 if (!boolAttr || !boolAttr.getValue()) {
2335 if (!getElseRegion().empty())
2336 regions.emplace_back(&getElseRegion());
2337 else
2338 regions.emplace_back(getOperation(), getResults());
2339 }
2340}
2341
2342LogicalResult IfOp::fold(FoldAdaptor adaptor,
2343 SmallVectorImpl<OpFoldResult> &results) {
2344 // if (!c) then A() else B() -> if c then B() else A()
2345 if (getElseRegion().empty())
2346 return failure();
2347
2348 arith::XOrIOp xorStmt = getCondition().getDefiningOp<arith::XOrIOp>();
2349 if (!xorStmt)
2350 return failure();
2351
2352 if (!matchPattern(xorStmt.getRhs(), m_One()))
2353 return failure();
2354
2355 getConditionMutable().assign(xorStmt.getLhs());
2356 Block *thenBlock = &getThenRegion().front();
2357 // It would be nicer to use iplist::swap, but that has no implemented
2358 // callbacks See: https://llvm.org/doxygen/ilist_8h_source.html#l00224
2359 getThenRegion().getBlocks().splice(getThenRegion().getBlocks().begin(),
2360 getElseRegion().getBlocks());
2361 getElseRegion().getBlocks().splice(getElseRegion().getBlocks().begin(),
2362 getThenRegion().getBlocks(), thenBlock);
2363 return success();
2364}
2365
2366void IfOp::getRegionInvocationBounds(
2367 ArrayRef<Attribute> operands,
2368 SmallVectorImpl<InvocationBounds> &invocationBounds) {
2369 if (auto cond = llvm::dyn_cast_or_null<BoolAttr>(operands[0])) {
2370 // If the condition is known, then one region is known to be executed once
2371 // and the other zero times.
2372 invocationBounds.emplace_back(0, cond.getValue() ? 1 : 0);
2373 invocationBounds.emplace_back(0, cond.getValue() ? 0 : 1);
2374 } else {
2375 // Non-constant condition. Each region may be executed 0 or 1 times.
2376 invocationBounds.assign(2, {0, 1});
2377 }
2378}
2379
2380namespace {
2381// Pattern to remove unused IfOp results.
2382struct RemoveUnusedResults : public OpRewritePattern<IfOp> {
2383 using OpRewritePattern<IfOp>::OpRewritePattern;
2384
2385 LogicalResult matchAndRewrite(IfOp op,
2386 PatternRewriter &rewriter) const override {
2387 // Compute the list of unused results.
2388 BitVector toErase(op.getNumResults(), false);
2389 for (auto [idx, result] : llvm::enumerate(op.getResults()))
2390 if (result.use_empty())
2391 toErase[idx] = true;
2392 if (toErase.none())
2393 return rewriter.notifyMatchFailure(op, "no results to erase");
2394
2395 // Erase results.
2396 auto newOp = cast<scf::IfOp>(rewriter.eraseOpResults(op, toErase));
2397
2398 // Erase operands.
2399 rewriter.modifyOpInPlace(newOp.thenYield(), [&]() {
2400 newOp.thenYield()->eraseOperands(toErase);
2401 });
2402 rewriter.modifyOpInPlace(newOp.elseYield(), [&]() {
2403 newOp.elseYield()->eraseOperands(toErase);
2404 });
2405
2406 return success();
2407 }
2408};
2409
2410struct RemoveStaticCondition : public OpRewritePattern<IfOp> {
2411 using OpRewritePattern<IfOp>::OpRewritePattern;
2412
2413 LogicalResult matchAndRewrite(IfOp op,
2414 PatternRewriter &rewriter) const override {
2415 BoolAttr condition;
2416 if (!matchPattern(op.getCondition(), m_Constant(&condition)))
2417 return failure();
2418
2419 if (condition.getValue())
2420 replaceOpWithRegion(rewriter, op, op.getThenRegion());
2421 else if (!op.getElseRegion().empty())
2422 replaceOpWithRegion(rewriter, op, op.getElseRegion());
2423 else
2424 rewriter.eraseOp(op);
2425
2426 return success();
2427 }
2428};
2429
2430/// Hoist any yielded results whose operands are defined outside
2431/// the if, to a select instruction.
2432struct ConvertTrivialIfToSelect : public OpRewritePattern<IfOp> {
2433 using OpRewritePattern<IfOp>::OpRewritePattern;
2434
2435 LogicalResult matchAndRewrite(IfOp op,
2436 PatternRewriter &rewriter) const override {
2437 if (op->getNumResults() == 0)
2438 return failure();
2439
2440 auto cond = op.getCondition();
2441 auto thenYieldArgs = op.thenYield().getOperands();
2442 auto elseYieldArgs = op.elseYield().getOperands();
2443
2444 SmallVector<Type> nonHoistable;
2445 for (auto [trueVal, falseVal] : llvm::zip(thenYieldArgs, elseYieldArgs)) {
2446 if (&op.getThenRegion() == trueVal.getParentRegion() ||
2447 &op.getElseRegion() == falseVal.getParentRegion())
2448 nonHoistable.push_back(trueVal.getType());
2449 }
2450 // Early exit if there aren't any yielded values we can
2451 // hoist outside the if.
2452 if (nonHoistable.size() == op->getNumResults())
2453 return failure();
2454
2455 IfOp replacement = IfOp::create(rewriter, op.getLoc(), nonHoistable, cond,
2456 /*withElseRegion=*/false);
2457 if (replacement.thenBlock())
2458 rewriter.eraseBlock(replacement.thenBlock());
2459 replacement.getThenRegion().takeBody(op.getThenRegion());
2460 replacement.getElseRegion().takeBody(op.getElseRegion());
2461
2462 SmallVector<Value> results(op->getNumResults());
2463 assert(thenYieldArgs.size() == results.size());
2464 assert(elseYieldArgs.size() == results.size());
2465
2466 SmallVector<Value> trueYields;
2467 SmallVector<Value> falseYields;
2469 for (const auto &it :
2470 llvm::enumerate(llvm::zip(thenYieldArgs, elseYieldArgs))) {
2471 Value trueVal = std::get<0>(it.value());
2472 Value falseVal = std::get<1>(it.value());
2473 if (&replacement.getThenRegion() == trueVal.getParentRegion() ||
2474 &replacement.getElseRegion() == falseVal.getParentRegion()) {
2475 results[it.index()] = replacement.getResult(trueYields.size());
2476 trueYields.push_back(trueVal);
2477 falseYields.push_back(falseVal);
2478 } else if (trueVal == falseVal)
2479 results[it.index()] = trueVal;
2480 else
2481 results[it.index()] = arith::SelectOp::create(rewriter, op.getLoc(),
2482 cond, trueVal, falseVal);
2483 }
2484
2485 rewriter.setInsertionPointToEnd(replacement.thenBlock());
2486 rewriter.replaceOpWithNewOp<YieldOp>(replacement.thenYield(), trueYields);
2487
2488 rewriter.setInsertionPointToEnd(replacement.elseBlock());
2489 rewriter.replaceOpWithNewOp<YieldOp>(replacement.elseYield(), falseYields);
2490
2491 rewriter.replaceOp(op, results);
2492 return success();
2493 }
2494};
2495
2496/// Allow the true region of an if to assume the condition is true
2497/// and vice versa. For example:
2498///
2499/// scf.if %cmp {
2500/// print(%cmp)
2501/// }
2502///
2503/// becomes
2504///
2505/// scf.if %cmp {
2506/// print(true)
2507/// }
2508///
2509struct ConditionPropagation : public OpRewritePattern<IfOp> {
2510 using OpRewritePattern<IfOp>::OpRewritePattern;
2511
2512 /// Kind of parent region in the ancestor cache.
2513 enum class Parent { Then, Else, None };
2514
2515 /// Returns the kind of region ("then", "else", or "none") of the
2516 /// IfOp that the given region is transitively nested in. Updates
2517 /// the cache accordingly.
2518 static Parent getParentType(Region *toCheck, IfOp op,
2520 Region *endRegion) {
2521 SmallVector<Region *> seen;
2522 while (toCheck != endRegion) {
2523 auto found = cache.find(toCheck);
2524 if (found != cache.end())
2525 return found->second;
2526 seen.push_back(toCheck);
2527 if (&op.getThenRegion() == toCheck) {
2528 for (Region *region : seen)
2529 cache[region] = Parent::Then;
2530 return Parent::Then;
2531 }
2532 if (&op.getElseRegion() == toCheck) {
2533 for (Region *region : seen)
2534 cache[region] = Parent::Else;
2535 return Parent::Else;
2536 }
2537 toCheck = toCheck->getParentRegion();
2538 }
2539
2540 for (Region *region : seen)
2541 cache[region] = Parent::None;
2542 return Parent::None;
2543 }
2544
2545 LogicalResult matchAndRewrite(IfOp op,
2546 PatternRewriter &rewriter) const override {
2547 // Early exit if the condition is constant since replacing a constant
2548 // in the body with another constant isn't a simplification.
2549 if (matchPattern(op.getCondition(), m_Constant()))
2550 return failure();
2551
2552 bool changed = false;
2553 mlir::Type i1Ty = rewriter.getI1Type();
2554
2555 // These variables serve to prevent creating duplicate constants
2556 // and hold constant true or false values.
2557 Value constantTrue = nullptr;
2558 Value constantFalse = nullptr;
2559
2561 for (OpOperand &use :
2562 llvm::make_early_inc_range(op.getCondition().getUses())) {
2563 switch (getParentType(use.getOwner()->getParentRegion(), op, cache,
2564 op.getCondition().getParentRegion())) {
2565 case Parent::Then: {
2566 changed = true;
2567
2568 if (!constantTrue)
2569 constantTrue = arith::ConstantOp::create(
2570 rewriter, op.getLoc(), i1Ty, rewriter.getIntegerAttr(i1Ty, 1));
2571
2572 rewriter.modifyOpInPlace(use.getOwner(),
2573 [&]() { use.set(constantTrue); });
2574 break;
2575 }
2576 case Parent::Else: {
2577 changed = true;
2578
2579 if (!constantFalse)
2580 constantFalse = arith::ConstantOp::create(
2581 rewriter, op.getLoc(), i1Ty, rewriter.getIntegerAttr(i1Ty, 0));
2582
2583 rewriter.modifyOpInPlace(use.getOwner(),
2584 [&]() { use.set(constantFalse); });
2585 break;
2586 }
2587 case Parent::None:
2588 break;
2589 }
2590 }
2591
2592 return success(changed);
2593 }
2594};
2595
2596/// Remove any statements from an if that are equivalent to the condition
2597/// or its negation. For example:
2598///
2599/// %res:2 = scf.if %cmp {
2600/// yield something(), true
2601/// } else {
2602/// yield something2(), false
2603/// }
2604/// print(%res#1)
2605///
2606/// becomes
2607/// %res = scf.if %cmp {
2608/// yield something()
2609/// } else {
2610/// yield something2()
2611/// }
2612/// print(%cmp)
2613///
2614/// Additionally if both branches yield the same value, replace all uses
2615/// of the result with the yielded value.
2616///
2617/// %res:2 = scf.if %cmp {
2618/// yield something(), %arg1
2619/// } else {
2620/// yield something2(), %arg1
2621/// }
2622/// print(%res#1)
2623///
2624/// becomes
2625/// %res = scf.if %cmp {
2626/// yield something()
2627/// } else {
2628/// yield something2()
2629/// }
2630/// print(%arg1)
2631///
2632struct ReplaceIfYieldWithConditionOrValue : public OpRewritePattern<IfOp> {
2633 using OpRewritePattern<IfOp>::OpRewritePattern;
2634
2635 LogicalResult matchAndRewrite(IfOp op,
2636 PatternRewriter &rewriter) const override {
2637 // Early exit if there are no results that could be replaced.
2638 if (op.getNumResults() == 0)
2639 return failure();
2640
2641 auto trueYield =
2642 cast<scf::YieldOp>(op.getThenRegion().back().getTerminator());
2643 auto falseYield =
2644 cast<scf::YieldOp>(op.getElseRegion().back().getTerminator());
2645
2646 rewriter.setInsertionPoint(op->getBlock(),
2647 op.getOperation()->getIterator());
2648 bool changed = false;
2649 Type i1Ty = rewriter.getI1Type();
2650 for (auto [trueResult, falseResult, opResult] :
2651 llvm::zip(trueYield.getResults(), falseYield.getResults(),
2652 op.getResults())) {
2653 if (trueResult == falseResult) {
2654 if (!opResult.use_empty()) {
2655 opResult.replaceAllUsesWith(trueResult);
2656 changed = true;
2657 }
2658 continue;
2659 }
2660
2661 BoolAttr trueYield, falseYield;
2662 if (!matchPattern(trueResult, m_Constant(&trueYield)) ||
2663 !matchPattern(falseResult, m_Constant(&falseYield)))
2664 continue;
2665
2666 bool trueVal = trueYield.getValue();
2667 bool falseVal = falseYield.getValue();
2668 if (!trueVal && falseVal) {
2669 if (!opResult.use_empty()) {
2670 Dialect *constDialect = trueResult.getDefiningOp()->getDialect();
2671 Value notCond = arith::XOrIOp::create(
2672 rewriter, op.getLoc(), op.getCondition(),
2673 constDialect
2674 ->materializeConstant(rewriter,
2675 rewriter.getIntegerAttr(i1Ty, 1), i1Ty,
2676 op.getLoc())
2677 ->getResult(0));
2678 opResult.replaceAllUsesWith(notCond);
2679 changed = true;
2680 }
2681 }
2682 if (trueVal && !falseVal) {
2683 if (!opResult.use_empty()) {
2684 opResult.replaceAllUsesWith(op.getCondition());
2685 changed = true;
2686 }
2687 }
2688 }
2689 return success(changed);
2690 }
2691};
2692
2693/// Merge any consecutive scf.if's with the same condition.
2694///
2695/// scf.if %cond {
2696/// firstCodeTrue();...
2697/// } else {
2698/// firstCodeFalse();...
2699/// }
2700/// %res = scf.if %cond {
2701/// secondCodeTrue();...
2702/// } else {
2703/// secondCodeFalse();...
2704/// }
2705///
2706/// becomes
2707/// %res = scf.if %cmp {
2708/// firstCodeTrue();...
2709/// secondCodeTrue();...
2710/// } else {
2711/// firstCodeFalse();...
2712/// secondCodeFalse();...
2713/// }
2714struct CombineIfs : public OpRewritePattern<IfOp> {
2715 using OpRewritePattern<IfOp>::OpRewritePattern;
2716
2717 LogicalResult matchAndRewrite(IfOp nextIf,
2718 PatternRewriter &rewriter) const override {
2719 Block *parent = nextIf->getBlock();
2720 if (nextIf == &parent->front())
2721 return failure();
2722
2723 auto prevIf = dyn_cast<IfOp>(nextIf->getPrevNode());
2724 if (!prevIf)
2725 return failure();
2726
2727 // Determine the logical then/else blocks when prevIf's
2728 // condition is used. Null means the block does not exist
2729 // in that case (e.g. empty else). If neither of these
2730 // are set, the two conditions cannot be compared.
2731 Block *nextThen = nullptr;
2732 Block *nextElse = nullptr;
2733 if (nextIf.getCondition() == prevIf.getCondition()) {
2734 nextThen = nextIf.thenBlock();
2735 if (!nextIf.getElseRegion().empty())
2736 nextElse = nextIf.elseBlock();
2737 }
2738 if (arith::XOrIOp notv =
2739 nextIf.getCondition().getDefiningOp<arith::XOrIOp>()) {
2740 if (notv.getLhs() == prevIf.getCondition() &&
2741 matchPattern(notv.getRhs(), m_One())) {
2742 nextElse = nextIf.thenBlock();
2743 if (!nextIf.getElseRegion().empty())
2744 nextThen = nextIf.elseBlock();
2745 }
2746 }
2747 if (arith::XOrIOp notv =
2748 prevIf.getCondition().getDefiningOp<arith::XOrIOp>()) {
2749 if (notv.getLhs() == nextIf.getCondition() &&
2750 matchPattern(notv.getRhs(), m_One())) {
2751 nextElse = nextIf.thenBlock();
2752 if (!nextIf.getElseRegion().empty())
2753 nextThen = nextIf.elseBlock();
2754 }
2755 }
2756
2757 if (!nextThen && !nextElse)
2758 return failure();
2759
2760 SmallVector<Value> prevElseYielded;
2761 if (!prevIf.getElseRegion().empty())
2762 prevElseYielded = prevIf.elseYield().getOperands();
2763 // Replace all uses of return values of op within nextIf with the
2764 // corresponding yields
2765 for (auto it : llvm::zip(prevIf.getResults(),
2766 prevIf.thenYield().getOperands(), prevElseYielded))
2767 for (OpOperand &use :
2768 llvm::make_early_inc_range(std::get<0>(it).getUses())) {
2769 if (nextThen && nextThen->getParent()->isAncestor(
2770 use.getOwner()->getParentRegion())) {
2771 rewriter.startOpModification(use.getOwner());
2772 use.set(std::get<1>(it));
2773 rewriter.finalizeOpModification(use.getOwner());
2774 } else if (nextElse && nextElse->getParent()->isAncestor(
2775 use.getOwner()->getParentRegion())) {
2776 rewriter.startOpModification(use.getOwner());
2777 use.set(std::get<2>(it));
2778 rewriter.finalizeOpModification(use.getOwner());
2779 }
2780 }
2781
2782 SmallVector<Type> mergedTypes(prevIf.getResultTypes());
2783 llvm::append_range(mergedTypes, nextIf.getResultTypes());
2784
2785 IfOp combinedIf = IfOp::create(rewriter, nextIf.getLoc(), mergedTypes,
2786 prevIf.getCondition(), /*hasElse=*/false);
2787 rewriter.eraseBlock(&combinedIf.getThenRegion().back());
2788
2789 rewriter.inlineRegionBefore(prevIf.getThenRegion(),
2790 combinedIf.getThenRegion(),
2791 combinedIf.getThenRegion().begin());
2792
2793 if (nextThen) {
2794 YieldOp thenYield = combinedIf.thenYield();
2795 YieldOp thenYield2 = cast<YieldOp>(nextThen->getTerminator());
2796 rewriter.mergeBlocks(nextThen, combinedIf.thenBlock());
2797 rewriter.setInsertionPointToEnd(combinedIf.thenBlock());
2798
2799 SmallVector<Value> mergedYields(thenYield.getOperands());
2800 llvm::append_range(mergedYields, thenYield2.getOperands());
2801 YieldOp::create(rewriter, thenYield2.getLoc(), mergedYields);
2802 rewriter.eraseOp(thenYield);
2803 rewriter.eraseOp(thenYield2);
2804 }
2805
2806 rewriter.inlineRegionBefore(prevIf.getElseRegion(),
2807 combinedIf.getElseRegion(),
2808 combinedIf.getElseRegion().begin());
2809
2810 if (nextElse) {
2811 if (combinedIf.getElseRegion().empty()) {
2812 rewriter.inlineRegionBefore(*nextElse->getParent(),
2813 combinedIf.getElseRegion(),
2814 combinedIf.getElseRegion().begin());
2815 } else {
2816 YieldOp elseYield = combinedIf.elseYield();
2817 YieldOp elseYield2 = cast<YieldOp>(nextElse->getTerminator());
2818 rewriter.mergeBlocks(nextElse, combinedIf.elseBlock());
2819
2820 rewriter.setInsertionPointToEnd(combinedIf.elseBlock());
2821
2822 SmallVector<Value> mergedElseYields(elseYield.getOperands());
2823 llvm::append_range(mergedElseYields, elseYield2.getOperands());
2824
2825 YieldOp::create(rewriter, elseYield2.getLoc(), mergedElseYields);
2826 rewriter.eraseOp(elseYield);
2827 rewriter.eraseOp(elseYield2);
2828 }
2829 }
2830
2831 SmallVector<Value> prevValues;
2832 SmallVector<Value> nextValues;
2833 for (const auto &pair : llvm::enumerate(combinedIf.getResults())) {
2834 if (pair.index() < prevIf.getNumResults())
2835 prevValues.push_back(pair.value());
2836 else
2837 nextValues.push_back(pair.value());
2838 }
2839 rewriter.replaceOp(prevIf, prevValues);
2840 rewriter.replaceOp(nextIf, nextValues);
2841 return success();
2842 }
2843};
2844
2845/// Pattern to remove an empty else branch.
2846struct RemoveEmptyElseBranch : public OpRewritePattern<IfOp> {
2847 using OpRewritePattern<IfOp>::OpRewritePattern;
2848
2849 LogicalResult matchAndRewrite(IfOp ifOp,
2850 PatternRewriter &rewriter) const override {
2851 // Cannot remove else region when there are operation results.
2852 if (ifOp.getNumResults())
2853 return failure();
2854 Block *elseBlock = ifOp.elseBlock();
2855 if (!elseBlock || !llvm::hasSingleElement(*elseBlock))
2856 return failure();
2857 auto newIfOp = rewriter.cloneWithoutRegions(ifOp);
2858 rewriter.inlineRegionBefore(ifOp.getThenRegion(), newIfOp.getThenRegion(),
2859 newIfOp.getThenRegion().begin());
2860 rewriter.eraseOp(ifOp);
2861 return success();
2862 }
2863};
2864
2865/// Convert nested `if`s into `arith.andi` + single `if`.
2866///
2867/// scf.if %arg0 {
2868/// scf.if %arg1 {
2869/// ...
2870/// scf.yield
2871/// }
2872/// scf.yield
2873/// }
2874/// becomes
2875///
2876/// %0 = arith.andi %arg0, %arg1
2877/// scf.if %0 {
2878/// ...
2879/// scf.yield
2880/// }
2881struct CombineNestedIfs : public OpRewritePattern<IfOp> {
2882 using OpRewritePattern<IfOp>::OpRewritePattern;
2883
2884 LogicalResult matchAndRewrite(IfOp op,
2885 PatternRewriter &rewriter) const override {
2886 auto nestedOps = op.thenBlock()->without_terminator();
2887 // Nested `if` must be the only op in block.
2888 if (!llvm::hasSingleElement(nestedOps))
2889 return failure();
2890
2891 // If there is an else block, it can only yield
2892 if (op.elseBlock() && !llvm::hasSingleElement(*op.elseBlock()))
2893 return failure();
2894
2895 auto nestedIf = dyn_cast<IfOp>(*nestedOps.begin());
2896 if (!nestedIf)
2897 return failure();
2898
2899 if (nestedIf.elseBlock() && !llvm::hasSingleElement(*nestedIf.elseBlock()))
2900 return failure();
2901
2902 SmallVector<Value> thenYield(op.thenYield().getOperands());
2903 SmallVector<Value> elseYield;
2904 if (op.elseBlock())
2905 llvm::append_range(elseYield, op.elseYield().getOperands());
2906
2907 // A list of indices for which we should upgrade the value yielded
2908 // in the else to a select.
2909 SmallVector<unsigned> elseYieldsToUpgradeToSelect;
2910
2911 // If the outer scf.if yields a value produced by the inner scf.if,
2912 // only permit combining if the value yielded when the condition
2913 // is false in the outer scf.if is the same value yielded when the
2914 // inner scf.if condition is false.
2915 // Note that the array access to elseYield will not go out of bounds
2916 // since it must have the same length as thenYield, since they both
2917 // come from the same scf.if.
2918 for (const auto &tup : llvm::enumerate(thenYield)) {
2919 if (tup.value().getDefiningOp() == nestedIf) {
2920 auto nestedIdx = llvm::cast<OpResult>(tup.value()).getResultNumber();
2921 if (nestedIf.elseYield().getOperand(nestedIdx) !=
2922 elseYield[tup.index()]) {
2923 return failure();
2924 }
2925 // If the correctness test passes, we will yield
2926 // corresponding value from the inner scf.if
2927 thenYield[tup.index()] = nestedIf.thenYield().getOperand(nestedIdx);
2928 continue;
2929 }
2930
2931 // Otherwise, we need to ensure the else block of the combined
2932 // condition still returns the same value when the outer condition is
2933 // true and the inner condition is false. This can be accomplished if
2934 // the then value is defined outside the outer scf.if and we replace the
2935 // value with a select that considers just the outer condition. Since
2936 // the else region contains just the yield, its yielded value is
2937 // defined outside the scf.if, by definition.
2938
2939 // If the then value is defined within the scf.if, bail.
2940 if (tup.value().getParentRegion() == &op.getThenRegion()) {
2941 return failure();
2942 }
2943 elseYieldsToUpgradeToSelect.push_back(tup.index());
2944 }
2945
2946 Location loc = op.getLoc();
2947 Value newCondition = arith::AndIOp::create(rewriter, loc, op.getCondition(),
2948 nestedIf.getCondition());
2949 auto newIf = IfOp::create(rewriter, loc, op.getResultTypes(), newCondition);
2950 Block *newIfBlock = rewriter.createBlock(&newIf.getThenRegion());
2951
2952 SmallVector<Value> results;
2953 llvm::append_range(results, newIf.getResults());
2954 rewriter.setInsertionPoint(newIf);
2955
2956 for (auto idx : elseYieldsToUpgradeToSelect)
2957 results[idx] =
2958 arith::SelectOp::create(rewriter, op.getLoc(), op.getCondition(),
2959 thenYield[idx], elseYield[idx]);
2960
2961 rewriter.mergeBlocks(nestedIf.thenBlock(), newIfBlock);
2962 rewriter.setInsertionPointToEnd(newIf.thenBlock());
2963 rewriter.replaceOpWithNewOp<YieldOp>(newIf.thenYield(), thenYield);
2964 if (!elseYield.empty()) {
2965 rewriter.createBlock(&newIf.getElseRegion());
2966 rewriter.setInsertionPointToEnd(newIf.elseBlock());
2967 YieldOp::create(rewriter, loc, elseYield);
2968 }
2969 rewriter.replaceOp(op, results);
2970 return success();
2971 }
2972};
2973
2974} // namespace
2975
2976void IfOp::getCanonicalizationPatterns(RewritePatternSet &results,
2977 MLIRContext *context) {
2978 results.add<CombineIfs, CombineNestedIfs, ConditionPropagation,
2979 ConvertTrivialIfToSelect, RemoveEmptyElseBranch,
2980 RemoveStaticCondition, RemoveUnusedResults,
2981 ReplaceIfYieldWithConditionOrValue>(context);
2982}
2983
2984Block *IfOp::thenBlock() { return &getThenRegion().back(); }
2985YieldOp IfOp::thenYield() { return cast<YieldOp>(&thenBlock()->back()); }
2986Block *IfOp::elseBlock() {
2987 Region &r = getElseRegion();
2988 if (r.empty())
2989 return nullptr;
2990 return &r.back();
2991}
2992YieldOp IfOp::elseYield() { return cast<YieldOp>(&elseBlock()->back()); }
2993
2994//===----------------------------------------------------------------------===//
2995// ParallelOp
2996//===----------------------------------------------------------------------===//
2997
2998void ParallelOp::build(
2999 OpBuilder &builder, OperationState &result, ValueRange lowerBounds,
3000 ValueRange upperBounds, ValueRange steps, ValueRange initVals,
3001 function_ref<void(OpBuilder &, Location, ValueRange, ValueRange)>
3002 bodyBuilderFn) {
3003 result.addOperands(lowerBounds);
3004 result.addOperands(upperBounds);
3005 result.addOperands(steps);
3006 result.addOperands(initVals);
3007 result.addAttribute(
3008 ParallelOp::getOperandSegmentSizeAttr(),
3009 builder.getDenseI32ArrayAttr({static_cast<int32_t>(lowerBounds.size()),
3010 static_cast<int32_t>(upperBounds.size()),
3011 static_cast<int32_t>(steps.size()),
3012 static_cast<int32_t>(initVals.size())}));
3013 result.addTypes(initVals.getTypes());
3014
3015 OpBuilder::InsertionGuard guard(builder);
3016 unsigned numIVs = steps.size();
3017 SmallVector<Type, 8> argTypes(numIVs, builder.getIndexType());
3018 SmallVector<Location, 8> argLocs(numIVs, result.location);
3019 Region *bodyRegion = result.addRegion();
3020 Block *bodyBlock = builder.createBlock(bodyRegion, {}, argTypes, argLocs);
3021
3022 if (bodyBuilderFn) {
3023 builder.setInsertionPointToStart(bodyBlock);
3024 bodyBuilderFn(builder, result.location,
3025 bodyBlock->getArguments().take_front(numIVs),
3026 bodyBlock->getArguments().drop_front(numIVs));
3027 }
3028 // Add terminator only if there are no reductions.
3029 if (initVals.empty())
3030 ParallelOp::ensureTerminator(*bodyRegion, builder, result.location);
3031}
3032
3033void ParallelOp::build(
3034 OpBuilder &builder, OperationState &result, ValueRange lowerBounds,
3035 ValueRange upperBounds, ValueRange steps,
3036 function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilderFn) {
3037 // Only pass a non-null wrapper if bodyBuilderFn is non-null itself. Make sure
3038 // we don't capture a reference to a temporary by constructing the lambda at
3039 // function level.
3040 auto wrappedBuilderFn = [&bodyBuilderFn](OpBuilder &nestedBuilder,
3041 Location nestedLoc, ValueRange ivs,
3042 ValueRange) {
3043 bodyBuilderFn(nestedBuilder, nestedLoc, ivs);
3044 };
3045 function_ref<void(OpBuilder &, Location, ValueRange, ValueRange)> wrapper;
3046 if (bodyBuilderFn)
3047 wrapper = wrappedBuilderFn;
3048
3049 build(builder, result, lowerBounds, upperBounds, steps, ValueRange(),
3050 wrapper);
3051}
3052
3053LogicalResult ParallelOp::verify() {
3054 // Check that there is at least one value in lowerBound, upperBound and step.
3055 // It is sufficient to test only step, because it is ensured already that the
3056 // number of elements in lowerBound, upperBound and step are the same.
3057 Operation::operand_range stepValues = getStep();
3058 if (stepValues.empty())
3059 return emitOpError(
3060 "needs at least one tuple element for lowerBound, upperBound and step");
3061
3062 // Check whether all constant step values are positive.
3063 for (Value stepValue : stepValues)
3064 if (auto cst = getConstantIntValue(stepValue))
3065 if (*cst <= 0)
3066 return emitOpError("constant step operand must be positive");
3067
3068 // Check that the body defines the same number of block arguments as the
3069 // number of tuple elements in step.
3070 Block *body = getBody();
3071 if (body->getNumArguments() != stepValues.size())
3072 return emitOpError() << "expects the same number of induction variables: "
3073 << body->getNumArguments()
3074 << " as bound and step values: " << stepValues.size();
3075 for (auto arg : body->getArguments())
3076 if (!arg.getType().isIndex())
3077 return emitOpError(
3078 "expects arguments for the induction variable to be of index type");
3079
3080 // Check that the terminator is an scf.reduce op.
3082 *this, getRegion(), "expects body to terminate with 'scf.reduce'");
3083 if (!reduceOp)
3084 return failure();
3085
3086 // Check that the number of results is the same as the number of reductions.
3087 auto resultsSize = getResults().size();
3088 auto reductionsSize = reduceOp.getReductions().size();
3089 auto initValsSize = getInitVals().size();
3090 if (resultsSize != reductionsSize)
3091 return emitOpError() << "expects number of results: " << resultsSize
3092 << " to be the same as number of reductions: "
3093 << reductionsSize;
3094 if (resultsSize != initValsSize)
3095 return emitOpError() << "expects number of results: " << resultsSize
3096 << " to be the same as number of initial values: "
3097 << initValsSize;
3098 if (reduceOp.getNumOperands() != initValsSize)
3099 // Delegate error reporting to ReduceOp
3100 return success();
3101
3102 // Check that the types of the results and reductions are the same.
3103 for (int64_t i = 0; i < static_cast<int64_t>(reductionsSize); ++i) {
3104 auto resultType = getOperation()->getResult(i).getType();
3105 auto reductionOperandType = reduceOp.getOperands()[i].getType();
3106 if (resultType != reductionOperandType)
3107 return reduceOp.emitOpError()
3108 << "expects type of " << i
3109 << "-th reduction operand: " << reductionOperandType
3110 << " to be the same as the " << i
3111 << "-th result type: " << resultType;
3112 }
3113 return success();
3114}
3115
3116ParseResult ParallelOp::parse(OpAsmParser &parser, OperationState &result) {
3117 auto &builder = parser.getBuilder();
3118 // Parse an opening `(` followed by induction variables followed by `)`
3119 SmallVector<OpAsmParser::Argument, 4> ivs;
3121 return failure();
3122
3123 // Parse loop bounds.
3124 SmallVector<OpAsmParser::UnresolvedOperand, 4> lower;
3125 if (parser.parseEqual() ||
3126 parser.parseOperandList(lower, ivs.size(),
3128 parser.resolveOperands(lower, builder.getIndexType(), result.operands))
3129 return failure();
3130
3131 SmallVector<OpAsmParser::UnresolvedOperand, 4> upper;
3132 if (parser.parseKeyword("to") ||
3133 parser.parseOperandList(upper, ivs.size(),
3135 parser.resolveOperands(upper, builder.getIndexType(), result.operands))
3136 return failure();
3137
3138 // Parse step values.
3139 SmallVector<OpAsmParser::UnresolvedOperand, 4> steps;
3140 if (parser.parseKeyword("step") ||
3141 parser.parseOperandList(steps, ivs.size(),
3143 parser.resolveOperands(steps, builder.getIndexType(), result.operands))
3144 return failure();
3145
3146 // Parse init values.
3147 SmallVector<OpAsmParser::UnresolvedOperand, 4> initVals;
3148 if (succeeded(parser.parseOptionalKeyword("init"))) {
3149 if (parser.parseOperandList(initVals, OpAsmParser::Delimiter::Paren))
3150 return failure();
3151 }
3152
3153 // Parse optional results in case there is a reduce.
3154 if (parser.parseOptionalArrowTypeList(result.types))
3155 return failure();
3156
3157 // Now parse the body.
3158 Region *body = result.addRegion();
3159 for (auto &iv : ivs)
3160 iv.type = builder.getIndexType();
3161 if (parser.parseRegion(*body, ivs))
3162 return failure();
3163
3164 // Set `operandSegmentSizes` attribute.
3165 result.addAttribute(
3166 ParallelOp::getOperandSegmentSizeAttr(),
3167 builder.getDenseI32ArrayAttr({static_cast<int32_t>(lower.size()),
3168 static_cast<int32_t>(upper.size()),
3169 static_cast<int32_t>(steps.size()),
3170 static_cast<int32_t>(initVals.size())}));
3171
3172 // Parse attributes.
3173 if (parser.parseOptionalAttrDict(result.attributes) ||
3174 parser.resolveOperands(initVals, result.types, parser.getNameLoc(),
3175 result.operands))
3176 return failure();
3177
3178 // Add a terminator if none was parsed.
3179 ParallelOp::ensureTerminator(*body, builder, result.location);
3180 return success();
3181}
3182
3183void ParallelOp::print(OpAsmPrinter &p) {
3184 p << " (" << getBody()->getArguments() << ") = (" << getLowerBound()
3185 << ") to (" << getUpperBound() << ") step (" << getStep() << ")";
3186 if (!getInitVals().empty())
3187 p << " init (" << getInitVals() << ")";
3188 p.printOptionalArrowTypeList(getResultTypes());
3189 p << ' ';
3190 p.printRegion(getRegion(), /*printEntryBlockArgs=*/false);
3192 (*this)->getAttrs(),
3193 /*elidedAttrs=*/ParallelOp::getOperandSegmentSizeAttr());
3194}
3195
3196SmallVector<Region *> ParallelOp::getLoopRegions() { return {&getRegion()}; }
3197
3198std::optional<SmallVector<Value>> ParallelOp::getLoopInductionVars() {
3199 return SmallVector<Value>{getBody()->getArguments()};
3200}
3201
3202std::optional<SmallVector<OpFoldResult>> ParallelOp::getLoopLowerBounds() {
3203 return getLowerBound();
3204}
3205
3206std::optional<SmallVector<OpFoldResult>> ParallelOp::getLoopUpperBounds() {
3207 return getUpperBound();
3208}
3209
3210std::optional<SmallVector<OpFoldResult>> ParallelOp::getLoopSteps() {
3211 return getStep();
3212}
3213
3215 auto ivArg = llvm::dyn_cast<BlockArgument>(val);
3216 if (!ivArg)
3217 return ParallelOp();
3218 assert(ivArg.getOwner() && "unlinked block argument");
3219 auto *containingOp = ivArg.getOwner()->getParentOp();
3220 return dyn_cast<ParallelOp>(containingOp);
3221}
3222
3223namespace {
3224// Collapse loop dimensions that perform a single iteration.
3225struct ParallelOpSingleOrZeroIterationDimsFolder
3226 : public OpRewritePattern<ParallelOp> {
3227 using OpRewritePattern<ParallelOp>::OpRewritePattern;
3228
3229 LogicalResult matchAndRewrite(ParallelOp op,
3230 PatternRewriter &rewriter) const override {
3231 Location loc = op.getLoc();
3232
3233 // Compute new loop bounds that omit all single-iteration loop dimensions.
3234 SmallVector<Value> newLowerBounds, newUpperBounds, newSteps;
3235 IRMapping mapping;
3236 for (auto [lb, ub, step, iv] :
3237 llvm::zip(op.getLowerBound(), op.getUpperBound(), op.getStep(),
3238 op.getInductionVars())) {
3239 auto numIterations =
3240 constantTripCount(lb, ub, step, /*isSigned=*/true, computeUbMinusLb);
3241 if (numIterations.has_value()) {
3242 // Remove the loop if it performs zero iterations.
3243 if (*numIterations == 0) {
3244 rewriter.replaceOp(op, op.getInitVals());
3245 return success();
3246 }
3247 // Replace the loop induction variable by the lower bound if the loop
3248 // performs a single iteration. Otherwise, copy the loop bounds.
3249 if (*numIterations == 1) {
3250 mapping.map(iv, getValueOrCreateConstantIndexOp(rewriter, loc, lb));
3251 continue;
3252 }
3253 }
3254 newLowerBounds.push_back(lb);
3255 newUpperBounds.push_back(ub);
3256 newSteps.push_back(step);
3257 }
3258 // Exit if none of the loop dimensions perform a single iteration.
3259 if (newLowerBounds.size() == op.getLowerBound().size())
3260 return failure();
3261
3262 if (newLowerBounds.empty()) {
3263 // All of the loop dimensions perform a single iteration. Inline
3264 // loop body and nested ReduceOp's
3265 SmallVector<Value> results;
3266 results.reserve(op.getInitVals().size());
3267 for (auto &bodyOp : op.getBody()->without_terminator())
3268 rewriter.clone(bodyOp, mapping);
3269 auto reduceOp = cast<ReduceOp>(op.getBody()->getTerminator());
3270 for (int64_t i = 0, e = reduceOp.getReductions().size(); i < e; ++i) {
3271 Block &reduceBlock = reduceOp.getReductions()[i].front();
3272 auto initValIndex = results.size();
3273 mapping.map(reduceBlock.getArgument(0), op.getInitVals()[initValIndex]);
3274 mapping.map(reduceBlock.getArgument(1),
3275 mapping.lookupOrDefault(reduceOp.getOperands()[i]));
3276 for (auto &reduceBodyOp : reduceBlock.without_terminator())
3277 rewriter.clone(reduceBodyOp, mapping);
3278
3279 auto result = mapping.lookupOrDefault(
3280 cast<ReduceReturnOp>(reduceBlock.getTerminator()).getResult());
3281 results.push_back(result);
3282 }
3283
3284 rewriter.replaceOp(op, results);
3285 return success();
3286 }
3287 // Replace the parallel loop by lower-dimensional parallel loop.
3288 auto newOp =
3289 ParallelOp::create(rewriter, op.getLoc(), newLowerBounds,
3290 newUpperBounds, newSteps, op.getInitVals(), nullptr);
3291 // Erase the empty block that was inserted by the builder.
3292 rewriter.eraseBlock(newOp.getBody());
3293 // Clone the loop body and remap the block arguments of the collapsed loops
3294 // (inlining does not support a cancellable block argument mapping).
3295 rewriter.cloneRegionBefore(op.getRegion(), newOp.getRegion(),
3296 newOp.getRegion().begin(), mapping);
3297 rewriter.replaceOp(op, newOp.getResults());
3298 return success();
3299 }
3300};
3301
3302struct MergeNestedParallelLoops : public OpRewritePattern<ParallelOp> {
3303 using OpRewritePattern<ParallelOp>::OpRewritePattern;
3304
3305 LogicalResult matchAndRewrite(ParallelOp op,
3306 PatternRewriter &rewriter) const override {
3307 Block &outerBody = *op.getBody();
3308 if (!llvm::hasSingleElement(outerBody.without_terminator()))
3309 return failure();
3310
3311 auto innerOp = dyn_cast<ParallelOp>(outerBody.front());
3312 if (!innerOp)
3313 return failure();
3314
3315 for (auto val : outerBody.getArguments())
3316 if (llvm::is_contained(innerOp.getLowerBound(), val) ||
3317 llvm::is_contained(innerOp.getUpperBound(), val) ||
3318 llvm::is_contained(innerOp.getStep(), val))
3319 return failure();
3320
3321 // Reductions are not supported yet.
3322 if (!op.getInitVals().empty() || !innerOp.getInitVals().empty())
3323 return failure();
3324
3325 auto bodyBuilder = [&](OpBuilder &builder, Location /*loc*/,
3326 ValueRange iterVals, ValueRange) {
3327 Block &innerBody = *innerOp.getBody();
3328 assert(iterVals.size() ==
3329 (outerBody.getNumArguments() + innerBody.getNumArguments()));
3330 IRMapping mapping;
3331 mapping.map(outerBody.getArguments(),
3332 iterVals.take_front(outerBody.getNumArguments()));
3333 mapping.map(innerBody.getArguments(),
3334 iterVals.take_back(innerBody.getNumArguments()));
3335 for (Operation &op : innerBody.without_terminator())
3336 builder.clone(op, mapping);
3337 };
3338
3339 auto concatValues = [](const auto &first, const auto &second) {
3340 SmallVector<Value> ret;
3341 ret.reserve(first.size() + second.size());
3342 ret.assign(first.begin(), first.end());
3343 ret.append(second.begin(), second.end());
3344 return ret;
3345 };
3346
3347 auto newLowerBounds =
3348 concatValues(op.getLowerBound(), innerOp.getLowerBound());
3349 auto newUpperBounds =
3350 concatValues(op.getUpperBound(), innerOp.getUpperBound());
3351 auto newSteps = concatValues(op.getStep(), innerOp.getStep());
3352
3353 rewriter.replaceOpWithNewOp<ParallelOp>(op, newLowerBounds, newUpperBounds,
3354 newSteps, ValueRange(),
3355 bodyBuilder);
3356 return success();
3357 }
3358};
3359
3360} // namespace
3361
3362void ParallelOp::getCanonicalizationPatterns(RewritePatternSet &results,
3363 MLIRContext *context) {
3364 results
3365 .add<ParallelOpSingleOrZeroIterationDimsFolder, MergeNestedParallelLoops>(
3366 context);
3367}
3368
3369/// Given the region at `index`, or the parent operation if `index` is None,
3370/// return the successor regions. These are the regions that may be selected
3371/// during the flow of control. `operands` is a set of optional attributes that
3372/// correspond to a constant value for each operand, or null if that operand is
3373/// not a constant.
3374void ParallelOp::getSuccessorRegions(
3375 RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {
3376 // Both the operation itself and the region may be branching into the body or
3377 // back into the operation itself. It is possible for loop not to enter the
3378 // body.
3379 regions.push_back(RegionSuccessor(&getRegion()));
3380 regions.push_back(RegionSuccessor(
3381 getOperation(), ResultRange{getResults().end(), getResults().end()}));
3382}
3383
3384//===----------------------------------------------------------------------===//
3385// ReduceOp
3386//===----------------------------------------------------------------------===//
3387
3388void ReduceOp::build(OpBuilder &builder, OperationState &result) {}
3389
3390void ReduceOp::build(OpBuilder &builder, OperationState &result,
3391 ValueRange operands) {
3392 result.addOperands(operands);
3393 for (Value v : operands) {
3394 OpBuilder::InsertionGuard guard(builder);
3395 Region *bodyRegion = result.addRegion();
3396 builder.createBlock(bodyRegion, {},
3397 ArrayRef<Type>{v.getType(), v.getType()},
3398 {result.location, result.location});
3399 }
3400}
3401
3402LogicalResult ReduceOp::verifyRegions() {
3403 if (getReductions().size() != getOperands().size())
3404 return emitOpError() << "expects number of reduction regions: "
3405 << getReductions().size()
3406 << " to be the same as number of reduction operands: "
3407 << getOperands().size();
3408 // The region of a ReduceOp has two arguments of the same type as its
3409 // corresponding operand.
3410 for (int64_t i = 0, e = getReductions().size(); i < e; ++i) {
3411 auto type = getOperands()[i].getType();
3412 Block &block = getReductions()[i].front();
3413 if (block.empty())
3414 return emitOpError() << i << "-th reduction has an empty body";
3415 if (block.getNumArguments() != 2 ||
3416 llvm::any_of(block.getArguments(), [&](const BlockArgument &arg) {
3417 return arg.getType() != type;
3418 }))
3419 return emitOpError() << "expected two block arguments with type " << type
3420 << " in the " << i << "-th reduction region";
3421
3422 // Check that the block is terminated by a ReduceReturnOp.
3423 if (!isa<ReduceReturnOp>(block.getTerminator()))
3424 return emitOpError("reduction bodies must be terminated with an "
3425 "'scf.reduce.return' op");
3426 }
3427
3428 return success();
3429}
3430
3431MutableOperandRange
3432ReduceOp::getMutableSuccessorOperands(RegionSuccessor point) {
3433 // No operands are forwarded to the next iteration.
3434 return MutableOperandRange(getOperation(), /*start=*/0, /*length=*/0);
3435}
3436
3437//===----------------------------------------------------------------------===//
3438// ReduceReturnOp
3439//===----------------------------------------------------------------------===//
3440
3441LogicalResult ReduceReturnOp::verify() {
3442 // The type of the return value should be the same type as the types of the
3443 // block arguments of the reduction body.
3444 Block *reductionBody = getOperation()->getBlock();
3445 // Should already be verified by an op trait.
3446 assert(isa<ReduceOp>(reductionBody->getParentOp()) && "expected scf.reduce");
3447 Type expectedResultType = reductionBody->getArgument(0).getType();
3448 if (expectedResultType != getResult().getType())
3449 return emitOpError() << "must have type " << expectedResultType
3450 << " (the type of the reduction inputs)";
3451 return success();
3452}
3453
3454//===----------------------------------------------------------------------===//
3455// WhileOp
3456//===----------------------------------------------------------------------===//
3457
3458void WhileOp::build(::mlir::OpBuilder &odsBuilder,
3459 ::mlir::OperationState &odsState, TypeRange resultTypes,
3460 ValueRange inits, BodyBuilderFn beforeBuilder,
3461 BodyBuilderFn afterBuilder) {
3462 odsState.addOperands(inits);
3463 odsState.addTypes(resultTypes);
3464
3465 OpBuilder::InsertionGuard guard(odsBuilder);
3466
3467 // Build before region.
3468 SmallVector<Location, 4> beforeArgLocs;
3469 beforeArgLocs.reserve(inits.size());
3470 for (Value operand : inits) {
3471 beforeArgLocs.push_back(operand.getLoc());
3472 }
3473
3474 Region *beforeRegion = odsState.addRegion();
3475 Block *beforeBlock = odsBuilder.createBlock(beforeRegion, /*insertPt=*/{},
3476 inits.getTypes(), beforeArgLocs);
3477 if (beforeBuilder)
3478 beforeBuilder(odsBuilder, odsState.location, beforeBlock->getArguments());
3479
3480 // Build after region.
3481 SmallVector<Location, 4> afterArgLocs(resultTypes.size(), odsState.location);
3482
3483 Region *afterRegion = odsState.addRegion();
3484 Block *afterBlock = odsBuilder.createBlock(afterRegion, /*insertPt=*/{},
3485 resultTypes, afterArgLocs);
3486
3487 if (afterBuilder)
3488 afterBuilder(odsBuilder, odsState.location, afterBlock->getArguments());
3489}
3490
3491ConditionOp WhileOp::getConditionOp() {
3492 return cast<ConditionOp>(getBeforeBody()->getTerminator());
3493}
3494
3495YieldOp WhileOp::getYieldOp() {
3496 return cast<YieldOp>(getAfterBody()->getTerminator());
3497}
3498
3499std::optional<MutableArrayRef<OpOperand>> WhileOp::getYieldedValuesMutable() {
3500 return getYieldOp().getResultsMutable();
3501}
3502
3503Block::BlockArgListType WhileOp::getBeforeArguments() {
3504 return getBeforeBody()->getArguments();
3505}
3506
3507Block::BlockArgListType WhileOp::getAfterArguments() {
3508 return getAfterBody()->getArguments();
3509}
3510
3511Block::BlockArgListType WhileOp::getRegionIterArgs() {
3512 return getBeforeArguments();
3513}
3514
3515OperandRange WhileOp::getEntrySuccessorOperands(RegionSuccessor successor) {
3516 assert(successor.getSuccessor() == &getBefore() &&
3517 "WhileOp is expected to branch only to the first region");
3518 return getInits();
3519}
3520
3521void WhileOp::getSuccessorRegions(RegionBranchPoint point,
3522 SmallVectorImpl<RegionSuccessor> &regions) {
3523 // The parent op always branches to the condition region.
3524 if (point.isParent()) {
3525 regions.emplace_back(&getBefore(), getBefore().getArguments());
3526 return;
3527 }
3528
3529 assert(llvm::is_contained(
3530 {&getAfter(), &getBefore()},
3531 point.getTerminatorPredecessorOrNull()->getParentRegion()) &&
3532 "there are only two regions in a WhileOp");
3533 // The body region always branches back to the condition region.
3535 &getAfter()) {
3536 regions.emplace_back(&getBefore(), getBefore().getArguments());
3537 return;
3538 }
3539
3540 regions.emplace_back(getOperation(), getResults());
3541 regions.emplace_back(&getAfter(), getAfter().getArguments());
3542}
3543
3544SmallVector<Region *> WhileOp::getLoopRegions() {
3545 return {&getBefore(), &getAfter()};
3546}
3547
3548/// Parses a `while` op.
3549///
3550/// op ::= `scf.while` assignments `:` function-type region `do` region
3551/// `attributes` attribute-dict
3552/// initializer ::= /* empty */ | `(` assignment-list `)`
3553/// assignment-list ::= assignment | assignment `,` assignment-list
3554/// assignment ::= ssa-value `=` ssa-value
3555ParseResult scf::WhileOp::parse(OpAsmParser &parser, OperationState &result) {
3556 SmallVector<OpAsmParser::Argument, 4> regionArgs;
3557 SmallVector<OpAsmParser::UnresolvedOperand, 4> operands;
3558 Region *before = result.addRegion();
3559 Region *after = result.addRegion();
3560
3561 OptionalParseResult listResult =
3562 parser.parseOptionalAssignmentList(regionArgs, operands);
3563 if (listResult.has_value() && failed(listResult.value()))
3564 return failure();
3565
3566 FunctionType functionType;
3567 SMLoc typeLoc = parser.getCurrentLocation();
3568 if (failed(parser.parseColonType(functionType)))
3569 return failure();
3570
3571 result.addTypes(functionType.getResults());
3572
3573 if (functionType.getNumInputs() != operands.size()) {
3574 return parser.emitError(typeLoc)
3575 << "expected as many input types as operands " << "(expected "
3576 << operands.size() << " got " << functionType.getNumInputs() << ")";
3577 }
3578
3579 // Resolve input operands.
3580 if (failed(parser.resolveOperands(operands, functionType.getInputs(),
3581 parser.getCurrentLocation(),
3582 result.operands)))
3583 return failure();
3584
3585 // Propagate the types into the region arguments.
3586 for (size_t i = 0, e = regionArgs.size(); i != e; ++i)
3587 regionArgs[i].type = functionType.getInput(i);
3588
3589 return failure(parser.parseRegion(*before, regionArgs) ||
3590 parser.parseKeyword("do") || parser.parseRegion(*after) ||
3591 parser.parseOptionalAttrDictWithKeyword(result.attributes));
3592}
3593
3594/// Prints a `while` op.
3595void scf::WhileOp::print(OpAsmPrinter &p) {
3596 printInitializationList(p, getBeforeArguments(), getInits(), " ");
3597 p << " : ";
3598 p.printFunctionalType(getInits().getTypes(), getResults().getTypes());
3599 p << ' ';
3600 p.printRegion(getBefore(), /*printEntryBlockArgs=*/false);
3601 p << " do ";
3602 p.printRegion(getAfter());
3603 p.printOptionalAttrDictWithKeyword((*this)->getAttrs());
3604}
3605
3606/// Verifies that two ranges of types match, i.e. have the same number of
3607/// entries and that types are pairwise equals. Reports errors on the given
3608/// operation in case of mismatch.
3609template <typename OpTy>
3610static LogicalResult verifyTypeRangesMatch(OpTy op, TypeRange left,
3611 TypeRange right, StringRef message) {
3612 if (left.size() != right.size())
3613 return op.emitOpError("expects the same number of ") << message;
3614
3615 for (unsigned i = 0, e = left.size(); i < e; ++i) {
3616 if (left[i] != right[i]) {
3617 InFlightDiagnostic diag = op.emitOpError("expects the same types for ")
3618 << message;
3619 diag.attachNote() << "for argument " << i << ", found " << left[i]
3620 << " and " << right[i];
3621 return diag;
3622 }
3623 }
3624
3625 return success();
3626}
3627
3628LogicalResult scf::WhileOp::verify() {
3629 auto beforeTerminator = verifyAndGetTerminator<scf::ConditionOp>(
3630 *this, getBefore(),
3631 "expects the 'before' region to terminate with 'scf.condition'");
3632 if (!beforeTerminator)
3633 return failure();
3634
3635 auto afterTerminator = verifyAndGetTerminator<scf::YieldOp>(
3636 *this, getAfter(),
3637 "expects the 'after' region to terminate with 'scf.yield'");
3638 return success(afterTerminator != nullptr);
3639}
3640
3641namespace {
3642/// Move a scf.if op that is directly before the scf.condition op in the while
3643/// before region, and whose condition matches the condition of the
3644/// scf.condition op, down into the while after region.
3645///
3646/// scf.while (..) : (...) -> ... {
3647/// %additional_used_values = ...
3648/// %cond = ...
3649/// ...
3650/// %res = scf.if %cond -> (...) {
3651/// use(%additional_used_values)
3652/// ... // then block
3653/// scf.yield %then_value
3654/// } else {
3655/// scf.yield %else_value
3656/// }
3657/// scf.condition(%cond) %res, ...
3658/// } do {
3659/// ^bb0(%res_arg, ...):
3660/// use(%res_arg)
3661/// ...
3662///
3663/// becomes
3664/// scf.while (..) : (...) -> ... {
3665/// %additional_used_values = ...
3666/// %cond = ...
3667/// ...
3668/// scf.condition(%cond) %else_value, ..., %additional_used_values
3669/// } do {
3670/// ^bb0(%res_arg ..., %additional_args): :
3671/// use(%additional_args)
3672/// ... // if then block
3673/// use(%then_value)
3674/// ...
3675struct WhileMoveIfDown : public OpRewritePattern<scf::WhileOp> {
3676 using OpRewritePattern<scf::WhileOp>::OpRewritePattern;
3677
3678 LogicalResult matchAndRewrite(scf::WhileOp op,
3679 PatternRewriter &rewriter) const override {
3680 auto conditionOp = op.getConditionOp();
3681
3682 // Only support ifOp right before the condition at the moment. Relaxing this
3683 // would require to:
3684 // - check that the body does not have side-effects conflicting with
3685 // operations between the if and the condition.
3686 // - check that results of the if operation are only used as arguments to
3687 // the condition.
3688 auto ifOp = dyn_cast_or_null<scf::IfOp>(conditionOp->getPrevNode());
3689
3690 // Check that the ifOp is directly before the conditionOp and that it
3691 // matches the condition of the conditionOp. Also ensure that the ifOp has
3692 // no else block with content, as that would complicate the transformation.
3693 // TODO: support else blocks with content.
3694 if (!ifOp || ifOp.getCondition() != conditionOp.getCondition() ||
3695 (ifOp.elseBlock() && !ifOp.elseBlock()->without_terminator().empty()))
3696 return failure();
3697
3698 assert((ifOp->use_empty() || (llvm::all_equal(ifOp->getUsers()) &&
3699 *ifOp->user_begin() == conditionOp)) &&
3700 "ifOp has unexpected uses");
3701
3702 Location loc = op.getLoc();
3703
3704 // Replace uses of ifOp results in the conditionOp with the yielded values
3705 // from the ifOp branches.
3706 for (auto [idx, arg] : llvm::enumerate(conditionOp.getArgs())) {
3707 auto it = llvm::find(ifOp->getResults(), arg);
3708 if (it != ifOp->getResults().end()) {
3709 size_t ifOpIdx = it.getIndex();
3710 Value thenValue = ifOp.thenYield()->getOperand(ifOpIdx);
3711 Value elseValue = ifOp.elseYield()->getOperand(ifOpIdx);
3712
3713 rewriter.replaceAllUsesWith(ifOp->getResults()[ifOpIdx], elseValue);
3714 rewriter.replaceAllUsesWith(op.getAfterArguments()[idx], thenValue);
3715 }
3716 }
3717
3718 // Collect additional used values from before region.
3719 SetVector<Value> additionalUsedValuesSet;
3720 visitUsedValuesDefinedAbove(ifOp.getThenRegion(), [&](OpOperand *operand) {
3721 if (&op.getBefore() == operand->get().getParentRegion())
3722 additionalUsedValuesSet.insert(operand->get());
3723 });
3724
3725 // Create new whileOp with additional used values as results.
3726 auto additionalUsedValues = additionalUsedValuesSet.getArrayRef();
3727 auto additionalValueTypes = llvm::map_to_vector(
3728 additionalUsedValues, [](Value val) { return val.getType(); });
3729 size_t additionalValueSize = additionalUsedValues.size();
3730 SmallVector<Type> newResultTypes(op.getResultTypes());
3731 newResultTypes.append(additionalValueTypes);
3732
3733 auto newWhileOp =
3734 scf::WhileOp::create(rewriter, loc, newResultTypes, op.getInits());
3735
3736 rewriter.modifyOpInPlace(newWhileOp, [&] {
3737 newWhileOp.getBefore().takeBody(op.getBefore());
3738 newWhileOp.getAfter().takeBody(op.getAfter());
3739 newWhileOp.getAfter().addArguments(
3740 additionalValueTypes,
3741 SmallVector<Location>(additionalValueSize, loc));
3742 });
3743
3744 rewriter.modifyOpInPlace(conditionOp, [&] {
3745 conditionOp.getArgsMutable().append(additionalUsedValues);
3746 });
3747
3748 // Replace uses of additional used values inside the ifOp then region with
3749 // the whileOp after region arguments.
3750 rewriter.replaceUsesWithIf(
3751 additionalUsedValues,
3752 newWhileOp.getAfterArguments().take_back(additionalValueSize),
3753 [&](OpOperand &use) {
3754 return ifOp.getThenRegion().isAncestor(
3755 use.getOwner()->getParentRegion());
3756 });
3757
3758 // Inline ifOp then region into new whileOp after region.
3759 rewriter.eraseOp(ifOp.thenYield());
3760 rewriter.inlineBlockBefore(ifOp.thenBlock(), newWhileOp.getAfterBody(),
3761 newWhileOp.getAfterBody()->begin());
3762 rewriter.eraseOp(ifOp);
3763 rewriter.replaceOp(op,
3764 newWhileOp->getResults().drop_back(additionalValueSize));
3765 return success();
3766 }
3767};
3768
3769/// Replace uses of the condition within the do block with true, since otherwise
3770/// the block would not be evaluated.
3771///
3772/// scf.while (..) : (i1, ...) -> ... {
3773/// %condition = call @evaluate_condition() : () -> i1
3774/// scf.condition(%condition) %condition : i1, ...
3775/// } do {
3776/// ^bb0(%arg0: i1, ...):
3777/// use(%arg0)
3778/// ...
3779///
3780/// becomes
3781/// scf.while (..) : (i1, ...) -> ... {
3782/// %condition = call @evaluate_condition() : () -> i1
3783/// scf.condition(%condition) %condition : i1, ...
3784/// } do {
3785/// ^bb0(%arg0: i1, ...):
3786/// use(%true)
3787/// ...
3788struct WhileConditionTruth : public OpRewritePattern<WhileOp> {
3789 using OpRewritePattern<WhileOp>::OpRewritePattern;
3790
3791 LogicalResult matchAndRewrite(WhileOp op,
3792 PatternRewriter &rewriter) const override {
3793 auto term = op.getConditionOp();
3794
3795 // These variables serve to prevent creating duplicate constants
3796 // and hold constant true or false values.
3797 Value constantTrue = nullptr;
3798
3799 bool replaced = false;
3800 for (auto yieldedAndBlockArgs :
3801 llvm::zip(term.getArgs(), op.getAfterArguments())) {
3802 if (std::get<0>(yieldedAndBlockArgs) == term.getCondition()) {
3803 if (!std::get<1>(yieldedAndBlockArgs).use_empty()) {
3804 if (!constantTrue)
3805 constantTrue = arith::ConstantOp::create(
3806 rewriter, op.getLoc(), term.getCondition().getType(),
3807 rewriter.getBoolAttr(true));
3808
3809 rewriter.replaceAllUsesWith(std::get<1>(yieldedAndBlockArgs),
3810 constantTrue);
3811 replaced = true;
3812 }
3813 }
3814 }
3815 return success(replaced);
3816 }
3817};
3818
3819/// Remove loop invariant arguments from `before` block of scf.while.
3820/// A before block argument is considered loop invariant if :-
3821/// 1. i-th yield operand is equal to the i-th while operand.
3822/// 2. i-th yield operand is k-th after block argument which is (k+1)-th
3823/// condition operand AND this (k+1)-th condition operand is equal to i-th
3824/// iter argument/while operand.
3825/// For the arguments which are removed, their uses inside scf.while
3826/// are replaced with their corresponding initial value.
3827///
3828/// Eg:
3829/// INPUT :-
3830/// %res = scf.while <...> iter_args(%arg0_before = %a, %arg1_before = %b,
3831/// ..., %argN_before = %N)
3832/// {
3833/// ...
3834/// scf.condition(%cond) %arg1_before, %arg0_before,
3835/// %arg2_before, %arg0_before, ...
3836/// } do {
3837/// ^bb0(%arg1_after, %arg0_after_1, %arg2_after, %arg0_after_2,
3838/// ..., %argK_after):
3839/// ...
3840/// scf.yield %arg0_after_2, %b, %arg1_after, ..., %argN
3841/// }
3842///
3843/// OUTPUT :-
3844/// %res = scf.while <...> iter_args(%arg2_before = %c, ..., %argN_before =
3845/// %N)
3846/// {
3847/// ...
3848/// scf.condition(%cond) %b, %a, %arg2_before, %a, ...
3849/// } do {
3850/// ^bb0(%arg1_after, %arg0_after_1, %arg2_after, %arg0_after_2,
3851/// ..., %argK_after):
3852/// ...
3853/// scf.yield %arg1_after, ..., %argN
3854/// }
3855///
3856/// EXPLANATION:
3857/// We iterate over each yield operand.
3858/// 1. 0-th yield operand %arg0_after_2 is 4-th condition operand
3859/// %arg0_before, which in turn is the 0-th iter argument. So we
3860/// remove 0-th before block argument and yield operand, and replace
3861/// all uses of the 0-th before block argument with its initial value
3862/// %a.
3863/// 2. 1-th yield operand %b is equal to the 1-th iter arg's initial
3864/// value. So we remove this operand and the corresponding before
3865/// block argument and replace all uses of 1-th before block argument
3866/// with %b.
3867struct RemoveLoopInvariantArgsFromBeforeBlock
3868 : public OpRewritePattern<WhileOp> {
3869 using OpRewritePattern<WhileOp>::OpRewritePattern;
3870
3871 LogicalResult matchAndRewrite(WhileOp op,
3872 PatternRewriter &rewriter) const override {
3873 Block &afterBlock = *op.getAfterBody();
3874 Block::BlockArgListType beforeBlockArgs = op.getBeforeArguments();
3875 ConditionOp condOp = op.getConditionOp();
3876 OperandRange condOpArgs = condOp.getArgs();
3877 Operation *yieldOp = afterBlock.getTerminator();
3878 ValueRange yieldOpArgs = yieldOp->getOperands();
3879
3880 bool canSimplify = false;
3881 for (const auto &it :
3882 llvm::enumerate(llvm::zip(op.getOperands(), yieldOpArgs))) {
3883 auto index = static_cast<unsigned>(it.index());
3884 auto [initVal, yieldOpArg] = it.value();
3885 // If i-th yield operand is equal to the i-th operand of the scf.while,
3886 // the i-th before block argument is a loop invariant.
3887 if (yieldOpArg == initVal) {
3888 canSimplify = true;
3889 break;
3890 }
3891 // If the i-th yield operand is k-th after block argument, then we check
3892 // if the (k+1)-th condition op operand is equal to either the i-th before
3893 // block argument or the initial value of i-th before block argument. If
3894 // the comparison results `true`, i-th before block argument is a loop
3895 // invariant.
3896 auto yieldOpBlockArg = llvm::dyn_cast<BlockArgument>(yieldOpArg);
3897 if (yieldOpBlockArg && yieldOpBlockArg.getOwner() == &afterBlock) {
3898 Value condOpArg = condOpArgs[yieldOpBlockArg.getArgNumber()];
3899 if (condOpArg == beforeBlockArgs[index] || condOpArg == initVal) {
3900 canSimplify = true;
3901 break;
3902 }
3903 }
3904 }
3905
3906 if (!canSimplify)
3907 return failure();
3908
3909 SmallVector<Value> newInitArgs, newYieldOpArgs;
3910 DenseMap<unsigned, Value> beforeBlockInitValMap;
3911 SmallVector<Location> newBeforeBlockArgLocs;
3912 for (const auto &it :
3913 llvm::enumerate(llvm::zip(op.getOperands(), yieldOpArgs))) {
3914 auto index = static_cast<unsigned>(it.index());
3915 auto [initVal, yieldOpArg] = it.value();
3916
3917 // If i-th yield operand is equal to the i-th operand of the scf.while,
3918 // the i-th before block argument is a loop invariant.
3919 if (yieldOpArg == initVal) {
3920 beforeBlockInitValMap.insert({index, initVal});
3921 continue;
3922 } else {
3923 // If the i-th yield operand is k-th after block argument, then we check
3924 // if the (k+1)-th condition op operand is equal to either the i-th
3925 // before block argument or the initial value of i-th before block
3926 // argument. If the comparison results `true`, i-th before block
3927 // argument is a loop invariant.
3928 auto yieldOpBlockArg = llvm::dyn_cast<BlockArgument>(yieldOpArg);
3929 if (yieldOpBlockArg && yieldOpBlockArg.getOwner() == &afterBlock) {
3930 Value condOpArg = condOpArgs[yieldOpBlockArg.getArgNumber()];
3931 if (condOpArg == beforeBlockArgs[index] || condOpArg == initVal) {
3932 beforeBlockInitValMap.insert({index, initVal});
3933 continue;
3934 }
3935 }
3936 }
3937 newInitArgs.emplace_back(initVal);
3938 newYieldOpArgs.emplace_back(yieldOpArg);
3939 newBeforeBlockArgLocs.emplace_back(beforeBlockArgs[index].getLoc());
3940 }
3941
3942 {
3943 OpBuilder::InsertionGuard g(rewriter);
3944 rewriter.setInsertionPoint(yieldOp);
3945 rewriter.replaceOpWithNewOp<YieldOp>(yieldOp, newYieldOpArgs);
3946 }
3947
3948 auto newWhile = WhileOp::create(rewriter, op.getLoc(), op.getResultTypes(),
3949 newInitArgs);
3950
3951 Block &newBeforeBlock = *rewriter.createBlock(
3952 &newWhile.getBefore(), /*insertPt*/ {},
3953 ValueRange(newYieldOpArgs).getTypes(), newBeforeBlockArgLocs);
3954
3955 Block &beforeBlock = *op.getBeforeBody();
3956 SmallVector<Value> newBeforeBlockArgs(beforeBlock.getNumArguments());
3957 // For each i-th before block argument we find it's replacement value as :-
3958 // 1. If i-th before block argument is a loop invariant, we fetch it's
3959 // initial value from `beforeBlockInitValMap` by querying for key `i`.
3960 // 2. Else we fetch j-th new before block argument as the replacement
3961 // value of i-th before block argument.
3962 for (unsigned i = 0, j = 0, n = beforeBlock.getNumArguments(); i < n; i++) {
3963 // If the index 'i' argument was a loop invariant we fetch it's initial
3964 // value from `beforeBlockInitValMap`.
3965 if (beforeBlockInitValMap.count(i) != 0)
3966 newBeforeBlockArgs[i] = beforeBlockInitValMap[i];
3967 else
3968 newBeforeBlockArgs[i] = newBeforeBlock.getArgument(j++);
3969 }
3970
3971 rewriter.mergeBlocks(&beforeBlock, &newBeforeBlock, newBeforeBlockArgs);
3972 rewriter.inlineRegionBefore(op.getAfter(), newWhile.getAfter(),
3973 newWhile.getAfter().begin());
3974
3975 rewriter.replaceOp(op, newWhile.getResults());
3976 return success();
3977 }
3978};
3979
3980/// Remove loop invariant value from result (condition op) of scf.while.
3981/// A value is considered loop invariant if the final value yielded by
3982/// scf.condition is defined outside of the `before` block. We remove the
3983/// corresponding argument in `after` block and replace the use with the value.
3984/// We also replace the use of the corresponding result of scf.while with the
3985/// value.
3986///
3987/// Eg:
3988/// INPUT :-
3989/// %res_input:K = scf.while <...> iter_args(%arg0_before = , ...,
3990/// %argN_before = %N) {
3991/// ...
3992/// scf.condition(%cond) %arg0_before, %a, %b, %arg1_before, ...
3993/// } do {
3994/// ^bb0(%arg0_after, %arg1_after, %arg2_after, ..., %argK_after):
3995/// ...
3996/// some_func(%arg1_after)
3997/// ...
3998/// scf.yield %arg0_after, %arg2_after, ..., %argN_after
3999/// }
4000///
4001/// OUTPUT :-
4002/// %res_output:M = scf.while <...> iter_args(%arg0 = , ..., %argN = %N) {
4003/// ...
4004/// scf.condition(%cond) %arg0, %arg1, ..., %argM
4005/// } do {
4006/// ^bb0(%arg0, %arg3, ..., %argM):
4007/// ...
4008/// some_func(%a)
4009/// ...
4010/// scf.yield %arg0, %b, ..., %argN
4011/// }
4012///
4013/// EXPLANATION:
4014/// 1. The 1-th and 2-th operand of scf.condition are defined outside the
4015/// before block of scf.while, so they get removed.
4016/// 2. %res_input#1's uses are replaced by %a and %res_input#2's uses are
4017/// replaced by %b.
4018/// 3. The corresponding after block argument %arg1_after's uses are
4019/// replaced by %a and %arg2_after's uses are replaced by %b.
4020struct RemoveLoopInvariantValueYielded : public OpRewritePattern<WhileOp> {
4021 using OpRewritePattern<WhileOp>::OpRewritePattern;
4022
4023 LogicalResult matchAndRewrite(WhileOp op,
4024 PatternRewriter &rewriter) const override {
4025 Block &beforeBlock = *op.getBeforeBody();
4026 ConditionOp condOp = op.getConditionOp();
4027 OperandRange condOpArgs = condOp.getArgs();
4028
4029 bool canSimplify = false;
4030 for (Value condOpArg : condOpArgs) {
4031 // Those values not defined within `before` block will be considered as
4032 // loop invariant values. We map the corresponding `index` with their
4033 // value.
4034 if (condOpArg.getParentBlock() != &beforeBlock) {
4035 canSimplify = true;
4036 break;
4037 }
4038 }
4039
4040 if (!canSimplify)
4041 return failure();
4042
4043 Block::BlockArgListType afterBlockArgs = op.getAfterArguments();
4044
4045 SmallVector<Value> newCondOpArgs;
4046 SmallVector<Type> newAfterBlockType;
4047 DenseMap<unsigned, Value> condOpInitValMap;
4048 SmallVector<Location> newAfterBlockArgLocs;
4049 for (const auto &it : llvm::enumerate(condOpArgs)) {
4050 auto index = static_cast<unsigned>(it.index());
4051 Value condOpArg = it.value();
4052 // Those values not defined within `before` block will be considered as
4053 // loop invariant values. We map the corresponding `index` with their
4054 // value.
4055 if (condOpArg.getParentBlock() != &beforeBlock) {
4056 condOpInitValMap.insert({index, condOpArg});
4057 } else {
4058 newCondOpArgs.emplace_back(condOpArg);
4059 newAfterBlockType.emplace_back(condOpArg.getType());
4060 newAfterBlockArgLocs.emplace_back(afterBlockArgs[index].getLoc());
4061 }
4062 }
4063
4064 {
4065 OpBuilder::InsertionGuard g(rewriter);
4066 rewriter.setInsertionPoint(condOp);
4067 rewriter.replaceOpWithNewOp<ConditionOp>(condOp, condOp.getCondition(),
4068 newCondOpArgs);
4069 }
4070
4071 auto newWhile = WhileOp::create(rewriter, op.getLoc(), newAfterBlockType,
4072 op.getOperands());
4073
4074 Block &newAfterBlock =
4075 *rewriter.createBlock(&newWhile.getAfter(), /*insertPt*/ {},
4076 newAfterBlockType, newAfterBlockArgLocs);
4077
4078 Block &afterBlock = *op.getAfterBody();
4079 // Since a new scf.condition op was created, we need to fetch the new
4080 // `after` block arguments which will be used while replacing operations of
4081 // previous scf.while's `after` blocks. We'd also be fetching new result
4082 // values too.
4083 SmallVector<Value> newAfterBlockArgs(afterBlock.getNumArguments());
4084 SmallVector<Value> newWhileResults(afterBlock.getNumArguments());
4085 for (unsigned i = 0, j = 0, n = afterBlock.getNumArguments(); i < n; i++) {
4086 Value afterBlockArg, result;
4087 // If index 'i' argument was loop invariant we fetch it's value from the
4088 // `condOpInitMap` map.
4089 if (condOpInitValMap.count(i) != 0) {
4090 afterBlockArg = condOpInitValMap[i];
4091 result = afterBlockArg;
4092 } else {
4093 afterBlockArg = newAfterBlock.getArgument(j);
4094 result = newWhile.getResult(j);
4095 j++;
4096 }
4097 newAfterBlockArgs[i] = afterBlockArg;
4098 newWhileResults[i] = result;
4099 }
4100
4101 rewriter.mergeBlocks(&afterBlock, &newAfterBlock, newAfterBlockArgs);
4102 rewriter.inlineRegionBefore(op.getBefore(), newWhile.getBefore(),
4103 newWhile.getBefore().begin());
4104
4105 rewriter.replaceOp(op, newWhileResults);
4106 return success();
4107 }
4108};
4109
4110/// Remove WhileOp results that are also unused in 'after' block.
4111///
4112/// %0:2 = scf.while () : () -> (i32, i64) {
4113/// %condition = "test.condition"() : () -> i1
4114/// %v1 = "test.get_some_value"() : () -> i32
4115/// %v2 = "test.get_some_value"() : () -> i64
4116/// scf.condition(%condition) %v1, %v2 : i32, i64
4117/// } do {
4118/// ^bb0(%arg0: i32, %arg1: i64):
4119/// "test.use"(%arg0) : (i32) -> ()
4120/// scf.yield
4121/// }
4122/// return %0#0 : i32
4123///
4124/// becomes
4125/// %0 = scf.while () : () -> (i32) {
4126/// %condition = "test.condition"() : () -> i1
4127/// %v1 = "test.get_some_value"() : () -> i32
4128/// %v2 = "test.get_some_value"() : () -> i64
4129/// scf.condition(%condition) %v1 : i32
4130/// } do {
4131/// ^bb0(%arg0: i32):
4132/// "test.use"(%arg0) : (i32) -> ()
4133/// scf.yield
4134/// }
4135/// return %0 : i32
4136struct WhileUnusedResult : public OpRewritePattern<WhileOp> {
4137 using OpRewritePattern<WhileOp>::OpRewritePattern;
4138
4139 LogicalResult matchAndRewrite(WhileOp op,
4140 PatternRewriter &rewriter) const override {
4141 auto term = op.getConditionOp();
4142 auto afterArgs = op.getAfterArguments();
4143 auto termArgs = term.getArgs();
4144
4145 // Collect results mapping, new terminator args and new result types.
4146 SmallVector<unsigned> newResultsIndices;
4147 SmallVector<Type> newResultTypes;
4148 SmallVector<Value> newTermArgs;
4149 SmallVector<Location> newArgLocs;
4150 bool needUpdate = false;
4151 for (const auto &it :
4152 llvm::enumerate(llvm::zip(op.getResults(), afterArgs, termArgs))) {
4153 auto i = static_cast<unsigned>(it.index());
4154 Value result = std::get<0>(it.value());
4155 Value afterArg = std::get<1>(it.value());
4156 Value termArg = std::get<2>(it.value());
4157 if (result.use_empty() && afterArg.use_empty()) {
4158 needUpdate = true;
4159 } else {
4160 newResultsIndices.emplace_back(i);
4161 newTermArgs.emplace_back(termArg);
4162 newResultTypes.emplace_back(result.getType());
4163 newArgLocs.emplace_back(result.getLoc());
4164 }
4165 }
4166
4167 if (!needUpdate)
4168 return failure();
4169
4170 {
4171 OpBuilder::InsertionGuard g(rewriter);
4172 rewriter.setInsertionPoint(term);
4173 rewriter.replaceOpWithNewOp<ConditionOp>(term, term.getCondition(),
4174 newTermArgs);
4175 }
4176
4177 auto newWhile =
4178 WhileOp::create(rewriter, op.getLoc(), newResultTypes, op.getInits());
4179
4180 Block &newAfterBlock = *rewriter.createBlock(
4181 &newWhile.getAfter(), /*insertPt*/ {}, newResultTypes, newArgLocs);
4182
4183 // Build new results list and new after block args (unused entries will be
4184 // null).
4185 SmallVector<Value> newResults(op.getNumResults());
4186 SmallVector<Value> newAfterBlockArgs(op.getNumResults());
4187 for (const auto &it : llvm::enumerate(newResultsIndices)) {
4188 newResults[it.value()] = newWhile.getResult(it.index());
4189 newAfterBlockArgs[it.value()] = newAfterBlock.getArgument(it.index());
4190 }
4191
4192 rewriter.inlineRegionBefore(op.getBefore(), newWhile.getBefore(),
4193 newWhile.getBefore().begin());
4194
4195 Block &afterBlock = *op.getAfterBody();
4196 rewriter.mergeBlocks(&afterBlock, &newAfterBlock, newAfterBlockArgs);
4197
4198 rewriter.replaceOp(op, newResults);
4199 return success();
4200 }
4201};
4202
4203/// Replace operations equivalent to the condition in the do block with true,
4204/// since otherwise the block would not be evaluated.
4205///
4206/// scf.while (..) : (i32, ...) -> ... {
4207/// %z = ... : i32
4208/// %condition = cmpi pred %z, %a
4209/// scf.condition(%condition) %z : i32, ...
4210/// } do {
4211/// ^bb0(%arg0: i32, ...):
4212/// %condition2 = cmpi pred %arg0, %a
4213/// use(%condition2)
4214/// ...
4215///
4216/// becomes
4217/// scf.while (..) : (i32, ...) -> ... {
4218/// %z = ... : i32
4219/// %condition = cmpi pred %z, %a
4220/// scf.condition(%condition) %z : i32, ...
4221/// } do {
4222/// ^bb0(%arg0: i32, ...):
4223/// use(%true)
4224/// ...
4225struct WhileCmpCond : public OpRewritePattern<scf::WhileOp> {
4226 using OpRewritePattern<scf::WhileOp>::OpRewritePattern;
4227
4228 LogicalResult matchAndRewrite(scf::WhileOp op,
4229 PatternRewriter &rewriter) const override {
4230 using namespace scf;
4231 auto cond = op.getConditionOp();
4232 auto cmp = cond.getCondition().getDefiningOp<arith::CmpIOp>();
4233 if (!cmp)
4234 return failure();
4235 bool changed = false;
4236 for (auto tup : llvm::zip(cond.getArgs(), op.getAfterArguments())) {
4237 for (size_t opIdx = 0; opIdx < 2; opIdx++) {
4238 if (std::get<0>(tup) != cmp.getOperand(opIdx))
4239 continue;
4240 for (OpOperand &u :
4241 llvm::make_early_inc_range(std::get<1>(tup).getUses())) {
4242 auto cmp2 = dyn_cast<arith::CmpIOp>(u.getOwner());
4243 if (!cmp2)
4244 continue;
4245 // For a binary operator 1-opIdx gets the other side.
4246 if (cmp2.getOperand(1 - opIdx) != cmp.getOperand(1 - opIdx))
4247 continue;
4248 bool samePredicate;
4249 if (cmp2.getPredicate() == cmp.getPredicate())
4250 samePredicate = true;
4251 else if (cmp2.getPredicate() ==
4252 arith::invertPredicate(cmp.getPredicate()))
4253 samePredicate = false;
4254 else
4255 continue;
4256
4257 rewriter.replaceOpWithNewOp<arith::ConstantIntOp>(cmp2, samePredicate,
4258 1);
4259 changed = true;
4260 }
4261 }
4262 }
4263 return success(changed);
4264 }
4265};
4266
4267/// Remove unused init/yield args.
4268struct WhileRemoveUnusedArgs : public OpRewritePattern<WhileOp> {
4269 using OpRewritePattern<WhileOp>::OpRewritePattern;
4270
4271 LogicalResult matchAndRewrite(WhileOp op,
4272 PatternRewriter &rewriter) const override {
4273
4274 if (!llvm::any_of(op.getBeforeArguments(),
4275 [](Value arg) { return arg.use_empty(); }))
4276 return rewriter.notifyMatchFailure(op, "No args to remove");
4277
4278 YieldOp yield = op.getYieldOp();
4279
4280 // Collect results mapping, new terminator args and new result types.
4281 SmallVector<Value> newYields;
4282 SmallVector<Value> newInits;
4283 llvm::BitVector argsToErase;
4284
4285 size_t argsCount = op.getBeforeArguments().size();
4286 newYields.reserve(argsCount);
4287 newInits.reserve(argsCount);
4288 argsToErase.reserve(argsCount);
4289 for (auto &&[beforeArg, yieldValue, initValue] : llvm::zip(
4290 op.getBeforeArguments(), yield.getOperands(), op.getInits())) {
4291 if (beforeArg.use_empty()) {
4292 argsToErase.push_back(true);
4293 } else {
4294 argsToErase.push_back(false);
4295 newYields.emplace_back(yieldValue);
4296 newInits.emplace_back(initValue);
4297 }
4298 }
4299
4300 Block &beforeBlock = *op.getBeforeBody();
4301 Block &afterBlock = *op.getAfterBody();
4302
4303 beforeBlock.eraseArguments(argsToErase);
4304
4305 Location loc = op.getLoc();
4306 auto newWhileOp =
4307 WhileOp::create(rewriter, loc, op.getResultTypes(), newInits,
4308 /*beforeBody*/ nullptr, /*afterBody*/ nullptr);
4309 Block &newBeforeBlock = *newWhileOp.getBeforeBody();
4310 Block &newAfterBlock = *newWhileOp.getAfterBody();
4311
4312 OpBuilder::InsertionGuard g(rewriter);
4313 rewriter.setInsertionPoint(yield);
4314 rewriter.replaceOpWithNewOp<YieldOp>(yield, newYields);
4315
4316 rewriter.mergeBlocks(&beforeBlock, &newBeforeBlock,
4317 newBeforeBlock.getArguments());
4318 rewriter.mergeBlocks(&afterBlock, &newAfterBlock,
4319 newAfterBlock.getArguments());
4320
4321 rewriter.replaceOp(op, newWhileOp.getResults());
4322 return success();
4323 }
4324};
4325
4326/// Remove duplicated ConditionOp args.
4327struct WhileRemoveDuplicatedResults : public OpRewritePattern<WhileOp> {
4329
4330 LogicalResult matchAndRewrite(WhileOp op,
4331 PatternRewriter &rewriter) const override {
4332 ConditionOp condOp = op.getConditionOp();
4333 ValueRange condOpArgs = condOp.getArgs();
4334
4335 llvm::SmallPtrSet<Value, 8> argsSet(llvm::from_range, condOpArgs);
4336
4337 if (argsSet.size() == condOpArgs.size())
4338 return rewriter.notifyMatchFailure(op, "No results to remove");
4339
4340 llvm::SmallDenseMap<Value, unsigned> argsMap;
4341 SmallVector<Value> newArgs;
4342 argsMap.reserve(condOpArgs.size());
4343 newArgs.reserve(condOpArgs.size());
4344 for (Value arg : condOpArgs) {
4345 if (!argsMap.count(arg)) {
4346 auto pos = static_cast<unsigned>(argsMap.size());
4347 argsMap.insert({arg, pos});
4348 newArgs.emplace_back(arg);
4349 }
4350 }
4351
4352 ValueRange argsRange(newArgs);
4353
4354 Location loc = op.getLoc();
4355 auto newWhileOp =
4356 scf::WhileOp::create(rewriter, loc, argsRange.getTypes(), op.getInits(),
4357 /*beforeBody*/ nullptr,
4358 /*afterBody*/ nullptr);
4359 Block &newBeforeBlock = *newWhileOp.getBeforeBody();
4360 Block &newAfterBlock = *newWhileOp.getAfterBody();
4361
4362 SmallVector<Value> afterArgsMapping;
4363 SmallVector<Value> resultsMapping;
4364 for (auto &&[i, arg] : llvm::enumerate(condOpArgs)) {
4365 auto it = argsMap.find(arg);
4366 assert(it != argsMap.end());
4367 auto pos = it->second;
4368 afterArgsMapping.emplace_back(newAfterBlock.getArgument(pos));
4369 resultsMapping.emplace_back(newWhileOp->getResult(pos));
4370 }
4371
4372 OpBuilder::InsertionGuard g(rewriter);
4373 rewriter.setInsertionPoint(condOp);
4374 rewriter.replaceOpWithNewOp<ConditionOp>(condOp, condOp.getCondition(),
4375 argsRange);
4376
4377 Block &beforeBlock = *op.getBeforeBody();
4378 Block &afterBlock = *op.getAfterBody();
4379
4380 rewriter.mergeBlocks(&beforeBlock, &newBeforeBlock,
4381 newBeforeBlock.getArguments());
4382 rewriter.mergeBlocks(&afterBlock, &newAfterBlock, afterArgsMapping);
4383 rewriter.replaceOp(op, resultsMapping);
4384 return success();
4385 }
4386};
4387
4388/// If both ranges contain same values return mappping indices from args2 to
4389/// args1. Otherwise return std::nullopt.
4390static std::optional<SmallVector<unsigned>> getArgsMapping(ValueRange args1,
4391 ValueRange args2) {
4392 if (args1.size() != args2.size())
4393 return std::nullopt;
4394
4395 SmallVector<unsigned> ret(args1.size());
4396 for (auto &&[i, arg1] : llvm::enumerate(args1)) {
4397 auto it = llvm::find(args2, arg1);
4398 if (it == args2.end())
4399 return std::nullopt;
4400
4401 ret[std::distance(args2.begin(), it)] = static_cast<unsigned>(i);
4402 }
4403
4404 return ret;
4405}
4406
4407static bool hasDuplicates(ValueRange args) {
4408 llvm::SmallDenseSet<Value> set;
4409 for (Value arg : args) {
4410 if (!set.insert(arg).second)
4411 return true;
4412 }
4413 return false;
4414}
4415
4416/// If `before` block args are directly forwarded to `scf.condition`, rearrange
4417/// `scf.condition` args into same order as block args. Update `after` block
4418/// args and op result values accordingly.
4419/// Needed to simplify `scf.while` -> `scf.for` uplifting.
4420struct WhileOpAlignBeforeArgs : public OpRewritePattern<WhileOp> {
4422
4423 LogicalResult matchAndRewrite(WhileOp loop,
4424 PatternRewriter &rewriter) const override {
4425 auto *oldBefore = loop.getBeforeBody();
4426 ConditionOp oldTerm = loop.getConditionOp();
4427 ValueRange beforeArgs = oldBefore->getArguments();
4428 ValueRange termArgs = oldTerm.getArgs();
4429 if (beforeArgs == termArgs)
4430 return failure();
4431
4432 if (hasDuplicates(termArgs))
4433 return failure();
4434
4435 auto mapping = getArgsMapping(beforeArgs, termArgs);
4436 if (!mapping)
4437 return failure();
4438
4439 {
4440 OpBuilder::InsertionGuard g(rewriter);
4441 rewriter.setInsertionPoint(oldTerm);
4442 rewriter.replaceOpWithNewOp<ConditionOp>(oldTerm, oldTerm.getCondition(),
4443 beforeArgs);
4444 }
4445
4446 auto *oldAfter = loop.getAfterBody();
4447
4448 SmallVector<Type> newResultTypes(beforeArgs.size());
4449 for (auto &&[i, j] : llvm::enumerate(*mapping))
4450 newResultTypes[j] = loop.getResult(i).getType();
4451
4452 auto newLoop = WhileOp::create(
4453 rewriter, loop.getLoc(), newResultTypes, loop.getInits(),
4454 /*beforeBuilder=*/nullptr, /*afterBuilder=*/nullptr);
4455 auto *newBefore = newLoop.getBeforeBody();
4456 auto *newAfter = newLoop.getAfterBody();
4457
4458 SmallVector<Value> newResults(beforeArgs.size());
4459 SmallVector<Value> newAfterArgs(beforeArgs.size());
4460 for (auto &&[i, j] : llvm::enumerate(*mapping)) {
4461 newResults[i] = newLoop.getResult(j);
4462 newAfterArgs[i] = newAfter->getArgument(j);
4463 }
4464
4465 rewriter.inlineBlockBefore(oldBefore, newBefore, newBefore->begin(),
4466 newBefore->getArguments());
4467 rewriter.inlineBlockBefore(oldAfter, newAfter, newAfter->begin(),
4468 newAfterArgs);
4469
4470 rewriter.replaceOp(loop, newResults);
4471 return success();
4472 }
4473};
4474} // namespace
4475
4476void WhileOp::getCanonicalizationPatterns(RewritePatternSet &results,
4477 MLIRContext *context) {
4478 results.add<RemoveLoopInvariantArgsFromBeforeBlock,
4479 RemoveLoopInvariantValueYielded, WhileConditionTruth,
4480 WhileCmpCond, WhileUnusedResult, WhileRemoveDuplicatedResults,
4481 WhileRemoveUnusedArgs, WhileOpAlignBeforeArgs, WhileMoveIfDown>(
4482 context);
4483}
4484
4485//===----------------------------------------------------------------------===//
4486// IndexSwitchOp
4487//===----------------------------------------------------------------------===//
4488
4489/// Parse the case regions and values.
4490static ParseResult
4492 SmallVectorImpl<std::unique_ptr<Region>> &caseRegions) {
4493 SmallVector<int64_t> caseValues;
4494 while (succeeded(p.parseOptionalKeyword("case"))) {
4495 int64_t value;
4496 Region &region = *caseRegions.emplace_back(std::make_unique<Region>());
4497 if (p.parseInteger(value) || p.parseRegion(region, /*arguments=*/{}))
4498 return failure();
4499 caseValues.push_back(value);
4500 }
4501 cases = p.getBuilder().getDenseI64ArrayAttr(caseValues);
4502 return success();
4503}
4504
4505/// Print the case regions and values.
4507 DenseI64ArrayAttr cases, RegionRange caseRegions) {
4508 for (auto [value, region] : llvm::zip(cases.asArrayRef(), caseRegions)) {
4509 p.printNewline();
4510 p << "case " << value << ' ';
4511 p.printRegion(*region, /*printEntryBlockArgs=*/false);
4512 }
4513}
4514
4515LogicalResult scf::IndexSwitchOp::verify() {
4516 if (getCases().size() != getCaseRegions().size()) {
4517 return emitOpError("has ")
4518 << getCaseRegions().size() << " case regions but "
4519 << getCases().size() << " case values";
4520 }
4521
4522 DenseSet<int64_t> valueSet;
4523 for (int64_t value : getCases())
4524 if (!valueSet.insert(value).second)
4525 return emitOpError("has duplicate case value: ") << value;
4526 auto verifyRegion = [&](Region &region, const Twine &name) -> LogicalResult {
4527 auto yield = dyn_cast<YieldOp>(region.front().back());
4528 if (!yield)
4529 return emitOpError("expected region to end with scf.yield, but got ")
4530 << region.front().back().getName();
4531
4532 if (yield.getNumOperands() != getNumResults()) {
4533 return (emitOpError("expected each region to return ")
4534 << getNumResults() << " values, but " << name << " returns "
4535 << yield.getNumOperands())
4536 .attachNote(yield.getLoc())
4537 << "see yield operation here";
4538 }
4539 for (auto [idx, result, operand] :
4540 llvm::enumerate(getResultTypes(), yield.getOperands())) {
4541 if (!operand)
4542 return yield.emitOpError() << "operand " << idx << " is null\n";
4543 if (result == operand.getType())
4544 continue;
4545 return (emitOpError("expected result #")
4546 << idx << " of each region to be " << result)
4547 .attachNote(yield.getLoc())
4548 << name << " returns " << operand.getType() << " here";
4549 }
4550 return success();
4551 };
4552
4553 if (failed(verifyRegion(getDefaultRegion(), "default region")))
4554 return failure();
4555 for (auto [idx, caseRegion] : llvm::enumerate(getCaseRegions()))
4556 if (failed(verifyRegion(caseRegion, "case region #" + Twine(idx))))
4557 return failure();
4558
4559 return success();
4560}
4561
4562unsigned scf::IndexSwitchOp::getNumCases() { return getCases().size(); }
4563
4564Block &scf::IndexSwitchOp::getDefaultBlock() {
4565 return getDefaultRegion().front();
4566}
4567
4568Block &scf::IndexSwitchOp::getCaseBlock(unsigned idx) {
4569 assert(idx < getNumCases() && "case index out-of-bounds");
4570 return getCaseRegions()[idx].front();
4571}
4572
4573void IndexSwitchOp::getSuccessorRegions(
4574 RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &successors) {
4575 // All regions branch back to the parent op.
4576 if (!point.isParent()) {
4577 successors.emplace_back(getOperation(), getResults());
4578 return;
4579 }
4580
4581 llvm::append_range(successors, getRegions());
4582}
4583
4584void IndexSwitchOp::getEntrySuccessorRegions(
4585 ArrayRef<Attribute> operands,
4586 SmallVectorImpl<RegionSuccessor> &successors) {
4587 FoldAdaptor adaptor(operands, *this);
4588
4589 // If a constant was not provided, all regions are possible successors.
4590 auto arg = dyn_cast_or_null<IntegerAttr>(adaptor.getArg());
4591 if (!arg) {
4592 llvm::append_range(successors, getRegions());
4593 return;
4594 }
4595
4596 // Otherwise, try to find a case with a matching value. If not, the
4597 // default region is the only successor.
4598 for (auto [caseValue, caseRegion] : llvm::zip(getCases(), getCaseRegions())) {
4599 if (caseValue == arg.getInt()) {
4600 successors.emplace_back(&caseRegion);
4601 return;
4602 }
4603 }
4604 successors.emplace_back(&getDefaultRegion());
4605}
4606
4607void IndexSwitchOp::getRegionInvocationBounds(
4608 ArrayRef<Attribute> operands, SmallVectorImpl<InvocationBounds> &bounds) {
4609 auto operandValue = llvm::dyn_cast_or_null<IntegerAttr>(operands.front());
4610 if (!operandValue) {
4611 // All regions are invoked at most once.
4612 bounds.append(getNumRegions(), InvocationBounds(/*lb=*/0, /*ub=*/1));
4613 return;
4614 }
4615
4616 unsigned liveIndex = getNumRegions() - 1;
4617 const auto *it = llvm::find(getCases(), operandValue.getInt());
4618 if (it != getCases().end())
4619 liveIndex = std::distance(getCases().begin(), it);
4620 for (unsigned i = 0, e = getNumRegions(); i < e; ++i)
4621 bounds.emplace_back(/*lb=*/0, /*ub=*/i == liveIndex);
4622}
4623
4624struct FoldConstantCase : OpRewritePattern<scf::IndexSwitchOp> {
4625 using OpRewritePattern<scf::IndexSwitchOp>::OpRewritePattern;
4626
4627 LogicalResult matchAndRewrite(scf::IndexSwitchOp op,
4628 PatternRewriter &rewriter) const override {
4629 // If `op.getArg()` is a constant, select the region that matches with
4630 // the constant value. Use the default region if no matche is found.
4631 std::optional<int64_t> maybeCst = getConstantIntValue(op.getArg());
4632 if (!maybeCst.has_value())
4633 return failure();
4634 int64_t cst = *maybeCst;
4635 int64_t caseIdx, e = op.getNumCases();
4636 for (caseIdx = 0; caseIdx < e; ++caseIdx) {
4637 if (cst == op.getCases()[caseIdx])
4638 break;
4639 }
4640
4641 Region &r = (caseIdx < op.getNumCases()) ? op.getCaseRegions()[caseIdx]
4642 : op.getDefaultRegion();
4643 Block &source = r.front();
4644 Operation *terminator = source.getTerminator();
4645 SmallVector<Value> results = terminator->getOperands();
4646
4647 rewriter.inlineBlockBefore(&source, op);
4648 rewriter.eraseOp(terminator);
4649 // Replace the operation with a potentially empty list of results.
4650 // Fold mechanism doesn't support the case where the result list is empty.
4651 rewriter.replaceOp(op, results);
4652
4653 return success();
4654 }
4655};
4656
4657/// Canonicalization patterns that folds away dead results of
4658/// "scf.index_switch" ops.
4660 using OpRewritePattern<IndexSwitchOp>::OpRewritePattern;
4661
4662 LogicalResult matchAndRewrite(IndexSwitchOp op,
4663 PatternRewriter &rewriter) const override {
4664 // Find dead results.
4665 BitVector deadResults(op.getNumResults(), false);
4666 for (auto [idx, result] : llvm::enumerate(op.getResults()))
4667 if (result.use_empty())
4668 deadResults[idx] = true;
4669 if (!deadResults.any())
4670 return rewriter.notifyMatchFailure(op, "no dead results to fold");
4671
4672 // Erase dead results.
4673 auto newOp =
4674 cast<scf::IndexSwitchOp>(rewriter.eraseOpResults(op, deadResults));
4675
4676 // Erase operands from yield ops.
4677 auto updateCaseRegion = [&](Region &region) {
4678 Operation *terminator = region.front().getTerminator();
4679 assert(isa<YieldOp>(terminator) && "expected yield op");
4680 rewriter.modifyOpInPlace(
4681 terminator, [&]() { terminator->eraseOperands(deadResults); });
4682 };
4683 updateCaseRegion(newOp.getDefaultRegion());
4684 for (Region &caseRegion : newOp.getCaseRegions())
4685 updateCaseRegion(caseRegion);
4686
4687 return success();
4688 }
4689};
4690
4691void IndexSwitchOp::getCanonicalizationPatterns(RewritePatternSet &results,
4692 MLIRContext *context) {
4693 results.add<FoldConstantCase, FoldUnusedIndexSwitchResults>(context);
4694}
4695
4696//===----------------------------------------------------------------------===//
4697// TableGen'd op method definitions
4698//===----------------------------------------------------------------------===//
4699
4700#define GET_OP_CLASSES
4701#include "mlir/Dialect/SCF/IR/SCFOps.cpp.inc"
return success()
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
static std::optional< int64_t > getUpperBound(Value iv)
Gets the constant upper bound on an affine.for iv.
static std::optional< int64_t > getLowerBound(Value iv)
Gets the constant lower bound on an iv.
static LogicalResult verifyRegion(emitc::SwitchOp op, Region &region, const Twine &name)
Definition EmitC.cpp:1400
static ParseResult parseSwitchCases(OpAsmParser &parser, DenseI64ArrayAttr &cases, SmallVectorImpl< std::unique_ptr< Region > > &caseRegions)
Parse the case regions and values.
Definition EmitC.cpp:1375
static void printSwitchCases(OpAsmPrinter &p, Operation *op, DenseI64ArrayAttr cases, RegionRange caseRegions)
Print the case regions and values.
Definition EmitC.cpp:1391
static void replaceOpWithRegion(PatternRewriter &rewriter, Operation *op, Region &region, ValueRange blockArgs={})
Replaces the given op with the contents of the given single-block region, using the operands of the b...
Definition SCF.cpp:137
static LogicalResult verifyTypeRangesMatch(OpTy op, TypeRange left, TypeRange right, StringRef message)
Verifies that two ranges of types match, i.e.
Definition SCF.cpp:3610
static void printInitializationList(OpAsmPrinter &p, Block::BlockArgListType blocksArgs, ValueRange initializers, StringRef prefix="")
Prints the initialization list in the form of <prefix>(inner = outer, inner2 = outer2,...
Definition SCF.cpp:566
static TerminatorTy verifyAndGetTerminator(Operation *op, Region &region, StringRef errorMessage)
Verifies that the first block of the given region is terminated by a TerminatorTy.
Definition SCF.cpp:101
static bool isLegalToInline(InlinerInterface &interface, Region *src, Region *insertRegion, bool shouldCloneInlinedRegion, IRMapping &valueMapping)
Utility to check that all of the operations within 'src' can be inlined.
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
b getContext())
*if copies could not be generated due to yet unimplemented cases *copyInPlacementStart and copyOutPlacementStart in copyPlacementBlock *specify the insertion points where the incoming copies and outgoing should be the output argument nBegin is set to its * replacement(set to `begin` if no invalidation happens). Since outgoing *copies could have been inserted at `end`
@ None
static std::string diag(const llvm::Value &value)
@ Paren
Parens surrounding zero or more operands.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseOptionalColon()=0
Parse a : token if present.
ParseResult parseInteger(IntT &result)
Parse an integer value from the stream.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual ParseResult parseOptionalAttrDictWithKeyword(NamedAttrList &result)=0
Parse a named dictionary into 'result' if the attributes keyword is present.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseOptionalArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an optional arrow followed by a type list.
virtual ParseResult parseArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an arrow followed by a type list.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
void printOptionalArrowTypeList(TypeRange &&types)
Print an optional arrow followed by a type list.
virtual void printNewline()
Print a newline and indent the printer to the start of the current operation/attribute/type.
This class represents an argument of a Block.
Definition Value.h:309
unsigned getArgNumber() const
Returns the number of this argument.
Definition Value.h:321
Block represents an ordered list of Operations.
Definition Block.h:33
MutableArrayRef< BlockArgument > BlockArgListType
Definition Block.h:95
bool empty()
Definition Block.h:158
BlockArgument getArgument(unsigned i)
Definition Block.h:139
unsigned getNumArguments()
Definition Block.h:138
iterator_range< args_iterator > addArguments(TypeRange types, ArrayRef< Location > locs)
Add one argument to the argument list for each type specified in the list.
Definition Block.cpp:165
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
Definition Block.cpp:27
Operation & front()
Definition Block.h:163
Operation & back()
Definition Block.h:162
Operation * getTerminator()
Get the terminator operation of this block.
Definition Block.cpp:249
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
Definition Block.cpp:158
void eraseArguments(unsigned start, unsigned num)
Erases 'num' arguments from the index 'start'.
Definition Block.cpp:206
BlockArgListType getArguments()
Definition Block.h:97
iterator_range< iterator > without_terminator()
Return an iterator range over the operation within this block excluding the terminator operation at t...
Definition Block.h:222
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
Definition Block.cpp:31
bool getValue() const
Return the boolean value of this attribute.
UnitAttr getUnitAttr()
Definition Builders.cpp:98
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
Definition Builders.cpp:163
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition Builders.cpp:228
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
Definition Builders.cpp:167
IntegerType getIntegerType(unsigned width)
Definition Builders.cpp:67
BoolAttr getBoolAttr(bool value)
Definition Builders.cpp:100
IntegerType getI1Type()
Definition Builders.cpp:53
MLIRContext * getContext() const
Definition Builders.h:56
IndexType getIndexType()
Definition Builders.cpp:51
virtual Operation * materializeConstant(OpBuilder &builder, Attribute value, Type type, Location loc)
Registered hook to materialize a single constant operation from a given attribute value with the desi...
Definition Dialect.h:83
This is a utility class for mapping one set of IR entities to another.
Definition IRMapping.h:26
auto lookupOrDefault(T from) const
Lookup a mapped value within the map.
Definition IRMapping.h:65
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Definition IRMapping.h:30
IRValueT get() const
Return the current value being used by this operand.
void set(IRValueT newValue)
Set the current value being used by this operand.
This class represents a diagnostic that is inflight and set to be reported.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
This class provides a mutable adaptor for a range of operands.
Definition ValueRange.h:118
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual OptionalParseResult parseOptionalAssignmentList(SmallVectorImpl< Argument > &lhs, SmallVectorImpl< UnresolvedOperand > &rhs)=0
virtual ParseResult parseRegion(Region &region, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult parseArgumentList(SmallVectorImpl< Argument > &result, Delimiter delimiter=Delimiter::None, bool allowType=false, bool allowAttrs=false)=0
Parse zero or more arguments with a specified surrounding delimiter.
ParseResult parseAssignmentList(SmallVectorImpl< Argument > &lhs, SmallVectorImpl< UnresolvedOperand > &rhs)
Parse a list of assignments of the form (x1 = y1, x2 = y2, ...)
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printOptionalAttrDictWithKeyword(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary prefixed with 'attribute...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
void printFunctionalType(Operation *op)
Print the complete type of an operation in functional form.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
RAII guard to reset the insertion point of the builder when destroyed.
Definition Builders.h:348
This class helps build Operations.
Definition Builders.h:207
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition Builders.cpp:430
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition Builders.cpp:562
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition Builders.h:431
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition Builders.h:398
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Definition Builders.h:436
void cloneRegionBefore(Region &region, Region &parent, Region::iterator before, IRMapping &mapping)
Clone the blocks that belong to "region" before the given position in another region "parent".
Definition Builders.cpp:589
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition Builders.h:412
Operation * cloneWithoutRegions(Operation &op, IRMapping &mapper)
Creates a deep copy of this operation but keep the operation regions empty.
Definition Builders.h:594
This class represents a single result from folding an operation.
This class represents an operand of an operation.
Definition Value.h:257
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Definition Value.cpp:226
Set of flags used to control the behavior of the various IR print methods (e.g.
A wrapper class that allows for printing an operation with a set of flags, useful to act as a "stream...
Definition Operation.h:1111
This class implements the operand iterators for the Operation class.
Definition ValueRange.h:43
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
void eraseOperands(unsigned idx, unsigned length=1)
Erase the operands starting at position idx and ending at position 'idx'+'length'.
Definition Operation.h:360
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Definition Operation.h:512
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition Operation.h:407
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:223
OperandRange operand_range
Definition Operation.h:371
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
Definition Operation.h:238
OperationName getName()
The name of an operation is the key identifier for it.
Definition Operation.h:119
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition Operation.h:378
Region * getParentRegion()
Returns the region to which the instruction belongs.
Definition Operation.h:230
bool isProperAncestor(Operation *other)
Return true if this operation is a proper ancestor of the other operation.
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
ParseResult value() const
Access the internal ParseResult value.
bool has_value() const
Returns true if we contain a valid ParseResult value.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class represents a point being branched from in the methods of the RegionBranchOpInterface.
bool isParent() const
Returns true if branching from the parent op.
Operation * getTerminatorPredecessorOrNull() const
Returns the terminator if branching from a region.
This class provides an abstraction over the different types of ranges over Regions.
Definition Region.h:346
This class represents a successor of a region.
bool isParent() const
Return true if the successor is the parent operation.
Region * getSuccessor() const
Return the given region successor.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition Region.h:26
Block & front()
Definition Region.h:65
Region * getParentRegion()
Return the region containing this region or nullptr if the region is attached to a top-level operatio...
Definition Region.cpp:45
bool isAncestor(Region *other)
Return true if this region is ancestor of the other region.
Definition Region.h:222
Block & back()
Definition Region.h:64
bool empty()
Definition Region.h:60
unsigned getNumArguments()
Definition Region.h:123
iterator begin()
Definition Region.h:55
BlockArgument getArgument(unsigned i)
Definition Region.h:124
bool hasOneBlock()
Return true if this region has exactly one block.
Definition Region.h:68
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void eraseBlock(Block *block)
This method erases all operations in a block.
Block * splitBlock(Block *block, Block::iterator before)
Split the operations starting at "before" (inclusive) out of the given block into a new block,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void finalizeOpModification(Operation *op)
This method is used to signal the end of an in-place modification of the given operation.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
Operation * eraseOpResults(Operation *op, const BitVector &eraseIndices)
Erase the specified results of the given operation.
virtual void replaceUsesWithIf(Value from, Value to, function_ref< bool(OpOperand &)> functor, bool *allUsesReplaced=nullptr)
Find uses of from and replace them with to if the functor returns true.
virtual void inlineBlockBefore(Block *source, Block *dest, Block::iterator before, ValueRange argValues={})
Inline the operations of block 'source' into block 'dest' before the given position.
void mergeBlocks(Block *source, Block *dest, ValueRange argValues={})
Inline the operations of block 'source' into the end of block 'dest'.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
void inlineRegionBefore(Region &region, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
virtual void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
virtual void startOpModification(Operation *op)
This method is used to notify the rewriter that an in-place operation modification is about to happen...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class provides an abstraction over the various different ranges of value types.
Definition TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
bool isIndex() const
Definition Types.cpp:54
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:387
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
bool use_empty() const
Returns true if this value has no uses.
Definition Value.h:208
Type getType() const
Return the type of this value.
Definition Value.h:105
Block * getParentBlock()
Return the Block in which this Value is defined.
Definition Value.cpp:46
user_range getUsers() const
Definition Value.h:218
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
Region * getParentRegion()
Return the Region in which this Value is defined.
Definition Value.cpp:39
Operation * getOwner() const
Return the owner of this operand.
Definition UseDefLists.h:38
constexpr auto RecursivelySpeculatable
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto NotSpeculatable
StringRef getMappingAttrName()
Name of the mapping attribute produced by loop mappers.
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:573
ParallelOp getParallelForInductionVarOwner(Value val)
Returns the parallel loop parent of an induction variable.
Definition SCF.cpp:3214
void buildTerminatedBody(OpBuilder &builder, Location loc)
Default callback for IfOp builders. Inserts a yield without arguments.
Definition SCF.cpp:94
LoopNest buildLoopNest(OpBuilder &builder, Location loc, ValueRange lbs, ValueRange ubs, ValueRange steps, ValueRange iterArgs, function_ref< ValueVector(OpBuilder &, Location, ValueRange, ValueRange)> bodyBuilder=nullptr)
Creates a perfect nest of "for" loops, i.e.
Definition SCF.cpp:838
bool insideMutuallyExclusiveBranches(Operation *a, Operation *b)
Return true if ops a and b (or their ancestors) are in mutually exclusive regions/blocks of an IfOp.
Definition SCF.cpp:2121
void promote(RewriterBase &rewriter, scf::ForallOp forallOp)
Promotes the loop body of a scf::ForallOp to its containing block.
Definition SCF.cpp:793
std::optional< llvm::APSInt > computeUbMinusLb(Value lb, Value ub, bool isSigned)
Helper function to compute the difference between two values.
Definition SCF.cpp:115
ForOp getForInductionVarOwner(Value val)
Returns the loop parent of an induction variable.
Definition SCF.cpp:745
SmallVector< Value > ValueVector
An owning vector of values, handy to return from functions.
Definition SCF.h:64
llvm::function_ref< Value(OpBuilder &, Location loc, Type, Value)> ValueTypeCastFnTy
Perform a replacement of one iter OpOperand of an scf.for to the replacement value with a different t...
Definition SCF.h:107
ForallOp getForallOpThreadIndexOwner(Value val)
Returns the ForallOp parent of an thread index variable.
Definition SCF.cpp:1610
SmallVector< Value > replaceAndCastForOpIterArg(RewriterBase &rewriter, scf::ForOp forOp, OpOperand &operand, Value replacement, const ValueTypeCastFnTy &castFn)
Definition SCF.cpp:927
bool preservesStaticInformation(Type source, Type target)
Returns true if target is a ranked tensor type that preserves static information available in the sou...
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition Matchers.h:490
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, MLIRContext *context)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
Definition Matchers.h:527
detail::DenseArrayAttrImpl< int64_t > DenseI64ArrayAttr
const FrozenRewritePatternSet GreedyRewriteConfig bool * changed
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
std::function< SmallVector< Value >( OpBuilder &b, Location loc, ArrayRef< BlockArgument > newBbArgs)> NewYieldValuesFn
A function that returns the additional yielded values during replaceWithAdditionalYields.
ParseResult parseDynamicIndexList(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &values, DenseI64ArrayAttr &integers, DenseBoolArrayAttr &scalableFlags, SmallVectorImpl< Type > *valueTypes=nullptr, AsmParser::Delimiter delimiter=AsmParser::Delimiter::Square)
Parser hooks for custom directive in assemblyFormat.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:304
llvm::DenseSet< ValueT, ValueInfoT > DenseSet
Definition LLVM.h:128
llvm::SetVector< T, Vector, Set, N > SetVector
Definition LLVM.h:131
detail::constant_int_predicate_matcher m_One()
Matches a constant scalar / vector splat / tensor splat integer one.
Definition Matchers.h:478
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition Utils.cpp:111
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
Definition LLVM.h:126
LogicalResult verifyListOfOperandsOrIntegers(Operation *op, StringRef name, unsigned expectedNumElements, ArrayRef< int64_t > attr, ValueRange values)
Verify that a the values has as many elements as the number of entries in attr for which isDynamic ev...
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition Matchers.h:369
void visitUsedValuesDefinedAbove(Region &region, Region &limit, function_ref< void(OpOperand *)> callback)
Calls callback for each use of a value within region or its descendants that was defined at the ances...
llvm::function_ref< Fn > function_ref
Definition LLVM.h:152
SmallVector< NamedAttribute > getPrunedAttributeList(Operation *op, ArrayRef< StringRef > elidedAttrs)
void printDynamicIndexList(OpAsmPrinter &printer, Operation *op, OperandRange values, ArrayRef< int64_t > integers, ArrayRef< bool > scalableFlags, TypeRange valueTypes=TypeRange(), AsmParser::Delimiter delimiter=AsmParser::Delimiter::Square)
Printer hooks for custom directive in assemblyFormat.
std::optional< APInt > constantTripCount(OpFoldResult lb, OpFoldResult ub, OpFoldResult step, bool isSigned, llvm::function_ref< std::optional< llvm::APSInt >(Value, Value, bool)> computeUbMinusLb)
Return the number of iterations for a loop with a lower bound lb, upper bound ub and step step.
LogicalResult foldDynamicIndexList(SmallVectorImpl< OpFoldResult > &ofrs, bool onlyNonNegative=false, bool onlyNonZero=false)
Returns "success" when any of the elements in ofrs is a constant value.
LogicalResult matchAndRewrite(ExecuteRegionOp op, PatternRewriter &rewriter) const override
Definition SCF.cpp:302
LogicalResult matchAndRewrite(scf::IndexSwitchOp op, PatternRewriter &rewriter) const override
Definition SCF.cpp:4627
Canonicalization patterns that folds away dead results of "scf.index_switch" ops.
Definition SCF.cpp:4659
LogicalResult matchAndRewrite(IndexSwitchOp op, PatternRewriter &rewriter) const override
Definition SCF.cpp:4662
LogicalResult matchAndRewrite(ExecuteRegionOp op, PatternRewriter &rewriter) const override
Definition SCF.cpp:261
LogicalResult matchAndRewrite(ExecuteRegionOp op, PatternRewriter &rewriter) const override
Definition SCF.cpp:212
This is the representation of an operand reference.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
This represents an operation in an abstracted form, suitable for use with the builder APIs.
void addOperands(ValueRange newOperands)
void addTypes(ArrayRef< Type > newTypes)
Region * addRegion()
Create a region that should be attached to the operation.