26 void transform::ApplyFuncToLLVMConversionPatternsOp::populatePatterns(
33 transform::ApplyFuncToLLVMConversionPatternsOp::verifyTypeConverter(
34 transform::TypeConverterBuilderOpInterface builder) {
35 if (builder.getTypeConverterType() !=
"LLVMTypeConverter")
36 return emitOpError(
"expected LLVMTypeConverter");
50 llvm::append_range(inputs, state.getPayloadValues(getInputs()));
54 for (
auto output : state.getPayloadValues(getOutputs()))
55 outputs.insert(output);
59 llvm::range_size(state.getPayloadValues(getOutputs()))) {
61 <<
"cast and call output values must be unique";
66 auto insertionOps = state.getPayloadOps(getInsertionPoint());
67 if (!llvm::hasSingleElement(insertionOps)) {
69 <<
"Only one op can be specified as an insertion point";
71 bool insertAfter = getInsertAfter();
72 Operation *insertionPoint = *insertionOps.begin();
77 for (
Value output : outputs) {
82 bool doesDominate = insertAfter
83 ? dom.properlyDominates(insertionPoint, user)
84 : dom.dominates(insertionPoint, user);
87 <<
"User " << user <<
" is not dominated by insertion point "
93 for (
Value input : inputs) {
97 bool doesDominate = insertAfter
98 ? dom.dominates(input, insertionPoint)
99 : dom.properlyDominates(input, insertionPoint);
102 <<
"input " << input <<
" does not dominate insertion point "
109 func::FuncOp targetFunction =
nullptr;
110 if (getFunctionName()) {
111 targetFunction = SymbolTable::lookupNearestSymbolFrom<func::FuncOp>(
112 insertionPoint, *getFunctionName());
113 if (!targetFunction) {
115 <<
"unresolved symbol " << *getFunctionName();
117 }
else if (getFunction()) {
118 auto payloadOps = state.getPayloadOps(getFunction());
119 if (!llvm::hasSingleElement(payloadOps)) {
122 targetFunction = dyn_cast<func::FuncOp>(*payloadOps.begin());
123 if (!targetFunction) {
127 llvm_unreachable(
"Invalid CastAndCall op without a function to call");
133 if (targetFunction.getNumArguments() != inputs.size()) {
135 <<
"mismatch between number of function arguments "
136 << targetFunction.getNumArguments() <<
" and number of inputs "
139 if (targetFunction.getNumResults() != outputs.size()) {
141 <<
"mismatch between number of function results "
142 << targetFunction->getNumResults() <<
" and number of outputs "
148 if (!getRegion().empty()) {
149 for (
Operation &op : getRegion().front()) {
150 cast<transform::TypeConverterBuilderOpInterface>(&op)
151 .populateTypeMaterializations(converter);
160 for (
auto [input, type] :
161 llvm::zip_equal(inputs, targetFunction.getArgumentTypes())) {
162 if (input.getType() != type) {
164 rewriter, input.getLoc(), type, input);
167 << input <<
" to type " << type;
173 auto callOp = rewriter.
create<func::CallOp>(insertionPoint->
getLoc(),
174 targetFunction, inputs);
178 for (
auto [output, newOutput] :
179 llvm::zip_equal(outputs, callOp.getResults())) {
180 Value convertedOutput = newOutput;
181 if (output.getType() != newOutput.getType()) {
183 rewriter, output.getLoc(), output.getType(), newOutput);
184 if (!convertedOutput) {
186 <<
"Failed to materialize conversion of " << newOutput
187 <<
" to type " << output.getType();
192 results.
set(cast<OpResult>(getResult()), {callOp});
197 if (!getRegion().empty()) {
198 for (
Operation &op : getRegion().front()) {
199 if (!isa<transform::TypeConverterBuilderOpInterface>(&op)) {
201 <<
"expected children ops to implement "
202 "TypeConverterBuilderOpInterface";
203 diag.attachNote(op.getLoc()) <<
"op without interface";
208 if (!getFunction() && !getFunctionName()) {
209 return emitOpError() <<
"expected a function handle or name to call";
211 if (getFunction() && getFunctionName()) {
212 return emitOpError() <<
"function handle and name are mutually exclusive";
217 void transform::CastAndCallOp::getEffects(
235 class FuncTransformDialectExtension
237 FuncTransformDialectExtension> {
244 declareGeneratedDialect<LLVM::LLVMDialect>();
246 registerTransformOps<
248 #include "mlir/Dialect/Func/TransformOps/FuncTransformOps.cpp.inc"
254 #define GET_OP_CLASSES
255 #include "mlir/Dialect/Func/TransformOps/FuncTransformOps.cpp.inc"
static std::string diag(const llvm::Value &value)
#define MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(CLASS_NAME)
The result of a transform IR operation application.
static DiagnosedSilenceableFailure success()
Constructs a DiagnosedSilenceableFailure in the success state.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
void addExtensions()
Add the given extensions to the registry.
A class for computing basic dominance information.
This class represents a diagnostic that is inflight and set to be reported.
Conversion from types to the LLVM IR dialect.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Operation is the basic unit of execution within MLIR.
Location getLoc()
The source location the operation was defined or derived from.
user_range getUsers()
Returns a range of all users.
void replaceAllUsesExcept(Value from, Value to, Operation *exceptedUser)
Find uses of from and replace them with to except if the user is exceptedUser.
Value materializeSourceConversion(OpBuilder &builder, Location loc, Type resultType, ValueRange inputs) const
Value materializeTargetConversion(OpBuilder &builder, Location loc, Type resultType, ValueRange inputs, Type originalType={}) const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Base class for extensions of the Transform dialect that supports injecting operations into the Transf...
void registerTransformDialectExtension(DialectRegistry ®istry)
Include the generated interface declarations.
DiagnosedSilenceableFailure emitSilenceableFailure(Location loc, const Twine &message={})
Emits a silenceable failure with the given message.
DiagnosedDefiniteFailure emitDefiniteFailure(Location loc, const Twine &message={})
Emits a definite failure with the given message.
void populateFuncToLLVMConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns, const SymbolTable *symbolTable=nullptr)
Collect the patterns to convert from the Func dialect to LLVM.
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...