MLIR  21.0.0git
Rewrite.cpp
Go to the documentation of this file.
1 //===- Rewrite.cpp - C API for Rewrite Patterns ---------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir-c/Rewrite.h"
10 
11 #include "mlir-c/Transforms.h"
12 #include "mlir/CAPI/IR.h"
13 #include "mlir/CAPI/Rewrite.h"
14 #include "mlir/CAPI/Support.h"
15 #include "mlir/CAPI/Wrap.h"
16 #include "mlir/IR/PatternMatch.h"
19 
20 using namespace mlir;
21 
22 //===----------------------------------------------------------------------===//
23 /// RewriterBase API inherited from OpBuilder
24 //===----------------------------------------------------------------------===//
25 
26 MlirContext mlirRewriterBaseGetContext(MlirRewriterBase rewriter) {
27  return wrap(unwrap(rewriter)->getContext());
28 }
29 
30 //===----------------------------------------------------------------------===//
31 /// Insertion points methods
32 //===----------------------------------------------------------------------===//
33 
34 void mlirRewriterBaseClearInsertionPoint(MlirRewriterBase rewriter) {
35  unwrap(rewriter)->clearInsertionPoint();
36 }
37 
38 void mlirRewriterBaseSetInsertionPointBefore(MlirRewriterBase rewriter,
39  MlirOperation op) {
40  unwrap(rewriter)->setInsertionPoint(unwrap(op));
41 }
42 
43 void mlirRewriterBaseSetInsertionPointAfter(MlirRewriterBase rewriter,
44  MlirOperation op) {
45  unwrap(rewriter)->setInsertionPointAfter(unwrap(op));
46 }
47 
48 void mlirRewriterBaseSetInsertionPointAfterValue(MlirRewriterBase rewriter,
49  MlirValue value) {
50  unwrap(rewriter)->setInsertionPointAfterValue(unwrap(value));
51 }
52 
53 void mlirRewriterBaseSetInsertionPointToStart(MlirRewriterBase rewriter,
54  MlirBlock block) {
55  unwrap(rewriter)->setInsertionPointToStart(unwrap(block));
56 }
57 
58 void mlirRewriterBaseSetInsertionPointToEnd(MlirRewriterBase rewriter,
59  MlirBlock block) {
60  unwrap(rewriter)->setInsertionPointToEnd(unwrap(block));
61 }
62 
63 MlirBlock mlirRewriterBaseGetInsertionBlock(MlirRewriterBase rewriter) {
64  return wrap(unwrap(rewriter)->getInsertionBlock());
65 }
66 
67 MlirBlock mlirRewriterBaseGetBlock(MlirRewriterBase rewriter) {
68  return wrap(unwrap(rewriter)->getBlock());
69 }
70 
71 //===----------------------------------------------------------------------===//
72 /// Block and operation creation/insertion/cloning
73 //===----------------------------------------------------------------------===//
74 
75 MlirBlock mlirRewriterBaseCreateBlockBefore(MlirRewriterBase rewriter,
76  MlirBlock insertBefore,
77  intptr_t nArgTypes,
78  MlirType const *argTypes,
79  MlirLocation const *locations) {
81  ArrayRef<Type> unwrappedArgs = unwrapList(nArgTypes, argTypes, args);
83  ArrayRef<Location> unwrappedLocs = unwrapList(nArgTypes, locations, locs);
84  return wrap(unwrap(rewriter)->createBlock(unwrap(insertBefore), unwrappedArgs,
85  unwrappedLocs));
86 }
87 
88 MlirOperation mlirRewriterBaseInsert(MlirRewriterBase rewriter,
89  MlirOperation op) {
90  return wrap(unwrap(rewriter)->insert(unwrap(op)));
91 }
92 
93 // Other methods of OpBuilder
94 
95 MlirOperation mlirRewriterBaseClone(MlirRewriterBase rewriter,
96  MlirOperation op) {
97  return wrap(unwrap(rewriter)->clone(*unwrap(op)));
98 }
99 
100 MlirOperation mlirRewriterBaseCloneWithoutRegions(MlirRewriterBase rewriter,
101  MlirOperation op) {
102  return wrap(unwrap(rewriter)->cloneWithoutRegions(*unwrap(op)));
103 }
104 
105 void mlirRewriterBaseCloneRegionBefore(MlirRewriterBase rewriter,
106  MlirRegion region, MlirBlock before) {
107 
108  unwrap(rewriter)->cloneRegionBefore(*unwrap(region), unwrap(before));
109 }
110 
111 //===----------------------------------------------------------------------===//
112 /// RewriterBase API
113 //===----------------------------------------------------------------------===//
114 
115 void mlirRewriterBaseInlineRegionBefore(MlirRewriterBase rewriter,
116  MlirRegion region, MlirBlock before) {
117  unwrap(rewriter)->inlineRegionBefore(*unwrap(region), unwrap(before));
118 }
119 
120 void mlirRewriterBaseReplaceOpWithValues(MlirRewriterBase rewriter,
121  MlirOperation op, intptr_t nValues,
122  MlirValue const *values) {
124  ArrayRef<Value> unwrappedVals = unwrapList(nValues, values, vals);
125  unwrap(rewriter)->replaceOp(unwrap(op), unwrappedVals);
126 }
127 
128 void mlirRewriterBaseReplaceOpWithOperation(MlirRewriterBase rewriter,
129  MlirOperation op,
130  MlirOperation newOp) {
131  unwrap(rewriter)->replaceOp(unwrap(op), unwrap(newOp));
132 }
133 
134 void mlirRewriterBaseEraseOp(MlirRewriterBase rewriter, MlirOperation op) {
135  unwrap(rewriter)->eraseOp(unwrap(op));
136 }
137 
138 void mlirRewriterBaseEraseBlock(MlirRewriterBase rewriter, MlirBlock block) {
139  unwrap(rewriter)->eraseBlock(unwrap(block));
140 }
141 
142 void mlirRewriterBaseInlineBlockBefore(MlirRewriterBase rewriter,
143  MlirBlock source, MlirOperation op,
144  intptr_t nArgValues,
145  MlirValue const *argValues) {
147  ArrayRef<Value> unwrappedVals = unwrapList(nArgValues, argValues, vals);
148 
149  unwrap(rewriter)->inlineBlockBefore(unwrap(source), unwrap(op),
150  unwrappedVals);
151 }
152 
153 void mlirRewriterBaseMergeBlocks(MlirRewriterBase rewriter, MlirBlock source,
154  MlirBlock dest, intptr_t nArgValues,
155  MlirValue const *argValues) {
157  ArrayRef<Value> unwrappedArgs = unwrapList(nArgValues, argValues, args);
158  unwrap(rewriter)->mergeBlocks(unwrap(source), unwrap(dest), unwrappedArgs);
159 }
160 
161 void mlirRewriterBaseMoveOpBefore(MlirRewriterBase rewriter, MlirOperation op,
162  MlirOperation existingOp) {
163  unwrap(rewriter)->moveOpBefore(unwrap(op), unwrap(existingOp));
164 }
165 
166 void mlirRewriterBaseMoveOpAfter(MlirRewriterBase rewriter, MlirOperation op,
167  MlirOperation existingOp) {
168  unwrap(rewriter)->moveOpAfter(unwrap(op), unwrap(existingOp));
169 }
170 
171 void mlirRewriterBaseMoveBlockBefore(MlirRewriterBase rewriter, MlirBlock block,
172  MlirBlock existingBlock) {
173  unwrap(rewriter)->moveBlockBefore(unwrap(block), unwrap(existingBlock));
174 }
175 
176 void mlirRewriterBaseStartOpModification(MlirRewriterBase rewriter,
177  MlirOperation op) {
178  unwrap(rewriter)->startOpModification(unwrap(op));
179 }
180 
181 void mlirRewriterBaseFinalizeOpModification(MlirRewriterBase rewriter,
182  MlirOperation op) {
183  unwrap(rewriter)->finalizeOpModification(unwrap(op));
184 }
185 
186 void mlirRewriterBaseCancelOpModification(MlirRewriterBase rewriter,
187  MlirOperation op) {
188  unwrap(rewriter)->cancelOpModification(unwrap(op));
189 }
190 
191 void mlirRewriterBaseReplaceAllUsesWith(MlirRewriterBase rewriter,
192  MlirValue from, MlirValue to) {
193  unwrap(rewriter)->replaceAllUsesWith(unwrap(from), unwrap(to));
194 }
195 
196 void mlirRewriterBaseReplaceAllValueRangeUsesWith(MlirRewriterBase rewriter,
197  intptr_t nValues,
198  MlirValue const *from,
199  MlirValue const *to) {
200  SmallVector<Value, 4> fromVals;
201  ArrayRef<Value> unwrappedFromVals = unwrapList(nValues, from, fromVals);
202  SmallVector<Value, 4> toVals;
203  ArrayRef<Value> unwrappedToVals = unwrapList(nValues, to, toVals);
204  unwrap(rewriter)->replaceAllUsesWith(unwrappedFromVals, unwrappedToVals);
205 }
206 
207 void mlirRewriterBaseReplaceAllOpUsesWithValueRange(MlirRewriterBase rewriter,
208  MlirOperation from,
209  intptr_t nTo,
210  MlirValue const *to) {
211  SmallVector<Value, 4> toVals;
212  ArrayRef<Value> unwrappedToVals = unwrapList(nTo, to, toVals);
213  unwrap(rewriter)->replaceAllOpUsesWith(unwrap(from), unwrappedToVals);
214 }
215 
216 void mlirRewriterBaseReplaceAllOpUsesWithOperation(MlirRewriterBase rewriter,
217  MlirOperation from,
218  MlirOperation to) {
219  unwrap(rewriter)->replaceAllOpUsesWith(unwrap(from), unwrap(to));
220 }
221 
222 void mlirRewriterBaseReplaceOpUsesWithinBlock(MlirRewriterBase rewriter,
223  MlirOperation op,
224  intptr_t nNewValues,
225  MlirValue const *newValues,
226  MlirBlock block) {
228  ArrayRef<Value> unwrappedVals = unwrapList(nNewValues, newValues, vals);
229  unwrap(rewriter)->replaceOpUsesWithinBlock(unwrap(op), unwrappedVals,
230  unwrap(block));
231 }
232 
233 void mlirRewriterBaseReplaceAllUsesExcept(MlirRewriterBase rewriter,
234  MlirValue from, MlirValue to,
235  MlirOperation exceptedUser) {
236  unwrap(rewriter)->replaceAllUsesExcept(unwrap(from), unwrap(to),
237  unwrap(exceptedUser));
238 }
239 
240 //===----------------------------------------------------------------------===//
241 /// IRRewriter API
242 //===----------------------------------------------------------------------===//
243 
244 MlirRewriterBase mlirIRRewriterCreate(MlirContext context) {
245  return wrap(new IRRewriter(unwrap(context)));
246 }
247 
248 MlirRewriterBase mlirIRRewriterCreateFromOp(MlirOperation op) {
249  return wrap(new IRRewriter(unwrap(op)));
250 }
251 
252 void mlirIRRewriterDestroy(MlirRewriterBase rewriter) {
253  delete static_cast<IRRewriter *>(unwrap(rewriter));
254 }
255 
256 //===----------------------------------------------------------------------===//
257 /// RewritePatternSet and FrozenRewritePatternSet API
258 //===----------------------------------------------------------------------===//
259 
260 inline mlir::RewritePatternSet &unwrap(MlirRewritePatternSet module) {
261  assert(module.ptr && "unexpected null module");
262  return *(static_cast<mlir::RewritePatternSet *>(module.ptr));
263 }
264 
265 inline MlirRewritePatternSet wrap(mlir::RewritePatternSet *module) {
266  return {module};
267 }
268 
270 unwrap(MlirFrozenRewritePatternSet module) {
271  assert(module.ptr && "unexpected null module");
272  return static_cast<mlir::FrozenRewritePatternSet *>(module.ptr);
273 }
274 
275 inline MlirFrozenRewritePatternSet wrap(mlir::FrozenRewritePatternSet *module) {
276  return {module};
277 }
278 
279 MlirFrozenRewritePatternSet mlirFreezeRewritePattern(MlirRewritePatternSet op) {
280  auto *m = new mlir::FrozenRewritePatternSet(std::move(unwrap(op)));
281  op.ptr = nullptr;
282  return wrap(m);
283 }
284 
285 void mlirFrozenRewritePatternSetDestroy(MlirFrozenRewritePatternSet op) {
286  delete unwrap(op);
287  op.ptr = nullptr;
288 }
289 
292  MlirFrozenRewritePatternSet patterns,
293  MlirGreedyRewriteDriverConfig) {
295 }
296 
297 //===----------------------------------------------------------------------===//
298 /// PDLPatternModule API
299 //===----------------------------------------------------------------------===//
300 
301 #if MLIR_ENABLE_PDL_IN_PATTERNMATCH
302 inline mlir::PDLPatternModule *unwrap(MlirPDLPatternModule module) {
303  assert(module.ptr && "unexpected null module");
304  return static_cast<mlir::PDLPatternModule *>(module.ptr);
305 }
306 
307 inline MlirPDLPatternModule wrap(mlir::PDLPatternModule *module) {
308  return {module};
309 }
310 
311 MlirPDLPatternModule mlirPDLPatternModuleFromModule(MlirModule op) {
312  return wrap(new mlir::PDLPatternModule(
314 }
315 
316 void mlirPDLPatternModuleDestroy(MlirPDLPatternModule op) {
317  delete unwrap(op);
318  op.ptr = nullptr;
319 }
320 
321 MlirRewritePatternSet
322 mlirRewritePatternSetFromPDLPatternModule(MlirPDLPatternModule op) {
323  auto *m = new mlir::RewritePatternSet(std::move(*unwrap(op)));
324  op.ptr = nullptr;
325  return wrap(m);
326 }
327 #endif // MLIR_ENABLE_PDL_IN_PATTERNMATCH
void mlirRewriterBaseReplaceOpUsesWithinBlock(MlirRewriterBase rewriter, MlirOperation op, intptr_t nNewValues, MlirValue const *newValues, MlirBlock block)
Find uses of from within block and replace them with to.
Definition: Rewrite.cpp:222
mlir::RewritePatternSet & unwrap(MlirRewritePatternSet module)
RewritePatternSet and FrozenRewritePatternSet API.
Definition: Rewrite.cpp:260
void mlirRewriterBaseMergeBlocks(MlirRewriterBase rewriter, MlirBlock source, MlirBlock dest, intptr_t nArgValues, MlirValue const *argValues)
Inline the operations of block 'source' into the end of block 'dest'.
Definition: Rewrite.cpp:153
void mlirIRRewriterDestroy(MlirRewriterBase rewriter)
Takes an IRRewriter owned by the caller and destroys it.
Definition: Rewrite.cpp:252
void mlirRewriterBaseStartOpModification(MlirRewriterBase rewriter, MlirOperation op)
This method is used to notify the rewriter that an in-place operation modification is about to happen...
Definition: Rewrite.cpp:176
MlirOperation mlirRewriterBaseInsert(MlirRewriterBase rewriter, MlirOperation op)
Insert the given operation at the current insertion point and return it.
Definition: Rewrite.cpp:88
MlirRewriterBase mlirIRRewriterCreate(MlirContext context)
IRRewriter API.
Definition: Rewrite.cpp:244
void mlirRewriterBaseMoveOpAfter(MlirRewriterBase rewriter, MlirOperation op, MlirOperation existingOp)
Unlink this operation from its current block and insert it right after existingOp which may be in the...
Definition: Rewrite.cpp:166
void mlirRewriterBaseCloneRegionBefore(MlirRewriterBase rewriter, MlirRegion region, MlirBlock before)
Clone the blocks that belong to "region" before the given position in another region "parent".
Definition: Rewrite.cpp:105
void mlirRewriterBaseSetInsertionPointAfter(MlirRewriterBase rewriter, MlirOperation op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition: Rewrite.cpp:43
void mlirFrozenRewritePatternSetDestroy(MlirFrozenRewritePatternSet op)
Definition: Rewrite.cpp:285
MlirRewritePatternSet wrap(mlir::RewritePatternSet *module)
Definition: Rewrite.cpp:265
void mlirRewriterBaseReplaceAllOpUsesWithOperation(MlirRewriterBase rewriter, MlirOperation from, MlirOperation to)
Find uses of from and replace them with to.
Definition: Rewrite.cpp:216
void mlirRewriterBaseMoveBlockBefore(MlirRewriterBase rewriter, MlirBlock block, MlirBlock existingBlock)
Unlink this block and insert it right before existingBlock.
Definition: Rewrite.cpp:171
void mlirRewriterBaseReplaceAllValueRangeUsesWith(MlirRewriterBase rewriter, intptr_t nValues, MlirValue const *from, MlirValue const *to)
Find uses of from and replace them with to.
Definition: Rewrite.cpp:196
void mlirRewriterBaseEraseBlock(MlirRewriterBase rewriter, MlirBlock block)
Erases a block along with all operations inside it.
Definition: Rewrite.cpp:138
void mlirRewriterBaseReplaceAllUsesExcept(MlirRewriterBase rewriter, MlirValue from, MlirValue to, MlirOperation exceptedUser)
Find uses of from and replace them with to except if the user is exceptedUser.
Definition: Rewrite.cpp:233
MlirBlock mlirRewriterBaseCreateBlockBefore(MlirRewriterBase rewriter, MlirBlock insertBefore, intptr_t nArgTypes, MlirType const *argTypes, MlirLocation const *locations)
Block and operation creation/insertion/cloning.
Definition: Rewrite.cpp:75
MlirLogicalResult mlirApplyPatternsAndFoldGreedily(MlirModule op, MlirFrozenRewritePatternSet patterns, MlirGreedyRewriteDriverConfig)
Definition: Rewrite.cpp:291
void mlirRewriterBaseSetInsertionPointToStart(MlirRewriterBase rewriter, MlirBlock block)
Sets the insertion point to the start of the specified block.
Definition: Rewrite.cpp:53
MlirContext mlirRewriterBaseGetContext(MlirRewriterBase rewriter)
RewriterBase API inherited from OpBuilder.
Definition: Rewrite.cpp:26
void mlirRewriterBaseReplaceAllOpUsesWithValueRange(MlirRewriterBase rewriter, MlirOperation from, intptr_t nTo, MlirValue const *to)
Find uses of from and replace them with to.
Definition: Rewrite.cpp:207
MlirOperation mlirRewriterBaseClone(MlirRewriterBase rewriter, MlirOperation op)
Creates a deep copy of the specified operation.
Definition: Rewrite.cpp:95
void mlirRewriterBaseInlineBlockBefore(MlirRewriterBase rewriter, MlirBlock source, MlirOperation op, intptr_t nArgValues, MlirValue const *argValues)
Inline the operations of block 'source' before the operation 'op'.
Definition: Rewrite.cpp:142
void mlirRewriterBaseReplaceOpWithValues(MlirRewriterBase rewriter, MlirOperation op, intptr_t nValues, MlirValue const *values)
Replace the results of the given (original) operation with the specified list of values (replacements...
Definition: Rewrite.cpp:120
void mlirRewriterBaseCancelOpModification(MlirRewriterBase rewriter, MlirOperation op)
This method cancels a pending in-place modification.
Definition: Rewrite.cpp:186
void mlirRewriterBaseSetInsertionPointAfterValue(MlirRewriterBase rewriter, MlirValue value)
Sets the insertion point to the node after the specified value.
Definition: Rewrite.cpp:48
void mlirRewriterBaseSetInsertionPointToEnd(MlirRewriterBase rewriter, MlirBlock block)
Sets the insertion point to the end of the specified block.
Definition: Rewrite.cpp:58
MlirOperation mlirRewriterBaseCloneWithoutRegions(MlirRewriterBase rewriter, MlirOperation op)
Creates a deep copy of this operation but keep the operation regions empty.
Definition: Rewrite.cpp:100
MlirBlock mlirRewriterBaseGetBlock(MlirRewriterBase rewriter)
Returns the current block of the rewriter.
Definition: Rewrite.cpp:67
void mlirRewriterBaseClearInsertionPoint(MlirRewriterBase rewriter)
Insertion points methods.
Definition: Rewrite.cpp:34
void mlirRewriterBaseInlineRegionBefore(MlirRewriterBase rewriter, MlirRegion region, MlirBlock before)
RewriterBase API.
Definition: Rewrite.cpp:115
void mlirRewriterBaseReplaceOpWithOperation(MlirRewriterBase rewriter, MlirOperation op, MlirOperation newOp)
Replace the results of the given (original) operation with the specified new op (replacement).
Definition: Rewrite.cpp:128
void mlirRewriterBaseReplaceAllUsesWith(MlirRewriterBase rewriter, MlirValue from, MlirValue to)
Find uses of from and replace them with to.
Definition: Rewrite.cpp:191
MlirFrozenRewritePatternSet mlirFreezeRewritePattern(MlirRewritePatternSet op)
FrozenRewritePatternSet API.
Definition: Rewrite.cpp:279
void mlirRewriterBaseFinalizeOpModification(MlirRewriterBase rewriter, MlirOperation op)
This method is used to signal the end of an in-place modification of the given operation.
Definition: Rewrite.cpp:181
MlirRewriterBase mlirIRRewriterCreateFromOp(MlirOperation op)
Create an IRRewriter and transfer ownership to the caller.
Definition: Rewrite.cpp:248
void mlirRewriterBaseMoveOpBefore(MlirRewriterBase rewriter, MlirOperation op, MlirOperation existingOp)
Unlink this operation from its current block and insert it right before existingOp which may be in th...
Definition: Rewrite.cpp:161
void mlirRewriterBaseSetInsertionPointBefore(MlirRewriterBase rewriter, MlirOperation op)
Sets the insertion point to the specified operation, which will cause subsequent insertions to go rig...
Definition: Rewrite.cpp:38
MlirBlock mlirRewriterBaseGetInsertionBlock(MlirRewriterBase rewriter)
Return the block the current insertion point belongs to.
Definition: Rewrite.cpp:63
void mlirRewriterBaseEraseOp(MlirRewriterBase rewriter, MlirOperation op)
Erases an operation that is known to have no uses.
Definition: Rewrite.cpp:134
static MlirBlock createBlock(const nb::sequence &pyArgTypes, const std::optional< nb::sequence > &pyArgLocs)
Create a block, using the current location context if no locations are specified.
Definition: IRCore.cpp:230
static MLIRContext * getContext(OpFoldResult val)
static llvm::ArrayRef< CppTy > unwrapList(size_t size, CTy *first, llvm::SmallVectorImpl< CppTy > &storage)
Definition: Wrap.h:40
This class represents a frozen set of patterns that can be processed by a pattern applicator.
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
Definition: PatternMatch.h:734
This class acts as an owning reference to an op, and will automatically destroy the held op on destru...
Definition: OwningOpRef.h:29
MLIR_CAPI_EXPORTED void mlirPDLPatternModuleDestroy(MlirPDLPatternModule op)
MLIR_CAPI_EXPORTED MlirPDLPatternModule mlirPDLPatternModuleFromModule(MlirModule op)
MLIR_CAPI_EXPORTED MlirRewritePatternSet mlirRewritePatternSetFromPDLPatternModule(MlirPDLPatternModule op)
Include the generated interface declarations.
LogicalResult applyPatternsGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
Operation * cloneWithoutRegions(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
const FrozenRewritePatternSet & patterns
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
A logical result value, essentially a boolean with named states.
Definition: Support.h:116