MLIR  20.0.0git
Rewrite.cpp
Go to the documentation of this file.
1 //===- Rewrite.cpp - C API for Rewrite Patterns ---------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir-c/Rewrite.h"
10 
11 #include "mlir-c/Transforms.h"
12 #include "mlir/CAPI/IR.h"
13 #include "mlir/CAPI/Rewrite.h"
14 #include "mlir/CAPI/Support.h"
15 #include "mlir/CAPI/Wrap.h"
16 #include "mlir/IR/PatternMatch.h"
19 
20 using namespace mlir;
21 
22 //===----------------------------------------------------------------------===//
23 /// RewriterBase API inherited from OpBuilder
24 //===----------------------------------------------------------------------===//
25 
26 MlirContext mlirRewriterBaseGetContext(MlirRewriterBase rewriter) {
27  return wrap(unwrap(rewriter)->getContext());
28 }
29 
30 //===----------------------------------------------------------------------===//
31 /// Insertion points methods
32 
33 void mlirRewriterBaseClearInsertionPoint(MlirRewriterBase rewriter) {
34  unwrap(rewriter)->clearInsertionPoint();
35 }
36 
37 void mlirRewriterBaseSetInsertionPointBefore(MlirRewriterBase rewriter,
38  MlirOperation op) {
39  unwrap(rewriter)->setInsertionPoint(unwrap(op));
40 }
41 
42 void mlirRewriterBaseSetInsertionPointAfter(MlirRewriterBase rewriter,
43  MlirOperation op) {
44  unwrap(rewriter)->setInsertionPointAfter(unwrap(op));
45 }
46 
47 void mlirRewriterBaseSetInsertionPointAfterValue(MlirRewriterBase rewriter,
48  MlirValue value) {
49  unwrap(rewriter)->setInsertionPointAfterValue(unwrap(value));
50 }
51 
52 void mlirRewriterBaseSetInsertionPointToStart(MlirRewriterBase rewriter,
53  MlirBlock block) {
54  unwrap(rewriter)->setInsertionPointToStart(unwrap(block));
55 }
56 
57 void mlirRewriterBaseSetInsertionPointToEnd(MlirRewriterBase rewriter,
58  MlirBlock block) {
59  unwrap(rewriter)->setInsertionPointToEnd(unwrap(block));
60 }
61 
62 MlirBlock mlirRewriterBaseGetInsertionBlock(MlirRewriterBase rewriter) {
63  return wrap(unwrap(rewriter)->getInsertionBlock());
64 }
65 
66 MlirBlock mlirRewriterBaseGetBlock(MlirRewriterBase rewriter) {
67  return wrap(unwrap(rewriter)->getBlock());
68 }
69 
70 //===----------------------------------------------------------------------===//
71 /// Block and operation creation/insertion/cloning
72 
73 MlirBlock mlirRewriterBaseCreateBlockBefore(MlirRewriterBase rewriter,
74  MlirBlock insertBefore,
75  intptr_t nArgTypes,
76  MlirType const *argTypes,
77  MlirLocation const *locations) {
79  ArrayRef<Type> unwrappedArgs = unwrapList(nArgTypes, argTypes, args);
81  ArrayRef<Location> unwrappedLocs = unwrapList(nArgTypes, locations, locs);
82  return wrap(unwrap(rewriter)->createBlock(unwrap(insertBefore), unwrappedArgs,
83  unwrappedLocs));
84 }
85 
86 MlirOperation mlirRewriterBaseInsert(MlirRewriterBase rewriter,
87  MlirOperation op) {
88  return wrap(unwrap(rewriter)->insert(unwrap(op)));
89 }
90 
91 // Other methods of OpBuilder
92 
93 MlirOperation mlirRewriterBaseClone(MlirRewriterBase rewriter,
94  MlirOperation op) {
95  return wrap(unwrap(rewriter)->clone(*unwrap(op)));
96 }
97 
98 MlirOperation mlirRewriterBaseCloneWithoutRegions(MlirRewriterBase rewriter,
99  MlirOperation op) {
100  return wrap(unwrap(rewriter)->cloneWithoutRegions(*unwrap(op)));
101 }
102 
103 void mlirRewriterBaseCloneRegionBefore(MlirRewriterBase rewriter,
104  MlirRegion region, MlirBlock before) {
105 
106  unwrap(rewriter)->cloneRegionBefore(*unwrap(region), unwrap(before));
107 }
108 
109 //===----------------------------------------------------------------------===//
110 /// RewriterBase API
111 //===----------------------------------------------------------------------===//
112 
113 void mlirRewriterBaseInlineRegionBefore(MlirRewriterBase rewriter,
114  MlirRegion region, MlirBlock before) {
115  unwrap(rewriter)->inlineRegionBefore(*unwrap(region), unwrap(before));
116 }
117 
118 void mlirRewriterBaseReplaceOpWithValues(MlirRewriterBase rewriter,
119  MlirOperation op, intptr_t nValues,
120  MlirValue const *values) {
122  ArrayRef<Value> unwrappedVals = unwrapList(nValues, values, vals);
123  unwrap(rewriter)->replaceOp(unwrap(op), unwrappedVals);
124 }
125 
126 void mlirRewriterBaseReplaceOpWithOperation(MlirRewriterBase rewriter,
127  MlirOperation op,
128  MlirOperation newOp) {
129  unwrap(rewriter)->replaceOp(unwrap(op), unwrap(newOp));
130 }
131 
132 void mlirRewriterBaseEraseOp(MlirRewriterBase rewriter, MlirOperation op) {
133  unwrap(rewriter)->eraseOp(unwrap(op));
134 }
135 
136 void mlirRewriterBaseEraseBlock(MlirRewriterBase rewriter, MlirBlock block) {
137  unwrap(rewriter)->eraseBlock(unwrap(block));
138 }
139 
140 void mlirRewriterBaseInlineBlockBefore(MlirRewriterBase rewriter,
141  MlirBlock source, MlirOperation op,
142  intptr_t nArgValues,
143  MlirValue const *argValues) {
145  ArrayRef<Value> unwrappedVals = unwrapList(nArgValues, argValues, vals);
146 
147  unwrap(rewriter)->inlineBlockBefore(unwrap(source), unwrap(op),
148  unwrappedVals);
149 }
150 
151 void mlirRewriterBaseMergeBlocks(MlirRewriterBase rewriter, MlirBlock source,
152  MlirBlock dest, intptr_t nArgValues,
153  MlirValue const *argValues) {
155  ArrayRef<Value> unwrappedArgs = unwrapList(nArgValues, argValues, args);
156  unwrap(rewriter)->mergeBlocks(unwrap(source), unwrap(dest), unwrappedArgs);
157 }
158 
159 void mlirRewriterBaseMoveOpBefore(MlirRewriterBase rewriter, MlirOperation op,
160  MlirOperation existingOp) {
161  unwrap(rewriter)->moveOpBefore(unwrap(op), unwrap(existingOp));
162 }
163 
164 void mlirRewriterBaseMoveOpAfter(MlirRewriterBase rewriter, MlirOperation op,
165  MlirOperation existingOp) {
166  unwrap(rewriter)->moveOpAfter(unwrap(op), unwrap(existingOp));
167 }
168 
169 void mlirRewriterBaseMoveBlockBefore(MlirRewriterBase rewriter, MlirBlock block,
170  MlirBlock existingBlock) {
171  unwrap(rewriter)->moveBlockBefore(unwrap(block), unwrap(existingBlock));
172 }
173 
174 void mlirRewriterBaseStartOpModification(MlirRewriterBase rewriter,
175  MlirOperation op) {
176  unwrap(rewriter)->startOpModification(unwrap(op));
177 }
178 
179 void mlirRewriterBaseFinalizeOpModification(MlirRewriterBase rewriter,
180  MlirOperation op) {
181  unwrap(rewriter)->finalizeOpModification(unwrap(op));
182 }
183 
184 void mlirRewriterBaseCancelOpModification(MlirRewriterBase rewriter,
185  MlirOperation op) {
186  unwrap(rewriter)->cancelOpModification(unwrap(op));
187 }
188 
189 void mlirRewriterBaseReplaceAllUsesWith(MlirRewriterBase rewriter,
190  MlirValue from, MlirValue to) {
191  unwrap(rewriter)->replaceAllUsesWith(unwrap(from), unwrap(to));
192 }
193 
194 void mlirRewriterBaseReplaceAllValueRangeUsesWith(MlirRewriterBase rewriter,
195  intptr_t nValues,
196  MlirValue const *from,
197  MlirValue const *to) {
198  SmallVector<Value, 4> fromVals;
199  ArrayRef<Value> unwrappedFromVals = unwrapList(nValues, from, fromVals);
200  SmallVector<Value, 4> toVals;
201  ArrayRef<Value> unwrappedToVals = unwrapList(nValues, to, toVals);
202  unwrap(rewriter)->replaceAllUsesWith(unwrappedFromVals, unwrappedToVals);
203 }
204 
205 void mlirRewriterBaseReplaceAllOpUsesWithValueRange(MlirRewriterBase rewriter,
206  MlirOperation from,
207  intptr_t nTo,
208  MlirValue const *to) {
209  SmallVector<Value, 4> toVals;
210  ArrayRef<Value> unwrappedToVals = unwrapList(nTo, to, toVals);
211  unwrap(rewriter)->replaceAllOpUsesWith(unwrap(from), unwrappedToVals);
212 }
213 
214 void mlirRewriterBaseReplaceAllOpUsesWithOperation(MlirRewriterBase rewriter,
215  MlirOperation from,
216  MlirOperation to) {
217  unwrap(rewriter)->replaceAllOpUsesWith(unwrap(from), unwrap(to));
218 }
219 
220 void mlirRewriterBaseReplaceOpUsesWithinBlock(MlirRewriterBase rewriter,
221  MlirOperation op,
222  intptr_t nNewValues,
223  MlirValue const *newValues,
224  MlirBlock block) {
226  ArrayRef<Value> unwrappedVals = unwrapList(nNewValues, newValues, vals);
227  unwrap(rewriter)->replaceOpUsesWithinBlock(unwrap(op), unwrappedVals,
228  unwrap(block));
229 }
230 
231 void mlirRewriterBaseReplaceAllUsesExcept(MlirRewriterBase rewriter,
232  MlirValue from, MlirValue to,
233  MlirOperation exceptedUser) {
234  unwrap(rewriter)->replaceAllUsesExcept(unwrap(from), unwrap(to),
235  unwrap(exceptedUser));
236 }
237 
238 //===----------------------------------------------------------------------===//
239 /// IRRewriter API
240 //===----------------------------------------------------------------------===//
241 
242 MlirRewriterBase mlirIRRewriterCreate(MlirContext context) {
243  return wrap(new IRRewriter(unwrap(context)));
244 }
245 
246 MlirRewriterBase mlirIRRewriterCreateFromOp(MlirOperation op) {
247  return wrap(new IRRewriter(unwrap(op)));
248 }
249 
250 void mlirIRRewriterDestroy(MlirRewriterBase rewriter) {
251  delete static_cast<IRRewriter *>(unwrap(rewriter));
252 }
253 
254 //===----------------------------------------------------------------------===//
255 /// RewritePatternSet and FrozenRewritePatternSet API
256 //===----------------------------------------------------------------------===//
257 
258 inline mlir::RewritePatternSet &unwrap(MlirRewritePatternSet module) {
259  assert(module.ptr && "unexpected null module");
260  return *(static_cast<mlir::RewritePatternSet *>(module.ptr));
261 }
262 
263 inline MlirRewritePatternSet wrap(mlir::RewritePatternSet *module) {
264  return {module};
265 }
266 
268 unwrap(MlirFrozenRewritePatternSet module) {
269  assert(module.ptr && "unexpected null module");
270  return static_cast<mlir::FrozenRewritePatternSet *>(module.ptr);
271 }
272 
273 inline MlirFrozenRewritePatternSet wrap(mlir::FrozenRewritePatternSet *module) {
274  return {module};
275 }
276 
277 MlirFrozenRewritePatternSet mlirFreezeRewritePattern(MlirRewritePatternSet op) {
278  auto *m = new mlir::FrozenRewritePatternSet(std::move(unwrap(op)));
279  op.ptr = nullptr;
280  return wrap(m);
281 }
282 
283 void mlirFrozenRewritePatternSetDestroy(MlirFrozenRewritePatternSet op) {
284  delete unwrap(op);
285  op.ptr = nullptr;
286 }
287 
290  MlirFrozenRewritePatternSet patterns,
291  MlirGreedyRewriteDriverConfig) {
292  return wrap(
294 }
295 
296 //===----------------------------------------------------------------------===//
297 /// PDLPatternModule API
298 //===----------------------------------------------------------------------===//
299 
300 #if MLIR_ENABLE_PDL_IN_PATTERNMATCH
301 inline mlir::PDLPatternModule *unwrap(MlirPDLPatternModule module) {
302  assert(module.ptr && "unexpected null module");
303  return static_cast<mlir::PDLPatternModule *>(module.ptr);
304 }
305 
306 inline MlirPDLPatternModule wrap(mlir::PDLPatternModule *module) {
307  return {module};
308 }
309 
310 MlirPDLPatternModule mlirPDLPatternModuleFromModule(MlirModule op) {
311  return wrap(new mlir::PDLPatternModule(
313 }
314 
315 void mlirPDLPatternModuleDestroy(MlirPDLPatternModule op) {
316  delete unwrap(op);
317  op.ptr = nullptr;
318 }
319 
320 MlirRewritePatternSet
321 mlirRewritePatternSetFromPDLPatternModule(MlirPDLPatternModule op) {
322  auto *m = new mlir::RewritePatternSet(std::move(*unwrap(op)));
323  op.ptr = nullptr;
324  return wrap(m);
325 }
326 #endif // MLIR_ENABLE_PDL_IN_PATTERNMATCH
void mlirRewriterBaseReplaceOpUsesWithinBlock(MlirRewriterBase rewriter, MlirOperation op, intptr_t nNewValues, MlirValue const *newValues, MlirBlock block)
Find uses of from within block and replace them with to.
Definition: Rewrite.cpp:220
mlir::RewritePatternSet & unwrap(MlirRewritePatternSet module)
RewritePatternSet and FrozenRewritePatternSet API.
Definition: Rewrite.cpp:258
void mlirRewriterBaseMergeBlocks(MlirRewriterBase rewriter, MlirBlock source, MlirBlock dest, intptr_t nArgValues, MlirValue const *argValues)
Inline the operations of block 'source' into the end of block 'dest'.
Definition: Rewrite.cpp:151
void mlirIRRewriterDestroy(MlirRewriterBase rewriter)
Takes an IRRewriter owned by the caller and destroys it.
Definition: Rewrite.cpp:250
void mlirRewriterBaseStartOpModification(MlirRewriterBase rewriter, MlirOperation op)
This method is used to notify the rewriter that an in-place operation modification is about to happen...
Definition: Rewrite.cpp:174
MlirOperation mlirRewriterBaseInsert(MlirRewriterBase rewriter, MlirOperation op)
Insert the given operation at the current insertion point and return it.
Definition: Rewrite.cpp:86
MlirRewriterBase mlirIRRewriterCreate(MlirContext context)
IRRewriter API.
Definition: Rewrite.cpp:242
void mlirRewriterBaseMoveOpAfter(MlirRewriterBase rewriter, MlirOperation op, MlirOperation existingOp)
Unlink this operation from its current block and insert it right after existingOp which may be in the...
Definition: Rewrite.cpp:164
void mlirRewriterBaseCloneRegionBefore(MlirRewriterBase rewriter, MlirRegion region, MlirBlock before)
Clone the blocks that belong to "region" before the given position in another region "parent".
Definition: Rewrite.cpp:103
void mlirRewriterBaseSetInsertionPointAfter(MlirRewriterBase rewriter, MlirOperation op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition: Rewrite.cpp:42
void mlirFrozenRewritePatternSetDestroy(MlirFrozenRewritePatternSet op)
Definition: Rewrite.cpp:283
MlirRewritePatternSet wrap(mlir::RewritePatternSet *module)
Definition: Rewrite.cpp:263
void mlirRewriterBaseReplaceAllOpUsesWithOperation(MlirRewriterBase rewriter, MlirOperation from, MlirOperation to)
Find uses of from and replace them with to.
Definition: Rewrite.cpp:214
void mlirRewriterBaseMoveBlockBefore(MlirRewriterBase rewriter, MlirBlock block, MlirBlock existingBlock)
Unlink this block and insert it right before existingBlock.
Definition: Rewrite.cpp:169
void mlirRewriterBaseReplaceAllValueRangeUsesWith(MlirRewriterBase rewriter, intptr_t nValues, MlirValue const *from, MlirValue const *to)
Find uses of from and replace them with to.
Definition: Rewrite.cpp:194
void mlirRewriterBaseEraseBlock(MlirRewriterBase rewriter, MlirBlock block)
Erases a block along with all operations inside it.
Definition: Rewrite.cpp:136
void mlirRewriterBaseReplaceAllUsesExcept(MlirRewriterBase rewriter, MlirValue from, MlirValue to, MlirOperation exceptedUser)
Find uses of from and replace them with to except if the user is exceptedUser.
Definition: Rewrite.cpp:231
MlirBlock mlirRewriterBaseCreateBlockBefore(MlirRewriterBase rewriter, MlirBlock insertBefore, intptr_t nArgTypes, MlirType const *argTypes, MlirLocation const *locations)
Block and operation creation/insertion/cloning.
Definition: Rewrite.cpp:73
MlirLogicalResult mlirApplyPatternsAndFoldGreedily(MlirModule op, MlirFrozenRewritePatternSet patterns, MlirGreedyRewriteDriverConfig)
Definition: Rewrite.cpp:289
void mlirRewriterBaseSetInsertionPointToStart(MlirRewriterBase rewriter, MlirBlock block)
Sets the insertion point to the start of the specified block.
Definition: Rewrite.cpp:52
MlirContext mlirRewriterBaseGetContext(MlirRewriterBase rewriter)
RewriterBase API inherited from OpBuilder.
Definition: Rewrite.cpp:26
void mlirRewriterBaseReplaceAllOpUsesWithValueRange(MlirRewriterBase rewriter, MlirOperation from, intptr_t nTo, MlirValue const *to)
Find uses of from and replace them with to.
Definition: Rewrite.cpp:205
MlirOperation mlirRewriterBaseClone(MlirRewriterBase rewriter, MlirOperation op)
Creates a deep copy of the specified operation.
Definition: Rewrite.cpp:93
void mlirRewriterBaseInlineBlockBefore(MlirRewriterBase rewriter, MlirBlock source, MlirOperation op, intptr_t nArgValues, MlirValue const *argValues)
Inline the operations of block 'source' before the operation 'op'.
Definition: Rewrite.cpp:140
void mlirRewriterBaseReplaceOpWithValues(MlirRewriterBase rewriter, MlirOperation op, intptr_t nValues, MlirValue const *values)
Replace the results of the given (original) operation with the specified list of values (replacements...
Definition: Rewrite.cpp:118
void mlirRewriterBaseCancelOpModification(MlirRewriterBase rewriter, MlirOperation op)
This method cancels a pending in-place modification.
Definition: Rewrite.cpp:184
void mlirRewriterBaseSetInsertionPointAfterValue(MlirRewriterBase rewriter, MlirValue value)
Sets the insertion point to the node after the specified value.
Definition: Rewrite.cpp:47
void mlirRewriterBaseSetInsertionPointToEnd(MlirRewriterBase rewriter, MlirBlock block)
Sets the insertion point to the end of the specified block.
Definition: Rewrite.cpp:57
MlirOperation mlirRewriterBaseCloneWithoutRegions(MlirRewriterBase rewriter, MlirOperation op)
Creates a deep copy of this operation but keep the operation regions empty.
Definition: Rewrite.cpp:98
MlirBlock mlirRewriterBaseGetBlock(MlirRewriterBase rewriter)
Returns the current block of the rewriter.
Definition: Rewrite.cpp:66
void mlirRewriterBaseClearInsertionPoint(MlirRewriterBase rewriter)
Insertion points methods.
Definition: Rewrite.cpp:33
void mlirRewriterBaseInlineRegionBefore(MlirRewriterBase rewriter, MlirRegion region, MlirBlock before)
RewriterBase API.
Definition: Rewrite.cpp:113
void mlirRewriterBaseReplaceOpWithOperation(MlirRewriterBase rewriter, MlirOperation op, MlirOperation newOp)
Replace the results of the given (original) operation with the specified new op (replacement).
Definition: Rewrite.cpp:126
void mlirRewriterBaseReplaceAllUsesWith(MlirRewriterBase rewriter, MlirValue from, MlirValue to)
Find uses of from and replace them with to.
Definition: Rewrite.cpp:189
MlirFrozenRewritePatternSet mlirFreezeRewritePattern(MlirRewritePatternSet op)
FrozenRewritePatternSet API.
Definition: Rewrite.cpp:277
void mlirRewriterBaseFinalizeOpModification(MlirRewriterBase rewriter, MlirOperation op)
This method is used to signal the end of an in-place modification of the given operation.
Definition: Rewrite.cpp:179
MlirRewriterBase mlirIRRewriterCreateFromOp(MlirOperation op)
Create an IRRewriter and transfer ownership to the caller.
Definition: Rewrite.cpp:246
void mlirRewriterBaseMoveOpBefore(MlirRewriterBase rewriter, MlirOperation op, MlirOperation existingOp)
Unlink this operation from its current block and insert it right before existingOp which may be in th...
Definition: Rewrite.cpp:159
void mlirRewriterBaseSetInsertionPointBefore(MlirRewriterBase rewriter, MlirOperation op)
Sets the insertion point to the specified operation, which will cause subsequent insertions to go rig...
Definition: Rewrite.cpp:37
MlirBlock mlirRewriterBaseGetInsertionBlock(MlirRewriterBase rewriter)
Return the block the current insertion point belongs to.
Definition: Rewrite.cpp:62
void mlirRewriterBaseEraseOp(MlirRewriterBase rewriter, MlirOperation op)
Erases an operation that is known to have no uses.
Definition: Rewrite.cpp:132
static MlirBlock createBlock(const py::sequence &pyArgTypes, const std::optional< py::sequence > &pyArgLocs)
Create a block, using the current location context if no locations are specified.
Definition: IRCore.cpp:211
static MLIRContext * getContext(OpFoldResult val)
static llvm::ArrayRef< CppTy > unwrapList(size_t size, CTy *first, llvm::SmallVectorImpl< CppTy > &storage)
Definition: Wrap.h:40
This class represents a frozen set of patterns that can be processed by a pattern applicator.
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
Definition: PatternMatch.h:772
This class acts as an owning reference to an op, and will automatically destroy the held op on destru...
Definition: OwningOpRef.h:29
MLIR_CAPI_EXPORTED void mlirPDLPatternModuleDestroy(MlirPDLPatternModule op)
MLIR_CAPI_EXPORTED MlirPDLPatternModule mlirPDLPatternModuleFromModule(MlirModule op)
MLIR_CAPI_EXPORTED MlirRewritePatternSet mlirRewritePatternSetFromPDLPatternModule(MlirPDLPatternModule op)
Include the generated interface declarations.
Operation * cloneWithoutRegions(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
LogicalResult applyPatternsAndFoldGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
A logical result value, essentially a boolean with named states.
Definition: Support.h:116