MLIR 23.0.0git
SCF.cpp
Go to the documentation of this file.
1//===- SCF.cpp - Structured Control Flow Operations -----------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
20#include "mlir/IR/IRMapping.h"
21#include "mlir/IR/Matchers.h"
22#include "mlir/IR/Operation.h"
30#include "llvm/ADT/MapVector.h"
31#include "llvm/ADT/STLExtras.h"
32#include "llvm/ADT/SmallPtrSet.h"
33#include "llvm/Support/Casting.h"
34#include "llvm/Support/DebugLog.h"
35#include <optional>
36
37using namespace mlir;
38using namespace mlir::scf;
39
40#include "mlir/Dialect/SCF/IR/SCFOpsDialect.cpp.inc"
41
42//===----------------------------------------------------------------------===//
43// SCFDialect Dialect Interfaces
44//===----------------------------------------------------------------------===//
45
46namespace {
47struct SCFInlinerInterface : public DialectInlinerInterface {
48 using DialectInlinerInterface::DialectInlinerInterface;
49 // We don't have any special restrictions on what can be inlined into
50 // destination regions (e.g. while/conditional bodies). Always allow it.
51 bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned,
52 IRMapping &valueMapping) const final {
53 return true;
54 }
55 // Operations in scf dialect are always legal to inline since they are
56 // pure.
57 bool isLegalToInline(Operation *, Region *, bool, IRMapping &) const final {
58 return true;
59 }
60 // Handle the given inlined terminator by replacing it with a new operation
61 // as necessary. Required when the region has only one block.
62 void handleTerminator(Operation *op, ValueRange valuesToRepl) const final {
63 auto retValOp = dyn_cast<scf::YieldOp>(op);
64 if (!retValOp)
65 return;
66
67 for (auto retValue : llvm::zip(valuesToRepl, retValOp.getOperands())) {
68 std::get<0>(retValue).replaceAllUsesWith(std::get<1>(retValue));
69 }
70 }
71};
72} // namespace
73
74//===----------------------------------------------------------------------===//
75// SCFDialect
76//===----------------------------------------------------------------------===//
77
78void SCFDialect::initialize() {
79 addOperations<
80#define GET_OP_LIST
81#include "mlir/Dialect/SCF/IR/SCFOps.cpp.inc"
82 >();
83 addInterfaces<SCFInlinerInterface>();
84 declarePromisedInterface<ConvertToEmitCPatternInterface, SCFDialect>();
85 declarePromisedInterfaces<bufferization::BufferDeallocationOpInterface,
86 InParallelOp, ReduceReturnOp>();
87 declarePromisedInterfaces<bufferization::BufferizableOpInterface, ConditionOp,
88 ExecuteRegionOp, ForOp, IfOp, IndexSwitchOp,
89 ForallOp, InParallelOp, WhileOp, YieldOp>();
90 declarePromisedInterface<ValueBoundsOpInterface, ForOp>();
91}
92
93/// Default callback for IfOp builders. Inserts a yield without arguments.
95 scf::YieldOp::create(builder, loc);
96}
97
98/// Verifies that the first block of the given `region` is terminated by a
99/// TerminatorTy. Reports errors on the given operation if it is not the case.
100template <typename TerminatorTy>
101static TerminatorTy verifyAndGetTerminator(Operation *op, Region &region,
102 StringRef errorMessage) {
103 Operation *terminatorOperation = nullptr;
104 if (!region.empty() && !region.front().empty()) {
105 terminatorOperation = &region.front().back();
106 if (auto yield = dyn_cast_or_null<TerminatorTy>(terminatorOperation))
107 return yield;
108 }
109 auto diag = op->emitOpError(errorMessage);
110 if (terminatorOperation)
111 diag.attachNote(terminatorOperation->getLoc()) << "terminator here";
112 return nullptr;
113}
114
115std::optional<llvm::APSInt> mlir::scf::computeUbMinusLb(Value lb, Value ub,
116 bool isSigned) {
117 llvm::APSInt diff;
118 auto addOp = ub.getDefiningOp<arith::AddIOp>();
119 if (!addOp)
120 return std::nullopt;
121 if ((isSigned && !addOp.hasNoSignedWrap()) ||
122 (!isSigned && !addOp.hasNoUnsignedWrap()))
123 return std::nullopt;
124
125 if (addOp.getLhs() != lb ||
126 !matchPattern(addOp.getRhs(), m_ConstantInt(&diff)))
127 return std::nullopt;
128 return diff;
129}
130
131//===----------------------------------------------------------------------===//
132// ExecuteRegionOp
133//===----------------------------------------------------------------------===//
134
135///
136/// (ssa-id `=`)? `execute_region` `->` function-result-type `{`
137/// block+
138/// `}`
139///
140/// Example:
141/// scf.execute_region -> i32 {
142/// %idx = load %rI[%i] : memref<128xi32>
143/// return %idx : i32
144/// }
145///
146ParseResult ExecuteRegionOp::parse(OpAsmParser &parser,
148 if (parser.parseOptionalArrowTypeList(result.types))
149 return failure();
150
151 if (succeeded(parser.parseOptionalKeyword("no_inline")))
152 result.addAttribute("no_inline", parser.getBuilder().getUnitAttr());
153
154 // Introduce the body region and parse it.
155 Region *body = result.addRegion();
156 if (parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{}) ||
157 parser.parseOptionalAttrDict(result.attributes))
158 return failure();
159
160 return success();
161}
162
163void ExecuteRegionOp::print(OpAsmPrinter &p) {
164 p.printOptionalArrowTypeList(getResultTypes());
165 p << ' ';
166 if (getNoInline())
167 p << "no_inline ";
168 p.printRegion(getRegion(),
169 /*printEntryBlockArgs=*/false,
170 /*printBlockTerminators=*/true);
171 p.printOptionalAttrDict((*this)->getAttrs(), /*elidedAttrs=*/{"no_inline"});
172}
173
174LogicalResult ExecuteRegionOp::verify() {
175 if (getRegion().empty())
176 return emitOpError("region needs to have at least one block");
177 if (getRegion().front().getNumArguments() > 0)
178 return emitOpError("region cannot have any arguments");
179 return success();
180}
181
182// Inline an ExecuteRegionOp if its parent can contain multiple blocks.
183// TODO generalize the conditions for operations which can be inlined into.
184// func @func_execute_region_elim() {
185// "test.foo"() : () -> ()
186// %v = scf.execute_region -> i64 {
187// %c = "test.cmp"() : () -> i1
188// cf.cond_br %c, ^bb2, ^bb3
189// ^bb2:
190// %x = "test.val1"() : () -> i64
191// cf.br ^bb4(%x : i64)
192// ^bb3:
193// %y = "test.val2"() : () -> i64
194// cf.br ^bb4(%y : i64)
195// ^bb4(%z : i64):
196// scf.yield %z : i64
197// }
198// "test.bar"(%v) : (i64) -> ()
199// return
200// }
201//
202// becomes
203//
204// func @func_execute_region_elim() {
205// "test.foo"() : () -> ()
206// %c = "test.cmp"() : () -> i1
207// cf.cond_br %c, ^bb1, ^bb2
208// ^bb1: // pred: ^bb0
209// %x = "test.val1"() : () -> i64
210// cf.br ^bb3(%x : i64)
211// ^bb2: // pred: ^bb0
212// %y = "test.val2"() : () -> i64
213// cf.br ^bb3(%y : i64)
214// ^bb3(%z: i64): // 2 preds: ^bb1, ^bb2
215// "test.bar"(%z) : (i64) -> ()
216// return
217// }
218//
219struct MultiBlockExecuteInliner : public OpRewritePattern<ExecuteRegionOp> {
220 using OpRewritePattern<ExecuteRegionOp>::OpRewritePattern;
221
222 LogicalResult matchAndRewrite(ExecuteRegionOp op,
223 PatternRewriter &rewriter) const override {
224 if (op.getNoInline())
225 return failure();
226 if (!isa<FunctionOpInterface, ExecuteRegionOp>(op->getParentOp()))
227 return failure();
228
229 Block *prevBlock = op->getBlock();
230 Block *postBlock = rewriter.splitBlock(prevBlock, op->getIterator());
231 rewriter.setInsertionPointToEnd(prevBlock);
232
233 cf::BranchOp::create(rewriter, op.getLoc(), &op.getRegion().front());
234
235 for (Block &blk : op.getRegion()) {
236 if (YieldOp yieldOp = dyn_cast<YieldOp>(blk.getTerminator())) {
237 rewriter.setInsertionPoint(yieldOp);
238 cf::BranchOp::create(rewriter, yieldOp.getLoc(), postBlock,
239 yieldOp.getResults());
240 rewriter.eraseOp(yieldOp);
241 }
242 }
243
244 rewriter.inlineRegionBefore(op.getRegion(), postBlock);
245 SmallVector<Value> blockArgs;
246
247 for (auto res : op.getResults())
248 blockArgs.push_back(postBlock->addArgument(res.getType(), res.getLoc()));
249
250 rewriter.replaceOp(op, blockArgs);
251 return success();
252 }
253};
254
255void ExecuteRegionOp::getCanonicalizationPatterns(RewritePatternSet &results,
256 MLIRContext *context) {
257 results.add<MultiBlockExecuteInliner>(context);
259 results, ExecuteRegionOp::getOperationName());
260 // Inline ops with a single block that are not marked as "no_inline".
262 results, ExecuteRegionOp::getOperationName(),
264 return failure(cast<ExecuteRegionOp>(op).getNoInline());
265 });
266}
267
268void ExecuteRegionOp::getSuccessorRegions(
270 // If the predecessor is the ExecuteRegionOp, branch into the body.
271 if (point.isParent()) {
272 regions.push_back(RegionSuccessor(&getRegion()));
273 return;
274 }
275
276 // Otherwise, the region branches back to the parent operation.
277 regions.push_back(RegionSuccessor::parent());
278}
279
280ValueRange ExecuteRegionOp::getSuccessorInputs(RegionSuccessor successor) {
281 return successor.isParent() ? ValueRange(getOperation()->getResults())
282 : ValueRange();
283}
284
285//===----------------------------------------------------------------------===//
286// ConditionOp
287//===----------------------------------------------------------------------===//
288
290ConditionOp::getMutableSuccessorOperands(RegionSuccessor point) {
291 assert(
292 (point.isParent() || point.getSuccessor() == &getParentOp().getAfter()) &&
293 "condition op can only exit the loop or branch to the after"
294 "region");
295 // Pass all operands except the condition to the successor region.
296 return getArgsMutable();
297}
298
299void ConditionOp::getSuccessorRegions(
301 FoldAdaptor adaptor(operands, *this);
302
303 WhileOp whileOp = getParentOp();
304
305 // Condition can either lead to the after region or back to the parent op
306 // depending on whether the condition is true or not.
307 auto boolAttr = dyn_cast_or_null<BoolAttr>(adaptor.getCondition());
308 if (!boolAttr || boolAttr.getValue())
309 regions.emplace_back(&whileOp.getAfter());
310 if (!boolAttr || !boolAttr.getValue())
311 regions.push_back(RegionSuccessor::parent());
312}
313
314//===----------------------------------------------------------------------===//
315// ForOp
316//===----------------------------------------------------------------------===//
317
318void ForOp::build(OpBuilder &builder, OperationState &result, Value lb,
319 Value ub, Value step, ValueRange initArgs,
320 BodyBuilderFn bodyBuilder, bool unsignedCmp) {
321 OpBuilder::InsertionGuard guard(builder);
322
323 if (unsignedCmp)
324 result.addAttribute(getUnsignedCmpAttrName(result.name),
325 builder.getUnitAttr());
326 result.addOperands({lb, ub, step});
327 result.addOperands(initArgs);
328 for (Value v : initArgs)
329 result.addTypes(v.getType());
330 Type t = lb.getType();
331 Region *bodyRegion = result.addRegion();
332 Block *bodyBlock = builder.createBlock(bodyRegion);
333 bodyBlock->addArgument(t, result.location);
334 for (Value v : initArgs)
335 bodyBlock->addArgument(v.getType(), v.getLoc());
336
337 // Create the default terminator if the builder is not provided and if the
338 // iteration arguments are not provided. Otherwise, leave this to the caller
339 // because we don't know which values to return from the loop.
340 if (initArgs.empty() && !bodyBuilder) {
341 ForOp::ensureTerminator(*bodyRegion, builder, result.location);
342 } else if (bodyBuilder) {
343 OpBuilder::InsertionGuard guard(builder);
344 builder.setInsertionPointToStart(bodyBlock);
345 bodyBuilder(builder, result.location, bodyBlock->getArgument(0),
346 bodyBlock->getArguments().drop_front());
347 }
348}
349
350LogicalResult ForOp::verify() {
351 // Check that the number of init args and op results is the same.
352 if (getInitArgs().size() != getNumResults())
353 return emitOpError(
354 "mismatch in number of loop-carried values and defined values");
355
356 return success();
357}
358
359LogicalResult ForOp::verifyRegions() {
360 // Check that the body defines as single block argument for the induction
361 // variable.
362 if (getInductionVar().getType() != getLowerBound().getType())
363 return emitOpError(
364 "expected induction variable to be same type as bounds and step");
365
366 if (getNumRegionIterArgs() != getNumResults())
367 return emitOpError(
368 "mismatch in number of basic block args and defined values");
369
370 auto initArgs = getInitArgs();
371 auto iterArgs = getRegionIterArgs();
372 auto opResults = getResults();
373 unsigned i = 0;
374 for (auto e : llvm::zip(initArgs, iterArgs, opResults)) {
375 if (std::get<0>(e).getType() != std::get<2>(e).getType())
376 return emitOpError() << "types mismatch between " << i
377 << "th iter operand and defined value";
378 if (std::get<1>(e).getType() != std::get<2>(e).getType())
379 return emitOpError() << "types mismatch between " << i
380 << "th iter region arg and defined value";
381
382 ++i;
383 }
384 return success();
385}
386
387std::optional<SmallVector<Value>> ForOp::getLoopInductionVars() {
388 return SmallVector<Value>{getInductionVar()};
389}
390
391std::optional<SmallVector<OpFoldResult>> ForOp::getLoopLowerBounds() {
393}
394
395std::optional<SmallVector<OpFoldResult>> ForOp::getLoopSteps() {
396 return SmallVector<OpFoldResult>{OpFoldResult(getStep())};
397}
398
399std::optional<SmallVector<OpFoldResult>> ForOp::getLoopUpperBounds() {
401}
402
403bool ForOp::isValidInductionVarType(Type type) {
404 return type.isIndex() || type.isSignlessInteger();
405}
406
407LogicalResult ForOp::setLoopLowerBounds(ArrayRef<OpFoldResult> bounds) {
408 if (bounds.size() != 1)
409 return failure();
410 if (auto val = dyn_cast<Value>(bounds[0])) {
411 setLowerBound(val);
412 return success();
413 }
414 return failure();
415}
416
417LogicalResult ForOp::setLoopUpperBounds(ArrayRef<OpFoldResult> bounds) {
418 if (bounds.size() != 1)
419 return failure();
420 if (auto val = dyn_cast<Value>(bounds[0])) {
421 setUpperBound(val);
422 return success();
423 }
424 return failure();
425}
426
427LogicalResult ForOp::setLoopSteps(ArrayRef<OpFoldResult> steps) {
428 if (steps.size() != 1)
429 return failure();
430 if (auto val = dyn_cast<Value>(steps[0])) {
431 setStep(val);
432 return success();
433 }
434 return failure();
435}
436
437std::optional<ResultRange> ForOp::getLoopResults() { return getResults(); }
438
439/// Promotes the loop body of a forOp to its containing block if the forOp
440/// it can be determined that the loop has a single iteration.
441LogicalResult ForOp::promoteIfSingleIteration(RewriterBase &rewriter) {
442 std::optional<APInt> tripCount = getStaticTripCount();
443 LDBG() << "promoteIfSingleIteration tripCount is " << tripCount
444 << " for loop "
445 << OpWithFlags(getOperation(), OpPrintingFlags().skipRegions());
446 if (!tripCount.has_value() || tripCount->getSExtValue() > 1)
447 return failure();
448
449 if (*tripCount == 0) {
450 rewriter.replaceAllUsesWith(getResults(), getInitArgs());
451 rewriter.eraseOp(*this);
452 return success();
453 }
454
455 // Replace all results with the yielded values.
456 auto yieldOp = cast<scf::YieldOp>(getBody()->getTerminator());
457 rewriter.replaceAllUsesWith(getResults(), getYieldedValues());
458
459 // Replace block arguments with lower bound (replacement for IV) and
460 // iter_args.
461 SmallVector<Value> bbArgReplacements;
462 bbArgReplacements.push_back(getLowerBound());
463 llvm::append_range(bbArgReplacements, getInitArgs());
464
465 // Move the loop body operations to the loop's containing block.
466 rewriter.inlineBlockBefore(getBody(), getOperation()->getBlock(),
467 getOperation()->getIterator(), bbArgReplacements);
468
469 // Erase the old terminator and the loop.
470 rewriter.eraseOp(yieldOp);
471 rewriter.eraseOp(*this);
472
473 return success();
474}
475
476/// Prints the initialization list in the form of
477/// <prefix>(%inner = %outer, %inner2 = %outer2, <...>)
478/// where 'inner' values are assumed to be region arguments and 'outer' values
479/// are regular SSA values.
481 Block::BlockArgListType blocksArgs,
482 ValueRange initializers,
483 StringRef prefix = "") {
484 assert(blocksArgs.size() == initializers.size() &&
485 "expected same length of arguments and initializers");
486 if (initializers.empty())
487 return;
488
489 p << prefix << '(';
490 llvm::interleaveComma(llvm::zip(blocksArgs, initializers), p, [&](auto it) {
491 p << std::get<0>(it) << " = " << std::get<1>(it);
492 });
493 p << ")";
494}
495
496void ForOp::print(OpAsmPrinter &p) {
497 if (getUnsignedCmp())
498 p << " unsigned";
499
500 p << " " << getInductionVar() << " = " << getLowerBound() << " to "
501 << getUpperBound() << " step " << getStep();
502
503 printInitializationList(p, getRegionIterArgs(), getInitArgs(), " iter_args");
504 if (!getInitArgs().empty())
505 p << " -> (" << getInitArgs().getTypes() << ')';
506 p << ' ';
507 if (Type t = getInductionVar().getType(); !t.isIndex())
508 p << " : " << t << ' ';
509 p.printRegion(getRegion(),
510 /*printEntryBlockArgs=*/false,
511 /*printBlockTerminators=*/!getInitArgs().empty());
512 p.printOptionalAttrDict((*this)->getAttrs(),
513 /*elidedAttrs=*/getUnsignedCmpAttrName().strref());
514}
515
516ParseResult ForOp::parse(OpAsmParser &parser, OperationState &result) {
517 auto &builder = parser.getBuilder();
518 Type type;
519
520 OpAsmParser::Argument inductionVariable;
522
523 if (succeeded(parser.parseOptionalKeyword("unsigned")))
524 result.addAttribute(getUnsignedCmpAttrName(result.name),
525 builder.getUnitAttr());
526
527 // Parse the induction variable followed by '='.
528 if (parser.parseOperand(inductionVariable.ssaName) || parser.parseEqual() ||
529 // Parse loop bounds.
530 parser.parseOperand(lb) || parser.parseKeyword("to") ||
531 parser.parseOperand(ub) || parser.parseKeyword("step") ||
532 parser.parseOperand(step))
533 return failure();
534
535 // Parse the optional initial iteration arguments.
538 regionArgs.push_back(inductionVariable);
539
540 bool hasIterArgs = succeeded(parser.parseOptionalKeyword("iter_args"));
541 if (hasIterArgs) {
542 // Parse assignment list and results type list.
543 if (parser.parseAssignmentList(regionArgs, operands) ||
544 parser.parseArrowTypeList(result.types))
545 return failure();
546 }
547
548 if (regionArgs.size() != result.types.size() + 1)
549 return parser.emitError(
550 parser.getNameLoc(),
551 "mismatch in number of loop-carried values and defined values");
552
553 // Parse optional type, else assume Index.
554 if (parser.parseOptionalColon())
555 type = builder.getIndexType();
556 else if (parser.parseType(type))
557 return failure();
558
559 // Set block argument types, so that they are known when parsing the region.
560 regionArgs.front().type = type;
561 for (auto [iterArg, type] :
562 llvm::zip_equal(llvm::drop_begin(regionArgs), result.types))
563 iterArg.type = type;
564
565 // Parse the body region.
566 Region *body = result.addRegion();
567 if (parser.parseRegion(*body, regionArgs))
568 return failure();
569 ForOp::ensureTerminator(*body, builder, result.location);
570
571 // Resolve input operands. This should be done after parsing the region to
572 // catch invalid IR where operands were defined inside of the region.
573 if (parser.resolveOperand(lb, type, result.operands) ||
574 parser.resolveOperand(ub, type, result.operands) ||
575 parser.resolveOperand(step, type, result.operands))
576 return failure();
577 if (hasIterArgs) {
578 for (auto argOperandType : llvm::zip_equal(llvm::drop_begin(regionArgs),
579 operands, result.types)) {
580 Type type = std::get<2>(argOperandType);
581 std::get<0>(argOperandType).type = type;
582 if (parser.resolveOperand(std::get<1>(argOperandType), type,
583 result.operands))
584 return failure();
585 }
586 }
587
588 // Parse the optional attribute list.
589 if (parser.parseOptionalAttrDict(result.attributes))
590 return failure();
591
592 return success();
593}
594
595SmallVector<Region *> ForOp::getLoopRegions() { return {&getRegion()}; }
596
597Block::BlockArgListType ForOp::getRegionIterArgs() {
598 return getBody()->getArguments().drop_front(getNumInductionVars());
599}
600
601MutableArrayRef<OpOperand> ForOp::getInitsMutable() {
602 return getInitArgsMutable();
603}
604
605FailureOr<LoopLikeOpInterface>
606ForOp::replaceWithAdditionalYields(RewriterBase &rewriter,
607 ValueRange newInitOperands,
608 bool replaceInitOperandUsesInLoop,
609 const NewYieldValuesFn &newYieldValuesFn) {
610 // Create a new loop before the existing one, with the extra operands.
611 OpBuilder::InsertionGuard g(rewriter);
612 rewriter.setInsertionPoint(getOperation());
613 auto inits = llvm::to_vector(getInitArgs());
614 inits.append(newInitOperands.begin(), newInitOperands.end());
615 scf::ForOp newLoop = scf::ForOp::create(
616 rewriter, getLoc(), getLowerBound(), getUpperBound(), getStep(), inits,
617 [](OpBuilder &, Location, Value, ValueRange) {}, getUnsignedCmp());
618 newLoop->setAttrs(getPrunedAttributeList(getOperation(), {}));
619
620 // Generate the new yield values and append them to the scf.yield operation.
621 auto yieldOp = cast<scf::YieldOp>(getBody()->getTerminator());
622 ArrayRef<BlockArgument> newIterArgs =
623 newLoop.getBody()->getArguments().take_back(newInitOperands.size());
624 {
625 OpBuilder::InsertionGuard g(rewriter);
626 rewriter.setInsertionPoint(yieldOp);
627 SmallVector<Value> newYieldedValues =
628 newYieldValuesFn(rewriter, getLoc(), newIterArgs);
629 assert(newInitOperands.size() == newYieldedValues.size() &&
630 "expected as many new yield values as new iter operands");
631 rewriter.modifyOpInPlace(yieldOp, [&]() {
632 yieldOp.getResultsMutable().append(newYieldedValues);
633 });
634 }
635
636 // Move the loop body to the new op.
637 rewriter.mergeBlocks(getBody(), newLoop.getBody(),
638 newLoop.getBody()->getArguments().take_front(
639 getBody()->getNumArguments()));
640
641 if (replaceInitOperandUsesInLoop) {
642 // Replace all uses of `newInitOperands` with the corresponding basic block
643 // arguments.
644 for (auto it : llvm::zip(newInitOperands, newIterArgs)) {
645 rewriter.replaceUsesWithIf(std::get<0>(it), std::get<1>(it),
646 [&](OpOperand &use) {
647 Operation *user = use.getOwner();
648 return newLoop->isProperAncestor(user);
649 });
650 }
651 }
652
653 // Replace the old loop.
654 rewriter.replaceOp(getOperation(),
655 newLoop->getResults().take_front(getNumResults()));
656 return cast<LoopLikeOpInterface>(newLoop.getOperation());
657}
658
660 auto ivArg = llvm::dyn_cast<BlockArgument>(val);
661 if (!ivArg)
662 return ForOp();
663 assert(ivArg.getOwner() && "unlinked block argument");
664 auto *containingOp = ivArg.getOwner()->getParentOp();
665 return dyn_cast_or_null<ForOp>(containingOp);
666}
667
668OperandRange ForOp::getEntrySuccessorOperands(RegionSuccessor successor) {
669 return getInitArgs();
670}
671
672void ForOp::getSuccessorRegions(RegionBranchPoint point,
674 if (std::optional<APInt> tripCount = getStaticTripCount()) {
675 // The loop has a known static trip count.
676 if (point.isParent()) {
677 if (*tripCount == 0) {
678 // The loop has zero iterations. It branches directly back to the
679 // parent.
680 regions.push_back(RegionSuccessor::parent());
681 } else {
682 // The loop has at least one iteration. It branches into the body.
683 regions.push_back(RegionSuccessor(&getRegion()));
684 }
685 return;
686 } else if (*tripCount == 1) {
687 // The loop has exactly 1 iteration. Therefore, it branches from the
688 // region to the parent. (No further iteration.)
689 regions.push_back(RegionSuccessor::parent());
690 return;
691 }
692 }
693
694 // Both the operation itself and the region may be branching into the body or
695 // back into the operation itself. It is possible for loop not to enter the
696 // body.
697 regions.push_back(RegionSuccessor(&getRegion()));
698 regions.push_back(RegionSuccessor::parent());
699}
700
701ValueRange ForOp::getSuccessorInputs(RegionSuccessor successor) {
702 return successor.isParent() ? ValueRange(getResults())
703 : ValueRange(getRegionIterArgs());
704}
705
706SmallVector<Region *> ForallOp::getLoopRegions() { return {&getRegion()}; }
707
708/// Promotes the loop body of a forallOp to its containing block if it can be
709/// determined that the loop has a single iteration.
710LogicalResult scf::ForallOp::promoteIfSingleIteration(RewriterBase &rewriter) {
711 for (auto [lb, ub, step] :
712 llvm::zip(getMixedLowerBound(), getMixedUpperBound(), getMixedStep())) {
713 auto tripCount =
714 constantTripCount(lb, ub, step, /*isSigned=*/true, computeUbMinusLb);
715 if (!tripCount.has_value() || *tripCount != 1)
716 return failure();
717 }
718
719 promote(rewriter, *this);
720 return success();
721}
722
723Block::BlockArgListType ForallOp::getRegionIterArgs() {
724 return getBody()->getArguments().drop_front(getRank());
725}
726
727MutableArrayRef<OpOperand> ForallOp::getInitsMutable() {
728 return getOutputsMutable();
729}
730
731/// Promotes the loop body of a scf::ForallOp to its containing block.
732void mlir::scf::promote(RewriterBase &rewriter, scf::ForallOp forallOp) {
733 OpBuilder::InsertionGuard g(rewriter);
734 scf::InParallelOp terminator = forallOp.getTerminator();
735
736 // Replace block arguments with lower bounds (replacements for IVs) and
737 // outputs.
738 SmallVector<Value> bbArgReplacements = forallOp.getLowerBound(rewriter);
739 bbArgReplacements.append(forallOp.getOutputs().begin(),
740 forallOp.getOutputs().end());
741
742 // Move the loop body operations to the loop's containing block.
743 rewriter.inlineBlockBefore(forallOp.getBody(), forallOp->getBlock(),
744 forallOp->getIterator(), bbArgReplacements);
745
746 // Replace the terminator with tensor.insert_slice ops.
747 rewriter.setInsertionPointAfter(forallOp);
748 SmallVector<Value> results;
749 results.reserve(forallOp.getResults().size());
750 for (auto &yieldingOp : terminator.getYieldingOps()) {
751 auto parallelInsertSliceOp =
752 dyn_cast<tensor::ParallelInsertSliceOp>(yieldingOp);
753 if (!parallelInsertSliceOp)
754 continue;
755
756 Value dst = parallelInsertSliceOp.getDest();
757 Value src = parallelInsertSliceOp.getSource();
758 if (llvm::isa<TensorType>(src.getType())) {
759 results.push_back(tensor::InsertSliceOp::create(
760 rewriter, forallOp.getLoc(), dst.getType(), src, dst,
761 parallelInsertSliceOp.getOffsets(), parallelInsertSliceOp.getSizes(),
762 parallelInsertSliceOp.getStrides(),
763 parallelInsertSliceOp.getStaticOffsets(),
764 parallelInsertSliceOp.getStaticSizes(),
765 parallelInsertSliceOp.getStaticStrides()));
766 } else {
767 llvm_unreachable("unsupported terminator");
768 }
769 }
770 rewriter.replaceAllUsesWith(forallOp.getResults(), results);
771
772 // Erase the old terminator and the loop.
773 rewriter.eraseOp(terminator);
774 rewriter.eraseOp(forallOp);
775}
776
778 OpBuilder &builder, Location loc, ValueRange lbs, ValueRange ubs,
779 ValueRange steps, ValueRange iterArgs,
781 bodyBuilder) {
782 assert(lbs.size() == ubs.size() &&
783 "expected the same number of lower and upper bounds");
784 assert(lbs.size() == steps.size() &&
785 "expected the same number of lower bounds and steps");
786
787 // If there are no bounds, call the body-building function and return early.
788 if (lbs.empty()) {
789 ValueVector results =
790 bodyBuilder ? bodyBuilder(builder, loc, ValueRange(), iterArgs)
791 : ValueVector();
792 assert(results.size() == iterArgs.size() &&
793 "loop nest body must return as many values as loop has iteration "
794 "arguments");
795 return LoopNest{{}, std::move(results)};
796 }
797
798 // First, create the loop structure iteratively using the body-builder
799 // callback of `ForOp::build`. Do not create `YieldOp`s yet.
800 OpBuilder::InsertionGuard guard(builder);
803 loops.reserve(lbs.size());
804 ivs.reserve(lbs.size());
805 ValueRange currentIterArgs = iterArgs;
806 Location currentLoc = loc;
807 for (unsigned i = 0, e = lbs.size(); i < e; ++i) {
808 auto loop = scf::ForOp::create(
809 builder, currentLoc, lbs[i], ubs[i], steps[i], currentIterArgs,
810 [&](OpBuilder &nestedBuilder, Location nestedLoc, Value iv,
811 ValueRange args) {
812 ivs.push_back(iv);
813 // It is safe to store ValueRange args because it points to block
814 // arguments of a loop operation that we also own.
815 currentIterArgs = args;
816 currentLoc = nestedLoc;
817 });
818 // Set the builder to point to the body of the newly created loop. We don't
819 // do this in the callback because the builder is reset when the callback
820 // returns.
821 builder.setInsertionPointToStart(loop.getBody());
822 loops.push_back(loop);
823 }
824
825 // For all loops but the innermost, yield the results of the nested loop.
826 for (unsigned i = 0, e = loops.size() - 1; i < e; ++i) {
827 builder.setInsertionPointToEnd(loops[i].getBody());
828 scf::YieldOp::create(builder, loc, loops[i + 1].getResults());
829 }
830
831 // In the body of the innermost loop, call the body building function if any
832 // and yield its results.
833 builder.setInsertionPointToStart(loops.back().getBody());
834 ValueVector results = bodyBuilder
835 ? bodyBuilder(builder, currentLoc, ivs,
836 loops.back().getRegionIterArgs())
837 : ValueVector();
838 assert(results.size() == iterArgs.size() &&
839 "loop nest body must return as many values as loop has iteration "
840 "arguments");
841 builder.setInsertionPointToEnd(loops.back().getBody());
842 scf::YieldOp::create(builder, loc, results);
843
844 // Return the loops.
845 ValueVector nestResults;
846 llvm::append_range(nestResults, loops.front().getResults());
847 return LoopNest{std::move(loops), std::move(nestResults)};
848}
849
851 OpBuilder &builder, Location loc, ValueRange lbs, ValueRange ubs,
852 ValueRange steps,
853 function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilder) {
854 // Delegate to the main function by wrapping the body builder.
855 return buildLoopNest(builder, loc, lbs, ubs, steps, {},
856 [&bodyBuilder](OpBuilder &nestedBuilder,
857 Location nestedLoc, ValueRange ivs,
859 if (bodyBuilder)
860 bodyBuilder(nestedBuilder, nestedLoc, ivs);
861 return {};
862 });
863}
864
867 OpOperand &operand, Value replacement,
868 const ValueTypeCastFnTy &castFn) {
869 assert(operand.getOwner() == forOp);
870 Type oldType = operand.get().getType(), newType = replacement.getType();
871
872 // 1. Create new iter operands, exactly 1 is replaced.
873 assert(operand.getOperandNumber() >= forOp.getNumControlOperands() &&
874 "expected an iter OpOperand");
875 assert(operand.get().getType() != replacement.getType() &&
876 "Expected a different type");
877 SmallVector<Value> newIterOperands;
878 for (OpOperand &opOperand : forOp.getInitArgsMutable()) {
879 if (opOperand.getOperandNumber() == operand.getOperandNumber()) {
880 newIterOperands.push_back(replacement);
881 continue;
882 }
883 newIterOperands.push_back(opOperand.get());
884 }
885
886 // 2. Create the new forOp shell.
887 scf::ForOp newForOp = scf::ForOp::create(
888 rewriter, forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
889 forOp.getStep(), newIterOperands, /*bodyBuilder=*/nullptr,
890 forOp.getUnsignedCmp());
891 newForOp->setAttrs(forOp->getAttrs());
892 Block &newBlock = newForOp.getRegion().front();
893 SmallVector<Value, 4> newBlockTransferArgs(newBlock.getArguments().begin(),
894 newBlock.getArguments().end());
895
896 // 3. Inject an incoming cast op at the beginning of the block for the bbArg
897 // corresponding to the `replacement` value.
898 OpBuilder::InsertionGuard g(rewriter);
899 rewriter.setInsertionPointToStart(&newBlock);
900 BlockArgument newRegionIterArg = newForOp.getTiedLoopRegionIterArg(
901 &newForOp->getOpOperand(operand.getOperandNumber()));
902 Value castIn = castFn(rewriter, newForOp.getLoc(), oldType, newRegionIterArg);
903 newBlockTransferArgs[newRegionIterArg.getArgNumber()] = castIn;
904
905 // 4. Steal the old block ops, mapping to the newBlockTransferArgs.
906 Block &oldBlock = forOp.getRegion().front();
907 rewriter.mergeBlocks(&oldBlock, &newBlock, newBlockTransferArgs);
908
909 // 5. Inject an outgoing cast op at the end of the block and yield it instead.
910 auto clonedYieldOp = cast<scf::YieldOp>(newBlock.getTerminator());
911 rewriter.setInsertionPoint(clonedYieldOp);
912 unsigned yieldIdx =
913 newRegionIterArg.getArgNumber() - forOp.getNumInductionVars();
914 Value castOut = castFn(rewriter, newForOp.getLoc(), newType,
915 clonedYieldOp.getOperand(yieldIdx));
916 SmallVector<Value> newYieldOperands = clonedYieldOp.getOperands();
917 newYieldOperands[yieldIdx] = castOut;
918 scf::YieldOp::create(rewriter, newForOp.getLoc(), newYieldOperands);
919 rewriter.eraseOp(clonedYieldOp);
920
921 // 6. Inject an outgoing cast op after the forOp.
922 rewriter.setInsertionPointAfter(newForOp);
923 SmallVector<Value> newResults = newForOp.getResults();
924 newResults[yieldIdx] =
925 castFn(rewriter, newForOp.getLoc(), oldType, newResults[yieldIdx]);
926
927 return newResults;
928}
929
930namespace {
931/// Fold scf.for iter_arg/result pairs that go through incoming/ougoing
932/// a tensor.cast op pair so as to pull the tensor.cast inside the scf.for:
933///
934/// ```
935/// %0 = tensor.cast %t0 : tensor<32x1024xf32> to tensor<?x?xf32>
936/// %1 = scf.for %i = %c0 to %c1024 step %c32 iter_args(%iter_t0 = %0)
937/// -> (tensor<?x?xf32>) {
938/// %2 = call @do(%iter_t0) : (tensor<?x?xf32>) -> tensor<?x?xf32>
939/// scf.yield %2 : tensor<?x?xf32>
940/// }
941/// use_of(%1)
942/// ```
943///
944/// folds into:
945///
946/// ```
947/// %0 = scf.for %arg2 = %c0 to %c1024 step %c32 iter_args(%arg3 = %arg0)
948/// -> (tensor<32x1024xf32>) {
949/// %2 = tensor.cast %arg3 : tensor<32x1024xf32> to tensor<?x?xf32>
950/// %3 = call @do(%2) : (tensor<?x?xf32>) -> tensor<?x?xf32>
951/// %4 = tensor.cast %3 : tensor<?x?xf32> to tensor<32x1024xf32>
952/// scf.yield %4 : tensor<32x1024xf32>
953/// }
954/// %1 = tensor.cast %0 : tensor<32x1024xf32> to tensor<?x?xf32>
955/// use_of(%1)
956/// ```
957struct ForOpTensorCastFolder : public OpRewritePattern<ForOp> {
959
960 LogicalResult matchAndRewrite(ForOp op,
961 PatternRewriter &rewriter) const override {
962 for (auto it : llvm::zip(op.getInitArgsMutable(), op.getResults())) {
963 OpOperand &iterOpOperand = std::get<0>(it);
964 auto incomingCast = iterOpOperand.get().getDefiningOp<tensor::CastOp>();
965 if (!incomingCast ||
966 incomingCast.getSource().getType() == incomingCast.getType())
967 continue;
968 // If the dest type of the cast does not preserve static information in
969 // the source type.
971 incomingCast.getDest().getType(),
972 incomingCast.getSource().getType()))
973 continue;
974 if (!std::get<1>(it).hasOneUse())
975 continue;
976
977 // Create a new ForOp with that iter operand replaced.
978 rewriter.replaceOp(
980 rewriter, op, iterOpOperand, incomingCast.getSource(),
981 [](OpBuilder &b, Location loc, Type type, Value source) {
982 return tensor::CastOp::create(b, loc, type, source);
983 }));
984 return success();
985 }
986 return failure();
987 }
988};
989} // namespace
990
991void ForOp::getCanonicalizationPatterns(RewritePatternSet &results,
992 MLIRContext *context) {
993 results.add<ForOpTensorCastFolder>(context);
995 results, ForOp::getOperationName());
997 results, ForOp::getOperationName(),
998 /*replBuilderFn=*/[](OpBuilder &builder, Location loc, Value value) {
999 // scf.for has only one non-successor input value: the loop induction
1000 // variable. In case of a single acyclic path through the op, the IV can
1001 // be safely replaced with the lower bound.
1002 auto blockArg = cast<BlockArgument>(value);
1003 assert(blockArg.getArgNumber() == 0 && "expected induction variable");
1004 auto forOp = cast<ForOp>(blockArg.getOwner()->getParentOp());
1005 return forOp.getLowerBound();
1006 });
1007}
1008
1009std::optional<APInt> ForOp::getConstantStep() {
1010 IntegerAttr step;
1011 if (matchPattern(getStep(), m_Constant(&step)))
1012 return step.getValue();
1013 return {};
1014}
1015
1016std::optional<MutableArrayRef<OpOperand>> ForOp::getYieldedValuesMutable() {
1017 return cast<scf::YieldOp>(getBody()->getTerminator()).getResultsMutable();
1018}
1019
1020Speculation::Speculatability ForOp::getSpeculatability() {
1021 // `scf.for (I = Start; I < End; I += 1)` terminates for all values of Start
1022 // and End.
1023 if (auto constantStep = getConstantStep())
1024 if (*constantStep == 1)
1026
1027 // For Step != 1, the loop may not terminate. We can add more smarts here if
1028 // needed.
1030}
1031
1032std::optional<APInt> ForOp::getStaticTripCount() {
1033 return constantTripCount(getLowerBound(), getUpperBound(), getStep(),
1034 /*isSigned=*/!getUnsignedCmp(), computeUbMinusLb);
1035}
1036
1037//===----------------------------------------------------------------------===//
1038// ForallOp
1039//===----------------------------------------------------------------------===//
1040
1041LogicalResult ForallOp::verify() {
1042 unsigned numLoops = getRank();
1043 // Check number of outputs.
1044 if (getNumResults() != getOutputs().size())
1045 return emitOpError("produces ")
1046 << getNumResults() << " results, but has only "
1047 << getOutputs().size() << " outputs";
1048
1049 // Check that the body defines block arguments for thread indices and outputs.
1050 auto *body = getBody();
1051 if (body->getNumArguments() != numLoops + getOutputs().size())
1052 return emitOpError("region expects ") << numLoops << " arguments";
1053 for (int64_t i = 0; i < numLoops; ++i)
1054 if (!body->getArgument(i).getType().isIndex())
1055 return emitOpError("expects ")
1056 << i << "-th block argument to be an index";
1057 for (unsigned i = 0; i < getOutputs().size(); ++i)
1058 if (body->getArgument(i + numLoops).getType() != getOutputs()[i].getType())
1059 return emitOpError("type mismatch between ")
1060 << i << "-th output and corresponding block argument";
1061 if (getMapping().has_value() && !getMapping()->empty()) {
1062 if (getDeviceMappingAttrs().size() != numLoops)
1063 return emitOpError() << "mapping attribute size must match op rank";
1064 if (failed(getDeviceMaskingAttr()))
1065 return emitOpError() << getMappingAttrName()
1066 << " supports at most one device masking attribute";
1067 }
1068
1069 // Verify mixed static/dynamic control variables.
1070 Operation *op = getOperation();
1071 if (failed(verifyListOfOperandsOrIntegers(op, "lower bound", numLoops,
1072 getStaticLowerBound(),
1073 getDynamicLowerBound())))
1074 return failure();
1075 if (failed(verifyListOfOperandsOrIntegers(op, "upper bound", numLoops,
1076 getStaticUpperBound(),
1077 getDynamicUpperBound())))
1078 return failure();
1079 if (failed(verifyListOfOperandsOrIntegers(op, "step", numLoops,
1080 getStaticStep(), getDynamicStep())))
1081 return failure();
1082
1083 return success();
1084}
1085
1086void ForallOp::print(OpAsmPrinter &p) {
1087 Operation *op = getOperation();
1088 p << " (" << getInductionVars();
1089 if (isNormalized()) {
1090 p << ") in ";
1091 printDynamicIndexList(p, op, getDynamicUpperBound(), getStaticUpperBound(),
1092 /*valueTypes=*/{}, /*scalables=*/{},
1094 } else {
1095 p << ") = ";
1096 printDynamicIndexList(p, op, getDynamicLowerBound(), getStaticLowerBound(),
1097 /*valueTypes=*/{}, /*scalables=*/{},
1099 p << " to ";
1100 printDynamicIndexList(p, op, getDynamicUpperBound(), getStaticUpperBound(),
1101 /*valueTypes=*/{}, /*scalables=*/{},
1103 p << " step ";
1104 printDynamicIndexList(p, op, getDynamicStep(), getStaticStep(),
1105 /*valueTypes=*/{}, /*scalables=*/{},
1107 }
1108 printInitializationList(p, getRegionOutArgs(), getOutputs(), " shared_outs");
1109 p << " ";
1110 if (!getRegionOutArgs().empty())
1111 p << "-> (" << getResultTypes() << ") ";
1112 p.printRegion(getRegion(),
1113 /*printEntryBlockArgs=*/false,
1114 /*printBlockTerminators=*/getNumResults() > 0);
1115 p.printOptionalAttrDict(op->getAttrs(), {getOperandSegmentSizesAttrName(),
1116 getStaticLowerBoundAttrName(),
1117 getStaticUpperBoundAttrName(),
1118 getStaticStepAttrName()});
1119}
1120
1121ParseResult ForallOp::parse(OpAsmParser &parser, OperationState &result) {
1122 OpBuilder b(parser.getContext());
1123 auto indexType = b.getIndexType();
1124
1125 // Parse an opening `(` followed by thread index variables followed by `)`
1126 // TODO: when we can refer to such "induction variable"-like handles from the
1127 // declarative assembly format, we can implement the parser as a custom hook.
1128 SmallVector<OpAsmParser::Argument, 4> ivs;
1130 return failure();
1131
1132 DenseI64ArrayAttr staticLbs, staticUbs, staticSteps;
1133 SmallVector<OpAsmParser::UnresolvedOperand> dynamicLbs, dynamicUbs,
1134 dynamicSteps;
1135 if (succeeded(parser.parseOptionalKeyword("in"))) {
1136 // Parse upper bounds.
1137 if (parseDynamicIndexList(parser, dynamicUbs, staticUbs,
1138 /*valueTypes=*/nullptr,
1140 parser.resolveOperands(dynamicUbs, indexType, result.operands))
1141 return failure();
1142
1143 unsigned numLoops = ivs.size();
1144 staticLbs = b.getDenseI64ArrayAttr(SmallVector<int64_t>(numLoops, 0));
1145 staticSteps = b.getDenseI64ArrayAttr(SmallVector<int64_t>(numLoops, 1));
1146 } else {
1147 // Parse lower bounds.
1148 if (parser.parseEqual() ||
1149 parseDynamicIndexList(parser, dynamicLbs, staticLbs,
1150 /*valueTypes=*/nullptr,
1152
1153 parser.resolveOperands(dynamicLbs, indexType, result.operands))
1154 return failure();
1155
1156 // Parse upper bounds.
1157 if (parser.parseKeyword("to") ||
1158 parseDynamicIndexList(parser, dynamicUbs, staticUbs,
1159 /*valueTypes=*/nullptr,
1161 parser.resolveOperands(dynamicUbs, indexType, result.operands))
1162 return failure();
1163
1164 // Parse step values.
1165 if (parser.parseKeyword("step") ||
1166 parseDynamicIndexList(parser, dynamicSteps, staticSteps,
1167 /*valueTypes=*/nullptr,
1169 parser.resolveOperands(dynamicSteps, indexType, result.operands))
1170 return failure();
1171 }
1172
1173 // Parse out operands and results.
1174 SmallVector<OpAsmParser::Argument, 4> regionOutArgs;
1175 SmallVector<OpAsmParser::UnresolvedOperand, 4> outOperands;
1176 SMLoc outOperandsLoc = parser.getCurrentLocation();
1177 if (succeeded(parser.parseOptionalKeyword("shared_outs"))) {
1178 if (outOperands.size() != result.types.size())
1179 return parser.emitError(outOperandsLoc,
1180 "mismatch between out operands and types");
1181 if (parser.parseAssignmentList(regionOutArgs, outOperands) ||
1182 parser.parseOptionalArrowTypeList(result.types) ||
1183 parser.resolveOperands(outOperands, result.types, outOperandsLoc,
1184 result.operands))
1185 return failure();
1186 }
1187
1188 // Parse region.
1189 SmallVector<OpAsmParser::Argument, 4> regionArgs;
1190 std::unique_ptr<Region> region = std::make_unique<Region>();
1191 for (auto &iv : ivs) {
1192 iv.type = b.getIndexType();
1193 regionArgs.push_back(iv);
1194 }
1195 for (const auto &it : llvm::enumerate(regionOutArgs)) {
1196 auto &out = it.value();
1197 out.type = result.types[it.index()];
1198 regionArgs.push_back(out);
1199 }
1200 if (parser.parseRegion(*region, regionArgs))
1201 return failure();
1202
1203 // Ensure terminator and move region.
1204 ForallOp::ensureTerminator(*region, b, result.location);
1205 result.addRegion(std::move(region));
1206
1207 // Parse the optional attribute list.
1208 if (parser.parseOptionalAttrDict(result.attributes))
1209 return failure();
1210
1211 result.addAttribute("staticLowerBound", staticLbs);
1212 result.addAttribute("staticUpperBound", staticUbs);
1213 result.addAttribute("staticStep", staticSteps);
1214 result.addAttribute("operandSegmentSizes",
1216 {static_cast<int32_t>(dynamicLbs.size()),
1217 static_cast<int32_t>(dynamicUbs.size()),
1218 static_cast<int32_t>(dynamicSteps.size()),
1219 static_cast<int32_t>(outOperands.size())}));
1220 return success();
1221}
1222
1223// Builder that takes loop bounds.
1224void ForallOp::build(
1225 mlir::OpBuilder &b, mlir::OperationState &result,
1226 ArrayRef<OpFoldResult> lbs, ArrayRef<OpFoldResult> ubs,
1227 ArrayRef<OpFoldResult> steps, ValueRange outputs,
1228 std::optional<ArrayAttr> mapping,
1229 function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilderFn) {
1230 SmallVector<int64_t> staticLbs, staticUbs, staticSteps;
1231 SmallVector<Value> dynamicLbs, dynamicUbs, dynamicSteps;
1232 dispatchIndexOpFoldResults(lbs, dynamicLbs, staticLbs);
1233 dispatchIndexOpFoldResults(ubs, dynamicUbs, staticUbs);
1234 dispatchIndexOpFoldResults(steps, dynamicSteps, staticSteps);
1235
1236 result.addOperands(dynamicLbs);
1237 result.addOperands(dynamicUbs);
1238 result.addOperands(dynamicSteps);
1239 result.addOperands(outputs);
1240 result.addTypes(TypeRange(outputs));
1241
1242 result.addAttribute(getStaticLowerBoundAttrName(result.name),
1243 b.getDenseI64ArrayAttr(staticLbs));
1244 result.addAttribute(getStaticUpperBoundAttrName(result.name),
1245 b.getDenseI64ArrayAttr(staticUbs));
1246 result.addAttribute(getStaticStepAttrName(result.name),
1247 b.getDenseI64ArrayAttr(staticSteps));
1248 result.addAttribute(
1249 "operandSegmentSizes",
1250 b.getDenseI32ArrayAttr({static_cast<int32_t>(dynamicLbs.size()),
1251 static_cast<int32_t>(dynamicUbs.size()),
1252 static_cast<int32_t>(dynamicSteps.size()),
1253 static_cast<int32_t>(outputs.size())}));
1254 if (mapping.has_value()) {
1255 result.addAttribute(ForallOp::getMappingAttrName(result.name),
1256 mapping.value());
1257 }
1258
1259 Region *bodyRegion = result.addRegion();
1260 OpBuilder::InsertionGuard g(b);
1261 b.createBlock(bodyRegion);
1262 Block &bodyBlock = bodyRegion->front();
1263
1264 // Add block arguments for indices and outputs.
1265 bodyBlock.addArguments(
1266 SmallVector<Type>(lbs.size(), b.getIndexType()),
1267 SmallVector<Location>(staticLbs.size(), result.location));
1268 bodyBlock.addArguments(
1269 TypeRange(outputs),
1270 SmallVector<Location>(outputs.size(), result.location));
1271
1272 b.setInsertionPointToStart(&bodyBlock);
1273 if (!bodyBuilderFn) {
1274 ForallOp::ensureTerminator(*bodyRegion, b, result.location);
1275 return;
1276 }
1277 bodyBuilderFn(b, result.location, bodyBlock.getArguments());
1278}
1279
1280// Builder that takes loop bounds.
1281void ForallOp::build(
1282 mlir::OpBuilder &b, mlir::OperationState &result,
1283 ArrayRef<OpFoldResult> ubs, ValueRange outputs,
1284 std::optional<ArrayAttr> mapping,
1285 function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilderFn) {
1286 unsigned numLoops = ubs.size();
1287 SmallVector<OpFoldResult> lbs(numLoops, b.getIndexAttr(0));
1288 SmallVector<OpFoldResult> steps(numLoops, b.getIndexAttr(1));
1289 build(b, result, lbs, ubs, steps, outputs, mapping, bodyBuilderFn);
1290}
1291
1292// Checks if the lbs are zeros and steps are ones.
1293bool ForallOp::isNormalized() {
1294 auto allEqual = [](ArrayRef<OpFoldResult> results, int64_t val) {
1295 return llvm::all_of(results, [&](OpFoldResult ofr) {
1296 auto intValue = getConstantIntValue(ofr);
1297 return intValue.has_value() && intValue == val;
1298 });
1299 };
1300 return allEqual(getMixedLowerBound(), 0) && allEqual(getMixedStep(), 1);
1301}
1302
1303InParallelOp ForallOp::getTerminator() {
1304 return cast<InParallelOp>(getBody()->getTerminator());
1305}
1306
1307SmallVector<Operation *> ForallOp::getCombiningOps(BlockArgument bbArg) {
1308 SmallVector<Operation *> storeOps;
1309 for (Operation *user : bbArg.getUsers()) {
1310 if (auto parallelOp = dyn_cast<ParallelCombiningOpInterface>(user)) {
1311 storeOps.push_back(parallelOp);
1312 }
1313 }
1314 return storeOps;
1315}
1316
1317SmallVector<DeviceMappingAttrInterface> ForallOp::getDeviceMappingAttrs() {
1318 SmallVector<DeviceMappingAttrInterface> res;
1319 if (!getMapping())
1320 return res;
1321 for (auto attr : getMapping()->getValue()) {
1322 auto m = dyn_cast<DeviceMappingAttrInterface>(attr);
1323 if (m)
1324 res.push_back(m);
1325 }
1326 return res;
1327}
1328
1329FailureOr<DeviceMaskingAttrInterface> ForallOp::getDeviceMaskingAttr() {
1330 DeviceMaskingAttrInterface res;
1331 if (!getMapping())
1332 return res;
1333 for (auto attr : getMapping()->getValue()) {
1334 auto m = dyn_cast<DeviceMaskingAttrInterface>(attr);
1335 if (m && res)
1336 return failure();
1337 if (m)
1338 res = m;
1339 }
1340 return res;
1341}
1342
1343bool ForallOp::usesLinearMapping() {
1344 SmallVector<DeviceMappingAttrInterface> ifaces = getDeviceMappingAttrs();
1345 if (ifaces.empty())
1346 return false;
1347 return ifaces.front().isLinearMapping();
1348}
1349
1350std::optional<SmallVector<Value>> ForallOp::getLoopInductionVars() {
1351 return SmallVector<Value>{getBody()->getArguments().take_front(getRank())};
1352}
1353
1354// Get lower bounds as OpFoldResult.
1355std::optional<SmallVector<OpFoldResult>> ForallOp::getLoopLowerBounds() {
1356 Builder b(getOperation()->getContext());
1357 return getMixedValues(getStaticLowerBound(), getDynamicLowerBound(), b);
1358}
1359
1360// Get upper bounds as OpFoldResult.
1361std::optional<SmallVector<OpFoldResult>> ForallOp::getLoopUpperBounds() {
1362 Builder b(getOperation()->getContext());
1363 return getMixedValues(getStaticUpperBound(), getDynamicUpperBound(), b);
1364}
1365
1366// Get steps as OpFoldResult.
1367std::optional<SmallVector<OpFoldResult>> ForallOp::getLoopSteps() {
1368 Builder b(getOperation()->getContext());
1369 return getMixedValues(getStaticStep(), getDynamicStep(), b);
1370}
1371
1373 auto tidxArg = llvm::dyn_cast<BlockArgument>(val);
1374 if (!tidxArg)
1375 return ForallOp();
1376 assert(tidxArg.getOwner() && "unlinked block argument");
1377 auto *containingOp = tidxArg.getOwner()->getParentOp();
1378 return dyn_cast<ForallOp>(containingOp);
1379}
1380
1381namespace {
1382/// Fold tensor.dim(forall shared_outs(... = %t)) to tensor.dim(%t).
1383struct DimOfForallOp : public OpRewritePattern<tensor::DimOp> {
1384 using OpRewritePattern<tensor::DimOp>::OpRewritePattern;
1385
1386 LogicalResult matchAndRewrite(tensor::DimOp dimOp,
1387 PatternRewriter &rewriter) const final {
1388 auto forallOp = dimOp.getSource().getDefiningOp<ForallOp>();
1389 if (!forallOp)
1390 return failure();
1391 Value sharedOut =
1392 forallOp.getTiedOpOperand(llvm::cast<OpResult>(dimOp.getSource()))
1393 ->get();
1394 rewriter.modifyOpInPlace(
1395 dimOp, [&]() { dimOp.getSourceMutable().assign(sharedOut); });
1396 return success();
1397 }
1398};
1399
1400class ForallOpControlOperandsFolder : public OpRewritePattern<ForallOp> {
1401public:
1402 using OpRewritePattern<ForallOp>::OpRewritePattern;
1403
1404 LogicalResult matchAndRewrite(ForallOp op,
1405 PatternRewriter &rewriter) const override {
1406 SmallVector<OpFoldResult> mixedLowerBound(op.getMixedLowerBound());
1407 SmallVector<OpFoldResult> mixedUpperBound(op.getMixedUpperBound());
1408 SmallVector<OpFoldResult> mixedStep(op.getMixedStep());
1409 if (failed(foldDynamicIndexList(mixedLowerBound)) &&
1410 failed(foldDynamicIndexList(mixedUpperBound)) &&
1411 failed(foldDynamicIndexList(mixedStep)))
1412 return failure();
1413
1414 rewriter.modifyOpInPlace(op, [&]() {
1415 SmallVector<Value> dynamicLowerBound, dynamicUpperBound, dynamicStep;
1416 SmallVector<int64_t> staticLowerBound, staticUpperBound, staticStep;
1417 dispatchIndexOpFoldResults(mixedLowerBound, dynamicLowerBound,
1418 staticLowerBound);
1419 op.getDynamicLowerBoundMutable().assign(dynamicLowerBound);
1420 op.setStaticLowerBound(staticLowerBound);
1421
1422 dispatchIndexOpFoldResults(mixedUpperBound, dynamicUpperBound,
1423 staticUpperBound);
1424 op.getDynamicUpperBoundMutable().assign(dynamicUpperBound);
1425 op.setStaticUpperBound(staticUpperBound);
1426
1427 dispatchIndexOpFoldResults(mixedStep, dynamicStep, staticStep);
1428 op.getDynamicStepMutable().assign(dynamicStep);
1429 op.setStaticStep(staticStep);
1430
1431 op->setAttr(ForallOp::getOperandSegmentSizeAttr(),
1432 rewriter.getDenseI32ArrayAttr(
1433 {static_cast<int32_t>(dynamicLowerBound.size()),
1434 static_cast<int32_t>(dynamicUpperBound.size()),
1435 static_cast<int32_t>(dynamicStep.size()),
1436 static_cast<int32_t>(op.getNumResults())}));
1437 });
1438 return success();
1439 }
1440};
1441
1442/// The following canonicalization pattern folds the iter arguments of
1443/// scf.forall op if :-
1444/// 1. The corresponding result has zero uses.
1445/// 2. The iter argument is NOT being modified within the loop body.
1446/// uses.
1447///
1448/// Example of first case :-
1449/// INPUT:
1450/// %res:3 = scf.forall ... shared_outs(%arg0 = %a, %arg1 = %b, %arg2 = %c)
1451/// {
1452/// ...
1453/// <SOME USE OF %arg0>
1454/// <SOME USE OF %arg1>
1455/// <SOME USE OF %arg2>
1456/// ...
1457/// scf.forall.in_parallel {
1458/// <STORE OP WITH DESTINATION %arg1>
1459/// <STORE OP WITH DESTINATION %arg0>
1460/// <STORE OP WITH DESTINATION %arg2>
1461/// }
1462/// }
1463/// return %res#1
1464///
1465/// OUTPUT:
1466/// %res:3 = scf.forall ... shared_outs(%new_arg0 = %b)
1467/// {
1468/// ...
1469/// <SOME USE OF %a>
1470/// <SOME USE OF %new_arg0>
1471/// <SOME USE OF %c>
1472/// ...
1473/// scf.forall.in_parallel {
1474/// <STORE OP WITH DESTINATION %new_arg0>
1475/// }
1476/// }
1477/// return %res
1478///
1479/// NOTE: 1. All uses of the folded shared_outs (iter argument) within the
1480/// scf.forall is replaced by their corresponding operands.
1481/// 2. Even if there are <STORE OP WITH DESTINATION *> ops within the body
1482/// of the scf.forall besides within scf.forall.in_parallel terminator,
1483/// this canonicalization remains valid. For more details, please refer
1484/// to :
1485/// https://github.com/llvm/llvm-project/pull/90189#discussion_r1589011124
1486/// 3. TODO(avarma): Generalize it for other store ops. Currently it
1487/// handles tensor.parallel_insert_slice ops only.
1488///
1489/// Example of second case :-
1490/// INPUT:
1491/// %res:2 = scf.forall ... shared_outs(%arg0 = %a, %arg1 = %b)
1492/// {
1493/// ...
1494/// <SOME USE OF %arg0>
1495/// <SOME USE OF %arg1>
1496/// ...
1497/// scf.forall.in_parallel {
1498/// <STORE OP WITH DESTINATION %arg1>
1499/// }
1500/// }
1501/// return %res#0, %res#1
1502///
1503/// OUTPUT:
1504/// %res = scf.forall ... shared_outs(%new_arg0 = %b)
1505/// {
1506/// ...
1507/// <SOME USE OF %a>
1508/// <SOME USE OF %new_arg0>
1509/// ...
1510/// scf.forall.in_parallel {
1511/// <STORE OP WITH DESTINATION %new_arg0>
1512/// }
1513/// }
1514/// return %a, %res
1515struct ForallOpIterArgsFolder : public OpRewritePattern<ForallOp> {
1516 using OpRewritePattern<ForallOp>::OpRewritePattern;
1517
1518 LogicalResult matchAndRewrite(ForallOp forallOp,
1519 PatternRewriter &rewriter) const final {
1520 // Step 1: For a given i-th result of scf.forall, check the following :-
1521 // a. If it has any use.
1522 // b. If the corresponding iter argument is being modified within
1523 // the loop, i.e. has at least one store op with the iter arg as
1524 // its destination operand. For this we use
1525 // ForallOp::getCombiningOps(iter_arg).
1526 //
1527 // Based on the check we maintain the following :-
1528 // a. op results, block arguments, outputs to delete
1529 // b. new outputs (i.e., outputs to retain)
1530 SmallVector<Value> resultsToDelete;
1531 SmallVector<Value> outsToDelete;
1532 SmallVector<BlockArgument> blockArgsToDelete;
1533 SmallVector<Value> newOuts;
1534 BitVector resultIndicesToDelete(forallOp.getNumResults(), false);
1535 BitVector blockIndicesToDelete(forallOp.getBody()->getNumArguments(),
1536 false);
1537 for (OpResult result : forallOp.getResults()) {
1538 OpOperand *opOperand = forallOp.getTiedOpOperand(result);
1539 BlockArgument blockArg = forallOp.getTiedBlockArgument(opOperand);
1540 if (result.use_empty() || forallOp.getCombiningOps(blockArg).empty()) {
1541 resultsToDelete.push_back(result);
1542 outsToDelete.push_back(opOperand->get());
1543 blockArgsToDelete.push_back(blockArg);
1544 resultIndicesToDelete[result.getResultNumber()] = true;
1545 blockIndicesToDelete[blockArg.getArgNumber()] = true;
1546 } else {
1547 newOuts.push_back(opOperand->get());
1548 }
1549 }
1550
1551 // Return early if all results of scf.forall have at least one use and being
1552 // modified within the loop.
1553 if (resultsToDelete.empty())
1554 return failure();
1555
1556 // Step 2: Erase combining ops and replace uses of deleted results and
1557 // block arguments with the corresponding outputs.
1558 for (auto blockArg : blockArgsToDelete) {
1559 SmallVector<Operation *> combiningOps =
1560 forallOp.getCombiningOps(blockArg);
1561 for (Operation *combiningOp : combiningOps)
1562 rewriter.eraseOp(combiningOp);
1563 }
1564 for (auto [blockArg, result, out] :
1565 llvm::zip_equal(blockArgsToDelete, resultsToDelete, outsToDelete)) {
1566 rewriter.replaceAllUsesWith(blockArg, out);
1567 rewriter.replaceAllUsesWith(result, out);
1568 }
1569 // TODO: There is no rewriter API for erasing block arguments.
1570 rewriter.modifyOpInPlace(forallOp, [&]() {
1571 forallOp.getBody()->eraseArguments(blockIndicesToDelete);
1572 });
1573
1574 // Step 3. Create a new scf.forall op with only the shared_outs/results
1575 // that should be retained.
1576 auto newForallOp = cast<scf::ForallOp>(
1577 rewriter.eraseOpResults(forallOp, resultIndicesToDelete));
1578 newForallOp.getOutputsMutable().assign(newOuts);
1579
1580 return success();
1581 }
1582};
1583
1584struct ForallOpSingleOrZeroIterationDimsFolder
1585 : public OpRewritePattern<ForallOp> {
1586 using OpRewritePattern<ForallOp>::OpRewritePattern;
1587
1588 LogicalResult matchAndRewrite(ForallOp op,
1589 PatternRewriter &rewriter) const override {
1590 // Do not fold dimensions if they are mapped to processing units.
1591 if (op.getMapping().has_value() && !op.getMapping()->empty())
1592 return failure();
1593 Location loc = op.getLoc();
1594
1595 // Compute new loop bounds that omit all single-iteration loop dimensions.
1596 SmallVector<OpFoldResult> newMixedLowerBounds, newMixedUpperBounds,
1597 newMixedSteps;
1598 IRMapping mapping;
1599 for (auto [lb, ub, step, iv] :
1600 llvm::zip(op.getMixedLowerBound(), op.getMixedUpperBound(),
1601 op.getMixedStep(), op.getInductionVars())) {
1602 auto numIterations =
1603 constantTripCount(lb, ub, step, /*isSigned=*/true, computeUbMinusLb);
1604 if (numIterations.has_value()) {
1605 // Remove the loop if it performs zero iterations.
1606 if (*numIterations == 0) {
1607 rewriter.replaceOp(op, op.getOutputs());
1608 return success();
1609 }
1610 // Replace the loop induction variable by the lower bound if the loop
1611 // performs a single iteration. Otherwise, copy the loop bounds.
1612 if (*numIterations == 1) {
1613 mapping.map(iv, getValueOrCreateConstantIndexOp(rewriter, loc, lb));
1614 continue;
1615 }
1616 }
1617 newMixedLowerBounds.push_back(lb);
1618 newMixedUpperBounds.push_back(ub);
1619 newMixedSteps.push_back(step);
1620 }
1621
1622 // All of the loop dimensions perform a single iteration. Inline loop body.
1623 if (newMixedLowerBounds.empty()) {
1624 promote(rewriter, op);
1625 return success();
1626 }
1627
1628 // Exit if none of the loop dimensions perform a single iteration.
1629 if (newMixedLowerBounds.size() == static_cast<unsigned>(op.getRank())) {
1630 return rewriter.notifyMatchFailure(
1631 op, "no dimensions have 0 or 1 iterations");
1632 }
1633
1634 // Replace the loop by a lower-dimensional loop.
1635 ForallOp newOp;
1636 newOp = ForallOp::create(rewriter, loc, newMixedLowerBounds,
1637 newMixedUpperBounds, newMixedSteps,
1638 op.getOutputs(), std::nullopt, nullptr);
1639 newOp.getBodyRegion().getBlocks().clear();
1640 // The new loop needs to keep all attributes from the old one, except for
1641 // "operandSegmentSizes" and static loop bound attributes which capture
1642 // the outdated information of the old iteration domain.
1643 SmallVector<StringAttr> elidedAttrs{newOp.getOperandSegmentSizesAttrName(),
1644 newOp.getStaticLowerBoundAttrName(),
1645 newOp.getStaticUpperBoundAttrName(),
1646 newOp.getStaticStepAttrName()};
1647 for (const auto &namedAttr : op->getAttrs()) {
1648 if (llvm::is_contained(elidedAttrs, namedAttr.getName()))
1649 continue;
1650 rewriter.modifyOpInPlace(newOp, [&]() {
1651 newOp->setAttr(namedAttr.getName(), namedAttr.getValue());
1652 });
1653 }
1654 rewriter.cloneRegionBefore(op.getRegion(), newOp.getRegion(),
1655 newOp.getRegion().begin(), mapping);
1656 rewriter.replaceOp(op, newOp.getResults());
1657 return success();
1658 }
1659};
1660
1661/// Replace all induction vars with a single trip count with their lower bound.
1662struct ForallOpReplaceConstantInductionVar : public OpRewritePattern<ForallOp> {
1663 using OpRewritePattern<ForallOp>::OpRewritePattern;
1664
1665 LogicalResult matchAndRewrite(ForallOp op,
1666 PatternRewriter &rewriter) const override {
1667 Location loc = op.getLoc();
1668 bool changed = false;
1669 for (auto [lb, ub, step, iv] :
1670 llvm::zip(op.getMixedLowerBound(), op.getMixedUpperBound(),
1671 op.getMixedStep(), op.getInductionVars())) {
1672 if (iv.hasNUses(0))
1673 continue;
1674 auto numIterations =
1675 constantTripCount(lb, ub, step, /*isSigned=*/true, computeUbMinusLb);
1676 if (!numIterations.has_value() || numIterations.value() != 1) {
1677 continue;
1678 }
1679 rewriter.replaceAllUsesWith(
1680 iv, getValueOrCreateConstantIndexOp(rewriter, loc, lb));
1681 changed = true;
1682 }
1683 return success(changed);
1684 }
1685};
1686
1687struct FoldTensorCastOfOutputIntoForallOp
1688 : public OpRewritePattern<scf::ForallOp> {
1689 using OpRewritePattern<scf::ForallOp>::OpRewritePattern;
1690
1691 struct TypeCast {
1692 Type srcType;
1693 Type dstType;
1694 };
1695
1696 LogicalResult matchAndRewrite(scf::ForallOp forallOp,
1697 PatternRewriter &rewriter) const final {
1698 llvm::SmallMapVector<unsigned, TypeCast, 2> tensorCastProducers;
1699 llvm::SmallVector<Value> newOutputTensors = forallOp.getOutputs();
1700 for (auto en : llvm::enumerate(newOutputTensors)) {
1701 auto castOp = en.value().getDefiningOp<tensor::CastOp>();
1702 if (!castOp)
1703 continue;
1704
1705 // Only casts that that preserve static information, i.e. will make the
1706 // loop result type "more" static than before, will be folded.
1707 if (!tensor::preservesStaticInformation(castOp.getDest().getType(),
1708 castOp.getSource().getType())) {
1709 continue;
1710 }
1711
1712 tensorCastProducers[en.index()] =
1713 TypeCast{castOp.getSource().getType(), castOp.getType()};
1714 newOutputTensors[en.index()] = castOp.getSource();
1715 }
1716
1717 if (tensorCastProducers.empty())
1718 return failure();
1719
1720 // Create new loop.
1721 Location loc = forallOp.getLoc();
1722 auto newForallOp = ForallOp::create(
1723 rewriter, loc, forallOp.getMixedLowerBound(),
1724 forallOp.getMixedUpperBound(), forallOp.getMixedStep(),
1725 newOutputTensors, forallOp.getMapping(),
1726 [&](OpBuilder nestedBuilder, Location nestedLoc, ValueRange bbArgs) {
1727 auto castBlockArgs =
1728 llvm::to_vector(bbArgs.take_back(forallOp->getNumResults()));
1729 for (auto [index, cast] : tensorCastProducers) {
1730 Value &oldTypeBBArg = castBlockArgs[index];
1731 oldTypeBBArg = tensor::CastOp::create(nestedBuilder, nestedLoc,
1732 cast.dstType, oldTypeBBArg);
1733 }
1734
1735 // Move old body into new parallel loop.
1736 SmallVector<Value> ivsBlockArgs =
1737 llvm::to_vector(bbArgs.take_front(forallOp.getRank()));
1738 ivsBlockArgs.append(castBlockArgs);
1739 rewriter.mergeBlocks(forallOp.getBody(),
1740 bbArgs.front().getParentBlock(), ivsBlockArgs);
1741 });
1742
1743 // After `mergeBlocks` happened, the destinations in the terminator were
1744 // mapped to the tensor.cast old-typed results of the output bbArgs. The
1745 // destination have to be updated to point to the output bbArgs directly.
1746 auto terminator = newForallOp.getTerminator();
1747 for (auto [yieldingOp, outputBlockArg] : llvm::zip(
1748 terminator.getYieldingOps(), newForallOp.getRegionIterArgs())) {
1749 if (auto parallelCombingingOp =
1750 dyn_cast<ParallelCombiningOpInterface>(yieldingOp)) {
1751 parallelCombingingOp.getUpdatedDestinations().assign(outputBlockArg);
1752 }
1753 }
1754
1755 // Cast results back to the original types.
1756 rewriter.setInsertionPointAfter(newForallOp);
1757 SmallVector<Value> castResults = newForallOp.getResults();
1758 for (auto &item : tensorCastProducers) {
1759 Value &oldTypeResult = castResults[item.first];
1760 oldTypeResult = tensor::CastOp::create(rewriter, loc, item.second.dstType,
1761 oldTypeResult);
1762 }
1763 rewriter.replaceOp(forallOp, castResults);
1764 return success();
1765 }
1766};
1767
1768} // namespace
1769
1770void ForallOp::getCanonicalizationPatterns(RewritePatternSet &results,
1771 MLIRContext *context) {
1772 results.add<DimOfForallOp, FoldTensorCastOfOutputIntoForallOp,
1773 ForallOpControlOperandsFolder, ForallOpIterArgsFolder,
1774 ForallOpSingleOrZeroIterationDimsFolder,
1775 ForallOpReplaceConstantInductionVar>(context);
1776}
1777
1778void ForallOp::getSuccessorRegions(RegionBranchPoint point,
1779 SmallVectorImpl<RegionSuccessor> &regions) {
1780 // There are two region branch points:
1781 // 1. "parent": entering the forall op for the first time.
1782 // 2. scf.in_parallel terminator
1783 if (point.isParent()) {
1784 // When first entering the forall op, the control flow typically branches
1785 // into the forall body. (In parallel for multiple threads.)
1786 regions.push_back(RegionSuccessor(&getRegion()));
1787 // However, when there are 0 threads, the control flow may branch back to
1788 // the parent immediately.
1789 regions.push_back(RegionSuccessor::parent());
1790 } else {
1791 // In accordance with the semantics of forall, its body is executed in
1792 // parallel by multiple threads. We should not expect to branch back into
1793 // the forall body after the region's execution is complete.
1794 regions.push_back(RegionSuccessor::parent());
1795 }
1796}
1797
1798//===----------------------------------------------------------------------===//
1799// InParallelOp
1800//===----------------------------------------------------------------------===//
1801
1802// Build a InParallelOp with mixed static and dynamic entries.
1803void InParallelOp::build(OpBuilder &b, OperationState &result) {
1804 OpBuilder::InsertionGuard g(b);
1805 Region *bodyRegion = result.addRegion();
1806 b.createBlock(bodyRegion);
1807}
1808
1809LogicalResult InParallelOp::verify() {
1810 scf::ForallOp forallOp =
1811 dyn_cast<scf::ForallOp>(getOperation()->getParentOp());
1812 if (!forallOp)
1813 return this->emitOpError("expected forall op parent");
1814
1815 for (Operation &op : getRegion().front().getOperations()) {
1816 auto parallelCombiningOp = dyn_cast<ParallelCombiningOpInterface>(&op);
1817 if (!parallelCombiningOp) {
1818 return this->emitOpError("expected only ParallelCombiningOpInterface")
1819 << " ops";
1820 }
1821
1822 // Verify that inserts are into out block arguments.
1823 MutableOperandRange dests = parallelCombiningOp.getUpdatedDestinations();
1824 ArrayRef<BlockArgument> regionOutArgs = forallOp.getRegionOutArgs();
1825 for (OpOperand &dest : dests) {
1826 if (!llvm::is_contained(regionOutArgs, dest.get()))
1827 return op.emitOpError("may only insert into an output block argument");
1828 }
1829 }
1830
1831 return success();
1832}
1833
1834void InParallelOp::print(OpAsmPrinter &p) {
1835 p << " ";
1836 p.printRegion(getRegion(),
1837 /*printEntryBlockArgs=*/false,
1838 /*printBlockTerminators=*/false);
1839 p.printOptionalAttrDict(getOperation()->getAttrs());
1840}
1841
1842ParseResult InParallelOp::parse(OpAsmParser &parser, OperationState &result) {
1843 auto &builder = parser.getBuilder();
1844
1845 SmallVector<OpAsmParser::Argument, 8> regionOperands;
1846 std::unique_ptr<Region> region = std::make_unique<Region>();
1847 if (parser.parseRegion(*region, regionOperands))
1848 return failure();
1849
1850 if (region->empty())
1851 OpBuilder(builder.getContext()).createBlock(region.get());
1852 result.addRegion(std::move(region));
1853
1854 // Parse the optional attribute list.
1855 if (parser.parseOptionalAttrDict(result.attributes))
1856 return failure();
1857 return success();
1858}
1859
1860OpResult InParallelOp::getParentResult(int64_t idx) {
1861 return getOperation()->getParentOp()->getResult(idx);
1862}
1863
1864SmallVector<BlockArgument> InParallelOp::getDests() {
1865 SmallVector<BlockArgument> updatedDests;
1866 for (Operation &yieldingOp : getYieldingOps()) {
1867 auto parallelCombiningOp =
1868 dyn_cast<ParallelCombiningOpInterface>(&yieldingOp);
1869 if (!parallelCombiningOp)
1870 continue;
1871 for (OpOperand &updatedOperand :
1872 parallelCombiningOp.getUpdatedDestinations())
1873 updatedDests.push_back(cast<BlockArgument>(updatedOperand.get()));
1874 }
1875 return updatedDests;
1876}
1877
1878llvm::iterator_range<Block::iterator> InParallelOp::getYieldingOps() {
1879 return getRegion().front().getOperations();
1880}
1881
1882//===----------------------------------------------------------------------===//
1883// IfOp
1884//===----------------------------------------------------------------------===//
1885
1887 assert(a && "expected non-empty operation");
1888 assert(b && "expected non-empty operation");
1889
1890 IfOp ifOp = a->getParentOfType<IfOp>();
1891 while (ifOp) {
1892 // Check if b is inside ifOp. (We already know that a is.)
1893 if (ifOp->isProperAncestor(b))
1894 // b is contained in ifOp. a and b are in mutually exclusive branches if
1895 // they are in different blocks of ifOp.
1896 return static_cast<bool>(ifOp.thenBlock()->findAncestorOpInBlock(*a)) !=
1897 static_cast<bool>(ifOp.thenBlock()->findAncestorOpInBlock(*b));
1898 // Check next enclosing IfOp.
1899 ifOp = ifOp->getParentOfType<IfOp>();
1900 }
1901
1902 // Could not find a common IfOp among a's and b's ancestors.
1903 return false;
1904}
1905
1906LogicalResult
1907IfOp::inferReturnTypes(MLIRContext *ctx, std::optional<Location> loc,
1908 IfOp::Adaptor adaptor,
1909 SmallVectorImpl<Type> &inferredReturnTypes) {
1910 if (adaptor.getRegions().empty())
1911 return failure();
1912 Region *r = &adaptor.getThenRegion();
1913 if (r->empty())
1914 return failure();
1915 Block &b = r->front();
1916 if (b.empty())
1917 return failure();
1918 auto yieldOp = llvm::dyn_cast<YieldOp>(b.back());
1919 if (!yieldOp)
1920 return failure();
1921 TypeRange types = yieldOp.getOperandTypes();
1922 llvm::append_range(inferredReturnTypes, types);
1923 return success();
1924}
1925
1926void IfOp::build(OpBuilder &builder, OperationState &result,
1927 TypeRange resultTypes, Value cond) {
1928 return build(builder, result, resultTypes, cond, /*addThenBlock=*/false,
1929 /*addElseBlock=*/false);
1930}
1931
1932void IfOp::build(OpBuilder &builder, OperationState &result,
1933 TypeRange resultTypes, Value cond, bool addThenBlock,
1934 bool addElseBlock) {
1935 assert((!addElseBlock || addThenBlock) &&
1936 "must not create else block w/o then block");
1937 result.addTypes(resultTypes);
1938 result.addOperands(cond);
1939
1940 // Add regions and blocks.
1941 OpBuilder::InsertionGuard guard(builder);
1942 Region *thenRegion = result.addRegion();
1943 if (addThenBlock)
1944 builder.createBlock(thenRegion);
1945 Region *elseRegion = result.addRegion();
1946 if (addElseBlock)
1947 builder.createBlock(elseRegion);
1948}
1949
1950void IfOp::build(OpBuilder &builder, OperationState &result, Value cond,
1951 bool withElseRegion) {
1952 build(builder, result, TypeRange{}, cond, withElseRegion);
1953}
1954
1955void IfOp::build(OpBuilder &builder, OperationState &result,
1956 TypeRange resultTypes, Value cond, bool withElseRegion) {
1957 result.addTypes(resultTypes);
1958 result.addOperands(cond);
1959
1960 // Build then region.
1961 OpBuilder::InsertionGuard guard(builder);
1962 Region *thenRegion = result.addRegion();
1963 builder.createBlock(thenRegion);
1964 if (resultTypes.empty())
1965 IfOp::ensureTerminator(*thenRegion, builder, result.location);
1966
1967 // Build else region.
1968 Region *elseRegion = result.addRegion();
1969 if (withElseRegion) {
1970 builder.createBlock(elseRegion);
1971 if (resultTypes.empty())
1972 IfOp::ensureTerminator(*elseRegion, builder, result.location);
1973 }
1974}
1975
1976void IfOp::build(OpBuilder &builder, OperationState &result, Value cond,
1977 function_ref<void(OpBuilder &, Location)> thenBuilder,
1978 function_ref<void(OpBuilder &, Location)> elseBuilder) {
1979 assert(thenBuilder && "the builder callback for 'then' must be present");
1980 result.addOperands(cond);
1981
1982 // Build then region.
1983 OpBuilder::InsertionGuard guard(builder);
1984 Region *thenRegion = result.addRegion();
1985 builder.createBlock(thenRegion);
1986 thenBuilder(builder, result.location);
1987
1988 // Build else region.
1989 Region *elseRegion = result.addRegion();
1990 if (elseBuilder) {
1991 builder.createBlock(elseRegion);
1992 elseBuilder(builder, result.location);
1993 }
1994
1995 // Infer result types.
1996 SmallVector<Type> inferredReturnTypes;
1997 MLIRContext *ctx = builder.getContext();
1998 auto attrDict = DictionaryAttr::get(ctx, result.attributes);
1999 if (succeeded(inferReturnTypes(ctx, std::nullopt, result.operands, attrDict,
2000 /*properties=*/nullptr, result.regions,
2001 inferredReturnTypes))) {
2002 result.addTypes(inferredReturnTypes);
2003 }
2004}
2005
2006LogicalResult IfOp::verify() {
2007 if (getNumResults() != 0 && getElseRegion().empty())
2008 return emitOpError("must have an else block if defining values");
2009 return success();
2010}
2011
2012ParseResult IfOp::parse(OpAsmParser &parser, OperationState &result) {
2013 // Create the regions for 'then'.
2014 result.regions.reserve(2);
2015 Region *thenRegion = result.addRegion();
2016 Region *elseRegion = result.addRegion();
2017
2018 auto &builder = parser.getBuilder();
2019 OpAsmParser::UnresolvedOperand cond;
2020 Type i1Type = builder.getIntegerType(1);
2021 if (parser.parseOperand(cond) ||
2022 parser.resolveOperand(cond, i1Type, result.operands))
2023 return failure();
2024 // Parse optional results type list.
2025 if (parser.parseOptionalArrowTypeList(result.types))
2026 return failure();
2027 // Parse the 'then' region.
2028 if (parser.parseRegion(*thenRegion, /*arguments=*/{}, /*argTypes=*/{}))
2029 return failure();
2030 IfOp::ensureTerminator(*thenRegion, parser.getBuilder(), result.location);
2031
2032 // If we find an 'else' keyword then parse the 'else' region.
2033 if (!parser.parseOptionalKeyword("else")) {
2034 if (parser.parseRegion(*elseRegion, /*arguments=*/{}, /*argTypes=*/{}))
2035 return failure();
2036 IfOp::ensureTerminator(*elseRegion, parser.getBuilder(), result.location);
2037 }
2038
2039 // Parse the optional attribute list.
2040 if (parser.parseOptionalAttrDict(result.attributes))
2041 return failure();
2042 return success();
2043}
2044
2045void IfOp::print(OpAsmPrinter &p) {
2046 bool printBlockTerminators = false;
2047
2048 p << " " << getCondition();
2049 if (!getResults().empty()) {
2050 p << " -> (" << getResultTypes() << ")";
2051 // Print yield explicitly if the op defines values.
2052 printBlockTerminators = true;
2053 }
2054 p << ' ';
2055 p.printRegion(getThenRegion(),
2056 /*printEntryBlockArgs=*/false,
2057 /*printBlockTerminators=*/printBlockTerminators);
2058
2059 // Print the 'else' regions if it exists and has a block.
2060 auto &elseRegion = getElseRegion();
2061 if (!elseRegion.empty()) {
2062 p << " else ";
2063 p.printRegion(elseRegion,
2064 /*printEntryBlockArgs=*/false,
2065 /*printBlockTerminators=*/printBlockTerminators);
2066 }
2067
2068 p.printOptionalAttrDict((*this)->getAttrs());
2069}
2070
2071void IfOp::getSuccessorRegions(RegionBranchPoint point,
2072 SmallVectorImpl<RegionSuccessor> &regions) {
2073 // The `then` and the `else` region branch back to the parent operation or one
2074 // of the recursive parent operations (early exit case).
2075 if (!point.isParent()) {
2076 regions.push_back(RegionSuccessor::parent());
2077 return;
2078 }
2079
2080 regions.push_back(RegionSuccessor(&getThenRegion()));
2081
2082 // Don't consider the else region if it is empty.
2083 Region *elseRegion = &this->getElseRegion();
2084 if (elseRegion->empty())
2085 regions.push_back(RegionSuccessor::parent());
2086 else
2087 regions.push_back(RegionSuccessor(elseRegion));
2088}
2089
2090ValueRange IfOp::getSuccessorInputs(RegionSuccessor successor) {
2091 return successor.isParent() ? ValueRange(getOperation()->getResults())
2092 : ValueRange();
2093}
2094
2095void IfOp::getEntrySuccessorRegions(ArrayRef<Attribute> operands,
2096 SmallVectorImpl<RegionSuccessor> &regions) {
2097 FoldAdaptor adaptor(operands, *this);
2098 auto boolAttr = dyn_cast_or_null<BoolAttr>(adaptor.getCondition());
2099 if (!boolAttr || boolAttr.getValue())
2100 regions.emplace_back(&getThenRegion());
2101
2102 // If the else region is empty, execution continues after the parent op.
2103 if (!boolAttr || !boolAttr.getValue()) {
2104 if (!getElseRegion().empty())
2105 regions.emplace_back(&getElseRegion());
2106 else
2107 regions.emplace_back(RegionSuccessor::parent());
2108 }
2109}
2110
2111LogicalResult IfOp::fold(FoldAdaptor adaptor,
2112 SmallVectorImpl<OpFoldResult> &results) {
2113 // if (!c) then A() else B() -> if c then B() else A()
2114 if (getElseRegion().empty())
2115 return failure();
2116
2117 arith::XOrIOp xorStmt = getCondition().getDefiningOp<arith::XOrIOp>();
2118 if (!xorStmt)
2119 return failure();
2120
2121 if (!matchPattern(xorStmt.getRhs(), m_One()))
2122 return failure();
2123
2124 getConditionMutable().assign(xorStmt.getLhs());
2125 Block *thenBlock = &getThenRegion().front();
2126 // It would be nicer to use iplist::swap, but that has no implemented
2127 // callbacks See: https://llvm.org/doxygen/ilist_8h_source.html#l00224
2128 getThenRegion().getBlocks().splice(getThenRegion().getBlocks().begin(),
2129 getElseRegion().getBlocks());
2130 getElseRegion().getBlocks().splice(getElseRegion().getBlocks().begin(),
2131 getThenRegion().getBlocks(), thenBlock);
2132 return success();
2133}
2134
2135void IfOp::getRegionInvocationBounds(
2136 ArrayRef<Attribute> operands,
2137 SmallVectorImpl<InvocationBounds> &invocationBounds) {
2138 if (auto cond = llvm::dyn_cast_or_null<BoolAttr>(operands[0])) {
2139 // If the condition is known, then one region is known to be executed once
2140 // and the other zero times.
2141 invocationBounds.emplace_back(0, cond.getValue() ? 1 : 0);
2142 invocationBounds.emplace_back(0, cond.getValue() ? 0 : 1);
2143 } else {
2144 // Non-constant condition. Each region may be executed 0 or 1 times.
2145 invocationBounds.assign(2, {0, 1});
2146 }
2147}
2148
2149namespace {
2150/// Hoist any yielded results whose operands are defined outside
2151/// the if, to a select instruction.
2152struct ConvertTrivialIfToSelect : public OpRewritePattern<IfOp> {
2153 using OpRewritePattern<IfOp>::OpRewritePattern;
2154
2155 LogicalResult matchAndRewrite(IfOp op,
2156 PatternRewriter &rewriter) const override {
2157 if (op->getNumResults() == 0)
2158 return failure();
2159
2160 auto cond = op.getCondition();
2161 auto thenYieldArgs = op.thenYield().getOperands();
2162 auto elseYieldArgs = op.elseYield().getOperands();
2163
2164 SmallVector<Type> nonHoistable;
2165 for (auto [trueVal, falseVal] : llvm::zip(thenYieldArgs, elseYieldArgs)) {
2166 if (&op.getThenRegion() == trueVal.getParentRegion() ||
2167 &op.getElseRegion() == falseVal.getParentRegion())
2168 nonHoistable.push_back(trueVal.getType());
2169 }
2170 // Early exit if there aren't any yielded values we can
2171 // hoist outside the if.
2172 if (nonHoistable.size() == op->getNumResults())
2173 return failure();
2174
2175 IfOp replacement = IfOp::create(rewriter, op.getLoc(), nonHoistable, cond,
2176 /*withElseRegion=*/false);
2177 if (replacement.thenBlock())
2178 rewriter.eraseBlock(replacement.thenBlock());
2179 replacement.getThenRegion().takeBody(op.getThenRegion());
2180 replacement.getElseRegion().takeBody(op.getElseRegion());
2181
2182 SmallVector<Value> results(op->getNumResults());
2183 assert(thenYieldArgs.size() == results.size());
2184 assert(elseYieldArgs.size() == results.size());
2185
2186 SmallVector<Value> trueYields;
2187 SmallVector<Value> falseYields;
2189 for (const auto &it :
2190 llvm::enumerate(llvm::zip(thenYieldArgs, elseYieldArgs))) {
2191 Value trueVal = std::get<0>(it.value());
2192 Value falseVal = std::get<1>(it.value());
2193 if (&replacement.getThenRegion() == trueVal.getParentRegion() ||
2194 &replacement.getElseRegion() == falseVal.getParentRegion()) {
2195 results[it.index()] = replacement.getResult(trueYields.size());
2196 trueYields.push_back(trueVal);
2197 falseYields.push_back(falseVal);
2198 } else if (trueVal == falseVal)
2199 results[it.index()] = trueVal;
2200 else
2201 results[it.index()] = arith::SelectOp::create(rewriter, op.getLoc(),
2202 cond, trueVal, falseVal);
2203 }
2204
2205 rewriter.setInsertionPointToEnd(replacement.thenBlock());
2206 rewriter.replaceOpWithNewOp<YieldOp>(replacement.thenYield(), trueYields);
2207
2208 rewriter.setInsertionPointToEnd(replacement.elseBlock());
2209 rewriter.replaceOpWithNewOp<YieldOp>(replacement.elseYield(), falseYields);
2210
2211 rewriter.replaceOp(op, results);
2212 return success();
2213 }
2214};
2215
2216/// Allow the true region of an if to assume the condition is true
2217/// and vice versa. For example:
2218///
2219/// scf.if %cmp {
2220/// print(%cmp)
2221/// }
2222///
2223/// becomes
2224///
2225/// scf.if %cmp {
2226/// print(true)
2227/// }
2228///
2229struct ConditionPropagation : public OpRewritePattern<IfOp> {
2230 using OpRewritePattern<IfOp>::OpRewritePattern;
2231
2232 /// Kind of parent region in the ancestor cache.
2233 enum class Parent { Then, Else, None };
2234
2235 /// Returns the kind of region ("then", "else", or "none") of the
2236 /// IfOp that the given region is transitively nested in. Updates
2237 /// the cache accordingly.
2238 static Parent getParentType(Region *toCheck, IfOp op,
2240 Region *endRegion) {
2241 SmallVector<Region *> seen;
2242 while (toCheck != endRegion) {
2243 auto found = cache.find(toCheck);
2244 if (found != cache.end())
2245 return found->second;
2246 seen.push_back(toCheck);
2247 if (&op.getThenRegion() == toCheck) {
2248 for (Region *region : seen)
2249 cache[region] = Parent::Then;
2250 return Parent::Then;
2251 }
2252 if (&op.getElseRegion() == toCheck) {
2253 for (Region *region : seen)
2254 cache[region] = Parent::Else;
2255 return Parent::Else;
2256 }
2257 toCheck = toCheck->getParentRegion();
2258 }
2259
2260 for (Region *region : seen)
2261 cache[region] = Parent::None;
2262 return Parent::None;
2263 }
2264
2265 LogicalResult matchAndRewrite(IfOp op,
2266 PatternRewriter &rewriter) const override {
2267 // Early exit if the condition is constant since replacing a constant
2268 // in the body with another constant isn't a simplification.
2269 if (matchPattern(op.getCondition(), m_Constant()))
2270 return failure();
2271
2272 bool changed = false;
2273 mlir::Type i1Ty = rewriter.getI1Type();
2274
2275 // These variables serve to prevent creating duplicate constants
2276 // and hold constant true or false values.
2277 Value constantTrue = nullptr;
2278 Value constantFalse = nullptr;
2279
2281 for (OpOperand &use :
2282 llvm::make_early_inc_range(op.getCondition().getUses())) {
2283 switch (getParentType(use.getOwner()->getParentRegion(), op, cache,
2284 op.getCondition().getParentRegion())) {
2285 case Parent::Then: {
2286 changed = true;
2287
2288 if (!constantTrue)
2289 constantTrue = arith::ConstantOp::create(
2290 rewriter, op.getLoc(), i1Ty, rewriter.getIntegerAttr(i1Ty, 1));
2291
2292 rewriter.modifyOpInPlace(use.getOwner(),
2293 [&]() { use.set(constantTrue); });
2294 break;
2295 }
2296 case Parent::Else: {
2297 changed = true;
2298
2299 if (!constantFalse)
2300 constantFalse = arith::ConstantOp::create(
2301 rewriter, op.getLoc(), i1Ty, rewriter.getIntegerAttr(i1Ty, 0));
2302
2303 rewriter.modifyOpInPlace(use.getOwner(),
2304 [&]() { use.set(constantFalse); });
2305 break;
2306 }
2307 case Parent::None:
2308 break;
2309 }
2310 }
2311
2312 return success(changed);
2313 }
2314};
2315
2316/// Remove any statements from an if that are equivalent to the condition
2317/// or its negation. For example:
2318///
2319/// %res:2 = scf.if %cmp {
2320/// yield something(), true
2321/// } else {
2322/// yield something2(), false
2323/// }
2324/// print(%res#1)
2325///
2326/// becomes
2327/// %res = scf.if %cmp {
2328/// yield something()
2329/// } else {
2330/// yield something2()
2331/// }
2332/// print(%cmp)
2333///
2334/// Additionally if both branches yield the same value, replace all uses
2335/// of the result with the yielded value.
2336///
2337/// %res:2 = scf.if %cmp {
2338/// yield something(), %arg1
2339/// } else {
2340/// yield something2(), %arg1
2341/// }
2342/// print(%res#1)
2343///
2344/// becomes
2345/// %res = scf.if %cmp {
2346/// yield something()
2347/// } else {
2348/// yield something2()
2349/// }
2350/// print(%arg1)
2351///
2352struct ReplaceIfYieldWithConditionOrValue : public OpRewritePattern<IfOp> {
2353 using OpRewritePattern<IfOp>::OpRewritePattern;
2354
2355 LogicalResult matchAndRewrite(IfOp op,
2356 PatternRewriter &rewriter) const override {
2357 // Early exit if there are no results that could be replaced.
2358 if (op.getNumResults() == 0)
2359 return failure();
2360
2361 auto trueYield =
2362 cast<scf::YieldOp>(op.getThenRegion().back().getTerminator());
2363 auto falseYield =
2364 cast<scf::YieldOp>(op.getElseRegion().back().getTerminator());
2365
2366 rewriter.setInsertionPoint(op->getBlock(),
2367 op.getOperation()->getIterator());
2368 bool changed = false;
2369 Type i1Ty = rewriter.getI1Type();
2370 for (auto [trueResult, falseResult, opResult] :
2371 llvm::zip(trueYield.getResults(), falseYield.getResults(),
2372 op.getResults())) {
2373 if (trueResult == falseResult) {
2374 if (!opResult.use_empty()) {
2375 opResult.replaceAllUsesWith(trueResult);
2376 changed = true;
2377 }
2378 continue;
2379 }
2380
2381 BoolAttr trueYield, falseYield;
2382 if (!matchPattern(trueResult, m_Constant(&trueYield)) ||
2383 !matchPattern(falseResult, m_Constant(&falseYield)))
2384 continue;
2385
2386 bool trueVal = trueYield.getValue();
2387 bool falseVal = falseYield.getValue();
2388 if (!trueVal && falseVal) {
2389 if (!opResult.use_empty()) {
2390 Dialect *constDialect = trueResult.getDefiningOp()->getDialect();
2391 Value notCond = arith::XOrIOp::create(
2392 rewriter, op.getLoc(), op.getCondition(),
2393 constDialect
2394 ->materializeConstant(rewriter,
2395 rewriter.getIntegerAttr(i1Ty, 1), i1Ty,
2396 op.getLoc())
2397 ->getResult(0));
2398 opResult.replaceAllUsesWith(notCond);
2399 changed = true;
2400 }
2401 }
2402 if (trueVal && !falseVal) {
2403 if (!opResult.use_empty()) {
2404 opResult.replaceAllUsesWith(op.getCondition());
2405 changed = true;
2406 }
2407 }
2408 }
2409 return success(changed);
2410 }
2411};
2412
2413/// Merge any consecutive scf.if's with the same condition.
2414///
2415/// scf.if %cond {
2416/// firstCodeTrue();...
2417/// } else {
2418/// firstCodeFalse();...
2419/// }
2420/// %res = scf.if %cond {
2421/// secondCodeTrue();...
2422/// } else {
2423/// secondCodeFalse();...
2424/// }
2425///
2426/// becomes
2427/// %res = scf.if %cmp {
2428/// firstCodeTrue();...
2429/// secondCodeTrue();...
2430/// } else {
2431/// firstCodeFalse();...
2432/// secondCodeFalse();...
2433/// }
2434struct CombineIfs : public OpRewritePattern<IfOp> {
2435 using OpRewritePattern<IfOp>::OpRewritePattern;
2436
2437 LogicalResult matchAndRewrite(IfOp nextIf,
2438 PatternRewriter &rewriter) const override {
2439 Block *parent = nextIf->getBlock();
2440 if (nextIf == &parent->front())
2441 return failure();
2442
2443 auto prevIf = dyn_cast<IfOp>(nextIf->getPrevNode());
2444 if (!prevIf)
2445 return failure();
2446
2447 // Determine the logical then/else blocks when prevIf's
2448 // condition is used. Null means the block does not exist
2449 // in that case (e.g. empty else). If neither of these
2450 // are set, the two conditions cannot be compared.
2451 Block *nextThen = nullptr;
2452 Block *nextElse = nullptr;
2453 if (nextIf.getCondition() == prevIf.getCondition()) {
2454 nextThen = nextIf.thenBlock();
2455 if (!nextIf.getElseRegion().empty())
2456 nextElse = nextIf.elseBlock();
2457 }
2458 if (arith::XOrIOp notv =
2459 nextIf.getCondition().getDefiningOp<arith::XOrIOp>()) {
2460 if (notv.getLhs() == prevIf.getCondition() &&
2461 matchPattern(notv.getRhs(), m_One())) {
2462 nextElse = nextIf.thenBlock();
2463 if (!nextIf.getElseRegion().empty())
2464 nextThen = nextIf.elseBlock();
2465 }
2466 }
2467 if (arith::XOrIOp notv =
2468 prevIf.getCondition().getDefiningOp<arith::XOrIOp>()) {
2469 if (notv.getLhs() == nextIf.getCondition() &&
2470 matchPattern(notv.getRhs(), m_One())) {
2471 nextElse = nextIf.thenBlock();
2472 if (!nextIf.getElseRegion().empty())
2473 nextThen = nextIf.elseBlock();
2474 }
2475 }
2476
2477 if (!nextThen && !nextElse)
2478 return failure();
2479
2480 SmallVector<Value> prevElseYielded;
2481 if (!prevIf.getElseRegion().empty())
2482 prevElseYielded = prevIf.elseYield().getOperands();
2483 // Replace all uses of return values of op within nextIf with the
2484 // corresponding yields
2485 for (auto it : llvm::zip(prevIf.getResults(),
2486 prevIf.thenYield().getOperands(), prevElseYielded))
2487 for (OpOperand &use :
2488 llvm::make_early_inc_range(std::get<0>(it).getUses())) {
2489 if (nextThen && nextThen->getParent()->isAncestor(
2490 use.getOwner()->getParentRegion())) {
2491 rewriter.startOpModification(use.getOwner());
2492 use.set(std::get<1>(it));
2493 rewriter.finalizeOpModification(use.getOwner());
2494 } else if (nextElse && nextElse->getParent()->isAncestor(
2495 use.getOwner()->getParentRegion())) {
2496 rewriter.startOpModification(use.getOwner());
2497 use.set(std::get<2>(it));
2498 rewriter.finalizeOpModification(use.getOwner());
2499 }
2500 }
2501
2502 SmallVector<Type> mergedTypes(prevIf.getResultTypes());
2503 llvm::append_range(mergedTypes, nextIf.getResultTypes());
2504
2505 IfOp combinedIf = IfOp::create(rewriter, nextIf.getLoc(), mergedTypes,
2506 prevIf.getCondition(), /*hasElse=*/false);
2507 rewriter.eraseBlock(&combinedIf.getThenRegion().back());
2508
2509 rewriter.inlineRegionBefore(prevIf.getThenRegion(),
2510 combinedIf.getThenRegion(),
2511 combinedIf.getThenRegion().begin());
2512
2513 if (nextThen) {
2514 YieldOp thenYield = combinedIf.thenYield();
2515 YieldOp thenYield2 = cast<YieldOp>(nextThen->getTerminator());
2516 rewriter.mergeBlocks(nextThen, combinedIf.thenBlock());
2517 rewriter.setInsertionPointToEnd(combinedIf.thenBlock());
2518
2519 SmallVector<Value> mergedYields(thenYield.getOperands());
2520 llvm::append_range(mergedYields, thenYield2.getOperands());
2521 YieldOp::create(rewriter, thenYield2.getLoc(), mergedYields);
2522 rewriter.eraseOp(thenYield);
2523 rewriter.eraseOp(thenYield2);
2524 }
2525
2526 rewriter.inlineRegionBefore(prevIf.getElseRegion(),
2527 combinedIf.getElseRegion(),
2528 combinedIf.getElseRegion().begin());
2529
2530 if (nextElse) {
2531 if (combinedIf.getElseRegion().empty()) {
2532 rewriter.inlineRegionBefore(*nextElse->getParent(),
2533 combinedIf.getElseRegion(),
2534 combinedIf.getElseRegion().begin());
2535 } else {
2536 YieldOp elseYield = combinedIf.elseYield();
2537 YieldOp elseYield2 = cast<YieldOp>(nextElse->getTerminator());
2538 rewriter.mergeBlocks(nextElse, combinedIf.elseBlock());
2539
2540 rewriter.setInsertionPointToEnd(combinedIf.elseBlock());
2541
2542 SmallVector<Value> mergedElseYields(elseYield.getOperands());
2543 llvm::append_range(mergedElseYields, elseYield2.getOperands());
2544
2545 YieldOp::create(rewriter, elseYield2.getLoc(), mergedElseYields);
2546 rewriter.eraseOp(elseYield);
2547 rewriter.eraseOp(elseYield2);
2548 }
2549 }
2550
2551 SmallVector<Value> prevValues;
2552 SmallVector<Value> nextValues;
2553 for (const auto &pair : llvm::enumerate(combinedIf.getResults())) {
2554 if (pair.index() < prevIf.getNumResults())
2555 prevValues.push_back(pair.value());
2556 else
2557 nextValues.push_back(pair.value());
2558 }
2559 rewriter.replaceOp(prevIf, prevValues);
2560 rewriter.replaceOp(nextIf, nextValues);
2561 return success();
2562 }
2563};
2564
2565/// Pattern to remove an empty else branch.
2566struct RemoveEmptyElseBranch : public OpRewritePattern<IfOp> {
2567 using OpRewritePattern<IfOp>::OpRewritePattern;
2568
2569 LogicalResult matchAndRewrite(IfOp ifOp,
2570 PatternRewriter &rewriter) const override {
2571 // Cannot remove else region when there are operation results.
2572 if (ifOp.getNumResults())
2573 return failure();
2574 Block *elseBlock = ifOp.elseBlock();
2575 if (!elseBlock || !llvm::hasSingleElement(*elseBlock))
2576 return failure();
2577 auto newIfOp = rewriter.cloneWithoutRegions(ifOp);
2578 rewriter.inlineRegionBefore(ifOp.getThenRegion(), newIfOp.getThenRegion(),
2579 newIfOp.getThenRegion().begin());
2580 rewriter.eraseOp(ifOp);
2581 return success();
2582 }
2583};
2584
2585/// Convert nested `if`s into `arith.andi` + single `if`.
2586///
2587/// scf.if %arg0 {
2588/// scf.if %arg1 {
2589/// ...
2590/// scf.yield
2591/// }
2592/// scf.yield
2593/// }
2594/// becomes
2595///
2596/// %0 = arith.andi %arg0, %arg1
2597/// scf.if %0 {
2598/// ...
2599/// scf.yield
2600/// }
2601struct CombineNestedIfs : public OpRewritePattern<IfOp> {
2602 using OpRewritePattern<IfOp>::OpRewritePattern;
2603
2604 LogicalResult matchAndRewrite(IfOp op,
2605 PatternRewriter &rewriter) const override {
2606 auto nestedOps = op.thenBlock()->without_terminator();
2607 // Nested `if` must be the only op in block.
2608 if (!llvm::hasSingleElement(nestedOps))
2609 return failure();
2610
2611 // If there is an else block, it can only yield
2612 if (op.elseBlock() && !llvm::hasSingleElement(*op.elseBlock()))
2613 return failure();
2614
2615 auto nestedIf = dyn_cast<IfOp>(*nestedOps.begin());
2616 if (!nestedIf)
2617 return failure();
2618
2619 if (nestedIf.elseBlock() && !llvm::hasSingleElement(*nestedIf.elseBlock()))
2620 return failure();
2621
2622 SmallVector<Value> thenYield(op.thenYield().getOperands());
2623 SmallVector<Value> elseYield;
2624 if (op.elseBlock())
2625 llvm::append_range(elseYield, op.elseYield().getOperands());
2626
2627 // A list of indices for which we should upgrade the value yielded
2628 // in the else to a select.
2629 SmallVector<unsigned> elseYieldsToUpgradeToSelect;
2630
2631 // If the outer scf.if yields a value produced by the inner scf.if,
2632 // only permit combining if the value yielded when the condition
2633 // is false in the outer scf.if is the same value yielded when the
2634 // inner scf.if condition is false.
2635 // Note that the array access to elseYield will not go out of bounds
2636 // since it must have the same length as thenYield, since they both
2637 // come from the same scf.if.
2638 for (const auto &tup : llvm::enumerate(thenYield)) {
2639 if (tup.value().getDefiningOp() == nestedIf) {
2640 auto nestedIdx = llvm::cast<OpResult>(tup.value()).getResultNumber();
2641 if (nestedIf.elseYield().getOperand(nestedIdx) !=
2642 elseYield[tup.index()]) {
2643 return failure();
2644 }
2645 // If the correctness test passes, we will yield
2646 // corresponding value from the inner scf.if
2647 thenYield[tup.index()] = nestedIf.thenYield().getOperand(nestedIdx);
2648 continue;
2649 }
2650
2651 // Otherwise, we need to ensure the else block of the combined
2652 // condition still returns the same value when the outer condition is
2653 // true and the inner condition is false. This can be accomplished if
2654 // the then value is defined outside the outer scf.if and we replace the
2655 // value with a select that considers just the outer condition. Since
2656 // the else region contains just the yield, its yielded value is
2657 // defined outside the scf.if, by definition.
2658
2659 // If the then value is defined within the scf.if, bail.
2660 if (tup.value().getParentRegion() == &op.getThenRegion()) {
2661 return failure();
2662 }
2663 elseYieldsToUpgradeToSelect.push_back(tup.index());
2664 }
2665
2666 Location loc = op.getLoc();
2667 Value newCondition = arith::AndIOp::create(rewriter, loc, op.getCondition(),
2668 nestedIf.getCondition());
2669 auto newIf = IfOp::create(rewriter, loc, op.getResultTypes(), newCondition);
2670 Block *newIfBlock = rewriter.createBlock(&newIf.getThenRegion());
2671
2672 SmallVector<Value> results;
2673 llvm::append_range(results, newIf.getResults());
2674 rewriter.setInsertionPoint(newIf);
2675
2676 for (auto idx : elseYieldsToUpgradeToSelect)
2677 results[idx] =
2678 arith::SelectOp::create(rewriter, op.getLoc(), op.getCondition(),
2679 thenYield[idx], elseYield[idx]);
2680
2681 rewriter.mergeBlocks(nestedIf.thenBlock(), newIfBlock);
2682 rewriter.setInsertionPointToEnd(newIf.thenBlock());
2683 rewriter.replaceOpWithNewOp<YieldOp>(newIf.thenYield(), thenYield);
2684 if (!elseYield.empty()) {
2685 rewriter.createBlock(&newIf.getElseRegion());
2686 rewriter.setInsertionPointToEnd(newIf.elseBlock());
2687 YieldOp::create(rewriter, loc, elseYield);
2688 }
2689 rewriter.replaceOp(op, results);
2690 return success();
2691 }
2692};
2693
2694} // namespace
2695
2696void IfOp::getCanonicalizationPatterns(RewritePatternSet &results,
2697 MLIRContext *context) {
2698 results.add<CombineIfs, CombineNestedIfs, ConditionPropagation,
2699 ConvertTrivialIfToSelect, RemoveEmptyElseBranch,
2700 ReplaceIfYieldWithConditionOrValue>(context);
2702 results, IfOp::getOperationName());
2704 IfOp::getOperationName());
2705}
2706
2707Block *IfOp::thenBlock() { return &getThenRegion().back(); }
2708YieldOp IfOp::thenYield() { return cast<YieldOp>(&thenBlock()->back()); }
2709Block *IfOp::elseBlock() {
2710 Region &r = getElseRegion();
2711 if (r.empty())
2712 return nullptr;
2713 return &r.back();
2714}
2715YieldOp IfOp::elseYield() { return cast<YieldOp>(&elseBlock()->back()); }
2716
2717//===----------------------------------------------------------------------===//
2718// ParallelOp
2719//===----------------------------------------------------------------------===//
2720
2721void ParallelOp::build(
2722 OpBuilder &builder, OperationState &result, ValueRange lowerBounds,
2723 ValueRange upperBounds, ValueRange steps, ValueRange initVals,
2724 function_ref<void(OpBuilder &, Location, ValueRange, ValueRange)>
2725 bodyBuilderFn) {
2726 result.addOperands(lowerBounds);
2727 result.addOperands(upperBounds);
2728 result.addOperands(steps);
2729 result.addOperands(initVals);
2730 result.addAttribute(
2731 ParallelOp::getOperandSegmentSizeAttr(),
2732 builder.getDenseI32ArrayAttr({static_cast<int32_t>(lowerBounds.size()),
2733 static_cast<int32_t>(upperBounds.size()),
2734 static_cast<int32_t>(steps.size()),
2735 static_cast<int32_t>(initVals.size())}));
2736 result.addTypes(initVals.getTypes());
2737
2738 OpBuilder::InsertionGuard guard(builder);
2739 unsigned numIVs = steps.size();
2740 SmallVector<Type, 8> argTypes(numIVs, builder.getIndexType());
2741 SmallVector<Location, 8> argLocs(numIVs, result.location);
2742 Region *bodyRegion = result.addRegion();
2743 Block *bodyBlock = builder.createBlock(bodyRegion, {}, argTypes, argLocs);
2744
2745 if (bodyBuilderFn) {
2746 builder.setInsertionPointToStart(bodyBlock);
2747 bodyBuilderFn(builder, result.location,
2748 bodyBlock->getArguments().take_front(numIVs),
2749 bodyBlock->getArguments().drop_front(numIVs));
2750 }
2751 // Add terminator only if there are no reductions.
2752 if (initVals.empty())
2753 ParallelOp::ensureTerminator(*bodyRegion, builder, result.location);
2754}
2755
2756void ParallelOp::build(
2757 OpBuilder &builder, OperationState &result, ValueRange lowerBounds,
2758 ValueRange upperBounds, ValueRange steps,
2759 function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilderFn) {
2760 // Only pass a non-null wrapper if bodyBuilderFn is non-null itself. Make sure
2761 // we don't capture a reference to a temporary by constructing the lambda at
2762 // function level.
2763 auto wrappedBuilderFn = [&bodyBuilderFn](OpBuilder &nestedBuilder,
2764 Location nestedLoc, ValueRange ivs,
2765 ValueRange) {
2766 bodyBuilderFn(nestedBuilder, nestedLoc, ivs);
2767 };
2768 function_ref<void(OpBuilder &, Location, ValueRange, ValueRange)> wrapper;
2769 if (bodyBuilderFn)
2770 wrapper = wrappedBuilderFn;
2771
2772 build(builder, result, lowerBounds, upperBounds, steps, ValueRange(),
2773 wrapper);
2774}
2775
2776LogicalResult ParallelOp::verify() {
2777 // Check that there is at least one value in lowerBound, upperBound and step.
2778 // It is sufficient to test only step, because it is ensured already that the
2779 // number of elements in lowerBound, upperBound and step are the same.
2780 Operation::operand_range stepValues = getStep();
2781 if (stepValues.empty())
2782 return emitOpError(
2783 "needs at least one tuple element for lowerBound, upperBound and step");
2784
2785 // Check whether all constant step values are positive.
2786 for (Value stepValue : stepValues)
2787 if (auto cst = getConstantIntValue(stepValue))
2788 if (*cst <= 0)
2789 return emitOpError("constant step operand must be positive");
2790
2791 // Check that the body defines the same number of block arguments as the
2792 // number of tuple elements in step.
2793 Block *body = getBody();
2794 if (body->getNumArguments() != stepValues.size())
2795 return emitOpError() << "expects the same number of induction variables: "
2796 << body->getNumArguments()
2797 << " as bound and step values: " << stepValues.size();
2798 for (auto arg : body->getArguments())
2799 if (!arg.getType().isIndex())
2800 return emitOpError(
2801 "expects arguments for the induction variable to be of index type");
2802
2803 // Check that the terminator is an scf.reduce op.
2805 *this, getRegion(), "expects body to terminate with 'scf.reduce'");
2806 if (!reduceOp)
2807 return failure();
2808
2809 // Check that the number of results is the same as the number of reductions.
2810 auto resultsSize = getResults().size();
2811 auto reductionsSize = reduceOp.getReductions().size();
2812 auto initValsSize = getInitVals().size();
2813 if (resultsSize != reductionsSize)
2814 return emitOpError() << "expects number of results: " << resultsSize
2815 << " to be the same as number of reductions: "
2816 << reductionsSize;
2817 if (resultsSize != initValsSize)
2818 return emitOpError() << "expects number of results: " << resultsSize
2819 << " to be the same as number of initial values: "
2820 << initValsSize;
2821 if (reduceOp.getNumOperands() != initValsSize)
2822 // Delegate error reporting to ReduceOp
2823 return success();
2824
2825 // Check that the types of the results and reductions are the same.
2826 for (int64_t i = 0; i < static_cast<int64_t>(reductionsSize); ++i) {
2827 auto resultType = getOperation()->getResult(i).getType();
2828 auto reductionOperandType = reduceOp.getOperands()[i].getType();
2829 if (resultType != reductionOperandType)
2830 return reduceOp.emitOpError()
2831 << "expects type of " << i
2832 << "-th reduction operand: " << reductionOperandType
2833 << " to be the same as the " << i
2834 << "-th result type: " << resultType;
2835 }
2836 return success();
2837}
2838
2839ParseResult ParallelOp::parse(OpAsmParser &parser, OperationState &result) {
2840 auto &builder = parser.getBuilder();
2841 // Parse an opening `(` followed by induction variables followed by `)`
2842 SmallVector<OpAsmParser::Argument, 4> ivs;
2844 return failure();
2845
2846 // Parse loop bounds.
2847 SmallVector<OpAsmParser::UnresolvedOperand, 4> lower;
2848 if (parser.parseEqual() ||
2849 parser.parseOperandList(lower, ivs.size(),
2851 parser.resolveOperands(lower, builder.getIndexType(), result.operands))
2852 return failure();
2853
2854 SmallVector<OpAsmParser::UnresolvedOperand, 4> upper;
2855 if (parser.parseKeyword("to") ||
2856 parser.parseOperandList(upper, ivs.size(),
2858 parser.resolveOperands(upper, builder.getIndexType(), result.operands))
2859 return failure();
2860
2861 // Parse step values.
2862 SmallVector<OpAsmParser::UnresolvedOperand, 4> steps;
2863 if (parser.parseKeyword("step") ||
2864 parser.parseOperandList(steps, ivs.size(),
2866 parser.resolveOperands(steps, builder.getIndexType(), result.operands))
2867 return failure();
2868
2869 // Parse init values.
2870 SmallVector<OpAsmParser::UnresolvedOperand, 4> initVals;
2871 if (succeeded(parser.parseOptionalKeyword("init"))) {
2872 if (parser.parseOperandList(initVals, OpAsmParser::Delimiter::Paren))
2873 return failure();
2874 }
2875
2876 // Parse optional results in case there is a reduce.
2877 if (parser.parseOptionalArrowTypeList(result.types))
2878 return failure();
2879
2880 // Now parse the body.
2881 Region *body = result.addRegion();
2882 for (auto &iv : ivs)
2883 iv.type = builder.getIndexType();
2884 if (parser.parseRegion(*body, ivs))
2885 return failure();
2886
2887 // Set `operandSegmentSizes` attribute.
2888 result.addAttribute(
2889 ParallelOp::getOperandSegmentSizeAttr(),
2890 builder.getDenseI32ArrayAttr({static_cast<int32_t>(lower.size()),
2891 static_cast<int32_t>(upper.size()),
2892 static_cast<int32_t>(steps.size()),
2893 static_cast<int32_t>(initVals.size())}));
2894
2895 // Parse attributes.
2896 if (parser.parseOptionalAttrDict(result.attributes) ||
2897 parser.resolveOperands(initVals, result.types, parser.getNameLoc(),
2898 result.operands))
2899 return failure();
2900
2901 // Add a terminator if none was parsed.
2902 ParallelOp::ensureTerminator(*body, builder, result.location);
2903 return success();
2904}
2905
2906void ParallelOp::print(OpAsmPrinter &p) {
2907 p << " (" << getBody()->getArguments() << ") = (" << getLowerBound()
2908 << ") to (" << getUpperBound() << ") step (" << getStep() << ")";
2909 if (!getInitVals().empty())
2910 p << " init (" << getInitVals() << ")";
2911 p.printOptionalArrowTypeList(getResultTypes());
2912 p << ' ';
2913 p.printRegion(getRegion(), /*printEntryBlockArgs=*/false);
2915 (*this)->getAttrs(),
2916 /*elidedAttrs=*/ParallelOp::getOperandSegmentSizeAttr());
2917}
2918
2919SmallVector<Region *> ParallelOp::getLoopRegions() { return {&getRegion()}; }
2920
2921std::optional<SmallVector<Value>> ParallelOp::getLoopInductionVars() {
2922 return SmallVector<Value>{getBody()->getArguments()};
2923}
2924
2925std::optional<SmallVector<OpFoldResult>> ParallelOp::getLoopLowerBounds() {
2926 return getLowerBound();
2927}
2928
2929std::optional<SmallVector<OpFoldResult>> ParallelOp::getLoopUpperBounds() {
2930 return getUpperBound();
2931}
2932
2933std::optional<SmallVector<OpFoldResult>> ParallelOp::getLoopSteps() {
2934 return getStep();
2935}
2936
2938 auto ivArg = llvm::dyn_cast<BlockArgument>(val);
2939 if (!ivArg)
2940 return ParallelOp();
2941 assert(ivArg.getOwner() && "unlinked block argument");
2942 auto *containingOp = ivArg.getOwner()->getParentOp();
2943 return dyn_cast<ParallelOp>(containingOp);
2944}
2945
2946namespace {
2947// Collapse loop dimensions that perform a single iteration.
2948struct ParallelOpSingleOrZeroIterationDimsFolder
2949 : public OpRewritePattern<ParallelOp> {
2950 using OpRewritePattern<ParallelOp>::OpRewritePattern;
2951
2952 LogicalResult matchAndRewrite(ParallelOp op,
2953 PatternRewriter &rewriter) const override {
2954 Location loc = op.getLoc();
2955
2956 // Compute new loop bounds that omit all single-iteration loop dimensions.
2957 SmallVector<Value> newLowerBounds, newUpperBounds, newSteps;
2958 IRMapping mapping;
2959 for (auto [lb, ub, step, iv] :
2960 llvm::zip(op.getLowerBound(), op.getUpperBound(), op.getStep(),
2961 op.getInductionVars())) {
2962 auto numIterations =
2963 constantTripCount(lb, ub, step, /*isSigned=*/true, computeUbMinusLb);
2964 if (numIterations.has_value()) {
2965 // Remove the loop if it performs zero iterations.
2966 if (*numIterations == 0) {
2967 rewriter.replaceOp(op, op.getInitVals());
2968 return success();
2969 }
2970 // Replace the loop induction variable by the lower bound if the loop
2971 // performs a single iteration. Otherwise, copy the loop bounds.
2972 if (*numIterations == 1) {
2973 mapping.map(iv, getValueOrCreateConstantIndexOp(rewriter, loc, lb));
2974 continue;
2975 }
2976 }
2977 newLowerBounds.push_back(lb);
2978 newUpperBounds.push_back(ub);
2979 newSteps.push_back(step);
2980 }
2981 // Exit if none of the loop dimensions perform a single iteration.
2982 if (newLowerBounds.size() == op.getLowerBound().size())
2983 return failure();
2984
2985 if (newLowerBounds.empty()) {
2986 // All of the loop dimensions perform a single iteration. Inline
2987 // loop body and nested ReduceOp's
2988 SmallVector<Value> results;
2989 results.reserve(op.getInitVals().size());
2990 for (auto &bodyOp : op.getBody()->without_terminator())
2991 rewriter.clone(bodyOp, mapping);
2992 auto reduceOp = cast<ReduceOp>(op.getBody()->getTerminator());
2993 for (int64_t i = 0, e = reduceOp.getReductions().size(); i < e; ++i) {
2994 Block &reduceBlock = reduceOp.getReductions()[i].front();
2995 auto initValIndex = results.size();
2996 mapping.map(reduceBlock.getArgument(0), op.getInitVals()[initValIndex]);
2997 mapping.map(reduceBlock.getArgument(1),
2998 mapping.lookupOrDefault(reduceOp.getOperands()[i]));
2999 for (auto &reduceBodyOp : reduceBlock.without_terminator())
3000 rewriter.clone(reduceBodyOp, mapping);
3001
3002 auto result = mapping.lookupOrDefault(
3003 cast<ReduceReturnOp>(reduceBlock.getTerminator()).getResult());
3004 results.push_back(result);
3005 }
3006
3007 rewriter.replaceOp(op, results);
3008 return success();
3009 }
3010 // Replace the parallel loop by lower-dimensional parallel loop.
3011 auto newOp =
3012 ParallelOp::create(rewriter, op.getLoc(), newLowerBounds,
3013 newUpperBounds, newSteps, op.getInitVals(), nullptr);
3014 // Erase the empty block that was inserted by the builder.
3015 rewriter.eraseBlock(newOp.getBody());
3016 // Clone the loop body and remap the block arguments of the collapsed loops
3017 // (inlining does not support a cancellable block argument mapping).
3018 rewriter.cloneRegionBefore(op.getRegion(), newOp.getRegion(),
3019 newOp.getRegion().begin(), mapping);
3020 rewriter.replaceOp(op, newOp.getResults());
3021 return success();
3022 }
3023};
3024
3025struct MergeNestedParallelLoops : public OpRewritePattern<ParallelOp> {
3026 using OpRewritePattern<ParallelOp>::OpRewritePattern;
3027
3028 LogicalResult matchAndRewrite(ParallelOp op,
3029 PatternRewriter &rewriter) const override {
3030 Block &outerBody = *op.getBody();
3031 if (!llvm::hasSingleElement(outerBody.without_terminator()))
3032 return failure();
3033
3034 auto innerOp = dyn_cast<ParallelOp>(outerBody.front());
3035 if (!innerOp)
3036 return failure();
3037
3038 for (auto val : outerBody.getArguments())
3039 if (llvm::is_contained(innerOp.getLowerBound(), val) ||
3040 llvm::is_contained(innerOp.getUpperBound(), val) ||
3041 llvm::is_contained(innerOp.getStep(), val))
3042 return failure();
3043
3044 // Reductions are not supported yet.
3045 if (!op.getInitVals().empty() || !innerOp.getInitVals().empty())
3046 return failure();
3047
3048 auto bodyBuilder = [&](OpBuilder &builder, Location /*loc*/,
3049 ValueRange iterVals, ValueRange) {
3050 Block &innerBody = *innerOp.getBody();
3051 assert(iterVals.size() ==
3052 (outerBody.getNumArguments() + innerBody.getNumArguments()));
3053 IRMapping mapping;
3054 mapping.map(outerBody.getArguments(),
3055 iterVals.take_front(outerBody.getNumArguments()));
3056 mapping.map(innerBody.getArguments(),
3057 iterVals.take_back(innerBody.getNumArguments()));
3058 for (Operation &op : innerBody.without_terminator())
3059 builder.clone(op, mapping);
3060 };
3061
3062 auto concatValues = [](const auto &first, const auto &second) {
3063 SmallVector<Value> ret;
3064 ret.reserve(first.size() + second.size());
3065 ret.assign(first.begin(), first.end());
3066 ret.append(second.begin(), second.end());
3067 return ret;
3068 };
3069
3070 auto newLowerBounds =
3071 concatValues(op.getLowerBound(), innerOp.getLowerBound());
3072 auto newUpperBounds =
3073 concatValues(op.getUpperBound(), innerOp.getUpperBound());
3074 auto newSteps = concatValues(op.getStep(), innerOp.getStep());
3075
3076 rewriter.replaceOpWithNewOp<ParallelOp>(op, newLowerBounds, newUpperBounds,
3077 newSteps, ValueRange(),
3078 bodyBuilder);
3079 return success();
3080 }
3081};
3082
3083} // namespace
3084
3085void ParallelOp::getCanonicalizationPatterns(RewritePatternSet &results,
3086 MLIRContext *context) {
3087 results
3088 .add<ParallelOpSingleOrZeroIterationDimsFolder, MergeNestedParallelLoops>(
3089 context);
3090}
3091
3092/// Given the region at `index`, or the parent operation if `index` is None,
3093/// return the successor regions. These are the regions that may be selected
3094/// during the flow of control. `operands` is a set of optional attributes that
3095/// correspond to a constant value for each operand, or null if that operand is
3096/// not a constant.
3097void ParallelOp::getSuccessorRegions(
3098 RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {
3099 // Both the operation itself and the region may be branching into the body or
3100 // back into the operation itself. It is possible for loop not to enter the
3101 // body.
3102 regions.push_back(RegionSuccessor(&getRegion()));
3103 regions.push_back(RegionSuccessor::parent());
3104}
3105
3106//===----------------------------------------------------------------------===//
3107// ReduceOp
3108//===----------------------------------------------------------------------===//
3109
3110void ReduceOp::build(OpBuilder &builder, OperationState &result) {}
3111
3112void ReduceOp::build(OpBuilder &builder, OperationState &result,
3113 ValueRange operands) {
3114 result.addOperands(operands);
3115 for (Value v : operands) {
3116 OpBuilder::InsertionGuard guard(builder);
3117 Region *bodyRegion = result.addRegion();
3118 builder.createBlock(bodyRegion, {},
3119 ArrayRef<Type>{v.getType(), v.getType()},
3120 {result.location, result.location});
3121 }
3122}
3123
3124LogicalResult ReduceOp::verifyRegions() {
3125 if (getReductions().size() != getOperands().size())
3126 return emitOpError() << "expects number of reduction regions: "
3127 << getReductions().size()
3128 << " to be the same as number of reduction operands: "
3129 << getOperands().size();
3130 // The region of a ReduceOp has two arguments of the same type as its
3131 // corresponding operand.
3132 for (int64_t i = 0, e = getReductions().size(); i < e; ++i) {
3133 auto type = getOperands()[i].getType();
3134 Block &block = getReductions()[i].front();
3135 if (block.empty())
3136 return emitOpError() << i << "-th reduction has an empty body";
3137 if (block.getNumArguments() != 2 ||
3138 llvm::any_of(block.getArguments(), [&](const BlockArgument &arg) {
3139 return arg.getType() != type;
3140 }))
3141 return emitOpError() << "expected two block arguments with type " << type
3142 << " in the " << i << "-th reduction region";
3143
3144 // Check that the block is terminated by a ReduceReturnOp.
3145 if (!isa<ReduceReturnOp>(block.getTerminator()))
3146 return emitOpError("reduction bodies must be terminated with an "
3147 "'scf.reduce.return' op");
3148 }
3149
3150 return success();
3151}
3152
3153MutableOperandRange
3154ReduceOp::getMutableSuccessorOperands(RegionSuccessor point) {
3155 // No operands are forwarded to the next iteration.
3156 return MutableOperandRange(getOperation(), /*start=*/0, /*length=*/0);
3157}
3158
3159//===----------------------------------------------------------------------===//
3160// ReduceReturnOp
3161//===----------------------------------------------------------------------===//
3162
3163LogicalResult ReduceReturnOp::verify() {
3164 // The type of the return value should be the same type as the types of the
3165 // block arguments of the reduction body.
3166 Block *reductionBody = getOperation()->getBlock();
3167 // Should already be verified by an op trait.
3168 assert(isa<ReduceOp>(reductionBody->getParentOp()) && "expected scf.reduce");
3169 Type expectedResultType = reductionBody->getArgument(0).getType();
3170 if (expectedResultType != getResult().getType())
3171 return emitOpError() << "must have type " << expectedResultType
3172 << " (the type of the reduction inputs)";
3173 return success();
3174}
3175
3176//===----------------------------------------------------------------------===//
3177// WhileOp
3178//===----------------------------------------------------------------------===//
3179
3180void WhileOp::build(::mlir::OpBuilder &odsBuilder,
3181 ::mlir::OperationState &odsState, TypeRange resultTypes,
3182 ValueRange inits, BodyBuilderFn beforeBuilder,
3183 BodyBuilderFn afterBuilder) {
3184 odsState.addOperands(inits);
3185 odsState.addTypes(resultTypes);
3186
3187 OpBuilder::InsertionGuard guard(odsBuilder);
3188
3189 // Build before region.
3190 SmallVector<Location, 4> beforeArgLocs;
3191 beforeArgLocs.reserve(inits.size());
3192 for (Value operand : inits) {
3193 beforeArgLocs.push_back(operand.getLoc());
3194 }
3195
3196 Region *beforeRegion = odsState.addRegion();
3197 Block *beforeBlock = odsBuilder.createBlock(beforeRegion, /*insertPt=*/{},
3198 inits.getTypes(), beforeArgLocs);
3199 if (beforeBuilder)
3200 beforeBuilder(odsBuilder, odsState.location, beforeBlock->getArguments());
3201
3202 // Build after region.
3203 SmallVector<Location, 4> afterArgLocs(resultTypes.size(), odsState.location);
3204
3205 Region *afterRegion = odsState.addRegion();
3206 Block *afterBlock = odsBuilder.createBlock(afterRegion, /*insertPt=*/{},
3207 resultTypes, afterArgLocs);
3208
3209 if (afterBuilder)
3210 afterBuilder(odsBuilder, odsState.location, afterBlock->getArguments());
3211}
3212
3213ConditionOp WhileOp::getConditionOp() {
3214 return cast<ConditionOp>(getBeforeBody()->getTerminator());
3215}
3216
3217YieldOp WhileOp::getYieldOp() {
3218 return cast<YieldOp>(getAfterBody()->getTerminator());
3219}
3220
3221std::optional<MutableArrayRef<OpOperand>> WhileOp::getYieldedValuesMutable() {
3222 return getYieldOp().getResultsMutable();
3223}
3224
3225Block::BlockArgListType WhileOp::getBeforeArguments() {
3226 return getBeforeBody()->getArguments();
3227}
3228
3229Block::BlockArgListType WhileOp::getAfterArguments() {
3230 return getAfterBody()->getArguments();
3231}
3232
3233Block::BlockArgListType WhileOp::getRegionIterArgs() {
3234 return getBeforeArguments();
3235}
3236
3237OperandRange WhileOp::getEntrySuccessorOperands(RegionSuccessor successor) {
3238 assert(successor.getSuccessor() == &getBefore() &&
3239 "WhileOp is expected to branch only to the first region");
3240 return getInits();
3241}
3242
3243void WhileOp::getSuccessorRegions(RegionBranchPoint point,
3244 SmallVectorImpl<RegionSuccessor> &regions) {
3245 // The parent op always branches to the condition region.
3246 if (point.isParent()) {
3247 regions.emplace_back(&getBefore());
3248 return;
3249 }
3250
3251 assert(llvm::is_contained(
3252 {&getAfter(), &getBefore()},
3253 point.getTerminatorPredecessorOrNull()->getParentRegion()) &&
3254 "there are only two regions in a WhileOp");
3255 // The body region always branches back to the condition region.
3256 if (point.getTerminatorPredecessorOrNull()->getParentRegion() ==
3257 &getAfter()) {
3258 regions.emplace_back(&getBefore());
3259 return;
3260 }
3261
3262 regions.push_back(RegionSuccessor::parent());
3263 regions.emplace_back(&getAfter());
3264}
3265
3266ValueRange WhileOp::getSuccessorInputs(RegionSuccessor successor) {
3267 if (successor.isParent())
3268 return getOperation()->getResults();
3269 if (successor == &getBefore())
3270 return getBefore().getArguments();
3271 if (successor == &getAfter())
3272 return getAfter().getArguments();
3273 llvm_unreachable("invalid region successor");
3274}
3275
3276SmallVector<Region *> WhileOp::getLoopRegions() {
3277 return {&getBefore(), &getAfter()};
3278}
3279
3280/// Parses a `while` op.
3281///
3282/// op ::= `scf.while` assignments `:` function-type region `do` region
3283/// `attributes` attribute-dict
3284/// initializer ::= /* empty */ | `(` assignment-list `)`
3285/// assignment-list ::= assignment | assignment `,` assignment-list
3286/// assignment ::= ssa-value `=` ssa-value
3287ParseResult scf::WhileOp::parse(OpAsmParser &parser, OperationState &result) {
3288 SmallVector<OpAsmParser::Argument, 4> regionArgs;
3289 SmallVector<OpAsmParser::UnresolvedOperand, 4> operands;
3290 Region *before = result.addRegion();
3291 Region *after = result.addRegion();
3292
3293 OptionalParseResult listResult =
3294 parser.parseOptionalAssignmentList(regionArgs, operands);
3295 if (listResult.has_value() && failed(listResult.value()))
3296 return failure();
3297
3298 FunctionType functionType;
3299 SMLoc typeLoc = parser.getCurrentLocation();
3300 if (failed(parser.parseColonType(functionType)))
3301 return failure();
3302
3303 result.addTypes(functionType.getResults());
3304
3305 if (functionType.getNumInputs() != operands.size()) {
3306 return parser.emitError(typeLoc)
3307 << "expected as many input types as operands " << "(expected "
3308 << operands.size() << " got " << functionType.getNumInputs() << ")";
3309 }
3310
3311 // Resolve input operands.
3312 if (failed(parser.resolveOperands(operands, functionType.getInputs(),
3313 parser.getCurrentLocation(),
3314 result.operands)))
3315 return failure();
3316
3317 // Propagate the types into the region arguments.
3318 for (size_t i = 0, e = regionArgs.size(); i != e; ++i)
3319 regionArgs[i].type = functionType.getInput(i);
3320
3321 return failure(parser.parseRegion(*before, regionArgs) ||
3322 parser.parseKeyword("do") || parser.parseRegion(*after) ||
3323 parser.parseOptionalAttrDictWithKeyword(result.attributes));
3324}
3325
3326/// Prints a `while` op.
3327void scf::WhileOp::print(OpAsmPrinter &p) {
3328 printInitializationList(p, getBeforeArguments(), getInits(), " ");
3329 p << " : ";
3330 p.printFunctionalType(getInits().getTypes(), getResults().getTypes());
3331 p << ' ';
3332 p.printRegion(getBefore(), /*printEntryBlockArgs=*/false);
3333 p << " do ";
3334 p.printRegion(getAfter());
3335 p.printOptionalAttrDictWithKeyword((*this)->getAttrs());
3336}
3337
3338/// Verifies that two ranges of types match, i.e. have the same number of
3339/// entries and that types are pairwise equals. Reports errors on the given
3340/// operation in case of mismatch.
3341template <typename OpTy>
3342static LogicalResult verifyTypeRangesMatch(OpTy op, TypeRange left,
3343 TypeRange right, StringRef message) {
3344 if (left.size() != right.size())
3345 return op.emitOpError("expects the same number of ") << message;
3346
3347 for (unsigned i = 0, e = left.size(); i < e; ++i) {
3348 if (left[i] != right[i]) {
3349 InFlightDiagnostic diag = op.emitOpError("expects the same types for ")
3350 << message;
3351 diag.attachNote() << "for argument " << i << ", found " << left[i]
3352 << " and " << right[i];
3353 return diag;
3354 }
3355 }
3356
3357 return success();
3358}
3359
3360LogicalResult scf::WhileOp::verify() {
3361 auto beforeTerminator = verifyAndGetTerminator<scf::ConditionOp>(
3362 *this, getBefore(),
3363 "expects the 'before' region to terminate with 'scf.condition'");
3364 if (!beforeTerminator)
3365 return failure();
3366
3367 auto afterTerminator = verifyAndGetTerminator<scf::YieldOp>(
3368 *this, getAfter(),
3369 "expects the 'after' region to terminate with 'scf.yield'");
3370 return success(afterTerminator != nullptr);
3371}
3372
3373namespace {
3374/// Move a scf.if op that is directly before the scf.condition op in the while
3375/// before region, and whose condition matches the condition of the
3376/// scf.condition op, down into the while after region.
3377///
3378/// scf.while (..) : (...) -> ... {
3379/// %additional_used_values = ...
3380/// %cond = ...
3381/// ...
3382/// %res = scf.if %cond -> (...) {
3383/// use(%additional_used_values)
3384/// ... // then block
3385/// scf.yield %then_value
3386/// } else {
3387/// scf.yield %else_value
3388/// }
3389/// scf.condition(%cond) %res, ...
3390/// } do {
3391/// ^bb0(%res_arg, ...):
3392/// use(%res_arg)
3393/// ...
3394///
3395/// becomes
3396/// scf.while (..) : (...) -> ... {
3397/// %additional_used_values = ...
3398/// %cond = ...
3399/// ...
3400/// scf.condition(%cond) %else_value, ..., %additional_used_values
3401/// } do {
3402/// ^bb0(%res_arg ..., %additional_args): :
3403/// use(%additional_args)
3404/// ... // if then block
3405/// use(%then_value)
3406/// ...
3407struct WhileMoveIfDown : public OpRewritePattern<scf::WhileOp> {
3408 using OpRewritePattern<scf::WhileOp>::OpRewritePattern;
3409
3410 LogicalResult matchAndRewrite(scf::WhileOp op,
3411 PatternRewriter &rewriter) const override {
3412 auto conditionOp = op.getConditionOp();
3413
3414 // Only support ifOp right before the condition at the moment. Relaxing this
3415 // would require to:
3416 // - check that the body does not have side-effects conflicting with
3417 // operations between the if and the condition.
3418 // - check that results of the if operation are only used as arguments to
3419 // the condition.
3420 auto ifOp = dyn_cast_or_null<scf::IfOp>(conditionOp->getPrevNode());
3421
3422 // Check that the ifOp is directly before the conditionOp and that it
3423 // matches the condition of the conditionOp. Also ensure that the ifOp has
3424 // no else block with content, as that would complicate the transformation.
3425 // TODO: support else blocks with content.
3426 if (!ifOp || ifOp.getCondition() != conditionOp.getCondition() ||
3427 (ifOp.elseBlock() && !ifOp.elseBlock()->without_terminator().empty()))
3428 return failure();
3429
3430 assert((ifOp->use_empty() || (llvm::all_equal(ifOp->getUsers()) &&
3431 *ifOp->user_begin() == conditionOp)) &&
3432 "ifOp has unexpected uses");
3433
3434 Location loc = op.getLoc();
3435
3436 // Replace uses of ifOp results in the conditionOp with the yielded values
3437 // from the ifOp branches.
3438 for (auto [idx, arg] : llvm::enumerate(conditionOp.getArgs())) {
3439 auto it = llvm::find(ifOp->getResults(), arg);
3440 if (it != ifOp->getResults().end()) {
3441 size_t ifOpIdx = it.getIndex();
3442 Value thenValue = ifOp.thenYield()->getOperand(ifOpIdx);
3443 Value elseValue = ifOp.elseYield()->getOperand(ifOpIdx);
3444
3445 rewriter.replaceAllUsesWith(ifOp->getResults()[ifOpIdx], elseValue);
3446 rewriter.replaceAllUsesWith(op.getAfterArguments()[idx], thenValue);
3447 }
3448 }
3449
3450 // Collect additional used values from before region.
3451 SetVector<Value> additionalUsedValuesSet;
3452 visitUsedValuesDefinedAbove(ifOp.getThenRegion(), [&](OpOperand *operand) {
3453 if (&op.getBefore() == operand->get().getParentRegion())
3454 additionalUsedValuesSet.insert(operand->get());
3455 });
3456
3457 // Create new whileOp with additional used values as results.
3458 auto additionalUsedValues = additionalUsedValuesSet.getArrayRef();
3459 auto additionalValueTypes = llvm::map_to_vector(
3460 additionalUsedValues, [](Value val) { return val.getType(); });
3461 size_t additionalValueSize = additionalUsedValues.size();
3462 SmallVector<Type> newResultTypes(op.getResultTypes());
3463 newResultTypes.append(additionalValueTypes);
3464
3465 auto newWhileOp =
3466 scf::WhileOp::create(rewriter, loc, newResultTypes, op.getInits());
3467
3468 rewriter.modifyOpInPlace(newWhileOp, [&] {
3469 newWhileOp.getBefore().takeBody(op.getBefore());
3470 newWhileOp.getAfter().takeBody(op.getAfter());
3471 newWhileOp.getAfter().addArguments(
3472 additionalValueTypes,
3473 SmallVector<Location>(additionalValueSize, loc));
3474 });
3475
3476 rewriter.modifyOpInPlace(conditionOp, [&] {
3477 conditionOp.getArgsMutable().append(additionalUsedValues);
3478 });
3479
3480 // Replace uses of additional used values inside the ifOp then region with
3481 // the whileOp after region arguments.
3482 rewriter.replaceUsesWithIf(
3483 additionalUsedValues,
3484 newWhileOp.getAfterArguments().take_back(additionalValueSize),
3485 [&](OpOperand &use) {
3486 return ifOp.getThenRegion().isAncestor(
3487 use.getOwner()->getParentRegion());
3488 });
3489
3490 // Inline ifOp then region into new whileOp after region.
3491 rewriter.eraseOp(ifOp.thenYield());
3492 rewriter.inlineBlockBefore(ifOp.thenBlock(), newWhileOp.getAfterBody(),
3493 newWhileOp.getAfterBody()->begin());
3494 rewriter.eraseOp(ifOp);
3495 rewriter.replaceOp(op,
3496 newWhileOp->getResults().drop_back(additionalValueSize));
3497 return success();
3498 }
3499};
3500
3501/// Replace uses of the condition within the do block with true, since otherwise
3502/// the block would not be evaluated.
3503///
3504/// scf.while (..) : (i1, ...) -> ... {
3505/// %condition = call @evaluate_condition() : () -> i1
3506/// scf.condition(%condition) %condition : i1, ...
3507/// } do {
3508/// ^bb0(%arg0: i1, ...):
3509/// use(%arg0)
3510/// ...
3511///
3512/// becomes
3513/// scf.while (..) : (i1, ...) -> ... {
3514/// %condition = call @evaluate_condition() : () -> i1
3515/// scf.condition(%condition) %condition : i1, ...
3516/// } do {
3517/// ^bb0(%arg0: i1, ...):
3518/// use(%true)
3519/// ...
3520struct WhileConditionTruth : public OpRewritePattern<WhileOp> {
3521 using OpRewritePattern<WhileOp>::OpRewritePattern;
3522
3523 LogicalResult matchAndRewrite(WhileOp op,
3524 PatternRewriter &rewriter) const override {
3525 auto term = op.getConditionOp();
3526
3527 // These variables serve to prevent creating duplicate constants
3528 // and hold constant true or false values.
3529 Value constantTrue = nullptr;
3530
3531 bool replaced = false;
3532 for (auto yieldedAndBlockArgs :
3533 llvm::zip(term.getArgs(), op.getAfterArguments())) {
3534 if (std::get<0>(yieldedAndBlockArgs) == term.getCondition()) {
3535 if (!std::get<1>(yieldedAndBlockArgs).use_empty()) {
3536 if (!constantTrue)
3537 constantTrue = arith::ConstantOp::create(
3538 rewriter, op.getLoc(), term.getCondition().getType(),
3539 rewriter.getBoolAttr(true));
3540
3541 rewriter.replaceAllUsesWith(std::get<1>(yieldedAndBlockArgs),
3542 constantTrue);
3543 replaced = true;
3544 }
3545 }
3546 }
3547 return success(replaced);
3548 }
3549};
3550
3551/// Replace operations equivalent to the condition in the do block with true,
3552/// since otherwise the block would not be evaluated.
3553///
3554/// scf.while (..) : (i32, ...) -> ... {
3555/// %z = ... : i32
3556/// %condition = cmpi pred %z, %a
3557/// scf.condition(%condition) %z : i32, ...
3558/// } do {
3559/// ^bb0(%arg0: i32, ...):
3560/// %condition2 = cmpi pred %arg0, %a
3561/// use(%condition2)
3562/// ...
3563///
3564/// becomes
3565/// scf.while (..) : (i32, ...) -> ... {
3566/// %z = ... : i32
3567/// %condition = cmpi pred %z, %a
3568/// scf.condition(%condition) %z : i32, ...
3569/// } do {
3570/// ^bb0(%arg0: i32, ...):
3571/// use(%true)
3572/// ...
3573struct WhileCmpCond : public OpRewritePattern<scf::WhileOp> {
3574 using OpRewritePattern<scf::WhileOp>::OpRewritePattern;
3575
3576 LogicalResult matchAndRewrite(scf::WhileOp op,
3577 PatternRewriter &rewriter) const override {
3578 using namespace scf;
3579 auto cond = op.getConditionOp();
3580 auto cmp = cond.getCondition().getDefiningOp<arith::CmpIOp>();
3581 if (!cmp)
3582 return failure();
3583 bool changed = false;
3584 for (auto tup : llvm::zip(cond.getArgs(), op.getAfterArguments())) {
3585 for (size_t opIdx = 0; opIdx < 2; opIdx++) {
3586 if (std::get<0>(tup) != cmp.getOperand(opIdx))
3587 continue;
3588 for (OpOperand &u :
3589 llvm::make_early_inc_range(std::get<1>(tup).getUses())) {
3590 auto cmp2 = dyn_cast<arith::CmpIOp>(u.getOwner());
3591 if (!cmp2)
3592 continue;
3593 // For a binary operator 1-opIdx gets the other side.
3594 if (cmp2.getOperand(1 - opIdx) != cmp.getOperand(1 - opIdx))
3595 continue;
3596 bool samePredicate;
3597 if (cmp2.getPredicate() == cmp.getPredicate())
3598 samePredicate = true;
3599 else if (cmp2.getPredicate() ==
3600 arith::invertPredicate(cmp.getPredicate()))
3601 samePredicate = false;
3602 else
3603 continue;
3604
3605 rewriter.replaceOpWithNewOp<arith::ConstantIntOp>(cmp2, samePredicate,
3606 1);
3607 changed = true;
3608 }
3609 }
3610 }
3611 return success(changed);
3612 }
3613};
3614
3615/// If both ranges contain same values return mappping indices from args2 to
3616/// args1. Otherwise return std::nullopt.
3617static std::optional<SmallVector<unsigned>> getArgsMapping(ValueRange args1,
3618 ValueRange args2) {
3619 if (args1.size() != args2.size())
3620 return std::nullopt;
3621
3622 SmallVector<unsigned> ret(args1.size());
3623 for (auto &&[i, arg1] : llvm::enumerate(args1)) {
3624 auto it = llvm::find(args2, arg1);
3625 if (it == args2.end())
3626 return std::nullopt;
3627
3628 ret[std::distance(args2.begin(), it)] = static_cast<unsigned>(i);
3629 }
3630
3631 return ret;
3632}
3633
3634static bool hasDuplicates(ValueRange args) {
3635 llvm::SmallDenseSet<Value> set;
3636 for (Value arg : args) {
3637 if (!set.insert(arg).second)
3638 return true;
3639 }
3640 return false;
3641}
3642
3643/// If `before` block args are directly forwarded to `scf.condition`, rearrange
3644/// `scf.condition` args into same order as block args. Update `after` block
3645/// args and op result values accordingly.
3646/// Needed to simplify `scf.while` -> `scf.for` uplifting.
3647struct WhileOpAlignBeforeArgs : public OpRewritePattern<WhileOp> {
3649
3650 LogicalResult matchAndRewrite(WhileOp loop,
3651 PatternRewriter &rewriter) const override {
3652 auto *oldBefore = loop.getBeforeBody();
3653 ConditionOp oldTerm = loop.getConditionOp();
3654 ValueRange beforeArgs = oldBefore->getArguments();
3655 ValueRange termArgs = oldTerm.getArgs();
3656 if (beforeArgs == termArgs)
3657 return failure();
3658
3659 if (hasDuplicates(termArgs))
3660 return failure();
3661
3662 auto mapping = getArgsMapping(beforeArgs, termArgs);
3663 if (!mapping)
3664 return failure();
3665
3666 {
3667 OpBuilder::InsertionGuard g(rewriter);
3668 rewriter.setInsertionPoint(oldTerm);
3669 rewriter.replaceOpWithNewOp<ConditionOp>(oldTerm, oldTerm.getCondition(),
3670 beforeArgs);
3671 }
3672
3673 auto *oldAfter = loop.getAfterBody();
3674
3675 SmallVector<Type> newResultTypes(beforeArgs.size());
3676 for (auto &&[i, j] : llvm::enumerate(*mapping))
3677 newResultTypes[j] = loop.getResult(i).getType();
3678
3679 auto newLoop = WhileOp::create(
3680 rewriter, loop.getLoc(), newResultTypes, loop.getInits(),
3681 /*beforeBuilder=*/nullptr, /*afterBuilder=*/nullptr);
3682 auto *newBefore = newLoop.getBeforeBody();
3683 auto *newAfter = newLoop.getAfterBody();
3684
3685 SmallVector<Value> newResults(beforeArgs.size());
3686 SmallVector<Value> newAfterArgs(beforeArgs.size());
3687 for (auto &&[i, j] : llvm::enumerate(*mapping)) {
3688 newResults[i] = newLoop.getResult(j);
3689 newAfterArgs[i] = newAfter->getArgument(j);
3690 }
3691
3692 rewriter.inlineBlockBefore(oldBefore, newBefore, newBefore->begin(),
3693 newBefore->getArguments());
3694 rewriter.inlineBlockBefore(oldAfter, newAfter, newAfter->begin(),
3695 newAfterArgs);
3696
3697 rewriter.replaceOp(loop, newResults);
3698 return success();
3699 }
3700};
3701} // namespace
3702
3703void WhileOp::getCanonicalizationPatterns(RewritePatternSet &results,
3704 MLIRContext *context) {
3705 results.add<WhileConditionTruth, WhileCmpCond, WhileOpAlignBeforeArgs,
3706 WhileMoveIfDown>(context);
3708 results, WhileOp::getOperationName());
3710 WhileOp::getOperationName());
3711}
3712
3713//===----------------------------------------------------------------------===//
3714// IndexSwitchOp
3715//===----------------------------------------------------------------------===//
3716
3717/// Parse the case regions and values.
3718static ParseResult
3720 SmallVectorImpl<std::unique_ptr<Region>> &caseRegions) {
3721 SmallVector<int64_t> caseValues;
3722 while (succeeded(p.parseOptionalKeyword("case"))) {
3723 int64_t value;
3724 Region &region = *caseRegions.emplace_back(std::make_unique<Region>());
3725 if (p.parseInteger(value) || p.parseRegion(region, /*arguments=*/{}))
3726 return failure();
3727 caseValues.push_back(value);
3728 }
3729 cases = p.getBuilder().getDenseI64ArrayAttr(caseValues);
3730 return success();
3731}
3732
3733/// Print the case regions and values.
3735 DenseI64ArrayAttr cases, RegionRange caseRegions) {
3736 for (auto [value, region] : llvm::zip(cases.asArrayRef(), caseRegions)) {
3737 p.printNewline();
3738 p << "case " << value << ' ';
3739 p.printRegion(*region, /*printEntryBlockArgs=*/false);
3740 }
3741}
3742
3743LogicalResult scf::IndexSwitchOp::verify() {
3744 if (getCases().size() != getCaseRegions().size()) {
3745 return emitOpError("has ")
3746 << getCaseRegions().size() << " case regions but "
3747 << getCases().size() << " case values";
3748 }
3749
3750 DenseSet<int64_t> valueSet;
3751 for (int64_t value : getCases())
3752 if (!valueSet.insert(value).second)
3753 return emitOpError("has duplicate case value: ") << value;
3754 auto verifyRegion = [&](Region &region, const Twine &name) -> LogicalResult {
3755 auto yield = dyn_cast<YieldOp>(region.front().back());
3756 if (!yield)
3757 return emitOpError("expected region to end with scf.yield, but got ")
3758 << region.front().back().getName();
3759
3760 if (yield.getNumOperands() != getNumResults()) {
3761 return (emitOpError("expected each region to return ")
3762 << getNumResults() << " values, but " << name << " returns "
3763 << yield.getNumOperands())
3764 .attachNote(yield.getLoc())
3765 << "see yield operation here";
3766 }
3767 for (auto [idx, result, operand] :
3768 llvm::enumerate(getResultTypes(), yield.getOperands())) {
3769 if (!operand)
3770 return yield.emitOpError() << "operand " << idx << " is null\n";
3771 if (result == operand.getType())
3772 continue;
3773 return (emitOpError("expected result #")
3774 << idx << " of each region to be " << result)
3775 .attachNote(yield.getLoc())
3776 << name << " returns " << operand.getType() << " here";
3777 }
3778 return success();
3779 };
3780
3781 if (failed(verifyRegion(getDefaultRegion(), "default region")))
3782 return failure();
3783 for (auto [idx, caseRegion] : llvm::enumerate(getCaseRegions()))
3784 if (failed(verifyRegion(caseRegion, "case region #" + Twine(idx))))
3785 return failure();
3786
3787 return success();
3788}
3789
3790unsigned scf::IndexSwitchOp::getNumCases() { return getCases().size(); }
3791
3792Block &scf::IndexSwitchOp::getDefaultBlock() {
3793 return getDefaultRegion().front();
3794}
3795
3796Block &scf::IndexSwitchOp::getCaseBlock(unsigned idx) {
3797 assert(idx < getNumCases() && "case index out-of-bounds");
3798 return getCaseRegions()[idx].front();
3799}
3800
3801void IndexSwitchOp::getSuccessorRegions(
3802 RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &successors) {
3803 // All regions branch back to the parent op.
3804 if (!point.isParent()) {
3805 successors.push_back(RegionSuccessor::parent());
3806 return;
3807 }
3808
3809 llvm::append_range(successors, getRegions());
3810}
3811
3812ValueRange IndexSwitchOp::getSuccessorInputs(RegionSuccessor successor) {
3813 return successor.isParent() ? ValueRange(getOperation()->getResults())
3814 : ValueRange();
3815}
3816
3817void IndexSwitchOp::getEntrySuccessorRegions(
3818 ArrayRef<Attribute> operands,
3819 SmallVectorImpl<RegionSuccessor> &successors) {
3820 FoldAdaptor adaptor(operands, *this);
3821
3822 // If a constant was not provided, all regions are possible successors.
3823 auto arg = dyn_cast_or_null<IntegerAttr>(adaptor.getArg());
3824 if (!arg) {
3825 llvm::append_range(successors, getRegions());
3826 return;
3827 }
3828
3829 // Otherwise, try to find a case with a matching value. If not, the
3830 // default region is the only successor.
3831 for (auto [caseValue, caseRegion] : llvm::zip(getCases(), getCaseRegions())) {
3832 if (caseValue == arg.getInt()) {
3833 successors.emplace_back(&caseRegion);
3834 return;
3835 }
3836 }
3837 successors.emplace_back(&getDefaultRegion());
3838}
3839
3840void IndexSwitchOp::getRegionInvocationBounds(
3841 ArrayRef<Attribute> operands, SmallVectorImpl<InvocationBounds> &bounds) {
3842 auto operandValue = llvm::dyn_cast_or_null<IntegerAttr>(operands.front());
3843 if (!operandValue) {
3844 // All regions are invoked at most once.
3845 bounds.append(getNumRegions(), InvocationBounds(/*lb=*/0, /*ub=*/1));
3846 return;
3847 }
3848
3849 unsigned liveIndex = getNumRegions() - 1;
3850 const auto *it = llvm::find(getCases(), operandValue.getInt());
3851 if (it != getCases().end())
3852 liveIndex = std::distance(getCases().begin(), it);
3853 for (unsigned i = 0, e = getNumRegions(); i < e; ++i)
3854 bounds.emplace_back(/*lb=*/0, /*ub=*/i == liveIndex);
3855}
3856
3857void IndexSwitchOp::getCanonicalizationPatterns(RewritePatternSet &results,
3858 MLIRContext *context) {
3860 results, IndexSwitchOp::getOperationName());
3862 results, IndexSwitchOp::getOperationName());
3863}
3864
3865//===----------------------------------------------------------------------===//
3866// TableGen'd op method definitions
3867//===----------------------------------------------------------------------===//
3868
3869#define GET_OP_CLASSES
3870#include "mlir/Dialect/SCF/IR/SCFOps.cpp.inc"
return success()
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
static std::optional< int64_t > getUpperBound(Value iv)
Gets the constant upper bound on an affine.for iv.
static std::optional< int64_t > getLowerBound(Value iv)
Gets the constant lower bound on an iv.
static LogicalResult verifyRegion(emitc::SwitchOp op, Region &region, const Twine &name)
Definition EmitC.cpp:1403
static ParseResult parseSwitchCases(OpAsmParser &parser, DenseI64ArrayAttr &cases, SmallVectorImpl< std::unique_ptr< Region > > &caseRegions)
Parse the case regions and values.
Definition EmitC.cpp:1378
static void printSwitchCases(OpAsmPrinter &p, Operation *op, DenseI64ArrayAttr cases, RegionRange caseRegions)
Print the case regions and values.
Definition EmitC.cpp:1394
static LogicalResult verifyTypeRangesMatch(OpTy op, TypeRange left, TypeRange right, StringRef message)
Verifies that two ranges of types match, i.e.
Definition SCF.cpp:3342
static void printInitializationList(OpAsmPrinter &p, Block::BlockArgListType blocksArgs, ValueRange initializers, StringRef prefix="")
Prints the initialization list in the form of <prefix>(inner = outer, inner2 = outer2,...
Definition SCF.cpp:480
static TerminatorTy verifyAndGetTerminator(Operation *op, Region &region, StringRef errorMessage)
Verifies that the first block of the given region is terminated by a TerminatorTy.
Definition SCF.cpp:101
static bool isLegalToInline(InlinerInterface &interface, Region *src, Region *insertRegion, bool shouldCloneInlinedRegion, IRMapping &valueMapping)
Utility to check that all of the operations within 'src' can be inlined.
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
b getContext())
*if copies could not be generated due to yet unimplemented cases *copyInPlacementStart and copyOutPlacementStart in copyPlacementBlock *specify the insertion points where the incoming copies and outgoing should be the output argument nBegin is set to its * replacement(set to `begin` if no invalidation happens). Since outgoing *copies could have been inserted at `end`
@ None
static std::string diag(const llvm::Value &value)
@ Paren
Parens surrounding zero or more operands.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseOptionalColon()=0
Parse a : token if present.
ParseResult parseInteger(IntT &result)
Parse an integer value from the stream.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual ParseResult parseOptionalAttrDictWithKeyword(NamedAttrList &result)=0
Parse a named dictionary into 'result' if the attributes keyword is present.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseOptionalArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an optional arrow followed by a type list.
virtual ParseResult parseArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an arrow followed by a type list.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
void printOptionalArrowTypeList(TypeRange &&types)
Print an optional arrow followed by a type list.
virtual void printNewline()
Print a newline and indent the printer to the start of the current operation/attribute/type.
This class represents an argument of a Block.
Definition Value.h:309
unsigned getArgNumber() const
Returns the number of this argument.
Definition Value.h:321
Block represents an ordered list of Operations.
Definition Block.h:33
MutableArrayRef< BlockArgument > BlockArgListType
Definition Block.h:95
bool empty()
Definition Block.h:158
BlockArgument getArgument(unsigned i)
Definition Block.h:139
unsigned getNumArguments()
Definition Block.h:138
iterator_range< args_iterator > addArguments(TypeRange types, ArrayRef< Location > locs)
Add one argument to the argument list for each type specified in the list.
Definition Block.cpp:165
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
Definition Block.cpp:27
Operation & front()
Definition Block.h:163
Operation & back()
Definition Block.h:162
Operation * getTerminator()
Get the terminator operation of this block.
Definition Block.cpp:249
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
Definition Block.cpp:158
BlockArgListType getArguments()
Definition Block.h:97
iterator_range< iterator > without_terminator()
Return an iterator range over the operation within this block excluding the terminator operation at t...
Definition Block.h:222
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
Definition Block.cpp:31
bool getValue() const
Return the boolean value of this attribute.
UnitAttr getUnitAttr()
Definition Builders.cpp:98
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
Definition Builders.cpp:163
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition Builders.cpp:228
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
Definition Builders.cpp:167
IntegerType getIntegerType(unsigned width)
Definition Builders.cpp:67
BoolAttr getBoolAttr(bool value)
Definition Builders.cpp:100
IntegerType getI1Type()
Definition Builders.cpp:53
MLIRContext * getContext() const
Definition Builders.h:56
IndexType getIndexType()
Definition Builders.cpp:51
virtual Operation * materializeConstant(OpBuilder &builder, Attribute value, Type type, Location loc)
Registered hook to materialize a single constant operation from a given attribute value with the desi...
Definition Dialect.h:83
This is a utility class for mapping one set of IR entities to another.
Definition IRMapping.h:26
auto lookupOrDefault(T from) const
Lookup a mapped value within the map.
Definition IRMapping.h:65
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Definition IRMapping.h:30
IRValueT get() const
Return the current value being used by this operand.
void set(IRValueT newValue)
Set the current value being used by this operand.
This class represents a diagnostic that is inflight and set to be reported.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
This class provides a mutable adaptor for a range of operands.
Definition ValueRange.h:118
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual OptionalParseResult parseOptionalAssignmentList(SmallVectorImpl< Argument > &lhs, SmallVectorImpl< UnresolvedOperand > &rhs)=0
virtual ParseResult parseRegion(Region &region, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult parseArgumentList(SmallVectorImpl< Argument > &result, Delimiter delimiter=Delimiter::None, bool allowType=false, bool allowAttrs=false)=0
Parse zero or more arguments with a specified surrounding delimiter.
ParseResult parseAssignmentList(SmallVectorImpl< Argument > &lhs, SmallVectorImpl< UnresolvedOperand > &rhs)
Parse a list of assignments of the form (x1 = y1, x2 = y2, ...)
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printOptionalAttrDictWithKeyword(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary prefixed with 'attribute...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
void printFunctionalType(Operation *op)
Print the complete type of an operation in functional form.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
RAII guard to reset the insertion point of the builder when destroyed.
Definition Builders.h:348
This class helps build Operations.
Definition Builders.h:207
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition Builders.cpp:430
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition Builders.cpp:562
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition Builders.h:431
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition Builders.h:398
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Definition Builders.h:436
void cloneRegionBefore(Region &region, Region &parent, Region::iterator before, IRMapping &mapping)
Clone the blocks that belong to "region" before the given position in another region "parent".
Definition Builders.cpp:589
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition Builders.h:412
Operation * cloneWithoutRegions(Operation &op, IRMapping &mapper)
Creates a deep copy of this operation but keep the operation regions empty.
Definition Builders.h:594
This class represents a single result from folding an operation.
This class represents an operand of an operation.
Definition Value.h:257
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Definition Value.cpp:226
Set of flags used to control the behavior of the various IR print methods (e.g.
A wrapper class that allows for printing an operation with a set of flags, useful to act as a "stream...
Definition Operation.h:1111
This class implements the operand iterators for the Operation class.
Definition ValueRange.h:43
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Definition Operation.h:512
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition Operation.h:407
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:223
OperandRange operand_range
Definition Operation.h:371
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
Definition Operation.h:238
OperationName getName()
The name of an operation is the key identifier for it.
Definition Operation.h:119
Region * getParentRegion()
Returns the region to which the instruction belongs.
Definition Operation.h:230
bool isProperAncestor(Operation *other)
Return true if this operation is a proper ancestor of the other operation.
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
ParseResult value() const
Access the internal ParseResult value.
bool has_value() const
Returns true if we contain a valid ParseResult value.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class represents a point being branched from in the methods of the RegionBranchOpInterface.
bool isParent() const
Returns true if branching from the parent op.
RegionBranchTerminatorOpInterface getTerminatorPredecessorOrNull() const
Returns the terminator if branching from a region.
This class provides an abstraction over the different types of ranges over Regions.
Definition Region.h:346
This class represents a successor of a region.
static RegionSuccessor parent()
Initialize a successor that branches after/out of the parent operation.
bool isParent() const
Return true if the successor is the parent operation.
Region * getSuccessor() const
Return the given region successor.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition Region.h:26
Block & front()
Definition Region.h:65
Region * getParentRegion()
Return the region containing this region or nullptr if the region is attached to a top-level operatio...
Definition Region.cpp:45
bool isAncestor(Region *other)
Return true if this region is ancestor of the other region.
Definition Region.h:222
Block & back()
Definition Region.h:64
bool empty()
Definition Region.h:60
unsigned getNumArguments()
Definition Region.h:123
iterator begin()
Definition Region.h:55
BlockArgument getArgument(unsigned i)
Definition Region.h:124
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void eraseBlock(Block *block)
This method erases all operations in a block.
Block * splitBlock(Block *block, Block::iterator before)
Split the operations starting at "before" (inclusive) out of the given block into a new block,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void finalizeOpModification(Operation *op)
This method is used to signal the end of an in-place modification of the given operation.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
Operation * eraseOpResults(Operation *op, const BitVector &eraseIndices)
Erase the specified results of the given operation.
virtual void replaceUsesWithIf(Value from, Value to, function_ref< bool(OpOperand &)> functor, bool *allUsesReplaced=nullptr)
Find uses of from and replace them with to if the functor returns true.
virtual void inlineBlockBefore(Block *source, Block *dest, Block::iterator before, ValueRange argValues={})
Inline the operations of block 'source' into block 'dest' before the given position.
void mergeBlocks(Block *source, Block *dest, ValueRange argValues={})
Inline the operations of block 'source' into the end of block 'dest'.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
void inlineRegionBefore(Region &region, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
virtual void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
virtual void startOpModification(Operation *op)
This method is used to notify the rewriter that an in-place operation modification is about to happen...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class provides an abstraction over the various different ranges of value types.
Definition TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
bool isSignlessInteger() const
Return true if this is a signless integer type (with the specified width).
Definition Types.cpp:64
bool isIndex() const
Definition Types.cpp:54
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:387
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
user_range getUsers() const
Definition Value.h:218
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
Region * getParentRegion()
Return the Region in which this Value is defined.
Definition Value.cpp:39
Operation * getOwner() const
Return the owner of this operand.
Definition UseDefLists.h:38
constexpr auto RecursivelySpeculatable
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto NotSpeculatable
static Value defaultReplBuilderFn(OpBuilder &builder, Location loc, Value value)
Default implementation of the non-successor-input replacement builder function.
StringRef getMappingAttrName()
Name of the mapping attribute produced by loop mappers.
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:573
ParallelOp getParallelForInductionVarOwner(Value val)
Returns the parallel loop parent of an induction variable.
Definition SCF.cpp:2937
void buildTerminatedBody(OpBuilder &builder, Location loc)
Default callback for IfOp builders. Inserts a yield without arguments.
Definition SCF.cpp:94
LoopNest buildLoopNest(OpBuilder &builder, Location loc, ValueRange lbs, ValueRange ubs, ValueRange steps, ValueRange iterArgs, function_ref< ValueVector(OpBuilder &, Location, ValueRange, ValueRange)> bodyBuilder=nullptr)
Creates a perfect nest of "for" loops, i.e.
Definition SCF.cpp:777
bool insideMutuallyExclusiveBranches(Operation *a, Operation *b)
Return true if ops a and b (or their ancestors) are in mutually exclusive regions/blocks of an IfOp.
Definition SCF.cpp:1886
void promote(RewriterBase &rewriter, scf::ForallOp forallOp)
Promotes the loop body of a scf::ForallOp to its containing block.
Definition SCF.cpp:732
std::optional< llvm::APSInt > computeUbMinusLb(Value lb, Value ub, bool isSigned)
Helper function to compute the difference between two values.
Definition SCF.cpp:115
ForOp getForInductionVarOwner(Value val)
Returns the loop parent of an induction variable.
Definition SCF.cpp:659
SmallVector< Value > ValueVector
An owning vector of values, handy to return from functions.
Definition SCF.h:64
llvm::function_ref< Value(OpBuilder &, Location loc, Type, Value)> ValueTypeCastFnTy
Perform a replacement of one iter OpOperand of an scf.for to the replacement value with a different t...
Definition SCF.h:107
ForallOp getForallOpThreadIndexOwner(Value val)
Returns the ForallOp parent of an thread index variable.
Definition SCF.cpp:1372
SmallVector< Value > replaceAndCastForOpIterArg(RewriterBase &rewriter, scf::ForOp forOp, OpOperand &operand, Value replacement, const ValueTypeCastFnTy &castFn)
Definition SCF.cpp:866
bool preservesStaticInformation(Type source, Type target)
Returns true if target is a ranked tensor type that preserves static information available in the sou...
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition Matchers.h:490
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, MLIRContext *context)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
Definition Matchers.h:527
detail::DenseArrayAttrImpl< int64_t > DenseI64ArrayAttr
const FrozenRewritePatternSet GreedyRewriteConfig bool * changed
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
std::function< SmallVector< Value >( OpBuilder &b, Location loc, ArrayRef< BlockArgument > newBbArgs)> NewYieldValuesFn
A function that returns the additional yielded values during replaceWithAdditionalYields.
ParseResult parseDynamicIndexList(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &values, DenseI64ArrayAttr &integers, DenseBoolArrayAttr &scalableFlags, SmallVectorImpl< Type > *valueTypes=nullptr, AsmParser::Delimiter delimiter=AsmParser::Delimiter::Square)
Parser hooks for custom directive in assemblyFormat.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:304
llvm::DenseSet< ValueT, ValueInfoT > DenseSet
Definition LLVM.h:120
llvm::SetVector< T, Vector, Set, N > SetVector
Definition LLVM.h:123
detail::constant_int_predicate_matcher m_One()
Matches a constant scalar / vector splat / tensor splat integer one.
Definition Matchers.h:478
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition Utils.cpp:111
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
Definition LLVM.h:118
void populateRegionBranchOpInterfaceInliningPattern(RewritePatternSet &patterns, StringRef opName, NonSuccessorInputReplacementBuilderFn replBuilderFn=detail::defaultReplBuilderFn, PatternMatcherFn matcherFn=detail::defaultMatcherFn, PatternBenefit benefit=1)
Populate a pattern that inlines the body of region branch ops when there is a single acyclic path thr...
LogicalResult verifyListOfOperandsOrIntegers(Operation *op, StringRef name, unsigned expectedNumElements, ArrayRef< int64_t > attr, ValueRange values)
Verify that a the values has as many elements as the number of entries in attr for which isDynamic ev...
void populateRegionBranchOpInterfaceCanonicalizationPatterns(RewritePatternSet &patterns, StringRef opName, PatternBenefit benefit=1)
Populate canonicalization patterns that simplify successor operands/inputs of region branch operation...
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition Matchers.h:369
void visitUsedValuesDefinedAbove(Region &region, Region &limit, function_ref< void(OpOperand *)> callback)
Calls callback for each use of a value within region or its descendants that was defined at the ances...
llvm::function_ref< Fn > function_ref
Definition LLVM.h:144
SmallVector< NamedAttribute > getPrunedAttributeList(Operation *op, ArrayRef< StringRef > elidedAttrs)
void printDynamicIndexList(OpAsmPrinter &printer, Operation *op, OperandRange values, ArrayRef< int64_t > integers, ArrayRef< bool > scalableFlags, TypeRange valueTypes=TypeRange(), AsmParser::Delimiter delimiter=AsmParser::Delimiter::Square)
Printer hooks for custom directive in assemblyFormat.
std::optional< APInt > constantTripCount(OpFoldResult lb, OpFoldResult ub, OpFoldResult step, bool isSigned, llvm::function_ref< std::optional< llvm::APSInt >(Value, Value, bool)> computeUbMinusLb)
Return the number of iterations for a loop with a lower bound lb, upper bound ub and step step.
LogicalResult foldDynamicIndexList(SmallVectorImpl< OpFoldResult > &ofrs, bool onlyNonNegative=false, bool onlyNonZero=false)
Returns "success" when any of the elements in ofrs is a constant value.
LogicalResult matchAndRewrite(ExecuteRegionOp op, PatternRewriter &rewriter) const override
Definition SCF.cpp:222
This is the representation of an operand reference.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
This represents an operation in an abstracted form, suitable for use with the builder APIs.
void addOperands(ValueRange newOperands)
void addTypes(ArrayRef< Type > newTypes)
Region * addRegion()
Create a region that should be attached to the operation.