MLIR 22.0.0git
Transforms.cpp
Go to the documentation of this file.
1//===- Transforms.cpp - Patterns and transforms for the EmitC dialect -----===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
11#include "mlir/IR/IRMapping.h"
12#include "mlir/IR/Location.h"
14#include "llvm/ADT/STLExtras.h"
15
16namespace mlir {
17namespace emitc {
18
19ExpressionOp createExpression(Operation *op, OpBuilder &builder) {
20 assert(isa<emitc::CExpressionInterface>(op) && "Expected a C expression");
21
22 // Create an expression yielding the value returned by op.
23 assert(op->getNumResults() == 1 && "Expected exactly one result");
24 Value result = op->getResult(0);
25 Type resultType = result.getType();
26 Location loc = op->getLoc();
27
28 builder.setInsertionPointAfter(op);
29 auto expressionOp =
30 emitc::ExpressionOp::create(builder, loc, resultType, op->getOperands());
31
32 // Replace all op's uses with the new expression's result.
33 result.replaceAllUsesWith(expressionOp.getResult());
34
35 Block &block = expressionOp.createBody();
36 IRMapping mapper;
37 for (auto [operand, arg] :
38 llvm::zip(expressionOp.getOperands(), block.getArguments()))
39 mapper.map(operand, arg);
40 builder.setInsertionPointToEnd(&block);
41
42 Operation *rootOp = builder.clone(*op, mapper);
43 op->erase();
44
45 // Create an op to yield op's value.
46 emitc::YieldOp::create(builder, loc, rootOp->getResults()[0]);
47 return expressionOp;
48}
49
50} // namespace emitc
51} // namespace mlir
52
53using namespace mlir;
54using namespace mlir::emitc;
55
56namespace {
57
58struct FoldExpressionOp : public OpRewritePattern<ExpressionOp> {
59 using OpRewritePattern<ExpressionOp>::OpRewritePattern;
60 LogicalResult matchAndRewrite(ExpressionOp expressionOp,
61 PatternRewriter &rewriter) const override {
62 Block *expressionBody = expressionOp.getBody();
63 ExpressionOp usedExpression;
64 SetVector<Value> foldedOperands;
65
66 auto takesItsOperandsAddress = [](Operation *user) {
67 auto applyOp = dyn_cast<emitc::ApplyOp>(user);
68 return applyOp && applyOp.getApplicableOperator() == "&";
69 };
70
71 // Select as expression to fold the first operand expression that
72 // - doesn't have its result value's address taken,
73 // - has a single user: assume any re-materialization was done separately,
74 // - has no side effects,
75 // and save all other operands to be used later as operands in the folded
76 // expression.
77 for (auto [operand, arg] : llvm::zip(expressionOp.getOperands(),
78 expressionBody->getArguments())) {
79 ExpressionOp operandExpression = operand.getDefiningOp<ExpressionOp>();
80 if (usedExpression || !operandExpression ||
81 llvm::any_of(arg.getUsers(), takesItsOperandsAddress) ||
82 !operandExpression.getResult().hasOneUse() ||
83 operandExpression.hasSideEffects())
84 foldedOperands.insert(operand);
85 else
86 usedExpression = operandExpression;
87 }
88
89 // If no operand expression was selected, bail out.
90 if (!usedExpression)
91 return failure();
92
93 // Collect additional operands from the folded expression.
94 for (Value operand : usedExpression.getOperands())
95 foldedOperands.insert(operand);
96
97 // Create a new expression to hold the folding result.
98 rewriter.setInsertionPointAfter(expressionOp);
99 auto foldedExpression = emitc::ExpressionOp::create(
100 rewriter, expressionOp.getLoc(), expressionOp.getResult().getType(),
101 foldedOperands.getArrayRef(), expressionOp.getDoNotInline());
102 Block &foldedExpressionBody = foldedExpression.createBody();
103
104 // Map each operand of the new expression to its matching block argument.
105 IRMapping mapper;
106 for (auto [operand, arg] : llvm::zip(foldedExpression.getOperands(),
107 foldedExpressionBody.getArguments()))
108 mapper.map(operand, arg);
109
110 // Prepare to fold the used expression and the matched expression into the
111 // newly created folded expression.
112 auto foldExpression = [&rewriter, &mapper](ExpressionOp expressionToFold,
113 bool withTerminator) {
114 Block *expressionToFoldBody = expressionToFold.getBody();
115 for (auto [operand, arg] :
116 llvm::zip(expressionToFold.getOperands(),
117 expressionToFoldBody->getArguments())) {
118 mapper.map(arg, mapper.lookup(operand));
119 }
120
121 for (Operation &opToClone : expressionToFoldBody->without_terminator())
122 rewriter.clone(opToClone, mapper);
123
124 if (withTerminator)
125 rewriter.clone(*expressionToFoldBody->getTerminator(), mapper);
126 };
127 rewriter.setInsertionPointToStart(&foldedExpressionBody);
128
129 // First, fold the used expression into the new expression and map its
130 // result to the clone of its root operation within the new expression.
131 foldExpression(usedExpression, /*withTerminator=*/false);
132 Operation *expressionRoot = usedExpression.getRootOp();
133 Operation *clonedExpressionRootOp = mapper.lookup(expressionRoot);
134 assert(clonedExpressionRootOp &&
135 "Expected cloned expression root to be in mapper");
136 assert(clonedExpressionRootOp->getNumResults() == 1 &&
137 "Expected cloned root to have a single result");
138 mapper.map(usedExpression.getResult(),
139 clonedExpressionRootOp->getResults()[0]);
140
141 // Now fold the matched expression into the new expression.
142 foldExpression(expressionOp, /*withTerminator=*/true);
143
144 // Complete the rewrite.
145 rewriter.replaceOp(expressionOp, foldedExpression);
146 rewriter.eraseOp(usedExpression);
147
148 return success();
149 }
150};
151
152} // namespace
153
155 patterns.add<FoldExpressionOp>(patterns.getContext());
156}
return success()
Block represents an ordered list of Operations.
Definition Block.h:33
Operation * getTerminator()
Get the terminator operation of this block.
Definition Block.cpp:244
BlockArgListType getArguments()
Definition Block.h:87
iterator_range< iterator > without_terminator()
Return an iterator range over the operation within this block excluding the terminator operation at t...
Definition Block.h:212
This is a utility class for mapping one set of IR entities to another.
Definition IRMapping.h:26
auto lookup(T from) const
Lookup a mapped value within the map.
Definition IRMapping.h:72
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Definition IRMapping.h:30
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
This class helps build Operations.
Definition Builders.h:207
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition Builders.cpp:562
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition Builders.h:431
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Definition Builders.h:436
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition Builders.h:412
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition Operation.h:407
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:223
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition Operation.h:378
result_range getResults()
Definition Operation.h:415
void erase()
Remove this operation from its parent block and delete it.
unsigned getNumResults()
Return the number of results held by this operation.
Definition Operation.h:404
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
void populateExpressionPatterns(RewritePatternSet &patterns)
Populates patterns with expression-related patterns.
ExpressionOp createExpression(Operation *op, OpBuilder &builder)
Include the generated interface declarations.
llvm::SetVector< T, Vector, Set, N > SetVector
Definition LLVM.h:131
const FrozenRewritePatternSet & patterns
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...