MLIR  16.0.0git
TransformInterfaces.cpp
Go to the documentation of this file.
1 //===- TransformInterfaces.cpp - Transform Dialect Interfaces -------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
11 #include "mlir/IR/Diagnostics.h"
12 #include "mlir/IR/Operation.h"
13 #include "llvm/Support/Debug.h"
14 
15 #define DEBUG_TYPE "transform-dialect"
16 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "] ")
17 
18 using namespace mlir;
19 
20 //===----------------------------------------------------------------------===//
21 // TransformState
22 //===----------------------------------------------------------------------===//
23 
24 constexpr const Value transform::TransformState::kTopLevelValue;
25 
28  : topLevel(root), options(options) {
29  auto result = mappings.try_emplace(&region);
30  assert(result.second && "the region scope is already present");
31  (void)result;
32 #if LLVM_ENABLE_ABI_BREAKING_CHECKS
33  regionStack.push_back(&region);
34 #endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
35 }
36 
38 
41  const TransformOpMapping &operationMapping = getMapping(value).direct;
42  auto iter = operationMapping.find(value);
43  assert(iter != operationMapping.end() && "unknown handle");
44  return iter->getSecond();
45 }
46 
48  for (const Mappings &mapping : llvm::make_second_range(mappings)) {
49  if (Value handle = mapping.reverse.lookup(op))
50  return handle;
51  }
52  return Value();
53 }
54 
55 LogicalResult transform::TransformState::tryEmplaceReverseMapping(
56  Mappings &map, Operation *operation, Value handle) {
57  auto insertionResult = map.reverse.insert({operation, handle});
58  if (!insertionResult.second && insertionResult.first->second != handle) {
59  InFlightDiagnostic diag = operation->emitError()
60  << "operation tracked by two handles";
61  diag.attachNote(handle.getLoc()) << "handle";
62  diag.attachNote(insertionResult.first->second.getLoc()) << "handle";
63  return diag;
64  }
65  return success();
66 }
67 
69 transform::TransformState::setPayloadOps(Value value,
70  ArrayRef<Operation *> targets) {
71  assert(value != kTopLevelValue &&
72  "attempting to reset the transformation root");
73 
74  if (value.use_empty())
75  return success();
76 
77  // Setting new payload for the value without cleaning it first is a misuse of
78  // the API, assert here.
79  SmallVector<Operation *> storedTargets(targets.begin(), targets.end());
80  Mappings &mappings = getMapping(value);
81  bool inserted =
82  mappings.direct.insert({value, std::move(storedTargets)}).second;
83  assert(inserted && "value is already associated with another list");
84  (void)inserted;
85 
86  // Having multiple handles to the same operation is an error in the transform
87  // expressed using the dialect and may be constructed by valid API calls from
88  // valid IR. Emit an error here.
89  for (Operation *op : targets) {
90  if (failed(tryEmplaceReverseMapping(mappings, op, value)))
91  return failure();
92  }
93 
94  return success();
95 }
96 
97 void transform::TransformState::removePayloadOps(Value value) {
98  Mappings &mappings = getMapping(value);
99  for (Operation *op : mappings.direct[value])
100  mappings.reverse.erase(op);
101  mappings.direct.erase(value);
102 }
103 
104 LogicalResult transform::TransformState::updatePayloadOps(
105  Value value, function_ref<Operation *(Operation *)> callback) {
106  Mappings &mappings = getMapping(value);
107  auto it = mappings.direct.find(value);
108  assert(it != mappings.direct.end() && "unknown handle");
109  SmallVector<Operation *> &association = it->getSecond();
110  SmallVector<Operation *> updated;
111  updated.reserve(association.size());
112 
113  for (Operation *op : association) {
114  mappings.reverse.erase(op);
115  if (Operation *updatedOp = callback(op)) {
116  updated.push_back(updatedOp);
117  if (failed(tryEmplaceReverseMapping(mappings, updatedOp, value)))
118  return failure();
119  }
120  }
121 
122  std::swap(association, updated);
123  return success();
124 }
125 
126 void transform::TransformState::recordHandleInvalidation(OpOperand &handle) {
127  ArrayRef<Operation *> potentialAncestors = getPayloadOps(handle.get());
128  for (const Mappings &mapping : llvm::make_second_range(mappings)) {
129  for (const auto &kvp : mapping.reverse) {
130  // If the op is associated with invalidated handle, skip the check as it
131  // may be reading invalid IR.
132  Operation *op = kvp.first;
133  Value otherHandle = kvp.second;
134  if (invalidatedHandles.count(otherHandle))
135  continue;
136 
137  for (Operation *ancestor : potentialAncestors) {
138  if (!ancestor->isProperAncestor(op))
139  continue;
140 
141  // Make sure the error-reporting lambda doesn't capture anything
142  // by-reference because it will go out of scope. Additionally, extract
143  // location from Payload IR ops because the ops themselves may be
144  // deleted before the lambda gets called.
145  Location ancestorLoc = ancestor->getLoc();
146  Location opLoc = op->getLoc();
147  Operation *owner = handle.getOwner();
148  unsigned operandNo = handle.getOperandNumber();
149  invalidatedHandles[otherHandle] = [ancestorLoc, opLoc, owner, operandNo,
150  otherHandle]() {
152  owner->emitOpError()
153  << "invalidated the handle to payload operations nested in the "
154  "payload operation associated with its operand #"
155  << operandNo;
156  diag.attachNote(ancestorLoc) << "ancestor op";
157  diag.attachNote(opLoc) << "nested op";
158  diag.attachNote(otherHandle.getLoc()) << "other handle";
159  };
160  }
161  }
162  }
163 }
164 
165 LogicalResult transform::TransformState::checkAndRecordHandleInvalidation(
166  TransformOpInterface transform) {
167  auto memoryEffectsIface =
168  cast<MemoryEffectOpInterface>(transform.getOperation());
170  memoryEffectsIface.getEffectsOnResource(
172 
173  for (OpOperand &target : transform->getOpOperands()) {
174  // If the operand uses an invalidated handle, report it.
175  auto it = invalidatedHandles.find(target.get());
176  if (it != invalidatedHandles.end())
177  return it->getSecond()(), failure();
178 
179  // Invalidate handles pointing to the operations nested in the operation
180  // associated with the handle consumed by this operation.
181  auto consumesTarget = [&](const MemoryEffects::EffectInstance &effect) {
182  return isa<MemoryEffects::Free>(effect.getEffect()) &&
183  effect.getValue() == target.get();
184  };
185  if (llvm::any_of(effects, consumesTarget))
186  recordHandleInvalidation(target);
187  }
188  return success();
189 }
190 
192 transform::TransformState::applyTransform(TransformOpInterface transform) {
193  LLVM_DEBUG(DBGS() << "applying: " << transform << "\n");
194  if (options.getExpensiveChecksEnabled()) {
195  if (failed(checkAndRecordHandleInvalidation(transform)))
197 
198  for (OpOperand &operand : transform->getOpOperands()) {
199  if (!isHandleConsumed(operand.get(), transform))
200  continue;
201 
203  for (Operation *op : getPayloadOps(operand.get())) {
204  if (!seen.insert(op).second) {
206  transform.emitSilenceableError()
207  << "a handle passed as operand #" << operand.getOperandNumber()
208  << " and consumed by this operation points to a payload "
209  "operation more than once";
210  diag.attachNote(op->getLoc()) << "repeated target op";
211  return diag;
212  }
213  }
214  }
215  }
216 
217  transform::TransformResults results(transform->getNumResults());
218  DiagnosedSilenceableFailure result(transform.apply(results, *this));
219  if (!result.succeeded())
220  return result;
221 
222  // Remove the mapping for the operand if it is consumed by the operation. This
223  // allows us to catch use-after-free with assertions later on.
224  auto memEffectInterface =
225  cast<MemoryEffectOpInterface>(transform.getOperation());
227  for (OpOperand &target : transform->getOpOperands()) {
228  effects.clear();
229  memEffectInterface.getEffectsOnValue(target.get(), effects);
230  if (llvm::any_of(effects, [](const MemoryEffects::EffectInstance &effect) {
231  return isa<transform::TransformMappingResource>(
232  effect.getResource()) &&
233  isa<MemoryEffects::Free>(effect.getEffect());
234  })) {
235  removePayloadOps(target.get());
236  }
237  }
238 
239  for (OpResult result : transform->getResults()) {
240  assert(result.getDefiningOp() == transform.getOperation() &&
241  "payload IR association for a value other than the result of the "
242  "current transform op");
243  if (failed(setPayloadOps(result, results.get(result.getResultNumber()))))
245  }
246 
248 }
249 
250 //===----------------------------------------------------------------------===//
251 // TransformState::Extension
252 //===----------------------------------------------------------------------===//
253 
255 
258  Operation *replacement) {
259  return state.updatePayloadOps(state.getHandleForPayloadOp(op),
260  [&](Operation *current) {
261  return current == op ? replacement : current;
262  });
263 }
264 
265 //===----------------------------------------------------------------------===//
266 // TransformResults
267 //===----------------------------------------------------------------------===//
268 
269 transform::TransformResults::TransformResults(unsigned numSegments) {
270  segments.resize(numSegments,
271  ArrayRef<Operation *>(nullptr, static_cast<size_t>(0)));
272 }
273 
275  ArrayRef<Operation *> ops) {
276  unsigned position = value.getResultNumber();
277  assert(position < segments.size() &&
278  "setting results for a non-existent handle");
279  assert(segments[position].data() == nullptr && "results already set");
280  unsigned start = operations.size();
281  llvm::append_range(operations, ops);
282  segments[position] = makeArrayRef(operations).drop_front(start);
283 }
284 
286 transform::TransformResults::get(unsigned resultNumber) const {
287  assert(resultNumber < segments.size() &&
288  "querying results for a non-existent handle");
289  assert(segments[resultNumber].data() != nullptr && "querying unset results");
290  return segments[resultNumber];
291 }
292 
293 //===----------------------------------------------------------------------===//
294 // Utilities for PossibleTopLevelTransformOpTrait.
295 //===----------------------------------------------------------------------===//
296 
298  TransformState &state, Operation *op, Region &region) {
299  SmallVector<Operation *> targets;
300  if (op->getNumOperands() != 0)
301  llvm::append_range(targets, state.getPayloadOps(op->getOperand(0)));
302  else
303  targets.push_back(state.getTopLevel());
304 
305  return state.mapBlockArguments(region.front().getArgument(0), targets);
306 }
307 
310  // Attaching this trait without the interface is a misuse of the API, but it
311  // cannot be caught via a static_assert because interface registration is
312  // dynamic.
313  assert(isa<TransformOpInterface>(op) &&
314  "should implement TransformOpInterface to have "
315  "PossibleTopLevelTransformOpTrait");
316 
317  if (op->getNumRegions() < 1)
318  return op->emitOpError() << "expects at least one region";
319 
320  Region *bodyRegion = &op->getRegion(0);
321  if (!llvm::hasNItems(*bodyRegion, 1))
322  return op->emitOpError() << "expects a single-block region";
323 
324  Block *body = &bodyRegion->front();
325  if (body->getNumArguments() != 1 ||
326  !body->getArgumentTypes()[0].isa<pdl::OperationType>()) {
327  return op->emitOpError()
328  << "expects the entry block to have one argument of type "
329  << pdl::OperationType::get(op->getContext());
330  }
331 
332  if (auto *parent =
334  if (op->getNumOperands() == 0) {
336  op->emitOpError()
337  << "expects the root operation to be provided for a nested op";
338  diag.attachNote(parent->getLoc())
339  << "nested in another possible top-level op";
340  return diag;
341  }
342  }
343 
344  return success();
345 }
346 
347 //===----------------------------------------------------------------------===//
348 // Memory effects.
349 //===----------------------------------------------------------------------===//
350 
352  ValueRange handles,
354  for (Value handle : handles) {
355  effects.emplace_back(MemoryEffects::Read::get(), handle,
357  effects.emplace_back(MemoryEffects::Free::get(), handle,
359  }
360 }
361 
362 /// Returns `true` if the given list of effects instances contains an instance
363 /// with the effect type specified as template parameter.
364 template <typename EffectTy, typename ResourceTy = SideEffects::DefaultResource>
366  return llvm::any_of(effects, [](const MemoryEffects::EffectInstance &effect) {
367  return isa<EffectTy>(effect.getEffect()) &&
368  isa<ResourceTy>(effect.getResource());
369  });
370 }
371 
373  transform::TransformOpInterface transform) {
374  auto iface = cast<MemoryEffectOpInterface>(transform.getOperation());
376  iface.getEffectsOnValue(handle, effects);
377  return hasEffect<MemoryEffects::Read, TransformMappingResource>(effects) &&
378  hasEffect<MemoryEffects::Free, TransformMappingResource>(effects);
379 }
380 
382  ValueRange handles,
384  for (Value handle : handles) {
385  effects.emplace_back(MemoryEffects::Allocate::get(), handle,
387  effects.emplace_back(MemoryEffects::Write::get(), handle,
389  }
390 }
391 
393  ValueRange handles,
395  for (Value handle : handles) {
396  effects.emplace_back(MemoryEffects::Read::get(), handle,
398  }
399 }
400 
403  effects.emplace_back(MemoryEffects::Read::get(), PayloadIRResource::get());
404  effects.emplace_back(MemoryEffects::Write::get(), PayloadIRResource::get());
405 }
406 
409  effects.emplace_back(MemoryEffects::Read::get(), PayloadIRResource::get());
410 }
411 
412 //===----------------------------------------------------------------------===//
413 // Generated interface implementation.
414 //===----------------------------------------------------------------------===//
415 
416 #include "mlir/Dialect/Transform/IR/TransformInterfaces.cpp.inc"
Diagnostic & attachNote(Optional< Location > loc=llvm::None)
Attaches a note to the last diagnostic.
Include the generated interface declarations.
This class contains a list of basic blocks and a link to the parent operation it is attached to...
Definition: Region.h:26
void set(OpResult value, ArrayRef< Operation *> ops)
Indicates that the result of the transform IR op at the given position corresponds to the given list ...
static std::string diag(llvm::Value &v)
The result of a transform IR operation application.
Operation is a basic unit of execution within MLIR.
Definition: Operation.h:28
This trait is supposed to be attached to Transform dialect operations that can be standalone top-leve...
This is a value defined by a result of an operation.
Definition: Value.h:425
EffectT * getEffect() const
Return the effect being applied.
unsigned getNumRegions()
Returns the number of regions held by this operation.
Definition: Operation.h:477
This class represents a diagnostic that is inflight and set to be reported.
Definition: Diagnostics.h:310
Block represents an ordered list of Operations.
Definition: Block.h:29
Block & front()
Definition: Region.h:65
Value getOperand(unsigned idx)
Definition: Operation.h:267
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value...
Definition: LogicalResult.h:72
virtual ~Extension()
Base virtual destructor.
unsigned getNumOperands()
Definition: Operation.h:263
Options controlling the application of transform operations by the TransformState.
Operation * getParentWithTrait()
Returns the closest surrounding parent operation with trait Trait.
Definition: Operation.h:179
bool getExpensiveChecksEnabled() const
Returns true if the expensive checks are requested.
ArrayRef< Operation * > getPayloadOps(Value value) const
Returns the list of ops that the given transform IR value corresponds to.
static bool hasEffect(ArrayRef< MemoryEffects::EffectInstance > effects)
Returns true if the given list of effects instances contains an instance with the effect type specifi...
BlockArgument getArgument(unsigned i)
Definition: Block.h:120
static constexpr const bool value
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:48
MLIRContext * getContext()
Return the context this operation is associated with.
Definition: Operation.h:147
static TransformMappingResource * get()
Returns a unique instance for the given effect class.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Definition: Value.cpp:212
static DerivedEffect * get()
Returns a unique instance for the derived effect class.
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
void consumesHandle(ValueRange handles, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
Populates effects with the memory effects indicating the operation on the given handle value: ...
static DiagnosedSilenceableFailure success()
Constructs a DiagnosedSilenceableFailure in the success state.
void onlyReadsHandle(ValueRange handles, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
Diagnostic & attachNote(Optional< Location > noteLoc=llvm::None)
Attaches a note to this diagnostic.
Definition: Diagnostics.h:348
unsigned getNumArguments()
Definition: Block.h:119
unsigned getResultNumber() const
Returns the number of this result.
Definition: Value.h:437
Operation * getTopLevel() const
Returns the op at which the transformation state is rooted.
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:154
IRValueT get() const
Return the current value being used by this operand.
Definition: UseDefLists.h:137
The state maintained across applications of various ops implementing the TransformOpInterface.
This class represents a specific instance of an effect.
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:26
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:85
bool use_empty() const
Returns true if this value has no uses.
Definition: Value.h:203
Value getHandleForPayloadOp(Operation *op) const
Returns the Transform IR handle for the given Payload IR op if it exists in the state, null otherwise.
static llvm::ManagedStatic< PassManagerOptions > options
LogicalResult mapBlockArguments(BlockArgument argument, ArrayRef< Operation *> operations)
Records the mapping between a block argument in the transform IR and a list of operations in the payl...
ValueTypeRange< BlockArgListType > getArgumentTypes()
Return a range containing the types of the arguments for this block.
Definition: Block.cpp:137
Resource * getResource() const
Return the resource that the effect applies to.
Local mapping between values defined by a specific op implementing the TransformOpInterface and the p...
LogicalResult mapPossibleTopLevelTransformOpBlockArguments(TransformState &state, Operation *op, Region &region)
Maps the only block argument of the op with PossibleTopLevelTransformOpTrait to either the list of op...
bool isHandleConsumed(Value handle, transform::TransformOpInterface transform)
Checks whether the transform op consumes the given handle.
void modifiesPayload(SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
Populates effects with the memory effects indicating the access to payload IR resource.
#define DBGS()
void onlyReadsPayload(SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
Operation * getOwner() const
Return the owner of this operand.
Definition: UseDefLists.h:40
This class represents an operand of an operation.
Definition: Value.h:251
LogicalResult verifyPossibleTopLevelTransformOpTrait(Operation *op)
Verification hook for PossibleTopLevelTransformOpTrait.
LogicalResult replacePayloadOp(Operation *op, Operation *replacement)
Replaces the given payload op with another op.
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "&#39;dim&#39; op " which is convenient for verifiers...
Definition: Operation.cpp:508
TransformState(Region &region, Operation *root, const TransformOptions &options=TransformOptions())
Creates a state for transform ops living in the given region.
void producesHandle(ValueRange handles, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:221
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:345
static DiagnosedSilenceableFailure definiteFailure()
Constructs a DiagnosedSilenceableFailure in the failure state.
DiagnosedSilenceableFailure applyTransform(TransformOpInterface transform)
Applies the transformation specified by the given transform op and updates the state accordingly...
Region & getRegion(unsigned index)
Returns the region held by this operation at position &#39;index&#39;.
Definition: Operation.h:486