MLIR 23.0.0git
SCF.cpp
Go to the documentation of this file.
1//===- SCF.cpp - Structured Control Flow Operations -----------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
20#include "mlir/IR/IRMapping.h"
21#include "mlir/IR/Matchers.h"
22#include "mlir/IR/Operation.h"
30#include "llvm/ADT/MapVector.h"
31#include "llvm/ADT/STLExtras.h"
32#include "llvm/ADT/SmallPtrSet.h"
33#include "llvm/Support/Casting.h"
34#include "llvm/Support/DebugLog.h"
35#include <optional>
36
37using namespace mlir;
38using namespace mlir::scf;
39
40#include "mlir/Dialect/SCF/IR/SCFOpsDialect.cpp.inc"
41
42//===----------------------------------------------------------------------===//
43// SCFDialect Dialect Interfaces
44//===----------------------------------------------------------------------===//
45
46namespace {
47struct SCFInlinerInterface : public DialectInlinerInterface {
48 using DialectInlinerInterface::DialectInlinerInterface;
49 // We don't have any special restrictions on what can be inlined into
50 // destination regions (e.g. while/conditional bodies). Always allow it.
51 bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned,
52 IRMapping &valueMapping) const final {
53 return true;
54 }
55 // Operations in scf dialect are always legal to inline since they are
56 // pure.
57 bool isLegalToInline(Operation *, Region *, bool, IRMapping &) const final {
58 return true;
59 }
60 // Handle the given inlined terminator by replacing it with a new operation
61 // as necessary. Required when the region has only one block.
62 void handleTerminator(Operation *op, ValueRange valuesToRepl) const final {
63 auto retValOp = dyn_cast<scf::YieldOp>(op);
64 if (!retValOp)
65 return;
66
67 for (auto retValue : llvm::zip(valuesToRepl, retValOp.getOperands())) {
68 std::get<0>(retValue).replaceAllUsesWith(std::get<1>(retValue));
69 }
70 }
71};
72} // namespace
73
74//===----------------------------------------------------------------------===//
75// SCFDialect
76//===----------------------------------------------------------------------===//
77
78void SCFDialect::initialize() {
79 addOperations<
80#define GET_OP_LIST
81#include "mlir/Dialect/SCF/IR/SCFOps.cpp.inc"
82 >();
83 addInterfaces<SCFInlinerInterface>();
84 declarePromisedInterface<ConvertToEmitCPatternInterface, SCFDialect>();
85 declarePromisedInterfaces<bufferization::BufferDeallocationOpInterface,
86 InParallelOp, ReduceReturnOp>();
87 declarePromisedInterfaces<bufferization::BufferizableOpInterface, ConditionOp,
88 ExecuteRegionOp, ForOp, IfOp, IndexSwitchOp,
89 ForallOp, InParallelOp, WhileOp, YieldOp>();
90 declarePromisedInterface<ValueBoundsOpInterface, ForOp>();
91}
92
93/// Default callback for IfOp builders. Inserts a yield without arguments.
95 scf::YieldOp::create(builder, loc);
96}
97
98/// Verifies that the first block of the given `region` is terminated by a
99/// TerminatorTy. Reports errors on the given operation if it is not the case.
100template <typename TerminatorTy>
101static TerminatorTy verifyAndGetTerminator(Operation *op, Region &region,
102 StringRef errorMessage) {
103 Operation *terminatorOperation = nullptr;
104 if (!region.empty() && !region.front().empty()) {
105 terminatorOperation = &region.front().back();
106 if (auto yield = dyn_cast_or_null<TerminatorTy>(terminatorOperation))
107 return yield;
108 }
109 auto diag = op->emitOpError(errorMessage);
110 if (terminatorOperation)
111 diag.attachNote(terminatorOperation->getLoc()) << "terminator here";
112 return nullptr;
113}
114
115std::optional<llvm::APSInt> mlir::scf::computeUbMinusLb(Value lb, Value ub,
116 bool isSigned) {
117 llvm::APSInt diff;
118 auto addOp = ub.getDefiningOp<arith::AddIOp>();
119 if (!addOp)
120 return std::nullopt;
121 if ((isSigned && !addOp.hasNoSignedWrap()) ||
122 (!isSigned && !addOp.hasNoUnsignedWrap()))
123 return std::nullopt;
124
125 if (addOp.getLhs() != lb ||
126 !matchPattern(addOp.getRhs(), m_ConstantInt(&diff)))
127 return std::nullopt;
128 return diff;
129}
130
131//===----------------------------------------------------------------------===//
132// ExecuteRegionOp
133//===----------------------------------------------------------------------===//
134
135///
136/// (ssa-id `=`)? `execute_region` `->` function-result-type `{`
137/// block+
138/// `}`
139///
140/// Example:
141/// scf.execute_region -> i32 {
142/// %idx = load %rI[%i] : memref<128xi32>
143/// return %idx : i32
144/// }
145///
146ParseResult ExecuteRegionOp::parse(OpAsmParser &parser,
148 if (parser.parseOptionalArrowTypeList(result.types))
149 return failure();
150
151 if (succeeded(parser.parseOptionalKeyword("no_inline")))
152 result.addAttribute("no_inline", parser.getBuilder().getUnitAttr());
153
154 // Introduce the body region and parse it.
155 Region *body = result.addRegion();
156 if (parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{}) ||
157 parser.parseOptionalAttrDict(result.attributes))
158 return failure();
159
160 return success();
161}
162
163void ExecuteRegionOp::print(OpAsmPrinter &p) {
164 p.printOptionalArrowTypeList(getResultTypes());
165 p << ' ';
166 if (getNoInline())
167 p << "no_inline ";
168 p.printRegion(getRegion(),
169 /*printEntryBlockArgs=*/false,
170 /*printBlockTerminators=*/true);
171 p.printOptionalAttrDict((*this)->getAttrs(), /*elidedAttrs=*/{"no_inline"});
172}
173
174LogicalResult ExecuteRegionOp::verify() {
175 if (getRegion().empty())
176 return emitOpError("region needs to have at least one block");
177 if (getRegion().front().getNumArguments() > 0)
178 return emitOpError("region cannot have any arguments");
179 return success();
180}
181
182// Inline an ExecuteRegionOp if its parent can contain multiple blocks.
183// TODO generalize the conditions for operations which can be inlined into.
184// func @func_execute_region_elim() {
185// "test.foo"() : () -> ()
186// %v = scf.execute_region -> i64 {
187// %c = "test.cmp"() : () -> i1
188// cf.cond_br %c, ^bb2, ^bb3
189// ^bb2:
190// %x = "test.val1"() : () -> i64
191// cf.br ^bb4(%x : i64)
192// ^bb3:
193// %y = "test.val2"() : () -> i64
194// cf.br ^bb4(%y : i64)
195// ^bb4(%z : i64):
196// scf.yield %z : i64
197// }
198// "test.bar"(%v) : (i64) -> ()
199// return
200// }
201//
202// becomes
203//
204// func @func_execute_region_elim() {
205// "test.foo"() : () -> ()
206// %c = "test.cmp"() : () -> i1
207// cf.cond_br %c, ^bb1, ^bb2
208// ^bb1: // pred: ^bb0
209// %x = "test.val1"() : () -> i64
210// cf.br ^bb3(%x : i64)
211// ^bb2: // pred: ^bb0
212// %y = "test.val2"() : () -> i64
213// cf.br ^bb3(%y : i64)
214// ^bb3(%z: i64): // 2 preds: ^bb1, ^bb2
215// "test.bar"(%z) : (i64) -> ()
216// return
217// }
218//
219struct MultiBlockExecuteInliner : public OpRewritePattern<ExecuteRegionOp> {
220 using OpRewritePattern<ExecuteRegionOp>::OpRewritePattern;
221
222 LogicalResult matchAndRewrite(ExecuteRegionOp op,
223 PatternRewriter &rewriter) const override {
224 if (op.getNoInline())
225 return failure();
226 if (!isa<FunctionOpInterface, ExecuteRegionOp>(op->getParentOp()))
227 return failure();
228
229 Block *prevBlock = op->getBlock();
230 Block *postBlock = rewriter.splitBlock(prevBlock, op->getIterator());
231 rewriter.setInsertionPointToEnd(prevBlock);
232
233 cf::BranchOp::create(rewriter, op.getLoc(), &op.getRegion().front());
234
235 for (Block &blk : op.getRegion()) {
236 if (YieldOp yieldOp = dyn_cast<YieldOp>(blk.getTerminator())) {
237 rewriter.setInsertionPoint(yieldOp);
238 cf::BranchOp::create(rewriter, yieldOp.getLoc(), postBlock,
239 yieldOp.getResults());
240 rewriter.eraseOp(yieldOp);
241 }
242 }
243
244 rewriter.inlineRegionBefore(op.getRegion(), postBlock);
245 SmallVector<Value> blockArgs;
246
247 for (auto res : op.getResults())
248 blockArgs.push_back(postBlock->addArgument(res.getType(), res.getLoc()));
249
250 rewriter.replaceOp(op, blockArgs);
251 return success();
252 }
253};
254
255void ExecuteRegionOp::getCanonicalizationPatterns(RewritePatternSet &results,
256 MLIRContext *context) {
257 results.add<MultiBlockExecuteInliner>(context);
259 results, ExecuteRegionOp::getOperationName());
260 // Inline ops with a single block that are not marked as "no_inline".
262 results, ExecuteRegionOp::getOperationName(),
264 return failure(cast<ExecuteRegionOp>(op).getNoInline());
265 });
266}
267
268void ExecuteRegionOp::getSuccessorRegions(
270 // If the predecessor is the ExecuteRegionOp, branch into the body.
271 if (point.isParent()) {
272 regions.push_back(RegionSuccessor(&getRegion()));
273 return;
274 }
275
276 // Otherwise, the region branches back to the parent operation.
277 regions.push_back(RegionSuccessor::parent());
278}
279
280ValueRange ExecuteRegionOp::getSuccessorInputs(RegionSuccessor successor) {
281 return successor.isParent() ? ValueRange(getOperation()->getResults())
282 : ValueRange();
283}
284
285//===----------------------------------------------------------------------===//
286// ConditionOp
287//===----------------------------------------------------------------------===//
288
290ConditionOp::getMutableSuccessorOperands(RegionSuccessor point) {
291 assert(
292 (point.isParent() || point.getSuccessor() == &getParentOp().getAfter()) &&
293 "condition op can only exit the loop or branch to the after"
294 "region");
295 // Pass all operands except the condition to the successor region.
296 return getArgsMutable();
297}
298
299void ConditionOp::getSuccessorRegions(
301 FoldAdaptor adaptor(operands, *this);
302
303 WhileOp whileOp = getParentOp();
304
305 // Condition can either lead to the after region or back to the parent op
306 // depending on whether the condition is true or not.
307 auto boolAttr = dyn_cast_or_null<BoolAttr>(adaptor.getCondition());
308 if (!boolAttr || boolAttr.getValue())
309 regions.emplace_back(&whileOp.getAfter());
310 if (!boolAttr || !boolAttr.getValue())
311 regions.push_back(RegionSuccessor::parent());
312}
313
314//===----------------------------------------------------------------------===//
315// ForOp
316//===----------------------------------------------------------------------===//
317
318void ForOp::build(OpBuilder &builder, OperationState &result, Value lb,
319 Value ub, Value step, ValueRange initArgs,
320 BodyBuilderFn bodyBuilder, bool unsignedCmp) {
321 OpBuilder::InsertionGuard guard(builder);
322
323 if (unsignedCmp)
324 result.addAttribute(getUnsignedCmpAttrName(result.name),
325 builder.getUnitAttr());
326 result.addOperands({lb, ub, step});
327 result.addOperands(initArgs);
328 for (Value v : initArgs)
329 result.addTypes(v.getType());
330 Type t = lb.getType();
331 Region *bodyRegion = result.addRegion();
332 Block *bodyBlock = builder.createBlock(bodyRegion);
333 bodyBlock->addArgument(t, result.location);
334 for (Value v : initArgs)
335 bodyBlock->addArgument(v.getType(), v.getLoc());
336
337 // Create the default terminator if the builder is not provided and if the
338 // iteration arguments are not provided. Otherwise, leave this to the caller
339 // because we don't know which values to return from the loop.
340 if (initArgs.empty() && !bodyBuilder) {
341 ForOp::ensureTerminator(*bodyRegion, builder, result.location);
342 } else if (bodyBuilder) {
343 OpBuilder::InsertionGuard guard(builder);
344 builder.setInsertionPointToStart(bodyBlock);
345 bodyBuilder(builder, result.location, bodyBlock->getArgument(0),
346 bodyBlock->getArguments().drop_front());
347 }
348}
349
350LogicalResult ForOp::verify() {
351 // Check that the body block has at least the induction variable argument.
352 // This must be checked before verifyRegions() and before any region trait
353 // verifiers (e.g. LoopLikeOpInterface) that call getRegionIterArgs(), to
354 // avoid crashing with an out-of-bounds drop_front on an empty arg list.
355 if (getBody()->getNumArguments() < getNumInductionVars())
356 return emitOpError("expected body to have at least ")
357 << getNumInductionVars()
358 << " argument(s) for the induction variable, but got "
359 << getBody()->getNumArguments();
360
361 // Check that the number of init args and op results is the same.
362 if (getInitArgs().size() != getNumResults())
363 return emitOpError(
364 "mismatch in number of loop-carried values and defined values");
365
366 return success();
367}
368
369LogicalResult ForOp::verifyRegions() {
370 // Check that the body block has at least the induction variable argument.
371 if (getBody()->getNumArguments() < getNumInductionVars())
372 return emitOpError("expected body to have at least ")
373 << getNumInductionVars() << " argument(s) for the induction "
374 << "variable, but got " << getBody()->getNumArguments();
375
376 // Check that the body defines as single block argument for the induction
377 // variable.
378 if (getInductionVar().getType() != getLowerBound().getType())
379 return emitOpError(
380 "expected induction variable to be same type as bounds and step");
381
382 if (getNumRegionIterArgs() != getNumResults())
383 return emitOpError(
384 "mismatch in number of basic block args and defined values");
385
386 auto initArgs = getInitArgs();
387 auto iterArgs = getRegionIterArgs();
388 auto opResults = getResults();
389 unsigned i = 0;
390 for (auto e : llvm::zip(initArgs, iterArgs, opResults)) {
391 if (std::get<0>(e).getType() != std::get<2>(e).getType())
392 return emitOpError() << "types mismatch between " << i
393 << "th iter operand and defined value";
394 if (std::get<1>(e).getType() != std::get<2>(e).getType())
395 return emitOpError() << "types mismatch between " << i
396 << "th iter region arg and defined value";
397
398 ++i;
399 }
400 return success();
401}
402
403std::optional<SmallVector<Value>> ForOp::getLoopInductionVars() {
404 return SmallVector<Value>{getInductionVar()};
405}
406
407std::optional<SmallVector<OpFoldResult>> ForOp::getLoopLowerBounds() {
409}
410
411std::optional<SmallVector<OpFoldResult>> ForOp::getLoopSteps() {
412 return SmallVector<OpFoldResult>{OpFoldResult(getStep())};
413}
414
415std::optional<SmallVector<OpFoldResult>> ForOp::getLoopUpperBounds() {
417}
418
419bool ForOp::isValidInductionVarType(Type type) {
420 return type.isIndex() || type.isSignlessInteger();
421}
422
423LogicalResult ForOp::setLoopLowerBounds(ArrayRef<OpFoldResult> bounds) {
424 if (bounds.size() != 1)
425 return failure();
426 if (auto val = dyn_cast<Value>(bounds[0])) {
427 setLowerBound(val);
428 return success();
429 }
430 return failure();
431}
432
433LogicalResult ForOp::setLoopUpperBounds(ArrayRef<OpFoldResult> bounds) {
434 if (bounds.size() != 1)
435 return failure();
436 if (auto val = dyn_cast<Value>(bounds[0])) {
437 setUpperBound(val);
438 return success();
439 }
440 return failure();
441}
442
443LogicalResult ForOp::setLoopSteps(ArrayRef<OpFoldResult> steps) {
444 if (steps.size() != 1)
445 return failure();
446 if (auto val = dyn_cast<Value>(steps[0])) {
447 setStep(val);
448 return success();
449 }
450 return failure();
451}
452
453std::optional<ResultRange> ForOp::getLoopResults() { return getResults(); }
454
455/// Promotes the loop body of a forOp to its containing block if the forOp
456/// it can be determined that the loop has a single iteration.
457LogicalResult ForOp::promoteIfSingleIteration(RewriterBase &rewriter) {
458 std::optional<APInt> tripCount = getStaticTripCount();
459 LDBG() << "promoteIfSingleIteration tripCount is " << tripCount
460 << " for loop "
461 << OpWithFlags(getOperation(), OpPrintingFlags().skipRegions());
462 if (!tripCount.has_value() || tripCount->getZExtValue() > 1)
463 return failure();
464
465 if (*tripCount == 0) {
466 rewriter.replaceAllUsesWith(getResults(), getInitArgs());
467 rewriter.eraseOp(*this);
468 return success();
469 }
470
471 // Replace all results with the yielded values.
472 auto yieldOp = cast<scf::YieldOp>(getBody()->getTerminator());
473 rewriter.replaceAllUsesWith(getResults(), getYieldedValues());
474
475 // Replace block arguments with lower bound (replacement for IV) and
476 // iter_args.
477 SmallVector<Value> bbArgReplacements;
478 bbArgReplacements.push_back(getLowerBound());
479 llvm::append_range(bbArgReplacements, getInitArgs());
480
481 // Move the loop body operations to the loop's containing block.
482 rewriter.inlineBlockBefore(getBody(), getOperation()->getBlock(),
483 getOperation()->getIterator(), bbArgReplacements);
484
485 // Erase the old terminator and the loop.
486 rewriter.eraseOp(yieldOp);
487 rewriter.eraseOp(*this);
488
489 return success();
490}
491
492/// Prints the initialization list in the form of
493/// <prefix>(%inner = %outer, %inner2 = %outer2, <...>)
494/// where 'inner' values are assumed to be region arguments and 'outer' values
495/// are regular SSA values.
497 Block::BlockArgListType blocksArgs,
498 ValueRange initializers,
499 StringRef prefix = "") {
500 assert(blocksArgs.size() == initializers.size() &&
501 "expected same length of arguments and initializers");
502 if (initializers.empty())
503 return;
504
505 p << prefix << '(';
506 llvm::interleaveComma(llvm::zip(blocksArgs, initializers), p, [&](auto it) {
507 p << std::get<0>(it) << " = " << std::get<1>(it);
508 });
509 p << ")";
510}
511
512void ForOp::print(OpAsmPrinter &p) {
513 if (getUnsignedCmp())
514 p << " unsigned";
515
516 p << " " << getInductionVar() << " = " << getLowerBound() << " to "
517 << getUpperBound() << " step " << getStep();
518
519 printInitializationList(p, getRegionIterArgs(), getInitArgs(), " iter_args");
520 if (!getInitArgs().empty())
521 p << " -> (" << getInitArgs().getTypes() << ')';
522 p << ' ';
523 if (Type t = getInductionVar().getType(); !t.isIndex())
524 p << " : " << t << ' ';
525 p.printRegion(getRegion(),
526 /*printEntryBlockArgs=*/false,
527 /*printBlockTerminators=*/!getInitArgs().empty());
528 p.printOptionalAttrDict((*this)->getAttrs(),
529 /*elidedAttrs=*/getUnsignedCmpAttrName().strref());
530}
531
532ParseResult ForOp::parse(OpAsmParser &parser, OperationState &result) {
533 auto &builder = parser.getBuilder();
534 Type type;
535
536 OpAsmParser::Argument inductionVariable;
538
539 if (succeeded(parser.parseOptionalKeyword("unsigned")))
540 result.addAttribute(getUnsignedCmpAttrName(result.name),
541 builder.getUnitAttr());
542
543 // Parse the induction variable followed by '='.
544 if (parser.parseOperand(inductionVariable.ssaName) || parser.parseEqual() ||
545 // Parse loop bounds.
546 parser.parseOperand(lb) || parser.parseKeyword("to") ||
547 parser.parseOperand(ub) || parser.parseKeyword("step") ||
548 parser.parseOperand(step))
549 return failure();
550
551 // Parse the optional initial iteration arguments.
554 regionArgs.push_back(inductionVariable);
555
556 bool hasIterArgs = succeeded(parser.parseOptionalKeyword("iter_args"));
557 if (hasIterArgs) {
558 // Parse assignment list and results type list.
559 if (parser.parseAssignmentList(regionArgs, operands) ||
560 parser.parseArrowTypeList(result.types))
561 return failure();
562 }
563
564 if (regionArgs.size() != result.types.size() + 1)
565 return parser.emitError(
566 parser.getNameLoc(),
567 "mismatch in number of loop-carried values and defined values");
568
569 // Parse optional type, else assume Index.
570 if (parser.parseOptionalColon())
571 type = builder.getIndexType();
572 else if (parser.parseType(type))
573 return failure();
574
575 // Set block argument types, so that they are known when parsing the region.
576 regionArgs.front().type = type;
577 for (auto [iterArg, type] :
578 llvm::zip_equal(llvm::drop_begin(regionArgs), result.types))
579 iterArg.type = type;
580
581 // Parse the body region.
582 Region *body = result.addRegion();
583 if (parser.parseRegion(*body, regionArgs))
584 return failure();
585 ForOp::ensureTerminator(*body, builder, result.location);
586
587 // Resolve input operands. This should be done after parsing the region to
588 // catch invalid IR where operands were defined inside of the region.
589 if (parser.resolveOperand(lb, type, result.operands) ||
590 parser.resolveOperand(ub, type, result.operands) ||
591 parser.resolveOperand(step, type, result.operands))
592 return failure();
593 if (hasIterArgs) {
594 for (auto argOperandType : llvm::zip_equal(llvm::drop_begin(regionArgs),
595 operands, result.types)) {
596 Type type = std::get<2>(argOperandType);
597 std::get<0>(argOperandType).type = type;
598 if (parser.resolveOperand(std::get<1>(argOperandType), type,
599 result.operands))
600 return failure();
601 }
602 }
603
604 // Parse the optional attribute list.
605 if (parser.parseOptionalAttrDict(result.attributes))
606 return failure();
607
608 return success();
609}
610
611SmallVector<Region *> ForOp::getLoopRegions() { return {&getRegion()}; }
612
613Block::BlockArgListType ForOp::getRegionIterArgs() {
614 return getBody()->getArguments().drop_front(getNumInductionVars());
615}
616
617MutableArrayRef<OpOperand> ForOp::getInitsMutable() {
618 return getInitArgsMutable();
619}
620
621FailureOr<LoopLikeOpInterface>
622ForOp::replaceWithAdditionalYields(RewriterBase &rewriter,
623 ValueRange newInitOperands,
624 bool replaceInitOperandUsesInLoop,
625 const NewYieldValuesFn &newYieldValuesFn) {
626 // Create a new loop before the existing one, with the extra operands.
627 OpBuilder::InsertionGuard g(rewriter);
628 rewriter.setInsertionPoint(getOperation());
629 auto inits = llvm::to_vector(getInitArgs());
630 inits.append(newInitOperands.begin(), newInitOperands.end());
631 scf::ForOp newLoop = scf::ForOp::create(
632 rewriter, getLoc(), getLowerBound(), getUpperBound(), getStep(), inits,
633 [](OpBuilder &, Location, Value, ValueRange) {}, getUnsignedCmp());
634 newLoop->setAttrs(getPrunedAttributeList(getOperation(), {}));
635
636 // Generate the new yield values and append them to the scf.yield operation.
637 auto yieldOp = cast<scf::YieldOp>(getBody()->getTerminator());
638 ArrayRef<BlockArgument> newIterArgs =
639 newLoop.getBody()->getArguments().take_back(newInitOperands.size());
640 {
641 OpBuilder::InsertionGuard g(rewriter);
642 rewriter.setInsertionPoint(yieldOp);
643 SmallVector<Value> newYieldedValues =
644 newYieldValuesFn(rewriter, getLoc(), newIterArgs);
645 assert(newInitOperands.size() == newYieldedValues.size() &&
646 "expected as many new yield values as new iter operands");
647 rewriter.modifyOpInPlace(yieldOp, [&]() {
648 yieldOp.getResultsMutable().append(newYieldedValues);
649 });
650 }
651
652 // Move the loop body to the new op.
653 rewriter.mergeBlocks(getBody(), newLoop.getBody(),
654 newLoop.getBody()->getArguments().take_front(
655 getBody()->getNumArguments()));
656
657 if (replaceInitOperandUsesInLoop) {
658 // Replace all uses of `newInitOperands` with the corresponding basic block
659 // arguments.
660 for (auto it : llvm::zip(newInitOperands, newIterArgs)) {
661 rewriter.replaceUsesWithIf(std::get<0>(it), std::get<1>(it),
662 [&](OpOperand &use) {
663 Operation *user = use.getOwner();
664 return newLoop->isProperAncestor(user);
665 });
666 }
667 }
668
669 // Replace the old loop.
670 rewriter.replaceOp(getOperation(),
671 newLoop->getResults().take_front(getNumResults()));
672 return cast<LoopLikeOpInterface>(newLoop.getOperation());
673}
674
676 auto ivArg = llvm::dyn_cast<BlockArgument>(val);
677 if (!ivArg)
678 return ForOp();
679 assert(ivArg.getOwner() && "unlinked block argument");
680 auto *containingOp = ivArg.getOwner()->getParentOp();
681 return dyn_cast_or_null<ForOp>(containingOp);
682}
683
684OperandRange ForOp::getEntrySuccessorOperands(RegionSuccessor successor) {
685 return getInitArgs();
686}
687
688void ForOp::getSuccessorRegions(RegionBranchPoint point,
690 if (std::optional<APInt> tripCount = getStaticTripCount()) {
691 // The loop has a known static trip count.
692 if (point.isParent()) {
693 if (*tripCount == 0) {
694 // The loop has zero iterations. It branches directly back to the
695 // parent.
696 regions.push_back(RegionSuccessor::parent());
697 } else {
698 // The loop has at least one iteration. It branches into the body.
699 regions.push_back(RegionSuccessor(&getRegion()));
700 }
701 return;
702 } else if (*tripCount == 1) {
703 // The loop has exactly 1 iteration. Therefore, it branches from the
704 // region to the parent. (No further iteration.)
705 regions.push_back(RegionSuccessor::parent());
706 return;
707 }
708 }
709
710 // Both the operation itself and the region may be branching into the body or
711 // back into the operation itself. It is possible for loop not to enter the
712 // body.
713 regions.push_back(RegionSuccessor(&getRegion()));
714 regions.push_back(RegionSuccessor::parent());
715}
716
717ValueRange ForOp::getSuccessorInputs(RegionSuccessor successor) {
718 return successor.isParent() ? ValueRange(getResults())
719 : ValueRange(getRegionIterArgs());
720}
721
722SmallVector<Region *> ForallOp::getLoopRegions() { return {&getRegion()}; }
723
724/// Promotes the loop body of a forallOp to its containing block if it can be
725/// determined that the loop has a single iteration.
726LogicalResult scf::ForallOp::promoteIfSingleIteration(RewriterBase &rewriter) {
727 for (auto [lb, ub, step] :
728 llvm::zip(getMixedLowerBound(), getMixedUpperBound(), getMixedStep())) {
729 auto tripCount =
730 constantTripCount(lb, ub, step, /*isSigned=*/true, computeUbMinusLb);
731 if (!tripCount.has_value() || *tripCount != 1)
732 return failure();
733 }
734
735 promote(rewriter, *this);
736 return success();
737}
738
739Block::BlockArgListType ForallOp::getRegionIterArgs() {
740 return getBody()->getArguments().drop_front(getRank());
741}
742
743MutableArrayRef<OpOperand> ForallOp::getInitsMutable() {
744 return getOutputsMutable();
745}
746
747/// Promotes the loop body of a scf::ForallOp to its containing block.
748void mlir::scf::promote(RewriterBase &rewriter, scf::ForallOp forallOp) {
749 OpBuilder::InsertionGuard g(rewriter);
750 scf::InParallelOp terminator = forallOp.getTerminator();
751
752 // Replace block arguments with lower bounds (replacements for IVs) and
753 // outputs.
754 SmallVector<Value> bbArgReplacements = forallOp.getLowerBound(rewriter);
755 bbArgReplacements.append(forallOp.getOutputs().begin(),
756 forallOp.getOutputs().end());
757
758 // Move the loop body operations to the loop's containing block.
759 rewriter.inlineBlockBefore(forallOp.getBody(), forallOp->getBlock(),
760 forallOp->getIterator(), bbArgReplacements);
761
762 // Replace the terminator with tensor.insert_slice ops.
763 rewriter.setInsertionPointAfter(forallOp);
764 SmallVector<Value> results;
765 results.reserve(forallOp.getResults().size());
766 for (auto &yieldingOp : terminator.getYieldingOps()) {
767 auto parallelInsertSliceOp =
768 dyn_cast<tensor::ParallelInsertSliceOp>(yieldingOp);
769 if (!parallelInsertSliceOp)
770 continue;
771
772 Value dst = parallelInsertSliceOp.getDest();
773 Value src = parallelInsertSliceOp.getSource();
774 if (llvm::isa<TensorType>(src.getType())) {
775 results.push_back(tensor::InsertSliceOp::create(
776 rewriter, forallOp.getLoc(), dst.getType(), src, dst,
777 parallelInsertSliceOp.getOffsets(), parallelInsertSliceOp.getSizes(),
778 parallelInsertSliceOp.getStrides(),
779 parallelInsertSliceOp.getStaticOffsets(),
780 parallelInsertSliceOp.getStaticSizes(),
781 parallelInsertSliceOp.getStaticStrides()));
782 } else {
783 llvm_unreachable("unsupported terminator");
784 }
785 }
786 rewriter.replaceAllUsesWith(forallOp.getResults(), results);
787
788 // Erase the old terminator and the loop.
789 rewriter.eraseOp(terminator);
790 rewriter.eraseOp(forallOp);
791}
792
794 OpBuilder &builder, Location loc, ValueRange lbs, ValueRange ubs,
795 ValueRange steps, ValueRange iterArgs,
797 bodyBuilder) {
798 assert(lbs.size() == ubs.size() &&
799 "expected the same number of lower and upper bounds");
800 assert(lbs.size() == steps.size() &&
801 "expected the same number of lower bounds and steps");
802
803 // If there are no bounds, call the body-building function and return early.
804 if (lbs.empty()) {
805 ValueVector results =
806 bodyBuilder ? bodyBuilder(builder, loc, ValueRange(), iterArgs)
807 : ValueVector();
808 assert(results.size() == iterArgs.size() &&
809 "loop nest body must return as many values as loop has iteration "
810 "arguments");
811 return LoopNest{{}, std::move(results)};
812 }
813
814 // First, create the loop structure iteratively using the body-builder
815 // callback of `ForOp::build`. Do not create `YieldOp`s yet.
816 OpBuilder::InsertionGuard guard(builder);
819 loops.reserve(lbs.size());
820 ivs.reserve(lbs.size());
821 ValueRange currentIterArgs = iterArgs;
822 Location currentLoc = loc;
823 for (unsigned i = 0, e = lbs.size(); i < e; ++i) {
824 auto loop = scf::ForOp::create(
825 builder, currentLoc, lbs[i], ubs[i], steps[i], currentIterArgs,
826 [&](OpBuilder &nestedBuilder, Location nestedLoc, Value iv,
827 ValueRange args) {
828 ivs.push_back(iv);
829 // It is safe to store ValueRange args because it points to block
830 // arguments of a loop operation that we also own.
831 currentIterArgs = args;
832 currentLoc = nestedLoc;
833 });
834 // Set the builder to point to the body of the newly created loop. We don't
835 // do this in the callback because the builder is reset when the callback
836 // returns.
837 builder.setInsertionPointToStart(loop.getBody());
838 loops.push_back(loop);
839 }
840
841 // For all loops but the innermost, yield the results of the nested loop.
842 for (unsigned i = 0, e = loops.size() - 1; i < e; ++i) {
843 builder.setInsertionPointToEnd(loops[i].getBody());
844 scf::YieldOp::create(builder, loc, loops[i + 1].getResults());
845 }
846
847 // In the body of the innermost loop, call the body building function if any
848 // and yield its results.
849 builder.setInsertionPointToStart(loops.back().getBody());
850 ValueVector results = bodyBuilder
851 ? bodyBuilder(builder, currentLoc, ivs,
852 loops.back().getRegionIterArgs())
853 : ValueVector();
854 assert(results.size() == iterArgs.size() &&
855 "loop nest body must return as many values as loop has iteration "
856 "arguments");
857 builder.setInsertionPointToEnd(loops.back().getBody());
858 scf::YieldOp::create(builder, loc, results);
859
860 // Return the loops.
861 ValueVector nestResults;
862 llvm::append_range(nestResults, loops.front().getResults());
863 return LoopNest{std::move(loops), std::move(nestResults)};
864}
865
867 OpBuilder &builder, Location loc, ValueRange lbs, ValueRange ubs,
868 ValueRange steps,
869 function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilder) {
870 // Delegate to the main function by wrapping the body builder.
871 return buildLoopNest(builder, loc, lbs, ubs, steps, {},
872 [&bodyBuilder](OpBuilder &nestedBuilder,
873 Location nestedLoc, ValueRange ivs,
875 if (bodyBuilder)
876 bodyBuilder(nestedBuilder, nestedLoc, ivs);
877 return {};
878 });
879}
880
883 OpOperand &operand, Value replacement,
884 const ValueTypeCastFnTy &castFn) {
885 assert(operand.getOwner() == forOp);
886 Type oldType = operand.get().getType(), newType = replacement.getType();
887
888 // 1. Create new iter operands, exactly 1 is replaced.
889 assert(operand.getOperandNumber() >= forOp.getNumControlOperands() &&
890 "expected an iter OpOperand");
891 assert(operand.get().getType() != replacement.getType() &&
892 "Expected a different type");
893 SmallVector<Value> newIterOperands;
894 for (OpOperand &opOperand : forOp.getInitArgsMutable()) {
895 if (opOperand.getOperandNumber() == operand.getOperandNumber()) {
896 newIterOperands.push_back(replacement);
897 continue;
898 }
899 newIterOperands.push_back(opOperand.get());
900 }
901
902 // 2. Create the new forOp shell.
903 scf::ForOp newForOp = scf::ForOp::create(
904 rewriter, forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
905 forOp.getStep(), newIterOperands, /*bodyBuilder=*/nullptr,
906 forOp.getUnsignedCmp());
907 newForOp->setAttrs(forOp->getAttrs());
908 Block &newBlock = newForOp.getRegion().front();
909 SmallVector<Value, 4> newBlockTransferArgs(newBlock.getArguments().begin(),
910 newBlock.getArguments().end());
911
912 // 3. Inject an incoming cast op at the beginning of the block for the bbArg
913 // corresponding to the `replacement` value.
914 OpBuilder::InsertionGuard g(rewriter);
915 rewriter.setInsertionPointToStart(&newBlock);
916 BlockArgument newRegionIterArg = newForOp.getTiedLoopRegionIterArg(
917 &newForOp->getOpOperand(operand.getOperandNumber()));
918 Value castIn = castFn(rewriter, newForOp.getLoc(), oldType, newRegionIterArg);
919 newBlockTransferArgs[newRegionIterArg.getArgNumber()] = castIn;
920
921 // 4. Steal the old block ops, mapping to the newBlockTransferArgs.
922 Block &oldBlock = forOp.getRegion().front();
923 rewriter.mergeBlocks(&oldBlock, &newBlock, newBlockTransferArgs);
924
925 // 5. Inject an outgoing cast op at the end of the block and yield it instead.
926 auto clonedYieldOp = cast<scf::YieldOp>(newBlock.getTerminator());
927 rewriter.setInsertionPoint(clonedYieldOp);
928 unsigned yieldIdx =
929 newRegionIterArg.getArgNumber() - forOp.getNumInductionVars();
930 Value castOut = castFn(rewriter, newForOp.getLoc(), newType,
931 clonedYieldOp.getOperand(yieldIdx));
932 SmallVector<Value> newYieldOperands = clonedYieldOp.getOperands();
933 newYieldOperands[yieldIdx] = castOut;
934 scf::YieldOp::create(rewriter, newForOp.getLoc(), newYieldOperands);
935 rewriter.eraseOp(clonedYieldOp);
936
937 // 6. Inject an outgoing cast op after the forOp.
938 rewriter.setInsertionPointAfter(newForOp);
939 SmallVector<Value> newResults = newForOp.getResults();
940 newResults[yieldIdx] =
941 castFn(rewriter, newForOp.getLoc(), oldType, newResults[yieldIdx]);
942
943 return newResults;
944}
945
946namespace {
947/// Fold scf.for iter_arg/result pairs that go through incoming/ougoing
948/// a tensor.cast op pair so as to pull the tensor.cast inside the scf.for:
949///
950/// ```
951/// %0 = tensor.cast %t0 : tensor<32x1024xf32> to tensor<?x?xf32>
952/// %1 = scf.for %i = %c0 to %c1024 step %c32 iter_args(%iter_t0 = %0)
953/// -> (tensor<?x?xf32>) {
954/// %2 = call @do(%iter_t0) : (tensor<?x?xf32>) -> tensor<?x?xf32>
955/// scf.yield %2 : tensor<?x?xf32>
956/// }
957/// use_of(%1)
958/// ```
959///
960/// folds into:
961///
962/// ```
963/// %0 = scf.for %arg2 = %c0 to %c1024 step %c32 iter_args(%arg3 = %arg0)
964/// -> (tensor<32x1024xf32>) {
965/// %2 = tensor.cast %arg3 : tensor<32x1024xf32> to tensor<?x?xf32>
966/// %3 = call @do(%2) : (tensor<?x?xf32>) -> tensor<?x?xf32>
967/// %4 = tensor.cast %3 : tensor<?x?xf32> to tensor<32x1024xf32>
968/// scf.yield %4 : tensor<32x1024xf32>
969/// }
970/// %1 = tensor.cast %0 : tensor<32x1024xf32> to tensor<?x?xf32>
971/// use_of(%1)
972/// ```
973struct ForOpTensorCastFolder : public OpRewritePattern<ForOp> {
975
976 LogicalResult matchAndRewrite(ForOp op,
977 PatternRewriter &rewriter) const override {
978 for (auto it : llvm::zip(op.getInitArgsMutable(), op.getResults())) {
979 OpOperand &iterOpOperand = std::get<0>(it);
980 auto incomingCast = iterOpOperand.get().getDefiningOp<tensor::CastOp>();
981 if (!incomingCast ||
982 incomingCast.getSource().getType() == incomingCast.getType())
983 continue;
984 // If the dest type of the cast does not preserve static information in
985 // the source type.
987 incomingCast.getDest().getType(),
988 incomingCast.getSource().getType()))
989 continue;
990 if (!std::get<1>(it).hasOneUse())
991 continue;
992
993 // Create a new ForOp with that iter operand replaced.
994 rewriter.replaceOp(
996 rewriter, op, iterOpOperand, incomingCast.getSource(),
997 [](OpBuilder &b, Location loc, Type type, Value source) {
998 return tensor::CastOp::create(b, loc, type, source);
999 }));
1000 return success();
1001 }
1002 return failure();
1003 }
1004};
1005} // namespace
1006
1007void ForOp::getCanonicalizationPatterns(RewritePatternSet &results,
1008 MLIRContext *context) {
1009 results.add<ForOpTensorCastFolder>(context);
1011 results, ForOp::getOperationName());
1013 results, ForOp::getOperationName(),
1014 /*replBuilderFn=*/[](OpBuilder &builder, Location loc, Value value) {
1015 // scf.for has only one non-successor input value: the loop induction
1016 // variable. In case of a single acyclic path through the op, the IV can
1017 // be safely replaced with the lower bound.
1018 auto blockArg = cast<BlockArgument>(value);
1019 assert(blockArg.getArgNumber() == 0 && "expected induction variable");
1020 auto forOp = cast<ForOp>(blockArg.getOwner()->getParentOp());
1021 return forOp.getLowerBound();
1022 });
1023}
1024
1025std::optional<APInt> ForOp::getConstantStep() {
1026 IntegerAttr step;
1027 if (matchPattern(getStep(), m_Constant(&step)))
1028 return step.getValue();
1029 return {};
1030}
1031
1032std::optional<MutableArrayRef<OpOperand>> ForOp::getYieldedValuesMutable() {
1033 return cast<scf::YieldOp>(getBody()->getTerminator()).getResultsMutable();
1034}
1035
1036Speculation::Speculatability ForOp::getSpeculatability() {
1037 // `scf.for (I = Start; I < End; I += 1)` terminates for all values of Start
1038 // and End.
1039 if (auto constantStep = getConstantStep())
1040 if (*constantStep == 1)
1042
1043 // For Step != 1, the loop may not terminate. We can add more smarts here if
1044 // needed.
1046}
1047
1048std::optional<APInt> ForOp::getStaticTripCount() {
1049 return constantTripCount(getLowerBound(), getUpperBound(), getStep(),
1050 /*isSigned=*/!getUnsignedCmp(), computeUbMinusLb);
1051}
1052
1053//===----------------------------------------------------------------------===//
1054// ForallOp
1055//===----------------------------------------------------------------------===//
1056
1057LogicalResult ForallOp::verify() {
1058 unsigned numLoops = getRank();
1059 // Check number of outputs.
1060 if (getNumResults() != getOutputs().size())
1061 return emitOpError("produces ")
1062 << getNumResults() << " results, but has only "
1063 << getOutputs().size() << " outputs";
1064
1065 // Check that the body defines block arguments for thread indices and outputs.
1066 auto *body = getBody();
1067 if (body->getNumArguments() != numLoops + getOutputs().size())
1068 return emitOpError("region expects ") << numLoops << " arguments";
1069 for (int64_t i = 0; i < numLoops; ++i)
1070 if (!body->getArgument(i).getType().isIndex())
1071 return emitOpError("expects ")
1072 << i << "-th block argument to be an index";
1073 for (unsigned i = 0; i < getOutputs().size(); ++i)
1074 if (body->getArgument(i + numLoops).getType() != getOutputs()[i].getType())
1075 return emitOpError("type mismatch between ")
1076 << i << "-th output and corresponding block argument";
1077 if (getMapping().has_value() && !getMapping()->empty()) {
1078 if (getDeviceMappingAttrs().size() != numLoops)
1079 return emitOpError() << "mapping attribute size must match op rank";
1080 if (failed(getDeviceMaskingAttr()))
1081 return emitOpError() << getMappingAttrName()
1082 << " supports at most one device masking attribute";
1083 }
1084
1085 // Verify mixed static/dynamic control variables.
1086 Operation *op = getOperation();
1087 if (failed(verifyListOfOperandsOrIntegers(op, "lower bound", numLoops,
1088 getStaticLowerBound(),
1089 getDynamicLowerBound())))
1090 return failure();
1091 if (failed(verifyListOfOperandsOrIntegers(op, "upper bound", numLoops,
1092 getStaticUpperBound(),
1093 getDynamicUpperBound())))
1094 return failure();
1095 if (failed(verifyListOfOperandsOrIntegers(op, "step", numLoops,
1096 getStaticStep(), getDynamicStep())))
1097 return failure();
1098
1099 return success();
1100}
1101
1102void ForallOp::print(OpAsmPrinter &p) {
1103 Operation *op = getOperation();
1104 p << " (" << getInductionVars();
1105 if (isNormalized()) {
1106 p << ") in ";
1107 printDynamicIndexList(p, op, getDynamicUpperBound(), getStaticUpperBound(),
1108 /*valueTypes=*/{}, /*scalables=*/{},
1110 } else {
1111 p << ") = ";
1112 printDynamicIndexList(p, op, getDynamicLowerBound(), getStaticLowerBound(),
1113 /*valueTypes=*/{}, /*scalables=*/{},
1115 p << " to ";
1116 printDynamicIndexList(p, op, getDynamicUpperBound(), getStaticUpperBound(),
1117 /*valueTypes=*/{}, /*scalables=*/{},
1119 p << " step ";
1120 printDynamicIndexList(p, op, getDynamicStep(), getStaticStep(),
1121 /*valueTypes=*/{}, /*scalables=*/{},
1123 }
1124 printInitializationList(p, getRegionOutArgs(), getOutputs(), " shared_outs");
1125 p << " ";
1126 if (!getRegionOutArgs().empty())
1127 p << "-> (" << getResultTypes() << ") ";
1128 p.printRegion(getRegion(),
1129 /*printEntryBlockArgs=*/false,
1130 /*printBlockTerminators=*/getNumResults() > 0);
1131 p.printOptionalAttrDict(op->getAttrs(), {getOperandSegmentSizesAttrName(),
1132 getStaticLowerBoundAttrName(),
1133 getStaticUpperBoundAttrName(),
1134 getStaticStepAttrName()});
1135}
1136
1137ParseResult ForallOp::parse(OpAsmParser &parser, OperationState &result) {
1138 OpBuilder b(parser.getContext());
1139 auto indexType = b.getIndexType();
1140
1141 // Parse an opening `(` followed by thread index variables followed by `)`
1142 // TODO: when we can refer to such "induction variable"-like handles from the
1143 // declarative assembly format, we can implement the parser as a custom hook.
1144 SmallVector<OpAsmParser::Argument, 4> ivs;
1146 return failure();
1147
1148 DenseI64ArrayAttr staticLbs, staticUbs, staticSteps;
1149 SmallVector<OpAsmParser::UnresolvedOperand> dynamicLbs, dynamicUbs,
1150 dynamicSteps;
1151 if (succeeded(parser.parseOptionalKeyword("in"))) {
1152 // Parse upper bounds.
1153 if (parseDynamicIndexList(parser, dynamicUbs, staticUbs,
1154 /*valueTypes=*/nullptr,
1156 parser.resolveOperands(dynamicUbs, indexType, result.operands))
1157 return failure();
1158
1159 unsigned numLoops = ivs.size();
1160 staticLbs = b.getDenseI64ArrayAttr(SmallVector<int64_t>(numLoops, 0));
1161 staticSteps = b.getDenseI64ArrayAttr(SmallVector<int64_t>(numLoops, 1));
1162 } else {
1163 // Parse lower bounds.
1164 if (parser.parseEqual() ||
1165 parseDynamicIndexList(parser, dynamicLbs, staticLbs,
1166 /*valueTypes=*/nullptr,
1168
1169 parser.resolveOperands(dynamicLbs, indexType, result.operands))
1170 return failure();
1171
1172 // Parse upper bounds.
1173 if (parser.parseKeyword("to") ||
1174 parseDynamicIndexList(parser, dynamicUbs, staticUbs,
1175 /*valueTypes=*/nullptr,
1177 parser.resolveOperands(dynamicUbs, indexType, result.operands))
1178 return failure();
1179
1180 // Parse step values.
1181 if (parser.parseKeyword("step") ||
1182 parseDynamicIndexList(parser, dynamicSteps, staticSteps,
1183 /*valueTypes=*/nullptr,
1185 parser.resolveOperands(dynamicSteps, indexType, result.operands))
1186 return failure();
1187 }
1188
1189 // Parse out operands and results.
1190 SmallVector<OpAsmParser::Argument, 4> regionOutArgs;
1191 SmallVector<OpAsmParser::UnresolvedOperand, 4> outOperands;
1192 SMLoc outOperandsLoc = parser.getCurrentLocation();
1193 if (succeeded(parser.parseOptionalKeyword("shared_outs"))) {
1194 if (outOperands.size() != result.types.size())
1195 return parser.emitError(outOperandsLoc,
1196 "mismatch between out operands and types");
1197 if (parser.parseAssignmentList(regionOutArgs, outOperands) ||
1198 parser.parseOptionalArrowTypeList(result.types) ||
1199 parser.resolveOperands(outOperands, result.types, outOperandsLoc,
1200 result.operands))
1201 return failure();
1202 }
1203
1204 // Parse region.
1205 SmallVector<OpAsmParser::Argument, 4> regionArgs;
1206 std::unique_ptr<Region> region = std::make_unique<Region>();
1207 for (auto &iv : ivs) {
1208 iv.type = b.getIndexType();
1209 regionArgs.push_back(iv);
1210 }
1211 for (const auto &it : llvm::enumerate(regionOutArgs)) {
1212 auto &out = it.value();
1213 out.type = result.types[it.index()];
1214 regionArgs.push_back(out);
1215 }
1216 if (parser.parseRegion(*region, regionArgs))
1217 return failure();
1218
1219 // Ensure terminator and move region.
1220 ForallOp::ensureTerminator(*region, b, result.location);
1221 result.addRegion(std::move(region));
1222
1223 // Parse the optional attribute list.
1224 if (parser.parseOptionalAttrDict(result.attributes))
1225 return failure();
1226
1227 result.addAttribute("staticLowerBound", staticLbs);
1228 result.addAttribute("staticUpperBound", staticUbs);
1229 result.addAttribute("staticStep", staticSteps);
1230 result.addAttribute("operandSegmentSizes",
1232 {static_cast<int32_t>(dynamicLbs.size()),
1233 static_cast<int32_t>(dynamicUbs.size()),
1234 static_cast<int32_t>(dynamicSteps.size()),
1235 static_cast<int32_t>(outOperands.size())}));
1236 return success();
1237}
1238
1239// Builder that takes loop bounds.
1240void ForallOp::build(
1241 mlir::OpBuilder &b, mlir::OperationState &result,
1242 ArrayRef<OpFoldResult> lbs, ArrayRef<OpFoldResult> ubs,
1243 ArrayRef<OpFoldResult> steps, ValueRange outputs,
1244 std::optional<ArrayAttr> mapping,
1245 function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilderFn) {
1246 SmallVector<int64_t> staticLbs, staticUbs, staticSteps;
1247 SmallVector<Value> dynamicLbs, dynamicUbs, dynamicSteps;
1248 dispatchIndexOpFoldResults(lbs, dynamicLbs, staticLbs);
1249 dispatchIndexOpFoldResults(ubs, dynamicUbs, staticUbs);
1250 dispatchIndexOpFoldResults(steps, dynamicSteps, staticSteps);
1251
1252 result.addOperands(dynamicLbs);
1253 result.addOperands(dynamicUbs);
1254 result.addOperands(dynamicSteps);
1255 result.addOperands(outputs);
1256 result.addTypes(TypeRange(outputs));
1257
1258 result.addAttribute(getStaticLowerBoundAttrName(result.name),
1259 b.getDenseI64ArrayAttr(staticLbs));
1260 result.addAttribute(getStaticUpperBoundAttrName(result.name),
1261 b.getDenseI64ArrayAttr(staticUbs));
1262 result.addAttribute(getStaticStepAttrName(result.name),
1263 b.getDenseI64ArrayAttr(staticSteps));
1264 result.addAttribute(
1265 "operandSegmentSizes",
1266 b.getDenseI32ArrayAttr({static_cast<int32_t>(dynamicLbs.size()),
1267 static_cast<int32_t>(dynamicUbs.size()),
1268 static_cast<int32_t>(dynamicSteps.size()),
1269 static_cast<int32_t>(outputs.size())}));
1270 if (mapping.has_value()) {
1271 result.addAttribute(ForallOp::getMappingAttrName(result.name),
1272 mapping.value());
1273 }
1274
1275 Region *bodyRegion = result.addRegion();
1276 OpBuilder::InsertionGuard g(b);
1277 b.createBlock(bodyRegion);
1278 Block &bodyBlock = bodyRegion->front();
1279
1280 // Add block arguments for indices and outputs.
1281 bodyBlock.addArguments(
1282 SmallVector<Type>(lbs.size(), b.getIndexType()),
1283 SmallVector<Location>(staticLbs.size(), result.location));
1284 bodyBlock.addArguments(
1285 TypeRange(outputs),
1286 SmallVector<Location>(outputs.size(), result.location));
1287
1288 b.setInsertionPointToStart(&bodyBlock);
1289 if (!bodyBuilderFn) {
1290 ForallOp::ensureTerminator(*bodyRegion, b, result.location);
1291 return;
1292 }
1293 bodyBuilderFn(b, result.location, bodyBlock.getArguments());
1294}
1295
1296// Builder that takes loop bounds.
1297void ForallOp::build(
1298 mlir::OpBuilder &b, mlir::OperationState &result,
1299 ArrayRef<OpFoldResult> ubs, ValueRange outputs,
1300 std::optional<ArrayAttr> mapping,
1301 function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilderFn) {
1302 unsigned numLoops = ubs.size();
1303 SmallVector<OpFoldResult> lbs(numLoops, b.getIndexAttr(0));
1304 SmallVector<OpFoldResult> steps(numLoops, b.getIndexAttr(1));
1305 build(b, result, lbs, ubs, steps, outputs, mapping, bodyBuilderFn);
1306}
1307
1308// Checks if the lbs are zeros and steps are ones.
1309bool ForallOp::isNormalized() {
1310 auto allEqual = [](ArrayRef<OpFoldResult> results, int64_t val) {
1311 return llvm::all_of(results, [&](OpFoldResult ofr) {
1312 auto intValue = getConstantIntValue(ofr);
1313 return intValue.has_value() && intValue == val;
1314 });
1315 };
1316 return allEqual(getMixedLowerBound(), 0) && allEqual(getMixedStep(), 1);
1317}
1318
1319InParallelOp ForallOp::getTerminator() {
1320 return cast<InParallelOp>(getBody()->getTerminator());
1321}
1322
1323SmallVector<Operation *> ForallOp::getCombiningOps(BlockArgument bbArg) {
1324 SmallVector<Operation *> storeOps;
1325 for (Operation *user : bbArg.getUsers()) {
1326 if (auto parallelOp = dyn_cast<ParallelCombiningOpInterface>(user)) {
1327 storeOps.push_back(parallelOp);
1328 }
1329 }
1330 return storeOps;
1331}
1332
1333SmallVector<DeviceMappingAttrInterface> ForallOp::getDeviceMappingAttrs() {
1334 SmallVector<DeviceMappingAttrInterface> res;
1335 if (!getMapping())
1336 return res;
1337 for (auto attr : getMapping()->getValue()) {
1338 auto m = dyn_cast<DeviceMappingAttrInterface>(attr);
1339 if (m)
1340 res.push_back(m);
1341 }
1342 return res;
1343}
1344
1345FailureOr<DeviceMaskingAttrInterface> ForallOp::getDeviceMaskingAttr() {
1346 DeviceMaskingAttrInterface res;
1347 if (!getMapping())
1348 return res;
1349 for (auto attr : getMapping()->getValue()) {
1350 auto m = dyn_cast<DeviceMaskingAttrInterface>(attr);
1351 if (m && res)
1352 return failure();
1353 if (m)
1354 res = m;
1355 }
1356 return res;
1357}
1358
1359bool ForallOp::usesLinearMapping() {
1360 SmallVector<DeviceMappingAttrInterface> ifaces = getDeviceMappingAttrs();
1361 if (ifaces.empty())
1362 return false;
1363 return ifaces.front().isLinearMapping();
1364}
1365
1366std::optional<SmallVector<Value>> ForallOp::getLoopInductionVars() {
1367 return SmallVector<Value>{getBody()->getArguments().take_front(getRank())};
1368}
1369
1370// Get lower bounds as OpFoldResult.
1371std::optional<SmallVector<OpFoldResult>> ForallOp::getLoopLowerBounds() {
1372 Builder b(getOperation()->getContext());
1373 return getMixedValues(getStaticLowerBound(), getDynamicLowerBound(), b);
1374}
1375
1376// Get upper bounds as OpFoldResult.
1377std::optional<SmallVector<OpFoldResult>> ForallOp::getLoopUpperBounds() {
1378 Builder b(getOperation()->getContext());
1379 return getMixedValues(getStaticUpperBound(), getDynamicUpperBound(), b);
1380}
1381
1382// Get steps as OpFoldResult.
1383std::optional<SmallVector<OpFoldResult>> ForallOp::getLoopSteps() {
1384 Builder b(getOperation()->getContext());
1385 return getMixedValues(getStaticStep(), getDynamicStep(), b);
1386}
1387
1389 auto tidxArg = llvm::dyn_cast<BlockArgument>(val);
1390 if (!tidxArg)
1391 return ForallOp();
1392 assert(tidxArg.getOwner() && "unlinked block argument");
1393 auto *containingOp = tidxArg.getOwner()->getParentOp();
1394 return dyn_cast<ForallOp>(containingOp);
1395}
1396
1397namespace {
1398/// Fold tensor.dim(forall shared_outs(... = %t)) to tensor.dim(%t).
1399struct DimOfForallOp : public OpRewritePattern<tensor::DimOp> {
1400 using OpRewritePattern<tensor::DimOp>::OpRewritePattern;
1401
1402 LogicalResult matchAndRewrite(tensor::DimOp dimOp,
1403 PatternRewriter &rewriter) const final {
1404 auto forallOp = dimOp.getSource().getDefiningOp<ForallOp>();
1405 if (!forallOp)
1406 return failure();
1407 Value sharedOut =
1408 forallOp.getTiedOpOperand(llvm::cast<OpResult>(dimOp.getSource()))
1409 ->get();
1410 rewriter.modifyOpInPlace(
1411 dimOp, [&]() { dimOp.getSourceMutable().assign(sharedOut); });
1412 return success();
1413 }
1414};
1415
1416class ForallOpControlOperandsFolder : public OpRewritePattern<ForallOp> {
1417public:
1418 using OpRewritePattern<ForallOp>::OpRewritePattern;
1419
1420 LogicalResult matchAndRewrite(ForallOp op,
1421 PatternRewriter &rewriter) const override {
1422 SmallVector<OpFoldResult> mixedLowerBound(op.getMixedLowerBound());
1423 SmallVector<OpFoldResult> mixedUpperBound(op.getMixedUpperBound());
1424 SmallVector<OpFoldResult> mixedStep(op.getMixedStep());
1425 if (failed(foldDynamicIndexList(mixedLowerBound)) &&
1426 failed(foldDynamicIndexList(mixedUpperBound)) &&
1427 failed(foldDynamicIndexList(mixedStep)))
1428 return failure();
1429
1430 rewriter.modifyOpInPlace(op, [&]() {
1431 SmallVector<Value> dynamicLowerBound, dynamicUpperBound, dynamicStep;
1432 SmallVector<int64_t> staticLowerBound, staticUpperBound, staticStep;
1433 dispatchIndexOpFoldResults(mixedLowerBound, dynamicLowerBound,
1434 staticLowerBound);
1435 op.getDynamicLowerBoundMutable().assign(dynamicLowerBound);
1436 op.setStaticLowerBound(staticLowerBound);
1437
1438 dispatchIndexOpFoldResults(mixedUpperBound, dynamicUpperBound,
1439 staticUpperBound);
1440 op.getDynamicUpperBoundMutable().assign(dynamicUpperBound);
1441 op.setStaticUpperBound(staticUpperBound);
1442
1443 dispatchIndexOpFoldResults(mixedStep, dynamicStep, staticStep);
1444 op.getDynamicStepMutable().assign(dynamicStep);
1445 op.setStaticStep(staticStep);
1446
1447 op->setAttr(ForallOp::getOperandSegmentSizeAttr(),
1448 rewriter.getDenseI32ArrayAttr(
1449 {static_cast<int32_t>(dynamicLowerBound.size()),
1450 static_cast<int32_t>(dynamicUpperBound.size()),
1451 static_cast<int32_t>(dynamicStep.size()),
1452 static_cast<int32_t>(op.getNumResults())}));
1453 });
1454 return success();
1455 }
1456};
1457
1458/// The following canonicalization pattern folds the iter arguments of
1459/// scf.forall op if :-
1460/// 1. The corresponding result has zero uses.
1461/// 2. The iter argument is NOT being modified within the loop body.
1462/// uses.
1463///
1464/// Example of first case :-
1465/// INPUT:
1466/// %res:3 = scf.forall ... shared_outs(%arg0 = %a, %arg1 = %b, %arg2 = %c)
1467/// {
1468/// ...
1469/// <SOME USE OF %arg0>
1470/// <SOME USE OF %arg1>
1471/// <SOME USE OF %arg2>
1472/// ...
1473/// scf.forall.in_parallel {
1474/// <STORE OP WITH DESTINATION %arg1>
1475/// <STORE OP WITH DESTINATION %arg0>
1476/// <STORE OP WITH DESTINATION %arg2>
1477/// }
1478/// }
1479/// return %res#1
1480///
1481/// OUTPUT:
1482/// %res:3 = scf.forall ... shared_outs(%new_arg0 = %b)
1483/// {
1484/// ...
1485/// <SOME USE OF %a>
1486/// <SOME USE OF %new_arg0>
1487/// <SOME USE OF %c>
1488/// ...
1489/// scf.forall.in_parallel {
1490/// <STORE OP WITH DESTINATION %new_arg0>
1491/// }
1492/// }
1493/// return %res
1494///
1495/// NOTE: 1. All uses of the folded shared_outs (iter argument) within the
1496/// scf.forall is replaced by their corresponding operands.
1497/// 2. Even if there are <STORE OP WITH DESTINATION *> ops within the body
1498/// of the scf.forall besides within scf.forall.in_parallel terminator,
1499/// this canonicalization remains valid. For more details, please refer
1500/// to :
1501/// https://github.com/llvm/llvm-project/pull/90189#discussion_r1589011124
1502/// 3. TODO(avarma): Generalize it for other store ops. Currently it
1503/// handles tensor.parallel_insert_slice ops only.
1504///
1505/// Example of second case :-
1506/// INPUT:
1507/// %res:2 = scf.forall ... shared_outs(%arg0 = %a, %arg1 = %b)
1508/// {
1509/// ...
1510/// <SOME USE OF %arg0>
1511/// <SOME USE OF %arg1>
1512/// ...
1513/// scf.forall.in_parallel {
1514/// <STORE OP WITH DESTINATION %arg1>
1515/// }
1516/// }
1517/// return %res#0, %res#1
1518///
1519/// OUTPUT:
1520/// %res = scf.forall ... shared_outs(%new_arg0 = %b)
1521/// {
1522/// ...
1523/// <SOME USE OF %a>
1524/// <SOME USE OF %new_arg0>
1525/// ...
1526/// scf.forall.in_parallel {
1527/// <STORE OP WITH DESTINATION %new_arg0>
1528/// }
1529/// }
1530/// return %a, %res
1531struct ForallOpIterArgsFolder : public OpRewritePattern<ForallOp> {
1532 using OpRewritePattern<ForallOp>::OpRewritePattern;
1533
1534 LogicalResult matchAndRewrite(ForallOp forallOp,
1535 PatternRewriter &rewriter) const final {
1536 // Step 1: For a given i-th result of scf.forall, check the following :-
1537 // a. If it has any use.
1538 // b. If the corresponding iter argument is being modified within
1539 // the loop, i.e. has at least one store op with the iter arg as
1540 // its destination operand. For this we use
1541 // ForallOp::getCombiningOps(iter_arg).
1542 //
1543 // Based on the check we maintain the following :-
1544 // a. op results, block arguments, outputs to delete
1545 // b. new outputs (i.e., outputs to retain)
1546 SmallVector<Value> resultsToDelete;
1547 SmallVector<Value> outsToDelete;
1548 SmallVector<BlockArgument> blockArgsToDelete;
1549 SmallVector<Value> newOuts;
1550 BitVector resultIndicesToDelete(forallOp.getNumResults(), false);
1551 BitVector blockIndicesToDelete(forallOp.getBody()->getNumArguments(),
1552 false);
1553 for (OpResult result : forallOp.getResults()) {
1554 OpOperand *opOperand = forallOp.getTiedOpOperand(result);
1555 BlockArgument blockArg = forallOp.getTiedBlockArgument(opOperand);
1556 if (result.use_empty() || forallOp.getCombiningOps(blockArg).empty()) {
1557 resultsToDelete.push_back(result);
1558 outsToDelete.push_back(opOperand->get());
1559 blockArgsToDelete.push_back(blockArg);
1560 resultIndicesToDelete[result.getResultNumber()] = true;
1561 blockIndicesToDelete[blockArg.getArgNumber()] = true;
1562 } else {
1563 newOuts.push_back(opOperand->get());
1564 }
1565 }
1566
1567 // Return early if all results of scf.forall have at least one use and being
1568 // modified within the loop.
1569 if (resultsToDelete.empty())
1570 return failure();
1571
1572 // Step 2: Erase combining ops and replace uses of deleted results and
1573 // block arguments with the corresponding outputs.
1574 for (auto blockArg : blockArgsToDelete) {
1575 SmallVector<Operation *> combiningOps =
1576 forallOp.getCombiningOps(blockArg);
1577 for (Operation *combiningOp : combiningOps)
1578 rewriter.eraseOp(combiningOp);
1579 }
1580 for (auto [blockArg, result, out] :
1581 llvm::zip_equal(blockArgsToDelete, resultsToDelete, outsToDelete)) {
1582 rewriter.replaceAllUsesWith(blockArg, out);
1583 rewriter.replaceAllUsesWith(result, out);
1584 }
1585 // TODO: There is no rewriter API for erasing block arguments.
1586 rewriter.modifyOpInPlace(forallOp, [&]() {
1587 forallOp.getBody()->eraseArguments(blockIndicesToDelete);
1588 });
1589
1590 // Step 3. Create a new scf.forall op with only the shared_outs/results
1591 // that should be retained.
1592 auto newForallOp = cast<scf::ForallOp>(
1593 rewriter.eraseOpResults(forallOp, resultIndicesToDelete));
1594 newForallOp.getOutputsMutable().assign(newOuts);
1595
1596 return success();
1597 }
1598};
1599
1600struct ForallOpSingleOrZeroIterationDimsFolder
1601 : public OpRewritePattern<ForallOp> {
1602 using OpRewritePattern<ForallOp>::OpRewritePattern;
1603
1604 LogicalResult matchAndRewrite(ForallOp op,
1605 PatternRewriter &rewriter) const override {
1606 // Do not fold dimensions if they are mapped to processing units.
1607 if (op.getMapping().has_value() && !op.getMapping()->empty())
1608 return failure();
1609 Location loc = op.getLoc();
1610
1611 // Compute new loop bounds that omit all single-iteration loop dimensions.
1612 SmallVector<OpFoldResult> newMixedLowerBounds, newMixedUpperBounds,
1613 newMixedSteps;
1614 IRMapping mapping;
1615 for (auto [lb, ub, step, iv] :
1616 llvm::zip(op.getMixedLowerBound(), op.getMixedUpperBound(),
1617 op.getMixedStep(), op.getInductionVars())) {
1618 auto numIterations =
1619 constantTripCount(lb, ub, step, /*isSigned=*/true, computeUbMinusLb);
1620 if (numIterations.has_value()) {
1621 // Remove the loop if it performs zero iterations.
1622 if (*numIterations == 0) {
1623 rewriter.replaceOp(op, op.getOutputs());
1624 return success();
1625 }
1626 // Replace the loop induction variable by the lower bound if the loop
1627 // performs a single iteration. Otherwise, copy the loop bounds.
1628 if (*numIterations == 1) {
1629 mapping.map(iv, getValueOrCreateConstantIndexOp(rewriter, loc, lb));
1630 continue;
1631 }
1632 }
1633 newMixedLowerBounds.push_back(lb);
1634 newMixedUpperBounds.push_back(ub);
1635 newMixedSteps.push_back(step);
1636 }
1637
1638 // All of the loop dimensions perform a single iteration. Inline loop body.
1639 if (newMixedLowerBounds.empty()) {
1640 promote(rewriter, op);
1641 return success();
1642 }
1643
1644 // Exit if none of the loop dimensions perform a single iteration.
1645 if (newMixedLowerBounds.size() == static_cast<unsigned>(op.getRank())) {
1646 return rewriter.notifyMatchFailure(
1647 op, "no dimensions have 0 or 1 iterations");
1648 }
1649
1650 // Replace the loop by a lower-dimensional loop.
1651 ForallOp newOp;
1652 newOp = ForallOp::create(rewriter, loc, newMixedLowerBounds,
1653 newMixedUpperBounds, newMixedSteps,
1654 op.getOutputs(), std::nullopt, nullptr);
1655 newOp.getBodyRegion().getBlocks().clear();
1656 // The new loop needs to keep all attributes from the old one, except for
1657 // "operandSegmentSizes" and static loop bound attributes which capture
1658 // the outdated information of the old iteration domain.
1659 SmallVector<StringAttr> elidedAttrs{newOp.getOperandSegmentSizesAttrName(),
1660 newOp.getStaticLowerBoundAttrName(),
1661 newOp.getStaticUpperBoundAttrName(),
1662 newOp.getStaticStepAttrName()};
1663 for (const auto &namedAttr : op->getAttrs()) {
1664 if (llvm::is_contained(elidedAttrs, namedAttr.getName()))
1665 continue;
1666 rewriter.modifyOpInPlace(newOp, [&]() {
1667 newOp->setAttr(namedAttr.getName(), namedAttr.getValue());
1668 });
1669 }
1670 rewriter.cloneRegionBefore(op.getRegion(), newOp.getRegion(),
1671 newOp.getRegion().begin(), mapping);
1672 rewriter.replaceOp(op, newOp.getResults());
1673 return success();
1674 }
1675};
1676
1677/// Replace all induction vars with a single trip count with their lower bound.
1678struct ForallOpReplaceConstantInductionVar : public OpRewritePattern<ForallOp> {
1679 using OpRewritePattern<ForallOp>::OpRewritePattern;
1680
1681 LogicalResult matchAndRewrite(ForallOp op,
1682 PatternRewriter &rewriter) const override {
1683 Location loc = op.getLoc();
1684 bool changed = false;
1685 for (auto [lb, ub, step, iv] :
1686 llvm::zip(op.getMixedLowerBound(), op.getMixedUpperBound(),
1687 op.getMixedStep(), op.getInductionVars())) {
1688 if (iv.hasNUses(0))
1689 continue;
1690 auto numIterations =
1691 constantTripCount(lb, ub, step, /*isSigned=*/true, computeUbMinusLb);
1692 if (!numIterations.has_value() || numIterations.value() != 1) {
1693 continue;
1694 }
1695 rewriter.replaceAllUsesWith(
1696 iv, getValueOrCreateConstantIndexOp(rewriter, loc, lb));
1697 changed = true;
1698 }
1699 return success(changed);
1700 }
1701};
1702
1703struct FoldTensorCastOfOutputIntoForallOp
1704 : public OpRewritePattern<scf::ForallOp> {
1705 using OpRewritePattern<scf::ForallOp>::OpRewritePattern;
1706
1707 struct TypeCast {
1708 Type srcType;
1709 Type dstType;
1710 };
1711
1712 LogicalResult matchAndRewrite(scf::ForallOp forallOp,
1713 PatternRewriter &rewriter) const final {
1714 llvm::SmallMapVector<unsigned, TypeCast, 2> tensorCastProducers;
1715 llvm::SmallVector<Value> newOutputTensors = forallOp.getOutputs();
1716 for (auto en : llvm::enumerate(newOutputTensors)) {
1717 auto castOp = en.value().getDefiningOp<tensor::CastOp>();
1718 if (!castOp)
1719 continue;
1720
1721 // Only casts that that preserve static information, i.e. will make the
1722 // loop result type "more" static than before, will be folded.
1723 if (!tensor::preservesStaticInformation(castOp.getDest().getType(),
1724 castOp.getSource().getType())) {
1725 continue;
1726 }
1727
1728 tensorCastProducers[en.index()] =
1729 TypeCast{castOp.getSource().getType(), castOp.getType()};
1730 newOutputTensors[en.index()] = castOp.getSource();
1731 }
1732
1733 if (tensorCastProducers.empty())
1734 return failure();
1735
1736 // Create new loop.
1737 Location loc = forallOp.getLoc();
1738 auto newForallOp = ForallOp::create(
1739 rewriter, loc, forallOp.getMixedLowerBound(),
1740 forallOp.getMixedUpperBound(), forallOp.getMixedStep(),
1741 newOutputTensors, forallOp.getMapping(),
1742 [&](OpBuilder nestedBuilder, Location nestedLoc, ValueRange bbArgs) {
1743 auto castBlockArgs =
1744 llvm::to_vector(bbArgs.take_back(forallOp->getNumResults()));
1745 for (auto [index, cast] : tensorCastProducers) {
1746 Value &oldTypeBBArg = castBlockArgs[index];
1747 oldTypeBBArg = tensor::CastOp::create(nestedBuilder, nestedLoc,
1748 cast.dstType, oldTypeBBArg);
1749 }
1750
1751 // Move old body into new parallel loop.
1752 SmallVector<Value> ivsBlockArgs =
1753 llvm::to_vector(bbArgs.take_front(forallOp.getRank()));
1754 ivsBlockArgs.append(castBlockArgs);
1755 rewriter.mergeBlocks(forallOp.getBody(),
1756 bbArgs.front().getParentBlock(), ivsBlockArgs);
1757 });
1758
1759 // After `mergeBlocks` happened, the destinations in the terminator may be
1760 // mapped to tensor.cast values wrapping the new output bbArgs (introduced
1761 // for indices in `tensorCastProducers`). Update those destinations to
1762 // point directly to the output bbArgs, bypassing the casts.
1763 //
1764 // Note: we cannot zip yieldingOps with regionIterArgs by position because
1765 // a parallel_insert_slice inside in_parallel may write to any shared
1766 // output, not necessarily the one at the same position.
1767 llvm::SmallDenseSet<Value> newIterArgSet(
1768 newForallOp.getRegionIterArgs().begin(),
1769 newForallOp.getRegionIterArgs().end());
1770 auto terminator = newForallOp.getTerminator();
1771 for (auto &yieldingOp : terminator.getYieldingOps()) {
1772 auto parallelCombiningOp =
1773 dyn_cast<ParallelCombiningOpInterface>(&yieldingOp);
1774 if (!parallelCombiningOp)
1775 continue;
1776 for (OpOperand &dest : parallelCombiningOp.getUpdatedDestinations()) {
1777 auto castOp = dest.get().getDefiningOp<tensor::CastOp>();
1778 if (castOp && newIterArgSet.contains(castOp.getSource()))
1779 dest.set(castOp.getSource());
1780 }
1781 }
1782
1783 // Cast results back to the original types.
1784 rewriter.setInsertionPointAfter(newForallOp);
1785 SmallVector<Value> castResults = newForallOp.getResults();
1786 for (auto &item : tensorCastProducers) {
1787 Value &oldTypeResult = castResults[item.first];
1788 oldTypeResult = tensor::CastOp::create(rewriter, loc, item.second.dstType,
1789 oldTypeResult);
1790 }
1791 rewriter.replaceOp(forallOp, castResults);
1792 return success();
1793 }
1794};
1795
1796} // namespace
1797
1798void ForallOp::getCanonicalizationPatterns(RewritePatternSet &results,
1799 MLIRContext *context) {
1800 results.add<DimOfForallOp, FoldTensorCastOfOutputIntoForallOp,
1801 ForallOpControlOperandsFolder, ForallOpIterArgsFolder,
1802 ForallOpSingleOrZeroIterationDimsFolder,
1803 ForallOpReplaceConstantInductionVar>(context);
1804}
1805
1806void ForallOp::getSuccessorRegions(RegionBranchPoint point,
1807 SmallVectorImpl<RegionSuccessor> &regions) {
1808 // There are two region branch points:
1809 // 1. "parent": entering the forall op for the first time.
1810 // 2. scf.in_parallel terminator
1811 if (point.isParent()) {
1812 // When first entering the forall op, the control flow typically branches
1813 // into the forall body. (In parallel for multiple threads.)
1814 regions.push_back(RegionSuccessor(&getRegion()));
1815 // However, when there are 0 threads, the control flow may branch back to
1816 // the parent immediately.
1817 regions.push_back(RegionSuccessor::parent());
1818 } else {
1819 // In accordance with the semantics of forall, its body is executed in
1820 // parallel by multiple threads. We should not expect to branch back into
1821 // the forall body after the region's execution is complete.
1822 regions.push_back(RegionSuccessor::parent());
1823 }
1824}
1825
1826//===----------------------------------------------------------------------===//
1827// InParallelOp
1828//===----------------------------------------------------------------------===//
1829
1830// Build a InParallelOp with mixed static and dynamic entries.
1831void InParallelOp::build(OpBuilder &b, OperationState &result) {
1832 OpBuilder::InsertionGuard g(b);
1833 Region *bodyRegion = result.addRegion();
1834 b.createBlock(bodyRegion);
1835}
1836
1837LogicalResult InParallelOp::verify() {
1838 scf::ForallOp forallOp =
1839 dyn_cast<scf::ForallOp>(getOperation()->getParentOp());
1840 if (!forallOp)
1841 return this->emitOpError("expected forall op parent");
1842
1843 for (Operation &op : getRegion().front().getOperations()) {
1844 auto parallelCombiningOp = dyn_cast<ParallelCombiningOpInterface>(&op);
1845 if (!parallelCombiningOp) {
1846 return this->emitOpError("expected only ParallelCombiningOpInterface")
1847 << " ops";
1848 }
1849
1850 // Verify that inserts are into out block arguments.
1851 MutableOperandRange dests = parallelCombiningOp.getUpdatedDestinations();
1852 ArrayRef<BlockArgument> regionOutArgs = forallOp.getRegionOutArgs();
1853 for (OpOperand &dest : dests) {
1854 if (!llvm::is_contained(regionOutArgs, dest.get()))
1855 return op.emitOpError("may only insert into an output block argument");
1856 }
1857 }
1858
1859 return success();
1860}
1861
1862void InParallelOp::print(OpAsmPrinter &p) {
1863 p << " ";
1864 p.printRegion(getRegion(),
1865 /*printEntryBlockArgs=*/false,
1866 /*printBlockTerminators=*/false);
1867 p.printOptionalAttrDict(getOperation()->getAttrs());
1868}
1869
1870ParseResult InParallelOp::parse(OpAsmParser &parser, OperationState &result) {
1871 auto &builder = parser.getBuilder();
1872
1873 SmallVector<OpAsmParser::Argument, 8> regionOperands;
1874 std::unique_ptr<Region> region = std::make_unique<Region>();
1875 if (parser.parseRegion(*region, regionOperands))
1876 return failure();
1877
1878 if (region->empty())
1879 OpBuilder(builder.getContext()).createBlock(region.get());
1880 result.addRegion(std::move(region));
1881
1882 // Parse the optional attribute list.
1883 if (parser.parseOptionalAttrDict(result.attributes))
1884 return failure();
1885 return success();
1886}
1887
1888OpResult InParallelOp::getParentResult(int64_t idx) {
1889 return getOperation()->getParentOp()->getResult(idx);
1890}
1891
1892SmallVector<BlockArgument> InParallelOp::getDests() {
1893 SmallVector<BlockArgument> updatedDests;
1894 for (Operation &yieldingOp : getYieldingOps()) {
1895 auto parallelCombiningOp =
1896 dyn_cast<ParallelCombiningOpInterface>(&yieldingOp);
1897 if (!parallelCombiningOp)
1898 continue;
1899 for (OpOperand &updatedOperand :
1900 parallelCombiningOp.getUpdatedDestinations())
1901 updatedDests.push_back(cast<BlockArgument>(updatedOperand.get()));
1902 }
1903 return updatedDests;
1904}
1905
1906llvm::iterator_range<Block::iterator> InParallelOp::getYieldingOps() {
1907 return getRegion().front().getOperations();
1908}
1909
1910//===----------------------------------------------------------------------===//
1911// IfOp
1912//===----------------------------------------------------------------------===//
1913
1915 assert(a && "expected non-empty operation");
1916 assert(b && "expected non-empty operation");
1917
1918 IfOp ifOp = a->getParentOfType<IfOp>();
1919 while (ifOp) {
1920 // Check if b is inside ifOp. (We already know that a is.)
1921 if (ifOp->isProperAncestor(b))
1922 // b is contained in ifOp. a and b are in mutually exclusive branches if
1923 // they are in different blocks of ifOp.
1924 return static_cast<bool>(ifOp.thenBlock()->findAncestorOpInBlock(*a)) !=
1925 static_cast<bool>(ifOp.thenBlock()->findAncestorOpInBlock(*b));
1926 // Check next enclosing IfOp.
1927 ifOp = ifOp->getParentOfType<IfOp>();
1928 }
1929
1930 // Could not find a common IfOp among a's and b's ancestors.
1931 return false;
1932}
1933
1934LogicalResult
1935IfOp::inferReturnTypes(MLIRContext *ctx, std::optional<Location> loc,
1936 IfOp::Adaptor adaptor,
1937 SmallVectorImpl<Type> &inferredReturnTypes) {
1938 if (adaptor.getRegions().empty())
1939 return failure();
1940 Region *r = &adaptor.getThenRegion();
1941 if (r->empty())
1942 return failure();
1943 Block &b = r->front();
1944 if (b.empty())
1945 return failure();
1946 auto yieldOp = llvm::dyn_cast<YieldOp>(b.back());
1947 if (!yieldOp)
1948 return failure();
1949 TypeRange types = yieldOp.getOperandTypes();
1950 llvm::append_range(inferredReturnTypes, types);
1951 return success();
1952}
1953
1954void IfOp::build(OpBuilder &builder, OperationState &result,
1955 TypeRange resultTypes, Value cond) {
1956 return build(builder, result, resultTypes, cond, /*addThenBlock=*/false,
1957 /*addElseBlock=*/false);
1958}
1959
1960void IfOp::build(OpBuilder &builder, OperationState &result,
1961 TypeRange resultTypes, Value cond, bool addThenBlock,
1962 bool addElseBlock) {
1963 assert((!addElseBlock || addThenBlock) &&
1964 "must not create else block w/o then block");
1965 result.addTypes(resultTypes);
1966 result.addOperands(cond);
1967
1968 // Add regions and blocks.
1969 OpBuilder::InsertionGuard guard(builder);
1970 Region *thenRegion = result.addRegion();
1971 if (addThenBlock)
1972 builder.createBlock(thenRegion);
1973 Region *elseRegion = result.addRegion();
1974 if (addElseBlock)
1975 builder.createBlock(elseRegion);
1976}
1977
1978void IfOp::build(OpBuilder &builder, OperationState &result, Value cond,
1979 bool withElseRegion) {
1980 build(builder, result, TypeRange{}, cond, withElseRegion);
1981}
1982
1983void IfOp::build(OpBuilder &builder, OperationState &result,
1984 TypeRange resultTypes, Value cond, bool withElseRegion) {
1985 result.addTypes(resultTypes);
1986 result.addOperands(cond);
1987
1988 // Build then region.
1989 OpBuilder::InsertionGuard guard(builder);
1990 Region *thenRegion = result.addRegion();
1991 builder.createBlock(thenRegion);
1992 if (resultTypes.empty())
1993 IfOp::ensureTerminator(*thenRegion, builder, result.location);
1994
1995 // Build else region.
1996 Region *elseRegion = result.addRegion();
1997 if (withElseRegion) {
1998 builder.createBlock(elseRegion);
1999 if (resultTypes.empty())
2000 IfOp::ensureTerminator(*elseRegion, builder, result.location);
2001 }
2002}
2003
2004void IfOp::build(OpBuilder &builder, OperationState &result, Value cond,
2005 function_ref<void(OpBuilder &, Location)> thenBuilder,
2006 function_ref<void(OpBuilder &, Location)> elseBuilder) {
2007 assert(thenBuilder && "the builder callback for 'then' must be present");
2008 result.addOperands(cond);
2009
2010 // Build then region.
2011 OpBuilder::InsertionGuard guard(builder);
2012 Region *thenRegion = result.addRegion();
2013 builder.createBlock(thenRegion);
2014 thenBuilder(builder, result.location);
2015
2016 // Build else region.
2017 Region *elseRegion = result.addRegion();
2018 if (elseBuilder) {
2019 builder.createBlock(elseRegion);
2020 elseBuilder(builder, result.location);
2021 }
2022
2023 // Infer result types.
2024 SmallVector<Type> inferredReturnTypes;
2025 MLIRContext *ctx = builder.getContext();
2026 auto attrDict = DictionaryAttr::get(ctx, result.attributes);
2027 if (succeeded(inferReturnTypes(ctx, std::nullopt, result.operands, attrDict,
2028 /*properties=*/nullptr, result.regions,
2029 inferredReturnTypes))) {
2030 result.addTypes(inferredReturnTypes);
2031 }
2032}
2033
2034LogicalResult IfOp::verify() {
2035 if (getNumResults() != 0 && getElseRegion().empty())
2036 return emitOpError("must have an else block if defining values");
2037 return success();
2038}
2039
2040ParseResult IfOp::parse(OpAsmParser &parser, OperationState &result) {
2041 // Create the regions for 'then'.
2042 result.regions.reserve(2);
2043 Region *thenRegion = result.addRegion();
2044 Region *elseRegion = result.addRegion();
2045
2046 auto &builder = parser.getBuilder();
2047 OpAsmParser::UnresolvedOperand cond;
2048 Type i1Type = builder.getIntegerType(1);
2049 if (parser.parseOperand(cond) ||
2050 parser.resolveOperand(cond, i1Type, result.operands))
2051 return failure();
2052 // Parse optional results type list.
2053 if (parser.parseOptionalArrowTypeList(result.types))
2054 return failure();
2055 // Parse the 'then' region.
2056 if (parser.parseRegion(*thenRegion, /*arguments=*/{}, /*argTypes=*/{}))
2057 return failure();
2058 IfOp::ensureTerminator(*thenRegion, parser.getBuilder(), result.location);
2059
2060 // If we find an 'else' keyword then parse the 'else' region.
2061 if (!parser.parseOptionalKeyword("else")) {
2062 if (parser.parseRegion(*elseRegion, /*arguments=*/{}, /*argTypes=*/{}))
2063 return failure();
2064 IfOp::ensureTerminator(*elseRegion, parser.getBuilder(), result.location);
2065 }
2066
2067 // Parse the optional attribute list.
2068 if (parser.parseOptionalAttrDict(result.attributes))
2069 return failure();
2070 return success();
2071}
2072
2073void IfOp::print(OpAsmPrinter &p) {
2074 bool printBlockTerminators = false;
2075
2076 p << " " << getCondition();
2077 if (!getResults().empty()) {
2078 p << " -> (" << getResultTypes() << ")";
2079 // Print yield explicitly if the op defines values.
2080 printBlockTerminators = true;
2081 }
2082 p << ' ';
2083 p.printRegion(getThenRegion(),
2084 /*printEntryBlockArgs=*/false,
2085 /*printBlockTerminators=*/printBlockTerminators);
2086
2087 // Print the 'else' regions if it exists and has a block.
2088 auto &elseRegion = getElseRegion();
2089 if (!elseRegion.empty()) {
2090 p << " else ";
2091 p.printRegion(elseRegion,
2092 /*printEntryBlockArgs=*/false,
2093 /*printBlockTerminators=*/printBlockTerminators);
2094 }
2095
2096 p.printOptionalAttrDict((*this)->getAttrs());
2097}
2098
2099void IfOp::getSuccessorRegions(RegionBranchPoint point,
2100 SmallVectorImpl<RegionSuccessor> &regions) {
2101 // The `then` and the `else` region branch back to the parent operation or one
2102 // of the recursive parent operations (early exit case).
2103 if (!point.isParent()) {
2104 regions.push_back(RegionSuccessor::parent());
2105 return;
2106 }
2107
2108 regions.push_back(RegionSuccessor(&getThenRegion()));
2109
2110 // Don't consider the else region if it is empty.
2111 Region *elseRegion = &this->getElseRegion();
2112 if (elseRegion->empty())
2113 regions.push_back(RegionSuccessor::parent());
2114 else
2115 regions.push_back(RegionSuccessor(elseRegion));
2116}
2117
2118ValueRange IfOp::getSuccessorInputs(RegionSuccessor successor) {
2119 return successor.isParent() ? ValueRange(getOperation()->getResults())
2120 : ValueRange();
2121}
2122
2123void IfOp::getEntrySuccessorRegions(ArrayRef<Attribute> operands,
2124 SmallVectorImpl<RegionSuccessor> &regions) {
2125 FoldAdaptor adaptor(operands, *this);
2126 auto boolAttr = dyn_cast_or_null<BoolAttr>(adaptor.getCondition());
2127 if (!boolAttr || boolAttr.getValue())
2128 regions.emplace_back(&getThenRegion());
2129
2130 // If the else region is empty, execution continues after the parent op.
2131 if (!boolAttr || !boolAttr.getValue()) {
2132 if (!getElseRegion().empty())
2133 regions.emplace_back(&getElseRegion());
2134 else
2135 regions.emplace_back(RegionSuccessor::parent());
2136 }
2137}
2138
2139LogicalResult IfOp::fold(FoldAdaptor adaptor,
2140 SmallVectorImpl<OpFoldResult> &results) {
2141 // if (!c) then A() else B() -> if c then B() else A()
2142 if (getElseRegion().empty())
2143 return failure();
2144
2145 arith::XOrIOp xorStmt = getCondition().getDefiningOp<arith::XOrIOp>();
2146 if (!xorStmt)
2147 return failure();
2148
2149 if (!matchPattern(xorStmt.getRhs(), m_One()))
2150 return failure();
2151
2152 getConditionMutable().assign(xorStmt.getLhs());
2153 Block *thenBlock = &getThenRegion().front();
2154 // It would be nicer to use iplist::swap, but that has no implemented
2155 // callbacks See: https://llvm.org/doxygen/ilist_8h_source.html#l00224
2156 getThenRegion().getBlocks().splice(getThenRegion().getBlocks().begin(),
2157 getElseRegion().getBlocks());
2158 getElseRegion().getBlocks().splice(getElseRegion().getBlocks().begin(),
2159 getThenRegion().getBlocks(), thenBlock);
2160 return success();
2161}
2162
2163void IfOp::getRegionInvocationBounds(
2164 ArrayRef<Attribute> operands,
2165 SmallVectorImpl<InvocationBounds> &invocationBounds) {
2166 if (auto cond = llvm::dyn_cast_or_null<BoolAttr>(operands[0])) {
2167 // If the condition is known, then one region is known to be executed once
2168 // and the other zero times.
2169 invocationBounds.emplace_back(0, cond.getValue() ? 1 : 0);
2170 invocationBounds.emplace_back(0, cond.getValue() ? 0 : 1);
2171 } else {
2172 // Non-constant condition. Each region may be executed 0 or 1 times.
2173 invocationBounds.assign(2, {0, 1});
2174 }
2175}
2176
2177namespace {
2178/// Hoist any yielded results whose operands are defined outside
2179/// the if, to a select instruction.
2180struct ConvertTrivialIfToSelect : public OpRewritePattern<IfOp> {
2181 using OpRewritePattern<IfOp>::OpRewritePattern;
2182
2183 LogicalResult matchAndRewrite(IfOp op,
2184 PatternRewriter &rewriter) const override {
2185 if (op->getNumResults() == 0)
2186 return failure();
2187
2188 auto cond = op.getCondition();
2189 auto thenYieldArgs = op.thenYield().getOperands();
2190 auto elseYieldArgs = op.elseYield().getOperands();
2191
2192 SmallVector<Type> nonHoistable;
2193 for (auto [trueVal, falseVal] : llvm::zip(thenYieldArgs, elseYieldArgs)) {
2194 if (&op.getThenRegion() == trueVal.getParentRegion() ||
2195 &op.getElseRegion() == falseVal.getParentRegion())
2196 nonHoistable.push_back(trueVal.getType());
2197 }
2198 // Early exit if there aren't any yielded values we can
2199 // hoist outside the if.
2200 if (nonHoistable.size() == op->getNumResults())
2201 return failure();
2202
2203 IfOp replacement = IfOp::create(rewriter, op.getLoc(), nonHoistable, cond,
2204 /*withElseRegion=*/false);
2205 if (replacement.thenBlock())
2206 rewriter.eraseBlock(replacement.thenBlock());
2207 replacement.getThenRegion().takeBody(op.getThenRegion());
2208 replacement.getElseRegion().takeBody(op.getElseRegion());
2209
2210 SmallVector<Value> results(op->getNumResults());
2211 assert(thenYieldArgs.size() == results.size());
2212 assert(elseYieldArgs.size() == results.size());
2213
2214 SmallVector<Value> trueYields;
2215 SmallVector<Value> falseYields;
2217 for (const auto &it :
2218 llvm::enumerate(llvm::zip(thenYieldArgs, elseYieldArgs))) {
2219 Value trueVal = std::get<0>(it.value());
2220 Value falseVal = std::get<1>(it.value());
2221 if (&replacement.getThenRegion() == trueVal.getParentRegion() ||
2222 &replacement.getElseRegion() == falseVal.getParentRegion()) {
2223 results[it.index()] = replacement.getResult(trueYields.size());
2224 trueYields.push_back(trueVal);
2225 falseYields.push_back(falseVal);
2226 } else if (trueVal == falseVal)
2227 results[it.index()] = trueVal;
2228 else
2229 results[it.index()] = arith::SelectOp::create(rewriter, op.getLoc(),
2230 cond, trueVal, falseVal);
2231 }
2232
2233 rewriter.setInsertionPointToEnd(replacement.thenBlock());
2234 rewriter.replaceOpWithNewOp<YieldOp>(replacement.thenYield(), trueYields);
2235
2236 rewriter.setInsertionPointToEnd(replacement.elseBlock());
2237 rewriter.replaceOpWithNewOp<YieldOp>(replacement.elseYield(), falseYields);
2238
2239 rewriter.replaceOp(op, results);
2240 return success();
2241 }
2242};
2243
2244/// Allow the true region of an if to assume the condition is true
2245/// and vice versa. For example:
2246///
2247/// scf.if %cmp {
2248/// print(%cmp)
2249/// }
2250///
2251/// becomes
2252///
2253/// scf.if %cmp {
2254/// print(true)
2255/// }
2256///
2257struct ConditionPropagation : public OpRewritePattern<IfOp> {
2258 using OpRewritePattern<IfOp>::OpRewritePattern;
2259
2260 /// Kind of parent region in the ancestor cache.
2261 enum class Parent { Then, Else, None };
2262
2263 /// Returns the kind of region ("then", "else", or "none") of the
2264 /// IfOp that the given region is transitively nested in. Updates
2265 /// the cache accordingly.
2266 static Parent getParentType(Region *toCheck, IfOp op,
2268 Region *endRegion) {
2269 SmallVector<Region *> seen;
2270 while (toCheck != endRegion) {
2271 auto found = cache.find(toCheck);
2272 if (found != cache.end())
2273 return found->second;
2274 seen.push_back(toCheck);
2275 if (&op.getThenRegion() == toCheck) {
2276 for (Region *region : seen)
2277 cache[region] = Parent::Then;
2278 return Parent::Then;
2279 }
2280 if (&op.getElseRegion() == toCheck) {
2281 for (Region *region : seen)
2282 cache[region] = Parent::Else;
2283 return Parent::Else;
2284 }
2285 toCheck = toCheck->getParentRegion();
2286 }
2287
2288 for (Region *region : seen)
2289 cache[region] = Parent::None;
2290 return Parent::None;
2291 }
2292
2293 LogicalResult matchAndRewrite(IfOp op,
2294 PatternRewriter &rewriter) const override {
2295 // Early exit if the condition is constant since replacing a constant
2296 // in the body with another constant isn't a simplification.
2297 if (matchPattern(op.getCondition(), m_Constant()))
2298 return failure();
2299
2300 bool changed = false;
2301 mlir::Type i1Ty = rewriter.getI1Type();
2302
2303 // These variables serve to prevent creating duplicate constants
2304 // and hold constant true or false values.
2305 Value constantTrue = nullptr;
2306 Value constantFalse = nullptr;
2307
2309 for (OpOperand &use :
2310 llvm::make_early_inc_range(op.getCondition().getUses())) {
2311 switch (getParentType(use.getOwner()->getParentRegion(), op, cache,
2312 op.getCondition().getParentRegion())) {
2313 case Parent::Then: {
2314 changed = true;
2315
2316 if (!constantTrue)
2317 constantTrue = arith::ConstantOp::create(
2318 rewriter, op.getLoc(), i1Ty, rewriter.getIntegerAttr(i1Ty, 1));
2319
2320 rewriter.modifyOpInPlace(use.getOwner(),
2321 [&]() { use.set(constantTrue); });
2322 break;
2323 }
2324 case Parent::Else: {
2325 changed = true;
2326
2327 if (!constantFalse)
2328 constantFalse = arith::ConstantOp::create(
2329 rewriter, op.getLoc(), i1Ty, rewriter.getIntegerAttr(i1Ty, 0));
2330
2331 rewriter.modifyOpInPlace(use.getOwner(),
2332 [&]() { use.set(constantFalse); });
2333 break;
2334 }
2335 case Parent::None:
2336 break;
2337 }
2338 }
2339
2340 return success(changed);
2341 }
2342};
2343
2344/// Remove any statements from an if that are equivalent to the condition
2345/// or its negation. For example:
2346///
2347/// %res:2 = scf.if %cmp {
2348/// yield something(), true
2349/// } else {
2350/// yield something2(), false
2351/// }
2352/// print(%res#1)
2353///
2354/// becomes
2355/// %res = scf.if %cmp {
2356/// yield something()
2357/// } else {
2358/// yield something2()
2359/// }
2360/// print(%cmp)
2361///
2362/// Additionally if both branches yield the same value, replace all uses
2363/// of the result with the yielded value.
2364///
2365/// %res:2 = scf.if %cmp {
2366/// yield something(), %arg1
2367/// } else {
2368/// yield something2(), %arg1
2369/// }
2370/// print(%res#1)
2371///
2372/// becomes
2373/// %res = scf.if %cmp {
2374/// yield something()
2375/// } else {
2376/// yield something2()
2377/// }
2378/// print(%arg1)
2379///
2380struct ReplaceIfYieldWithConditionOrValue : public OpRewritePattern<IfOp> {
2381 using OpRewritePattern<IfOp>::OpRewritePattern;
2382
2383 LogicalResult matchAndRewrite(IfOp op,
2384 PatternRewriter &rewriter) const override {
2385 // Early exit if there are no results that could be replaced.
2386 if (op.getNumResults() == 0)
2387 return failure();
2388
2389 auto trueYield =
2390 cast<scf::YieldOp>(op.getThenRegion().back().getTerminator());
2391 auto falseYield =
2392 cast<scf::YieldOp>(op.getElseRegion().back().getTerminator());
2393
2394 rewriter.setInsertionPoint(op->getBlock(),
2395 op.getOperation()->getIterator());
2396 bool changed = false;
2397 Type i1Ty = rewriter.getI1Type();
2398 for (auto [trueResult, falseResult, opResult] :
2399 llvm::zip(trueYield.getResults(), falseYield.getResults(),
2400 op.getResults())) {
2401 if (trueResult == falseResult) {
2402 if (!opResult.use_empty()) {
2403 opResult.replaceAllUsesWith(trueResult);
2404 changed = true;
2405 }
2406 continue;
2407 }
2408
2409 BoolAttr trueYield, falseYield;
2410 if (!matchPattern(trueResult, m_Constant(&trueYield)) ||
2411 !matchPattern(falseResult, m_Constant(&falseYield)))
2412 continue;
2413
2414 bool trueVal = trueYield.getValue();
2415 bool falseVal = falseYield.getValue();
2416 if (!trueVal && falseVal) {
2417 if (!opResult.use_empty()) {
2418 Dialect *constDialect = trueResult.getDefiningOp()->getDialect();
2419 Value notCond = arith::XOrIOp::create(
2420 rewriter, op.getLoc(), op.getCondition(),
2421 constDialect
2422 ->materializeConstant(rewriter,
2423 rewriter.getIntegerAttr(i1Ty, 1), i1Ty,
2424 op.getLoc())
2425 ->getResult(0));
2426 opResult.replaceAllUsesWith(notCond);
2427 changed = true;
2428 }
2429 }
2430 if (trueVal && !falseVal) {
2431 if (!opResult.use_empty()) {
2432 opResult.replaceAllUsesWith(op.getCondition());
2433 changed = true;
2434 }
2435 }
2436 }
2437 return success(changed);
2438 }
2439};
2440
2441/// Merge any consecutive scf.if's with the same condition.
2442///
2443/// scf.if %cond {
2444/// firstCodeTrue();...
2445/// } else {
2446/// firstCodeFalse();...
2447/// }
2448/// %res = scf.if %cond {
2449/// secondCodeTrue();...
2450/// } else {
2451/// secondCodeFalse();...
2452/// }
2453///
2454/// becomes
2455/// %res = scf.if %cmp {
2456/// firstCodeTrue();...
2457/// secondCodeTrue();...
2458/// } else {
2459/// firstCodeFalse();...
2460/// secondCodeFalse();...
2461/// }
2462struct CombineIfs : public OpRewritePattern<IfOp> {
2463 using OpRewritePattern<IfOp>::OpRewritePattern;
2464
2465 LogicalResult matchAndRewrite(IfOp nextIf,
2466 PatternRewriter &rewriter) const override {
2467 Block *parent = nextIf->getBlock();
2468 if (nextIf == &parent->front())
2469 return failure();
2470
2471 auto prevIf = dyn_cast<IfOp>(nextIf->getPrevNode());
2472 if (!prevIf)
2473 return failure();
2474
2475 // Determine the logical then/else blocks when prevIf's
2476 // condition is used. Null means the block does not exist
2477 // in that case (e.g. empty else). If neither of these
2478 // are set, the two conditions cannot be compared.
2479 Block *nextThen = nullptr;
2480 Block *nextElse = nullptr;
2481 if (nextIf.getCondition() == prevIf.getCondition()) {
2482 nextThen = nextIf.thenBlock();
2483 if (!nextIf.getElseRegion().empty())
2484 nextElse = nextIf.elseBlock();
2485 }
2486 if (arith::XOrIOp notv =
2487 nextIf.getCondition().getDefiningOp<arith::XOrIOp>()) {
2488 if (notv.getLhs() == prevIf.getCondition() &&
2489 matchPattern(notv.getRhs(), m_One())) {
2490 nextElse = nextIf.thenBlock();
2491 if (!nextIf.getElseRegion().empty())
2492 nextThen = nextIf.elseBlock();
2493 }
2494 }
2495 if (arith::XOrIOp notv =
2496 prevIf.getCondition().getDefiningOp<arith::XOrIOp>()) {
2497 if (notv.getLhs() == nextIf.getCondition() &&
2498 matchPattern(notv.getRhs(), m_One())) {
2499 nextElse = nextIf.thenBlock();
2500 if (!nextIf.getElseRegion().empty())
2501 nextThen = nextIf.elseBlock();
2502 }
2503 }
2504
2505 if (!nextThen && !nextElse)
2506 return failure();
2507
2508 SmallVector<Value> prevElseYielded;
2509 if (!prevIf.getElseRegion().empty())
2510 prevElseYielded = prevIf.elseYield().getOperands();
2511 // Replace all uses of return values of op within nextIf with the
2512 // corresponding yields
2513 for (auto it : llvm::zip(prevIf.getResults(),
2514 prevIf.thenYield().getOperands(), prevElseYielded))
2515 for (OpOperand &use :
2516 llvm::make_early_inc_range(std::get<0>(it).getUses())) {
2517 if (nextThen && nextThen->getParent()->isAncestor(
2518 use.getOwner()->getParentRegion())) {
2519 rewriter.startOpModification(use.getOwner());
2520 use.set(std::get<1>(it));
2521 rewriter.finalizeOpModification(use.getOwner());
2522 } else if (nextElse && nextElse->getParent()->isAncestor(
2523 use.getOwner()->getParentRegion())) {
2524 rewriter.startOpModification(use.getOwner());
2525 use.set(std::get<2>(it));
2526 rewriter.finalizeOpModification(use.getOwner());
2527 }
2528 }
2529
2530 SmallVector<Type> mergedTypes(prevIf.getResultTypes());
2531 llvm::append_range(mergedTypes, nextIf.getResultTypes());
2532
2533 IfOp combinedIf = IfOp::create(rewriter, nextIf.getLoc(), mergedTypes,
2534 prevIf.getCondition(), /*hasElse=*/false);
2535 rewriter.eraseBlock(&combinedIf.getThenRegion().back());
2536
2537 rewriter.inlineRegionBefore(prevIf.getThenRegion(),
2538 combinedIf.getThenRegion(),
2539 combinedIf.getThenRegion().begin());
2540
2541 if (nextThen) {
2542 YieldOp thenYield = combinedIf.thenYield();
2543 YieldOp thenYield2 = cast<YieldOp>(nextThen->getTerminator());
2544 rewriter.mergeBlocks(nextThen, combinedIf.thenBlock());
2545 rewriter.setInsertionPointToEnd(combinedIf.thenBlock());
2546
2547 SmallVector<Value> mergedYields(thenYield.getOperands());
2548 llvm::append_range(mergedYields, thenYield2.getOperands());
2549 YieldOp::create(rewriter, thenYield2.getLoc(), mergedYields);
2550 rewriter.eraseOp(thenYield);
2551 rewriter.eraseOp(thenYield2);
2552 }
2553
2554 rewriter.inlineRegionBefore(prevIf.getElseRegion(),
2555 combinedIf.getElseRegion(),
2556 combinedIf.getElseRegion().begin());
2557
2558 if (nextElse) {
2559 if (combinedIf.getElseRegion().empty()) {
2560 rewriter.inlineRegionBefore(*nextElse->getParent(),
2561 combinedIf.getElseRegion(),
2562 combinedIf.getElseRegion().begin());
2563 } else {
2564 YieldOp elseYield = combinedIf.elseYield();
2565 YieldOp elseYield2 = cast<YieldOp>(nextElse->getTerminator());
2566 rewriter.mergeBlocks(nextElse, combinedIf.elseBlock());
2567
2568 rewriter.setInsertionPointToEnd(combinedIf.elseBlock());
2569
2570 SmallVector<Value> mergedElseYields(elseYield.getOperands());
2571 llvm::append_range(mergedElseYields, elseYield2.getOperands());
2572
2573 YieldOp::create(rewriter, elseYield2.getLoc(), mergedElseYields);
2574 rewriter.eraseOp(elseYield);
2575 rewriter.eraseOp(elseYield2);
2576 }
2577 }
2578
2579 SmallVector<Value> prevValues;
2580 SmallVector<Value> nextValues;
2581 for (const auto &pair : llvm::enumerate(combinedIf.getResults())) {
2582 if (pair.index() < prevIf.getNumResults())
2583 prevValues.push_back(pair.value());
2584 else
2585 nextValues.push_back(pair.value());
2586 }
2587 rewriter.replaceOp(prevIf, prevValues);
2588 rewriter.replaceOp(nextIf, nextValues);
2589 return success();
2590 }
2591};
2592
2593/// Pattern to remove an empty else branch.
2594struct RemoveEmptyElseBranch : public OpRewritePattern<IfOp> {
2595 using OpRewritePattern<IfOp>::OpRewritePattern;
2596
2597 LogicalResult matchAndRewrite(IfOp ifOp,
2598 PatternRewriter &rewriter) const override {
2599 // Cannot remove else region when there are operation results.
2600 if (ifOp.getNumResults())
2601 return failure();
2602 Block *elseBlock = ifOp.elseBlock();
2603 if (!elseBlock || !llvm::hasSingleElement(*elseBlock))
2604 return failure();
2605 auto newIfOp = rewriter.cloneWithoutRegions(ifOp);
2606 rewriter.inlineRegionBefore(ifOp.getThenRegion(), newIfOp.getThenRegion(),
2607 newIfOp.getThenRegion().begin());
2608 rewriter.eraseOp(ifOp);
2609 return success();
2610 }
2611};
2612
2613/// Convert nested `if`s into `arith.andi` + single `if`.
2614///
2615/// scf.if %arg0 {
2616/// scf.if %arg1 {
2617/// ...
2618/// scf.yield
2619/// }
2620/// scf.yield
2621/// }
2622/// becomes
2623///
2624/// %0 = arith.andi %arg0, %arg1
2625/// scf.if %0 {
2626/// ...
2627/// scf.yield
2628/// }
2629struct CombineNestedIfs : public OpRewritePattern<IfOp> {
2630 using OpRewritePattern<IfOp>::OpRewritePattern;
2631
2632 LogicalResult matchAndRewrite(IfOp op,
2633 PatternRewriter &rewriter) const override {
2634 auto nestedOps = op.thenBlock()->without_terminator();
2635 // Nested `if` must be the only op in block.
2636 if (!llvm::hasSingleElement(nestedOps))
2637 return failure();
2638
2639 // If there is an else block, it can only yield
2640 if (op.elseBlock() && !llvm::hasSingleElement(*op.elseBlock()))
2641 return failure();
2642
2643 auto nestedIf = dyn_cast<IfOp>(*nestedOps.begin());
2644 if (!nestedIf)
2645 return failure();
2646
2647 if (nestedIf.elseBlock() && !llvm::hasSingleElement(*nestedIf.elseBlock()))
2648 return failure();
2649
2650 SmallVector<Value> thenYield(op.thenYield().getOperands());
2651 SmallVector<Value> elseYield;
2652 if (op.elseBlock())
2653 llvm::append_range(elseYield, op.elseYield().getOperands());
2654
2655 // A list of indices for which we should upgrade the value yielded
2656 // in the else to a select.
2657 SmallVector<unsigned> elseYieldsToUpgradeToSelect;
2658
2659 // If the outer scf.if yields a value produced by the inner scf.if,
2660 // only permit combining if the value yielded when the condition
2661 // is false in the outer scf.if is the same value yielded when the
2662 // inner scf.if condition is false.
2663 // Note that the array access to elseYield will not go out of bounds
2664 // since it must have the same length as thenYield, since they both
2665 // come from the same scf.if.
2666 for (const auto &tup : llvm::enumerate(thenYield)) {
2667 if (tup.value().getDefiningOp() == nestedIf) {
2668 auto nestedIdx = llvm::cast<OpResult>(tup.value()).getResultNumber();
2669 if (nestedIf.elseYield().getOperand(nestedIdx) !=
2670 elseYield[tup.index()]) {
2671 return failure();
2672 }
2673 // If the correctness test passes, we will yield
2674 // corresponding value from the inner scf.if
2675 thenYield[tup.index()] = nestedIf.thenYield().getOperand(nestedIdx);
2676 continue;
2677 }
2678
2679 // Otherwise, we need to ensure the else block of the combined
2680 // condition still returns the same value when the outer condition is
2681 // true and the inner condition is false. This can be accomplished if
2682 // the then value is defined outside the outer scf.if and we replace the
2683 // value with a select that considers just the outer condition. Since
2684 // the else region contains just the yield, its yielded value is
2685 // defined outside the scf.if, by definition.
2686
2687 // If the then value is defined within the scf.if, bail.
2688 if (tup.value().getParentRegion() == &op.getThenRegion()) {
2689 return failure();
2690 }
2691 elseYieldsToUpgradeToSelect.push_back(tup.index());
2692 }
2693
2694 Location loc = op.getLoc();
2695 Value newCondition = arith::AndIOp::create(rewriter, loc, op.getCondition(),
2696 nestedIf.getCondition());
2697 auto newIf = IfOp::create(rewriter, loc, op.getResultTypes(), newCondition);
2698 Block *newIfBlock = rewriter.createBlock(&newIf.getThenRegion());
2699
2700 SmallVector<Value> results;
2701 llvm::append_range(results, newIf.getResults());
2702 rewriter.setInsertionPoint(newIf);
2703
2704 for (auto idx : elseYieldsToUpgradeToSelect)
2705 results[idx] =
2706 arith::SelectOp::create(rewriter, op.getLoc(), op.getCondition(),
2707 thenYield[idx], elseYield[idx]);
2708
2709 rewriter.mergeBlocks(nestedIf.thenBlock(), newIfBlock);
2710 rewriter.setInsertionPointToEnd(newIf.thenBlock());
2711 rewriter.replaceOpWithNewOp<YieldOp>(newIf.thenYield(), thenYield);
2712 if (!elseYield.empty()) {
2713 rewriter.createBlock(&newIf.getElseRegion());
2714 rewriter.setInsertionPointToEnd(newIf.elseBlock());
2715 YieldOp::create(rewriter, loc, elseYield);
2716 }
2717 rewriter.replaceOp(op, results);
2718 return success();
2719 }
2720};
2721
2722} // namespace
2723
2724void IfOp::getCanonicalizationPatterns(RewritePatternSet &results,
2725 MLIRContext *context) {
2726 results.add<CombineIfs, CombineNestedIfs, ConditionPropagation,
2727 ConvertTrivialIfToSelect, RemoveEmptyElseBranch,
2728 ReplaceIfYieldWithConditionOrValue>(context);
2730 results, IfOp::getOperationName());
2732 IfOp::getOperationName());
2733}
2734
2735Block *IfOp::thenBlock() { return &getThenRegion().back(); }
2736YieldOp IfOp::thenYield() { return cast<YieldOp>(&thenBlock()->back()); }
2737Block *IfOp::elseBlock() {
2738 Region &r = getElseRegion();
2739 if (r.empty())
2740 return nullptr;
2741 return &r.back();
2742}
2743YieldOp IfOp::elseYield() { return cast<YieldOp>(&elseBlock()->back()); }
2744
2745//===----------------------------------------------------------------------===//
2746// ParallelOp
2747//===----------------------------------------------------------------------===//
2748
2749void ParallelOp::build(
2750 OpBuilder &builder, OperationState &result, ValueRange lowerBounds,
2751 ValueRange upperBounds, ValueRange steps, ValueRange initVals,
2752 function_ref<void(OpBuilder &, Location, ValueRange, ValueRange)>
2753 bodyBuilderFn) {
2754 result.addOperands(lowerBounds);
2755 result.addOperands(upperBounds);
2756 result.addOperands(steps);
2757 result.addOperands(initVals);
2758 result.addAttribute(
2759 ParallelOp::getOperandSegmentSizeAttr(),
2760 builder.getDenseI32ArrayAttr({static_cast<int32_t>(lowerBounds.size()),
2761 static_cast<int32_t>(upperBounds.size()),
2762 static_cast<int32_t>(steps.size()),
2763 static_cast<int32_t>(initVals.size())}));
2764 result.addTypes(initVals.getTypes());
2765
2766 OpBuilder::InsertionGuard guard(builder);
2767 unsigned numIVs = steps.size();
2768 SmallVector<Type, 8> argTypes(numIVs, builder.getIndexType());
2769 SmallVector<Location, 8> argLocs(numIVs, result.location);
2770 Region *bodyRegion = result.addRegion();
2771 Block *bodyBlock = builder.createBlock(bodyRegion, {}, argTypes, argLocs);
2772
2773 if (bodyBuilderFn) {
2774 builder.setInsertionPointToStart(bodyBlock);
2775 bodyBuilderFn(builder, result.location,
2776 bodyBlock->getArguments().take_front(numIVs),
2777 bodyBlock->getArguments().drop_front(numIVs));
2778 }
2779 // Add terminator only if there are no reductions.
2780 if (initVals.empty())
2781 ParallelOp::ensureTerminator(*bodyRegion, builder, result.location);
2782}
2783
2784void ParallelOp::build(
2785 OpBuilder &builder, OperationState &result, ValueRange lowerBounds,
2786 ValueRange upperBounds, ValueRange steps,
2787 function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilderFn) {
2788 // Only pass a non-null wrapper if bodyBuilderFn is non-null itself. Make sure
2789 // we don't capture a reference to a temporary by constructing the lambda at
2790 // function level.
2791 auto wrappedBuilderFn = [&bodyBuilderFn](OpBuilder &nestedBuilder,
2792 Location nestedLoc, ValueRange ivs,
2793 ValueRange) {
2794 bodyBuilderFn(nestedBuilder, nestedLoc, ivs);
2795 };
2796 function_ref<void(OpBuilder &, Location, ValueRange, ValueRange)> wrapper;
2797 if (bodyBuilderFn)
2798 wrapper = wrappedBuilderFn;
2799
2800 build(builder, result, lowerBounds, upperBounds, steps, ValueRange(),
2801 wrapper);
2802}
2803
2804LogicalResult ParallelOp::verify() {
2805 // Check that there is at least one value in lowerBound, upperBound and step.
2806 // It is sufficient to test only step, because it is ensured already that the
2807 // number of elements in lowerBound, upperBound and step are the same.
2808 Operation::operand_range stepValues = getStep();
2809 if (stepValues.empty())
2810 return emitOpError(
2811 "needs at least one tuple element for lowerBound, upperBound and step");
2812
2813 // Check whether all constant step values are positive.
2814 for (Value stepValue : stepValues)
2815 if (auto cst = getConstantIntValue(stepValue))
2816 if (*cst <= 0)
2817 return emitOpError("constant step operand must be positive");
2818
2819 // Check that the body defines the same number of block arguments as the
2820 // number of tuple elements in step.
2821 Block *body = getBody();
2822 if (body->getNumArguments() != stepValues.size())
2823 return emitOpError() << "expects the same number of induction variables: "
2824 << body->getNumArguments()
2825 << " as bound and step values: " << stepValues.size();
2826 for (auto arg : body->getArguments())
2827 if (!arg.getType().isIndex())
2828 return emitOpError(
2829 "expects arguments for the induction variable to be of index type");
2830
2831 // Check that the terminator is an scf.reduce op.
2833 *this, getRegion(), "expects body to terminate with 'scf.reduce'");
2834 if (!reduceOp)
2835 return failure();
2836
2837 // Check that the number of results is the same as the number of reductions.
2838 auto resultsSize = getResults().size();
2839 auto reductionsSize = reduceOp.getReductions().size();
2840 auto initValsSize = getInitVals().size();
2841 if (resultsSize != reductionsSize)
2842 return emitOpError() << "expects number of results: " << resultsSize
2843 << " to be the same as number of reductions: "
2844 << reductionsSize;
2845 if (resultsSize != initValsSize)
2846 return emitOpError() << "expects number of results: " << resultsSize
2847 << " to be the same as number of initial values: "
2848 << initValsSize;
2849 if (reduceOp.getNumOperands() != initValsSize)
2850 // Delegate error reporting to ReduceOp
2851 return success();
2852
2853 // Check that the types of the results and reductions are the same.
2854 for (int64_t i = 0; i < static_cast<int64_t>(reductionsSize); ++i) {
2855 auto resultType = getOperation()->getResult(i).getType();
2856 auto reductionOperandType = reduceOp.getOperands()[i].getType();
2857 if (resultType != reductionOperandType)
2858 return reduceOp.emitOpError()
2859 << "expects type of " << i
2860 << "-th reduction operand: " << reductionOperandType
2861 << " to be the same as the " << i
2862 << "-th result type: " << resultType;
2863 }
2864 return success();
2865}
2866
2867ParseResult ParallelOp::parse(OpAsmParser &parser, OperationState &result) {
2868 auto &builder = parser.getBuilder();
2869 // Parse an opening `(` followed by induction variables followed by `)`
2870 SmallVector<OpAsmParser::Argument, 4> ivs;
2872 return failure();
2873
2874 // Parse loop bounds.
2875 SmallVector<OpAsmParser::UnresolvedOperand, 4> lower;
2876 if (parser.parseEqual() ||
2877 parser.parseOperandList(lower, ivs.size(),
2879 parser.resolveOperands(lower, builder.getIndexType(), result.operands))
2880 return failure();
2881
2882 SmallVector<OpAsmParser::UnresolvedOperand, 4> upper;
2883 if (parser.parseKeyword("to") ||
2884 parser.parseOperandList(upper, ivs.size(),
2886 parser.resolveOperands(upper, builder.getIndexType(), result.operands))
2887 return failure();
2888
2889 // Parse step values.
2890 SmallVector<OpAsmParser::UnresolvedOperand, 4> steps;
2891 if (parser.parseKeyword("step") ||
2892 parser.parseOperandList(steps, ivs.size(),
2894 parser.resolveOperands(steps, builder.getIndexType(), result.operands))
2895 return failure();
2896
2897 // Parse init values.
2898 SmallVector<OpAsmParser::UnresolvedOperand, 4> initVals;
2899 if (succeeded(parser.parseOptionalKeyword("init"))) {
2900 if (parser.parseOperandList(initVals, OpAsmParser::Delimiter::Paren))
2901 return failure();
2902 }
2903
2904 // Parse optional results in case there is a reduce.
2905 if (parser.parseOptionalArrowTypeList(result.types))
2906 return failure();
2907
2908 // Now parse the body.
2909 Region *body = result.addRegion();
2910 for (auto &iv : ivs)
2911 iv.type = builder.getIndexType();
2912 if (parser.parseRegion(*body, ivs))
2913 return failure();
2914
2915 // Set `operandSegmentSizes` attribute.
2916 result.addAttribute(
2917 ParallelOp::getOperandSegmentSizeAttr(),
2918 builder.getDenseI32ArrayAttr({static_cast<int32_t>(lower.size()),
2919 static_cast<int32_t>(upper.size()),
2920 static_cast<int32_t>(steps.size()),
2921 static_cast<int32_t>(initVals.size())}));
2922
2923 // Parse attributes.
2924 if (parser.parseOptionalAttrDict(result.attributes) ||
2925 parser.resolveOperands(initVals, result.types, parser.getNameLoc(),
2926 result.operands))
2927 return failure();
2928
2929 // Add a terminator if none was parsed.
2930 ParallelOp::ensureTerminator(*body, builder, result.location);
2931 return success();
2932}
2933
2934void ParallelOp::print(OpAsmPrinter &p) {
2935 p << " (" << getBody()->getArguments() << ") = (" << getLowerBound()
2936 << ") to (" << getUpperBound() << ") step (" << getStep() << ")";
2937 if (!getInitVals().empty())
2938 p << " init (" << getInitVals() << ")";
2939 p.printOptionalArrowTypeList(getResultTypes());
2940 p << ' ';
2941 p.printRegion(getRegion(), /*printEntryBlockArgs=*/false);
2943 (*this)->getAttrs(),
2944 /*elidedAttrs=*/ParallelOp::getOperandSegmentSizeAttr());
2945}
2946
2947SmallVector<Region *> ParallelOp::getLoopRegions() { return {&getRegion()}; }
2948
2949std::optional<SmallVector<Value>> ParallelOp::getLoopInductionVars() {
2950 return SmallVector<Value>{getBody()->getArguments()};
2951}
2952
2953std::optional<SmallVector<OpFoldResult>> ParallelOp::getLoopLowerBounds() {
2954 return getLowerBound();
2955}
2956
2957std::optional<SmallVector<OpFoldResult>> ParallelOp::getLoopUpperBounds() {
2958 return getUpperBound();
2959}
2960
2961std::optional<SmallVector<OpFoldResult>> ParallelOp::getLoopSteps() {
2962 return getStep();
2963}
2964
2966 auto ivArg = llvm::dyn_cast<BlockArgument>(val);
2967 if (!ivArg)
2968 return ParallelOp();
2969 assert(ivArg.getOwner() && "unlinked block argument");
2970 auto *containingOp = ivArg.getOwner()->getParentOp();
2971 return dyn_cast<ParallelOp>(containingOp);
2972}
2973
2974namespace {
2975// Collapse loop dimensions that perform a single iteration.
2976struct ParallelOpSingleOrZeroIterationDimsFolder
2977 : public OpRewritePattern<ParallelOp> {
2978 using OpRewritePattern<ParallelOp>::OpRewritePattern;
2979
2980 LogicalResult matchAndRewrite(ParallelOp op,
2981 PatternRewriter &rewriter) const override {
2982 Location loc = op.getLoc();
2983
2984 // Compute new loop bounds that omit all single-iteration loop dimensions.
2985 SmallVector<Value> newLowerBounds, newUpperBounds, newSteps;
2986 IRMapping mapping;
2987 for (auto [lb, ub, step, iv] :
2988 llvm::zip(op.getLowerBound(), op.getUpperBound(), op.getStep(),
2989 op.getInductionVars())) {
2990 auto numIterations =
2991 constantTripCount(lb, ub, step, /*isSigned=*/true, computeUbMinusLb);
2992 if (numIterations.has_value()) {
2993 // Remove the loop if it performs zero iterations.
2994 if (*numIterations == 0) {
2995 rewriter.replaceOp(op, op.getInitVals());
2996 return success();
2997 }
2998 // Replace the loop induction variable by the lower bound if the loop
2999 // performs a single iteration. Otherwise, copy the loop bounds.
3000 if (*numIterations == 1) {
3001 mapping.map(iv, getValueOrCreateConstantIndexOp(rewriter, loc, lb));
3002 continue;
3003 }
3004 }
3005 newLowerBounds.push_back(lb);
3006 newUpperBounds.push_back(ub);
3007 newSteps.push_back(step);
3008 }
3009 // Exit if none of the loop dimensions perform a single iteration.
3010 if (newLowerBounds.size() == op.getLowerBound().size())
3011 return failure();
3012
3013 if (newLowerBounds.empty()) {
3014 // All of the loop dimensions perform a single iteration. Inline
3015 // loop body and nested ReduceOp's
3016 SmallVector<Value> results;
3017 results.reserve(op.getInitVals().size());
3018 for (auto &bodyOp : op.getBody()->without_terminator())
3019 rewriter.clone(bodyOp, mapping);
3020 auto reduceOp = cast<ReduceOp>(op.getBody()->getTerminator());
3021 for (int64_t i = 0, e = reduceOp.getReductions().size(); i < e; ++i) {
3022 Block &reduceBlock = reduceOp.getReductions()[i].front();
3023 auto initValIndex = results.size();
3024 mapping.map(reduceBlock.getArgument(0), op.getInitVals()[initValIndex]);
3025 mapping.map(reduceBlock.getArgument(1),
3026 mapping.lookupOrDefault(reduceOp.getOperands()[i]));
3027 for (auto &reduceBodyOp : reduceBlock.without_terminator())
3028 rewriter.clone(reduceBodyOp, mapping);
3029
3030 auto result = mapping.lookupOrDefault(
3031 cast<ReduceReturnOp>(reduceBlock.getTerminator()).getResult());
3032 results.push_back(result);
3033 }
3034
3035 rewriter.replaceOp(op, results);
3036 return success();
3037 }
3038 // Replace the parallel loop by lower-dimensional parallel loop.
3039 auto newOp =
3040 ParallelOp::create(rewriter, op.getLoc(), newLowerBounds,
3041 newUpperBounds, newSteps, op.getInitVals(), nullptr);
3042 // Erase the empty block that was inserted by the builder.
3043 rewriter.eraseBlock(newOp.getBody());
3044 // Clone the loop body and remap the block arguments of the collapsed loops
3045 // (inlining does not support a cancellable block argument mapping).
3046 rewriter.cloneRegionBefore(op.getRegion(), newOp.getRegion(),
3047 newOp.getRegion().begin(), mapping);
3048 rewriter.replaceOp(op, newOp.getResults());
3049 return success();
3050 }
3051};
3052
3053struct MergeNestedParallelLoops : public OpRewritePattern<ParallelOp> {
3054 using OpRewritePattern<ParallelOp>::OpRewritePattern;
3055
3056 LogicalResult matchAndRewrite(ParallelOp op,
3057 PatternRewriter &rewriter) const override {
3058 Block &outerBody = *op.getBody();
3059 if (!llvm::hasSingleElement(outerBody.without_terminator()))
3060 return failure();
3061
3062 auto innerOp = dyn_cast<ParallelOp>(outerBody.front());
3063 if (!innerOp)
3064 return failure();
3065
3066 for (auto val : outerBody.getArguments())
3067 if (llvm::is_contained(innerOp.getLowerBound(), val) ||
3068 llvm::is_contained(innerOp.getUpperBound(), val) ||
3069 llvm::is_contained(innerOp.getStep(), val))
3070 return failure();
3071
3072 // Reductions are not supported yet.
3073 if (!op.getInitVals().empty() || !innerOp.getInitVals().empty())
3074 return failure();
3075
3076 auto bodyBuilder = [&](OpBuilder &builder, Location /*loc*/,
3077 ValueRange iterVals, ValueRange) {
3078 Block &innerBody = *innerOp.getBody();
3079 assert(iterVals.size() ==
3080 (outerBody.getNumArguments() + innerBody.getNumArguments()));
3081 IRMapping mapping;
3082 mapping.map(outerBody.getArguments(),
3083 iterVals.take_front(outerBody.getNumArguments()));
3084 mapping.map(innerBody.getArguments(),
3085 iterVals.take_back(innerBody.getNumArguments()));
3086 for (Operation &op : innerBody.without_terminator())
3087 builder.clone(op, mapping);
3088 };
3089
3090 auto concatValues = [](const auto &first, const auto &second) {
3091 SmallVector<Value> ret;
3092 ret.reserve(first.size() + second.size());
3093 ret.assign(first.begin(), first.end());
3094 ret.append(second.begin(), second.end());
3095 return ret;
3096 };
3097
3098 auto newLowerBounds =
3099 concatValues(op.getLowerBound(), innerOp.getLowerBound());
3100 auto newUpperBounds =
3101 concatValues(op.getUpperBound(), innerOp.getUpperBound());
3102 auto newSteps = concatValues(op.getStep(), innerOp.getStep());
3103
3104 rewriter.replaceOpWithNewOp<ParallelOp>(op, newLowerBounds, newUpperBounds,
3105 newSteps, ValueRange(),
3106 bodyBuilder);
3107 return success();
3108 }
3109};
3110
3111} // namespace
3112
3113void ParallelOp::getCanonicalizationPatterns(RewritePatternSet &results,
3114 MLIRContext *context) {
3115 results
3116 .add<ParallelOpSingleOrZeroIterationDimsFolder, MergeNestedParallelLoops>(
3117 context);
3118}
3119
3120/// Given the region at `index`, or the parent operation if `index` is None,
3121/// return the successor regions. These are the regions that may be selected
3122/// during the flow of control. `operands` is a set of optional attributes that
3123/// correspond to a constant value for each operand, or null if that operand is
3124/// not a constant.
3125void ParallelOp::getSuccessorRegions(
3126 RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {
3127 // Both the operation itself and the region may be branching into the body or
3128 // back into the operation itself. It is possible for loop not to enter the
3129 // body.
3130 regions.push_back(RegionSuccessor(&getRegion()));
3131 regions.push_back(RegionSuccessor::parent());
3132}
3133
3134//===----------------------------------------------------------------------===//
3135// ReduceOp
3136//===----------------------------------------------------------------------===//
3137
3138void ReduceOp::build(OpBuilder &builder, OperationState &result) {}
3139
3140void ReduceOp::build(OpBuilder &builder, OperationState &result,
3141 ValueRange operands) {
3142 result.addOperands(operands);
3143 for (Value v : operands) {
3144 OpBuilder::InsertionGuard guard(builder);
3145 Region *bodyRegion = result.addRegion();
3146 builder.createBlock(bodyRegion, {},
3147 ArrayRef<Type>{v.getType(), v.getType()},
3148 {result.location, result.location});
3149 }
3150}
3151
3152LogicalResult ReduceOp::verifyRegions() {
3153 if (getReductions().size() != getOperands().size())
3154 return emitOpError() << "expects number of reduction regions: "
3155 << getReductions().size()
3156 << " to be the same as number of reduction operands: "
3157 << getOperands().size();
3158 // The region of a ReduceOp has two arguments of the same type as its
3159 // corresponding operand.
3160 for (int64_t i = 0, e = getReductions().size(); i < e; ++i) {
3161 auto type = getOperands()[i].getType();
3162 Block &block = getReductions()[i].front();
3163 if (block.empty())
3164 return emitOpError() << i << "-th reduction has an empty body";
3165 if (block.getNumArguments() != 2 ||
3166 llvm::any_of(block.getArguments(), [&](const BlockArgument &arg) {
3167 return arg.getType() != type;
3168 }))
3169 return emitOpError() << "expected two block arguments with type " << type
3170 << " in the " << i << "-th reduction region";
3171
3172 // Check that the block is terminated by a ReduceReturnOp.
3173 if (!isa<ReduceReturnOp>(block.getTerminator()))
3174 return emitOpError("reduction bodies must be terminated with an "
3175 "'scf.reduce.return' op");
3176 }
3177
3178 return success();
3179}
3180
3181MutableOperandRange
3182ReduceOp::getMutableSuccessorOperands(RegionSuccessor point) {
3183 // No operands are forwarded to the next iteration.
3184 return MutableOperandRange(getOperation(), /*start=*/0, /*length=*/0);
3185}
3186
3187//===----------------------------------------------------------------------===//
3188// ReduceReturnOp
3189//===----------------------------------------------------------------------===//
3190
3191LogicalResult ReduceReturnOp::verify() {
3192 // The type of the return value should be the same type as the types of the
3193 // block arguments of the reduction body.
3194 Block *reductionBody = getOperation()->getBlock();
3195 // Should already be verified by an op trait.
3196 assert(isa<ReduceOp>(reductionBody->getParentOp()) && "expected scf.reduce");
3197 Type expectedResultType = reductionBody->getArgument(0).getType();
3198 if (expectedResultType != getResult().getType())
3199 return emitOpError() << "must have type " << expectedResultType
3200 << " (the type of the reduction inputs)";
3201 return success();
3202}
3203
3204//===----------------------------------------------------------------------===//
3205// WhileOp
3206//===----------------------------------------------------------------------===//
3207
3208void WhileOp::build(::mlir::OpBuilder &odsBuilder,
3209 ::mlir::OperationState &odsState, TypeRange resultTypes,
3210 ValueRange inits, BodyBuilderFn beforeBuilder,
3211 BodyBuilderFn afterBuilder) {
3212 odsState.addOperands(inits);
3213 odsState.addTypes(resultTypes);
3214
3215 OpBuilder::InsertionGuard guard(odsBuilder);
3216
3217 // Build before region.
3218 SmallVector<Location, 4> beforeArgLocs;
3219 beforeArgLocs.reserve(inits.size());
3220 for (Value operand : inits) {
3221 beforeArgLocs.push_back(operand.getLoc());
3222 }
3223
3224 Region *beforeRegion = odsState.addRegion();
3225 Block *beforeBlock = odsBuilder.createBlock(beforeRegion, /*insertPt=*/{},
3226 inits.getTypes(), beforeArgLocs);
3227 if (beforeBuilder)
3228 beforeBuilder(odsBuilder, odsState.location, beforeBlock->getArguments());
3229
3230 // Build after region.
3231 SmallVector<Location, 4> afterArgLocs(resultTypes.size(), odsState.location);
3232
3233 Region *afterRegion = odsState.addRegion();
3234 Block *afterBlock = odsBuilder.createBlock(afterRegion, /*insertPt=*/{},
3235 resultTypes, afterArgLocs);
3236
3237 if (afterBuilder)
3238 afterBuilder(odsBuilder, odsState.location, afterBlock->getArguments());
3239}
3240
3241ConditionOp WhileOp::getConditionOp() {
3242 return cast<ConditionOp>(getBeforeBody()->getTerminator());
3243}
3244
3245YieldOp WhileOp::getYieldOp() {
3246 return cast<YieldOp>(getAfterBody()->getTerminator());
3247}
3248
3249std::optional<MutableArrayRef<OpOperand>> WhileOp::getYieldedValuesMutable() {
3250 return getYieldOp().getResultsMutable();
3251}
3252
3253Block::BlockArgListType WhileOp::getBeforeArguments() {
3254 return getBeforeBody()->getArguments();
3255}
3256
3257Block::BlockArgListType WhileOp::getAfterArguments() {
3258 return getAfterBody()->getArguments();
3259}
3260
3261Block::BlockArgListType WhileOp::getRegionIterArgs() {
3262 return getBeforeArguments();
3263}
3264
3265OperandRange WhileOp::getEntrySuccessorOperands(RegionSuccessor successor) {
3266 assert(successor.getSuccessor() == &getBefore() &&
3267 "WhileOp is expected to branch only to the first region");
3268 return getInits();
3269}
3270
3271void WhileOp::getSuccessorRegions(RegionBranchPoint point,
3272 SmallVectorImpl<RegionSuccessor> &regions) {
3273 // The parent op always branches to the condition region.
3274 if (point.isParent()) {
3275 regions.emplace_back(&getBefore());
3276 return;
3277 }
3278
3279 assert(llvm::is_contained(
3280 {&getAfter(), &getBefore()},
3281 point.getTerminatorPredecessorOrNull()->getParentRegion()) &&
3282 "there are only two regions in a WhileOp");
3283 // The body region always branches back to the condition region.
3284 if (point.getTerminatorPredecessorOrNull()->getParentRegion() ==
3285 &getAfter()) {
3286 regions.emplace_back(&getBefore());
3287 return;
3288 }
3289
3290 regions.push_back(RegionSuccessor::parent());
3291 regions.emplace_back(&getAfter());
3292}
3293
3294ValueRange WhileOp::getSuccessorInputs(RegionSuccessor successor) {
3295 if (successor.isParent())
3296 return getOperation()->getResults();
3297 if (successor == &getBefore())
3298 return getBefore().getArguments();
3299 if (successor == &getAfter())
3300 return getAfter().getArguments();
3301 llvm_unreachable("invalid region successor");
3302}
3303
3304SmallVector<Region *> WhileOp::getLoopRegions() {
3305 return {&getBefore(), &getAfter()};
3306}
3307
3308/// Parses a `while` op.
3309///
3310/// op ::= `scf.while` assignments `:` function-type region `do` region
3311/// `attributes` attribute-dict
3312/// initializer ::= /* empty */ | `(` assignment-list `)`
3313/// assignment-list ::= assignment | assignment `,` assignment-list
3314/// assignment ::= ssa-value `=` ssa-value
3315ParseResult scf::WhileOp::parse(OpAsmParser &parser, OperationState &result) {
3316 SmallVector<OpAsmParser::Argument, 4> regionArgs;
3317 SmallVector<OpAsmParser::UnresolvedOperand, 4> operands;
3318 Region *before = result.addRegion();
3319 Region *after = result.addRegion();
3320
3321 OptionalParseResult listResult =
3322 parser.parseOptionalAssignmentList(regionArgs, operands);
3323 if (listResult.has_value() && failed(listResult.value()))
3324 return failure();
3325
3326 FunctionType functionType;
3327 SMLoc typeLoc = parser.getCurrentLocation();
3328 if (failed(parser.parseColonType(functionType)))
3329 return failure();
3330
3331 result.addTypes(functionType.getResults());
3332
3333 if (functionType.getNumInputs() != operands.size()) {
3334 return parser.emitError(typeLoc)
3335 << "expected as many input types as operands " << "(expected "
3336 << operands.size() << " got " << functionType.getNumInputs() << ")";
3337 }
3338
3339 // Resolve input operands.
3340 if (failed(parser.resolveOperands(operands, functionType.getInputs(),
3341 parser.getCurrentLocation(),
3342 result.operands)))
3343 return failure();
3344
3345 // Propagate the types into the region arguments.
3346 for (size_t i = 0, e = regionArgs.size(); i != e; ++i)
3347 regionArgs[i].type = functionType.getInput(i);
3348
3349 return failure(parser.parseRegion(*before, regionArgs) ||
3350 parser.parseKeyword("do") || parser.parseRegion(*after) ||
3351 parser.parseOptionalAttrDictWithKeyword(result.attributes));
3352}
3353
3354/// Prints a `while` op.
3355void scf::WhileOp::print(OpAsmPrinter &p) {
3356 printInitializationList(p, getBeforeArguments(), getInits(), " ");
3357 p << " : ";
3358 p.printFunctionalType(getInits().getTypes(), getResults().getTypes());
3359 p << ' ';
3360 p.printRegion(getBefore(), /*printEntryBlockArgs=*/false);
3361 p << " do ";
3362 p.printRegion(getAfter());
3363 p.printOptionalAttrDictWithKeyword((*this)->getAttrs());
3364}
3365
3366/// Verifies that two ranges of types match, i.e. have the same number of
3367/// entries and that types are pairwise equals. Reports errors on the given
3368/// operation in case of mismatch.
3369template <typename OpTy>
3370static LogicalResult verifyTypeRangesMatch(OpTy op, TypeRange left,
3371 TypeRange right, StringRef message) {
3372 if (left.size() != right.size())
3373 return op.emitOpError("expects the same number of ") << message;
3374
3375 for (unsigned i = 0, e = left.size(); i < e; ++i) {
3376 if (left[i] != right[i]) {
3377 InFlightDiagnostic diag = op.emitOpError("expects the same types for ")
3378 << message;
3379 diag.attachNote() << "for argument " << i << ", found " << left[i]
3380 << " and " << right[i];
3381 return diag;
3382 }
3383 }
3384
3385 return success();
3386}
3387
3388LogicalResult scf::WhileOp::verify() {
3389 auto beforeTerminator = verifyAndGetTerminator<scf::ConditionOp>(
3390 *this, getBefore(),
3391 "expects the 'before' region to terminate with 'scf.condition'");
3392 if (!beforeTerminator)
3393 return failure();
3394
3395 auto afterTerminator = verifyAndGetTerminator<scf::YieldOp>(
3396 *this, getAfter(),
3397 "expects the 'after' region to terminate with 'scf.yield'");
3398 return success(afterTerminator != nullptr);
3399}
3400
3401namespace {
3402/// Move a scf.if op that is directly before the scf.condition op in the while
3403/// before region, and whose condition matches the condition of the
3404/// scf.condition op, down into the while after region.
3405///
3406/// scf.while (..) : (...) -> ... {
3407/// %additional_used_values = ...
3408/// %cond = ...
3409/// ...
3410/// %res = scf.if %cond -> (...) {
3411/// use(%additional_used_values)
3412/// ... // then block
3413/// scf.yield %then_value
3414/// } else {
3415/// scf.yield %else_value
3416/// }
3417/// scf.condition(%cond) %res, ...
3418/// } do {
3419/// ^bb0(%res_arg, ...):
3420/// use(%res_arg)
3421/// ...
3422///
3423/// becomes
3424/// scf.while (..) : (...) -> ... {
3425/// %additional_used_values = ...
3426/// %cond = ...
3427/// ...
3428/// scf.condition(%cond) %else_value, ..., %additional_used_values
3429/// } do {
3430/// ^bb0(%res_arg ..., %additional_args): :
3431/// use(%additional_args)
3432/// ... // if then block
3433/// use(%then_value)
3434/// ...
3435struct WhileMoveIfDown : public OpRewritePattern<scf::WhileOp> {
3436 using OpRewritePattern<scf::WhileOp>::OpRewritePattern;
3437
3438 LogicalResult matchAndRewrite(scf::WhileOp op,
3439 PatternRewriter &rewriter) const override {
3440 auto conditionOp = op.getConditionOp();
3441
3442 // Only support ifOp right before the condition at the moment. Relaxing this
3443 // would require to:
3444 // - check that the body does not have side-effects conflicting with
3445 // operations between the if and the condition.
3446 // - check that results of the if operation are only used as arguments to
3447 // the condition.
3448 auto ifOp = dyn_cast_or_null<scf::IfOp>(conditionOp->getPrevNode());
3449
3450 // Check that the ifOp is directly before the conditionOp and that it
3451 // matches the condition of the conditionOp. Also ensure that the ifOp has
3452 // no else block with content, as that would complicate the transformation.
3453 // TODO: support else blocks with content.
3454 if (!ifOp || ifOp.getCondition() != conditionOp.getCondition() ||
3455 (ifOp.elseBlock() && !ifOp.elseBlock()->without_terminator().empty()))
3456 return failure();
3457
3458 assert((ifOp->use_empty() || (llvm::all_equal(ifOp->getUsers()) &&
3459 *ifOp->user_begin() == conditionOp)) &&
3460 "ifOp has unexpected uses");
3461
3462 Location loc = op.getLoc();
3463
3464 // Replace uses of ifOp results in the conditionOp with the yielded values
3465 // from the ifOp branches.
3466 for (auto [idx, arg] : llvm::enumerate(conditionOp.getArgs())) {
3467 auto it = llvm::find(ifOp->getResults(), arg);
3468 if (it != ifOp->getResults().end()) {
3469 size_t ifOpIdx = it.getIndex();
3470 Value thenValue = ifOp.thenYield()->getOperand(ifOpIdx);
3471 Value elseValue = ifOp.elseYield()->getOperand(ifOpIdx);
3472
3473 rewriter.replaceAllUsesWith(ifOp->getResults()[ifOpIdx], elseValue);
3474 rewriter.replaceAllUsesWith(op.getAfterArguments()[idx], thenValue);
3475 }
3476 }
3477
3478 // Collect additional used values from before region.
3479 SetVector<Value> additionalUsedValuesSet;
3480 visitUsedValuesDefinedAbove(ifOp.getThenRegion(), [&](OpOperand *operand) {
3481 if (&op.getBefore() == operand->get().getParentRegion())
3482 additionalUsedValuesSet.insert(operand->get());
3483 });
3484
3485 // Create new whileOp with additional used values as results.
3486 auto additionalUsedValues = additionalUsedValuesSet.getArrayRef();
3487 auto additionalValueTypes = llvm::map_to_vector(
3488 additionalUsedValues, [](Value val) { return val.getType(); });
3489 size_t additionalValueSize = additionalUsedValues.size();
3490 SmallVector<Type> newResultTypes(op.getResultTypes());
3491 newResultTypes.append(additionalValueTypes);
3492
3493 auto newWhileOp =
3494 scf::WhileOp::create(rewriter, loc, newResultTypes, op.getInits());
3495
3496 rewriter.modifyOpInPlace(newWhileOp, [&] {
3497 newWhileOp.getBefore().takeBody(op.getBefore());
3498 newWhileOp.getAfter().takeBody(op.getAfter());
3499 newWhileOp.getAfter().addArguments(
3500 additionalValueTypes,
3501 SmallVector<Location>(additionalValueSize, loc));
3502 });
3503
3504 rewriter.modifyOpInPlace(conditionOp, [&] {
3505 conditionOp.getArgsMutable().append(additionalUsedValues);
3506 });
3507
3508 // Replace uses of additional used values inside the ifOp then region with
3509 // the whileOp after region arguments.
3510 rewriter.replaceUsesWithIf(
3511 additionalUsedValues,
3512 newWhileOp.getAfterArguments().take_back(additionalValueSize),
3513 [&](OpOperand &use) {
3514 return ifOp.getThenRegion().isAncestor(
3515 use.getOwner()->getParentRegion());
3516 });
3517
3518 // Inline ifOp then region into new whileOp after region.
3519 rewriter.eraseOp(ifOp.thenYield());
3520 rewriter.inlineBlockBefore(ifOp.thenBlock(), newWhileOp.getAfterBody(),
3521 newWhileOp.getAfterBody()->begin());
3522 rewriter.eraseOp(ifOp);
3523 rewriter.replaceOp(op,
3524 newWhileOp->getResults().drop_back(additionalValueSize));
3525 return success();
3526 }
3527};
3528
3529/// Replace uses of the condition within the do block with true, since otherwise
3530/// the block would not be evaluated.
3531///
3532/// scf.while (..) : (i1, ...) -> ... {
3533/// %condition = call @evaluate_condition() : () -> i1
3534/// scf.condition(%condition) %condition : i1, ...
3535/// } do {
3536/// ^bb0(%arg0: i1, ...):
3537/// use(%arg0)
3538/// ...
3539///
3540/// becomes
3541/// scf.while (..) : (i1, ...) -> ... {
3542/// %condition = call @evaluate_condition() : () -> i1
3543/// scf.condition(%condition) %condition : i1, ...
3544/// } do {
3545/// ^bb0(%arg0: i1, ...):
3546/// use(%true)
3547/// ...
3548struct WhileConditionTruth : public OpRewritePattern<WhileOp> {
3549 using OpRewritePattern<WhileOp>::OpRewritePattern;
3550
3551 LogicalResult matchAndRewrite(WhileOp op,
3552 PatternRewriter &rewriter) const override {
3553 auto term = op.getConditionOp();
3554
3555 // These variables serve to prevent creating duplicate constants
3556 // and hold constant true or false values.
3557 Value constantTrue = nullptr;
3558
3559 bool replaced = false;
3560 for (auto yieldedAndBlockArgs :
3561 llvm::zip(term.getArgs(), op.getAfterArguments())) {
3562 if (std::get<0>(yieldedAndBlockArgs) == term.getCondition()) {
3563 if (!std::get<1>(yieldedAndBlockArgs).use_empty()) {
3564 if (!constantTrue)
3565 constantTrue = arith::ConstantOp::create(
3566 rewriter, op.getLoc(), term.getCondition().getType(),
3567 rewriter.getBoolAttr(true));
3568
3569 rewriter.replaceAllUsesWith(std::get<1>(yieldedAndBlockArgs),
3570 constantTrue);
3571 replaced = true;
3572 }
3573 }
3574 }
3575 return success(replaced);
3576 }
3577};
3578
3579/// Replace operations equivalent to the condition in the do block with true,
3580/// since otherwise the block would not be evaluated.
3581///
3582/// scf.while (..) : (i32, ...) -> ... {
3583/// %z = ... : i32
3584/// %condition = cmpi pred %z, %a
3585/// scf.condition(%condition) %z : i32, ...
3586/// } do {
3587/// ^bb0(%arg0: i32, ...):
3588/// %condition2 = cmpi pred %arg0, %a
3589/// use(%condition2)
3590/// ...
3591///
3592/// becomes
3593/// scf.while (..) : (i32, ...) -> ... {
3594/// %z = ... : i32
3595/// %condition = cmpi pred %z, %a
3596/// scf.condition(%condition) %z : i32, ...
3597/// } do {
3598/// ^bb0(%arg0: i32, ...):
3599/// use(%true)
3600/// ...
3601struct WhileCmpCond : public OpRewritePattern<scf::WhileOp> {
3602 using OpRewritePattern<scf::WhileOp>::OpRewritePattern;
3603
3604 LogicalResult matchAndRewrite(scf::WhileOp op,
3605 PatternRewriter &rewriter) const override {
3606 using namespace scf;
3607 auto cond = op.getConditionOp();
3608 auto cmp = cond.getCondition().getDefiningOp<arith::CmpIOp>();
3609 if (!cmp)
3610 return failure();
3611 bool changed = false;
3612 for (auto tup : llvm::zip(cond.getArgs(), op.getAfterArguments())) {
3613 for (size_t opIdx = 0; opIdx < 2; opIdx++) {
3614 if (std::get<0>(tup) != cmp.getOperand(opIdx))
3615 continue;
3616 for (OpOperand &u :
3617 llvm::make_early_inc_range(std::get<1>(tup).getUses())) {
3618 auto cmp2 = dyn_cast<arith::CmpIOp>(u.getOwner());
3619 if (!cmp2)
3620 continue;
3621 // For a binary operator 1-opIdx gets the other side.
3622 if (cmp2.getOperand(1 - opIdx) != cmp.getOperand(1 - opIdx))
3623 continue;
3624 bool samePredicate;
3625 if (cmp2.getPredicate() == cmp.getPredicate())
3626 samePredicate = true;
3627 else if (cmp2.getPredicate() ==
3628 arith::invertPredicate(cmp.getPredicate()))
3629 samePredicate = false;
3630 else
3631 continue;
3632
3633 rewriter.replaceOpWithNewOp<arith::ConstantIntOp>(cmp2, samePredicate,
3634 1);
3635 changed = true;
3636 }
3637 }
3638 }
3639 return success(changed);
3640 }
3641};
3642
3643/// If both ranges contain same values return mappping indices from args2 to
3644/// args1. Otherwise return std::nullopt.
3645static std::optional<SmallVector<unsigned>> getArgsMapping(ValueRange args1,
3646 ValueRange args2) {
3647 if (args1.size() != args2.size())
3648 return std::nullopt;
3649
3650 SmallVector<unsigned> ret(args1.size());
3651 for (auto &&[i, arg1] : llvm::enumerate(args1)) {
3652 auto it = llvm::find(args2, arg1);
3653 if (it == args2.end())
3654 return std::nullopt;
3655
3656 ret[std::distance(args2.begin(), it)] = static_cast<unsigned>(i);
3657 }
3658
3659 return ret;
3660}
3661
3662static bool hasDuplicates(ValueRange args) {
3663 llvm::SmallDenseSet<Value> set;
3664 for (Value arg : args) {
3665 if (!set.insert(arg).second)
3666 return true;
3667 }
3668 return false;
3669}
3670
3671/// If `before` block args are directly forwarded to `scf.condition`, rearrange
3672/// `scf.condition` args into same order as block args. Update `after` block
3673/// args and op result values accordingly.
3674/// Needed to simplify `scf.while` -> `scf.for` uplifting.
3675struct WhileOpAlignBeforeArgs : public OpRewritePattern<WhileOp> {
3677
3678 LogicalResult matchAndRewrite(WhileOp loop,
3679 PatternRewriter &rewriter) const override {
3680 auto *oldBefore = loop.getBeforeBody();
3681 ConditionOp oldTerm = loop.getConditionOp();
3682 ValueRange beforeArgs = oldBefore->getArguments();
3683 ValueRange termArgs = oldTerm.getArgs();
3684 if (beforeArgs == termArgs)
3685 return failure();
3686
3687 if (hasDuplicates(termArgs))
3688 return failure();
3689
3690 auto mapping = getArgsMapping(beforeArgs, termArgs);
3691 if (!mapping)
3692 return failure();
3693
3694 {
3695 OpBuilder::InsertionGuard g(rewriter);
3696 rewriter.setInsertionPoint(oldTerm);
3697 rewriter.replaceOpWithNewOp<ConditionOp>(oldTerm, oldTerm.getCondition(),
3698 beforeArgs);
3699 }
3700
3701 auto *oldAfter = loop.getAfterBody();
3702
3703 SmallVector<Type> newResultTypes(beforeArgs.size());
3704 for (auto &&[i, j] : llvm::enumerate(*mapping))
3705 newResultTypes[j] = loop.getResult(i).getType();
3706
3707 auto newLoop = WhileOp::create(
3708 rewriter, loop.getLoc(), newResultTypes, loop.getInits(),
3709 /*beforeBuilder=*/nullptr, /*afterBuilder=*/nullptr);
3710 auto *newBefore = newLoop.getBeforeBody();
3711 auto *newAfter = newLoop.getAfterBody();
3712
3713 SmallVector<Value> newResults(beforeArgs.size());
3714 SmallVector<Value> newAfterArgs(beforeArgs.size());
3715 for (auto &&[i, j] : llvm::enumerate(*mapping)) {
3716 newResults[i] = newLoop.getResult(j);
3717 newAfterArgs[i] = newAfter->getArgument(j);
3718 }
3719
3720 rewriter.inlineBlockBefore(oldBefore, newBefore, newBefore->begin(),
3721 newBefore->getArguments());
3722 rewriter.inlineBlockBefore(oldAfter, newAfter, newAfter->begin(),
3723 newAfterArgs);
3724
3725 rewriter.replaceOp(loop, newResults);
3726 return success();
3727 }
3728};
3729} // namespace
3730
3731void WhileOp::getCanonicalizationPatterns(RewritePatternSet &results,
3732 MLIRContext *context) {
3733 results.add<WhileConditionTruth, WhileCmpCond, WhileOpAlignBeforeArgs,
3734 WhileMoveIfDown>(context);
3736 results, WhileOp::getOperationName());
3738 WhileOp::getOperationName());
3739}
3740
3741//===----------------------------------------------------------------------===//
3742// IndexSwitchOp
3743//===----------------------------------------------------------------------===//
3744
3745/// Parse the case regions and values.
3746static ParseResult
3748 SmallVectorImpl<std::unique_ptr<Region>> &caseRegions) {
3749 SmallVector<int64_t> caseValues;
3750 while (succeeded(p.parseOptionalKeyword("case"))) {
3751 int64_t value;
3752 Region &region = *caseRegions.emplace_back(std::make_unique<Region>());
3753 if (p.parseInteger(value) || p.parseRegion(region, /*arguments=*/{}))
3754 return failure();
3755 caseValues.push_back(value);
3756 }
3757 cases = p.getBuilder().getDenseI64ArrayAttr(caseValues);
3758 return success();
3759}
3760
3761/// Print the case regions and values.
3763 DenseI64ArrayAttr cases, RegionRange caseRegions) {
3764 for (auto [value, region] : llvm::zip(cases.asArrayRef(), caseRegions)) {
3765 p.printNewline();
3766 p << "case " << value << ' ';
3767 p.printRegion(*region, /*printEntryBlockArgs=*/false);
3768 }
3769}
3770
3771LogicalResult scf::IndexSwitchOp::verify() {
3772 if (getCases().size() != getCaseRegions().size()) {
3773 return emitOpError("has ")
3774 << getCaseRegions().size() << " case regions but "
3775 << getCases().size() << " case values";
3776 }
3777
3778 DenseSet<int64_t> valueSet;
3779 for (int64_t value : getCases())
3780 if (!valueSet.insert(value).second)
3781 return emitOpError("has duplicate case value: ") << value;
3782 auto verifyRegion = [&](Region &region, const Twine &name) -> LogicalResult {
3783 auto yield = dyn_cast<YieldOp>(region.front().back());
3784 if (!yield)
3785 return emitOpError("expected region to end with scf.yield, but got ")
3786 << region.front().back().getName();
3787
3788 if (yield.getNumOperands() != getNumResults()) {
3789 return (emitOpError("expected each region to return ")
3790 << getNumResults() << " values, but " << name << " returns "
3791 << yield.getNumOperands())
3792 .attachNote(yield.getLoc())
3793 << "see yield operation here";
3794 }
3795 for (auto [idx, result, operand] :
3796 llvm::enumerate(getResultTypes(), yield.getOperands())) {
3797 if (!operand)
3798 return yield.emitOpError() << "operand " << idx << " is null\n";
3799 if (result == operand.getType())
3800 continue;
3801 return (emitOpError("expected result #")
3802 << idx << " of each region to be " << result)
3803 .attachNote(yield.getLoc())
3804 << name << " returns " << operand.getType() << " here";
3805 }
3806 return success();
3807 };
3808
3809 if (failed(verifyRegion(getDefaultRegion(), "default region")))
3810 return failure();
3811 for (auto [idx, caseRegion] : llvm::enumerate(getCaseRegions()))
3812 if (failed(verifyRegion(caseRegion, "case region #" + Twine(idx))))
3813 return failure();
3814
3815 return success();
3816}
3817
3818unsigned scf::IndexSwitchOp::getNumCases() { return getCases().size(); }
3819
3820Block &scf::IndexSwitchOp::getDefaultBlock() {
3821 return getDefaultRegion().front();
3822}
3823
3824Block &scf::IndexSwitchOp::getCaseBlock(unsigned idx) {
3825 assert(idx < getNumCases() && "case index out-of-bounds");
3826 return getCaseRegions()[idx].front();
3827}
3828
3829void IndexSwitchOp::getSuccessorRegions(
3830 RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &successors) {
3831 // All regions branch back to the parent op.
3832 if (!point.isParent()) {
3833 successors.push_back(RegionSuccessor::parent());
3834 return;
3835 }
3836
3837 llvm::append_range(successors, getRegions());
3838}
3839
3840ValueRange IndexSwitchOp::getSuccessorInputs(RegionSuccessor successor) {
3841 return successor.isParent() ? ValueRange(getOperation()->getResults())
3842 : ValueRange();
3843}
3844
3845void IndexSwitchOp::getEntrySuccessorRegions(
3846 ArrayRef<Attribute> operands,
3847 SmallVectorImpl<RegionSuccessor> &successors) {
3848 FoldAdaptor adaptor(operands, *this);
3849
3850 // If a constant was not provided, all regions are possible successors.
3851 auto arg = dyn_cast_or_null<IntegerAttr>(adaptor.getArg());
3852 if (!arg) {
3853 llvm::append_range(successors, getRegions());
3854 return;
3855 }
3856
3857 // Otherwise, try to find a case with a matching value. If not, the
3858 // default region is the only successor.
3859 for (auto [caseValue, caseRegion] : llvm::zip(getCases(), getCaseRegions())) {
3860 if (caseValue == arg.getInt()) {
3861 successors.emplace_back(&caseRegion);
3862 return;
3863 }
3864 }
3865 successors.emplace_back(&getDefaultRegion());
3866}
3867
3868void IndexSwitchOp::getRegionInvocationBounds(
3869 ArrayRef<Attribute> operands, SmallVectorImpl<InvocationBounds> &bounds) {
3870 auto operandValue = llvm::dyn_cast_or_null<IntegerAttr>(operands.front());
3871 if (!operandValue) {
3872 // All regions are invoked at most once.
3873 bounds.append(getNumRegions(), InvocationBounds(/*lb=*/0, /*ub=*/1));
3874 return;
3875 }
3876
3877 unsigned liveIndex = getNumRegions() - 1;
3878 const auto *it = llvm::find(getCases(), operandValue.getInt());
3879 if (it != getCases().end())
3880 liveIndex = std::distance(getCases().begin(), it);
3881 for (unsigned i = 0, e = getNumRegions(); i < e; ++i)
3882 bounds.emplace_back(/*lb=*/0, /*ub=*/i == liveIndex);
3883}
3884
3885void IndexSwitchOp::getCanonicalizationPatterns(RewritePatternSet &results,
3886 MLIRContext *context) {
3888 results, IndexSwitchOp::getOperationName());
3890 results, IndexSwitchOp::getOperationName());
3891}
3892
3893//===----------------------------------------------------------------------===//
3894// TableGen'd op method definitions
3895//===----------------------------------------------------------------------===//
3896
3897#define GET_OP_CLASSES
3898#include "mlir/Dialect/SCF/IR/SCFOps.cpp.inc"
return success()
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
static std::optional< int64_t > getUpperBound(Value iv)
Gets the constant upper bound on an affine.for iv.
static std::optional< int64_t > getLowerBound(Value iv)
Gets the constant lower bound on an iv.
static LogicalResult verifyRegion(emitc::SwitchOp op, Region &region, const Twine &name)
Definition EmitC.cpp:1497
static ParseResult parseSwitchCases(OpAsmParser &parser, DenseI64ArrayAttr &cases, SmallVectorImpl< std::unique_ptr< Region > > &caseRegions)
Parse the case regions and values.
Definition EmitC.cpp:1472
static void printSwitchCases(OpAsmPrinter &p, Operation *op, DenseI64ArrayAttr cases, RegionRange caseRegions)
Print the case regions and values.
Definition EmitC.cpp:1488
static LogicalResult verifyTypeRangesMatch(OpTy op, TypeRange left, TypeRange right, StringRef message)
Verifies that two ranges of types match, i.e.
Definition SCF.cpp:3370
static void printInitializationList(OpAsmPrinter &p, Block::BlockArgListType blocksArgs, ValueRange initializers, StringRef prefix="")
Prints the initialization list in the form of <prefix>(inner = outer, inner2 = outer2,...
Definition SCF.cpp:496
static TerminatorTy verifyAndGetTerminator(Operation *op, Region &region, StringRef errorMessage)
Verifies that the first block of the given region is terminated by a TerminatorTy.
Definition SCF.cpp:101
static bool isLegalToInline(InlinerInterface &interface, Region *src, Region *insertRegion, bool shouldCloneInlinedRegion, IRMapping &valueMapping)
Utility to check that all of the operations within 'src' can be inlined.
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
b getContext())
*if copies could not be generated due to yet unimplemented cases *copyInPlacementStart and copyOutPlacementStart in copyPlacementBlock *specify the insertion points where the incoming copies and outgoing should be the output argument nBegin is set to its * replacement(set to `begin` if no invalidation happens). Since outgoing *copies could have been inserted at `end`
@ None
static std::string diag(const llvm::Value &value)
@ Paren
Parens surrounding zero or more operands.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseOptionalColon()=0
Parse a : token if present.
ParseResult parseInteger(IntT &result)
Parse an integer value from the stream.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual ParseResult parseOptionalAttrDictWithKeyword(NamedAttrList &result)=0
Parse a named dictionary into 'result' if the attributes keyword is present.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseOptionalArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an optional arrow followed by a type list.
virtual ParseResult parseArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an arrow followed by a type list.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
void printOptionalArrowTypeList(TypeRange &&types)
Print an optional arrow followed by a type list.
virtual void printNewline()
Print a newline and indent the printer to the start of the current operation/attribute/type.
This class represents an argument of a Block.
Definition Value.h:306
unsigned getArgNumber() const
Returns the number of this argument.
Definition Value.h:318
Block represents an ordered list of Operations.
Definition Block.h:33
MutableArrayRef< BlockArgument > BlockArgListType
Definition Block.h:95
bool empty()
Definition Block.h:158
BlockArgument getArgument(unsigned i)
Definition Block.h:139
unsigned getNumArguments()
Definition Block.h:138
iterator_range< args_iterator > addArguments(TypeRange types, ArrayRef< Location > locs)
Add one argument to the argument list for each type specified in the list.
Definition Block.cpp:165
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
Definition Block.cpp:27
Operation & front()
Definition Block.h:163
Operation & back()
Definition Block.h:162
Operation * getTerminator()
Get the terminator operation of this block.
Definition Block.cpp:249
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
Definition Block.cpp:158
BlockArgListType getArguments()
Definition Block.h:97
iterator_range< iterator > without_terminator()
Return an iterator range over the operation within this block excluding the terminator operation at t...
Definition Block.h:222
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
Definition Block.cpp:31
bool getValue() const
Return the boolean value of this attribute.
UnitAttr getUnitAttr()
Definition Builders.cpp:102
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
Definition Builders.cpp:167
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition Builders.cpp:232
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
Definition Builders.cpp:171
IntegerType getIntegerType(unsigned width)
Definition Builders.cpp:71
BoolAttr getBoolAttr(bool value)
Definition Builders.cpp:104
IntegerType getI1Type()
Definition Builders.cpp:57
MLIRContext * getContext() const
Definition Builders.h:56
IndexType getIndexType()
Definition Builders.cpp:55
virtual Operation * materializeConstant(OpBuilder &builder, Attribute value, Type type, Location loc)
Registered hook to materialize a single constant operation from a given attribute value with the desi...
Definition Dialect.h:83
This is a utility class for mapping one set of IR entities to another.
Definition IRMapping.h:26
auto lookupOrDefault(T from) const
Lookup a mapped value within the map.
Definition IRMapping.h:65
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Definition IRMapping.h:30
IRValueT get() const
Return the current value being used by this operand.
void set(IRValueT newValue)
Set the current value being used by this operand.
This class represents a diagnostic that is inflight and set to be reported.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
This class provides a mutable adaptor for a range of operands.
Definition ValueRange.h:119
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual OptionalParseResult parseOptionalAssignmentList(SmallVectorImpl< Argument > &lhs, SmallVectorImpl< UnresolvedOperand > &rhs)=0
virtual ParseResult parseRegion(Region &region, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult parseArgumentList(SmallVectorImpl< Argument > &result, Delimiter delimiter=Delimiter::None, bool allowType=false, bool allowAttrs=false)=0
Parse zero or more arguments with a specified surrounding delimiter.
ParseResult parseAssignmentList(SmallVectorImpl< Argument > &lhs, SmallVectorImpl< UnresolvedOperand > &rhs)
Parse a list of assignments of the form (x1 = y1, x2 = y2, ...)
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printOptionalAttrDictWithKeyword(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary prefixed with 'attribute...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
void printFunctionalType(Operation *op)
Print the complete type of an operation in functional form.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
RAII guard to reset the insertion point of the builder when destroyed.
Definition Builders.h:350
This class helps build Operations.
Definition Builders.h:209
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition Builders.cpp:434
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition Builders.cpp:566
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition Builders.h:433
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition Builders.h:400
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Definition Builders.h:438
void cloneRegionBefore(Region &region, Region &parent, Region::iterator before, IRMapping &mapping)
Clone the blocks that belong to "region" before the given position in another region "parent".
Definition Builders.cpp:593
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition Builders.h:414
Operation * cloneWithoutRegions(Operation &op, IRMapping &mapper)
Creates a deep copy of this operation but keep the operation regions empty.
Definition Builders.h:596
This class represents a single result from folding an operation.
This class represents an operand of an operation.
Definition Value.h:254
unsigned getOperandNumber() const
Return which operand this is in the OpOperand list of the Operation.
Definition Value.cpp:226
Set of flags used to control the behavior of the various IR print methods (e.g.
A wrapper class that allows for printing an operation with a set of flags, useful to act as a "stream...
Definition Operation.h:1140
This class implements the operand iterators for the Operation class.
Definition ValueRange.h:44
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Definition Operation.h:541
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition Operation.h:436
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:244
OperandRange operand_range
Definition Operation.h:400
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
Definition Operation.h:259
OperationName getName()
The name of an operation is the key identifier for it.
Definition Operation.h:119
Region * getParentRegion()
Returns the region to which the instruction belongs.
Definition Operation.h:251
bool isProperAncestor(Operation *other)
Return true if this operation is a proper ancestor of the other operation.
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
ParseResult value() const
Access the internal ParseResult value.
bool has_value() const
Returns true if we contain a valid ParseResult value.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class represents a point being branched from in the methods of the RegionBranchOpInterface.
bool isParent() const
Returns true if branching from the parent op.
RegionBranchTerminatorOpInterface getTerminatorPredecessorOrNull() const
Returns the terminator if branching from a region.
This class provides an abstraction over the different types of ranges over Regions.
Definition Region.h:357
This class represents a successor of a region.
static RegionSuccessor parent()
Initialize a successor that branches after/out of the parent operation.
bool isParent() const
Return true if the successor is the parent operation.
Region * getSuccessor() const
Return the given region successor.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition Region.h:26
Block & front()
Definition Region.h:65
Region * getParentRegion()
Return the region containing this region or nullptr if the region is attached to a top-level operatio...
Definition Region.cpp:45
bool isAncestor(Region *other)
Return true if this region is ancestor of the other region.
Definition Region.h:233
Block & back()
Definition Region.h:64
bool empty()
Definition Region.h:60
unsigned getNumArguments()
Definition Region.h:123
iterator begin()
Definition Region.h:55
BlockArgument getArgument(unsigned i)
Definition Region.h:124
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void eraseBlock(Block *block)
This method erases all operations in a block.
Block * splitBlock(Block *block, Block::iterator before)
Split the operations starting at "before" (inclusive) out of the given block into a new block,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void finalizeOpModification(Operation *op)
This method is used to signal the end of an in-place modification of the given operation.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
Operation * eraseOpResults(Operation *op, const BitVector &eraseIndices)
Erase the specified results of the given operation.
virtual void replaceUsesWithIf(Value from, Value to, function_ref< bool(OpOperand &)> functor, bool *allUsesReplaced=nullptr)
Find uses of from and replace them with to if the functor returns true.
virtual void inlineBlockBefore(Block *source, Block *dest, Block::iterator before, ValueRange argValues={})
Inline the operations of block 'source' into block 'dest' before the given position.
void mergeBlocks(Block *source, Block *dest, ValueRange argValues={})
Inline the operations of block 'source' into the end of block 'dest'.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
void inlineRegionBefore(Region &region, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
virtual void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
virtual void startOpModification(Operation *op)
This method is used to notify the rewriter that an in-place operation modification is about to happen...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class provides an abstraction over the various different ranges of value types.
Definition TypeRange.h:40
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
bool isSignlessInteger() const
Return true if this is a signless integer type (with the specified width).
Definition Types.cpp:66
bool isIndex() const
Definition Types.cpp:56
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:389
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
user_range getUsers() const
Definition Value.h:218
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
Region * getParentRegion()
Return the Region in which this Value is defined.
Definition Value.cpp:39
Operation * getOwner() const
Return the owner of this operand.
Definition UseDefLists.h:38
constexpr auto RecursivelySpeculatable
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto NotSpeculatable
static Value defaultReplBuilderFn(OpBuilder &builder, Location loc, Value value)
Default implementation of the non-successor-input replacement builder function.
StringRef getMappingAttrName()
Name of the mapping attribute produced by loop mappers.
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
ParallelOp getParallelForInductionVarOwner(Value val)
Returns the parallel loop parent of an induction variable.
Definition SCF.cpp:2965
void buildTerminatedBody(OpBuilder &builder, Location loc)
Default callback for IfOp builders. Inserts a yield without arguments.
Definition SCF.cpp:94
LoopNest buildLoopNest(OpBuilder &builder, Location loc, ValueRange lbs, ValueRange ubs, ValueRange steps, ValueRange iterArgs, function_ref< ValueVector(OpBuilder &, Location, ValueRange, ValueRange)> bodyBuilder=nullptr)
Creates a perfect nest of "for" loops, i.e.
Definition SCF.cpp:793
bool insideMutuallyExclusiveBranches(Operation *a, Operation *b)
Return true if ops a and b (or their ancestors) are in mutually exclusive regions/blocks of an IfOp.
Definition SCF.cpp:1914
void promote(RewriterBase &rewriter, scf::ForallOp forallOp)
Promotes the loop body of a scf::ForallOp to its containing block.
Definition SCF.cpp:748
std::optional< llvm::APSInt > computeUbMinusLb(Value lb, Value ub, bool isSigned)
Helper function to compute the difference between two values.
Definition SCF.cpp:115
ForOp getForInductionVarOwner(Value val)
Returns the loop parent of an induction variable.
Definition SCF.cpp:675
SmallVector< Value > ValueVector
An owning vector of values, handy to return from functions.
Definition SCF.h:65
llvm::function_ref< Value(OpBuilder &, Location loc, Type, Value)> ValueTypeCastFnTy
Perform a replacement of one iter OpOperand of an scf.for to the replacement value with a different t...
Definition SCF.h:108
ForallOp getForallOpThreadIndexOwner(Value val)
Returns the ForallOp parent of an thread index variable.
Definition SCF.cpp:1388
SmallVector< Value > replaceAndCastForOpIterArg(RewriterBase &rewriter, scf::ForOp forOp, OpOperand &operand, Value replacement, const ValueTypeCastFnTy &castFn)
Definition SCF.cpp:882
bool preservesStaticInformation(Type source, Type target)
Returns true if target is a ranked tensor type that preserves static information available in the sou...
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition Matchers.h:490
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, MLIRContext *context)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
Definition Matchers.h:527
detail::DenseArrayAttrImpl< int64_t > DenseI64ArrayAttr
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
std::function< SmallVector< Value >( OpBuilder &b, Location loc, ArrayRef< BlockArgument > newBbArgs)> NewYieldValuesFn
A function that returns the additional yielded values during replaceWithAdditionalYields.
ParseResult parseDynamicIndexList(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &values, DenseI64ArrayAttr &integers, DenseBoolArrayAttr &scalableFlags, SmallVectorImpl< Type > *valueTypes=nullptr, AsmParser::Delimiter delimiter=AsmParser::Delimiter::Square)
Parser hooks for custom directive in assemblyFormat.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:307
llvm::DenseSet< ValueT, ValueInfoT > DenseSet
Definition LLVM.h:122
llvm::SetVector< T, Vector, Set, N > SetVector
Definition LLVM.h:125
detail::constant_int_predicate_matcher m_One()
Matches a constant scalar / vector splat / tensor splat integer one.
Definition Matchers.h:478
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition Utils.cpp:114
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
Definition LLVM.h:120
void populateRegionBranchOpInterfaceInliningPattern(RewritePatternSet &patterns, StringRef opName, NonSuccessorInputReplacementBuilderFn replBuilderFn=detail::defaultReplBuilderFn, PatternMatcherFn matcherFn=detail::defaultMatcherFn, PatternBenefit benefit=1)
Populate a pattern that inlines the body of region branch ops when there is a single acyclic path thr...
LogicalResult verifyListOfOperandsOrIntegers(Operation *op, StringRef name, unsigned expectedNumElements, ArrayRef< int64_t > attr, ValueRange values)
Verify that a the values has as many elements as the number of entries in attr for which isDynamic ev...
void populateRegionBranchOpInterfaceCanonicalizationPatterns(RewritePatternSet &patterns, StringRef opName, PatternBenefit benefit=1)
Populate canonicalization patterns that simplify successor operands/inputs of region branch operation...
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition Matchers.h:369
void visitUsedValuesDefinedAbove(Region &region, Region &limit, function_ref< void(OpOperand *)> callback)
Calls callback for each use of a value within region or its descendants that was defined at the ances...
llvm::function_ref< Fn > function_ref
Definition LLVM.h:147
SmallVector< NamedAttribute > getPrunedAttributeList(Operation *op, ArrayRef< StringRef > elidedAttrs)
void printDynamicIndexList(OpAsmPrinter &printer, Operation *op, OperandRange values, ArrayRef< int64_t > integers, ArrayRef< bool > scalableFlags, TypeRange valueTypes=TypeRange(), AsmParser::Delimiter delimiter=AsmParser::Delimiter::Square)
Printer hooks for custom directive in assemblyFormat.
std::optional< APInt > constantTripCount(OpFoldResult lb, OpFoldResult ub, OpFoldResult step, bool isSigned, llvm::function_ref< std::optional< llvm::APSInt >(Value, Value, bool)> computeUbMinusLb)
Return the number of iterations for a loop with a lower bound lb, upper bound ub and step step,...
LogicalResult foldDynamicIndexList(SmallVectorImpl< OpFoldResult > &ofrs, bool onlyNonNegative=false, bool onlyNonZero=false)
Returns "success" when any of the elements in ofrs is a constant value.
LogicalResult matchAndRewrite(ExecuteRegionOp op, PatternRewriter &rewriter) const override
Definition SCF.cpp:222
This is the representation of an operand reference.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
This represents an operation in an abstracted form, suitable for use with the builder APIs.
void addOperands(ValueRange newOperands)
void addTypes(ArrayRef< Type > newTypes)
Region * addRegion()
Create a region that should be attached to the operation.