20#include "llvm/ADT/STLExtras.h"
28void transform::ApplyFuncToLLVMConversionPatternsOp::populatePatterns(
35transform::ApplyFuncToLLVMConversionPatternsOp::verifyTypeConverter(
36 transform::TypeConverterBuilderOpInterface builder) {
37 if (builder.getTypeConverterType() !=
"LLVMTypeConverter")
62 <<
"cast and call output values must be unique";
68 if (!llvm::hasSingleElement(insertionOps)) {
70 <<
"Only one op can be specified as an insertion point";
72 bool insertAfter = getInsertAfter();
73 Operation *insertionPoint = *insertionOps.begin();
78 for (
Value output : outputs) {
83 bool doesDominate = insertAfter
84 ? dom.properlyDominates(insertionPoint, user)
85 : dom.dominates(insertionPoint, user);
88 <<
"User " << user <<
" is not dominated by insertion point "
94 for (
Value input : inputs) {
98 bool doesDominate = insertAfter
99 ? dom.dominates(input, insertionPoint)
100 : dom.properlyDominates(input, insertionPoint);
103 <<
"input " << input <<
" does not dominate insertion point "
110 func::FuncOp targetFunction =
nullptr;
111 if (getFunctionName()) {
113 insertionPoint, *getFunctionName());
114 if (!targetFunction) {
116 <<
"unresolved symbol " << *getFunctionName();
118 }
else if (getFunction()) {
120 if (!llvm::hasSingleElement(payloadOps)) {
123 targetFunction = dyn_cast<func::FuncOp>(*payloadOps.begin());
124 if (!targetFunction) {
128 llvm_unreachable(
"Invalid CastAndCall op without a function to call");
134 if (targetFunction.getNumArguments() != inputs.size()) {
136 <<
"mismatch between number of function arguments "
137 << targetFunction.getNumArguments() <<
" and number of inputs "
140 if (targetFunction.getNumResults() != outputs.size()) {
142 <<
"mismatch between number of function results "
143 << targetFunction->getNumResults() <<
" and number of outputs "
148 mlir::TypeConverter converter;
149 if (!getRegion().empty()) {
150 for (
Operation &op : getRegion().front()) {
151 cast<transform::TypeConverterBuilderOpInterface>(&op)
152 .populateTypeMaterializations(converter);
161 for (
auto [input, type] :
162 llvm::zip_equal(inputs, targetFunction.getArgumentTypes())) {
163 if (input.getType() != type) {
164 Value newInput = converter.materializeSourceConversion(
165 rewriter, input.getLoc(), type, input);
168 << input <<
" to type " << type;
174 auto callOp = func::CallOp::create(rewriter, insertionPoint->
getLoc(),
175 targetFunction, inputs);
179 for (
auto [output, newOutput] :
180 llvm::zip_equal(outputs, callOp.getResults())) {
181 Value convertedOutput = newOutput;
182 if (output.getType() != newOutput.getType()) {
183 convertedOutput = converter.materializeTargetConversion(
184 rewriter, output.getLoc(), output.getType(), newOutput);
185 if (!convertedOutput) {
187 <<
"Failed to materialize conversion of " << newOutput
188 <<
" to type " << output.getType();
193 results.
set(cast<OpResult>(getResult()), {callOp});
197LogicalResult transform::CastAndCallOp::verify() {
198 if (!getRegion().empty()) {
199 for (
Operation &op : getRegion().front()) {
200 if (!isa<transform::TypeConverterBuilderOpInterface>(&op)) {
202 <<
"expected children ops to implement "
203 "TypeConverterBuilderOpInterface";
204 diag.attachNote(op.getLoc()) <<
"op without interface";
209 if (!getFunction() && !getFunctionName()) {
210 return emitOpError() <<
"expected a function handle or name to call";
212 if (getFunction() && getFunctionName()) {
213 return emitOpError() <<
"function handle and name are mutually exclusive";
218void transform::CastAndCallOp::getEffects(
240 if (!llvm::hasSingleElement(payloadOps))
243 auto targetModuleOp = dyn_cast<ModuleOp>(*payloadOps.begin());
246 <<
"target is expected to be module operation";
248 func::FuncOp funcOp =
249 targetModuleOp.lookupSymbol<func::FuncOp>(getFunctionName());
252 <<
"function with name '" << getFunctionName() <<
"' not found";
254 unsigned numArgs = funcOp.getNumArguments();
255 unsigned numResults = funcOp.getNumResults();
258 if (numArgs != getArgsInterchange().size())
260 <<
"function with name '" << getFunctionName() <<
"' has " << numArgs
261 <<
" arguments, but " << getArgsInterchange().size()
262 <<
" args interchange were given";
264 if (numResults != getResultsInterchange().size())
266 <<
"function with name '" << getFunctionName() <<
"' has "
267 << numResults <<
" results, but " << getResultsInterchange().size()
268 <<
" results interchange were given";
272 argsInterchange.insert_range(getArgsInterchange());
273 resultsInterchange.insert_range(getResultsInterchange());
274 if (argsInterchange.size() != getArgsInterchange().size())
276 <<
"args interchange must be unique";
278 if (resultsInterchange.size() != getResultsInterchange().size())
280 <<
"results interchange must be unique";
283 for (
unsigned index : argsInterchange) {
284 if (
index >= numArgs) {
286 <<
"args interchange index " <<
index
287 <<
" is out of bounds for function with name '"
288 << getFunctionName() <<
"' with " << numArgs <<
" arguments";
291 for (
unsigned index : resultsInterchange) {
292 if (
index >= numResults) {
294 <<
"results interchange index " <<
index
295 <<
" is out of bounds for function with name '"
296 << getFunctionName() <<
"' with " << numResults <<
" results";
301 for (
auto [newArgIdx, oldArgIdx] : llvm::enumerate(argsInterchange))
302 oldArgToNewArg[oldArgIdx] = newArgIdx;
305 for (
auto [newResIdx, oldResIdx] : llvm::enumerate(resultsInterchange))
306 oldResToNewRes[oldResIdx] = newResIdx;
309 rewriter, funcOp, oldArgToNewArg, oldResToNewRes);
310 if (
failed(newFuncOpOrFailure))
312 <<
"failed to replace function signature '" << getFunctionName()
313 <<
"' with new order";
315 if (getAdjustFuncCalls()) {
317 targetModuleOp.walk([&](func::CallOp callOp) {
318 if (callOp.getCallee() == getFunctionName().getRootReference().getValue())
319 callOps.push_back(callOp);
322 for (func::CallOp callOp : callOps)
327 results.
set(cast<OpResult>(getTransformedModule()), {targetModuleOp});
328 results.
set(cast<OpResult>(getTransformedFunction()), {*newFuncOpOrFailure});
333void transform::ReplaceFuncSignatureOp::getEffects(
349 if (!llvm::hasSingleElement(payloadOps))
352 auto targetModuleOp = dyn_cast<ModuleOp>(*payloadOps.begin());
355 <<
"target is expected to be module operation";
357 func::FuncOp funcOp =
358 targetModuleOp.lookupSymbol<func::FuncOp>(getFunctionName());
361 <<
"function with name '" << getFunctionName() <<
"' is not found";
363 auto transformationResult =
365 if (
failed(transformationResult))
367 <<
"failed to deduplicate function arguments of function "
370 auto [newFuncOp, newCallOp] = *transformationResult;
372 results.
set(cast<OpResult>(getTransformedModule()), {targetModuleOp});
373 results.
set(cast<OpResult>(getTransformedFunction()), {newFuncOp});
378void transform::DeduplicateFuncArgsOp::getEffects(
390class FuncTransformDialectExtension
392 FuncTransformDialectExtension> {
399 declareGeneratedDialect<LLVM::LLVMDialect>();
401 registerTransformOps<
403#include "mlir/Dialect/Func/TransformOps/FuncTransformOps.cpp.inc"
409#define GET_OP_CLASSES
410#include "mlir/Dialect/Func/TransformOps/FuncTransformOps.cpp.inc"
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
static std::string diag(const llvm::Value &value)
#define MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(CLASS_NAME)
The result of a transform IR operation application.
static DiagnosedSilenceableFailure success()
Constructs a DiagnosedSilenceableFailure in the success state.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
void addExtensions()
Add the given extensions to the registry.
A class for computing basic dominance information.
This class represents a diagnostic that is inflight and set to be reported.
Conversion from types to the LLVM IR dialect.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Operation is the basic unit of execution within MLIR.
Location getLoc()
The source location the operation was defined or derived from.
user_range getUsers()
Returns a range of all users.
void replaceAllUsesExcept(Value from, Value to, Operation *exceptedUser)
Find uses of from and replace them with to except if the user is exceptedUser.
static Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Base class for extensions of the Transform dialect that supports injecting operations into the Transf...
mlir::func::CallOp replaceCallOpWithNewMapping(mlir::RewriterBase &rewriter, mlir::func::CallOp callOp, ArrayRef< int > oldArgIdxToNewArgIdx, ArrayRef< int > oldResIdxToNewResIdx)
Creates a new call operation with the values as the original call operation, but with the arguments m...
void registerTransformDialectExtension(DialectRegistry ®istry)
mlir::FailureOr< mlir::func::FuncOp > replaceFuncWithNewMapping(mlir::RewriterBase &rewriter, mlir::func::FuncOp funcOp, ArrayRef< int > oldArgIdxToNewArgIdx, ArrayRef< int > oldResIdxToNewResIdx)
Creates a new function operation with the same name as the original function operation,...
mlir::FailureOr< std::pair< mlir::func::FuncOp, mlir::func::CallOp > > deduplicateArgsOfFuncOp(mlir::RewriterBase &rewriter, mlir::func::FuncOp funcOp, mlir::ModuleOp moduleOp)
This utility function examines all call operations within the given moduleOp that target the specifie...
Include the generated interface declarations.
DiagnosedSilenceableFailure emitSilenceableFailure(Location loc, const Twine &message={})
Emits a silenceable failure with the given message.
void populateFuncToLLVMConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns, SymbolTableCollection *symbolTables=nullptr)
Collect the patterns to convert from the Func dialect to LLVM.
llvm::SetVector< T, Vector, Set, N > SetVector
DiagnosedDefiniteFailure emitDefiniteFailure(Location loc, const Twine &message={})
Emits a definite failure with the given message.
const FrozenRewritePatternSet & patterns