26 void transform::ApplyFuncToLLVMConversionPatternsOp::populatePatterns(
33 transform::ApplyFuncToLLVMConversionPatternsOp::verifyTypeConverter(
34 transform::TypeConverterBuilderOpInterface builder) {
35 if (builder.getTypeConverterType() !=
"LLVMTypeConverter")
36 return emitOpError(
"expected LLVMTypeConverter");
50 llvm::append_range(inputs, state.getPayloadValues(getInputs()));
54 outputs.insert_range(state.getPayloadValues(getOutputs()));
58 llvm::range_size(state.getPayloadValues(getOutputs()))) {
60 <<
"cast and call output values must be unique";
65 auto insertionOps = state.getPayloadOps(getInsertionPoint());
66 if (!llvm::hasSingleElement(insertionOps)) {
68 <<
"Only one op can be specified as an insertion point";
70 bool insertAfter = getInsertAfter();
71 Operation *insertionPoint = *insertionOps.begin();
76 for (
Value output : outputs) {
81 bool doesDominate = insertAfter
82 ? dom.properlyDominates(insertionPoint, user)
83 : dom.dominates(insertionPoint, user);
86 <<
"User " << user <<
" is not dominated by insertion point "
92 for (
Value input : inputs) {
96 bool doesDominate = insertAfter
97 ? dom.dominates(input, insertionPoint)
98 : dom.properlyDominates(input, insertionPoint);
101 <<
"input " << input <<
" does not dominate insertion point "
108 func::FuncOp targetFunction =
nullptr;
109 if (getFunctionName()) {
110 targetFunction = SymbolTable::lookupNearestSymbolFrom<func::FuncOp>(
111 insertionPoint, *getFunctionName());
112 if (!targetFunction) {
114 <<
"unresolved symbol " << *getFunctionName();
116 }
else if (getFunction()) {
117 auto payloadOps = state.getPayloadOps(getFunction());
118 if (!llvm::hasSingleElement(payloadOps)) {
121 targetFunction = dyn_cast<func::FuncOp>(*payloadOps.begin());
122 if (!targetFunction) {
126 llvm_unreachable(
"Invalid CastAndCall op without a function to call");
132 if (targetFunction.getNumArguments() != inputs.size()) {
134 <<
"mismatch between number of function arguments "
135 << targetFunction.getNumArguments() <<
" and number of inputs "
138 if (targetFunction.getNumResults() != outputs.size()) {
140 <<
"mismatch between number of function results "
141 << targetFunction->getNumResults() <<
" and number of outputs "
147 if (!getRegion().empty()) {
148 for (
Operation &op : getRegion().front()) {
149 cast<transform::TypeConverterBuilderOpInterface>(&op)
150 .populateTypeMaterializations(converter);
159 for (
auto [input, type] :
160 llvm::zip_equal(inputs, targetFunction.getArgumentTypes())) {
161 if (input.getType() != type) {
163 rewriter, input.getLoc(), type, input);
166 << input <<
" to type " << type;
172 auto callOp = rewriter.
create<func::CallOp>(insertionPoint->
getLoc(),
173 targetFunction, inputs);
177 for (
auto [output, newOutput] :
178 llvm::zip_equal(outputs, callOp.getResults())) {
179 Value convertedOutput = newOutput;
180 if (output.getType() != newOutput.getType()) {
182 rewriter, output.getLoc(), output.getType(), newOutput);
183 if (!convertedOutput) {
185 <<
"Failed to materialize conversion of " << newOutput
186 <<
" to type " << output.getType();
191 results.
set(cast<OpResult>(getResult()), {callOp});
196 if (!getRegion().empty()) {
197 for (
Operation &op : getRegion().front()) {
198 if (!isa<transform::TypeConverterBuilderOpInterface>(&op)) {
200 <<
"expected children ops to implement "
201 "TypeConverterBuilderOpInterface";
202 diag.attachNote(op.getLoc()) <<
"op without interface";
207 if (!getFunction() && !getFunctionName()) {
208 return emitOpError() <<
"expected a function handle or name to call";
210 if (getFunction() && getFunctionName()) {
211 return emitOpError() <<
"function handle and name are mutually exclusive";
216 void transform::CastAndCallOp::getEffects(
234 class FuncTransformDialectExtension
236 FuncTransformDialectExtension> {
243 declareGeneratedDialect<LLVM::LLVMDialect>();
245 registerTransformOps<
247 #include "mlir/Dialect/Func/TransformOps/FuncTransformOps.cpp.inc"
253 #define GET_OP_CLASSES
254 #include "mlir/Dialect/Func/TransformOps/FuncTransformOps.cpp.inc"
static std::string diag(const llvm::Value &value)
#define MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(CLASS_NAME)
The result of a transform IR operation application.
static DiagnosedSilenceableFailure success()
Constructs a DiagnosedSilenceableFailure in the success state.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
void addExtensions()
Add the given extensions to the registry.
A class for computing basic dominance information.
This class represents a diagnostic that is inflight and set to be reported.
Conversion from types to the LLVM IR dialect.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Operation is the basic unit of execution within MLIR.
Location getLoc()
The source location the operation was defined or derived from.
user_range getUsers()
Returns a range of all users.
void replaceAllUsesExcept(Value from, Value to, Operation *exceptedUser)
Find uses of from and replace them with to except if the user is exceptedUser.
Value materializeSourceConversion(OpBuilder &builder, Location loc, Type resultType, ValueRange inputs) const
Materialize a conversion from a set of types into one result type by generating a cast sequence of so...
Value materializeTargetConversion(OpBuilder &builder, Location loc, Type resultType, ValueRange inputs, Type originalType={}) const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Base class for extensions of the Transform dialect that supports injecting operations into the Transf...
void registerTransformDialectExtension(DialectRegistry ®istry)
Include the generated interface declarations.
DiagnosedSilenceableFailure emitSilenceableFailure(Location loc, const Twine &message={})
Emits a silenceable failure with the given message.
DiagnosedDefiniteFailure emitDefiniteFailure(Location loc, const Twine &message={})
Emits a definite failure with the given message.
const FrozenRewritePatternSet & patterns
void populateFuncToLLVMConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns, const SymbolTable *symbolTable=nullptr)
Collect the patterns to convert from the Func dialect to LLVM.
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...