MLIR 23.0.0git
Transforms.cpp
Go to the documentation of this file.
1//===- Transforms.cpp - Patterns and transforms for the EmitC dialect -----===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
11#include "mlir/IR/IRMapping.h"
12#include "mlir/IR/Location.h"
14#include "llvm/ADT/STLExtras.h"
15
16namespace mlir {
17namespace emitc {
18
19ExpressionOp createExpression(Operation *op, OpBuilder &builder) {
20 assert(isa<emitc::CExpressionInterface>(op) && "Expected a C expression");
21
22 // Create an expression yielding the value returned by op.
23 assert(op->getNumResults() == 1 && "Expected exactly one result");
24 Value result = op->getResult(0);
25 Type resultType = result.getType();
26 Location loc = op->getLoc();
27
28 builder.setInsertionPointAfter(op);
29 auto expressionOp =
30 emitc::ExpressionOp::create(builder, loc, resultType, op->getOperands());
31
32 // Replace all op's uses with the new expression's result.
33 result.replaceAllUsesWith(expressionOp.getResult());
34
35 Block &block = expressionOp.createBody();
36 IRMapping mapper;
37 for (auto [operand, arg] :
38 llvm::zip(expressionOp.getOperands(), block.getArguments()))
39 mapper.map(operand, arg);
40 builder.setInsertionPointToEnd(&block);
41
42 Operation *rootOp = builder.clone(*op, mapper);
43 op->erase();
44
45 // Create an op to yield op's value.
46 emitc::YieldOp::create(builder, loc, rootOp->getResults()[0]);
47 return expressionOp;
48}
49
50} // namespace emitc
51} // namespace mlir
52
53using namespace mlir;
54using namespace mlir::emitc;
55
56namespace {
57
58struct FoldExpressionOp : public OpRewritePattern<ExpressionOp> {
59 using OpRewritePattern<ExpressionOp>::OpRewritePattern;
60 LogicalResult matchAndRewrite(ExpressionOp expressionOp,
61 PatternRewriter &rewriter) const override {
62 Block *expressionBody = expressionOp.getBody();
63 ExpressionOp usedExpression;
64 SetVector<Value> foldedOperands;
65
66 // Select as expression to fold the first operand expression that
67 // - has a single user: assume any re-materialization was done separately,
68 // - has no side effects,
69 // and save all other operands to be used later as operands in the folded
70 // expression.
71 for (auto [operand, arg] : llvm::zip(expressionOp.getOperands(),
72 expressionBody->getArguments())) {
73 ExpressionOp operandExpression = operand.getDefiningOp<ExpressionOp>();
74 if (usedExpression || !operandExpression ||
75 !operandExpression.getResult().hasOneUse() ||
76 operandExpression.hasSideEffects())
77 foldedOperands.insert(operand);
78 else
79 usedExpression = operandExpression;
80 }
81
82 // If no operand expression was selected, bail out.
83 if (!usedExpression)
84 return failure();
85
86 // Collect additional operands from the folded expression.
87 for (Value operand : usedExpression.getOperands())
88 foldedOperands.insert(operand);
89
90 // Create a new expression to hold the folding result.
91 rewriter.setInsertionPointAfter(expressionOp);
92 auto foldedExpression = emitc::ExpressionOp::create(
93 rewriter, expressionOp.getLoc(), expressionOp.getResult().getType(),
94 foldedOperands.getArrayRef(), expressionOp.getDoNotInline());
95 Block &foldedExpressionBody = foldedExpression.createBody();
96
97 // Map each operand of the new expression to its matching block argument.
98 IRMapping mapper;
99 for (auto [operand, arg] : llvm::zip(foldedExpression.getOperands(),
100 foldedExpressionBody.getArguments()))
101 mapper.map(operand, arg);
102
103 // Prepare to fold the used expression and the matched expression into the
104 // newly created folded expression.
105 auto foldExpression = [&rewriter, &mapper](ExpressionOp expressionToFold,
106 bool withTerminator) {
107 Block *expressionToFoldBody = expressionToFold.getBody();
108 for (auto [operand, arg] :
109 llvm::zip(expressionToFold.getOperands(),
110 expressionToFoldBody->getArguments())) {
111 mapper.map(arg, mapper.lookup(operand));
112 }
113
114 for (Operation &opToClone : expressionToFoldBody->without_terminator())
115 rewriter.clone(opToClone, mapper);
116
117 if (withTerminator)
118 rewriter.clone(*expressionToFoldBody->getTerminator(), mapper);
119 };
120 rewriter.setInsertionPointToStart(&foldedExpressionBody);
121
122 // First, fold the used expression into the new expression and map its
123 // result to the clone of its root operation within the new expression.
124 foldExpression(usedExpression, /*withTerminator=*/false);
125 Operation *expressionRoot = usedExpression.getRootOp();
126 Operation *clonedExpressionRootOp = mapper.lookup(expressionRoot);
127 assert(clonedExpressionRootOp &&
128 "Expected cloned expression root to be in mapper");
129 assert(clonedExpressionRootOp->getNumResults() == 1 &&
130 "Expected cloned root to have a single result");
131 mapper.map(usedExpression.getResult(),
132 clonedExpressionRootOp->getResults()[0]);
133
134 // Now fold the matched expression into the new expression.
135 foldExpression(expressionOp, /*withTerminator=*/true);
136
137 // Complete the rewrite.
138 rewriter.replaceOp(expressionOp, foldedExpression);
139 rewriter.eraseOp(usedExpression);
140
141 return success();
142 }
143};
144
145} // namespace
146
148 ExpressionOp::getCanonicalizationPatterns(patterns, patterns.getContext());
149 patterns.add<FoldExpressionOp>(patterns.getContext());
150}
return success()
Block represents an ordered list of Operations.
Definition Block.h:33
Operation * getTerminator()
Get the terminator operation of this block.
Definition Block.cpp:249
BlockArgListType getArguments()
Definition Block.h:97
iterator_range< iterator > without_terminator()
Return an iterator range over the operation within this block excluding the terminator operation at t...
Definition Block.h:222
This is a utility class for mapping one set of IR entities to another.
Definition IRMapping.h:26
auto lookup(T from) const
Lookup a mapped value within the map.
Definition IRMapping.h:72
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Definition IRMapping.h:30
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
This class helps build Operations.
Definition Builders.h:209
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition Builders.cpp:567
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition Builders.h:433
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Definition Builders.h:438
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition Builders.h:414
Operation is the basic unit of execution within MLIR.
Definition Operation.h:87
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition Operation.h:432
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:240
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition Operation.h:403
result_range getResults()
Definition Operation.h:440
void erase()
Remove this operation from its parent block and delete it.
unsigned getNumResults()
Return the number of results held by this operation.
Definition Operation.h:429
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
void populateExpressionPatterns(RewritePatternSet &patterns)
Populates patterns with expression-related patterns.
ExpressionOp createExpression(Operation *op, OpBuilder &builder)
Include the generated interface declarations.
llvm::SetVector< T, Vector, Set, N > SetVector
Definition LLVM.h:125
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...