MLIR  20.0.0git
WrapInZeroTripCheck.cpp
Go to the documentation of this file.
1 //===- WrapInZeroTripCheck.cpp - Loop transforms to add zero-trip-check ---===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
11 #include "mlir/IR/IRMapping.h"
12 #include "mlir/IR/PatternMatch.h"
13 
14 using namespace mlir;
15 
16 /// Create zero-trip-check around a `while` op and return the new loop op in the
17 /// check. The while loop is rotated to avoid evaluating the condition twice.
18 ///
19 /// Given an example below:
20 ///
21 /// scf.while (%arg0 = %init) : (i32) -> i64 {
22 /// %val = .., %arg0 : i64
23 /// %cond = arith.cmpi .., %arg0 : i32
24 /// scf.condition(%cond) %val : i64
25 /// } do {
26 /// ^bb0(%arg1: i64):
27 /// %next = .., %arg1 : i32
28 /// scf.yield %next : i32
29 /// }
30 ///
31 /// First clone before block to the front of the loop:
32 ///
33 /// %pre_val = .., %init : i64
34 /// %pre_cond = arith.cmpi .., %init : i32
35 /// scf.while (%arg0 = %init) : (i32) -> i64 {
36 /// %val = .., %arg0 : i64
37 /// %cond = arith.cmpi .., %arg0 : i32
38 /// scf.condition(%cond) %val : i64
39 /// } do {
40 /// ^bb0(%arg1: i64):
41 /// %next = .., %arg1 : i32
42 /// scf.yield %next : i32
43 /// }
44 ///
45 /// Create `if` op with the condition, rotate and move the loop into the else
46 /// branch:
47 ///
48 /// %pre_val = .., %init : i64
49 /// %pre_cond = arith.cmpi .., %init : i32
50 /// scf.if %pre_cond -> i64 {
51 /// %res = scf.while (%arg1 = %va0) : (i64) -> i64 {
52 /// // Original after block
53 /// %next = .., %arg1 : i32
54 /// // Original before block
55 /// %val = .., %next : i64
56 /// %cond = arith.cmpi .., %next : i32
57 /// scf.condition(%cond) %val : i64
58 /// } do {
59 /// ^bb0(%arg2: i64):
60 /// %scf.yield %arg2 : i32
61 /// }
62 /// scf.yield %res : i64
63 /// } else {
64 /// scf.yield %pre_val : i64
65 /// }
66 FailureOr<scf::WhileOp> mlir::scf::wrapWhileLoopInZeroTripCheck(
67  scf::WhileOp whileOp, RewriterBase &rewriter, bool forceCreateCheck) {
68  // If the loop is in do-while form (after block only passes through values),
69  // there is no need to create a zero-trip-check as before block is always run.
70  if (!forceCreateCheck && isa<scf::YieldOp>(whileOp.getAfterBody()->front())) {
71  return whileOp;
72  }
73 
74  OpBuilder::InsertionGuard insertion_guard(rewriter);
75 
76  IRMapping mapper;
77  Block *beforeBlock = whileOp.getBeforeBody();
78  // Clone before block before the loop for zero-trip-check.
79  for (auto [arg, init] :
80  llvm::zip_equal(beforeBlock->getArguments(), whileOp.getInits())) {
81  mapper.map(arg, init);
82  }
83  rewriter.setInsertionPoint(whileOp);
84  for (auto &op : *beforeBlock) {
85  if (isa<scf::ConditionOp>(op)) {
86  break;
87  }
88  // Safe to clone everything as in a single block all defs have been cloned
89  // and added to mapper in order.
90  rewriter.insert(op.clone(mapper));
91  }
92 
93  scf::ConditionOp condOp = whileOp.getConditionOp();
94  Value clonedCondition = mapper.lookupOrDefault(condOp.getCondition());
95  SmallVector<Value> clonedCondArgs = llvm::map_to_vector(
96  condOp.getArgs(), [&](Value arg) { return mapper.lookupOrDefault(arg); });
97 
98  // Create rotated while loop.
99  auto newLoopOp = rewriter.create<scf::WhileOp>(
100  whileOp.getLoc(), whileOp.getResultTypes(), clonedCondArgs,
101  [&](OpBuilder &builder, Location loc, ValueRange args) {
102  // Rotate and move the loop body into before block.
103  auto newBlock = builder.getBlock();
104  rewriter.mergeBlocks(whileOp.getAfterBody(), newBlock, args);
105  auto yieldOp = cast<scf::YieldOp>(newBlock->getTerminator());
106  rewriter.mergeBlocks(whileOp.getBeforeBody(), newBlock,
107  yieldOp.getResults());
108  rewriter.eraseOp(yieldOp);
109  },
110  [&](OpBuilder &builder, Location loc, ValueRange args) {
111  // Pass through values.
112  builder.create<scf::YieldOp>(loc, args);
113  });
114 
115  // Create zero-trip-check and move the while loop in.
116  auto ifOp = rewriter.create<scf::IfOp>(
117  whileOp.getLoc(), clonedCondition,
118  [&](OpBuilder &builder, Location loc) {
119  // Then runs the while loop.
120  rewriter.moveOpBefore(newLoopOp, builder.getInsertionBlock(),
121  builder.getInsertionPoint());
122  builder.create<scf::YieldOp>(loc, newLoopOp.getResults());
123  },
124  [&](OpBuilder &builder, Location loc) {
125  // Else returns the results from precondition.
126  builder.create<scf::YieldOp>(loc, clonedCondArgs);
127  });
128 
129  rewriter.replaceOp(whileOp, ifOp);
130 
131  return newLoopOp;
132 }
Block represents an ordered list of Operations.
Definition: Block.h:31
BlockArgListType getArguments()
Definition: Block.h:85
This is a utility class for mapping one set of IR entities to another.
Definition: IRMapping.h:26
auto lookupOrDefault(T from) const
Lookup a mapped value within the map.
Definition: IRMapping.h:65
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Definition: IRMapping.h:30
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:356
This class helps build Operations.
Definition: Builders.h:215
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:406
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:488
Operation * insert(Operation *op)
Insert the given operation at the current insertion point and return it.
Definition: Builders.cpp:452
Operation * clone(IRMapping &mapper, CloneOptions options=CloneOptions::all())
Create a deep copy of this operation, remapping any operands that use values outside of the operation...
Definition: Operation.cpp:717
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:400
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void mergeBlocks(Block *source, Block *dest, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into the end of block 'dest'.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void moveOpBefore(Operation *op, Operation *existingOp)
Unlink this operation from its current block and insert it right before existingOp which may be in th...
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
FailureOr< WhileOp > wrapWhileLoopInZeroTripCheck(WhileOp whileOp, RewriterBase &rewriter, bool forceCreateCheck=false)
Create zero-trip-check around a while op and return the new loop op in the check.
Include the generated interface declarations.