MLIR  21.0.0git
WrapInZeroTripCheck.cpp
Go to the documentation of this file.
1 //===- WrapInZeroTripCheck.cpp - Loop transforms to add zero-trip-check ---===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
11 #include "mlir/IR/IRMapping.h"
12 #include "mlir/IR/PatternMatch.h"
13 
14 using namespace mlir;
15 
16 /// Create zero-trip-check around a `while` op and return the new loop op in the
17 /// check. The while loop is rotated to avoid evaluating the condition twice.
18 ///
19 /// Given an example below:
20 ///
21 /// scf.while (%arg0 = %init) : (i32) -> i64 {
22 /// %val = .., %arg0 : i64
23 /// %cond = arith.cmpi .., %arg0 : i32
24 /// scf.condition(%cond) %val : i64
25 /// } do {
26 /// ^bb0(%arg1: i64):
27 /// %next = .., %arg1 : i32
28 /// scf.yield %next : i32
29 /// }
30 ///
31 /// First clone before block to the front of the loop:
32 ///
33 /// %pre_val = .., %init : i64
34 /// %pre_cond = arith.cmpi .., %init : i32
35 /// scf.while (%arg0 = %init) : (i32) -> i64 {
36 /// %val = .., %arg0 : i64
37 /// %cond = arith.cmpi .., %arg0 : i32
38 /// scf.condition(%cond) %val : i64
39 /// } do {
40 /// ^bb0(%arg1: i64):
41 /// %next = .., %arg1 : i32
42 /// scf.yield %next : i32
43 /// }
44 ///
45 /// Create `if` op with the condition, rotate and move the loop into the else
46 /// branch:
47 ///
48 /// %pre_val = .., %init : i64
49 /// %pre_cond = arith.cmpi .., %init : i32
50 /// scf.if %pre_cond -> i64 {
51 /// %res = scf.while (%arg1 = %va0) : (i64) -> i64 {
52 /// // Original after block
53 /// %next = .., %arg1 : i32
54 /// // Original before block
55 /// %val = .., %next : i64
56 /// %cond = arith.cmpi .., %next : i32
57 /// scf.condition(%cond) %val : i64
58 /// } do {
59 /// ^bb0(%arg2: i64):
60 /// %scf.yield %arg2 : i32
61 /// }
62 /// scf.yield %res : i64
63 /// } else {
64 /// scf.yield %pre_val : i64
65 /// }
66 FailureOr<scf::WhileOp> mlir::scf::wrapWhileLoopInZeroTripCheck(
67  scf::WhileOp whileOp, RewriterBase &rewriter, bool forceCreateCheck) {
68  // If the loop is in do-while form (after block only passes through values),
69  // there is no need to create a zero-trip-check as before block is always run.
70  if (!forceCreateCheck && isa<scf::YieldOp>(whileOp.getAfterBody()->front())) {
71  return whileOp;
72  }
73 
74  OpBuilder::InsertionGuard insertion_guard(rewriter);
75 
76  IRMapping mapper;
77  Block *beforeBlock = whileOp.getBeforeBody();
78  // Clone before block before the loop for zero-trip-check.
79  for (auto [arg, init] :
80  llvm::zip_equal(beforeBlock->getArguments(), whileOp.getInits())) {
81  mapper.map(arg, init);
82  }
83  rewriter.setInsertionPoint(whileOp);
84  for (auto &op : *beforeBlock) {
85  if (isa<scf::ConditionOp>(op)) {
86  break;
87  }
88  // Safe to clone everything as in a single block all defs have been cloned
89  // and added to mapper in order.
90  rewriter.insert(op.clone(mapper));
91  }
92 
93  scf::ConditionOp condOp = whileOp.getConditionOp();
94  Value clonedCondition = mapper.lookupOrDefault(condOp.getCondition());
95  SmallVector<Value> clonedCondArgs = llvm::map_to_vector(
96  condOp.getArgs(), [&](Value arg) { return mapper.lookupOrDefault(arg); });
97 
98  // Create rotated while loop.
99  auto newLoopOp = rewriter.create<scf::WhileOp>(
100  whileOp.getLoc(), whileOp.getResultTypes(), clonedCondArgs,
101  [&](OpBuilder &builder, Location loc, ValueRange args) {
102  // Rotate and move the loop body into before block.
103  auto newBlock = builder.getBlock();
104  rewriter.mergeBlocks(whileOp.getAfterBody(), newBlock, args);
105  auto yieldOp = cast<scf::YieldOp>(newBlock->getTerminator());
106  rewriter.mergeBlocks(whileOp.getBeforeBody(), newBlock,
107  yieldOp.getResults());
108  rewriter.eraseOp(yieldOp);
109  },
110  [&](OpBuilder &builder, Location loc, ValueRange args) {
111  // Pass through values.
112  builder.create<scf::YieldOp>(loc, args);
113  });
114 
115  // Create zero-trip-check and move the while loop in.
116  auto ifOp = rewriter.create<scf::IfOp>(
117  whileOp.getLoc(), clonedCondition,
118  [&](OpBuilder &builder, Location loc) {
119  // Then runs the while loop.
120  rewriter.moveOpBefore(newLoopOp, builder.getInsertionBlock(),
121  builder.getInsertionPoint());
122  builder.create<scf::YieldOp>(loc, newLoopOp.getResults());
123  },
124  [&](OpBuilder &builder, Location loc) {
125  // Else returns the results from precondition.
126  builder.create<scf::YieldOp>(loc, clonedCondArgs);
127  });
128 
129  rewriter.replaceOp(whileOp, ifOp);
130 
131  return newLoopOp;
132 }
Block represents an ordered list of Operations.
Definition: Block.h:33
BlockArgListType getArguments()
Definition: Block.h:87
This is a utility class for mapping one set of IR entities to another.
Definition: IRMapping.h:26
auto lookupOrDefault(T from) const
Lookup a mapped value within the map.
Definition: IRMapping.h:65
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Definition: IRMapping.h:30
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:346
This class helps build Operations.
Definition: Builders.h:205
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:396
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:455
Operation * insert(Operation *op)
Insert the given operation at the current insertion point and return it.
Definition: Builders.cpp:419
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:358
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void moveOpBefore(Operation *op, Operation *existingOp)
Unlink this operation from its current block and insert it right before existingOp which may be in th...
void mergeBlocks(Block *source, Block *dest, ValueRange argValues={})
Inline the operations of block 'source' into the end of block 'dest'.
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
FailureOr< WhileOp > wrapWhileLoopInZeroTripCheck(WhileOp whileOp, RewriterBase &rewriter, bool forceCreateCheck=false)
Create zero-trip-check around a while op and return the new loop op in the check.
Include the generated interface declarations.