MLIR 23.0.0git
SCF.cpp
Go to the documentation of this file.
1//===- SCF.cpp - Structured Control Flow Operations -----------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
20#include "mlir/IR/IRMapping.h"
21#include "mlir/IR/Matchers.h"
22#include "mlir/IR/Operation.h"
30#include "llvm/ADT/MapVector.h"
31#include "llvm/ADT/STLExtras.h"
32#include "llvm/ADT/SmallPtrSet.h"
33#include "llvm/Support/Casting.h"
34#include "llvm/Support/DebugLog.h"
35#include <optional>
36
37using namespace mlir;
38using namespace mlir::scf;
39
40#include "mlir/Dialect/SCF/IR/SCFOpsDialect.cpp.inc"
41
42//===----------------------------------------------------------------------===//
43// SCFDialect Dialect Interfaces
44//===----------------------------------------------------------------------===//
45
46namespace {
47struct SCFInlinerInterface : public DialectInlinerInterface {
48 using DialectInlinerInterface::DialectInlinerInterface;
49 // We don't have any special restrictions on what can be inlined into
50 // destination regions (e.g. while/conditional bodies). Always allow it.
51 bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned,
52 IRMapping &valueMapping) const final {
53 return true;
54 }
55 // Operations in scf dialect are always legal to inline since they are
56 // pure.
57 bool isLegalToInline(Operation *, Region *, bool, IRMapping &) const final {
58 return true;
59 }
60 // Handle the given inlined terminator by replacing it with a new operation
61 // as necessary. Required when the region has only one block.
62 void handleTerminator(Operation *op, ValueRange valuesToRepl) const final {
63 auto retValOp = dyn_cast<scf::YieldOp>(op);
64 if (!retValOp)
65 return;
66
67 for (auto retValue : llvm::zip(valuesToRepl, retValOp.getOperands())) {
68 std::get<0>(retValue).replaceAllUsesWith(std::get<1>(retValue));
69 }
70 }
71};
72} // namespace
73
74//===----------------------------------------------------------------------===//
75// SCFDialect
76//===----------------------------------------------------------------------===//
77
78void SCFDialect::initialize() {
79 addOperations<
80#define GET_OP_LIST
81#include "mlir/Dialect/SCF/IR/SCFOps.cpp.inc"
82 >();
83 addInterfaces<SCFInlinerInterface>();
84 declarePromisedInterface<ConvertToEmitCPatternInterface, SCFDialect>();
85 declarePromisedInterfaces<bufferization::BufferDeallocationOpInterface,
86 InParallelOp, ReduceReturnOp>();
87 declarePromisedInterfaces<bufferization::BufferizableOpInterface, ConditionOp,
88 ExecuteRegionOp, ForOp, IfOp, IndexSwitchOp,
89 ForallOp, InParallelOp, WhileOp, YieldOp>();
90 declarePromisedInterface<ValueBoundsOpInterface, ForOp>();
91}
92
93/// Default callback for IfOp builders. Inserts a yield without arguments.
95 scf::YieldOp::create(builder, loc);
96}
97
98/// Verifies that the first block of the given `region` is terminated by a
99/// TerminatorTy. Reports errors on the given operation if it is not the case.
100template <typename TerminatorTy>
101static TerminatorTy verifyAndGetTerminator(Operation *op, Region &region,
102 StringRef errorMessage) {
103 Operation *terminatorOperation = nullptr;
104 if (!region.empty() && !region.front().empty()) {
105 terminatorOperation = &region.front().back();
106 if (auto yield = dyn_cast_or_null<TerminatorTy>(terminatorOperation))
107 return yield;
108 }
109 auto diag = op->emitOpError(errorMessage);
110 if (terminatorOperation)
111 diag.attachNote(terminatorOperation->getLoc()) << "terminator here";
112 return nullptr;
113}
114
115std::optional<llvm::APSInt> mlir::scf::computeUbMinusLb(Value lb, Value ub,
116 bool isSigned) {
117 llvm::APSInt diff;
118 auto addOp = ub.getDefiningOp<arith::AddIOp>();
119 if (!addOp)
120 return std::nullopt;
121 if ((isSigned && !addOp.hasNoSignedWrap()) ||
122 (!isSigned && !addOp.hasNoUnsignedWrap()))
123 return std::nullopt;
124
125 if (addOp.getLhs() != lb ||
126 !matchPattern(addOp.getRhs(), m_ConstantInt(&diff)))
127 return std::nullopt;
128 return diff;
129}
130
131//===----------------------------------------------------------------------===//
132// ExecuteRegionOp
133//===----------------------------------------------------------------------===//
134
135///
136/// (ssa-id `=`)? `execute_region` `->` function-result-type `{`
137/// block+
138/// `}`
139///
140/// Example:
141/// scf.execute_region -> i32 {
142/// %idx = load %rI[%i] : memref<128xi32>
143/// return %idx : i32
144/// }
145///
146ParseResult ExecuteRegionOp::parse(OpAsmParser &parser,
148 if (parser.parseOptionalArrowTypeList(result.types))
149 return failure();
150
151 if (succeeded(parser.parseOptionalKeyword("no_inline")))
152 result.addAttribute("no_inline", parser.getBuilder().getUnitAttr());
153
154 // Introduce the body region and parse it.
155 Region *body = result.addRegion();
156 if (parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{}) ||
157 parser.parseOptionalAttrDict(result.attributes))
158 return failure();
159
160 return success();
161}
162
163void ExecuteRegionOp::print(OpAsmPrinter &p) {
164 p.printOptionalArrowTypeList(getResultTypes());
165 p << ' ';
166 if (getNoInline())
167 p << "no_inline ";
168 p.printRegion(getRegion(),
169 /*printEntryBlockArgs=*/false,
170 /*printBlockTerminators=*/true);
171 p.printOptionalAttrDict((*this)->getAttrs(), /*elidedAttrs=*/{"no_inline"});
172}
173
174LogicalResult ExecuteRegionOp::verify() {
175 if (getRegion().empty())
176 return emitOpError("region needs to have at least one block");
177 if (getRegion().front().getNumArguments() > 0)
178 return emitOpError("region cannot have any arguments");
179 return success();
180}
181
182// Inline an ExecuteRegionOp if its parent can contain multiple blocks.
183// TODO generalize the conditions for operations which can be inlined into.
184// func @func_execute_region_elim() {
185// "test.foo"() : () -> ()
186// %v = scf.execute_region -> i64 {
187// %c = "test.cmp"() : () -> i1
188// cf.cond_br %c, ^bb2, ^bb3
189// ^bb2:
190// %x = "test.val1"() : () -> i64
191// cf.br ^bb4(%x : i64)
192// ^bb3:
193// %y = "test.val2"() : () -> i64
194// cf.br ^bb4(%y : i64)
195// ^bb4(%z : i64):
196// scf.yield %z : i64
197// }
198// "test.bar"(%v) : (i64) -> ()
199// return
200// }
201//
202// becomes
203//
204// func @func_execute_region_elim() {
205// "test.foo"() : () -> ()
206// %c = "test.cmp"() : () -> i1
207// cf.cond_br %c, ^bb1, ^bb2
208// ^bb1: // pred: ^bb0
209// %x = "test.val1"() : () -> i64
210// cf.br ^bb3(%x : i64)
211// ^bb2: // pred: ^bb0
212// %y = "test.val2"() : () -> i64
213// cf.br ^bb3(%y : i64)
214// ^bb3(%z: i64): // 2 preds: ^bb1, ^bb2
215// "test.bar"(%z) : (i64) -> ()
216// return
217// }
218//
219struct MultiBlockExecuteInliner : public OpRewritePattern<ExecuteRegionOp> {
220 using OpRewritePattern<ExecuteRegionOp>::OpRewritePattern;
221
222 LogicalResult matchAndRewrite(ExecuteRegionOp op,
223 PatternRewriter &rewriter) const override {
224 if (op.getNoInline())
225 return failure();
226 if (!isa<FunctionOpInterface, ExecuteRegionOp>(op->getParentOp()))
227 return failure();
228
229 Block *prevBlock = op->getBlock();
230 Block *postBlock = rewriter.splitBlock(prevBlock, op->getIterator());
231 rewriter.setInsertionPointToEnd(prevBlock);
232
233 cf::BranchOp::create(rewriter, op.getLoc(), &op.getRegion().front());
234
235 for (Block &blk : op.getRegion()) {
236 if (YieldOp yieldOp = dyn_cast<YieldOp>(blk.getTerminator())) {
237 rewriter.setInsertionPoint(yieldOp);
238 cf::BranchOp::create(rewriter, yieldOp.getLoc(), postBlock,
239 yieldOp.getResults());
240 rewriter.eraseOp(yieldOp);
241 }
242 }
243
244 rewriter.inlineRegionBefore(op.getRegion(), postBlock);
245 SmallVector<Value> blockArgs;
246
247 for (auto res : op.getResults())
248 blockArgs.push_back(postBlock->addArgument(res.getType(), res.getLoc()));
249
250 rewriter.replaceOp(op, blockArgs);
251 return success();
252 }
253};
254
255void ExecuteRegionOp::getCanonicalizationPatterns(RewritePatternSet &results,
256 MLIRContext *context) {
257 results.add<MultiBlockExecuteInliner>(context);
259 results, ExecuteRegionOp::getOperationName());
260 // Inline ops with a single block that are not marked as "no_inline".
262 results, ExecuteRegionOp::getOperationName(),
264 return failure(cast<ExecuteRegionOp>(op).getNoInline());
265 });
266}
267
268void ExecuteRegionOp::getSuccessorRegions(
270 // If the predecessor is the ExecuteRegionOp, branch into the body.
271 if (point.isParent()) {
272 regions.push_back(RegionSuccessor(&getRegion()));
273 return;
274 }
275
276 // Otherwise, the region branches back to the parent operation.
277 regions.push_back(RegionSuccessor::parent());
278}
279
280ValueRange ExecuteRegionOp::getSuccessorInputs(RegionSuccessor successor) {
281 return successor.isParent() ? ValueRange(getOperation()->getResults())
282 : ValueRange();
283}
284
285//===----------------------------------------------------------------------===//
286// ConditionOp
287//===----------------------------------------------------------------------===//
288
290ConditionOp::getMutableSuccessorOperands(RegionSuccessor point) {
291 assert(
292 (point.isParent() || point.getSuccessor() == &getParentOp().getAfter()) &&
293 "condition op can only exit the loop or branch to the after"
294 "region");
295 // Pass all operands except the condition to the successor region.
296 return getArgsMutable();
297}
298
299void ConditionOp::getSuccessorRegions(
301 FoldAdaptor adaptor(operands, *this);
302
303 WhileOp whileOp = getParentOp();
304
305 // Condition can either lead to the after region or back to the parent op
306 // depending on whether the condition is true or not.
307 auto boolAttr = dyn_cast_or_null<BoolAttr>(adaptor.getCondition());
308 if (!boolAttr || boolAttr.getValue())
309 regions.emplace_back(&whileOp.getAfter());
310 if (!boolAttr || !boolAttr.getValue())
311 regions.push_back(RegionSuccessor::parent());
312}
313
314//===----------------------------------------------------------------------===//
315// ForOp
316//===----------------------------------------------------------------------===//
317
318void ForOp::build(OpBuilder &builder, OperationState &result, Value lb,
319 Value ub, Value step, ValueRange initArgs,
320 BodyBuilderFn bodyBuilder, bool unsignedCmp) {
321 OpBuilder::InsertionGuard guard(builder);
322
323 if (unsignedCmp)
324 result.addAttribute(getUnsignedCmpAttrName(result.name),
325 builder.getUnitAttr());
326 result.addOperands({lb, ub, step});
327 result.addOperands(initArgs);
328 for (Value v : initArgs)
329 result.addTypes(v.getType());
330 Type t = lb.getType();
331 Region *bodyRegion = result.addRegion();
332 Block *bodyBlock = builder.createBlock(bodyRegion);
333 bodyBlock->addArgument(t, result.location);
334 for (Value v : initArgs)
335 bodyBlock->addArgument(v.getType(), v.getLoc());
336
337 // Create the default terminator if the builder is not provided and if the
338 // iteration arguments are not provided. Otherwise, leave this to the caller
339 // because we don't know which values to return from the loop.
340 if (initArgs.empty() && !bodyBuilder) {
341 ForOp::ensureTerminator(*bodyRegion, builder, result.location);
342 } else if (bodyBuilder) {
343 OpBuilder::InsertionGuard guard(builder);
344 builder.setInsertionPointToStart(bodyBlock);
345 bodyBuilder(builder, result.location, bodyBlock->getArgument(0),
346 bodyBlock->getArguments().drop_front());
347 }
348}
349
350LogicalResult ForOp::verify() {
351 // Check that the body block has at least the induction variable argument.
352 // This must be checked before verifyRegions() and before any region trait
353 // verifiers (e.g. LoopLikeOpInterface) that call getRegionIterArgs(), to
354 // avoid crashing with an out-of-bounds drop_front on an empty arg list.
355 if (getBody()->getNumArguments() < getNumInductionVars())
356 return emitOpError("expected body to have at least ")
357 << getNumInductionVars()
358 << " argument(s) for the induction variable, but got "
359 << getBody()->getNumArguments();
360
361 // Check that the number of init args and op results is the same.
362 if (getInitArgs().size() != getNumResults())
363 return emitOpError(
364 "mismatch in number of loop-carried values and defined values");
365
366 return success();
367}
368
369LogicalResult ForOp::verifyRegions() {
370 // Check that the body block has at least the induction variable argument.
371 if (getBody()->getNumArguments() < getNumInductionVars())
372 return emitOpError("expected body to have at least ")
373 << getNumInductionVars() << " argument(s) for the induction "
374 << "variable, but got " << getBody()->getNumArguments();
375
376 // Check that the body defines as single block argument for the induction
377 // variable.
378 if (getInductionVar().getType() != getLowerBound().getType())
379 return emitOpError(
380 "expected induction variable to be same type as bounds and step");
381
382 if (getNumRegionIterArgs() != getNumResults())
383 return emitOpError(
384 "mismatch in number of basic block args and defined values");
385
386 auto initArgs = getInitArgs();
387 auto iterArgs = getRegionIterArgs();
388 auto opResults = getResults();
389 unsigned i = 0;
390 for (auto e : llvm::zip(initArgs, iterArgs, opResults)) {
391 if (std::get<0>(e).getType() != std::get<2>(e).getType())
392 return emitOpError() << "types mismatch between " << i
393 << "th iter operand and defined value";
394 if (std::get<1>(e).getType() != std::get<2>(e).getType())
395 return emitOpError() << "types mismatch between " << i
396 << "th iter region arg and defined value";
397
398 ++i;
399 }
400 return success();
401}
402
403std::optional<SmallVector<Value>> ForOp::getLoopInductionVars() {
404 return SmallVector<Value>{getInductionVar()};
405}
406
407std::optional<SmallVector<OpFoldResult>> ForOp::getLoopLowerBounds() {
409}
410
411std::optional<SmallVector<OpFoldResult>> ForOp::getLoopSteps() {
412 return SmallVector<OpFoldResult>{OpFoldResult(getStep())};
413}
414
415std::optional<SmallVector<OpFoldResult>> ForOp::getLoopUpperBounds() {
417}
418
419bool ForOp::isValidInductionVarType(Type type) {
420 return type.isIndex() || type.isSignlessInteger();
421}
422
423LogicalResult ForOp::setLoopLowerBounds(ArrayRef<OpFoldResult> bounds) {
424 if (bounds.size() != 1)
425 return failure();
426 if (auto val = dyn_cast<Value>(bounds[0])) {
427 setLowerBound(val);
428 return success();
429 }
430 return failure();
431}
432
433LogicalResult ForOp::setLoopUpperBounds(ArrayRef<OpFoldResult> bounds) {
434 if (bounds.size() != 1)
435 return failure();
436 if (auto val = dyn_cast<Value>(bounds[0])) {
437 setUpperBound(val);
438 return success();
439 }
440 return failure();
441}
442
443LogicalResult ForOp::setLoopSteps(ArrayRef<OpFoldResult> steps) {
444 if (steps.size() != 1)
445 return failure();
446 if (auto val = dyn_cast<Value>(steps[0])) {
447 setStep(val);
448 return success();
449 }
450 return failure();
451}
452
453std::optional<ResultRange> ForOp::getLoopResults() { return getResults(); }
454
455/// Promotes the loop body of a forOp to its containing block if the forOp
456/// it can be determined that the loop has a single iteration.
457LogicalResult ForOp::promoteIfSingleIteration(RewriterBase &rewriter) {
458 std::optional<APInt> tripCount = getStaticTripCount();
459 LDBG() << "promoteIfSingleIteration tripCount is " << tripCount
460 << " for loop "
461 << OpWithFlags(getOperation(), OpPrintingFlags().skipRegions());
462 if (!tripCount.has_value() || tripCount->getZExtValue() > 1)
463 return failure();
464
465 if (*tripCount == 0) {
466 rewriter.replaceAllUsesWith(getResults(), getInitArgs());
467 rewriter.eraseOp(*this);
468 return success();
469 }
470
471 // Replace all results with the yielded values.
472 auto yieldOp = cast<scf::YieldOp>(getBody()->getTerminator());
473 rewriter.replaceAllUsesWith(getResults(), getYieldedValues());
474
475 // Replace block arguments with lower bound (replacement for IV) and
476 // iter_args.
477 SmallVector<Value> bbArgReplacements;
478 bbArgReplacements.push_back(getLowerBound());
479 llvm::append_range(bbArgReplacements, getInitArgs());
480
481 // Move the loop body operations to the loop's containing block.
482 rewriter.inlineBlockBefore(getBody(), getOperation()->getBlock(),
483 getOperation()->getIterator(), bbArgReplacements);
484
485 // Erase the old terminator and the loop.
486 rewriter.eraseOp(yieldOp);
487 rewriter.eraseOp(*this);
488
489 return success();
490}
491
492/// Prints the initialization list in the form of
493/// <prefix>(%inner = %outer, %inner2 = %outer2, <...>)
494/// where 'inner' values are assumed to be region arguments and 'outer' values
495/// are regular SSA values.
497 Block::BlockArgListType blocksArgs,
498 ValueRange initializers,
499 StringRef prefix = "") {
500 assert(blocksArgs.size() == initializers.size() &&
501 "expected same length of arguments and initializers");
502 if (initializers.empty())
503 return;
504
505 p << prefix << '(';
506 llvm::interleaveComma(llvm::zip(blocksArgs, initializers), p, [&](auto it) {
507 p << std::get<0>(it) << " = " << std::get<1>(it);
508 });
509 p << ")";
510}
511
512void ForOp::print(OpAsmPrinter &p) {
513 if (getUnsignedCmp())
514 p << " unsigned";
515
516 p << " " << getInductionVar() << " = " << getLowerBound() << " to "
517 << getUpperBound() << " step " << getStep();
518
519 printInitializationList(p, getRegionIterArgs(), getInitArgs(), " iter_args");
520 if (!getInitArgs().empty())
521 p << " -> (" << getInitArgs().getTypes() << ')';
522 p << ' ';
523 if (Type t = getInductionVar().getType(); !t.isIndex())
524 p << " : " << t << ' ';
525 p.printRegion(getRegion(),
526 /*printEntryBlockArgs=*/false,
527 /*printBlockTerminators=*/!getInitArgs().empty());
528 p.printOptionalAttrDict((*this)->getAttrs(),
529 /*elidedAttrs=*/getUnsignedCmpAttrName().strref());
530}
531
532ParseResult ForOp::parse(OpAsmParser &parser, OperationState &result) {
533 auto &builder = parser.getBuilder();
534 Type type;
535
536 OpAsmParser::Argument inductionVariable;
538
539 if (succeeded(parser.parseOptionalKeyword("unsigned")))
540 result.addAttribute(getUnsignedCmpAttrName(result.name),
541 builder.getUnitAttr());
542
543 // Parse the induction variable followed by '='.
544 if (parser.parseOperand(inductionVariable.ssaName) || parser.parseEqual() ||
545 // Parse loop bounds.
546 parser.parseOperand(lb) || parser.parseKeyword("to") ||
547 parser.parseOperand(ub) || parser.parseKeyword("step") ||
548 parser.parseOperand(step))
549 return failure();
550
551 // Parse the optional initial iteration arguments.
554 regionArgs.push_back(inductionVariable);
555
556 bool hasIterArgs = succeeded(parser.parseOptionalKeyword("iter_args"));
557 if (hasIterArgs) {
558 // Parse assignment list and results type list.
559 if (parser.parseAssignmentList(regionArgs, operands) ||
560 parser.parseArrowTypeList(result.types))
561 return failure();
562 }
563
564 if (regionArgs.size() != result.types.size() + 1)
565 return parser.emitError(
566 parser.getNameLoc(),
567 "mismatch in number of loop-carried values and defined values");
568
569 // Parse optional type, else assume Index.
570 if (parser.parseOptionalColon())
571 type = builder.getIndexType();
572 else if (parser.parseType(type))
573 return failure();
574
575 // Set block argument types, so that they are known when parsing the region.
576 regionArgs.front().type = type;
577 for (auto [iterArg, type] :
578 llvm::zip_equal(llvm::drop_begin(regionArgs), result.types))
579 iterArg.type = type;
580
581 // Parse the body region.
582 Region *body = result.addRegion();
583 if (parser.parseRegion(*body, regionArgs))
584 return failure();
585 ForOp::ensureTerminator(*body, builder, result.location);
586
587 // Resolve input operands. This should be done after parsing the region to
588 // catch invalid IR where operands were defined inside of the region.
589 if (parser.resolveOperand(lb, type, result.operands) ||
590 parser.resolveOperand(ub, type, result.operands) ||
591 parser.resolveOperand(step, type, result.operands))
592 return failure();
593 if (hasIterArgs) {
594 for (auto argOperandType : llvm::zip_equal(llvm::drop_begin(regionArgs),
595 operands, result.types)) {
596 Type type = std::get<2>(argOperandType);
597 std::get<0>(argOperandType).type = type;
598 if (parser.resolveOperand(std::get<1>(argOperandType), type,
599 result.operands))
600 return failure();
601 }
602 }
603
604 // Parse the optional attribute list.
605 if (parser.parseOptionalAttrDict(result.attributes))
606 return failure();
607
608 return success();
609}
610
611SmallVector<Region *> ForOp::getLoopRegions() { return {&getRegion()}; }
612
613Block::BlockArgListType ForOp::getRegionIterArgs() {
614 return getBody()->getArguments().drop_front(getNumInductionVars());
615}
616
617MutableArrayRef<OpOperand> ForOp::getInitsMutable() {
618 return getInitArgsMutable();
619}
620
621FailureOr<LoopLikeOpInterface>
622ForOp::replaceWithAdditionalYields(RewriterBase &rewriter,
623 ValueRange newInitOperands,
624 bool replaceInitOperandUsesInLoop,
625 const NewYieldValuesFn &newYieldValuesFn) {
626 // Create a new loop before the existing one, with the extra operands.
627 OpBuilder::InsertionGuard g(rewriter);
628 rewriter.setInsertionPoint(getOperation());
629 auto inits = llvm::to_vector(getInitArgs());
630 inits.append(newInitOperands.begin(), newInitOperands.end());
631 scf::ForOp newLoop = scf::ForOp::create(
632 rewriter, getLoc(), getLowerBound(), getUpperBound(), getStep(), inits,
633 [](OpBuilder &, Location, Value, ValueRange) {}, getUnsignedCmp());
634 newLoop->setAttrs(getPrunedAttributeList(getOperation(), {}));
635
636 // Generate the new yield values and append them to the scf.yield operation.
637 auto yieldOp = cast<scf::YieldOp>(getBody()->getTerminator());
638 ArrayRef<BlockArgument> newIterArgs =
639 newLoop.getBody()->getArguments().take_back(newInitOperands.size());
640 {
641 OpBuilder::InsertionGuard g(rewriter);
642 rewriter.setInsertionPoint(yieldOp);
643 SmallVector<Value> newYieldedValues =
644 newYieldValuesFn(rewriter, getLoc(), newIterArgs);
645 assert(newInitOperands.size() == newYieldedValues.size() &&
646 "expected as many new yield values as new iter operands");
647 rewriter.modifyOpInPlace(yieldOp, [&]() {
648 yieldOp.getResultsMutable().append(newYieldedValues);
649 });
650 }
651
652 // Move the loop body to the new op.
653 rewriter.mergeBlocks(getBody(), newLoop.getBody(),
654 newLoop.getBody()->getArguments().take_front(
655 getBody()->getNumArguments()));
656
657 if (replaceInitOperandUsesInLoop) {
658 // Replace all uses of `newInitOperands` with the corresponding basic block
659 // arguments.
660 for (auto it : llvm::zip(newInitOperands, newIterArgs)) {
661 rewriter.replaceUsesWithIf(std::get<0>(it), std::get<1>(it),
662 [&](OpOperand &use) {
663 Operation *user = use.getOwner();
664 return newLoop->isProperAncestor(user);
665 });
666 }
667 }
668
669 // Replace the old loop.
670 rewriter.replaceOp(getOperation(),
671 newLoop->getResults().take_front(getNumResults()));
672 return cast<LoopLikeOpInterface>(newLoop.getOperation());
673}
674
676 auto ivArg = llvm::dyn_cast<BlockArgument>(val);
677 if (!ivArg)
678 return ForOp();
679 assert(ivArg.getOwner() && "unlinked block argument");
680 auto *containingOp = ivArg.getOwner()->getParentOp();
681 return dyn_cast_or_null<ForOp>(containingOp);
682}
683
684OperandRange ForOp::getEntrySuccessorOperands(RegionSuccessor successor) {
685 return getInitArgs();
686}
687
688void ForOp::getSuccessorRegions(RegionBranchPoint point,
690 if (std::optional<APInt> tripCount = getStaticTripCount()) {
691 // The loop has a known static trip count.
692 if (point.isParent()) {
693 if (*tripCount == 0) {
694 // The loop has zero iterations. It branches directly back to the
695 // parent.
696 regions.push_back(RegionSuccessor::parent());
697 } else {
698 // The loop has at least one iteration. It branches into the body.
699 regions.push_back(RegionSuccessor(&getRegion()));
700 }
701 return;
702 } else if (*tripCount == 1) {
703 // The loop has exactly 1 iteration. Therefore, it branches from the
704 // region to the parent. (No further iteration.)
705 regions.push_back(RegionSuccessor::parent());
706 return;
707 }
708 }
709
710 // Both the operation itself and the region may be branching into the body or
711 // back into the operation itself. It is possible for loop not to enter the
712 // body.
713 regions.push_back(RegionSuccessor(&getRegion()));
714 regions.push_back(RegionSuccessor::parent());
715}
716
717ValueRange ForOp::getSuccessorInputs(RegionSuccessor successor) {
718 return successor.isParent() ? ValueRange(getResults())
719 : ValueRange(getRegionIterArgs());
720}
721
722SmallVector<Region *> ForallOp::getLoopRegions() { return {&getRegion()}; }
723
724/// Promotes the loop body of a forallOp to its containing block if it can be
725/// determined that the loop has a single iteration.
726LogicalResult scf::ForallOp::promoteIfSingleIteration(RewriterBase &rewriter) {
727 for (auto [lb, ub, step] :
728 llvm::zip(getMixedLowerBound(), getMixedUpperBound(), getMixedStep())) {
729 auto tripCount =
730 constantTripCount(lb, ub, step, /*isSigned=*/true, computeUbMinusLb);
731 if (!tripCount.has_value() || *tripCount != 1)
732 return failure();
733 }
734
735 promote(rewriter, *this);
736 return success();
737}
738
739Block::BlockArgListType ForallOp::getRegionIterArgs() {
740 return getBody()->getArguments().drop_front(getRank());
741}
742
743MutableArrayRef<OpOperand> ForallOp::getInitsMutable() {
744 return getOutputsMutable();
745}
746
747/// Promotes the loop body of a scf::ForallOp to its containing block.
748void mlir::scf::promote(RewriterBase &rewriter, scf::ForallOp forallOp) {
749 OpBuilder::InsertionGuard g(rewriter);
750 scf::InParallelOp terminator = forallOp.getTerminator();
751
752 // Replace block arguments with lower bounds (replacements for IVs) and
753 // outputs.
754 SmallVector<Value> bbArgReplacements = forallOp.getLowerBound(rewriter);
755 bbArgReplacements.append(forallOp.getOutputs().begin(),
756 forallOp.getOutputs().end());
757
758 // Move the loop body operations to the loop's containing block.
759 rewriter.inlineBlockBefore(forallOp.getBody(), forallOp->getBlock(),
760 forallOp->getIterator(), bbArgReplacements);
761
762 // Replace the terminator with tensor.insert_slice ops.
763 rewriter.setInsertionPointAfter(forallOp);
764 SmallVector<Value> results;
765 results.reserve(forallOp.getResults().size());
766 for (auto &yieldingOp : terminator.getYieldingOps()) {
767 auto parallelInsertSliceOp =
768 dyn_cast<tensor::ParallelInsertSliceOp>(yieldingOp);
769 if (!parallelInsertSliceOp)
770 continue;
771
772 Value dst = parallelInsertSliceOp.getDest();
773 Value src = parallelInsertSliceOp.getSource();
774 if (llvm::isa<TensorType>(src.getType())) {
775 results.push_back(tensor::InsertSliceOp::create(
776 rewriter, forallOp.getLoc(), dst.getType(), src, dst,
777 parallelInsertSliceOp.getOffsets(), parallelInsertSliceOp.getSizes(),
778 parallelInsertSliceOp.getStrides(),
779 parallelInsertSliceOp.getStaticOffsets(),
780 parallelInsertSliceOp.getStaticSizes(),
781 parallelInsertSliceOp.getStaticStrides()));
782 } else {
783 llvm_unreachable("unsupported terminator");
784 }
785 }
786 rewriter.replaceAllUsesWith(forallOp.getResults(), results);
787
788 // Erase the old terminator and the loop.
789 rewriter.eraseOp(terminator);
790 rewriter.eraseOp(forallOp);
791}
792
794 OpBuilder &builder, Location loc, ValueRange lbs, ValueRange ubs,
795 ValueRange steps, ValueRange iterArgs,
797 bodyBuilder) {
798 assert(lbs.size() == ubs.size() &&
799 "expected the same number of lower and upper bounds");
800 assert(lbs.size() == steps.size() &&
801 "expected the same number of lower bounds and steps");
802
803 // If there are no bounds, call the body-building function and return early.
804 if (lbs.empty()) {
805 ValueVector results =
806 bodyBuilder ? bodyBuilder(builder, loc, ValueRange(), iterArgs)
807 : ValueVector();
808 assert(results.size() == iterArgs.size() &&
809 "loop nest body must return as many values as loop has iteration "
810 "arguments");
811 return LoopNest{{}, std::move(results)};
812 }
813
814 // First, create the loop structure iteratively using the body-builder
815 // callback of `ForOp::build`. Do not create `YieldOp`s yet.
816 OpBuilder::InsertionGuard guard(builder);
819 loops.reserve(lbs.size());
820 ivs.reserve(lbs.size());
821 ValueRange currentIterArgs = iterArgs;
822 Location currentLoc = loc;
823 for (unsigned i = 0, e = lbs.size(); i < e; ++i) {
824 auto loop = scf::ForOp::create(
825 builder, currentLoc, lbs[i], ubs[i], steps[i], currentIterArgs,
826 [&](OpBuilder &nestedBuilder, Location nestedLoc, Value iv,
827 ValueRange args) {
828 ivs.push_back(iv);
829 // It is safe to store ValueRange args because it points to block
830 // arguments of a loop operation that we also own.
831 currentIterArgs = args;
832 currentLoc = nestedLoc;
833 });
834 // Set the builder to point to the body of the newly created loop. We don't
835 // do this in the callback because the builder is reset when the callback
836 // returns.
837 builder.setInsertionPointToStart(loop.getBody());
838 loops.push_back(loop);
839 }
840
841 // For all loops but the innermost, yield the results of the nested loop.
842 for (unsigned i = 0, e = loops.size() - 1; i < e; ++i) {
843 builder.setInsertionPointToEnd(loops[i].getBody());
844 scf::YieldOp::create(builder, loc, loops[i + 1].getResults());
845 }
846
847 // In the body of the innermost loop, call the body building function if any
848 // and yield its results.
849 builder.setInsertionPointToStart(loops.back().getBody());
850 ValueVector results = bodyBuilder
851 ? bodyBuilder(builder, currentLoc, ivs,
852 loops.back().getRegionIterArgs())
853 : ValueVector();
854 assert(results.size() == iterArgs.size() &&
855 "loop nest body must return as many values as loop has iteration "
856 "arguments");
857 builder.setInsertionPointToEnd(loops.back().getBody());
858 scf::YieldOp::create(builder, loc, results);
859
860 // Return the loops.
861 ValueVector nestResults;
862 llvm::append_range(nestResults, loops.front().getResults());
863 return LoopNest{std::move(loops), std::move(nestResults)};
864}
865
867 OpBuilder &builder, Location loc, ValueRange lbs, ValueRange ubs,
868 ValueRange steps,
869 function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilder) {
870 // Delegate to the main function by wrapping the body builder.
871 return buildLoopNest(builder, loc, lbs, ubs, steps, {},
872 [&bodyBuilder](OpBuilder &nestedBuilder,
873 Location nestedLoc, ValueRange ivs,
875 if (bodyBuilder)
876 bodyBuilder(nestedBuilder, nestedLoc, ivs);
877 return {};
878 });
879}
880
883 OpOperand &operand, Value replacement,
884 const ValueTypeCastFnTy &castFn) {
885 assert(operand.getOwner() == forOp);
886 Type oldType = operand.get().getType(), newType = replacement.getType();
887
888 // 1. Create new iter operands, exactly 1 is replaced.
889 assert(operand.getOperandNumber() >= forOp.getNumControlOperands() &&
890 "expected an iter OpOperand");
891 assert(operand.get().getType() != replacement.getType() &&
892 "Expected a different type");
893 SmallVector<Value> newIterOperands;
894 for (OpOperand &opOperand : forOp.getInitArgsMutable()) {
895 if (opOperand.getOperandNumber() == operand.getOperandNumber()) {
896 newIterOperands.push_back(replacement);
897 continue;
898 }
899 newIterOperands.push_back(opOperand.get());
900 }
901
902 // 2. Create the new forOp shell.
903 scf::ForOp newForOp = scf::ForOp::create(
904 rewriter, forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
905 forOp.getStep(), newIterOperands, /*bodyBuilder=*/nullptr,
906 forOp.getUnsignedCmp());
907 newForOp->setAttrs(forOp->getAttrs());
908 Block &newBlock = newForOp.getRegion().front();
909 SmallVector<Value, 4> newBlockTransferArgs(newBlock.getArguments().begin(),
910 newBlock.getArguments().end());
911
912 // 3. Inject an incoming cast op at the beginning of the block for the bbArg
913 // corresponding to the `replacement` value.
914 OpBuilder::InsertionGuard g(rewriter);
915 rewriter.setInsertionPointToStart(&newBlock);
916 BlockArgument newRegionIterArg = newForOp.getTiedLoopRegionIterArg(
917 &newForOp->getOpOperand(operand.getOperandNumber()));
918 Value castIn = castFn(rewriter, newForOp.getLoc(), oldType, newRegionIterArg);
919 newBlockTransferArgs[newRegionIterArg.getArgNumber()] = castIn;
920
921 // 4. Steal the old block ops, mapping to the newBlockTransferArgs.
922 Block &oldBlock = forOp.getRegion().front();
923 rewriter.mergeBlocks(&oldBlock, &newBlock, newBlockTransferArgs);
924
925 // 5. Inject an outgoing cast op at the end of the block and yield it instead.
926 auto clonedYieldOp = cast<scf::YieldOp>(newBlock.getTerminator());
927 rewriter.setInsertionPoint(clonedYieldOp);
928 unsigned yieldIdx =
929 newRegionIterArg.getArgNumber() - forOp.getNumInductionVars();
930 Value castOut = castFn(rewriter, newForOp.getLoc(), newType,
931 clonedYieldOp.getOperand(yieldIdx));
932 SmallVector<Value> newYieldOperands = clonedYieldOp.getOperands();
933 newYieldOperands[yieldIdx] = castOut;
934 scf::YieldOp::create(rewriter, newForOp.getLoc(), newYieldOperands);
935 rewriter.eraseOp(clonedYieldOp);
936
937 // 6. Inject an outgoing cast op after the forOp.
938 rewriter.setInsertionPointAfter(newForOp);
939 SmallVector<Value> newResults = newForOp.getResults();
940 newResults[yieldIdx] =
941 castFn(rewriter, newForOp.getLoc(), oldType, newResults[yieldIdx]);
942
943 return newResults;
944}
945
946namespace {
947/// Fold scf.for iter_arg/result pairs that go through incoming/ougoing
948/// a tensor.cast op pair so as to pull the tensor.cast inside the scf.for:
949///
950/// ```
951/// %0 = tensor.cast %t0 : tensor<32x1024xf32> to tensor<?x?xf32>
952/// %1 = scf.for %i = %c0 to %c1024 step %c32 iter_args(%iter_t0 = %0)
953/// -> (tensor<?x?xf32>) {
954/// %2 = call @do(%iter_t0) : (tensor<?x?xf32>) -> tensor<?x?xf32>
955/// scf.yield %2 : tensor<?x?xf32>
956/// }
957/// use_of(%1)
958/// ```
959///
960/// folds into:
961///
962/// ```
963/// %0 = scf.for %arg2 = %c0 to %c1024 step %c32 iter_args(%arg3 = %arg0)
964/// -> (tensor<32x1024xf32>) {
965/// %2 = tensor.cast %arg3 : tensor<32x1024xf32> to tensor<?x?xf32>
966/// %3 = call @do(%2) : (tensor<?x?xf32>) -> tensor<?x?xf32>
967/// %4 = tensor.cast %3 : tensor<?x?xf32> to tensor<32x1024xf32>
968/// scf.yield %4 : tensor<32x1024xf32>
969/// }
970/// %1 = tensor.cast %0 : tensor<32x1024xf32> to tensor<?x?xf32>
971/// use_of(%1)
972/// ```
973struct ForOpTensorCastFolder : public OpRewritePattern<ForOp> {
975
976 LogicalResult matchAndRewrite(ForOp op,
977 PatternRewriter &rewriter) const override {
978 for (auto it : llvm::zip(op.getInitArgsMutable(), op.getResults())) {
979 OpOperand &iterOpOperand = std::get<0>(it);
980 auto incomingCast = iterOpOperand.get().getDefiningOp<tensor::CastOp>();
981 if (!incomingCast ||
982 incomingCast.getSource().getType() == incomingCast.getType())
983 continue;
984 // If the dest type of the cast does not preserve static information in
985 // the source type.
987 incomingCast.getDest().getType(),
988 incomingCast.getSource().getType()))
989 continue;
990 if (!std::get<1>(it).hasOneUse())
991 continue;
992
993 // Create a new ForOp with that iter operand replaced.
994 rewriter.replaceOp(
996 rewriter, op, iterOpOperand, incomingCast.getSource(),
997 [](OpBuilder &b, Location loc, Type type, Value source) {
998 return tensor::CastOp::create(b, loc, type, source);
999 }));
1000 return success();
1001 }
1002 return failure();
1003 }
1004};
1005} // namespace
1006
1007void ForOp::getCanonicalizationPatterns(RewritePatternSet &results,
1008 MLIRContext *context) {
1009 results.add<ForOpTensorCastFolder>(context);
1011 results, ForOp::getOperationName());
1013 results, ForOp::getOperationName(),
1014 /*replBuilderFn=*/[](OpBuilder &builder, Location loc, Value value) {
1015 // scf.for has only one non-successor input value: the loop induction
1016 // variable. In case of a single acyclic path through the op, the IV can
1017 // be safely replaced with the lower bound.
1018 auto blockArg = cast<BlockArgument>(value);
1019 assert(blockArg.getArgNumber() == 0 && "expected induction variable");
1020 auto forOp = cast<ForOp>(blockArg.getOwner()->getParentOp());
1021 return forOp.getLowerBound();
1022 });
1023}
1024
1025std::optional<APInt> ForOp::getConstantStep() {
1026 IntegerAttr step;
1027 if (matchPattern(getStep(), m_Constant(&step)))
1028 return step.getValue();
1029 return {};
1030}
1031
1032std::optional<MutableArrayRef<OpOperand>> ForOp::getYieldedValuesMutable() {
1033 return cast<scf::YieldOp>(getBody()->getTerminator()).getResultsMutable();
1034}
1035
1036Speculation::Speculatability ForOp::getSpeculatability() {
1037 // `scf.for (I = Start; I < End; I += 1)` terminates for all values of Start
1038 // and End.
1039 if (auto constantStep = getConstantStep())
1040 if (*constantStep == 1)
1042
1043 // For Step != 1, the loop may not terminate. We can add more smarts here if
1044 // needed.
1046}
1047
1048std::optional<APInt> ForOp::getStaticTripCount() {
1049 return constantTripCount(getLowerBound(), getUpperBound(), getStep(),
1050 /*isSigned=*/!getUnsignedCmp(), computeUbMinusLb);
1051}
1052
1053//===----------------------------------------------------------------------===//
1054// ForallOp
1055//===----------------------------------------------------------------------===//
1056
1057LogicalResult ForallOp::verify() {
1058 unsigned numLoops = getRank();
1059 // Check number of outputs.
1060 if (getNumResults() != getOutputs().size())
1061 return emitOpError("produces ")
1062 << getNumResults() << " results, but has only "
1063 << getOutputs().size() << " outputs";
1064
1065 // Check that the body defines block arguments for thread indices and outputs.
1066 auto *body = getBody();
1067 if (body->getNumArguments() != numLoops + getOutputs().size())
1068 return emitOpError("region expects ") << numLoops << " arguments";
1069 for (int64_t i = 0; i < numLoops; ++i)
1070 if (!body->getArgument(i).getType().isIndex())
1071 return emitOpError("expects ")
1072 << i << "-th block argument to be an index";
1073 for (unsigned i = 0; i < getOutputs().size(); ++i)
1074 if (body->getArgument(i + numLoops).getType() != getOutputs()[i].getType())
1075 return emitOpError("type mismatch between ")
1076 << i << "-th output and corresponding block argument";
1077 if (getMapping().has_value() && !getMapping()->empty()) {
1078 if (getDeviceMappingAttrs().size() != numLoops)
1079 return emitOpError() << "mapping attribute size must match op rank";
1080 if (failed(getDeviceMaskingAttr()))
1081 return emitOpError() << getMappingAttrName()
1082 << " supports at most one device masking attribute";
1083 }
1084
1085 // Verify mixed static/dynamic control variables.
1086 Operation *op = getOperation();
1087 if (failed(verifyListOfOperandsOrIntegers(op, "lower bound", numLoops,
1088 getStaticLowerBound(),
1089 getDynamicLowerBound())))
1090 return failure();
1091 if (failed(verifyListOfOperandsOrIntegers(op, "upper bound", numLoops,
1092 getStaticUpperBound(),
1093 getDynamicUpperBound())))
1094 return failure();
1095 if (failed(verifyListOfOperandsOrIntegers(op, "step", numLoops,
1096 getStaticStep(), getDynamicStep())))
1097 return failure();
1098
1099 return success();
1100}
1101
1102void ForallOp::print(OpAsmPrinter &p) {
1103 Operation *op = getOperation();
1104 p << " (" << getInductionVars();
1105 if (isNormalized()) {
1106 p << ") in ";
1107 printDynamicIndexList(p, op, getDynamicUpperBound(), getStaticUpperBound(),
1108 /*valueTypes=*/{}, /*scalables=*/{},
1110 } else {
1111 p << ") = ";
1112 printDynamicIndexList(p, op, getDynamicLowerBound(), getStaticLowerBound(),
1113 /*valueTypes=*/{}, /*scalables=*/{},
1115 p << " to ";
1116 printDynamicIndexList(p, op, getDynamicUpperBound(), getStaticUpperBound(),
1117 /*valueTypes=*/{}, /*scalables=*/{},
1119 p << " step ";
1120 printDynamicIndexList(p, op, getDynamicStep(), getStaticStep(),
1121 /*valueTypes=*/{}, /*scalables=*/{},
1123 }
1124 printInitializationList(p, getRegionOutArgs(), getOutputs(), " shared_outs");
1125 p << " ";
1126 if (!getRegionOutArgs().empty())
1127 p << "-> (" << getResultTypes() << ") ";
1128 p.printRegion(getRegion(),
1129 /*printEntryBlockArgs=*/false,
1130 /*printBlockTerminators=*/getNumResults() > 0);
1131 p.printOptionalAttrDict(op->getAttrs(), {getOperandSegmentSizesAttrName(),
1132 getStaticLowerBoundAttrName(),
1133 getStaticUpperBoundAttrName(),
1134 getStaticStepAttrName()});
1135}
1136
1137ParseResult ForallOp::parse(OpAsmParser &parser, OperationState &result) {
1138 OpBuilder b(parser.getContext());
1139 auto indexType = b.getIndexType();
1140
1141 // Parse an opening `(` followed by thread index variables followed by `)`
1142 // TODO: when we can refer to such "induction variable"-like handles from the
1143 // declarative assembly format, we can implement the parser as a custom hook.
1144 SmallVector<OpAsmParser::Argument, 4> ivs;
1146 return failure();
1147
1148 DenseI64ArrayAttr staticLbs, staticUbs, staticSteps;
1149 SmallVector<OpAsmParser::UnresolvedOperand> dynamicLbs, dynamicUbs,
1150 dynamicSteps;
1151 if (succeeded(parser.parseOptionalKeyword("in"))) {
1152 // Parse upper bounds.
1153 if (parseDynamicIndexList(parser, dynamicUbs, staticUbs,
1154 /*valueTypes=*/nullptr,
1156 parser.resolveOperands(dynamicUbs, indexType, result.operands))
1157 return failure();
1158
1159 unsigned numLoops = ivs.size();
1160 staticLbs = b.getDenseI64ArrayAttr(SmallVector<int64_t>(numLoops, 0));
1161 staticSteps = b.getDenseI64ArrayAttr(SmallVector<int64_t>(numLoops, 1));
1162 } else {
1163 // Parse lower bounds.
1164 if (parser.parseEqual() ||
1165 parseDynamicIndexList(parser, dynamicLbs, staticLbs,
1166 /*valueTypes=*/nullptr,
1168
1169 parser.resolveOperands(dynamicLbs, indexType, result.operands))
1170 return failure();
1171
1172 // Parse upper bounds.
1173 if (parser.parseKeyword("to") ||
1174 parseDynamicIndexList(parser, dynamicUbs, staticUbs,
1175 /*valueTypes=*/nullptr,
1177 parser.resolveOperands(dynamicUbs, indexType, result.operands))
1178 return failure();
1179
1180 // Parse step values.
1181 if (parser.parseKeyword("step") ||
1182 parseDynamicIndexList(parser, dynamicSteps, staticSteps,
1183 /*valueTypes=*/nullptr,
1185 parser.resolveOperands(dynamicSteps, indexType, result.operands))
1186 return failure();
1187 }
1188
1189 // Parse out operands and results.
1190 SmallVector<OpAsmParser::Argument, 4> regionOutArgs;
1191 SmallVector<OpAsmParser::UnresolvedOperand, 4> outOperands;
1192 SMLoc outOperandsLoc = parser.getCurrentLocation();
1193 if (succeeded(parser.parseOptionalKeyword("shared_outs"))) {
1194 if (outOperands.size() != result.types.size())
1195 return parser.emitError(outOperandsLoc,
1196 "mismatch between out operands and types");
1197 if (parser.parseAssignmentList(regionOutArgs, outOperands) ||
1198 parser.parseOptionalArrowTypeList(result.types) ||
1199 parser.resolveOperands(outOperands, result.types, outOperandsLoc,
1200 result.operands))
1201 return failure();
1202 }
1203
1204 // Parse region.
1205 SmallVector<OpAsmParser::Argument, 4> regionArgs;
1206 std::unique_ptr<Region> region = std::make_unique<Region>();
1207 for (auto &iv : ivs) {
1208 iv.type = b.getIndexType();
1209 regionArgs.push_back(iv);
1210 }
1211 for (const auto &it : llvm::enumerate(regionOutArgs)) {
1212 auto &out = it.value();
1213 out.type = result.types[it.index()];
1214 regionArgs.push_back(out);
1215 }
1216 if (parser.parseRegion(*region, regionArgs))
1217 return failure();
1218
1219 // Ensure terminator and move region.
1220 ForallOp::ensureTerminator(*region, b, result.location);
1221 result.addRegion(std::move(region));
1222
1223 // Parse the optional attribute list.
1224 if (parser.parseOptionalAttrDict(result.attributes))
1225 return failure();
1226
1227 result.addAttribute("staticLowerBound", staticLbs);
1228 result.addAttribute("staticUpperBound", staticUbs);
1229 result.addAttribute("staticStep", staticSteps);
1230 result.addAttribute("operandSegmentSizes",
1232 {static_cast<int32_t>(dynamicLbs.size()),
1233 static_cast<int32_t>(dynamicUbs.size()),
1234 static_cast<int32_t>(dynamicSteps.size()),
1235 static_cast<int32_t>(outOperands.size())}));
1236 return success();
1237}
1238
1239// Builder that takes loop bounds.
1240void ForallOp::build(
1241 mlir::OpBuilder &b, mlir::OperationState &result,
1242 ArrayRef<OpFoldResult> lbs, ArrayRef<OpFoldResult> ubs,
1243 ArrayRef<OpFoldResult> steps, ValueRange outputs,
1244 std::optional<ArrayAttr> mapping,
1245 function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilderFn) {
1246 SmallVector<int64_t> staticLbs, staticUbs, staticSteps;
1247 SmallVector<Value> dynamicLbs, dynamicUbs, dynamicSteps;
1248 dispatchIndexOpFoldResults(lbs, dynamicLbs, staticLbs);
1249 dispatchIndexOpFoldResults(ubs, dynamicUbs, staticUbs);
1250 dispatchIndexOpFoldResults(steps, dynamicSteps, staticSteps);
1251
1252 result.addOperands(dynamicLbs);
1253 result.addOperands(dynamicUbs);
1254 result.addOperands(dynamicSteps);
1255 result.addOperands(outputs);
1256 result.addTypes(TypeRange(outputs));
1257
1258 result.addAttribute(getStaticLowerBoundAttrName(result.name),
1259 b.getDenseI64ArrayAttr(staticLbs));
1260 result.addAttribute(getStaticUpperBoundAttrName(result.name),
1261 b.getDenseI64ArrayAttr(staticUbs));
1262 result.addAttribute(getStaticStepAttrName(result.name),
1263 b.getDenseI64ArrayAttr(staticSteps));
1264 result.addAttribute(
1265 "operandSegmentSizes",
1266 b.getDenseI32ArrayAttr({static_cast<int32_t>(dynamicLbs.size()),
1267 static_cast<int32_t>(dynamicUbs.size()),
1268 static_cast<int32_t>(dynamicSteps.size()),
1269 static_cast<int32_t>(outputs.size())}));
1270 if (mapping.has_value()) {
1271 result.addAttribute(ForallOp::getMappingAttrName(result.name),
1272 mapping.value());
1273 }
1274
1275 Region *bodyRegion = result.addRegion();
1276 OpBuilder::InsertionGuard g(b);
1277 b.createBlock(bodyRegion);
1278 Block &bodyBlock = bodyRegion->front();
1279
1280 // Add block arguments for indices and outputs.
1281 bodyBlock.addArguments(
1282 SmallVector<Type>(lbs.size(), b.getIndexType()),
1283 SmallVector<Location>(staticLbs.size(), result.location));
1284 bodyBlock.addArguments(
1285 TypeRange(outputs),
1286 SmallVector<Location>(outputs.size(), result.location));
1287
1288 b.setInsertionPointToStart(&bodyBlock);
1289 if (!bodyBuilderFn) {
1290 ForallOp::ensureTerminator(*bodyRegion, b, result.location);
1291 return;
1292 }
1293 bodyBuilderFn(b, result.location, bodyBlock.getArguments());
1294}
1295
1296// Builder that takes loop bounds.
1297void ForallOp::build(
1298 mlir::OpBuilder &b, mlir::OperationState &result,
1299 ArrayRef<OpFoldResult> ubs, ValueRange outputs,
1300 std::optional<ArrayAttr> mapping,
1301 function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilderFn) {
1302 unsigned numLoops = ubs.size();
1303 SmallVector<OpFoldResult> lbs(numLoops, b.getIndexAttr(0));
1304 SmallVector<OpFoldResult> steps(numLoops, b.getIndexAttr(1));
1305 build(b, result, lbs, ubs, steps, outputs, mapping, bodyBuilderFn);
1306}
1307
1308// Checks if the lbs are zeros and steps are ones.
1309bool ForallOp::isNormalized() {
1310 auto allEqual = [](ArrayRef<OpFoldResult> results, int64_t val) {
1311 return llvm::all_of(results, [&](OpFoldResult ofr) {
1312 auto intValue = getConstantIntValue(ofr);
1313 return intValue.has_value() && intValue == val;
1314 });
1315 };
1316 return allEqual(getMixedLowerBound(), 0) && allEqual(getMixedStep(), 1);
1317}
1318
1319InParallelOp ForallOp::getTerminator() {
1320 return cast<InParallelOp>(getBody()->getTerminator());
1321}
1322
1323SmallVector<Operation *> ForallOp::getCombiningOps(BlockArgument bbArg) {
1324 SmallVector<Operation *> storeOps;
1325 for (Operation *user : bbArg.getUsers()) {
1326 if (auto parallelOp = dyn_cast<ParallelCombiningOpInterface>(user)) {
1327 storeOps.push_back(parallelOp);
1328 }
1329 }
1330 return storeOps;
1331}
1332
1333SmallVector<DeviceMappingAttrInterface> ForallOp::getDeviceMappingAttrs() {
1334 SmallVector<DeviceMappingAttrInterface> res;
1335 if (!getMapping())
1336 return res;
1337 for (auto attr : getMapping()->getValue()) {
1338 auto m = dyn_cast<DeviceMappingAttrInterface>(attr);
1339 if (m)
1340 res.push_back(m);
1341 }
1342 return res;
1343}
1344
1345FailureOr<DeviceMaskingAttrInterface> ForallOp::getDeviceMaskingAttr() {
1346 DeviceMaskingAttrInterface res;
1347 if (!getMapping())
1348 return res;
1349 for (auto attr : getMapping()->getValue()) {
1350 auto m = dyn_cast<DeviceMaskingAttrInterface>(attr);
1351 if (m && res)
1352 return failure();
1353 if (m)
1354 res = m;
1355 }
1356 return res;
1357}
1358
1359bool ForallOp::usesLinearMapping() {
1360 SmallVector<DeviceMappingAttrInterface> ifaces = getDeviceMappingAttrs();
1361 if (ifaces.empty())
1362 return false;
1363 return ifaces.front().isLinearMapping();
1364}
1365
1366std::optional<SmallVector<Value>> ForallOp::getLoopInductionVars() {
1367 return SmallVector<Value>{getBody()->getArguments().take_front(getRank())};
1368}
1369
1370// Get lower bounds as OpFoldResult.
1371std::optional<SmallVector<OpFoldResult>> ForallOp::getLoopLowerBounds() {
1372 Builder b(getOperation()->getContext());
1373 return getMixedValues(getStaticLowerBound(), getDynamicLowerBound(), b);
1374}
1375
1376// Get upper bounds as OpFoldResult.
1377std::optional<SmallVector<OpFoldResult>> ForallOp::getLoopUpperBounds() {
1378 Builder b(getOperation()->getContext());
1379 return getMixedValues(getStaticUpperBound(), getDynamicUpperBound(), b);
1380}
1381
1382// Get steps as OpFoldResult.
1383std::optional<SmallVector<OpFoldResult>> ForallOp::getLoopSteps() {
1384 Builder b(getOperation()->getContext());
1385 return getMixedValues(getStaticStep(), getDynamicStep(), b);
1386}
1387
1389 auto tidxArg = llvm::dyn_cast<BlockArgument>(val);
1390 if (!tidxArg)
1391 return ForallOp();
1392 assert(tidxArg.getOwner() && "unlinked block argument");
1393 auto *containingOp = tidxArg.getOwner()->getParentOp();
1394 return dyn_cast<ForallOp>(containingOp);
1395}
1396
1397namespace {
1398/// Fold tensor.dim(forall shared_outs(... = %t)) to tensor.dim(%t).
1399struct DimOfForallOp : public OpRewritePattern<tensor::DimOp> {
1400 using OpRewritePattern<tensor::DimOp>::OpRewritePattern;
1401
1402 LogicalResult matchAndRewrite(tensor::DimOp dimOp,
1403 PatternRewriter &rewriter) const final {
1404 auto forallOp = dimOp.getSource().getDefiningOp<ForallOp>();
1405 if (!forallOp)
1406 return failure();
1407 Value sharedOut =
1408 forallOp.getTiedOpOperand(llvm::cast<OpResult>(dimOp.getSource()))
1409 ->get();
1410 rewriter.modifyOpInPlace(
1411 dimOp, [&]() { dimOp.getSourceMutable().assign(sharedOut); });
1412 return success();
1413 }
1414};
1415
1416class ForallOpControlOperandsFolder : public OpRewritePattern<ForallOp> {
1417public:
1418 using OpRewritePattern<ForallOp>::OpRewritePattern;
1419
1420 LogicalResult matchAndRewrite(ForallOp op,
1421 PatternRewriter &rewriter) const override {
1422 SmallVector<OpFoldResult> mixedLowerBound(op.getMixedLowerBound());
1423 SmallVector<OpFoldResult> mixedUpperBound(op.getMixedUpperBound());
1424 SmallVector<OpFoldResult> mixedStep(op.getMixedStep());
1425 if (failed(foldDynamicIndexList(mixedLowerBound)) &&
1426 failed(foldDynamicIndexList(mixedUpperBound)) &&
1427 failed(foldDynamicIndexList(mixedStep)))
1428 return failure();
1429
1430 rewriter.modifyOpInPlace(op, [&]() {
1431 SmallVector<Value> dynamicLowerBound, dynamicUpperBound, dynamicStep;
1432 SmallVector<int64_t> staticLowerBound, staticUpperBound, staticStep;
1433 dispatchIndexOpFoldResults(mixedLowerBound, dynamicLowerBound,
1434 staticLowerBound);
1435 op.getDynamicLowerBoundMutable().assign(dynamicLowerBound);
1436 op.setStaticLowerBound(staticLowerBound);
1437
1438 dispatchIndexOpFoldResults(mixedUpperBound, dynamicUpperBound,
1439 staticUpperBound);
1440 op.getDynamicUpperBoundMutable().assign(dynamicUpperBound);
1441 op.setStaticUpperBound(staticUpperBound);
1442
1443 dispatchIndexOpFoldResults(mixedStep, dynamicStep, staticStep);
1444 op.getDynamicStepMutable().assign(dynamicStep);
1445 op.setStaticStep(staticStep);
1446
1447 op->setAttr(ForallOp::getOperandSegmentSizeAttr(),
1448 rewriter.getDenseI32ArrayAttr(
1449 {static_cast<int32_t>(dynamicLowerBound.size()),
1450 static_cast<int32_t>(dynamicUpperBound.size()),
1451 static_cast<int32_t>(dynamicStep.size()),
1452 static_cast<int32_t>(op.getNumResults())}));
1453 });
1454 return success();
1455 }
1456};
1457
1458/// The following canonicalization pattern folds the iter arguments of
1459/// scf.forall op if :-
1460/// 1. The corresponding result has zero uses.
1461/// 2. The iter argument is NOT being modified within the loop body.
1462/// uses.
1463///
1464/// Example of first case :-
1465/// INPUT:
1466/// %res:3 = scf.forall ... shared_outs(%arg0 = %a, %arg1 = %b, %arg2 = %c)
1467/// {
1468/// ...
1469/// <SOME USE OF %arg0>
1470/// <SOME USE OF %arg1>
1471/// <SOME USE OF %arg2>
1472/// ...
1473/// scf.forall.in_parallel {
1474/// <STORE OP WITH DESTINATION %arg1>
1475/// <STORE OP WITH DESTINATION %arg0>
1476/// <STORE OP WITH DESTINATION %arg2>
1477/// }
1478/// }
1479/// return %res#1
1480///
1481/// OUTPUT:
1482/// %res:3 = scf.forall ... shared_outs(%new_arg0 = %b)
1483/// {
1484/// ...
1485/// <SOME USE OF %a>
1486/// <SOME USE OF %new_arg0>
1487/// <SOME USE OF %c>
1488/// ...
1489/// scf.forall.in_parallel {
1490/// <STORE OP WITH DESTINATION %new_arg0>
1491/// }
1492/// }
1493/// return %res
1494///
1495/// NOTE: 1. All uses of the folded shared_outs (iter argument) within the
1496/// scf.forall is replaced by their corresponding operands.
1497/// 2. Even if there are <STORE OP WITH DESTINATION *> ops within the body
1498/// of the scf.forall besides within scf.forall.in_parallel terminator,
1499/// this canonicalization remains valid. For more details, please refer
1500/// to :
1501/// https://github.com/llvm/llvm-project/pull/90189#discussion_r1589011124
1502/// 3. TODO(avarma): Generalize it for other store ops. Currently it
1503/// handles tensor.parallel_insert_slice ops only.
1504///
1505/// Example of second case :-
1506/// INPUT:
1507/// %res:2 = scf.forall ... shared_outs(%arg0 = %a, %arg1 = %b)
1508/// {
1509/// ...
1510/// <SOME USE OF %arg0>
1511/// <SOME USE OF %arg1>
1512/// ...
1513/// scf.forall.in_parallel {
1514/// <STORE OP WITH DESTINATION %arg1>
1515/// }
1516/// }
1517/// return %res#0, %res#1
1518///
1519/// OUTPUT:
1520/// %res = scf.forall ... shared_outs(%new_arg0 = %b)
1521/// {
1522/// ...
1523/// <SOME USE OF %a>
1524/// <SOME USE OF %new_arg0>
1525/// ...
1526/// scf.forall.in_parallel {
1527/// <STORE OP WITH DESTINATION %new_arg0>
1528/// }
1529/// }
1530/// return %a, %res
1531struct ForallOpIterArgsFolder : public OpRewritePattern<ForallOp> {
1532 using OpRewritePattern<ForallOp>::OpRewritePattern;
1533
1534 LogicalResult matchAndRewrite(ForallOp forallOp,
1535 PatternRewriter &rewriter) const final {
1536 // Step 1: For a given i-th result of scf.forall, check the following :-
1537 // a. If it has any use.
1538 // b. If the corresponding iter argument is being modified within
1539 // the loop, i.e. has at least one store op with the iter arg as
1540 // its destination operand. For this we use
1541 // ForallOp::getCombiningOps(iter_arg).
1542 //
1543 // Based on the check we maintain the following :-
1544 // a. op results, block arguments, outputs to delete
1545 // b. new outputs (i.e., outputs to retain)
1546 SmallVector<Value> resultsToDelete;
1547 SmallVector<Value> outsToDelete;
1548 SmallVector<BlockArgument> blockArgsToDelete;
1549 SmallVector<Value> newOuts;
1550 BitVector resultIndicesToDelete(forallOp.getNumResults(), false);
1551 BitVector blockIndicesToDelete(forallOp.getBody()->getNumArguments(),
1552 false);
1553 for (OpResult result : forallOp.getResults()) {
1554 OpOperand *opOperand = forallOp.getTiedOpOperand(result);
1555 BlockArgument blockArg = forallOp.getTiedBlockArgument(opOperand);
1556 if (result.use_empty() || forallOp.getCombiningOps(blockArg).empty()) {
1557 resultsToDelete.push_back(result);
1558 outsToDelete.push_back(opOperand->get());
1559 blockArgsToDelete.push_back(blockArg);
1560 resultIndicesToDelete[result.getResultNumber()] = true;
1561 blockIndicesToDelete[blockArg.getArgNumber()] = true;
1562 } else {
1563 newOuts.push_back(opOperand->get());
1564 }
1565 }
1566
1567 // Return early if all results of scf.forall have at least one use and being
1568 // modified within the loop.
1569 if (resultsToDelete.empty())
1570 return failure();
1571
1572 // Step 2: Erase combining ops and replace uses of deleted results and
1573 // block arguments with the corresponding outputs.
1574 for (auto blockArg : blockArgsToDelete) {
1575 SmallVector<Operation *> combiningOps =
1576 forallOp.getCombiningOps(blockArg);
1577 for (Operation *combiningOp : combiningOps)
1578 rewriter.eraseOp(combiningOp);
1579 }
1580 for (auto [blockArg, result, out] :
1581 llvm::zip_equal(blockArgsToDelete, resultsToDelete, outsToDelete)) {
1582 rewriter.replaceAllUsesWith(blockArg, out);
1583 rewriter.replaceAllUsesWith(result, out);
1584 }
1585 // TODO: There is no rewriter API for erasing block arguments.
1586 rewriter.modifyOpInPlace(forallOp, [&]() {
1587 forallOp.getBody()->eraseArguments(blockIndicesToDelete);
1588 });
1589
1590 // Step 3. Create a new scf.forall op with only the shared_outs/results
1591 // that should be retained.
1592 auto newForallOp = cast<scf::ForallOp>(
1593 rewriter.eraseOpResults(forallOp, resultIndicesToDelete));
1594 newForallOp.getOutputsMutable().assign(newOuts);
1595
1596 return success();
1597 }
1598};
1599
1600struct ForallOpSingleOrZeroIterationDimsFolder
1601 : public OpRewritePattern<ForallOp> {
1602 using OpRewritePattern<ForallOp>::OpRewritePattern;
1603
1604 LogicalResult matchAndRewrite(ForallOp op,
1605 PatternRewriter &rewriter) const override {
1606 // Do not fold dimensions if they are mapped to processing units.
1607 if (op.getMapping().has_value() && !op.getMapping()->empty())
1608 return failure();
1609 Location loc = op.getLoc();
1610
1611 // Compute new loop bounds that omit all single-iteration loop dimensions.
1612 SmallVector<OpFoldResult> newMixedLowerBounds, newMixedUpperBounds,
1613 newMixedSteps;
1614 IRMapping mapping;
1615 for (auto [lb, ub, step, iv] :
1616 llvm::zip(op.getMixedLowerBound(), op.getMixedUpperBound(),
1617 op.getMixedStep(), op.getInductionVars())) {
1618 auto numIterations =
1619 constantTripCount(lb, ub, step, /*isSigned=*/true, computeUbMinusLb);
1620 if (numIterations.has_value()) {
1621 // Remove the loop if it performs zero iterations.
1622 if (*numIterations == 0) {
1623 rewriter.replaceOp(op, op.getOutputs());
1624 return success();
1625 }
1626 // Replace the loop induction variable by the lower bound if the loop
1627 // performs a single iteration. Otherwise, copy the loop bounds.
1628 if (*numIterations == 1) {
1629 mapping.map(iv, getValueOrCreateConstantIndexOp(rewriter, loc, lb));
1630 continue;
1631 }
1632 }
1633 newMixedLowerBounds.push_back(lb);
1634 newMixedUpperBounds.push_back(ub);
1635 newMixedSteps.push_back(step);
1636 }
1637
1638 // All of the loop dimensions perform a single iteration. Inline loop body.
1639 if (newMixedLowerBounds.empty()) {
1640 promote(rewriter, op);
1641 return success();
1642 }
1643
1644 // Exit if none of the loop dimensions perform a single iteration.
1645 if (newMixedLowerBounds.size() == static_cast<unsigned>(op.getRank())) {
1646 return rewriter.notifyMatchFailure(
1647 op, "no dimensions have 0 or 1 iterations");
1648 }
1649
1650 // Replace the loop by a lower-dimensional loop.
1651 ForallOp newOp;
1652 newOp = ForallOp::create(rewriter, loc, newMixedLowerBounds,
1653 newMixedUpperBounds, newMixedSteps,
1654 op.getOutputs(), std::nullopt, nullptr);
1655 newOp.getBodyRegion().getBlocks().clear();
1656 // The new loop needs to keep all attributes from the old one, except for
1657 // "operandSegmentSizes" and static loop bound attributes which capture
1658 // the outdated information of the old iteration domain.
1659 SmallVector<StringAttr> elidedAttrs{newOp.getOperandSegmentSizesAttrName(),
1660 newOp.getStaticLowerBoundAttrName(),
1661 newOp.getStaticUpperBoundAttrName(),
1662 newOp.getStaticStepAttrName()};
1663 for (const auto &namedAttr : op->getAttrs()) {
1664 if (llvm::is_contained(elidedAttrs, namedAttr.getName()))
1665 continue;
1666 rewriter.modifyOpInPlace(newOp, [&]() {
1667 newOp->setAttr(namedAttr.getName(), namedAttr.getValue());
1668 });
1669 }
1670 rewriter.cloneRegionBefore(op.getRegion(), newOp.getRegion(),
1671 newOp.getRegion().begin(), mapping);
1672 rewriter.replaceOp(op, newOp.getResults());
1673 return success();
1674 }
1675};
1676
1677/// Replace all induction vars with a single trip count with their lower bound.
1678struct ForallOpReplaceConstantInductionVar : public OpRewritePattern<ForallOp> {
1679 using OpRewritePattern<ForallOp>::OpRewritePattern;
1680
1681 LogicalResult matchAndRewrite(ForallOp op,
1682 PatternRewriter &rewriter) const override {
1683 Location loc = op.getLoc();
1684 bool changed = false;
1685 for (auto [lb, ub, step, iv] :
1686 llvm::zip(op.getMixedLowerBound(), op.getMixedUpperBound(),
1687 op.getMixedStep(), op.getInductionVars())) {
1688 if (iv.hasNUses(0))
1689 continue;
1690 auto numIterations =
1691 constantTripCount(lb, ub, step, /*isSigned=*/true, computeUbMinusLb);
1692 if (!numIterations.has_value() || numIterations.value() != 1) {
1693 continue;
1694 }
1695 rewriter.replaceAllUsesWith(
1696 iv, getValueOrCreateConstantIndexOp(rewriter, loc, lb));
1697 changed = true;
1698 }
1699 return success(changed);
1700 }
1701};
1702
1703struct FoldTensorCastOfOutputIntoForallOp
1704 : public OpRewritePattern<scf::ForallOp> {
1705 using OpRewritePattern<scf::ForallOp>::OpRewritePattern;
1706
1707 struct TypeCast {
1708 Type srcType;
1709 Type dstType;
1710 };
1711
1712 LogicalResult matchAndRewrite(scf::ForallOp forallOp,
1713 PatternRewriter &rewriter) const final {
1714 llvm::SmallMapVector<unsigned, TypeCast, 2> tensorCastProducers;
1715 llvm::SmallVector<Value> newOutputTensors = forallOp.getOutputs();
1716 for (auto en : llvm::enumerate(newOutputTensors)) {
1717 auto castOp = en.value().getDefiningOp<tensor::CastOp>();
1718 if (!castOp)
1719 continue;
1720
1721 // Only casts that that preserve static information, i.e. will make the
1722 // loop result type "more" static than before, will be folded.
1723 if (!tensor::preservesStaticInformation(castOp.getDest().getType(),
1724 castOp.getSource().getType())) {
1725 continue;
1726 }
1727
1728 tensorCastProducers[en.index()] =
1729 TypeCast{castOp.getSource().getType(), castOp.getType()};
1730 newOutputTensors[en.index()] = castOp.getSource();
1731 }
1732
1733 if (tensorCastProducers.empty())
1734 return failure();
1735
1736 // Create new loop.
1737 Location loc = forallOp.getLoc();
1738 auto newForallOp = ForallOp::create(
1739 rewriter, loc, forallOp.getMixedLowerBound(),
1740 forallOp.getMixedUpperBound(), forallOp.getMixedStep(),
1741 newOutputTensors, forallOp.getMapping(),
1742 [&](OpBuilder nestedBuilder, Location nestedLoc, ValueRange bbArgs) {
1743 auto castBlockArgs =
1744 llvm::to_vector(bbArgs.take_back(forallOp->getNumResults()));
1745 for (auto [index, cast] : tensorCastProducers) {
1746 Value &oldTypeBBArg = castBlockArgs[index];
1747 oldTypeBBArg = tensor::CastOp::create(nestedBuilder, nestedLoc,
1748 cast.dstType, oldTypeBBArg);
1749 }
1750
1751 // Move old body into new parallel loop.
1752 SmallVector<Value> ivsBlockArgs =
1753 llvm::to_vector(bbArgs.take_front(forallOp.getRank()));
1754 ivsBlockArgs.append(castBlockArgs);
1755 rewriter.mergeBlocks(forallOp.getBody(),
1756 bbArgs.front().getParentBlock(), ivsBlockArgs);
1757 });
1758
1759 // After `mergeBlocks` happened, the destinations in the terminator were
1760 // mapped to the tensor.cast old-typed results of the output bbArgs. The
1761 // destination have to be updated to point to the output bbArgs directly.
1762 auto terminator = newForallOp.getTerminator();
1763 for (auto [yieldingOp, outputBlockArg] : llvm::zip(
1764 terminator.getYieldingOps(), newForallOp.getRegionIterArgs())) {
1765 if (auto parallelCombingingOp =
1766 dyn_cast<ParallelCombiningOpInterface>(yieldingOp)) {
1767 parallelCombingingOp.getUpdatedDestinations().assign(outputBlockArg);
1768 }
1769 }
1770
1771 // Cast results back to the original types.
1772 rewriter.setInsertionPointAfter(newForallOp);
1773 SmallVector<Value> castResults = newForallOp.getResults();
1774 for (auto &item : tensorCastProducers) {
1775 Value &oldTypeResult = castResults[item.first];
1776 oldTypeResult = tensor::CastOp::create(rewriter, loc, item.second.dstType,
1777 oldTypeResult);
1778 }
1779 rewriter.replaceOp(forallOp, castResults);
1780 return success();
1781 }
1782};
1783
1784} // namespace
1785
1786void ForallOp::getCanonicalizationPatterns(RewritePatternSet &results,
1787 MLIRContext *context) {
1788 results.add<DimOfForallOp, FoldTensorCastOfOutputIntoForallOp,
1789 ForallOpControlOperandsFolder, ForallOpIterArgsFolder,
1790 ForallOpSingleOrZeroIterationDimsFolder,
1791 ForallOpReplaceConstantInductionVar>(context);
1792}
1793
1794void ForallOp::getSuccessorRegions(RegionBranchPoint point,
1795 SmallVectorImpl<RegionSuccessor> &regions) {
1796 // There are two region branch points:
1797 // 1. "parent": entering the forall op for the first time.
1798 // 2. scf.in_parallel terminator
1799 if (point.isParent()) {
1800 // When first entering the forall op, the control flow typically branches
1801 // into the forall body. (In parallel for multiple threads.)
1802 regions.push_back(RegionSuccessor(&getRegion()));
1803 // However, when there are 0 threads, the control flow may branch back to
1804 // the parent immediately.
1805 regions.push_back(RegionSuccessor::parent());
1806 } else {
1807 // In accordance with the semantics of forall, its body is executed in
1808 // parallel by multiple threads. We should not expect to branch back into
1809 // the forall body after the region's execution is complete.
1810 regions.push_back(RegionSuccessor::parent());
1811 }
1812}
1813
1814//===----------------------------------------------------------------------===//
1815// InParallelOp
1816//===----------------------------------------------------------------------===//
1817
1818// Build a InParallelOp with mixed static and dynamic entries.
1819void InParallelOp::build(OpBuilder &b, OperationState &result) {
1820 OpBuilder::InsertionGuard g(b);
1821 Region *bodyRegion = result.addRegion();
1822 b.createBlock(bodyRegion);
1823}
1824
1825LogicalResult InParallelOp::verify() {
1826 scf::ForallOp forallOp =
1827 dyn_cast<scf::ForallOp>(getOperation()->getParentOp());
1828 if (!forallOp)
1829 return this->emitOpError("expected forall op parent");
1830
1831 for (Operation &op : getRegion().front().getOperations()) {
1832 auto parallelCombiningOp = dyn_cast<ParallelCombiningOpInterface>(&op);
1833 if (!parallelCombiningOp) {
1834 return this->emitOpError("expected only ParallelCombiningOpInterface")
1835 << " ops";
1836 }
1837
1838 // Verify that inserts are into out block arguments.
1839 MutableOperandRange dests = parallelCombiningOp.getUpdatedDestinations();
1840 ArrayRef<BlockArgument> regionOutArgs = forallOp.getRegionOutArgs();
1841 for (OpOperand &dest : dests) {
1842 if (!llvm::is_contained(regionOutArgs, dest.get()))
1843 return op.emitOpError("may only insert into an output block argument");
1844 }
1845 }
1846
1847 return success();
1848}
1849
1850void InParallelOp::print(OpAsmPrinter &p) {
1851 p << " ";
1852 p.printRegion(getRegion(),
1853 /*printEntryBlockArgs=*/false,
1854 /*printBlockTerminators=*/false);
1855 p.printOptionalAttrDict(getOperation()->getAttrs());
1856}
1857
1858ParseResult InParallelOp::parse(OpAsmParser &parser, OperationState &result) {
1859 auto &builder = parser.getBuilder();
1860
1861 SmallVector<OpAsmParser::Argument, 8> regionOperands;
1862 std::unique_ptr<Region> region = std::make_unique<Region>();
1863 if (parser.parseRegion(*region, regionOperands))
1864 return failure();
1865
1866 if (region->empty())
1867 OpBuilder(builder.getContext()).createBlock(region.get());
1868 result.addRegion(std::move(region));
1869
1870 // Parse the optional attribute list.
1871 if (parser.parseOptionalAttrDict(result.attributes))
1872 return failure();
1873 return success();
1874}
1875
1876OpResult InParallelOp::getParentResult(int64_t idx) {
1877 return getOperation()->getParentOp()->getResult(idx);
1878}
1879
1880SmallVector<BlockArgument> InParallelOp::getDests() {
1881 SmallVector<BlockArgument> updatedDests;
1882 for (Operation &yieldingOp : getYieldingOps()) {
1883 auto parallelCombiningOp =
1884 dyn_cast<ParallelCombiningOpInterface>(&yieldingOp);
1885 if (!parallelCombiningOp)
1886 continue;
1887 for (OpOperand &updatedOperand :
1888 parallelCombiningOp.getUpdatedDestinations())
1889 updatedDests.push_back(cast<BlockArgument>(updatedOperand.get()));
1890 }
1891 return updatedDests;
1892}
1893
1894llvm::iterator_range<Block::iterator> InParallelOp::getYieldingOps() {
1895 return getRegion().front().getOperations();
1896}
1897
1898//===----------------------------------------------------------------------===//
1899// IfOp
1900//===----------------------------------------------------------------------===//
1901
1903 assert(a && "expected non-empty operation");
1904 assert(b && "expected non-empty operation");
1905
1906 IfOp ifOp = a->getParentOfType<IfOp>();
1907 while (ifOp) {
1908 // Check if b is inside ifOp. (We already know that a is.)
1909 if (ifOp->isProperAncestor(b))
1910 // b is contained in ifOp. a and b are in mutually exclusive branches if
1911 // they are in different blocks of ifOp.
1912 return static_cast<bool>(ifOp.thenBlock()->findAncestorOpInBlock(*a)) !=
1913 static_cast<bool>(ifOp.thenBlock()->findAncestorOpInBlock(*b));
1914 // Check next enclosing IfOp.
1915 ifOp = ifOp->getParentOfType<IfOp>();
1916 }
1917
1918 // Could not find a common IfOp among a's and b's ancestors.
1919 return false;
1920}
1921
1922LogicalResult
1923IfOp::inferReturnTypes(MLIRContext *ctx, std::optional<Location> loc,
1924 IfOp::Adaptor adaptor,
1925 SmallVectorImpl<Type> &inferredReturnTypes) {
1926 if (adaptor.getRegions().empty())
1927 return failure();
1928 Region *r = &adaptor.getThenRegion();
1929 if (r->empty())
1930 return failure();
1931 Block &b = r->front();
1932 if (b.empty())
1933 return failure();
1934 auto yieldOp = llvm::dyn_cast<YieldOp>(b.back());
1935 if (!yieldOp)
1936 return failure();
1937 TypeRange types = yieldOp.getOperandTypes();
1938 llvm::append_range(inferredReturnTypes, types);
1939 return success();
1940}
1941
1942void IfOp::build(OpBuilder &builder, OperationState &result,
1943 TypeRange resultTypes, Value cond) {
1944 return build(builder, result, resultTypes, cond, /*addThenBlock=*/false,
1945 /*addElseBlock=*/false);
1946}
1947
1948void IfOp::build(OpBuilder &builder, OperationState &result,
1949 TypeRange resultTypes, Value cond, bool addThenBlock,
1950 bool addElseBlock) {
1951 assert((!addElseBlock || addThenBlock) &&
1952 "must not create else block w/o then block");
1953 result.addTypes(resultTypes);
1954 result.addOperands(cond);
1955
1956 // Add regions and blocks.
1957 OpBuilder::InsertionGuard guard(builder);
1958 Region *thenRegion = result.addRegion();
1959 if (addThenBlock)
1960 builder.createBlock(thenRegion);
1961 Region *elseRegion = result.addRegion();
1962 if (addElseBlock)
1963 builder.createBlock(elseRegion);
1964}
1965
1966void IfOp::build(OpBuilder &builder, OperationState &result, Value cond,
1967 bool withElseRegion) {
1968 build(builder, result, TypeRange{}, cond, withElseRegion);
1969}
1970
1971void IfOp::build(OpBuilder &builder, OperationState &result,
1972 TypeRange resultTypes, Value cond, bool withElseRegion) {
1973 result.addTypes(resultTypes);
1974 result.addOperands(cond);
1975
1976 // Build then region.
1977 OpBuilder::InsertionGuard guard(builder);
1978 Region *thenRegion = result.addRegion();
1979 builder.createBlock(thenRegion);
1980 if (resultTypes.empty())
1981 IfOp::ensureTerminator(*thenRegion, builder, result.location);
1982
1983 // Build else region.
1984 Region *elseRegion = result.addRegion();
1985 if (withElseRegion) {
1986 builder.createBlock(elseRegion);
1987 if (resultTypes.empty())
1988 IfOp::ensureTerminator(*elseRegion, builder, result.location);
1989 }
1990}
1991
1992void IfOp::build(OpBuilder &builder, OperationState &result, Value cond,
1993 function_ref<void(OpBuilder &, Location)> thenBuilder,
1994 function_ref<void(OpBuilder &, Location)> elseBuilder) {
1995 assert(thenBuilder && "the builder callback for 'then' must be present");
1996 result.addOperands(cond);
1997
1998 // Build then region.
1999 OpBuilder::InsertionGuard guard(builder);
2000 Region *thenRegion = result.addRegion();
2001 builder.createBlock(thenRegion);
2002 thenBuilder(builder, result.location);
2003
2004 // Build else region.
2005 Region *elseRegion = result.addRegion();
2006 if (elseBuilder) {
2007 builder.createBlock(elseRegion);
2008 elseBuilder(builder, result.location);
2009 }
2010
2011 // Infer result types.
2012 SmallVector<Type> inferredReturnTypes;
2013 MLIRContext *ctx = builder.getContext();
2014 auto attrDict = DictionaryAttr::get(ctx, result.attributes);
2015 if (succeeded(inferReturnTypes(ctx, std::nullopt, result.operands, attrDict,
2016 /*properties=*/nullptr, result.regions,
2017 inferredReturnTypes))) {
2018 result.addTypes(inferredReturnTypes);
2019 }
2020}
2021
2022LogicalResult IfOp::verify() {
2023 if (getNumResults() != 0 && getElseRegion().empty())
2024 return emitOpError("must have an else block if defining values");
2025 return success();
2026}
2027
2028ParseResult IfOp::parse(OpAsmParser &parser, OperationState &result) {
2029 // Create the regions for 'then'.
2030 result.regions.reserve(2);
2031 Region *thenRegion = result.addRegion();
2032 Region *elseRegion = result.addRegion();
2033
2034 auto &builder = parser.getBuilder();
2035 OpAsmParser::UnresolvedOperand cond;
2036 Type i1Type = builder.getIntegerType(1);
2037 if (parser.parseOperand(cond) ||
2038 parser.resolveOperand(cond, i1Type, result.operands))
2039 return failure();
2040 // Parse optional results type list.
2041 if (parser.parseOptionalArrowTypeList(result.types))
2042 return failure();
2043 // Parse the 'then' region.
2044 if (parser.parseRegion(*thenRegion, /*arguments=*/{}, /*argTypes=*/{}))
2045 return failure();
2046 IfOp::ensureTerminator(*thenRegion, parser.getBuilder(), result.location);
2047
2048 // If we find an 'else' keyword then parse the 'else' region.
2049 if (!parser.parseOptionalKeyword("else")) {
2050 if (parser.parseRegion(*elseRegion, /*arguments=*/{}, /*argTypes=*/{}))
2051 return failure();
2052 IfOp::ensureTerminator(*elseRegion, parser.getBuilder(), result.location);
2053 }
2054
2055 // Parse the optional attribute list.
2056 if (parser.parseOptionalAttrDict(result.attributes))
2057 return failure();
2058 return success();
2059}
2060
2061void IfOp::print(OpAsmPrinter &p) {
2062 bool printBlockTerminators = false;
2063
2064 p << " " << getCondition();
2065 if (!getResults().empty()) {
2066 p << " -> (" << getResultTypes() << ")";
2067 // Print yield explicitly if the op defines values.
2068 printBlockTerminators = true;
2069 }
2070 p << ' ';
2071 p.printRegion(getThenRegion(),
2072 /*printEntryBlockArgs=*/false,
2073 /*printBlockTerminators=*/printBlockTerminators);
2074
2075 // Print the 'else' regions if it exists and has a block.
2076 auto &elseRegion = getElseRegion();
2077 if (!elseRegion.empty()) {
2078 p << " else ";
2079 p.printRegion(elseRegion,
2080 /*printEntryBlockArgs=*/false,
2081 /*printBlockTerminators=*/printBlockTerminators);
2082 }
2083
2084 p.printOptionalAttrDict((*this)->getAttrs());
2085}
2086
2087void IfOp::getSuccessorRegions(RegionBranchPoint point,
2088 SmallVectorImpl<RegionSuccessor> &regions) {
2089 // The `then` and the `else` region branch back to the parent operation or one
2090 // of the recursive parent operations (early exit case).
2091 if (!point.isParent()) {
2092 regions.push_back(RegionSuccessor::parent());
2093 return;
2094 }
2095
2096 regions.push_back(RegionSuccessor(&getThenRegion()));
2097
2098 // Don't consider the else region if it is empty.
2099 Region *elseRegion = &this->getElseRegion();
2100 if (elseRegion->empty())
2101 regions.push_back(RegionSuccessor::parent());
2102 else
2103 regions.push_back(RegionSuccessor(elseRegion));
2104}
2105
2106ValueRange IfOp::getSuccessorInputs(RegionSuccessor successor) {
2107 return successor.isParent() ? ValueRange(getOperation()->getResults())
2108 : ValueRange();
2109}
2110
2111void IfOp::getEntrySuccessorRegions(ArrayRef<Attribute> operands,
2112 SmallVectorImpl<RegionSuccessor> &regions) {
2113 FoldAdaptor adaptor(operands, *this);
2114 auto boolAttr = dyn_cast_or_null<BoolAttr>(adaptor.getCondition());
2115 if (!boolAttr || boolAttr.getValue())
2116 regions.emplace_back(&getThenRegion());
2117
2118 // If the else region is empty, execution continues after the parent op.
2119 if (!boolAttr || !boolAttr.getValue()) {
2120 if (!getElseRegion().empty())
2121 regions.emplace_back(&getElseRegion());
2122 else
2123 regions.emplace_back(RegionSuccessor::parent());
2124 }
2125}
2126
2127LogicalResult IfOp::fold(FoldAdaptor adaptor,
2128 SmallVectorImpl<OpFoldResult> &results) {
2129 // if (!c) then A() else B() -> if c then B() else A()
2130 if (getElseRegion().empty())
2131 return failure();
2132
2133 arith::XOrIOp xorStmt = getCondition().getDefiningOp<arith::XOrIOp>();
2134 if (!xorStmt)
2135 return failure();
2136
2137 if (!matchPattern(xorStmt.getRhs(), m_One()))
2138 return failure();
2139
2140 getConditionMutable().assign(xorStmt.getLhs());
2141 Block *thenBlock = &getThenRegion().front();
2142 // It would be nicer to use iplist::swap, but that has no implemented
2143 // callbacks See: https://llvm.org/doxygen/ilist_8h_source.html#l00224
2144 getThenRegion().getBlocks().splice(getThenRegion().getBlocks().begin(),
2145 getElseRegion().getBlocks());
2146 getElseRegion().getBlocks().splice(getElseRegion().getBlocks().begin(),
2147 getThenRegion().getBlocks(), thenBlock);
2148 return success();
2149}
2150
2151void IfOp::getRegionInvocationBounds(
2152 ArrayRef<Attribute> operands,
2153 SmallVectorImpl<InvocationBounds> &invocationBounds) {
2154 if (auto cond = llvm::dyn_cast_or_null<BoolAttr>(operands[0])) {
2155 // If the condition is known, then one region is known to be executed once
2156 // and the other zero times.
2157 invocationBounds.emplace_back(0, cond.getValue() ? 1 : 0);
2158 invocationBounds.emplace_back(0, cond.getValue() ? 0 : 1);
2159 } else {
2160 // Non-constant condition. Each region may be executed 0 or 1 times.
2161 invocationBounds.assign(2, {0, 1});
2162 }
2163}
2164
2165namespace {
2166/// Hoist any yielded results whose operands are defined outside
2167/// the if, to a select instruction.
2168struct ConvertTrivialIfToSelect : public OpRewritePattern<IfOp> {
2169 using OpRewritePattern<IfOp>::OpRewritePattern;
2170
2171 LogicalResult matchAndRewrite(IfOp op,
2172 PatternRewriter &rewriter) const override {
2173 if (op->getNumResults() == 0)
2174 return failure();
2175
2176 auto cond = op.getCondition();
2177 auto thenYieldArgs = op.thenYield().getOperands();
2178 auto elseYieldArgs = op.elseYield().getOperands();
2179
2180 SmallVector<Type> nonHoistable;
2181 for (auto [trueVal, falseVal] : llvm::zip(thenYieldArgs, elseYieldArgs)) {
2182 if (&op.getThenRegion() == trueVal.getParentRegion() ||
2183 &op.getElseRegion() == falseVal.getParentRegion())
2184 nonHoistable.push_back(trueVal.getType());
2185 }
2186 // Early exit if there aren't any yielded values we can
2187 // hoist outside the if.
2188 if (nonHoistable.size() == op->getNumResults())
2189 return failure();
2190
2191 IfOp replacement = IfOp::create(rewriter, op.getLoc(), nonHoistable, cond,
2192 /*withElseRegion=*/false);
2193 if (replacement.thenBlock())
2194 rewriter.eraseBlock(replacement.thenBlock());
2195 replacement.getThenRegion().takeBody(op.getThenRegion());
2196 replacement.getElseRegion().takeBody(op.getElseRegion());
2197
2198 SmallVector<Value> results(op->getNumResults());
2199 assert(thenYieldArgs.size() == results.size());
2200 assert(elseYieldArgs.size() == results.size());
2201
2202 SmallVector<Value> trueYields;
2203 SmallVector<Value> falseYields;
2205 for (const auto &it :
2206 llvm::enumerate(llvm::zip(thenYieldArgs, elseYieldArgs))) {
2207 Value trueVal = std::get<0>(it.value());
2208 Value falseVal = std::get<1>(it.value());
2209 if (&replacement.getThenRegion() == trueVal.getParentRegion() ||
2210 &replacement.getElseRegion() == falseVal.getParentRegion()) {
2211 results[it.index()] = replacement.getResult(trueYields.size());
2212 trueYields.push_back(trueVal);
2213 falseYields.push_back(falseVal);
2214 } else if (trueVal == falseVal)
2215 results[it.index()] = trueVal;
2216 else
2217 results[it.index()] = arith::SelectOp::create(rewriter, op.getLoc(),
2218 cond, trueVal, falseVal);
2219 }
2220
2221 rewriter.setInsertionPointToEnd(replacement.thenBlock());
2222 rewriter.replaceOpWithNewOp<YieldOp>(replacement.thenYield(), trueYields);
2223
2224 rewriter.setInsertionPointToEnd(replacement.elseBlock());
2225 rewriter.replaceOpWithNewOp<YieldOp>(replacement.elseYield(), falseYields);
2226
2227 rewriter.replaceOp(op, results);
2228 return success();
2229 }
2230};
2231
2232/// Allow the true region of an if to assume the condition is true
2233/// and vice versa. For example:
2234///
2235/// scf.if %cmp {
2236/// print(%cmp)
2237/// }
2238///
2239/// becomes
2240///
2241/// scf.if %cmp {
2242/// print(true)
2243/// }
2244///
2245struct ConditionPropagation : public OpRewritePattern<IfOp> {
2246 using OpRewritePattern<IfOp>::OpRewritePattern;
2247
2248 /// Kind of parent region in the ancestor cache.
2249 enum class Parent { Then, Else, None };
2250
2251 /// Returns the kind of region ("then", "else", or "none") of the
2252 /// IfOp that the given region is transitively nested in. Updates
2253 /// the cache accordingly.
2254 static Parent getParentType(Region *toCheck, IfOp op,
2256 Region *endRegion) {
2257 SmallVector<Region *> seen;
2258 while (toCheck != endRegion) {
2259 auto found = cache.find(toCheck);
2260 if (found != cache.end())
2261 return found->second;
2262 seen.push_back(toCheck);
2263 if (&op.getThenRegion() == toCheck) {
2264 for (Region *region : seen)
2265 cache[region] = Parent::Then;
2266 return Parent::Then;
2267 }
2268 if (&op.getElseRegion() == toCheck) {
2269 for (Region *region : seen)
2270 cache[region] = Parent::Else;
2271 return Parent::Else;
2272 }
2273 toCheck = toCheck->getParentRegion();
2274 }
2275
2276 for (Region *region : seen)
2277 cache[region] = Parent::None;
2278 return Parent::None;
2279 }
2280
2281 LogicalResult matchAndRewrite(IfOp op,
2282 PatternRewriter &rewriter) const override {
2283 // Early exit if the condition is constant since replacing a constant
2284 // in the body with another constant isn't a simplification.
2285 if (matchPattern(op.getCondition(), m_Constant()))
2286 return failure();
2287
2288 bool changed = false;
2289 mlir::Type i1Ty = rewriter.getI1Type();
2290
2291 // These variables serve to prevent creating duplicate constants
2292 // and hold constant true or false values.
2293 Value constantTrue = nullptr;
2294 Value constantFalse = nullptr;
2295
2297 for (OpOperand &use :
2298 llvm::make_early_inc_range(op.getCondition().getUses())) {
2299 switch (getParentType(use.getOwner()->getParentRegion(), op, cache,
2300 op.getCondition().getParentRegion())) {
2301 case Parent::Then: {
2302 changed = true;
2303
2304 if (!constantTrue)
2305 constantTrue = arith::ConstantOp::create(
2306 rewriter, op.getLoc(), i1Ty, rewriter.getIntegerAttr(i1Ty, 1));
2307
2308 rewriter.modifyOpInPlace(use.getOwner(),
2309 [&]() { use.set(constantTrue); });
2310 break;
2311 }
2312 case Parent::Else: {
2313 changed = true;
2314
2315 if (!constantFalse)
2316 constantFalse = arith::ConstantOp::create(
2317 rewriter, op.getLoc(), i1Ty, rewriter.getIntegerAttr(i1Ty, 0));
2318
2319 rewriter.modifyOpInPlace(use.getOwner(),
2320 [&]() { use.set(constantFalse); });
2321 break;
2322 }
2323 case Parent::None:
2324 break;
2325 }
2326 }
2327
2328 return success(changed);
2329 }
2330};
2331
2332/// Remove any statements from an if that are equivalent to the condition
2333/// or its negation. For example:
2334///
2335/// %res:2 = scf.if %cmp {
2336/// yield something(), true
2337/// } else {
2338/// yield something2(), false
2339/// }
2340/// print(%res#1)
2341///
2342/// becomes
2343/// %res = scf.if %cmp {
2344/// yield something()
2345/// } else {
2346/// yield something2()
2347/// }
2348/// print(%cmp)
2349///
2350/// Additionally if both branches yield the same value, replace all uses
2351/// of the result with the yielded value.
2352///
2353/// %res:2 = scf.if %cmp {
2354/// yield something(), %arg1
2355/// } else {
2356/// yield something2(), %arg1
2357/// }
2358/// print(%res#1)
2359///
2360/// becomes
2361/// %res = scf.if %cmp {
2362/// yield something()
2363/// } else {
2364/// yield something2()
2365/// }
2366/// print(%arg1)
2367///
2368struct ReplaceIfYieldWithConditionOrValue : public OpRewritePattern<IfOp> {
2369 using OpRewritePattern<IfOp>::OpRewritePattern;
2370
2371 LogicalResult matchAndRewrite(IfOp op,
2372 PatternRewriter &rewriter) const override {
2373 // Early exit if there are no results that could be replaced.
2374 if (op.getNumResults() == 0)
2375 return failure();
2376
2377 auto trueYield =
2378 cast<scf::YieldOp>(op.getThenRegion().back().getTerminator());
2379 auto falseYield =
2380 cast<scf::YieldOp>(op.getElseRegion().back().getTerminator());
2381
2382 rewriter.setInsertionPoint(op->getBlock(),
2383 op.getOperation()->getIterator());
2384 bool changed = false;
2385 Type i1Ty = rewriter.getI1Type();
2386 for (auto [trueResult, falseResult, opResult] :
2387 llvm::zip(trueYield.getResults(), falseYield.getResults(),
2388 op.getResults())) {
2389 if (trueResult == falseResult) {
2390 if (!opResult.use_empty()) {
2391 opResult.replaceAllUsesWith(trueResult);
2392 changed = true;
2393 }
2394 continue;
2395 }
2396
2397 BoolAttr trueYield, falseYield;
2398 if (!matchPattern(trueResult, m_Constant(&trueYield)) ||
2399 !matchPattern(falseResult, m_Constant(&falseYield)))
2400 continue;
2401
2402 bool trueVal = trueYield.getValue();
2403 bool falseVal = falseYield.getValue();
2404 if (!trueVal && falseVal) {
2405 if (!opResult.use_empty()) {
2406 Dialect *constDialect = trueResult.getDefiningOp()->getDialect();
2407 Value notCond = arith::XOrIOp::create(
2408 rewriter, op.getLoc(), op.getCondition(),
2409 constDialect
2410 ->materializeConstant(rewriter,
2411 rewriter.getIntegerAttr(i1Ty, 1), i1Ty,
2412 op.getLoc())
2413 ->getResult(0));
2414 opResult.replaceAllUsesWith(notCond);
2415 changed = true;
2416 }
2417 }
2418 if (trueVal && !falseVal) {
2419 if (!opResult.use_empty()) {
2420 opResult.replaceAllUsesWith(op.getCondition());
2421 changed = true;
2422 }
2423 }
2424 }
2425 return success(changed);
2426 }
2427};
2428
2429/// Merge any consecutive scf.if's with the same condition.
2430///
2431/// scf.if %cond {
2432/// firstCodeTrue();...
2433/// } else {
2434/// firstCodeFalse();...
2435/// }
2436/// %res = scf.if %cond {
2437/// secondCodeTrue();...
2438/// } else {
2439/// secondCodeFalse();...
2440/// }
2441///
2442/// becomes
2443/// %res = scf.if %cmp {
2444/// firstCodeTrue();...
2445/// secondCodeTrue();...
2446/// } else {
2447/// firstCodeFalse();...
2448/// secondCodeFalse();...
2449/// }
2450struct CombineIfs : public OpRewritePattern<IfOp> {
2451 using OpRewritePattern<IfOp>::OpRewritePattern;
2452
2453 LogicalResult matchAndRewrite(IfOp nextIf,
2454 PatternRewriter &rewriter) const override {
2455 Block *parent = nextIf->getBlock();
2456 if (nextIf == &parent->front())
2457 return failure();
2458
2459 auto prevIf = dyn_cast<IfOp>(nextIf->getPrevNode());
2460 if (!prevIf)
2461 return failure();
2462
2463 // Determine the logical then/else blocks when prevIf's
2464 // condition is used. Null means the block does not exist
2465 // in that case (e.g. empty else). If neither of these
2466 // are set, the two conditions cannot be compared.
2467 Block *nextThen = nullptr;
2468 Block *nextElse = nullptr;
2469 if (nextIf.getCondition() == prevIf.getCondition()) {
2470 nextThen = nextIf.thenBlock();
2471 if (!nextIf.getElseRegion().empty())
2472 nextElse = nextIf.elseBlock();
2473 }
2474 if (arith::XOrIOp notv =
2475 nextIf.getCondition().getDefiningOp<arith::XOrIOp>()) {
2476 if (notv.getLhs() == prevIf.getCondition() &&
2477 matchPattern(notv.getRhs(), m_One())) {
2478 nextElse = nextIf.thenBlock();
2479 if (!nextIf.getElseRegion().empty())
2480 nextThen = nextIf.elseBlock();
2481 }
2482 }
2483 if (arith::XOrIOp notv =
2484 prevIf.getCondition().getDefiningOp<arith::XOrIOp>()) {
2485 if (notv.getLhs() == nextIf.getCondition() &&
2486 matchPattern(notv.getRhs(), m_One())) {
2487 nextElse = nextIf.thenBlock();
2488 if (!nextIf.getElseRegion().empty())
2489 nextThen = nextIf.elseBlock();
2490 }
2491 }
2492
2493 if (!nextThen && !nextElse)
2494 return failure();
2495
2496 SmallVector<Value> prevElseYielded;
2497 if (!prevIf.getElseRegion().empty())
2498 prevElseYielded = prevIf.elseYield().getOperands();
2499 // Replace all uses of return values of op within nextIf with the
2500 // corresponding yields
2501 for (auto it : llvm::zip(prevIf.getResults(),
2502 prevIf.thenYield().getOperands(), prevElseYielded))
2503 for (OpOperand &use :
2504 llvm::make_early_inc_range(std::get<0>(it).getUses())) {
2505 if (nextThen && nextThen->getParent()->isAncestor(
2506 use.getOwner()->getParentRegion())) {
2507 rewriter.startOpModification(use.getOwner());
2508 use.set(std::get<1>(it));
2509 rewriter.finalizeOpModification(use.getOwner());
2510 } else if (nextElse && nextElse->getParent()->isAncestor(
2511 use.getOwner()->getParentRegion())) {
2512 rewriter.startOpModification(use.getOwner());
2513 use.set(std::get<2>(it));
2514 rewriter.finalizeOpModification(use.getOwner());
2515 }
2516 }
2517
2518 SmallVector<Type> mergedTypes(prevIf.getResultTypes());
2519 llvm::append_range(mergedTypes, nextIf.getResultTypes());
2520
2521 IfOp combinedIf = IfOp::create(rewriter, nextIf.getLoc(), mergedTypes,
2522 prevIf.getCondition(), /*hasElse=*/false);
2523 rewriter.eraseBlock(&combinedIf.getThenRegion().back());
2524
2525 rewriter.inlineRegionBefore(prevIf.getThenRegion(),
2526 combinedIf.getThenRegion(),
2527 combinedIf.getThenRegion().begin());
2528
2529 if (nextThen) {
2530 YieldOp thenYield = combinedIf.thenYield();
2531 YieldOp thenYield2 = cast<YieldOp>(nextThen->getTerminator());
2532 rewriter.mergeBlocks(nextThen, combinedIf.thenBlock());
2533 rewriter.setInsertionPointToEnd(combinedIf.thenBlock());
2534
2535 SmallVector<Value> mergedYields(thenYield.getOperands());
2536 llvm::append_range(mergedYields, thenYield2.getOperands());
2537 YieldOp::create(rewriter, thenYield2.getLoc(), mergedYields);
2538 rewriter.eraseOp(thenYield);
2539 rewriter.eraseOp(thenYield2);
2540 }
2541
2542 rewriter.inlineRegionBefore(prevIf.getElseRegion(),
2543 combinedIf.getElseRegion(),
2544 combinedIf.getElseRegion().begin());
2545
2546 if (nextElse) {
2547 if (combinedIf.getElseRegion().empty()) {
2548 rewriter.inlineRegionBefore(*nextElse->getParent(),
2549 combinedIf.getElseRegion(),
2550 combinedIf.getElseRegion().begin());
2551 } else {
2552 YieldOp elseYield = combinedIf.elseYield();
2553 YieldOp elseYield2 = cast<YieldOp>(nextElse->getTerminator());
2554 rewriter.mergeBlocks(nextElse, combinedIf.elseBlock());
2555
2556 rewriter.setInsertionPointToEnd(combinedIf.elseBlock());
2557
2558 SmallVector<Value> mergedElseYields(elseYield.getOperands());
2559 llvm::append_range(mergedElseYields, elseYield2.getOperands());
2560
2561 YieldOp::create(rewriter, elseYield2.getLoc(), mergedElseYields);
2562 rewriter.eraseOp(elseYield);
2563 rewriter.eraseOp(elseYield2);
2564 }
2565 }
2566
2567 SmallVector<Value> prevValues;
2568 SmallVector<Value> nextValues;
2569 for (const auto &pair : llvm::enumerate(combinedIf.getResults())) {
2570 if (pair.index() < prevIf.getNumResults())
2571 prevValues.push_back(pair.value());
2572 else
2573 nextValues.push_back(pair.value());
2574 }
2575 rewriter.replaceOp(prevIf, prevValues);
2576 rewriter.replaceOp(nextIf, nextValues);
2577 return success();
2578 }
2579};
2580
2581/// Pattern to remove an empty else branch.
2582struct RemoveEmptyElseBranch : public OpRewritePattern<IfOp> {
2583 using OpRewritePattern<IfOp>::OpRewritePattern;
2584
2585 LogicalResult matchAndRewrite(IfOp ifOp,
2586 PatternRewriter &rewriter) const override {
2587 // Cannot remove else region when there are operation results.
2588 if (ifOp.getNumResults())
2589 return failure();
2590 Block *elseBlock = ifOp.elseBlock();
2591 if (!elseBlock || !llvm::hasSingleElement(*elseBlock))
2592 return failure();
2593 auto newIfOp = rewriter.cloneWithoutRegions(ifOp);
2594 rewriter.inlineRegionBefore(ifOp.getThenRegion(), newIfOp.getThenRegion(),
2595 newIfOp.getThenRegion().begin());
2596 rewriter.eraseOp(ifOp);
2597 return success();
2598 }
2599};
2600
2601/// Convert nested `if`s into `arith.andi` + single `if`.
2602///
2603/// scf.if %arg0 {
2604/// scf.if %arg1 {
2605/// ...
2606/// scf.yield
2607/// }
2608/// scf.yield
2609/// }
2610/// becomes
2611///
2612/// %0 = arith.andi %arg0, %arg1
2613/// scf.if %0 {
2614/// ...
2615/// scf.yield
2616/// }
2617struct CombineNestedIfs : public OpRewritePattern<IfOp> {
2618 using OpRewritePattern<IfOp>::OpRewritePattern;
2619
2620 LogicalResult matchAndRewrite(IfOp op,
2621 PatternRewriter &rewriter) const override {
2622 auto nestedOps = op.thenBlock()->without_terminator();
2623 // Nested `if` must be the only op in block.
2624 if (!llvm::hasSingleElement(nestedOps))
2625 return failure();
2626
2627 // If there is an else block, it can only yield
2628 if (op.elseBlock() && !llvm::hasSingleElement(*op.elseBlock()))
2629 return failure();
2630
2631 auto nestedIf = dyn_cast<IfOp>(*nestedOps.begin());
2632 if (!nestedIf)
2633 return failure();
2634
2635 if (nestedIf.elseBlock() && !llvm::hasSingleElement(*nestedIf.elseBlock()))
2636 return failure();
2637
2638 SmallVector<Value> thenYield(op.thenYield().getOperands());
2639 SmallVector<Value> elseYield;
2640 if (op.elseBlock())
2641 llvm::append_range(elseYield, op.elseYield().getOperands());
2642
2643 // A list of indices for which we should upgrade the value yielded
2644 // in the else to a select.
2645 SmallVector<unsigned> elseYieldsToUpgradeToSelect;
2646
2647 // If the outer scf.if yields a value produced by the inner scf.if,
2648 // only permit combining if the value yielded when the condition
2649 // is false in the outer scf.if is the same value yielded when the
2650 // inner scf.if condition is false.
2651 // Note that the array access to elseYield will not go out of bounds
2652 // since it must have the same length as thenYield, since they both
2653 // come from the same scf.if.
2654 for (const auto &tup : llvm::enumerate(thenYield)) {
2655 if (tup.value().getDefiningOp() == nestedIf) {
2656 auto nestedIdx = llvm::cast<OpResult>(tup.value()).getResultNumber();
2657 if (nestedIf.elseYield().getOperand(nestedIdx) !=
2658 elseYield[tup.index()]) {
2659 return failure();
2660 }
2661 // If the correctness test passes, we will yield
2662 // corresponding value from the inner scf.if
2663 thenYield[tup.index()] = nestedIf.thenYield().getOperand(nestedIdx);
2664 continue;
2665 }
2666
2667 // Otherwise, we need to ensure the else block of the combined
2668 // condition still returns the same value when the outer condition is
2669 // true and the inner condition is false. This can be accomplished if
2670 // the then value is defined outside the outer scf.if and we replace the
2671 // value with a select that considers just the outer condition. Since
2672 // the else region contains just the yield, its yielded value is
2673 // defined outside the scf.if, by definition.
2674
2675 // If the then value is defined within the scf.if, bail.
2676 if (tup.value().getParentRegion() == &op.getThenRegion()) {
2677 return failure();
2678 }
2679 elseYieldsToUpgradeToSelect.push_back(tup.index());
2680 }
2681
2682 Location loc = op.getLoc();
2683 Value newCondition = arith::AndIOp::create(rewriter, loc, op.getCondition(),
2684 nestedIf.getCondition());
2685 auto newIf = IfOp::create(rewriter, loc, op.getResultTypes(), newCondition);
2686 Block *newIfBlock = rewriter.createBlock(&newIf.getThenRegion());
2687
2688 SmallVector<Value> results;
2689 llvm::append_range(results, newIf.getResults());
2690 rewriter.setInsertionPoint(newIf);
2691
2692 for (auto idx : elseYieldsToUpgradeToSelect)
2693 results[idx] =
2694 arith::SelectOp::create(rewriter, op.getLoc(), op.getCondition(),
2695 thenYield[idx], elseYield[idx]);
2696
2697 rewriter.mergeBlocks(nestedIf.thenBlock(), newIfBlock);
2698 rewriter.setInsertionPointToEnd(newIf.thenBlock());
2699 rewriter.replaceOpWithNewOp<YieldOp>(newIf.thenYield(), thenYield);
2700 if (!elseYield.empty()) {
2701 rewriter.createBlock(&newIf.getElseRegion());
2702 rewriter.setInsertionPointToEnd(newIf.elseBlock());
2703 YieldOp::create(rewriter, loc, elseYield);
2704 }
2705 rewriter.replaceOp(op, results);
2706 return success();
2707 }
2708};
2709
2710} // namespace
2711
2712void IfOp::getCanonicalizationPatterns(RewritePatternSet &results,
2713 MLIRContext *context) {
2714 results.add<CombineIfs, CombineNestedIfs, ConditionPropagation,
2715 ConvertTrivialIfToSelect, RemoveEmptyElseBranch,
2716 ReplaceIfYieldWithConditionOrValue>(context);
2718 results, IfOp::getOperationName());
2720 IfOp::getOperationName());
2721}
2722
2723Block *IfOp::thenBlock() { return &getThenRegion().back(); }
2724YieldOp IfOp::thenYield() { return cast<YieldOp>(&thenBlock()->back()); }
2725Block *IfOp::elseBlock() {
2726 Region &r = getElseRegion();
2727 if (r.empty())
2728 return nullptr;
2729 return &r.back();
2730}
2731YieldOp IfOp::elseYield() { return cast<YieldOp>(&elseBlock()->back()); }
2732
2733//===----------------------------------------------------------------------===//
2734// ParallelOp
2735//===----------------------------------------------------------------------===//
2736
2737void ParallelOp::build(
2738 OpBuilder &builder, OperationState &result, ValueRange lowerBounds,
2739 ValueRange upperBounds, ValueRange steps, ValueRange initVals,
2740 function_ref<void(OpBuilder &, Location, ValueRange, ValueRange)>
2741 bodyBuilderFn) {
2742 result.addOperands(lowerBounds);
2743 result.addOperands(upperBounds);
2744 result.addOperands(steps);
2745 result.addOperands(initVals);
2746 result.addAttribute(
2747 ParallelOp::getOperandSegmentSizeAttr(),
2748 builder.getDenseI32ArrayAttr({static_cast<int32_t>(lowerBounds.size()),
2749 static_cast<int32_t>(upperBounds.size()),
2750 static_cast<int32_t>(steps.size()),
2751 static_cast<int32_t>(initVals.size())}));
2752 result.addTypes(initVals.getTypes());
2753
2754 OpBuilder::InsertionGuard guard(builder);
2755 unsigned numIVs = steps.size();
2756 SmallVector<Type, 8> argTypes(numIVs, builder.getIndexType());
2757 SmallVector<Location, 8> argLocs(numIVs, result.location);
2758 Region *bodyRegion = result.addRegion();
2759 Block *bodyBlock = builder.createBlock(bodyRegion, {}, argTypes, argLocs);
2760
2761 if (bodyBuilderFn) {
2762 builder.setInsertionPointToStart(bodyBlock);
2763 bodyBuilderFn(builder, result.location,
2764 bodyBlock->getArguments().take_front(numIVs),
2765 bodyBlock->getArguments().drop_front(numIVs));
2766 }
2767 // Add terminator only if there are no reductions.
2768 if (initVals.empty())
2769 ParallelOp::ensureTerminator(*bodyRegion, builder, result.location);
2770}
2771
2772void ParallelOp::build(
2773 OpBuilder &builder, OperationState &result, ValueRange lowerBounds,
2774 ValueRange upperBounds, ValueRange steps,
2775 function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilderFn) {
2776 // Only pass a non-null wrapper if bodyBuilderFn is non-null itself. Make sure
2777 // we don't capture a reference to a temporary by constructing the lambda at
2778 // function level.
2779 auto wrappedBuilderFn = [&bodyBuilderFn](OpBuilder &nestedBuilder,
2780 Location nestedLoc, ValueRange ivs,
2781 ValueRange) {
2782 bodyBuilderFn(nestedBuilder, nestedLoc, ivs);
2783 };
2784 function_ref<void(OpBuilder &, Location, ValueRange, ValueRange)> wrapper;
2785 if (bodyBuilderFn)
2786 wrapper = wrappedBuilderFn;
2787
2788 build(builder, result, lowerBounds, upperBounds, steps, ValueRange(),
2789 wrapper);
2790}
2791
2792LogicalResult ParallelOp::verify() {
2793 // Check that there is at least one value in lowerBound, upperBound and step.
2794 // It is sufficient to test only step, because it is ensured already that the
2795 // number of elements in lowerBound, upperBound and step are the same.
2796 Operation::operand_range stepValues = getStep();
2797 if (stepValues.empty())
2798 return emitOpError(
2799 "needs at least one tuple element for lowerBound, upperBound and step");
2800
2801 // Check whether all constant step values are positive.
2802 for (Value stepValue : stepValues)
2803 if (auto cst = getConstantIntValue(stepValue))
2804 if (*cst <= 0)
2805 return emitOpError("constant step operand must be positive");
2806
2807 // Check that the body defines the same number of block arguments as the
2808 // number of tuple elements in step.
2809 Block *body = getBody();
2810 if (body->getNumArguments() != stepValues.size())
2811 return emitOpError() << "expects the same number of induction variables: "
2812 << body->getNumArguments()
2813 << " as bound and step values: " << stepValues.size();
2814 for (auto arg : body->getArguments())
2815 if (!arg.getType().isIndex())
2816 return emitOpError(
2817 "expects arguments for the induction variable to be of index type");
2818
2819 // Check that the terminator is an scf.reduce op.
2821 *this, getRegion(), "expects body to terminate with 'scf.reduce'");
2822 if (!reduceOp)
2823 return failure();
2824
2825 // Check that the number of results is the same as the number of reductions.
2826 auto resultsSize = getResults().size();
2827 auto reductionsSize = reduceOp.getReductions().size();
2828 auto initValsSize = getInitVals().size();
2829 if (resultsSize != reductionsSize)
2830 return emitOpError() << "expects number of results: " << resultsSize
2831 << " to be the same as number of reductions: "
2832 << reductionsSize;
2833 if (resultsSize != initValsSize)
2834 return emitOpError() << "expects number of results: " << resultsSize
2835 << " to be the same as number of initial values: "
2836 << initValsSize;
2837 if (reduceOp.getNumOperands() != initValsSize)
2838 // Delegate error reporting to ReduceOp
2839 return success();
2840
2841 // Check that the types of the results and reductions are the same.
2842 for (int64_t i = 0; i < static_cast<int64_t>(reductionsSize); ++i) {
2843 auto resultType = getOperation()->getResult(i).getType();
2844 auto reductionOperandType = reduceOp.getOperands()[i].getType();
2845 if (resultType != reductionOperandType)
2846 return reduceOp.emitOpError()
2847 << "expects type of " << i
2848 << "-th reduction operand: " << reductionOperandType
2849 << " to be the same as the " << i
2850 << "-th result type: " << resultType;
2851 }
2852 return success();
2853}
2854
2855ParseResult ParallelOp::parse(OpAsmParser &parser, OperationState &result) {
2856 auto &builder = parser.getBuilder();
2857 // Parse an opening `(` followed by induction variables followed by `)`
2858 SmallVector<OpAsmParser::Argument, 4> ivs;
2860 return failure();
2861
2862 // Parse loop bounds.
2863 SmallVector<OpAsmParser::UnresolvedOperand, 4> lower;
2864 if (parser.parseEqual() ||
2865 parser.parseOperandList(lower, ivs.size(),
2867 parser.resolveOperands(lower, builder.getIndexType(), result.operands))
2868 return failure();
2869
2870 SmallVector<OpAsmParser::UnresolvedOperand, 4> upper;
2871 if (parser.parseKeyword("to") ||
2872 parser.parseOperandList(upper, ivs.size(),
2874 parser.resolveOperands(upper, builder.getIndexType(), result.operands))
2875 return failure();
2876
2877 // Parse step values.
2878 SmallVector<OpAsmParser::UnresolvedOperand, 4> steps;
2879 if (parser.parseKeyword("step") ||
2880 parser.parseOperandList(steps, ivs.size(),
2882 parser.resolveOperands(steps, builder.getIndexType(), result.operands))
2883 return failure();
2884
2885 // Parse init values.
2886 SmallVector<OpAsmParser::UnresolvedOperand, 4> initVals;
2887 if (succeeded(parser.parseOptionalKeyword("init"))) {
2888 if (parser.parseOperandList(initVals, OpAsmParser::Delimiter::Paren))
2889 return failure();
2890 }
2891
2892 // Parse optional results in case there is a reduce.
2893 if (parser.parseOptionalArrowTypeList(result.types))
2894 return failure();
2895
2896 // Now parse the body.
2897 Region *body = result.addRegion();
2898 for (auto &iv : ivs)
2899 iv.type = builder.getIndexType();
2900 if (parser.parseRegion(*body, ivs))
2901 return failure();
2902
2903 // Set `operandSegmentSizes` attribute.
2904 result.addAttribute(
2905 ParallelOp::getOperandSegmentSizeAttr(),
2906 builder.getDenseI32ArrayAttr({static_cast<int32_t>(lower.size()),
2907 static_cast<int32_t>(upper.size()),
2908 static_cast<int32_t>(steps.size()),
2909 static_cast<int32_t>(initVals.size())}));
2910
2911 // Parse attributes.
2912 if (parser.parseOptionalAttrDict(result.attributes) ||
2913 parser.resolveOperands(initVals, result.types, parser.getNameLoc(),
2914 result.operands))
2915 return failure();
2916
2917 // Add a terminator if none was parsed.
2918 ParallelOp::ensureTerminator(*body, builder, result.location);
2919 return success();
2920}
2921
2922void ParallelOp::print(OpAsmPrinter &p) {
2923 p << " (" << getBody()->getArguments() << ") = (" << getLowerBound()
2924 << ") to (" << getUpperBound() << ") step (" << getStep() << ")";
2925 if (!getInitVals().empty())
2926 p << " init (" << getInitVals() << ")";
2927 p.printOptionalArrowTypeList(getResultTypes());
2928 p << ' ';
2929 p.printRegion(getRegion(), /*printEntryBlockArgs=*/false);
2931 (*this)->getAttrs(),
2932 /*elidedAttrs=*/ParallelOp::getOperandSegmentSizeAttr());
2933}
2934
2935SmallVector<Region *> ParallelOp::getLoopRegions() { return {&getRegion()}; }
2936
2937std::optional<SmallVector<Value>> ParallelOp::getLoopInductionVars() {
2938 return SmallVector<Value>{getBody()->getArguments()};
2939}
2940
2941std::optional<SmallVector<OpFoldResult>> ParallelOp::getLoopLowerBounds() {
2942 return getLowerBound();
2943}
2944
2945std::optional<SmallVector<OpFoldResult>> ParallelOp::getLoopUpperBounds() {
2946 return getUpperBound();
2947}
2948
2949std::optional<SmallVector<OpFoldResult>> ParallelOp::getLoopSteps() {
2950 return getStep();
2951}
2952
2954 auto ivArg = llvm::dyn_cast<BlockArgument>(val);
2955 if (!ivArg)
2956 return ParallelOp();
2957 assert(ivArg.getOwner() && "unlinked block argument");
2958 auto *containingOp = ivArg.getOwner()->getParentOp();
2959 return dyn_cast<ParallelOp>(containingOp);
2960}
2961
2962namespace {
2963// Collapse loop dimensions that perform a single iteration.
2964struct ParallelOpSingleOrZeroIterationDimsFolder
2965 : public OpRewritePattern<ParallelOp> {
2966 using OpRewritePattern<ParallelOp>::OpRewritePattern;
2967
2968 LogicalResult matchAndRewrite(ParallelOp op,
2969 PatternRewriter &rewriter) const override {
2970 Location loc = op.getLoc();
2971
2972 // Compute new loop bounds that omit all single-iteration loop dimensions.
2973 SmallVector<Value> newLowerBounds, newUpperBounds, newSteps;
2974 IRMapping mapping;
2975 for (auto [lb, ub, step, iv] :
2976 llvm::zip(op.getLowerBound(), op.getUpperBound(), op.getStep(),
2977 op.getInductionVars())) {
2978 auto numIterations =
2979 constantTripCount(lb, ub, step, /*isSigned=*/true, computeUbMinusLb);
2980 if (numIterations.has_value()) {
2981 // Remove the loop if it performs zero iterations.
2982 if (*numIterations == 0) {
2983 rewriter.replaceOp(op, op.getInitVals());
2984 return success();
2985 }
2986 // Replace the loop induction variable by the lower bound if the loop
2987 // performs a single iteration. Otherwise, copy the loop bounds.
2988 if (*numIterations == 1) {
2989 mapping.map(iv, getValueOrCreateConstantIndexOp(rewriter, loc, lb));
2990 continue;
2991 }
2992 }
2993 newLowerBounds.push_back(lb);
2994 newUpperBounds.push_back(ub);
2995 newSteps.push_back(step);
2996 }
2997 // Exit if none of the loop dimensions perform a single iteration.
2998 if (newLowerBounds.size() == op.getLowerBound().size())
2999 return failure();
3000
3001 if (newLowerBounds.empty()) {
3002 // All of the loop dimensions perform a single iteration. Inline
3003 // loop body and nested ReduceOp's
3004 SmallVector<Value> results;
3005 results.reserve(op.getInitVals().size());
3006 for (auto &bodyOp : op.getBody()->without_terminator())
3007 rewriter.clone(bodyOp, mapping);
3008 auto reduceOp = cast<ReduceOp>(op.getBody()->getTerminator());
3009 for (int64_t i = 0, e = reduceOp.getReductions().size(); i < e; ++i) {
3010 Block &reduceBlock = reduceOp.getReductions()[i].front();
3011 auto initValIndex = results.size();
3012 mapping.map(reduceBlock.getArgument(0), op.getInitVals()[initValIndex]);
3013 mapping.map(reduceBlock.getArgument(1),
3014 mapping.lookupOrDefault(reduceOp.getOperands()[i]));
3015 for (auto &reduceBodyOp : reduceBlock.without_terminator())
3016 rewriter.clone(reduceBodyOp, mapping);
3017
3018 auto result = mapping.lookupOrDefault(
3019 cast<ReduceReturnOp>(reduceBlock.getTerminator()).getResult());
3020 results.push_back(result);
3021 }
3022
3023 rewriter.replaceOp(op, results);
3024 return success();
3025 }
3026 // Replace the parallel loop by lower-dimensional parallel loop.
3027 auto newOp =
3028 ParallelOp::create(rewriter, op.getLoc(), newLowerBounds,
3029 newUpperBounds, newSteps, op.getInitVals(), nullptr);
3030 // Erase the empty block that was inserted by the builder.
3031 rewriter.eraseBlock(newOp.getBody());
3032 // Clone the loop body and remap the block arguments of the collapsed loops
3033 // (inlining does not support a cancellable block argument mapping).
3034 rewriter.cloneRegionBefore(op.getRegion(), newOp.getRegion(),
3035 newOp.getRegion().begin(), mapping);
3036 rewriter.replaceOp(op, newOp.getResults());
3037 return success();
3038 }
3039};
3040
3041struct MergeNestedParallelLoops : public OpRewritePattern<ParallelOp> {
3042 using OpRewritePattern<ParallelOp>::OpRewritePattern;
3043
3044 LogicalResult matchAndRewrite(ParallelOp op,
3045 PatternRewriter &rewriter) const override {
3046 Block &outerBody = *op.getBody();
3047 if (!llvm::hasSingleElement(outerBody.without_terminator()))
3048 return failure();
3049
3050 auto innerOp = dyn_cast<ParallelOp>(outerBody.front());
3051 if (!innerOp)
3052 return failure();
3053
3054 for (auto val : outerBody.getArguments())
3055 if (llvm::is_contained(innerOp.getLowerBound(), val) ||
3056 llvm::is_contained(innerOp.getUpperBound(), val) ||
3057 llvm::is_contained(innerOp.getStep(), val))
3058 return failure();
3059
3060 // Reductions are not supported yet.
3061 if (!op.getInitVals().empty() || !innerOp.getInitVals().empty())
3062 return failure();
3063
3064 auto bodyBuilder = [&](OpBuilder &builder, Location /*loc*/,
3065 ValueRange iterVals, ValueRange) {
3066 Block &innerBody = *innerOp.getBody();
3067 assert(iterVals.size() ==
3068 (outerBody.getNumArguments() + innerBody.getNumArguments()));
3069 IRMapping mapping;
3070 mapping.map(outerBody.getArguments(),
3071 iterVals.take_front(outerBody.getNumArguments()));
3072 mapping.map(innerBody.getArguments(),
3073 iterVals.take_back(innerBody.getNumArguments()));
3074 for (Operation &op : innerBody.without_terminator())
3075 builder.clone(op, mapping);
3076 };
3077
3078 auto concatValues = [](const auto &first, const auto &second) {
3079 SmallVector<Value> ret;
3080 ret.reserve(first.size() + second.size());
3081 ret.assign(first.begin(), first.end());
3082 ret.append(second.begin(), second.end());
3083 return ret;
3084 };
3085
3086 auto newLowerBounds =
3087 concatValues(op.getLowerBound(), innerOp.getLowerBound());
3088 auto newUpperBounds =
3089 concatValues(op.getUpperBound(), innerOp.getUpperBound());
3090 auto newSteps = concatValues(op.getStep(), innerOp.getStep());
3091
3092 rewriter.replaceOpWithNewOp<ParallelOp>(op, newLowerBounds, newUpperBounds,
3093 newSteps, ValueRange(),
3094 bodyBuilder);
3095 return success();
3096 }
3097};
3098
3099} // namespace
3100
3101void ParallelOp::getCanonicalizationPatterns(RewritePatternSet &results,
3102 MLIRContext *context) {
3103 results
3104 .add<ParallelOpSingleOrZeroIterationDimsFolder, MergeNestedParallelLoops>(
3105 context);
3106}
3107
3108/// Given the region at `index`, or the parent operation if `index` is None,
3109/// return the successor regions. These are the regions that may be selected
3110/// during the flow of control. `operands` is a set of optional attributes that
3111/// correspond to a constant value for each operand, or null if that operand is
3112/// not a constant.
3113void ParallelOp::getSuccessorRegions(
3114 RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {
3115 // Both the operation itself and the region may be branching into the body or
3116 // back into the operation itself. It is possible for loop not to enter the
3117 // body.
3118 regions.push_back(RegionSuccessor(&getRegion()));
3119 regions.push_back(RegionSuccessor::parent());
3120}
3121
3122//===----------------------------------------------------------------------===//
3123// ReduceOp
3124//===----------------------------------------------------------------------===//
3125
3126void ReduceOp::build(OpBuilder &builder, OperationState &result) {}
3127
3128void ReduceOp::build(OpBuilder &builder, OperationState &result,
3129 ValueRange operands) {
3130 result.addOperands(operands);
3131 for (Value v : operands) {
3132 OpBuilder::InsertionGuard guard(builder);
3133 Region *bodyRegion = result.addRegion();
3134 builder.createBlock(bodyRegion, {},
3135 ArrayRef<Type>{v.getType(), v.getType()},
3136 {result.location, result.location});
3137 }
3138}
3139
3140LogicalResult ReduceOp::verifyRegions() {
3141 if (getReductions().size() != getOperands().size())
3142 return emitOpError() << "expects number of reduction regions: "
3143 << getReductions().size()
3144 << " to be the same as number of reduction operands: "
3145 << getOperands().size();
3146 // The region of a ReduceOp has two arguments of the same type as its
3147 // corresponding operand.
3148 for (int64_t i = 0, e = getReductions().size(); i < e; ++i) {
3149 auto type = getOperands()[i].getType();
3150 Block &block = getReductions()[i].front();
3151 if (block.empty())
3152 return emitOpError() << i << "-th reduction has an empty body";
3153 if (block.getNumArguments() != 2 ||
3154 llvm::any_of(block.getArguments(), [&](const BlockArgument &arg) {
3155 return arg.getType() != type;
3156 }))
3157 return emitOpError() << "expected two block arguments with type " << type
3158 << " in the " << i << "-th reduction region";
3159
3160 // Check that the block is terminated by a ReduceReturnOp.
3161 if (!isa<ReduceReturnOp>(block.getTerminator()))
3162 return emitOpError("reduction bodies must be terminated with an "
3163 "'scf.reduce.return' op");
3164 }
3165
3166 return success();
3167}
3168
3169MutableOperandRange
3170ReduceOp::getMutableSuccessorOperands(RegionSuccessor point) {
3171 // No operands are forwarded to the next iteration.
3172 return MutableOperandRange(getOperation(), /*start=*/0, /*length=*/0);
3173}
3174
3175//===----------------------------------------------------------------------===//
3176// ReduceReturnOp
3177//===----------------------------------------------------------------------===//
3178
3179LogicalResult ReduceReturnOp::verify() {
3180 // The type of the return value should be the same type as the types of the
3181 // block arguments of the reduction body.
3182 Block *reductionBody = getOperation()->getBlock();
3183 // Should already be verified by an op trait.
3184 assert(isa<ReduceOp>(reductionBody->getParentOp()) && "expected scf.reduce");
3185 Type expectedResultType = reductionBody->getArgument(0).getType();
3186 if (expectedResultType != getResult().getType())
3187 return emitOpError() << "must have type " << expectedResultType
3188 << " (the type of the reduction inputs)";
3189 return success();
3190}
3191
3192//===----------------------------------------------------------------------===//
3193// WhileOp
3194//===----------------------------------------------------------------------===//
3195
3196void WhileOp::build(::mlir::OpBuilder &odsBuilder,
3197 ::mlir::OperationState &odsState, TypeRange resultTypes,
3198 ValueRange inits, BodyBuilderFn beforeBuilder,
3199 BodyBuilderFn afterBuilder) {
3200 odsState.addOperands(inits);
3201 odsState.addTypes(resultTypes);
3202
3203 OpBuilder::InsertionGuard guard(odsBuilder);
3204
3205 // Build before region.
3206 SmallVector<Location, 4> beforeArgLocs;
3207 beforeArgLocs.reserve(inits.size());
3208 for (Value operand : inits) {
3209 beforeArgLocs.push_back(operand.getLoc());
3210 }
3211
3212 Region *beforeRegion = odsState.addRegion();
3213 Block *beforeBlock = odsBuilder.createBlock(beforeRegion, /*insertPt=*/{},
3214 inits.getTypes(), beforeArgLocs);
3215 if (beforeBuilder)
3216 beforeBuilder(odsBuilder, odsState.location, beforeBlock->getArguments());
3217
3218 // Build after region.
3219 SmallVector<Location, 4> afterArgLocs(resultTypes.size(), odsState.location);
3220
3221 Region *afterRegion = odsState.addRegion();
3222 Block *afterBlock = odsBuilder.createBlock(afterRegion, /*insertPt=*/{},
3223 resultTypes, afterArgLocs);
3224
3225 if (afterBuilder)
3226 afterBuilder(odsBuilder, odsState.location, afterBlock->getArguments());
3227}
3228
3229ConditionOp WhileOp::getConditionOp() {
3230 return cast<ConditionOp>(getBeforeBody()->getTerminator());
3231}
3232
3233YieldOp WhileOp::getYieldOp() {
3234 return cast<YieldOp>(getAfterBody()->getTerminator());
3235}
3236
3237std::optional<MutableArrayRef<OpOperand>> WhileOp::getYieldedValuesMutable() {
3238 return getYieldOp().getResultsMutable();
3239}
3240
3241Block::BlockArgListType WhileOp::getBeforeArguments() {
3242 return getBeforeBody()->getArguments();
3243}
3244
3245Block::BlockArgListType WhileOp::getAfterArguments() {
3246 return getAfterBody()->getArguments();
3247}
3248
3249Block::BlockArgListType WhileOp::getRegionIterArgs() {
3250 return getBeforeArguments();
3251}
3252
3253OperandRange WhileOp::getEntrySuccessorOperands(RegionSuccessor successor) {
3254 assert(successor.getSuccessor() == &getBefore() &&
3255 "WhileOp is expected to branch only to the first region");
3256 return getInits();
3257}
3258
3259void WhileOp::getSuccessorRegions(RegionBranchPoint point,
3260 SmallVectorImpl<RegionSuccessor> &regions) {
3261 // The parent op always branches to the condition region.
3262 if (point.isParent()) {
3263 regions.emplace_back(&getBefore());
3264 return;
3265 }
3266
3267 assert(llvm::is_contained(
3268 {&getAfter(), &getBefore()},
3269 point.getTerminatorPredecessorOrNull()->getParentRegion()) &&
3270 "there are only two regions in a WhileOp");
3271 // The body region always branches back to the condition region.
3272 if (point.getTerminatorPredecessorOrNull()->getParentRegion() ==
3273 &getAfter()) {
3274 regions.emplace_back(&getBefore());
3275 return;
3276 }
3277
3278 regions.push_back(RegionSuccessor::parent());
3279 regions.emplace_back(&getAfter());
3280}
3281
3282ValueRange WhileOp::getSuccessorInputs(RegionSuccessor successor) {
3283 if (successor.isParent())
3284 return getOperation()->getResults();
3285 if (successor == &getBefore())
3286 return getBefore().getArguments();
3287 if (successor == &getAfter())
3288 return getAfter().getArguments();
3289 llvm_unreachable("invalid region successor");
3290}
3291
3292SmallVector<Region *> WhileOp::getLoopRegions() {
3293 return {&getBefore(), &getAfter()};
3294}
3295
3296/// Parses a `while` op.
3297///
3298/// op ::= `scf.while` assignments `:` function-type region `do` region
3299/// `attributes` attribute-dict
3300/// initializer ::= /* empty */ | `(` assignment-list `)`
3301/// assignment-list ::= assignment | assignment `,` assignment-list
3302/// assignment ::= ssa-value `=` ssa-value
3303ParseResult scf::WhileOp::parse(OpAsmParser &parser, OperationState &result) {
3304 SmallVector<OpAsmParser::Argument, 4> regionArgs;
3305 SmallVector<OpAsmParser::UnresolvedOperand, 4> operands;
3306 Region *before = result.addRegion();
3307 Region *after = result.addRegion();
3308
3309 OptionalParseResult listResult =
3310 parser.parseOptionalAssignmentList(regionArgs, operands);
3311 if (listResult.has_value() && failed(listResult.value()))
3312 return failure();
3313
3314 FunctionType functionType;
3315 SMLoc typeLoc = parser.getCurrentLocation();
3316 if (failed(parser.parseColonType(functionType)))
3317 return failure();
3318
3319 result.addTypes(functionType.getResults());
3320
3321 if (functionType.getNumInputs() != operands.size()) {
3322 return parser.emitError(typeLoc)
3323 << "expected as many input types as operands " << "(expected "
3324 << operands.size() << " got " << functionType.getNumInputs() << ")";
3325 }
3326
3327 // Resolve input operands.
3328 if (failed(parser.resolveOperands(operands, functionType.getInputs(),
3329 parser.getCurrentLocation(),
3330 result.operands)))
3331 return failure();
3332
3333 // Propagate the types into the region arguments.
3334 for (size_t i = 0, e = regionArgs.size(); i != e; ++i)
3335 regionArgs[i].type = functionType.getInput(i);
3336
3337 return failure(parser.parseRegion(*before, regionArgs) ||
3338 parser.parseKeyword("do") || parser.parseRegion(*after) ||
3339 parser.parseOptionalAttrDictWithKeyword(result.attributes));
3340}
3341
3342/// Prints a `while` op.
3343void scf::WhileOp::print(OpAsmPrinter &p) {
3344 printInitializationList(p, getBeforeArguments(), getInits(), " ");
3345 p << " : ";
3346 p.printFunctionalType(getInits().getTypes(), getResults().getTypes());
3347 p << ' ';
3348 p.printRegion(getBefore(), /*printEntryBlockArgs=*/false);
3349 p << " do ";
3350 p.printRegion(getAfter());
3351 p.printOptionalAttrDictWithKeyword((*this)->getAttrs());
3352}
3353
3354/// Verifies that two ranges of types match, i.e. have the same number of
3355/// entries and that types are pairwise equals. Reports errors on the given
3356/// operation in case of mismatch.
3357template <typename OpTy>
3358static LogicalResult verifyTypeRangesMatch(OpTy op, TypeRange left,
3359 TypeRange right, StringRef message) {
3360 if (left.size() != right.size())
3361 return op.emitOpError("expects the same number of ") << message;
3362
3363 for (unsigned i = 0, e = left.size(); i < e; ++i) {
3364 if (left[i] != right[i]) {
3365 InFlightDiagnostic diag = op.emitOpError("expects the same types for ")
3366 << message;
3367 diag.attachNote() << "for argument " << i << ", found " << left[i]
3368 << " and " << right[i];
3369 return diag;
3370 }
3371 }
3372
3373 return success();
3374}
3375
3376LogicalResult scf::WhileOp::verify() {
3377 auto beforeTerminator = verifyAndGetTerminator<scf::ConditionOp>(
3378 *this, getBefore(),
3379 "expects the 'before' region to terminate with 'scf.condition'");
3380 if (!beforeTerminator)
3381 return failure();
3382
3383 auto afterTerminator = verifyAndGetTerminator<scf::YieldOp>(
3384 *this, getAfter(),
3385 "expects the 'after' region to terminate with 'scf.yield'");
3386 return success(afterTerminator != nullptr);
3387}
3388
3389namespace {
3390/// Move a scf.if op that is directly before the scf.condition op in the while
3391/// before region, and whose condition matches the condition of the
3392/// scf.condition op, down into the while after region.
3393///
3394/// scf.while (..) : (...) -> ... {
3395/// %additional_used_values = ...
3396/// %cond = ...
3397/// ...
3398/// %res = scf.if %cond -> (...) {
3399/// use(%additional_used_values)
3400/// ... // then block
3401/// scf.yield %then_value
3402/// } else {
3403/// scf.yield %else_value
3404/// }
3405/// scf.condition(%cond) %res, ...
3406/// } do {
3407/// ^bb0(%res_arg, ...):
3408/// use(%res_arg)
3409/// ...
3410///
3411/// becomes
3412/// scf.while (..) : (...) -> ... {
3413/// %additional_used_values = ...
3414/// %cond = ...
3415/// ...
3416/// scf.condition(%cond) %else_value, ..., %additional_used_values
3417/// } do {
3418/// ^bb0(%res_arg ..., %additional_args): :
3419/// use(%additional_args)
3420/// ... // if then block
3421/// use(%then_value)
3422/// ...
3423struct WhileMoveIfDown : public OpRewritePattern<scf::WhileOp> {
3424 using OpRewritePattern<scf::WhileOp>::OpRewritePattern;
3425
3426 LogicalResult matchAndRewrite(scf::WhileOp op,
3427 PatternRewriter &rewriter) const override {
3428 auto conditionOp = op.getConditionOp();
3429
3430 // Only support ifOp right before the condition at the moment. Relaxing this
3431 // would require to:
3432 // - check that the body does not have side-effects conflicting with
3433 // operations between the if and the condition.
3434 // - check that results of the if operation are only used as arguments to
3435 // the condition.
3436 auto ifOp = dyn_cast_or_null<scf::IfOp>(conditionOp->getPrevNode());
3437
3438 // Check that the ifOp is directly before the conditionOp and that it
3439 // matches the condition of the conditionOp. Also ensure that the ifOp has
3440 // no else block with content, as that would complicate the transformation.
3441 // TODO: support else blocks with content.
3442 if (!ifOp || ifOp.getCondition() != conditionOp.getCondition() ||
3443 (ifOp.elseBlock() && !ifOp.elseBlock()->without_terminator().empty()))
3444 return failure();
3445
3446 assert((ifOp->use_empty() || (llvm::all_equal(ifOp->getUsers()) &&
3447 *ifOp->user_begin() == conditionOp)) &&
3448 "ifOp has unexpected uses");
3449
3450 Location loc = op.getLoc();
3451
3452 // Replace uses of ifOp results in the conditionOp with the yielded values
3453 // from the ifOp branches.
3454 for (auto [idx, arg] : llvm::enumerate(conditionOp.getArgs())) {
3455 auto it = llvm::find(ifOp->getResults(), arg);
3456 if (it != ifOp->getResults().end()) {
3457 size_t ifOpIdx = it.getIndex();
3458 Value thenValue = ifOp.thenYield()->getOperand(ifOpIdx);
3459 Value elseValue = ifOp.elseYield()->getOperand(ifOpIdx);
3460
3461 rewriter.replaceAllUsesWith(ifOp->getResults()[ifOpIdx], elseValue);
3462 rewriter.replaceAllUsesWith(op.getAfterArguments()[idx], thenValue);
3463 }
3464 }
3465
3466 // Collect additional used values from before region.
3467 SetVector<Value> additionalUsedValuesSet;
3468 visitUsedValuesDefinedAbove(ifOp.getThenRegion(), [&](OpOperand *operand) {
3469 if (&op.getBefore() == operand->get().getParentRegion())
3470 additionalUsedValuesSet.insert(operand->get());
3471 });
3472
3473 // Create new whileOp with additional used values as results.
3474 auto additionalUsedValues = additionalUsedValuesSet.getArrayRef();
3475 auto additionalValueTypes = llvm::map_to_vector(
3476 additionalUsedValues, [](Value val) { return val.getType(); });
3477 size_t additionalValueSize = additionalUsedValues.size();
3478 SmallVector<Type> newResultTypes(op.getResultTypes());
3479 newResultTypes.append(additionalValueTypes);
3480
3481 auto newWhileOp =
3482 scf::WhileOp::create(rewriter, loc, newResultTypes, op.getInits());
3483
3484 rewriter.modifyOpInPlace(newWhileOp, [&] {
3485 newWhileOp.getBefore().takeBody(op.getBefore());
3486 newWhileOp.getAfter().takeBody(op.getAfter());
3487 newWhileOp.getAfter().addArguments(
3488 additionalValueTypes,
3489 SmallVector<Location>(additionalValueSize, loc));
3490 });
3491
3492 rewriter.modifyOpInPlace(conditionOp, [&] {
3493 conditionOp.getArgsMutable().append(additionalUsedValues);
3494 });
3495
3496 // Replace uses of additional used values inside the ifOp then region with
3497 // the whileOp after region arguments.
3498 rewriter.replaceUsesWithIf(
3499 additionalUsedValues,
3500 newWhileOp.getAfterArguments().take_back(additionalValueSize),
3501 [&](OpOperand &use) {
3502 return ifOp.getThenRegion().isAncestor(
3503 use.getOwner()->getParentRegion());
3504 });
3505
3506 // Inline ifOp then region into new whileOp after region.
3507 rewriter.eraseOp(ifOp.thenYield());
3508 rewriter.inlineBlockBefore(ifOp.thenBlock(), newWhileOp.getAfterBody(),
3509 newWhileOp.getAfterBody()->begin());
3510 rewriter.eraseOp(ifOp);
3511 rewriter.replaceOp(op,
3512 newWhileOp->getResults().drop_back(additionalValueSize));
3513 return success();
3514 }
3515};
3516
3517/// Replace uses of the condition within the do block with true, since otherwise
3518/// the block would not be evaluated.
3519///
3520/// scf.while (..) : (i1, ...) -> ... {
3521/// %condition = call @evaluate_condition() : () -> i1
3522/// scf.condition(%condition) %condition : i1, ...
3523/// } do {
3524/// ^bb0(%arg0: i1, ...):
3525/// use(%arg0)
3526/// ...
3527///
3528/// becomes
3529/// scf.while (..) : (i1, ...) -> ... {
3530/// %condition = call @evaluate_condition() : () -> i1
3531/// scf.condition(%condition) %condition : i1, ...
3532/// } do {
3533/// ^bb0(%arg0: i1, ...):
3534/// use(%true)
3535/// ...
3536struct WhileConditionTruth : public OpRewritePattern<WhileOp> {
3537 using OpRewritePattern<WhileOp>::OpRewritePattern;
3538
3539 LogicalResult matchAndRewrite(WhileOp op,
3540 PatternRewriter &rewriter) const override {
3541 auto term = op.getConditionOp();
3542
3543 // These variables serve to prevent creating duplicate constants
3544 // and hold constant true or false values.
3545 Value constantTrue = nullptr;
3546
3547 bool replaced = false;
3548 for (auto yieldedAndBlockArgs :
3549 llvm::zip(term.getArgs(), op.getAfterArguments())) {
3550 if (std::get<0>(yieldedAndBlockArgs) == term.getCondition()) {
3551 if (!std::get<1>(yieldedAndBlockArgs).use_empty()) {
3552 if (!constantTrue)
3553 constantTrue = arith::ConstantOp::create(
3554 rewriter, op.getLoc(), term.getCondition().getType(),
3555 rewriter.getBoolAttr(true));
3556
3557 rewriter.replaceAllUsesWith(std::get<1>(yieldedAndBlockArgs),
3558 constantTrue);
3559 replaced = true;
3560 }
3561 }
3562 }
3563 return success(replaced);
3564 }
3565};
3566
3567/// Replace operations equivalent to the condition in the do block with true,
3568/// since otherwise the block would not be evaluated.
3569///
3570/// scf.while (..) : (i32, ...) -> ... {
3571/// %z = ... : i32
3572/// %condition = cmpi pred %z, %a
3573/// scf.condition(%condition) %z : i32, ...
3574/// } do {
3575/// ^bb0(%arg0: i32, ...):
3576/// %condition2 = cmpi pred %arg0, %a
3577/// use(%condition2)
3578/// ...
3579///
3580/// becomes
3581/// scf.while (..) : (i32, ...) -> ... {
3582/// %z = ... : i32
3583/// %condition = cmpi pred %z, %a
3584/// scf.condition(%condition) %z : i32, ...
3585/// } do {
3586/// ^bb0(%arg0: i32, ...):
3587/// use(%true)
3588/// ...
3589struct WhileCmpCond : public OpRewritePattern<scf::WhileOp> {
3590 using OpRewritePattern<scf::WhileOp>::OpRewritePattern;
3591
3592 LogicalResult matchAndRewrite(scf::WhileOp op,
3593 PatternRewriter &rewriter) const override {
3594 using namespace scf;
3595 auto cond = op.getConditionOp();
3596 auto cmp = cond.getCondition().getDefiningOp<arith::CmpIOp>();
3597 if (!cmp)
3598 return failure();
3599 bool changed = false;
3600 for (auto tup : llvm::zip(cond.getArgs(), op.getAfterArguments())) {
3601 for (size_t opIdx = 0; opIdx < 2; opIdx++) {
3602 if (std::get<0>(tup) != cmp.getOperand(opIdx))
3603 continue;
3604 for (OpOperand &u :
3605 llvm::make_early_inc_range(std::get<1>(tup).getUses())) {
3606 auto cmp2 = dyn_cast<arith::CmpIOp>(u.getOwner());
3607 if (!cmp2)
3608 continue;
3609 // For a binary operator 1-opIdx gets the other side.
3610 if (cmp2.getOperand(1 - opIdx) != cmp.getOperand(1 - opIdx))
3611 continue;
3612 bool samePredicate;
3613 if (cmp2.getPredicate() == cmp.getPredicate())
3614 samePredicate = true;
3615 else if (cmp2.getPredicate() ==
3616 arith::invertPredicate(cmp.getPredicate()))
3617 samePredicate = false;
3618 else
3619 continue;
3620
3621 rewriter.replaceOpWithNewOp<arith::ConstantIntOp>(cmp2, samePredicate,
3622 1);
3623 changed = true;
3624 }
3625 }
3626 }
3627 return success(changed);
3628 }
3629};
3630
3631/// If both ranges contain same values return mappping indices from args2 to
3632/// args1. Otherwise return std::nullopt.
3633static std::optional<SmallVector<unsigned>> getArgsMapping(ValueRange args1,
3634 ValueRange args2) {
3635 if (args1.size() != args2.size())
3636 return std::nullopt;
3637
3638 SmallVector<unsigned> ret(args1.size());
3639 for (auto &&[i, arg1] : llvm::enumerate(args1)) {
3640 auto it = llvm::find(args2, arg1);
3641 if (it == args2.end())
3642 return std::nullopt;
3643
3644 ret[std::distance(args2.begin(), it)] = static_cast<unsigned>(i);
3645 }
3646
3647 return ret;
3648}
3649
3650static bool hasDuplicates(ValueRange args) {
3651 llvm::SmallDenseSet<Value> set;
3652 for (Value arg : args) {
3653 if (!set.insert(arg).second)
3654 return true;
3655 }
3656 return false;
3657}
3658
3659/// If `before` block args are directly forwarded to `scf.condition`, rearrange
3660/// `scf.condition` args into same order as block args. Update `after` block
3661/// args and op result values accordingly.
3662/// Needed to simplify `scf.while` -> `scf.for` uplifting.
3663struct WhileOpAlignBeforeArgs : public OpRewritePattern<WhileOp> {
3665
3666 LogicalResult matchAndRewrite(WhileOp loop,
3667 PatternRewriter &rewriter) const override {
3668 auto *oldBefore = loop.getBeforeBody();
3669 ConditionOp oldTerm = loop.getConditionOp();
3670 ValueRange beforeArgs = oldBefore->getArguments();
3671 ValueRange termArgs = oldTerm.getArgs();
3672 if (beforeArgs == termArgs)
3673 return failure();
3674
3675 if (hasDuplicates(termArgs))
3676 return failure();
3677
3678 auto mapping = getArgsMapping(beforeArgs, termArgs);
3679 if (!mapping)
3680 return failure();
3681
3682 {
3683 OpBuilder::InsertionGuard g(rewriter);
3684 rewriter.setInsertionPoint(oldTerm);
3685 rewriter.replaceOpWithNewOp<ConditionOp>(oldTerm, oldTerm.getCondition(),
3686 beforeArgs);
3687 }
3688
3689 auto *oldAfter = loop.getAfterBody();
3690
3691 SmallVector<Type> newResultTypes(beforeArgs.size());
3692 for (auto &&[i, j] : llvm::enumerate(*mapping))
3693 newResultTypes[j] = loop.getResult(i).getType();
3694
3695 auto newLoop = WhileOp::create(
3696 rewriter, loop.getLoc(), newResultTypes, loop.getInits(),
3697 /*beforeBuilder=*/nullptr, /*afterBuilder=*/nullptr);
3698 auto *newBefore = newLoop.getBeforeBody();
3699 auto *newAfter = newLoop.getAfterBody();
3700
3701 SmallVector<Value> newResults(beforeArgs.size());
3702 SmallVector<Value> newAfterArgs(beforeArgs.size());
3703 for (auto &&[i, j] : llvm::enumerate(*mapping)) {
3704 newResults[i] = newLoop.getResult(j);
3705 newAfterArgs[i] = newAfter->getArgument(j);
3706 }
3707
3708 rewriter.inlineBlockBefore(oldBefore, newBefore, newBefore->begin(),
3709 newBefore->getArguments());
3710 rewriter.inlineBlockBefore(oldAfter, newAfter, newAfter->begin(),
3711 newAfterArgs);
3712
3713 rewriter.replaceOp(loop, newResults);
3714 return success();
3715 }
3716};
3717} // namespace
3718
3719void WhileOp::getCanonicalizationPatterns(RewritePatternSet &results,
3720 MLIRContext *context) {
3721 results.add<WhileConditionTruth, WhileCmpCond, WhileOpAlignBeforeArgs,
3722 WhileMoveIfDown>(context);
3724 results, WhileOp::getOperationName());
3726 WhileOp::getOperationName());
3727}
3728
3729//===----------------------------------------------------------------------===//
3730// IndexSwitchOp
3731//===----------------------------------------------------------------------===//
3732
3733/// Parse the case regions and values.
3734static ParseResult
3736 SmallVectorImpl<std::unique_ptr<Region>> &caseRegions) {
3737 SmallVector<int64_t> caseValues;
3738 while (succeeded(p.parseOptionalKeyword("case"))) {
3739 int64_t value;
3740 Region &region = *caseRegions.emplace_back(std::make_unique<Region>());
3741 if (p.parseInteger(value) || p.parseRegion(region, /*arguments=*/{}))
3742 return failure();
3743 caseValues.push_back(value);
3744 }
3745 cases = p.getBuilder().getDenseI64ArrayAttr(caseValues);
3746 return success();
3747}
3748
3749/// Print the case regions and values.
3751 DenseI64ArrayAttr cases, RegionRange caseRegions) {
3752 for (auto [value, region] : llvm::zip(cases.asArrayRef(), caseRegions)) {
3753 p.printNewline();
3754 p << "case " << value << ' ';
3755 p.printRegion(*region, /*printEntryBlockArgs=*/false);
3756 }
3757}
3758
3759LogicalResult scf::IndexSwitchOp::verify() {
3760 if (getCases().size() != getCaseRegions().size()) {
3761 return emitOpError("has ")
3762 << getCaseRegions().size() << " case regions but "
3763 << getCases().size() << " case values";
3764 }
3765
3766 DenseSet<int64_t> valueSet;
3767 for (int64_t value : getCases())
3768 if (!valueSet.insert(value).second)
3769 return emitOpError("has duplicate case value: ") << value;
3770 auto verifyRegion = [&](Region &region, const Twine &name) -> LogicalResult {
3771 auto yield = dyn_cast<YieldOp>(region.front().back());
3772 if (!yield)
3773 return emitOpError("expected region to end with scf.yield, but got ")
3774 << region.front().back().getName();
3775
3776 if (yield.getNumOperands() != getNumResults()) {
3777 return (emitOpError("expected each region to return ")
3778 << getNumResults() << " values, but " << name << " returns "
3779 << yield.getNumOperands())
3780 .attachNote(yield.getLoc())
3781 << "see yield operation here";
3782 }
3783 for (auto [idx, result, operand] :
3784 llvm::enumerate(getResultTypes(), yield.getOperands())) {
3785 if (!operand)
3786 return yield.emitOpError() << "operand " << idx << " is null\n";
3787 if (result == operand.getType())
3788 continue;
3789 return (emitOpError("expected result #")
3790 << idx << " of each region to be " << result)
3791 .attachNote(yield.getLoc())
3792 << name << " returns " << operand.getType() << " here";
3793 }
3794 return success();
3795 };
3796
3797 if (failed(verifyRegion(getDefaultRegion(), "default region")))
3798 return failure();
3799 for (auto [idx, caseRegion] : llvm::enumerate(getCaseRegions()))
3800 if (failed(verifyRegion(caseRegion, "case region #" + Twine(idx))))
3801 return failure();
3802
3803 return success();
3804}
3805
3806unsigned scf::IndexSwitchOp::getNumCases() { return getCases().size(); }
3807
3808Block &scf::IndexSwitchOp::getDefaultBlock() {
3809 return getDefaultRegion().front();
3810}
3811
3812Block &scf::IndexSwitchOp::getCaseBlock(unsigned idx) {
3813 assert(idx < getNumCases() && "case index out-of-bounds");
3814 return getCaseRegions()[idx].front();
3815}
3816
3817void IndexSwitchOp::getSuccessorRegions(
3818 RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &successors) {
3819 // All regions branch back to the parent op.
3820 if (!point.isParent()) {
3821 successors.push_back(RegionSuccessor::parent());
3822 return;
3823 }
3824
3825 llvm::append_range(successors, getRegions());
3826}
3827
3828ValueRange IndexSwitchOp::getSuccessorInputs(RegionSuccessor successor) {
3829 return successor.isParent() ? ValueRange(getOperation()->getResults())
3830 : ValueRange();
3831}
3832
3833void IndexSwitchOp::getEntrySuccessorRegions(
3834 ArrayRef<Attribute> operands,
3835 SmallVectorImpl<RegionSuccessor> &successors) {
3836 FoldAdaptor adaptor(operands, *this);
3837
3838 // If a constant was not provided, all regions are possible successors.
3839 auto arg = dyn_cast_or_null<IntegerAttr>(adaptor.getArg());
3840 if (!arg) {
3841 llvm::append_range(successors, getRegions());
3842 return;
3843 }
3844
3845 // Otherwise, try to find a case with a matching value. If not, the
3846 // default region is the only successor.
3847 for (auto [caseValue, caseRegion] : llvm::zip(getCases(), getCaseRegions())) {
3848 if (caseValue == arg.getInt()) {
3849 successors.emplace_back(&caseRegion);
3850 return;
3851 }
3852 }
3853 successors.emplace_back(&getDefaultRegion());
3854}
3855
3856void IndexSwitchOp::getRegionInvocationBounds(
3857 ArrayRef<Attribute> operands, SmallVectorImpl<InvocationBounds> &bounds) {
3858 auto operandValue = llvm::dyn_cast_or_null<IntegerAttr>(operands.front());
3859 if (!operandValue) {
3860 // All regions are invoked at most once.
3861 bounds.append(getNumRegions(), InvocationBounds(/*lb=*/0, /*ub=*/1));
3862 return;
3863 }
3864
3865 unsigned liveIndex = getNumRegions() - 1;
3866 const auto *it = llvm::find(getCases(), operandValue.getInt());
3867 if (it != getCases().end())
3868 liveIndex = std::distance(getCases().begin(), it);
3869 for (unsigned i = 0, e = getNumRegions(); i < e; ++i)
3870 bounds.emplace_back(/*lb=*/0, /*ub=*/i == liveIndex);
3871}
3872
3873void IndexSwitchOp::getCanonicalizationPatterns(RewritePatternSet &results,
3874 MLIRContext *context) {
3876 results, IndexSwitchOp::getOperationName());
3878 results, IndexSwitchOp::getOperationName());
3879}
3880
3881//===----------------------------------------------------------------------===//
3882// TableGen'd op method definitions
3883//===----------------------------------------------------------------------===//
3884
3885#define GET_OP_CLASSES
3886#include "mlir/Dialect/SCF/IR/SCFOps.cpp.inc"
return success()
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
static std::optional< int64_t > getUpperBound(Value iv)
Gets the constant upper bound on an affine.for iv.
static std::optional< int64_t > getLowerBound(Value iv)
Gets the constant lower bound on an iv.
static LogicalResult verifyRegion(emitc::SwitchOp op, Region &region, const Twine &name)
Definition EmitC.cpp:1497
static ParseResult parseSwitchCases(OpAsmParser &parser, DenseI64ArrayAttr &cases, SmallVectorImpl< std::unique_ptr< Region > > &caseRegions)
Parse the case regions and values.
Definition EmitC.cpp:1472
static void printSwitchCases(OpAsmPrinter &p, Operation *op, DenseI64ArrayAttr cases, RegionRange caseRegions)
Print the case regions and values.
Definition EmitC.cpp:1488
static LogicalResult verifyTypeRangesMatch(OpTy op, TypeRange left, TypeRange right, StringRef message)
Verifies that two ranges of types match, i.e.
Definition SCF.cpp:3358
static void printInitializationList(OpAsmPrinter &p, Block::BlockArgListType blocksArgs, ValueRange initializers, StringRef prefix="")
Prints the initialization list in the form of <prefix>(inner = outer, inner2 = outer2,...
Definition SCF.cpp:496
static TerminatorTy verifyAndGetTerminator(Operation *op, Region &region, StringRef errorMessage)
Verifies that the first block of the given region is terminated by a TerminatorTy.
Definition SCF.cpp:101
static bool isLegalToInline(InlinerInterface &interface, Region *src, Region *insertRegion, bool shouldCloneInlinedRegion, IRMapping &valueMapping)
Utility to check that all of the operations within 'src' can be inlined.
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
b getContext())
*if copies could not be generated due to yet unimplemented cases *copyInPlacementStart and copyOutPlacementStart in copyPlacementBlock *specify the insertion points where the incoming copies and outgoing should be the output argument nBegin is set to its * replacement(set to `begin` if no invalidation happens). Since outgoing *copies could have been inserted at `end`
@ None
static std::string diag(const llvm::Value &value)
@ Paren
Parens surrounding zero or more operands.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseOptionalColon()=0
Parse a : token if present.
ParseResult parseInteger(IntT &result)
Parse an integer value from the stream.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual ParseResult parseOptionalAttrDictWithKeyword(NamedAttrList &result)=0
Parse a named dictionary into 'result' if the attributes keyword is present.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseOptionalArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an optional arrow followed by a type list.
virtual ParseResult parseArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an arrow followed by a type list.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
void printOptionalArrowTypeList(TypeRange &&types)
Print an optional arrow followed by a type list.
virtual void printNewline()
Print a newline and indent the printer to the start of the current operation/attribute/type.
This class represents an argument of a Block.
Definition Value.h:309
unsigned getArgNumber() const
Returns the number of this argument.
Definition Value.h:321
Block represents an ordered list of Operations.
Definition Block.h:33
MutableArrayRef< BlockArgument > BlockArgListType
Definition Block.h:95
bool empty()
Definition Block.h:158
BlockArgument getArgument(unsigned i)
Definition Block.h:139
unsigned getNumArguments()
Definition Block.h:138
iterator_range< args_iterator > addArguments(TypeRange types, ArrayRef< Location > locs)
Add one argument to the argument list for each type specified in the list.
Definition Block.cpp:165
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
Definition Block.cpp:27
Operation & front()
Definition Block.h:163
Operation & back()
Definition Block.h:162
Operation * getTerminator()
Get the terminator operation of this block.
Definition Block.cpp:249
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
Definition Block.cpp:158
BlockArgListType getArguments()
Definition Block.h:97
iterator_range< iterator > without_terminator()
Return an iterator range over the operation within this block excluding the terminator operation at t...
Definition Block.h:222
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
Definition Block.cpp:31
bool getValue() const
Return the boolean value of this attribute.
UnitAttr getUnitAttr()
Definition Builders.cpp:102
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
Definition Builders.cpp:167
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition Builders.cpp:232
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
Definition Builders.cpp:171
IntegerType getIntegerType(unsigned width)
Definition Builders.cpp:71
BoolAttr getBoolAttr(bool value)
Definition Builders.cpp:104
IntegerType getI1Type()
Definition Builders.cpp:57
MLIRContext * getContext() const
Definition Builders.h:56
IndexType getIndexType()
Definition Builders.cpp:55
virtual Operation * materializeConstant(OpBuilder &builder, Attribute value, Type type, Location loc)
Registered hook to materialize a single constant operation from a given attribute value with the desi...
Definition Dialect.h:83
This is a utility class for mapping one set of IR entities to another.
Definition IRMapping.h:26
auto lookupOrDefault(T from) const
Lookup a mapped value within the map.
Definition IRMapping.h:65
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Definition IRMapping.h:30
IRValueT get() const
Return the current value being used by this operand.
void set(IRValueT newValue)
Set the current value being used by this operand.
This class represents a diagnostic that is inflight and set to be reported.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
This class provides a mutable adaptor for a range of operands.
Definition ValueRange.h:118
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual OptionalParseResult parseOptionalAssignmentList(SmallVectorImpl< Argument > &lhs, SmallVectorImpl< UnresolvedOperand > &rhs)=0
virtual ParseResult parseRegion(Region &region, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult parseArgumentList(SmallVectorImpl< Argument > &result, Delimiter delimiter=Delimiter::None, bool allowType=false, bool allowAttrs=false)=0
Parse zero or more arguments with a specified surrounding delimiter.
ParseResult parseAssignmentList(SmallVectorImpl< Argument > &lhs, SmallVectorImpl< UnresolvedOperand > &rhs)
Parse a list of assignments of the form (x1 = y1, x2 = y2, ...)
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printOptionalAttrDictWithKeyword(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary prefixed with 'attribute...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
void printFunctionalType(Operation *op)
Print the complete type of an operation in functional form.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
RAII guard to reset the insertion point of the builder when destroyed.
Definition Builders.h:350
This class helps build Operations.
Definition Builders.h:209
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition Builders.cpp:434
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition Builders.cpp:566
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition Builders.h:433
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition Builders.h:400
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Definition Builders.h:438
void cloneRegionBefore(Region &region, Region &parent, Region::iterator before, IRMapping &mapping)
Clone the blocks that belong to "region" before the given position in another region "parent".
Definition Builders.cpp:593
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition Builders.h:414
Operation * cloneWithoutRegions(Operation &op, IRMapping &mapper)
Creates a deep copy of this operation but keep the operation regions empty.
Definition Builders.h:596
This class represents a single result from folding an operation.
This class represents an operand of an operation.
Definition Value.h:257
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Definition Value.cpp:226
Set of flags used to control the behavior of the various IR print methods (e.g.
A wrapper class that allows for printing an operation with a set of flags, useful to act as a "stream...
Definition Operation.h:1140
This class implements the operand iterators for the Operation class.
Definition ValueRange.h:43
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Definition Operation.h:541
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition Operation.h:436
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:244
OperandRange operand_range
Definition Operation.h:400
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
Definition Operation.h:259
OperationName getName()
The name of an operation is the key identifier for it.
Definition Operation.h:119
Region * getParentRegion()
Returns the region to which the instruction belongs.
Definition Operation.h:251
bool isProperAncestor(Operation *other)
Return true if this operation is a proper ancestor of the other operation.
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
ParseResult value() const
Access the internal ParseResult value.
bool has_value() const
Returns true if we contain a valid ParseResult value.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class represents a point being branched from in the methods of the RegionBranchOpInterface.
bool isParent() const
Returns true if branching from the parent op.
RegionBranchTerminatorOpInterface getTerminatorPredecessorOrNull() const
Returns the terminator if branching from a region.
This class provides an abstraction over the different types of ranges over Regions.
Definition Region.h:357
This class represents a successor of a region.
static RegionSuccessor parent()
Initialize a successor that branches after/out of the parent operation.
bool isParent() const
Return true if the successor is the parent operation.
Region * getSuccessor() const
Return the given region successor.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition Region.h:26
Block & front()
Definition Region.h:65
Region * getParentRegion()
Return the region containing this region or nullptr if the region is attached to a top-level operatio...
Definition Region.cpp:45
bool isAncestor(Region *other)
Return true if this region is ancestor of the other region.
Definition Region.h:233
Block & back()
Definition Region.h:64
bool empty()
Definition Region.h:60
unsigned getNumArguments()
Definition Region.h:123
iterator begin()
Definition Region.h:55
BlockArgument getArgument(unsigned i)
Definition Region.h:124
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void eraseBlock(Block *block)
This method erases all operations in a block.
Block * splitBlock(Block *block, Block::iterator before)
Split the operations starting at "before" (inclusive) out of the given block into a new block,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void finalizeOpModification(Operation *op)
This method is used to signal the end of an in-place modification of the given operation.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
Operation * eraseOpResults(Operation *op, const BitVector &eraseIndices)
Erase the specified results of the given operation.
virtual void replaceUsesWithIf(Value from, Value to, function_ref< bool(OpOperand &)> functor, bool *allUsesReplaced=nullptr)
Find uses of from and replace them with to if the functor returns true.
virtual void inlineBlockBefore(Block *source, Block *dest, Block::iterator before, ValueRange argValues={})
Inline the operations of block 'source' into block 'dest' before the given position.
void mergeBlocks(Block *source, Block *dest, ValueRange argValues={})
Inline the operations of block 'source' into the end of block 'dest'.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
void inlineRegionBefore(Region &region, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
virtual void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
virtual void startOpModification(Operation *op)
This method is used to notify the rewriter that an in-place operation modification is about to happen...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class provides an abstraction over the various different ranges of value types.
Definition TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
bool isSignlessInteger() const
Return true if this is a signless integer type (with the specified width).
Definition Types.cpp:66
bool isIndex() const
Definition Types.cpp:56
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:387
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
user_range getUsers() const
Definition Value.h:218
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
Region * getParentRegion()
Return the Region in which this Value is defined.
Definition Value.cpp:39
Operation * getOwner() const
Return the owner of this operand.
Definition UseDefLists.h:38
constexpr auto RecursivelySpeculatable
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto NotSpeculatable
static Value defaultReplBuilderFn(OpBuilder &builder, Location loc, Value value)
Default implementation of the non-successor-input replacement builder function.
StringRef getMappingAttrName()
Name of the mapping attribute produced by loop mappers.
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
ParallelOp getParallelForInductionVarOwner(Value val)
Returns the parallel loop parent of an induction variable.
Definition SCF.cpp:2953
void buildTerminatedBody(OpBuilder &builder, Location loc)
Default callback for IfOp builders. Inserts a yield without arguments.
Definition SCF.cpp:94
LoopNest buildLoopNest(OpBuilder &builder, Location loc, ValueRange lbs, ValueRange ubs, ValueRange steps, ValueRange iterArgs, function_ref< ValueVector(OpBuilder &, Location, ValueRange, ValueRange)> bodyBuilder=nullptr)
Creates a perfect nest of "for" loops, i.e.
Definition SCF.cpp:793
bool insideMutuallyExclusiveBranches(Operation *a, Operation *b)
Return true if ops a and b (or their ancestors) are in mutually exclusive regions/blocks of an IfOp.
Definition SCF.cpp:1902
void promote(RewriterBase &rewriter, scf::ForallOp forallOp)
Promotes the loop body of a scf::ForallOp to its containing block.
Definition SCF.cpp:748
std::optional< llvm::APSInt > computeUbMinusLb(Value lb, Value ub, bool isSigned)
Helper function to compute the difference between two values.
Definition SCF.cpp:115
ForOp getForInductionVarOwner(Value val)
Returns the loop parent of an induction variable.
Definition SCF.cpp:675
SmallVector< Value > ValueVector
An owning vector of values, handy to return from functions.
Definition SCF.h:64
llvm::function_ref< Value(OpBuilder &, Location loc, Type, Value)> ValueTypeCastFnTy
Perform a replacement of one iter OpOperand of an scf.for to the replacement value with a different t...
Definition SCF.h:107
ForallOp getForallOpThreadIndexOwner(Value val)
Returns the ForallOp parent of an thread index variable.
Definition SCF.cpp:1388
SmallVector< Value > replaceAndCastForOpIterArg(RewriterBase &rewriter, scf::ForOp forOp, OpOperand &operand, Value replacement, const ValueTypeCastFnTy &castFn)
Definition SCF.cpp:882
bool preservesStaticInformation(Type source, Type target)
Returns true if target is a ranked tensor type that preserves static information available in the sou...
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition Matchers.h:490
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, MLIRContext *context)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
Definition Matchers.h:527
detail::DenseArrayAttrImpl< int64_t > DenseI64ArrayAttr
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
std::function< SmallVector< Value >( OpBuilder &b, Location loc, ArrayRef< BlockArgument > newBbArgs)> NewYieldValuesFn
A function that returns the additional yielded values during replaceWithAdditionalYields.
ParseResult parseDynamicIndexList(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &values, DenseI64ArrayAttr &integers, DenseBoolArrayAttr &scalableFlags, SmallVectorImpl< Type > *valueTypes=nullptr, AsmParser::Delimiter delimiter=AsmParser::Delimiter::Square)
Parser hooks for custom directive in assemblyFormat.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:305
llvm::DenseSet< ValueT, ValueInfoT > DenseSet
Definition LLVM.h:120
llvm::SetVector< T, Vector, Set, N > SetVector
Definition LLVM.h:123
detail::constant_int_predicate_matcher m_One()
Matches a constant scalar / vector splat / tensor splat integer one.
Definition Matchers.h:478
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition Utils.cpp:112
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
Definition LLVM.h:118
void populateRegionBranchOpInterfaceInliningPattern(RewritePatternSet &patterns, StringRef opName, NonSuccessorInputReplacementBuilderFn replBuilderFn=detail::defaultReplBuilderFn, PatternMatcherFn matcherFn=detail::defaultMatcherFn, PatternBenefit benefit=1)
Populate a pattern that inlines the body of region branch ops when there is a single acyclic path thr...
LogicalResult verifyListOfOperandsOrIntegers(Operation *op, StringRef name, unsigned expectedNumElements, ArrayRef< int64_t > attr, ValueRange values)
Verify that a the values has as many elements as the number of entries in attr for which isDynamic ev...
void populateRegionBranchOpInterfaceCanonicalizationPatterns(RewritePatternSet &patterns, StringRef opName, PatternBenefit benefit=1)
Populate canonicalization patterns that simplify successor operands/inputs of region branch operation...
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition Matchers.h:369
void visitUsedValuesDefinedAbove(Region &region, Region &limit, function_ref< void(OpOperand *)> callback)
Calls callback for each use of a value within region or its descendants that was defined at the ances...
llvm::function_ref< Fn > function_ref
Definition LLVM.h:144
SmallVector< NamedAttribute > getPrunedAttributeList(Operation *op, ArrayRef< StringRef > elidedAttrs)
void printDynamicIndexList(OpAsmPrinter &printer, Operation *op, OperandRange values, ArrayRef< int64_t > integers, ArrayRef< bool > scalableFlags, TypeRange valueTypes=TypeRange(), AsmParser::Delimiter delimiter=AsmParser::Delimiter::Square)
Printer hooks for custom directive in assemblyFormat.
std::optional< APInt > constantTripCount(OpFoldResult lb, OpFoldResult ub, OpFoldResult step, bool isSigned, llvm::function_ref< std::optional< llvm::APSInt >(Value, Value, bool)> computeUbMinusLb)
Return the number of iterations for a loop with a lower bound lb, upper bound ub and step step,...
LogicalResult foldDynamicIndexList(SmallVectorImpl< OpFoldResult > &ofrs, bool onlyNonNegative=false, bool onlyNonZero=false)
Returns "success" when any of the elements in ofrs is a constant value.
LogicalResult matchAndRewrite(ExecuteRegionOp op, PatternRewriter &rewriter) const override
Definition SCF.cpp:222
This is the representation of an operand reference.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
This represents an operation in an abstracted form, suitable for use with the builder APIs.
void addOperands(ValueRange newOperands)
void addTypes(ArrayRef< Type > newTypes)
Region * addRegion()
Create a region that should be attached to the operation.