MLIR  16.0.0git
TransformInterfaces.cpp
Go to the documentation of this file.
1 //===- TransformInterfaces.cpp - Transform Dialect Interfaces -------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
11 #include "mlir/IR/Diagnostics.h"
12 #include "mlir/IR/Operation.h"
13 #include "llvm/ADT/STLExtras.h"
14 #include "llvm/Support/Debug.h"
15 #include "llvm/Support/ErrorHandling.h"
16 
17 #define DEBUG_TYPE "transform-dialect"
18 #define DEBUG_PRINT_AFTER_ALL "transform-dialect-print-top-level-after-all"
19 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "] ")
20 
21 using namespace mlir;
22 
23 //===----------------------------------------------------------------------===//
24 // TransformState
25 //===----------------------------------------------------------------------===//
26 
27 constexpr const Value transform::TransformState::kTopLevelValue;
28 
29 transform::TransformState::TransformState(Region *region,
30  Operation *payloadRoot,
31  const TransformOptions &options)
32  : topLevel(payloadRoot), options(options) {
33  auto result = mappings.try_emplace(region);
34  assert(result.second && "the region scope is already present");
35  (void)result;
36 #if LLVM_ENABLE_ABI_BREAKING_CHECKS
37  regionStack.push_back(region);
38 #endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
39 }
40 
41 Operation *transform::TransformState::getTopLevel() const { return topLevel; }
42 
45  const TransformOpMapping &operationMapping = getMapping(value).direct;
46  auto iter = operationMapping.find(value);
47  assert(iter != operationMapping.end() && "unknown handle");
48  return iter->getSecond();
49 }
50 
52  Operation *op, SmallVectorImpl<Value> &handles) const {
53  bool found = false;
54  for (const Mappings &mapping : llvm::make_second_range(mappings)) {
55  auto iterator = mapping.reverse.find(op);
56  if (iterator != mapping.reverse.end()) {
57  llvm::append_range(handles, iterator->getSecond());
58  found = true;
59  }
60  }
61 
62  return success(found);
63 }
64 
66 transform::TransformState::setPayloadOps(Value value,
67  ArrayRef<Operation *> targets) {
68  assert(value != kTopLevelValue &&
69  "attempting to reset the transformation root");
70 
71  auto iface = value.getType().cast<TransformTypeInterface>();
73  iface.checkPayload(value.getLoc(), targets);
74  if (failed(result.checkAndReport()))
75  return failure();
76 
77  // Setting new payload for the value without cleaning it first is a misuse of
78  // the API, assert here.
79  SmallVector<Operation *> storedTargets(targets.begin(), targets.end());
80  Mappings &mappings = getMapping(value);
81  bool inserted =
82  mappings.direct.insert({value, std::move(storedTargets)}).second;
83  assert(inserted && "value is already associated with another list");
84  (void)inserted;
85 
86  for (Operation *op : targets)
87  mappings.reverse[op].push_back(value);
88 
89  return success();
90 }
91 
92 void transform::TransformState::dropReverseMapping(Mappings &mappings,
93  Operation *op, Value value) {
94  auto it = mappings.reverse.find(op);
95  if (it == mappings.reverse.end())
96  return;
97 
98  llvm::erase_value(it->getSecond(), value);
99  if (it->getSecond().empty())
100  mappings.reverse.erase(it);
101 }
102 
103 void transform::TransformState::removePayloadOps(Value value) {
104  Mappings &mappings = getMapping(value);
105  for (Operation *op : mappings.direct[value])
106  dropReverseMapping(mappings, op, value);
107  mappings.direct.erase(value);
108 }
109 
110 LogicalResult transform::TransformState::updatePayloadOps(
111  Value value, function_ref<Operation *(Operation *)> callback) {
112  Mappings &mappings = getMapping(value);
113  auto it = mappings.direct.find(value);
114  assert(it != mappings.direct.end() && "unknown handle");
115  SmallVector<Operation *> &association = it->getSecond();
116  SmallVector<Operation *> updated;
117  updated.reserve(association.size());
118 
119  for (Operation *op : association) {
120  dropReverseMapping(mappings, op, value);
121  if (Operation *updatedOp = callback(op)) {
122  updated.push_back(updatedOp);
123  mappings.reverse[updatedOp].push_back(value);
124  }
125  }
126 
127  auto iface = value.getType().cast<TransformTypeInterface>();
129  iface.checkPayload(value.getLoc(), updated);
130  if (failed(result.checkAndReport()))
131  return failure();
132 
133  it->second = updated;
134  return success();
135 }
136 
137 void transform::TransformState::recordHandleInvalidationOne(
138  OpOperand &handle, Operation *payloadOp, Value otherHandle) {
139  ArrayRef<Operation *> potentialAncestors = getPayloadOps(handle.get());
140  // If the op is associated with invalidated handle, skip the check as it
141  // may be reading invalid IR.
142  if (invalidatedHandles.count(otherHandle))
143  return;
144 
145  for (Operation *ancestor : potentialAncestors) {
146  if (!ancestor->isAncestor(payloadOp))
147  continue;
148 
149  // Make sure the error-reporting lambda doesn't capture anything
150  // by-reference because it will go out of scope. Additionally, extract
151  // location from Payload IR ops because the ops themselves may be
152  // deleted before the lambda gets called.
153  Location ancestorLoc = ancestor->getLoc();
154  Location opLoc = payloadOp->getLoc();
155  Operation *owner = handle.getOwner();
156  unsigned operandNo = handle.getOperandNumber();
157  invalidatedHandles[otherHandle] = [ancestorLoc, opLoc, owner, operandNo,
158  otherHandle](Location currentLoc) {
159  InFlightDiagnostic diag = emitError(currentLoc)
160  << "op uses a handle invalidated by a "
161  "previously executed transform op";
162  diag.attachNote(otherHandle.getLoc()) << "handle to invalidated ops";
163  diag.attachNote(owner->getLoc())
164  << "invalidated by this transform op that consumes its operand #"
165  << operandNo
166  << " and invalidates handles to payload ops nested in payload "
167  "ops associated with the consumed handle";
168  diag.attachNote(ancestorLoc) << "ancestor payload op";
169  diag.attachNote(opLoc) << "nested payload op";
170  };
171  }
172 }
173 
174 void transform::TransformState::recordHandleInvalidation(OpOperand &handle) {
175  for (const Mappings &mapping : llvm::make_second_range(mappings))
176  for (const auto &[payloadOp, otherHandles] : mapping.reverse)
177  for (Value otherHandle : otherHandles)
178  recordHandleInvalidationOne(handle, payloadOp, otherHandle);
179 }
180 
181 LogicalResult transform::TransformState::checkAndRecordHandleInvalidation(
182  TransformOpInterface transform) {
183  auto memoryEffectsIface =
184  cast<MemoryEffectOpInterface>(transform.getOperation());
186  memoryEffectsIface.getEffectsOnResource(
188 
189  for (OpOperand &target : transform->getOpOperands()) {
190  // If the operand uses an invalidated handle, report it.
191  auto it = invalidatedHandles.find(target.get());
192  if (it != invalidatedHandles.end())
193  return it->getSecond()(transform->getLoc()), failure();
194 
195  // Invalidate handles pointing to the operations nested in the operation
196  // associated with the handle consumed by this operation.
197  auto consumesTarget = [&](const MemoryEffects::EffectInstance &effect) {
198  return isa<MemoryEffects::Free>(effect.getEffect()) &&
199  effect.getValue() == target.get();
200  };
201  if (llvm::any_of(effects, consumesTarget))
202  recordHandleInvalidation(target);
203  }
204  return success();
205 }
206 
208 transform::TransformState::applyTransform(TransformOpInterface transform) {
209  LLVM_DEBUG(DBGS() << "applying: " << transform << "\n");
210  auto printOnFailureRAII = llvm::make_scope_exit([this] {
211  (void)this;
212  DEBUG_WITH_TYPE(DEBUG_PRINT_AFTER_ALL, {
213  DBGS() << "Top-level payload:\n";
214  getTopLevel()->print(llvm::dbgs(),
215  mlir::OpPrintingFlags().printGenericOpForm());
216  });
217  });
218  if (options.getExpensiveChecksEnabled()) {
219  if (failed(checkAndRecordHandleInvalidation(transform)))
221 
222  for (OpOperand &operand : transform->getOpOperands()) {
223  if (!isHandleConsumed(operand.get(), transform))
224  continue;
225 
227  for (Operation *op : getPayloadOps(operand.get())) {
228  if (!seen.insert(op).second) {
230  transform.emitSilenceableError()
231  << "a handle passed as operand #" << operand.getOperandNumber()
232  << " and consumed by this operation points to a payload "
233  "operation more than once";
234  diag.attachNote(op->getLoc()) << "repeated target op";
235  return diag;
236  }
237  }
238  }
239  }
240 
241  transform::TransformResults results(transform->getNumResults());
242  // Compute the result but do not short-circuit the silenceable failure case as
243  // we still want the handles to propagate properly so the "suppress" mode can
244  // proceed on a best effort basis.
245  DiagnosedSilenceableFailure result(transform.apply(results, *this));
246  if (result.isDefiniteFailure())
247  return result;
248 
249  // Remove the mapping for the operand if it is consumed by the operation. This
250  // allows us to catch use-after-free with assertions later on.
251  auto memEffectInterface =
252  cast<MemoryEffectOpInterface>(transform.getOperation());
254  for (OpOperand &target : transform->getOpOperands()) {
255  effects.clear();
256  memEffectInterface.getEffectsOnValue(target.get(), effects);
257  if (llvm::any_of(effects, [](const MemoryEffects::EffectInstance &effect) {
258  return isa<transform::TransformMappingResource>(
259  effect.getResource()) &&
260  isa<MemoryEffects::Free>(effect.getEffect());
261  })) {
262  removePayloadOps(target.get());
263  }
264  }
265 
266  for (OpResult result : transform->getResults()) {
267  assert(result.getDefiningOp() == transform.getOperation() &&
268  "payload IR association for a value other than the result of the "
269  "current transform op");
270  if (failed(setPayloadOps(result, results.get(result.getResultNumber()))))
272  }
273 
274  printOnFailureRAII.release();
275  DEBUG_WITH_TYPE(DEBUG_PRINT_AFTER_ALL, {
276  DBGS() << "Top-level payload:\n";
277  getTopLevel()->print(llvm::dbgs());
278  });
279  return result;
280 }
281 
282 //===----------------------------------------------------------------------===//
283 // TransformState::Extension
284 //===----------------------------------------------------------------------===//
285 
287 
290  Operation *replacement) {
291  SmallVector<Value> handles;
292  if (failed(state.getHandlesForPayloadOp(op, handles)))
293  return failure();
294 
295  for (Value handle : handles) {
296  LogicalResult result =
297  state.updatePayloadOps(handle, [&](Operation *current) {
298  return current == op ? replacement : current;
299  });
300  if (failed(result))
301  return failure();
302  }
303  return success();
304 }
305 
306 //===----------------------------------------------------------------------===//
307 // TransformResults
308 //===----------------------------------------------------------------------===//
309 
310 transform::TransformResults::TransformResults(unsigned numSegments) {
311  segments.resize(numSegments,
312  ArrayRef<Operation *>(nullptr, static_cast<size_t>(0)));
313 }
314 
316  ArrayRef<Operation *> ops) {
317  int64_t position = value.getResultNumber();
318  assert(position < static_cast<int64_t>(segments.size()) &&
319  "setting results for a non-existent handle");
320  assert(segments[position].data() == nullptr && "results already set");
321  int64_t start = operations.size();
322  llvm::append_range(operations, ops);
323  segments[position] = makeArrayRef(operations).drop_front(start);
324 }
325 
327 transform::TransformResults::get(unsigned resultNumber) const {
328  assert(resultNumber < segments.size() &&
329  "querying results for a non-existent handle");
330  assert(segments[resultNumber].data() != nullptr && "querying unset results");
331  return segments[resultNumber];
332 }
333 
334 //===----------------------------------------------------------------------===//
335 // Utilities for PossibleTopLevelTransformOpTrait.
336 //===----------------------------------------------------------------------===//
337 
339  TransformState &state, Operation *op, Region &region) {
340  SmallVector<Operation *> targets;
341  if (op->getNumOperands() != 0)
342  llvm::append_range(targets, state.getPayloadOps(op->getOperand(0)));
343  else
344  targets.push_back(state.getTopLevel());
345 
346  return state.mapBlockArguments(region.front().getArgument(0), targets);
347 }
348 
351  // Attaching this trait without the interface is a misuse of the API, but it
352  // cannot be caught via a static_assert because interface registration is
353  // dynamic.
354  assert(isa<TransformOpInterface>(op) &&
355  "should implement TransformOpInterface to have "
356  "PossibleTopLevelTransformOpTrait");
357 
358  if (op->getNumRegions() < 1)
359  return op->emitOpError() << "expects at least one region";
360 
361  Region *bodyRegion = &op->getRegion(0);
362  if (!llvm::hasNItems(*bodyRegion, 1))
363  return op->emitOpError() << "expects a single-block region";
364 
365  Block *body = &bodyRegion->front();
366  if (body->getNumArguments() != 1 ||
367  !body->getArgumentTypes()[0].isa<TransformTypeInterface>()) {
368  return op->emitOpError() << "expects the entry block to have one argument "
369  "of type implementing TransformTypeInterface";
370  }
371 
372  if (auto *parent =
374  if (op->getNumOperands() == 0) {
376  op->emitOpError()
377  << "expects the root operation to be provided for a nested op";
378  diag.attachNote(parent->getLoc())
379  << "nested in another possible top-level op";
380  return diag;
381  }
382  }
383 
384  return success();
385 }
386 
387 //===----------------------------------------------------------------------===//
388 // Memory effects.
389 //===----------------------------------------------------------------------===//
390 
392  ValueRange handles,
394  for (Value handle : handles) {
395  effects.emplace_back(MemoryEffects::Read::get(), handle,
397  effects.emplace_back(MemoryEffects::Free::get(), handle,
399  }
400 }
401 
402 /// Returns `true` if the given list of effects instances contains an instance
403 /// with the effect type specified as template parameter.
404 template <typename EffectTy, typename ResourceTy = SideEffects::DefaultResource>
406  return llvm::any_of(effects, [](const MemoryEffects::EffectInstance &effect) {
407  return isa<EffectTy>(effect.getEffect()) &&
408  isa<ResourceTy>(effect.getResource());
409  });
410 }
411 
413  transform::TransformOpInterface transform) {
414  auto iface = cast<MemoryEffectOpInterface>(transform.getOperation());
416  iface.getEffectsOnValue(handle, effects);
417  return ::hasEffect<MemoryEffects::Read, TransformMappingResource>(effects) &&
418  ::hasEffect<MemoryEffects::Free, TransformMappingResource>(effects);
419 }
420 
422  ValueRange handles,
424  for (Value handle : handles) {
425  effects.emplace_back(MemoryEffects::Allocate::get(), handle,
427  effects.emplace_back(MemoryEffects::Write::get(), handle,
429  }
430 }
431 
433  ValueRange handles,
435  for (Value handle : handles) {
436  effects.emplace_back(MemoryEffects::Read::get(), handle,
438  }
439 }
440 
443  effects.emplace_back(MemoryEffects::Read::get(), PayloadIRResource::get());
444  effects.emplace_back(MemoryEffects::Write::get(), PayloadIRResource::get());
445 }
446 
449  effects.emplace_back(MemoryEffects::Read::get(), PayloadIRResource::get());
450 }
451 
452 //===----------------------------------------------------------------------===//
453 // Entry point.
454 //===----------------------------------------------------------------------===//
455 
457  TransformOpInterface transform,
458  const TransformOptions &options) {
459 #ifndef NDEBUG
460  if (!transform->hasTrait<PossibleTopLevelTransformOpTrait>() ||
461  transform->getNumOperands() != 0) {
462  transform->emitError()
463  << "expected transform to start at the top-level transform op";
464  llvm::report_fatal_error("could not run transforms",
465  /*gen_crash_diag=*/false);
466  }
467 #endif // NDEBUG
468 
469  TransformState state(transform->getParentRegion(), payloadRoot, options);
470  return state.applyTransform(transform).checkAndReport();
471 }
472 
473 //===----------------------------------------------------------------------===//
474 // Generated interface implementation.
475 //===----------------------------------------------------------------------===//
476 
477 #include "mlir/Dialect/Transform/IR/TransformInterfaces.cpp.inc"
static std::string diag(llvm::Value &value)
static constexpr const bool value
static llvm::ManagedStatic< PassManagerOptions > options
#define DEBUG_PRINT_AFTER_ALL
#define DBGS()
Block represents an ordered list of Operations.
Definition: Block.h:30
ValueTypeRange< BlockArgListType > getArgumentTypes()
Return a range containing the types of the arguments for this block.
Definition: Block.cpp:137
BlockArgument getArgument(unsigned i)
Definition: Block.h:118
unsigned getNumArguments()
Definition: Block.h:117
The result of a transform IR operation application.
bool isDefiniteFailure() const
Returns true if this is a definite failure.
LogicalResult checkAndReport()
Converts all kinds of failure into a LogicalResult failure, emitting the diagnostic if necessary.
static DiagnosedSilenceableFailure definiteFailure()
Constructs a DiagnosedSilenceableFailure in the failure state.
IRValueT get() const
Return the current value being used by this operand.
Definition: UseDefLists.h:137
This class represents a diagnostic that is inflight and set to be reported.
Definition: Diagnostics.h:307
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:64
This class represents an operand of an operation.
Definition: Value.h:247
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Definition: Value.cpp:212
Set of flags used to control the behavior of the various IR print methods (e.g.
This is a value defined by a result of an operation.
Definition: Value.h:442
Operation is a basic unit of execution within MLIR.
Definition: Operation.h:31
Value getOperand(unsigned idx)
Definition: Operation.h:267
unsigned getNumRegions()
Returns the number of regions held by this operation.
Definition: Operation.h:477
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:154
unsigned getNumOperands()
Definition: Operation.h:263
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Definition: Operation.h:486
Operation * getParentWithTrait()
Returns the closest surrounding parent operation with trait Trait.
Definition: Operation.h:179
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:512
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
Block & front()
Definition: Region.h:65
This class represents a specific instance of an effect.
Resource * getResource() const
Return the resource that the effect applies to.
EffectT * getEffect() const
Return the effect being applied.
static DerivedEffect * get()
Returns a unique instance for the derived effect class.
static TransformMappingResource * get()
Returns a unique instance for the given effect class.
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:349
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:85
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:26
Operation * getOwner() const
Return the owner of this operand.
Definition: UseDefLists.h:40
This trait is supposed to be attached to Transform dialect operations that can be standalone top-leve...
Options controlling the application of transform operations by the TransformState.
Local mapping between values defined by a specific op implementing the TransformOpInterface and the p...
void set(OpResult value, ArrayRef< Operation * > ops)
Indicates that the result of the transform IR op at the given position corresponds to the given list ...
LogicalResult replacePayloadOp(Operation *op, Operation *replacement)
Replaces the given payload op with another op.
virtual ~Extension()
Base virtual destructor.
The state maintained across applications of various ops implementing the TransformOpInterface.
LogicalResult getHandlesForPayloadOp(Operation *op, SmallVectorImpl< Value > &handles) const
Populates handles with all handles pointing to the given Payload IR op.
LogicalResult mapBlockArguments(BlockArgument argument, ArrayRef< Operation * > operations)
Records the mapping between a block argument in the transform IR and a list of operations in the payl...
DiagnosedSilenceableFailure applyTransform(TransformOpInterface transform)
Applies the transformation specified by the given transform op and updates the state accordingly.
ArrayRef< Operation * > getPayloadOps(Value value) const
Returns the list of ops that the given transform IR value corresponds to.
Operation * getTopLevel() const
Returns the op at which the transformation state is rooted.
LogicalResult mapPossibleTopLevelTransformOpBlockArguments(TransformState &state, Operation *op, Region &region)
Maps the only block argument of the op with PossibleTopLevelTransformOpTrait to either the list of op...
LogicalResult verifyPossibleTopLevelTransformOpTrait(Operation *op)
Verification hook for PossibleTopLevelTransformOpTrait.
void onlyReadsPayload(SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
bool isHandleConsumed(Value handle, transform::TransformOpInterface transform)
Checks whether the transform op consumes the given handle.
void onlyReadsHandle(ValueRange handles, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
void consumesHandle(ValueRange handles, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
Populates effects with the memory effects indicating the operation on the given handle value:
LogicalResult applyTransforms(Operation *payloadRoot, TransformOpInterface transform, const TransformOptions &options=TransformOptions())
Entry point to the Transform dialect infrastructure.
void producesHandle(ValueRange handles, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
void modifiesPayload(SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
Populates effects with the memory effects indicating the access to payload IR resource.
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
bool hasEffect(Operation *op, Value value=nullptr)
Returns true if op has an effect of type EffectTy on value.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26