MLIR 22.0.0git
SCF.cpp
Go to the documentation of this file.
1//===- SCF.cpp - Structured Control Flow Operations -----------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
20#include "mlir/IR/IRMapping.h"
21#include "mlir/IR/Matchers.h"
22#include "mlir/IR/Operation.h"
30#include "llvm/ADT/MapVector.h"
31#include "llvm/ADT/STLExtras.h"
32#include "llvm/ADT/SmallPtrSet.h"
33#include "llvm/Support/Casting.h"
34#include "llvm/Support/DebugLog.h"
35#include <optional>
36
37using namespace mlir;
38using namespace mlir::scf;
39
40#include "mlir/Dialect/SCF/IR/SCFOpsDialect.cpp.inc"
41
42//===----------------------------------------------------------------------===//
43// SCFDialect Dialect Interfaces
44//===----------------------------------------------------------------------===//
45
46namespace {
47struct SCFInlinerInterface : public DialectInlinerInterface {
48 using DialectInlinerInterface::DialectInlinerInterface;
49 // We don't have any special restrictions on what can be inlined into
50 // destination regions (e.g. while/conditional bodies). Always allow it.
51 bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned,
52 IRMapping &valueMapping) const final {
53 return true;
54 }
55 // Operations in scf dialect are always legal to inline since they are
56 // pure.
57 bool isLegalToInline(Operation *, Region *, bool, IRMapping &) const final {
58 return true;
59 }
60 // Handle the given inlined terminator by replacing it with a new operation
61 // as necessary. Required when the region has only one block.
62 void handleTerminator(Operation *op, ValueRange valuesToRepl) const final {
63 auto retValOp = dyn_cast<scf::YieldOp>(op);
64 if (!retValOp)
65 return;
66
67 for (auto retValue : llvm::zip(valuesToRepl, retValOp.getOperands())) {
68 std::get<0>(retValue).replaceAllUsesWith(std::get<1>(retValue));
69 }
70 }
71};
72} // namespace
73
74//===----------------------------------------------------------------------===//
75// SCFDialect
76//===----------------------------------------------------------------------===//
77
78void SCFDialect::initialize() {
79 addOperations<
80#define GET_OP_LIST
81#include "mlir/Dialect/SCF/IR/SCFOps.cpp.inc"
82 >();
83 addInterfaces<SCFInlinerInterface>();
84 declarePromisedInterface<ConvertToEmitCPatternInterface, SCFDialect>();
85 declarePromisedInterfaces<bufferization::BufferDeallocationOpInterface,
86 InParallelOp, ReduceReturnOp>();
87 declarePromisedInterfaces<bufferization::BufferizableOpInterface, ConditionOp,
88 ExecuteRegionOp, ForOp, IfOp, IndexSwitchOp,
89 ForallOp, InParallelOp, WhileOp, YieldOp>();
90 declarePromisedInterface<ValueBoundsOpInterface, ForOp>();
91}
92
93/// Default callback for IfOp builders. Inserts a yield without arguments.
95 scf::YieldOp::create(builder, loc);
96}
97
98/// Verifies that the first block of the given `region` is terminated by a
99/// TerminatorTy. Reports errors on the given operation if it is not the case.
100template <typename TerminatorTy>
101static TerminatorTy verifyAndGetTerminator(Operation *op, Region &region,
102 StringRef errorMessage) {
103 Operation *terminatorOperation = nullptr;
104 if (!region.empty() && !region.front().empty()) {
105 terminatorOperation = &region.front().back();
106 if (auto yield = dyn_cast_or_null<TerminatorTy>(terminatorOperation))
107 return yield;
108 }
109 auto diag = op->emitOpError(errorMessage);
110 if (terminatorOperation)
111 diag.attachNote(terminatorOperation->getLoc()) << "terminator here";
112 return nullptr;
113}
114
115std::optional<llvm::APSInt> mlir::scf::computeUbMinusLb(Value lb, Value ub,
116 bool isSigned) {
117 llvm::APSInt diff;
118 auto addOp = ub.getDefiningOp<arith::AddIOp>();
119 if (!addOp)
120 return std::nullopt;
121 if ((isSigned && !addOp.hasNoSignedWrap()) ||
122 (!isSigned && !addOp.hasNoUnsignedWrap()))
123 return std::nullopt;
124
125 if (addOp.getLhs() != lb ||
126 !matchPattern(addOp.getRhs(), m_ConstantInt(&diff)))
127 return std::nullopt;
128 return diff;
129}
130
131//===----------------------------------------------------------------------===//
132// ExecuteRegionOp
133//===----------------------------------------------------------------------===//
134
135/// Replaces the given op with the contents of the given single-block region,
136/// using the operands of the block terminator to replace operation results.
138 Region &region, ValueRange blockArgs = {}) {
139 assert(region.hasOneBlock() && "expected single-block region");
140 Block *block = &region.front();
141 Operation *terminator = block->getTerminator();
142 ValueRange results = terminator->getOperands();
143 rewriter.inlineBlockBefore(block, op, blockArgs);
144 rewriter.replaceOp(op, results);
145 rewriter.eraseOp(terminator);
146}
147
148///
149/// (ssa-id `=`)? `execute_region` `->` function-result-type `{`
150/// block+
151/// `}`
152///
153/// Example:
154/// scf.execute_region -> i32 {
155/// %idx = load %rI[%i] : memref<128xi32>
156/// return %idx : i32
157/// }
158///
159ParseResult ExecuteRegionOp::parse(OpAsmParser &parser,
161 if (parser.parseOptionalArrowTypeList(result.types))
162 return failure();
163
164 if (succeeded(parser.parseOptionalKeyword("no_inline")))
165 result.addAttribute("no_inline", parser.getBuilder().getUnitAttr());
166
167 // Introduce the body region and parse it.
168 Region *body = result.addRegion();
169 if (parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{}) ||
170 parser.parseOptionalAttrDict(result.attributes))
171 return failure();
172
173 return success();
174}
175
176void ExecuteRegionOp::print(OpAsmPrinter &p) {
177 p.printOptionalArrowTypeList(getResultTypes());
178 p << ' ';
179 if (getNoInline())
180 p << "no_inline ";
181 p.printRegion(getRegion(),
182 /*printEntryBlockArgs=*/false,
183 /*printBlockTerminators=*/true);
184 p.printOptionalAttrDict((*this)->getAttrs(), /*elidedAttrs=*/{"no_inline"});
185}
186
187LogicalResult ExecuteRegionOp::verify() {
188 if (getRegion().empty())
189 return emitOpError("region needs to have at least one block");
190 if (getRegion().front().getNumArguments() > 0)
191 return emitOpError("region cannot have any arguments");
192 return success();
193}
194
195// Inline an ExecuteRegionOp if it only contains one block.
196// "test.foo"() : () -> ()
197// %v = scf.execute_region -> i64 {
198// %x = "test.val"() : () -> i64
199// scf.yield %x : i64
200// }
201// "test.bar"(%v) : (i64) -> ()
202//
203// becomes
204//
205// "test.foo"() : () -> ()
206// %x = "test.val"() : () -> i64
207// "test.bar"(%x) : (i64) -> ()
208//
209struct SingleBlockExecuteInliner : public OpRewritePattern<ExecuteRegionOp> {
210 using OpRewritePattern<ExecuteRegionOp>::OpRewritePattern;
211
212 LogicalResult matchAndRewrite(ExecuteRegionOp op,
213 PatternRewriter &rewriter) const override {
214 if (!op.getRegion().hasOneBlock() || op.getNoInline())
215 return failure();
216 replaceOpWithRegion(rewriter, op, op.getRegion());
217 return success();
218 }
219};
220
221// Inline an ExecuteRegionOp if its parent can contain multiple blocks.
222// TODO generalize the conditions for operations which can be inlined into.
223// func @func_execute_region_elim() {
224// "test.foo"() : () -> ()
225// %v = scf.execute_region -> i64 {
226// %c = "test.cmp"() : () -> i1
227// cf.cond_br %c, ^bb2, ^bb3
228// ^bb2:
229// %x = "test.val1"() : () -> i64
230// cf.br ^bb4(%x : i64)
231// ^bb3:
232// %y = "test.val2"() : () -> i64
233// cf.br ^bb4(%y : i64)
234// ^bb4(%z : i64):
235// scf.yield %z : i64
236// }
237// "test.bar"(%v) : (i64) -> ()
238// return
239// }
240//
241// becomes
242//
243// func @func_execute_region_elim() {
244// "test.foo"() : () -> ()
245// %c = "test.cmp"() : () -> i1
246// cf.cond_br %c, ^bb1, ^bb2
247// ^bb1: // pred: ^bb0
248// %x = "test.val1"() : () -> i64
249// cf.br ^bb3(%x : i64)
250// ^bb2: // pred: ^bb0
251// %y = "test.val2"() : () -> i64
252// cf.br ^bb3(%y : i64)
253// ^bb3(%z: i64): // 2 preds: ^bb1, ^bb2
254// "test.bar"(%z) : (i64) -> ()
255// return
256// }
257//
258struct MultiBlockExecuteInliner : public OpRewritePattern<ExecuteRegionOp> {
259 using OpRewritePattern<ExecuteRegionOp>::OpRewritePattern;
260
261 LogicalResult matchAndRewrite(ExecuteRegionOp op,
262 PatternRewriter &rewriter) const override {
263 if (op.getNoInline())
264 return failure();
265 if (!isa<FunctionOpInterface, ExecuteRegionOp>(op->getParentOp()))
266 return failure();
267
268 Block *prevBlock = op->getBlock();
269 Block *postBlock = rewriter.splitBlock(prevBlock, op->getIterator());
270 rewriter.setInsertionPointToEnd(prevBlock);
271
272 cf::BranchOp::create(rewriter, op.getLoc(), &op.getRegion().front());
273
274 for (Block &blk : op.getRegion()) {
275 if (YieldOp yieldOp = dyn_cast<YieldOp>(blk.getTerminator())) {
276 rewriter.setInsertionPoint(yieldOp);
277 cf::BranchOp::create(rewriter, yieldOp.getLoc(), postBlock,
278 yieldOp.getResults());
279 rewriter.eraseOp(yieldOp);
280 }
281 }
282
283 rewriter.inlineRegionBefore(op.getRegion(), postBlock);
284 SmallVector<Value> blockArgs;
285
286 for (auto res : op.getResults())
287 blockArgs.push_back(postBlock->addArgument(res.getType(), res.getLoc()));
288
289 rewriter.replaceOp(op, blockArgs);
290 return success();
291 }
292};
293
294// Pattern to eliminate ExecuteRegionOp results which forward external
295// values from the region. In case there are multiple yield operations,
296// all of them must have the same operands in order for the pattern to be
297// applicable.
299 : public OpRewritePattern<ExecuteRegionOp> {
300 using OpRewritePattern<ExecuteRegionOp>::OpRewritePattern;
301
302 LogicalResult matchAndRewrite(ExecuteRegionOp op,
303 PatternRewriter &rewriter) const override {
304 if (op.getNumResults() == 0)
305 return failure();
306
308 for (Block &block : op.getRegion()) {
309 if (auto yield = dyn_cast<scf::YieldOp>(block.getTerminator()))
310 yieldOps.push_back(yield.getOperation());
311 }
312
313 if (yieldOps.empty())
314 return failure();
315
316 // Check if all yield operations have the same operands.
317 auto yieldOpsOperands = yieldOps[0]->getOperands();
318 for (auto *yieldOp : yieldOps) {
319 if (yieldOp->getOperands() != yieldOpsOperands)
320 return failure();
321 }
322
323 SmallVector<Value> externalValues;
324 SmallVector<Value> internalValues;
325 SmallVector<Value> opResultsToReplaceWithExternalValues;
326 SmallVector<Value> opResultsToKeep;
327 for (auto [index, yieldedValue] : llvm::enumerate(yieldOpsOperands)) {
328 if (isValueFromInsideRegion(yieldedValue, op)) {
329 internalValues.push_back(yieldedValue);
330 opResultsToKeep.push_back(op.getResult(index));
331 } else {
332 externalValues.push_back(yieldedValue);
333 opResultsToReplaceWithExternalValues.push_back(op.getResult(index));
334 }
335 }
336 // No yielded external values - nothing to do.
337 if (externalValues.empty())
338 return failure();
339
340 // There are yielded external values - create a new execute_region returning
341 // just the internal values.
342 SmallVector<Type> resultTypes;
343 for (Value value : internalValues)
344 resultTypes.push_back(value.getType());
345 auto newOp =
346 ExecuteRegionOp::create(rewriter, op.getLoc(), TypeRange(resultTypes));
347 newOp->setAttrs(op->getAttrs());
348
349 // Move old op's region to the new operation.
350 rewriter.inlineRegionBefore(op.getRegion(), newOp.getRegion(),
351 newOp.getRegion().end());
352
353 // Replace all yield operations with a new yield operation with updated
354 // results. scf.execute_region must have at least one yield operation.
355 for (auto *yieldOp : yieldOps) {
356 rewriter.setInsertionPoint(yieldOp);
357 rewriter.replaceOpWithNewOp<scf::YieldOp>(yieldOp,
358 ValueRange(internalValues));
359 }
360
361 // Replace the old operation with the external values directly.
362 rewriter.replaceAllUsesWith(opResultsToReplaceWithExternalValues,
363 externalValues);
364 // Replace the old operation's remaining results with the new operation's
365 // results.
366 rewriter.replaceAllUsesWith(opResultsToKeep, newOp.getResults());
367 rewriter.eraseOp(op);
368 return success();
369 }
370
371private:
372 bool isValueFromInsideRegion(Value value,
373 ExecuteRegionOp executeRegionOp) const {
374 // Check if the value is defined within the execute_region
375 if (Operation *defOp = value.getDefiningOp())
376 return &executeRegionOp.getRegion() == defOp->getParentRegion();
377
378 // If it's a block argument, check if it's from within the region
379 if (BlockArgument blockArg = dyn_cast<BlockArgument>(value))
380 return &executeRegionOp.getRegion() == blockArg.getParentRegion();
381
382 return false; // Value is from outside the region
383 }
384};
385
386void ExecuteRegionOp::getCanonicalizationPatterns(RewritePatternSet &results,
387 MLIRContext *context) {
390}
391
392void ExecuteRegionOp::getSuccessorRegions(
394 // If the predecessor is the ExecuteRegionOp, branch into the body.
395 if (point.isParent()) {
396 regions.push_back(RegionSuccessor(&getRegion()));
397 return;
398 }
399
400 // Otherwise, the region branches back to the parent operation.
401 regions.push_back(RegionSuccessor(getOperation(), getResults()));
402}
403
404//===----------------------------------------------------------------------===//
405// ConditionOp
406//===----------------------------------------------------------------------===//
407
409ConditionOp::getMutableSuccessorOperands(RegionSuccessor point) {
410 assert(
411 (point.isParent() || point.getSuccessor() == &getParentOp().getAfter()) &&
412 "condition op can only exit the loop or branch to the after"
413 "region");
414 // Pass all operands except the condition to the successor region.
415 return getArgsMutable();
416}
417
418void ConditionOp::getSuccessorRegions(
420 FoldAdaptor adaptor(operands, *this);
421
422 WhileOp whileOp = getParentOp();
423
424 // Condition can either lead to the after region or back to the parent op
425 // depending on whether the condition is true or not.
426 auto boolAttr = dyn_cast_or_null<BoolAttr>(adaptor.getCondition());
427 if (!boolAttr || boolAttr.getValue())
428 regions.emplace_back(&whileOp.getAfter(),
429 whileOp.getAfter().getArguments());
430 if (!boolAttr || !boolAttr.getValue())
431 regions.emplace_back(whileOp.getOperation(), whileOp.getResults());
432}
433
434//===----------------------------------------------------------------------===//
435// ForOp
436//===----------------------------------------------------------------------===//
437
438void ForOp::build(OpBuilder &builder, OperationState &result, Value lb,
439 Value ub, Value step, ValueRange initArgs,
440 BodyBuilderFn bodyBuilder, bool unsignedCmp) {
441 OpBuilder::InsertionGuard guard(builder);
442
443 if (unsignedCmp)
444 result.addAttribute(getUnsignedCmpAttrName(result.name),
445 builder.getUnitAttr());
446 result.addOperands({lb, ub, step});
447 result.addOperands(initArgs);
448 for (Value v : initArgs)
449 result.addTypes(v.getType());
450 Type t = lb.getType();
451 Region *bodyRegion = result.addRegion();
452 Block *bodyBlock = builder.createBlock(bodyRegion);
453 bodyBlock->addArgument(t, result.location);
454 for (Value v : initArgs)
455 bodyBlock->addArgument(v.getType(), v.getLoc());
456
457 // Create the default terminator if the builder is not provided and if the
458 // iteration arguments are not provided. Otherwise, leave this to the caller
459 // because we don't know which values to return from the loop.
460 if (initArgs.empty() && !bodyBuilder) {
461 ForOp::ensureTerminator(*bodyRegion, builder, result.location);
462 } else if (bodyBuilder) {
463 OpBuilder::InsertionGuard guard(builder);
464 builder.setInsertionPointToStart(bodyBlock);
465 bodyBuilder(builder, result.location, bodyBlock->getArgument(0),
466 bodyBlock->getArguments().drop_front());
467 }
468}
469
470LogicalResult ForOp::verify() {
471 // Check that the number of init args and op results is the same.
472 if (getInitArgs().size() != getNumResults())
473 return emitOpError(
474 "mismatch in number of loop-carried values and defined values");
475
476 return success();
477}
478
479LogicalResult ForOp::verifyRegions() {
480 // Check that the body defines as single block argument for the induction
481 // variable.
482 if (getInductionVar().getType() != getLowerBound().getType())
483 return emitOpError(
484 "expected induction variable to be same type as bounds and step");
485
486 if (getNumRegionIterArgs() != getNumResults())
487 return emitOpError(
488 "mismatch in number of basic block args and defined values");
489
490 auto initArgs = getInitArgs();
491 auto iterArgs = getRegionIterArgs();
492 auto opResults = getResults();
493 unsigned i = 0;
494 for (auto e : llvm::zip(initArgs, iterArgs, opResults)) {
495 if (std::get<0>(e).getType() != std::get<2>(e).getType())
496 return emitOpError() << "types mismatch between " << i
497 << "th iter operand and defined value";
498 if (std::get<1>(e).getType() != std::get<2>(e).getType())
499 return emitOpError() << "types mismatch between " << i
500 << "th iter region arg and defined value";
501
502 ++i;
503 }
504 return success();
505}
506
507std::optional<SmallVector<Value>> ForOp::getLoopInductionVars() {
508 return SmallVector<Value>{getInductionVar()};
509}
510
511std::optional<SmallVector<OpFoldResult>> ForOp::getLoopLowerBounds() {
513}
514
515std::optional<SmallVector<OpFoldResult>> ForOp::getLoopSteps() {
516 return SmallVector<OpFoldResult>{OpFoldResult(getStep())};
517}
518
519std::optional<SmallVector<OpFoldResult>> ForOp::getLoopUpperBounds() {
521}
522
523std::optional<ResultRange> ForOp::getLoopResults() { return getResults(); }
524
525/// Promotes the loop body of a forOp to its containing block if the forOp
526/// it can be determined that the loop has a single iteration.
527LogicalResult ForOp::promoteIfSingleIteration(RewriterBase &rewriter) {
528 std::optional<APInt> tripCount = getStaticTripCount();
529 LDBG() << "promoteIfSingleIteration tripCount is " << tripCount
530 << " for loop "
531 << OpWithFlags(getOperation(), OpPrintingFlags().skipRegions());
532 if (!tripCount.has_value() || tripCount->getSExtValue() > 1)
533 return failure();
534
535 if (*tripCount == 0) {
536 rewriter.replaceAllUsesWith(getResults(), getInitArgs());
537 rewriter.eraseOp(*this);
538 return success();
539 }
540
541 // Replace all results with the yielded values.
542 auto yieldOp = cast<scf::YieldOp>(getBody()->getTerminator());
543 rewriter.replaceAllUsesWith(getResults(), getYieldedValues());
544
545 // Replace block arguments with lower bound (replacement for IV) and
546 // iter_args.
547 SmallVector<Value> bbArgReplacements;
548 bbArgReplacements.push_back(getLowerBound());
549 llvm::append_range(bbArgReplacements, getInitArgs());
550
551 // Move the loop body operations to the loop's containing block.
552 rewriter.inlineBlockBefore(getBody(), getOperation()->getBlock(),
553 getOperation()->getIterator(), bbArgReplacements);
554
555 // Erase the old terminator and the loop.
556 rewriter.eraseOp(yieldOp);
557 rewriter.eraseOp(*this);
558
559 return success();
560}
561
562/// Prints the initialization list in the form of
563/// <prefix>(%inner = %outer, %inner2 = %outer2, <...>)
564/// where 'inner' values are assumed to be region arguments and 'outer' values
565/// are regular SSA values.
567 Block::BlockArgListType blocksArgs,
568 ValueRange initializers,
569 StringRef prefix = "") {
570 assert(blocksArgs.size() == initializers.size() &&
571 "expected same length of arguments and initializers");
572 if (initializers.empty())
573 return;
574
575 p << prefix << '(';
576 llvm::interleaveComma(llvm::zip(blocksArgs, initializers), p, [&](auto it) {
577 p << std::get<0>(it) << " = " << std::get<1>(it);
578 });
579 p << ")";
580}
581
582void ForOp::print(OpAsmPrinter &p) {
583 if (getUnsignedCmp())
584 p << " unsigned";
585
586 p << " " << getInductionVar() << " = " << getLowerBound() << " to "
587 << getUpperBound() << " step " << getStep();
588
589 printInitializationList(p, getRegionIterArgs(), getInitArgs(), " iter_args");
590 if (!getInitArgs().empty())
591 p << " -> (" << getInitArgs().getTypes() << ')';
592 p << ' ';
593 if (Type t = getInductionVar().getType(); !t.isIndex())
594 p << " : " << t << ' ';
595 p.printRegion(getRegion(),
596 /*printEntryBlockArgs=*/false,
597 /*printBlockTerminators=*/!getInitArgs().empty());
598 p.printOptionalAttrDict((*this)->getAttrs(),
599 /*elidedAttrs=*/getUnsignedCmpAttrName().strref());
600}
601
602ParseResult ForOp::parse(OpAsmParser &parser, OperationState &result) {
603 auto &builder = parser.getBuilder();
604 Type type;
605
606 OpAsmParser::Argument inductionVariable;
608
609 if (succeeded(parser.parseOptionalKeyword("unsigned")))
610 result.addAttribute(getUnsignedCmpAttrName(result.name),
611 builder.getUnitAttr());
612
613 // Parse the induction variable followed by '='.
614 if (parser.parseOperand(inductionVariable.ssaName) || parser.parseEqual() ||
615 // Parse loop bounds.
616 parser.parseOperand(lb) || parser.parseKeyword("to") ||
617 parser.parseOperand(ub) || parser.parseKeyword("step") ||
618 parser.parseOperand(step))
619 return failure();
620
621 // Parse the optional initial iteration arguments.
624 regionArgs.push_back(inductionVariable);
625
626 bool hasIterArgs = succeeded(parser.parseOptionalKeyword("iter_args"));
627 if (hasIterArgs) {
628 // Parse assignment list and results type list.
629 if (parser.parseAssignmentList(regionArgs, operands) ||
630 parser.parseArrowTypeList(result.types))
631 return failure();
632 }
633
634 if (regionArgs.size() != result.types.size() + 1)
635 return parser.emitError(
636 parser.getNameLoc(),
637 "mismatch in number of loop-carried values and defined values");
638
639 // Parse optional type, else assume Index.
640 if (parser.parseOptionalColon())
641 type = builder.getIndexType();
642 else if (parser.parseType(type))
643 return failure();
644
645 // Set block argument types, so that they are known when parsing the region.
646 regionArgs.front().type = type;
647 for (auto [iterArg, type] :
648 llvm::zip_equal(llvm::drop_begin(regionArgs), result.types))
649 iterArg.type = type;
650
651 // Parse the body region.
652 Region *body = result.addRegion();
653 if (parser.parseRegion(*body, regionArgs))
654 return failure();
655 ForOp::ensureTerminator(*body, builder, result.location);
656
657 // Resolve input operands. This should be done after parsing the region to
658 // catch invalid IR where operands were defined inside of the region.
659 if (parser.resolveOperand(lb, type, result.operands) ||
660 parser.resolveOperand(ub, type, result.operands) ||
661 parser.resolveOperand(step, type, result.operands))
662 return failure();
663 if (hasIterArgs) {
664 for (auto argOperandType : llvm::zip_equal(llvm::drop_begin(regionArgs),
665 operands, result.types)) {
666 Type type = std::get<2>(argOperandType);
667 std::get<0>(argOperandType).type = type;
668 if (parser.resolveOperand(std::get<1>(argOperandType), type,
669 result.operands))
670 return failure();
671 }
672 }
673
674 // Parse the optional attribute list.
675 if (parser.parseOptionalAttrDict(result.attributes))
676 return failure();
677
678 return success();
679}
680
681SmallVector<Region *> ForOp::getLoopRegions() { return {&getRegion()}; }
682
683Block::BlockArgListType ForOp::getRegionIterArgs() {
684 return getBody()->getArguments().drop_front(getNumInductionVars());
685}
686
687MutableArrayRef<OpOperand> ForOp::getInitsMutable() {
688 return getInitArgsMutable();
689}
690
691FailureOr<LoopLikeOpInterface>
692ForOp::replaceWithAdditionalYields(RewriterBase &rewriter,
693 ValueRange newInitOperands,
694 bool replaceInitOperandUsesInLoop,
695 const NewYieldValuesFn &newYieldValuesFn) {
696 // Create a new loop before the existing one, with the extra operands.
697 OpBuilder::InsertionGuard g(rewriter);
698 rewriter.setInsertionPoint(getOperation());
699 auto inits = llvm::to_vector(getInitArgs());
700 inits.append(newInitOperands.begin(), newInitOperands.end());
701 scf::ForOp newLoop = scf::ForOp::create(
702 rewriter, getLoc(), getLowerBound(), getUpperBound(), getStep(), inits,
703 [](OpBuilder &, Location, Value, ValueRange) {}, getUnsignedCmp());
704 newLoop->setAttrs(getPrunedAttributeList(getOperation(), {}));
705
706 // Generate the new yield values and append them to the scf.yield operation.
707 auto yieldOp = cast<scf::YieldOp>(getBody()->getTerminator());
708 ArrayRef<BlockArgument> newIterArgs =
709 newLoop.getBody()->getArguments().take_back(newInitOperands.size());
710 {
711 OpBuilder::InsertionGuard g(rewriter);
712 rewriter.setInsertionPoint(yieldOp);
713 SmallVector<Value> newYieldedValues =
714 newYieldValuesFn(rewriter, getLoc(), newIterArgs);
715 assert(newInitOperands.size() == newYieldedValues.size() &&
716 "expected as many new yield values as new iter operands");
717 rewriter.modifyOpInPlace(yieldOp, [&]() {
718 yieldOp.getResultsMutable().append(newYieldedValues);
719 });
720 }
721
722 // Move the loop body to the new op.
723 rewriter.mergeBlocks(getBody(), newLoop.getBody(),
724 newLoop.getBody()->getArguments().take_front(
725 getBody()->getNumArguments()));
726
727 if (replaceInitOperandUsesInLoop) {
728 // Replace all uses of `newInitOperands` with the corresponding basic block
729 // arguments.
730 for (auto it : llvm::zip(newInitOperands, newIterArgs)) {
731 rewriter.replaceUsesWithIf(std::get<0>(it), std::get<1>(it),
732 [&](OpOperand &use) {
733 Operation *user = use.getOwner();
734 return newLoop->isProperAncestor(user);
735 });
736 }
737 }
738
739 // Replace the old loop.
740 rewriter.replaceOp(getOperation(),
741 newLoop->getResults().take_front(getNumResults()));
742 return cast<LoopLikeOpInterface>(newLoop.getOperation());
743}
744
746 auto ivArg = llvm::dyn_cast<BlockArgument>(val);
747 if (!ivArg)
748 return ForOp();
749 assert(ivArg.getOwner() && "unlinked block argument");
750 auto *containingOp = ivArg.getOwner()->getParentOp();
751 return dyn_cast_or_null<ForOp>(containingOp);
752}
753
754OperandRange ForOp::getEntrySuccessorOperands(RegionSuccessor successor) {
755 return getInitArgs();
756}
757
758void ForOp::getSuccessorRegions(RegionBranchPoint point,
760 // Both the operation itself and the region may be branching into the body or
761 // back into the operation itself. It is possible for loop not to enter the
762 // body.
763 regions.push_back(RegionSuccessor(&getRegion(), getRegionIterArgs()));
764 regions.push_back(RegionSuccessor(getOperation(), getResults()));
765}
766
767SmallVector<Region *> ForallOp::getLoopRegions() { return {&getRegion()}; }
768
769/// Promotes the loop body of a forallOp to its containing block if it can be
770/// determined that the loop has a single iteration.
771LogicalResult scf::ForallOp::promoteIfSingleIteration(RewriterBase &rewriter) {
772 for (auto [lb, ub, step] :
773 llvm::zip(getMixedLowerBound(), getMixedUpperBound(), getMixedStep())) {
774 auto tripCount =
775 constantTripCount(lb, ub, step, /*isSigned=*/true, computeUbMinusLb);
776 if (!tripCount.has_value() || *tripCount != 1)
777 return failure();
778 }
779
780 promote(rewriter, *this);
781 return success();
782}
783
784Block::BlockArgListType ForallOp::getRegionIterArgs() {
785 return getBody()->getArguments().drop_front(getRank());
786}
787
788MutableArrayRef<OpOperand> ForallOp::getInitsMutable() {
789 return getOutputsMutable();
790}
791
792/// Promotes the loop body of a scf::ForallOp to its containing block.
793void mlir::scf::promote(RewriterBase &rewriter, scf::ForallOp forallOp) {
794 OpBuilder::InsertionGuard g(rewriter);
795 scf::InParallelOp terminator = forallOp.getTerminator();
796
797 // Replace block arguments with lower bounds (replacements for IVs) and
798 // outputs.
799 SmallVector<Value> bbArgReplacements = forallOp.getLowerBound(rewriter);
800 bbArgReplacements.append(forallOp.getOutputs().begin(),
801 forallOp.getOutputs().end());
802
803 // Move the loop body operations to the loop's containing block.
804 rewriter.inlineBlockBefore(forallOp.getBody(), forallOp->getBlock(),
805 forallOp->getIterator(), bbArgReplacements);
806
807 // Replace the terminator with tensor.insert_slice ops.
808 rewriter.setInsertionPointAfter(forallOp);
809 SmallVector<Value> results;
810 results.reserve(forallOp.getResults().size());
811 for (auto &yieldingOp : terminator.getYieldingOps()) {
812 auto parallelInsertSliceOp =
813 dyn_cast<tensor::ParallelInsertSliceOp>(yieldingOp);
814 if (!parallelInsertSliceOp)
815 continue;
816
817 Value dst = parallelInsertSliceOp.getDest();
818 Value src = parallelInsertSliceOp.getSource();
819 if (llvm::isa<TensorType>(src.getType())) {
820 results.push_back(tensor::InsertSliceOp::create(
821 rewriter, forallOp.getLoc(), dst.getType(), src, dst,
822 parallelInsertSliceOp.getOffsets(), parallelInsertSliceOp.getSizes(),
823 parallelInsertSliceOp.getStrides(),
824 parallelInsertSliceOp.getStaticOffsets(),
825 parallelInsertSliceOp.getStaticSizes(),
826 parallelInsertSliceOp.getStaticStrides()));
827 } else {
828 llvm_unreachable("unsupported terminator");
829 }
830 }
831 rewriter.replaceAllUsesWith(forallOp.getResults(), results);
832
833 // Erase the old terminator and the loop.
834 rewriter.eraseOp(terminator);
835 rewriter.eraseOp(forallOp);
836}
837
839 OpBuilder &builder, Location loc, ValueRange lbs, ValueRange ubs,
840 ValueRange steps, ValueRange iterArgs,
842 bodyBuilder) {
843 assert(lbs.size() == ubs.size() &&
844 "expected the same number of lower and upper bounds");
845 assert(lbs.size() == steps.size() &&
846 "expected the same number of lower bounds and steps");
847
848 // If there are no bounds, call the body-building function and return early.
849 if (lbs.empty()) {
850 ValueVector results =
851 bodyBuilder ? bodyBuilder(builder, loc, ValueRange(), iterArgs)
852 : ValueVector();
853 assert(results.size() == iterArgs.size() &&
854 "loop nest body must return as many values as loop has iteration "
855 "arguments");
856 return LoopNest{{}, std::move(results)};
857 }
858
859 // First, create the loop structure iteratively using the body-builder
860 // callback of `ForOp::build`. Do not create `YieldOp`s yet.
861 OpBuilder::InsertionGuard guard(builder);
864 loops.reserve(lbs.size());
865 ivs.reserve(lbs.size());
866 ValueRange currentIterArgs = iterArgs;
867 Location currentLoc = loc;
868 for (unsigned i = 0, e = lbs.size(); i < e; ++i) {
869 auto loop = scf::ForOp::create(
870 builder, currentLoc, lbs[i], ubs[i], steps[i], currentIterArgs,
871 [&](OpBuilder &nestedBuilder, Location nestedLoc, Value iv,
872 ValueRange args) {
873 ivs.push_back(iv);
874 // It is safe to store ValueRange args because it points to block
875 // arguments of a loop operation that we also own.
876 currentIterArgs = args;
877 currentLoc = nestedLoc;
878 });
879 // Set the builder to point to the body of the newly created loop. We don't
880 // do this in the callback because the builder is reset when the callback
881 // returns.
882 builder.setInsertionPointToStart(loop.getBody());
883 loops.push_back(loop);
884 }
885
886 // For all loops but the innermost, yield the results of the nested loop.
887 for (unsigned i = 0, e = loops.size() - 1; i < e; ++i) {
888 builder.setInsertionPointToEnd(loops[i].getBody());
889 scf::YieldOp::create(builder, loc, loops[i + 1].getResults());
890 }
891
892 // In the body of the innermost loop, call the body building function if any
893 // and yield its results.
894 builder.setInsertionPointToStart(loops.back().getBody());
895 ValueVector results = bodyBuilder
896 ? bodyBuilder(builder, currentLoc, ivs,
897 loops.back().getRegionIterArgs())
898 : ValueVector();
899 assert(results.size() == iterArgs.size() &&
900 "loop nest body must return as many values as loop has iteration "
901 "arguments");
902 builder.setInsertionPointToEnd(loops.back().getBody());
903 scf::YieldOp::create(builder, loc, results);
904
905 // Return the loops.
906 ValueVector nestResults;
907 llvm::append_range(nestResults, loops.front().getResults());
908 return LoopNest{std::move(loops), std::move(nestResults)};
909}
910
912 OpBuilder &builder, Location loc, ValueRange lbs, ValueRange ubs,
913 ValueRange steps,
914 function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilder) {
915 // Delegate to the main function by wrapping the body builder.
916 return buildLoopNest(builder, loc, lbs, ubs, steps, {},
917 [&bodyBuilder](OpBuilder &nestedBuilder,
918 Location nestedLoc, ValueRange ivs,
920 if (bodyBuilder)
921 bodyBuilder(nestedBuilder, nestedLoc, ivs);
922 return {};
923 });
924}
925
928 OpOperand &operand, Value replacement,
929 const ValueTypeCastFnTy &castFn) {
930 assert(operand.getOwner() == forOp);
931 Type oldType = operand.get().getType(), newType = replacement.getType();
932
933 // 1. Create new iter operands, exactly 1 is replaced.
934 assert(operand.getOperandNumber() >= forOp.getNumControlOperands() &&
935 "expected an iter OpOperand");
936 assert(operand.get().getType() != replacement.getType() &&
937 "Expected a different type");
938 SmallVector<Value> newIterOperands;
939 for (OpOperand &opOperand : forOp.getInitArgsMutable()) {
940 if (opOperand.getOperandNumber() == operand.getOperandNumber()) {
941 newIterOperands.push_back(replacement);
942 continue;
943 }
944 newIterOperands.push_back(opOperand.get());
945 }
946
947 // 2. Create the new forOp shell.
948 scf::ForOp newForOp = scf::ForOp::create(
949 rewriter, forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
950 forOp.getStep(), newIterOperands, /*bodyBuilder=*/nullptr,
951 forOp.getUnsignedCmp());
952 newForOp->setAttrs(forOp->getAttrs());
953 Block &newBlock = newForOp.getRegion().front();
954 SmallVector<Value, 4> newBlockTransferArgs(newBlock.getArguments().begin(),
955 newBlock.getArguments().end());
956
957 // 3. Inject an incoming cast op at the beginning of the block for the bbArg
958 // corresponding to the `replacement` value.
959 OpBuilder::InsertionGuard g(rewriter);
960 rewriter.setInsertionPointToStart(&newBlock);
961 BlockArgument newRegionIterArg = newForOp.getTiedLoopRegionIterArg(
962 &newForOp->getOpOperand(operand.getOperandNumber()));
963 Value castIn = castFn(rewriter, newForOp.getLoc(), oldType, newRegionIterArg);
964 newBlockTransferArgs[newRegionIterArg.getArgNumber()] = castIn;
965
966 // 4. Steal the old block ops, mapping to the newBlockTransferArgs.
967 Block &oldBlock = forOp.getRegion().front();
968 rewriter.mergeBlocks(&oldBlock, &newBlock, newBlockTransferArgs);
969
970 // 5. Inject an outgoing cast op at the end of the block and yield it instead.
971 auto clonedYieldOp = cast<scf::YieldOp>(newBlock.getTerminator());
972 rewriter.setInsertionPoint(clonedYieldOp);
973 unsigned yieldIdx =
974 newRegionIterArg.getArgNumber() - forOp.getNumInductionVars();
975 Value castOut = castFn(rewriter, newForOp.getLoc(), newType,
976 clonedYieldOp.getOperand(yieldIdx));
977 SmallVector<Value> newYieldOperands = clonedYieldOp.getOperands();
978 newYieldOperands[yieldIdx] = castOut;
979 scf::YieldOp::create(rewriter, newForOp.getLoc(), newYieldOperands);
980 rewriter.eraseOp(clonedYieldOp);
981
982 // 6. Inject an outgoing cast op after the forOp.
983 rewriter.setInsertionPointAfter(newForOp);
984 SmallVector<Value> newResults = newForOp.getResults();
985 newResults[yieldIdx] =
986 castFn(rewriter, newForOp.getLoc(), oldType, newResults[yieldIdx]);
987
988 return newResults;
989}
990
991namespace {
992// Fold away ForOp iter arguments when:
993// 1) The op yields the iter arguments.
994// 2) The argument's corresponding outer region iterators (inputs) are yielded.
995// 3) The iter arguments have no use and the corresponding (operation) results
996// have no use.
997//
998// These arguments must be defined outside of the ForOp region and can just be
999// forwarded after simplifying the op inits, yields and returns.
1000//
1001// The implementation uses `inlineBlockBefore` to steal the content of the
1002// original ForOp and avoid cloning.
1003struct ForOpIterArgsFolder : public OpRewritePattern<scf::ForOp> {
1004 using OpRewritePattern<scf::ForOp>::OpRewritePattern;
1005
1006 LogicalResult matchAndRewrite(scf::ForOp forOp,
1007 PatternRewriter &rewriter) const final {
1008 bool canonicalize = false;
1009
1010 // An internal flat vector of block transfer
1011 // arguments `newBlockTransferArgs` keeps the 1-1 mapping of original to
1012 // transformed block argument mappings. This plays the role of a
1013 // IRMapping for the particular use case of calling into
1014 // `inlineBlockBefore`.
1015 int64_t numResults = forOp.getNumResults();
1016 SmallVector<bool, 4> keepMask;
1017 keepMask.reserve(numResults);
1018 SmallVector<Value, 4> newBlockTransferArgs, newIterArgs, newYieldValues,
1019 newResultValues;
1020 newBlockTransferArgs.reserve(1 + numResults);
1021 newBlockTransferArgs.push_back(Value()); // iv placeholder with null value
1022 newIterArgs.reserve(forOp.getInitArgs().size());
1023 newYieldValues.reserve(numResults);
1024 newResultValues.reserve(numResults);
1025 DenseMap<std::pair<Value, Value>, std::pair<Value, Value>> initYieldToArg;
1026 for (auto [init, arg, result, yielded] :
1027 llvm::zip(forOp.getInitArgs(), // iter from outside
1028 forOp.getRegionIterArgs(), // iter inside region
1029 forOp.getResults(), // op results
1030 forOp.getYieldedValues() // iter yield
1031 )) {
1032 // Forwarded is `true` when:
1033 // 1) The region `iter` argument is yielded.
1034 // 2) The region `iter` argument the corresponding input is yielded.
1035 // 3) The region `iter` argument has no use, and the corresponding op
1036 // result has no use.
1037 bool forwarded = (arg == yielded) || (init == yielded) ||
1038 (arg.use_empty() && result.use_empty());
1039 if (forwarded) {
1040 canonicalize = true;
1041 keepMask.push_back(false);
1042 newBlockTransferArgs.push_back(init);
1043 newResultValues.push_back(init);
1044 continue;
1045 }
1046
1047 // Check if a previous kept argument always has the same values for init
1048 // and yielded values.
1049 if (auto it = initYieldToArg.find({init, yielded});
1050 it != initYieldToArg.end()) {
1051 canonicalize = true;
1052 keepMask.push_back(false);
1053 auto [sameArg, sameResult] = it->second;
1054 rewriter.replaceAllUsesWith(arg, sameArg);
1055 rewriter.replaceAllUsesWith(result, sameResult);
1056 // The replacement value doesn't matter because there are no uses.
1057 newBlockTransferArgs.push_back(init);
1058 newResultValues.push_back(init);
1059 continue;
1060 }
1061
1062 // This value is kept.
1063 initYieldToArg.insert({{init, yielded}, {arg, result}});
1064 keepMask.push_back(true);
1065 newIterArgs.push_back(init);
1066 newYieldValues.push_back(yielded);
1067 newBlockTransferArgs.push_back(Value()); // placeholder with null value
1068 newResultValues.push_back(Value()); // placeholder with null value
1069 }
1070
1071 if (!canonicalize)
1072 return failure();
1073
1074 scf::ForOp newForOp =
1075 scf::ForOp::create(rewriter, forOp.getLoc(), forOp.getLowerBound(),
1076 forOp.getUpperBound(), forOp.getStep(), newIterArgs,
1077 /*bodyBuilder=*/nullptr, forOp.getUnsignedCmp());
1078 newForOp->setAttrs(forOp->getAttrs());
1079 Block &newBlock = newForOp.getRegion().front();
1080
1081 // Replace the null placeholders with newly constructed values.
1082 newBlockTransferArgs[0] = newBlock.getArgument(0); // iv
1083 for (unsigned idx = 0, collapsedIdx = 0, e = newResultValues.size();
1084 idx != e; ++idx) {
1085 Value &blockTransferArg = newBlockTransferArgs[1 + idx];
1086 Value &newResultVal = newResultValues[idx];
1087 assert((blockTransferArg && newResultVal) ||
1088 (!blockTransferArg && !newResultVal));
1089 if (!blockTransferArg) {
1090 blockTransferArg = newForOp.getRegionIterArgs()[collapsedIdx];
1091 newResultVal = newForOp.getResult(collapsedIdx++);
1092 }
1093 }
1094
1095 Block &oldBlock = forOp.getRegion().front();
1096 assert(oldBlock.getNumArguments() == newBlockTransferArgs.size() &&
1097 "unexpected argument size mismatch");
1098
1099 // No results case: the scf::ForOp builder already created a zero
1100 // result terminator. Merge before this terminator and just get rid of the
1101 // original terminator that has been merged in.
1102 if (newIterArgs.empty()) {
1103 auto newYieldOp = cast<scf::YieldOp>(newBlock.getTerminator());
1104 rewriter.inlineBlockBefore(&oldBlock, newYieldOp, newBlockTransferArgs);
1105 rewriter.eraseOp(newBlock.getTerminator()->getPrevNode());
1106 rewriter.replaceOp(forOp, newResultValues);
1107 return success();
1108 }
1109
1110 // No terminator case: merge and rewrite the merged terminator.
1111 auto cloneFilteredTerminator = [&](scf::YieldOp mergedTerminator) {
1112 OpBuilder::InsertionGuard g(rewriter);
1113 rewriter.setInsertionPoint(mergedTerminator);
1114 SmallVector<Value, 4> filteredOperands;
1115 filteredOperands.reserve(newResultValues.size());
1116 for (unsigned idx = 0, e = keepMask.size(); idx < e; ++idx)
1117 if (keepMask[idx])
1118 filteredOperands.push_back(mergedTerminator.getOperand(idx));
1119 scf::YieldOp::create(rewriter, mergedTerminator.getLoc(),
1120 filteredOperands);
1121 };
1122
1123 rewriter.mergeBlocks(&oldBlock, &newBlock, newBlockTransferArgs);
1124 auto mergedYieldOp = cast<scf::YieldOp>(newBlock.getTerminator());
1125 cloneFilteredTerminator(mergedYieldOp);
1126 rewriter.eraseOp(mergedYieldOp);
1127 rewriter.replaceOp(forOp, newResultValues);
1128 return success();
1129 }
1130};
1131
1132/// Rewriting pattern that erases loops that are known not to iterate, replaces
1133/// single-iteration loops with their bodies, and removes empty loops that
1134/// iterate at least once and only return values defined outside of the loop.
1135struct SimplifyTrivialLoops : public OpRewritePattern<ForOp> {
1136 using OpRewritePattern<ForOp>::OpRewritePattern;
1137
1138 LogicalResult matchAndRewrite(ForOp op,
1139 PatternRewriter &rewriter) const override {
1140 std::optional<APInt> tripCount = op.getStaticTripCount();
1141 if (!tripCount.has_value())
1142 return rewriter.notifyMatchFailure(op,
1143 "can't compute constant trip count");
1144
1145 if (tripCount->isZero()) {
1146 LDBG() << "SimplifyTrivialLoops tripCount is 0 for loop "
1147 << OpWithFlags(op, OpPrintingFlags().skipRegions());
1148 rewriter.replaceOp(op, op.getInitArgs());
1149 return success();
1150 }
1151
1152 if (tripCount->getSExtValue() == 1) {
1153 LDBG() << "SimplifyTrivialLoops tripCount is 1 for loop "
1154 << OpWithFlags(op, OpPrintingFlags().skipRegions());
1155 SmallVector<Value, 4> blockArgs;
1156 blockArgs.reserve(op.getInitArgs().size() + 1);
1157 blockArgs.push_back(op.getLowerBound());
1158 llvm::append_range(blockArgs, op.getInitArgs());
1159 replaceOpWithRegion(rewriter, op, op.getRegion(), blockArgs);
1160 return success();
1161 }
1162
1163 // Now we are left with loops that have more than 1 iterations.
1164 Block &block = op.getRegion().front();
1165 if (!llvm::hasSingleElement(block))
1166 return failure();
1167 // The loop is empty and iterates at least once, if it only returns values
1168 // defined outside of the loop, remove it and replace it with yield values.
1169 if (llvm::any_of(op.getYieldedValues(),
1170 [&](Value v) { return !op.isDefinedOutsideOfLoop(v); }))
1171 return failure();
1172 LDBG() << "SimplifyTrivialLoops empty body loop allows replacement with "
1173 "yield operands for loop "
1174 << OpWithFlags(op, OpPrintingFlags().skipRegions());
1175 rewriter.replaceOp(op, op.getYieldedValues());
1176 return success();
1177 }
1178};
1179
1180/// Fold scf.for iter_arg/result pairs that go through incoming/ougoing
1181/// a tensor.cast op pair so as to pull the tensor.cast inside the scf.for:
1182///
1183/// ```
1184/// %0 = tensor.cast %t0 : tensor<32x1024xf32> to tensor<?x?xf32>
1185/// %1 = scf.for %i = %c0 to %c1024 step %c32 iter_args(%iter_t0 = %0)
1186/// -> (tensor<?x?xf32>) {
1187/// %2 = call @do(%iter_t0) : (tensor<?x?xf32>) -> tensor<?x?xf32>
1188/// scf.yield %2 : tensor<?x?xf32>
1189/// }
1190/// use_of(%1)
1191/// ```
1192///
1193/// folds into:
1194///
1195/// ```
1196/// %0 = scf.for %arg2 = %c0 to %c1024 step %c32 iter_args(%arg3 = %arg0)
1197/// -> (tensor<32x1024xf32>) {
1198/// %2 = tensor.cast %arg3 : tensor<32x1024xf32> to tensor<?x?xf32>
1199/// %3 = call @do(%2) : (tensor<?x?xf32>) -> tensor<?x?xf32>
1200/// %4 = tensor.cast %3 : tensor<?x?xf32> to tensor<32x1024xf32>
1201/// scf.yield %4 : tensor<32x1024xf32>
1202/// }
1203/// %1 = tensor.cast %0 : tensor<32x1024xf32> to tensor<?x?xf32>
1204/// use_of(%1)
1205/// ```
1206struct ForOpTensorCastFolder : public OpRewritePattern<ForOp> {
1207 using OpRewritePattern<ForOp>::OpRewritePattern;
1208
1209 LogicalResult matchAndRewrite(ForOp op,
1210 PatternRewriter &rewriter) const override {
1211 for (auto it : llvm::zip(op.getInitArgsMutable(), op.getResults())) {
1212 OpOperand &iterOpOperand = std::get<0>(it);
1213 auto incomingCast = iterOpOperand.get().getDefiningOp<tensor::CastOp>();
1214 if (!incomingCast ||
1215 incomingCast.getSource().getType() == incomingCast.getType())
1216 continue;
1217 // If the dest type of the cast does not preserve static information in
1218 // the source type.
1220 incomingCast.getDest().getType(),
1221 incomingCast.getSource().getType()))
1222 continue;
1223 if (!std::get<1>(it).hasOneUse())
1224 continue;
1225
1226 // Create a new ForOp with that iter operand replaced.
1227 rewriter.replaceOp(
1229 rewriter, op, iterOpOperand, incomingCast.getSource(),
1230 [](OpBuilder &b, Location loc, Type type, Value source) {
1231 return tensor::CastOp::create(b, loc, type, source);
1232 }));
1233 return success();
1234 }
1235 return failure();
1236 }
1237};
1238
1239} // namespace
1240
1241void ForOp::getCanonicalizationPatterns(RewritePatternSet &results,
1242 MLIRContext *context) {
1243 results.add<ForOpIterArgsFolder, SimplifyTrivialLoops, ForOpTensorCastFolder>(
1244 context);
1245}
1246
1247std::optional<APInt> ForOp::getConstantStep() {
1248 IntegerAttr step;
1249 if (matchPattern(getStep(), m_Constant(&step)))
1250 return step.getValue();
1251 return {};
1252}
1253
1254std::optional<MutableArrayRef<OpOperand>> ForOp::getYieldedValuesMutable() {
1255 return cast<scf::YieldOp>(getBody()->getTerminator()).getResultsMutable();
1256}
1257
1258Speculation::Speculatability ForOp::getSpeculatability() {
1259 // `scf.for (I = Start; I < End; I += 1)` terminates for all values of Start
1260 // and End.
1261 if (auto constantStep = getConstantStep())
1262 if (*constantStep == 1)
1264
1265 // For Step != 1, the loop may not terminate. We can add more smarts here if
1266 // needed.
1268}
1269
1270std::optional<APInt> ForOp::getStaticTripCount() {
1271 return constantTripCount(getLowerBound(), getUpperBound(), getStep(),
1272 /*isSigned=*/!getUnsignedCmp(), computeUbMinusLb);
1273}
1274
1275//===----------------------------------------------------------------------===//
1276// ForallOp
1277//===----------------------------------------------------------------------===//
1278
1279LogicalResult ForallOp::verify() {
1280 unsigned numLoops = getRank();
1281 // Check number of outputs.
1282 if (getNumResults() != getOutputs().size())
1283 return emitOpError("produces ")
1284 << getNumResults() << " results, but has only "
1285 << getOutputs().size() << " outputs";
1286
1287 // Check that the body defines block arguments for thread indices and outputs.
1288 auto *body = getBody();
1289 if (body->getNumArguments() != numLoops + getOutputs().size())
1290 return emitOpError("region expects ") << numLoops << " arguments";
1291 for (int64_t i = 0; i < numLoops; ++i)
1292 if (!body->getArgument(i).getType().isIndex())
1293 return emitOpError("expects ")
1294 << i << "-th block argument to be an index";
1295 for (unsigned i = 0; i < getOutputs().size(); ++i)
1296 if (body->getArgument(i + numLoops).getType() != getOutputs()[i].getType())
1297 return emitOpError("type mismatch between ")
1298 << i << "-th output and corresponding block argument";
1299 if (getMapping().has_value() && !getMapping()->empty()) {
1300 if (getDeviceMappingAttrs().size() != numLoops)
1301 return emitOpError() << "mapping attribute size must match op rank";
1302 if (failed(getDeviceMaskingAttr()))
1303 return emitOpError() << getMappingAttrName()
1304 << " supports at most one device masking attribute";
1305 }
1306
1307 // Verify mixed static/dynamic control variables.
1308 Operation *op = getOperation();
1309 if (failed(verifyListOfOperandsOrIntegers(op, "lower bound", numLoops,
1310 getStaticLowerBound(),
1311 getDynamicLowerBound())))
1312 return failure();
1313 if (failed(verifyListOfOperandsOrIntegers(op, "upper bound", numLoops,
1314 getStaticUpperBound(),
1315 getDynamicUpperBound())))
1316 return failure();
1317 if (failed(verifyListOfOperandsOrIntegers(op, "step", numLoops,
1318 getStaticStep(), getDynamicStep())))
1319 return failure();
1320
1321 return success();
1322}
1323
1324void ForallOp::print(OpAsmPrinter &p) {
1325 Operation *op = getOperation();
1326 p << " (" << getInductionVars();
1327 if (isNormalized()) {
1328 p << ") in ";
1329 printDynamicIndexList(p, op, getDynamicUpperBound(), getStaticUpperBound(),
1330 /*valueTypes=*/{}, /*scalables=*/{},
1332 } else {
1333 p << ") = ";
1334 printDynamicIndexList(p, op, getDynamicLowerBound(), getStaticLowerBound(),
1335 /*valueTypes=*/{}, /*scalables=*/{},
1337 p << " to ";
1338 printDynamicIndexList(p, op, getDynamicUpperBound(), getStaticUpperBound(),
1339 /*valueTypes=*/{}, /*scalables=*/{},
1341 p << " step ";
1342 printDynamicIndexList(p, op, getDynamicStep(), getStaticStep(),
1343 /*valueTypes=*/{}, /*scalables=*/{},
1345 }
1346 printInitializationList(p, getRegionOutArgs(), getOutputs(), " shared_outs");
1347 p << " ";
1348 if (!getRegionOutArgs().empty())
1349 p << "-> (" << getResultTypes() << ") ";
1350 p.printRegion(getRegion(),
1351 /*printEntryBlockArgs=*/false,
1352 /*printBlockTerminators=*/getNumResults() > 0);
1353 p.printOptionalAttrDict(op->getAttrs(), {getOperandSegmentSizesAttrName(),
1354 getStaticLowerBoundAttrName(),
1355 getStaticUpperBoundAttrName(),
1356 getStaticStepAttrName()});
1357}
1358
1359ParseResult ForallOp::parse(OpAsmParser &parser, OperationState &result) {
1360 OpBuilder b(parser.getContext());
1361 auto indexType = b.getIndexType();
1362
1363 // Parse an opening `(` followed by thread index variables followed by `)`
1364 // TODO: when we can refer to such "induction variable"-like handles from the
1365 // declarative assembly format, we can implement the parser as a custom hook.
1366 SmallVector<OpAsmParser::Argument, 4> ivs;
1368 return failure();
1369
1370 DenseI64ArrayAttr staticLbs, staticUbs, staticSteps;
1371 SmallVector<OpAsmParser::UnresolvedOperand> dynamicLbs, dynamicUbs,
1372 dynamicSteps;
1373 if (succeeded(parser.parseOptionalKeyword("in"))) {
1374 // Parse upper bounds.
1375 if (parseDynamicIndexList(parser, dynamicUbs, staticUbs,
1376 /*valueTypes=*/nullptr,
1378 parser.resolveOperands(dynamicUbs, indexType, result.operands))
1379 return failure();
1380
1381 unsigned numLoops = ivs.size();
1382 staticLbs = b.getDenseI64ArrayAttr(SmallVector<int64_t>(numLoops, 0));
1383 staticSteps = b.getDenseI64ArrayAttr(SmallVector<int64_t>(numLoops, 1));
1384 } else {
1385 // Parse lower bounds.
1386 if (parser.parseEqual() ||
1387 parseDynamicIndexList(parser, dynamicLbs, staticLbs,
1388 /*valueTypes=*/nullptr,
1390
1391 parser.resolveOperands(dynamicLbs, indexType, result.operands))
1392 return failure();
1393
1394 // Parse upper bounds.
1395 if (parser.parseKeyword("to") ||
1396 parseDynamicIndexList(parser, dynamicUbs, staticUbs,
1397 /*valueTypes=*/nullptr,
1399 parser.resolveOperands(dynamicUbs, indexType, result.operands))
1400 return failure();
1401
1402 // Parse step values.
1403 if (parser.parseKeyword("step") ||
1404 parseDynamicIndexList(parser, dynamicSteps, staticSteps,
1405 /*valueTypes=*/nullptr,
1407 parser.resolveOperands(dynamicSteps, indexType, result.operands))
1408 return failure();
1409 }
1410
1411 // Parse out operands and results.
1412 SmallVector<OpAsmParser::Argument, 4> regionOutArgs;
1413 SmallVector<OpAsmParser::UnresolvedOperand, 4> outOperands;
1414 SMLoc outOperandsLoc = parser.getCurrentLocation();
1415 if (succeeded(parser.parseOptionalKeyword("shared_outs"))) {
1416 if (outOperands.size() != result.types.size())
1417 return parser.emitError(outOperandsLoc,
1418 "mismatch between out operands and types");
1419 if (parser.parseAssignmentList(regionOutArgs, outOperands) ||
1420 parser.parseOptionalArrowTypeList(result.types) ||
1421 parser.resolveOperands(outOperands, result.types, outOperandsLoc,
1422 result.operands))
1423 return failure();
1424 }
1425
1426 // Parse region.
1427 SmallVector<OpAsmParser::Argument, 4> regionArgs;
1428 std::unique_ptr<Region> region = std::make_unique<Region>();
1429 for (auto &iv : ivs) {
1430 iv.type = b.getIndexType();
1431 regionArgs.push_back(iv);
1432 }
1433 for (const auto &it : llvm::enumerate(regionOutArgs)) {
1434 auto &out = it.value();
1435 out.type = result.types[it.index()];
1436 regionArgs.push_back(out);
1437 }
1438 if (parser.parseRegion(*region, regionArgs))
1439 return failure();
1440
1441 // Ensure terminator and move region.
1442 ForallOp::ensureTerminator(*region, b, result.location);
1443 result.addRegion(std::move(region));
1444
1445 // Parse the optional attribute list.
1446 if (parser.parseOptionalAttrDict(result.attributes))
1447 return failure();
1448
1449 result.addAttribute("staticLowerBound", staticLbs);
1450 result.addAttribute("staticUpperBound", staticUbs);
1451 result.addAttribute("staticStep", staticSteps);
1452 result.addAttribute("operandSegmentSizes",
1454 {static_cast<int32_t>(dynamicLbs.size()),
1455 static_cast<int32_t>(dynamicUbs.size()),
1456 static_cast<int32_t>(dynamicSteps.size()),
1457 static_cast<int32_t>(outOperands.size())}));
1458 return success();
1459}
1460
1461// Builder that takes loop bounds.
1462void ForallOp::build(
1463 mlir::OpBuilder &b, mlir::OperationState &result,
1464 ArrayRef<OpFoldResult> lbs, ArrayRef<OpFoldResult> ubs,
1465 ArrayRef<OpFoldResult> steps, ValueRange outputs,
1466 std::optional<ArrayAttr> mapping,
1467 function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilderFn) {
1468 SmallVector<int64_t> staticLbs, staticUbs, staticSteps;
1469 SmallVector<Value> dynamicLbs, dynamicUbs, dynamicSteps;
1470 dispatchIndexOpFoldResults(lbs, dynamicLbs, staticLbs);
1471 dispatchIndexOpFoldResults(ubs, dynamicUbs, staticUbs);
1472 dispatchIndexOpFoldResults(steps, dynamicSteps, staticSteps);
1473
1474 result.addOperands(dynamicLbs);
1475 result.addOperands(dynamicUbs);
1476 result.addOperands(dynamicSteps);
1477 result.addOperands(outputs);
1478 result.addTypes(TypeRange(outputs));
1479
1480 result.addAttribute(getStaticLowerBoundAttrName(result.name),
1481 b.getDenseI64ArrayAttr(staticLbs));
1482 result.addAttribute(getStaticUpperBoundAttrName(result.name),
1483 b.getDenseI64ArrayAttr(staticUbs));
1484 result.addAttribute(getStaticStepAttrName(result.name),
1485 b.getDenseI64ArrayAttr(staticSteps));
1486 result.addAttribute(
1487 "operandSegmentSizes",
1488 b.getDenseI32ArrayAttr({static_cast<int32_t>(dynamicLbs.size()),
1489 static_cast<int32_t>(dynamicUbs.size()),
1490 static_cast<int32_t>(dynamicSteps.size()),
1491 static_cast<int32_t>(outputs.size())}));
1492 if (mapping.has_value()) {
1493 result.addAttribute(ForallOp::getMappingAttrName(result.name),
1494 mapping.value());
1495 }
1496
1497 Region *bodyRegion = result.addRegion();
1498 OpBuilder::InsertionGuard g(b);
1499 b.createBlock(bodyRegion);
1500 Block &bodyBlock = bodyRegion->front();
1501
1502 // Add block arguments for indices and outputs.
1503 bodyBlock.addArguments(
1504 SmallVector<Type>(lbs.size(), b.getIndexType()),
1505 SmallVector<Location>(staticLbs.size(), result.location));
1506 bodyBlock.addArguments(
1507 TypeRange(outputs),
1508 SmallVector<Location>(outputs.size(), result.location));
1509
1510 b.setInsertionPointToStart(&bodyBlock);
1511 if (!bodyBuilderFn) {
1512 ForallOp::ensureTerminator(*bodyRegion, b, result.location);
1513 return;
1514 }
1515 bodyBuilderFn(b, result.location, bodyBlock.getArguments());
1516}
1517
1518// Builder that takes loop bounds.
1519void ForallOp::build(
1520 mlir::OpBuilder &b, mlir::OperationState &result,
1521 ArrayRef<OpFoldResult> ubs, ValueRange outputs,
1522 std::optional<ArrayAttr> mapping,
1523 function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilderFn) {
1524 unsigned numLoops = ubs.size();
1525 SmallVector<OpFoldResult> lbs(numLoops, b.getIndexAttr(0));
1526 SmallVector<OpFoldResult> steps(numLoops, b.getIndexAttr(1));
1527 build(b, result, lbs, ubs, steps, outputs, mapping, bodyBuilderFn);
1528}
1529
1530// Checks if the lbs are zeros and steps are ones.
1531bool ForallOp::isNormalized() {
1532 auto allEqual = [](ArrayRef<OpFoldResult> results, int64_t val) {
1533 return llvm::all_of(results, [&](OpFoldResult ofr) {
1534 auto intValue = getConstantIntValue(ofr);
1535 return intValue.has_value() && intValue == val;
1536 });
1537 };
1538 return allEqual(getMixedLowerBound(), 0) && allEqual(getMixedStep(), 1);
1539}
1540
1541InParallelOp ForallOp::getTerminator() {
1542 return cast<InParallelOp>(getBody()->getTerminator());
1543}
1544
1545SmallVector<Operation *> ForallOp::getCombiningOps(BlockArgument bbArg) {
1546 SmallVector<Operation *> storeOps;
1547 for (Operation *user : bbArg.getUsers()) {
1548 if (auto parallelOp = dyn_cast<ParallelCombiningOpInterface>(user)) {
1549 storeOps.push_back(parallelOp);
1550 }
1551 }
1552 return storeOps;
1553}
1554
1555SmallVector<DeviceMappingAttrInterface> ForallOp::getDeviceMappingAttrs() {
1556 SmallVector<DeviceMappingAttrInterface> res;
1557 if (!getMapping())
1558 return res;
1559 for (auto attr : getMapping()->getValue()) {
1560 auto m = dyn_cast<DeviceMappingAttrInterface>(attr);
1561 if (m)
1562 res.push_back(m);
1563 }
1564 return res;
1565}
1566
1567FailureOr<DeviceMaskingAttrInterface> ForallOp::getDeviceMaskingAttr() {
1568 DeviceMaskingAttrInterface res;
1569 if (!getMapping())
1570 return res;
1571 for (auto attr : getMapping()->getValue()) {
1572 auto m = dyn_cast<DeviceMaskingAttrInterface>(attr);
1573 if (m && res)
1574 return failure();
1575 if (m)
1576 res = m;
1577 }
1578 return res;
1579}
1580
1581bool ForallOp::usesLinearMapping() {
1582 SmallVector<DeviceMappingAttrInterface> ifaces = getDeviceMappingAttrs();
1583 if (ifaces.empty())
1584 return false;
1585 return ifaces.front().isLinearMapping();
1586}
1587
1588std::optional<SmallVector<Value>> ForallOp::getLoopInductionVars() {
1589 return SmallVector<Value>{getBody()->getArguments().take_front(getRank())};
1590}
1591
1592// Get lower bounds as OpFoldResult.
1593std::optional<SmallVector<OpFoldResult>> ForallOp::getLoopLowerBounds() {
1594 Builder b(getOperation()->getContext());
1595 return getMixedValues(getStaticLowerBound(), getDynamicLowerBound(), b);
1596}
1597
1598// Get upper bounds as OpFoldResult.
1599std::optional<SmallVector<OpFoldResult>> ForallOp::getLoopUpperBounds() {
1600 Builder b(getOperation()->getContext());
1601 return getMixedValues(getStaticUpperBound(), getDynamicUpperBound(), b);
1602}
1603
1604// Get steps as OpFoldResult.
1605std::optional<SmallVector<OpFoldResult>> ForallOp::getLoopSteps() {
1606 Builder b(getOperation()->getContext());
1607 return getMixedValues(getStaticStep(), getDynamicStep(), b);
1608}
1609
1611 auto tidxArg = llvm::dyn_cast<BlockArgument>(val);
1612 if (!tidxArg)
1613 return ForallOp();
1614 assert(tidxArg.getOwner() && "unlinked block argument");
1615 auto *containingOp = tidxArg.getOwner()->getParentOp();
1616 return dyn_cast<ForallOp>(containingOp);
1617}
1618
1619namespace {
1620/// Fold tensor.dim(forall shared_outs(... = %t)) to tensor.dim(%t).
1621struct DimOfForallOp : public OpRewritePattern<tensor::DimOp> {
1622 using OpRewritePattern<tensor::DimOp>::OpRewritePattern;
1623
1624 LogicalResult matchAndRewrite(tensor::DimOp dimOp,
1625 PatternRewriter &rewriter) const final {
1626 auto forallOp = dimOp.getSource().getDefiningOp<ForallOp>();
1627 if (!forallOp)
1628 return failure();
1629 Value sharedOut =
1630 forallOp.getTiedOpOperand(llvm::cast<OpResult>(dimOp.getSource()))
1631 ->get();
1632 rewriter.modifyOpInPlace(
1633 dimOp, [&]() { dimOp.getSourceMutable().assign(sharedOut); });
1634 return success();
1635 }
1636};
1637
1638class ForallOpControlOperandsFolder : public OpRewritePattern<ForallOp> {
1639public:
1640 using OpRewritePattern<ForallOp>::OpRewritePattern;
1641
1642 LogicalResult matchAndRewrite(ForallOp op,
1643 PatternRewriter &rewriter) const override {
1644 SmallVector<OpFoldResult> mixedLowerBound(op.getMixedLowerBound());
1645 SmallVector<OpFoldResult> mixedUpperBound(op.getMixedUpperBound());
1646 SmallVector<OpFoldResult> mixedStep(op.getMixedStep());
1647 if (failed(foldDynamicIndexList(mixedLowerBound)) &&
1648 failed(foldDynamicIndexList(mixedUpperBound)) &&
1649 failed(foldDynamicIndexList(mixedStep)))
1650 return failure();
1651
1652 rewriter.modifyOpInPlace(op, [&]() {
1653 SmallVector<Value> dynamicLowerBound, dynamicUpperBound, dynamicStep;
1654 SmallVector<int64_t> staticLowerBound, staticUpperBound, staticStep;
1655 dispatchIndexOpFoldResults(mixedLowerBound, dynamicLowerBound,
1656 staticLowerBound);
1657 op.getDynamicLowerBoundMutable().assign(dynamicLowerBound);
1658 op.setStaticLowerBound(staticLowerBound);
1659
1660 dispatchIndexOpFoldResults(mixedUpperBound, dynamicUpperBound,
1661 staticUpperBound);
1662 op.getDynamicUpperBoundMutable().assign(dynamicUpperBound);
1663 op.setStaticUpperBound(staticUpperBound);
1664
1665 dispatchIndexOpFoldResults(mixedStep, dynamicStep, staticStep);
1666 op.getDynamicStepMutable().assign(dynamicStep);
1667 op.setStaticStep(staticStep);
1668
1669 op->setAttr(ForallOp::getOperandSegmentSizeAttr(),
1670 rewriter.getDenseI32ArrayAttr(
1671 {static_cast<int32_t>(dynamicLowerBound.size()),
1672 static_cast<int32_t>(dynamicUpperBound.size()),
1673 static_cast<int32_t>(dynamicStep.size()),
1674 static_cast<int32_t>(op.getNumResults())}));
1675 });
1676 return success();
1677 }
1678};
1679
1680/// The following canonicalization pattern folds the iter arguments of
1681/// scf.forall op if :-
1682/// 1. The corresponding result has zero uses.
1683/// 2. The iter argument is NOT being modified within the loop body.
1684/// uses.
1685///
1686/// Example of first case :-
1687/// INPUT:
1688/// %res:3 = scf.forall ... shared_outs(%arg0 = %a, %arg1 = %b, %arg2 = %c)
1689/// {
1690/// ...
1691/// <SOME USE OF %arg0>
1692/// <SOME USE OF %arg1>
1693/// <SOME USE OF %arg2>
1694/// ...
1695/// scf.forall.in_parallel {
1696/// <STORE OP WITH DESTINATION %arg1>
1697/// <STORE OP WITH DESTINATION %arg0>
1698/// <STORE OP WITH DESTINATION %arg2>
1699/// }
1700/// }
1701/// return %res#1
1702///
1703/// OUTPUT:
1704/// %res:3 = scf.forall ... shared_outs(%new_arg0 = %b)
1705/// {
1706/// ...
1707/// <SOME USE OF %a>
1708/// <SOME USE OF %new_arg0>
1709/// <SOME USE OF %c>
1710/// ...
1711/// scf.forall.in_parallel {
1712/// <STORE OP WITH DESTINATION %new_arg0>
1713/// }
1714/// }
1715/// return %res
1716///
1717/// NOTE: 1. All uses of the folded shared_outs (iter argument) within the
1718/// scf.forall is replaced by their corresponding operands.
1719/// 2. Even if there are <STORE OP WITH DESTINATION *> ops within the body
1720/// of the scf.forall besides within scf.forall.in_parallel terminator,
1721/// this canonicalization remains valid. For more details, please refer
1722/// to :
1723/// https://github.com/llvm/llvm-project/pull/90189#discussion_r1589011124
1724/// 3. TODO(avarma): Generalize it for other store ops. Currently it
1725/// handles tensor.parallel_insert_slice ops only.
1726///
1727/// Example of second case :-
1728/// INPUT:
1729/// %res:2 = scf.forall ... shared_outs(%arg0 = %a, %arg1 = %b)
1730/// {
1731/// ...
1732/// <SOME USE OF %arg0>
1733/// <SOME USE OF %arg1>
1734/// ...
1735/// scf.forall.in_parallel {
1736/// <STORE OP WITH DESTINATION %arg1>
1737/// }
1738/// }
1739/// return %res#0, %res#1
1740///
1741/// OUTPUT:
1742/// %res = scf.forall ... shared_outs(%new_arg0 = %b)
1743/// {
1744/// ...
1745/// <SOME USE OF %a>
1746/// <SOME USE OF %new_arg0>
1747/// ...
1748/// scf.forall.in_parallel {
1749/// <STORE OP WITH DESTINATION %new_arg0>
1750/// }
1751/// }
1752/// return %a, %res
1753struct ForallOpIterArgsFolder : public OpRewritePattern<ForallOp> {
1754 using OpRewritePattern<ForallOp>::OpRewritePattern;
1755
1756 LogicalResult matchAndRewrite(ForallOp forallOp,
1757 PatternRewriter &rewriter) const final {
1758 // Step 1: For a given i-th result of scf.forall, check the following :-
1759 // a. If it has any use.
1760 // b. If the corresponding iter argument is being modified within
1761 // the loop, i.e. has at least one store op with the iter arg as
1762 // its destination operand. For this we use
1763 // ForallOp::getCombiningOps(iter_arg).
1764 //
1765 // Based on the check we maintain the following :-
1766 // a. `resultToDelete` - i-th result of scf.forall that'll be
1767 // deleted.
1768 // b. `resultToReplace` - i-th result of the old scf.forall
1769 // whose uses will be replaced by the new scf.forall.
1770 // c. `newOuts` - the shared_outs' operand of the new scf.forall
1771 // corresponding to the i-th result with at least one use.
1772 SetVector<OpResult> resultToDelete;
1773 SmallVector<Value> resultToReplace;
1774 SmallVector<Value> newOuts;
1775 for (OpResult result : forallOp.getResults()) {
1776 OpOperand *opOperand = forallOp.getTiedOpOperand(result);
1777 BlockArgument blockArg = forallOp.getTiedBlockArgument(opOperand);
1778 if (result.use_empty() || forallOp.getCombiningOps(blockArg).empty()) {
1779 resultToDelete.insert(result);
1780 } else {
1781 resultToReplace.push_back(result);
1782 newOuts.push_back(opOperand->get());
1783 }
1784 }
1785
1786 // Return early if all results of scf.forall have at least one use and being
1787 // modified within the loop.
1788 if (resultToDelete.empty())
1789 return failure();
1790
1791 // Step 2: For the the i-th result, do the following :-
1792 // a. Fetch the corresponding BlockArgument.
1793 // b. Look for store ops (currently tensor.parallel_insert_slice)
1794 // with the BlockArgument as its destination operand.
1795 // c. Remove the operations fetched in b.
1796 for (OpResult result : resultToDelete) {
1797 OpOperand *opOperand = forallOp.getTiedOpOperand(result);
1798 BlockArgument blockArg = forallOp.getTiedBlockArgument(opOperand);
1799 SmallVector<Operation *> combiningOps =
1800 forallOp.getCombiningOps(blockArg);
1801 for (Operation *combiningOp : combiningOps)
1802 rewriter.eraseOp(combiningOp);
1803 }
1804
1805 // Step 3. Create a new scf.forall op with the new shared_outs' operands
1806 // fetched earlier
1807 auto newForallOp = scf::ForallOp::create(
1808 rewriter, forallOp.getLoc(), forallOp.getMixedLowerBound(),
1809 forallOp.getMixedUpperBound(), forallOp.getMixedStep(), newOuts,
1810 forallOp.getMapping(),
1811 /*bodyBuilderFn =*/[](OpBuilder &, Location, ValueRange) {});
1812
1813 // Step 4. Merge the block of the old scf.forall into the newly created
1814 // scf.forall using the new set of arguments.
1815 Block *loopBody = forallOp.getBody();
1816 Block *newLoopBody = newForallOp.getBody();
1817 ArrayRef<BlockArgument> newBbArgs = newLoopBody->getArguments();
1818 // Form initial new bbArg list with just the control operands of the new
1819 // scf.forall op.
1820 SmallVector<Value> newBlockArgs =
1821 llvm::map_to_vector(newBbArgs.take_front(forallOp.getRank()),
1822 [](BlockArgument b) -> Value { return b; });
1823 Block::BlockArgListType newSharedOutsArgs = newForallOp.getRegionOutArgs();
1824 unsigned index = 0;
1825 // Take the new corresponding bbArg if the old bbArg was used as a
1826 // destination in the in_parallel op. For all other bbArgs, use the
1827 // corresponding init_arg from the old scf.forall op.
1828 for (OpResult result : forallOp.getResults()) {
1829 if (resultToDelete.count(result)) {
1830 newBlockArgs.push_back(forallOp.getTiedOpOperand(result)->get());
1831 } else {
1832 newBlockArgs.push_back(newSharedOutsArgs[index++]);
1833 }
1834 }
1835 rewriter.mergeBlocks(loopBody, newLoopBody, newBlockArgs);
1836
1837 // Step 5. Replace the uses of result of old scf.forall with that of the new
1838 // scf.forall.
1839 for (auto &&[oldResult, newResult] :
1840 llvm::zip(resultToReplace, newForallOp->getResults()))
1841 rewriter.replaceAllUsesWith(oldResult, newResult);
1842
1843 // Step 6. Replace the uses of those values that either has no use or are
1844 // not being modified within the loop with the corresponding
1845 // OpOperand.
1846 for (OpResult oldResult : resultToDelete)
1847 rewriter.replaceAllUsesWith(oldResult,
1848 forallOp.getTiedOpOperand(oldResult)->get());
1849 return success();
1850 }
1851};
1852
1853struct ForallOpSingleOrZeroIterationDimsFolder
1854 : public OpRewritePattern<ForallOp> {
1855 using OpRewritePattern<ForallOp>::OpRewritePattern;
1856
1857 LogicalResult matchAndRewrite(ForallOp op,
1858 PatternRewriter &rewriter) const override {
1859 // Do not fold dimensions if they are mapped to processing units.
1860 if (op.getMapping().has_value() && !op.getMapping()->empty())
1861 return failure();
1862 Location loc = op.getLoc();
1863
1864 // Compute new loop bounds that omit all single-iteration loop dimensions.
1865 SmallVector<OpFoldResult> newMixedLowerBounds, newMixedUpperBounds,
1866 newMixedSteps;
1867 IRMapping mapping;
1868 for (auto [lb, ub, step, iv] :
1869 llvm::zip(op.getMixedLowerBound(), op.getMixedUpperBound(),
1870 op.getMixedStep(), op.getInductionVars())) {
1871 auto numIterations =
1872 constantTripCount(lb, ub, step, /*isSigned=*/true, computeUbMinusLb);
1873 if (numIterations.has_value()) {
1874 // Remove the loop if it performs zero iterations.
1875 if (*numIterations == 0) {
1876 rewriter.replaceOp(op, op.getOutputs());
1877 return success();
1878 }
1879 // Replace the loop induction variable by the lower bound if the loop
1880 // performs a single iteration. Otherwise, copy the loop bounds.
1881 if (*numIterations == 1) {
1882 mapping.map(iv, getValueOrCreateConstantIndexOp(rewriter, loc, lb));
1883 continue;
1884 }
1885 }
1886 newMixedLowerBounds.push_back(lb);
1887 newMixedUpperBounds.push_back(ub);
1888 newMixedSteps.push_back(step);
1889 }
1890
1891 // All of the loop dimensions perform a single iteration. Inline loop body.
1892 if (newMixedLowerBounds.empty()) {
1893 promote(rewriter, op);
1894 return success();
1895 }
1896
1897 // Exit if none of the loop dimensions perform a single iteration.
1898 if (newMixedLowerBounds.size() == static_cast<unsigned>(op.getRank())) {
1899 return rewriter.notifyMatchFailure(
1900 op, "no dimensions have 0 or 1 iterations");
1901 }
1902
1903 // Replace the loop by a lower-dimensional loop.
1904 ForallOp newOp;
1905 newOp = ForallOp::create(rewriter, loc, newMixedLowerBounds,
1906 newMixedUpperBounds, newMixedSteps,
1907 op.getOutputs(), std::nullopt, nullptr);
1908 newOp.getBodyRegion().getBlocks().clear();
1909 // The new loop needs to keep all attributes from the old one, except for
1910 // "operandSegmentSizes" and static loop bound attributes which capture
1911 // the outdated information of the old iteration domain.
1912 SmallVector<StringAttr> elidedAttrs{newOp.getOperandSegmentSizesAttrName(),
1913 newOp.getStaticLowerBoundAttrName(),
1914 newOp.getStaticUpperBoundAttrName(),
1915 newOp.getStaticStepAttrName()};
1916 for (const auto &namedAttr : op->getAttrs()) {
1917 if (llvm::is_contained(elidedAttrs, namedAttr.getName()))
1918 continue;
1919 rewriter.modifyOpInPlace(newOp, [&]() {
1920 newOp->setAttr(namedAttr.getName(), namedAttr.getValue());
1921 });
1922 }
1923 rewriter.cloneRegionBefore(op.getRegion(), newOp.getRegion(),
1924 newOp.getRegion().begin(), mapping);
1925 rewriter.replaceOp(op, newOp.getResults());
1926 return success();
1927 }
1928};
1929
1930/// Replace all induction vars with a single trip count with their lower bound.
1931struct ForallOpReplaceConstantInductionVar : public OpRewritePattern<ForallOp> {
1932 using OpRewritePattern<ForallOp>::OpRewritePattern;
1933
1934 LogicalResult matchAndRewrite(ForallOp op,
1935 PatternRewriter &rewriter) const override {
1936 Location loc = op.getLoc();
1937 bool changed = false;
1938 for (auto [lb, ub, step, iv] :
1939 llvm::zip(op.getMixedLowerBound(), op.getMixedUpperBound(),
1940 op.getMixedStep(), op.getInductionVars())) {
1941 if (iv.hasNUses(0))
1942 continue;
1943 auto numIterations =
1944 constantTripCount(lb, ub, step, /*isSigned=*/true, computeUbMinusLb);
1945 if (!numIterations.has_value() || numIterations.value() != 1) {
1946 continue;
1947 }
1948 rewriter.replaceAllUsesWith(
1949 iv, getValueOrCreateConstantIndexOp(rewriter, loc, lb));
1950 changed = true;
1951 }
1952 return success(changed);
1953 }
1954};
1955
1956struct FoldTensorCastOfOutputIntoForallOp
1957 : public OpRewritePattern<scf::ForallOp> {
1958 using OpRewritePattern<scf::ForallOp>::OpRewritePattern;
1959
1960 struct TypeCast {
1961 Type srcType;
1962 Type dstType;
1963 };
1964
1965 LogicalResult matchAndRewrite(scf::ForallOp forallOp,
1966 PatternRewriter &rewriter) const final {
1967 llvm::SmallMapVector<unsigned, TypeCast, 2> tensorCastProducers;
1968 llvm::SmallVector<Value> newOutputTensors = forallOp.getOutputs();
1969 for (auto en : llvm::enumerate(newOutputTensors)) {
1970 auto castOp = en.value().getDefiningOp<tensor::CastOp>();
1971 if (!castOp)
1972 continue;
1973
1974 // Only casts that that preserve static information, i.e. will make the
1975 // loop result type "more" static than before, will be folded.
1976 if (!tensor::preservesStaticInformation(castOp.getDest().getType(),
1977 castOp.getSource().getType())) {
1978 continue;
1979 }
1980
1981 tensorCastProducers[en.index()] =
1982 TypeCast{castOp.getSource().getType(), castOp.getType()};
1983 newOutputTensors[en.index()] = castOp.getSource();
1984 }
1985
1986 if (tensorCastProducers.empty())
1987 return failure();
1988
1989 // Create new loop.
1990 Location loc = forallOp.getLoc();
1991 auto newForallOp = ForallOp::create(
1992 rewriter, loc, forallOp.getMixedLowerBound(),
1993 forallOp.getMixedUpperBound(), forallOp.getMixedStep(),
1994 newOutputTensors, forallOp.getMapping(),
1995 [&](OpBuilder nestedBuilder, Location nestedLoc, ValueRange bbArgs) {
1996 auto castBlockArgs =
1997 llvm::to_vector(bbArgs.take_back(forallOp->getNumResults()));
1998 for (auto [index, cast] : tensorCastProducers) {
1999 Value &oldTypeBBArg = castBlockArgs[index];
2000 oldTypeBBArg = tensor::CastOp::create(nestedBuilder, nestedLoc,
2001 cast.dstType, oldTypeBBArg);
2002 }
2003
2004 // Move old body into new parallel loop.
2005 SmallVector<Value> ivsBlockArgs =
2006 llvm::to_vector(bbArgs.take_front(forallOp.getRank()));
2007 ivsBlockArgs.append(castBlockArgs);
2008 rewriter.mergeBlocks(forallOp.getBody(),
2009 bbArgs.front().getParentBlock(), ivsBlockArgs);
2010 });
2011
2012 // After `mergeBlocks` happened, the destinations in the terminator were
2013 // mapped to the tensor.cast old-typed results of the output bbArgs. The
2014 // destination have to be updated to point to the output bbArgs directly.
2015 auto terminator = newForallOp.getTerminator();
2016 for (auto [yieldingOp, outputBlockArg] : llvm::zip(
2017 terminator.getYieldingOps(), newForallOp.getRegionIterArgs())) {
2018 if (auto parallelCombingingOp =
2019 dyn_cast<ParallelCombiningOpInterface>(yieldingOp)) {
2020 parallelCombingingOp.getUpdatedDestinations().assign(outputBlockArg);
2021 }
2022 }
2023
2024 // Cast results back to the original types.
2025 rewriter.setInsertionPointAfter(newForallOp);
2026 SmallVector<Value> castResults = newForallOp.getResults();
2027 for (auto &item : tensorCastProducers) {
2028 Value &oldTypeResult = castResults[item.first];
2029 oldTypeResult = tensor::CastOp::create(rewriter, loc, item.second.dstType,
2030 oldTypeResult);
2031 }
2032 rewriter.replaceOp(forallOp, castResults);
2033 return success();
2034 }
2035};
2036
2037} // namespace
2038
2039void ForallOp::getCanonicalizationPatterns(RewritePatternSet &results,
2040 MLIRContext *context) {
2041 results.add<DimOfForallOp, FoldTensorCastOfOutputIntoForallOp,
2042 ForallOpControlOperandsFolder, ForallOpIterArgsFolder,
2043 ForallOpSingleOrZeroIterationDimsFolder,
2044 ForallOpReplaceConstantInductionVar>(context);
2045}
2046
2047/// Given the region at `index`, or the parent operation if `index` is None,
2048/// return the successor regions. These are the regions that may be selected
2049/// during the flow of control. `operands` is a set of optional attributes that
2050/// correspond to a constant value for each operand, or null if that operand is
2051/// not a constant.
2052void ForallOp::getSuccessorRegions(RegionBranchPoint point,
2053 SmallVectorImpl<RegionSuccessor> &regions) {
2054 // In accordance with the semantics of forall, its body is executed in
2055 // parallel by multiple threads. We should not expect to branch back into
2056 // the forall body after the region's execution is complete.
2057 if (point.isParent())
2058 regions.push_back(RegionSuccessor(&getRegion(), getRegionIterArgs()));
2059 else
2060 regions.push_back(
2061 RegionSuccessor(getOperation(), getOperation()->getResults()));
2062}
2063
2064//===----------------------------------------------------------------------===//
2065// InParallelOp
2066//===----------------------------------------------------------------------===//
2067
2068// Build a InParallelOp with mixed static and dynamic entries.
2069void InParallelOp::build(OpBuilder &b, OperationState &result) {
2070 OpBuilder::InsertionGuard g(b);
2071 Region *bodyRegion = result.addRegion();
2072 b.createBlock(bodyRegion);
2073}
2074
2075LogicalResult InParallelOp::verify() {
2076 scf::ForallOp forallOp =
2077 dyn_cast<scf::ForallOp>(getOperation()->getParentOp());
2078 if (!forallOp)
2079 return this->emitOpError("expected forall op parent");
2080
2081 for (Operation &op : getRegion().front().getOperations()) {
2082 auto parallelCombiningOp = dyn_cast<ParallelCombiningOpInterface>(&op);
2083 if (!parallelCombiningOp) {
2084 return this->emitOpError("expected only ParallelCombiningOpInterface")
2085 << " ops";
2086 }
2087
2088 // Verify that inserts are into out block arguments.
2089 MutableOperandRange dests = parallelCombiningOp.getUpdatedDestinations();
2090 ArrayRef<BlockArgument> regionOutArgs = forallOp.getRegionOutArgs();
2091 for (OpOperand &dest : dests) {
2092 if (!llvm::is_contained(regionOutArgs, dest.get()))
2093 return op.emitOpError("may only insert into an output block argument");
2094 }
2095 }
2096
2097 return success();
2098}
2099
2100void InParallelOp::print(OpAsmPrinter &p) {
2101 p << " ";
2102 p.printRegion(getRegion(),
2103 /*printEntryBlockArgs=*/false,
2104 /*printBlockTerminators=*/false);
2105 p.printOptionalAttrDict(getOperation()->getAttrs());
2106}
2107
2108ParseResult InParallelOp::parse(OpAsmParser &parser, OperationState &result) {
2109 auto &builder = parser.getBuilder();
2110
2111 SmallVector<OpAsmParser::Argument, 8> regionOperands;
2112 std::unique_ptr<Region> region = std::make_unique<Region>();
2113 if (parser.parseRegion(*region, regionOperands))
2114 return failure();
2115
2116 if (region->empty())
2117 OpBuilder(builder.getContext()).createBlock(region.get());
2118 result.addRegion(std::move(region));
2119
2120 // Parse the optional attribute list.
2121 if (parser.parseOptionalAttrDict(result.attributes))
2122 return failure();
2123 return success();
2124}
2125
2126OpResult InParallelOp::getParentResult(int64_t idx) {
2127 return getOperation()->getParentOp()->getResult(idx);
2128}
2129
2130SmallVector<BlockArgument> InParallelOp::getDests() {
2131 SmallVector<BlockArgument> updatedDests;
2132 for (Operation &yieldingOp : getYieldingOps()) {
2133 auto parallelCombiningOp =
2134 dyn_cast<ParallelCombiningOpInterface>(&yieldingOp);
2135 if (!parallelCombiningOp)
2136 continue;
2137 for (OpOperand &updatedOperand :
2138 parallelCombiningOp.getUpdatedDestinations())
2139 updatedDests.push_back(cast<BlockArgument>(updatedOperand.get()));
2140 }
2141 return updatedDests;
2142}
2143
2144llvm::iterator_range<Block::iterator> InParallelOp::getYieldingOps() {
2145 return getRegion().front().getOperations();
2146}
2147
2148//===----------------------------------------------------------------------===//
2149// IfOp
2150//===----------------------------------------------------------------------===//
2151
2153 assert(a && "expected non-empty operation");
2154 assert(b && "expected non-empty operation");
2155
2156 IfOp ifOp = a->getParentOfType<IfOp>();
2157 while (ifOp) {
2158 // Check if b is inside ifOp. (We already know that a is.)
2159 if (ifOp->isProperAncestor(b))
2160 // b is contained in ifOp. a and b are in mutually exclusive branches if
2161 // they are in different blocks of ifOp.
2162 return static_cast<bool>(ifOp.thenBlock()->findAncestorOpInBlock(*a)) !=
2163 static_cast<bool>(ifOp.thenBlock()->findAncestorOpInBlock(*b));
2164 // Check next enclosing IfOp.
2165 ifOp = ifOp->getParentOfType<IfOp>();
2166 }
2167
2168 // Could not find a common IfOp among a's and b's ancestors.
2169 return false;
2170}
2171
2172LogicalResult
2173IfOp::inferReturnTypes(MLIRContext *ctx, std::optional<Location> loc,
2174 IfOp::Adaptor adaptor,
2175 SmallVectorImpl<Type> &inferredReturnTypes) {
2176 if (adaptor.getRegions().empty())
2177 return failure();
2178 Region *r = &adaptor.getThenRegion();
2179 if (r->empty())
2180 return failure();
2181 Block &b = r->front();
2182 if (b.empty())
2183 return failure();
2184 auto yieldOp = llvm::dyn_cast<YieldOp>(b.back());
2185 if (!yieldOp)
2186 return failure();
2187 TypeRange types = yieldOp.getOperandTypes();
2188 llvm::append_range(inferredReturnTypes, types);
2189 return success();
2190}
2191
2192void IfOp::build(OpBuilder &builder, OperationState &result,
2193 TypeRange resultTypes, Value cond) {
2194 return build(builder, result, resultTypes, cond, /*addThenBlock=*/false,
2195 /*addElseBlock=*/false);
2196}
2197
2198void IfOp::build(OpBuilder &builder, OperationState &result,
2199 TypeRange resultTypes, Value cond, bool addThenBlock,
2200 bool addElseBlock) {
2201 assert((!addElseBlock || addThenBlock) &&
2202 "must not create else block w/o then block");
2203 result.addTypes(resultTypes);
2204 result.addOperands(cond);
2205
2206 // Add regions and blocks.
2207 OpBuilder::InsertionGuard guard(builder);
2208 Region *thenRegion = result.addRegion();
2209 if (addThenBlock)
2210 builder.createBlock(thenRegion);
2211 Region *elseRegion = result.addRegion();
2212 if (addElseBlock)
2213 builder.createBlock(elseRegion);
2214}
2215
2216void IfOp::build(OpBuilder &builder, OperationState &result, Value cond,
2217 bool withElseRegion) {
2218 build(builder, result, TypeRange{}, cond, withElseRegion);
2219}
2220
2221void IfOp::build(OpBuilder &builder, OperationState &result,
2222 TypeRange resultTypes, Value cond, bool withElseRegion) {
2223 result.addTypes(resultTypes);
2224 result.addOperands(cond);
2225
2226 // Build then region.
2227 OpBuilder::InsertionGuard guard(builder);
2228 Region *thenRegion = result.addRegion();
2229 builder.createBlock(thenRegion);
2230 if (resultTypes.empty())
2231 IfOp::ensureTerminator(*thenRegion, builder, result.location);
2232
2233 // Build else region.
2234 Region *elseRegion = result.addRegion();
2235 if (withElseRegion) {
2236 builder.createBlock(elseRegion);
2237 if (resultTypes.empty())
2238 IfOp::ensureTerminator(*elseRegion, builder, result.location);
2239 }
2240}
2241
2242void IfOp::build(OpBuilder &builder, OperationState &result, Value cond,
2243 function_ref<void(OpBuilder &, Location)> thenBuilder,
2244 function_ref<void(OpBuilder &, Location)> elseBuilder) {
2245 assert(thenBuilder && "the builder callback for 'then' must be present");
2246 result.addOperands(cond);
2247
2248 // Build then region.
2249 OpBuilder::InsertionGuard guard(builder);
2250 Region *thenRegion = result.addRegion();
2251 builder.createBlock(thenRegion);
2252 thenBuilder(builder, result.location);
2253
2254 // Build else region.
2255 Region *elseRegion = result.addRegion();
2256 if (elseBuilder) {
2257 builder.createBlock(elseRegion);
2258 elseBuilder(builder, result.location);
2259 }
2260
2261 // Infer result types.
2262 SmallVector<Type> inferredReturnTypes;
2263 MLIRContext *ctx = builder.getContext();
2264 auto attrDict = DictionaryAttr::get(ctx, result.attributes);
2265 if (succeeded(inferReturnTypes(ctx, std::nullopt, result.operands, attrDict,
2266 /*properties=*/nullptr, result.regions,
2267 inferredReturnTypes))) {
2268 result.addTypes(inferredReturnTypes);
2269 }
2270}
2271
2272LogicalResult IfOp::verify() {
2273 if (getNumResults() != 0 && getElseRegion().empty())
2274 return emitOpError("must have an else block if defining values");
2275 return success();
2276}
2277
2278ParseResult IfOp::parse(OpAsmParser &parser, OperationState &result) {
2279 // Create the regions for 'then'.
2280 result.regions.reserve(2);
2281 Region *thenRegion = result.addRegion();
2282 Region *elseRegion = result.addRegion();
2283
2284 auto &builder = parser.getBuilder();
2285 OpAsmParser::UnresolvedOperand cond;
2286 Type i1Type = builder.getIntegerType(1);
2287 if (parser.parseOperand(cond) ||
2288 parser.resolveOperand(cond, i1Type, result.operands))
2289 return failure();
2290 // Parse optional results type list.
2291 if (parser.parseOptionalArrowTypeList(result.types))
2292 return failure();
2293 // Parse the 'then' region.
2294 if (parser.parseRegion(*thenRegion, /*arguments=*/{}, /*argTypes=*/{}))
2295 return failure();
2296 IfOp::ensureTerminator(*thenRegion, parser.getBuilder(), result.location);
2297
2298 // If we find an 'else' keyword then parse the 'else' region.
2299 if (!parser.parseOptionalKeyword("else")) {
2300 if (parser.parseRegion(*elseRegion, /*arguments=*/{}, /*argTypes=*/{}))
2301 return failure();
2302 IfOp::ensureTerminator(*elseRegion, parser.getBuilder(), result.location);
2303 }
2304
2305 // Parse the optional attribute list.
2306 if (parser.parseOptionalAttrDict(result.attributes))
2307 return failure();
2308 return success();
2309}
2310
2311void IfOp::print(OpAsmPrinter &p) {
2312 bool printBlockTerminators = false;
2313
2314 p << " " << getCondition();
2315 if (!getResults().empty()) {
2316 p << " -> (" << getResultTypes() << ")";
2317 // Print yield explicitly if the op defines values.
2318 printBlockTerminators = true;
2319 }
2320 p << ' ';
2321 p.printRegion(getThenRegion(),
2322 /*printEntryBlockArgs=*/false,
2323 /*printBlockTerminators=*/printBlockTerminators);
2324
2325 // Print the 'else' regions if it exists and has a block.
2326 auto &elseRegion = getElseRegion();
2327 if (!elseRegion.empty()) {
2328 p << " else ";
2329 p.printRegion(elseRegion,
2330 /*printEntryBlockArgs=*/false,
2331 /*printBlockTerminators=*/printBlockTerminators);
2332 }
2333
2334 p.printOptionalAttrDict((*this)->getAttrs());
2335}
2336
2337void IfOp::getSuccessorRegions(RegionBranchPoint point,
2338 SmallVectorImpl<RegionSuccessor> &regions) {
2339 // The `then` and the `else` region branch back to the parent operation or one
2340 // of the recursive parent operations (early exit case).
2341 if (!point.isParent()) {
2342 regions.push_back(RegionSuccessor(getOperation(), getResults()));
2343 return;
2344 }
2345
2346 regions.push_back(RegionSuccessor(&getThenRegion()));
2347
2348 // Don't consider the else region if it is empty.
2349 Region *elseRegion = &this->getElseRegion();
2350 if (elseRegion->empty())
2351 regions.push_back(
2352 RegionSuccessor(getOperation(), getOperation()->getResults()));
2353 else
2354 regions.push_back(RegionSuccessor(elseRegion));
2355}
2356
2357void IfOp::getEntrySuccessorRegions(ArrayRef<Attribute> operands,
2358 SmallVectorImpl<RegionSuccessor> &regions) {
2359 FoldAdaptor adaptor(operands, *this);
2360 auto boolAttr = dyn_cast_or_null<BoolAttr>(adaptor.getCondition());
2361 if (!boolAttr || boolAttr.getValue())
2362 regions.emplace_back(&getThenRegion());
2363
2364 // If the else region is empty, execution continues after the parent op.
2365 if (!boolAttr || !boolAttr.getValue()) {
2366 if (!getElseRegion().empty())
2367 regions.emplace_back(&getElseRegion());
2368 else
2369 regions.emplace_back(getOperation(), getResults());
2370 }
2371}
2372
2373LogicalResult IfOp::fold(FoldAdaptor adaptor,
2374 SmallVectorImpl<OpFoldResult> &results) {
2375 // if (!c) then A() else B() -> if c then B() else A()
2376 if (getElseRegion().empty())
2377 return failure();
2378
2379 arith::XOrIOp xorStmt = getCondition().getDefiningOp<arith::XOrIOp>();
2380 if (!xorStmt)
2381 return failure();
2382
2383 if (!matchPattern(xorStmt.getRhs(), m_One()))
2384 return failure();
2385
2386 getConditionMutable().assign(xorStmt.getLhs());
2387 Block *thenBlock = &getThenRegion().front();
2388 // It would be nicer to use iplist::swap, but that has no implemented
2389 // callbacks See: https://llvm.org/doxygen/ilist_8h_source.html#l00224
2390 getThenRegion().getBlocks().splice(getThenRegion().getBlocks().begin(),
2391 getElseRegion().getBlocks());
2392 getElseRegion().getBlocks().splice(getElseRegion().getBlocks().begin(),
2393 getThenRegion().getBlocks(), thenBlock);
2394 return success();
2395}
2396
2397void IfOp::getRegionInvocationBounds(
2398 ArrayRef<Attribute> operands,
2399 SmallVectorImpl<InvocationBounds> &invocationBounds) {
2400 if (auto cond = llvm::dyn_cast_or_null<BoolAttr>(operands[0])) {
2401 // If the condition is known, then one region is known to be executed once
2402 // and the other zero times.
2403 invocationBounds.emplace_back(0, cond.getValue() ? 1 : 0);
2404 invocationBounds.emplace_back(0, cond.getValue() ? 0 : 1);
2405 } else {
2406 // Non-constant condition. Each region may be executed 0 or 1 times.
2407 invocationBounds.assign(2, {0, 1});
2408 }
2409}
2410
2411namespace {
2412// Pattern to remove unused IfOp results.
2413struct RemoveUnusedResults : public OpRewritePattern<IfOp> {
2414 using OpRewritePattern<IfOp>::OpRewritePattern;
2415
2416 void transferBody(Block *source, Block *dest, ArrayRef<OpResult> usedResults,
2417 PatternRewriter &rewriter) const {
2418 // Move all operations to the destination block.
2419 rewriter.mergeBlocks(source, dest);
2420 // Replace the yield op by one that returns only the used values.
2421 auto yieldOp = cast<scf::YieldOp>(dest->getTerminator());
2422 SmallVector<Value, 4> usedOperands;
2423 llvm::transform(usedResults, std::back_inserter(usedOperands),
2424 [&](OpResult result) {
2425 return yieldOp.getOperand(result.getResultNumber());
2426 });
2427 rewriter.modifyOpInPlace(yieldOp,
2428 [&]() { yieldOp->setOperands(usedOperands); });
2429 }
2430
2431 LogicalResult matchAndRewrite(IfOp op,
2432 PatternRewriter &rewriter) const override {
2433 // Compute the list of used results.
2434 SmallVector<OpResult, 4> usedResults;
2435 llvm::copy_if(op.getResults(), std::back_inserter(usedResults),
2436 [](OpResult result) { return !result.use_empty(); });
2437
2438 // Replace the operation if only a subset of its results have uses.
2439 if (usedResults.size() == op.getNumResults())
2440 return failure();
2441
2442 // Compute the result types of the replacement operation.
2443 SmallVector<Type, 4> newTypes;
2444 llvm::transform(usedResults, std::back_inserter(newTypes),
2445 [](OpResult result) { return result.getType(); });
2446
2447 // Create a replacement operation with empty then and else regions.
2448 auto newOp =
2449 IfOp::create(rewriter, op.getLoc(), newTypes, op.getCondition());
2450 rewriter.createBlock(&newOp.getThenRegion());
2451 rewriter.createBlock(&newOp.getElseRegion());
2452
2453 // Move the bodies and replace the terminators (note there is a then and
2454 // an else region since the operation returns results).
2455 transferBody(op.getBody(0), newOp.getBody(0), usedResults, rewriter);
2456 transferBody(op.getBody(1), newOp.getBody(1), usedResults, rewriter);
2457
2458 // Replace the operation by the new one.
2459 SmallVector<Value, 4> repResults(op.getNumResults());
2460 for (const auto &en : llvm::enumerate(usedResults))
2461 repResults[en.value().getResultNumber()] = newOp.getResult(en.index());
2462 rewriter.replaceOp(op, repResults);
2463 return success();
2464 }
2465};
2466
2467struct RemoveStaticCondition : public OpRewritePattern<IfOp> {
2468 using OpRewritePattern<IfOp>::OpRewritePattern;
2469
2470 LogicalResult matchAndRewrite(IfOp op,
2471 PatternRewriter &rewriter) const override {
2472 BoolAttr condition;
2473 if (!matchPattern(op.getCondition(), m_Constant(&condition)))
2474 return failure();
2475
2476 if (condition.getValue())
2477 replaceOpWithRegion(rewriter, op, op.getThenRegion());
2478 else if (!op.getElseRegion().empty())
2479 replaceOpWithRegion(rewriter, op, op.getElseRegion());
2480 else
2481 rewriter.eraseOp(op);
2482
2483 return success();
2484 }
2485};
2486
2487/// Hoist any yielded results whose operands are defined outside
2488/// the if, to a select instruction.
2489struct ConvertTrivialIfToSelect : public OpRewritePattern<IfOp> {
2490 using OpRewritePattern<IfOp>::OpRewritePattern;
2491
2492 LogicalResult matchAndRewrite(IfOp op,
2493 PatternRewriter &rewriter) const override {
2494 if (op->getNumResults() == 0)
2495 return failure();
2496
2497 auto cond = op.getCondition();
2498 auto thenYieldArgs = op.thenYield().getOperands();
2499 auto elseYieldArgs = op.elseYield().getOperands();
2500
2501 SmallVector<Type> nonHoistable;
2502 for (auto [trueVal, falseVal] : llvm::zip(thenYieldArgs, elseYieldArgs)) {
2503 if (&op.getThenRegion() == trueVal.getParentRegion() ||
2504 &op.getElseRegion() == falseVal.getParentRegion())
2505 nonHoistable.push_back(trueVal.getType());
2506 }
2507 // Early exit if there aren't any yielded values we can
2508 // hoist outside the if.
2509 if (nonHoistable.size() == op->getNumResults())
2510 return failure();
2511
2512 IfOp replacement = IfOp::create(rewriter, op.getLoc(), nonHoistable, cond,
2513 /*withElseRegion=*/false);
2514 if (replacement.thenBlock())
2515 rewriter.eraseBlock(replacement.thenBlock());
2516 replacement.getThenRegion().takeBody(op.getThenRegion());
2517 replacement.getElseRegion().takeBody(op.getElseRegion());
2518
2519 SmallVector<Value> results(op->getNumResults());
2520 assert(thenYieldArgs.size() == results.size());
2521 assert(elseYieldArgs.size() == results.size());
2522
2523 SmallVector<Value> trueYields;
2524 SmallVector<Value> falseYields;
2526 for (const auto &it :
2527 llvm::enumerate(llvm::zip(thenYieldArgs, elseYieldArgs))) {
2528 Value trueVal = std::get<0>(it.value());
2529 Value falseVal = std::get<1>(it.value());
2530 if (&replacement.getThenRegion() == trueVal.getParentRegion() ||
2531 &replacement.getElseRegion() == falseVal.getParentRegion()) {
2532 results[it.index()] = replacement.getResult(trueYields.size());
2533 trueYields.push_back(trueVal);
2534 falseYields.push_back(falseVal);
2535 } else if (trueVal == falseVal)
2536 results[it.index()] = trueVal;
2537 else
2538 results[it.index()] = arith::SelectOp::create(rewriter, op.getLoc(),
2539 cond, trueVal, falseVal);
2540 }
2541
2542 rewriter.setInsertionPointToEnd(replacement.thenBlock());
2543 rewriter.replaceOpWithNewOp<YieldOp>(replacement.thenYield(), trueYields);
2544
2545 rewriter.setInsertionPointToEnd(replacement.elseBlock());
2546 rewriter.replaceOpWithNewOp<YieldOp>(replacement.elseYield(), falseYields);
2547
2548 rewriter.replaceOp(op, results);
2549 return success();
2550 }
2551};
2552
2553/// Allow the true region of an if to assume the condition is true
2554/// and vice versa. For example:
2555///
2556/// scf.if %cmp {
2557/// print(%cmp)
2558/// }
2559///
2560/// becomes
2561///
2562/// scf.if %cmp {
2563/// print(true)
2564/// }
2565///
2566struct ConditionPropagation : public OpRewritePattern<IfOp> {
2567 using OpRewritePattern<IfOp>::OpRewritePattern;
2568
2569 /// Kind of parent region in the ancestor cache.
2570 enum class Parent { Then, Else, None };
2571
2572 /// Returns the kind of region ("then", "else", or "none") of the
2573 /// IfOp that the given region is transitively nested in. Updates
2574 /// the cache accordingly.
2575 static Parent getParentType(Region *toCheck, IfOp op,
2577 Region *endRegion) {
2578 SmallVector<Region *> seen;
2579 while (toCheck != endRegion) {
2580 auto found = cache.find(toCheck);
2581 if (found != cache.end())
2582 return found->second;
2583 seen.push_back(toCheck);
2584 if (&op.getThenRegion() == toCheck) {
2585 for (Region *region : seen)
2586 cache[region] = Parent::Then;
2587 return Parent::Then;
2588 }
2589 if (&op.getElseRegion() == toCheck) {
2590 for (Region *region : seen)
2591 cache[region] = Parent::Else;
2592 return Parent::Else;
2593 }
2594 toCheck = toCheck->getParentRegion();
2595 }
2596
2597 for (Region *region : seen)
2598 cache[region] = Parent::None;
2599 return Parent::None;
2600 }
2601
2602 LogicalResult matchAndRewrite(IfOp op,
2603 PatternRewriter &rewriter) const override {
2604 // Early exit if the condition is constant since replacing a constant
2605 // in the body with another constant isn't a simplification.
2606 if (matchPattern(op.getCondition(), m_Constant()))
2607 return failure();
2608
2609 bool changed = false;
2610 mlir::Type i1Ty = rewriter.getI1Type();
2611
2612 // These variables serve to prevent creating duplicate constants
2613 // and hold constant true or false values.
2614 Value constantTrue = nullptr;
2615 Value constantFalse = nullptr;
2616
2618 for (OpOperand &use :
2619 llvm::make_early_inc_range(op.getCondition().getUses())) {
2620 switch (getParentType(use.getOwner()->getParentRegion(), op, cache,
2621 op.getCondition().getParentRegion())) {
2622 case Parent::Then: {
2623 changed = true;
2624
2625 if (!constantTrue)
2626 constantTrue = arith::ConstantOp::create(
2627 rewriter, op.getLoc(), i1Ty, rewriter.getIntegerAttr(i1Ty, 1));
2628
2629 rewriter.modifyOpInPlace(use.getOwner(),
2630 [&]() { use.set(constantTrue); });
2631 break;
2632 }
2633 case Parent::Else: {
2634 changed = true;
2635
2636 if (!constantFalse)
2637 constantFalse = arith::ConstantOp::create(
2638 rewriter, op.getLoc(), i1Ty, rewriter.getIntegerAttr(i1Ty, 0));
2639
2640 rewriter.modifyOpInPlace(use.getOwner(),
2641 [&]() { use.set(constantFalse); });
2642 break;
2643 }
2644 case Parent::None:
2645 break;
2646 }
2647 }
2648
2649 return success(changed);
2650 }
2651};
2652
2653/// Remove any statements from an if that are equivalent to the condition
2654/// or its negation. For example:
2655///
2656/// %res:2 = scf.if %cmp {
2657/// yield something(), true
2658/// } else {
2659/// yield something2(), false
2660/// }
2661/// print(%res#1)
2662///
2663/// becomes
2664/// %res = scf.if %cmp {
2665/// yield something()
2666/// } else {
2667/// yield something2()
2668/// }
2669/// print(%cmp)
2670///
2671/// Additionally if both branches yield the same value, replace all uses
2672/// of the result with the yielded value.
2673///
2674/// %res:2 = scf.if %cmp {
2675/// yield something(), %arg1
2676/// } else {
2677/// yield something2(), %arg1
2678/// }
2679/// print(%res#1)
2680///
2681/// becomes
2682/// %res = scf.if %cmp {
2683/// yield something()
2684/// } else {
2685/// yield something2()
2686/// }
2687/// print(%arg1)
2688///
2689struct ReplaceIfYieldWithConditionOrValue : public OpRewritePattern<IfOp> {
2690 using OpRewritePattern<IfOp>::OpRewritePattern;
2691
2692 LogicalResult matchAndRewrite(IfOp op,
2693 PatternRewriter &rewriter) const override {
2694 // Early exit if there are no results that could be replaced.
2695 if (op.getNumResults() == 0)
2696 return failure();
2697
2698 auto trueYield =
2699 cast<scf::YieldOp>(op.getThenRegion().back().getTerminator());
2700 auto falseYield =
2701 cast<scf::YieldOp>(op.getElseRegion().back().getTerminator());
2702
2703 rewriter.setInsertionPoint(op->getBlock(),
2704 op.getOperation()->getIterator());
2705 bool changed = false;
2706 Type i1Ty = rewriter.getI1Type();
2707 for (auto [trueResult, falseResult, opResult] :
2708 llvm::zip(trueYield.getResults(), falseYield.getResults(),
2709 op.getResults())) {
2710 if (trueResult == falseResult) {
2711 if (!opResult.use_empty()) {
2712 opResult.replaceAllUsesWith(trueResult);
2713 changed = true;
2714 }
2715 continue;
2716 }
2717
2718 BoolAttr trueYield, falseYield;
2719 if (!matchPattern(trueResult, m_Constant(&trueYield)) ||
2720 !matchPattern(falseResult, m_Constant(&falseYield)))
2721 continue;
2722
2723 bool trueVal = trueYield.getValue();
2724 bool falseVal = falseYield.getValue();
2725 if (!trueVal && falseVal) {
2726 if (!opResult.use_empty()) {
2727 Dialect *constDialect = trueResult.getDefiningOp()->getDialect();
2728 Value notCond = arith::XOrIOp::create(
2729 rewriter, op.getLoc(), op.getCondition(),
2730 constDialect
2731 ->materializeConstant(rewriter,
2732 rewriter.getIntegerAttr(i1Ty, 1), i1Ty,
2733 op.getLoc())
2734 ->getResult(0));
2735 opResult.replaceAllUsesWith(notCond);
2736 changed = true;
2737 }
2738 }
2739 if (trueVal && !falseVal) {
2740 if (!opResult.use_empty()) {
2741 opResult.replaceAllUsesWith(op.getCondition());
2742 changed = true;
2743 }
2744 }
2745 }
2746 return success(changed);
2747 }
2748};
2749
2750/// Merge any consecutive scf.if's with the same condition.
2751///
2752/// scf.if %cond {
2753/// firstCodeTrue();...
2754/// } else {
2755/// firstCodeFalse();...
2756/// }
2757/// %res = scf.if %cond {
2758/// secondCodeTrue();...
2759/// } else {
2760/// secondCodeFalse();...
2761/// }
2762///
2763/// becomes
2764/// %res = scf.if %cmp {
2765/// firstCodeTrue();...
2766/// secondCodeTrue();...
2767/// } else {
2768/// firstCodeFalse();...
2769/// secondCodeFalse();...
2770/// }
2771struct CombineIfs : public OpRewritePattern<IfOp> {
2772 using OpRewritePattern<IfOp>::OpRewritePattern;
2773
2774 LogicalResult matchAndRewrite(IfOp nextIf,
2775 PatternRewriter &rewriter) const override {
2776 Block *parent = nextIf->getBlock();
2777 if (nextIf == &parent->front())
2778 return failure();
2779
2780 auto prevIf = dyn_cast<IfOp>(nextIf->getPrevNode());
2781 if (!prevIf)
2782 return failure();
2783
2784 // Determine the logical then/else blocks when prevIf's
2785 // condition is used. Null means the block does not exist
2786 // in that case (e.g. empty else). If neither of these
2787 // are set, the two conditions cannot be compared.
2788 Block *nextThen = nullptr;
2789 Block *nextElse = nullptr;
2790 if (nextIf.getCondition() == prevIf.getCondition()) {
2791 nextThen = nextIf.thenBlock();
2792 if (!nextIf.getElseRegion().empty())
2793 nextElse = nextIf.elseBlock();
2794 }
2795 if (arith::XOrIOp notv =
2796 nextIf.getCondition().getDefiningOp<arith::XOrIOp>()) {
2797 if (notv.getLhs() == prevIf.getCondition() &&
2798 matchPattern(notv.getRhs(), m_One())) {
2799 nextElse = nextIf.thenBlock();
2800 if (!nextIf.getElseRegion().empty())
2801 nextThen = nextIf.elseBlock();
2802 }
2803 }
2804 if (arith::XOrIOp notv =
2805 prevIf.getCondition().getDefiningOp<arith::XOrIOp>()) {
2806 if (notv.getLhs() == nextIf.getCondition() &&
2807 matchPattern(notv.getRhs(), m_One())) {
2808 nextElse = nextIf.thenBlock();
2809 if (!nextIf.getElseRegion().empty())
2810 nextThen = nextIf.elseBlock();
2811 }
2812 }
2813
2814 if (!nextThen && !nextElse)
2815 return failure();
2816
2817 SmallVector<Value> prevElseYielded;
2818 if (!prevIf.getElseRegion().empty())
2819 prevElseYielded = prevIf.elseYield().getOperands();
2820 // Replace all uses of return values of op within nextIf with the
2821 // corresponding yields
2822 for (auto it : llvm::zip(prevIf.getResults(),
2823 prevIf.thenYield().getOperands(), prevElseYielded))
2824 for (OpOperand &use :
2825 llvm::make_early_inc_range(std::get<0>(it).getUses())) {
2826 if (nextThen && nextThen->getParent()->isAncestor(
2827 use.getOwner()->getParentRegion())) {
2828 rewriter.startOpModification(use.getOwner());
2829 use.set(std::get<1>(it));
2830 rewriter.finalizeOpModification(use.getOwner());
2831 } else if (nextElse && nextElse->getParent()->isAncestor(
2832 use.getOwner()->getParentRegion())) {
2833 rewriter.startOpModification(use.getOwner());
2834 use.set(std::get<2>(it));
2835 rewriter.finalizeOpModification(use.getOwner());
2836 }
2837 }
2838
2839 SmallVector<Type> mergedTypes(prevIf.getResultTypes());
2840 llvm::append_range(mergedTypes, nextIf.getResultTypes());
2841
2842 IfOp combinedIf = IfOp::create(rewriter, nextIf.getLoc(), mergedTypes,
2843 prevIf.getCondition(), /*hasElse=*/false);
2844 rewriter.eraseBlock(&combinedIf.getThenRegion().back());
2845
2846 rewriter.inlineRegionBefore(prevIf.getThenRegion(),
2847 combinedIf.getThenRegion(),
2848 combinedIf.getThenRegion().begin());
2849
2850 if (nextThen) {
2851 YieldOp thenYield = combinedIf.thenYield();
2852 YieldOp thenYield2 = cast<YieldOp>(nextThen->getTerminator());
2853 rewriter.mergeBlocks(nextThen, combinedIf.thenBlock());
2854 rewriter.setInsertionPointToEnd(combinedIf.thenBlock());
2855
2856 SmallVector<Value> mergedYields(thenYield.getOperands());
2857 llvm::append_range(mergedYields, thenYield2.getOperands());
2858 YieldOp::create(rewriter, thenYield2.getLoc(), mergedYields);
2859 rewriter.eraseOp(thenYield);
2860 rewriter.eraseOp(thenYield2);
2861 }
2862
2863 rewriter.inlineRegionBefore(prevIf.getElseRegion(),
2864 combinedIf.getElseRegion(),
2865 combinedIf.getElseRegion().begin());
2866
2867 if (nextElse) {
2868 if (combinedIf.getElseRegion().empty()) {
2869 rewriter.inlineRegionBefore(*nextElse->getParent(),
2870 combinedIf.getElseRegion(),
2871 combinedIf.getElseRegion().begin());
2872 } else {
2873 YieldOp elseYield = combinedIf.elseYield();
2874 YieldOp elseYield2 = cast<YieldOp>(nextElse->getTerminator());
2875 rewriter.mergeBlocks(nextElse, combinedIf.elseBlock());
2876
2877 rewriter.setInsertionPointToEnd(combinedIf.elseBlock());
2878
2879 SmallVector<Value> mergedElseYields(elseYield.getOperands());
2880 llvm::append_range(mergedElseYields, elseYield2.getOperands());
2881
2882 YieldOp::create(rewriter, elseYield2.getLoc(), mergedElseYields);
2883 rewriter.eraseOp(elseYield);
2884 rewriter.eraseOp(elseYield2);
2885 }
2886 }
2887
2888 SmallVector<Value> prevValues;
2889 SmallVector<Value> nextValues;
2890 for (const auto &pair : llvm::enumerate(combinedIf.getResults())) {
2891 if (pair.index() < prevIf.getNumResults())
2892 prevValues.push_back(pair.value());
2893 else
2894 nextValues.push_back(pair.value());
2895 }
2896 rewriter.replaceOp(prevIf, prevValues);
2897 rewriter.replaceOp(nextIf, nextValues);
2898 return success();
2899 }
2900};
2901
2902/// Pattern to remove an empty else branch.
2903struct RemoveEmptyElseBranch : public OpRewritePattern<IfOp> {
2904 using OpRewritePattern<IfOp>::OpRewritePattern;
2905
2906 LogicalResult matchAndRewrite(IfOp ifOp,
2907 PatternRewriter &rewriter) const override {
2908 // Cannot remove else region when there are operation results.
2909 if (ifOp.getNumResults())
2910 return failure();
2911 Block *elseBlock = ifOp.elseBlock();
2912 if (!elseBlock || !llvm::hasSingleElement(*elseBlock))
2913 return failure();
2914 auto newIfOp = rewriter.cloneWithoutRegions(ifOp);
2915 rewriter.inlineRegionBefore(ifOp.getThenRegion(), newIfOp.getThenRegion(),
2916 newIfOp.getThenRegion().begin());
2917 rewriter.eraseOp(ifOp);
2918 return success();
2919 }
2920};
2921
2922/// Convert nested `if`s into `arith.andi` + single `if`.
2923///
2924/// scf.if %arg0 {
2925/// scf.if %arg1 {
2926/// ...
2927/// scf.yield
2928/// }
2929/// scf.yield
2930/// }
2931/// becomes
2932///
2933/// %0 = arith.andi %arg0, %arg1
2934/// scf.if %0 {
2935/// ...
2936/// scf.yield
2937/// }
2938struct CombineNestedIfs : public OpRewritePattern<IfOp> {
2939 using OpRewritePattern<IfOp>::OpRewritePattern;
2940
2941 LogicalResult matchAndRewrite(IfOp op,
2942 PatternRewriter &rewriter) const override {
2943 auto nestedOps = op.thenBlock()->without_terminator();
2944 // Nested `if` must be the only op in block.
2945 if (!llvm::hasSingleElement(nestedOps))
2946 return failure();
2947
2948 // If there is an else block, it can only yield
2949 if (op.elseBlock() && !llvm::hasSingleElement(*op.elseBlock()))
2950 return failure();
2951
2952 auto nestedIf = dyn_cast<IfOp>(*nestedOps.begin());
2953 if (!nestedIf)
2954 return failure();
2955
2956 if (nestedIf.elseBlock() && !llvm::hasSingleElement(*nestedIf.elseBlock()))
2957 return failure();
2958
2959 SmallVector<Value> thenYield(op.thenYield().getOperands());
2960 SmallVector<Value> elseYield;
2961 if (op.elseBlock())
2962 llvm::append_range(elseYield, op.elseYield().getOperands());
2963
2964 // A list of indices for which we should upgrade the value yielded
2965 // in the else to a select.
2966 SmallVector<unsigned> elseYieldsToUpgradeToSelect;
2967
2968 // If the outer scf.if yields a value produced by the inner scf.if,
2969 // only permit combining if the value yielded when the condition
2970 // is false in the outer scf.if is the same value yielded when the
2971 // inner scf.if condition is false.
2972 // Note that the array access to elseYield will not go out of bounds
2973 // since it must have the same length as thenYield, since they both
2974 // come from the same scf.if.
2975 for (const auto &tup : llvm::enumerate(thenYield)) {
2976 if (tup.value().getDefiningOp() == nestedIf) {
2977 auto nestedIdx = llvm::cast<OpResult>(tup.value()).getResultNumber();
2978 if (nestedIf.elseYield().getOperand(nestedIdx) !=
2979 elseYield[tup.index()]) {
2980 return failure();
2981 }
2982 // If the correctness test passes, we will yield
2983 // corresponding value from the inner scf.if
2984 thenYield[tup.index()] = nestedIf.thenYield().getOperand(nestedIdx);
2985 continue;
2986 }
2987
2988 // Otherwise, we need to ensure the else block of the combined
2989 // condition still returns the same value when the outer condition is
2990 // true and the inner condition is false. This can be accomplished if
2991 // the then value is defined outside the outer scf.if and we replace the
2992 // value with a select that considers just the outer condition. Since
2993 // the else region contains just the yield, its yielded value is
2994 // defined outside the scf.if, by definition.
2995
2996 // If the then value is defined within the scf.if, bail.
2997 if (tup.value().getParentRegion() == &op.getThenRegion()) {
2998 return failure();
2999 }
3000 elseYieldsToUpgradeToSelect.push_back(tup.index());
3001 }
3002
3003 Location loc = op.getLoc();
3004 Value newCondition = arith::AndIOp::create(rewriter, loc, op.getCondition(),
3005 nestedIf.getCondition());
3006 auto newIf = IfOp::create(rewriter, loc, op.getResultTypes(), newCondition);
3007 Block *newIfBlock = rewriter.createBlock(&newIf.getThenRegion());
3008
3009 SmallVector<Value> results;
3010 llvm::append_range(results, newIf.getResults());
3011 rewriter.setInsertionPoint(newIf);
3012
3013 for (auto idx : elseYieldsToUpgradeToSelect)
3014 results[idx] =
3015 arith::SelectOp::create(rewriter, op.getLoc(), op.getCondition(),
3016 thenYield[idx], elseYield[idx]);
3017
3018 rewriter.mergeBlocks(nestedIf.thenBlock(), newIfBlock);
3019 rewriter.setInsertionPointToEnd(newIf.thenBlock());
3020 rewriter.replaceOpWithNewOp<YieldOp>(newIf.thenYield(), thenYield);
3021 if (!elseYield.empty()) {
3022 rewriter.createBlock(&newIf.getElseRegion());
3023 rewriter.setInsertionPointToEnd(newIf.elseBlock());
3024 YieldOp::create(rewriter, loc, elseYield);
3025 }
3026 rewriter.replaceOp(op, results);
3027 return success();
3028 }
3029};
3030
3031} // namespace
3032
3033void IfOp::getCanonicalizationPatterns(RewritePatternSet &results,
3034 MLIRContext *context) {
3035 results.add<CombineIfs, CombineNestedIfs, ConditionPropagation,
3036 ConvertTrivialIfToSelect, RemoveEmptyElseBranch,
3037 RemoveStaticCondition, RemoveUnusedResults,
3038 ReplaceIfYieldWithConditionOrValue>(context);
3039}
3040
3041Block *IfOp::thenBlock() { return &getThenRegion().back(); }
3042YieldOp IfOp::thenYield() { return cast<YieldOp>(&thenBlock()->back()); }
3043Block *IfOp::elseBlock() {
3044 Region &r = getElseRegion();
3045 if (r.empty())
3046 return nullptr;
3047 return &r.back();
3048}
3049YieldOp IfOp::elseYield() { return cast<YieldOp>(&elseBlock()->back()); }
3050
3051//===----------------------------------------------------------------------===//
3052// ParallelOp
3053//===----------------------------------------------------------------------===//
3054
3055void ParallelOp::build(
3056 OpBuilder &builder, OperationState &result, ValueRange lowerBounds,
3057 ValueRange upperBounds, ValueRange steps, ValueRange initVals,
3058 function_ref<void(OpBuilder &, Location, ValueRange, ValueRange)>
3059 bodyBuilderFn) {
3060 result.addOperands(lowerBounds);
3061 result.addOperands(upperBounds);
3062 result.addOperands(steps);
3063 result.addOperands(initVals);
3064 result.addAttribute(
3065 ParallelOp::getOperandSegmentSizeAttr(),
3066 builder.getDenseI32ArrayAttr({static_cast<int32_t>(lowerBounds.size()),
3067 static_cast<int32_t>(upperBounds.size()),
3068 static_cast<int32_t>(steps.size()),
3069 static_cast<int32_t>(initVals.size())}));
3070 result.addTypes(initVals.getTypes());
3071
3072 OpBuilder::InsertionGuard guard(builder);
3073 unsigned numIVs = steps.size();
3074 SmallVector<Type, 8> argTypes(numIVs, builder.getIndexType());
3075 SmallVector<Location, 8> argLocs(numIVs, result.location);
3076 Region *bodyRegion = result.addRegion();
3077 Block *bodyBlock = builder.createBlock(bodyRegion, {}, argTypes, argLocs);
3078
3079 if (bodyBuilderFn) {
3080 builder.setInsertionPointToStart(bodyBlock);
3081 bodyBuilderFn(builder, result.location,
3082 bodyBlock->getArguments().take_front(numIVs),
3083 bodyBlock->getArguments().drop_front(numIVs));
3084 }
3085 // Add terminator only if there are no reductions.
3086 if (initVals.empty())
3087 ParallelOp::ensureTerminator(*bodyRegion, builder, result.location);
3088}
3089
3090void ParallelOp::build(
3091 OpBuilder &builder, OperationState &result, ValueRange lowerBounds,
3092 ValueRange upperBounds, ValueRange steps,
3093 function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilderFn) {
3094 // Only pass a non-null wrapper if bodyBuilderFn is non-null itself. Make sure
3095 // we don't capture a reference to a temporary by constructing the lambda at
3096 // function level.
3097 auto wrappedBuilderFn = [&bodyBuilderFn](OpBuilder &nestedBuilder,
3098 Location nestedLoc, ValueRange ivs,
3099 ValueRange) {
3100 bodyBuilderFn(nestedBuilder, nestedLoc, ivs);
3101 };
3102 function_ref<void(OpBuilder &, Location, ValueRange, ValueRange)> wrapper;
3103 if (bodyBuilderFn)
3104 wrapper = wrappedBuilderFn;
3105
3106 build(builder, result, lowerBounds, upperBounds, steps, ValueRange(),
3107 wrapper);
3108}
3109
3110LogicalResult ParallelOp::verify() {
3111 // Check that there is at least one value in lowerBound, upperBound and step.
3112 // It is sufficient to test only step, because it is ensured already that the
3113 // number of elements in lowerBound, upperBound and step are the same.
3114 Operation::operand_range stepValues = getStep();
3115 if (stepValues.empty())
3116 return emitOpError(
3117 "needs at least one tuple element for lowerBound, upperBound and step");
3118
3119 // Check whether all constant step values are positive.
3120 for (Value stepValue : stepValues)
3121 if (auto cst = getConstantIntValue(stepValue))
3122 if (*cst <= 0)
3123 return emitOpError("constant step operand must be positive");
3124
3125 // Check that the body defines the same number of block arguments as the
3126 // number of tuple elements in step.
3127 Block *body = getBody();
3128 if (body->getNumArguments() != stepValues.size())
3129 return emitOpError() << "expects the same number of induction variables: "
3130 << body->getNumArguments()
3131 << " as bound and step values: " << stepValues.size();
3132 for (auto arg : body->getArguments())
3133 if (!arg.getType().isIndex())
3134 return emitOpError(
3135 "expects arguments for the induction variable to be of index type");
3136
3137 // Check that the terminator is an scf.reduce op.
3139 *this, getRegion(), "expects body to terminate with 'scf.reduce'");
3140 if (!reduceOp)
3141 return failure();
3142
3143 // Check that the number of results is the same as the number of reductions.
3144 auto resultsSize = getResults().size();
3145 auto reductionsSize = reduceOp.getReductions().size();
3146 auto initValsSize = getInitVals().size();
3147 if (resultsSize != reductionsSize)
3148 return emitOpError() << "expects number of results: " << resultsSize
3149 << " to be the same as number of reductions: "
3150 << reductionsSize;
3151 if (resultsSize != initValsSize)
3152 return emitOpError() << "expects number of results: " << resultsSize
3153 << " to be the same as number of initial values: "
3154 << initValsSize;
3155 if (reduceOp.getNumOperands() != initValsSize)
3156 // Delegate error reporting to ReduceOp
3157 return success();
3158
3159 // Check that the types of the results and reductions are the same.
3160 for (int64_t i = 0; i < static_cast<int64_t>(reductionsSize); ++i) {
3161 auto resultType = getOperation()->getResult(i).getType();
3162 auto reductionOperandType = reduceOp.getOperands()[i].getType();
3163 if (resultType != reductionOperandType)
3164 return reduceOp.emitOpError()
3165 << "expects type of " << i
3166 << "-th reduction operand: " << reductionOperandType
3167 << " to be the same as the " << i
3168 << "-th result type: " << resultType;
3169 }
3170 return success();
3171}
3172
3173ParseResult ParallelOp::parse(OpAsmParser &parser, OperationState &result) {
3174 auto &builder = parser.getBuilder();
3175 // Parse an opening `(` followed by induction variables followed by `)`
3176 SmallVector<OpAsmParser::Argument, 4> ivs;
3178 return failure();
3179
3180 // Parse loop bounds.
3181 SmallVector<OpAsmParser::UnresolvedOperand, 4> lower;
3182 if (parser.parseEqual() ||
3183 parser.parseOperandList(lower, ivs.size(),
3185 parser.resolveOperands(lower, builder.getIndexType(), result.operands))
3186 return failure();
3187
3188 SmallVector<OpAsmParser::UnresolvedOperand, 4> upper;
3189 if (parser.parseKeyword("to") ||
3190 parser.parseOperandList(upper, ivs.size(),
3192 parser.resolveOperands(upper, builder.getIndexType(), result.operands))
3193 return failure();
3194
3195 // Parse step values.
3196 SmallVector<OpAsmParser::UnresolvedOperand, 4> steps;
3197 if (parser.parseKeyword("step") ||
3198 parser.parseOperandList(steps, ivs.size(),
3200 parser.resolveOperands(steps, builder.getIndexType(), result.operands))
3201 return failure();
3202
3203 // Parse init values.
3204 SmallVector<OpAsmParser::UnresolvedOperand, 4> initVals;
3205 if (succeeded(parser.parseOptionalKeyword("init"))) {
3206 if (parser.parseOperandList(initVals, OpAsmParser::Delimiter::Paren))
3207 return failure();
3208 }
3209
3210 // Parse optional results in case there is a reduce.
3211 if (parser.parseOptionalArrowTypeList(result.types))
3212 return failure();
3213
3214 // Now parse the body.
3215 Region *body = result.addRegion();
3216 for (auto &iv : ivs)
3217 iv.type = builder.getIndexType();
3218 if (parser.parseRegion(*body, ivs))
3219 return failure();
3220
3221 // Set `operandSegmentSizes` attribute.
3222 result.addAttribute(
3223 ParallelOp::getOperandSegmentSizeAttr(),
3224 builder.getDenseI32ArrayAttr({static_cast<int32_t>(lower.size()),
3225 static_cast<int32_t>(upper.size()),
3226 static_cast<int32_t>(steps.size()),
3227 static_cast<int32_t>(initVals.size())}));
3228
3229 // Parse attributes.
3230 if (parser.parseOptionalAttrDict(result.attributes) ||
3231 parser.resolveOperands(initVals, result.types, parser.getNameLoc(),
3232 result.operands))
3233 return failure();
3234
3235 // Add a terminator if none was parsed.
3236 ParallelOp::ensureTerminator(*body, builder, result.location);
3237 return success();
3238}
3239
3240void ParallelOp::print(OpAsmPrinter &p) {
3241 p << " (" << getBody()->getArguments() << ") = (" << getLowerBound()
3242 << ") to (" << getUpperBound() << ") step (" << getStep() << ")";
3243 if (!getInitVals().empty())
3244 p << " init (" << getInitVals() << ")";
3245 p.printOptionalArrowTypeList(getResultTypes());
3246 p << ' ';
3247 p.printRegion(getRegion(), /*printEntryBlockArgs=*/false);
3249 (*this)->getAttrs(),
3250 /*elidedAttrs=*/ParallelOp::getOperandSegmentSizeAttr());
3251}
3252
3253SmallVector<Region *> ParallelOp::getLoopRegions() { return {&getRegion()}; }
3254
3255std::optional<SmallVector<Value>> ParallelOp::getLoopInductionVars() {
3256 return SmallVector<Value>{getBody()->getArguments()};
3257}
3258
3259std::optional<SmallVector<OpFoldResult>> ParallelOp::getLoopLowerBounds() {
3260 return getLowerBound();
3261}
3262
3263std::optional<SmallVector<OpFoldResult>> ParallelOp::getLoopUpperBounds() {
3264 return getUpperBound();
3265}
3266
3267std::optional<SmallVector<OpFoldResult>> ParallelOp::getLoopSteps() {
3268 return getStep();
3269}
3270
3272 auto ivArg = llvm::dyn_cast<BlockArgument>(val);
3273 if (!ivArg)
3274 return ParallelOp();
3275 assert(ivArg.getOwner() && "unlinked block argument");
3276 auto *containingOp = ivArg.getOwner()->getParentOp();
3277 return dyn_cast<ParallelOp>(containingOp);
3278}
3279
3280namespace {
3281// Collapse loop dimensions that perform a single iteration.
3282struct ParallelOpSingleOrZeroIterationDimsFolder
3283 : public OpRewritePattern<ParallelOp> {
3284 using OpRewritePattern<ParallelOp>::OpRewritePattern;
3285
3286 LogicalResult matchAndRewrite(ParallelOp op,
3287 PatternRewriter &rewriter) const override {
3288 Location loc = op.getLoc();
3289
3290 // Compute new loop bounds that omit all single-iteration loop dimensions.
3291 SmallVector<Value> newLowerBounds, newUpperBounds, newSteps;
3292 IRMapping mapping;
3293 for (auto [lb, ub, step, iv] :
3294 llvm::zip(op.getLowerBound(), op.getUpperBound(), op.getStep(),
3295 op.getInductionVars())) {
3296 auto numIterations =
3297 constantTripCount(lb, ub, step, /*isSigned=*/true, computeUbMinusLb);
3298 if (numIterations.has_value()) {
3299 // Remove the loop if it performs zero iterations.
3300 if (*numIterations == 0) {
3301 rewriter.replaceOp(op, op.getInitVals());
3302 return success();
3303 }
3304 // Replace the loop induction variable by the lower bound if the loop
3305 // performs a single iteration. Otherwise, copy the loop bounds.
3306 if (*numIterations == 1) {
3307 mapping.map(iv, getValueOrCreateConstantIndexOp(rewriter, loc, lb));
3308 continue;
3309 }
3310 }
3311 newLowerBounds.push_back(lb);
3312 newUpperBounds.push_back(ub);
3313 newSteps.push_back(step);
3314 }
3315 // Exit if none of the loop dimensions perform a single iteration.
3316 if (newLowerBounds.size() == op.getLowerBound().size())
3317 return failure();
3318
3319 if (newLowerBounds.empty()) {
3320 // All of the loop dimensions perform a single iteration. Inline
3321 // loop body and nested ReduceOp's
3322 SmallVector<Value> results;
3323 results.reserve(op.getInitVals().size());
3324 for (auto &bodyOp : op.getBody()->without_terminator())
3325 rewriter.clone(bodyOp, mapping);
3326 auto reduceOp = cast<ReduceOp>(op.getBody()->getTerminator());
3327 for (int64_t i = 0, e = reduceOp.getReductions().size(); i < e; ++i) {
3328 Block &reduceBlock = reduceOp.getReductions()[i].front();
3329 auto initValIndex = results.size();
3330 mapping.map(reduceBlock.getArgument(0), op.getInitVals()[initValIndex]);
3331 mapping.map(reduceBlock.getArgument(1),
3332 mapping.lookupOrDefault(reduceOp.getOperands()[i]));
3333 for (auto &reduceBodyOp : reduceBlock.without_terminator())
3334 rewriter.clone(reduceBodyOp, mapping);
3335
3336 auto result = mapping.lookupOrDefault(
3337 cast<ReduceReturnOp>(reduceBlock.getTerminator()).getResult());
3338 results.push_back(result);
3339 }
3340
3341 rewriter.replaceOp(op, results);
3342 return success();
3343 }
3344 // Replace the parallel loop by lower-dimensional parallel loop.
3345 auto newOp =
3346 ParallelOp::create(rewriter, op.getLoc(), newLowerBounds,
3347 newUpperBounds, newSteps, op.getInitVals(), nullptr);
3348 // Erase the empty block that was inserted by the builder.
3349 rewriter.eraseBlock(newOp.getBody());
3350 // Clone the loop body and remap the block arguments of the collapsed loops
3351 // (inlining does not support a cancellable block argument mapping).
3352 rewriter.cloneRegionBefore(op.getRegion(), newOp.getRegion(),
3353 newOp.getRegion().begin(), mapping);
3354 rewriter.replaceOp(op, newOp.getResults());
3355 return success();
3356 }
3357};
3358
3359struct MergeNestedParallelLoops : public OpRewritePattern<ParallelOp> {
3360 using OpRewritePattern<ParallelOp>::OpRewritePattern;
3361
3362 LogicalResult matchAndRewrite(ParallelOp op,
3363 PatternRewriter &rewriter) const override {
3364 Block &outerBody = *op.getBody();
3365 if (!llvm::hasSingleElement(outerBody.without_terminator()))
3366 return failure();
3367
3368 auto innerOp = dyn_cast<ParallelOp>(outerBody.front());
3369 if (!innerOp)
3370 return failure();
3371
3372 for (auto val : outerBody.getArguments())
3373 if (llvm::is_contained(innerOp.getLowerBound(), val) ||
3374 llvm::is_contained(innerOp.getUpperBound(), val) ||
3375 llvm::is_contained(innerOp.getStep(), val))
3376 return failure();
3377
3378 // Reductions are not supported yet.
3379 if (!op.getInitVals().empty() || !innerOp.getInitVals().empty())
3380 return failure();
3381
3382 auto bodyBuilder = [&](OpBuilder &builder, Location /*loc*/,
3383 ValueRange iterVals, ValueRange) {
3384 Block &innerBody = *innerOp.getBody();
3385 assert(iterVals.size() ==
3386 (outerBody.getNumArguments() + innerBody.getNumArguments()));
3387 IRMapping mapping;
3388 mapping.map(outerBody.getArguments(),
3389 iterVals.take_front(outerBody.getNumArguments()));
3390 mapping.map(innerBody.getArguments(),
3391 iterVals.take_back(innerBody.getNumArguments()));
3392 for (Operation &op : innerBody.without_terminator())
3393 builder.clone(op, mapping);
3394 };
3395
3396 auto concatValues = [](const auto &first, const auto &second) {
3397 SmallVector<Value> ret;
3398 ret.reserve(first.size() + second.size());
3399 ret.assign(first.begin(), first.end());
3400 ret.append(second.begin(), second.end());
3401 return ret;
3402 };
3403
3404 auto newLowerBounds =
3405 concatValues(op.getLowerBound(), innerOp.getLowerBound());
3406 auto newUpperBounds =
3407 concatValues(op.getUpperBound(), innerOp.getUpperBound());
3408 auto newSteps = concatValues(op.getStep(), innerOp.getStep());
3409
3410 rewriter.replaceOpWithNewOp<ParallelOp>(op, newLowerBounds, newUpperBounds,
3411 newSteps, ValueRange(),
3412 bodyBuilder);
3413 return success();
3414 }
3415};
3416
3417} // namespace
3418
3419void ParallelOp::getCanonicalizationPatterns(RewritePatternSet &results,
3420 MLIRContext *context) {
3421 results
3422 .add<ParallelOpSingleOrZeroIterationDimsFolder, MergeNestedParallelLoops>(
3423 context);
3424}
3425
3426/// Given the region at `index`, or the parent operation if `index` is None,
3427/// return the successor regions. These are the regions that may be selected
3428/// during the flow of control. `operands` is a set of optional attributes that
3429/// correspond to a constant value for each operand, or null if that operand is
3430/// not a constant.
3431void ParallelOp::getSuccessorRegions(
3432 RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {
3433 // Both the operation itself and the region may be branching into the body or
3434 // back into the operation itself. It is possible for loop not to enter the
3435 // body.
3436 regions.push_back(RegionSuccessor(&getRegion()));
3437 regions.push_back(RegionSuccessor(
3438 getOperation(), ResultRange{getResults().end(), getResults().end()}));
3439}
3440
3441//===----------------------------------------------------------------------===//
3442// ReduceOp
3443//===----------------------------------------------------------------------===//
3444
3445void ReduceOp::build(OpBuilder &builder, OperationState &result) {}
3446
3447void ReduceOp::build(OpBuilder &builder, OperationState &result,
3448 ValueRange operands) {
3449 result.addOperands(operands);
3450 for (Value v : operands) {
3451 OpBuilder::InsertionGuard guard(builder);
3452 Region *bodyRegion = result.addRegion();
3453 builder.createBlock(bodyRegion, {},
3454 ArrayRef<Type>{v.getType(), v.getType()},
3455 {result.location, result.location});
3456 }
3457}
3458
3459LogicalResult ReduceOp::verifyRegions() {
3460 if (getReductions().size() != getOperands().size())
3461 return emitOpError() << "expects number of reduction regions: "
3462 << getReductions().size()
3463 << " to be the same as number of reduction operands: "
3464 << getOperands().size();
3465 // The region of a ReduceOp has two arguments of the same type as its
3466 // corresponding operand.
3467 for (int64_t i = 0, e = getReductions().size(); i < e; ++i) {
3468 auto type = getOperands()[i].getType();
3469 Block &block = getReductions()[i].front();
3470 if (block.empty())
3471 return emitOpError() << i << "-th reduction has an empty body";
3472 if (block.getNumArguments() != 2 ||
3473 llvm::any_of(block.getArguments(), [&](const BlockArgument &arg) {
3474 return arg.getType() != type;
3475 }))
3476 return emitOpError() << "expected two block arguments with type " << type
3477 << " in the " << i << "-th reduction region";
3478
3479 // Check that the block is terminated by a ReduceReturnOp.
3480 if (!isa<ReduceReturnOp>(block.getTerminator()))
3481 return emitOpError("reduction bodies must be terminated with an "
3482 "'scf.reduce.return' op");
3483 }
3484
3485 return success();
3486}
3487
3488MutableOperandRange
3489ReduceOp::getMutableSuccessorOperands(RegionSuccessor point) {
3490 // No operands are forwarded to the next iteration.
3491 return MutableOperandRange(getOperation(), /*start=*/0, /*length=*/0);
3492}
3493
3494//===----------------------------------------------------------------------===//
3495// ReduceReturnOp
3496//===----------------------------------------------------------------------===//
3497
3498LogicalResult ReduceReturnOp::verify() {
3499 // The type of the return value should be the same type as the types of the
3500 // block arguments of the reduction body.
3501 Block *reductionBody = getOperation()->getBlock();
3502 // Should already be verified by an op trait.
3503 assert(isa<ReduceOp>(reductionBody->getParentOp()) && "expected scf.reduce");
3504 Type expectedResultType = reductionBody->getArgument(0).getType();
3505 if (expectedResultType != getResult().getType())
3506 return emitOpError() << "must have type " << expectedResultType
3507 << " (the type of the reduction inputs)";
3508 return success();
3509}
3510
3511//===----------------------------------------------------------------------===//
3512// WhileOp
3513//===----------------------------------------------------------------------===//
3514
3515void WhileOp::build(::mlir::OpBuilder &odsBuilder,
3516 ::mlir::OperationState &odsState, TypeRange resultTypes,
3517 ValueRange inits, BodyBuilderFn beforeBuilder,
3518 BodyBuilderFn afterBuilder) {
3519 odsState.addOperands(inits);
3520 odsState.addTypes(resultTypes);
3521
3522 OpBuilder::InsertionGuard guard(odsBuilder);
3523
3524 // Build before region.
3525 SmallVector<Location, 4> beforeArgLocs;
3526 beforeArgLocs.reserve(inits.size());
3527 for (Value operand : inits) {
3528 beforeArgLocs.push_back(operand.getLoc());
3529 }
3530
3531 Region *beforeRegion = odsState.addRegion();
3532 Block *beforeBlock = odsBuilder.createBlock(beforeRegion, /*insertPt=*/{},
3533 inits.getTypes(), beforeArgLocs);
3534 if (beforeBuilder)
3535 beforeBuilder(odsBuilder, odsState.location, beforeBlock->getArguments());
3536
3537 // Build after region.
3538 SmallVector<Location, 4> afterArgLocs(resultTypes.size(), odsState.location);
3539
3540 Region *afterRegion = odsState.addRegion();
3541 Block *afterBlock = odsBuilder.createBlock(afterRegion, /*insertPt=*/{},
3542 resultTypes, afterArgLocs);
3543
3544 if (afterBuilder)
3545 afterBuilder(odsBuilder, odsState.location, afterBlock->getArguments());
3546}
3547
3548ConditionOp WhileOp::getConditionOp() {
3549 return cast<ConditionOp>(getBeforeBody()->getTerminator());
3550}
3551
3552YieldOp WhileOp::getYieldOp() {
3553 return cast<YieldOp>(getAfterBody()->getTerminator());
3554}
3555
3556std::optional<MutableArrayRef<OpOperand>> WhileOp::getYieldedValuesMutable() {
3557 return getYieldOp().getResultsMutable();
3558}
3559
3560Block::BlockArgListType WhileOp::getBeforeArguments() {
3561 return getBeforeBody()->getArguments();
3562}
3563
3564Block::BlockArgListType WhileOp::getAfterArguments() {
3565 return getAfterBody()->getArguments();
3566}
3567
3568Block::BlockArgListType WhileOp::getRegionIterArgs() {
3569 return getBeforeArguments();
3570}
3571
3572OperandRange WhileOp::getEntrySuccessorOperands(RegionSuccessor successor) {
3573 assert(successor.getSuccessor() == &getBefore() &&
3574 "WhileOp is expected to branch only to the first region");
3575 return getInits();
3576}
3577
3578void WhileOp::getSuccessorRegions(RegionBranchPoint point,
3579 SmallVectorImpl<RegionSuccessor> &regions) {
3580 // The parent op always branches to the condition region.
3581 if (point.isParent()) {
3582 regions.emplace_back(&getBefore(), getBefore().getArguments());
3583 return;
3584 }
3585
3586 assert(llvm::is_contained(
3587 {&getAfter(), &getBefore()},
3588 point.getTerminatorPredecessorOrNull()->getParentRegion()) &&
3589 "there are only two regions in a WhileOp");
3590 // The body region always branches back to the condition region.
3592 &getAfter()) {
3593 regions.emplace_back(&getBefore(), getBefore().getArguments());
3594 return;
3595 }
3596
3597 regions.emplace_back(getOperation(), getResults());
3598 regions.emplace_back(&getAfter(), getAfter().getArguments());
3599}
3600
3601SmallVector<Region *> WhileOp::getLoopRegions() {
3602 return {&getBefore(), &getAfter()};
3603}
3604
3605/// Parses a `while` op.
3606///
3607/// op ::= `scf.while` assignments `:` function-type region `do` region
3608/// `attributes` attribute-dict
3609/// initializer ::= /* empty */ | `(` assignment-list `)`
3610/// assignment-list ::= assignment | assignment `,` assignment-list
3611/// assignment ::= ssa-value `=` ssa-value
3612ParseResult scf::WhileOp::parse(OpAsmParser &parser, OperationState &result) {
3613 SmallVector<OpAsmParser::Argument, 4> regionArgs;
3614 SmallVector<OpAsmParser::UnresolvedOperand, 4> operands;
3615 Region *before = result.addRegion();
3616 Region *after = result.addRegion();
3617
3618 OptionalParseResult listResult =
3619 parser.parseOptionalAssignmentList(regionArgs, operands);
3620 if (listResult.has_value() && failed(listResult.value()))
3621 return failure();
3622
3623 FunctionType functionType;
3624 SMLoc typeLoc = parser.getCurrentLocation();
3625 if (failed(parser.parseColonType(functionType)))
3626 return failure();
3627
3628 result.addTypes(functionType.getResults());
3629
3630 if (functionType.getNumInputs() != operands.size()) {
3631 return parser.emitError(typeLoc)
3632 << "expected as many input types as operands " << "(expected "
3633 << operands.size() << " got " << functionType.getNumInputs() << ")";
3634 }
3635
3636 // Resolve input operands.
3637 if (failed(parser.resolveOperands(operands, functionType.getInputs(),
3638 parser.getCurrentLocation(),
3639 result.operands)))
3640 return failure();
3641
3642 // Propagate the types into the region arguments.
3643 for (size_t i = 0, e = regionArgs.size(); i != e; ++i)
3644 regionArgs[i].type = functionType.getInput(i);
3645
3646 return failure(parser.parseRegion(*before, regionArgs) ||
3647 parser.parseKeyword("do") || parser.parseRegion(*after) ||
3648 parser.parseOptionalAttrDictWithKeyword(result.attributes));
3649}
3650
3651/// Prints a `while` op.
3652void scf::WhileOp::print(OpAsmPrinter &p) {
3653 printInitializationList(p, getBeforeArguments(), getInits(), " ");
3654 p << " : ";
3655 p.printFunctionalType(getInits().getTypes(), getResults().getTypes());
3656 p << ' ';
3657 p.printRegion(getBefore(), /*printEntryBlockArgs=*/false);
3658 p << " do ";
3659 p.printRegion(getAfter());
3660 p.printOptionalAttrDictWithKeyword((*this)->getAttrs());
3661}
3662
3663/// Verifies that two ranges of types match, i.e. have the same number of
3664/// entries and that types are pairwise equals. Reports errors on the given
3665/// operation in case of mismatch.
3666template <typename OpTy>
3667static LogicalResult verifyTypeRangesMatch(OpTy op, TypeRange left,
3668 TypeRange right, StringRef message) {
3669 if (left.size() != right.size())
3670 return op.emitOpError("expects the same number of ") << message;
3671
3672 for (unsigned i = 0, e = left.size(); i < e; ++i) {
3673 if (left[i] != right[i]) {
3674 InFlightDiagnostic diag = op.emitOpError("expects the same types for ")
3675 << message;
3676 diag.attachNote() << "for argument " << i << ", found " << left[i]
3677 << " and " << right[i];
3678 return diag;
3679 }
3680 }
3681
3682 return success();
3683}
3684
3685LogicalResult scf::WhileOp::verify() {
3686 auto beforeTerminator = verifyAndGetTerminator<scf::ConditionOp>(
3687 *this, getBefore(),
3688 "expects the 'before' region to terminate with 'scf.condition'");
3689 if (!beforeTerminator)
3690 return failure();
3691
3692 auto afterTerminator = verifyAndGetTerminator<scf::YieldOp>(
3693 *this, getAfter(),
3694 "expects the 'after' region to terminate with 'scf.yield'");
3695 return success(afterTerminator != nullptr);
3696}
3697
3698namespace {
3699/// Move a scf.if op that is directly before the scf.condition op in the while
3700/// before region, and whose condition matches the condition of the
3701/// scf.condition op, down into the while after region.
3702///
3703/// scf.while (..) : (...) -> ... {
3704/// %additional_used_values = ...
3705/// %cond = ...
3706/// ...
3707/// %res = scf.if %cond -> (...) {
3708/// use(%additional_used_values)
3709/// ... // then block
3710/// scf.yield %then_value
3711/// } else {
3712/// scf.yield %else_value
3713/// }
3714/// scf.condition(%cond) %res, ...
3715/// } do {
3716/// ^bb0(%res_arg, ...):
3717/// use(%res_arg)
3718/// ...
3719///
3720/// becomes
3721/// scf.while (..) : (...) -> ... {
3722/// %additional_used_values = ...
3723/// %cond = ...
3724/// ...
3725/// scf.condition(%cond) %else_value, ..., %additional_used_values
3726/// } do {
3727/// ^bb0(%res_arg ..., %additional_args): :
3728/// use(%additional_args)
3729/// ... // if then block
3730/// use(%then_value)
3731/// ...
3732struct WhileMoveIfDown : public OpRewritePattern<scf::WhileOp> {
3733 using OpRewritePattern<scf::WhileOp>::OpRewritePattern;
3734
3735 LogicalResult matchAndRewrite(scf::WhileOp op,
3736 PatternRewriter &rewriter) const override {
3737 auto conditionOp = op.getConditionOp();
3738
3739 // Only support ifOp right before the condition at the moment. Relaxing this
3740 // would require to:
3741 // - check that the body does not have side-effects conflicting with
3742 // operations between the if and the condition.
3743 // - check that results of the if operation are only used as arguments to
3744 // the condition.
3745 auto ifOp = dyn_cast_or_null<scf::IfOp>(conditionOp->getPrevNode());
3746
3747 // Check that the ifOp is directly before the conditionOp and that it
3748 // matches the condition of the conditionOp. Also ensure that the ifOp has
3749 // no else block with content, as that would complicate the transformation.
3750 // TODO: support else blocks with content.
3751 if (!ifOp || ifOp.getCondition() != conditionOp.getCondition() ||
3752 (ifOp.elseBlock() && !ifOp.elseBlock()->without_terminator().empty()))
3753 return failure();
3754
3755 assert((ifOp->use_empty() || (llvm::all_equal(ifOp->getUsers()) &&
3756 *ifOp->user_begin() == conditionOp)) &&
3757 "ifOp has unexpected uses");
3758
3759 Location loc = op.getLoc();
3760
3761 // Replace uses of ifOp results in the conditionOp with the yielded values
3762 // from the ifOp branches.
3763 for (auto [idx, arg] : llvm::enumerate(conditionOp.getArgs())) {
3764 auto it = llvm::find(ifOp->getResults(), arg);
3765 if (it != ifOp->getResults().end()) {
3766 size_t ifOpIdx = it.getIndex();
3767 Value thenValue = ifOp.thenYield()->getOperand(ifOpIdx);
3768 Value elseValue = ifOp.elseYield()->getOperand(ifOpIdx);
3769
3770 rewriter.replaceAllUsesWith(ifOp->getResults()[ifOpIdx], elseValue);
3771 rewriter.replaceAllUsesWith(op.getAfterArguments()[idx], thenValue);
3772 }
3773 }
3774
3775 // Collect additional used values from before region.
3776 SetVector<Value> additionalUsedValuesSet;
3777 visitUsedValuesDefinedAbove(ifOp.getThenRegion(), [&](OpOperand *operand) {
3778 if (&op.getBefore() == operand->get().getParentRegion())
3779 additionalUsedValuesSet.insert(operand->get());
3780 });
3781
3782 // Create new whileOp with additional used values as results.
3783 auto additionalUsedValues = additionalUsedValuesSet.getArrayRef();
3784 auto additionalValueTypes = llvm::map_to_vector(
3785 additionalUsedValues, [](Value val) { return val.getType(); });
3786 size_t additionalValueSize = additionalUsedValues.size();
3787 SmallVector<Type> newResultTypes(op.getResultTypes());
3788 newResultTypes.append(additionalValueTypes);
3789
3790 auto newWhileOp =
3791 scf::WhileOp::create(rewriter, loc, newResultTypes, op.getInits());
3792
3793 rewriter.modifyOpInPlace(newWhileOp, [&] {
3794 newWhileOp.getBefore().takeBody(op.getBefore());
3795 newWhileOp.getAfter().takeBody(op.getAfter());
3796 newWhileOp.getAfter().addArguments(
3797 additionalValueTypes,
3798 SmallVector<Location>(additionalValueSize, loc));
3799 });
3800
3801 rewriter.modifyOpInPlace(conditionOp, [&] {
3802 conditionOp.getArgsMutable().append(additionalUsedValues);
3803 });
3804
3805 // Replace uses of additional used values inside the ifOp then region with
3806 // the whileOp after region arguments.
3807 rewriter.replaceUsesWithIf(
3808 additionalUsedValues,
3809 newWhileOp.getAfterArguments().take_back(additionalValueSize),
3810 [&](OpOperand &use) {
3811 return ifOp.getThenRegion().isAncestor(
3812 use.getOwner()->getParentRegion());
3813 });
3814
3815 // Inline ifOp then region into new whileOp after region.
3816 rewriter.eraseOp(ifOp.thenYield());
3817 rewriter.inlineBlockBefore(ifOp.thenBlock(), newWhileOp.getAfterBody(),
3818 newWhileOp.getAfterBody()->begin());
3819 rewriter.eraseOp(ifOp);
3820 rewriter.replaceOp(op,
3821 newWhileOp->getResults().drop_back(additionalValueSize));
3822 return success();
3823 }
3824};
3825
3826/// Replace uses of the condition within the do block with true, since otherwise
3827/// the block would not be evaluated.
3828///
3829/// scf.while (..) : (i1, ...) -> ... {
3830/// %condition = call @evaluate_condition() : () -> i1
3831/// scf.condition(%condition) %condition : i1, ...
3832/// } do {
3833/// ^bb0(%arg0: i1, ...):
3834/// use(%arg0)
3835/// ...
3836///
3837/// becomes
3838/// scf.while (..) : (i1, ...) -> ... {
3839/// %condition = call @evaluate_condition() : () -> i1
3840/// scf.condition(%condition) %condition : i1, ...
3841/// } do {
3842/// ^bb0(%arg0: i1, ...):
3843/// use(%true)
3844/// ...
3845struct WhileConditionTruth : public OpRewritePattern<WhileOp> {
3846 using OpRewritePattern<WhileOp>::OpRewritePattern;
3847
3848 LogicalResult matchAndRewrite(WhileOp op,
3849 PatternRewriter &rewriter) const override {
3850 auto term = op.getConditionOp();
3851
3852 // These variables serve to prevent creating duplicate constants
3853 // and hold constant true or false values.
3854 Value constantTrue = nullptr;
3855
3856 bool replaced = false;
3857 for (auto yieldedAndBlockArgs :
3858 llvm::zip(term.getArgs(), op.getAfterArguments())) {
3859 if (std::get<0>(yieldedAndBlockArgs) == term.getCondition()) {
3860 if (!std::get<1>(yieldedAndBlockArgs).use_empty()) {
3861 if (!constantTrue)
3862 constantTrue = arith::ConstantOp::create(
3863 rewriter, op.getLoc(), term.getCondition().getType(),
3864 rewriter.getBoolAttr(true));
3865
3866 rewriter.replaceAllUsesWith(std::get<1>(yieldedAndBlockArgs),
3867 constantTrue);
3868 replaced = true;
3869 }
3870 }
3871 }
3872 return success(replaced);
3873 }
3874};
3875
3876/// Remove loop invariant arguments from `before` block of scf.while.
3877/// A before block argument is considered loop invariant if :-
3878/// 1. i-th yield operand is equal to the i-th while operand.
3879/// 2. i-th yield operand is k-th after block argument which is (k+1)-th
3880/// condition operand AND this (k+1)-th condition operand is equal to i-th
3881/// iter argument/while operand.
3882/// For the arguments which are removed, their uses inside scf.while
3883/// are replaced with their corresponding initial value.
3884///
3885/// Eg:
3886/// INPUT :-
3887/// %res = scf.while <...> iter_args(%arg0_before = %a, %arg1_before = %b,
3888/// ..., %argN_before = %N)
3889/// {
3890/// ...
3891/// scf.condition(%cond) %arg1_before, %arg0_before,
3892/// %arg2_before, %arg0_before, ...
3893/// } do {
3894/// ^bb0(%arg1_after, %arg0_after_1, %arg2_after, %arg0_after_2,
3895/// ..., %argK_after):
3896/// ...
3897/// scf.yield %arg0_after_2, %b, %arg1_after, ..., %argN
3898/// }
3899///
3900/// OUTPUT :-
3901/// %res = scf.while <...> iter_args(%arg2_before = %c, ..., %argN_before =
3902/// %N)
3903/// {
3904/// ...
3905/// scf.condition(%cond) %b, %a, %arg2_before, %a, ...
3906/// } do {
3907/// ^bb0(%arg1_after, %arg0_after_1, %arg2_after, %arg0_after_2,
3908/// ..., %argK_after):
3909/// ...
3910/// scf.yield %arg1_after, ..., %argN
3911/// }
3912///
3913/// EXPLANATION:
3914/// We iterate over each yield operand.
3915/// 1. 0-th yield operand %arg0_after_2 is 4-th condition operand
3916/// %arg0_before, which in turn is the 0-th iter argument. So we
3917/// remove 0-th before block argument and yield operand, and replace
3918/// all uses of the 0-th before block argument with its initial value
3919/// %a.
3920/// 2. 1-th yield operand %b is equal to the 1-th iter arg's initial
3921/// value. So we remove this operand and the corresponding before
3922/// block argument and replace all uses of 1-th before block argument
3923/// with %b.
3924struct RemoveLoopInvariantArgsFromBeforeBlock
3925 : public OpRewritePattern<WhileOp> {
3926 using OpRewritePattern<WhileOp>::OpRewritePattern;
3927
3928 LogicalResult matchAndRewrite(WhileOp op,
3929 PatternRewriter &rewriter) const override {
3930 Block &afterBlock = *op.getAfterBody();
3931 Block::BlockArgListType beforeBlockArgs = op.getBeforeArguments();
3932 ConditionOp condOp = op.getConditionOp();
3933 OperandRange condOpArgs = condOp.getArgs();
3934 Operation *yieldOp = afterBlock.getTerminator();
3935 ValueRange yieldOpArgs = yieldOp->getOperands();
3936
3937 bool canSimplify = false;
3938 for (const auto &it :
3939 llvm::enumerate(llvm::zip(op.getOperands(), yieldOpArgs))) {
3940 auto index = static_cast<unsigned>(it.index());
3941 auto [initVal, yieldOpArg] = it.value();
3942 // If i-th yield operand is equal to the i-th operand of the scf.while,
3943 // the i-th before block argument is a loop invariant.
3944 if (yieldOpArg == initVal) {
3945 canSimplify = true;
3946 break;
3947 }
3948 // If the i-th yield operand is k-th after block argument, then we check
3949 // if the (k+1)-th condition op operand is equal to either the i-th before
3950 // block argument or the initial value of i-th before block argument. If
3951 // the comparison results `true`, i-th before block argument is a loop
3952 // invariant.
3953 auto yieldOpBlockArg = llvm::dyn_cast<BlockArgument>(yieldOpArg);
3954 if (yieldOpBlockArg && yieldOpBlockArg.getOwner() == &afterBlock) {
3955 Value condOpArg = condOpArgs[yieldOpBlockArg.getArgNumber()];
3956 if (condOpArg == beforeBlockArgs[index] || condOpArg == initVal) {
3957 canSimplify = true;
3958 break;
3959 }
3960 }
3961 }
3962
3963 if (!canSimplify)
3964 return failure();
3965
3966 SmallVector<Value> newInitArgs, newYieldOpArgs;
3967 DenseMap<unsigned, Value> beforeBlockInitValMap;
3968 SmallVector<Location> newBeforeBlockArgLocs;
3969 for (const auto &it :
3970 llvm::enumerate(llvm::zip(op.getOperands(), yieldOpArgs))) {
3971 auto index = static_cast<unsigned>(it.index());
3972 auto [initVal, yieldOpArg] = it.value();
3973
3974 // If i-th yield operand is equal to the i-th operand of the scf.while,
3975 // the i-th before block argument is a loop invariant.
3976 if (yieldOpArg == initVal) {
3977 beforeBlockInitValMap.insert({index, initVal});
3978 continue;
3979 } else {
3980 // If the i-th yield operand is k-th after block argument, then we check
3981 // if the (k+1)-th condition op operand is equal to either the i-th
3982 // before block argument or the initial value of i-th before block
3983 // argument. If the comparison results `true`, i-th before block
3984 // argument is a loop invariant.
3985 auto yieldOpBlockArg = llvm::dyn_cast<BlockArgument>(yieldOpArg);
3986 if (yieldOpBlockArg && yieldOpBlockArg.getOwner() == &afterBlock) {
3987 Value condOpArg = condOpArgs[yieldOpBlockArg.getArgNumber()];
3988 if (condOpArg == beforeBlockArgs[index] || condOpArg == initVal) {
3989 beforeBlockInitValMap.insert({index, initVal});
3990 continue;
3991 }
3992 }
3993 }
3994 newInitArgs.emplace_back(initVal);
3995 newYieldOpArgs.emplace_back(yieldOpArg);
3996 newBeforeBlockArgLocs.emplace_back(beforeBlockArgs[index].getLoc());
3997 }
3998
3999 {
4000 OpBuilder::InsertionGuard g(rewriter);
4001 rewriter.setInsertionPoint(yieldOp);
4002 rewriter.replaceOpWithNewOp<YieldOp>(yieldOp, newYieldOpArgs);
4003 }
4004
4005 auto newWhile = WhileOp::create(rewriter, op.getLoc(), op.getResultTypes(),
4006 newInitArgs);
4007
4008 Block &newBeforeBlock = *rewriter.createBlock(
4009 &newWhile.getBefore(), /*insertPt*/ {},
4010 ValueRange(newYieldOpArgs).getTypes(), newBeforeBlockArgLocs);
4011
4012 Block &beforeBlock = *op.getBeforeBody();
4013 SmallVector<Value> newBeforeBlockArgs(beforeBlock.getNumArguments());
4014 // For each i-th before block argument we find it's replacement value as :-
4015 // 1. If i-th before block argument is a loop invariant, we fetch it's
4016 // initial value from `beforeBlockInitValMap` by querying for key `i`.
4017 // 2. Else we fetch j-th new before block argument as the replacement
4018 // value of i-th before block argument.
4019 for (unsigned i = 0, j = 0, n = beforeBlock.getNumArguments(); i < n; i++) {
4020 // If the index 'i' argument was a loop invariant we fetch it's initial
4021 // value from `beforeBlockInitValMap`.
4022 if (beforeBlockInitValMap.count(i) != 0)
4023 newBeforeBlockArgs[i] = beforeBlockInitValMap[i];
4024 else
4025 newBeforeBlockArgs[i] = newBeforeBlock.getArgument(j++);
4026 }
4027
4028 rewriter.mergeBlocks(&beforeBlock, &newBeforeBlock, newBeforeBlockArgs);
4029 rewriter.inlineRegionBefore(op.getAfter(), newWhile.getAfter(),
4030 newWhile.getAfter().begin());
4031
4032 rewriter.replaceOp(op, newWhile.getResults());
4033 return success();
4034 }
4035};
4036
4037/// Remove loop invariant value from result (condition op) of scf.while.
4038/// A value is considered loop invariant if the final value yielded by
4039/// scf.condition is defined outside of the `before` block. We remove the
4040/// corresponding argument in `after` block and replace the use with the value.
4041/// We also replace the use of the corresponding result of scf.while with the
4042/// value.
4043///
4044/// Eg:
4045/// INPUT :-
4046/// %res_input:K = scf.while <...> iter_args(%arg0_before = , ...,
4047/// %argN_before = %N) {
4048/// ...
4049/// scf.condition(%cond) %arg0_before, %a, %b, %arg1_before, ...
4050/// } do {
4051/// ^bb0(%arg0_after, %arg1_after, %arg2_after, ..., %argK_after):
4052/// ...
4053/// some_func(%arg1_after)
4054/// ...
4055/// scf.yield %arg0_after, %arg2_after, ..., %argN_after
4056/// }
4057///
4058/// OUTPUT :-
4059/// %res_output:M = scf.while <...> iter_args(%arg0 = , ..., %argN = %N) {
4060/// ...
4061/// scf.condition(%cond) %arg0, %arg1, ..., %argM
4062/// } do {
4063/// ^bb0(%arg0, %arg3, ..., %argM):
4064/// ...
4065/// some_func(%a)
4066/// ...
4067/// scf.yield %arg0, %b, ..., %argN
4068/// }
4069///
4070/// EXPLANATION:
4071/// 1. The 1-th and 2-th operand of scf.condition are defined outside the
4072/// before block of scf.while, so they get removed.
4073/// 2. %res_input#1's uses are replaced by %a and %res_input#2's uses are
4074/// replaced by %b.
4075/// 3. The corresponding after block argument %arg1_after's uses are
4076/// replaced by %a and %arg2_after's uses are replaced by %b.
4077struct RemoveLoopInvariantValueYielded : public OpRewritePattern<WhileOp> {
4078 using OpRewritePattern<WhileOp>::OpRewritePattern;
4079
4080 LogicalResult matchAndRewrite(WhileOp op,
4081 PatternRewriter &rewriter) const override {
4082 Block &beforeBlock = *op.getBeforeBody();
4083 ConditionOp condOp = op.getConditionOp();
4084 OperandRange condOpArgs = condOp.getArgs();
4085
4086 bool canSimplify = false;
4087 for (Value condOpArg : condOpArgs) {
4088 // Those values not defined within `before` block will be considered as
4089 // loop invariant values. We map the corresponding `index` with their
4090 // value.
4091 if (condOpArg.getParentBlock() != &beforeBlock) {
4092 canSimplify = true;
4093 break;
4094 }
4095 }
4096
4097 if (!canSimplify)
4098 return failure();
4099
4100 Block::BlockArgListType afterBlockArgs = op.getAfterArguments();
4101
4102 SmallVector<Value> newCondOpArgs;
4103 SmallVector<Type> newAfterBlockType;
4104 DenseMap<unsigned, Value> condOpInitValMap;
4105 SmallVector<Location> newAfterBlockArgLocs;
4106 for (const auto &it : llvm::enumerate(condOpArgs)) {
4107 auto index = static_cast<unsigned>(it.index());
4108 Value condOpArg = it.value();
4109 // Those values not defined within `before` block will be considered as
4110 // loop invariant values. We map the corresponding `index` with their
4111 // value.
4112 if (condOpArg.getParentBlock() != &beforeBlock) {
4113 condOpInitValMap.insert({index, condOpArg});
4114 } else {
4115 newCondOpArgs.emplace_back(condOpArg);
4116 newAfterBlockType.emplace_back(condOpArg.getType());
4117 newAfterBlockArgLocs.emplace_back(afterBlockArgs[index].getLoc());
4118 }
4119 }
4120
4121 {
4122 OpBuilder::InsertionGuard g(rewriter);
4123 rewriter.setInsertionPoint(condOp);
4124 rewriter.replaceOpWithNewOp<ConditionOp>(condOp, condOp.getCondition(),
4125 newCondOpArgs);
4126 }
4127
4128 auto newWhile = WhileOp::create(rewriter, op.getLoc(), newAfterBlockType,
4129 op.getOperands());
4130
4131 Block &newAfterBlock =
4132 *rewriter.createBlock(&newWhile.getAfter(), /*insertPt*/ {},
4133 newAfterBlockType, newAfterBlockArgLocs);
4134
4135 Block &afterBlock = *op.getAfterBody();
4136 // Since a new scf.condition op was created, we need to fetch the new
4137 // `after` block arguments which will be used while replacing operations of
4138 // previous scf.while's `after` blocks. We'd also be fetching new result
4139 // values too.
4140 SmallVector<Value> newAfterBlockArgs(afterBlock.getNumArguments());
4141 SmallVector<Value> newWhileResults(afterBlock.getNumArguments());
4142 for (unsigned i = 0, j = 0, n = afterBlock.getNumArguments(); i < n; i++) {
4143 Value afterBlockArg, result;
4144 // If index 'i' argument was loop invariant we fetch it's value from the
4145 // `condOpInitMap` map.
4146 if (condOpInitValMap.count(i) != 0) {
4147 afterBlockArg = condOpInitValMap[i];
4148 result = afterBlockArg;
4149 } else {
4150 afterBlockArg = newAfterBlock.getArgument(j);
4151 result = newWhile.getResult(j);
4152 j++;
4153 }
4154 newAfterBlockArgs[i] = afterBlockArg;
4155 newWhileResults[i] = result;
4156 }
4157
4158 rewriter.mergeBlocks(&afterBlock, &newAfterBlock, newAfterBlockArgs);
4159 rewriter.inlineRegionBefore(op.getBefore(), newWhile.getBefore(),
4160 newWhile.getBefore().begin());
4161
4162 rewriter.replaceOp(op, newWhileResults);
4163 return success();
4164 }
4165};
4166
4167/// Remove WhileOp results that are also unused in 'after' block.
4168///
4169/// %0:2 = scf.while () : () -> (i32, i64) {
4170/// %condition = "test.condition"() : () -> i1
4171/// %v1 = "test.get_some_value"() : () -> i32
4172/// %v2 = "test.get_some_value"() : () -> i64
4173/// scf.condition(%condition) %v1, %v2 : i32, i64
4174/// } do {
4175/// ^bb0(%arg0: i32, %arg1: i64):
4176/// "test.use"(%arg0) : (i32) -> ()
4177/// scf.yield
4178/// }
4179/// return %0#0 : i32
4180///
4181/// becomes
4182/// %0 = scf.while () : () -> (i32) {
4183/// %condition = "test.condition"() : () -> i1
4184/// %v1 = "test.get_some_value"() : () -> i32
4185/// %v2 = "test.get_some_value"() : () -> i64
4186/// scf.condition(%condition) %v1 : i32
4187/// } do {
4188/// ^bb0(%arg0: i32):
4189/// "test.use"(%arg0) : (i32) -> ()
4190/// scf.yield
4191/// }
4192/// return %0 : i32
4193struct WhileUnusedResult : public OpRewritePattern<WhileOp> {
4194 using OpRewritePattern<WhileOp>::OpRewritePattern;
4195
4196 LogicalResult matchAndRewrite(WhileOp op,
4197 PatternRewriter &rewriter) const override {
4198 auto term = op.getConditionOp();
4199 auto afterArgs = op.getAfterArguments();
4200 auto termArgs = term.getArgs();
4201
4202 // Collect results mapping, new terminator args and new result types.
4203 SmallVector<unsigned> newResultsIndices;
4204 SmallVector<Type> newResultTypes;
4205 SmallVector<Value> newTermArgs;
4206 SmallVector<Location> newArgLocs;
4207 bool needUpdate = false;
4208 for (const auto &it :
4209 llvm::enumerate(llvm::zip(op.getResults(), afterArgs, termArgs))) {
4210 auto i = static_cast<unsigned>(it.index());
4211 Value result = std::get<0>(it.value());
4212 Value afterArg = std::get<1>(it.value());
4213 Value termArg = std::get<2>(it.value());
4214 if (result.use_empty() && afterArg.use_empty()) {
4215 needUpdate = true;
4216 } else {
4217 newResultsIndices.emplace_back(i);
4218 newTermArgs.emplace_back(termArg);
4219 newResultTypes.emplace_back(result.getType());
4220 newArgLocs.emplace_back(result.getLoc());
4221 }
4222 }
4223
4224 if (!needUpdate)
4225 return failure();
4226
4227 {
4228 OpBuilder::InsertionGuard g(rewriter);
4229 rewriter.setInsertionPoint(term);
4230 rewriter.replaceOpWithNewOp<ConditionOp>(term, term.getCondition(),
4231 newTermArgs);
4232 }
4233
4234 auto newWhile =
4235 WhileOp::create(rewriter, op.getLoc(), newResultTypes, op.getInits());
4236
4237 Block &newAfterBlock = *rewriter.createBlock(
4238 &newWhile.getAfter(), /*insertPt*/ {}, newResultTypes, newArgLocs);
4239
4240 // Build new results list and new after block args (unused entries will be
4241 // null).
4242 SmallVector<Value> newResults(op.getNumResults());
4243 SmallVector<Value> newAfterBlockArgs(op.getNumResults());
4244 for (const auto &it : llvm::enumerate(newResultsIndices)) {
4245 newResults[it.value()] = newWhile.getResult(it.index());
4246 newAfterBlockArgs[it.value()] = newAfterBlock.getArgument(it.index());
4247 }
4248
4249 rewriter.inlineRegionBefore(op.getBefore(), newWhile.getBefore(),
4250 newWhile.getBefore().begin());
4251
4252 Block &afterBlock = *op.getAfterBody();
4253 rewriter.mergeBlocks(&afterBlock, &newAfterBlock, newAfterBlockArgs);
4254
4255 rewriter.replaceOp(op, newResults);
4256 return success();
4257 }
4258};
4259
4260/// Replace operations equivalent to the condition in the do block with true,
4261/// since otherwise the block would not be evaluated.
4262///
4263/// scf.while (..) : (i32, ...) -> ... {
4264/// %z = ... : i32
4265/// %condition = cmpi pred %z, %a
4266/// scf.condition(%condition) %z : i32, ...
4267/// } do {
4268/// ^bb0(%arg0: i32, ...):
4269/// %condition2 = cmpi pred %arg0, %a
4270/// use(%condition2)
4271/// ...
4272///
4273/// becomes
4274/// scf.while (..) : (i32, ...) -> ... {
4275/// %z = ... : i32
4276/// %condition = cmpi pred %z, %a
4277/// scf.condition(%condition) %z : i32, ...
4278/// } do {
4279/// ^bb0(%arg0: i32, ...):
4280/// use(%true)
4281/// ...
4282struct WhileCmpCond : public OpRewritePattern<scf::WhileOp> {
4283 using OpRewritePattern<scf::WhileOp>::OpRewritePattern;
4284
4285 LogicalResult matchAndRewrite(scf::WhileOp op,
4286 PatternRewriter &rewriter) const override {
4287 using namespace scf;
4288 auto cond = op.getConditionOp();
4289 auto cmp = cond.getCondition().getDefiningOp<arith::CmpIOp>();
4290 if (!cmp)
4291 return failure();
4292 bool changed = false;
4293 for (auto tup : llvm::zip(cond.getArgs(), op.getAfterArguments())) {
4294 for (size_t opIdx = 0; opIdx < 2; opIdx++) {
4295 if (std::get<0>(tup) != cmp.getOperand(opIdx))
4296 continue;
4297 for (OpOperand &u :
4298 llvm::make_early_inc_range(std::get<1>(tup).getUses())) {
4299 auto cmp2 = dyn_cast<arith::CmpIOp>(u.getOwner());
4300 if (!cmp2)
4301 continue;
4302 // For a binary operator 1-opIdx gets the other side.
4303 if (cmp2.getOperand(1 - opIdx) != cmp.getOperand(1 - opIdx))
4304 continue;
4305 bool samePredicate;
4306 if (cmp2.getPredicate() == cmp.getPredicate())
4307 samePredicate = true;
4308 else if (cmp2.getPredicate() ==
4309 arith::invertPredicate(cmp.getPredicate()))
4310 samePredicate = false;
4311 else
4312 continue;
4313
4314 rewriter.replaceOpWithNewOp<arith::ConstantIntOp>(cmp2, samePredicate,
4315 1);
4316 changed = true;
4317 }
4318 }
4319 }
4320 return success(changed);
4321 }
4322};
4323
4324/// Remove unused init/yield args.
4325struct WhileRemoveUnusedArgs : public OpRewritePattern<WhileOp> {
4326 using OpRewritePattern<WhileOp>::OpRewritePattern;
4327
4328 LogicalResult matchAndRewrite(WhileOp op,
4329 PatternRewriter &rewriter) const override {
4330
4331 if (!llvm::any_of(op.getBeforeArguments(),
4332 [](Value arg) { return arg.use_empty(); }))
4333 return rewriter.notifyMatchFailure(op, "No args to remove");
4334
4335 YieldOp yield = op.getYieldOp();
4336
4337 // Collect results mapping, new terminator args and new result types.
4338 SmallVector<Value> newYields;
4339 SmallVector<Value> newInits;
4340 llvm::BitVector argsToErase;
4341
4342 size_t argsCount = op.getBeforeArguments().size();
4343 newYields.reserve(argsCount);
4344 newInits.reserve(argsCount);
4345 argsToErase.reserve(argsCount);
4346 for (auto &&[beforeArg, yieldValue, initValue] : llvm::zip(
4347 op.getBeforeArguments(), yield.getOperands(), op.getInits())) {
4348 if (beforeArg.use_empty()) {
4349 argsToErase.push_back(true);
4350 } else {
4351 argsToErase.push_back(false);
4352 newYields.emplace_back(yieldValue);
4353 newInits.emplace_back(initValue);
4354 }
4355 }
4356
4357 Block &beforeBlock = *op.getBeforeBody();
4358 Block &afterBlock = *op.getAfterBody();
4359
4360 beforeBlock.eraseArguments(argsToErase);
4361
4362 Location loc = op.getLoc();
4363 auto newWhileOp =
4364 WhileOp::create(rewriter, loc, op.getResultTypes(), newInits,
4365 /*beforeBody*/ nullptr, /*afterBody*/ nullptr);
4366 Block &newBeforeBlock = *newWhileOp.getBeforeBody();
4367 Block &newAfterBlock = *newWhileOp.getAfterBody();
4368
4369 OpBuilder::InsertionGuard g(rewriter);
4370 rewriter.setInsertionPoint(yield);
4371 rewriter.replaceOpWithNewOp<YieldOp>(yield, newYields);
4372
4373 rewriter.mergeBlocks(&beforeBlock, &newBeforeBlock,
4374 newBeforeBlock.getArguments());
4375 rewriter.mergeBlocks(&afterBlock, &newAfterBlock,
4376 newAfterBlock.getArguments());
4377
4378 rewriter.replaceOp(op, newWhileOp.getResults());
4379 return success();
4380 }
4381};
4382
4383/// Remove duplicated ConditionOp args.
4384struct WhileRemoveDuplicatedResults : public OpRewritePattern<WhileOp> {
4386
4387 LogicalResult matchAndRewrite(WhileOp op,
4388 PatternRewriter &rewriter) const override {
4389 ConditionOp condOp = op.getConditionOp();
4390 ValueRange condOpArgs = condOp.getArgs();
4391
4392 llvm::SmallPtrSet<Value, 8> argsSet(llvm::from_range, condOpArgs);
4393
4394 if (argsSet.size() == condOpArgs.size())
4395 return rewriter.notifyMatchFailure(op, "No results to remove");
4396
4397 llvm::SmallDenseMap<Value, unsigned> argsMap;
4398 SmallVector<Value> newArgs;
4399 argsMap.reserve(condOpArgs.size());
4400 newArgs.reserve(condOpArgs.size());
4401 for (Value arg : condOpArgs) {
4402 if (!argsMap.count(arg)) {
4403 auto pos = static_cast<unsigned>(argsMap.size());
4404 argsMap.insert({arg, pos});
4405 newArgs.emplace_back(arg);
4406 }
4407 }
4408
4409 ValueRange argsRange(newArgs);
4410
4411 Location loc = op.getLoc();
4412 auto newWhileOp =
4413 scf::WhileOp::create(rewriter, loc, argsRange.getTypes(), op.getInits(),
4414 /*beforeBody*/ nullptr,
4415 /*afterBody*/ nullptr);
4416 Block &newBeforeBlock = *newWhileOp.getBeforeBody();
4417 Block &newAfterBlock = *newWhileOp.getAfterBody();
4418
4419 SmallVector<Value> afterArgsMapping;
4420 SmallVector<Value> resultsMapping;
4421 for (auto &&[i, arg] : llvm::enumerate(condOpArgs)) {
4422 auto it = argsMap.find(arg);
4423 assert(it != argsMap.end());
4424 auto pos = it->second;
4425 afterArgsMapping.emplace_back(newAfterBlock.getArgument(pos));
4426 resultsMapping.emplace_back(newWhileOp->getResult(pos));
4427 }
4428
4429 OpBuilder::InsertionGuard g(rewriter);
4430 rewriter.setInsertionPoint(condOp);
4431 rewriter.replaceOpWithNewOp<ConditionOp>(condOp, condOp.getCondition(),
4432 argsRange);
4433
4434 Block &beforeBlock = *op.getBeforeBody();
4435 Block &afterBlock = *op.getAfterBody();
4436
4437 rewriter.mergeBlocks(&beforeBlock, &newBeforeBlock,
4438 newBeforeBlock.getArguments());
4439 rewriter.mergeBlocks(&afterBlock, &newAfterBlock, afterArgsMapping);
4440 rewriter.replaceOp(op, resultsMapping);
4441 return success();
4442 }
4443};
4444
4445/// If both ranges contain same values return mappping indices from args2 to
4446/// args1. Otherwise return std::nullopt.
4447static std::optional<SmallVector<unsigned>> getArgsMapping(ValueRange args1,
4448 ValueRange args2) {
4449 if (args1.size() != args2.size())
4450 return std::nullopt;
4451
4452 SmallVector<unsigned> ret(args1.size());
4453 for (auto &&[i, arg1] : llvm::enumerate(args1)) {
4454 auto it = llvm::find(args2, arg1);
4455 if (it == args2.end())
4456 return std::nullopt;
4457
4458 ret[std::distance(args2.begin(), it)] = static_cast<unsigned>(i);
4459 }
4460
4461 return ret;
4462}
4463
4464static bool hasDuplicates(ValueRange args) {
4465 llvm::SmallDenseSet<Value> set;
4466 for (Value arg : args) {
4467 if (!set.insert(arg).second)
4468 return true;
4469 }
4470 return false;
4471}
4472
4473/// If `before` block args are directly forwarded to `scf.condition`, rearrange
4474/// `scf.condition` args into same order as block args. Update `after` block
4475/// args and op result values accordingly.
4476/// Needed to simplify `scf.while` -> `scf.for` uplifting.
4477struct WhileOpAlignBeforeArgs : public OpRewritePattern<WhileOp> {
4479
4480 LogicalResult matchAndRewrite(WhileOp loop,
4481 PatternRewriter &rewriter) const override {
4482 auto *oldBefore = loop.getBeforeBody();
4483 ConditionOp oldTerm = loop.getConditionOp();
4484 ValueRange beforeArgs = oldBefore->getArguments();
4485 ValueRange termArgs = oldTerm.getArgs();
4486 if (beforeArgs == termArgs)
4487 return failure();
4488
4489 if (hasDuplicates(termArgs))
4490 return failure();
4491
4492 auto mapping = getArgsMapping(beforeArgs, termArgs);
4493 if (!mapping)
4494 return failure();
4495
4496 {
4497 OpBuilder::InsertionGuard g(rewriter);
4498 rewriter.setInsertionPoint(oldTerm);
4499 rewriter.replaceOpWithNewOp<ConditionOp>(oldTerm, oldTerm.getCondition(),
4500 beforeArgs);
4501 }
4502
4503 auto *oldAfter = loop.getAfterBody();
4504
4505 SmallVector<Type> newResultTypes(beforeArgs.size());
4506 for (auto &&[i, j] : llvm::enumerate(*mapping))
4507 newResultTypes[j] = loop.getResult(i).getType();
4508
4509 auto newLoop = WhileOp::create(
4510 rewriter, loop.getLoc(), newResultTypes, loop.getInits(),
4511 /*beforeBuilder=*/nullptr, /*afterBuilder=*/nullptr);
4512 auto *newBefore = newLoop.getBeforeBody();
4513 auto *newAfter = newLoop.getAfterBody();
4514
4515 SmallVector<Value> newResults(beforeArgs.size());
4516 SmallVector<Value> newAfterArgs(beforeArgs.size());
4517 for (auto &&[i, j] : llvm::enumerate(*mapping)) {
4518 newResults[i] = newLoop.getResult(j);
4519 newAfterArgs[i] = newAfter->getArgument(j);
4520 }
4521
4522 rewriter.inlineBlockBefore(oldBefore, newBefore, newBefore->begin(),
4523 newBefore->getArguments());
4524 rewriter.inlineBlockBefore(oldAfter, newAfter, newAfter->begin(),
4525 newAfterArgs);
4526
4527 rewriter.replaceOp(loop, newResults);
4528 return success();
4529 }
4530};
4531} // namespace
4532
4533void WhileOp::getCanonicalizationPatterns(RewritePatternSet &results,
4534 MLIRContext *context) {
4535 results.add<RemoveLoopInvariantArgsFromBeforeBlock,
4536 RemoveLoopInvariantValueYielded, WhileConditionTruth,
4537 WhileCmpCond, WhileUnusedResult, WhileRemoveDuplicatedResults,
4538 WhileRemoveUnusedArgs, WhileOpAlignBeforeArgs, WhileMoveIfDown>(
4539 context);
4540}
4541
4542//===----------------------------------------------------------------------===//
4543// IndexSwitchOp
4544//===----------------------------------------------------------------------===//
4545
4546/// Parse the case regions and values.
4547static ParseResult
4549 SmallVectorImpl<std::unique_ptr<Region>> &caseRegions) {
4550 SmallVector<int64_t> caseValues;
4551 while (succeeded(p.parseOptionalKeyword("case"))) {
4552 int64_t value;
4553 Region &region = *caseRegions.emplace_back(std::make_unique<Region>());
4554 if (p.parseInteger(value) || p.parseRegion(region, /*arguments=*/{}))
4555 return failure();
4556 caseValues.push_back(value);
4557 }
4558 cases = p.getBuilder().getDenseI64ArrayAttr(caseValues);
4559 return success();
4560}
4561
4562/// Print the case regions and values.
4564 DenseI64ArrayAttr cases, RegionRange caseRegions) {
4565 for (auto [value, region] : llvm::zip(cases.asArrayRef(), caseRegions)) {
4566 p.printNewline();
4567 p << "case " << value << ' ';
4568 p.printRegion(*region, /*printEntryBlockArgs=*/false);
4569 }
4570}
4571
4572LogicalResult scf::IndexSwitchOp::verify() {
4573 if (getCases().size() != getCaseRegions().size()) {
4574 return emitOpError("has ")
4575 << getCaseRegions().size() << " case regions but "
4576 << getCases().size() << " case values";
4577 }
4578
4579 DenseSet<int64_t> valueSet;
4580 for (int64_t value : getCases())
4581 if (!valueSet.insert(value).second)
4582 return emitOpError("has duplicate case value: ") << value;
4583 auto verifyRegion = [&](Region &region, const Twine &name) -> LogicalResult {
4584 auto yield = dyn_cast<YieldOp>(region.front().back());
4585 if (!yield)
4586 return emitOpError("expected region to end with scf.yield, but got ")
4587 << region.front().back().getName();
4588
4589 if (yield.getNumOperands() != getNumResults()) {
4590 return (emitOpError("expected each region to return ")
4591 << getNumResults() << " values, but " << name << " returns "
4592 << yield.getNumOperands())
4593 .attachNote(yield.getLoc())
4594 << "see yield operation here";
4595 }
4596 for (auto [idx, result, operand] :
4597 llvm::enumerate(getResultTypes(), yield.getOperands())) {
4598 if (!operand)
4599 return yield.emitOpError() << "operand " << idx << " is null\n";
4600 if (result == operand.getType())
4601 continue;
4602 return (emitOpError("expected result #")
4603 << idx << " of each region to be " << result)
4604 .attachNote(yield.getLoc())
4605 << name << " returns " << operand.getType() << " here";
4606 }
4607 return success();
4608 };
4609
4610 if (failed(verifyRegion(getDefaultRegion(), "default region")))
4611 return failure();
4612 for (auto [idx, caseRegion] : llvm::enumerate(getCaseRegions()))
4613 if (failed(verifyRegion(caseRegion, "case region #" + Twine(idx))))
4614 return failure();
4615
4616 return success();
4617}
4618
4619unsigned scf::IndexSwitchOp::getNumCases() { return getCases().size(); }
4620
4621Block &scf::IndexSwitchOp::getDefaultBlock() {
4622 return getDefaultRegion().front();
4623}
4624
4625Block &scf::IndexSwitchOp::getCaseBlock(unsigned idx) {
4626 assert(idx < getNumCases() && "case index out-of-bounds");
4627 return getCaseRegions()[idx].front();
4628}
4629
4630void IndexSwitchOp::getSuccessorRegions(
4631 RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &successors) {
4632 // All regions branch back to the parent op.
4633 if (!point.isParent()) {
4634 successors.emplace_back(getOperation(), getResults());
4635 return;
4636 }
4637
4638 llvm::append_range(successors, getRegions());
4639}
4640
4641void IndexSwitchOp::getEntrySuccessorRegions(
4642 ArrayRef<Attribute> operands,
4643 SmallVectorImpl<RegionSuccessor> &successors) {
4644 FoldAdaptor adaptor(operands, *this);
4645
4646 // If a constant was not provided, all regions are possible successors.
4647 auto arg = dyn_cast_or_null<IntegerAttr>(adaptor.getArg());
4648 if (!arg) {
4649 llvm::append_range(successors, getRegions());
4650 return;
4651 }
4652
4653 // Otherwise, try to find a case with a matching value. If not, the
4654 // default region is the only successor.
4655 for (auto [caseValue, caseRegion] : llvm::zip(getCases(), getCaseRegions())) {
4656 if (caseValue == arg.getInt()) {
4657 successors.emplace_back(&caseRegion);
4658 return;
4659 }
4660 }
4661 successors.emplace_back(&getDefaultRegion());
4662}
4663
4664void IndexSwitchOp::getRegionInvocationBounds(
4665 ArrayRef<Attribute> operands, SmallVectorImpl<InvocationBounds> &bounds) {
4666 auto operandValue = llvm::dyn_cast_or_null<IntegerAttr>(operands.front());
4667 if (!operandValue) {
4668 // All regions are invoked at most once.
4669 bounds.append(getNumRegions(), InvocationBounds(/*lb=*/0, /*ub=*/1));
4670 return;
4671 }
4672
4673 unsigned liveIndex = getNumRegions() - 1;
4674 const auto *it = llvm::find(getCases(), operandValue.getInt());
4675 if (it != getCases().end())
4676 liveIndex = std::distance(getCases().begin(), it);
4677 for (unsigned i = 0, e = getNumRegions(); i < e; ++i)
4678 bounds.emplace_back(/*lb=*/0, /*ub=*/i == liveIndex);
4679}
4680
4681struct FoldConstantCase : OpRewritePattern<scf::IndexSwitchOp> {
4682 using OpRewritePattern<scf::IndexSwitchOp>::OpRewritePattern;
4683
4684 LogicalResult matchAndRewrite(scf::IndexSwitchOp op,
4685 PatternRewriter &rewriter) const override {
4686 // If `op.getArg()` is a constant, select the region that matches with
4687 // the constant value. Use the default region if no matche is found.
4688 std::optional<int64_t> maybeCst = getConstantIntValue(op.getArg());
4689 if (!maybeCst.has_value())
4690 return failure();
4691 int64_t cst = *maybeCst;
4692 int64_t caseIdx, e = op.getNumCases();
4693 for (caseIdx = 0; caseIdx < e; ++caseIdx) {
4694 if (cst == op.getCases()[caseIdx])
4695 break;
4696 }
4697
4698 Region &r = (caseIdx < op.getNumCases()) ? op.getCaseRegions()[caseIdx]
4699 : op.getDefaultRegion();
4700 Block &source = r.front();
4701 Operation *terminator = source.getTerminator();
4702 SmallVector<Value> results = terminator->getOperands();
4703
4704 rewriter.inlineBlockBefore(&source, op);
4705 rewriter.eraseOp(terminator);
4706 // Replace the operation with a potentially empty list of results.
4707 // Fold mechanism doesn't support the case where the result list is empty.
4708 rewriter.replaceOp(op, results);
4709
4710 return success();
4711 }
4712};
4713
4714void IndexSwitchOp::getCanonicalizationPatterns(RewritePatternSet &results,
4715 MLIRContext *context) {
4716 results.add<FoldConstantCase>(context);
4717}
4718
4719//===----------------------------------------------------------------------===//
4720// TableGen'd op method definitions
4721//===----------------------------------------------------------------------===//
4722
4723#define GET_OP_CLASSES
4724#include "mlir/Dialect/SCF/IR/SCFOps.cpp.inc"
return success()
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
static std::optional< int64_t > getUpperBound(Value iv)
Gets the constant upper bound on an affine.for iv.
static std::optional< int64_t > getLowerBound(Value iv)
Gets the constant lower bound on an iv.
static LogicalResult verifyRegion(emitc::SwitchOp op, Region &region, const Twine &name)
Definition EmitC.cpp:1400
static ParseResult parseSwitchCases(OpAsmParser &parser, DenseI64ArrayAttr &cases, SmallVectorImpl< std::unique_ptr< Region > > &caseRegions)
Parse the case regions and values.
Definition EmitC.cpp:1375
static void printSwitchCases(OpAsmPrinter &p, Operation *op, DenseI64ArrayAttr cases, RegionRange caseRegions)
Print the case regions and values.
Definition EmitC.cpp:1391
static void replaceOpWithRegion(PatternRewriter &rewriter, Operation *op, Region &region, ValueRange blockArgs={})
Replaces the given op with the contents of the given single-block region, using the operands of the b...
Definition SCF.cpp:137
static LogicalResult verifyTypeRangesMatch(OpTy op, TypeRange left, TypeRange right, StringRef message)
Verifies that two ranges of types match, i.e.
Definition SCF.cpp:3667
static void printInitializationList(OpAsmPrinter &p, Block::BlockArgListType blocksArgs, ValueRange initializers, StringRef prefix="")
Prints the initialization list in the form of <prefix>(inner = outer, inner2 = outer2,...
Definition SCF.cpp:566
static TerminatorTy verifyAndGetTerminator(Operation *op, Region &region, StringRef errorMessage)
Verifies that the first block of the given region is terminated by a TerminatorTy.
Definition SCF.cpp:101
static bool isLegalToInline(InlinerInterface &interface, Region *src, Region *insertRegion, bool shouldCloneInlinedRegion, IRMapping &valueMapping)
Utility to check that all of the operations within 'src' can be inlined.
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
b getContext())
*if copies could not be generated due to yet unimplemented cases *copyInPlacementStart and copyOutPlacementStart in copyPlacementBlock *specify the insertion points where the incoming copies and outgoing should be the output argument nBegin is set to its * replacement(set to `begin` if no invalidation happens). Since outgoing *copies could have been inserted at `end`
@ None
static std::string diag(const llvm::Value &value)
@ Paren
Parens surrounding zero or more operands.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseOptionalColon()=0
Parse a : token if present.
ParseResult parseInteger(IntT &result)
Parse an integer value from the stream.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual ParseResult parseOptionalAttrDictWithKeyword(NamedAttrList &result)=0
Parse a named dictionary into 'result' if the attributes keyword is present.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseOptionalArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an optional arrow followed by a type list.
virtual ParseResult parseArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an arrow followed by a type list.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
void printOptionalArrowTypeList(TypeRange &&types)
Print an optional arrow followed by a type list.
This class represents an argument of a Block.
Definition Value.h:309
unsigned getArgNumber() const
Returns the number of this argument.
Definition Value.h:321
Block represents an ordered list of Operations.
Definition Block.h:33
MutableArrayRef< BlockArgument > BlockArgListType
Definition Block.h:85
bool empty()
Definition Block.h:148
BlockArgument getArgument(unsigned i)
Definition Block.h:129
unsigned getNumArguments()
Definition Block.h:128
iterator_range< args_iterator > addArguments(TypeRange types, ArrayRef< Location > locs)
Add one argument to the argument list for each type specified in the list.
Definition Block.cpp:160
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
Definition Block.cpp:27
Operation & front()
Definition Block.h:153
Operation & back()
Definition Block.h:152
Operation * getTerminator()
Get the terminator operation of this block.
Definition Block.cpp:244
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
Definition Block.cpp:153
void eraseArguments(unsigned start, unsigned num)
Erases 'num' arguments from the index 'start'.
Definition Block.cpp:201
BlockArgListType getArguments()
Definition Block.h:87
iterator_range< iterator > without_terminator()
Return an iterator range over the operation within this block excluding the terminator operation at t...
Definition Block.h:212
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
Definition Block.cpp:31
bool getValue() const
Return the boolean value of this attribute.
UnitAttr getUnitAttr()
Definition Builders.cpp:98
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
Definition Builders.cpp:163
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition Builders.cpp:228
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
Definition Builders.cpp:167
IntegerType getIntegerType(unsigned width)
Definition Builders.cpp:67
BoolAttr getBoolAttr(bool value)
Definition Builders.cpp:100
IntegerType getI1Type()
Definition Builders.cpp:53
MLIRContext * getContext() const
Definition Builders.h:56
IndexType getIndexType()
Definition Builders.cpp:51
virtual Operation * materializeConstant(OpBuilder &builder, Attribute value, Type type, Location loc)
Registered hook to materialize a single constant operation from a given attribute value with the desi...
Definition Dialect.h:83
This is a utility class for mapping one set of IR entities to another.
Definition IRMapping.h:26
auto lookupOrDefault(T from) const
Lookup a mapped value within the map.
Definition IRMapping.h:65
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Definition IRMapping.h:30
IRValueT get() const
Return the current value being used by this operand.
void set(IRValueT newValue)
Set the current value being used by this operand.
This class represents a diagnostic that is inflight and set to be reported.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
This class provides a mutable adaptor for a range of operands.
Definition ValueRange.h:118
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual OptionalParseResult parseOptionalAssignmentList(SmallVectorImpl< Argument > &lhs, SmallVectorImpl< UnresolvedOperand > &rhs)=0
virtual ParseResult parseRegion(Region &region, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult parseArgumentList(SmallVectorImpl< Argument > &result, Delimiter delimiter=Delimiter::None, bool allowType=false, bool allowAttrs=false)=0
Parse zero or more arguments with a specified surrounding delimiter.
ParseResult parseAssignmentList(SmallVectorImpl< Argument > &lhs, SmallVectorImpl< UnresolvedOperand > &rhs)
Parse a list of assignments of the form (x1 = y1, x2 = y2, ...)
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printNewline()=0
Print a newline and indent the printer to the start of the current operation.
virtual void printOptionalAttrDictWithKeyword(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary prefixed with 'attribute...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
void printFunctionalType(Operation *op)
Print the complete type of an operation in functional form.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
RAII guard to reset the insertion point of the builder when destroyed.
Definition Builders.h:348
This class helps build Operations.
Definition Builders.h:207
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition Builders.cpp:430
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition Builders.cpp:562
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition Builders.h:431
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition Builders.h:398
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Definition Builders.h:436
void cloneRegionBefore(Region &region, Region &parent, Region::iterator before, IRMapping &mapping)
Clone the blocks that belong to "region" before the given position in another region "parent".
Definition Builders.cpp:589
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition Builders.h:412
Operation * cloneWithoutRegions(Operation &op, IRMapping &mapper)
Creates a deep copy of this operation but keep the operation regions empty.
Definition Builders.h:594
This class represents a single result from folding an operation.
This class represents an operand of an operation.
Definition Value.h:257
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Definition Value.cpp:226
Set of flags used to control the behavior of the various IR print methods (e.g.
A wrapper class that allows for printing an operation with a set of flags, useful to act as a "stream...
Definition Operation.h:1111
This class implements the operand iterators for the Operation class.
Definition ValueRange.h:43
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Definition Operation.h:512
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition Operation.h:407
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:223
OperandRange operand_range
Definition Operation.h:371
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
Definition Operation.h:238
OperationName getName()
The name of an operation is the key identifier for it.
Definition Operation.h:119
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition Operation.h:378
Region * getParentRegion()
Returns the region to which the instruction belongs.
Definition Operation.h:230
bool isProperAncestor(Operation *other)
Return true if this operation is a proper ancestor of the other operation.
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
ParseResult value() const
Access the internal ParseResult value.
bool has_value() const
Returns true if we contain a valid ParseResult value.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class represents a point being branched from in the methods of the RegionBranchOpInterface.
bool isParent() const
Returns true if branching from the parent op.
Operation * getTerminatorPredecessorOrNull() const
Returns the terminator if branching from a region.
This class provides an abstraction over the different types of ranges over Regions.
Definition Region.h:346
This class represents a successor of a region.
bool isParent() const
Return true if the successor is the parent operation.
Region * getSuccessor() const
Return the given region successor.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition Region.h:26
Block & front()
Definition Region.h:65
Region * getParentRegion()
Return the region containing this region or nullptr if the region is attached to a top-level operatio...
Definition Region.cpp:45
bool isAncestor(Region *other)
Return true if this region is ancestor of the other region.
Definition Region.h:222
Block & back()
Definition Region.h:64
bool empty()
Definition Region.h:60
unsigned getNumArguments()
Definition Region.h:123
iterator begin()
Definition Region.h:55
BlockArgument getArgument(unsigned i)
Definition Region.h:124
bool hasOneBlock()
Return true if this region has exactly one block.
Definition Region.h:68
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void eraseBlock(Block *block)
This method erases all operations in a block.
Block * splitBlock(Block *block, Block::iterator before)
Split the operations starting at "before" (inclusive) out of the given block into a new block,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void finalizeOpModification(Operation *op)
This method is used to signal the end of an in-place modification of the given operation.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
virtual void replaceUsesWithIf(Value from, Value to, function_ref< bool(OpOperand &)> functor, bool *allUsesReplaced=nullptr)
Find uses of from and replace them with to if the functor returns true.
virtual void inlineBlockBefore(Block *source, Block *dest, Block::iterator before, ValueRange argValues={})
Inline the operations of block 'source' into block 'dest' before the given position.
void mergeBlocks(Block *source, Block *dest, ValueRange argValues={})
Inline the operations of block 'source' into the end of block 'dest'.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
void inlineRegionBefore(Region &region, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
virtual void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
virtual void startOpModification(Operation *op)
This method is used to notify the rewriter that an in-place operation modification is about to happen...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class provides an abstraction over the various different ranges of value types.
Definition TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
bool isIndex() const
Definition Types.cpp:54
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:387
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
bool use_empty() const
Returns true if this value has no uses.
Definition Value.h:208
Type getType() const
Return the type of this value.
Definition Value.h:105
Block * getParentBlock()
Return the Block in which this Value is defined.
Definition Value.cpp:46
user_range getUsers() const
Definition Value.h:218
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
Region * getParentRegion()
Return the Region in which this Value is defined.
Definition Value.cpp:39
Operation * getOwner() const
Return the owner of this operand.
Definition UseDefLists.h:38
constexpr auto RecursivelySpeculatable
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto NotSpeculatable
StringRef getMappingAttrName()
Name of the mapping attribute produced by loop mappers.
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:573
ParallelOp getParallelForInductionVarOwner(Value val)
Returns the parallel loop parent of an induction variable.
Definition SCF.cpp:3271
void buildTerminatedBody(OpBuilder &builder, Location loc)
Default callback for IfOp builders. Inserts a yield without arguments.
Definition SCF.cpp:94
LoopNest buildLoopNest(OpBuilder &builder, Location loc, ValueRange lbs, ValueRange ubs, ValueRange steps, ValueRange iterArgs, function_ref< ValueVector(OpBuilder &, Location, ValueRange, ValueRange)> bodyBuilder=nullptr)
Creates a perfect nest of "for" loops, i.e.
Definition SCF.cpp:838
bool insideMutuallyExclusiveBranches(Operation *a, Operation *b)
Return true if ops a and b (or their ancestors) are in mutually exclusive regions/blocks of an IfOp.
Definition SCF.cpp:2152
void promote(RewriterBase &rewriter, scf::ForallOp forallOp)
Promotes the loop body of a scf::ForallOp to its containing block.
Definition SCF.cpp:793
std::optional< llvm::APSInt > computeUbMinusLb(Value lb, Value ub, bool isSigned)
Helper function to compute the difference between two values.
Definition SCF.cpp:115
ForOp getForInductionVarOwner(Value val)
Returns the loop parent of an induction variable.
Definition SCF.cpp:745
SmallVector< Value > ValueVector
An owning vector of values, handy to return from functions.
Definition SCF.h:64
llvm::function_ref< Value(OpBuilder &, Location loc, Type, Value)> ValueTypeCastFnTy
Perform a replacement of one iter OpOperand of an scf.for to the replacement value with a different t...
Definition SCF.h:107
ForallOp getForallOpThreadIndexOwner(Value val)
Returns the ForallOp parent of an thread index variable.
Definition SCF.cpp:1610
SmallVector< Value > replaceAndCastForOpIterArg(RewriterBase &rewriter, scf::ForOp forOp, OpOperand &operand, Value replacement, const ValueTypeCastFnTy &castFn)
Definition SCF.cpp:927
bool preservesStaticInformation(Type source, Type target)
Returns true if target is a ranked tensor type that preserves static information available in the sou...
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition Matchers.h:490
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, MLIRContext *context)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
Definition Matchers.h:527
detail::DenseArrayAttrImpl< int64_t > DenseI64ArrayAttr
const FrozenRewritePatternSet GreedyRewriteConfig bool * changed
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
std::function< SmallVector< Value >( OpBuilder &b, Location loc, ArrayRef< BlockArgument > newBbArgs)> NewYieldValuesFn
A function that returns the additional yielded values during replaceWithAdditionalYields.
ParseResult parseDynamicIndexList(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &values, DenseI64ArrayAttr &integers, DenseBoolArrayAttr &scalableFlags, SmallVectorImpl< Type > *valueTypes=nullptr, AsmParser::Delimiter delimiter=AsmParser::Delimiter::Square)
Parser hooks for custom directive in assemblyFormat.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:304
llvm::DenseSet< ValueT, ValueInfoT > DenseSet
Definition LLVM.h:128
llvm::SetVector< T, Vector, Set, N > SetVector
Definition LLVM.h:131
detail::constant_int_predicate_matcher m_One()
Matches a constant scalar / vector splat / tensor splat integer one.
Definition Matchers.h:478
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition Utils.cpp:111
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
Definition LLVM.h:126
LogicalResult verifyListOfOperandsOrIntegers(Operation *op, StringRef name, unsigned expectedNumElements, ArrayRef< int64_t > attr, ValueRange values)
Verify that a the values has as many elements as the number of entries in attr for which isDynamic ev...
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition Matchers.h:369
void visitUsedValuesDefinedAbove(Region &region, Region &limit, function_ref< void(OpOperand *)> callback)
Calls callback for each use of a value within region or its descendants that was defined at the ances...
llvm::function_ref< Fn > function_ref
Definition LLVM.h:152
SmallVector< NamedAttribute > getPrunedAttributeList(Operation *op, ArrayRef< StringRef > elidedAttrs)
void printDynamicIndexList(OpAsmPrinter &printer, Operation *op, OperandRange values, ArrayRef< int64_t > integers, ArrayRef< bool > scalableFlags, TypeRange valueTypes=TypeRange(), AsmParser::Delimiter delimiter=AsmParser::Delimiter::Square)
Printer hooks for custom directive in assemblyFormat.
std::optional< APInt > constantTripCount(OpFoldResult lb, OpFoldResult ub, OpFoldResult step, bool isSigned, llvm::function_ref< std::optional< llvm::APSInt >(Value, Value, bool)> computeUbMinusLb)
Return the number of iterations for a loop with a lower bound lb, upper bound ub and step step.
LogicalResult foldDynamicIndexList(SmallVectorImpl< OpFoldResult > &ofrs, bool onlyNonNegative=false, bool onlyNonZero=false)
Returns "success" when any of the elements in ofrs is a constant value.
LogicalResult matchAndRewrite(ExecuteRegionOp op, PatternRewriter &rewriter) const override
Definition SCF.cpp:302
LogicalResult matchAndRewrite(scf::IndexSwitchOp op, PatternRewriter &rewriter) const override
Definition SCF.cpp:4684
LogicalResult matchAndRewrite(ExecuteRegionOp op, PatternRewriter &rewriter) const override
Definition SCF.cpp:261
LogicalResult matchAndRewrite(ExecuteRegionOp op, PatternRewriter &rewriter) const override
Definition SCF.cpp:212
This is the representation of an operand reference.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
This represents an operation in an abstracted form, suitable for use with the builder APIs.
void addOperands(ValueRange newOperands)
void addTypes(ArrayRef< Type > newTypes)
Region * addRegion()
Create a region that should be attached to the operation.