14#include "llvm/ADT/STLExtras.h"
20 assert(isa<emitc::CExpressionInterface>(op) &&
"Expected a C expression");
23 assert(op->
getNumResults() == 1 &&
"Expected exactly one result");
30 emitc::ExpressionOp::create(builder, loc, resultType, op->
getOperands());
33 result.replaceAllUsesWith(expressionOp.getResult());
35 Block &block = expressionOp.createBody();
37 for (
auto [operand, arg] :
38 llvm::zip(expressionOp.getOperands(), block.
getArguments()))
39 mapper.
map(operand, arg);
46 emitc::YieldOp::create(builder, loc, rootOp->
getResults()[0]);
59 using OpRewritePattern<ExpressionOp>::OpRewritePattern;
60 LogicalResult matchAndRewrite(ExpressionOp expressionOp,
61 PatternRewriter &rewriter)
const override {
62 Block *expressionBody = expressionOp.getBody();
63 ExpressionOp usedExpression;
66 auto takesItsOperandsAddress = [](Operation *user) {
67 auto applyOp = dyn_cast<emitc::ApplyOp>(user);
68 return applyOp && applyOp.getApplicableOperator() ==
"&";
77 for (
auto [operand, arg] : llvm::zip(expressionOp.getOperands(),
79 ExpressionOp operandExpression = operand.getDefiningOp<ExpressionOp>();
80 if (usedExpression || !operandExpression ||
81 llvm::any_of(arg.getUsers(), takesItsOperandsAddress) ||
82 !operandExpression.getResult().hasOneUse() ||
83 operandExpression.hasSideEffects())
84 foldedOperands.insert(operand);
86 usedExpression = operandExpression;
94 for (Value operand : usedExpression.getOperands())
95 foldedOperands.insert(operand);
99 auto foldedExpression = emitc::ExpressionOp::create(
100 rewriter, expressionOp.getLoc(), expressionOp.getResult().getType(),
101 foldedOperands.getArrayRef(), expressionOp.getDoNotInline());
102 Block &foldedExpressionBody = foldedExpression.createBody();
106 for (
auto [operand, arg] : llvm::zip(foldedExpression.getOperands(),
108 mapper.
map(operand, arg);
112 auto foldExpression = [&rewriter, &mapper](ExpressionOp expressionToFold,
113 bool withTerminator) {
114 Block *expressionToFoldBody = expressionToFold.getBody();
115 for (
auto [operand, arg] :
116 llvm::zip(expressionToFold.getOperands(),
122 rewriter.
clone(opToClone, mapper);
131 foldExpression(usedExpression,
false);
132 Operation *expressionRoot = usedExpression.getRootOp();
133 Operation *clonedExpressionRootOp = mapper.
lookup(expressionRoot);
134 assert(clonedExpressionRootOp &&
135 "Expected cloned expression root to be in mapper");
137 "Expected cloned root to have a single result");
138 mapper.
map(usedExpression.getResult(),
142 foldExpression(expressionOp,
true);
145 rewriter.
replaceOp(expressionOp, foldedExpression);
146 rewriter.
eraseOp(usedExpression);
Block represents an ordered list of Operations.
Operation * getTerminator()
Get the terminator operation of this block.
BlockArgListType getArguments()
iterator_range< iterator > without_terminator()
Return an iterator range over the operation within this block excluding the terminator operation at t...
This is a utility class for mapping one set of IR entities to another.
auto lookup(T from) const
Lookup a mapped value within the map.
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
This class helps build Operations.
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Operation is the basic unit of execution within MLIR.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Location getLoc()
The source location the operation was defined or derived from.
operand_range getOperands()
Returns an iterator on the underlying Value's.
result_range getResults()
void erase()
Remove this operation from its parent block and delete it.
unsigned getNumResults()
Return the number of results held by this operation.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
void populateExpressionPatterns(RewritePatternSet &patterns)
Populates patterns with expression-related patterns.
ExpressionOp createExpression(Operation *op, OpBuilder &builder)
Include the generated interface declarations.
llvm::SetVector< T, Vector, Set, N > SetVector
const FrozenRewritePatternSet & patterns
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...