MLIR 23.0.0git
SCF.cpp
Go to the documentation of this file.
1//===- SCF.cpp - Structured Control Flow Operations -----------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
20#include "mlir/IR/IRMapping.h"
21#include "mlir/IR/Matchers.h"
22#include "mlir/IR/Operation.h"
30#include "llvm/ADT/MapVector.h"
31#include "llvm/ADT/STLExtras.h"
32#include "llvm/ADT/SmallPtrSet.h"
33#include "llvm/Support/Casting.h"
34#include "llvm/Support/DebugLog.h"
35#include <optional>
36
37using namespace mlir;
38using namespace mlir::scf;
39
40#include "mlir/Dialect/SCF/IR/SCFOpsDialect.cpp.inc"
41
42//===----------------------------------------------------------------------===//
43// SCFDialect Dialect Interfaces
44//===----------------------------------------------------------------------===//
45
46namespace {
47struct SCFInlinerInterface : public DialectInlinerInterface {
48 using DialectInlinerInterface::DialectInlinerInterface;
49 // We don't have any special restrictions on what can be inlined into
50 // destination regions (e.g. while/conditional bodies). Always allow it.
51 bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned,
52 IRMapping &valueMapping) const final {
53 return true;
54 }
55 // Operations in scf dialect are always legal to inline since they are
56 // pure.
57 bool isLegalToInline(Operation *, Region *, bool, IRMapping &) const final {
58 return true;
59 }
60 // Handle the given inlined terminator by replacing it with a new operation
61 // as necessary. Required when the region has only one block.
62 void handleTerminator(Operation *op, ValueRange valuesToRepl) const final {
63 auto retValOp = dyn_cast<scf::YieldOp>(op);
64 if (!retValOp)
65 return;
66
67 for (auto retValue : llvm::zip(valuesToRepl, retValOp.getOperands())) {
68 std::get<0>(retValue).replaceAllUsesWith(std::get<1>(retValue));
69 }
70 }
71};
72} // namespace
73
74//===----------------------------------------------------------------------===//
75// SCFDialect
76//===----------------------------------------------------------------------===//
77
78void SCFDialect::initialize() {
79 addOperations<
80#define GET_OP_LIST
81#include "mlir/Dialect/SCF/IR/SCFOps.cpp.inc"
82 >();
83 addInterfaces<SCFInlinerInterface>();
84 declarePromisedInterface<ConvertToEmitCPatternInterface, SCFDialect>();
85 declarePromisedInterfaces<bufferization::BufferDeallocationOpInterface,
86 InParallelOp, ReduceReturnOp>();
87 declarePromisedInterfaces<bufferization::BufferizableOpInterface, ConditionOp,
88 ExecuteRegionOp, ForOp, IfOp, IndexSwitchOp,
89 ForallOp, InParallelOp, WhileOp, YieldOp>();
90 declarePromisedInterface<ValueBoundsOpInterface, ForOp>();
91}
92
93/// Default callback for IfOp builders. Inserts a yield without arguments.
95 scf::YieldOp::create(builder, loc);
96}
97
98/// Verifies that the first block of the given `region` is terminated by a
99/// TerminatorTy. Reports errors on the given operation if it is not the case.
100template <typename TerminatorTy>
101static TerminatorTy verifyAndGetTerminator(Operation *op, Region &region,
102 StringRef errorMessage) {
103 Operation *terminatorOperation = nullptr;
104 if (!region.empty() && !region.front().empty()) {
105 terminatorOperation = &region.front().back();
106 if (auto yield = dyn_cast_or_null<TerminatorTy>(terminatorOperation))
107 return yield;
108 }
109 auto diag = op->emitOpError(errorMessage);
110 if (terminatorOperation)
111 diag.attachNote(terminatorOperation->getLoc()) << "terminator here";
112 return nullptr;
113}
114
115std::optional<llvm::APSInt> mlir::scf::computeUbMinusLb(Value lb, Value ub,
116 bool isSigned) {
117 llvm::APSInt diff;
118 auto addOp = ub.getDefiningOp<arith::AddIOp>();
119 if (!addOp)
120 return std::nullopt;
121 if ((isSigned && !addOp.hasNoSignedWrap()) ||
122 (!isSigned && !addOp.hasNoUnsignedWrap()))
123 return std::nullopt;
124
125 if (addOp.getLhs() != lb ||
126 !matchPattern(addOp.getRhs(), m_ConstantInt(&diff)))
127 return std::nullopt;
128 return diff;
129}
130
131//===----------------------------------------------------------------------===//
132// ExecuteRegionOp
133//===----------------------------------------------------------------------===//
134
135/// Replaces the given op with the contents of the given single-block region,
136/// using the operands of the block terminator to replace operation results.
138 Region &region, ValueRange blockArgs = {}) {
139 assert(region.hasOneBlock() && "expected single-block region");
140 Block *block = &region.front();
141 Operation *terminator = block->getTerminator();
142 ValueRange results = terminator->getOperands();
143 rewriter.inlineBlockBefore(block, op, blockArgs);
144 rewriter.replaceOp(op, results);
145 rewriter.eraseOp(terminator);
146}
147
148///
149/// (ssa-id `=`)? `execute_region` `->` function-result-type `{`
150/// block+
151/// `}`
152///
153/// Example:
154/// scf.execute_region -> i32 {
155/// %idx = load %rI[%i] : memref<128xi32>
156/// return %idx : i32
157/// }
158///
159ParseResult ExecuteRegionOp::parse(OpAsmParser &parser,
161 if (parser.parseOptionalArrowTypeList(result.types))
162 return failure();
163
164 if (succeeded(parser.parseOptionalKeyword("no_inline")))
165 result.addAttribute("no_inline", parser.getBuilder().getUnitAttr());
166
167 // Introduce the body region and parse it.
168 Region *body = result.addRegion();
169 if (parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{}) ||
170 parser.parseOptionalAttrDict(result.attributes))
171 return failure();
172
173 return success();
174}
175
176void ExecuteRegionOp::print(OpAsmPrinter &p) {
177 p.printOptionalArrowTypeList(getResultTypes());
178 p << ' ';
179 if (getNoInline())
180 p << "no_inline ";
181 p.printRegion(getRegion(),
182 /*printEntryBlockArgs=*/false,
183 /*printBlockTerminators=*/true);
184 p.printOptionalAttrDict((*this)->getAttrs(), /*elidedAttrs=*/{"no_inline"});
185}
186
187LogicalResult ExecuteRegionOp::verify() {
188 if (getRegion().empty())
189 return emitOpError("region needs to have at least one block");
190 if (getRegion().front().getNumArguments() > 0)
191 return emitOpError("region cannot have any arguments");
192 return success();
193}
194
195// Inline an ExecuteRegionOp if it only contains one block.
196// "test.foo"() : () -> ()
197// %v = scf.execute_region -> i64 {
198// %x = "test.val"() : () -> i64
199// scf.yield %x : i64
200// }
201// "test.bar"(%v) : (i64) -> ()
202//
203// becomes
204//
205// "test.foo"() : () -> ()
206// %x = "test.val"() : () -> i64
207// "test.bar"(%x) : (i64) -> ()
208//
209struct SingleBlockExecuteInliner : public OpRewritePattern<ExecuteRegionOp> {
210 using OpRewritePattern<ExecuteRegionOp>::OpRewritePattern;
211
212 LogicalResult matchAndRewrite(ExecuteRegionOp op,
213 PatternRewriter &rewriter) const override {
214 if (!op.getRegion().hasOneBlock() || op.getNoInline())
215 return failure();
216 replaceOpWithRegion(rewriter, op, op.getRegion());
217 return success();
218 }
219};
220
221// Inline an ExecuteRegionOp if its parent can contain multiple blocks.
222// TODO generalize the conditions for operations which can be inlined into.
223// func @func_execute_region_elim() {
224// "test.foo"() : () -> ()
225// %v = scf.execute_region -> i64 {
226// %c = "test.cmp"() : () -> i1
227// cf.cond_br %c, ^bb2, ^bb3
228// ^bb2:
229// %x = "test.val1"() : () -> i64
230// cf.br ^bb4(%x : i64)
231// ^bb3:
232// %y = "test.val2"() : () -> i64
233// cf.br ^bb4(%y : i64)
234// ^bb4(%z : i64):
235// scf.yield %z : i64
236// }
237// "test.bar"(%v) : (i64) -> ()
238// return
239// }
240//
241// becomes
242//
243// func @func_execute_region_elim() {
244// "test.foo"() : () -> ()
245// %c = "test.cmp"() : () -> i1
246// cf.cond_br %c, ^bb1, ^bb2
247// ^bb1: // pred: ^bb0
248// %x = "test.val1"() : () -> i64
249// cf.br ^bb3(%x : i64)
250// ^bb2: // pred: ^bb0
251// %y = "test.val2"() : () -> i64
252// cf.br ^bb3(%y : i64)
253// ^bb3(%z: i64): // 2 preds: ^bb1, ^bb2
254// "test.bar"(%z) : (i64) -> ()
255// return
256// }
257//
258struct MultiBlockExecuteInliner : public OpRewritePattern<ExecuteRegionOp> {
259 using OpRewritePattern<ExecuteRegionOp>::OpRewritePattern;
260
261 LogicalResult matchAndRewrite(ExecuteRegionOp op,
262 PatternRewriter &rewriter) const override {
263 if (op.getNoInline())
264 return failure();
265 if (!isa<FunctionOpInterface, ExecuteRegionOp>(op->getParentOp()))
266 return failure();
267
268 Block *prevBlock = op->getBlock();
269 Block *postBlock = rewriter.splitBlock(prevBlock, op->getIterator());
270 rewriter.setInsertionPointToEnd(prevBlock);
271
272 cf::BranchOp::create(rewriter, op.getLoc(), &op.getRegion().front());
273
274 for (Block &blk : op.getRegion()) {
275 if (YieldOp yieldOp = dyn_cast<YieldOp>(blk.getTerminator())) {
276 rewriter.setInsertionPoint(yieldOp);
277 cf::BranchOp::create(rewriter, yieldOp.getLoc(), postBlock,
278 yieldOp.getResults());
279 rewriter.eraseOp(yieldOp);
280 }
281 }
282
283 rewriter.inlineRegionBefore(op.getRegion(), postBlock);
284 SmallVector<Value> blockArgs;
285
286 for (auto res : op.getResults())
287 blockArgs.push_back(postBlock->addArgument(res.getType(), res.getLoc()));
288
289 rewriter.replaceOp(op, blockArgs);
290 return success();
291 }
292};
293
294void ExecuteRegionOp::getCanonicalizationPatterns(RewritePatternSet &results,
295 MLIRContext *context) {
298 results, ExecuteRegionOp::getOperationName());
299}
300
301void ExecuteRegionOp::getSuccessorRegions(
303 // If the predecessor is the ExecuteRegionOp, branch into the body.
304 if (point.isParent()) {
305 regions.push_back(RegionSuccessor(&getRegion()));
306 return;
307 }
308
309 // Otherwise, the region branches back to the parent operation.
310 regions.push_back(RegionSuccessor::parent(getResults()));
311}
312
313//===----------------------------------------------------------------------===//
314// ConditionOp
315//===----------------------------------------------------------------------===//
316
318ConditionOp::getMutableSuccessorOperands(RegionSuccessor point) {
319 assert(
320 (point.isParent() || point.getSuccessor() == &getParentOp().getAfter()) &&
321 "condition op can only exit the loop or branch to the after"
322 "region");
323 // Pass all operands except the condition to the successor region.
324 return getArgsMutable();
325}
326
327void ConditionOp::getSuccessorRegions(
329 FoldAdaptor adaptor(operands, *this);
330
331 WhileOp whileOp = getParentOp();
332
333 // Condition can either lead to the after region or back to the parent op
334 // depending on whether the condition is true or not.
335 auto boolAttr = dyn_cast_or_null<BoolAttr>(adaptor.getCondition());
336 if (!boolAttr || boolAttr.getValue())
337 regions.emplace_back(&whileOp.getAfter(),
338 whileOp.getAfter().getArguments());
339 if (!boolAttr || !boolAttr.getValue())
340 regions.push_back(RegionSuccessor::parent(whileOp.getResults()));
341}
342
343//===----------------------------------------------------------------------===//
344// ForOp
345//===----------------------------------------------------------------------===//
346
347void ForOp::build(OpBuilder &builder, OperationState &result, Value lb,
348 Value ub, Value step, ValueRange initArgs,
349 BodyBuilderFn bodyBuilder, bool unsignedCmp) {
350 OpBuilder::InsertionGuard guard(builder);
351
352 if (unsignedCmp)
353 result.addAttribute(getUnsignedCmpAttrName(result.name),
354 builder.getUnitAttr());
355 result.addOperands({lb, ub, step});
356 result.addOperands(initArgs);
357 for (Value v : initArgs)
358 result.addTypes(v.getType());
359 Type t = lb.getType();
360 Region *bodyRegion = result.addRegion();
361 Block *bodyBlock = builder.createBlock(bodyRegion);
362 bodyBlock->addArgument(t, result.location);
363 for (Value v : initArgs)
364 bodyBlock->addArgument(v.getType(), v.getLoc());
365
366 // Create the default terminator if the builder is not provided and if the
367 // iteration arguments are not provided. Otherwise, leave this to the caller
368 // because we don't know which values to return from the loop.
369 if (initArgs.empty() && !bodyBuilder) {
370 ForOp::ensureTerminator(*bodyRegion, builder, result.location);
371 } else if (bodyBuilder) {
372 OpBuilder::InsertionGuard guard(builder);
373 builder.setInsertionPointToStart(bodyBlock);
374 bodyBuilder(builder, result.location, bodyBlock->getArgument(0),
375 bodyBlock->getArguments().drop_front());
376 }
377}
378
379LogicalResult ForOp::verify() {
380 // Check that the number of init args and op results is the same.
381 if (getInitArgs().size() != getNumResults())
382 return emitOpError(
383 "mismatch in number of loop-carried values and defined values");
384
385 return success();
386}
387
388LogicalResult ForOp::verifyRegions() {
389 // Check that the body defines as single block argument for the induction
390 // variable.
391 if (getInductionVar().getType() != getLowerBound().getType())
392 return emitOpError(
393 "expected induction variable to be same type as bounds and step");
394
395 if (getNumRegionIterArgs() != getNumResults())
396 return emitOpError(
397 "mismatch in number of basic block args and defined values");
398
399 auto initArgs = getInitArgs();
400 auto iterArgs = getRegionIterArgs();
401 auto opResults = getResults();
402 unsigned i = 0;
403 for (auto e : llvm::zip(initArgs, iterArgs, opResults)) {
404 if (std::get<0>(e).getType() != std::get<2>(e).getType())
405 return emitOpError() << "types mismatch between " << i
406 << "th iter operand and defined value";
407 if (std::get<1>(e).getType() != std::get<2>(e).getType())
408 return emitOpError() << "types mismatch between " << i
409 << "th iter region arg and defined value";
410
411 ++i;
412 }
413 return success();
414}
415
416std::optional<SmallVector<Value>> ForOp::getLoopInductionVars() {
417 return SmallVector<Value>{getInductionVar()};
418}
419
420std::optional<SmallVector<OpFoldResult>> ForOp::getLoopLowerBounds() {
422}
423
424std::optional<SmallVector<OpFoldResult>> ForOp::getLoopSteps() {
425 return SmallVector<OpFoldResult>{OpFoldResult(getStep())};
426}
427
428std::optional<SmallVector<OpFoldResult>> ForOp::getLoopUpperBounds() {
430}
431
432bool ForOp::isValidInductionVarType(Type type) {
433 return type.isIndex() || type.isSignlessInteger();
434}
435
436LogicalResult ForOp::setLoopLowerBounds(ArrayRef<OpFoldResult> bounds) {
437 if (bounds.size() != 1)
438 return failure();
439 if (auto val = dyn_cast<Value>(bounds[0])) {
440 setLowerBound(val);
441 return success();
442 }
443 return failure();
444}
445
446LogicalResult ForOp::setLoopUpperBounds(ArrayRef<OpFoldResult> bounds) {
447 if (bounds.size() != 1)
448 return failure();
449 if (auto val = dyn_cast<Value>(bounds[0])) {
450 setUpperBound(val);
451 return success();
452 }
453 return failure();
454}
455
456LogicalResult ForOp::setLoopSteps(ArrayRef<OpFoldResult> steps) {
457 if (steps.size() != 1)
458 return failure();
459 if (auto val = dyn_cast<Value>(steps[0])) {
460 setStep(val);
461 return success();
462 }
463 return failure();
464}
465
466std::optional<ResultRange> ForOp::getLoopResults() { return getResults(); }
467
468/// Promotes the loop body of a forOp to its containing block if the forOp
469/// it can be determined that the loop has a single iteration.
470LogicalResult ForOp::promoteIfSingleIteration(RewriterBase &rewriter) {
471 std::optional<APInt> tripCount = getStaticTripCount();
472 LDBG() << "promoteIfSingleIteration tripCount is " << tripCount
473 << " for loop "
474 << OpWithFlags(getOperation(), OpPrintingFlags().skipRegions());
475 if (!tripCount.has_value() || tripCount->getSExtValue() > 1)
476 return failure();
477
478 if (*tripCount == 0) {
479 rewriter.replaceAllUsesWith(getResults(), getInitArgs());
480 rewriter.eraseOp(*this);
481 return success();
482 }
483
484 // Replace all results with the yielded values.
485 auto yieldOp = cast<scf::YieldOp>(getBody()->getTerminator());
486 rewriter.replaceAllUsesWith(getResults(), getYieldedValues());
487
488 // Replace block arguments with lower bound (replacement for IV) and
489 // iter_args.
490 SmallVector<Value> bbArgReplacements;
491 bbArgReplacements.push_back(getLowerBound());
492 llvm::append_range(bbArgReplacements, getInitArgs());
493
494 // Move the loop body operations to the loop's containing block.
495 rewriter.inlineBlockBefore(getBody(), getOperation()->getBlock(),
496 getOperation()->getIterator(), bbArgReplacements);
497
498 // Erase the old terminator and the loop.
499 rewriter.eraseOp(yieldOp);
500 rewriter.eraseOp(*this);
501
502 return success();
503}
504
505/// Prints the initialization list in the form of
506/// <prefix>(%inner = %outer, %inner2 = %outer2, <...>)
507/// where 'inner' values are assumed to be region arguments and 'outer' values
508/// are regular SSA values.
510 Block::BlockArgListType blocksArgs,
511 ValueRange initializers,
512 StringRef prefix = "") {
513 assert(blocksArgs.size() == initializers.size() &&
514 "expected same length of arguments and initializers");
515 if (initializers.empty())
516 return;
517
518 p << prefix << '(';
519 llvm::interleaveComma(llvm::zip(blocksArgs, initializers), p, [&](auto it) {
520 p << std::get<0>(it) << " = " << std::get<1>(it);
521 });
522 p << ")";
523}
524
525void ForOp::print(OpAsmPrinter &p) {
526 if (getUnsignedCmp())
527 p << " unsigned";
528
529 p << " " << getInductionVar() << " = " << getLowerBound() << " to "
530 << getUpperBound() << " step " << getStep();
531
532 printInitializationList(p, getRegionIterArgs(), getInitArgs(), " iter_args");
533 if (!getInitArgs().empty())
534 p << " -> (" << getInitArgs().getTypes() << ')';
535 p << ' ';
536 if (Type t = getInductionVar().getType(); !t.isIndex())
537 p << " : " << t << ' ';
538 p.printRegion(getRegion(),
539 /*printEntryBlockArgs=*/false,
540 /*printBlockTerminators=*/!getInitArgs().empty());
541 p.printOptionalAttrDict((*this)->getAttrs(),
542 /*elidedAttrs=*/getUnsignedCmpAttrName().strref());
543}
544
545ParseResult ForOp::parse(OpAsmParser &parser, OperationState &result) {
546 auto &builder = parser.getBuilder();
547 Type type;
548
549 OpAsmParser::Argument inductionVariable;
551
552 if (succeeded(parser.parseOptionalKeyword("unsigned")))
553 result.addAttribute(getUnsignedCmpAttrName(result.name),
554 builder.getUnitAttr());
555
556 // Parse the induction variable followed by '='.
557 if (parser.parseOperand(inductionVariable.ssaName) || parser.parseEqual() ||
558 // Parse loop bounds.
559 parser.parseOperand(lb) || parser.parseKeyword("to") ||
560 parser.parseOperand(ub) || parser.parseKeyword("step") ||
561 parser.parseOperand(step))
562 return failure();
563
564 // Parse the optional initial iteration arguments.
567 regionArgs.push_back(inductionVariable);
568
569 bool hasIterArgs = succeeded(parser.parseOptionalKeyword("iter_args"));
570 if (hasIterArgs) {
571 // Parse assignment list and results type list.
572 if (parser.parseAssignmentList(regionArgs, operands) ||
573 parser.parseArrowTypeList(result.types))
574 return failure();
575 }
576
577 if (regionArgs.size() != result.types.size() + 1)
578 return parser.emitError(
579 parser.getNameLoc(),
580 "mismatch in number of loop-carried values and defined values");
581
582 // Parse optional type, else assume Index.
583 if (parser.parseOptionalColon())
584 type = builder.getIndexType();
585 else if (parser.parseType(type))
586 return failure();
587
588 // Set block argument types, so that they are known when parsing the region.
589 regionArgs.front().type = type;
590 for (auto [iterArg, type] :
591 llvm::zip_equal(llvm::drop_begin(regionArgs), result.types))
592 iterArg.type = type;
593
594 // Parse the body region.
595 Region *body = result.addRegion();
596 if (parser.parseRegion(*body, regionArgs))
597 return failure();
598 ForOp::ensureTerminator(*body, builder, result.location);
599
600 // Resolve input operands. This should be done after parsing the region to
601 // catch invalid IR where operands were defined inside of the region.
602 if (parser.resolveOperand(lb, type, result.operands) ||
603 parser.resolveOperand(ub, type, result.operands) ||
604 parser.resolveOperand(step, type, result.operands))
605 return failure();
606 if (hasIterArgs) {
607 for (auto argOperandType : llvm::zip_equal(llvm::drop_begin(regionArgs),
608 operands, result.types)) {
609 Type type = std::get<2>(argOperandType);
610 std::get<0>(argOperandType).type = type;
611 if (parser.resolveOperand(std::get<1>(argOperandType), type,
612 result.operands))
613 return failure();
614 }
615 }
616
617 // Parse the optional attribute list.
618 if (parser.parseOptionalAttrDict(result.attributes))
619 return failure();
620
621 return success();
622}
623
624SmallVector<Region *> ForOp::getLoopRegions() { return {&getRegion()}; }
625
626Block::BlockArgListType ForOp::getRegionIterArgs() {
627 return getBody()->getArguments().drop_front(getNumInductionVars());
628}
629
630MutableArrayRef<OpOperand> ForOp::getInitsMutable() {
631 return getInitArgsMutable();
632}
633
634FailureOr<LoopLikeOpInterface>
635ForOp::replaceWithAdditionalYields(RewriterBase &rewriter,
636 ValueRange newInitOperands,
637 bool replaceInitOperandUsesInLoop,
638 const NewYieldValuesFn &newYieldValuesFn) {
639 // Create a new loop before the existing one, with the extra operands.
640 OpBuilder::InsertionGuard g(rewriter);
641 rewriter.setInsertionPoint(getOperation());
642 auto inits = llvm::to_vector(getInitArgs());
643 inits.append(newInitOperands.begin(), newInitOperands.end());
644 scf::ForOp newLoop = scf::ForOp::create(
645 rewriter, getLoc(), getLowerBound(), getUpperBound(), getStep(), inits,
646 [](OpBuilder &, Location, Value, ValueRange) {}, getUnsignedCmp());
647 newLoop->setAttrs(getPrunedAttributeList(getOperation(), {}));
648
649 // Generate the new yield values and append them to the scf.yield operation.
650 auto yieldOp = cast<scf::YieldOp>(getBody()->getTerminator());
651 ArrayRef<BlockArgument> newIterArgs =
652 newLoop.getBody()->getArguments().take_back(newInitOperands.size());
653 {
654 OpBuilder::InsertionGuard g(rewriter);
655 rewriter.setInsertionPoint(yieldOp);
656 SmallVector<Value> newYieldedValues =
657 newYieldValuesFn(rewriter, getLoc(), newIterArgs);
658 assert(newInitOperands.size() == newYieldedValues.size() &&
659 "expected as many new yield values as new iter operands");
660 rewriter.modifyOpInPlace(yieldOp, [&]() {
661 yieldOp.getResultsMutable().append(newYieldedValues);
662 });
663 }
664
665 // Move the loop body to the new op.
666 rewriter.mergeBlocks(getBody(), newLoop.getBody(),
667 newLoop.getBody()->getArguments().take_front(
668 getBody()->getNumArguments()));
669
670 if (replaceInitOperandUsesInLoop) {
671 // Replace all uses of `newInitOperands` with the corresponding basic block
672 // arguments.
673 for (auto it : llvm::zip(newInitOperands, newIterArgs)) {
674 rewriter.replaceUsesWithIf(std::get<0>(it), std::get<1>(it),
675 [&](OpOperand &use) {
676 Operation *user = use.getOwner();
677 return newLoop->isProperAncestor(user);
678 });
679 }
680 }
681
682 // Replace the old loop.
683 rewriter.replaceOp(getOperation(),
684 newLoop->getResults().take_front(getNumResults()));
685 return cast<LoopLikeOpInterface>(newLoop.getOperation());
686}
687
689 auto ivArg = llvm::dyn_cast<BlockArgument>(val);
690 if (!ivArg)
691 return ForOp();
692 assert(ivArg.getOwner() && "unlinked block argument");
693 auto *containingOp = ivArg.getOwner()->getParentOp();
694 return dyn_cast_or_null<ForOp>(containingOp);
695}
696
697OperandRange ForOp::getEntrySuccessorOperands(RegionSuccessor successor) {
698 return getInitArgs();
699}
700
701void ForOp::getSuccessorRegions(RegionBranchPoint point,
703 // Both the operation itself and the region may be branching into the body or
704 // back into the operation itself. It is possible for loop not to enter the
705 // body.
706 regions.push_back(RegionSuccessor(&getRegion(), getRegionIterArgs()));
707 regions.push_back(RegionSuccessor::parent(getResults()));
708}
709
710SmallVector<Region *> ForallOp::getLoopRegions() { return {&getRegion()}; }
711
712/// Promotes the loop body of a forallOp to its containing block if it can be
713/// determined that the loop has a single iteration.
714LogicalResult scf::ForallOp::promoteIfSingleIteration(RewriterBase &rewriter) {
715 for (auto [lb, ub, step] :
716 llvm::zip(getMixedLowerBound(), getMixedUpperBound(), getMixedStep())) {
717 auto tripCount =
718 constantTripCount(lb, ub, step, /*isSigned=*/true, computeUbMinusLb);
719 if (!tripCount.has_value() || *tripCount != 1)
720 return failure();
721 }
722
723 promote(rewriter, *this);
724 return success();
725}
726
727Block::BlockArgListType ForallOp::getRegionIterArgs() {
728 return getBody()->getArguments().drop_front(getRank());
729}
730
731MutableArrayRef<OpOperand> ForallOp::getInitsMutable() {
732 return getOutputsMutable();
733}
734
735/// Promotes the loop body of a scf::ForallOp to its containing block.
736void mlir::scf::promote(RewriterBase &rewriter, scf::ForallOp forallOp) {
737 OpBuilder::InsertionGuard g(rewriter);
738 scf::InParallelOp terminator = forallOp.getTerminator();
739
740 // Replace block arguments with lower bounds (replacements for IVs) and
741 // outputs.
742 SmallVector<Value> bbArgReplacements = forallOp.getLowerBound(rewriter);
743 bbArgReplacements.append(forallOp.getOutputs().begin(),
744 forallOp.getOutputs().end());
745
746 // Move the loop body operations to the loop's containing block.
747 rewriter.inlineBlockBefore(forallOp.getBody(), forallOp->getBlock(),
748 forallOp->getIterator(), bbArgReplacements);
749
750 // Replace the terminator with tensor.insert_slice ops.
751 rewriter.setInsertionPointAfter(forallOp);
752 SmallVector<Value> results;
753 results.reserve(forallOp.getResults().size());
754 for (auto &yieldingOp : terminator.getYieldingOps()) {
755 auto parallelInsertSliceOp =
756 dyn_cast<tensor::ParallelInsertSliceOp>(yieldingOp);
757 if (!parallelInsertSliceOp)
758 continue;
759
760 Value dst = parallelInsertSliceOp.getDest();
761 Value src = parallelInsertSliceOp.getSource();
762 if (llvm::isa<TensorType>(src.getType())) {
763 results.push_back(tensor::InsertSliceOp::create(
764 rewriter, forallOp.getLoc(), dst.getType(), src, dst,
765 parallelInsertSliceOp.getOffsets(), parallelInsertSliceOp.getSizes(),
766 parallelInsertSliceOp.getStrides(),
767 parallelInsertSliceOp.getStaticOffsets(),
768 parallelInsertSliceOp.getStaticSizes(),
769 parallelInsertSliceOp.getStaticStrides()));
770 } else {
771 llvm_unreachable("unsupported terminator");
772 }
773 }
774 rewriter.replaceAllUsesWith(forallOp.getResults(), results);
775
776 // Erase the old terminator and the loop.
777 rewriter.eraseOp(terminator);
778 rewriter.eraseOp(forallOp);
779}
780
782 OpBuilder &builder, Location loc, ValueRange lbs, ValueRange ubs,
783 ValueRange steps, ValueRange iterArgs,
785 bodyBuilder) {
786 assert(lbs.size() == ubs.size() &&
787 "expected the same number of lower and upper bounds");
788 assert(lbs.size() == steps.size() &&
789 "expected the same number of lower bounds and steps");
790
791 // If there are no bounds, call the body-building function and return early.
792 if (lbs.empty()) {
793 ValueVector results =
794 bodyBuilder ? bodyBuilder(builder, loc, ValueRange(), iterArgs)
795 : ValueVector();
796 assert(results.size() == iterArgs.size() &&
797 "loop nest body must return as many values as loop has iteration "
798 "arguments");
799 return LoopNest{{}, std::move(results)};
800 }
801
802 // First, create the loop structure iteratively using the body-builder
803 // callback of `ForOp::build`. Do not create `YieldOp`s yet.
804 OpBuilder::InsertionGuard guard(builder);
807 loops.reserve(lbs.size());
808 ivs.reserve(lbs.size());
809 ValueRange currentIterArgs = iterArgs;
810 Location currentLoc = loc;
811 for (unsigned i = 0, e = lbs.size(); i < e; ++i) {
812 auto loop = scf::ForOp::create(
813 builder, currentLoc, lbs[i], ubs[i], steps[i], currentIterArgs,
814 [&](OpBuilder &nestedBuilder, Location nestedLoc, Value iv,
815 ValueRange args) {
816 ivs.push_back(iv);
817 // It is safe to store ValueRange args because it points to block
818 // arguments of a loop operation that we also own.
819 currentIterArgs = args;
820 currentLoc = nestedLoc;
821 });
822 // Set the builder to point to the body of the newly created loop. We don't
823 // do this in the callback because the builder is reset when the callback
824 // returns.
825 builder.setInsertionPointToStart(loop.getBody());
826 loops.push_back(loop);
827 }
828
829 // For all loops but the innermost, yield the results of the nested loop.
830 for (unsigned i = 0, e = loops.size() - 1; i < e; ++i) {
831 builder.setInsertionPointToEnd(loops[i].getBody());
832 scf::YieldOp::create(builder, loc, loops[i + 1].getResults());
833 }
834
835 // In the body of the innermost loop, call the body building function if any
836 // and yield its results.
837 builder.setInsertionPointToStart(loops.back().getBody());
838 ValueVector results = bodyBuilder
839 ? bodyBuilder(builder, currentLoc, ivs,
840 loops.back().getRegionIterArgs())
841 : ValueVector();
842 assert(results.size() == iterArgs.size() &&
843 "loop nest body must return as many values as loop has iteration "
844 "arguments");
845 builder.setInsertionPointToEnd(loops.back().getBody());
846 scf::YieldOp::create(builder, loc, results);
847
848 // Return the loops.
849 ValueVector nestResults;
850 llvm::append_range(nestResults, loops.front().getResults());
851 return LoopNest{std::move(loops), std::move(nestResults)};
852}
853
855 OpBuilder &builder, Location loc, ValueRange lbs, ValueRange ubs,
856 ValueRange steps,
857 function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilder) {
858 // Delegate to the main function by wrapping the body builder.
859 return buildLoopNest(builder, loc, lbs, ubs, steps, {},
860 [&bodyBuilder](OpBuilder &nestedBuilder,
861 Location nestedLoc, ValueRange ivs,
863 if (bodyBuilder)
864 bodyBuilder(nestedBuilder, nestedLoc, ivs);
865 return {};
866 });
867}
868
871 OpOperand &operand, Value replacement,
872 const ValueTypeCastFnTy &castFn) {
873 assert(operand.getOwner() == forOp);
874 Type oldType = operand.get().getType(), newType = replacement.getType();
875
876 // 1. Create new iter operands, exactly 1 is replaced.
877 assert(operand.getOperandNumber() >= forOp.getNumControlOperands() &&
878 "expected an iter OpOperand");
879 assert(operand.get().getType() != replacement.getType() &&
880 "Expected a different type");
881 SmallVector<Value> newIterOperands;
882 for (OpOperand &opOperand : forOp.getInitArgsMutable()) {
883 if (opOperand.getOperandNumber() == operand.getOperandNumber()) {
884 newIterOperands.push_back(replacement);
885 continue;
886 }
887 newIterOperands.push_back(opOperand.get());
888 }
889
890 // 2. Create the new forOp shell.
891 scf::ForOp newForOp = scf::ForOp::create(
892 rewriter, forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
893 forOp.getStep(), newIterOperands, /*bodyBuilder=*/nullptr,
894 forOp.getUnsignedCmp());
895 newForOp->setAttrs(forOp->getAttrs());
896 Block &newBlock = newForOp.getRegion().front();
897 SmallVector<Value, 4> newBlockTransferArgs(newBlock.getArguments().begin(),
898 newBlock.getArguments().end());
899
900 // 3. Inject an incoming cast op at the beginning of the block for the bbArg
901 // corresponding to the `replacement` value.
902 OpBuilder::InsertionGuard g(rewriter);
903 rewriter.setInsertionPointToStart(&newBlock);
904 BlockArgument newRegionIterArg = newForOp.getTiedLoopRegionIterArg(
905 &newForOp->getOpOperand(operand.getOperandNumber()));
906 Value castIn = castFn(rewriter, newForOp.getLoc(), oldType, newRegionIterArg);
907 newBlockTransferArgs[newRegionIterArg.getArgNumber()] = castIn;
908
909 // 4. Steal the old block ops, mapping to the newBlockTransferArgs.
910 Block &oldBlock = forOp.getRegion().front();
911 rewriter.mergeBlocks(&oldBlock, &newBlock, newBlockTransferArgs);
912
913 // 5. Inject an outgoing cast op at the end of the block and yield it instead.
914 auto clonedYieldOp = cast<scf::YieldOp>(newBlock.getTerminator());
915 rewriter.setInsertionPoint(clonedYieldOp);
916 unsigned yieldIdx =
917 newRegionIterArg.getArgNumber() - forOp.getNumInductionVars();
918 Value castOut = castFn(rewriter, newForOp.getLoc(), newType,
919 clonedYieldOp.getOperand(yieldIdx));
920 SmallVector<Value> newYieldOperands = clonedYieldOp.getOperands();
921 newYieldOperands[yieldIdx] = castOut;
922 scf::YieldOp::create(rewriter, newForOp.getLoc(), newYieldOperands);
923 rewriter.eraseOp(clonedYieldOp);
924
925 // 6. Inject an outgoing cast op after the forOp.
926 rewriter.setInsertionPointAfter(newForOp);
927 SmallVector<Value> newResults = newForOp.getResults();
928 newResults[yieldIdx] =
929 castFn(rewriter, newForOp.getLoc(), oldType, newResults[yieldIdx]);
930
931 return newResults;
932}
933
934namespace {
935/// Rewriting pattern that erases loops that are known not to iterate, replaces
936/// single-iteration loops with their bodies, and removes empty loops that
937/// iterate at least once and only return values defined outside of the loop.
938struct SimplifyTrivialLoops : public OpRewritePattern<ForOp> {
940
941 LogicalResult matchAndRewrite(ForOp op,
942 PatternRewriter &rewriter) const override {
943 std::optional<APInt> tripCount = op.getStaticTripCount();
944 if (!tripCount.has_value())
945 return rewriter.notifyMatchFailure(op,
946 "can't compute constant trip count");
947
948 if (tripCount->isZero()) {
949 LDBG() << "SimplifyTrivialLoops tripCount is 0 for loop "
950 << OpWithFlags(op, OpPrintingFlags().skipRegions());
951 rewriter.replaceOp(op, op.getInitArgs());
952 return success();
953 }
954
955 if (tripCount->getSExtValue() == 1) {
956 LDBG() << "SimplifyTrivialLoops tripCount is 1 for loop "
957 << OpWithFlags(op, OpPrintingFlags().skipRegions());
958 SmallVector<Value, 4> blockArgs;
959 blockArgs.reserve(op.getInitArgs().size() + 1);
960 blockArgs.push_back(op.getLowerBound());
961 llvm::append_range(blockArgs, op.getInitArgs());
962 replaceOpWithRegion(rewriter, op, op.getRegion(), blockArgs);
963 return success();
964 }
965
966 // Now we are left with loops that have more than 1 iterations.
967 Block &block = op.getRegion().front();
968 if (!llvm::hasSingleElement(block))
969 return failure();
970 // The loop is empty and iterates at least once, if it only returns values
971 // defined outside of the loop, remove it and replace it with yield values.
972 if (llvm::any_of(op.getYieldedValues(),
973 [&](Value v) { return !op.isDefinedOutsideOfLoop(v); }))
974 return failure();
975 LDBG() << "SimplifyTrivialLoops empty body loop allows replacement with "
976 "yield operands for loop "
977 << OpWithFlags(op, OpPrintingFlags().skipRegions());
978 rewriter.replaceOp(op, op.getYieldedValues());
979 return success();
980 }
981};
982
983/// Fold scf.for iter_arg/result pairs that go through incoming/ougoing
984/// a tensor.cast op pair so as to pull the tensor.cast inside the scf.for:
985///
986/// ```
987/// %0 = tensor.cast %t0 : tensor<32x1024xf32> to tensor<?x?xf32>
988/// %1 = scf.for %i = %c0 to %c1024 step %c32 iter_args(%iter_t0 = %0)
989/// -> (tensor<?x?xf32>) {
990/// %2 = call @do(%iter_t0) : (tensor<?x?xf32>) -> tensor<?x?xf32>
991/// scf.yield %2 : tensor<?x?xf32>
992/// }
993/// use_of(%1)
994/// ```
995///
996/// folds into:
997///
998/// ```
999/// %0 = scf.for %arg2 = %c0 to %c1024 step %c32 iter_args(%arg3 = %arg0)
1000/// -> (tensor<32x1024xf32>) {
1001/// %2 = tensor.cast %arg3 : tensor<32x1024xf32> to tensor<?x?xf32>
1002/// %3 = call @do(%2) : (tensor<?x?xf32>) -> tensor<?x?xf32>
1003/// %4 = tensor.cast %3 : tensor<?x?xf32> to tensor<32x1024xf32>
1004/// scf.yield %4 : tensor<32x1024xf32>
1005/// }
1006/// %1 = tensor.cast %0 : tensor<32x1024xf32> to tensor<?x?xf32>
1007/// use_of(%1)
1008/// ```
1009struct ForOpTensorCastFolder : public OpRewritePattern<ForOp> {
1010 using OpRewritePattern<ForOp>::OpRewritePattern;
1011
1012 LogicalResult matchAndRewrite(ForOp op,
1013 PatternRewriter &rewriter) const override {
1014 for (auto it : llvm::zip(op.getInitArgsMutable(), op.getResults())) {
1015 OpOperand &iterOpOperand = std::get<0>(it);
1016 auto incomingCast = iterOpOperand.get().getDefiningOp<tensor::CastOp>();
1017 if (!incomingCast ||
1018 incomingCast.getSource().getType() == incomingCast.getType())
1019 continue;
1020 // If the dest type of the cast does not preserve static information in
1021 // the source type.
1023 incomingCast.getDest().getType(),
1024 incomingCast.getSource().getType()))
1025 continue;
1026 if (!std::get<1>(it).hasOneUse())
1027 continue;
1028
1029 // Create a new ForOp with that iter operand replaced.
1030 rewriter.replaceOp(
1032 rewriter, op, iterOpOperand, incomingCast.getSource(),
1033 [](OpBuilder &b, Location loc, Type type, Value source) {
1034 return tensor::CastOp::create(b, loc, type, source);
1035 }));
1036 return success();
1037 }
1038 return failure();
1039 }
1040};
1041} // namespace
1042
1043void ForOp::getCanonicalizationPatterns(RewritePatternSet &results,
1044 MLIRContext *context) {
1045 results.add<SimplifyTrivialLoops, ForOpTensorCastFolder>(context);
1047 results, ForOp::getOperationName());
1048}
1049
1050std::optional<APInt> ForOp::getConstantStep() {
1051 IntegerAttr step;
1052 if (matchPattern(getStep(), m_Constant(&step)))
1053 return step.getValue();
1054 return {};
1055}
1056
1057std::optional<MutableArrayRef<OpOperand>> ForOp::getYieldedValuesMutable() {
1058 return cast<scf::YieldOp>(getBody()->getTerminator()).getResultsMutable();
1059}
1060
1061Speculation::Speculatability ForOp::getSpeculatability() {
1062 // `scf.for (I = Start; I < End; I += 1)` terminates for all values of Start
1063 // and End.
1064 if (auto constantStep = getConstantStep())
1065 if (*constantStep == 1)
1067
1068 // For Step != 1, the loop may not terminate. We can add more smarts here if
1069 // needed.
1071}
1072
1073std::optional<APInt> ForOp::getStaticTripCount() {
1074 return constantTripCount(getLowerBound(), getUpperBound(), getStep(),
1075 /*isSigned=*/!getUnsignedCmp(), computeUbMinusLb);
1076}
1077
1078//===----------------------------------------------------------------------===//
1079// ForallOp
1080//===----------------------------------------------------------------------===//
1081
1082LogicalResult ForallOp::verify() {
1083 unsigned numLoops = getRank();
1084 // Check number of outputs.
1085 if (getNumResults() != getOutputs().size())
1086 return emitOpError("produces ")
1087 << getNumResults() << " results, but has only "
1088 << getOutputs().size() << " outputs";
1089
1090 // Check that the body defines block arguments for thread indices and outputs.
1091 auto *body = getBody();
1092 if (body->getNumArguments() != numLoops + getOutputs().size())
1093 return emitOpError("region expects ") << numLoops << " arguments";
1094 for (int64_t i = 0; i < numLoops; ++i)
1095 if (!body->getArgument(i).getType().isIndex())
1096 return emitOpError("expects ")
1097 << i << "-th block argument to be an index";
1098 for (unsigned i = 0; i < getOutputs().size(); ++i)
1099 if (body->getArgument(i + numLoops).getType() != getOutputs()[i].getType())
1100 return emitOpError("type mismatch between ")
1101 << i << "-th output and corresponding block argument";
1102 if (getMapping().has_value() && !getMapping()->empty()) {
1103 if (getDeviceMappingAttrs().size() != numLoops)
1104 return emitOpError() << "mapping attribute size must match op rank";
1105 if (failed(getDeviceMaskingAttr()))
1106 return emitOpError() << getMappingAttrName()
1107 << " supports at most one device masking attribute";
1108 }
1109
1110 // Verify mixed static/dynamic control variables.
1111 Operation *op = getOperation();
1112 if (failed(verifyListOfOperandsOrIntegers(op, "lower bound", numLoops,
1113 getStaticLowerBound(),
1114 getDynamicLowerBound())))
1115 return failure();
1116 if (failed(verifyListOfOperandsOrIntegers(op, "upper bound", numLoops,
1117 getStaticUpperBound(),
1118 getDynamicUpperBound())))
1119 return failure();
1120 if (failed(verifyListOfOperandsOrIntegers(op, "step", numLoops,
1121 getStaticStep(), getDynamicStep())))
1122 return failure();
1123
1124 return success();
1125}
1126
1127void ForallOp::print(OpAsmPrinter &p) {
1128 Operation *op = getOperation();
1129 p << " (" << getInductionVars();
1130 if (isNormalized()) {
1131 p << ") in ";
1132 printDynamicIndexList(p, op, getDynamicUpperBound(), getStaticUpperBound(),
1133 /*valueTypes=*/{}, /*scalables=*/{},
1135 } else {
1136 p << ") = ";
1137 printDynamicIndexList(p, op, getDynamicLowerBound(), getStaticLowerBound(),
1138 /*valueTypes=*/{}, /*scalables=*/{},
1140 p << " to ";
1141 printDynamicIndexList(p, op, getDynamicUpperBound(), getStaticUpperBound(),
1142 /*valueTypes=*/{}, /*scalables=*/{},
1144 p << " step ";
1145 printDynamicIndexList(p, op, getDynamicStep(), getStaticStep(),
1146 /*valueTypes=*/{}, /*scalables=*/{},
1148 }
1149 printInitializationList(p, getRegionOutArgs(), getOutputs(), " shared_outs");
1150 p << " ";
1151 if (!getRegionOutArgs().empty())
1152 p << "-> (" << getResultTypes() << ") ";
1153 p.printRegion(getRegion(),
1154 /*printEntryBlockArgs=*/false,
1155 /*printBlockTerminators=*/getNumResults() > 0);
1156 p.printOptionalAttrDict(op->getAttrs(), {getOperandSegmentSizesAttrName(),
1157 getStaticLowerBoundAttrName(),
1158 getStaticUpperBoundAttrName(),
1159 getStaticStepAttrName()});
1160}
1161
1162ParseResult ForallOp::parse(OpAsmParser &parser, OperationState &result) {
1163 OpBuilder b(parser.getContext());
1164 auto indexType = b.getIndexType();
1165
1166 // Parse an opening `(` followed by thread index variables followed by `)`
1167 // TODO: when we can refer to such "induction variable"-like handles from the
1168 // declarative assembly format, we can implement the parser as a custom hook.
1169 SmallVector<OpAsmParser::Argument, 4> ivs;
1171 return failure();
1172
1173 DenseI64ArrayAttr staticLbs, staticUbs, staticSteps;
1174 SmallVector<OpAsmParser::UnresolvedOperand> dynamicLbs, dynamicUbs,
1175 dynamicSteps;
1176 if (succeeded(parser.parseOptionalKeyword("in"))) {
1177 // Parse upper bounds.
1178 if (parseDynamicIndexList(parser, dynamicUbs, staticUbs,
1179 /*valueTypes=*/nullptr,
1181 parser.resolveOperands(dynamicUbs, indexType, result.operands))
1182 return failure();
1183
1184 unsigned numLoops = ivs.size();
1185 staticLbs = b.getDenseI64ArrayAttr(SmallVector<int64_t>(numLoops, 0));
1186 staticSteps = b.getDenseI64ArrayAttr(SmallVector<int64_t>(numLoops, 1));
1187 } else {
1188 // Parse lower bounds.
1189 if (parser.parseEqual() ||
1190 parseDynamicIndexList(parser, dynamicLbs, staticLbs,
1191 /*valueTypes=*/nullptr,
1193
1194 parser.resolveOperands(dynamicLbs, indexType, result.operands))
1195 return failure();
1196
1197 // Parse upper bounds.
1198 if (parser.parseKeyword("to") ||
1199 parseDynamicIndexList(parser, dynamicUbs, staticUbs,
1200 /*valueTypes=*/nullptr,
1202 parser.resolveOperands(dynamicUbs, indexType, result.operands))
1203 return failure();
1204
1205 // Parse step values.
1206 if (parser.parseKeyword("step") ||
1207 parseDynamicIndexList(parser, dynamicSteps, staticSteps,
1208 /*valueTypes=*/nullptr,
1210 parser.resolveOperands(dynamicSteps, indexType, result.operands))
1211 return failure();
1212 }
1213
1214 // Parse out operands and results.
1215 SmallVector<OpAsmParser::Argument, 4> regionOutArgs;
1216 SmallVector<OpAsmParser::UnresolvedOperand, 4> outOperands;
1217 SMLoc outOperandsLoc = parser.getCurrentLocation();
1218 if (succeeded(parser.parseOptionalKeyword("shared_outs"))) {
1219 if (outOperands.size() != result.types.size())
1220 return parser.emitError(outOperandsLoc,
1221 "mismatch between out operands and types");
1222 if (parser.parseAssignmentList(regionOutArgs, outOperands) ||
1223 parser.parseOptionalArrowTypeList(result.types) ||
1224 parser.resolveOperands(outOperands, result.types, outOperandsLoc,
1225 result.operands))
1226 return failure();
1227 }
1228
1229 // Parse region.
1230 SmallVector<OpAsmParser::Argument, 4> regionArgs;
1231 std::unique_ptr<Region> region = std::make_unique<Region>();
1232 for (auto &iv : ivs) {
1233 iv.type = b.getIndexType();
1234 regionArgs.push_back(iv);
1235 }
1236 for (const auto &it : llvm::enumerate(regionOutArgs)) {
1237 auto &out = it.value();
1238 out.type = result.types[it.index()];
1239 regionArgs.push_back(out);
1240 }
1241 if (parser.parseRegion(*region, regionArgs))
1242 return failure();
1243
1244 // Ensure terminator and move region.
1245 ForallOp::ensureTerminator(*region, b, result.location);
1246 result.addRegion(std::move(region));
1247
1248 // Parse the optional attribute list.
1249 if (parser.parseOptionalAttrDict(result.attributes))
1250 return failure();
1251
1252 result.addAttribute("staticLowerBound", staticLbs);
1253 result.addAttribute("staticUpperBound", staticUbs);
1254 result.addAttribute("staticStep", staticSteps);
1255 result.addAttribute("operandSegmentSizes",
1257 {static_cast<int32_t>(dynamicLbs.size()),
1258 static_cast<int32_t>(dynamicUbs.size()),
1259 static_cast<int32_t>(dynamicSteps.size()),
1260 static_cast<int32_t>(outOperands.size())}));
1261 return success();
1262}
1263
1264// Builder that takes loop bounds.
1265void ForallOp::build(
1266 mlir::OpBuilder &b, mlir::OperationState &result,
1267 ArrayRef<OpFoldResult> lbs, ArrayRef<OpFoldResult> ubs,
1268 ArrayRef<OpFoldResult> steps, ValueRange outputs,
1269 std::optional<ArrayAttr> mapping,
1270 function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilderFn) {
1271 SmallVector<int64_t> staticLbs, staticUbs, staticSteps;
1272 SmallVector<Value> dynamicLbs, dynamicUbs, dynamicSteps;
1273 dispatchIndexOpFoldResults(lbs, dynamicLbs, staticLbs);
1274 dispatchIndexOpFoldResults(ubs, dynamicUbs, staticUbs);
1275 dispatchIndexOpFoldResults(steps, dynamicSteps, staticSteps);
1276
1277 result.addOperands(dynamicLbs);
1278 result.addOperands(dynamicUbs);
1279 result.addOperands(dynamicSteps);
1280 result.addOperands(outputs);
1281 result.addTypes(TypeRange(outputs));
1282
1283 result.addAttribute(getStaticLowerBoundAttrName(result.name),
1284 b.getDenseI64ArrayAttr(staticLbs));
1285 result.addAttribute(getStaticUpperBoundAttrName(result.name),
1286 b.getDenseI64ArrayAttr(staticUbs));
1287 result.addAttribute(getStaticStepAttrName(result.name),
1288 b.getDenseI64ArrayAttr(staticSteps));
1289 result.addAttribute(
1290 "operandSegmentSizes",
1291 b.getDenseI32ArrayAttr({static_cast<int32_t>(dynamicLbs.size()),
1292 static_cast<int32_t>(dynamicUbs.size()),
1293 static_cast<int32_t>(dynamicSteps.size()),
1294 static_cast<int32_t>(outputs.size())}));
1295 if (mapping.has_value()) {
1296 result.addAttribute(ForallOp::getMappingAttrName(result.name),
1297 mapping.value());
1298 }
1299
1300 Region *bodyRegion = result.addRegion();
1301 OpBuilder::InsertionGuard g(b);
1302 b.createBlock(bodyRegion);
1303 Block &bodyBlock = bodyRegion->front();
1304
1305 // Add block arguments for indices and outputs.
1306 bodyBlock.addArguments(
1307 SmallVector<Type>(lbs.size(), b.getIndexType()),
1308 SmallVector<Location>(staticLbs.size(), result.location));
1309 bodyBlock.addArguments(
1310 TypeRange(outputs),
1311 SmallVector<Location>(outputs.size(), result.location));
1312
1313 b.setInsertionPointToStart(&bodyBlock);
1314 if (!bodyBuilderFn) {
1315 ForallOp::ensureTerminator(*bodyRegion, b, result.location);
1316 return;
1317 }
1318 bodyBuilderFn(b, result.location, bodyBlock.getArguments());
1319}
1320
1321// Builder that takes loop bounds.
1322void ForallOp::build(
1323 mlir::OpBuilder &b, mlir::OperationState &result,
1324 ArrayRef<OpFoldResult> ubs, ValueRange outputs,
1325 std::optional<ArrayAttr> mapping,
1326 function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilderFn) {
1327 unsigned numLoops = ubs.size();
1328 SmallVector<OpFoldResult> lbs(numLoops, b.getIndexAttr(0));
1329 SmallVector<OpFoldResult> steps(numLoops, b.getIndexAttr(1));
1330 build(b, result, lbs, ubs, steps, outputs, mapping, bodyBuilderFn);
1331}
1332
1333// Checks if the lbs are zeros and steps are ones.
1334bool ForallOp::isNormalized() {
1335 auto allEqual = [](ArrayRef<OpFoldResult> results, int64_t val) {
1336 return llvm::all_of(results, [&](OpFoldResult ofr) {
1337 auto intValue = getConstantIntValue(ofr);
1338 return intValue.has_value() && intValue == val;
1339 });
1340 };
1341 return allEqual(getMixedLowerBound(), 0) && allEqual(getMixedStep(), 1);
1342}
1343
1344InParallelOp ForallOp::getTerminator() {
1345 return cast<InParallelOp>(getBody()->getTerminator());
1346}
1347
1348SmallVector<Operation *> ForallOp::getCombiningOps(BlockArgument bbArg) {
1349 SmallVector<Operation *> storeOps;
1350 for (Operation *user : bbArg.getUsers()) {
1351 if (auto parallelOp = dyn_cast<ParallelCombiningOpInterface>(user)) {
1352 storeOps.push_back(parallelOp);
1353 }
1354 }
1355 return storeOps;
1356}
1357
1358SmallVector<DeviceMappingAttrInterface> ForallOp::getDeviceMappingAttrs() {
1359 SmallVector<DeviceMappingAttrInterface> res;
1360 if (!getMapping())
1361 return res;
1362 for (auto attr : getMapping()->getValue()) {
1363 auto m = dyn_cast<DeviceMappingAttrInterface>(attr);
1364 if (m)
1365 res.push_back(m);
1366 }
1367 return res;
1368}
1369
1370FailureOr<DeviceMaskingAttrInterface> ForallOp::getDeviceMaskingAttr() {
1371 DeviceMaskingAttrInterface res;
1372 if (!getMapping())
1373 return res;
1374 for (auto attr : getMapping()->getValue()) {
1375 auto m = dyn_cast<DeviceMaskingAttrInterface>(attr);
1376 if (m && res)
1377 return failure();
1378 if (m)
1379 res = m;
1380 }
1381 return res;
1382}
1383
1384bool ForallOp::usesLinearMapping() {
1385 SmallVector<DeviceMappingAttrInterface> ifaces = getDeviceMappingAttrs();
1386 if (ifaces.empty())
1387 return false;
1388 return ifaces.front().isLinearMapping();
1389}
1390
1391std::optional<SmallVector<Value>> ForallOp::getLoopInductionVars() {
1392 return SmallVector<Value>{getBody()->getArguments().take_front(getRank())};
1393}
1394
1395// Get lower bounds as OpFoldResult.
1396std::optional<SmallVector<OpFoldResult>> ForallOp::getLoopLowerBounds() {
1397 Builder b(getOperation()->getContext());
1398 return getMixedValues(getStaticLowerBound(), getDynamicLowerBound(), b);
1399}
1400
1401// Get upper bounds as OpFoldResult.
1402std::optional<SmallVector<OpFoldResult>> ForallOp::getLoopUpperBounds() {
1403 Builder b(getOperation()->getContext());
1404 return getMixedValues(getStaticUpperBound(), getDynamicUpperBound(), b);
1405}
1406
1407// Get steps as OpFoldResult.
1408std::optional<SmallVector<OpFoldResult>> ForallOp::getLoopSteps() {
1409 Builder b(getOperation()->getContext());
1410 return getMixedValues(getStaticStep(), getDynamicStep(), b);
1411}
1412
1414 auto tidxArg = llvm::dyn_cast<BlockArgument>(val);
1415 if (!tidxArg)
1416 return ForallOp();
1417 assert(tidxArg.getOwner() && "unlinked block argument");
1418 auto *containingOp = tidxArg.getOwner()->getParentOp();
1419 return dyn_cast<ForallOp>(containingOp);
1420}
1421
1422namespace {
1423/// Fold tensor.dim(forall shared_outs(... = %t)) to tensor.dim(%t).
1424struct DimOfForallOp : public OpRewritePattern<tensor::DimOp> {
1425 using OpRewritePattern<tensor::DimOp>::OpRewritePattern;
1426
1427 LogicalResult matchAndRewrite(tensor::DimOp dimOp,
1428 PatternRewriter &rewriter) const final {
1429 auto forallOp = dimOp.getSource().getDefiningOp<ForallOp>();
1430 if (!forallOp)
1431 return failure();
1432 Value sharedOut =
1433 forallOp.getTiedOpOperand(llvm::cast<OpResult>(dimOp.getSource()))
1434 ->get();
1435 rewriter.modifyOpInPlace(
1436 dimOp, [&]() { dimOp.getSourceMutable().assign(sharedOut); });
1437 return success();
1438 }
1439};
1440
1441class ForallOpControlOperandsFolder : public OpRewritePattern<ForallOp> {
1442public:
1443 using OpRewritePattern<ForallOp>::OpRewritePattern;
1444
1445 LogicalResult matchAndRewrite(ForallOp op,
1446 PatternRewriter &rewriter) const override {
1447 SmallVector<OpFoldResult> mixedLowerBound(op.getMixedLowerBound());
1448 SmallVector<OpFoldResult> mixedUpperBound(op.getMixedUpperBound());
1449 SmallVector<OpFoldResult> mixedStep(op.getMixedStep());
1450 if (failed(foldDynamicIndexList(mixedLowerBound)) &&
1451 failed(foldDynamicIndexList(mixedUpperBound)) &&
1452 failed(foldDynamicIndexList(mixedStep)))
1453 return failure();
1454
1455 rewriter.modifyOpInPlace(op, [&]() {
1456 SmallVector<Value> dynamicLowerBound, dynamicUpperBound, dynamicStep;
1457 SmallVector<int64_t> staticLowerBound, staticUpperBound, staticStep;
1458 dispatchIndexOpFoldResults(mixedLowerBound, dynamicLowerBound,
1459 staticLowerBound);
1460 op.getDynamicLowerBoundMutable().assign(dynamicLowerBound);
1461 op.setStaticLowerBound(staticLowerBound);
1462
1463 dispatchIndexOpFoldResults(mixedUpperBound, dynamicUpperBound,
1464 staticUpperBound);
1465 op.getDynamicUpperBoundMutable().assign(dynamicUpperBound);
1466 op.setStaticUpperBound(staticUpperBound);
1467
1468 dispatchIndexOpFoldResults(mixedStep, dynamicStep, staticStep);
1469 op.getDynamicStepMutable().assign(dynamicStep);
1470 op.setStaticStep(staticStep);
1471
1472 op->setAttr(ForallOp::getOperandSegmentSizeAttr(),
1473 rewriter.getDenseI32ArrayAttr(
1474 {static_cast<int32_t>(dynamicLowerBound.size()),
1475 static_cast<int32_t>(dynamicUpperBound.size()),
1476 static_cast<int32_t>(dynamicStep.size()),
1477 static_cast<int32_t>(op.getNumResults())}));
1478 });
1479 return success();
1480 }
1481};
1482
1483/// The following canonicalization pattern folds the iter arguments of
1484/// scf.forall op if :-
1485/// 1. The corresponding result has zero uses.
1486/// 2. The iter argument is NOT being modified within the loop body.
1487/// uses.
1488///
1489/// Example of first case :-
1490/// INPUT:
1491/// %res:3 = scf.forall ... shared_outs(%arg0 = %a, %arg1 = %b, %arg2 = %c)
1492/// {
1493/// ...
1494/// <SOME USE OF %arg0>
1495/// <SOME USE OF %arg1>
1496/// <SOME USE OF %arg2>
1497/// ...
1498/// scf.forall.in_parallel {
1499/// <STORE OP WITH DESTINATION %arg1>
1500/// <STORE OP WITH DESTINATION %arg0>
1501/// <STORE OP WITH DESTINATION %arg2>
1502/// }
1503/// }
1504/// return %res#1
1505///
1506/// OUTPUT:
1507/// %res:3 = scf.forall ... shared_outs(%new_arg0 = %b)
1508/// {
1509/// ...
1510/// <SOME USE OF %a>
1511/// <SOME USE OF %new_arg0>
1512/// <SOME USE OF %c>
1513/// ...
1514/// scf.forall.in_parallel {
1515/// <STORE OP WITH DESTINATION %new_arg0>
1516/// }
1517/// }
1518/// return %res
1519///
1520/// NOTE: 1. All uses of the folded shared_outs (iter argument) within the
1521/// scf.forall is replaced by their corresponding operands.
1522/// 2. Even if there are <STORE OP WITH DESTINATION *> ops within the body
1523/// of the scf.forall besides within scf.forall.in_parallel terminator,
1524/// this canonicalization remains valid. For more details, please refer
1525/// to :
1526/// https://github.com/llvm/llvm-project/pull/90189#discussion_r1589011124
1527/// 3. TODO(avarma): Generalize it for other store ops. Currently it
1528/// handles tensor.parallel_insert_slice ops only.
1529///
1530/// Example of second case :-
1531/// INPUT:
1532/// %res:2 = scf.forall ... shared_outs(%arg0 = %a, %arg1 = %b)
1533/// {
1534/// ...
1535/// <SOME USE OF %arg0>
1536/// <SOME USE OF %arg1>
1537/// ...
1538/// scf.forall.in_parallel {
1539/// <STORE OP WITH DESTINATION %arg1>
1540/// }
1541/// }
1542/// return %res#0, %res#1
1543///
1544/// OUTPUT:
1545/// %res = scf.forall ... shared_outs(%new_arg0 = %b)
1546/// {
1547/// ...
1548/// <SOME USE OF %a>
1549/// <SOME USE OF %new_arg0>
1550/// ...
1551/// scf.forall.in_parallel {
1552/// <STORE OP WITH DESTINATION %new_arg0>
1553/// }
1554/// }
1555/// return %a, %res
1556struct ForallOpIterArgsFolder : public OpRewritePattern<ForallOp> {
1557 using OpRewritePattern<ForallOp>::OpRewritePattern;
1558
1559 LogicalResult matchAndRewrite(ForallOp forallOp,
1560 PatternRewriter &rewriter) const final {
1561 // Step 1: For a given i-th result of scf.forall, check the following :-
1562 // a. If it has any use.
1563 // b. If the corresponding iter argument is being modified within
1564 // the loop, i.e. has at least one store op with the iter arg as
1565 // its destination operand. For this we use
1566 // ForallOp::getCombiningOps(iter_arg).
1567 //
1568 // Based on the check we maintain the following :-
1569 // a. op results, block arguments, outputs to delete
1570 // b. new outputs (i.e., outputs to retain)
1571 SmallVector<Value> resultsToDelete;
1572 SmallVector<Value> outsToDelete;
1573 SmallVector<BlockArgument> blockArgsToDelete;
1574 SmallVector<Value> newOuts;
1575 BitVector resultIndicesToDelete(forallOp.getNumResults(), false);
1576 BitVector blockIndicesToDelete(forallOp.getBody()->getNumArguments(),
1577 false);
1578 for (OpResult result : forallOp.getResults()) {
1579 OpOperand *opOperand = forallOp.getTiedOpOperand(result);
1580 BlockArgument blockArg = forallOp.getTiedBlockArgument(opOperand);
1581 if (result.use_empty() || forallOp.getCombiningOps(blockArg).empty()) {
1582 resultsToDelete.push_back(result);
1583 outsToDelete.push_back(opOperand->get());
1584 blockArgsToDelete.push_back(blockArg);
1585 resultIndicesToDelete[result.getResultNumber()] = true;
1586 blockIndicesToDelete[blockArg.getArgNumber()] = true;
1587 } else {
1588 newOuts.push_back(opOperand->get());
1589 }
1590 }
1591
1592 // Return early if all results of scf.forall have at least one use and being
1593 // modified within the loop.
1594 if (resultsToDelete.empty())
1595 return failure();
1596
1597 // Step 2: Erase combining ops and replace uses of deleted results and
1598 // block arguments with the corresponding outputs.
1599 for (auto blockArg : blockArgsToDelete) {
1600 SmallVector<Operation *> combiningOps =
1601 forallOp.getCombiningOps(blockArg);
1602 for (Operation *combiningOp : combiningOps)
1603 rewriter.eraseOp(combiningOp);
1604 }
1605 for (auto [blockArg, result, out] :
1606 llvm::zip_equal(blockArgsToDelete, resultsToDelete, outsToDelete)) {
1607 rewriter.replaceAllUsesWith(blockArg, out);
1608 rewriter.replaceAllUsesWith(result, out);
1609 }
1610 // TODO: There is no rewriter API for erasing block arguments.
1611 rewriter.modifyOpInPlace(forallOp, [&]() {
1612 forallOp.getBody()->eraseArguments(blockIndicesToDelete);
1613 });
1614
1615 // Step 3. Create a new scf.forall op with only the shared_outs/results
1616 // that should be retained.
1617 auto newForallOp = cast<scf::ForallOp>(
1618 rewriter.eraseOpResults(forallOp, resultIndicesToDelete));
1619 newForallOp.getOutputsMutable().assign(newOuts);
1620
1621 return success();
1622 }
1623};
1624
1625struct ForallOpSingleOrZeroIterationDimsFolder
1626 : public OpRewritePattern<ForallOp> {
1627 using OpRewritePattern<ForallOp>::OpRewritePattern;
1628
1629 LogicalResult matchAndRewrite(ForallOp op,
1630 PatternRewriter &rewriter) const override {
1631 // Do not fold dimensions if they are mapped to processing units.
1632 if (op.getMapping().has_value() && !op.getMapping()->empty())
1633 return failure();
1634 Location loc = op.getLoc();
1635
1636 // Compute new loop bounds that omit all single-iteration loop dimensions.
1637 SmallVector<OpFoldResult> newMixedLowerBounds, newMixedUpperBounds,
1638 newMixedSteps;
1639 IRMapping mapping;
1640 for (auto [lb, ub, step, iv] :
1641 llvm::zip(op.getMixedLowerBound(), op.getMixedUpperBound(),
1642 op.getMixedStep(), op.getInductionVars())) {
1643 auto numIterations =
1644 constantTripCount(lb, ub, step, /*isSigned=*/true, computeUbMinusLb);
1645 if (numIterations.has_value()) {
1646 // Remove the loop if it performs zero iterations.
1647 if (*numIterations == 0) {
1648 rewriter.replaceOp(op, op.getOutputs());
1649 return success();
1650 }
1651 // Replace the loop induction variable by the lower bound if the loop
1652 // performs a single iteration. Otherwise, copy the loop bounds.
1653 if (*numIterations == 1) {
1654 mapping.map(iv, getValueOrCreateConstantIndexOp(rewriter, loc, lb));
1655 continue;
1656 }
1657 }
1658 newMixedLowerBounds.push_back(lb);
1659 newMixedUpperBounds.push_back(ub);
1660 newMixedSteps.push_back(step);
1661 }
1662
1663 // All of the loop dimensions perform a single iteration. Inline loop body.
1664 if (newMixedLowerBounds.empty()) {
1665 promote(rewriter, op);
1666 return success();
1667 }
1668
1669 // Exit if none of the loop dimensions perform a single iteration.
1670 if (newMixedLowerBounds.size() == static_cast<unsigned>(op.getRank())) {
1671 return rewriter.notifyMatchFailure(
1672 op, "no dimensions have 0 or 1 iterations");
1673 }
1674
1675 // Replace the loop by a lower-dimensional loop.
1676 ForallOp newOp;
1677 newOp = ForallOp::create(rewriter, loc, newMixedLowerBounds,
1678 newMixedUpperBounds, newMixedSteps,
1679 op.getOutputs(), std::nullopt, nullptr);
1680 newOp.getBodyRegion().getBlocks().clear();
1681 // The new loop needs to keep all attributes from the old one, except for
1682 // "operandSegmentSizes" and static loop bound attributes which capture
1683 // the outdated information of the old iteration domain.
1684 SmallVector<StringAttr> elidedAttrs{newOp.getOperandSegmentSizesAttrName(),
1685 newOp.getStaticLowerBoundAttrName(),
1686 newOp.getStaticUpperBoundAttrName(),
1687 newOp.getStaticStepAttrName()};
1688 for (const auto &namedAttr : op->getAttrs()) {
1689 if (llvm::is_contained(elidedAttrs, namedAttr.getName()))
1690 continue;
1691 rewriter.modifyOpInPlace(newOp, [&]() {
1692 newOp->setAttr(namedAttr.getName(), namedAttr.getValue());
1693 });
1694 }
1695 rewriter.cloneRegionBefore(op.getRegion(), newOp.getRegion(),
1696 newOp.getRegion().begin(), mapping);
1697 rewriter.replaceOp(op, newOp.getResults());
1698 return success();
1699 }
1700};
1701
1702/// Replace all induction vars with a single trip count with their lower bound.
1703struct ForallOpReplaceConstantInductionVar : public OpRewritePattern<ForallOp> {
1704 using OpRewritePattern<ForallOp>::OpRewritePattern;
1705
1706 LogicalResult matchAndRewrite(ForallOp op,
1707 PatternRewriter &rewriter) const override {
1708 Location loc = op.getLoc();
1709 bool changed = false;
1710 for (auto [lb, ub, step, iv] :
1711 llvm::zip(op.getMixedLowerBound(), op.getMixedUpperBound(),
1712 op.getMixedStep(), op.getInductionVars())) {
1713 if (iv.hasNUses(0))
1714 continue;
1715 auto numIterations =
1716 constantTripCount(lb, ub, step, /*isSigned=*/true, computeUbMinusLb);
1717 if (!numIterations.has_value() || numIterations.value() != 1) {
1718 continue;
1719 }
1720 rewriter.replaceAllUsesWith(
1721 iv, getValueOrCreateConstantIndexOp(rewriter, loc, lb));
1722 changed = true;
1723 }
1724 return success(changed);
1725 }
1726};
1727
1728struct FoldTensorCastOfOutputIntoForallOp
1729 : public OpRewritePattern<scf::ForallOp> {
1730 using OpRewritePattern<scf::ForallOp>::OpRewritePattern;
1731
1732 struct TypeCast {
1733 Type srcType;
1734 Type dstType;
1735 };
1736
1737 LogicalResult matchAndRewrite(scf::ForallOp forallOp,
1738 PatternRewriter &rewriter) const final {
1739 llvm::SmallMapVector<unsigned, TypeCast, 2> tensorCastProducers;
1740 llvm::SmallVector<Value> newOutputTensors = forallOp.getOutputs();
1741 for (auto en : llvm::enumerate(newOutputTensors)) {
1742 auto castOp = en.value().getDefiningOp<tensor::CastOp>();
1743 if (!castOp)
1744 continue;
1745
1746 // Only casts that that preserve static information, i.e. will make the
1747 // loop result type "more" static than before, will be folded.
1748 if (!tensor::preservesStaticInformation(castOp.getDest().getType(),
1749 castOp.getSource().getType())) {
1750 continue;
1751 }
1752
1753 tensorCastProducers[en.index()] =
1754 TypeCast{castOp.getSource().getType(), castOp.getType()};
1755 newOutputTensors[en.index()] = castOp.getSource();
1756 }
1757
1758 if (tensorCastProducers.empty())
1759 return failure();
1760
1761 // Create new loop.
1762 Location loc = forallOp.getLoc();
1763 auto newForallOp = ForallOp::create(
1764 rewriter, loc, forallOp.getMixedLowerBound(),
1765 forallOp.getMixedUpperBound(), forallOp.getMixedStep(),
1766 newOutputTensors, forallOp.getMapping(),
1767 [&](OpBuilder nestedBuilder, Location nestedLoc, ValueRange bbArgs) {
1768 auto castBlockArgs =
1769 llvm::to_vector(bbArgs.take_back(forallOp->getNumResults()));
1770 for (auto [index, cast] : tensorCastProducers) {
1771 Value &oldTypeBBArg = castBlockArgs[index];
1772 oldTypeBBArg = tensor::CastOp::create(nestedBuilder, nestedLoc,
1773 cast.dstType, oldTypeBBArg);
1774 }
1775
1776 // Move old body into new parallel loop.
1777 SmallVector<Value> ivsBlockArgs =
1778 llvm::to_vector(bbArgs.take_front(forallOp.getRank()));
1779 ivsBlockArgs.append(castBlockArgs);
1780 rewriter.mergeBlocks(forallOp.getBody(),
1781 bbArgs.front().getParentBlock(), ivsBlockArgs);
1782 });
1783
1784 // After `mergeBlocks` happened, the destinations in the terminator were
1785 // mapped to the tensor.cast old-typed results of the output bbArgs. The
1786 // destination have to be updated to point to the output bbArgs directly.
1787 auto terminator = newForallOp.getTerminator();
1788 for (auto [yieldingOp, outputBlockArg] : llvm::zip(
1789 terminator.getYieldingOps(), newForallOp.getRegionIterArgs())) {
1790 if (auto parallelCombingingOp =
1791 dyn_cast<ParallelCombiningOpInterface>(yieldingOp)) {
1792 parallelCombingingOp.getUpdatedDestinations().assign(outputBlockArg);
1793 }
1794 }
1795
1796 // Cast results back to the original types.
1797 rewriter.setInsertionPointAfter(newForallOp);
1798 SmallVector<Value> castResults = newForallOp.getResults();
1799 for (auto &item : tensorCastProducers) {
1800 Value &oldTypeResult = castResults[item.first];
1801 oldTypeResult = tensor::CastOp::create(rewriter, loc, item.second.dstType,
1802 oldTypeResult);
1803 }
1804 rewriter.replaceOp(forallOp, castResults);
1805 return success();
1806 }
1807};
1808
1809} // namespace
1810
1811void ForallOp::getCanonicalizationPatterns(RewritePatternSet &results,
1812 MLIRContext *context) {
1813 results.add<DimOfForallOp, FoldTensorCastOfOutputIntoForallOp,
1814 ForallOpControlOperandsFolder, ForallOpIterArgsFolder,
1815 ForallOpSingleOrZeroIterationDimsFolder,
1816 ForallOpReplaceConstantInductionVar>(context);
1817}
1818
1819void ForallOp::getSuccessorRegions(RegionBranchPoint point,
1820 SmallVectorImpl<RegionSuccessor> &regions) {
1821 // There are two region branch points:
1822 // 1. "parent": entering the forall op for the first time.
1823 // 2. scf.in_parallel terminator
1824 if (point.isParent()) {
1825 // When first entering the forall op, the control flow typically branches
1826 // into the forall body. (In parallel for multiple threads.)
1827 regions.push_back(RegionSuccessor(&getRegion()));
1828 // However, when there are 0 threads, the control flow may branch back to
1829 // the parent immediately.
1830 regions.push_back(RegionSuccessor::parent(
1831 ResultRange{getResults().end(), getResults().end()}));
1832 } else {
1833 // In accordance with the semantics of forall, its body is executed in
1834 // parallel by multiple threads. We should not expect to branch back into
1835 // the forall body after the region's execution is complete.
1836 regions.push_back(RegionSuccessor::parent(
1837 ResultRange{getResults().end(), getResults().end()}));
1838 }
1839}
1840
1841//===----------------------------------------------------------------------===//
1842// InParallelOp
1843//===----------------------------------------------------------------------===//
1844
1845// Build a InParallelOp with mixed static and dynamic entries.
1846void InParallelOp::build(OpBuilder &b, OperationState &result) {
1847 OpBuilder::InsertionGuard g(b);
1848 Region *bodyRegion = result.addRegion();
1849 b.createBlock(bodyRegion);
1850}
1851
1852LogicalResult InParallelOp::verify() {
1853 scf::ForallOp forallOp =
1854 dyn_cast<scf::ForallOp>(getOperation()->getParentOp());
1855 if (!forallOp)
1856 return this->emitOpError("expected forall op parent");
1857
1858 for (Operation &op : getRegion().front().getOperations()) {
1859 auto parallelCombiningOp = dyn_cast<ParallelCombiningOpInterface>(&op);
1860 if (!parallelCombiningOp) {
1861 return this->emitOpError("expected only ParallelCombiningOpInterface")
1862 << " ops";
1863 }
1864
1865 // Verify that inserts are into out block arguments.
1866 MutableOperandRange dests = parallelCombiningOp.getUpdatedDestinations();
1867 ArrayRef<BlockArgument> regionOutArgs = forallOp.getRegionOutArgs();
1868 for (OpOperand &dest : dests) {
1869 if (!llvm::is_contained(regionOutArgs, dest.get()))
1870 return op.emitOpError("may only insert into an output block argument");
1871 }
1872 }
1873
1874 return success();
1875}
1876
1877void InParallelOp::print(OpAsmPrinter &p) {
1878 p << " ";
1879 p.printRegion(getRegion(),
1880 /*printEntryBlockArgs=*/false,
1881 /*printBlockTerminators=*/false);
1882 p.printOptionalAttrDict(getOperation()->getAttrs());
1883}
1884
1885ParseResult InParallelOp::parse(OpAsmParser &parser, OperationState &result) {
1886 auto &builder = parser.getBuilder();
1887
1888 SmallVector<OpAsmParser::Argument, 8> regionOperands;
1889 std::unique_ptr<Region> region = std::make_unique<Region>();
1890 if (parser.parseRegion(*region, regionOperands))
1891 return failure();
1892
1893 if (region->empty())
1894 OpBuilder(builder.getContext()).createBlock(region.get());
1895 result.addRegion(std::move(region));
1896
1897 // Parse the optional attribute list.
1898 if (parser.parseOptionalAttrDict(result.attributes))
1899 return failure();
1900 return success();
1901}
1902
1903OpResult InParallelOp::getParentResult(int64_t idx) {
1904 return getOperation()->getParentOp()->getResult(idx);
1905}
1906
1907SmallVector<BlockArgument> InParallelOp::getDests() {
1908 SmallVector<BlockArgument> updatedDests;
1909 for (Operation &yieldingOp : getYieldingOps()) {
1910 auto parallelCombiningOp =
1911 dyn_cast<ParallelCombiningOpInterface>(&yieldingOp);
1912 if (!parallelCombiningOp)
1913 continue;
1914 for (OpOperand &updatedOperand :
1915 parallelCombiningOp.getUpdatedDestinations())
1916 updatedDests.push_back(cast<BlockArgument>(updatedOperand.get()));
1917 }
1918 return updatedDests;
1919}
1920
1921llvm::iterator_range<Block::iterator> InParallelOp::getYieldingOps() {
1922 return getRegion().front().getOperations();
1923}
1924
1925//===----------------------------------------------------------------------===//
1926// IfOp
1927//===----------------------------------------------------------------------===//
1928
1930 assert(a && "expected non-empty operation");
1931 assert(b && "expected non-empty operation");
1932
1933 IfOp ifOp = a->getParentOfType<IfOp>();
1934 while (ifOp) {
1935 // Check if b is inside ifOp. (We already know that a is.)
1936 if (ifOp->isProperAncestor(b))
1937 // b is contained in ifOp. a and b are in mutually exclusive branches if
1938 // they are in different blocks of ifOp.
1939 return static_cast<bool>(ifOp.thenBlock()->findAncestorOpInBlock(*a)) !=
1940 static_cast<bool>(ifOp.thenBlock()->findAncestorOpInBlock(*b));
1941 // Check next enclosing IfOp.
1942 ifOp = ifOp->getParentOfType<IfOp>();
1943 }
1944
1945 // Could not find a common IfOp among a's and b's ancestors.
1946 return false;
1947}
1948
1949LogicalResult
1950IfOp::inferReturnTypes(MLIRContext *ctx, std::optional<Location> loc,
1951 IfOp::Adaptor adaptor,
1952 SmallVectorImpl<Type> &inferredReturnTypes) {
1953 if (adaptor.getRegions().empty())
1954 return failure();
1955 Region *r = &adaptor.getThenRegion();
1956 if (r->empty())
1957 return failure();
1958 Block &b = r->front();
1959 if (b.empty())
1960 return failure();
1961 auto yieldOp = llvm::dyn_cast<YieldOp>(b.back());
1962 if (!yieldOp)
1963 return failure();
1964 TypeRange types = yieldOp.getOperandTypes();
1965 llvm::append_range(inferredReturnTypes, types);
1966 return success();
1967}
1968
1969void IfOp::build(OpBuilder &builder, OperationState &result,
1970 TypeRange resultTypes, Value cond) {
1971 return build(builder, result, resultTypes, cond, /*addThenBlock=*/false,
1972 /*addElseBlock=*/false);
1973}
1974
1975void IfOp::build(OpBuilder &builder, OperationState &result,
1976 TypeRange resultTypes, Value cond, bool addThenBlock,
1977 bool addElseBlock) {
1978 assert((!addElseBlock || addThenBlock) &&
1979 "must not create else block w/o then block");
1980 result.addTypes(resultTypes);
1981 result.addOperands(cond);
1982
1983 // Add regions and blocks.
1984 OpBuilder::InsertionGuard guard(builder);
1985 Region *thenRegion = result.addRegion();
1986 if (addThenBlock)
1987 builder.createBlock(thenRegion);
1988 Region *elseRegion = result.addRegion();
1989 if (addElseBlock)
1990 builder.createBlock(elseRegion);
1991}
1992
1993void IfOp::build(OpBuilder &builder, OperationState &result, Value cond,
1994 bool withElseRegion) {
1995 build(builder, result, TypeRange{}, cond, withElseRegion);
1996}
1997
1998void IfOp::build(OpBuilder &builder, OperationState &result,
1999 TypeRange resultTypes, Value cond, bool withElseRegion) {
2000 result.addTypes(resultTypes);
2001 result.addOperands(cond);
2002
2003 // Build then region.
2004 OpBuilder::InsertionGuard guard(builder);
2005 Region *thenRegion = result.addRegion();
2006 builder.createBlock(thenRegion);
2007 if (resultTypes.empty())
2008 IfOp::ensureTerminator(*thenRegion, builder, result.location);
2009
2010 // Build else region.
2011 Region *elseRegion = result.addRegion();
2012 if (withElseRegion) {
2013 builder.createBlock(elseRegion);
2014 if (resultTypes.empty())
2015 IfOp::ensureTerminator(*elseRegion, builder, result.location);
2016 }
2017}
2018
2019void IfOp::build(OpBuilder &builder, OperationState &result, Value cond,
2020 function_ref<void(OpBuilder &, Location)> thenBuilder,
2021 function_ref<void(OpBuilder &, Location)> elseBuilder) {
2022 assert(thenBuilder && "the builder callback for 'then' must be present");
2023 result.addOperands(cond);
2024
2025 // Build then region.
2026 OpBuilder::InsertionGuard guard(builder);
2027 Region *thenRegion = result.addRegion();
2028 builder.createBlock(thenRegion);
2029 thenBuilder(builder, result.location);
2030
2031 // Build else region.
2032 Region *elseRegion = result.addRegion();
2033 if (elseBuilder) {
2034 builder.createBlock(elseRegion);
2035 elseBuilder(builder, result.location);
2036 }
2037
2038 // Infer result types.
2039 SmallVector<Type> inferredReturnTypes;
2040 MLIRContext *ctx = builder.getContext();
2041 auto attrDict = DictionaryAttr::get(ctx, result.attributes);
2042 if (succeeded(inferReturnTypes(ctx, std::nullopt, result.operands, attrDict,
2043 /*properties=*/nullptr, result.regions,
2044 inferredReturnTypes))) {
2045 result.addTypes(inferredReturnTypes);
2046 }
2047}
2048
2049LogicalResult IfOp::verify() {
2050 if (getNumResults() != 0 && getElseRegion().empty())
2051 return emitOpError("must have an else block if defining values");
2052 return success();
2053}
2054
2055ParseResult IfOp::parse(OpAsmParser &parser, OperationState &result) {
2056 // Create the regions for 'then'.
2057 result.regions.reserve(2);
2058 Region *thenRegion = result.addRegion();
2059 Region *elseRegion = result.addRegion();
2060
2061 auto &builder = parser.getBuilder();
2062 OpAsmParser::UnresolvedOperand cond;
2063 Type i1Type = builder.getIntegerType(1);
2064 if (parser.parseOperand(cond) ||
2065 parser.resolveOperand(cond, i1Type, result.operands))
2066 return failure();
2067 // Parse optional results type list.
2068 if (parser.parseOptionalArrowTypeList(result.types))
2069 return failure();
2070 // Parse the 'then' region.
2071 if (parser.parseRegion(*thenRegion, /*arguments=*/{}, /*argTypes=*/{}))
2072 return failure();
2073 IfOp::ensureTerminator(*thenRegion, parser.getBuilder(), result.location);
2074
2075 // If we find an 'else' keyword then parse the 'else' region.
2076 if (!parser.parseOptionalKeyword("else")) {
2077 if (parser.parseRegion(*elseRegion, /*arguments=*/{}, /*argTypes=*/{}))
2078 return failure();
2079 IfOp::ensureTerminator(*elseRegion, parser.getBuilder(), result.location);
2080 }
2081
2082 // Parse the optional attribute list.
2083 if (parser.parseOptionalAttrDict(result.attributes))
2084 return failure();
2085 return success();
2086}
2087
2088void IfOp::print(OpAsmPrinter &p) {
2089 bool printBlockTerminators = false;
2090
2091 p << " " << getCondition();
2092 if (!getResults().empty()) {
2093 p << " -> (" << getResultTypes() << ")";
2094 // Print yield explicitly if the op defines values.
2095 printBlockTerminators = true;
2096 }
2097 p << ' ';
2098 p.printRegion(getThenRegion(),
2099 /*printEntryBlockArgs=*/false,
2100 /*printBlockTerminators=*/printBlockTerminators);
2101
2102 // Print the 'else' regions if it exists and has a block.
2103 auto &elseRegion = getElseRegion();
2104 if (!elseRegion.empty()) {
2105 p << " else ";
2106 p.printRegion(elseRegion,
2107 /*printEntryBlockArgs=*/false,
2108 /*printBlockTerminators=*/printBlockTerminators);
2109 }
2110
2111 p.printOptionalAttrDict((*this)->getAttrs());
2112}
2113
2114void IfOp::getSuccessorRegions(RegionBranchPoint point,
2115 SmallVectorImpl<RegionSuccessor> &regions) {
2116 // The `then` and the `else` region branch back to the parent operation or one
2117 // of the recursive parent operations (early exit case).
2118 if (!point.isParent()) {
2119 regions.push_back(RegionSuccessor::parent(getResults()));
2120 return;
2121 }
2122
2123 regions.push_back(RegionSuccessor(&getThenRegion()));
2124
2125 // Don't consider the else region if it is empty.
2126 Region *elseRegion = &this->getElseRegion();
2127 if (elseRegion->empty())
2128 regions.push_back(RegionSuccessor::parent(getResults()));
2129 else
2130 regions.push_back(RegionSuccessor(elseRegion));
2131}
2132
2133void IfOp::getEntrySuccessorRegions(ArrayRef<Attribute> operands,
2134 SmallVectorImpl<RegionSuccessor> &regions) {
2135 FoldAdaptor adaptor(operands, *this);
2136 auto boolAttr = dyn_cast_or_null<BoolAttr>(adaptor.getCondition());
2137 if (!boolAttr || boolAttr.getValue())
2138 regions.emplace_back(&getThenRegion());
2139
2140 // If the else region is empty, execution continues after the parent op.
2141 if (!boolAttr || !boolAttr.getValue()) {
2142 if (!getElseRegion().empty())
2143 regions.emplace_back(&getElseRegion());
2144 else
2145 regions.emplace_back(RegionSuccessor::parent(getResults()));
2146 }
2147}
2148
2149LogicalResult IfOp::fold(FoldAdaptor adaptor,
2150 SmallVectorImpl<OpFoldResult> &results) {
2151 // if (!c) then A() else B() -> if c then B() else A()
2152 if (getElseRegion().empty())
2153 return failure();
2154
2155 arith::XOrIOp xorStmt = getCondition().getDefiningOp<arith::XOrIOp>();
2156 if (!xorStmt)
2157 return failure();
2158
2159 if (!matchPattern(xorStmt.getRhs(), m_One()))
2160 return failure();
2161
2162 getConditionMutable().assign(xorStmt.getLhs());
2163 Block *thenBlock = &getThenRegion().front();
2164 // It would be nicer to use iplist::swap, but that has no implemented
2165 // callbacks See: https://llvm.org/doxygen/ilist_8h_source.html#l00224
2166 getThenRegion().getBlocks().splice(getThenRegion().getBlocks().begin(),
2167 getElseRegion().getBlocks());
2168 getElseRegion().getBlocks().splice(getElseRegion().getBlocks().begin(),
2169 getThenRegion().getBlocks(), thenBlock);
2170 return success();
2171}
2172
2173void IfOp::getRegionInvocationBounds(
2174 ArrayRef<Attribute> operands,
2175 SmallVectorImpl<InvocationBounds> &invocationBounds) {
2176 if (auto cond = llvm::dyn_cast_or_null<BoolAttr>(operands[0])) {
2177 // If the condition is known, then one region is known to be executed once
2178 // and the other zero times.
2179 invocationBounds.emplace_back(0, cond.getValue() ? 1 : 0);
2180 invocationBounds.emplace_back(0, cond.getValue() ? 0 : 1);
2181 } else {
2182 // Non-constant condition. Each region may be executed 0 or 1 times.
2183 invocationBounds.assign(2, {0, 1});
2184 }
2185}
2186
2187namespace {
2188struct RemoveStaticCondition : public OpRewritePattern<IfOp> {
2189 using OpRewritePattern<IfOp>::OpRewritePattern;
2190
2191 LogicalResult matchAndRewrite(IfOp op,
2192 PatternRewriter &rewriter) const override {
2193 BoolAttr condition;
2194 if (!matchPattern(op.getCondition(), m_Constant(&condition)))
2195 return failure();
2196
2197 if (condition.getValue())
2198 replaceOpWithRegion(rewriter, op, op.getThenRegion());
2199 else if (!op.getElseRegion().empty())
2200 replaceOpWithRegion(rewriter, op, op.getElseRegion());
2201 else
2202 rewriter.eraseOp(op);
2203
2204 return success();
2205 }
2206};
2207
2208/// Hoist any yielded results whose operands are defined outside
2209/// the if, to a select instruction.
2210struct ConvertTrivialIfToSelect : public OpRewritePattern<IfOp> {
2211 using OpRewritePattern<IfOp>::OpRewritePattern;
2212
2213 LogicalResult matchAndRewrite(IfOp op,
2214 PatternRewriter &rewriter) const override {
2215 if (op->getNumResults() == 0)
2216 return failure();
2217
2218 auto cond = op.getCondition();
2219 auto thenYieldArgs = op.thenYield().getOperands();
2220 auto elseYieldArgs = op.elseYield().getOperands();
2221
2222 SmallVector<Type> nonHoistable;
2223 for (auto [trueVal, falseVal] : llvm::zip(thenYieldArgs, elseYieldArgs)) {
2224 if (&op.getThenRegion() == trueVal.getParentRegion() ||
2225 &op.getElseRegion() == falseVal.getParentRegion())
2226 nonHoistable.push_back(trueVal.getType());
2227 }
2228 // Early exit if there aren't any yielded values we can
2229 // hoist outside the if.
2230 if (nonHoistable.size() == op->getNumResults())
2231 return failure();
2232
2233 IfOp replacement = IfOp::create(rewriter, op.getLoc(), nonHoistable, cond,
2234 /*withElseRegion=*/false);
2235 if (replacement.thenBlock())
2236 rewriter.eraseBlock(replacement.thenBlock());
2237 replacement.getThenRegion().takeBody(op.getThenRegion());
2238 replacement.getElseRegion().takeBody(op.getElseRegion());
2239
2240 SmallVector<Value> results(op->getNumResults());
2241 assert(thenYieldArgs.size() == results.size());
2242 assert(elseYieldArgs.size() == results.size());
2243
2244 SmallVector<Value> trueYields;
2245 SmallVector<Value> falseYields;
2247 for (const auto &it :
2248 llvm::enumerate(llvm::zip(thenYieldArgs, elseYieldArgs))) {
2249 Value trueVal = std::get<0>(it.value());
2250 Value falseVal = std::get<1>(it.value());
2251 if (&replacement.getThenRegion() == trueVal.getParentRegion() ||
2252 &replacement.getElseRegion() == falseVal.getParentRegion()) {
2253 results[it.index()] = replacement.getResult(trueYields.size());
2254 trueYields.push_back(trueVal);
2255 falseYields.push_back(falseVal);
2256 } else if (trueVal == falseVal)
2257 results[it.index()] = trueVal;
2258 else
2259 results[it.index()] = arith::SelectOp::create(rewriter, op.getLoc(),
2260 cond, trueVal, falseVal);
2261 }
2262
2263 rewriter.setInsertionPointToEnd(replacement.thenBlock());
2264 rewriter.replaceOpWithNewOp<YieldOp>(replacement.thenYield(), trueYields);
2265
2266 rewriter.setInsertionPointToEnd(replacement.elseBlock());
2267 rewriter.replaceOpWithNewOp<YieldOp>(replacement.elseYield(), falseYields);
2268
2269 rewriter.replaceOp(op, results);
2270 return success();
2271 }
2272};
2273
2274/// Allow the true region of an if to assume the condition is true
2275/// and vice versa. For example:
2276///
2277/// scf.if %cmp {
2278/// print(%cmp)
2279/// }
2280///
2281/// becomes
2282///
2283/// scf.if %cmp {
2284/// print(true)
2285/// }
2286///
2287struct ConditionPropagation : public OpRewritePattern<IfOp> {
2288 using OpRewritePattern<IfOp>::OpRewritePattern;
2289
2290 /// Kind of parent region in the ancestor cache.
2291 enum class Parent { Then, Else, None };
2292
2293 /// Returns the kind of region ("then", "else", or "none") of the
2294 /// IfOp that the given region is transitively nested in. Updates
2295 /// the cache accordingly.
2296 static Parent getParentType(Region *toCheck, IfOp op,
2298 Region *endRegion) {
2299 SmallVector<Region *> seen;
2300 while (toCheck != endRegion) {
2301 auto found = cache.find(toCheck);
2302 if (found != cache.end())
2303 return found->second;
2304 seen.push_back(toCheck);
2305 if (&op.getThenRegion() == toCheck) {
2306 for (Region *region : seen)
2307 cache[region] = Parent::Then;
2308 return Parent::Then;
2309 }
2310 if (&op.getElseRegion() == toCheck) {
2311 for (Region *region : seen)
2312 cache[region] = Parent::Else;
2313 return Parent::Else;
2314 }
2315 toCheck = toCheck->getParentRegion();
2316 }
2317
2318 for (Region *region : seen)
2319 cache[region] = Parent::None;
2320 return Parent::None;
2321 }
2322
2323 LogicalResult matchAndRewrite(IfOp op,
2324 PatternRewriter &rewriter) const override {
2325 // Early exit if the condition is constant since replacing a constant
2326 // in the body with another constant isn't a simplification.
2327 if (matchPattern(op.getCondition(), m_Constant()))
2328 return failure();
2329
2330 bool changed = false;
2331 mlir::Type i1Ty = rewriter.getI1Type();
2332
2333 // These variables serve to prevent creating duplicate constants
2334 // and hold constant true or false values.
2335 Value constantTrue = nullptr;
2336 Value constantFalse = nullptr;
2337
2339 for (OpOperand &use :
2340 llvm::make_early_inc_range(op.getCondition().getUses())) {
2341 switch (getParentType(use.getOwner()->getParentRegion(), op, cache,
2342 op.getCondition().getParentRegion())) {
2343 case Parent::Then: {
2344 changed = true;
2345
2346 if (!constantTrue)
2347 constantTrue = arith::ConstantOp::create(
2348 rewriter, op.getLoc(), i1Ty, rewriter.getIntegerAttr(i1Ty, 1));
2349
2350 rewriter.modifyOpInPlace(use.getOwner(),
2351 [&]() { use.set(constantTrue); });
2352 break;
2353 }
2354 case Parent::Else: {
2355 changed = true;
2356
2357 if (!constantFalse)
2358 constantFalse = arith::ConstantOp::create(
2359 rewriter, op.getLoc(), i1Ty, rewriter.getIntegerAttr(i1Ty, 0));
2360
2361 rewriter.modifyOpInPlace(use.getOwner(),
2362 [&]() { use.set(constantFalse); });
2363 break;
2364 }
2365 case Parent::None:
2366 break;
2367 }
2368 }
2369
2370 return success(changed);
2371 }
2372};
2373
2374/// Remove any statements from an if that are equivalent to the condition
2375/// or its negation. For example:
2376///
2377/// %res:2 = scf.if %cmp {
2378/// yield something(), true
2379/// } else {
2380/// yield something2(), false
2381/// }
2382/// print(%res#1)
2383///
2384/// becomes
2385/// %res = scf.if %cmp {
2386/// yield something()
2387/// } else {
2388/// yield something2()
2389/// }
2390/// print(%cmp)
2391///
2392/// Additionally if both branches yield the same value, replace all uses
2393/// of the result with the yielded value.
2394///
2395/// %res:2 = scf.if %cmp {
2396/// yield something(), %arg1
2397/// } else {
2398/// yield something2(), %arg1
2399/// }
2400/// print(%res#1)
2401///
2402/// becomes
2403/// %res = scf.if %cmp {
2404/// yield something()
2405/// } else {
2406/// yield something2()
2407/// }
2408/// print(%arg1)
2409///
2410struct ReplaceIfYieldWithConditionOrValue : public OpRewritePattern<IfOp> {
2411 using OpRewritePattern<IfOp>::OpRewritePattern;
2412
2413 LogicalResult matchAndRewrite(IfOp op,
2414 PatternRewriter &rewriter) const override {
2415 // Early exit if there are no results that could be replaced.
2416 if (op.getNumResults() == 0)
2417 return failure();
2418
2419 auto trueYield =
2420 cast<scf::YieldOp>(op.getThenRegion().back().getTerminator());
2421 auto falseYield =
2422 cast<scf::YieldOp>(op.getElseRegion().back().getTerminator());
2423
2424 rewriter.setInsertionPoint(op->getBlock(),
2425 op.getOperation()->getIterator());
2426 bool changed = false;
2427 Type i1Ty = rewriter.getI1Type();
2428 for (auto [trueResult, falseResult, opResult] :
2429 llvm::zip(trueYield.getResults(), falseYield.getResults(),
2430 op.getResults())) {
2431 if (trueResult == falseResult) {
2432 if (!opResult.use_empty()) {
2433 opResult.replaceAllUsesWith(trueResult);
2434 changed = true;
2435 }
2436 continue;
2437 }
2438
2439 BoolAttr trueYield, falseYield;
2440 if (!matchPattern(trueResult, m_Constant(&trueYield)) ||
2441 !matchPattern(falseResult, m_Constant(&falseYield)))
2442 continue;
2443
2444 bool trueVal = trueYield.getValue();
2445 bool falseVal = falseYield.getValue();
2446 if (!trueVal && falseVal) {
2447 if (!opResult.use_empty()) {
2448 Dialect *constDialect = trueResult.getDefiningOp()->getDialect();
2449 Value notCond = arith::XOrIOp::create(
2450 rewriter, op.getLoc(), op.getCondition(),
2451 constDialect
2452 ->materializeConstant(rewriter,
2453 rewriter.getIntegerAttr(i1Ty, 1), i1Ty,
2454 op.getLoc())
2455 ->getResult(0));
2456 opResult.replaceAllUsesWith(notCond);
2457 changed = true;
2458 }
2459 }
2460 if (trueVal && !falseVal) {
2461 if (!opResult.use_empty()) {
2462 opResult.replaceAllUsesWith(op.getCondition());
2463 changed = true;
2464 }
2465 }
2466 }
2467 return success(changed);
2468 }
2469};
2470
2471/// Merge any consecutive scf.if's with the same condition.
2472///
2473/// scf.if %cond {
2474/// firstCodeTrue();...
2475/// } else {
2476/// firstCodeFalse();...
2477/// }
2478/// %res = scf.if %cond {
2479/// secondCodeTrue();...
2480/// } else {
2481/// secondCodeFalse();...
2482/// }
2483///
2484/// becomes
2485/// %res = scf.if %cmp {
2486/// firstCodeTrue();...
2487/// secondCodeTrue();...
2488/// } else {
2489/// firstCodeFalse();...
2490/// secondCodeFalse();...
2491/// }
2492struct CombineIfs : public OpRewritePattern<IfOp> {
2493 using OpRewritePattern<IfOp>::OpRewritePattern;
2494
2495 LogicalResult matchAndRewrite(IfOp nextIf,
2496 PatternRewriter &rewriter) const override {
2497 Block *parent = nextIf->getBlock();
2498 if (nextIf == &parent->front())
2499 return failure();
2500
2501 auto prevIf = dyn_cast<IfOp>(nextIf->getPrevNode());
2502 if (!prevIf)
2503 return failure();
2504
2505 // Determine the logical then/else blocks when prevIf's
2506 // condition is used. Null means the block does not exist
2507 // in that case (e.g. empty else). If neither of these
2508 // are set, the two conditions cannot be compared.
2509 Block *nextThen = nullptr;
2510 Block *nextElse = nullptr;
2511 if (nextIf.getCondition() == prevIf.getCondition()) {
2512 nextThen = nextIf.thenBlock();
2513 if (!nextIf.getElseRegion().empty())
2514 nextElse = nextIf.elseBlock();
2515 }
2516 if (arith::XOrIOp notv =
2517 nextIf.getCondition().getDefiningOp<arith::XOrIOp>()) {
2518 if (notv.getLhs() == prevIf.getCondition() &&
2519 matchPattern(notv.getRhs(), m_One())) {
2520 nextElse = nextIf.thenBlock();
2521 if (!nextIf.getElseRegion().empty())
2522 nextThen = nextIf.elseBlock();
2523 }
2524 }
2525 if (arith::XOrIOp notv =
2526 prevIf.getCondition().getDefiningOp<arith::XOrIOp>()) {
2527 if (notv.getLhs() == nextIf.getCondition() &&
2528 matchPattern(notv.getRhs(), m_One())) {
2529 nextElse = nextIf.thenBlock();
2530 if (!nextIf.getElseRegion().empty())
2531 nextThen = nextIf.elseBlock();
2532 }
2533 }
2534
2535 if (!nextThen && !nextElse)
2536 return failure();
2537
2538 SmallVector<Value> prevElseYielded;
2539 if (!prevIf.getElseRegion().empty())
2540 prevElseYielded = prevIf.elseYield().getOperands();
2541 // Replace all uses of return values of op within nextIf with the
2542 // corresponding yields
2543 for (auto it : llvm::zip(prevIf.getResults(),
2544 prevIf.thenYield().getOperands(), prevElseYielded))
2545 for (OpOperand &use :
2546 llvm::make_early_inc_range(std::get<0>(it).getUses())) {
2547 if (nextThen && nextThen->getParent()->isAncestor(
2548 use.getOwner()->getParentRegion())) {
2549 rewriter.startOpModification(use.getOwner());
2550 use.set(std::get<1>(it));
2551 rewriter.finalizeOpModification(use.getOwner());
2552 } else if (nextElse && nextElse->getParent()->isAncestor(
2553 use.getOwner()->getParentRegion())) {
2554 rewriter.startOpModification(use.getOwner());
2555 use.set(std::get<2>(it));
2556 rewriter.finalizeOpModification(use.getOwner());
2557 }
2558 }
2559
2560 SmallVector<Type> mergedTypes(prevIf.getResultTypes());
2561 llvm::append_range(mergedTypes, nextIf.getResultTypes());
2562
2563 IfOp combinedIf = IfOp::create(rewriter, nextIf.getLoc(), mergedTypes,
2564 prevIf.getCondition(), /*hasElse=*/false);
2565 rewriter.eraseBlock(&combinedIf.getThenRegion().back());
2566
2567 rewriter.inlineRegionBefore(prevIf.getThenRegion(),
2568 combinedIf.getThenRegion(),
2569 combinedIf.getThenRegion().begin());
2570
2571 if (nextThen) {
2572 YieldOp thenYield = combinedIf.thenYield();
2573 YieldOp thenYield2 = cast<YieldOp>(nextThen->getTerminator());
2574 rewriter.mergeBlocks(nextThen, combinedIf.thenBlock());
2575 rewriter.setInsertionPointToEnd(combinedIf.thenBlock());
2576
2577 SmallVector<Value> mergedYields(thenYield.getOperands());
2578 llvm::append_range(mergedYields, thenYield2.getOperands());
2579 YieldOp::create(rewriter, thenYield2.getLoc(), mergedYields);
2580 rewriter.eraseOp(thenYield);
2581 rewriter.eraseOp(thenYield2);
2582 }
2583
2584 rewriter.inlineRegionBefore(prevIf.getElseRegion(),
2585 combinedIf.getElseRegion(),
2586 combinedIf.getElseRegion().begin());
2587
2588 if (nextElse) {
2589 if (combinedIf.getElseRegion().empty()) {
2590 rewriter.inlineRegionBefore(*nextElse->getParent(),
2591 combinedIf.getElseRegion(),
2592 combinedIf.getElseRegion().begin());
2593 } else {
2594 YieldOp elseYield = combinedIf.elseYield();
2595 YieldOp elseYield2 = cast<YieldOp>(nextElse->getTerminator());
2596 rewriter.mergeBlocks(nextElse, combinedIf.elseBlock());
2597
2598 rewriter.setInsertionPointToEnd(combinedIf.elseBlock());
2599
2600 SmallVector<Value> mergedElseYields(elseYield.getOperands());
2601 llvm::append_range(mergedElseYields, elseYield2.getOperands());
2602
2603 YieldOp::create(rewriter, elseYield2.getLoc(), mergedElseYields);
2604 rewriter.eraseOp(elseYield);
2605 rewriter.eraseOp(elseYield2);
2606 }
2607 }
2608
2609 SmallVector<Value> prevValues;
2610 SmallVector<Value> nextValues;
2611 for (const auto &pair : llvm::enumerate(combinedIf.getResults())) {
2612 if (pair.index() < prevIf.getNumResults())
2613 prevValues.push_back(pair.value());
2614 else
2615 nextValues.push_back(pair.value());
2616 }
2617 rewriter.replaceOp(prevIf, prevValues);
2618 rewriter.replaceOp(nextIf, nextValues);
2619 return success();
2620 }
2621};
2622
2623/// Pattern to remove an empty else branch.
2624struct RemoveEmptyElseBranch : public OpRewritePattern<IfOp> {
2625 using OpRewritePattern<IfOp>::OpRewritePattern;
2626
2627 LogicalResult matchAndRewrite(IfOp ifOp,
2628 PatternRewriter &rewriter) const override {
2629 // Cannot remove else region when there are operation results.
2630 if (ifOp.getNumResults())
2631 return failure();
2632 Block *elseBlock = ifOp.elseBlock();
2633 if (!elseBlock || !llvm::hasSingleElement(*elseBlock))
2634 return failure();
2635 auto newIfOp = rewriter.cloneWithoutRegions(ifOp);
2636 rewriter.inlineRegionBefore(ifOp.getThenRegion(), newIfOp.getThenRegion(),
2637 newIfOp.getThenRegion().begin());
2638 rewriter.eraseOp(ifOp);
2639 return success();
2640 }
2641};
2642
2643/// Convert nested `if`s into `arith.andi` + single `if`.
2644///
2645/// scf.if %arg0 {
2646/// scf.if %arg1 {
2647/// ...
2648/// scf.yield
2649/// }
2650/// scf.yield
2651/// }
2652/// becomes
2653///
2654/// %0 = arith.andi %arg0, %arg1
2655/// scf.if %0 {
2656/// ...
2657/// scf.yield
2658/// }
2659struct CombineNestedIfs : public OpRewritePattern<IfOp> {
2660 using OpRewritePattern<IfOp>::OpRewritePattern;
2661
2662 LogicalResult matchAndRewrite(IfOp op,
2663 PatternRewriter &rewriter) const override {
2664 auto nestedOps = op.thenBlock()->without_terminator();
2665 // Nested `if` must be the only op in block.
2666 if (!llvm::hasSingleElement(nestedOps))
2667 return failure();
2668
2669 // If there is an else block, it can only yield
2670 if (op.elseBlock() && !llvm::hasSingleElement(*op.elseBlock()))
2671 return failure();
2672
2673 auto nestedIf = dyn_cast<IfOp>(*nestedOps.begin());
2674 if (!nestedIf)
2675 return failure();
2676
2677 if (nestedIf.elseBlock() && !llvm::hasSingleElement(*nestedIf.elseBlock()))
2678 return failure();
2679
2680 SmallVector<Value> thenYield(op.thenYield().getOperands());
2681 SmallVector<Value> elseYield;
2682 if (op.elseBlock())
2683 llvm::append_range(elseYield, op.elseYield().getOperands());
2684
2685 // A list of indices for which we should upgrade the value yielded
2686 // in the else to a select.
2687 SmallVector<unsigned> elseYieldsToUpgradeToSelect;
2688
2689 // If the outer scf.if yields a value produced by the inner scf.if,
2690 // only permit combining if the value yielded when the condition
2691 // is false in the outer scf.if is the same value yielded when the
2692 // inner scf.if condition is false.
2693 // Note that the array access to elseYield will not go out of bounds
2694 // since it must have the same length as thenYield, since they both
2695 // come from the same scf.if.
2696 for (const auto &tup : llvm::enumerate(thenYield)) {
2697 if (tup.value().getDefiningOp() == nestedIf) {
2698 auto nestedIdx = llvm::cast<OpResult>(tup.value()).getResultNumber();
2699 if (nestedIf.elseYield().getOperand(nestedIdx) !=
2700 elseYield[tup.index()]) {
2701 return failure();
2702 }
2703 // If the correctness test passes, we will yield
2704 // corresponding value from the inner scf.if
2705 thenYield[tup.index()] = nestedIf.thenYield().getOperand(nestedIdx);
2706 continue;
2707 }
2708
2709 // Otherwise, we need to ensure the else block of the combined
2710 // condition still returns the same value when the outer condition is
2711 // true and the inner condition is false. This can be accomplished if
2712 // the then value is defined outside the outer scf.if and we replace the
2713 // value with a select that considers just the outer condition. Since
2714 // the else region contains just the yield, its yielded value is
2715 // defined outside the scf.if, by definition.
2716
2717 // If the then value is defined within the scf.if, bail.
2718 if (tup.value().getParentRegion() == &op.getThenRegion()) {
2719 return failure();
2720 }
2721 elseYieldsToUpgradeToSelect.push_back(tup.index());
2722 }
2723
2724 Location loc = op.getLoc();
2725 Value newCondition = arith::AndIOp::create(rewriter, loc, op.getCondition(),
2726 nestedIf.getCondition());
2727 auto newIf = IfOp::create(rewriter, loc, op.getResultTypes(), newCondition);
2728 Block *newIfBlock = rewriter.createBlock(&newIf.getThenRegion());
2729
2730 SmallVector<Value> results;
2731 llvm::append_range(results, newIf.getResults());
2732 rewriter.setInsertionPoint(newIf);
2733
2734 for (auto idx : elseYieldsToUpgradeToSelect)
2735 results[idx] =
2736 arith::SelectOp::create(rewriter, op.getLoc(), op.getCondition(),
2737 thenYield[idx], elseYield[idx]);
2738
2739 rewriter.mergeBlocks(nestedIf.thenBlock(), newIfBlock);
2740 rewriter.setInsertionPointToEnd(newIf.thenBlock());
2741 rewriter.replaceOpWithNewOp<YieldOp>(newIf.thenYield(), thenYield);
2742 if (!elseYield.empty()) {
2743 rewriter.createBlock(&newIf.getElseRegion());
2744 rewriter.setInsertionPointToEnd(newIf.elseBlock());
2745 YieldOp::create(rewriter, loc, elseYield);
2746 }
2747 rewriter.replaceOp(op, results);
2748 return success();
2749 }
2750};
2751
2752} // namespace
2753
2754void IfOp::getCanonicalizationPatterns(RewritePatternSet &results,
2755 MLIRContext *context) {
2756 results.add<CombineIfs, CombineNestedIfs, ConditionPropagation,
2757 ConvertTrivialIfToSelect, RemoveEmptyElseBranch,
2758 RemoveStaticCondition, ReplaceIfYieldWithConditionOrValue>(
2759 context);
2761 results, IfOp::getOperationName());
2762}
2763
2764Block *IfOp::thenBlock() { return &getThenRegion().back(); }
2765YieldOp IfOp::thenYield() { return cast<YieldOp>(&thenBlock()->back()); }
2766Block *IfOp::elseBlock() {
2767 Region &r = getElseRegion();
2768 if (r.empty())
2769 return nullptr;
2770 return &r.back();
2771}
2772YieldOp IfOp::elseYield() { return cast<YieldOp>(&elseBlock()->back()); }
2773
2774//===----------------------------------------------------------------------===//
2775// ParallelOp
2776//===----------------------------------------------------------------------===//
2777
2778void ParallelOp::build(
2779 OpBuilder &builder, OperationState &result, ValueRange lowerBounds,
2780 ValueRange upperBounds, ValueRange steps, ValueRange initVals,
2781 function_ref<void(OpBuilder &, Location, ValueRange, ValueRange)>
2782 bodyBuilderFn) {
2783 result.addOperands(lowerBounds);
2784 result.addOperands(upperBounds);
2785 result.addOperands(steps);
2786 result.addOperands(initVals);
2787 result.addAttribute(
2788 ParallelOp::getOperandSegmentSizeAttr(),
2789 builder.getDenseI32ArrayAttr({static_cast<int32_t>(lowerBounds.size()),
2790 static_cast<int32_t>(upperBounds.size()),
2791 static_cast<int32_t>(steps.size()),
2792 static_cast<int32_t>(initVals.size())}));
2793 result.addTypes(initVals.getTypes());
2794
2795 OpBuilder::InsertionGuard guard(builder);
2796 unsigned numIVs = steps.size();
2797 SmallVector<Type, 8> argTypes(numIVs, builder.getIndexType());
2798 SmallVector<Location, 8> argLocs(numIVs, result.location);
2799 Region *bodyRegion = result.addRegion();
2800 Block *bodyBlock = builder.createBlock(bodyRegion, {}, argTypes, argLocs);
2801
2802 if (bodyBuilderFn) {
2803 builder.setInsertionPointToStart(bodyBlock);
2804 bodyBuilderFn(builder, result.location,
2805 bodyBlock->getArguments().take_front(numIVs),
2806 bodyBlock->getArguments().drop_front(numIVs));
2807 }
2808 // Add terminator only if there are no reductions.
2809 if (initVals.empty())
2810 ParallelOp::ensureTerminator(*bodyRegion, builder, result.location);
2811}
2812
2813void ParallelOp::build(
2814 OpBuilder &builder, OperationState &result, ValueRange lowerBounds,
2815 ValueRange upperBounds, ValueRange steps,
2816 function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilderFn) {
2817 // Only pass a non-null wrapper if bodyBuilderFn is non-null itself. Make sure
2818 // we don't capture a reference to a temporary by constructing the lambda at
2819 // function level.
2820 auto wrappedBuilderFn = [&bodyBuilderFn](OpBuilder &nestedBuilder,
2821 Location nestedLoc, ValueRange ivs,
2822 ValueRange) {
2823 bodyBuilderFn(nestedBuilder, nestedLoc, ivs);
2824 };
2825 function_ref<void(OpBuilder &, Location, ValueRange, ValueRange)> wrapper;
2826 if (bodyBuilderFn)
2827 wrapper = wrappedBuilderFn;
2828
2829 build(builder, result, lowerBounds, upperBounds, steps, ValueRange(),
2830 wrapper);
2831}
2832
2833LogicalResult ParallelOp::verify() {
2834 // Check that there is at least one value in lowerBound, upperBound and step.
2835 // It is sufficient to test only step, because it is ensured already that the
2836 // number of elements in lowerBound, upperBound and step are the same.
2837 Operation::operand_range stepValues = getStep();
2838 if (stepValues.empty())
2839 return emitOpError(
2840 "needs at least one tuple element for lowerBound, upperBound and step");
2841
2842 // Check whether all constant step values are positive.
2843 for (Value stepValue : stepValues)
2844 if (auto cst = getConstantIntValue(stepValue))
2845 if (*cst <= 0)
2846 return emitOpError("constant step operand must be positive");
2847
2848 // Check that the body defines the same number of block arguments as the
2849 // number of tuple elements in step.
2850 Block *body = getBody();
2851 if (body->getNumArguments() != stepValues.size())
2852 return emitOpError() << "expects the same number of induction variables: "
2853 << body->getNumArguments()
2854 << " as bound and step values: " << stepValues.size();
2855 for (auto arg : body->getArguments())
2856 if (!arg.getType().isIndex())
2857 return emitOpError(
2858 "expects arguments for the induction variable to be of index type");
2859
2860 // Check that the terminator is an scf.reduce op.
2862 *this, getRegion(), "expects body to terminate with 'scf.reduce'");
2863 if (!reduceOp)
2864 return failure();
2865
2866 // Check that the number of results is the same as the number of reductions.
2867 auto resultsSize = getResults().size();
2868 auto reductionsSize = reduceOp.getReductions().size();
2869 auto initValsSize = getInitVals().size();
2870 if (resultsSize != reductionsSize)
2871 return emitOpError() << "expects number of results: " << resultsSize
2872 << " to be the same as number of reductions: "
2873 << reductionsSize;
2874 if (resultsSize != initValsSize)
2875 return emitOpError() << "expects number of results: " << resultsSize
2876 << " to be the same as number of initial values: "
2877 << initValsSize;
2878 if (reduceOp.getNumOperands() != initValsSize)
2879 // Delegate error reporting to ReduceOp
2880 return success();
2881
2882 // Check that the types of the results and reductions are the same.
2883 for (int64_t i = 0; i < static_cast<int64_t>(reductionsSize); ++i) {
2884 auto resultType = getOperation()->getResult(i).getType();
2885 auto reductionOperandType = reduceOp.getOperands()[i].getType();
2886 if (resultType != reductionOperandType)
2887 return reduceOp.emitOpError()
2888 << "expects type of " << i
2889 << "-th reduction operand: " << reductionOperandType
2890 << " to be the same as the " << i
2891 << "-th result type: " << resultType;
2892 }
2893 return success();
2894}
2895
2896ParseResult ParallelOp::parse(OpAsmParser &parser, OperationState &result) {
2897 auto &builder = parser.getBuilder();
2898 // Parse an opening `(` followed by induction variables followed by `)`
2899 SmallVector<OpAsmParser::Argument, 4> ivs;
2901 return failure();
2902
2903 // Parse loop bounds.
2904 SmallVector<OpAsmParser::UnresolvedOperand, 4> lower;
2905 if (parser.parseEqual() ||
2906 parser.parseOperandList(lower, ivs.size(),
2908 parser.resolveOperands(lower, builder.getIndexType(), result.operands))
2909 return failure();
2910
2911 SmallVector<OpAsmParser::UnresolvedOperand, 4> upper;
2912 if (parser.parseKeyword("to") ||
2913 parser.parseOperandList(upper, ivs.size(),
2915 parser.resolveOperands(upper, builder.getIndexType(), result.operands))
2916 return failure();
2917
2918 // Parse step values.
2919 SmallVector<OpAsmParser::UnresolvedOperand, 4> steps;
2920 if (parser.parseKeyword("step") ||
2921 parser.parseOperandList(steps, ivs.size(),
2923 parser.resolveOperands(steps, builder.getIndexType(), result.operands))
2924 return failure();
2925
2926 // Parse init values.
2927 SmallVector<OpAsmParser::UnresolvedOperand, 4> initVals;
2928 if (succeeded(parser.parseOptionalKeyword("init"))) {
2929 if (parser.parseOperandList(initVals, OpAsmParser::Delimiter::Paren))
2930 return failure();
2931 }
2932
2933 // Parse optional results in case there is a reduce.
2934 if (parser.parseOptionalArrowTypeList(result.types))
2935 return failure();
2936
2937 // Now parse the body.
2938 Region *body = result.addRegion();
2939 for (auto &iv : ivs)
2940 iv.type = builder.getIndexType();
2941 if (parser.parseRegion(*body, ivs))
2942 return failure();
2943
2944 // Set `operandSegmentSizes` attribute.
2945 result.addAttribute(
2946 ParallelOp::getOperandSegmentSizeAttr(),
2947 builder.getDenseI32ArrayAttr({static_cast<int32_t>(lower.size()),
2948 static_cast<int32_t>(upper.size()),
2949 static_cast<int32_t>(steps.size()),
2950 static_cast<int32_t>(initVals.size())}));
2951
2952 // Parse attributes.
2953 if (parser.parseOptionalAttrDict(result.attributes) ||
2954 parser.resolveOperands(initVals, result.types, parser.getNameLoc(),
2955 result.operands))
2956 return failure();
2957
2958 // Add a terminator if none was parsed.
2959 ParallelOp::ensureTerminator(*body, builder, result.location);
2960 return success();
2961}
2962
2963void ParallelOp::print(OpAsmPrinter &p) {
2964 p << " (" << getBody()->getArguments() << ") = (" << getLowerBound()
2965 << ") to (" << getUpperBound() << ") step (" << getStep() << ")";
2966 if (!getInitVals().empty())
2967 p << " init (" << getInitVals() << ")";
2968 p.printOptionalArrowTypeList(getResultTypes());
2969 p << ' ';
2970 p.printRegion(getRegion(), /*printEntryBlockArgs=*/false);
2972 (*this)->getAttrs(),
2973 /*elidedAttrs=*/ParallelOp::getOperandSegmentSizeAttr());
2974}
2975
2976SmallVector<Region *> ParallelOp::getLoopRegions() { return {&getRegion()}; }
2977
2978std::optional<SmallVector<Value>> ParallelOp::getLoopInductionVars() {
2979 return SmallVector<Value>{getBody()->getArguments()};
2980}
2981
2982std::optional<SmallVector<OpFoldResult>> ParallelOp::getLoopLowerBounds() {
2983 return getLowerBound();
2984}
2985
2986std::optional<SmallVector<OpFoldResult>> ParallelOp::getLoopUpperBounds() {
2987 return getUpperBound();
2988}
2989
2990std::optional<SmallVector<OpFoldResult>> ParallelOp::getLoopSteps() {
2991 return getStep();
2992}
2993
2995 auto ivArg = llvm::dyn_cast<BlockArgument>(val);
2996 if (!ivArg)
2997 return ParallelOp();
2998 assert(ivArg.getOwner() && "unlinked block argument");
2999 auto *containingOp = ivArg.getOwner()->getParentOp();
3000 return dyn_cast<ParallelOp>(containingOp);
3001}
3002
3003namespace {
3004// Collapse loop dimensions that perform a single iteration.
3005struct ParallelOpSingleOrZeroIterationDimsFolder
3006 : public OpRewritePattern<ParallelOp> {
3007 using OpRewritePattern<ParallelOp>::OpRewritePattern;
3008
3009 LogicalResult matchAndRewrite(ParallelOp op,
3010 PatternRewriter &rewriter) const override {
3011 Location loc = op.getLoc();
3012
3013 // Compute new loop bounds that omit all single-iteration loop dimensions.
3014 SmallVector<Value> newLowerBounds, newUpperBounds, newSteps;
3015 IRMapping mapping;
3016 for (auto [lb, ub, step, iv] :
3017 llvm::zip(op.getLowerBound(), op.getUpperBound(), op.getStep(),
3018 op.getInductionVars())) {
3019 auto numIterations =
3020 constantTripCount(lb, ub, step, /*isSigned=*/true, computeUbMinusLb);
3021 if (numIterations.has_value()) {
3022 // Remove the loop if it performs zero iterations.
3023 if (*numIterations == 0) {
3024 rewriter.replaceOp(op, op.getInitVals());
3025 return success();
3026 }
3027 // Replace the loop induction variable by the lower bound if the loop
3028 // performs a single iteration. Otherwise, copy the loop bounds.
3029 if (*numIterations == 1) {
3030 mapping.map(iv, getValueOrCreateConstantIndexOp(rewriter, loc, lb));
3031 continue;
3032 }
3033 }
3034 newLowerBounds.push_back(lb);
3035 newUpperBounds.push_back(ub);
3036 newSteps.push_back(step);
3037 }
3038 // Exit if none of the loop dimensions perform a single iteration.
3039 if (newLowerBounds.size() == op.getLowerBound().size())
3040 return failure();
3041
3042 if (newLowerBounds.empty()) {
3043 // All of the loop dimensions perform a single iteration. Inline
3044 // loop body and nested ReduceOp's
3045 SmallVector<Value> results;
3046 results.reserve(op.getInitVals().size());
3047 for (auto &bodyOp : op.getBody()->without_terminator())
3048 rewriter.clone(bodyOp, mapping);
3049 auto reduceOp = cast<ReduceOp>(op.getBody()->getTerminator());
3050 for (int64_t i = 0, e = reduceOp.getReductions().size(); i < e; ++i) {
3051 Block &reduceBlock = reduceOp.getReductions()[i].front();
3052 auto initValIndex = results.size();
3053 mapping.map(reduceBlock.getArgument(0), op.getInitVals()[initValIndex]);
3054 mapping.map(reduceBlock.getArgument(1),
3055 mapping.lookupOrDefault(reduceOp.getOperands()[i]));
3056 for (auto &reduceBodyOp : reduceBlock.without_terminator())
3057 rewriter.clone(reduceBodyOp, mapping);
3058
3059 auto result = mapping.lookupOrDefault(
3060 cast<ReduceReturnOp>(reduceBlock.getTerminator()).getResult());
3061 results.push_back(result);
3062 }
3063
3064 rewriter.replaceOp(op, results);
3065 return success();
3066 }
3067 // Replace the parallel loop by lower-dimensional parallel loop.
3068 auto newOp =
3069 ParallelOp::create(rewriter, op.getLoc(), newLowerBounds,
3070 newUpperBounds, newSteps, op.getInitVals(), nullptr);
3071 // Erase the empty block that was inserted by the builder.
3072 rewriter.eraseBlock(newOp.getBody());
3073 // Clone the loop body and remap the block arguments of the collapsed loops
3074 // (inlining does not support a cancellable block argument mapping).
3075 rewriter.cloneRegionBefore(op.getRegion(), newOp.getRegion(),
3076 newOp.getRegion().begin(), mapping);
3077 rewriter.replaceOp(op, newOp.getResults());
3078 return success();
3079 }
3080};
3081
3082struct MergeNestedParallelLoops : public OpRewritePattern<ParallelOp> {
3083 using OpRewritePattern<ParallelOp>::OpRewritePattern;
3084
3085 LogicalResult matchAndRewrite(ParallelOp op,
3086 PatternRewriter &rewriter) const override {
3087 Block &outerBody = *op.getBody();
3088 if (!llvm::hasSingleElement(outerBody.without_terminator()))
3089 return failure();
3090
3091 auto innerOp = dyn_cast<ParallelOp>(outerBody.front());
3092 if (!innerOp)
3093 return failure();
3094
3095 for (auto val : outerBody.getArguments())
3096 if (llvm::is_contained(innerOp.getLowerBound(), val) ||
3097 llvm::is_contained(innerOp.getUpperBound(), val) ||
3098 llvm::is_contained(innerOp.getStep(), val))
3099 return failure();
3100
3101 // Reductions are not supported yet.
3102 if (!op.getInitVals().empty() || !innerOp.getInitVals().empty())
3103 return failure();
3104
3105 auto bodyBuilder = [&](OpBuilder &builder, Location /*loc*/,
3106 ValueRange iterVals, ValueRange) {
3107 Block &innerBody = *innerOp.getBody();
3108 assert(iterVals.size() ==
3109 (outerBody.getNumArguments() + innerBody.getNumArguments()));
3110 IRMapping mapping;
3111 mapping.map(outerBody.getArguments(),
3112 iterVals.take_front(outerBody.getNumArguments()));
3113 mapping.map(innerBody.getArguments(),
3114 iterVals.take_back(innerBody.getNumArguments()));
3115 for (Operation &op : innerBody.without_terminator())
3116 builder.clone(op, mapping);
3117 };
3118
3119 auto concatValues = [](const auto &first, const auto &second) {
3120 SmallVector<Value> ret;
3121 ret.reserve(first.size() + second.size());
3122 ret.assign(first.begin(), first.end());
3123 ret.append(second.begin(), second.end());
3124 return ret;
3125 };
3126
3127 auto newLowerBounds =
3128 concatValues(op.getLowerBound(), innerOp.getLowerBound());
3129 auto newUpperBounds =
3130 concatValues(op.getUpperBound(), innerOp.getUpperBound());
3131 auto newSteps = concatValues(op.getStep(), innerOp.getStep());
3132
3133 rewriter.replaceOpWithNewOp<ParallelOp>(op, newLowerBounds, newUpperBounds,
3134 newSteps, ValueRange(),
3135 bodyBuilder);
3136 return success();
3137 }
3138};
3139
3140} // namespace
3141
3142void ParallelOp::getCanonicalizationPatterns(RewritePatternSet &results,
3143 MLIRContext *context) {
3144 results
3145 .add<ParallelOpSingleOrZeroIterationDimsFolder, MergeNestedParallelLoops>(
3146 context);
3147}
3148
3149/// Given the region at `index`, or the parent operation if `index` is None,
3150/// return the successor regions. These are the regions that may be selected
3151/// during the flow of control. `operands` is a set of optional attributes that
3152/// correspond to a constant value for each operand, or null if that operand is
3153/// not a constant.
3154void ParallelOp::getSuccessorRegions(
3155 RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {
3156 // Both the operation itself and the region may be branching into the body or
3157 // back into the operation itself. It is possible for loop not to enter the
3158 // body.
3159 regions.push_back(RegionSuccessor(&getRegion()));
3160 regions.push_back(RegionSuccessor::parent(
3161 ResultRange{getResults().end(), getResults().end()}));
3162}
3163
3164//===----------------------------------------------------------------------===//
3165// ReduceOp
3166//===----------------------------------------------------------------------===//
3167
3168void ReduceOp::build(OpBuilder &builder, OperationState &result) {}
3169
3170void ReduceOp::build(OpBuilder &builder, OperationState &result,
3171 ValueRange operands) {
3172 result.addOperands(operands);
3173 for (Value v : operands) {
3174 OpBuilder::InsertionGuard guard(builder);
3175 Region *bodyRegion = result.addRegion();
3176 builder.createBlock(bodyRegion, {},
3177 ArrayRef<Type>{v.getType(), v.getType()},
3178 {result.location, result.location});
3179 }
3180}
3181
3182LogicalResult ReduceOp::verifyRegions() {
3183 if (getReductions().size() != getOperands().size())
3184 return emitOpError() << "expects number of reduction regions: "
3185 << getReductions().size()
3186 << " to be the same as number of reduction operands: "
3187 << getOperands().size();
3188 // The region of a ReduceOp has two arguments of the same type as its
3189 // corresponding operand.
3190 for (int64_t i = 0, e = getReductions().size(); i < e; ++i) {
3191 auto type = getOperands()[i].getType();
3192 Block &block = getReductions()[i].front();
3193 if (block.empty())
3194 return emitOpError() << i << "-th reduction has an empty body";
3195 if (block.getNumArguments() != 2 ||
3196 llvm::any_of(block.getArguments(), [&](const BlockArgument &arg) {
3197 return arg.getType() != type;
3198 }))
3199 return emitOpError() << "expected two block arguments with type " << type
3200 << " in the " << i << "-th reduction region";
3201
3202 // Check that the block is terminated by a ReduceReturnOp.
3203 if (!isa<ReduceReturnOp>(block.getTerminator()))
3204 return emitOpError("reduction bodies must be terminated with an "
3205 "'scf.reduce.return' op");
3206 }
3207
3208 return success();
3209}
3210
3211MutableOperandRange
3212ReduceOp::getMutableSuccessorOperands(RegionSuccessor point) {
3213 // No operands are forwarded to the next iteration.
3214 return MutableOperandRange(getOperation(), /*start=*/0, /*length=*/0);
3215}
3216
3217//===----------------------------------------------------------------------===//
3218// ReduceReturnOp
3219//===----------------------------------------------------------------------===//
3220
3221LogicalResult ReduceReturnOp::verify() {
3222 // The type of the return value should be the same type as the types of the
3223 // block arguments of the reduction body.
3224 Block *reductionBody = getOperation()->getBlock();
3225 // Should already be verified by an op trait.
3226 assert(isa<ReduceOp>(reductionBody->getParentOp()) && "expected scf.reduce");
3227 Type expectedResultType = reductionBody->getArgument(0).getType();
3228 if (expectedResultType != getResult().getType())
3229 return emitOpError() << "must have type " << expectedResultType
3230 << " (the type of the reduction inputs)";
3231 return success();
3232}
3233
3234//===----------------------------------------------------------------------===//
3235// WhileOp
3236//===----------------------------------------------------------------------===//
3237
3238void WhileOp::build(::mlir::OpBuilder &odsBuilder,
3239 ::mlir::OperationState &odsState, TypeRange resultTypes,
3240 ValueRange inits, BodyBuilderFn beforeBuilder,
3241 BodyBuilderFn afterBuilder) {
3242 odsState.addOperands(inits);
3243 odsState.addTypes(resultTypes);
3244
3245 OpBuilder::InsertionGuard guard(odsBuilder);
3246
3247 // Build before region.
3248 SmallVector<Location, 4> beforeArgLocs;
3249 beforeArgLocs.reserve(inits.size());
3250 for (Value operand : inits) {
3251 beforeArgLocs.push_back(operand.getLoc());
3252 }
3253
3254 Region *beforeRegion = odsState.addRegion();
3255 Block *beforeBlock = odsBuilder.createBlock(beforeRegion, /*insertPt=*/{},
3256 inits.getTypes(), beforeArgLocs);
3257 if (beforeBuilder)
3258 beforeBuilder(odsBuilder, odsState.location, beforeBlock->getArguments());
3259
3260 // Build after region.
3261 SmallVector<Location, 4> afterArgLocs(resultTypes.size(), odsState.location);
3262
3263 Region *afterRegion = odsState.addRegion();
3264 Block *afterBlock = odsBuilder.createBlock(afterRegion, /*insertPt=*/{},
3265 resultTypes, afterArgLocs);
3266
3267 if (afterBuilder)
3268 afterBuilder(odsBuilder, odsState.location, afterBlock->getArguments());
3269}
3270
3271ConditionOp WhileOp::getConditionOp() {
3272 return cast<ConditionOp>(getBeforeBody()->getTerminator());
3273}
3274
3275YieldOp WhileOp::getYieldOp() {
3276 return cast<YieldOp>(getAfterBody()->getTerminator());
3277}
3278
3279std::optional<MutableArrayRef<OpOperand>> WhileOp::getYieldedValuesMutable() {
3280 return getYieldOp().getResultsMutable();
3281}
3282
3283Block::BlockArgListType WhileOp::getBeforeArguments() {
3284 return getBeforeBody()->getArguments();
3285}
3286
3287Block::BlockArgListType WhileOp::getAfterArguments() {
3288 return getAfterBody()->getArguments();
3289}
3290
3291Block::BlockArgListType WhileOp::getRegionIterArgs() {
3292 return getBeforeArguments();
3293}
3294
3295OperandRange WhileOp::getEntrySuccessorOperands(RegionSuccessor successor) {
3296 assert(successor.getSuccessor() == &getBefore() &&
3297 "WhileOp is expected to branch only to the first region");
3298 return getInits();
3299}
3300
3301void WhileOp::getSuccessorRegions(RegionBranchPoint point,
3302 SmallVectorImpl<RegionSuccessor> &regions) {
3303 // The parent op always branches to the condition region.
3304 if (point.isParent()) {
3305 regions.emplace_back(&getBefore(), getBefore().getArguments());
3306 return;
3307 }
3308
3309 assert(llvm::is_contained(
3310 {&getAfter(), &getBefore()},
3311 point.getTerminatorPredecessorOrNull()->getParentRegion()) &&
3312 "there are only two regions in a WhileOp");
3313 // The body region always branches back to the condition region.
3315 &getAfter()) {
3316 regions.emplace_back(&getBefore(), getBefore().getArguments());
3317 return;
3318 }
3319
3320 regions.push_back(RegionSuccessor::parent(getResults()));
3321 regions.emplace_back(&getAfter(), getAfter().getArguments());
3322}
3323
3324SmallVector<Region *> WhileOp::getLoopRegions() {
3325 return {&getBefore(), &getAfter()};
3326}
3327
3328/// Parses a `while` op.
3329///
3330/// op ::= `scf.while` assignments `:` function-type region `do` region
3331/// `attributes` attribute-dict
3332/// initializer ::= /* empty */ | `(` assignment-list `)`
3333/// assignment-list ::= assignment | assignment `,` assignment-list
3334/// assignment ::= ssa-value `=` ssa-value
3335ParseResult scf::WhileOp::parse(OpAsmParser &parser, OperationState &result) {
3336 SmallVector<OpAsmParser::Argument, 4> regionArgs;
3337 SmallVector<OpAsmParser::UnresolvedOperand, 4> operands;
3338 Region *before = result.addRegion();
3339 Region *after = result.addRegion();
3340
3341 OptionalParseResult listResult =
3342 parser.parseOptionalAssignmentList(regionArgs, operands);
3343 if (listResult.has_value() && failed(listResult.value()))
3344 return failure();
3345
3346 FunctionType functionType;
3347 SMLoc typeLoc = parser.getCurrentLocation();
3348 if (failed(parser.parseColonType(functionType)))
3349 return failure();
3350
3351 result.addTypes(functionType.getResults());
3352
3353 if (functionType.getNumInputs() != operands.size()) {
3354 return parser.emitError(typeLoc)
3355 << "expected as many input types as operands " << "(expected "
3356 << operands.size() << " got " << functionType.getNumInputs() << ")";
3357 }
3358
3359 // Resolve input operands.
3360 if (failed(parser.resolveOperands(operands, functionType.getInputs(),
3361 parser.getCurrentLocation(),
3362 result.operands)))
3363 return failure();
3364
3365 // Propagate the types into the region arguments.
3366 for (size_t i = 0, e = regionArgs.size(); i != e; ++i)
3367 regionArgs[i].type = functionType.getInput(i);
3368
3369 return failure(parser.parseRegion(*before, regionArgs) ||
3370 parser.parseKeyword("do") || parser.parseRegion(*after) ||
3371 parser.parseOptionalAttrDictWithKeyword(result.attributes));
3372}
3373
3374/// Prints a `while` op.
3375void scf::WhileOp::print(OpAsmPrinter &p) {
3376 printInitializationList(p, getBeforeArguments(), getInits(), " ");
3377 p << " : ";
3378 p.printFunctionalType(getInits().getTypes(), getResults().getTypes());
3379 p << ' ';
3380 p.printRegion(getBefore(), /*printEntryBlockArgs=*/false);
3381 p << " do ";
3382 p.printRegion(getAfter());
3383 p.printOptionalAttrDictWithKeyword((*this)->getAttrs());
3384}
3385
3386/// Verifies that two ranges of types match, i.e. have the same number of
3387/// entries and that types are pairwise equals. Reports errors on the given
3388/// operation in case of mismatch.
3389template <typename OpTy>
3390static LogicalResult verifyTypeRangesMatch(OpTy op, TypeRange left,
3391 TypeRange right, StringRef message) {
3392 if (left.size() != right.size())
3393 return op.emitOpError("expects the same number of ") << message;
3394
3395 for (unsigned i = 0, e = left.size(); i < e; ++i) {
3396 if (left[i] != right[i]) {
3397 InFlightDiagnostic diag = op.emitOpError("expects the same types for ")
3398 << message;
3399 diag.attachNote() << "for argument " << i << ", found " << left[i]
3400 << " and " << right[i];
3401 return diag;
3402 }
3403 }
3404
3405 return success();
3406}
3407
3408LogicalResult scf::WhileOp::verify() {
3409 auto beforeTerminator = verifyAndGetTerminator<scf::ConditionOp>(
3410 *this, getBefore(),
3411 "expects the 'before' region to terminate with 'scf.condition'");
3412 if (!beforeTerminator)
3413 return failure();
3414
3415 auto afterTerminator = verifyAndGetTerminator<scf::YieldOp>(
3416 *this, getAfter(),
3417 "expects the 'after' region to terminate with 'scf.yield'");
3418 return success(afterTerminator != nullptr);
3419}
3420
3421namespace {
3422/// Move a scf.if op that is directly before the scf.condition op in the while
3423/// before region, and whose condition matches the condition of the
3424/// scf.condition op, down into the while after region.
3425///
3426/// scf.while (..) : (...) -> ... {
3427/// %additional_used_values = ...
3428/// %cond = ...
3429/// ...
3430/// %res = scf.if %cond -> (...) {
3431/// use(%additional_used_values)
3432/// ... // then block
3433/// scf.yield %then_value
3434/// } else {
3435/// scf.yield %else_value
3436/// }
3437/// scf.condition(%cond) %res, ...
3438/// } do {
3439/// ^bb0(%res_arg, ...):
3440/// use(%res_arg)
3441/// ...
3442///
3443/// becomes
3444/// scf.while (..) : (...) -> ... {
3445/// %additional_used_values = ...
3446/// %cond = ...
3447/// ...
3448/// scf.condition(%cond) %else_value, ..., %additional_used_values
3449/// } do {
3450/// ^bb0(%res_arg ..., %additional_args): :
3451/// use(%additional_args)
3452/// ... // if then block
3453/// use(%then_value)
3454/// ...
3455struct WhileMoveIfDown : public OpRewritePattern<scf::WhileOp> {
3456 using OpRewritePattern<scf::WhileOp>::OpRewritePattern;
3457
3458 LogicalResult matchAndRewrite(scf::WhileOp op,
3459 PatternRewriter &rewriter) const override {
3460 auto conditionOp = op.getConditionOp();
3461
3462 // Only support ifOp right before the condition at the moment. Relaxing this
3463 // would require to:
3464 // - check that the body does not have side-effects conflicting with
3465 // operations between the if and the condition.
3466 // - check that results of the if operation are only used as arguments to
3467 // the condition.
3468 auto ifOp = dyn_cast_or_null<scf::IfOp>(conditionOp->getPrevNode());
3469
3470 // Check that the ifOp is directly before the conditionOp and that it
3471 // matches the condition of the conditionOp. Also ensure that the ifOp has
3472 // no else block with content, as that would complicate the transformation.
3473 // TODO: support else blocks with content.
3474 if (!ifOp || ifOp.getCondition() != conditionOp.getCondition() ||
3475 (ifOp.elseBlock() && !ifOp.elseBlock()->without_terminator().empty()))
3476 return failure();
3477
3478 assert((ifOp->use_empty() || (llvm::all_equal(ifOp->getUsers()) &&
3479 *ifOp->user_begin() == conditionOp)) &&
3480 "ifOp has unexpected uses");
3481
3482 Location loc = op.getLoc();
3483
3484 // Replace uses of ifOp results in the conditionOp with the yielded values
3485 // from the ifOp branches.
3486 for (auto [idx, arg] : llvm::enumerate(conditionOp.getArgs())) {
3487 auto it = llvm::find(ifOp->getResults(), arg);
3488 if (it != ifOp->getResults().end()) {
3489 size_t ifOpIdx = it.getIndex();
3490 Value thenValue = ifOp.thenYield()->getOperand(ifOpIdx);
3491 Value elseValue = ifOp.elseYield()->getOperand(ifOpIdx);
3492
3493 rewriter.replaceAllUsesWith(ifOp->getResults()[ifOpIdx], elseValue);
3494 rewriter.replaceAllUsesWith(op.getAfterArguments()[idx], thenValue);
3495 }
3496 }
3497
3498 // Collect additional used values from before region.
3499 SetVector<Value> additionalUsedValuesSet;
3500 visitUsedValuesDefinedAbove(ifOp.getThenRegion(), [&](OpOperand *operand) {
3501 if (&op.getBefore() == operand->get().getParentRegion())
3502 additionalUsedValuesSet.insert(operand->get());
3503 });
3504
3505 // Create new whileOp with additional used values as results.
3506 auto additionalUsedValues = additionalUsedValuesSet.getArrayRef();
3507 auto additionalValueTypes = llvm::map_to_vector(
3508 additionalUsedValues, [](Value val) { return val.getType(); });
3509 size_t additionalValueSize = additionalUsedValues.size();
3510 SmallVector<Type> newResultTypes(op.getResultTypes());
3511 newResultTypes.append(additionalValueTypes);
3512
3513 auto newWhileOp =
3514 scf::WhileOp::create(rewriter, loc, newResultTypes, op.getInits());
3515
3516 rewriter.modifyOpInPlace(newWhileOp, [&] {
3517 newWhileOp.getBefore().takeBody(op.getBefore());
3518 newWhileOp.getAfter().takeBody(op.getAfter());
3519 newWhileOp.getAfter().addArguments(
3520 additionalValueTypes,
3521 SmallVector<Location>(additionalValueSize, loc));
3522 });
3523
3524 rewriter.modifyOpInPlace(conditionOp, [&] {
3525 conditionOp.getArgsMutable().append(additionalUsedValues);
3526 });
3527
3528 // Replace uses of additional used values inside the ifOp then region with
3529 // the whileOp after region arguments.
3530 rewriter.replaceUsesWithIf(
3531 additionalUsedValues,
3532 newWhileOp.getAfterArguments().take_back(additionalValueSize),
3533 [&](OpOperand &use) {
3534 return ifOp.getThenRegion().isAncestor(
3535 use.getOwner()->getParentRegion());
3536 });
3537
3538 // Inline ifOp then region into new whileOp after region.
3539 rewriter.eraseOp(ifOp.thenYield());
3540 rewriter.inlineBlockBefore(ifOp.thenBlock(), newWhileOp.getAfterBody(),
3541 newWhileOp.getAfterBody()->begin());
3542 rewriter.eraseOp(ifOp);
3543 rewriter.replaceOp(op,
3544 newWhileOp->getResults().drop_back(additionalValueSize));
3545 return success();
3546 }
3547};
3548
3549/// Replace uses of the condition within the do block with true, since otherwise
3550/// the block would not be evaluated.
3551///
3552/// scf.while (..) : (i1, ...) -> ... {
3553/// %condition = call @evaluate_condition() : () -> i1
3554/// scf.condition(%condition) %condition : i1, ...
3555/// } do {
3556/// ^bb0(%arg0: i1, ...):
3557/// use(%arg0)
3558/// ...
3559///
3560/// becomes
3561/// scf.while (..) : (i1, ...) -> ... {
3562/// %condition = call @evaluate_condition() : () -> i1
3563/// scf.condition(%condition) %condition : i1, ...
3564/// } do {
3565/// ^bb0(%arg0: i1, ...):
3566/// use(%true)
3567/// ...
3568struct WhileConditionTruth : public OpRewritePattern<WhileOp> {
3569 using OpRewritePattern<WhileOp>::OpRewritePattern;
3570
3571 LogicalResult matchAndRewrite(WhileOp op,
3572 PatternRewriter &rewriter) const override {
3573 auto term = op.getConditionOp();
3574
3575 // These variables serve to prevent creating duplicate constants
3576 // and hold constant true or false values.
3577 Value constantTrue = nullptr;
3578
3579 bool replaced = false;
3580 for (auto yieldedAndBlockArgs :
3581 llvm::zip(term.getArgs(), op.getAfterArguments())) {
3582 if (std::get<0>(yieldedAndBlockArgs) == term.getCondition()) {
3583 if (!std::get<1>(yieldedAndBlockArgs).use_empty()) {
3584 if (!constantTrue)
3585 constantTrue = arith::ConstantOp::create(
3586 rewriter, op.getLoc(), term.getCondition().getType(),
3587 rewriter.getBoolAttr(true));
3588
3589 rewriter.replaceAllUsesWith(std::get<1>(yieldedAndBlockArgs),
3590 constantTrue);
3591 replaced = true;
3592 }
3593 }
3594 }
3595 return success(replaced);
3596 }
3597};
3598
3599/// Replace operations equivalent to the condition in the do block with true,
3600/// since otherwise the block would not be evaluated.
3601///
3602/// scf.while (..) : (i32, ...) -> ... {
3603/// %z = ... : i32
3604/// %condition = cmpi pred %z, %a
3605/// scf.condition(%condition) %z : i32, ...
3606/// } do {
3607/// ^bb0(%arg0: i32, ...):
3608/// %condition2 = cmpi pred %arg0, %a
3609/// use(%condition2)
3610/// ...
3611///
3612/// becomes
3613/// scf.while (..) : (i32, ...) -> ... {
3614/// %z = ... : i32
3615/// %condition = cmpi pred %z, %a
3616/// scf.condition(%condition) %z : i32, ...
3617/// } do {
3618/// ^bb0(%arg0: i32, ...):
3619/// use(%true)
3620/// ...
3621struct WhileCmpCond : public OpRewritePattern<scf::WhileOp> {
3622 using OpRewritePattern<scf::WhileOp>::OpRewritePattern;
3623
3624 LogicalResult matchAndRewrite(scf::WhileOp op,
3625 PatternRewriter &rewriter) const override {
3626 using namespace scf;
3627 auto cond = op.getConditionOp();
3628 auto cmp = cond.getCondition().getDefiningOp<arith::CmpIOp>();
3629 if (!cmp)
3630 return failure();
3631 bool changed = false;
3632 for (auto tup : llvm::zip(cond.getArgs(), op.getAfterArguments())) {
3633 for (size_t opIdx = 0; opIdx < 2; opIdx++) {
3634 if (std::get<0>(tup) != cmp.getOperand(opIdx))
3635 continue;
3636 for (OpOperand &u :
3637 llvm::make_early_inc_range(std::get<1>(tup).getUses())) {
3638 auto cmp2 = dyn_cast<arith::CmpIOp>(u.getOwner());
3639 if (!cmp2)
3640 continue;
3641 // For a binary operator 1-opIdx gets the other side.
3642 if (cmp2.getOperand(1 - opIdx) != cmp.getOperand(1 - opIdx))
3643 continue;
3644 bool samePredicate;
3645 if (cmp2.getPredicate() == cmp.getPredicate())
3646 samePredicate = true;
3647 else if (cmp2.getPredicate() ==
3648 arith::invertPredicate(cmp.getPredicate()))
3649 samePredicate = false;
3650 else
3651 continue;
3652
3653 rewriter.replaceOpWithNewOp<arith::ConstantIntOp>(cmp2, samePredicate,
3654 1);
3655 changed = true;
3656 }
3657 }
3658 }
3659 return success(changed);
3660 }
3661};
3662
3663/// If both ranges contain same values return mappping indices from args2 to
3664/// args1. Otherwise return std::nullopt.
3665static std::optional<SmallVector<unsigned>> getArgsMapping(ValueRange args1,
3666 ValueRange args2) {
3667 if (args1.size() != args2.size())
3668 return std::nullopt;
3669
3670 SmallVector<unsigned> ret(args1.size());
3671 for (auto &&[i, arg1] : llvm::enumerate(args1)) {
3672 auto it = llvm::find(args2, arg1);
3673 if (it == args2.end())
3674 return std::nullopt;
3675
3676 ret[std::distance(args2.begin(), it)] = static_cast<unsigned>(i);
3677 }
3678
3679 return ret;
3680}
3681
3682static bool hasDuplicates(ValueRange args) {
3683 llvm::SmallDenseSet<Value> set;
3684 for (Value arg : args) {
3685 if (!set.insert(arg).second)
3686 return true;
3687 }
3688 return false;
3689}
3690
3691/// If `before` block args are directly forwarded to `scf.condition`, rearrange
3692/// `scf.condition` args into same order as block args. Update `after` block
3693/// args and op result values accordingly.
3694/// Needed to simplify `scf.while` -> `scf.for` uplifting.
3695struct WhileOpAlignBeforeArgs : public OpRewritePattern<WhileOp> {
3697
3698 LogicalResult matchAndRewrite(WhileOp loop,
3699 PatternRewriter &rewriter) const override {
3700 auto *oldBefore = loop.getBeforeBody();
3701 ConditionOp oldTerm = loop.getConditionOp();
3702 ValueRange beforeArgs = oldBefore->getArguments();
3703 ValueRange termArgs = oldTerm.getArgs();
3704 if (beforeArgs == termArgs)
3705 return failure();
3706
3707 if (hasDuplicates(termArgs))
3708 return failure();
3709
3710 auto mapping = getArgsMapping(beforeArgs, termArgs);
3711 if (!mapping)
3712 return failure();
3713
3714 {
3715 OpBuilder::InsertionGuard g(rewriter);
3716 rewriter.setInsertionPoint(oldTerm);
3717 rewriter.replaceOpWithNewOp<ConditionOp>(oldTerm, oldTerm.getCondition(),
3718 beforeArgs);
3719 }
3720
3721 auto *oldAfter = loop.getAfterBody();
3722
3723 SmallVector<Type> newResultTypes(beforeArgs.size());
3724 for (auto &&[i, j] : llvm::enumerate(*mapping))
3725 newResultTypes[j] = loop.getResult(i).getType();
3726
3727 auto newLoop = WhileOp::create(
3728 rewriter, loop.getLoc(), newResultTypes, loop.getInits(),
3729 /*beforeBuilder=*/nullptr, /*afterBuilder=*/nullptr);
3730 auto *newBefore = newLoop.getBeforeBody();
3731 auto *newAfter = newLoop.getAfterBody();
3732
3733 SmallVector<Value> newResults(beforeArgs.size());
3734 SmallVector<Value> newAfterArgs(beforeArgs.size());
3735 for (auto &&[i, j] : llvm::enumerate(*mapping)) {
3736 newResults[i] = newLoop.getResult(j);
3737 newAfterArgs[i] = newAfter->getArgument(j);
3738 }
3739
3740 rewriter.inlineBlockBefore(oldBefore, newBefore, newBefore->begin(),
3741 newBefore->getArguments());
3742 rewriter.inlineBlockBefore(oldAfter, newAfter, newAfter->begin(),
3743 newAfterArgs);
3744
3745 rewriter.replaceOp(loop, newResults);
3746 return success();
3747 }
3748};
3749} // namespace
3750
3751void WhileOp::getCanonicalizationPatterns(RewritePatternSet &results,
3752 MLIRContext *context) {
3753 results.add<WhileConditionTruth, WhileCmpCond, WhileOpAlignBeforeArgs,
3754 WhileMoveIfDown>(context);
3756 results, WhileOp::getOperationName());
3757}
3758
3759//===----------------------------------------------------------------------===//
3760// IndexSwitchOp
3761//===----------------------------------------------------------------------===//
3762
3763/// Parse the case regions and values.
3764static ParseResult
3766 SmallVectorImpl<std::unique_ptr<Region>> &caseRegions) {
3767 SmallVector<int64_t> caseValues;
3768 while (succeeded(p.parseOptionalKeyword("case"))) {
3769 int64_t value;
3770 Region &region = *caseRegions.emplace_back(std::make_unique<Region>());
3771 if (p.parseInteger(value) || p.parseRegion(region, /*arguments=*/{}))
3772 return failure();
3773 caseValues.push_back(value);
3774 }
3775 cases = p.getBuilder().getDenseI64ArrayAttr(caseValues);
3776 return success();
3777}
3778
3779/// Print the case regions and values.
3781 DenseI64ArrayAttr cases, RegionRange caseRegions) {
3782 for (auto [value, region] : llvm::zip(cases.asArrayRef(), caseRegions)) {
3783 p.printNewline();
3784 p << "case " << value << ' ';
3785 p.printRegion(*region, /*printEntryBlockArgs=*/false);
3786 }
3787}
3788
3789LogicalResult scf::IndexSwitchOp::verify() {
3790 if (getCases().size() != getCaseRegions().size()) {
3791 return emitOpError("has ")
3792 << getCaseRegions().size() << " case regions but "
3793 << getCases().size() << " case values";
3794 }
3795
3796 DenseSet<int64_t> valueSet;
3797 for (int64_t value : getCases())
3798 if (!valueSet.insert(value).second)
3799 return emitOpError("has duplicate case value: ") << value;
3800 auto verifyRegion = [&](Region &region, const Twine &name) -> LogicalResult {
3801 auto yield = dyn_cast<YieldOp>(region.front().back());
3802 if (!yield)
3803 return emitOpError("expected region to end with scf.yield, but got ")
3804 << region.front().back().getName();
3805
3806 if (yield.getNumOperands() != getNumResults()) {
3807 return (emitOpError("expected each region to return ")
3808 << getNumResults() << " values, but " << name << " returns "
3809 << yield.getNumOperands())
3810 .attachNote(yield.getLoc())
3811 << "see yield operation here";
3812 }
3813 for (auto [idx, result, operand] :
3814 llvm::enumerate(getResultTypes(), yield.getOperands())) {
3815 if (!operand)
3816 return yield.emitOpError() << "operand " << idx << " is null\n";
3817 if (result == operand.getType())
3818 continue;
3819 return (emitOpError("expected result #")
3820 << idx << " of each region to be " << result)
3821 .attachNote(yield.getLoc())
3822 << name << " returns " << operand.getType() << " here";
3823 }
3824 return success();
3825 };
3826
3827 if (failed(verifyRegion(getDefaultRegion(), "default region")))
3828 return failure();
3829 for (auto [idx, caseRegion] : llvm::enumerate(getCaseRegions()))
3830 if (failed(verifyRegion(caseRegion, "case region #" + Twine(idx))))
3831 return failure();
3832
3833 return success();
3834}
3835
3836unsigned scf::IndexSwitchOp::getNumCases() { return getCases().size(); }
3837
3838Block &scf::IndexSwitchOp::getDefaultBlock() {
3839 return getDefaultRegion().front();
3840}
3841
3842Block &scf::IndexSwitchOp::getCaseBlock(unsigned idx) {
3843 assert(idx < getNumCases() && "case index out-of-bounds");
3844 return getCaseRegions()[idx].front();
3845}
3846
3847void IndexSwitchOp::getSuccessorRegions(
3848 RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &successors) {
3849 // All regions branch back to the parent op.
3850 if (!point.isParent()) {
3851 successors.push_back(RegionSuccessor::parent(getResults()));
3852 return;
3853 }
3854
3855 llvm::append_range(successors, getRegions());
3856}
3857
3858void IndexSwitchOp::getEntrySuccessorRegions(
3859 ArrayRef<Attribute> operands,
3860 SmallVectorImpl<RegionSuccessor> &successors) {
3861 FoldAdaptor adaptor(operands, *this);
3862
3863 // If a constant was not provided, all regions are possible successors.
3864 auto arg = dyn_cast_or_null<IntegerAttr>(adaptor.getArg());
3865 if (!arg) {
3866 llvm::append_range(successors, getRegions());
3867 return;
3868 }
3869
3870 // Otherwise, try to find a case with a matching value. If not, the
3871 // default region is the only successor.
3872 for (auto [caseValue, caseRegion] : llvm::zip(getCases(), getCaseRegions())) {
3873 if (caseValue == arg.getInt()) {
3874 successors.emplace_back(&caseRegion);
3875 return;
3876 }
3877 }
3878 successors.emplace_back(&getDefaultRegion());
3879}
3880
3881void IndexSwitchOp::getRegionInvocationBounds(
3882 ArrayRef<Attribute> operands, SmallVectorImpl<InvocationBounds> &bounds) {
3883 auto operandValue = llvm::dyn_cast_or_null<IntegerAttr>(operands.front());
3884 if (!operandValue) {
3885 // All regions are invoked at most once.
3886 bounds.append(getNumRegions(), InvocationBounds(/*lb=*/0, /*ub=*/1));
3887 return;
3888 }
3889
3890 unsigned liveIndex = getNumRegions() - 1;
3891 const auto *it = llvm::find(getCases(), operandValue.getInt());
3892 if (it != getCases().end())
3893 liveIndex = std::distance(getCases().begin(), it);
3894 for (unsigned i = 0, e = getNumRegions(); i < e; ++i)
3895 bounds.emplace_back(/*lb=*/0, /*ub=*/i == liveIndex);
3896}
3897
3898struct FoldConstantCase : OpRewritePattern<scf::IndexSwitchOp> {
3899 using OpRewritePattern<scf::IndexSwitchOp>::OpRewritePattern;
3900
3901 LogicalResult matchAndRewrite(scf::IndexSwitchOp op,
3902 PatternRewriter &rewriter) const override {
3903 // If `op.getArg()` is a constant, select the region that matches with
3904 // the constant value. Use the default region if no matche is found.
3905 std::optional<int64_t> maybeCst = getConstantIntValue(op.getArg());
3906 if (!maybeCst.has_value())
3907 return failure();
3908 int64_t cst = *maybeCst;
3909 int64_t caseIdx, e = op.getNumCases();
3910 for (caseIdx = 0; caseIdx < e; ++caseIdx) {
3911 if (cst == op.getCases()[caseIdx])
3912 break;
3913 }
3914
3915 Region &r = (caseIdx < op.getNumCases()) ? op.getCaseRegions()[caseIdx]
3916 : op.getDefaultRegion();
3917 Block &source = r.front();
3918 Operation *terminator = source.getTerminator();
3919 SmallVector<Value> results = terminator->getOperands();
3920
3921 rewriter.inlineBlockBefore(&source, op);
3922 rewriter.eraseOp(terminator);
3923 // Replace the operation with a potentially empty list of results.
3924 // Fold mechanism doesn't support the case where the result list is empty.
3925 rewriter.replaceOp(op, results);
3926
3927 return success();
3928 }
3929};
3930
3931void IndexSwitchOp::getCanonicalizationPatterns(RewritePatternSet &results,
3932 MLIRContext *context) {
3933 results.add<FoldConstantCase>(context);
3935 results, IndexSwitchOp::getOperationName());
3936}
3937
3938//===----------------------------------------------------------------------===//
3939// TableGen'd op method definitions
3940//===----------------------------------------------------------------------===//
3941
3942#define GET_OP_CLASSES
3943#include "mlir/Dialect/SCF/IR/SCFOps.cpp.inc"
return success()
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
static std::optional< int64_t > getUpperBound(Value iv)
Gets the constant upper bound on an affine.for iv.
static std::optional< int64_t > getLowerBound(Value iv)
Gets the constant lower bound on an iv.
static LogicalResult verifyRegion(emitc::SwitchOp op, Region &region, const Twine &name)
Definition EmitC.cpp:1399
static ParseResult parseSwitchCases(OpAsmParser &parser, DenseI64ArrayAttr &cases, SmallVectorImpl< std::unique_ptr< Region > > &caseRegions)
Parse the case regions and values.
Definition EmitC.cpp:1374
static void printSwitchCases(OpAsmPrinter &p, Operation *op, DenseI64ArrayAttr cases, RegionRange caseRegions)
Print the case regions and values.
Definition EmitC.cpp:1390
static void replaceOpWithRegion(PatternRewriter &rewriter, Operation *op, Region &region, ValueRange blockArgs={})
Replaces the given op with the contents of the given single-block region, using the operands of the b...
Definition SCF.cpp:137
static LogicalResult verifyTypeRangesMatch(OpTy op, TypeRange left, TypeRange right, StringRef message)
Verifies that two ranges of types match, i.e.
Definition SCF.cpp:3390
static void printInitializationList(OpAsmPrinter &p, Block::BlockArgListType blocksArgs, ValueRange initializers, StringRef prefix="")
Prints the initialization list in the form of <prefix>(inner = outer, inner2 = outer2,...
Definition SCF.cpp:509
static TerminatorTy verifyAndGetTerminator(Operation *op, Region &region, StringRef errorMessage)
Verifies that the first block of the given region is terminated by a TerminatorTy.
Definition SCF.cpp:101
static bool isLegalToInline(InlinerInterface &interface, Region *src, Region *insertRegion, bool shouldCloneInlinedRegion, IRMapping &valueMapping)
Utility to check that all of the operations within 'src' can be inlined.
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
b getContext())
*if copies could not be generated due to yet unimplemented cases *copyInPlacementStart and copyOutPlacementStart in copyPlacementBlock *specify the insertion points where the incoming copies and outgoing should be the output argument nBegin is set to its * replacement(set to `begin` if no invalidation happens). Since outgoing *copies could have been inserted at `end`
@ None
static std::string diag(const llvm::Value &value)
@ Paren
Parens surrounding zero or more operands.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseOptionalColon()=0
Parse a : token if present.
ParseResult parseInteger(IntT &result)
Parse an integer value from the stream.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual ParseResult parseOptionalAttrDictWithKeyword(NamedAttrList &result)=0
Parse a named dictionary into 'result' if the attributes keyword is present.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseOptionalArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an optional arrow followed by a type list.
virtual ParseResult parseArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an arrow followed by a type list.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
void printOptionalArrowTypeList(TypeRange &&types)
Print an optional arrow followed by a type list.
virtual void printNewline()
Print a newline and indent the printer to the start of the current operation/attribute/type.
This class represents an argument of a Block.
Definition Value.h:309
unsigned getArgNumber() const
Returns the number of this argument.
Definition Value.h:321
Block represents an ordered list of Operations.
Definition Block.h:33
MutableArrayRef< BlockArgument > BlockArgListType
Definition Block.h:95
bool empty()
Definition Block.h:158
BlockArgument getArgument(unsigned i)
Definition Block.h:139
unsigned getNumArguments()
Definition Block.h:138
iterator_range< args_iterator > addArguments(TypeRange types, ArrayRef< Location > locs)
Add one argument to the argument list for each type specified in the list.
Definition Block.cpp:165
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
Definition Block.cpp:27
Operation & front()
Definition Block.h:163
Operation & back()
Definition Block.h:162
Operation * getTerminator()
Get the terminator operation of this block.
Definition Block.cpp:249
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
Definition Block.cpp:158
BlockArgListType getArguments()
Definition Block.h:97
iterator_range< iterator > without_terminator()
Return an iterator range over the operation within this block excluding the terminator operation at t...
Definition Block.h:222
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
Definition Block.cpp:31
bool getValue() const
Return the boolean value of this attribute.
UnitAttr getUnitAttr()
Definition Builders.cpp:98
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
Definition Builders.cpp:163
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition Builders.cpp:228
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
Definition Builders.cpp:167
IntegerType getIntegerType(unsigned width)
Definition Builders.cpp:67
BoolAttr getBoolAttr(bool value)
Definition Builders.cpp:100
IntegerType getI1Type()
Definition Builders.cpp:53
MLIRContext * getContext() const
Definition Builders.h:56
IndexType getIndexType()
Definition Builders.cpp:51
virtual Operation * materializeConstant(OpBuilder &builder, Attribute value, Type type, Location loc)
Registered hook to materialize a single constant operation from a given attribute value with the desi...
Definition Dialect.h:83
This is a utility class for mapping one set of IR entities to another.
Definition IRMapping.h:26
auto lookupOrDefault(T from) const
Lookup a mapped value within the map.
Definition IRMapping.h:65
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Definition IRMapping.h:30
IRValueT get() const
Return the current value being used by this operand.
void set(IRValueT newValue)
Set the current value being used by this operand.
This class represents a diagnostic that is inflight and set to be reported.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
This class provides a mutable adaptor for a range of operands.
Definition ValueRange.h:118
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual OptionalParseResult parseOptionalAssignmentList(SmallVectorImpl< Argument > &lhs, SmallVectorImpl< UnresolvedOperand > &rhs)=0
virtual ParseResult parseRegion(Region &region, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult parseArgumentList(SmallVectorImpl< Argument > &result, Delimiter delimiter=Delimiter::None, bool allowType=false, bool allowAttrs=false)=0
Parse zero or more arguments with a specified surrounding delimiter.
ParseResult parseAssignmentList(SmallVectorImpl< Argument > &lhs, SmallVectorImpl< UnresolvedOperand > &rhs)
Parse a list of assignments of the form (x1 = y1, x2 = y2, ...)
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printOptionalAttrDictWithKeyword(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary prefixed with 'attribute...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
void printFunctionalType(Operation *op)
Print the complete type of an operation in functional form.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
RAII guard to reset the insertion point of the builder when destroyed.
Definition Builders.h:348
This class helps build Operations.
Definition Builders.h:207
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition Builders.cpp:430
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition Builders.cpp:562
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition Builders.h:431
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition Builders.h:398
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Definition Builders.h:436
void cloneRegionBefore(Region &region, Region &parent, Region::iterator before, IRMapping &mapping)
Clone the blocks that belong to "region" before the given position in another region "parent".
Definition Builders.cpp:589
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition Builders.h:412
Operation * cloneWithoutRegions(Operation &op, IRMapping &mapper)
Creates a deep copy of this operation but keep the operation regions empty.
Definition Builders.h:594
This class represents a single result from folding an operation.
This class represents an operand of an operation.
Definition Value.h:257
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Definition Value.cpp:226
Set of flags used to control the behavior of the various IR print methods (e.g.
A wrapper class that allows for printing an operation with a set of flags, useful to act as a "stream...
Definition Operation.h:1111
This class implements the operand iterators for the Operation class.
Definition ValueRange.h:43
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Definition Operation.h:512
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition Operation.h:407
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:223
OperandRange operand_range
Definition Operation.h:371
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
Definition Operation.h:238
OperationName getName()
The name of an operation is the key identifier for it.
Definition Operation.h:119
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition Operation.h:378
Region * getParentRegion()
Returns the region to which the instruction belongs.
Definition Operation.h:230
bool isProperAncestor(Operation *other)
Return true if this operation is a proper ancestor of the other operation.
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
ParseResult value() const
Access the internal ParseResult value.
bool has_value() const
Returns true if we contain a valid ParseResult value.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class represents a point being branched from in the methods of the RegionBranchOpInterface.
bool isParent() const
Returns true if branching from the parent op.
Operation * getTerminatorPredecessorOrNull() const
Returns the terminator if branching from a region.
This class provides an abstraction over the different types of ranges over Regions.
Definition Region.h:346
This class represents a successor of a region.
static RegionSuccessor parent(Operation::result_range results)
Initialize a successor that branches back to/out of the parent operation.
bool isParent() const
Return true if the successor is the parent operation.
Region * getSuccessor() const
Return the given region successor.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition Region.h:26
Block & front()
Definition Region.h:65
Region * getParentRegion()
Return the region containing this region or nullptr if the region is attached to a top-level operatio...
Definition Region.cpp:45
bool isAncestor(Region *other)
Return true if this region is ancestor of the other region.
Definition Region.h:222
Block & back()
Definition Region.h:64
bool empty()
Definition Region.h:60
unsigned getNumArguments()
Definition Region.h:123
iterator begin()
Definition Region.h:55
BlockArgument getArgument(unsigned i)
Definition Region.h:124
bool hasOneBlock()
Return true if this region has exactly one block.
Definition Region.h:68
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void eraseBlock(Block *block)
This method erases all operations in a block.
Block * splitBlock(Block *block, Block::iterator before)
Split the operations starting at "before" (inclusive) out of the given block into a new block,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void finalizeOpModification(Operation *op)
This method is used to signal the end of an in-place modification of the given operation.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
Operation * eraseOpResults(Operation *op, const BitVector &eraseIndices)
Erase the specified results of the given operation.
virtual void replaceUsesWithIf(Value from, Value to, function_ref< bool(OpOperand &)> functor, bool *allUsesReplaced=nullptr)
Find uses of from and replace them with to if the functor returns true.
virtual void inlineBlockBefore(Block *source, Block *dest, Block::iterator before, ValueRange argValues={})
Inline the operations of block 'source' into block 'dest' before the given position.
void mergeBlocks(Block *source, Block *dest, ValueRange argValues={})
Inline the operations of block 'source' into the end of block 'dest'.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
void inlineRegionBefore(Region &region, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
virtual void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
virtual void startOpModification(Operation *op)
This method is used to notify the rewriter that an in-place operation modification is about to happen...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class provides an abstraction over the various different ranges of value types.
Definition TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
bool isSignlessInteger() const
Return true if this is a signless integer type (with the specified width).
Definition Types.cpp:64
bool isIndex() const
Definition Types.cpp:54
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:387
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
user_range getUsers() const
Definition Value.h:218
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
Region * getParentRegion()
Return the Region in which this Value is defined.
Definition Value.cpp:39
Operation * getOwner() const
Return the owner of this operand.
Definition UseDefLists.h:38
constexpr auto RecursivelySpeculatable
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto NotSpeculatable
StringRef getMappingAttrName()
Name of the mapping attribute produced by loop mappers.
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:573
ParallelOp getParallelForInductionVarOwner(Value val)
Returns the parallel loop parent of an induction variable.
Definition SCF.cpp:2994
void buildTerminatedBody(OpBuilder &builder, Location loc)
Default callback for IfOp builders. Inserts a yield without arguments.
Definition SCF.cpp:94
LoopNest buildLoopNest(OpBuilder &builder, Location loc, ValueRange lbs, ValueRange ubs, ValueRange steps, ValueRange iterArgs, function_ref< ValueVector(OpBuilder &, Location, ValueRange, ValueRange)> bodyBuilder=nullptr)
Creates a perfect nest of "for" loops, i.e.
Definition SCF.cpp:781
bool insideMutuallyExclusiveBranches(Operation *a, Operation *b)
Return true if ops a and b (or their ancestors) are in mutually exclusive regions/blocks of an IfOp.
Definition SCF.cpp:1929
void promote(RewriterBase &rewriter, scf::ForallOp forallOp)
Promotes the loop body of a scf::ForallOp to its containing block.
Definition SCF.cpp:736
std::optional< llvm::APSInt > computeUbMinusLb(Value lb, Value ub, bool isSigned)
Helper function to compute the difference between two values.
Definition SCF.cpp:115
ForOp getForInductionVarOwner(Value val)
Returns the loop parent of an induction variable.
Definition SCF.cpp:688
SmallVector< Value > ValueVector
An owning vector of values, handy to return from functions.
Definition SCF.h:64
llvm::function_ref< Value(OpBuilder &, Location loc, Type, Value)> ValueTypeCastFnTy
Perform a replacement of one iter OpOperand of an scf.for to the replacement value with a different t...
Definition SCF.h:107
ForallOp getForallOpThreadIndexOwner(Value val)
Returns the ForallOp parent of an thread index variable.
Definition SCF.cpp:1413
SmallVector< Value > replaceAndCastForOpIterArg(RewriterBase &rewriter, scf::ForOp forOp, OpOperand &operand, Value replacement, const ValueTypeCastFnTy &castFn)
Definition SCF.cpp:870
bool preservesStaticInformation(Type source, Type target)
Returns true if target is a ranked tensor type that preserves static information available in the sou...
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition Matchers.h:490
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, MLIRContext *context)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
Definition Matchers.h:527
detail::DenseArrayAttrImpl< int64_t > DenseI64ArrayAttr
const FrozenRewritePatternSet GreedyRewriteConfig bool * changed
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
std::function< SmallVector< Value >( OpBuilder &b, Location loc, ArrayRef< BlockArgument > newBbArgs)> NewYieldValuesFn
A function that returns the additional yielded values during replaceWithAdditionalYields.
ParseResult parseDynamicIndexList(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &values, DenseI64ArrayAttr &integers, DenseBoolArrayAttr &scalableFlags, SmallVectorImpl< Type > *valueTypes=nullptr, AsmParser::Delimiter delimiter=AsmParser::Delimiter::Square)
Parser hooks for custom directive in assemblyFormat.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:304
llvm::DenseSet< ValueT, ValueInfoT > DenseSet
Definition LLVM.h:128
llvm::SetVector< T, Vector, Set, N > SetVector
Definition LLVM.h:131
detail::constant_int_predicate_matcher m_One()
Matches a constant scalar / vector splat / tensor splat integer one.
Definition Matchers.h:478
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition Utils.cpp:111
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
Definition LLVM.h:126
LogicalResult verifyListOfOperandsOrIntegers(Operation *op, StringRef name, unsigned expectedNumElements, ArrayRef< int64_t > attr, ValueRange values)
Verify that a the values has as many elements as the number of entries in attr for which isDynamic ev...
void populateRegionBranchOpInterfaceCanonicalizationPatterns(RewritePatternSet &patterns, StringRef opName, PatternBenefit benefit=1)
Populate canonicalization patterns that simplify successor operands/inputs of region branch operation...
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition Matchers.h:369
void visitUsedValuesDefinedAbove(Region &region, Region &limit, function_ref< void(OpOperand *)> callback)
Calls callback for each use of a value within region or its descendants that was defined at the ances...
llvm::function_ref< Fn > function_ref
Definition LLVM.h:152
SmallVector< NamedAttribute > getPrunedAttributeList(Operation *op, ArrayRef< StringRef > elidedAttrs)
void printDynamicIndexList(OpAsmPrinter &printer, Operation *op, OperandRange values, ArrayRef< int64_t > integers, ArrayRef< bool > scalableFlags, TypeRange valueTypes=TypeRange(), AsmParser::Delimiter delimiter=AsmParser::Delimiter::Square)
Printer hooks for custom directive in assemblyFormat.
std::optional< APInt > constantTripCount(OpFoldResult lb, OpFoldResult ub, OpFoldResult step, bool isSigned, llvm::function_ref< std::optional< llvm::APSInt >(Value, Value, bool)> computeUbMinusLb)
Return the number of iterations for a loop with a lower bound lb, upper bound ub and step step.
LogicalResult foldDynamicIndexList(SmallVectorImpl< OpFoldResult > &ofrs, bool onlyNonNegative=false, bool onlyNonZero=false)
Returns "success" when any of the elements in ofrs is a constant value.
LogicalResult matchAndRewrite(scf::IndexSwitchOp op, PatternRewriter &rewriter) const override
Definition SCF.cpp:3901
LogicalResult matchAndRewrite(ExecuteRegionOp op, PatternRewriter &rewriter) const override
Definition SCF.cpp:261
LogicalResult matchAndRewrite(ExecuteRegionOp op, PatternRewriter &rewriter) const override
Definition SCF.cpp:212
This is the representation of an operand reference.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
This represents an operation in an abstracted form, suitable for use with the builder APIs.
void addOperands(ValueRange newOperands)
void addTypes(ArrayRef< Type > newTypes)
Region * addRegion()
Create a region that should be attached to the operation.