MLIR  21.0.0git
FuncTransformOps.cpp
Go to the documentation of this file.
1 //===- FuncTransformOps.cpp - Implementation of CF transform ops ----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
19 #include "mlir/IR/PatternMatch.h"
21 
22 using namespace mlir;
23 
24 //===----------------------------------------------------------------------===//
25 // Apply...ConversionPatternsOp
26 //===----------------------------------------------------------------------===//
27 
28 void transform::ApplyFuncToLLVMConversionPatternsOp::populatePatterns(
29  TypeConverter &typeConverter, RewritePatternSet &patterns) {
31  static_cast<LLVMTypeConverter &>(typeConverter), patterns);
32 }
33 
34 LogicalResult
35 transform::ApplyFuncToLLVMConversionPatternsOp::verifyTypeConverter(
36  transform::TypeConverterBuilderOpInterface builder) {
37  if (builder.getTypeConverterType() != "LLVMTypeConverter")
38  return emitOpError("expected LLVMTypeConverter");
39  return success();
40 }
41 
42 //===----------------------------------------------------------------------===//
43 // CastAndCallOp
44 //===----------------------------------------------------------------------===//
45 
47 transform::CastAndCallOp::apply(transform::TransformRewriter &rewriter,
50  SmallVector<Value> inputs;
51  if (getInputs())
52  llvm::append_range(inputs, state.getPayloadValues(getInputs()));
53 
54  SetVector<Value> outputs;
55  if (getOutputs()) {
56  outputs.insert_range(state.getPayloadValues(getOutputs()));
57 
58  // Verify that the set of output values to be replaced is unique.
59  if (outputs.size() !=
60  llvm::range_size(state.getPayloadValues(getOutputs()))) {
61  return emitSilenceableFailure(getLoc())
62  << "cast and call output values must be unique";
63  }
64  }
65 
66  // Get the insertion point for the call.
67  auto insertionOps = state.getPayloadOps(getInsertionPoint());
68  if (!llvm::hasSingleElement(insertionOps)) {
69  return emitSilenceableFailure(getLoc())
70  << "Only one op can be specified as an insertion point";
71  }
72  bool insertAfter = getInsertAfter();
73  Operation *insertionPoint = *insertionOps.begin();
74 
75  // Check that all inputs dominate the insertion point, and the insertion
76  // point dominates all users of the outputs.
77  DominanceInfo dom(insertionPoint);
78  for (Value output : outputs) {
79  for (Operation *user : output.getUsers()) {
80  // If we are inserting after the insertion point operation, the
81  // insertion point operation must properly dominate the user. Otherwise
82  // basic dominance is enough.
83  bool doesDominate = insertAfter
84  ? dom.properlyDominates(insertionPoint, user)
85  : dom.dominates(insertionPoint, user);
86  if (!doesDominate) {
87  return emitDefiniteFailure()
88  << "User " << user << " is not dominated by insertion point "
89  << insertionPoint;
90  }
91  }
92  }
93 
94  for (Value input : inputs) {
95  // If we are inserting before the insertion point operation, the
96  // input must properly dominate the insertion point operation. Otherwise
97  // basic dominance is enough.
98  bool doesDominate = insertAfter
99  ? dom.dominates(input, insertionPoint)
100  : dom.properlyDominates(input, insertionPoint);
101  if (!doesDominate) {
102  return emitDefiniteFailure()
103  << "input " << input << " does not dominate insertion point "
104  << insertionPoint;
105  }
106  }
107 
108  // Get the function to call. This can either be specified by symbol or as a
109  // transform handle.
110  func::FuncOp targetFunction = nullptr;
111  if (getFunctionName()) {
112  targetFunction = SymbolTable::lookupNearestSymbolFrom<func::FuncOp>(
113  insertionPoint, *getFunctionName());
114  if (!targetFunction) {
115  return emitDefiniteFailure()
116  << "unresolved symbol " << *getFunctionName();
117  }
118  } else if (getFunction()) {
119  auto payloadOps = state.getPayloadOps(getFunction());
120  if (!llvm::hasSingleElement(payloadOps)) {
121  return emitDefiniteFailure() << "requires a single function to call";
122  }
123  targetFunction = dyn_cast<func::FuncOp>(*payloadOps.begin());
124  if (!targetFunction) {
125  return emitDefiniteFailure() << "invalid non-function callee";
126  }
127  } else {
128  llvm_unreachable("Invalid CastAndCall op without a function to call");
129  return emitDefiniteFailure();
130  }
131 
132  // Verify that the function argument and result lengths match the inputs and
133  // outputs given to this op.
134  if (targetFunction.getNumArguments() != inputs.size()) {
135  return emitSilenceableFailure(targetFunction.getLoc())
136  << "mismatch between number of function arguments "
137  << targetFunction.getNumArguments() << " and number of inputs "
138  << inputs.size();
139  }
140  if (targetFunction.getNumResults() != outputs.size()) {
141  return emitSilenceableFailure(targetFunction.getLoc())
142  << "mismatch between number of function results "
143  << targetFunction->getNumResults() << " and number of outputs "
144  << outputs.size();
145  }
146 
147  // Gather all specified converters.
148  mlir::TypeConverter converter;
149  if (!getRegion().empty()) {
150  for (Operation &op : getRegion().front()) {
151  cast<transform::TypeConverterBuilderOpInterface>(&op)
152  .populateTypeMaterializations(converter);
153  }
154  }
155 
156  if (insertAfter)
157  rewriter.setInsertionPointAfter(insertionPoint);
158  else
159  rewriter.setInsertionPoint(insertionPoint);
160 
161  for (auto [input, type] :
162  llvm::zip_equal(inputs, targetFunction.getArgumentTypes())) {
163  if (input.getType() != type) {
164  Value newInput = converter.materializeSourceConversion(
165  rewriter, input.getLoc(), type, input);
166  if (!newInput) {
167  return emitDefiniteFailure() << "Failed to materialize conversion of "
168  << input << " to type " << type;
169  }
170  input = newInput;
171  }
172  }
173 
174  auto callOp = rewriter.create<func::CallOp>(insertionPoint->getLoc(),
175  targetFunction, inputs);
176 
177  // Cast the call results back to the expected types. If any conversions fail
178  // this is a definite failure as the call has been constructed at this point.
179  for (auto [output, newOutput] :
180  llvm::zip_equal(outputs, callOp.getResults())) {
181  Value convertedOutput = newOutput;
182  if (output.getType() != newOutput.getType()) {
183  convertedOutput = converter.materializeTargetConversion(
184  rewriter, output.getLoc(), output.getType(), newOutput);
185  if (!convertedOutput) {
186  return emitDefiniteFailure()
187  << "Failed to materialize conversion of " << newOutput
188  << " to type " << output.getType();
189  }
190  }
191  rewriter.replaceAllUsesExcept(output, convertedOutput, callOp);
192  }
193  results.set(cast<OpResult>(getResult()), {callOp});
195 }
196 
197 LogicalResult transform::CastAndCallOp::verify() {
198  if (!getRegion().empty()) {
199  for (Operation &op : getRegion().front()) {
200  if (!isa<transform::TypeConverterBuilderOpInterface>(&op)) {
201  InFlightDiagnostic diag = emitOpError()
202  << "expected children ops to implement "
203  "TypeConverterBuilderOpInterface";
204  diag.attachNote(op.getLoc()) << "op without interface";
205  return diag;
206  }
207  }
208  }
209  if (!getFunction() && !getFunctionName()) {
210  return emitOpError() << "expected a function handle or name to call";
211  }
212  if (getFunction() && getFunctionName()) {
213  return emitOpError() << "function handle and name are mutually exclusive";
214  }
215  return success();
216 }
217 
218 void transform::CastAndCallOp::getEffects(
220  transform::onlyReadsHandle(getInsertionPointMutable(), effects);
221  if (getInputs())
222  transform::onlyReadsHandle(getInputsMutable(), effects);
223  if (getOutputs())
224  transform::onlyReadsHandle(getOutputsMutable(), effects);
225  if (getFunction())
226  transform::onlyReadsHandle(getFunctionMutable(), effects);
227  transform::producesHandle(getOperation()->getOpResults(), effects);
229 }
230 
231 //===----------------------------------------------------------------------===//
232 // ReplaceFuncSignatureOp
233 //===----------------------------------------------------------------------===//
234 
236 transform::ReplaceFuncSignatureOp::apply(transform::TransformRewriter &rewriter,
238  transform::TransformState &state) {
239  auto payloadOps = state.getPayloadOps(getModule());
240  if (!llvm::hasSingleElement(payloadOps))
241  return emitDefiniteFailure() << "requires a single module to operate on";
242 
243  auto targetModuleOp = dyn_cast<ModuleOp>(*payloadOps.begin());
244  if (!targetModuleOp)
245  return emitSilenceableFailure(getLoc())
246  << "target is expected to be module operation";
247 
248  func::FuncOp funcOp =
249  targetModuleOp.lookupSymbol<func::FuncOp>(getFunctionName());
250  if (!funcOp)
251  return emitSilenceableFailure(getLoc())
252  << "function with name '" << getFunctionName() << "' not found";
253 
254  unsigned numArgs = funcOp.getNumArguments();
255  unsigned numResults = funcOp.getNumResults();
256  // Check that the number of arguments and results matches the
257  // interchange sizes.
258  if (numArgs != getArgsInterchange().size())
259  return emitSilenceableFailure(getLoc())
260  << "function with name '" << getFunctionName() << "' has " << numArgs
261  << " arguments, but " << getArgsInterchange().size()
262  << " args interchange were given";
263 
264  if (numResults != getResultsInterchange().size())
265  return emitSilenceableFailure(getLoc())
266  << "function with name '" << getFunctionName() << "' has "
267  << numResults << " results, but " << getResultsInterchange().size()
268  << " results interchange were given";
269 
270  // Check that the args and results interchanges are unique.
271  SetVector<unsigned> argsInterchange, resultsInterchange;
272  argsInterchange.insert_range(getArgsInterchange());
273  resultsInterchange.insert_range(getResultsInterchange());
274  if (argsInterchange.size() != getArgsInterchange().size())
275  return emitSilenceableFailure(getLoc())
276  << "args interchange must be unique";
277 
278  if (resultsInterchange.size() != getResultsInterchange().size())
279  return emitSilenceableFailure(getLoc())
280  << "results interchange must be unique";
281 
282  // Check that the args and results interchange indices are in bounds.
283  for (unsigned index : argsInterchange) {
284  if (index >= numArgs) {
285  return emitSilenceableFailure(getLoc())
286  << "args interchange index " << index
287  << " is out of bounds for function with name '"
288  << getFunctionName() << "' with " << numArgs << " arguments";
289  }
290  }
291  for (unsigned index : resultsInterchange) {
292  if (index >= numResults) {
293  return emitSilenceableFailure(getLoc())
294  << "results interchange index " << index
295  << " is out of bounds for function with name '"
296  << getFunctionName() << "' with " << numResults << " results";
297  }
298  }
299 
300  FailureOr<func::FuncOp> newFuncOpOrFailure = func::replaceFuncWithNewOrder(
301  rewriter, funcOp, argsInterchange.getArrayRef(),
302  resultsInterchange.getArrayRef());
303  if (failed(newFuncOpOrFailure))
304  return emitSilenceableFailure(getLoc())
305  << "failed to replace function signature '" << getFunctionName()
306  << "' with new order";
307 
308  if (getAdjustFuncCalls()) {
310  targetModuleOp.walk([&](func::CallOp callOp) {
311  if (callOp.getCallee() == getFunctionName().getRootReference().getValue())
312  callOps.push_back(callOp);
313  });
314 
315  for (func::CallOp callOp : callOps)
316  func::replaceCallOpWithNewOrder(rewriter, callOp,
317  argsInterchange.getArrayRef(),
318  resultsInterchange.getArrayRef());
319  }
320 
321  results.set(cast<OpResult>(getTransformedModule()), {targetModuleOp});
322  results.set(cast<OpResult>(getTransformedFunction()), {*newFuncOpOrFailure});
323 
325 }
326 
327 void transform::ReplaceFuncSignatureOp::getEffects(
329  transform::consumesHandle(getModuleMutable(), effects);
330  transform::producesHandle(getOperation()->getOpResults(), effects);
332 }
333 
334 //===----------------------------------------------------------------------===//
335 // Transform op registration
336 //===----------------------------------------------------------------------===//
337 
338 namespace {
339 class FuncTransformDialectExtension
341  FuncTransformDialectExtension> {
342 public:
343  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(FuncTransformDialectExtension)
344 
345  using Base::Base;
346 
347  void init() {
348  declareGeneratedDialect<LLVM::LLVMDialect>();
349 
350  registerTransformOps<
351 #define GET_OP_LIST
352 #include "mlir/Dialect/Func/TransformOps/FuncTransformOps.cpp.inc"
353  >();
354  }
355 };
356 } // namespace
357 
358 #define GET_OP_CLASSES
359 #include "mlir/Dialect/Func/TransformOps/FuncTransformOps.cpp.inc"
360 
362  registry.addExtensions<FuncTransformDialectExtension>();
363 }
static std::string diag(const llvm::Value &value)
#define MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(CLASS_NAME)
Definition: TypeID.h:331
The result of a transform IR operation application.
static DiagnosedSilenceableFailure success()
Constructs a DiagnosedSilenceableFailure in the success state.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
void addExtensions()
Add the given extensions to the registry.
A class for computing basic dominance information.
Definition: Dominance.h:140
This class represents a diagnostic that is inflight and set to be reported.
Definition: Diagnostics.h:314
Conversion from types to the LLVM IR dialect.
Definition: TypeConverter.h:35
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:396
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:452
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition: Builders.h:410
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
user_range getUsers()
Returns a range of all users.
Definition: Operation.h:873
void replaceAllUsesExcept(Value from, Value to, Operation *exceptedUser)
Find uses of from and replace them with to except if the user is exceptedUser.
Definition: PatternMatch.h:665
Type conversion class.
Value materializeSourceConversion(OpBuilder &builder, Location loc, Type resultType, ValueRange inputs) const
Materialize a conversion from a set of types into one result type by generating a cast sequence of so...
Value materializeTargetConversion(OpBuilder &builder, Location loc, Type resultType, ValueRange inputs, Type originalType={}) const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Base class for extensions of the Transform dialect that supports injecting operations into the Transf...
Local mapping between values defined by a specific op implementing the TransformOpInterface and the p...
void set(OpResult value, Range &&ops)
Indicates that the result of the transform IR op at the given position corresponds to the given list ...
This is a special rewriter to be used in transform op implementations, providing additional helper fu...
The state maintained across applications of various ops implementing the TransformOpInterface.
CallOp replaceCallOpWithNewOrder(RewriterBase &rewriter, CallOp callOp, llvm::ArrayRef< unsigned > newArgsOrder, llvm::ArrayRef< unsigned > newResultsOrder)
Creates a new call operation with the values as the original call operation, but with the arguments r...
void registerTransformDialectExtension(DialectRegistry &registry)
FailureOr< FuncOp > replaceFuncWithNewOrder(RewriterBase &rewriter, FuncOp funcOp, llvm::ArrayRef< unsigned > newArgsOrder, llvm::ArrayRef< unsigned > newResultsOrder)
Creates a new function operation with the same name as the original function operation,...
void producesHandle(ResultRange handles, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
void consumesHandle(MutableArrayRef< OpOperand > handles, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
Populates effects with the memory effects indicating the operation on the given handle value:
void onlyReadsHandle(MutableArrayRef< OpOperand > handles, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
void modifiesPayload(SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
Populates effects with the memory effects indicating the access to payload IR resource.
Include the generated interface declarations.
DiagnosedSilenceableFailure emitSilenceableFailure(Location loc, const Twine &message={})
Emits a silenceable failure with the given message.
void populateFuncToLLVMConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns, SymbolTableCollection *symbolTables=nullptr)
Collect the patterns to convert from the Func dialect to LLVM.
Definition: FuncToLLVM.cpp:746
DiagnosedDefiniteFailure emitDefiniteFailure(Location loc, const Twine &message={})
Emits a definite failure with the given message.
const FrozenRewritePatternSet & patterns
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:423