MLIR 22.0.0git
WrapInZeroTripCheck.cpp
Go to the documentation of this file.
1//===- WrapInZeroTripCheck.cpp - Loop transforms to add zero-trip-check ---===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
11#include "mlir/IR/IRMapping.h"
13
14using namespace mlir;
15
16/// Create zero-trip-check around a `while` op and return the new loop op in the
17/// check. The while loop is rotated to avoid evaluating the condition twice.
18///
19/// Given an example below:
20///
21/// scf.while (%arg0 = %init) : (i32) -> i64 {
22/// %val = .., %arg0 : i64
23/// %cond = arith.cmpi .., %arg0 : i32
24/// scf.condition(%cond) %val : i64
25/// } do {
26/// ^bb0(%arg1: i64):
27/// %next = .., %arg1 : i32
28/// scf.yield %next : i32
29/// }
30///
31/// First clone before block to the front of the loop:
32///
33/// %pre_val = .., %init : i64
34/// %pre_cond = arith.cmpi .., %init : i32
35/// scf.while (%arg0 = %init) : (i32) -> i64 {
36/// %val = .., %arg0 : i64
37/// %cond = arith.cmpi .., %arg0 : i32
38/// scf.condition(%cond) %val : i64
39/// } do {
40/// ^bb0(%arg1: i64):
41/// %next = .., %arg1 : i32
42/// scf.yield %next : i32
43/// }
44///
45/// Create `if` op with the condition, rotate and move the loop into the else
46/// branch:
47///
48/// %pre_val = .., %init : i64
49/// %pre_cond = arith.cmpi .., %init : i32
50/// scf.if %pre_cond -> i64 {
51/// %res = scf.while (%arg1 = %va0) : (i64) -> i64 {
52/// // Original after block
53/// %next = .., %arg1 : i32
54/// // Original before block
55/// %val = .., %next : i64
56/// %cond = arith.cmpi .., %next : i32
57/// scf.condition(%cond) %val : i64
58/// } do {
59/// ^bb0(%arg2: i64):
60/// %scf.yield %arg2 : i32
61/// }
62/// scf.yield %res : i64
63/// } else {
64/// scf.yield %pre_val : i64
65/// }
66FailureOr<scf::WhileOp> mlir::scf::wrapWhileLoopInZeroTripCheck(
67 scf::WhileOp whileOp, RewriterBase &rewriter, bool forceCreateCheck) {
68 // If the loop is in do-while form (after block only passes through values),
69 // there is no need to create a zero-trip-check as before block is always run.
70 if (!forceCreateCheck && isa<scf::YieldOp>(whileOp.getAfterBody()->front())) {
71 return whileOp;
72 }
73
74 OpBuilder::InsertionGuard insertion_guard(rewriter);
75
76 IRMapping mapper;
77 Block *beforeBlock = whileOp.getBeforeBody();
78 // Clone before block before the loop for zero-trip-check.
79 for (auto [arg, init] :
80 llvm::zip_equal(beforeBlock->getArguments(), whileOp.getInits())) {
81 mapper.map(arg, init);
82 }
83 rewriter.setInsertionPoint(whileOp);
84 for (auto &op : *beforeBlock) {
85 if (isa<scf::ConditionOp>(op)) {
86 break;
87 }
88 // Safe to clone everything as in a single block all defs have been cloned
89 // and added to mapper in order.
90 rewriter.insert(op.clone(mapper));
91 }
92
93 scf::ConditionOp condOp = whileOp.getConditionOp();
94 Value clonedCondition = mapper.lookupOrDefault(condOp.getCondition());
95 SmallVector<Value> clonedCondArgs = llvm::map_to_vector(
96 condOp.getArgs(), [&](Value arg) { return mapper.lookupOrDefault(arg); });
97
98 // Create rotated while loop.
99 auto newLoopOp = scf::WhileOp::create(
100 rewriter, whileOp.getLoc(), whileOp.getResultTypes(), clonedCondArgs,
101 [&](OpBuilder &builder, Location loc, ValueRange args) {
102 // Rotate and move the loop body into before block.
103 auto newBlock = builder.getBlock();
104 rewriter.mergeBlocks(whileOp.getAfterBody(), newBlock, args);
105 auto yieldOp = cast<scf::YieldOp>(newBlock->getTerminator());
106 rewriter.mergeBlocks(whileOp.getBeforeBody(), newBlock,
107 yieldOp.getResults());
108 rewriter.eraseOp(yieldOp);
109 },
110 [&](OpBuilder &builder, Location loc, ValueRange args) {
111 // Pass through values.
112 scf::YieldOp::create(builder, loc, args);
113 });
114
115 // Create zero-trip-check and move the while loop in.
116 auto ifOp = scf::IfOp::create(
117 rewriter, whileOp.getLoc(), clonedCondition,
118 [&](OpBuilder &builder, Location loc) {
119 // Then runs the while loop.
120 rewriter.moveOpBefore(newLoopOp, builder.getInsertionBlock(),
121 builder.getInsertionPoint());
122 scf::YieldOp::create(builder, loc, newLoopOp.getResults());
123 },
124 [&](OpBuilder &builder, Location loc) {
125 // Else returns the results from precondition.
126 scf::YieldOp::create(builder, loc, clonedCondArgs);
127 });
128
129 rewriter.replaceOp(whileOp, ifOp);
130
131 return newLoopOp;
132}
Block represents an ordered list of Operations.
Definition Block.h:33
BlockArgListType getArguments()
Definition Block.h:87
This is a utility class for mapping one set of IR entities to another.
Definition IRMapping.h:26
auto lookupOrDefault(T from) const
Lookup a mapped value within the map.
Definition IRMapping.h:65
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Definition IRMapping.h:30
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
RAII guard to reset the insertion point of the builder when destroyed.
Definition Builders.h:348
This class helps build Operations.
Definition Builders.h:207
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition Builders.h:398
Operation * insert(Operation *op)
Insert the given operation at the current insertion point and return it.
Definition Builders.cpp:421
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
FailureOr< WhileOp > wrapWhileLoopInZeroTripCheck(WhileOp whileOp, RewriterBase &rewriter, bool forceCreateCheck=false)
Create zero-trip-check around a while op and return the new loop op in the check.
Include the generated interface declarations.