MLIR 22.0.0git
SCF.cpp
Go to the documentation of this file.
1//===- SCF.cpp - Structured Control Flow Operations -----------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
20#include "mlir/IR/IRMapping.h"
21#include "mlir/IR/Matchers.h"
22#include "mlir/IR/Operation.h"
29#include "llvm/ADT/MapVector.h"
30#include "llvm/ADT/STLExtras.h"
31#include "llvm/ADT/SmallPtrSet.h"
32#include "llvm/Support/Casting.h"
33#include "llvm/Support/DebugLog.h"
34#include <optional>
35
36using namespace mlir;
37using namespace mlir::scf;
38
39#include "mlir/Dialect/SCF/IR/SCFOpsDialect.cpp.inc"
40
41//===----------------------------------------------------------------------===//
42// SCFDialect Dialect Interfaces
43//===----------------------------------------------------------------------===//
44
45namespace {
46struct SCFInlinerInterface : public DialectInlinerInterface {
48 // We don't have any special restrictions on what can be inlined into
49 // destination regions (e.g. while/conditional bodies). Always allow it.
50 bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned,
51 IRMapping &valueMapping) const final {
52 return true;
53 }
54 // Operations in scf dialect are always legal to inline since they are
55 // pure.
56 bool isLegalToInline(Operation *, Region *, bool, IRMapping &) const final {
57 return true;
58 }
59 // Handle the given inlined terminator by replacing it with a new operation
60 // as necessary. Required when the region has only one block.
61 void handleTerminator(Operation *op, ValueRange valuesToRepl) const final {
62 auto retValOp = dyn_cast<scf::YieldOp>(op);
63 if (!retValOp)
64 return;
65
66 for (auto retValue : llvm::zip(valuesToRepl, retValOp.getOperands())) {
67 std::get<0>(retValue).replaceAllUsesWith(std::get<1>(retValue));
68 }
69 }
70};
71} // namespace
72
73//===----------------------------------------------------------------------===//
74// SCFDialect
75//===----------------------------------------------------------------------===//
76
77void SCFDialect::initialize() {
78 addOperations<
79#define GET_OP_LIST
80#include "mlir/Dialect/SCF/IR/SCFOps.cpp.inc"
81 >();
82 addInterfaces<SCFInlinerInterface>();
83 declarePromisedInterface<ConvertToEmitCPatternInterface, SCFDialect>();
84 declarePromisedInterfaces<bufferization::BufferDeallocationOpInterface,
85 InParallelOp, ReduceReturnOp>();
86 declarePromisedInterfaces<bufferization::BufferizableOpInterface, ConditionOp,
87 ExecuteRegionOp, ForOp, IfOp, IndexSwitchOp,
88 ForallOp, InParallelOp, WhileOp, YieldOp>();
89 declarePromisedInterface<ValueBoundsOpInterface, ForOp>();
90}
91
92/// Default callback for IfOp builders. Inserts a yield without arguments.
94 scf::YieldOp::create(builder, loc);
95}
96
97/// Verifies that the first block of the given `region` is terminated by a
98/// TerminatorTy. Reports errors on the given operation if it is not the case.
99template <typename TerminatorTy>
100static TerminatorTy verifyAndGetTerminator(Operation *op, Region &region,
101 StringRef errorMessage) {
102 Operation *terminatorOperation = nullptr;
103 if (!region.empty() && !region.front().empty()) {
104 terminatorOperation = &region.front().back();
105 if (auto yield = dyn_cast_or_null<TerminatorTy>(terminatorOperation))
106 return yield;
107 }
108 auto diag = op->emitOpError(errorMessage);
109 if (terminatorOperation)
110 diag.attachNote(terminatorOperation->getLoc()) << "terminator here";
111 return nullptr;
112}
113
114std::optional<llvm::APSInt> mlir::scf::computeUbMinusLb(Value lb, Value ub,
115 bool isSigned) {
116 llvm::APSInt diff;
117 auto addOp = ub.getDefiningOp<arith::AddIOp>();
118 if (!addOp)
119 return std::nullopt;
120 if ((isSigned && !addOp.hasNoSignedWrap()) ||
121 (!isSigned && !addOp.hasNoUnsignedWrap()))
122 return std::nullopt;
123
124 if (addOp.getLhs() != lb ||
125 !matchPattern(addOp.getRhs(), m_ConstantInt(&diff)))
126 return std::nullopt;
127 return diff;
128}
129
130//===----------------------------------------------------------------------===//
131// ExecuteRegionOp
132//===----------------------------------------------------------------------===//
133
134/// Replaces the given op with the contents of the given single-block region,
135/// using the operands of the block terminator to replace operation results.
137 Region &region, ValueRange blockArgs = {}) {
138 assert(region.hasOneBlock() && "expected single-block region");
139 Block *block = &region.front();
140 Operation *terminator = block->getTerminator();
141 ValueRange results = terminator->getOperands();
142 rewriter.inlineBlockBefore(block, op, blockArgs);
143 rewriter.replaceOp(op, results);
144 rewriter.eraseOp(terminator);
145}
146
147///
148/// (ssa-id `=`)? `execute_region` `->` function-result-type `{`
149/// block+
150/// `}`
151///
152/// Example:
153/// scf.execute_region -> i32 {
154/// %idx = load %rI[%i] : memref<128xi32>
155/// return %idx : i32
156/// }
157///
158ParseResult ExecuteRegionOp::parse(OpAsmParser &parser,
160 if (parser.parseOptionalArrowTypeList(result.types))
161 return failure();
162
163 if (succeeded(parser.parseOptionalKeyword("no_inline")))
164 result.addAttribute("no_inline", parser.getBuilder().getUnitAttr());
165
166 // Introduce the body region and parse it.
167 Region *body = result.addRegion();
168 if (parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{}) ||
169 parser.parseOptionalAttrDict(result.attributes))
170 return failure();
171
172 return success();
173}
174
175void ExecuteRegionOp::print(OpAsmPrinter &p) {
176 p.printOptionalArrowTypeList(getResultTypes());
177 p << ' ';
178 if (getNoInline())
179 p << "no_inline ";
180 p.printRegion(getRegion(),
181 /*printEntryBlockArgs=*/false,
182 /*printBlockTerminators=*/true);
183 p.printOptionalAttrDict((*this)->getAttrs(), /*elidedAttrs=*/{"no_inline"});
184}
185
186LogicalResult ExecuteRegionOp::verify() {
187 if (getRegion().empty())
188 return emitOpError("region needs to have at least one block");
189 if (getRegion().front().getNumArguments() > 0)
190 return emitOpError("region cannot have any arguments");
191 return success();
192}
193
194// Inline an ExecuteRegionOp if it only contains one block.
195// "test.foo"() : () -> ()
196// %v = scf.execute_region -> i64 {
197// %x = "test.val"() : () -> i64
198// scf.yield %x : i64
199// }
200// "test.bar"(%v) : (i64) -> ()
201//
202// becomes
203//
204// "test.foo"() : () -> ()
205// %x = "test.val"() : () -> i64
206// "test.bar"(%x) : (i64) -> ()
207//
208struct SingleBlockExecuteInliner : public OpRewritePattern<ExecuteRegionOp> {
209 using OpRewritePattern<ExecuteRegionOp>::OpRewritePattern;
210
211 LogicalResult matchAndRewrite(ExecuteRegionOp op,
212 PatternRewriter &rewriter) const override {
213 if (!op.getRegion().hasOneBlock() || op.getNoInline())
214 return failure();
215 replaceOpWithRegion(rewriter, op, op.getRegion());
216 return success();
217 }
218};
219
220// Inline an ExecuteRegionOp if its parent can contain multiple blocks.
221// TODO generalize the conditions for operations which can be inlined into.
222// func @func_execute_region_elim() {
223// "test.foo"() : () -> ()
224// %v = scf.execute_region -> i64 {
225// %c = "test.cmp"() : () -> i1
226// cf.cond_br %c, ^bb2, ^bb3
227// ^bb2:
228// %x = "test.val1"() : () -> i64
229// cf.br ^bb4(%x : i64)
230// ^bb3:
231// %y = "test.val2"() : () -> i64
232// cf.br ^bb4(%y : i64)
233// ^bb4(%z : i64):
234// scf.yield %z : i64
235// }
236// "test.bar"(%v) : (i64) -> ()
237// return
238// }
239//
240// becomes
241//
242// func @func_execute_region_elim() {
243// "test.foo"() : () -> ()
244// %c = "test.cmp"() : () -> i1
245// cf.cond_br %c, ^bb1, ^bb2
246// ^bb1: // pred: ^bb0
247// %x = "test.val1"() : () -> i64
248// cf.br ^bb3(%x : i64)
249// ^bb2: // pred: ^bb0
250// %y = "test.val2"() : () -> i64
251// cf.br ^bb3(%y : i64)
252// ^bb3(%z: i64): // 2 preds: ^bb1, ^bb2
253// "test.bar"(%z) : (i64) -> ()
254// return
255// }
256//
257struct MultiBlockExecuteInliner : public OpRewritePattern<ExecuteRegionOp> {
258 using OpRewritePattern<ExecuteRegionOp>::OpRewritePattern;
259
260 LogicalResult matchAndRewrite(ExecuteRegionOp op,
261 PatternRewriter &rewriter) const override {
262 if (op.getNoInline())
263 return failure();
264 if (!isa<FunctionOpInterface, ExecuteRegionOp>(op->getParentOp()))
265 return failure();
266
267 Block *prevBlock = op->getBlock();
268 Block *postBlock = rewriter.splitBlock(prevBlock, op->getIterator());
269 rewriter.setInsertionPointToEnd(prevBlock);
270
271 cf::BranchOp::create(rewriter, op.getLoc(), &op.getRegion().front());
272
273 for (Block &blk : op.getRegion()) {
274 if (YieldOp yieldOp = dyn_cast<YieldOp>(blk.getTerminator())) {
275 rewriter.setInsertionPoint(yieldOp);
276 cf::BranchOp::create(rewriter, yieldOp.getLoc(), postBlock,
277 yieldOp.getResults());
278 rewriter.eraseOp(yieldOp);
279 }
280 }
281
282 rewriter.inlineRegionBefore(op.getRegion(), postBlock);
283 SmallVector<Value> blockArgs;
284
285 for (auto res : op.getResults())
286 blockArgs.push_back(postBlock->addArgument(res.getType(), res.getLoc()));
287
288 rewriter.replaceOp(op, blockArgs);
289 return success();
290 }
291};
292
293// Pattern to eliminate ExecuteRegionOp results which forward external
294// values from the region. In case there are multiple yield operations,
295// all of them must have the same operands in order for the pattern to be
296// applicable.
298 : public OpRewritePattern<ExecuteRegionOp> {
299 using OpRewritePattern<ExecuteRegionOp>::OpRewritePattern;
300
301 LogicalResult matchAndRewrite(ExecuteRegionOp op,
302 PatternRewriter &rewriter) const override {
303 if (op.getNumResults() == 0)
304 return failure();
305
307 for (Block &block : op.getRegion()) {
308 if (auto yield = dyn_cast<scf::YieldOp>(block.getTerminator()))
309 yieldOps.push_back(yield.getOperation());
310 }
311
312 if (yieldOps.empty())
313 return failure();
314
315 // Check if all yield operations have the same operands.
316 auto yieldOpsOperands = yieldOps[0]->getOperands();
317 for (auto *yieldOp : yieldOps) {
318 if (yieldOp->getOperands() != yieldOpsOperands)
319 return failure();
320 }
321
322 SmallVector<Value> externalValues;
323 SmallVector<Value> internalValues;
324 SmallVector<Value> opResultsToReplaceWithExternalValues;
325 SmallVector<Value> opResultsToKeep;
326 for (auto [index, yieldedValue] : llvm::enumerate(yieldOpsOperands)) {
327 if (isValueFromInsideRegion(yieldedValue, op)) {
328 internalValues.push_back(yieldedValue);
329 opResultsToKeep.push_back(op.getResult(index));
330 } else {
331 externalValues.push_back(yieldedValue);
332 opResultsToReplaceWithExternalValues.push_back(op.getResult(index));
333 }
334 }
335 // No yielded external values - nothing to do.
336 if (externalValues.empty())
337 return failure();
338
339 // There are yielded external values - create a new execute_region returning
340 // just the internal values.
341 SmallVector<Type> resultTypes;
342 for (Value value : internalValues)
343 resultTypes.push_back(value.getType());
344 auto newOp =
345 ExecuteRegionOp::create(rewriter, op.getLoc(), TypeRange(resultTypes));
346 newOp->setAttrs(op->getAttrs());
347
348 // Move old op's region to the new operation.
349 rewriter.inlineRegionBefore(op.getRegion(), newOp.getRegion(),
350 newOp.getRegion().end());
351
352 // Replace all yield operations with a new yield operation with updated
353 // results. scf.execute_region must have at least one yield operation.
354 for (auto *yieldOp : yieldOps) {
355 rewriter.setInsertionPoint(yieldOp);
356 rewriter.replaceOpWithNewOp<scf::YieldOp>(yieldOp,
357 ValueRange(internalValues));
358 }
359
360 // Replace the old operation with the external values directly.
361 rewriter.replaceAllUsesWith(opResultsToReplaceWithExternalValues,
362 externalValues);
363 // Replace the old operation's remaining results with the new operation's
364 // results.
365 rewriter.replaceAllUsesWith(opResultsToKeep, newOp.getResults());
366 rewriter.eraseOp(op);
367 return success();
368 }
369
370private:
371 bool isValueFromInsideRegion(Value value,
372 ExecuteRegionOp executeRegionOp) const {
373 // Check if the value is defined within the execute_region
374 if (Operation *defOp = value.getDefiningOp())
375 return &executeRegionOp.getRegion() == defOp->getParentRegion();
376
377 // If it's a block argument, check if it's from within the region
378 if (BlockArgument blockArg = dyn_cast<BlockArgument>(value))
379 return &executeRegionOp.getRegion() == blockArg.getParentRegion();
380
381 return false; // Value is from outside the region
382 }
383};
384
385void ExecuteRegionOp::getCanonicalizationPatterns(RewritePatternSet &results,
386 MLIRContext *context) {
389}
390
391void ExecuteRegionOp::getSuccessorRegions(
393 // If the predecessor is the ExecuteRegionOp, branch into the body.
394 if (point.isParent()) {
395 regions.push_back(RegionSuccessor(&getRegion()));
396 return;
397 }
398
399 // Otherwise, the region branches back to the parent operation.
400 regions.push_back(RegionSuccessor(getOperation(), getResults()));
401}
402
403//===----------------------------------------------------------------------===//
404// ConditionOp
405//===----------------------------------------------------------------------===//
406
408ConditionOp::getMutableSuccessorOperands(RegionSuccessor point) {
409 assert(
410 (point.isParent() || point.getSuccessor() == &getParentOp().getAfter()) &&
411 "condition op can only exit the loop or branch to the after"
412 "region");
413 // Pass all operands except the condition to the successor region.
414 return getArgsMutable();
415}
416
417void ConditionOp::getSuccessorRegions(
419 FoldAdaptor adaptor(operands, *this);
420
421 WhileOp whileOp = getParentOp();
422
423 // Condition can either lead to the after region or back to the parent op
424 // depending on whether the condition is true or not.
425 auto boolAttr = dyn_cast_or_null<BoolAttr>(adaptor.getCondition());
426 if (!boolAttr || boolAttr.getValue())
427 regions.emplace_back(&whileOp.getAfter(),
428 whileOp.getAfter().getArguments());
429 if (!boolAttr || !boolAttr.getValue())
430 regions.emplace_back(whileOp.getOperation(), whileOp.getResults());
431}
432
433//===----------------------------------------------------------------------===//
434// ForOp
435//===----------------------------------------------------------------------===//
436
437void ForOp::build(OpBuilder &builder, OperationState &result, Value lb,
438 Value ub, Value step, ValueRange initArgs,
439 BodyBuilderFn bodyBuilder, bool unsignedCmp) {
440 OpBuilder::InsertionGuard guard(builder);
441
442 if (unsignedCmp)
443 result.addAttribute(getUnsignedCmpAttrName(result.name),
444 builder.getUnitAttr());
445 result.addOperands({lb, ub, step});
446 result.addOperands(initArgs);
447 for (Value v : initArgs)
448 result.addTypes(v.getType());
449 Type t = lb.getType();
450 Region *bodyRegion = result.addRegion();
451 Block *bodyBlock = builder.createBlock(bodyRegion);
452 bodyBlock->addArgument(t, result.location);
453 for (Value v : initArgs)
454 bodyBlock->addArgument(v.getType(), v.getLoc());
455
456 // Create the default terminator if the builder is not provided and if the
457 // iteration arguments are not provided. Otherwise, leave this to the caller
458 // because we don't know which values to return from the loop.
459 if (initArgs.empty() && !bodyBuilder) {
460 ForOp::ensureTerminator(*bodyRegion, builder, result.location);
461 } else if (bodyBuilder) {
462 OpBuilder::InsertionGuard guard(builder);
463 builder.setInsertionPointToStart(bodyBlock);
464 bodyBuilder(builder, result.location, bodyBlock->getArgument(0),
465 bodyBlock->getArguments().drop_front());
466 }
467}
468
469LogicalResult ForOp::verify() {
470 // Check that the number of init args and op results is the same.
471 if (getInitArgs().size() != getNumResults())
472 return emitOpError(
473 "mismatch in number of loop-carried values and defined values");
474
475 return success();
476}
477
478LogicalResult ForOp::verifyRegions() {
479 // Check that the body defines as single block argument for the induction
480 // variable.
481 if (getInductionVar().getType() != getLowerBound().getType())
482 return emitOpError(
483 "expected induction variable to be same type as bounds and step");
484
485 if (getNumRegionIterArgs() != getNumResults())
486 return emitOpError(
487 "mismatch in number of basic block args and defined values");
488
489 auto initArgs = getInitArgs();
490 auto iterArgs = getRegionIterArgs();
491 auto opResults = getResults();
492 unsigned i = 0;
493 for (auto e : llvm::zip(initArgs, iterArgs, opResults)) {
494 if (std::get<0>(e).getType() != std::get<2>(e).getType())
495 return emitOpError() << "types mismatch between " << i
496 << "th iter operand and defined value";
497 if (std::get<1>(e).getType() != std::get<2>(e).getType())
498 return emitOpError() << "types mismatch between " << i
499 << "th iter region arg and defined value";
500
501 ++i;
502 }
503 return success();
504}
505
506std::optional<SmallVector<Value>> ForOp::getLoopInductionVars() {
507 return SmallVector<Value>{getInductionVar()};
508}
509
510std::optional<SmallVector<OpFoldResult>> ForOp::getLoopLowerBounds() {
512}
513
514std::optional<SmallVector<OpFoldResult>> ForOp::getLoopSteps() {
515 return SmallVector<OpFoldResult>{OpFoldResult(getStep())};
516}
517
518std::optional<SmallVector<OpFoldResult>> ForOp::getLoopUpperBounds() {
520}
521
522std::optional<ResultRange> ForOp::getLoopResults() { return getResults(); }
523
524/// Promotes the loop body of a forOp to its containing block if the forOp
525/// it can be determined that the loop has a single iteration.
526LogicalResult ForOp::promoteIfSingleIteration(RewriterBase &rewriter) {
527 std::optional<APInt> tripCount = getStaticTripCount();
528 LDBG() << "promoteIfSingleIteration tripCount is " << tripCount
529 << " for loop "
530 << OpWithFlags(getOperation(), OpPrintingFlags().skipRegions());
531 if (!tripCount.has_value() || tripCount->getSExtValue() > 1)
532 return failure();
533
534 if (*tripCount == 0) {
535 rewriter.replaceAllUsesWith(getResults(), getInitArgs());
536 rewriter.eraseOp(*this);
537 return success();
538 }
539
540 // Replace all results with the yielded values.
541 auto yieldOp = cast<scf::YieldOp>(getBody()->getTerminator());
542 rewriter.replaceAllUsesWith(getResults(), getYieldedValues());
543
544 // Replace block arguments with lower bound (replacement for IV) and
545 // iter_args.
546 SmallVector<Value> bbArgReplacements;
547 bbArgReplacements.push_back(getLowerBound());
548 llvm::append_range(bbArgReplacements, getInitArgs());
549
550 // Move the loop body operations to the loop's containing block.
551 rewriter.inlineBlockBefore(getBody(), getOperation()->getBlock(),
552 getOperation()->getIterator(), bbArgReplacements);
553
554 // Erase the old terminator and the loop.
555 rewriter.eraseOp(yieldOp);
556 rewriter.eraseOp(*this);
557
558 return success();
559}
560
561/// Prints the initialization list in the form of
562/// <prefix>(%inner = %outer, %inner2 = %outer2, <...>)
563/// where 'inner' values are assumed to be region arguments and 'outer' values
564/// are regular SSA values.
566 Block::BlockArgListType blocksArgs,
567 ValueRange initializers,
568 StringRef prefix = "") {
569 assert(blocksArgs.size() == initializers.size() &&
570 "expected same length of arguments and initializers");
571 if (initializers.empty())
572 return;
573
574 p << prefix << '(';
575 llvm::interleaveComma(llvm::zip(blocksArgs, initializers), p, [&](auto it) {
576 p << std::get<0>(it) << " = " << std::get<1>(it);
577 });
578 p << ")";
579}
580
581void ForOp::print(OpAsmPrinter &p) {
582 if (getUnsignedCmp())
583 p << " unsigned";
584
585 p << " " << getInductionVar() << " = " << getLowerBound() << " to "
586 << getUpperBound() << " step " << getStep();
587
588 printInitializationList(p, getRegionIterArgs(), getInitArgs(), " iter_args");
589 if (!getInitArgs().empty())
590 p << " -> (" << getInitArgs().getTypes() << ')';
591 p << ' ';
592 if (Type t = getInductionVar().getType(); !t.isIndex())
593 p << " : " << t << ' ';
594 p.printRegion(getRegion(),
595 /*printEntryBlockArgs=*/false,
596 /*printBlockTerminators=*/!getInitArgs().empty());
597 p.printOptionalAttrDict((*this)->getAttrs(),
598 /*elidedAttrs=*/getUnsignedCmpAttrName().strref());
599}
600
601ParseResult ForOp::parse(OpAsmParser &parser, OperationState &result) {
602 auto &builder = parser.getBuilder();
603 Type type;
604
605 OpAsmParser::Argument inductionVariable;
607
608 if (succeeded(parser.parseOptionalKeyword("unsigned")))
609 result.addAttribute(getUnsignedCmpAttrName(result.name),
610 builder.getUnitAttr());
611
612 // Parse the induction variable followed by '='.
613 if (parser.parseOperand(inductionVariable.ssaName) || parser.parseEqual() ||
614 // Parse loop bounds.
615 parser.parseOperand(lb) || parser.parseKeyword("to") ||
616 parser.parseOperand(ub) || parser.parseKeyword("step") ||
617 parser.parseOperand(step))
618 return failure();
619
620 // Parse the optional initial iteration arguments.
623 regionArgs.push_back(inductionVariable);
624
625 bool hasIterArgs = succeeded(parser.parseOptionalKeyword("iter_args"));
626 if (hasIterArgs) {
627 // Parse assignment list and results type list.
628 if (parser.parseAssignmentList(regionArgs, operands) ||
629 parser.parseArrowTypeList(result.types))
630 return failure();
631 }
632
633 if (regionArgs.size() != result.types.size() + 1)
634 return parser.emitError(
635 parser.getNameLoc(),
636 "mismatch in number of loop-carried values and defined values");
637
638 // Parse optional type, else assume Index.
639 if (parser.parseOptionalColon())
640 type = builder.getIndexType();
641 else if (parser.parseType(type))
642 return failure();
643
644 // Set block argument types, so that they are known when parsing the region.
645 regionArgs.front().type = type;
646 for (auto [iterArg, type] :
647 llvm::zip_equal(llvm::drop_begin(regionArgs), result.types))
648 iterArg.type = type;
649
650 // Parse the body region.
651 Region *body = result.addRegion();
652 if (parser.parseRegion(*body, regionArgs))
653 return failure();
654 ForOp::ensureTerminator(*body, builder, result.location);
655
656 // Resolve input operands. This should be done after parsing the region to
657 // catch invalid IR where operands were defined inside of the region.
658 if (parser.resolveOperand(lb, type, result.operands) ||
659 parser.resolveOperand(ub, type, result.operands) ||
660 parser.resolveOperand(step, type, result.operands))
661 return failure();
662 if (hasIterArgs) {
663 for (auto argOperandType : llvm::zip_equal(llvm::drop_begin(regionArgs),
664 operands, result.types)) {
665 Type type = std::get<2>(argOperandType);
666 std::get<0>(argOperandType).type = type;
667 if (parser.resolveOperand(std::get<1>(argOperandType), type,
668 result.operands))
669 return failure();
670 }
671 }
672
673 // Parse the optional attribute list.
674 if (parser.parseOptionalAttrDict(result.attributes))
675 return failure();
676
677 return success();
678}
679
680SmallVector<Region *> ForOp::getLoopRegions() { return {&getRegion()}; }
681
682Block::BlockArgListType ForOp::getRegionIterArgs() {
683 return getBody()->getArguments().drop_front(getNumInductionVars());
684}
685
686MutableArrayRef<OpOperand> ForOp::getInitsMutable() {
687 return getInitArgsMutable();
688}
689
690FailureOr<LoopLikeOpInterface>
691ForOp::replaceWithAdditionalYields(RewriterBase &rewriter,
692 ValueRange newInitOperands,
693 bool replaceInitOperandUsesInLoop,
694 const NewYieldValuesFn &newYieldValuesFn) {
695 // Create a new loop before the existing one, with the extra operands.
696 OpBuilder::InsertionGuard g(rewriter);
697 rewriter.setInsertionPoint(getOperation());
698 auto inits = llvm::to_vector(getInitArgs());
699 inits.append(newInitOperands.begin(), newInitOperands.end());
700 scf::ForOp newLoop = scf::ForOp::create(
701 rewriter, getLoc(), getLowerBound(), getUpperBound(), getStep(), inits,
702 [](OpBuilder &, Location, Value, ValueRange) {}, getUnsignedCmp());
703 newLoop->setAttrs(getPrunedAttributeList(getOperation(), {}));
704
705 // Generate the new yield values and append them to the scf.yield operation.
706 auto yieldOp = cast<scf::YieldOp>(getBody()->getTerminator());
707 ArrayRef<BlockArgument> newIterArgs =
708 newLoop.getBody()->getArguments().take_back(newInitOperands.size());
709 {
710 OpBuilder::InsertionGuard g(rewriter);
711 rewriter.setInsertionPoint(yieldOp);
712 SmallVector<Value> newYieldedValues =
713 newYieldValuesFn(rewriter, getLoc(), newIterArgs);
714 assert(newInitOperands.size() == newYieldedValues.size() &&
715 "expected as many new yield values as new iter operands");
716 rewriter.modifyOpInPlace(yieldOp, [&]() {
717 yieldOp.getResultsMutable().append(newYieldedValues);
718 });
719 }
720
721 // Move the loop body to the new op.
722 rewriter.mergeBlocks(getBody(), newLoop.getBody(),
723 newLoop.getBody()->getArguments().take_front(
724 getBody()->getNumArguments()));
725
726 if (replaceInitOperandUsesInLoop) {
727 // Replace all uses of `newInitOperands` with the corresponding basic block
728 // arguments.
729 for (auto it : llvm::zip(newInitOperands, newIterArgs)) {
730 rewriter.replaceUsesWithIf(std::get<0>(it), std::get<1>(it),
731 [&](OpOperand &use) {
732 Operation *user = use.getOwner();
733 return newLoop->isProperAncestor(user);
734 });
735 }
736 }
737
738 // Replace the old loop.
739 rewriter.replaceOp(getOperation(),
740 newLoop->getResults().take_front(getNumResults()));
741 return cast<LoopLikeOpInterface>(newLoop.getOperation());
742}
743
745 auto ivArg = llvm::dyn_cast<BlockArgument>(val);
746 if (!ivArg)
747 return ForOp();
748 assert(ivArg.getOwner() && "unlinked block argument");
749 auto *containingOp = ivArg.getOwner()->getParentOp();
750 return dyn_cast_or_null<ForOp>(containingOp);
751}
752
753OperandRange ForOp::getEntrySuccessorOperands(RegionSuccessor successor) {
754 return getInitArgs();
755}
756
757void ForOp::getSuccessorRegions(RegionBranchPoint point,
759 // Both the operation itself and the region may be branching into the body or
760 // back into the operation itself. It is possible for loop not to enter the
761 // body.
762 regions.push_back(RegionSuccessor(&getRegion(), getRegionIterArgs()));
763 regions.push_back(RegionSuccessor(getOperation(), getResults()));
764}
765
766SmallVector<Region *> ForallOp::getLoopRegions() { return {&getRegion()}; }
767
768/// Promotes the loop body of a forallOp to its containing block if it can be
769/// determined that the loop has a single iteration.
770LogicalResult scf::ForallOp::promoteIfSingleIteration(RewriterBase &rewriter) {
771 for (auto [lb, ub, step] :
772 llvm::zip(getMixedLowerBound(), getMixedUpperBound(), getMixedStep())) {
773 auto tripCount =
774 constantTripCount(lb, ub, step, /*isSigned=*/true, computeUbMinusLb);
775 if (!tripCount.has_value() || *tripCount != 1)
776 return failure();
777 }
778
779 promote(rewriter, *this);
780 return success();
781}
782
783Block::BlockArgListType ForallOp::getRegionIterArgs() {
784 return getBody()->getArguments().drop_front(getRank());
785}
786
787MutableArrayRef<OpOperand> ForallOp::getInitsMutable() {
788 return getOutputsMutable();
789}
790
791/// Promotes the loop body of a scf::ForallOp to its containing block.
792void mlir::scf::promote(RewriterBase &rewriter, scf::ForallOp forallOp) {
793 OpBuilder::InsertionGuard g(rewriter);
794 scf::InParallelOp terminator = forallOp.getTerminator();
795
796 // Replace block arguments with lower bounds (replacements for IVs) and
797 // outputs.
798 SmallVector<Value> bbArgReplacements = forallOp.getLowerBound(rewriter);
799 bbArgReplacements.append(forallOp.getOutputs().begin(),
800 forallOp.getOutputs().end());
801
802 // Move the loop body operations to the loop's containing block.
803 rewriter.inlineBlockBefore(forallOp.getBody(), forallOp->getBlock(),
804 forallOp->getIterator(), bbArgReplacements);
805
806 // Replace the terminator with tensor.insert_slice ops.
807 rewriter.setInsertionPointAfter(forallOp);
808 SmallVector<Value> results;
809 results.reserve(forallOp.getResults().size());
810 for (auto &yieldingOp : terminator.getYieldingOps()) {
811 auto parallelInsertSliceOp =
812 dyn_cast<tensor::ParallelInsertSliceOp>(yieldingOp);
813 if (!parallelInsertSliceOp)
814 continue;
815
816 Value dst = parallelInsertSliceOp.getDest();
817 Value src = parallelInsertSliceOp.getSource();
818 if (llvm::isa<TensorType>(src.getType())) {
819 results.push_back(tensor::InsertSliceOp::create(
820 rewriter, forallOp.getLoc(), dst.getType(), src, dst,
821 parallelInsertSliceOp.getOffsets(), parallelInsertSliceOp.getSizes(),
822 parallelInsertSliceOp.getStrides(),
823 parallelInsertSliceOp.getStaticOffsets(),
824 parallelInsertSliceOp.getStaticSizes(),
825 parallelInsertSliceOp.getStaticStrides()));
826 } else {
827 llvm_unreachable("unsupported terminator");
828 }
829 }
830 rewriter.replaceAllUsesWith(forallOp.getResults(), results);
831
832 // Erase the old terminator and the loop.
833 rewriter.eraseOp(terminator);
834 rewriter.eraseOp(forallOp);
835}
836
838 OpBuilder &builder, Location loc, ValueRange lbs, ValueRange ubs,
839 ValueRange steps, ValueRange iterArgs,
841 bodyBuilder) {
842 assert(lbs.size() == ubs.size() &&
843 "expected the same number of lower and upper bounds");
844 assert(lbs.size() == steps.size() &&
845 "expected the same number of lower bounds and steps");
846
847 // If there are no bounds, call the body-building function and return early.
848 if (lbs.empty()) {
849 ValueVector results =
850 bodyBuilder ? bodyBuilder(builder, loc, ValueRange(), iterArgs)
851 : ValueVector();
852 assert(results.size() == iterArgs.size() &&
853 "loop nest body must return as many values as loop has iteration "
854 "arguments");
855 return LoopNest{{}, std::move(results)};
856 }
857
858 // First, create the loop structure iteratively using the body-builder
859 // callback of `ForOp::build`. Do not create `YieldOp`s yet.
860 OpBuilder::InsertionGuard guard(builder);
863 loops.reserve(lbs.size());
864 ivs.reserve(lbs.size());
865 ValueRange currentIterArgs = iterArgs;
866 Location currentLoc = loc;
867 for (unsigned i = 0, e = lbs.size(); i < e; ++i) {
868 auto loop = scf::ForOp::create(
869 builder, currentLoc, lbs[i], ubs[i], steps[i], currentIterArgs,
870 [&](OpBuilder &nestedBuilder, Location nestedLoc, Value iv,
871 ValueRange args) {
872 ivs.push_back(iv);
873 // It is safe to store ValueRange args because it points to block
874 // arguments of a loop operation that we also own.
875 currentIterArgs = args;
876 currentLoc = nestedLoc;
877 });
878 // Set the builder to point to the body of the newly created loop. We don't
879 // do this in the callback because the builder is reset when the callback
880 // returns.
881 builder.setInsertionPointToStart(loop.getBody());
882 loops.push_back(loop);
883 }
884
885 // For all loops but the innermost, yield the results of the nested loop.
886 for (unsigned i = 0, e = loops.size() - 1; i < e; ++i) {
887 builder.setInsertionPointToEnd(loops[i].getBody());
888 scf::YieldOp::create(builder, loc, loops[i + 1].getResults());
889 }
890
891 // In the body of the innermost loop, call the body building function if any
892 // and yield its results.
893 builder.setInsertionPointToStart(loops.back().getBody());
894 ValueVector results = bodyBuilder
895 ? bodyBuilder(builder, currentLoc, ivs,
896 loops.back().getRegionIterArgs())
897 : ValueVector();
898 assert(results.size() == iterArgs.size() &&
899 "loop nest body must return as many values as loop has iteration "
900 "arguments");
901 builder.setInsertionPointToEnd(loops.back().getBody());
902 scf::YieldOp::create(builder, loc, results);
903
904 // Return the loops.
905 ValueVector nestResults;
906 llvm::append_range(nestResults, loops.front().getResults());
907 return LoopNest{std::move(loops), std::move(nestResults)};
908}
909
911 OpBuilder &builder, Location loc, ValueRange lbs, ValueRange ubs,
912 ValueRange steps,
913 function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilder) {
914 // Delegate to the main function by wrapping the body builder.
915 return buildLoopNest(builder, loc, lbs, ubs, steps, {},
916 [&bodyBuilder](OpBuilder &nestedBuilder,
917 Location nestedLoc, ValueRange ivs,
919 if (bodyBuilder)
920 bodyBuilder(nestedBuilder, nestedLoc, ivs);
921 return {};
922 });
923}
924
927 OpOperand &operand, Value replacement,
928 const ValueTypeCastFnTy &castFn) {
929 assert(operand.getOwner() == forOp);
930 Type oldType = operand.get().getType(), newType = replacement.getType();
931
932 // 1. Create new iter operands, exactly 1 is replaced.
933 assert(operand.getOperandNumber() >= forOp.getNumControlOperands() &&
934 "expected an iter OpOperand");
935 assert(operand.get().getType() != replacement.getType() &&
936 "Expected a different type");
937 SmallVector<Value> newIterOperands;
938 for (OpOperand &opOperand : forOp.getInitArgsMutable()) {
939 if (opOperand.getOperandNumber() == operand.getOperandNumber()) {
940 newIterOperands.push_back(replacement);
941 continue;
942 }
943 newIterOperands.push_back(opOperand.get());
944 }
945
946 // 2. Create the new forOp shell.
947 scf::ForOp newForOp = scf::ForOp::create(
948 rewriter, forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
949 forOp.getStep(), newIterOperands, /*bodyBuilder=*/nullptr,
950 forOp.getUnsignedCmp());
951 newForOp->setAttrs(forOp->getAttrs());
952 Block &newBlock = newForOp.getRegion().front();
953 SmallVector<Value, 4> newBlockTransferArgs(newBlock.getArguments().begin(),
954 newBlock.getArguments().end());
955
956 // 3. Inject an incoming cast op at the beginning of the block for the bbArg
957 // corresponding to the `replacement` value.
958 OpBuilder::InsertionGuard g(rewriter);
959 rewriter.setInsertionPointToStart(&newBlock);
960 BlockArgument newRegionIterArg = newForOp.getTiedLoopRegionIterArg(
961 &newForOp->getOpOperand(operand.getOperandNumber()));
962 Value castIn = castFn(rewriter, newForOp.getLoc(), oldType, newRegionIterArg);
963 newBlockTransferArgs[newRegionIterArg.getArgNumber()] = castIn;
964
965 // 4. Steal the old block ops, mapping to the newBlockTransferArgs.
966 Block &oldBlock = forOp.getRegion().front();
967 rewriter.mergeBlocks(&oldBlock, &newBlock, newBlockTransferArgs);
968
969 // 5. Inject an outgoing cast op at the end of the block and yield it instead.
970 auto clonedYieldOp = cast<scf::YieldOp>(newBlock.getTerminator());
971 rewriter.setInsertionPoint(clonedYieldOp);
972 unsigned yieldIdx =
973 newRegionIterArg.getArgNumber() - forOp.getNumInductionVars();
974 Value castOut = castFn(rewriter, newForOp.getLoc(), newType,
975 clonedYieldOp.getOperand(yieldIdx));
976 SmallVector<Value> newYieldOperands = clonedYieldOp.getOperands();
977 newYieldOperands[yieldIdx] = castOut;
978 scf::YieldOp::create(rewriter, newForOp.getLoc(), newYieldOperands);
979 rewriter.eraseOp(clonedYieldOp);
980
981 // 6. Inject an outgoing cast op after the forOp.
982 rewriter.setInsertionPointAfter(newForOp);
983 SmallVector<Value> newResults = newForOp.getResults();
984 newResults[yieldIdx] =
985 castFn(rewriter, newForOp.getLoc(), oldType, newResults[yieldIdx]);
986
987 return newResults;
988}
989
990namespace {
991// Fold away ForOp iter arguments when:
992// 1) The op yields the iter arguments.
993// 2) The argument's corresponding outer region iterators (inputs) are yielded.
994// 3) The iter arguments have no use and the corresponding (operation) results
995// have no use.
996//
997// These arguments must be defined outside of the ForOp region and can just be
998// forwarded after simplifying the op inits, yields and returns.
999//
1000// The implementation uses `inlineBlockBefore` to steal the content of the
1001// original ForOp and avoid cloning.
1002struct ForOpIterArgsFolder : public OpRewritePattern<scf::ForOp> {
1003 using OpRewritePattern<scf::ForOp>::OpRewritePattern;
1004
1005 LogicalResult matchAndRewrite(scf::ForOp forOp,
1006 PatternRewriter &rewriter) const final {
1007 bool canonicalize = false;
1008
1009 // An internal flat vector of block transfer
1010 // arguments `newBlockTransferArgs` keeps the 1-1 mapping of original to
1011 // transformed block argument mappings. This plays the role of a
1012 // IRMapping for the particular use case of calling into
1013 // `inlineBlockBefore`.
1014 int64_t numResults = forOp.getNumResults();
1015 SmallVector<bool, 4> keepMask;
1016 keepMask.reserve(numResults);
1017 SmallVector<Value, 4> newBlockTransferArgs, newIterArgs, newYieldValues,
1018 newResultValues;
1019 newBlockTransferArgs.reserve(1 + numResults);
1020 newBlockTransferArgs.push_back(Value()); // iv placeholder with null value
1021 newIterArgs.reserve(forOp.getInitArgs().size());
1022 newYieldValues.reserve(numResults);
1023 newResultValues.reserve(numResults);
1024 DenseMap<std::pair<Value, Value>, std::pair<Value, Value>> initYieldToArg;
1025 for (auto [init, arg, result, yielded] :
1026 llvm::zip(forOp.getInitArgs(), // iter from outside
1027 forOp.getRegionIterArgs(), // iter inside region
1028 forOp.getResults(), // op results
1029 forOp.getYieldedValues() // iter yield
1030 )) {
1031 // Forwarded is `true` when:
1032 // 1) The region `iter` argument is yielded.
1033 // 2) The region `iter` argument the corresponding input is yielded.
1034 // 3) The region `iter` argument has no use, and the corresponding op
1035 // result has no use.
1036 bool forwarded = (arg == yielded) || (init == yielded) ||
1037 (arg.use_empty() && result.use_empty());
1038 if (forwarded) {
1039 canonicalize = true;
1040 keepMask.push_back(false);
1041 newBlockTransferArgs.push_back(init);
1042 newResultValues.push_back(init);
1043 continue;
1044 }
1045
1046 // Check if a previous kept argument always has the same values for init
1047 // and yielded values.
1048 if (auto it = initYieldToArg.find({init, yielded});
1049 it != initYieldToArg.end()) {
1050 canonicalize = true;
1051 keepMask.push_back(false);
1052 auto [sameArg, sameResult] = it->second;
1053 rewriter.replaceAllUsesWith(arg, sameArg);
1054 rewriter.replaceAllUsesWith(result, sameResult);
1055 // The replacement value doesn't matter because there are no uses.
1056 newBlockTransferArgs.push_back(init);
1057 newResultValues.push_back(init);
1058 continue;
1059 }
1060
1061 // This value is kept.
1062 initYieldToArg.insert({{init, yielded}, {arg, result}});
1063 keepMask.push_back(true);
1064 newIterArgs.push_back(init);
1065 newYieldValues.push_back(yielded);
1066 newBlockTransferArgs.push_back(Value()); // placeholder with null value
1067 newResultValues.push_back(Value()); // placeholder with null value
1068 }
1069
1070 if (!canonicalize)
1071 return failure();
1072
1073 scf::ForOp newForOp =
1074 scf::ForOp::create(rewriter, forOp.getLoc(), forOp.getLowerBound(),
1075 forOp.getUpperBound(), forOp.getStep(), newIterArgs,
1076 /*bodyBuilder=*/nullptr, forOp.getUnsignedCmp());
1077 newForOp->setAttrs(forOp->getAttrs());
1078 Block &newBlock = newForOp.getRegion().front();
1079
1080 // Replace the null placeholders with newly constructed values.
1081 newBlockTransferArgs[0] = newBlock.getArgument(0); // iv
1082 for (unsigned idx = 0, collapsedIdx = 0, e = newResultValues.size();
1083 idx != e; ++idx) {
1084 Value &blockTransferArg = newBlockTransferArgs[1 + idx];
1085 Value &newResultVal = newResultValues[idx];
1086 assert((blockTransferArg && newResultVal) ||
1087 (!blockTransferArg && !newResultVal));
1088 if (!blockTransferArg) {
1089 blockTransferArg = newForOp.getRegionIterArgs()[collapsedIdx];
1090 newResultVal = newForOp.getResult(collapsedIdx++);
1091 }
1092 }
1093
1094 Block &oldBlock = forOp.getRegion().front();
1095 assert(oldBlock.getNumArguments() == newBlockTransferArgs.size() &&
1096 "unexpected argument size mismatch");
1097
1098 // No results case: the scf::ForOp builder already created a zero
1099 // result terminator. Merge before this terminator and just get rid of the
1100 // original terminator that has been merged in.
1101 if (newIterArgs.empty()) {
1102 auto newYieldOp = cast<scf::YieldOp>(newBlock.getTerminator());
1103 rewriter.inlineBlockBefore(&oldBlock, newYieldOp, newBlockTransferArgs);
1104 rewriter.eraseOp(newBlock.getTerminator()->getPrevNode());
1105 rewriter.replaceOp(forOp, newResultValues);
1106 return success();
1107 }
1108
1109 // No terminator case: merge and rewrite the merged terminator.
1110 auto cloneFilteredTerminator = [&](scf::YieldOp mergedTerminator) {
1111 OpBuilder::InsertionGuard g(rewriter);
1112 rewriter.setInsertionPoint(mergedTerminator);
1113 SmallVector<Value, 4> filteredOperands;
1114 filteredOperands.reserve(newResultValues.size());
1115 for (unsigned idx = 0, e = keepMask.size(); idx < e; ++idx)
1116 if (keepMask[idx])
1117 filteredOperands.push_back(mergedTerminator.getOperand(idx));
1118 scf::YieldOp::create(rewriter, mergedTerminator.getLoc(),
1119 filteredOperands);
1120 };
1121
1122 rewriter.mergeBlocks(&oldBlock, &newBlock, newBlockTransferArgs);
1123 auto mergedYieldOp = cast<scf::YieldOp>(newBlock.getTerminator());
1124 cloneFilteredTerminator(mergedYieldOp);
1125 rewriter.eraseOp(mergedYieldOp);
1126 rewriter.replaceOp(forOp, newResultValues);
1127 return success();
1128 }
1129};
1130
1131/// Rewriting pattern that erases loops that are known not to iterate, replaces
1132/// single-iteration loops with their bodies, and removes empty loops that
1133/// iterate at least once and only return values defined outside of the loop.
1134struct SimplifyTrivialLoops : public OpRewritePattern<ForOp> {
1135 using OpRewritePattern<ForOp>::OpRewritePattern;
1136
1137 LogicalResult matchAndRewrite(ForOp op,
1138 PatternRewriter &rewriter) const override {
1139 std::optional<APInt> tripCount = op.getStaticTripCount();
1140 if (!tripCount.has_value())
1141 return rewriter.notifyMatchFailure(op,
1142 "can't compute constant trip count");
1143
1144 if (tripCount->isZero()) {
1145 LDBG() << "SimplifyTrivialLoops tripCount is 0 for loop "
1146 << OpWithFlags(op, OpPrintingFlags().skipRegions());
1147 rewriter.replaceOp(op, op.getInitArgs());
1148 return success();
1149 }
1150
1151 if (tripCount->getSExtValue() == 1) {
1152 LDBG() << "SimplifyTrivialLoops tripCount is 1 for loop "
1153 << OpWithFlags(op, OpPrintingFlags().skipRegions());
1154 SmallVector<Value, 4> blockArgs;
1155 blockArgs.reserve(op.getInitArgs().size() + 1);
1156 blockArgs.push_back(op.getLowerBound());
1157 llvm::append_range(blockArgs, op.getInitArgs());
1158 replaceOpWithRegion(rewriter, op, op.getRegion(), blockArgs);
1159 return success();
1160 }
1161
1162 // Now we are left with loops that have more than 1 iterations.
1163 Block &block = op.getRegion().front();
1164 if (!llvm::hasSingleElement(block))
1165 return failure();
1166 // The loop is empty and iterates at least once, if it only returns values
1167 // defined outside of the loop, remove it and replace it with yield values.
1168 if (llvm::any_of(op.getYieldedValues(),
1169 [&](Value v) { return !op.isDefinedOutsideOfLoop(v); }))
1170 return failure();
1171 LDBG() << "SimplifyTrivialLoops empty body loop allows replacement with "
1172 "yield operands for loop "
1173 << OpWithFlags(op, OpPrintingFlags().skipRegions());
1174 rewriter.replaceOp(op, op.getYieldedValues());
1175 return success();
1176 }
1177};
1178
1179/// Fold scf.for iter_arg/result pairs that go through incoming/ougoing
1180/// a tensor.cast op pair so as to pull the tensor.cast inside the scf.for:
1181///
1182/// ```
1183/// %0 = tensor.cast %t0 : tensor<32x1024xf32> to tensor<?x?xf32>
1184/// %1 = scf.for %i = %c0 to %c1024 step %c32 iter_args(%iter_t0 = %0)
1185/// -> (tensor<?x?xf32>) {
1186/// %2 = call @do(%iter_t0) : (tensor<?x?xf32>) -> tensor<?x?xf32>
1187/// scf.yield %2 : tensor<?x?xf32>
1188/// }
1189/// use_of(%1)
1190/// ```
1191///
1192/// folds into:
1193///
1194/// ```
1195/// %0 = scf.for %arg2 = %c0 to %c1024 step %c32 iter_args(%arg3 = %arg0)
1196/// -> (tensor<32x1024xf32>) {
1197/// %2 = tensor.cast %arg3 : tensor<32x1024xf32> to tensor<?x?xf32>
1198/// %3 = call @do(%2) : (tensor<?x?xf32>) -> tensor<?x?xf32>
1199/// %4 = tensor.cast %3 : tensor<?x?xf32> to tensor<32x1024xf32>
1200/// scf.yield %4 : tensor<32x1024xf32>
1201/// }
1202/// %1 = tensor.cast %0 : tensor<32x1024xf32> to tensor<?x?xf32>
1203/// use_of(%1)
1204/// ```
1205struct ForOpTensorCastFolder : public OpRewritePattern<ForOp> {
1206 using OpRewritePattern<ForOp>::OpRewritePattern;
1207
1208 LogicalResult matchAndRewrite(ForOp op,
1209 PatternRewriter &rewriter) const override {
1210 for (auto it : llvm::zip(op.getInitArgsMutable(), op.getResults())) {
1211 OpOperand &iterOpOperand = std::get<0>(it);
1212 auto incomingCast = iterOpOperand.get().getDefiningOp<tensor::CastOp>();
1213 if (!incomingCast ||
1214 incomingCast.getSource().getType() == incomingCast.getType())
1215 continue;
1216 // If the dest type of the cast does not preserve static information in
1217 // the source type.
1219 incomingCast.getDest().getType(),
1220 incomingCast.getSource().getType()))
1221 continue;
1222 if (!std::get<1>(it).hasOneUse())
1223 continue;
1224
1225 // Create a new ForOp with that iter operand replaced.
1226 rewriter.replaceOp(
1228 rewriter, op, iterOpOperand, incomingCast.getSource(),
1229 [](OpBuilder &b, Location loc, Type type, Value source) {
1230 return tensor::CastOp::create(b, loc, type, source);
1231 }));
1232 return success();
1233 }
1234 return failure();
1235 }
1236};
1237
1238} // namespace
1239
1240void ForOp::getCanonicalizationPatterns(RewritePatternSet &results,
1241 MLIRContext *context) {
1242 results.add<ForOpIterArgsFolder, SimplifyTrivialLoops, ForOpTensorCastFolder>(
1243 context);
1244}
1245
1246std::optional<APInt> ForOp::getConstantStep() {
1247 IntegerAttr step;
1248 if (matchPattern(getStep(), m_Constant(&step)))
1249 return step.getValue();
1250 return {};
1251}
1252
1253std::optional<MutableArrayRef<OpOperand>> ForOp::getYieldedValuesMutable() {
1254 return cast<scf::YieldOp>(getBody()->getTerminator()).getResultsMutable();
1255}
1256
1257Speculation::Speculatability ForOp::getSpeculatability() {
1258 // `scf.for (I = Start; I < End; I += 1)` terminates for all values of Start
1259 // and End.
1260 if (auto constantStep = getConstantStep())
1261 if (*constantStep == 1)
1263
1264 // For Step != 1, the loop may not terminate. We can add more smarts here if
1265 // needed.
1267}
1268
1269std::optional<APInt> ForOp::getStaticTripCount() {
1270 return constantTripCount(getLowerBound(), getUpperBound(), getStep(),
1271 /*isSigned=*/!getUnsignedCmp(), computeUbMinusLb);
1272}
1273
1274//===----------------------------------------------------------------------===//
1275// ForallOp
1276//===----------------------------------------------------------------------===//
1277
1278LogicalResult ForallOp::verify() {
1279 unsigned numLoops = getRank();
1280 // Check number of outputs.
1281 if (getNumResults() != getOutputs().size())
1282 return emitOpError("produces ")
1283 << getNumResults() << " results, but has only "
1284 << getOutputs().size() << " outputs";
1285
1286 // Check that the body defines block arguments for thread indices and outputs.
1287 auto *body = getBody();
1288 if (body->getNumArguments() != numLoops + getOutputs().size())
1289 return emitOpError("region expects ") << numLoops << " arguments";
1290 for (int64_t i = 0; i < numLoops; ++i)
1291 if (!body->getArgument(i).getType().isIndex())
1292 return emitOpError("expects ")
1293 << i << "-th block argument to be an index";
1294 for (unsigned i = 0; i < getOutputs().size(); ++i)
1295 if (body->getArgument(i + numLoops).getType() != getOutputs()[i].getType())
1296 return emitOpError("type mismatch between ")
1297 << i << "-th output and corresponding block argument";
1298 if (getMapping().has_value() && !getMapping()->empty()) {
1299 if (getDeviceMappingAttrs().size() != numLoops)
1300 return emitOpError() << "mapping attribute size must match op rank";
1301 if (failed(getDeviceMaskingAttr()))
1302 return emitOpError() << getMappingAttrName()
1303 << " supports at most one device masking attribute";
1304 }
1305
1306 // Verify mixed static/dynamic control variables.
1307 Operation *op = getOperation();
1308 if (failed(verifyListOfOperandsOrIntegers(op, "lower bound", numLoops,
1309 getStaticLowerBound(),
1310 getDynamicLowerBound())))
1311 return failure();
1312 if (failed(verifyListOfOperandsOrIntegers(op, "upper bound", numLoops,
1313 getStaticUpperBound(),
1314 getDynamicUpperBound())))
1315 return failure();
1316 if (failed(verifyListOfOperandsOrIntegers(op, "step", numLoops,
1317 getStaticStep(), getDynamicStep())))
1318 return failure();
1319
1320 return success();
1321}
1322
1323void ForallOp::print(OpAsmPrinter &p) {
1324 Operation *op = getOperation();
1325 p << " (" << getInductionVars();
1326 if (isNormalized()) {
1327 p << ") in ";
1328 printDynamicIndexList(p, op, getDynamicUpperBound(), getStaticUpperBound(),
1329 /*valueTypes=*/{}, /*scalables=*/{},
1331 } else {
1332 p << ") = ";
1333 printDynamicIndexList(p, op, getDynamicLowerBound(), getStaticLowerBound(),
1334 /*valueTypes=*/{}, /*scalables=*/{},
1336 p << " to ";
1337 printDynamicIndexList(p, op, getDynamicUpperBound(), getStaticUpperBound(),
1338 /*valueTypes=*/{}, /*scalables=*/{},
1340 p << " step ";
1341 printDynamicIndexList(p, op, getDynamicStep(), getStaticStep(),
1342 /*valueTypes=*/{}, /*scalables=*/{},
1344 }
1345 printInitializationList(p, getRegionOutArgs(), getOutputs(), " shared_outs");
1346 p << " ";
1347 if (!getRegionOutArgs().empty())
1348 p << "-> (" << getResultTypes() << ") ";
1349 p.printRegion(getRegion(),
1350 /*printEntryBlockArgs=*/false,
1351 /*printBlockTerminators=*/getNumResults() > 0);
1352 p.printOptionalAttrDict(op->getAttrs(), {getOperandSegmentSizesAttrName(),
1353 getStaticLowerBoundAttrName(),
1354 getStaticUpperBoundAttrName(),
1355 getStaticStepAttrName()});
1356}
1357
1358ParseResult ForallOp::parse(OpAsmParser &parser, OperationState &result) {
1359 OpBuilder b(parser.getContext());
1360 auto indexType = b.getIndexType();
1361
1362 // Parse an opening `(` followed by thread index variables followed by `)`
1363 // TODO: when we can refer to such "induction variable"-like handles from the
1364 // declarative assembly format, we can implement the parser as a custom hook.
1365 SmallVector<OpAsmParser::Argument, 4> ivs;
1367 return failure();
1368
1369 DenseI64ArrayAttr staticLbs, staticUbs, staticSteps;
1370 SmallVector<OpAsmParser::UnresolvedOperand> dynamicLbs, dynamicUbs,
1371 dynamicSteps;
1372 if (succeeded(parser.parseOptionalKeyword("in"))) {
1373 // Parse upper bounds.
1374 if (parseDynamicIndexList(parser, dynamicUbs, staticUbs,
1375 /*valueTypes=*/nullptr,
1377 parser.resolveOperands(dynamicUbs, indexType, result.operands))
1378 return failure();
1379
1380 unsigned numLoops = ivs.size();
1381 staticLbs = b.getDenseI64ArrayAttr(SmallVector<int64_t>(numLoops, 0));
1382 staticSteps = b.getDenseI64ArrayAttr(SmallVector<int64_t>(numLoops, 1));
1383 } else {
1384 // Parse lower bounds.
1385 if (parser.parseEqual() ||
1386 parseDynamicIndexList(parser, dynamicLbs, staticLbs,
1387 /*valueTypes=*/nullptr,
1389
1390 parser.resolveOperands(dynamicLbs, indexType, result.operands))
1391 return failure();
1392
1393 // Parse upper bounds.
1394 if (parser.parseKeyword("to") ||
1395 parseDynamicIndexList(parser, dynamicUbs, staticUbs,
1396 /*valueTypes=*/nullptr,
1398 parser.resolveOperands(dynamicUbs, indexType, result.operands))
1399 return failure();
1400
1401 // Parse step values.
1402 if (parser.parseKeyword("step") ||
1403 parseDynamicIndexList(parser, dynamicSteps, staticSteps,
1404 /*valueTypes=*/nullptr,
1406 parser.resolveOperands(dynamicSteps, indexType, result.operands))
1407 return failure();
1408 }
1409
1410 // Parse out operands and results.
1411 SmallVector<OpAsmParser::Argument, 4> regionOutArgs;
1412 SmallVector<OpAsmParser::UnresolvedOperand, 4> outOperands;
1413 SMLoc outOperandsLoc = parser.getCurrentLocation();
1414 if (succeeded(parser.parseOptionalKeyword("shared_outs"))) {
1415 if (outOperands.size() != result.types.size())
1416 return parser.emitError(outOperandsLoc,
1417 "mismatch between out operands and types");
1418 if (parser.parseAssignmentList(regionOutArgs, outOperands) ||
1419 parser.parseOptionalArrowTypeList(result.types) ||
1420 parser.resolveOperands(outOperands, result.types, outOperandsLoc,
1421 result.operands))
1422 return failure();
1423 }
1424
1425 // Parse region.
1426 SmallVector<OpAsmParser::Argument, 4> regionArgs;
1427 std::unique_ptr<Region> region = std::make_unique<Region>();
1428 for (auto &iv : ivs) {
1429 iv.type = b.getIndexType();
1430 regionArgs.push_back(iv);
1431 }
1432 for (const auto &it : llvm::enumerate(regionOutArgs)) {
1433 auto &out = it.value();
1434 out.type = result.types[it.index()];
1435 regionArgs.push_back(out);
1436 }
1437 if (parser.parseRegion(*region, regionArgs))
1438 return failure();
1439
1440 // Ensure terminator and move region.
1441 ForallOp::ensureTerminator(*region, b, result.location);
1442 result.addRegion(std::move(region));
1443
1444 // Parse the optional attribute list.
1445 if (parser.parseOptionalAttrDict(result.attributes))
1446 return failure();
1447
1448 result.addAttribute("staticLowerBound", staticLbs);
1449 result.addAttribute("staticUpperBound", staticUbs);
1450 result.addAttribute("staticStep", staticSteps);
1451 result.addAttribute("operandSegmentSizes",
1453 {static_cast<int32_t>(dynamicLbs.size()),
1454 static_cast<int32_t>(dynamicUbs.size()),
1455 static_cast<int32_t>(dynamicSteps.size()),
1456 static_cast<int32_t>(outOperands.size())}));
1457 return success();
1458}
1459
1460// Builder that takes loop bounds.
1461void ForallOp::build(
1462 mlir::OpBuilder &b, mlir::OperationState &result,
1463 ArrayRef<OpFoldResult> lbs, ArrayRef<OpFoldResult> ubs,
1464 ArrayRef<OpFoldResult> steps, ValueRange outputs,
1465 std::optional<ArrayAttr> mapping,
1466 function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilderFn) {
1467 SmallVector<int64_t> staticLbs, staticUbs, staticSteps;
1468 SmallVector<Value> dynamicLbs, dynamicUbs, dynamicSteps;
1469 dispatchIndexOpFoldResults(lbs, dynamicLbs, staticLbs);
1470 dispatchIndexOpFoldResults(ubs, dynamicUbs, staticUbs);
1471 dispatchIndexOpFoldResults(steps, dynamicSteps, staticSteps);
1472
1473 result.addOperands(dynamicLbs);
1474 result.addOperands(dynamicUbs);
1475 result.addOperands(dynamicSteps);
1476 result.addOperands(outputs);
1477 result.addTypes(TypeRange(outputs));
1478
1479 result.addAttribute(getStaticLowerBoundAttrName(result.name),
1480 b.getDenseI64ArrayAttr(staticLbs));
1481 result.addAttribute(getStaticUpperBoundAttrName(result.name),
1482 b.getDenseI64ArrayAttr(staticUbs));
1483 result.addAttribute(getStaticStepAttrName(result.name),
1484 b.getDenseI64ArrayAttr(staticSteps));
1485 result.addAttribute(
1486 "operandSegmentSizes",
1487 b.getDenseI32ArrayAttr({static_cast<int32_t>(dynamicLbs.size()),
1488 static_cast<int32_t>(dynamicUbs.size()),
1489 static_cast<int32_t>(dynamicSteps.size()),
1490 static_cast<int32_t>(outputs.size())}));
1491 if (mapping.has_value()) {
1492 result.addAttribute(ForallOp::getMappingAttrName(result.name),
1493 mapping.value());
1494 }
1495
1496 Region *bodyRegion = result.addRegion();
1497 OpBuilder::InsertionGuard g(b);
1498 b.createBlock(bodyRegion);
1499 Block &bodyBlock = bodyRegion->front();
1500
1501 // Add block arguments for indices and outputs.
1502 bodyBlock.addArguments(
1503 SmallVector<Type>(lbs.size(), b.getIndexType()),
1504 SmallVector<Location>(staticLbs.size(), result.location));
1505 bodyBlock.addArguments(
1506 TypeRange(outputs),
1507 SmallVector<Location>(outputs.size(), result.location));
1508
1509 b.setInsertionPointToStart(&bodyBlock);
1510 if (!bodyBuilderFn) {
1511 ForallOp::ensureTerminator(*bodyRegion, b, result.location);
1512 return;
1513 }
1514 bodyBuilderFn(b, result.location, bodyBlock.getArguments());
1515}
1516
1517// Builder that takes loop bounds.
1518void ForallOp::build(
1519 mlir::OpBuilder &b, mlir::OperationState &result,
1520 ArrayRef<OpFoldResult> ubs, ValueRange outputs,
1521 std::optional<ArrayAttr> mapping,
1522 function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilderFn) {
1523 unsigned numLoops = ubs.size();
1524 SmallVector<OpFoldResult> lbs(numLoops, b.getIndexAttr(0));
1525 SmallVector<OpFoldResult> steps(numLoops, b.getIndexAttr(1));
1526 build(b, result, lbs, ubs, steps, outputs, mapping, bodyBuilderFn);
1527}
1528
1529// Checks if the lbs are zeros and steps are ones.
1530bool ForallOp::isNormalized() {
1531 auto allEqual = [](ArrayRef<OpFoldResult> results, int64_t val) {
1532 return llvm::all_of(results, [&](OpFoldResult ofr) {
1533 auto intValue = getConstantIntValue(ofr);
1534 return intValue.has_value() && intValue == val;
1535 });
1536 };
1537 return allEqual(getMixedLowerBound(), 0) && allEqual(getMixedStep(), 1);
1538}
1539
1540InParallelOp ForallOp::getTerminator() {
1541 return cast<InParallelOp>(getBody()->getTerminator());
1542}
1543
1544SmallVector<Operation *> ForallOp::getCombiningOps(BlockArgument bbArg) {
1545 SmallVector<Operation *> storeOps;
1546 for (Operation *user : bbArg.getUsers()) {
1547 if (auto parallelOp = dyn_cast<ParallelCombiningOpInterface>(user)) {
1548 storeOps.push_back(parallelOp);
1549 }
1550 }
1551 return storeOps;
1552}
1553
1554SmallVector<DeviceMappingAttrInterface> ForallOp::getDeviceMappingAttrs() {
1555 SmallVector<DeviceMappingAttrInterface> res;
1556 if (!getMapping())
1557 return res;
1558 for (auto attr : getMapping()->getValue()) {
1559 auto m = dyn_cast<DeviceMappingAttrInterface>(attr);
1560 if (m)
1561 res.push_back(m);
1562 }
1563 return res;
1564}
1565
1566FailureOr<DeviceMaskingAttrInterface> ForallOp::getDeviceMaskingAttr() {
1567 DeviceMaskingAttrInterface res;
1568 if (!getMapping())
1569 return res;
1570 for (auto attr : getMapping()->getValue()) {
1571 auto m = dyn_cast<DeviceMaskingAttrInterface>(attr);
1572 if (m && res)
1573 return failure();
1574 if (m)
1575 res = m;
1576 }
1577 return res;
1578}
1579
1580bool ForallOp::usesLinearMapping() {
1581 SmallVector<DeviceMappingAttrInterface> ifaces = getDeviceMappingAttrs();
1582 if (ifaces.empty())
1583 return false;
1584 return ifaces.front().isLinearMapping();
1585}
1586
1587std::optional<SmallVector<Value>> ForallOp::getLoopInductionVars() {
1588 return SmallVector<Value>{getBody()->getArguments().take_front(getRank())};
1589}
1590
1591// Get lower bounds as OpFoldResult.
1592std::optional<SmallVector<OpFoldResult>> ForallOp::getLoopLowerBounds() {
1593 Builder b(getOperation()->getContext());
1594 return getMixedValues(getStaticLowerBound(), getDynamicLowerBound(), b);
1595}
1596
1597// Get upper bounds as OpFoldResult.
1598std::optional<SmallVector<OpFoldResult>> ForallOp::getLoopUpperBounds() {
1599 Builder b(getOperation()->getContext());
1600 return getMixedValues(getStaticUpperBound(), getDynamicUpperBound(), b);
1601}
1602
1603// Get steps as OpFoldResult.
1604std::optional<SmallVector<OpFoldResult>> ForallOp::getLoopSteps() {
1605 Builder b(getOperation()->getContext());
1606 return getMixedValues(getStaticStep(), getDynamicStep(), b);
1607}
1608
1610 auto tidxArg = llvm::dyn_cast<BlockArgument>(val);
1611 if (!tidxArg)
1612 return ForallOp();
1613 assert(tidxArg.getOwner() && "unlinked block argument");
1614 auto *containingOp = tidxArg.getOwner()->getParentOp();
1615 return dyn_cast<ForallOp>(containingOp);
1616}
1617
1618namespace {
1619/// Fold tensor.dim(forall shared_outs(... = %t)) to tensor.dim(%t).
1620struct DimOfForallOp : public OpRewritePattern<tensor::DimOp> {
1621 using OpRewritePattern<tensor::DimOp>::OpRewritePattern;
1622
1623 LogicalResult matchAndRewrite(tensor::DimOp dimOp,
1624 PatternRewriter &rewriter) const final {
1625 auto forallOp = dimOp.getSource().getDefiningOp<ForallOp>();
1626 if (!forallOp)
1627 return failure();
1628 Value sharedOut =
1629 forallOp.getTiedOpOperand(llvm::cast<OpResult>(dimOp.getSource()))
1630 ->get();
1631 rewriter.modifyOpInPlace(
1632 dimOp, [&]() { dimOp.getSourceMutable().assign(sharedOut); });
1633 return success();
1634 }
1635};
1636
1637class ForallOpControlOperandsFolder : public OpRewritePattern<ForallOp> {
1638public:
1639 using OpRewritePattern<ForallOp>::OpRewritePattern;
1640
1641 LogicalResult matchAndRewrite(ForallOp op,
1642 PatternRewriter &rewriter) const override {
1643 SmallVector<OpFoldResult> mixedLowerBound(op.getMixedLowerBound());
1644 SmallVector<OpFoldResult> mixedUpperBound(op.getMixedUpperBound());
1645 SmallVector<OpFoldResult> mixedStep(op.getMixedStep());
1646 if (failed(foldDynamicIndexList(mixedLowerBound)) &&
1647 failed(foldDynamicIndexList(mixedUpperBound)) &&
1648 failed(foldDynamicIndexList(mixedStep)))
1649 return failure();
1650
1651 rewriter.modifyOpInPlace(op, [&]() {
1652 SmallVector<Value> dynamicLowerBound, dynamicUpperBound, dynamicStep;
1653 SmallVector<int64_t> staticLowerBound, staticUpperBound, staticStep;
1654 dispatchIndexOpFoldResults(mixedLowerBound, dynamicLowerBound,
1655 staticLowerBound);
1656 op.getDynamicLowerBoundMutable().assign(dynamicLowerBound);
1657 op.setStaticLowerBound(staticLowerBound);
1658
1659 dispatchIndexOpFoldResults(mixedUpperBound, dynamicUpperBound,
1660 staticUpperBound);
1661 op.getDynamicUpperBoundMutable().assign(dynamicUpperBound);
1662 op.setStaticUpperBound(staticUpperBound);
1663
1664 dispatchIndexOpFoldResults(mixedStep, dynamicStep, staticStep);
1665 op.getDynamicStepMutable().assign(dynamicStep);
1666 op.setStaticStep(staticStep);
1667
1668 op->setAttr(ForallOp::getOperandSegmentSizeAttr(),
1669 rewriter.getDenseI32ArrayAttr(
1670 {static_cast<int32_t>(dynamicLowerBound.size()),
1671 static_cast<int32_t>(dynamicUpperBound.size()),
1672 static_cast<int32_t>(dynamicStep.size()),
1673 static_cast<int32_t>(op.getNumResults())}));
1674 });
1675 return success();
1676 }
1677};
1678
1679/// The following canonicalization pattern folds the iter arguments of
1680/// scf.forall op if :-
1681/// 1. The corresponding result has zero uses.
1682/// 2. The iter argument is NOT being modified within the loop body.
1683/// uses.
1684///
1685/// Example of first case :-
1686/// INPUT:
1687/// %res:3 = scf.forall ... shared_outs(%arg0 = %a, %arg1 = %b, %arg2 = %c)
1688/// {
1689/// ...
1690/// <SOME USE OF %arg0>
1691/// <SOME USE OF %arg1>
1692/// <SOME USE OF %arg2>
1693/// ...
1694/// scf.forall.in_parallel {
1695/// <STORE OP WITH DESTINATION %arg1>
1696/// <STORE OP WITH DESTINATION %arg0>
1697/// <STORE OP WITH DESTINATION %arg2>
1698/// }
1699/// }
1700/// return %res#1
1701///
1702/// OUTPUT:
1703/// %res:3 = scf.forall ... shared_outs(%new_arg0 = %b)
1704/// {
1705/// ...
1706/// <SOME USE OF %a>
1707/// <SOME USE OF %new_arg0>
1708/// <SOME USE OF %c>
1709/// ...
1710/// scf.forall.in_parallel {
1711/// <STORE OP WITH DESTINATION %new_arg0>
1712/// }
1713/// }
1714/// return %res
1715///
1716/// NOTE: 1. All uses of the folded shared_outs (iter argument) within the
1717/// scf.forall is replaced by their corresponding operands.
1718/// 2. Even if there are <STORE OP WITH DESTINATION *> ops within the body
1719/// of the scf.forall besides within scf.forall.in_parallel terminator,
1720/// this canonicalization remains valid. For more details, please refer
1721/// to :
1722/// https://github.com/llvm/llvm-project/pull/90189#discussion_r1589011124
1723/// 3. TODO(avarma): Generalize it for other store ops. Currently it
1724/// handles tensor.parallel_insert_slice ops only.
1725///
1726/// Example of second case :-
1727/// INPUT:
1728/// %res:2 = scf.forall ... shared_outs(%arg0 = %a, %arg1 = %b)
1729/// {
1730/// ...
1731/// <SOME USE OF %arg0>
1732/// <SOME USE OF %arg1>
1733/// ...
1734/// scf.forall.in_parallel {
1735/// <STORE OP WITH DESTINATION %arg1>
1736/// }
1737/// }
1738/// return %res#0, %res#1
1739///
1740/// OUTPUT:
1741/// %res = scf.forall ... shared_outs(%new_arg0 = %b)
1742/// {
1743/// ...
1744/// <SOME USE OF %a>
1745/// <SOME USE OF %new_arg0>
1746/// ...
1747/// scf.forall.in_parallel {
1748/// <STORE OP WITH DESTINATION %new_arg0>
1749/// }
1750/// }
1751/// return %a, %res
1752struct ForallOpIterArgsFolder : public OpRewritePattern<ForallOp> {
1753 using OpRewritePattern<ForallOp>::OpRewritePattern;
1754
1755 LogicalResult matchAndRewrite(ForallOp forallOp,
1756 PatternRewriter &rewriter) const final {
1757 // Step 1: For a given i-th result of scf.forall, check the following :-
1758 // a. If it has any use.
1759 // b. If the corresponding iter argument is being modified within
1760 // the loop, i.e. has at least one store op with the iter arg as
1761 // its destination operand. For this we use
1762 // ForallOp::getCombiningOps(iter_arg).
1763 //
1764 // Based on the check we maintain the following :-
1765 // a. `resultToDelete` - i-th result of scf.forall that'll be
1766 // deleted.
1767 // b. `resultToReplace` - i-th result of the old scf.forall
1768 // whose uses will be replaced by the new scf.forall.
1769 // c. `newOuts` - the shared_outs' operand of the new scf.forall
1770 // corresponding to the i-th result with at least one use.
1771 SetVector<OpResult> resultToDelete;
1772 SmallVector<Value> resultToReplace;
1773 SmallVector<Value> newOuts;
1774 for (OpResult result : forallOp.getResults()) {
1775 OpOperand *opOperand = forallOp.getTiedOpOperand(result);
1776 BlockArgument blockArg = forallOp.getTiedBlockArgument(opOperand);
1777 if (result.use_empty() || forallOp.getCombiningOps(blockArg).empty()) {
1778 resultToDelete.insert(result);
1779 } else {
1780 resultToReplace.push_back(result);
1781 newOuts.push_back(opOperand->get());
1782 }
1783 }
1784
1785 // Return early if all results of scf.forall have at least one use and being
1786 // modified within the loop.
1787 if (resultToDelete.empty())
1788 return failure();
1789
1790 // Step 2: For the the i-th result, do the following :-
1791 // a. Fetch the corresponding BlockArgument.
1792 // b. Look for store ops (currently tensor.parallel_insert_slice)
1793 // with the BlockArgument as its destination operand.
1794 // c. Remove the operations fetched in b.
1795 for (OpResult result : resultToDelete) {
1796 OpOperand *opOperand = forallOp.getTiedOpOperand(result);
1797 BlockArgument blockArg = forallOp.getTiedBlockArgument(opOperand);
1798 SmallVector<Operation *> combiningOps =
1799 forallOp.getCombiningOps(blockArg);
1800 for (Operation *combiningOp : combiningOps)
1801 rewriter.eraseOp(combiningOp);
1802 }
1803
1804 // Step 3. Create a new scf.forall op with the new shared_outs' operands
1805 // fetched earlier
1806 auto newForallOp = scf::ForallOp::create(
1807 rewriter, forallOp.getLoc(), forallOp.getMixedLowerBound(),
1808 forallOp.getMixedUpperBound(), forallOp.getMixedStep(), newOuts,
1809 forallOp.getMapping(),
1810 /*bodyBuilderFn =*/[](OpBuilder &, Location, ValueRange) {});
1811
1812 // Step 4. Merge the block of the old scf.forall into the newly created
1813 // scf.forall using the new set of arguments.
1814 Block *loopBody = forallOp.getBody();
1815 Block *newLoopBody = newForallOp.getBody();
1816 ArrayRef<BlockArgument> newBbArgs = newLoopBody->getArguments();
1817 // Form initial new bbArg list with just the control operands of the new
1818 // scf.forall op.
1819 SmallVector<Value> newBlockArgs =
1820 llvm::map_to_vector(newBbArgs.take_front(forallOp.getRank()),
1821 [](BlockArgument b) -> Value { return b; });
1822 Block::BlockArgListType newSharedOutsArgs = newForallOp.getRegionOutArgs();
1823 unsigned index = 0;
1824 // Take the new corresponding bbArg if the old bbArg was used as a
1825 // destination in the in_parallel op. For all other bbArgs, use the
1826 // corresponding init_arg from the old scf.forall op.
1827 for (OpResult result : forallOp.getResults()) {
1828 if (resultToDelete.count(result)) {
1829 newBlockArgs.push_back(forallOp.getTiedOpOperand(result)->get());
1830 } else {
1831 newBlockArgs.push_back(newSharedOutsArgs[index++]);
1832 }
1833 }
1834 rewriter.mergeBlocks(loopBody, newLoopBody, newBlockArgs);
1835
1836 // Step 5. Replace the uses of result of old scf.forall with that of the new
1837 // scf.forall.
1838 for (auto &&[oldResult, newResult] :
1839 llvm::zip(resultToReplace, newForallOp->getResults()))
1840 rewriter.replaceAllUsesWith(oldResult, newResult);
1841
1842 // Step 6. Replace the uses of those values that either has no use or are
1843 // not being modified within the loop with the corresponding
1844 // OpOperand.
1845 for (OpResult oldResult : resultToDelete)
1846 rewriter.replaceAllUsesWith(oldResult,
1847 forallOp.getTiedOpOperand(oldResult)->get());
1848 return success();
1849 }
1850};
1851
1852struct ForallOpSingleOrZeroIterationDimsFolder
1853 : public OpRewritePattern<ForallOp> {
1854 using OpRewritePattern<ForallOp>::OpRewritePattern;
1855
1856 LogicalResult matchAndRewrite(ForallOp op,
1857 PatternRewriter &rewriter) const override {
1858 // Do not fold dimensions if they are mapped to processing units.
1859 if (op.getMapping().has_value() && !op.getMapping()->empty())
1860 return failure();
1861 Location loc = op.getLoc();
1862
1863 // Compute new loop bounds that omit all single-iteration loop dimensions.
1864 SmallVector<OpFoldResult> newMixedLowerBounds, newMixedUpperBounds,
1865 newMixedSteps;
1866 IRMapping mapping;
1867 for (auto [lb, ub, step, iv] :
1868 llvm::zip(op.getMixedLowerBound(), op.getMixedUpperBound(),
1869 op.getMixedStep(), op.getInductionVars())) {
1870 auto numIterations =
1871 constantTripCount(lb, ub, step, /*isSigned=*/true, computeUbMinusLb);
1872 if (numIterations.has_value()) {
1873 // Remove the loop if it performs zero iterations.
1874 if (*numIterations == 0) {
1875 rewriter.replaceOp(op, op.getOutputs());
1876 return success();
1877 }
1878 // Replace the loop induction variable by the lower bound if the loop
1879 // performs a single iteration. Otherwise, copy the loop bounds.
1880 if (*numIterations == 1) {
1881 mapping.map(iv, getValueOrCreateConstantIndexOp(rewriter, loc, lb));
1882 continue;
1883 }
1884 }
1885 newMixedLowerBounds.push_back(lb);
1886 newMixedUpperBounds.push_back(ub);
1887 newMixedSteps.push_back(step);
1888 }
1889
1890 // All of the loop dimensions perform a single iteration. Inline loop body.
1891 if (newMixedLowerBounds.empty()) {
1892 promote(rewriter, op);
1893 return success();
1894 }
1895
1896 // Exit if none of the loop dimensions perform a single iteration.
1897 if (newMixedLowerBounds.size() == static_cast<unsigned>(op.getRank())) {
1898 return rewriter.notifyMatchFailure(
1899 op, "no dimensions have 0 or 1 iterations");
1900 }
1901
1902 // Replace the loop by a lower-dimensional loop.
1903 ForallOp newOp;
1904 newOp = ForallOp::create(rewriter, loc, newMixedLowerBounds,
1905 newMixedUpperBounds, newMixedSteps,
1906 op.getOutputs(), std::nullopt, nullptr);
1907 newOp.getBodyRegion().getBlocks().clear();
1908 // The new loop needs to keep all attributes from the old one, except for
1909 // "operandSegmentSizes" and static loop bound attributes which capture
1910 // the outdated information of the old iteration domain.
1911 SmallVector<StringAttr> elidedAttrs{newOp.getOperandSegmentSizesAttrName(),
1912 newOp.getStaticLowerBoundAttrName(),
1913 newOp.getStaticUpperBoundAttrName(),
1914 newOp.getStaticStepAttrName()};
1915 for (const auto &namedAttr : op->getAttrs()) {
1916 if (llvm::is_contained(elidedAttrs, namedAttr.getName()))
1917 continue;
1918 rewriter.modifyOpInPlace(newOp, [&]() {
1919 newOp->setAttr(namedAttr.getName(), namedAttr.getValue());
1920 });
1921 }
1922 rewriter.cloneRegionBefore(op.getRegion(), newOp.getRegion(),
1923 newOp.getRegion().begin(), mapping);
1924 rewriter.replaceOp(op, newOp.getResults());
1925 return success();
1926 }
1927};
1928
1929/// Replace all induction vars with a single trip count with their lower bound.
1930struct ForallOpReplaceConstantInductionVar : public OpRewritePattern<ForallOp> {
1931 using OpRewritePattern<ForallOp>::OpRewritePattern;
1932
1933 LogicalResult matchAndRewrite(ForallOp op,
1934 PatternRewriter &rewriter) const override {
1935 Location loc = op.getLoc();
1936 bool changed = false;
1937 for (auto [lb, ub, step, iv] :
1938 llvm::zip(op.getMixedLowerBound(), op.getMixedUpperBound(),
1939 op.getMixedStep(), op.getInductionVars())) {
1940 if (iv.hasNUses(0))
1941 continue;
1942 auto numIterations =
1943 constantTripCount(lb, ub, step, /*isSigned=*/true, computeUbMinusLb);
1944 if (!numIterations.has_value() || numIterations.value() != 1) {
1945 continue;
1946 }
1947 rewriter.replaceAllUsesWith(
1948 iv, getValueOrCreateConstantIndexOp(rewriter, loc, lb));
1949 changed = true;
1950 }
1951 return success(changed);
1952 }
1953};
1954
1955struct FoldTensorCastOfOutputIntoForallOp
1956 : public OpRewritePattern<scf::ForallOp> {
1957 using OpRewritePattern<scf::ForallOp>::OpRewritePattern;
1958
1959 struct TypeCast {
1960 Type srcType;
1961 Type dstType;
1962 };
1963
1964 LogicalResult matchAndRewrite(scf::ForallOp forallOp,
1965 PatternRewriter &rewriter) const final {
1966 llvm::SmallMapVector<unsigned, TypeCast, 2> tensorCastProducers;
1967 llvm::SmallVector<Value> newOutputTensors = forallOp.getOutputs();
1968 for (auto en : llvm::enumerate(newOutputTensors)) {
1969 auto castOp = en.value().getDefiningOp<tensor::CastOp>();
1970 if (!castOp)
1971 continue;
1972
1973 // Only casts that that preserve static information, i.e. will make the
1974 // loop result type "more" static than before, will be folded.
1975 if (!tensor::preservesStaticInformation(castOp.getDest().getType(),
1976 castOp.getSource().getType())) {
1977 continue;
1978 }
1979
1980 tensorCastProducers[en.index()] =
1981 TypeCast{castOp.getSource().getType(), castOp.getType()};
1982 newOutputTensors[en.index()] = castOp.getSource();
1983 }
1984
1985 if (tensorCastProducers.empty())
1986 return failure();
1987
1988 // Create new loop.
1989 Location loc = forallOp.getLoc();
1990 auto newForallOp = ForallOp::create(
1991 rewriter, loc, forallOp.getMixedLowerBound(),
1992 forallOp.getMixedUpperBound(), forallOp.getMixedStep(),
1993 newOutputTensors, forallOp.getMapping(),
1994 [&](OpBuilder nestedBuilder, Location nestedLoc, ValueRange bbArgs) {
1995 auto castBlockArgs =
1996 llvm::to_vector(bbArgs.take_back(forallOp->getNumResults()));
1997 for (auto [index, cast] : tensorCastProducers) {
1998 Value &oldTypeBBArg = castBlockArgs[index];
1999 oldTypeBBArg = tensor::CastOp::create(nestedBuilder, nestedLoc,
2000 cast.dstType, oldTypeBBArg);
2001 }
2002
2003 // Move old body into new parallel loop.
2004 SmallVector<Value> ivsBlockArgs =
2005 llvm::to_vector(bbArgs.take_front(forallOp.getRank()));
2006 ivsBlockArgs.append(castBlockArgs);
2007 rewriter.mergeBlocks(forallOp.getBody(),
2008 bbArgs.front().getParentBlock(), ivsBlockArgs);
2009 });
2010
2011 // After `mergeBlocks` happened, the destinations in the terminator were
2012 // mapped to the tensor.cast old-typed results of the output bbArgs. The
2013 // destination have to be updated to point to the output bbArgs directly.
2014 auto terminator = newForallOp.getTerminator();
2015 for (auto [yieldingOp, outputBlockArg] : llvm::zip(
2016 terminator.getYieldingOps(), newForallOp.getRegionIterArgs())) {
2017 if (auto parallelCombingingOp =
2018 dyn_cast<ParallelCombiningOpInterface>(yieldingOp)) {
2019 parallelCombingingOp.getUpdatedDestinations().assign(outputBlockArg);
2020 }
2021 }
2022
2023 // Cast results back to the original types.
2024 rewriter.setInsertionPointAfter(newForallOp);
2025 SmallVector<Value> castResults = newForallOp.getResults();
2026 for (auto &item : tensorCastProducers) {
2027 Value &oldTypeResult = castResults[item.first];
2028 oldTypeResult = tensor::CastOp::create(rewriter, loc, item.second.dstType,
2029 oldTypeResult);
2030 }
2031 rewriter.replaceOp(forallOp, castResults);
2032 return success();
2033 }
2034};
2035
2036} // namespace
2037
2038void ForallOp::getCanonicalizationPatterns(RewritePatternSet &results,
2039 MLIRContext *context) {
2040 results.add<DimOfForallOp, FoldTensorCastOfOutputIntoForallOp,
2041 ForallOpControlOperandsFolder, ForallOpIterArgsFolder,
2042 ForallOpSingleOrZeroIterationDimsFolder,
2043 ForallOpReplaceConstantInductionVar>(context);
2044}
2045
2046/// Given the region at `index`, or the parent operation if `index` is None,
2047/// return the successor regions. These are the regions that may be selected
2048/// during the flow of control. `operands` is a set of optional attributes that
2049/// correspond to a constant value for each operand, or null if that operand is
2050/// not a constant.
2051void ForallOp::getSuccessorRegions(RegionBranchPoint point,
2052 SmallVectorImpl<RegionSuccessor> &regions) {
2053 // In accordance with the semantics of forall, its body is executed in
2054 // parallel by multiple threads. We should not expect to branch back into
2055 // the forall body after the region's execution is complete.
2056 if (point.isParent())
2057 regions.push_back(RegionSuccessor(&getRegion(), getRegionIterArgs()));
2058 else
2059 regions.push_back(
2060 RegionSuccessor(getOperation(), getOperation()->getResults()));
2061}
2062
2063//===----------------------------------------------------------------------===//
2064// InParallelOp
2065//===----------------------------------------------------------------------===//
2066
2067// Build a InParallelOp with mixed static and dynamic entries.
2068void InParallelOp::build(OpBuilder &b, OperationState &result) {
2069 OpBuilder::InsertionGuard g(b);
2070 Region *bodyRegion = result.addRegion();
2071 b.createBlock(bodyRegion);
2072}
2073
2074LogicalResult InParallelOp::verify() {
2075 scf::ForallOp forallOp =
2076 dyn_cast<scf::ForallOp>(getOperation()->getParentOp());
2077 if (!forallOp)
2078 return this->emitOpError("expected forall op parent");
2079
2080 for (Operation &op : getRegion().front().getOperations()) {
2081 auto parallelCombiningOp = dyn_cast<ParallelCombiningOpInterface>(&op);
2082 if (!parallelCombiningOp) {
2083 return this->emitOpError("expected only ParallelCombiningOpInterface")
2084 << " ops";
2085 }
2086
2087 // Verify that inserts are into out block arguments.
2088 MutableOperandRange dests = parallelCombiningOp.getUpdatedDestinations();
2089 ArrayRef<BlockArgument> regionOutArgs = forallOp.getRegionOutArgs();
2090 for (OpOperand &dest : dests) {
2091 if (!llvm::is_contained(regionOutArgs, dest.get()))
2092 return op.emitOpError("may only insert into an output block argument");
2093 }
2094 }
2095
2096 return success();
2097}
2098
2099void InParallelOp::print(OpAsmPrinter &p) {
2100 p << " ";
2101 p.printRegion(getRegion(),
2102 /*printEntryBlockArgs=*/false,
2103 /*printBlockTerminators=*/false);
2104 p.printOptionalAttrDict(getOperation()->getAttrs());
2105}
2106
2107ParseResult InParallelOp::parse(OpAsmParser &parser, OperationState &result) {
2108 auto &builder = parser.getBuilder();
2109
2110 SmallVector<OpAsmParser::Argument, 8> regionOperands;
2111 std::unique_ptr<Region> region = std::make_unique<Region>();
2112 if (parser.parseRegion(*region, regionOperands))
2113 return failure();
2114
2115 if (region->empty())
2116 OpBuilder(builder.getContext()).createBlock(region.get());
2117 result.addRegion(std::move(region));
2118
2119 // Parse the optional attribute list.
2120 if (parser.parseOptionalAttrDict(result.attributes))
2121 return failure();
2122 return success();
2123}
2124
2125OpResult InParallelOp::getParentResult(int64_t idx) {
2126 return getOperation()->getParentOp()->getResult(idx);
2127}
2128
2129SmallVector<BlockArgument> InParallelOp::getDests() {
2130 SmallVector<BlockArgument> updatedDests;
2131 for (Operation &yieldingOp : getYieldingOps()) {
2132 auto parallelCombiningOp =
2133 dyn_cast<ParallelCombiningOpInterface>(&yieldingOp);
2134 if (!parallelCombiningOp)
2135 continue;
2136 for (OpOperand &updatedOperand :
2137 parallelCombiningOp.getUpdatedDestinations())
2138 updatedDests.push_back(cast<BlockArgument>(updatedOperand.get()));
2139 }
2140 return updatedDests;
2141}
2142
2143llvm::iterator_range<Block::iterator> InParallelOp::getYieldingOps() {
2144 return getRegion().front().getOperations();
2145}
2146
2147//===----------------------------------------------------------------------===//
2148// IfOp
2149//===----------------------------------------------------------------------===//
2150
2152 assert(a && "expected non-empty operation");
2153 assert(b && "expected non-empty operation");
2154
2155 IfOp ifOp = a->getParentOfType<IfOp>();
2156 while (ifOp) {
2157 // Check if b is inside ifOp. (We already know that a is.)
2158 if (ifOp->isProperAncestor(b))
2159 // b is contained in ifOp. a and b are in mutually exclusive branches if
2160 // they are in different blocks of ifOp.
2161 return static_cast<bool>(ifOp.thenBlock()->findAncestorOpInBlock(*a)) !=
2162 static_cast<bool>(ifOp.thenBlock()->findAncestorOpInBlock(*b));
2163 // Check next enclosing IfOp.
2164 ifOp = ifOp->getParentOfType<IfOp>();
2165 }
2166
2167 // Could not find a common IfOp among a's and b's ancestors.
2168 return false;
2169}
2170
2171LogicalResult
2172IfOp::inferReturnTypes(MLIRContext *ctx, std::optional<Location> loc,
2173 IfOp::Adaptor adaptor,
2174 SmallVectorImpl<Type> &inferredReturnTypes) {
2175 if (adaptor.getRegions().empty())
2176 return failure();
2177 Region *r = &adaptor.getThenRegion();
2178 if (r->empty())
2179 return failure();
2180 Block &b = r->front();
2181 if (b.empty())
2182 return failure();
2183 auto yieldOp = llvm::dyn_cast<YieldOp>(b.back());
2184 if (!yieldOp)
2185 return failure();
2186 TypeRange types = yieldOp.getOperandTypes();
2187 llvm::append_range(inferredReturnTypes, types);
2188 return success();
2189}
2190
2191void IfOp::build(OpBuilder &builder, OperationState &result,
2192 TypeRange resultTypes, Value cond) {
2193 return build(builder, result, resultTypes, cond, /*addThenBlock=*/false,
2194 /*addElseBlock=*/false);
2195}
2196
2197void IfOp::build(OpBuilder &builder, OperationState &result,
2198 TypeRange resultTypes, Value cond, bool addThenBlock,
2199 bool addElseBlock) {
2200 assert((!addElseBlock || addThenBlock) &&
2201 "must not create else block w/o then block");
2202 result.addTypes(resultTypes);
2203 result.addOperands(cond);
2204
2205 // Add regions and blocks.
2206 OpBuilder::InsertionGuard guard(builder);
2207 Region *thenRegion = result.addRegion();
2208 if (addThenBlock)
2209 builder.createBlock(thenRegion);
2210 Region *elseRegion = result.addRegion();
2211 if (addElseBlock)
2212 builder.createBlock(elseRegion);
2213}
2214
2215void IfOp::build(OpBuilder &builder, OperationState &result, Value cond,
2216 bool withElseRegion) {
2217 build(builder, result, TypeRange{}, cond, withElseRegion);
2218}
2219
2220void IfOp::build(OpBuilder &builder, OperationState &result,
2221 TypeRange resultTypes, Value cond, bool withElseRegion) {
2222 result.addTypes(resultTypes);
2223 result.addOperands(cond);
2224
2225 // Build then region.
2226 OpBuilder::InsertionGuard guard(builder);
2227 Region *thenRegion = result.addRegion();
2228 builder.createBlock(thenRegion);
2229 if (resultTypes.empty())
2230 IfOp::ensureTerminator(*thenRegion, builder, result.location);
2231
2232 // Build else region.
2233 Region *elseRegion = result.addRegion();
2234 if (withElseRegion) {
2235 builder.createBlock(elseRegion);
2236 if (resultTypes.empty())
2237 IfOp::ensureTerminator(*elseRegion, builder, result.location);
2238 }
2239}
2240
2241void IfOp::build(OpBuilder &builder, OperationState &result, Value cond,
2242 function_ref<void(OpBuilder &, Location)> thenBuilder,
2243 function_ref<void(OpBuilder &, Location)> elseBuilder) {
2244 assert(thenBuilder && "the builder callback for 'then' must be present");
2245 result.addOperands(cond);
2246
2247 // Build then region.
2248 OpBuilder::InsertionGuard guard(builder);
2249 Region *thenRegion = result.addRegion();
2250 builder.createBlock(thenRegion);
2251 thenBuilder(builder, result.location);
2252
2253 // Build else region.
2254 Region *elseRegion = result.addRegion();
2255 if (elseBuilder) {
2256 builder.createBlock(elseRegion);
2257 elseBuilder(builder, result.location);
2258 }
2259
2260 // Infer result types.
2261 SmallVector<Type> inferredReturnTypes;
2262 MLIRContext *ctx = builder.getContext();
2263 auto attrDict = DictionaryAttr::get(ctx, result.attributes);
2264 if (succeeded(inferReturnTypes(ctx, std::nullopt, result.operands, attrDict,
2265 /*properties=*/nullptr, result.regions,
2266 inferredReturnTypes))) {
2267 result.addTypes(inferredReturnTypes);
2268 }
2269}
2270
2271LogicalResult IfOp::verify() {
2272 if (getNumResults() != 0 && getElseRegion().empty())
2273 return emitOpError("must have an else block if defining values");
2274 return success();
2275}
2276
2277ParseResult IfOp::parse(OpAsmParser &parser, OperationState &result) {
2278 // Create the regions for 'then'.
2279 result.regions.reserve(2);
2280 Region *thenRegion = result.addRegion();
2281 Region *elseRegion = result.addRegion();
2282
2283 auto &builder = parser.getBuilder();
2284 OpAsmParser::UnresolvedOperand cond;
2285 Type i1Type = builder.getIntegerType(1);
2286 if (parser.parseOperand(cond) ||
2287 parser.resolveOperand(cond, i1Type, result.operands))
2288 return failure();
2289 // Parse optional results type list.
2290 if (parser.parseOptionalArrowTypeList(result.types))
2291 return failure();
2292 // Parse the 'then' region.
2293 if (parser.parseRegion(*thenRegion, /*arguments=*/{}, /*argTypes=*/{}))
2294 return failure();
2295 IfOp::ensureTerminator(*thenRegion, parser.getBuilder(), result.location);
2296
2297 // If we find an 'else' keyword then parse the 'else' region.
2298 if (!parser.parseOptionalKeyword("else")) {
2299 if (parser.parseRegion(*elseRegion, /*arguments=*/{}, /*argTypes=*/{}))
2300 return failure();
2301 IfOp::ensureTerminator(*elseRegion, parser.getBuilder(), result.location);
2302 }
2303
2304 // Parse the optional attribute list.
2305 if (parser.parseOptionalAttrDict(result.attributes))
2306 return failure();
2307 return success();
2308}
2309
2310void IfOp::print(OpAsmPrinter &p) {
2311 bool printBlockTerminators = false;
2312
2313 p << " " << getCondition();
2314 if (!getResults().empty()) {
2315 p << " -> (" << getResultTypes() << ")";
2316 // Print yield explicitly if the op defines values.
2317 printBlockTerminators = true;
2318 }
2319 p << ' ';
2320 p.printRegion(getThenRegion(),
2321 /*printEntryBlockArgs=*/false,
2322 /*printBlockTerminators=*/printBlockTerminators);
2323
2324 // Print the 'else' regions if it exists and has a block.
2325 auto &elseRegion = getElseRegion();
2326 if (!elseRegion.empty()) {
2327 p << " else ";
2328 p.printRegion(elseRegion,
2329 /*printEntryBlockArgs=*/false,
2330 /*printBlockTerminators=*/printBlockTerminators);
2331 }
2332
2333 p.printOptionalAttrDict((*this)->getAttrs());
2334}
2335
2336void IfOp::getSuccessorRegions(RegionBranchPoint point,
2337 SmallVectorImpl<RegionSuccessor> &regions) {
2338 // The `then` and the `else` region branch back to the parent operation or one
2339 // of the recursive parent operations (early exit case).
2340 if (!point.isParent()) {
2341 regions.push_back(RegionSuccessor(getOperation(), getResults()));
2342 return;
2343 }
2344
2345 regions.push_back(RegionSuccessor(&getThenRegion()));
2346
2347 // Don't consider the else region if it is empty.
2348 Region *elseRegion = &this->getElseRegion();
2349 if (elseRegion->empty())
2350 regions.push_back(
2351 RegionSuccessor(getOperation(), getOperation()->getResults()));
2352 else
2353 regions.push_back(RegionSuccessor(elseRegion));
2354}
2355
2356void IfOp::getEntrySuccessorRegions(ArrayRef<Attribute> operands,
2357 SmallVectorImpl<RegionSuccessor> &regions) {
2358 FoldAdaptor adaptor(operands, *this);
2359 auto boolAttr = dyn_cast_or_null<BoolAttr>(adaptor.getCondition());
2360 if (!boolAttr || boolAttr.getValue())
2361 regions.emplace_back(&getThenRegion());
2362
2363 // If the else region is empty, execution continues after the parent op.
2364 if (!boolAttr || !boolAttr.getValue()) {
2365 if (!getElseRegion().empty())
2366 regions.emplace_back(&getElseRegion());
2367 else
2368 regions.emplace_back(getOperation(), getResults());
2369 }
2370}
2371
2372LogicalResult IfOp::fold(FoldAdaptor adaptor,
2373 SmallVectorImpl<OpFoldResult> &results) {
2374 // if (!c) then A() else B() -> if c then B() else A()
2375 if (getElseRegion().empty())
2376 return failure();
2377
2378 arith::XOrIOp xorStmt = getCondition().getDefiningOp<arith::XOrIOp>();
2379 if (!xorStmt)
2380 return failure();
2381
2382 if (!matchPattern(xorStmt.getRhs(), m_One()))
2383 return failure();
2384
2385 getConditionMutable().assign(xorStmt.getLhs());
2386 Block *thenBlock = &getThenRegion().front();
2387 // It would be nicer to use iplist::swap, but that has no implemented
2388 // callbacks See: https://llvm.org/doxygen/ilist_8h_source.html#l00224
2389 getThenRegion().getBlocks().splice(getThenRegion().getBlocks().begin(),
2390 getElseRegion().getBlocks());
2391 getElseRegion().getBlocks().splice(getElseRegion().getBlocks().begin(),
2392 getThenRegion().getBlocks(), thenBlock);
2393 return success();
2394}
2395
2396void IfOp::getRegionInvocationBounds(
2397 ArrayRef<Attribute> operands,
2398 SmallVectorImpl<InvocationBounds> &invocationBounds) {
2399 if (auto cond = llvm::dyn_cast_or_null<BoolAttr>(operands[0])) {
2400 // If the condition is known, then one region is known to be executed once
2401 // and the other zero times.
2402 invocationBounds.emplace_back(0, cond.getValue() ? 1 : 0);
2403 invocationBounds.emplace_back(0, cond.getValue() ? 0 : 1);
2404 } else {
2405 // Non-constant condition. Each region may be executed 0 or 1 times.
2406 invocationBounds.assign(2, {0, 1});
2407 }
2408}
2409
2410namespace {
2411// Pattern to remove unused IfOp results.
2412struct RemoveUnusedResults : public OpRewritePattern<IfOp> {
2413 using OpRewritePattern<IfOp>::OpRewritePattern;
2414
2415 void transferBody(Block *source, Block *dest, ArrayRef<OpResult> usedResults,
2416 PatternRewriter &rewriter) const {
2417 // Move all operations to the destination block.
2418 rewriter.mergeBlocks(source, dest);
2419 // Replace the yield op by one that returns only the used values.
2420 auto yieldOp = cast<scf::YieldOp>(dest->getTerminator());
2421 SmallVector<Value, 4> usedOperands;
2422 llvm::transform(usedResults, std::back_inserter(usedOperands),
2423 [&](OpResult result) {
2424 return yieldOp.getOperand(result.getResultNumber());
2425 });
2426 rewriter.modifyOpInPlace(yieldOp,
2427 [&]() { yieldOp->setOperands(usedOperands); });
2428 }
2429
2430 LogicalResult matchAndRewrite(IfOp op,
2431 PatternRewriter &rewriter) const override {
2432 // Compute the list of used results.
2433 SmallVector<OpResult, 4> usedResults;
2434 llvm::copy_if(op.getResults(), std::back_inserter(usedResults),
2435 [](OpResult result) { return !result.use_empty(); });
2436
2437 // Replace the operation if only a subset of its results have uses.
2438 if (usedResults.size() == op.getNumResults())
2439 return failure();
2440
2441 // Compute the result types of the replacement operation.
2442 SmallVector<Type, 4> newTypes;
2443 llvm::transform(usedResults, std::back_inserter(newTypes),
2444 [](OpResult result) { return result.getType(); });
2445
2446 // Create a replacement operation with empty then and else regions.
2447 auto newOp =
2448 IfOp::create(rewriter, op.getLoc(), newTypes, op.getCondition());
2449 rewriter.createBlock(&newOp.getThenRegion());
2450 rewriter.createBlock(&newOp.getElseRegion());
2451
2452 // Move the bodies and replace the terminators (note there is a then and
2453 // an else region since the operation returns results).
2454 transferBody(op.getBody(0), newOp.getBody(0), usedResults, rewriter);
2455 transferBody(op.getBody(1), newOp.getBody(1), usedResults, rewriter);
2456
2457 // Replace the operation by the new one.
2458 SmallVector<Value, 4> repResults(op.getNumResults());
2459 for (const auto &en : llvm::enumerate(usedResults))
2460 repResults[en.value().getResultNumber()] = newOp.getResult(en.index());
2461 rewriter.replaceOp(op, repResults);
2462 return success();
2463 }
2464};
2465
2466struct RemoveStaticCondition : public OpRewritePattern<IfOp> {
2467 using OpRewritePattern<IfOp>::OpRewritePattern;
2468
2469 LogicalResult matchAndRewrite(IfOp op,
2470 PatternRewriter &rewriter) const override {
2471 BoolAttr condition;
2472 if (!matchPattern(op.getCondition(), m_Constant(&condition)))
2473 return failure();
2474
2475 if (condition.getValue())
2476 replaceOpWithRegion(rewriter, op, op.getThenRegion());
2477 else if (!op.getElseRegion().empty())
2478 replaceOpWithRegion(rewriter, op, op.getElseRegion());
2479 else
2480 rewriter.eraseOp(op);
2481
2482 return success();
2483 }
2484};
2485
2486/// Hoist any yielded results whose operands are defined outside
2487/// the if, to a select instruction.
2488struct ConvertTrivialIfToSelect : public OpRewritePattern<IfOp> {
2489 using OpRewritePattern<IfOp>::OpRewritePattern;
2490
2491 LogicalResult matchAndRewrite(IfOp op,
2492 PatternRewriter &rewriter) const override {
2493 if (op->getNumResults() == 0)
2494 return failure();
2495
2496 auto cond = op.getCondition();
2497 auto thenYieldArgs = op.thenYield().getOperands();
2498 auto elseYieldArgs = op.elseYield().getOperands();
2499
2500 SmallVector<Type> nonHoistable;
2501 for (auto [trueVal, falseVal] : llvm::zip(thenYieldArgs, elseYieldArgs)) {
2502 if (&op.getThenRegion() == trueVal.getParentRegion() ||
2503 &op.getElseRegion() == falseVal.getParentRegion())
2504 nonHoistable.push_back(trueVal.getType());
2505 }
2506 // Early exit if there aren't any yielded values we can
2507 // hoist outside the if.
2508 if (nonHoistable.size() == op->getNumResults())
2509 return failure();
2510
2511 IfOp replacement = IfOp::create(rewriter, op.getLoc(), nonHoistable, cond,
2512 /*withElseRegion=*/false);
2513 if (replacement.thenBlock())
2514 rewriter.eraseBlock(replacement.thenBlock());
2515 replacement.getThenRegion().takeBody(op.getThenRegion());
2516 replacement.getElseRegion().takeBody(op.getElseRegion());
2517
2518 SmallVector<Value> results(op->getNumResults());
2519 assert(thenYieldArgs.size() == results.size());
2520 assert(elseYieldArgs.size() == results.size());
2521
2522 SmallVector<Value> trueYields;
2523 SmallVector<Value> falseYields;
2525 for (const auto &it :
2526 llvm::enumerate(llvm::zip(thenYieldArgs, elseYieldArgs))) {
2527 Value trueVal = std::get<0>(it.value());
2528 Value falseVal = std::get<1>(it.value());
2529 if (&replacement.getThenRegion() == trueVal.getParentRegion() ||
2530 &replacement.getElseRegion() == falseVal.getParentRegion()) {
2531 results[it.index()] = replacement.getResult(trueYields.size());
2532 trueYields.push_back(trueVal);
2533 falseYields.push_back(falseVal);
2534 } else if (trueVal == falseVal)
2535 results[it.index()] = trueVal;
2536 else
2537 results[it.index()] = arith::SelectOp::create(rewriter, op.getLoc(),
2538 cond, trueVal, falseVal);
2539 }
2540
2541 rewriter.setInsertionPointToEnd(replacement.thenBlock());
2542 rewriter.replaceOpWithNewOp<YieldOp>(replacement.thenYield(), trueYields);
2543
2544 rewriter.setInsertionPointToEnd(replacement.elseBlock());
2545 rewriter.replaceOpWithNewOp<YieldOp>(replacement.elseYield(), falseYields);
2546
2547 rewriter.replaceOp(op, results);
2548 return success();
2549 }
2550};
2551
2552/// Allow the true region of an if to assume the condition is true
2553/// and vice versa. For example:
2554///
2555/// scf.if %cmp {
2556/// print(%cmp)
2557/// }
2558///
2559/// becomes
2560///
2561/// scf.if %cmp {
2562/// print(true)
2563/// }
2564///
2565struct ConditionPropagation : public OpRewritePattern<IfOp> {
2566 using OpRewritePattern<IfOp>::OpRewritePattern;
2567
2568 /// Kind of parent region in the ancestor cache.
2569 enum class Parent { Then, Else, None };
2570
2571 /// Returns the kind of region ("then", "else", or "none") of the
2572 /// IfOp that the given region is transitively nested in. Updates
2573 /// the cache accordingly.
2574 static Parent getParentType(Region *toCheck, IfOp op,
2576 Region *endRegion) {
2577 SmallVector<Region *> seen;
2578 while (toCheck != endRegion) {
2579 auto found = cache.find(toCheck);
2580 if (found != cache.end())
2581 return found->second;
2582 seen.push_back(toCheck);
2583 if (&op.getThenRegion() == toCheck) {
2584 for (Region *region : seen)
2585 cache[region] = Parent::Then;
2586 return Parent::Then;
2587 }
2588 if (&op.getElseRegion() == toCheck) {
2589 for (Region *region : seen)
2590 cache[region] = Parent::Else;
2591 return Parent::Else;
2592 }
2593 toCheck = toCheck->getParentRegion();
2594 }
2595
2596 for (Region *region : seen)
2597 cache[region] = Parent::None;
2598 return Parent::None;
2599 }
2600
2601 LogicalResult matchAndRewrite(IfOp op,
2602 PatternRewriter &rewriter) const override {
2603 // Early exit if the condition is constant since replacing a constant
2604 // in the body with another constant isn't a simplification.
2605 if (matchPattern(op.getCondition(), m_Constant()))
2606 return failure();
2607
2608 bool changed = false;
2609 mlir::Type i1Ty = rewriter.getI1Type();
2610
2611 // These variables serve to prevent creating duplicate constants
2612 // and hold constant true or false values.
2613 Value constantTrue = nullptr;
2614 Value constantFalse = nullptr;
2615
2617 for (OpOperand &use :
2618 llvm::make_early_inc_range(op.getCondition().getUses())) {
2619 switch (getParentType(use.getOwner()->getParentRegion(), op, cache,
2620 op.getCondition().getParentRegion())) {
2621 case Parent::Then: {
2622 changed = true;
2623
2624 if (!constantTrue)
2625 constantTrue = arith::ConstantOp::create(
2626 rewriter, op.getLoc(), i1Ty, rewriter.getIntegerAttr(i1Ty, 1));
2627
2628 rewriter.modifyOpInPlace(use.getOwner(),
2629 [&]() { use.set(constantTrue); });
2630 break;
2631 }
2632 case Parent::Else: {
2633 changed = true;
2634
2635 if (!constantFalse)
2636 constantFalse = arith::ConstantOp::create(
2637 rewriter, op.getLoc(), i1Ty, rewriter.getIntegerAttr(i1Ty, 0));
2638
2639 rewriter.modifyOpInPlace(use.getOwner(),
2640 [&]() { use.set(constantFalse); });
2641 break;
2642 }
2643 case Parent::None:
2644 break;
2645 }
2646 }
2647
2648 return success(changed);
2649 }
2650};
2651
2652/// Remove any statements from an if that are equivalent to the condition
2653/// or its negation. For example:
2654///
2655/// %res:2 = scf.if %cmp {
2656/// yield something(), true
2657/// } else {
2658/// yield something2(), false
2659/// }
2660/// print(%res#1)
2661///
2662/// becomes
2663/// %res = scf.if %cmp {
2664/// yield something()
2665/// } else {
2666/// yield something2()
2667/// }
2668/// print(%cmp)
2669///
2670/// Additionally if both branches yield the same value, replace all uses
2671/// of the result with the yielded value.
2672///
2673/// %res:2 = scf.if %cmp {
2674/// yield something(), %arg1
2675/// } else {
2676/// yield something2(), %arg1
2677/// }
2678/// print(%res#1)
2679///
2680/// becomes
2681/// %res = scf.if %cmp {
2682/// yield something()
2683/// } else {
2684/// yield something2()
2685/// }
2686/// print(%arg1)
2687///
2688struct ReplaceIfYieldWithConditionOrValue : public OpRewritePattern<IfOp> {
2689 using OpRewritePattern<IfOp>::OpRewritePattern;
2690
2691 LogicalResult matchAndRewrite(IfOp op,
2692 PatternRewriter &rewriter) const override {
2693 // Early exit if there are no results that could be replaced.
2694 if (op.getNumResults() == 0)
2695 return failure();
2696
2697 auto trueYield =
2698 cast<scf::YieldOp>(op.getThenRegion().back().getTerminator());
2699 auto falseYield =
2700 cast<scf::YieldOp>(op.getElseRegion().back().getTerminator());
2701
2702 rewriter.setInsertionPoint(op->getBlock(),
2703 op.getOperation()->getIterator());
2704 bool changed = false;
2705 Type i1Ty = rewriter.getI1Type();
2706 for (auto [trueResult, falseResult, opResult] :
2707 llvm::zip(trueYield.getResults(), falseYield.getResults(),
2708 op.getResults())) {
2709 if (trueResult == falseResult) {
2710 if (!opResult.use_empty()) {
2711 opResult.replaceAllUsesWith(trueResult);
2712 changed = true;
2713 }
2714 continue;
2715 }
2716
2717 BoolAttr trueYield, falseYield;
2718 if (!matchPattern(trueResult, m_Constant(&trueYield)) ||
2719 !matchPattern(falseResult, m_Constant(&falseYield)))
2720 continue;
2721
2722 bool trueVal = trueYield.getValue();
2723 bool falseVal = falseYield.getValue();
2724 if (!trueVal && falseVal) {
2725 if (!opResult.use_empty()) {
2726 Dialect *constDialect = trueResult.getDefiningOp()->getDialect();
2727 Value notCond = arith::XOrIOp::create(
2728 rewriter, op.getLoc(), op.getCondition(),
2729 constDialect
2730 ->materializeConstant(rewriter,
2731 rewriter.getIntegerAttr(i1Ty, 1), i1Ty,
2732 op.getLoc())
2733 ->getResult(0));
2734 opResult.replaceAllUsesWith(notCond);
2735 changed = true;
2736 }
2737 }
2738 if (trueVal && !falseVal) {
2739 if (!opResult.use_empty()) {
2740 opResult.replaceAllUsesWith(op.getCondition());
2741 changed = true;
2742 }
2743 }
2744 }
2745 return success(changed);
2746 }
2747};
2748
2749/// Merge any consecutive scf.if's with the same condition.
2750///
2751/// scf.if %cond {
2752/// firstCodeTrue();...
2753/// } else {
2754/// firstCodeFalse();...
2755/// }
2756/// %res = scf.if %cond {
2757/// secondCodeTrue();...
2758/// } else {
2759/// secondCodeFalse();...
2760/// }
2761///
2762/// becomes
2763/// %res = scf.if %cmp {
2764/// firstCodeTrue();...
2765/// secondCodeTrue();...
2766/// } else {
2767/// firstCodeFalse();...
2768/// secondCodeFalse();...
2769/// }
2770struct CombineIfs : public OpRewritePattern<IfOp> {
2771 using OpRewritePattern<IfOp>::OpRewritePattern;
2772
2773 LogicalResult matchAndRewrite(IfOp nextIf,
2774 PatternRewriter &rewriter) const override {
2775 Block *parent = nextIf->getBlock();
2776 if (nextIf == &parent->front())
2777 return failure();
2778
2779 auto prevIf = dyn_cast<IfOp>(nextIf->getPrevNode());
2780 if (!prevIf)
2781 return failure();
2782
2783 // Determine the logical then/else blocks when prevIf's
2784 // condition is used. Null means the block does not exist
2785 // in that case (e.g. empty else). If neither of these
2786 // are set, the two conditions cannot be compared.
2787 Block *nextThen = nullptr;
2788 Block *nextElse = nullptr;
2789 if (nextIf.getCondition() == prevIf.getCondition()) {
2790 nextThen = nextIf.thenBlock();
2791 if (!nextIf.getElseRegion().empty())
2792 nextElse = nextIf.elseBlock();
2793 }
2794 if (arith::XOrIOp notv =
2795 nextIf.getCondition().getDefiningOp<arith::XOrIOp>()) {
2796 if (notv.getLhs() == prevIf.getCondition() &&
2797 matchPattern(notv.getRhs(), m_One())) {
2798 nextElse = nextIf.thenBlock();
2799 if (!nextIf.getElseRegion().empty())
2800 nextThen = nextIf.elseBlock();
2801 }
2802 }
2803 if (arith::XOrIOp notv =
2804 prevIf.getCondition().getDefiningOp<arith::XOrIOp>()) {
2805 if (notv.getLhs() == nextIf.getCondition() &&
2806 matchPattern(notv.getRhs(), m_One())) {
2807 nextElse = nextIf.thenBlock();
2808 if (!nextIf.getElseRegion().empty())
2809 nextThen = nextIf.elseBlock();
2810 }
2811 }
2812
2813 if (!nextThen && !nextElse)
2814 return failure();
2815
2816 SmallVector<Value> prevElseYielded;
2817 if (!prevIf.getElseRegion().empty())
2818 prevElseYielded = prevIf.elseYield().getOperands();
2819 // Replace all uses of return values of op within nextIf with the
2820 // corresponding yields
2821 for (auto it : llvm::zip(prevIf.getResults(),
2822 prevIf.thenYield().getOperands(), prevElseYielded))
2823 for (OpOperand &use :
2824 llvm::make_early_inc_range(std::get<0>(it).getUses())) {
2825 if (nextThen && nextThen->getParent()->isAncestor(
2826 use.getOwner()->getParentRegion())) {
2827 rewriter.startOpModification(use.getOwner());
2828 use.set(std::get<1>(it));
2829 rewriter.finalizeOpModification(use.getOwner());
2830 } else if (nextElse && nextElse->getParent()->isAncestor(
2831 use.getOwner()->getParentRegion())) {
2832 rewriter.startOpModification(use.getOwner());
2833 use.set(std::get<2>(it));
2834 rewriter.finalizeOpModification(use.getOwner());
2835 }
2836 }
2837
2838 SmallVector<Type> mergedTypes(prevIf.getResultTypes());
2839 llvm::append_range(mergedTypes, nextIf.getResultTypes());
2840
2841 IfOp combinedIf = IfOp::create(rewriter, nextIf.getLoc(), mergedTypes,
2842 prevIf.getCondition(), /*hasElse=*/false);
2843 rewriter.eraseBlock(&combinedIf.getThenRegion().back());
2844
2845 rewriter.inlineRegionBefore(prevIf.getThenRegion(),
2846 combinedIf.getThenRegion(),
2847 combinedIf.getThenRegion().begin());
2848
2849 if (nextThen) {
2850 YieldOp thenYield = combinedIf.thenYield();
2851 YieldOp thenYield2 = cast<YieldOp>(nextThen->getTerminator());
2852 rewriter.mergeBlocks(nextThen, combinedIf.thenBlock());
2853 rewriter.setInsertionPointToEnd(combinedIf.thenBlock());
2854
2855 SmallVector<Value> mergedYields(thenYield.getOperands());
2856 llvm::append_range(mergedYields, thenYield2.getOperands());
2857 YieldOp::create(rewriter, thenYield2.getLoc(), mergedYields);
2858 rewriter.eraseOp(thenYield);
2859 rewriter.eraseOp(thenYield2);
2860 }
2861
2862 rewriter.inlineRegionBefore(prevIf.getElseRegion(),
2863 combinedIf.getElseRegion(),
2864 combinedIf.getElseRegion().begin());
2865
2866 if (nextElse) {
2867 if (combinedIf.getElseRegion().empty()) {
2868 rewriter.inlineRegionBefore(*nextElse->getParent(),
2869 combinedIf.getElseRegion(),
2870 combinedIf.getElseRegion().begin());
2871 } else {
2872 YieldOp elseYield = combinedIf.elseYield();
2873 YieldOp elseYield2 = cast<YieldOp>(nextElse->getTerminator());
2874 rewriter.mergeBlocks(nextElse, combinedIf.elseBlock());
2875
2876 rewriter.setInsertionPointToEnd(combinedIf.elseBlock());
2877
2878 SmallVector<Value> mergedElseYields(elseYield.getOperands());
2879 llvm::append_range(mergedElseYields, elseYield2.getOperands());
2880
2881 YieldOp::create(rewriter, elseYield2.getLoc(), mergedElseYields);
2882 rewriter.eraseOp(elseYield);
2883 rewriter.eraseOp(elseYield2);
2884 }
2885 }
2886
2887 SmallVector<Value> prevValues;
2888 SmallVector<Value> nextValues;
2889 for (const auto &pair : llvm::enumerate(combinedIf.getResults())) {
2890 if (pair.index() < prevIf.getNumResults())
2891 prevValues.push_back(pair.value());
2892 else
2893 nextValues.push_back(pair.value());
2894 }
2895 rewriter.replaceOp(prevIf, prevValues);
2896 rewriter.replaceOp(nextIf, nextValues);
2897 return success();
2898 }
2899};
2900
2901/// Pattern to remove an empty else branch.
2902struct RemoveEmptyElseBranch : public OpRewritePattern<IfOp> {
2903 using OpRewritePattern<IfOp>::OpRewritePattern;
2904
2905 LogicalResult matchAndRewrite(IfOp ifOp,
2906 PatternRewriter &rewriter) const override {
2907 // Cannot remove else region when there are operation results.
2908 if (ifOp.getNumResults())
2909 return failure();
2910 Block *elseBlock = ifOp.elseBlock();
2911 if (!elseBlock || !llvm::hasSingleElement(*elseBlock))
2912 return failure();
2913 auto newIfOp = rewriter.cloneWithoutRegions(ifOp);
2914 rewriter.inlineRegionBefore(ifOp.getThenRegion(), newIfOp.getThenRegion(),
2915 newIfOp.getThenRegion().begin());
2916 rewriter.eraseOp(ifOp);
2917 return success();
2918 }
2919};
2920
2921/// Convert nested `if`s into `arith.andi` + single `if`.
2922///
2923/// scf.if %arg0 {
2924/// scf.if %arg1 {
2925/// ...
2926/// scf.yield
2927/// }
2928/// scf.yield
2929/// }
2930/// becomes
2931///
2932/// %0 = arith.andi %arg0, %arg1
2933/// scf.if %0 {
2934/// ...
2935/// scf.yield
2936/// }
2937struct CombineNestedIfs : public OpRewritePattern<IfOp> {
2938 using OpRewritePattern<IfOp>::OpRewritePattern;
2939
2940 LogicalResult matchAndRewrite(IfOp op,
2941 PatternRewriter &rewriter) const override {
2942 auto nestedOps = op.thenBlock()->without_terminator();
2943 // Nested `if` must be the only op in block.
2944 if (!llvm::hasSingleElement(nestedOps))
2945 return failure();
2946
2947 // If there is an else block, it can only yield
2948 if (op.elseBlock() && !llvm::hasSingleElement(*op.elseBlock()))
2949 return failure();
2950
2951 auto nestedIf = dyn_cast<IfOp>(*nestedOps.begin());
2952 if (!nestedIf)
2953 return failure();
2954
2955 if (nestedIf.elseBlock() && !llvm::hasSingleElement(*nestedIf.elseBlock()))
2956 return failure();
2957
2958 SmallVector<Value> thenYield(op.thenYield().getOperands());
2959 SmallVector<Value> elseYield;
2960 if (op.elseBlock())
2961 llvm::append_range(elseYield, op.elseYield().getOperands());
2962
2963 // A list of indices for which we should upgrade the value yielded
2964 // in the else to a select.
2965 SmallVector<unsigned> elseYieldsToUpgradeToSelect;
2966
2967 // If the outer scf.if yields a value produced by the inner scf.if,
2968 // only permit combining if the value yielded when the condition
2969 // is false in the outer scf.if is the same value yielded when the
2970 // inner scf.if condition is false.
2971 // Note that the array access to elseYield will not go out of bounds
2972 // since it must have the same length as thenYield, since they both
2973 // come from the same scf.if.
2974 for (const auto &tup : llvm::enumerate(thenYield)) {
2975 if (tup.value().getDefiningOp() == nestedIf) {
2976 auto nestedIdx = llvm::cast<OpResult>(tup.value()).getResultNumber();
2977 if (nestedIf.elseYield().getOperand(nestedIdx) !=
2978 elseYield[tup.index()]) {
2979 return failure();
2980 }
2981 // If the correctness test passes, we will yield
2982 // corresponding value from the inner scf.if
2983 thenYield[tup.index()] = nestedIf.thenYield().getOperand(nestedIdx);
2984 continue;
2985 }
2986
2987 // Otherwise, we need to ensure the else block of the combined
2988 // condition still returns the same value when the outer condition is
2989 // true and the inner condition is false. This can be accomplished if
2990 // the then value is defined outside the outer scf.if and we replace the
2991 // value with a select that considers just the outer condition. Since
2992 // the else region contains just the yield, its yielded value is
2993 // defined outside the scf.if, by definition.
2994
2995 // If the then value is defined within the scf.if, bail.
2996 if (tup.value().getParentRegion() == &op.getThenRegion()) {
2997 return failure();
2998 }
2999 elseYieldsToUpgradeToSelect.push_back(tup.index());
3000 }
3001
3002 Location loc = op.getLoc();
3003 Value newCondition = arith::AndIOp::create(rewriter, loc, op.getCondition(),
3004 nestedIf.getCondition());
3005 auto newIf = IfOp::create(rewriter, loc, op.getResultTypes(), newCondition);
3006 Block *newIfBlock = rewriter.createBlock(&newIf.getThenRegion());
3007
3008 SmallVector<Value> results;
3009 llvm::append_range(results, newIf.getResults());
3010 rewriter.setInsertionPoint(newIf);
3011
3012 for (auto idx : elseYieldsToUpgradeToSelect)
3013 results[idx] =
3014 arith::SelectOp::create(rewriter, op.getLoc(), op.getCondition(),
3015 thenYield[idx], elseYield[idx]);
3016
3017 rewriter.mergeBlocks(nestedIf.thenBlock(), newIfBlock);
3018 rewriter.setInsertionPointToEnd(newIf.thenBlock());
3019 rewriter.replaceOpWithNewOp<YieldOp>(newIf.thenYield(), thenYield);
3020 if (!elseYield.empty()) {
3021 rewriter.createBlock(&newIf.getElseRegion());
3022 rewriter.setInsertionPointToEnd(newIf.elseBlock());
3023 YieldOp::create(rewriter, loc, elseYield);
3024 }
3025 rewriter.replaceOp(op, results);
3026 return success();
3027 }
3028};
3029
3030} // namespace
3031
3032void IfOp::getCanonicalizationPatterns(RewritePatternSet &results,
3033 MLIRContext *context) {
3034 results.add<CombineIfs, CombineNestedIfs, ConditionPropagation,
3035 ConvertTrivialIfToSelect, RemoveEmptyElseBranch,
3036 RemoveStaticCondition, RemoveUnusedResults,
3037 ReplaceIfYieldWithConditionOrValue>(context);
3038}
3039
3040Block *IfOp::thenBlock() { return &getThenRegion().back(); }
3041YieldOp IfOp::thenYield() { return cast<YieldOp>(&thenBlock()->back()); }
3042Block *IfOp::elseBlock() {
3043 Region &r = getElseRegion();
3044 if (r.empty())
3045 return nullptr;
3046 return &r.back();
3047}
3048YieldOp IfOp::elseYield() { return cast<YieldOp>(&elseBlock()->back()); }
3049
3050//===----------------------------------------------------------------------===//
3051// ParallelOp
3052//===----------------------------------------------------------------------===//
3053
3054void ParallelOp::build(
3055 OpBuilder &builder, OperationState &result, ValueRange lowerBounds,
3056 ValueRange upperBounds, ValueRange steps, ValueRange initVals,
3057 function_ref<void(OpBuilder &, Location, ValueRange, ValueRange)>
3058 bodyBuilderFn) {
3059 result.addOperands(lowerBounds);
3060 result.addOperands(upperBounds);
3061 result.addOperands(steps);
3062 result.addOperands(initVals);
3063 result.addAttribute(
3064 ParallelOp::getOperandSegmentSizeAttr(),
3065 builder.getDenseI32ArrayAttr({static_cast<int32_t>(lowerBounds.size()),
3066 static_cast<int32_t>(upperBounds.size()),
3067 static_cast<int32_t>(steps.size()),
3068 static_cast<int32_t>(initVals.size())}));
3069 result.addTypes(initVals.getTypes());
3070
3071 OpBuilder::InsertionGuard guard(builder);
3072 unsigned numIVs = steps.size();
3073 SmallVector<Type, 8> argTypes(numIVs, builder.getIndexType());
3074 SmallVector<Location, 8> argLocs(numIVs, result.location);
3075 Region *bodyRegion = result.addRegion();
3076 Block *bodyBlock = builder.createBlock(bodyRegion, {}, argTypes, argLocs);
3077
3078 if (bodyBuilderFn) {
3079 builder.setInsertionPointToStart(bodyBlock);
3080 bodyBuilderFn(builder, result.location,
3081 bodyBlock->getArguments().take_front(numIVs),
3082 bodyBlock->getArguments().drop_front(numIVs));
3083 }
3084 // Add terminator only if there are no reductions.
3085 if (initVals.empty())
3086 ParallelOp::ensureTerminator(*bodyRegion, builder, result.location);
3087}
3088
3089void ParallelOp::build(
3090 OpBuilder &builder, OperationState &result, ValueRange lowerBounds,
3091 ValueRange upperBounds, ValueRange steps,
3092 function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilderFn) {
3093 // Only pass a non-null wrapper if bodyBuilderFn is non-null itself. Make sure
3094 // we don't capture a reference to a temporary by constructing the lambda at
3095 // function level.
3096 auto wrappedBuilderFn = [&bodyBuilderFn](OpBuilder &nestedBuilder,
3097 Location nestedLoc, ValueRange ivs,
3098 ValueRange) {
3099 bodyBuilderFn(nestedBuilder, nestedLoc, ivs);
3100 };
3101 function_ref<void(OpBuilder &, Location, ValueRange, ValueRange)> wrapper;
3102 if (bodyBuilderFn)
3103 wrapper = wrappedBuilderFn;
3104
3105 build(builder, result, lowerBounds, upperBounds, steps, ValueRange(),
3106 wrapper);
3107}
3108
3109LogicalResult ParallelOp::verify() {
3110 // Check that there is at least one value in lowerBound, upperBound and step.
3111 // It is sufficient to test only step, because it is ensured already that the
3112 // number of elements in lowerBound, upperBound and step are the same.
3113 Operation::operand_range stepValues = getStep();
3114 if (stepValues.empty())
3115 return emitOpError(
3116 "needs at least one tuple element for lowerBound, upperBound and step");
3117
3118 // Check whether all constant step values are positive.
3119 for (Value stepValue : stepValues)
3120 if (auto cst = getConstantIntValue(stepValue))
3121 if (*cst <= 0)
3122 return emitOpError("constant step operand must be positive");
3123
3124 // Check that the body defines the same number of block arguments as the
3125 // number of tuple elements in step.
3126 Block *body = getBody();
3127 if (body->getNumArguments() != stepValues.size())
3128 return emitOpError() << "expects the same number of induction variables: "
3129 << body->getNumArguments()
3130 << " as bound and step values: " << stepValues.size();
3131 for (auto arg : body->getArguments())
3132 if (!arg.getType().isIndex())
3133 return emitOpError(
3134 "expects arguments for the induction variable to be of index type");
3135
3136 // Check that the terminator is an scf.reduce op.
3138 *this, getRegion(), "expects body to terminate with 'scf.reduce'");
3139 if (!reduceOp)
3140 return failure();
3141
3142 // Check that the number of results is the same as the number of reductions.
3143 auto resultsSize = getResults().size();
3144 auto reductionsSize = reduceOp.getReductions().size();
3145 auto initValsSize = getInitVals().size();
3146 if (resultsSize != reductionsSize)
3147 return emitOpError() << "expects number of results: " << resultsSize
3148 << " to be the same as number of reductions: "
3149 << reductionsSize;
3150 if (resultsSize != initValsSize)
3151 return emitOpError() << "expects number of results: " << resultsSize
3152 << " to be the same as number of initial values: "
3153 << initValsSize;
3154
3155 // Check that the types of the results and reductions are the same.
3156 for (int64_t i = 0; i < static_cast<int64_t>(reductionsSize); ++i) {
3157 auto resultType = getOperation()->getResult(i).getType();
3158 auto reductionOperandType = reduceOp.getOperands()[i].getType();
3159 if (resultType != reductionOperandType)
3160 return reduceOp.emitOpError()
3161 << "expects type of " << i
3162 << "-th reduction operand: " << reductionOperandType
3163 << " to be the same as the " << i
3164 << "-th result type: " << resultType;
3165 }
3166 return success();
3167}
3168
3169ParseResult ParallelOp::parse(OpAsmParser &parser, OperationState &result) {
3170 auto &builder = parser.getBuilder();
3171 // Parse an opening `(` followed by induction variables followed by `)`
3172 SmallVector<OpAsmParser::Argument, 4> ivs;
3174 return failure();
3175
3176 // Parse loop bounds.
3177 SmallVector<OpAsmParser::UnresolvedOperand, 4> lower;
3178 if (parser.parseEqual() ||
3179 parser.parseOperandList(lower, ivs.size(),
3181 parser.resolveOperands(lower, builder.getIndexType(), result.operands))
3182 return failure();
3183
3184 SmallVector<OpAsmParser::UnresolvedOperand, 4> upper;
3185 if (parser.parseKeyword("to") ||
3186 parser.parseOperandList(upper, ivs.size(),
3188 parser.resolveOperands(upper, builder.getIndexType(), result.operands))
3189 return failure();
3190
3191 // Parse step values.
3192 SmallVector<OpAsmParser::UnresolvedOperand, 4> steps;
3193 if (parser.parseKeyword("step") ||
3194 parser.parseOperandList(steps, ivs.size(),
3196 parser.resolveOperands(steps, builder.getIndexType(), result.operands))
3197 return failure();
3198
3199 // Parse init values.
3200 SmallVector<OpAsmParser::UnresolvedOperand, 4> initVals;
3201 if (succeeded(parser.parseOptionalKeyword("init"))) {
3202 if (parser.parseOperandList(initVals, OpAsmParser::Delimiter::Paren))
3203 return failure();
3204 }
3205
3206 // Parse optional results in case there is a reduce.
3207 if (parser.parseOptionalArrowTypeList(result.types))
3208 return failure();
3209
3210 // Now parse the body.
3211 Region *body = result.addRegion();
3212 for (auto &iv : ivs)
3213 iv.type = builder.getIndexType();
3214 if (parser.parseRegion(*body, ivs))
3215 return failure();
3216
3217 // Set `operandSegmentSizes` attribute.
3218 result.addAttribute(
3219 ParallelOp::getOperandSegmentSizeAttr(),
3220 builder.getDenseI32ArrayAttr({static_cast<int32_t>(lower.size()),
3221 static_cast<int32_t>(upper.size()),
3222 static_cast<int32_t>(steps.size()),
3223 static_cast<int32_t>(initVals.size())}));
3224
3225 // Parse attributes.
3226 if (parser.parseOptionalAttrDict(result.attributes) ||
3227 parser.resolveOperands(initVals, result.types, parser.getNameLoc(),
3228 result.operands))
3229 return failure();
3230
3231 // Add a terminator if none was parsed.
3232 ParallelOp::ensureTerminator(*body, builder, result.location);
3233 return success();
3234}
3235
3236void ParallelOp::print(OpAsmPrinter &p) {
3237 p << " (" << getBody()->getArguments() << ") = (" << getLowerBound()
3238 << ") to (" << getUpperBound() << ") step (" << getStep() << ")";
3239 if (!getInitVals().empty())
3240 p << " init (" << getInitVals() << ")";
3241 p.printOptionalArrowTypeList(getResultTypes());
3242 p << ' ';
3243 p.printRegion(getRegion(), /*printEntryBlockArgs=*/false);
3245 (*this)->getAttrs(),
3246 /*elidedAttrs=*/ParallelOp::getOperandSegmentSizeAttr());
3247}
3248
3249SmallVector<Region *> ParallelOp::getLoopRegions() { return {&getRegion()}; }
3250
3251std::optional<SmallVector<Value>> ParallelOp::getLoopInductionVars() {
3252 return SmallVector<Value>{getBody()->getArguments()};
3253}
3254
3255std::optional<SmallVector<OpFoldResult>> ParallelOp::getLoopLowerBounds() {
3256 return getLowerBound();
3257}
3258
3259std::optional<SmallVector<OpFoldResult>> ParallelOp::getLoopUpperBounds() {
3260 return getUpperBound();
3261}
3262
3263std::optional<SmallVector<OpFoldResult>> ParallelOp::getLoopSteps() {
3264 return getStep();
3265}
3266
3268 auto ivArg = llvm::dyn_cast<BlockArgument>(val);
3269 if (!ivArg)
3270 return ParallelOp();
3271 assert(ivArg.getOwner() && "unlinked block argument");
3272 auto *containingOp = ivArg.getOwner()->getParentOp();
3273 return dyn_cast<ParallelOp>(containingOp);
3274}
3275
3276namespace {
3277// Collapse loop dimensions that perform a single iteration.
3278struct ParallelOpSingleOrZeroIterationDimsFolder
3279 : public OpRewritePattern<ParallelOp> {
3280 using OpRewritePattern<ParallelOp>::OpRewritePattern;
3281
3282 LogicalResult matchAndRewrite(ParallelOp op,
3283 PatternRewriter &rewriter) const override {
3284 Location loc = op.getLoc();
3285
3286 // Compute new loop bounds that omit all single-iteration loop dimensions.
3287 SmallVector<Value> newLowerBounds, newUpperBounds, newSteps;
3288 IRMapping mapping;
3289 for (auto [lb, ub, step, iv] :
3290 llvm::zip(op.getLowerBound(), op.getUpperBound(), op.getStep(),
3291 op.getInductionVars())) {
3292 auto numIterations =
3293 constantTripCount(lb, ub, step, /*isSigned=*/true, computeUbMinusLb);
3294 if (numIterations.has_value()) {
3295 // Remove the loop if it performs zero iterations.
3296 if (*numIterations == 0) {
3297 rewriter.replaceOp(op, op.getInitVals());
3298 return success();
3299 }
3300 // Replace the loop induction variable by the lower bound if the loop
3301 // performs a single iteration. Otherwise, copy the loop bounds.
3302 if (*numIterations == 1) {
3303 mapping.map(iv, getValueOrCreateConstantIndexOp(rewriter, loc, lb));
3304 continue;
3305 }
3306 }
3307 newLowerBounds.push_back(lb);
3308 newUpperBounds.push_back(ub);
3309 newSteps.push_back(step);
3310 }
3311 // Exit if none of the loop dimensions perform a single iteration.
3312 if (newLowerBounds.size() == op.getLowerBound().size())
3313 return failure();
3314
3315 if (newLowerBounds.empty()) {
3316 // All of the loop dimensions perform a single iteration. Inline
3317 // loop body and nested ReduceOp's
3318 SmallVector<Value> results;
3319 results.reserve(op.getInitVals().size());
3320 for (auto &bodyOp : op.getBody()->without_terminator())
3321 rewriter.clone(bodyOp, mapping);
3322 auto reduceOp = cast<ReduceOp>(op.getBody()->getTerminator());
3323 for (int64_t i = 0, e = reduceOp.getReductions().size(); i < e; ++i) {
3324 Block &reduceBlock = reduceOp.getReductions()[i].front();
3325 auto initValIndex = results.size();
3326 mapping.map(reduceBlock.getArgument(0), op.getInitVals()[initValIndex]);
3327 mapping.map(reduceBlock.getArgument(1),
3328 mapping.lookupOrDefault(reduceOp.getOperands()[i]));
3329 for (auto &reduceBodyOp : reduceBlock.without_terminator())
3330 rewriter.clone(reduceBodyOp, mapping);
3331
3332 auto result = mapping.lookupOrDefault(
3333 cast<ReduceReturnOp>(reduceBlock.getTerminator()).getResult());
3334 results.push_back(result);
3335 }
3336
3337 rewriter.replaceOp(op, results);
3338 return success();
3339 }
3340 // Replace the parallel loop by lower-dimensional parallel loop.
3341 auto newOp =
3342 ParallelOp::create(rewriter, op.getLoc(), newLowerBounds,
3343 newUpperBounds, newSteps, op.getInitVals(), nullptr);
3344 // Erase the empty block that was inserted by the builder.
3345 rewriter.eraseBlock(newOp.getBody());
3346 // Clone the loop body and remap the block arguments of the collapsed loops
3347 // (inlining does not support a cancellable block argument mapping).
3348 rewriter.cloneRegionBefore(op.getRegion(), newOp.getRegion(),
3349 newOp.getRegion().begin(), mapping);
3350 rewriter.replaceOp(op, newOp.getResults());
3351 return success();
3352 }
3353};
3354
3355struct MergeNestedParallelLoops : public OpRewritePattern<ParallelOp> {
3356 using OpRewritePattern<ParallelOp>::OpRewritePattern;
3357
3358 LogicalResult matchAndRewrite(ParallelOp op,
3359 PatternRewriter &rewriter) const override {
3360 Block &outerBody = *op.getBody();
3361 if (!llvm::hasSingleElement(outerBody.without_terminator()))
3362 return failure();
3363
3364 auto innerOp = dyn_cast<ParallelOp>(outerBody.front());
3365 if (!innerOp)
3366 return failure();
3367
3368 for (auto val : outerBody.getArguments())
3369 if (llvm::is_contained(innerOp.getLowerBound(), val) ||
3370 llvm::is_contained(innerOp.getUpperBound(), val) ||
3371 llvm::is_contained(innerOp.getStep(), val))
3372 return failure();
3373
3374 // Reductions are not supported yet.
3375 if (!op.getInitVals().empty() || !innerOp.getInitVals().empty())
3376 return failure();
3377
3378 auto bodyBuilder = [&](OpBuilder &builder, Location /*loc*/,
3379 ValueRange iterVals, ValueRange) {
3380 Block &innerBody = *innerOp.getBody();
3381 assert(iterVals.size() ==
3382 (outerBody.getNumArguments() + innerBody.getNumArguments()));
3383 IRMapping mapping;
3384 mapping.map(outerBody.getArguments(),
3385 iterVals.take_front(outerBody.getNumArguments()));
3386 mapping.map(innerBody.getArguments(),
3387 iterVals.take_back(innerBody.getNumArguments()));
3388 for (Operation &op : innerBody.without_terminator())
3389 builder.clone(op, mapping);
3390 };
3391
3392 auto concatValues = [](const auto &first, const auto &second) {
3393 SmallVector<Value> ret;
3394 ret.reserve(first.size() + second.size());
3395 ret.assign(first.begin(), first.end());
3396 ret.append(second.begin(), second.end());
3397 return ret;
3398 };
3399
3400 auto newLowerBounds =
3401 concatValues(op.getLowerBound(), innerOp.getLowerBound());
3402 auto newUpperBounds =
3403 concatValues(op.getUpperBound(), innerOp.getUpperBound());
3404 auto newSteps = concatValues(op.getStep(), innerOp.getStep());
3405
3406 rewriter.replaceOpWithNewOp<ParallelOp>(op, newLowerBounds, newUpperBounds,
3407 newSteps, ValueRange(),
3408 bodyBuilder);
3409 return success();
3410 }
3411};
3412
3413} // namespace
3414
3415void ParallelOp::getCanonicalizationPatterns(RewritePatternSet &results,
3416 MLIRContext *context) {
3417 results
3418 .add<ParallelOpSingleOrZeroIterationDimsFolder, MergeNestedParallelLoops>(
3419 context);
3420}
3421
3422/// Given the region at `index`, or the parent operation if `index` is None,
3423/// return the successor regions. These are the regions that may be selected
3424/// during the flow of control. `operands` is a set of optional attributes that
3425/// correspond to a constant value for each operand, or null if that operand is
3426/// not a constant.
3427void ParallelOp::getSuccessorRegions(
3428 RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {
3429 // Both the operation itself and the region may be branching into the body or
3430 // back into the operation itself. It is possible for loop not to enter the
3431 // body.
3432 regions.push_back(RegionSuccessor(&getRegion()));
3433 regions.push_back(RegionSuccessor(
3434 getOperation(), ResultRange{getResults().end(), getResults().end()}));
3435}
3436
3437//===----------------------------------------------------------------------===//
3438// ReduceOp
3439//===----------------------------------------------------------------------===//
3440
3441void ReduceOp::build(OpBuilder &builder, OperationState &result) {}
3442
3443void ReduceOp::build(OpBuilder &builder, OperationState &result,
3444 ValueRange operands) {
3445 result.addOperands(operands);
3446 for (Value v : operands) {
3447 OpBuilder::InsertionGuard guard(builder);
3448 Region *bodyRegion = result.addRegion();
3449 builder.createBlock(bodyRegion, {},
3450 ArrayRef<Type>{v.getType(), v.getType()},
3451 {result.location, result.location});
3452 }
3453}
3454
3455LogicalResult ReduceOp::verifyRegions() {
3456 // The region of a ReduceOp has two arguments of the same type as its
3457 // corresponding operand.
3458 for (int64_t i = 0, e = getReductions().size(); i < e; ++i) {
3459 auto type = getOperands()[i].getType();
3460 Block &block = getReductions()[i].front();
3461 if (block.empty())
3462 return emitOpError() << i << "-th reduction has an empty body";
3463 if (block.getNumArguments() != 2 ||
3464 llvm::any_of(block.getArguments(), [&](const BlockArgument &arg) {
3465 return arg.getType() != type;
3466 }))
3467 return emitOpError() << "expected two block arguments with type " << type
3468 << " in the " << i << "-th reduction region";
3469
3470 // Check that the block is terminated by a ReduceReturnOp.
3471 if (!isa<ReduceReturnOp>(block.getTerminator()))
3472 return emitOpError("reduction bodies must be terminated with an "
3473 "'scf.reduce.return' op");
3474 }
3475
3476 return success();
3477}
3478
3479MutableOperandRange
3480ReduceOp::getMutableSuccessorOperands(RegionSuccessor point) {
3481 // No operands are forwarded to the next iteration.
3482 return MutableOperandRange(getOperation(), /*start=*/0, /*length=*/0);
3483}
3484
3485//===----------------------------------------------------------------------===//
3486// ReduceReturnOp
3487//===----------------------------------------------------------------------===//
3488
3489LogicalResult ReduceReturnOp::verify() {
3490 // The type of the return value should be the same type as the types of the
3491 // block arguments of the reduction body.
3492 Block *reductionBody = getOperation()->getBlock();
3493 // Should already be verified by an op trait.
3494 assert(isa<ReduceOp>(reductionBody->getParentOp()) && "expected scf.reduce");
3495 Type expectedResultType = reductionBody->getArgument(0).getType();
3496 if (expectedResultType != getResult().getType())
3497 return emitOpError() << "must have type " << expectedResultType
3498 << " (the type of the reduction inputs)";
3499 return success();
3500}
3501
3502//===----------------------------------------------------------------------===//
3503// WhileOp
3504//===----------------------------------------------------------------------===//
3505
3506void WhileOp::build(::mlir::OpBuilder &odsBuilder,
3507 ::mlir::OperationState &odsState, TypeRange resultTypes,
3508 ValueRange inits, BodyBuilderFn beforeBuilder,
3509 BodyBuilderFn afterBuilder) {
3510 odsState.addOperands(inits);
3511 odsState.addTypes(resultTypes);
3512
3513 OpBuilder::InsertionGuard guard(odsBuilder);
3514
3515 // Build before region.
3516 SmallVector<Location, 4> beforeArgLocs;
3517 beforeArgLocs.reserve(inits.size());
3518 for (Value operand : inits) {
3519 beforeArgLocs.push_back(operand.getLoc());
3520 }
3521
3522 Region *beforeRegion = odsState.addRegion();
3523 Block *beforeBlock = odsBuilder.createBlock(beforeRegion, /*insertPt=*/{},
3524 inits.getTypes(), beforeArgLocs);
3525 if (beforeBuilder)
3526 beforeBuilder(odsBuilder, odsState.location, beforeBlock->getArguments());
3527
3528 // Build after region.
3529 SmallVector<Location, 4> afterArgLocs(resultTypes.size(), odsState.location);
3530
3531 Region *afterRegion = odsState.addRegion();
3532 Block *afterBlock = odsBuilder.createBlock(afterRegion, /*insertPt=*/{},
3533 resultTypes, afterArgLocs);
3534
3535 if (afterBuilder)
3536 afterBuilder(odsBuilder, odsState.location, afterBlock->getArguments());
3537}
3538
3539ConditionOp WhileOp::getConditionOp() {
3540 return cast<ConditionOp>(getBeforeBody()->getTerminator());
3541}
3542
3543YieldOp WhileOp::getYieldOp() {
3544 return cast<YieldOp>(getAfterBody()->getTerminator());
3545}
3546
3547std::optional<MutableArrayRef<OpOperand>> WhileOp::getYieldedValuesMutable() {
3548 return getYieldOp().getResultsMutable();
3549}
3550
3551Block::BlockArgListType WhileOp::getBeforeArguments() {
3552 return getBeforeBody()->getArguments();
3553}
3554
3555Block::BlockArgListType WhileOp::getAfterArguments() {
3556 return getAfterBody()->getArguments();
3557}
3558
3559Block::BlockArgListType WhileOp::getRegionIterArgs() {
3560 return getBeforeArguments();
3561}
3562
3563OperandRange WhileOp::getEntrySuccessorOperands(RegionSuccessor successor) {
3564 assert(successor.getSuccessor() == &getBefore() &&
3565 "WhileOp is expected to branch only to the first region");
3566 return getInits();
3567}
3568
3569void WhileOp::getSuccessorRegions(RegionBranchPoint point,
3570 SmallVectorImpl<RegionSuccessor> &regions) {
3571 // The parent op always branches to the condition region.
3572 if (point.isParent()) {
3573 regions.emplace_back(&getBefore(), getBefore().getArguments());
3574 return;
3575 }
3576
3577 assert(llvm::is_contained(
3578 {&getAfter(), &getBefore()},
3579 point.getTerminatorPredecessorOrNull()->getParentRegion()) &&
3580 "there are only two regions in a WhileOp");
3581 // The body region always branches back to the condition region.
3583 &getAfter()) {
3584 regions.emplace_back(&getBefore(), getBefore().getArguments());
3585 return;
3586 }
3587
3588 regions.emplace_back(getOperation(), getResults());
3589 regions.emplace_back(&getAfter(), getAfter().getArguments());
3590}
3591
3592SmallVector<Region *> WhileOp::getLoopRegions() {
3593 return {&getBefore(), &getAfter()};
3594}
3595
3596/// Parses a `while` op.
3597///
3598/// op ::= `scf.while` assignments `:` function-type region `do` region
3599/// `attributes` attribute-dict
3600/// initializer ::= /* empty */ | `(` assignment-list `)`
3601/// assignment-list ::= assignment | assignment `,` assignment-list
3602/// assignment ::= ssa-value `=` ssa-value
3603ParseResult scf::WhileOp::parse(OpAsmParser &parser, OperationState &result) {
3604 SmallVector<OpAsmParser::Argument, 4> regionArgs;
3605 SmallVector<OpAsmParser::UnresolvedOperand, 4> operands;
3606 Region *before = result.addRegion();
3607 Region *after = result.addRegion();
3608
3609 OptionalParseResult listResult =
3610 parser.parseOptionalAssignmentList(regionArgs, operands);
3611 if (listResult.has_value() && failed(listResult.value()))
3612 return failure();
3613
3614 FunctionType functionType;
3615 SMLoc typeLoc = parser.getCurrentLocation();
3616 if (failed(parser.parseColonType(functionType)))
3617 return failure();
3618
3619 result.addTypes(functionType.getResults());
3620
3621 if (functionType.getNumInputs() != operands.size()) {
3622 return parser.emitError(typeLoc)
3623 << "expected as many input types as operands " << "(expected "
3624 << operands.size() << " got " << functionType.getNumInputs() << ")";
3625 }
3626
3627 // Resolve input operands.
3628 if (failed(parser.resolveOperands(operands, functionType.getInputs(),
3629 parser.getCurrentLocation(),
3630 result.operands)))
3631 return failure();
3632
3633 // Propagate the types into the region arguments.
3634 for (size_t i = 0, e = regionArgs.size(); i != e; ++i)
3635 regionArgs[i].type = functionType.getInput(i);
3636
3637 return failure(parser.parseRegion(*before, regionArgs) ||
3638 parser.parseKeyword("do") || parser.parseRegion(*after) ||
3639 parser.parseOptionalAttrDictWithKeyword(result.attributes));
3640}
3641
3642/// Prints a `while` op.
3643void scf::WhileOp::print(OpAsmPrinter &p) {
3644 printInitializationList(p, getBeforeArguments(), getInits(), " ");
3645 p << " : ";
3646 p.printFunctionalType(getInits().getTypes(), getResults().getTypes());
3647 p << ' ';
3648 p.printRegion(getBefore(), /*printEntryBlockArgs=*/false);
3649 p << " do ";
3650 p.printRegion(getAfter());
3651 p.printOptionalAttrDictWithKeyword((*this)->getAttrs());
3652}
3653
3654/// Verifies that two ranges of types match, i.e. have the same number of
3655/// entries and that types are pairwise equals. Reports errors on the given
3656/// operation in case of mismatch.
3657template <typename OpTy>
3658static LogicalResult verifyTypeRangesMatch(OpTy op, TypeRange left,
3659 TypeRange right, StringRef message) {
3660 if (left.size() != right.size())
3661 return op.emitOpError("expects the same number of ") << message;
3662
3663 for (unsigned i = 0, e = left.size(); i < e; ++i) {
3664 if (left[i] != right[i]) {
3665 InFlightDiagnostic diag = op.emitOpError("expects the same types for ")
3666 << message;
3667 diag.attachNote() << "for argument " << i << ", found " << left[i]
3668 << " and " << right[i];
3669 return diag;
3670 }
3671 }
3672
3673 return success();
3674}
3675
3676LogicalResult scf::WhileOp::verify() {
3677 auto beforeTerminator = verifyAndGetTerminator<scf::ConditionOp>(
3678 *this, getBefore(),
3679 "expects the 'before' region to terminate with 'scf.condition'");
3680 if (!beforeTerminator)
3681 return failure();
3682
3683 auto afterTerminator = verifyAndGetTerminator<scf::YieldOp>(
3684 *this, getAfter(),
3685 "expects the 'after' region to terminate with 'scf.yield'");
3686 return success(afterTerminator != nullptr);
3687}
3688
3689namespace {
3690/// Replace uses of the condition within the do block with true, since otherwise
3691/// the block would not be evaluated.
3692///
3693/// scf.while (..) : (i1, ...) -> ... {
3694/// %condition = call @evaluate_condition() : () -> i1
3695/// scf.condition(%condition) %condition : i1, ...
3696/// } do {
3697/// ^bb0(%arg0: i1, ...):
3698/// use(%arg0)
3699/// ...
3700///
3701/// becomes
3702/// scf.while (..) : (i1, ...) -> ... {
3703/// %condition = call @evaluate_condition() : () -> i1
3704/// scf.condition(%condition) %condition : i1, ...
3705/// } do {
3706/// ^bb0(%arg0: i1, ...):
3707/// use(%true)
3708/// ...
3709struct WhileConditionTruth : public OpRewritePattern<WhileOp> {
3710 using OpRewritePattern<WhileOp>::OpRewritePattern;
3711
3712 LogicalResult matchAndRewrite(WhileOp op,
3713 PatternRewriter &rewriter) const override {
3714 auto term = op.getConditionOp();
3715
3716 // These variables serve to prevent creating duplicate constants
3717 // and hold constant true or false values.
3718 Value constantTrue = nullptr;
3719
3720 bool replaced = false;
3721 for (auto yieldedAndBlockArgs :
3722 llvm::zip(term.getArgs(), op.getAfterArguments())) {
3723 if (std::get<0>(yieldedAndBlockArgs) == term.getCondition()) {
3724 if (!std::get<1>(yieldedAndBlockArgs).use_empty()) {
3725 if (!constantTrue)
3726 constantTrue = arith::ConstantOp::create(
3727 rewriter, op.getLoc(), term.getCondition().getType(),
3728 rewriter.getBoolAttr(true));
3729
3730 rewriter.replaceAllUsesWith(std::get<1>(yieldedAndBlockArgs),
3731 constantTrue);
3732 replaced = true;
3733 }
3734 }
3735 }
3736 return success(replaced);
3737 }
3738};
3739
3740/// Remove loop invariant arguments from `before` block of scf.while.
3741/// A before block argument is considered loop invariant if :-
3742/// 1. i-th yield operand is equal to the i-th while operand.
3743/// 2. i-th yield operand is k-th after block argument which is (k+1)-th
3744/// condition operand AND this (k+1)-th condition operand is equal to i-th
3745/// iter argument/while operand.
3746/// For the arguments which are removed, their uses inside scf.while
3747/// are replaced with their corresponding initial value.
3748///
3749/// Eg:
3750/// INPUT :-
3751/// %res = scf.while <...> iter_args(%arg0_before = %a, %arg1_before = %b,
3752/// ..., %argN_before = %N)
3753/// {
3754/// ...
3755/// scf.condition(%cond) %arg1_before, %arg0_before,
3756/// %arg2_before, %arg0_before, ...
3757/// } do {
3758/// ^bb0(%arg1_after, %arg0_after_1, %arg2_after, %arg0_after_2,
3759/// ..., %argK_after):
3760/// ...
3761/// scf.yield %arg0_after_2, %b, %arg1_after, ..., %argN
3762/// }
3763///
3764/// OUTPUT :-
3765/// %res = scf.while <...> iter_args(%arg2_before = %c, ..., %argN_before =
3766/// %N)
3767/// {
3768/// ...
3769/// scf.condition(%cond) %b, %a, %arg2_before, %a, ...
3770/// } do {
3771/// ^bb0(%arg1_after, %arg0_after_1, %arg2_after, %arg0_after_2,
3772/// ..., %argK_after):
3773/// ...
3774/// scf.yield %arg1_after, ..., %argN
3775/// }
3776///
3777/// EXPLANATION:
3778/// We iterate over each yield operand.
3779/// 1. 0-th yield operand %arg0_after_2 is 4-th condition operand
3780/// %arg0_before, which in turn is the 0-th iter argument. So we
3781/// remove 0-th before block argument and yield operand, and replace
3782/// all uses of the 0-th before block argument with its initial value
3783/// %a.
3784/// 2. 1-th yield operand %b is equal to the 1-th iter arg's initial
3785/// value. So we remove this operand and the corresponding before
3786/// block argument and replace all uses of 1-th before block argument
3787/// with %b.
3788struct RemoveLoopInvariantArgsFromBeforeBlock
3789 : public OpRewritePattern<WhileOp> {
3790 using OpRewritePattern<WhileOp>::OpRewritePattern;
3791
3792 LogicalResult matchAndRewrite(WhileOp op,
3793 PatternRewriter &rewriter) const override {
3794 Block &afterBlock = *op.getAfterBody();
3795 Block::BlockArgListType beforeBlockArgs = op.getBeforeArguments();
3796 ConditionOp condOp = op.getConditionOp();
3797 OperandRange condOpArgs = condOp.getArgs();
3798 Operation *yieldOp = afterBlock.getTerminator();
3799 ValueRange yieldOpArgs = yieldOp->getOperands();
3800
3801 bool canSimplify = false;
3802 for (const auto &it :
3803 llvm::enumerate(llvm::zip(op.getOperands(), yieldOpArgs))) {
3804 auto index = static_cast<unsigned>(it.index());
3805 auto [initVal, yieldOpArg] = it.value();
3806 // If i-th yield operand is equal to the i-th operand of the scf.while,
3807 // the i-th before block argument is a loop invariant.
3808 if (yieldOpArg == initVal) {
3809 canSimplify = true;
3810 break;
3811 }
3812 // If the i-th yield operand is k-th after block argument, then we check
3813 // if the (k+1)-th condition op operand is equal to either the i-th before
3814 // block argument or the initial value of i-th before block argument. If
3815 // the comparison results `true`, i-th before block argument is a loop
3816 // invariant.
3817 auto yieldOpBlockArg = llvm::dyn_cast<BlockArgument>(yieldOpArg);
3818 if (yieldOpBlockArg && yieldOpBlockArg.getOwner() == &afterBlock) {
3819 Value condOpArg = condOpArgs[yieldOpBlockArg.getArgNumber()];
3820 if (condOpArg == beforeBlockArgs[index] || condOpArg == initVal) {
3821 canSimplify = true;
3822 break;
3823 }
3824 }
3825 }
3826
3827 if (!canSimplify)
3828 return failure();
3829
3830 SmallVector<Value> newInitArgs, newYieldOpArgs;
3831 DenseMap<unsigned, Value> beforeBlockInitValMap;
3832 SmallVector<Location> newBeforeBlockArgLocs;
3833 for (const auto &it :
3834 llvm::enumerate(llvm::zip(op.getOperands(), yieldOpArgs))) {
3835 auto index = static_cast<unsigned>(it.index());
3836 auto [initVal, yieldOpArg] = it.value();
3837
3838 // If i-th yield operand is equal to the i-th operand of the scf.while,
3839 // the i-th before block argument is a loop invariant.
3840 if (yieldOpArg == initVal) {
3841 beforeBlockInitValMap.insert({index, initVal});
3842 continue;
3843 } else {
3844 // If the i-th yield operand is k-th after block argument, then we check
3845 // if the (k+1)-th condition op operand is equal to either the i-th
3846 // before block argument or the initial value of i-th before block
3847 // argument. If the comparison results `true`, i-th before block
3848 // argument is a loop invariant.
3849 auto yieldOpBlockArg = llvm::dyn_cast<BlockArgument>(yieldOpArg);
3850 if (yieldOpBlockArg && yieldOpBlockArg.getOwner() == &afterBlock) {
3851 Value condOpArg = condOpArgs[yieldOpBlockArg.getArgNumber()];
3852 if (condOpArg == beforeBlockArgs[index] || condOpArg == initVal) {
3853 beforeBlockInitValMap.insert({index, initVal});
3854 continue;
3855 }
3856 }
3857 }
3858 newInitArgs.emplace_back(initVal);
3859 newYieldOpArgs.emplace_back(yieldOpArg);
3860 newBeforeBlockArgLocs.emplace_back(beforeBlockArgs[index].getLoc());
3861 }
3862
3863 {
3864 OpBuilder::InsertionGuard g(rewriter);
3865 rewriter.setInsertionPoint(yieldOp);
3866 rewriter.replaceOpWithNewOp<YieldOp>(yieldOp, newYieldOpArgs);
3867 }
3868
3869 auto newWhile = WhileOp::create(rewriter, op.getLoc(), op.getResultTypes(),
3870 newInitArgs);
3871
3872 Block &newBeforeBlock = *rewriter.createBlock(
3873 &newWhile.getBefore(), /*insertPt*/ {},
3874 ValueRange(newYieldOpArgs).getTypes(), newBeforeBlockArgLocs);
3875
3876 Block &beforeBlock = *op.getBeforeBody();
3877 SmallVector<Value> newBeforeBlockArgs(beforeBlock.getNumArguments());
3878 // For each i-th before block argument we find it's replacement value as :-
3879 // 1. If i-th before block argument is a loop invariant, we fetch it's
3880 // initial value from `beforeBlockInitValMap` by querying for key `i`.
3881 // 2. Else we fetch j-th new before block argument as the replacement
3882 // value of i-th before block argument.
3883 for (unsigned i = 0, j = 0, n = beforeBlock.getNumArguments(); i < n; i++) {
3884 // If the index 'i' argument was a loop invariant we fetch it's initial
3885 // value from `beforeBlockInitValMap`.
3886 if (beforeBlockInitValMap.count(i) != 0)
3887 newBeforeBlockArgs[i] = beforeBlockInitValMap[i];
3888 else
3889 newBeforeBlockArgs[i] = newBeforeBlock.getArgument(j++);
3890 }
3891
3892 rewriter.mergeBlocks(&beforeBlock, &newBeforeBlock, newBeforeBlockArgs);
3893 rewriter.inlineRegionBefore(op.getAfter(), newWhile.getAfter(),
3894 newWhile.getAfter().begin());
3895
3896 rewriter.replaceOp(op, newWhile.getResults());
3897 return success();
3898 }
3899};
3900
3901/// Remove loop invariant value from result (condition op) of scf.while.
3902/// A value is considered loop invariant if the final value yielded by
3903/// scf.condition is defined outside of the `before` block. We remove the
3904/// corresponding argument in `after` block and replace the use with the value.
3905/// We also replace the use of the corresponding result of scf.while with the
3906/// value.
3907///
3908/// Eg:
3909/// INPUT :-
3910/// %res_input:K = scf.while <...> iter_args(%arg0_before = , ...,
3911/// %argN_before = %N) {
3912/// ...
3913/// scf.condition(%cond) %arg0_before, %a, %b, %arg1_before, ...
3914/// } do {
3915/// ^bb0(%arg0_after, %arg1_after, %arg2_after, ..., %argK_after):
3916/// ...
3917/// some_func(%arg1_after)
3918/// ...
3919/// scf.yield %arg0_after, %arg2_after, ..., %argN_after
3920/// }
3921///
3922/// OUTPUT :-
3923/// %res_output:M = scf.while <...> iter_args(%arg0 = , ..., %argN = %N) {
3924/// ...
3925/// scf.condition(%cond) %arg0, %arg1, ..., %argM
3926/// } do {
3927/// ^bb0(%arg0, %arg3, ..., %argM):
3928/// ...
3929/// some_func(%a)
3930/// ...
3931/// scf.yield %arg0, %b, ..., %argN
3932/// }
3933///
3934/// EXPLANATION:
3935/// 1. The 1-th and 2-th operand of scf.condition are defined outside the
3936/// before block of scf.while, so they get removed.
3937/// 2. %res_input#1's uses are replaced by %a and %res_input#2's uses are
3938/// replaced by %b.
3939/// 3. The corresponding after block argument %arg1_after's uses are
3940/// replaced by %a and %arg2_after's uses are replaced by %b.
3941struct RemoveLoopInvariantValueYielded : public OpRewritePattern<WhileOp> {
3942 using OpRewritePattern<WhileOp>::OpRewritePattern;
3943
3944 LogicalResult matchAndRewrite(WhileOp op,
3945 PatternRewriter &rewriter) const override {
3946 Block &beforeBlock = *op.getBeforeBody();
3947 ConditionOp condOp = op.getConditionOp();
3948 OperandRange condOpArgs = condOp.getArgs();
3949
3950 bool canSimplify = false;
3951 for (Value condOpArg : condOpArgs) {
3952 // Those values not defined within `before` block will be considered as
3953 // loop invariant values. We map the corresponding `index` with their
3954 // value.
3955 if (condOpArg.getParentBlock() != &beforeBlock) {
3956 canSimplify = true;
3957 break;
3958 }
3959 }
3960
3961 if (!canSimplify)
3962 return failure();
3963
3964 Block::BlockArgListType afterBlockArgs = op.getAfterArguments();
3965
3966 SmallVector<Value> newCondOpArgs;
3967 SmallVector<Type> newAfterBlockType;
3968 DenseMap<unsigned, Value> condOpInitValMap;
3969 SmallVector<Location> newAfterBlockArgLocs;
3970 for (const auto &it : llvm::enumerate(condOpArgs)) {
3971 auto index = static_cast<unsigned>(it.index());
3972 Value condOpArg = it.value();
3973 // Those values not defined within `before` block will be considered as
3974 // loop invariant values. We map the corresponding `index` with their
3975 // value.
3976 if (condOpArg.getParentBlock() != &beforeBlock) {
3977 condOpInitValMap.insert({index, condOpArg});
3978 } else {
3979 newCondOpArgs.emplace_back(condOpArg);
3980 newAfterBlockType.emplace_back(condOpArg.getType());
3981 newAfterBlockArgLocs.emplace_back(afterBlockArgs[index].getLoc());
3982 }
3983 }
3984
3985 {
3986 OpBuilder::InsertionGuard g(rewriter);
3987 rewriter.setInsertionPoint(condOp);
3988 rewriter.replaceOpWithNewOp<ConditionOp>(condOp, condOp.getCondition(),
3989 newCondOpArgs);
3990 }
3991
3992 auto newWhile = WhileOp::create(rewriter, op.getLoc(), newAfterBlockType,
3993 op.getOperands());
3994
3995 Block &newAfterBlock =
3996 *rewriter.createBlock(&newWhile.getAfter(), /*insertPt*/ {},
3997 newAfterBlockType, newAfterBlockArgLocs);
3998
3999 Block &afterBlock = *op.getAfterBody();
4000 // Since a new scf.condition op was created, we need to fetch the new
4001 // `after` block arguments which will be used while replacing operations of
4002 // previous scf.while's `after` blocks. We'd also be fetching new result
4003 // values too.
4004 SmallVector<Value> newAfterBlockArgs(afterBlock.getNumArguments());
4005 SmallVector<Value> newWhileResults(afterBlock.getNumArguments());
4006 for (unsigned i = 0, j = 0, n = afterBlock.getNumArguments(); i < n; i++) {
4007 Value afterBlockArg, result;
4008 // If index 'i' argument was loop invariant we fetch it's value from the
4009 // `condOpInitMap` map.
4010 if (condOpInitValMap.count(i) != 0) {
4011 afterBlockArg = condOpInitValMap[i];
4012 result = afterBlockArg;
4013 } else {
4014 afterBlockArg = newAfterBlock.getArgument(j);
4015 result = newWhile.getResult(j);
4016 j++;
4017 }
4018 newAfterBlockArgs[i] = afterBlockArg;
4019 newWhileResults[i] = result;
4020 }
4021
4022 rewriter.mergeBlocks(&afterBlock, &newAfterBlock, newAfterBlockArgs);
4023 rewriter.inlineRegionBefore(op.getBefore(), newWhile.getBefore(),
4024 newWhile.getBefore().begin());
4025
4026 rewriter.replaceOp(op, newWhileResults);
4027 return success();
4028 }
4029};
4030
4031/// Remove WhileOp results that are also unused in 'after' block.
4032///
4033/// %0:2 = scf.while () : () -> (i32, i64) {
4034/// %condition = "test.condition"() : () -> i1
4035/// %v1 = "test.get_some_value"() : () -> i32
4036/// %v2 = "test.get_some_value"() : () -> i64
4037/// scf.condition(%condition) %v1, %v2 : i32, i64
4038/// } do {
4039/// ^bb0(%arg0: i32, %arg1: i64):
4040/// "test.use"(%arg0) : (i32) -> ()
4041/// scf.yield
4042/// }
4043/// return %0#0 : i32
4044///
4045/// becomes
4046/// %0 = scf.while () : () -> (i32) {
4047/// %condition = "test.condition"() : () -> i1
4048/// %v1 = "test.get_some_value"() : () -> i32
4049/// %v2 = "test.get_some_value"() : () -> i64
4050/// scf.condition(%condition) %v1 : i32
4051/// } do {
4052/// ^bb0(%arg0: i32):
4053/// "test.use"(%arg0) : (i32) -> ()
4054/// scf.yield
4055/// }
4056/// return %0 : i32
4057struct WhileUnusedResult : public OpRewritePattern<WhileOp> {
4058 using OpRewritePattern<WhileOp>::OpRewritePattern;
4059
4060 LogicalResult matchAndRewrite(WhileOp op,
4061 PatternRewriter &rewriter) const override {
4062 auto term = op.getConditionOp();
4063 auto afterArgs = op.getAfterArguments();
4064 auto termArgs = term.getArgs();
4065
4066 // Collect results mapping, new terminator args and new result types.
4067 SmallVector<unsigned> newResultsIndices;
4068 SmallVector<Type> newResultTypes;
4069 SmallVector<Value> newTermArgs;
4070 SmallVector<Location> newArgLocs;
4071 bool needUpdate = false;
4072 for (const auto &it :
4073 llvm::enumerate(llvm::zip(op.getResults(), afterArgs, termArgs))) {
4074 auto i = static_cast<unsigned>(it.index());
4075 Value result = std::get<0>(it.value());
4076 Value afterArg = std::get<1>(it.value());
4077 Value termArg = std::get<2>(it.value());
4078 if (result.use_empty() && afterArg.use_empty()) {
4079 needUpdate = true;
4080 } else {
4081 newResultsIndices.emplace_back(i);
4082 newTermArgs.emplace_back(termArg);
4083 newResultTypes.emplace_back(result.getType());
4084 newArgLocs.emplace_back(result.getLoc());
4085 }
4086 }
4087
4088 if (!needUpdate)
4089 return failure();
4090
4091 {
4092 OpBuilder::InsertionGuard g(rewriter);
4093 rewriter.setInsertionPoint(term);
4094 rewriter.replaceOpWithNewOp<ConditionOp>(term, term.getCondition(),
4095 newTermArgs);
4096 }
4097
4098 auto newWhile =
4099 WhileOp::create(rewriter, op.getLoc(), newResultTypes, op.getInits());
4100
4101 Block &newAfterBlock = *rewriter.createBlock(
4102 &newWhile.getAfter(), /*insertPt*/ {}, newResultTypes, newArgLocs);
4103
4104 // Build new results list and new after block args (unused entries will be
4105 // null).
4106 SmallVector<Value> newResults(op.getNumResults());
4107 SmallVector<Value> newAfterBlockArgs(op.getNumResults());
4108 for (const auto &it : llvm::enumerate(newResultsIndices)) {
4109 newResults[it.value()] = newWhile.getResult(it.index());
4110 newAfterBlockArgs[it.value()] = newAfterBlock.getArgument(it.index());
4111 }
4112
4113 rewriter.inlineRegionBefore(op.getBefore(), newWhile.getBefore(),
4114 newWhile.getBefore().begin());
4115
4116 Block &afterBlock = *op.getAfterBody();
4117 rewriter.mergeBlocks(&afterBlock, &newAfterBlock, newAfterBlockArgs);
4118
4119 rewriter.replaceOp(op, newResults);
4120 return success();
4121 }
4122};
4123
4124/// Replace operations equivalent to the condition in the do block with true,
4125/// since otherwise the block would not be evaluated.
4126///
4127/// scf.while (..) : (i32, ...) -> ... {
4128/// %z = ... : i32
4129/// %condition = cmpi pred %z, %a
4130/// scf.condition(%condition) %z : i32, ...
4131/// } do {
4132/// ^bb0(%arg0: i32, ...):
4133/// %condition2 = cmpi pred %arg0, %a
4134/// use(%condition2)
4135/// ...
4136///
4137/// becomes
4138/// scf.while (..) : (i32, ...) -> ... {
4139/// %z = ... : i32
4140/// %condition = cmpi pred %z, %a
4141/// scf.condition(%condition) %z : i32, ...
4142/// } do {
4143/// ^bb0(%arg0: i32, ...):
4144/// use(%true)
4145/// ...
4146struct WhileCmpCond : public OpRewritePattern<scf::WhileOp> {
4147 using OpRewritePattern<scf::WhileOp>::OpRewritePattern;
4148
4149 LogicalResult matchAndRewrite(scf::WhileOp op,
4150 PatternRewriter &rewriter) const override {
4151 using namespace scf;
4152 auto cond = op.getConditionOp();
4153 auto cmp = cond.getCondition().getDefiningOp<arith::CmpIOp>();
4154 if (!cmp)
4155 return failure();
4156 bool changed = false;
4157 for (auto tup : llvm::zip(cond.getArgs(), op.getAfterArguments())) {
4158 for (size_t opIdx = 0; opIdx < 2; opIdx++) {
4159 if (std::get<0>(tup) != cmp.getOperand(opIdx))
4160 continue;
4161 for (OpOperand &u :
4162 llvm::make_early_inc_range(std::get<1>(tup).getUses())) {
4163 auto cmp2 = dyn_cast<arith::CmpIOp>(u.getOwner());
4164 if (!cmp2)
4165 continue;
4166 // For a binary operator 1-opIdx gets the other side.
4167 if (cmp2.getOperand(1 - opIdx) != cmp.getOperand(1 - opIdx))
4168 continue;
4169 bool samePredicate;
4170 if (cmp2.getPredicate() == cmp.getPredicate())
4171 samePredicate = true;
4172 else if (cmp2.getPredicate() ==
4173 arith::invertPredicate(cmp.getPredicate()))
4174 samePredicate = false;
4175 else
4176 continue;
4177
4178 rewriter.replaceOpWithNewOp<arith::ConstantIntOp>(cmp2, samePredicate,
4179 1);
4180 changed = true;
4181 }
4182 }
4183 }
4184 return success(changed);
4185 }
4186};
4187
4188/// Remove unused init/yield args.
4189struct WhileRemoveUnusedArgs : public OpRewritePattern<WhileOp> {
4190 using OpRewritePattern<WhileOp>::OpRewritePattern;
4191
4192 LogicalResult matchAndRewrite(WhileOp op,
4193 PatternRewriter &rewriter) const override {
4194
4195 if (!llvm::any_of(op.getBeforeArguments(),
4196 [](Value arg) { return arg.use_empty(); }))
4197 return rewriter.notifyMatchFailure(op, "No args to remove");
4198
4199 YieldOp yield = op.getYieldOp();
4200
4201 // Collect results mapping, new terminator args and new result types.
4202 SmallVector<Value> newYields;
4203 SmallVector<Value> newInits;
4204 llvm::BitVector argsToErase;
4205
4206 size_t argsCount = op.getBeforeArguments().size();
4207 newYields.reserve(argsCount);
4208 newInits.reserve(argsCount);
4209 argsToErase.reserve(argsCount);
4210 for (auto &&[beforeArg, yieldValue, initValue] : llvm::zip(
4211 op.getBeforeArguments(), yield.getOperands(), op.getInits())) {
4212 if (beforeArg.use_empty()) {
4213 argsToErase.push_back(true);
4214 } else {
4215 argsToErase.push_back(false);
4216 newYields.emplace_back(yieldValue);
4217 newInits.emplace_back(initValue);
4218 }
4219 }
4220
4221 Block &beforeBlock = *op.getBeforeBody();
4222 Block &afterBlock = *op.getAfterBody();
4223
4224 beforeBlock.eraseArguments(argsToErase);
4225
4226 Location loc = op.getLoc();
4227 auto newWhileOp =
4228 WhileOp::create(rewriter, loc, op.getResultTypes(), newInits,
4229 /*beforeBody*/ nullptr, /*afterBody*/ nullptr);
4230 Block &newBeforeBlock = *newWhileOp.getBeforeBody();
4231 Block &newAfterBlock = *newWhileOp.getAfterBody();
4232
4233 OpBuilder::InsertionGuard g(rewriter);
4234 rewriter.setInsertionPoint(yield);
4235 rewriter.replaceOpWithNewOp<YieldOp>(yield, newYields);
4236
4237 rewriter.mergeBlocks(&beforeBlock, &newBeforeBlock,
4238 newBeforeBlock.getArguments());
4239 rewriter.mergeBlocks(&afterBlock, &newAfterBlock,
4240 newAfterBlock.getArguments());
4241
4242 rewriter.replaceOp(op, newWhileOp.getResults());
4243 return success();
4244 }
4245};
4246
4247/// Remove duplicated ConditionOp args.
4248struct WhileRemoveDuplicatedResults : public OpRewritePattern<WhileOp> {
4250
4251 LogicalResult matchAndRewrite(WhileOp op,
4252 PatternRewriter &rewriter) const override {
4253 ConditionOp condOp = op.getConditionOp();
4254 ValueRange condOpArgs = condOp.getArgs();
4255
4256 llvm::SmallPtrSet<Value, 8> argsSet(llvm::from_range, condOpArgs);
4257
4258 if (argsSet.size() == condOpArgs.size())
4259 return rewriter.notifyMatchFailure(op, "No results to remove");
4260
4261 llvm::SmallDenseMap<Value, unsigned> argsMap;
4262 SmallVector<Value> newArgs;
4263 argsMap.reserve(condOpArgs.size());
4264 newArgs.reserve(condOpArgs.size());
4265 for (Value arg : condOpArgs) {
4266 if (!argsMap.count(arg)) {
4267 auto pos = static_cast<unsigned>(argsMap.size());
4268 argsMap.insert({arg, pos});
4269 newArgs.emplace_back(arg);
4270 }
4271 }
4272
4273 ValueRange argsRange(newArgs);
4274
4275 Location loc = op.getLoc();
4276 auto newWhileOp =
4277 scf::WhileOp::create(rewriter, loc, argsRange.getTypes(), op.getInits(),
4278 /*beforeBody*/ nullptr,
4279 /*afterBody*/ nullptr);
4280 Block &newBeforeBlock = *newWhileOp.getBeforeBody();
4281 Block &newAfterBlock = *newWhileOp.getAfterBody();
4282
4283 SmallVector<Value> afterArgsMapping;
4284 SmallVector<Value> resultsMapping;
4285 for (auto &&[i, arg] : llvm::enumerate(condOpArgs)) {
4286 auto it = argsMap.find(arg);
4287 assert(it != argsMap.end());
4288 auto pos = it->second;
4289 afterArgsMapping.emplace_back(newAfterBlock.getArgument(pos));
4290 resultsMapping.emplace_back(newWhileOp->getResult(pos));
4291 }
4292
4293 OpBuilder::InsertionGuard g(rewriter);
4294 rewriter.setInsertionPoint(condOp);
4295 rewriter.replaceOpWithNewOp<ConditionOp>(condOp, condOp.getCondition(),
4296 argsRange);
4297
4298 Block &beforeBlock = *op.getBeforeBody();
4299 Block &afterBlock = *op.getAfterBody();
4300
4301 rewriter.mergeBlocks(&beforeBlock, &newBeforeBlock,
4302 newBeforeBlock.getArguments());
4303 rewriter.mergeBlocks(&afterBlock, &newAfterBlock, afterArgsMapping);
4304 rewriter.replaceOp(op, resultsMapping);
4305 return success();
4306 }
4307};
4308
4309/// If both ranges contain same values return mappping indices from args2 to
4310/// args1. Otherwise return std::nullopt.
4311static std::optional<SmallVector<unsigned>> getArgsMapping(ValueRange args1,
4312 ValueRange args2) {
4313 if (args1.size() != args2.size())
4314 return std::nullopt;
4315
4316 SmallVector<unsigned> ret(args1.size());
4317 for (auto &&[i, arg1] : llvm::enumerate(args1)) {
4318 auto it = llvm::find(args2, arg1);
4319 if (it == args2.end())
4320 return std::nullopt;
4321
4322 ret[std::distance(args2.begin(), it)] = static_cast<unsigned>(i);
4323 }
4324
4325 return ret;
4326}
4327
4328static bool hasDuplicates(ValueRange args) {
4329 llvm::SmallDenseSet<Value> set;
4330 for (Value arg : args) {
4331 if (!set.insert(arg).second)
4332 return true;
4333 }
4334 return false;
4335}
4336
4337/// If `before` block args are directly forwarded to `scf.condition`, rearrange
4338/// `scf.condition` args into same order as block args. Update `after` block
4339/// args and op result values accordingly.
4340/// Needed to simplify `scf.while` -> `scf.for` uplifting.
4341struct WhileOpAlignBeforeArgs : public OpRewritePattern<WhileOp> {
4343
4344 LogicalResult matchAndRewrite(WhileOp loop,
4345 PatternRewriter &rewriter) const override {
4346 auto oldBefore = loop.getBeforeBody();
4347 ConditionOp oldTerm = loop.getConditionOp();
4348 ValueRange beforeArgs = oldBefore->getArguments();
4349 ValueRange termArgs = oldTerm.getArgs();
4350 if (beforeArgs == termArgs)
4351 return failure();
4352
4353 if (hasDuplicates(termArgs))
4354 return failure();
4355
4356 auto mapping = getArgsMapping(beforeArgs, termArgs);
4357 if (!mapping)
4358 return failure();
4359
4360 {
4361 OpBuilder::InsertionGuard g(rewriter);
4362 rewriter.setInsertionPoint(oldTerm);
4363 rewriter.replaceOpWithNewOp<ConditionOp>(oldTerm, oldTerm.getCondition(),
4364 beforeArgs);
4365 }
4366
4367 auto oldAfter = loop.getAfterBody();
4368
4369 SmallVector<Type> newResultTypes(beforeArgs.size());
4370 for (auto &&[i, j] : llvm::enumerate(*mapping))
4371 newResultTypes[j] = loop.getResult(i).getType();
4372
4373 auto newLoop = WhileOp::create(
4374 rewriter, loop.getLoc(), newResultTypes, loop.getInits(),
4375 /*beforeBuilder=*/nullptr, /*afterBuilder=*/nullptr);
4376 auto newBefore = newLoop.getBeforeBody();
4377 auto newAfter = newLoop.getAfterBody();
4378
4379 SmallVector<Value> newResults(beforeArgs.size());
4380 SmallVector<Value> newAfterArgs(beforeArgs.size());
4381 for (auto &&[i, j] : llvm::enumerate(*mapping)) {
4382 newResults[i] = newLoop.getResult(j);
4383 newAfterArgs[i] = newAfter->getArgument(j);
4384 }
4385
4386 rewriter.inlineBlockBefore(oldBefore, newBefore, newBefore->begin(),
4387 newBefore->getArguments());
4388 rewriter.inlineBlockBefore(oldAfter, newAfter, newAfter->begin(),
4389 newAfterArgs);
4390
4391 rewriter.replaceOp(loop, newResults);
4392 return success();
4393 }
4394};
4395} // namespace
4396
4397void WhileOp::getCanonicalizationPatterns(RewritePatternSet &results,
4398 MLIRContext *context) {
4399 results.add<RemoveLoopInvariantArgsFromBeforeBlock,
4400 RemoveLoopInvariantValueYielded, WhileConditionTruth,
4401 WhileCmpCond, WhileUnusedResult, WhileRemoveDuplicatedResults,
4402 WhileRemoveUnusedArgs, WhileOpAlignBeforeArgs>(context);
4403}
4404
4405//===----------------------------------------------------------------------===//
4406// IndexSwitchOp
4407//===----------------------------------------------------------------------===//
4408
4409/// Parse the case regions and values.
4410static ParseResult
4412 SmallVectorImpl<std::unique_ptr<Region>> &caseRegions) {
4413 SmallVector<int64_t> caseValues;
4414 while (succeeded(p.parseOptionalKeyword("case"))) {
4415 int64_t value;
4416 Region &region = *caseRegions.emplace_back(std::make_unique<Region>());
4417 if (p.parseInteger(value) || p.parseRegion(region, /*arguments=*/{}))
4418 return failure();
4419 caseValues.push_back(value);
4420 }
4421 cases = p.getBuilder().getDenseI64ArrayAttr(caseValues);
4422 return success();
4423}
4424
4425/// Print the case regions and values.
4427 DenseI64ArrayAttr cases, RegionRange caseRegions) {
4428 for (auto [value, region] : llvm::zip(cases.asArrayRef(), caseRegions)) {
4429 p.printNewline();
4430 p << "case " << value << ' ';
4431 p.printRegion(*region, /*printEntryBlockArgs=*/false);
4432 }
4433}
4434
4435LogicalResult scf::IndexSwitchOp::verify() {
4436 if (getCases().size() != getCaseRegions().size()) {
4437 return emitOpError("has ")
4438 << getCaseRegions().size() << " case regions but "
4439 << getCases().size() << " case values";
4440 }
4441
4442 DenseSet<int64_t> valueSet;
4443 for (int64_t value : getCases())
4444 if (!valueSet.insert(value).second)
4445 return emitOpError("has duplicate case value: ") << value;
4446 auto verifyRegion = [&](Region &region, const Twine &name) -> LogicalResult {
4447 auto yield = dyn_cast<YieldOp>(region.front().back());
4448 if (!yield)
4449 return emitOpError("expected region to end with scf.yield, but got ")
4450 << region.front().back().getName();
4451
4452 if (yield.getNumOperands() != getNumResults()) {
4453 return (emitOpError("expected each region to return ")
4454 << getNumResults() << " values, but " << name << " returns "
4455 << yield.getNumOperands())
4456 .attachNote(yield.getLoc())
4457 << "see yield operation here";
4458 }
4459 for (auto [idx, result, operand] :
4460 llvm::enumerate(getResultTypes(), yield.getOperands())) {
4461 if (!operand)
4462 return yield.emitOpError() << "operand " << idx << " is null\n";
4463 if (result == operand.getType())
4464 continue;
4465 return (emitOpError("expected result #")
4466 << idx << " of each region to be " << result)
4467 .attachNote(yield.getLoc())
4468 << name << " returns " << operand.getType() << " here";
4469 }
4470 return success();
4471 };
4472
4473 if (failed(verifyRegion(getDefaultRegion(), "default region")))
4474 return failure();
4475 for (auto [idx, caseRegion] : llvm::enumerate(getCaseRegions()))
4476 if (failed(verifyRegion(caseRegion, "case region #" + Twine(idx))))
4477 return failure();
4478
4479 return success();
4480}
4481
4482unsigned scf::IndexSwitchOp::getNumCases() { return getCases().size(); }
4483
4484Block &scf::IndexSwitchOp::getDefaultBlock() {
4485 return getDefaultRegion().front();
4486}
4487
4488Block &scf::IndexSwitchOp::getCaseBlock(unsigned idx) {
4489 assert(idx < getNumCases() && "case index out-of-bounds");
4490 return getCaseRegions()[idx].front();
4491}
4492
4493void IndexSwitchOp::getSuccessorRegions(
4494 RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &successors) {
4495 // All regions branch back to the parent op.
4496 if (!point.isParent()) {
4497 successors.emplace_back(getOperation(), getResults());
4498 return;
4499 }
4500
4501 llvm::append_range(successors, getRegions());
4502}
4503
4504void IndexSwitchOp::getEntrySuccessorRegions(
4505 ArrayRef<Attribute> operands,
4506 SmallVectorImpl<RegionSuccessor> &successors) {
4507 FoldAdaptor adaptor(operands, *this);
4508
4509 // If a constant was not provided, all regions are possible successors.
4510 auto arg = dyn_cast_or_null<IntegerAttr>(adaptor.getArg());
4511 if (!arg) {
4512 llvm::append_range(successors, getRegions());
4513 return;
4514 }
4515
4516 // Otherwise, try to find a case with a matching value. If not, the
4517 // default region is the only successor.
4518 for (auto [caseValue, caseRegion] : llvm::zip(getCases(), getCaseRegions())) {
4519 if (caseValue == arg.getInt()) {
4520 successors.emplace_back(&caseRegion);
4521 return;
4522 }
4523 }
4524 successors.emplace_back(&getDefaultRegion());
4525}
4526
4527void IndexSwitchOp::getRegionInvocationBounds(
4528 ArrayRef<Attribute> operands, SmallVectorImpl<InvocationBounds> &bounds) {
4529 auto operandValue = llvm::dyn_cast_or_null<IntegerAttr>(operands.front());
4530 if (!operandValue) {
4531 // All regions are invoked at most once.
4532 bounds.append(getNumRegions(), InvocationBounds(/*lb=*/0, /*ub=*/1));
4533 return;
4534 }
4535
4536 unsigned liveIndex = getNumRegions() - 1;
4537 const auto *it = llvm::find(getCases(), operandValue.getInt());
4538 if (it != getCases().end())
4539 liveIndex = std::distance(getCases().begin(), it);
4540 for (unsigned i = 0, e = getNumRegions(); i < e; ++i)
4541 bounds.emplace_back(/*lb=*/0, /*ub=*/i == liveIndex);
4542}
4543
4544struct FoldConstantCase : OpRewritePattern<scf::IndexSwitchOp> {
4545 using OpRewritePattern<scf::IndexSwitchOp>::OpRewritePattern;
4546
4547 LogicalResult matchAndRewrite(scf::IndexSwitchOp op,
4548 PatternRewriter &rewriter) const override {
4549 // If `op.getArg()` is a constant, select the region that matches with
4550 // the constant value. Use the default region if no matche is found.
4551 std::optional<int64_t> maybeCst = getConstantIntValue(op.getArg());
4552 if (!maybeCst.has_value())
4553 return failure();
4554 int64_t cst = *maybeCst;
4555 int64_t caseIdx, e = op.getNumCases();
4556 for (caseIdx = 0; caseIdx < e; ++caseIdx) {
4557 if (cst == op.getCases()[caseIdx])
4558 break;
4559 }
4560
4561 Region &r = (caseIdx < op.getNumCases()) ? op.getCaseRegions()[caseIdx]
4562 : op.getDefaultRegion();
4563 Block &source = r.front();
4564 Operation *terminator = source.getTerminator();
4565 SmallVector<Value> results = terminator->getOperands();
4566
4567 rewriter.inlineBlockBefore(&source, op);
4568 rewriter.eraseOp(terminator);
4569 // Replace the operation with a potentially empty list of results.
4570 // Fold mechanism doesn't support the case where the result list is empty.
4571 rewriter.replaceOp(op, results);
4572
4573 return success();
4574 }
4575};
4576
4577void IndexSwitchOp::getCanonicalizationPatterns(RewritePatternSet &results,
4578 MLIRContext *context) {
4579 results.add<FoldConstantCase>(context);
4580}
4581
4582//===----------------------------------------------------------------------===//
4583// TableGen'd op method definitions
4584//===----------------------------------------------------------------------===//
4585
4586#define GET_OP_CLASSES
4587#include "mlir/Dialect/SCF/IR/SCFOps.cpp.inc"
return success()
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
static std::optional< int64_t > getUpperBound(Value iv)
Gets the constant upper bound on an affine.for iv.
static std::optional< int64_t > getLowerBound(Value iv)
Gets the constant lower bound on an iv.
static LogicalResult verifyRegion(emitc::SwitchOp op, Region &region, const Twine &name)
Definition EmitC.cpp:1371
static ParseResult parseSwitchCases(OpAsmParser &parser, DenseI64ArrayAttr &cases, SmallVectorImpl< std::unique_ptr< Region > > &caseRegions)
Parse the case regions and values.
Definition EmitC.cpp:1346
static void printSwitchCases(OpAsmPrinter &p, Operation *op, DenseI64ArrayAttr cases, RegionRange caseRegions)
Print the case regions and values.
Definition EmitC.cpp:1362
static void replaceOpWithRegion(PatternRewriter &rewriter, Operation *op, Region &region, ValueRange blockArgs={})
Replaces the given op with the contents of the given single-block region, using the operands of the b...
Definition SCF.cpp:136
static LogicalResult verifyTypeRangesMatch(OpTy op, TypeRange left, TypeRange right, StringRef message)
Verifies that two ranges of types match, i.e.
Definition SCF.cpp:3658
static void printInitializationList(OpAsmPrinter &p, Block::BlockArgListType blocksArgs, ValueRange initializers, StringRef prefix="")
Prints the initialization list in the form of <prefix>(inner = outer, inner2 = outer2,...
Definition SCF.cpp:565
static TerminatorTy verifyAndGetTerminator(Operation *op, Region &region, StringRef errorMessage)
Verifies that the first block of the given region is terminated by a TerminatorTy.
Definition SCF.cpp:100
static bool isLegalToInline(InlinerInterface &interface, Region *src, Region *insertRegion, bool shouldCloneInlinedRegion, IRMapping &valueMapping)
Utility to check that all of the operations within 'src' can be inlined.
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
b getContext())
*if copies could not be generated due to yet unimplemented cases *copyInPlacementStart and copyOutPlacementStart in copyPlacementBlock *specify the insertion points where the incoming copies and outgoing should be the output argument nBegin is set to its * replacement(set to `begin` if no invalidation happens). Since outgoing *copies could have been inserted at `end`
@ None
static std::string diag(const llvm::Value &value)
@ Paren
Parens surrounding zero or more operands.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseOptionalColon()=0
Parse a : token if present.
ParseResult parseInteger(IntT &result)
Parse an integer value from the stream.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual ParseResult parseOptionalAttrDictWithKeyword(NamedAttrList &result)=0
Parse a named dictionary into 'result' if the attributes keyword is present.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseOptionalArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an optional arrow followed by a type list.
virtual ParseResult parseArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an arrow followed by a type list.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
void printOptionalArrowTypeList(TypeRange &&types)
Print an optional arrow followed by a type list.
This class represents an argument of a Block.
Definition Value.h:309
unsigned getArgNumber() const
Returns the number of this argument.
Definition Value.h:321
Block represents an ordered list of Operations.
Definition Block.h:33
MutableArrayRef< BlockArgument > BlockArgListType
Definition Block.h:85
bool empty()
Definition Block.h:148
BlockArgument getArgument(unsigned i)
Definition Block.h:129
unsigned getNumArguments()
Definition Block.h:128
iterator_range< args_iterator > addArguments(TypeRange types, ArrayRef< Location > locs)
Add one argument to the argument list for each type specified in the list.
Definition Block.cpp:160
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
Definition Block.cpp:27
Operation & front()
Definition Block.h:153
Operation & back()
Definition Block.h:152
Operation * getTerminator()
Get the terminator operation of this block.
Definition Block.cpp:244
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
Definition Block.cpp:153
void eraseArguments(unsigned start, unsigned num)
Erases 'num' arguments from the index 'start'.
Definition Block.cpp:201
BlockArgListType getArguments()
Definition Block.h:87
iterator_range< iterator > without_terminator()
Return an iterator range over the operation within this block excluding the terminator operation at t...
Definition Block.h:212
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
Definition Block.cpp:31
bool getValue() const
Return the boolean value of this attribute.
UnitAttr getUnitAttr()
Definition Builders.cpp:98
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
Definition Builders.cpp:163
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition Builders.cpp:228
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
Definition Builders.cpp:167
IntegerType getIntegerType(unsigned width)
Definition Builders.cpp:67
BoolAttr getBoolAttr(bool value)
Definition Builders.cpp:100
IntegerType getI1Type()
Definition Builders.cpp:53
MLIRContext * getContext() const
Definition Builders.h:56
IndexType getIndexType()
Definition Builders.cpp:51
This is the interface that must be implemented by the dialects of operations to be inlined.
DialectInlinerInterface(Dialect *dialect)
virtual Operation * materializeConstant(OpBuilder &builder, Attribute value, Type type, Location loc)
Registered hook to materialize a single constant operation from a given attribute value with the desi...
Definition Dialect.h:83
This is a utility class for mapping one set of IR entities to another.
Definition IRMapping.h:26
auto lookupOrDefault(T from) const
Lookup a mapped value within the map.
Definition IRMapping.h:65
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Definition IRMapping.h:30
IRValueT get() const
Return the current value being used by this operand.
void set(IRValueT newValue)
Set the current value being used by this operand.
This class represents a diagnostic that is inflight and set to be reported.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
This class provides a mutable adaptor for a range of operands.
Definition ValueRange.h:118
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual OptionalParseResult parseOptionalAssignmentList(SmallVectorImpl< Argument > &lhs, SmallVectorImpl< UnresolvedOperand > &rhs)=0
virtual ParseResult parseRegion(Region &region, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult parseArgumentList(SmallVectorImpl< Argument > &result, Delimiter delimiter=Delimiter::None, bool allowType=false, bool allowAttrs=false)=0
Parse zero or more arguments with a specified surrounding delimiter.
ParseResult parseAssignmentList(SmallVectorImpl< Argument > &lhs, SmallVectorImpl< UnresolvedOperand > &rhs)
Parse a list of assignments of the form (x1 = y1, x2 = y2, ...)
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printNewline()=0
Print a newline and indent the printer to the start of the current operation.
virtual void printOptionalAttrDictWithKeyword(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary prefixed with 'attribute...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
void printFunctionalType(Operation *op)
Print the complete type of an operation in functional form.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
RAII guard to reset the insertion point of the builder when destroyed.
Definition Builders.h:348
This class helps build Operations.
Definition Builders.h:207
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition Builders.cpp:430
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition Builders.cpp:562
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition Builders.h:431
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition Builders.h:398
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Definition Builders.h:436
void cloneRegionBefore(Region &region, Region &parent, Region::iterator before, IRMapping &mapping)
Clone the blocks that belong to "region" before the given position in another region "parent".
Definition Builders.cpp:589
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition Builders.h:412
Operation * cloneWithoutRegions(Operation &op, IRMapping &mapper)
Creates a deep copy of this operation but keep the operation regions empty.
Definition Builders.h:594
This class represents a single result from folding an operation.
This class represents an operand of an operation.
Definition Value.h:257
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Definition Value.cpp:226
Set of flags used to control the behavior of the various IR print methods (e.g.
A wrapper class that allows for printing an operation with a set of flags, useful to act as a "stream...
Definition Operation.h:1111
This class implements the operand iterators for the Operation class.
Definition ValueRange.h:43
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Definition Operation.h:512
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition Operation.h:407
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:223
OperandRange operand_range
Definition Operation.h:371
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
Definition Operation.h:238
OperationName getName()
The name of an operation is the key identifier for it.
Definition Operation.h:119
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition Operation.h:378
Region * getParentRegion()
Returns the region to which the instruction belongs.
Definition Operation.h:230
bool isProperAncestor(Operation *other)
Return true if this operation is a proper ancestor of the other operation.
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
ParseResult value() const
Access the internal ParseResult value.
bool has_value() const
Returns true if we contain a valid ParseResult value.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class represents a point being branched from in the methods of the RegionBranchOpInterface.
bool isParent() const
Returns true if branching from the parent op.
Operation * getTerminatorPredecessorOrNull() const
Returns the terminator if branching from a region.
This class provides an abstraction over the different types of ranges over Regions.
Definition Region.h:346
This class represents a successor of a region.
bool isParent() const
Return true if the successor is the parent operation.
Region * getSuccessor() const
Return the given region successor.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition Region.h:26
Block & front()
Definition Region.h:65
Region * getParentRegion()
Return the region containing this region or nullptr if the region is attached to a top-level operatio...
Definition Region.cpp:45
bool isAncestor(Region *other)
Return true if this region is ancestor of the other region.
Definition Region.h:222
Block & back()
Definition Region.h:64
bool empty()
Definition Region.h:60
unsigned getNumArguments()
Definition Region.h:123
iterator begin()
Definition Region.h:55
BlockArgument getArgument(unsigned i)
Definition Region.h:124
bool hasOneBlock()
Return true if this region has exactly one block.
Definition Region.h:68
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void eraseBlock(Block *block)
This method erases all operations in a block.
Block * splitBlock(Block *block, Block::iterator before)
Split the operations starting at "before" (inclusive) out of the given block into a new block,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void finalizeOpModification(Operation *op)
This method is used to signal the end of an in-place modification of the given operation.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void replaceUsesWithIf(Value from, Value to, function_ref< bool(OpOperand &)> functor, bool *allUsesReplaced=nullptr)
Find uses of from and replace them with to if the functor returns true.
virtual void inlineBlockBefore(Block *source, Block *dest, Block::iterator before, ValueRange argValues={})
Inline the operations of block 'source' into block 'dest' before the given position.
void mergeBlocks(Block *source, Block *dest, ValueRange argValues={})
Inline the operations of block 'source' into the end of block 'dest'.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
void inlineRegionBefore(Region &region, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
virtual void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
virtual void startOpModification(Operation *op)
This method is used to notify the rewriter that an in-place operation modification is about to happen...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class provides an abstraction over the various different ranges of value types.
Definition TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
bool isIndex() const
Definition Types.cpp:54
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:387
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
bool use_empty() const
Returns true if this value has no uses.
Definition Value.h:208
Type getType() const
Return the type of this value.
Definition Value.h:105
Block * getParentBlock()
Return the Block in which this Value is defined.
Definition Value.cpp:46
user_range getUsers() const
Definition Value.h:218
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
Region * getParentRegion()
Return the Region in which this Value is defined.
Definition Value.cpp:39
Operation * getOwner() const
Return the owner of this operand.
Definition UseDefLists.h:38
constexpr auto RecursivelySpeculatable
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto NotSpeculatable
StringRef getMappingAttrName()
Name of the mapping attribute produced by loop mappers.
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:561
ParallelOp getParallelForInductionVarOwner(Value val)
Returns the parallel loop parent of an induction variable.
Definition SCF.cpp:3267
void buildTerminatedBody(OpBuilder &builder, Location loc)
Default callback for IfOp builders. Inserts a yield without arguments.
Definition SCF.cpp:93
LoopNest buildLoopNest(OpBuilder &builder, Location loc, ValueRange lbs, ValueRange ubs, ValueRange steps, ValueRange iterArgs, function_ref< ValueVector(OpBuilder &, Location, ValueRange, ValueRange)> bodyBuilder=nullptr)
Creates a perfect nest of "for" loops, i.e.
Definition SCF.cpp:837
bool insideMutuallyExclusiveBranches(Operation *a, Operation *b)
Return true if ops a and b (or their ancestors) are in mutually exclusive regions/blocks of an IfOp.
Definition SCF.cpp:2151
void promote(RewriterBase &rewriter, scf::ForallOp forallOp)
Promotes the loop body of a scf::ForallOp to its containing block.
Definition SCF.cpp:792
std::optional< llvm::APSInt > computeUbMinusLb(Value lb, Value ub, bool isSigned)
Helper function to compute the difference between two values.
Definition SCF.cpp:114
ForOp getForInductionVarOwner(Value val)
Returns the loop parent of an induction variable.
Definition SCF.cpp:744
SmallVector< Value > ValueVector
An owning vector of values, handy to return from functions.
Definition SCF.h:64
llvm::function_ref< Value(OpBuilder &, Location loc, Type, Value)> ValueTypeCastFnTy
Perform a replacement of one iter OpOperand of an scf.for to the replacement value with a different t...
Definition SCF.h:107
ForallOp getForallOpThreadIndexOwner(Value val)
Returns the ForallOp parent of an thread index variable.
Definition SCF.cpp:1609
SmallVector< Value > replaceAndCastForOpIterArg(RewriterBase &rewriter, scf::ForOp forOp, OpOperand &operand, Value replacement, const ValueTypeCastFnTy &castFn)
Definition SCF.cpp:926
bool preservesStaticInformation(Type source, Type target)
Returns true if target is a ranked tensor type that preserves static information available in the sou...
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition Matchers.h:490
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, MLIRContext *context)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
Definition Matchers.h:527
detail::DenseArrayAttrImpl< int64_t > DenseI64ArrayAttr
const FrozenRewritePatternSet GreedyRewriteConfig bool * changed
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
std::function< SmallVector< Value >( OpBuilder &b, Location loc, ArrayRef< BlockArgument > newBbArgs)> NewYieldValuesFn
A function that returns the additional yielded values during replaceWithAdditionalYields.
ParseResult parseDynamicIndexList(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &values, DenseI64ArrayAttr &integers, DenseBoolArrayAttr &scalableFlags, SmallVectorImpl< Type > *valueTypes=nullptr, AsmParser::Delimiter delimiter=AsmParser::Delimiter::Square)
Parser hooks for custom directive in assemblyFormat.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:304
llvm::DenseSet< ValueT, ValueInfoT > DenseSet
Definition LLVM.h:128
llvm::SetVector< T, Vector, Set, N > SetVector
Definition LLVM.h:131
detail::constant_int_predicate_matcher m_One()
Matches a constant scalar / vector splat / tensor splat integer one.
Definition Matchers.h:478
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition Utils.cpp:111
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
Definition LLVM.h:126
LogicalResult verifyListOfOperandsOrIntegers(Operation *op, StringRef name, unsigned expectedNumElements, ArrayRef< int64_t > attr, ValueRange values)
Verify that a the values has as many elements as the number of entries in attr for which isDynamic ev...
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition Matchers.h:369
llvm::function_ref< Fn > function_ref
Definition LLVM.h:152
SmallVector< NamedAttribute > getPrunedAttributeList(Operation *op, ArrayRef< StringRef > elidedAttrs)
void printDynamicIndexList(OpAsmPrinter &printer, Operation *op, OperandRange values, ArrayRef< int64_t > integers, ArrayRef< bool > scalableFlags, TypeRange valueTypes=TypeRange(), AsmParser::Delimiter delimiter=AsmParser::Delimiter::Square)
Printer hooks for custom directive in assemblyFormat.
std::optional< APInt > constantTripCount(OpFoldResult lb, OpFoldResult ub, OpFoldResult step, bool isSigned, llvm::function_ref< std::optional< llvm::APSInt >(Value, Value, bool)> computeUbMinusLb)
Return the number of iterations for a loop with a lower bound lb, upper bound ub and step step.
LogicalResult foldDynamicIndexList(SmallVectorImpl< OpFoldResult > &ofrs, bool onlyNonNegative=false, bool onlyNonZero=false)
Returns "success" when any of the elements in ofrs is a constant value.
LogicalResult matchAndRewrite(ExecuteRegionOp op, PatternRewriter &rewriter) const override
Definition SCF.cpp:301
LogicalResult matchAndRewrite(scf::IndexSwitchOp op, PatternRewriter &rewriter) const override
Definition SCF.cpp:4547
LogicalResult matchAndRewrite(ExecuteRegionOp op, PatternRewriter &rewriter) const override
Definition SCF.cpp:260
LogicalResult matchAndRewrite(ExecuteRegionOp op, PatternRewriter &rewriter) const override
Definition SCF.cpp:211
This is the representation of an operand reference.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
This represents an operation in an abstracted form, suitable for use with the builder APIs.
void addOperands(ValueRange newOperands)
void addTypes(ArrayRef< Type > newTypes)
Region * addRegion()
Create a region that should be attached to the operation.