MLIR  22.0.0git
Transforms.cpp
Go to the documentation of this file.
1 //===- Transforms.cpp - Patterns and transforms for the EmitC dialect -----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
11 #include "mlir/IR/IRMapping.h"
12 #include "mlir/IR/Location.h"
13 #include "mlir/IR/PatternMatch.h"
14 #include "llvm/ADT/STLExtras.h"
15 
16 namespace mlir {
17 namespace emitc {
18 
19 ExpressionOp createExpression(Operation *op, OpBuilder &builder) {
20  assert(isa<emitc::CExpressionInterface>(op) && "Expected a C expression");
21 
22  // Create an expression yielding the value returned by op.
23  assert(op->getNumResults() == 1 && "Expected exactly one result");
24  Value result = op->getResult(0);
25  Type resultType = result.getType();
26  Location loc = op->getLoc();
27 
28  builder.setInsertionPointAfter(op);
29  auto expressionOp =
30  emitc::ExpressionOp::create(builder, loc, resultType, op->getOperands());
31 
32  // Replace all op's uses with the new expression's result.
33  result.replaceAllUsesWith(expressionOp.getResult());
34 
35  Block &block = expressionOp.createBody();
36  IRMapping mapper;
37  for (auto [operand, arg] :
38  llvm::zip(expressionOp.getOperands(), block.getArguments()))
39  mapper.map(operand, arg);
40  builder.setInsertionPointToEnd(&block);
41 
42  Operation *rootOp = builder.clone(*op, mapper);
43  op->erase();
44 
45  // Create an op to yield op's value.
46  emitc::YieldOp::create(builder, loc, rootOp->getResults()[0]);
47  return expressionOp;
48 }
49 
50 } // namespace emitc
51 } // namespace mlir
52 
53 using namespace mlir;
54 using namespace mlir::emitc;
55 
56 namespace {
57 
58 struct FoldExpressionOp : public OpRewritePattern<ExpressionOp> {
60  LogicalResult matchAndRewrite(ExpressionOp expressionOp,
61  PatternRewriter &rewriter) const override {
62  Block *expressionBody = expressionOp.getBody();
63  ExpressionOp usedExpression;
64  SetVector<Value> foldedOperands;
65 
66  auto takesItsOperandsAddress = [](Operation *user) {
67  auto applyOp = dyn_cast<emitc::ApplyOp>(user);
68  return applyOp && applyOp.getApplicableOperator() == "&";
69  };
70 
71  // Select as expression to fold the first operand expression that
72  // - doesn't have its result value's address taken,
73  // - has a single user: assume any re-materialization was done separately,
74  // - has no side effects,
75  // and save all other operands to be used later as operands in the folded
76  // expression.
77  for (auto [operand, arg] : llvm::zip(expressionOp.getOperands(),
78  expressionBody->getArguments())) {
79  ExpressionOp operandExpression = operand.getDefiningOp<ExpressionOp>();
80  if (usedExpression || !operandExpression ||
81  llvm::any_of(arg.getUsers(), takesItsOperandsAddress) ||
82  !operandExpression.getResult().hasOneUse() ||
83  operandExpression.hasSideEffects())
84  foldedOperands.insert(operand);
85  else
86  usedExpression = operandExpression;
87  }
88 
89  // If no operand expression was selected, bail out.
90  if (!usedExpression)
91  return failure();
92 
93  // Collect additional operands from the folded expression.
94  for (Value operand : usedExpression.getOperands())
95  foldedOperands.insert(operand);
96 
97  // Create a new expression to hold the folding result.
98  rewriter.setInsertionPointAfter(expressionOp);
99  auto foldedExpression = emitc::ExpressionOp::create(
100  rewriter, expressionOp.getLoc(), expressionOp.getResult().getType(),
101  foldedOperands.getArrayRef(), expressionOp.getDoNotInline());
102  Block &foldedExpressionBody = foldedExpression.createBody();
103 
104  // Map each operand of the new expression to its matching block argument.
105  IRMapping mapper;
106  for (auto [operand, arg] : llvm::zip(foldedExpression.getOperands(),
107  foldedExpressionBody.getArguments()))
108  mapper.map(operand, arg);
109 
110  // Prepare to fold the used expression and the matched expression into the
111  // newly created folded expression.
112  auto foldExpression = [&rewriter, &mapper](ExpressionOp expressionToFold,
113  bool withTerminator) {
114  Block *expressionToFoldBody = expressionToFold.getBody();
115  for (auto [operand, arg] :
116  llvm::zip(expressionToFold.getOperands(),
117  expressionToFoldBody->getArguments())) {
118  mapper.map(arg, mapper.lookup(operand));
119  }
120 
121  for (Operation &opToClone : expressionToFoldBody->without_terminator())
122  rewriter.clone(opToClone, mapper);
123 
124  if (withTerminator)
125  rewriter.clone(*expressionToFoldBody->getTerminator(), mapper);
126  };
127  rewriter.setInsertionPointToStart(&foldedExpressionBody);
128 
129  // First, fold the used expression into the new expression and map its
130  // result to the clone of its root operation within the new expression.
131  foldExpression(usedExpression, /*withTerminator=*/false);
132  Operation *expressionRoot = usedExpression.getRootOp();
133  Operation *clonedExpressionRootOp = mapper.lookup(expressionRoot);
134  assert(clonedExpressionRootOp &&
135  "Expected cloned expression root to be in mapper");
136  assert(clonedExpressionRootOp->getNumResults() == 1 &&
137  "Expected cloned root to have a single result");
138  mapper.map(usedExpression.getResult(),
139  clonedExpressionRootOp->getResults()[0]);
140 
141  // Now fold the matched expression into the new expression.
142  foldExpression(expressionOp, /*withTerminator=*/true);
143 
144  // Complete the rewrite.
145  rewriter.replaceOp(expressionOp, foldedExpression);
146  rewriter.eraseOp(usedExpression);
147 
148  return success();
149  }
150 };
151 
152 } // namespace
153 
155  patterns.add<FoldExpressionOp>(patterns.getContext());
156 }
Block represents an ordered list of Operations.
Definition: Block.h:33
Operation * getTerminator()
Get the terminator operation of this block.
Definition: Block.cpp:244
BlockArgListType getArguments()
Definition: Block.h:87
iterator_range< iterator > without_terminator()
Return an iterator range over the operation within this block excluding the terminator operation at t...
Definition: Block.h:212
This is a utility class for mapping one set of IR entities to another.
Definition: IRMapping.h:26
auto lookup(T from) const
Lookup a mapped value within the map.
Definition: IRMapping.h:72
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Definition: IRMapping.h:30
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
This class helps build Operations.
Definition: Builders.h:207
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition: Builders.cpp:552
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:431
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Definition: Builders.h:436
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition: Builders.h:412
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:407
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:378
result_range getResults()
Definition: Operation.h:415
void erase()
Remove this operation from its parent block and delete it.
Definition: Operation.cpp:538
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:404
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:783
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
void replaceAllUsesWith(Value newValue)
Replace all uses of 'this' value with the new value, updating anything in the IR that uses 'this' to ...
Definition: Value.h:149
void populateExpressionPatterns(RewritePatternSet &patterns)
Populates patterns with expression-related patterns.
Definition: Transforms.cpp:154
ExpressionOp createExpression(Operation *op, OpBuilder &builder)
Definition: Transforms.cpp:19
Include the generated interface declarations.
const FrozenRewritePatternSet & patterns
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:314