28 void transform::ApplyFuncToLLVMConversionPatternsOp::populatePatterns(
35 transform::ApplyFuncToLLVMConversionPatternsOp::verifyTypeConverter(
36 transform::TypeConverterBuilderOpInterface builder) {
37 if (builder.getTypeConverterType() !=
"LLVMTypeConverter")
38 return emitOpError(
"expected LLVMTypeConverter");
52 llvm::append_range(inputs, state.getPayloadValues(getInputs()));
56 outputs.insert_range(state.getPayloadValues(getOutputs()));
60 llvm::range_size(state.getPayloadValues(getOutputs()))) {
62 <<
"cast and call output values must be unique";
67 auto insertionOps = state.getPayloadOps(getInsertionPoint());
68 if (!llvm::hasSingleElement(insertionOps)) {
70 <<
"Only one op can be specified as an insertion point";
72 bool insertAfter = getInsertAfter();
73 Operation *insertionPoint = *insertionOps.begin();
78 for (
Value output : outputs) {
83 bool doesDominate = insertAfter
84 ? dom.properlyDominates(insertionPoint, user)
85 : dom.dominates(insertionPoint, user);
88 <<
"User " << user <<
" is not dominated by insertion point "
94 for (
Value input : inputs) {
98 bool doesDominate = insertAfter
99 ? dom.dominates(input, insertionPoint)
100 : dom.properlyDominates(input, insertionPoint);
103 <<
"input " << input <<
" does not dominate insertion point "
110 func::FuncOp targetFunction =
nullptr;
111 if (getFunctionName()) {
112 targetFunction = SymbolTable::lookupNearestSymbolFrom<func::FuncOp>(
113 insertionPoint, *getFunctionName());
114 if (!targetFunction) {
116 <<
"unresolved symbol " << *getFunctionName();
118 }
else if (getFunction()) {
119 auto payloadOps = state.getPayloadOps(getFunction());
120 if (!llvm::hasSingleElement(payloadOps)) {
123 targetFunction = dyn_cast<func::FuncOp>(*payloadOps.begin());
124 if (!targetFunction) {
128 llvm_unreachable(
"Invalid CastAndCall op without a function to call");
134 if (targetFunction.getNumArguments() != inputs.size()) {
136 <<
"mismatch between number of function arguments "
137 << targetFunction.getNumArguments() <<
" and number of inputs "
140 if (targetFunction.getNumResults() != outputs.size()) {
142 <<
"mismatch between number of function results "
143 << targetFunction->getNumResults() <<
" and number of outputs "
149 if (!getRegion().empty()) {
150 for (
Operation &op : getRegion().front()) {
151 cast<transform::TypeConverterBuilderOpInterface>(&op)
152 .populateTypeMaterializations(converter);
161 for (
auto [input, type] :
162 llvm::zip_equal(inputs, targetFunction.getArgumentTypes())) {
163 if (input.getType() != type) {
165 rewriter, input.getLoc(), type, input);
168 << input <<
" to type " << type;
174 auto callOp = rewriter.
create<func::CallOp>(insertionPoint->
getLoc(),
175 targetFunction, inputs);
179 for (
auto [output, newOutput] :
180 llvm::zip_equal(outputs, callOp.getResults())) {
181 Value convertedOutput = newOutput;
182 if (output.getType() != newOutput.getType()) {
184 rewriter, output.getLoc(), output.getType(), newOutput);
185 if (!convertedOutput) {
187 <<
"Failed to materialize conversion of " << newOutput
188 <<
" to type " << output.getType();
193 results.
set(cast<OpResult>(getResult()), {callOp});
198 if (!getRegion().empty()) {
199 for (
Operation &op : getRegion().front()) {
200 if (!isa<transform::TypeConverterBuilderOpInterface>(&op)) {
202 <<
"expected children ops to implement "
203 "TypeConverterBuilderOpInterface";
204 diag.attachNote(op.getLoc()) <<
"op without interface";
209 if (!getFunction() && !getFunctionName()) {
210 return emitOpError() <<
"expected a function handle or name to call";
212 if (getFunction() && getFunctionName()) {
213 return emitOpError() <<
"function handle and name are mutually exclusive";
218 void transform::CastAndCallOp::getEffects(
239 auto payloadOps = state.getPayloadOps(getModule());
240 if (!llvm::hasSingleElement(payloadOps))
243 auto targetModuleOp = dyn_cast<ModuleOp>(*payloadOps.begin());
246 <<
"target is expected to be module operation";
248 func::FuncOp funcOp =
249 targetModuleOp.lookupSymbol<func::FuncOp>(getFunctionName());
252 <<
"function with name '" << getFunctionName() <<
"' not found";
254 unsigned numArgs = funcOp.getNumArguments();
255 unsigned numResults = funcOp.getNumResults();
258 if (numArgs != getArgsInterchange().size())
260 <<
"function with name '" << getFunctionName() <<
"' has " << numArgs
261 <<
" arguments, but " << getArgsInterchange().size()
262 <<
" args interchange were given";
264 if (numResults != getResultsInterchange().size())
266 <<
"function with name '" << getFunctionName() <<
"' has "
267 << numResults <<
" results, but " << getResultsInterchange().size()
268 <<
" results interchange were given";
272 argsInterchange.insert_range(getArgsInterchange());
273 resultsInterchange.insert_range(getResultsInterchange());
274 if (argsInterchange.size() != getArgsInterchange().size())
276 <<
"args interchange must be unique";
278 if (resultsInterchange.size() != getResultsInterchange().size())
280 <<
"results interchange must be unique";
283 for (
unsigned index : argsInterchange) {
284 if (index >= numArgs) {
286 <<
"args interchange index " << index
287 <<
" is out of bounds for function with name '"
288 << getFunctionName() <<
"' with " << numArgs <<
" arguments";
291 for (
unsigned index : resultsInterchange) {
292 if (index >= numResults) {
294 <<
"results interchange index " << index
295 <<
" is out of bounds for function with name '"
296 << getFunctionName() <<
"' with " << numResults <<
" results";
301 rewriter, funcOp, argsInterchange.getArrayRef(),
302 resultsInterchange.getArrayRef());
303 if (failed(newFuncOpOrFailure))
305 <<
"failed to replace function signature '" << getFunctionName()
306 <<
"' with new order";
308 if (getAdjustFuncCalls()) {
310 targetModuleOp.walk([&](func::CallOp callOp) {
311 if (callOp.getCallee() == getFunctionName().getRootReference().getValue())
312 callOps.push_back(callOp);
315 for (func::CallOp callOp : callOps)
317 argsInterchange.getArrayRef(),
318 resultsInterchange.getArrayRef());
321 results.
set(cast<OpResult>(getTransformedModule()), {targetModuleOp});
322 results.
set(cast<OpResult>(getTransformedFunction()), {*newFuncOpOrFailure});
327 void transform::ReplaceFuncSignatureOp::getEffects(
339 class FuncTransformDialectExtension
341 FuncTransformDialectExtension> {
348 declareGeneratedDialect<LLVM::LLVMDialect>();
350 registerTransformOps<
352 #include "mlir/Dialect/Func/TransformOps/FuncTransformOps.cpp.inc"
358 #define GET_OP_CLASSES
359 #include "mlir/Dialect/Func/TransformOps/FuncTransformOps.cpp.inc"
static std::string diag(const llvm::Value &value)
#define MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(CLASS_NAME)
The result of a transform IR operation application.
static DiagnosedSilenceableFailure success()
Constructs a DiagnosedSilenceableFailure in the success state.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
void addExtensions()
Add the given extensions to the registry.
A class for computing basic dominance information.
This class represents a diagnostic that is inflight and set to be reported.
Conversion from types to the LLVM IR dialect.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Operation is the basic unit of execution within MLIR.
Location getLoc()
The source location the operation was defined or derived from.
user_range getUsers()
Returns a range of all users.
void replaceAllUsesExcept(Value from, Value to, Operation *exceptedUser)
Find uses of from and replace them with to except if the user is exceptedUser.
Value materializeSourceConversion(OpBuilder &builder, Location loc, Type resultType, ValueRange inputs) const
Materialize a conversion from a set of types into one result type by generating a cast sequence of so...
Value materializeTargetConversion(OpBuilder &builder, Location loc, Type resultType, ValueRange inputs, Type originalType={}) const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Base class for extensions of the Transform dialect that supports injecting operations into the Transf...
CallOp replaceCallOpWithNewOrder(RewriterBase &rewriter, CallOp callOp, llvm::ArrayRef< unsigned > newArgsOrder, llvm::ArrayRef< unsigned > newResultsOrder)
Creates a new call operation with the values as the original call operation, but with the arguments r...
void registerTransformDialectExtension(DialectRegistry ®istry)
FailureOr< FuncOp > replaceFuncWithNewOrder(RewriterBase &rewriter, FuncOp funcOp, llvm::ArrayRef< unsigned > newArgsOrder, llvm::ArrayRef< unsigned > newResultsOrder)
Creates a new function operation with the same name as the original function operation,...
Include the generated interface declarations.
DiagnosedSilenceableFailure emitSilenceableFailure(Location loc, const Twine &message={})
Emits a silenceable failure with the given message.
void populateFuncToLLVMConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns, SymbolTableCollection *symbolTables=nullptr)
Collect the patterns to convert from the Func dialect to LLVM.
DiagnosedDefiniteFailure emitDefiniteFailure(Location loc, const Twine &message={})
Emits a definite failure with the given message.
const FrozenRewritePatternSet & patterns
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...