MLIR  21.0.0git
FuncTransformOps.cpp
Go to the documentation of this file.
1 //===- FuncTransformOps.cpp - Implementation of CF transform ops ---===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
19 
20 using namespace mlir;
21 
22 //===----------------------------------------------------------------------===//
23 // Apply...ConversionPatternsOp
24 //===----------------------------------------------------------------------===//
25 
26 void transform::ApplyFuncToLLVMConversionPatternsOp::populatePatterns(
27  TypeConverter &typeConverter, RewritePatternSet &patterns) {
29  static_cast<LLVMTypeConverter &>(typeConverter), patterns);
30 }
31 
32 LogicalResult
33 transform::ApplyFuncToLLVMConversionPatternsOp::verifyTypeConverter(
34  transform::TypeConverterBuilderOpInterface builder) {
35  if (builder.getTypeConverterType() != "LLVMTypeConverter")
36  return emitOpError("expected LLVMTypeConverter");
37  return success();
38 }
39 
40 //===----------------------------------------------------------------------===//
41 // CastAndCallOp
42 //===----------------------------------------------------------------------===//
43 
45 transform::CastAndCallOp::apply(transform::TransformRewriter &rewriter,
48  SmallVector<Value> inputs;
49  if (getInputs())
50  llvm::append_range(inputs, state.getPayloadValues(getInputs()));
51 
52  SetVector<Value> outputs;
53  if (getOutputs()) {
54  outputs.insert_range(state.getPayloadValues(getOutputs()));
55 
56  // Verify that the set of output values to be replaced is unique.
57  if (outputs.size() !=
58  llvm::range_size(state.getPayloadValues(getOutputs()))) {
59  return emitSilenceableFailure(getLoc())
60  << "cast and call output values must be unique";
61  }
62  }
63 
64  // Get the insertion point for the call.
65  auto insertionOps = state.getPayloadOps(getInsertionPoint());
66  if (!llvm::hasSingleElement(insertionOps)) {
67  return emitSilenceableFailure(getLoc())
68  << "Only one op can be specified as an insertion point";
69  }
70  bool insertAfter = getInsertAfter();
71  Operation *insertionPoint = *insertionOps.begin();
72 
73  // Check that all inputs dominate the insertion point, and the insertion
74  // point dominates all users of the outputs.
75  DominanceInfo dom(insertionPoint);
76  for (Value output : outputs) {
77  for (Operation *user : output.getUsers()) {
78  // If we are inserting after the insertion point operation, the
79  // insertion point operation must properly dominate the user. Otherwise
80  // basic dominance is enough.
81  bool doesDominate = insertAfter
82  ? dom.properlyDominates(insertionPoint, user)
83  : dom.dominates(insertionPoint, user);
84  if (!doesDominate) {
85  return emitDefiniteFailure()
86  << "User " << user << " is not dominated by insertion point "
87  << insertionPoint;
88  }
89  }
90  }
91 
92  for (Value input : inputs) {
93  // If we are inserting before the insertion point operation, the
94  // input must properly dominate the insertion point operation. Otherwise
95  // basic dominance is enough.
96  bool doesDominate = insertAfter
97  ? dom.dominates(input, insertionPoint)
98  : dom.properlyDominates(input, insertionPoint);
99  if (!doesDominate) {
100  return emitDefiniteFailure()
101  << "input " << input << " does not dominate insertion point "
102  << insertionPoint;
103  }
104  }
105 
106  // Get the function to call. This can either be specified by symbol or as a
107  // transform handle.
108  func::FuncOp targetFunction = nullptr;
109  if (getFunctionName()) {
110  targetFunction = SymbolTable::lookupNearestSymbolFrom<func::FuncOp>(
111  insertionPoint, *getFunctionName());
112  if (!targetFunction) {
113  return emitDefiniteFailure()
114  << "unresolved symbol " << *getFunctionName();
115  }
116  } else if (getFunction()) {
117  auto payloadOps = state.getPayloadOps(getFunction());
118  if (!llvm::hasSingleElement(payloadOps)) {
119  return emitDefiniteFailure() << "requires a single function to call";
120  }
121  targetFunction = dyn_cast<func::FuncOp>(*payloadOps.begin());
122  if (!targetFunction) {
123  return emitDefiniteFailure() << "invalid non-function callee";
124  }
125  } else {
126  llvm_unreachable("Invalid CastAndCall op without a function to call");
127  return emitDefiniteFailure();
128  }
129 
130  // Verify that the function argument and result lengths match the inputs and
131  // outputs given to this op.
132  if (targetFunction.getNumArguments() != inputs.size()) {
133  return emitSilenceableFailure(targetFunction.getLoc())
134  << "mismatch between number of function arguments "
135  << targetFunction.getNumArguments() << " and number of inputs "
136  << inputs.size();
137  }
138  if (targetFunction.getNumResults() != outputs.size()) {
139  return emitSilenceableFailure(targetFunction.getLoc())
140  << "mismatch between number of function results "
141  << targetFunction->getNumResults() << " and number of outputs "
142  << outputs.size();
143  }
144 
145  // Gather all specified converters.
146  mlir::TypeConverter converter;
147  if (!getRegion().empty()) {
148  for (Operation &op : getRegion().front()) {
149  cast<transform::TypeConverterBuilderOpInterface>(&op)
150  .populateTypeMaterializations(converter);
151  }
152  }
153 
154  if (insertAfter)
155  rewriter.setInsertionPointAfter(insertionPoint);
156  else
157  rewriter.setInsertionPoint(insertionPoint);
158 
159  for (auto [input, type] :
160  llvm::zip_equal(inputs, targetFunction.getArgumentTypes())) {
161  if (input.getType() != type) {
162  Value newInput = converter.materializeSourceConversion(
163  rewriter, input.getLoc(), type, input);
164  if (!newInput) {
165  return emitDefiniteFailure() << "Failed to materialize conversion of "
166  << input << " to type " << type;
167  }
168  input = newInput;
169  }
170  }
171 
172  auto callOp = rewriter.create<func::CallOp>(insertionPoint->getLoc(),
173  targetFunction, inputs);
174 
175  // Cast the call results back to the expected types. If any conversions fail
176  // this is a definite failure as the call has been constructed at this point.
177  for (auto [output, newOutput] :
178  llvm::zip_equal(outputs, callOp.getResults())) {
179  Value convertedOutput = newOutput;
180  if (output.getType() != newOutput.getType()) {
181  convertedOutput = converter.materializeTargetConversion(
182  rewriter, output.getLoc(), output.getType(), newOutput);
183  if (!convertedOutput) {
184  return emitDefiniteFailure()
185  << "Failed to materialize conversion of " << newOutput
186  << " to type " << output.getType();
187  }
188  }
189  rewriter.replaceAllUsesExcept(output, convertedOutput, callOp);
190  }
191  results.set(cast<OpResult>(getResult()), {callOp});
193 }
194 
195 LogicalResult transform::CastAndCallOp::verify() {
196  if (!getRegion().empty()) {
197  for (Operation &op : getRegion().front()) {
198  if (!isa<transform::TypeConverterBuilderOpInterface>(&op)) {
199  InFlightDiagnostic diag = emitOpError()
200  << "expected children ops to implement "
201  "TypeConverterBuilderOpInterface";
202  diag.attachNote(op.getLoc()) << "op without interface";
203  return diag;
204  }
205  }
206  }
207  if (!getFunction() && !getFunctionName()) {
208  return emitOpError() << "expected a function handle or name to call";
209  }
210  if (getFunction() && getFunctionName()) {
211  return emitOpError() << "function handle and name are mutually exclusive";
212  }
213  return success();
214 }
215 
216 void transform::CastAndCallOp::getEffects(
218  transform::onlyReadsHandle(getInsertionPointMutable(), effects);
219  if (getInputs())
220  transform::onlyReadsHandle(getInputsMutable(), effects);
221  if (getOutputs())
222  transform::onlyReadsHandle(getOutputsMutable(), effects);
223  if (getFunction())
224  transform::onlyReadsHandle(getFunctionMutable(), effects);
225  transform::producesHandle(getOperation()->getOpResults(), effects);
227 }
228 
229 //===----------------------------------------------------------------------===//
230 // Transform op registration
231 //===----------------------------------------------------------------------===//
232 
233 namespace {
234 class FuncTransformDialectExtension
236  FuncTransformDialectExtension> {
237 public:
238  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(FuncTransformDialectExtension)
239 
240  using Base::Base;
241 
242  void init() {
243  declareGeneratedDialect<LLVM::LLVMDialect>();
244 
245  registerTransformOps<
246 #define GET_OP_LIST
247 #include "mlir/Dialect/Func/TransformOps/FuncTransformOps.cpp.inc"
248  >();
249  }
250 };
251 } // namespace
252 
253 #define GET_OP_CLASSES
254 #include "mlir/Dialect/Func/TransformOps/FuncTransformOps.cpp.inc"
255 
257  registry.addExtensions<FuncTransformDialectExtension>();
258 }
static std::string diag(const llvm::Value &value)
#define MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(CLASS_NAME)
Definition: TypeID.h:331
The result of a transform IR operation application.
static DiagnosedSilenceableFailure success()
Constructs a DiagnosedSilenceableFailure in the success state.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
void addExtensions()
Add the given extensions to the registry.
A class for computing basic dominance information.
Definition: Dominance.h:140
This class represents a diagnostic that is inflight and set to be reported.
Definition: Diagnostics.h:314
Conversion from types to the LLVM IR dialect.
Definition: TypeConverter.h:35
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:396
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:453
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition: Builders.h:410
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
user_range getUsers()
Returns a range of all users.
Definition: Operation.h:874
void replaceAllUsesExcept(Value from, Value to, Operation *exceptedUser)
Find uses of from and replace them with to except if the user is exceptedUser.
Definition: PatternMatch.h:666
Type conversion class.
Value materializeSourceConversion(OpBuilder &builder, Location loc, Type resultType, ValueRange inputs) const
Materialize a conversion from a set of types into one result type by generating a cast sequence of so...
Value materializeTargetConversion(OpBuilder &builder, Location loc, Type resultType, ValueRange inputs, Type originalType={}) const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Base class for extensions of the Transform dialect that supports injecting operations into the Transf...
Local mapping between values defined by a specific op implementing the TransformOpInterface and the p...
void set(OpResult value, Range &&ops)
Indicates that the result of the transform IR op at the given position corresponds to the given list ...
This is a special rewriter to be used in transform op implementations, providing additional helper fu...
The state maintained across applications of various ops implementing the TransformOpInterface.
void registerTransformDialectExtension(DialectRegistry &registry)
void producesHandle(ResultRange handles, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
void onlyReadsHandle(MutableArrayRef< OpOperand > handles, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
void modifiesPayload(SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
Populates effects with the memory effects indicating the access to payload IR resource.
Include the generated interface declarations.
DiagnosedSilenceableFailure emitSilenceableFailure(Location loc, const Twine &message={})
Emits a silenceable failure with the given message.
DiagnosedDefiniteFailure emitDefiniteFailure(Location loc, const Twine &message={})
Emits a definite failure with the given message.
const FrozenRewritePatternSet & patterns
void populateFuncToLLVMConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns, const SymbolTable *symbolTable=nullptr)
Collect the patterns to convert from the Func dialect to LLVM.
Definition: FuncToLLVM.cpp:742
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:424