MLIR  20.0.0git
FuncTransformOps.cpp
Go to the documentation of this file.
1 //===- FuncTransformOps.cpp - Implementation of CF transform ops ---===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
19 
20 using namespace mlir;
21 
22 //===----------------------------------------------------------------------===//
23 // Apply...ConversionPatternsOp
24 //===----------------------------------------------------------------------===//
25 
26 void transform::ApplyFuncToLLVMConversionPatternsOp::populatePatterns(
27  TypeConverter &typeConverter, RewritePatternSet &patterns) {
29  static_cast<LLVMTypeConverter &>(typeConverter), patterns);
30 }
31 
32 LogicalResult
33 transform::ApplyFuncToLLVMConversionPatternsOp::verifyTypeConverter(
34  transform::TypeConverterBuilderOpInterface builder) {
35  if (builder.getTypeConverterType() != "LLVMTypeConverter")
36  return emitOpError("expected LLVMTypeConverter");
37  return success();
38 }
39 
40 //===----------------------------------------------------------------------===//
41 // CastAndCallOp
42 //===----------------------------------------------------------------------===//
43 
45 transform::CastAndCallOp::apply(transform::TransformRewriter &rewriter,
48  SmallVector<Value> inputs;
49  if (getInputs())
50  llvm::append_range(inputs, state.getPayloadValues(getInputs()));
51 
52  SetVector<Value> outputs;
53  if (getOutputs()) {
54  for (auto output : state.getPayloadValues(getOutputs()))
55  outputs.insert(output);
56 
57  // Verify that the set of output values to be replaced is unique.
58  if (outputs.size() !=
59  llvm::range_size(state.getPayloadValues(getOutputs()))) {
60  return emitSilenceableFailure(getLoc())
61  << "cast and call output values must be unique";
62  }
63  }
64 
65  // Get the insertion point for the call.
66  auto insertionOps = state.getPayloadOps(getInsertionPoint());
67  if (!llvm::hasSingleElement(insertionOps)) {
68  return emitSilenceableFailure(getLoc())
69  << "Only one op can be specified as an insertion point";
70  }
71  bool insertAfter = getInsertAfter();
72  Operation *insertionPoint = *insertionOps.begin();
73 
74  // Check that all inputs dominate the insertion point, and the insertion
75  // point dominates all users of the outputs.
76  DominanceInfo dom(insertionPoint);
77  for (Value output : outputs) {
78  for (Operation *user : output.getUsers()) {
79  // If we are inserting after the insertion point operation, the
80  // insertion point operation must properly dominate the user. Otherwise
81  // basic dominance is enough.
82  bool doesDominate = insertAfter
83  ? dom.properlyDominates(insertionPoint, user)
84  : dom.dominates(insertionPoint, user);
85  if (!doesDominate) {
86  return emitDefiniteFailure()
87  << "User " << user << " is not dominated by insertion point "
88  << insertionPoint;
89  }
90  }
91  }
92 
93  for (Value input : inputs) {
94  // If we are inserting before the insertion point operation, the
95  // input must properly dominate the insertion point operation. Otherwise
96  // basic dominance is enough.
97  bool doesDominate = insertAfter
98  ? dom.dominates(input, insertionPoint)
99  : dom.properlyDominates(input, insertionPoint);
100  if (!doesDominate) {
101  return emitDefiniteFailure()
102  << "input " << input << " does not dominate insertion point "
103  << insertionPoint;
104  }
105  }
106 
107  // Get the function to call. This can either be specified by symbol or as a
108  // transform handle.
109  func::FuncOp targetFunction = nullptr;
110  if (getFunctionName()) {
111  targetFunction = SymbolTable::lookupNearestSymbolFrom<func::FuncOp>(
112  insertionPoint, *getFunctionName());
113  if (!targetFunction) {
114  return emitDefiniteFailure()
115  << "unresolved symbol " << *getFunctionName();
116  }
117  } else if (getFunction()) {
118  auto payloadOps = state.getPayloadOps(getFunction());
119  if (!llvm::hasSingleElement(payloadOps)) {
120  return emitDefiniteFailure() << "requires a single function to call";
121  }
122  targetFunction = dyn_cast<func::FuncOp>(*payloadOps.begin());
123  if (!targetFunction) {
124  return emitDefiniteFailure() << "invalid non-function callee";
125  }
126  } else {
127  llvm_unreachable("Invalid CastAndCall op without a function to call");
128  return emitDefiniteFailure();
129  }
130 
131  // Verify that the function argument and result lengths match the inputs and
132  // outputs given to this op.
133  if (targetFunction.getNumArguments() != inputs.size()) {
134  return emitSilenceableFailure(targetFunction.getLoc())
135  << "mismatch between number of function arguments "
136  << targetFunction.getNumArguments() << " and number of inputs "
137  << inputs.size();
138  }
139  if (targetFunction.getNumResults() != outputs.size()) {
140  return emitSilenceableFailure(targetFunction.getLoc())
141  << "mismatch between number of function results "
142  << targetFunction->getNumResults() << " and number of outputs "
143  << outputs.size();
144  }
145 
146  // Gather all specified converters.
147  mlir::TypeConverter converter;
148  if (!getRegion().empty()) {
149  for (Operation &op : getRegion().front()) {
150  cast<transform::TypeConverterBuilderOpInterface>(&op)
151  .populateTypeMaterializations(converter);
152  }
153  }
154 
155  if (insertAfter)
156  rewriter.setInsertionPointAfter(insertionPoint);
157  else
158  rewriter.setInsertionPoint(insertionPoint);
159 
160  for (auto [input, type] :
161  llvm::zip_equal(inputs, targetFunction.getArgumentTypes())) {
162  if (input.getType() != type) {
163  Value newInput = converter.materializeSourceConversion(
164  rewriter, input.getLoc(), type, input);
165  if (!newInput) {
166  return emitDefiniteFailure() << "Failed to materialize conversion of "
167  << input << " to type " << type;
168  }
169  input = newInput;
170  }
171  }
172 
173  auto callOp = rewriter.create<func::CallOp>(insertionPoint->getLoc(),
174  targetFunction, inputs);
175 
176  // Cast the call results back to the expected types. If any conversions fail
177  // this is a definite failure as the call has been constructed at this point.
178  for (auto [output, newOutput] :
179  llvm::zip_equal(outputs, callOp.getResults())) {
180  Value convertedOutput = newOutput;
181  if (output.getType() != newOutput.getType()) {
182  convertedOutput = converter.materializeTargetConversion(
183  rewriter, output.getLoc(), output.getType(), newOutput);
184  if (!convertedOutput) {
185  return emitDefiniteFailure()
186  << "Failed to materialize conversion of " << newOutput
187  << " to type " << output.getType();
188  }
189  }
190  rewriter.replaceAllUsesExcept(output, convertedOutput, callOp);
191  }
192  results.set(cast<OpResult>(getResult()), {callOp});
194 }
195 
196 LogicalResult transform::CastAndCallOp::verify() {
197  if (!getRegion().empty()) {
198  for (Operation &op : getRegion().front()) {
199  if (!isa<transform::TypeConverterBuilderOpInterface>(&op)) {
200  InFlightDiagnostic diag = emitOpError()
201  << "expected children ops to implement "
202  "TypeConverterBuilderOpInterface";
203  diag.attachNote(op.getLoc()) << "op without interface";
204  return diag;
205  }
206  }
207  }
208  if (!getFunction() && !getFunctionName()) {
209  return emitOpError() << "expected a function handle or name to call";
210  }
211  if (getFunction() && getFunctionName()) {
212  return emitOpError() << "function handle and name are mutually exclusive";
213  }
214  return success();
215 }
216 
217 void transform::CastAndCallOp::getEffects(
219  transform::onlyReadsHandle(getInsertionPointMutable(), effects);
220  if (getInputs())
221  transform::onlyReadsHandle(getInputsMutable(), effects);
222  if (getOutputs())
223  transform::onlyReadsHandle(getOutputsMutable(), effects);
224  if (getFunction())
225  transform::onlyReadsHandle(getFunctionMutable(), effects);
226  transform::producesHandle(getOperation()->getOpResults(), effects);
228 }
229 
230 //===----------------------------------------------------------------------===//
231 // Transform op registration
232 //===----------------------------------------------------------------------===//
233 
234 namespace {
235 class FuncTransformDialectExtension
237  FuncTransformDialectExtension> {
238 public:
239  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(FuncTransformDialectExtension)
240 
241  using Base::Base;
242 
243  void init() {
244  declareGeneratedDialect<LLVM::LLVMDialect>();
245 
246  registerTransformOps<
247 #define GET_OP_LIST
248 #include "mlir/Dialect/Func/TransformOps/FuncTransformOps.cpp.inc"
249  >();
250  }
251 };
252 } // namespace
253 
254 #define GET_OP_CLASSES
255 #include "mlir/Dialect/Func/TransformOps/FuncTransformOps.cpp.inc"
256 
258  registry.addExtensions<FuncTransformDialectExtension>();
259 }
static std::string diag(const llvm::Value &value)
#define MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(CLASS_NAME)
Definition: TypeID.h:274
The result of a transform IR operation application.
static DiagnosedSilenceableFailure success()
Constructs a DiagnosedSilenceableFailure in the success state.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
void addExtensions()
Add the given extensions to the registry.
A class for computing basic dominance information.
Definition: Dominance.h:140
This class represents a diagnostic that is inflight and set to be reported.
Definition: Diagnostics.h:314
Conversion from types to the LLVM IR dialect.
Definition: TypeConverter.h:35
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:406
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:497
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition: Builders.h:420
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
user_range getUsers()
Returns a range of all users.
Definition: Operation.h:869
void replaceAllUsesExcept(Value from, Value to, Operation *exceptedUser)
Find uses of from and replace them with to except if the user is exceptedUser.
Definition: PatternMatch.h:708
Type conversion class.
Value materializeSourceConversion(OpBuilder &builder, Location loc, Type resultType, ValueRange inputs) const
Value materializeTargetConversion(OpBuilder &builder, Location loc, Type resultType, ValueRange inputs, Type originalType={}) const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Base class for extensions of the Transform dialect that supports injecting operations into the Transf...
Local mapping between values defined by a specific op implementing the TransformOpInterface and the p...
void set(OpResult value, Range &&ops)
Indicates that the result of the transform IR op at the given position corresponds to the given list ...
This is a special rewriter to be used in transform op implementations, providing additional helper fu...
The state maintained across applications of various ops implementing the TransformOpInterface.
void registerTransformDialectExtension(DialectRegistry &registry)
void producesHandle(ResultRange handles, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
void onlyReadsHandle(MutableArrayRef< OpOperand > handles, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
void modifiesPayload(SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
Populates effects with the memory effects indicating the access to payload IR resource.
Include the generated interface declarations.
DiagnosedSilenceableFailure emitSilenceableFailure(Location loc, const Twine &message={})
Emits a silenceable failure with the given message.
DiagnosedDefiniteFailure emitDefiniteFailure(Location loc, const Twine &message={})
Emits a definite failure with the given message.
void populateFuncToLLVMConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns, const SymbolTable *symbolTable=nullptr)
Collect the patterns to convert from the Func dialect to LLVM.
Definition: FuncToLLVM.cpp:733
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:426