MLIR  22.0.0git
FuncTransformOps.cpp
Go to the documentation of this file.
1 //===- FuncTransformOps.cpp - Implementation of CF transform ops ----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
18 #include "mlir/IR/PatternMatch.h"
20 
21 using namespace mlir;
22 
23 //===----------------------------------------------------------------------===//
24 // Apply...ConversionPatternsOp
25 //===----------------------------------------------------------------------===//
26 
27 void transform::ApplyFuncToLLVMConversionPatternsOp::populatePatterns(
28  TypeConverter &typeConverter, RewritePatternSet &patterns) {
30  static_cast<LLVMTypeConverter &>(typeConverter), patterns);
31 }
32 
33 LogicalResult
34 transform::ApplyFuncToLLVMConversionPatternsOp::verifyTypeConverter(
35  transform::TypeConverterBuilderOpInterface builder) {
36  if (builder.getTypeConverterType() != "LLVMTypeConverter")
37  return emitOpError("expected LLVMTypeConverter");
38  return success();
39 }
40 
41 //===----------------------------------------------------------------------===//
42 // CastAndCallOp
43 //===----------------------------------------------------------------------===//
44 
46 transform::CastAndCallOp::apply(transform::TransformRewriter &rewriter,
49  SmallVector<Value> inputs;
50  if (getInputs())
51  llvm::append_range(inputs, state.getPayloadValues(getInputs()));
52 
53  SetVector<Value> outputs;
54  if (getOutputs()) {
55  outputs.insert_range(state.getPayloadValues(getOutputs()));
56 
57  // Verify that the set of output values to be replaced is unique.
58  if (outputs.size() !=
59  llvm::range_size(state.getPayloadValues(getOutputs()))) {
60  return emitSilenceableFailure(getLoc())
61  << "cast and call output values must be unique";
62  }
63  }
64 
65  // Get the insertion point for the call.
66  auto insertionOps = state.getPayloadOps(getInsertionPoint());
67  if (!llvm::hasSingleElement(insertionOps)) {
68  return emitSilenceableFailure(getLoc())
69  << "Only one op can be specified as an insertion point";
70  }
71  bool insertAfter = getInsertAfter();
72  Operation *insertionPoint = *insertionOps.begin();
73 
74  // Check that all inputs dominate the insertion point, and the insertion
75  // point dominates all users of the outputs.
76  DominanceInfo dom(insertionPoint);
77  for (Value output : outputs) {
78  for (Operation *user : output.getUsers()) {
79  // If we are inserting after the insertion point operation, the
80  // insertion point operation must properly dominate the user. Otherwise
81  // basic dominance is enough.
82  bool doesDominate = insertAfter
83  ? dom.properlyDominates(insertionPoint, user)
84  : dom.dominates(insertionPoint, user);
85  if (!doesDominate) {
86  return emitDefiniteFailure()
87  << "User " << user << " is not dominated by insertion point "
88  << insertionPoint;
89  }
90  }
91  }
92 
93  for (Value input : inputs) {
94  // If we are inserting before the insertion point operation, the
95  // input must properly dominate the insertion point operation. Otherwise
96  // basic dominance is enough.
97  bool doesDominate = insertAfter
98  ? dom.dominates(input, insertionPoint)
99  : dom.properlyDominates(input, insertionPoint);
100  if (!doesDominate) {
101  return emitDefiniteFailure()
102  << "input " << input << " does not dominate insertion point "
103  << insertionPoint;
104  }
105  }
106 
107  // Get the function to call. This can either be specified by symbol or as a
108  // transform handle.
109  func::FuncOp targetFunction = nullptr;
110  if (getFunctionName()) {
111  targetFunction = SymbolTable::lookupNearestSymbolFrom<func::FuncOp>(
112  insertionPoint, *getFunctionName());
113  if (!targetFunction) {
114  return emitDefiniteFailure()
115  << "unresolved symbol " << *getFunctionName();
116  }
117  } else if (getFunction()) {
118  auto payloadOps = state.getPayloadOps(getFunction());
119  if (!llvm::hasSingleElement(payloadOps)) {
120  return emitDefiniteFailure() << "requires a single function to call";
121  }
122  targetFunction = dyn_cast<func::FuncOp>(*payloadOps.begin());
123  if (!targetFunction) {
124  return emitDefiniteFailure() << "invalid non-function callee";
125  }
126  } else {
127  llvm_unreachable("Invalid CastAndCall op without a function to call");
128  return emitDefiniteFailure();
129  }
130 
131  // Verify that the function argument and result lengths match the inputs and
132  // outputs given to this op.
133  if (targetFunction.getNumArguments() != inputs.size()) {
134  return emitSilenceableFailure(targetFunction.getLoc())
135  << "mismatch between number of function arguments "
136  << targetFunction.getNumArguments() << " and number of inputs "
137  << inputs.size();
138  }
139  if (targetFunction.getNumResults() != outputs.size()) {
140  return emitSilenceableFailure(targetFunction.getLoc())
141  << "mismatch between number of function results "
142  << targetFunction->getNumResults() << " and number of outputs "
143  << outputs.size();
144  }
145 
146  // Gather all specified converters.
147  mlir::TypeConverter converter;
148  if (!getRegion().empty()) {
149  for (Operation &op : getRegion().front()) {
150  cast<transform::TypeConverterBuilderOpInterface>(&op)
151  .populateTypeMaterializations(converter);
152  }
153  }
154 
155  if (insertAfter)
156  rewriter.setInsertionPointAfter(insertionPoint);
157  else
158  rewriter.setInsertionPoint(insertionPoint);
159 
160  for (auto [input, type] :
161  llvm::zip_equal(inputs, targetFunction.getArgumentTypes())) {
162  if (input.getType() != type) {
163  Value newInput = converter.materializeSourceConversion(
164  rewriter, input.getLoc(), type, input);
165  if (!newInput) {
166  return emitDefiniteFailure() << "Failed to materialize conversion of "
167  << input << " to type " << type;
168  }
169  input = newInput;
170  }
171  }
172 
173  auto callOp = func::CallOp::create(rewriter, insertionPoint->getLoc(),
174  targetFunction, inputs);
175 
176  // Cast the call results back to the expected types. If any conversions fail
177  // this is a definite failure as the call has been constructed at this point.
178  for (auto [output, newOutput] :
179  llvm::zip_equal(outputs, callOp.getResults())) {
180  Value convertedOutput = newOutput;
181  if (output.getType() != newOutput.getType()) {
182  convertedOutput = converter.materializeTargetConversion(
183  rewriter, output.getLoc(), output.getType(), newOutput);
184  if (!convertedOutput) {
185  return emitDefiniteFailure()
186  << "Failed to materialize conversion of " << newOutput
187  << " to type " << output.getType();
188  }
189  }
190  rewriter.replaceAllUsesExcept(output, convertedOutput, callOp);
191  }
192  results.set(cast<OpResult>(getResult()), {callOp});
194 }
195 
196 LogicalResult transform::CastAndCallOp::verify() {
197  if (!getRegion().empty()) {
198  for (Operation &op : getRegion().front()) {
199  if (!isa<transform::TypeConverterBuilderOpInterface>(&op)) {
200  InFlightDiagnostic diag = emitOpError()
201  << "expected children ops to implement "
202  "TypeConverterBuilderOpInterface";
203  diag.attachNote(op.getLoc()) << "op without interface";
204  return diag;
205  }
206  }
207  }
208  if (!getFunction() && !getFunctionName()) {
209  return emitOpError() << "expected a function handle or name to call";
210  }
211  if (getFunction() && getFunctionName()) {
212  return emitOpError() << "function handle and name are mutually exclusive";
213  }
214  return success();
215 }
216 
217 void transform::CastAndCallOp::getEffects(
219  transform::onlyReadsHandle(getInsertionPointMutable(), effects);
220  if (getInputs())
221  transform::onlyReadsHandle(getInputsMutable(), effects);
222  if (getOutputs())
223  transform::onlyReadsHandle(getOutputsMutable(), effects);
224  if (getFunction())
225  transform::onlyReadsHandle(getFunctionMutable(), effects);
226  transform::producesHandle(getOperation()->getOpResults(), effects);
228 }
229 
230 //===----------------------------------------------------------------------===//
231 // ReplaceFuncSignatureOp
232 //===----------------------------------------------------------------------===//
233 
235 transform::ReplaceFuncSignatureOp::apply(transform::TransformRewriter &rewriter,
237  transform::TransformState &state) {
238  auto payloadOps = state.getPayloadOps(getModule());
239  if (!llvm::hasSingleElement(payloadOps))
240  return emitDefiniteFailure() << "requires a single module to operate on";
241 
242  auto targetModuleOp = dyn_cast<ModuleOp>(*payloadOps.begin());
243  if (!targetModuleOp)
244  return emitSilenceableFailure(getLoc())
245  << "target is expected to be module operation";
246 
247  func::FuncOp funcOp =
248  targetModuleOp.lookupSymbol<func::FuncOp>(getFunctionName());
249  if (!funcOp)
250  return emitSilenceableFailure(getLoc())
251  << "function with name '" << getFunctionName() << "' not found";
252 
253  unsigned numArgs = funcOp.getNumArguments();
254  unsigned numResults = funcOp.getNumResults();
255  // Check that the number of arguments and results matches the
256  // interchange sizes.
257  if (numArgs != getArgsInterchange().size())
258  return emitSilenceableFailure(getLoc())
259  << "function with name '" << getFunctionName() << "' has " << numArgs
260  << " arguments, but " << getArgsInterchange().size()
261  << " args interchange were given";
262 
263  if (numResults != getResultsInterchange().size())
264  return emitSilenceableFailure(getLoc())
265  << "function with name '" << getFunctionName() << "' has "
266  << numResults << " results, but " << getResultsInterchange().size()
267  << " results interchange were given";
268 
269  // Check that the args and results interchanges are unique.
270  SetVector<unsigned> argsInterchange, resultsInterchange;
271  argsInterchange.insert_range(getArgsInterchange());
272  resultsInterchange.insert_range(getResultsInterchange());
273  if (argsInterchange.size() != getArgsInterchange().size())
274  return emitSilenceableFailure(getLoc())
275  << "args interchange must be unique";
276 
277  if (resultsInterchange.size() != getResultsInterchange().size())
278  return emitSilenceableFailure(getLoc())
279  << "results interchange must be unique";
280 
281  // Check that the args and results interchange indices are in bounds.
282  for (unsigned index : argsInterchange) {
283  if (index >= numArgs) {
284  return emitSilenceableFailure(getLoc())
285  << "args interchange index " << index
286  << " is out of bounds for function with name '"
287  << getFunctionName() << "' with " << numArgs << " arguments";
288  }
289  }
290  for (unsigned index : resultsInterchange) {
291  if (index >= numResults) {
292  return emitSilenceableFailure(getLoc())
293  << "results interchange index " << index
294  << " is out of bounds for function with name '"
295  << getFunctionName() << "' with " << numResults << " results";
296  }
297  }
298 
299  FailureOr<func::FuncOp> newFuncOpOrFailure = func::replaceFuncWithNewOrder(
300  rewriter, funcOp, argsInterchange.getArrayRef(),
301  resultsInterchange.getArrayRef());
302  if (failed(newFuncOpOrFailure))
303  return emitSilenceableFailure(getLoc())
304  << "failed to replace function signature '" << getFunctionName()
305  << "' with new order";
306 
307  if (getAdjustFuncCalls()) {
309  targetModuleOp.walk([&](func::CallOp callOp) {
310  if (callOp.getCallee() == getFunctionName().getRootReference().getValue())
311  callOps.push_back(callOp);
312  });
313 
314  for (func::CallOp callOp : callOps)
315  func::replaceCallOpWithNewOrder(rewriter, callOp,
316  argsInterchange.getArrayRef(),
317  resultsInterchange.getArrayRef());
318  }
319 
320  results.set(cast<OpResult>(getTransformedModule()), {targetModuleOp});
321  results.set(cast<OpResult>(getTransformedFunction()), {*newFuncOpOrFailure});
322 
324 }
325 
326 void transform::ReplaceFuncSignatureOp::getEffects(
328  transform::consumesHandle(getModuleMutable(), effects);
329  transform::producesHandle(getOperation()->getOpResults(), effects);
331 }
332 
333 //===----------------------------------------------------------------------===//
334 // Transform op registration
335 //===----------------------------------------------------------------------===//
336 
337 namespace {
338 class FuncTransformDialectExtension
340  FuncTransformDialectExtension> {
341 public:
342  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(FuncTransformDialectExtension)
343 
344  using Base::Base;
345 
346  void init() {
347  declareGeneratedDialect<LLVM::LLVMDialect>();
348 
349  registerTransformOps<
350 #define GET_OP_LIST
351 #include "mlir/Dialect/Func/TransformOps/FuncTransformOps.cpp.inc"
352  >();
353  }
354 };
355 } // namespace
356 
357 #define GET_OP_CLASSES
358 #include "mlir/Dialect/Func/TransformOps/FuncTransformOps.cpp.inc"
359 
361  registry.addExtensions<FuncTransformDialectExtension>();
362 }
static std::string diag(const llvm::Value &value)
#define MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(CLASS_NAME)
Definition: TypeID.h:331
The result of a transform IR operation application.
static DiagnosedSilenceableFailure success()
Constructs a DiagnosedSilenceableFailure in the success state.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
void addExtensions()
Add the given extensions to the registry.
A class for computing basic dominance information.
Definition: Dominance.h:140
This class represents a diagnostic that is inflight and set to be reported.
Definition: Diagnostics.h:314
Conversion from types to the LLVM IR dialect.
Definition: TypeConverter.h:35
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:396
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition: Builders.h:410
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
user_range getUsers()
Returns a range of all users.
Definition: Operation.h:873
void replaceAllUsesExcept(Value from, Value to, Operation *exceptedUser)
Find uses of from and replace them with to except if the user is exceptedUser.
Definition: PatternMatch.h:700
Type conversion class.
Value materializeSourceConversion(OpBuilder &builder, Location loc, Type resultType, ValueRange inputs) const
Materialize a conversion from a set of types into one result type by generating a cast sequence of so...
Value materializeTargetConversion(OpBuilder &builder, Location loc, Type resultType, ValueRange inputs, Type originalType={}) const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Base class for extensions of the Transform dialect that supports injecting operations into the Transf...
Local mapping between values defined by a specific op implementing the TransformOpInterface and the p...
void set(OpResult value, Range &&ops)
Indicates that the result of the transform IR op at the given position corresponds to the given list ...
This is a special rewriter to be used in transform op implementations, providing additional helper fu...
The state maintained across applications of various ops implementing the TransformOpInterface.
CallOp replaceCallOpWithNewOrder(RewriterBase &rewriter, CallOp callOp, llvm::ArrayRef< unsigned > newArgsOrder, llvm::ArrayRef< unsigned > newResultsOrder)
Creates a new call operation with the values as the original call operation, but with the arguments r...
void registerTransformDialectExtension(DialectRegistry &registry)
FailureOr< FuncOp > replaceFuncWithNewOrder(RewriterBase &rewriter, FuncOp funcOp, llvm::ArrayRef< unsigned > newArgsOrder, llvm::ArrayRef< unsigned > newResultsOrder)
Creates a new function operation with the same name as the original function operation,...
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition: Remarks.h:491
void producesHandle(ResultRange handles, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
void consumesHandle(MutableArrayRef< OpOperand > handles, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
Populates effects with the memory effects indicating the operation on the given handle value:
void onlyReadsHandle(MutableArrayRef< OpOperand > handles, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
void modifiesPayload(SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
Populates effects with the memory effects indicating the access to payload IR resource.
Include the generated interface declarations.
DiagnosedSilenceableFailure emitSilenceableFailure(Location loc, const Twine &message={})
Emits a silenceable failure with the given message.
void populateFuncToLLVMConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns, SymbolTableCollection *symbolTables=nullptr)
Collect the patterns to convert from the Func dialect to LLVM.
Definition: FuncToLLVM.cpp:788
DiagnosedDefiniteFailure emitDefiniteFailure(Location loc, const Twine &message={})
Emits a definite failure with the given message.
const FrozenRewritePatternSet & patterns
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:423