MLIR 22.0.0git
FuncTransformOps.cpp
Go to the documentation of this file.
1//===- FuncTransformOps.cpp - Implementation of CF transform ops ----------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
10
20#include "llvm/ADT/STLExtras.h"
21
22using namespace mlir;
23
24//===----------------------------------------------------------------------===//
25// Apply...ConversionPatternsOp
26//===----------------------------------------------------------------------===//
27
28void transform::ApplyFuncToLLVMConversionPatternsOp::populatePatterns(
29 TypeConverter &typeConverter, RewritePatternSet &patterns) {
31 static_cast<LLVMTypeConverter &>(typeConverter), patterns);
32}
33
34LogicalResult
35transform::ApplyFuncToLLVMConversionPatternsOp::verifyTypeConverter(
36 transform::TypeConverterBuilderOpInterface builder) {
37 if (builder.getTypeConverterType() != "LLVMTypeConverter")
38 return emitOpError("expected LLVMTypeConverter");
39 return success();
40}
41
42//===----------------------------------------------------------------------===//
43// CastAndCallOp
44//===----------------------------------------------------------------------===//
45
47transform::CastAndCallOp::apply(transform::TransformRewriter &rewriter,
50 SmallVector<Value> inputs;
51 if (getInputs())
52 llvm::append_range(inputs, state.getPayloadValues(getInputs()));
53
54 SetVector<Value> outputs;
55 if (getOutputs()) {
56 outputs.insert_range(state.getPayloadValues(getOutputs()));
57
58 // Verify that the set of output values to be replaced is unique.
59 if (outputs.size() !=
60 llvm::range_size(state.getPayloadValues(getOutputs()))) {
61 return emitSilenceableFailure(getLoc())
62 << "cast and call output values must be unique";
63 }
64 }
65
66 // Get the insertion point for the call.
67 auto insertionOps = state.getPayloadOps(getInsertionPoint());
68 if (!llvm::hasSingleElement(insertionOps)) {
69 return emitSilenceableFailure(getLoc())
70 << "Only one op can be specified as an insertion point";
71 }
72 bool insertAfter = getInsertAfter();
73 Operation *insertionPoint = *insertionOps.begin();
74
75 // Check that all inputs dominate the insertion point, and the insertion
76 // point dominates all users of the outputs.
77 DominanceInfo dom(insertionPoint);
78 for (Value output : outputs) {
79 for (Operation *user : output.getUsers()) {
80 // If we are inserting after the insertion point operation, the
81 // insertion point operation must properly dominate the user. Otherwise
82 // basic dominance is enough.
83 bool doesDominate = insertAfter
84 ? dom.properlyDominates(insertionPoint, user)
85 : dom.dominates(insertionPoint, user);
86 if (!doesDominate) {
87 return emitDefiniteFailure()
88 << "User " << user << " is not dominated by insertion point "
89 << insertionPoint;
90 }
91 }
92 }
93
94 for (Value input : inputs) {
95 // If we are inserting before the insertion point operation, the
96 // input must properly dominate the insertion point operation. Otherwise
97 // basic dominance is enough.
98 bool doesDominate = insertAfter
99 ? dom.dominates(input, insertionPoint)
100 : dom.properlyDominates(input, insertionPoint);
101 if (!doesDominate) {
102 return emitDefiniteFailure()
103 << "input " << input << " does not dominate insertion point "
104 << insertionPoint;
105 }
106 }
107
108 // Get the function to call. This can either be specified by symbol or as a
109 // transform handle.
110 func::FuncOp targetFunction = nullptr;
111 if (getFunctionName()) {
113 insertionPoint, *getFunctionName());
114 if (!targetFunction) {
115 return emitDefiniteFailure()
116 << "unresolved symbol " << *getFunctionName();
117 }
118 } else if (getFunction()) {
119 auto payloadOps = state.getPayloadOps(getFunction());
120 if (!llvm::hasSingleElement(payloadOps)) {
121 return emitDefiniteFailure() << "requires a single function to call";
122 }
123 targetFunction = dyn_cast<func::FuncOp>(*payloadOps.begin());
124 if (!targetFunction) {
125 return emitDefiniteFailure() << "invalid non-function callee";
126 }
127 } else {
128 llvm_unreachable("Invalid CastAndCall op without a function to call");
129 return emitDefiniteFailure();
130 }
131
132 // Verify that the function argument and result lengths match the inputs and
133 // outputs given to this op.
134 if (targetFunction.getNumArguments() != inputs.size()) {
135 return emitSilenceableFailure(targetFunction.getLoc())
136 << "mismatch between number of function arguments "
137 << targetFunction.getNumArguments() << " and number of inputs "
138 << inputs.size();
139 }
140 if (targetFunction.getNumResults() != outputs.size()) {
141 return emitSilenceableFailure(targetFunction.getLoc())
142 << "mismatch between number of function results "
143 << targetFunction->getNumResults() << " and number of outputs "
144 << outputs.size();
145 }
146
147 // Gather all specified converters.
148 mlir::TypeConverter converter;
149 if (!getRegion().empty()) {
150 for (Operation &op : getRegion().front()) {
151 cast<transform::TypeConverterBuilderOpInterface>(&op)
152 .populateTypeMaterializations(converter);
153 }
154 }
155
156 if (insertAfter)
157 rewriter.setInsertionPointAfter(insertionPoint);
158 else
159 rewriter.setInsertionPoint(insertionPoint);
160
161 for (auto [input, type] :
162 llvm::zip_equal(inputs, targetFunction.getArgumentTypes())) {
163 if (input.getType() != type) {
164 Value newInput = converter.materializeSourceConversion(
165 rewriter, input.getLoc(), type, input);
166 if (!newInput) {
167 return emitDefiniteFailure() << "Failed to materialize conversion of "
168 << input << " to type " << type;
169 }
170 input = newInput;
171 }
172 }
173
174 auto callOp = func::CallOp::create(rewriter, insertionPoint->getLoc(),
175 targetFunction, inputs);
176
177 // Cast the call results back to the expected types. If any conversions fail
178 // this is a definite failure as the call has been constructed at this point.
179 for (auto [output, newOutput] :
180 llvm::zip_equal(outputs, callOp.getResults())) {
181 Value convertedOutput = newOutput;
182 if (output.getType() != newOutput.getType()) {
183 convertedOutput = converter.materializeTargetConversion(
184 rewriter, output.getLoc(), output.getType(), newOutput);
185 if (!convertedOutput) {
186 return emitDefiniteFailure()
187 << "Failed to materialize conversion of " << newOutput
188 << " to type " << output.getType();
189 }
190 }
191 rewriter.replaceAllUsesExcept(output, convertedOutput, callOp);
192 }
193 results.set(cast<OpResult>(getResult()), {callOp});
195}
196
197LogicalResult transform::CastAndCallOp::verify() {
198 if (!getRegion().empty()) {
199 for (Operation &op : getRegion().front()) {
200 if (!isa<transform::TypeConverterBuilderOpInterface>(&op)) {
202 << "expected children ops to implement "
203 "TypeConverterBuilderOpInterface";
204 diag.attachNote(op.getLoc()) << "op without interface";
205 return diag;
206 }
207 }
208 }
209 if (!getFunction() && !getFunctionName()) {
210 return emitOpError() << "expected a function handle or name to call";
211 }
212 if (getFunction() && getFunctionName()) {
213 return emitOpError() << "function handle and name are mutually exclusive";
214 }
215 return success();
216}
217
218void transform::CastAndCallOp::getEffects(
220 transform::onlyReadsHandle(getInsertionPointMutable(), effects);
221 if (getInputs())
222 transform::onlyReadsHandle(getInputsMutable(), effects);
223 if (getOutputs())
224 transform::onlyReadsHandle(getOutputsMutable(), effects);
225 if (getFunction())
226 transform::onlyReadsHandle(getFunctionMutable(), effects);
227 transform::producesHandle(getOperation()->getOpResults(), effects);
229}
230
231//===----------------------------------------------------------------------===//
232// ReplaceFuncSignatureOp
233//===----------------------------------------------------------------------===//
234
236transform::ReplaceFuncSignatureOp::apply(transform::TransformRewriter &rewriter,
239 auto payloadOps = state.getPayloadOps(getModule());
240 if (!llvm::hasSingleElement(payloadOps))
241 return emitDefiniteFailure() << "requires a single module to operate on";
242
243 auto targetModuleOp = dyn_cast<ModuleOp>(*payloadOps.begin());
244 if (!targetModuleOp)
245 return emitSilenceableFailure(getLoc())
246 << "target is expected to be module operation";
247
248 func::FuncOp funcOp =
249 targetModuleOp.lookupSymbol<func::FuncOp>(getFunctionName());
250 if (!funcOp)
251 return emitSilenceableFailure(getLoc())
252 << "function with name '" << getFunctionName() << "' not found";
253
254 unsigned numArgs = funcOp.getNumArguments();
255 unsigned numResults = funcOp.getNumResults();
256 // Check that the number of arguments and results matches the
257 // interchange sizes.
258 if (numArgs != getArgsInterchange().size())
259 return emitSilenceableFailure(getLoc())
260 << "function with name '" << getFunctionName() << "' has " << numArgs
261 << " arguments, but " << getArgsInterchange().size()
262 << " args interchange were given";
263
264 if (numResults != getResultsInterchange().size())
265 return emitSilenceableFailure(getLoc())
266 << "function with name '" << getFunctionName() << "' has "
267 << numResults << " results, but " << getResultsInterchange().size()
268 << " results interchange were given";
269
270 // Check that the args and results interchanges are unique.
271 SetVector<unsigned> argsInterchange, resultsInterchange;
272 argsInterchange.insert_range(getArgsInterchange());
273 resultsInterchange.insert_range(getResultsInterchange());
274 if (argsInterchange.size() != getArgsInterchange().size())
275 return emitSilenceableFailure(getLoc())
276 << "args interchange must be unique";
277
278 if (resultsInterchange.size() != getResultsInterchange().size())
279 return emitSilenceableFailure(getLoc())
280 << "results interchange must be unique";
281
282 // Check that the args and results interchange indices are in bounds.
283 for (unsigned index : argsInterchange) {
284 if (index >= numArgs) {
285 return emitSilenceableFailure(getLoc())
286 << "args interchange index " << index
287 << " is out of bounds for function with name '"
288 << getFunctionName() << "' with " << numArgs << " arguments";
289 }
290 }
291 for (unsigned index : resultsInterchange) {
292 if (index >= numResults) {
293 return emitSilenceableFailure(getLoc())
294 << "results interchange index " << index
295 << " is out of bounds for function with name '"
296 << getFunctionName() << "' with " << numResults << " results";
297 }
298 }
299
300 llvm::SmallVector<int> oldArgToNewArg(argsInterchange.size());
301 for (auto [newArgIdx, oldArgIdx] : llvm::enumerate(argsInterchange))
302 oldArgToNewArg[oldArgIdx] = newArgIdx;
303
304 llvm::SmallVector<int> oldResToNewRes(resultsInterchange.size());
305 for (auto [newResIdx, oldResIdx] : llvm::enumerate(resultsInterchange))
306 oldResToNewRes[oldResIdx] = newResIdx;
307
308 FailureOr<func::FuncOp> newFuncOpOrFailure = func::replaceFuncWithNewMapping(
309 rewriter, funcOp, oldArgToNewArg, oldResToNewRes);
310 if (failed(newFuncOpOrFailure))
311 return emitSilenceableFailure(getLoc())
312 << "failed to replace function signature '" << getFunctionName()
313 << "' with new order";
314
315 if (getAdjustFuncCalls()) {
317 targetModuleOp.walk([&](func::CallOp callOp) {
318 if (callOp.getCallee() == getFunctionName().getRootReference().getValue())
319 callOps.push_back(callOp);
320 });
321
322 for (func::CallOp callOp : callOps)
323 func::replaceCallOpWithNewMapping(rewriter, callOp, oldArgToNewArg,
324 oldResToNewRes);
325 }
326
327 results.set(cast<OpResult>(getTransformedModule()), {targetModuleOp});
328 results.set(cast<OpResult>(getTransformedFunction()), {*newFuncOpOrFailure});
329
331}
332
333void transform::ReplaceFuncSignatureOp::getEffects(
335 transform::consumesHandle(getModuleMutable(), effects);
336 transform::producesHandle(getOperation()->getOpResults(), effects);
338}
339
340//===----------------------------------------------------------------------===//
341// DeduplicateFuncArgsOp
342//===----------------------------------------------------------------------===//
343
345transform::DeduplicateFuncArgsOp::apply(transform::TransformRewriter &rewriter,
348 auto payloadOps = state.getPayloadOps(getModule());
349 if (!llvm::hasSingleElement(payloadOps))
350 return emitDefiniteFailure() << "requires a single module to operate on";
351
352 auto targetModuleOp = dyn_cast<ModuleOp>(*payloadOps.begin());
353 if (!targetModuleOp)
354 return emitSilenceableFailure(getLoc())
355 << "target is expected to be module operation";
356
357 func::FuncOp funcOp =
358 targetModuleOp.lookupSymbol<func::FuncOp>(getFunctionName());
359 if (!funcOp)
360 return emitSilenceableFailure(getLoc())
361 << "function with name '" << getFunctionName() << "' is not found";
362
363 auto transformationResult =
364 func::deduplicateArgsOfFuncOp(rewriter, funcOp, targetModuleOp);
365 if (failed(transformationResult))
366 return emitSilenceableFailure(getLoc())
367 << "failed to deduplicate function arguments of function "
368 << funcOp.getName();
369
370 auto [newFuncOp, newCallOp] = *transformationResult;
371
372 results.set(cast<OpResult>(getTransformedModule()), {targetModuleOp});
373 results.set(cast<OpResult>(getTransformedFunction()), {newFuncOp});
374
376}
377
378void transform::DeduplicateFuncArgsOp::getEffects(
380 transform::consumesHandle(getModuleMutable(), effects);
381 transform::producesHandle(getOperation()->getOpResults(), effects);
383}
384
385//===----------------------------------------------------------------------===//
386// Transform op registration
387//===----------------------------------------------------------------------===//
388
389namespace {
390class FuncTransformDialectExtension
392 FuncTransformDialectExtension> {
393public:
394 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(FuncTransformDialectExtension)
395
396 using Base::Base;
397
398 void init() {
399 declareGeneratedDialect<LLVM::LLVMDialect>();
400
401 registerTransformOps<
402#define GET_OP_LIST
403#include "mlir/Dialect/Func/TransformOps/FuncTransformOps.cpp.inc"
404 >();
405 }
406};
407} // namespace
408
409#define GET_OP_CLASSES
410#include "mlir/Dialect/Func/TransformOps/FuncTransformOps.cpp.inc"
411
413 registry.addExtensions<FuncTransformDialectExtension>();
414}
return success()
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
static std::string diag(const llvm::Value &value)
#define MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(CLASS_NAME)
Definition TypeID.h:331
The result of a transform IR operation application.
static DiagnosedSilenceableFailure success()
Constructs a DiagnosedSilenceableFailure in the success state.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
void addExtensions()
Add the given extensions to the registry.
A class for computing basic dominance information.
Definition Dominance.h:140
This class represents a diagnostic that is inflight and set to be reported.
Conversion from types to the LLVM IR dialect.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition Builders.h:398
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition Builders.h:412
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:223
user_range getUsers()
Returns a range of all users.
Definition Operation.h:873
void replaceAllUsesExcept(Value from, Value to, Operation *exceptedUser)
Find uses of from and replace them with to except if the user is exceptedUser.
static Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Base class for extensions of the Transform dialect that supports injecting operations into the Transf...
Local mapping between values defined by a specific op implementing the TransformOpInterface and the p...
void set(OpResult value, Range &&ops)
Indicates that the result of the transform IR op at the given position corresponds to the given list ...
This is a special rewriter to be used in transform op implementations, providing additional helper fu...
The state maintained across applications of various ops implementing the TransformOpInterface.
auto getPayloadOps(Value value) const
Returns an iterator that enumerates all ops that the given transform IR value corresponds to.
auto getPayloadValues(Value handleValue) const
Returns an iterator that enumerates all payload IR values that the given transform IR value correspon...
mlir::func::CallOp replaceCallOpWithNewMapping(mlir::RewriterBase &rewriter, mlir::func::CallOp callOp, ArrayRef< int > oldArgIdxToNewArgIdx, ArrayRef< int > oldResIdxToNewResIdx)
Creates a new call operation with the values as the original call operation, but with the arguments m...
void registerTransformDialectExtension(DialectRegistry &registry)
mlir::FailureOr< mlir::func::FuncOp > replaceFuncWithNewMapping(mlir::RewriterBase &rewriter, mlir::func::FuncOp funcOp, ArrayRef< int > oldArgIdxToNewArgIdx, ArrayRef< int > oldResIdxToNewResIdx)
Creates a new function operation with the same name as the original function operation,...
mlir::FailureOr< std::pair< mlir::func::FuncOp, mlir::func::CallOp > > deduplicateArgsOfFuncOp(mlir::RewriterBase &rewriter, mlir::func::FuncOp funcOp, mlir::ModuleOp moduleOp)
This utility function examines all call operations within the given moduleOp that target the specifie...
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:561
void producesHandle(ResultRange handles, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
void consumesHandle(MutableArrayRef< OpOperand > handles, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
Populates effects with the memory effects indicating the operation on the given handle value:
void onlyReadsHandle(MutableArrayRef< OpOperand > handles, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
void modifiesPayload(SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
Populates effects with the memory effects indicating the access to payload IR resource.
Include the generated interface declarations.
DiagnosedSilenceableFailure emitSilenceableFailure(Location loc, const Twine &message={})
Emits a silenceable failure with the given message.
void populateFuncToLLVMConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns, SymbolTableCollection *symbolTables=nullptr)
Collect the patterns to convert from the Func dialect to LLVM.
llvm::SetVector< T, Vector, Set, N > SetVector
Definition LLVM.h:131
DiagnosedDefiniteFailure emitDefiniteFailure(Location loc, const Twine &message={})
Emits a definite failure with the given message.
const FrozenRewritePatternSet & patterns