MLIR  19.0.0git
Utils.cpp
Go to the documentation of this file.
1 //===- Utils.cpp - Utils related to the transform dialect -------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
11 #include "mlir/IR/Verifier.h"
13 #include "llvm/Support/Debug.h"
14 
15 using namespace mlir;
16 
17 #define DEBUG_TYPE "transform-dialect-utils"
18 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ")
19 
20 /// Return whether `func1` can be merged into `func2`. For that to work
21 /// `func1` has to be a declaration (aka has to be external) and `func2`
22 /// either has to be a declaration as well, or it has to be public (otherwise,
23 /// it wouldn't be visible by `func1`).
24 static bool canMergeInto(FunctionOpInterface func1, FunctionOpInterface func2) {
25  return func1.isExternal() && (func2.isPublic() || func2.isExternal());
26 }
27 
28 /// Merge `func1` into `func2`. The two ops must be inside the same parent op
29 /// and mergable according to `canMergeInto`. The function erases `func1` such
30 /// that only `func2` exists when the function returns.
31 static InFlightDiagnostic mergeInto(FunctionOpInterface func1,
32  FunctionOpInterface func2) {
33  assert(canMergeInto(func1, func2));
34  assert(func1->getParentOp() == func2->getParentOp() &&
35  "expected func1 and func2 to be in the same parent op");
36 
37  // Check that function signatures match.
38  if (func1.getFunctionType() != func2.getFunctionType()) {
39  return func1.emitError()
40  << "external definition has a mismatching signature ("
41  << func2.getFunctionType() << ")";
42  }
43 
44  // Check and merge argument attributes.
45  MLIRContext *context = func1->getContext();
46  auto *td = context->getLoadedDialect<transform::TransformDialect>();
47  StringAttr consumedName = td->getConsumedAttrName();
48  StringAttr readOnlyName = td->getReadOnlyAttrName();
49  for (unsigned i = 0, e = func1.getNumArguments(); i < e; ++i) {
50  bool isExternalConsumed = func2.getArgAttr(i, consumedName) != nullptr;
51  bool isExternalReadonly = func2.getArgAttr(i, readOnlyName) != nullptr;
52  bool isConsumed = func1.getArgAttr(i, consumedName) != nullptr;
53  bool isReadonly = func1.getArgAttr(i, readOnlyName) != nullptr;
54  if (!isExternalConsumed && !isExternalReadonly) {
55  if (isConsumed)
56  func2.setArgAttr(i, consumedName, UnitAttr::get(context));
57  else if (isReadonly)
58  func2.setArgAttr(i, readOnlyName, UnitAttr::get(context));
59  continue;
60  }
61 
62  if ((isExternalConsumed && !isConsumed) ||
63  (isExternalReadonly && !isReadonly)) {
64  return func1.emitError()
65  << "external definition has mismatching consumption "
66  "annotations for argument #"
67  << i;
68  }
69  }
70 
71  // `func1` is the external one, so we can remove it.
72  assert(func1.isExternal());
73  func1->erase();
74 
75  return InFlightDiagnostic();
76 }
77 
81  assert(target->hasTrait<OpTrait::SymbolTable>() &&
82  "requires target to implement the 'SymbolTable' trait");
83  assert(other->hasTrait<OpTrait::SymbolTable>() &&
84  "requires target to implement the 'SymbolTable' trait");
85 
86  SymbolTable targetSymbolTable(target);
87  SymbolTable otherSymbolTable(*other);
88 
89  // Step 1:
90  //
91  // Rename private symbols in both ops in order to resolve conflicts that can
92  // be resolved that way.
93  LLVM_DEBUG(DBGS() << "renaming private symbols to resolve conflicts:\n");
94  // TODO: Do we *actually* need to test in both directions?
95  for (auto &&[symbolTable, otherSymbolTable] : llvm::zip(
96  SmallVector<SymbolTable *, 2>{&targetSymbolTable, &otherSymbolTable},
97  SmallVector<SymbolTable *, 2>{&otherSymbolTable,
98  &targetSymbolTable})) {
99  Operation *symbolTableOp = symbolTable->getOp();
100  for (Operation &op : symbolTableOp->getRegion(0).front()) {
101  auto symbolOp = dyn_cast<SymbolOpInterface>(op);
102  if (!symbolOp)
103  continue;
104  StringAttr name = symbolOp.getNameAttr();
105  LLVM_DEBUG(DBGS() << " found @" << name.getValue() << "\n");
106 
107  // Check if there is a colliding op in the other module.
108  auto collidingOp =
109  cast_or_null<SymbolOpInterface>(otherSymbolTable->lookup(name));
110  if (!collidingOp)
111  continue;
112 
113  LLVM_DEBUG(DBGS() << " collision found for @" << name.getValue());
114 
115  // Collisions are fine if both opt are functions and can be merged.
116  if (auto funcOp = dyn_cast<FunctionOpInterface>(op),
117  collidingFuncOp =
118  dyn_cast<FunctionOpInterface>(collidingOp.getOperation());
119  funcOp && collidingFuncOp) {
120  if (canMergeInto(funcOp, collidingFuncOp) ||
121  canMergeInto(collidingFuncOp, funcOp)) {
122  LLVM_DEBUG(llvm::dbgs() << " but both ops are functions and "
123  "will be merged\n");
124  continue;
125  }
126 
127  // If they can't be merged, proceed like any other collision.
128  LLVM_DEBUG(llvm::dbgs() << " and both ops are function definitions");
129  }
130 
131  // Collision can be resolved by renaming if one of the ops is private.
132  auto renameToUnique =
133  [&](SymbolOpInterface op, SymbolOpInterface otherOp,
134  SymbolTable &symbolTable,
135  SymbolTable &otherSymbolTable) -> InFlightDiagnostic {
136  LLVM_DEBUG(llvm::dbgs() << ", renaming\n");
137  FailureOr<StringAttr> maybeNewName =
138  symbolTable.renameToUnique(op, {&otherSymbolTable});
139  if (failed(maybeNewName)) {
140  InFlightDiagnostic diag = op->emitError("failed to rename symbol");
141  diag.attachNote(otherOp->getLoc())
142  << "attempted renaming due to collision with this op";
143  return diag;
144  }
145  LLVM_DEBUG(DBGS() << " renamed to @" << maybeNewName->getValue()
146  << "\n");
147  return InFlightDiagnostic();
148  };
149 
150  if (symbolOp.isPrivate()) {
151  InFlightDiagnostic diag = renameToUnique(
152  symbolOp, collidingOp, *symbolTable, *otherSymbolTable);
153  if (failed(diag))
154  return diag;
155  continue;
156  }
157  if (collidingOp.isPrivate()) {
158  InFlightDiagnostic diag = renameToUnique(
159  collidingOp, symbolOp, *otherSymbolTable, *symbolTable);
160  if (failed(diag))
161  return diag;
162  continue;
163  }
164  LLVM_DEBUG(llvm::dbgs() << ", emitting error\n");
165  InFlightDiagnostic diag = symbolOp.emitError()
166  << "doubly defined symbol @" << name.getValue();
167  diag.attachNote(collidingOp->getLoc()) << "previously defined here";
168  return diag;
169  }
170  }
171 
172  // TODO: This duplicates pass infrastructure. We should split this pass into
173  // several and let the pass infrastructure do the verification.
174  for (auto *op : SmallVector<Operation *>{target, *other}) {
175  if (failed(mlir::verify(op)))
176  return op->emitError() << "failed to verify input op after renaming";
177  }
178 
179  // Step 2:
180  //
181  // Move all ops from `other` into target and merge public symbols.
182  LLVM_DEBUG(DBGS() << "moving all symbols into target\n");
183  {
185  for (Operation &op : other->getRegion(0).front()) {
186  if (auto symbol = dyn_cast<SymbolOpInterface>(op))
187  opsToMove.push_back(symbol);
188  }
189 
190  for (SymbolOpInterface op : opsToMove) {
191  // Remember potentially colliding op in the target module.
192  auto collidingOp = cast_or_null<SymbolOpInterface>(
193  targetSymbolTable.lookup(op.getNameAttr()));
194 
195  // Move op even if we get a collision.
196  LLVM_DEBUG(DBGS() << " moving @" << op.getName());
197  op->moveBefore(&target->getRegion(0).front(),
198  target->getRegion(0).front().end());
199 
200  // If there is no collision, we are done.
201  if (!collidingOp) {
202  LLVM_DEBUG(llvm::dbgs() << " without collision\n");
203  continue;
204  }
205 
206  // The two colliding ops must both be functions because we have already
207  // emitted errors otherwise earlier.
208  auto funcOp = cast<FunctionOpInterface>(op.getOperation());
209  auto collidingFuncOp =
210  cast<FunctionOpInterface>(collidingOp.getOperation());
211 
212  // Both ops are in the target module now and can be treated
213  // symmetrically, so w.l.o.g. we can reduce to merging `funcOp` into
214  // `collidingFuncOp`.
215  if (!canMergeInto(funcOp, collidingFuncOp)) {
216  std::swap(funcOp, collidingFuncOp);
217  }
218  assert(canMergeInto(funcOp, collidingFuncOp));
219 
220  LLVM_DEBUG(llvm::dbgs() << " with collision, trying to keep op at "
221  << collidingFuncOp.getLoc() << ":\n"
222  << collidingFuncOp << "\n");
223 
224  // Update symbol table. This works with or without the previous `swap`.
225  targetSymbolTable.remove(funcOp);
226  targetSymbolTable.insert(collidingFuncOp);
227  assert(targetSymbolTable.lookup(funcOp.getName()) == collidingFuncOp);
228 
229  // Do the actual merging.
230  {
231  InFlightDiagnostic diag = mergeInto(funcOp, collidingFuncOp);
232  if (failed(diag))
233  return diag;
234  }
235  }
236  }
237 
238  if (failed(mlir::verify(target)))
239  return target->emitError()
240  << "failed to verify target op after merging symbols";
241 
242  LLVM_DEBUG(DBGS() << "done merging ops\n");
243  return InFlightDiagnostic();
244 }
static bool canMergeInto(FunctionOpInterface func1, FunctionOpInterface func2)
Return whether func1 can be merged into func2.
Definition: Utils.cpp:24
static InFlightDiagnostic mergeInto(FunctionOpInterface func1, FunctionOpInterface func2)
Merge func1 into func2.
Definition: Utils.cpp:31
#define DBGS()
Definition: Utils.cpp:18
static std::string diag(const llvm::Value &value)
iterator end()
Definition: Block.h:141
This class provides support for representing a failure result, or a valid value of type T.
Definition: LogicalResult.h:78
This class represents a diagnostic that is inflight and set to be reported.
Definition: Diagnostics.h:308
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
Dialect * getLoadedDialect(StringRef name)
Get a registered IR dialect with the given namespace.
A trait used to provide symbol table functionalities to a region operation.
Definition: SymbolTable.h:435
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition: Operation.h:745
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:268
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Definition: Operation.h:682
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:119
void moveBefore(Operation *existingOp)
Unlink this operation from its current block and insert it right before existingOp which may be in th...
Definition: Operation.cpp:555
This class acts as an owning reference to an op, and will automatically destroy the held op on destru...
Definition: OwningOpRef.h:29
Block & front()
Definition: Region.h:65
This class allows for representing and managing the symbol table used by operations with the 'SymbolT...
Definition: SymbolTable.h:24
Operation * lookup(StringRef name) const
Look up a symbol with the specified name, returning null if no such name exists.
StringAttr insert(Operation *symbol, Block::iterator insertPt={})
Insert a new symbol into the table, and rename it as necessary to avoid collisions.
void remove(Operation *op)
Remove the given symbol from the table, without deleting it.
InFlightDiagnostic mergeSymbolsInto(Operation *target, OwningOpRef< Operation * > other)
Merge all symbols from other into target.
Definition: Utils.cpp:79
Include the generated interface declarations.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:421
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72