MLIR  20.0.0git
TransformInterfaces.cpp
Go to the documentation of this file.
1 //===- TransformInterfaces.cpp - Transform Dialect Interfaces -------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
11 #include "mlir/IR/Diagnostics.h"
12 #include "mlir/IR/Operation.h"
13 #include "mlir/IR/PatternMatch.h"
16 #include "llvm/ADT/STLExtras.h"
17 #include "llvm/ADT/ScopeExit.h"
18 #include "llvm/Support/Debug.h"
19 #include "llvm/Support/ErrorHandling.h"
20 
21 #define DEBUG_TYPE "transform-dialect"
22 #define DEBUG_TYPE_FULL "transform-dialect-full"
23 #define DEBUG_PRINT_AFTER_ALL "transform-dialect-print-top-level-after-all"
24 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "] ")
25 #define LDBG(X) LLVM_DEBUG(DBGS() << (X))
26 #define FULL_LDBG(X) DEBUG_WITH_TYPE(DEBUG_TYPE_FULL, (DBGS() << (X)))
27 
28 using namespace mlir;
29 
30 //===----------------------------------------------------------------------===//
31 // Helper functions
32 //===----------------------------------------------------------------------===//
33 
34 /// Return true if `a` happens before `b`, i.e., `a` or one of its ancestors
35 /// properly dominates `b` and `b` is not inside `a`.
36 static bool happensBefore(Operation *a, Operation *b) {
37  do {
38  if (a->isProperAncestor(b))
39  return false;
40  if (Operation *bAncestor = a->getBlock()->findAncestorOpInBlock(*b)) {
41  return a->isBeforeInBlock(bAncestor);
42  }
43  } while ((a = a->getParentOp()));
44  return false;
45 }
46 
47 //===----------------------------------------------------------------------===//
48 // TransformState
49 //===----------------------------------------------------------------------===//
50 
51 constexpr const Value transform::TransformState::kTopLevelValue;
52 
53 transform::TransformState::TransformState(
54  Region *region, Operation *payloadRoot,
55  const RaggedArray<MappedValue> &extraMappings,
56  const TransformOptions &options)
57  : topLevel(payloadRoot), options(options) {
58  topLevelMappedValues.reserve(extraMappings.size());
59  for (ArrayRef<MappedValue> mapping : extraMappings)
60  topLevelMappedValues.push_back(mapping);
61  if (region) {
62  RegionScope *scope = new RegionScope(*this, *region);
63  topLevelRegionScope.reset(scope);
64  }
65 }
66 
67 Operation *transform::TransformState::getTopLevel() const { return topLevel; }
68 
70 transform::TransformState::getPayloadOpsView(Value value) const {
71  const TransformOpMapping &operationMapping = getMapping(value).direct;
72  auto iter = operationMapping.find(value);
73  assert(iter != operationMapping.end() &&
74  "cannot find mapping for payload handle (param/value handle "
75  "provided?)");
76  return iter->getSecond();
77 }
78 
80  const ParamMapping &mapping = getMapping(value).params;
81  auto iter = mapping.find(value);
82  assert(iter != mapping.end() && "cannot find mapping for param handle "
83  "(operation/value handle provided?)");
84  return iter->getSecond();
85 }
86 
88 transform::TransformState::getPayloadValuesView(Value handleValue) const {
89  const ValueMapping &mapping = getMapping(handleValue).values;
90  auto iter = mapping.find(handleValue);
91  assert(iter != mapping.end() && "cannot find mapping for value handle "
92  "(param/operation handle provided?)");
93  return iter->getSecond();
94 }
95 
97  Operation *op, SmallVectorImpl<Value> &handles,
98  bool includeOutOfScope) const {
99  bool found = false;
100  for (const auto &[region, mapping] : llvm::reverse(mappings)) {
101  auto iterator = mapping->reverse.find(op);
102  if (iterator != mapping->reverse.end()) {
103  llvm::append_range(handles, iterator->getSecond());
104  found = true;
105  }
106  // Stop looking when reaching a region that is isolated from above.
107  if (!includeOutOfScope &&
109  break;
110  }
111 
112  return success(found);
113 }
114 
116  Value payloadValue, SmallVectorImpl<Value> &handles,
117  bool includeOutOfScope) const {
118  bool found = false;
119  for (const auto &[region, mapping] : llvm::reverse(mappings)) {
120  auto iterator = mapping->reverseValues.find(payloadValue);
121  if (iterator != mapping->reverseValues.end()) {
122  llvm::append_range(handles, iterator->getSecond());
123  found = true;
124  }
125  // Stop looking when reaching a region that is isolated from above.
126  if (!includeOutOfScope &&
128  break;
129  }
130 
131  return success(found);
132 }
133 
134 /// Given a list of MappedValues, cast them to the value kind implied by the
135 /// interface of the handle type, and dispatch to one of the callbacks.
138  function_ref<LogicalResult(ArrayRef<Operation *>)> operationsFn,
139  function_ref<LogicalResult(ArrayRef<transform::Param>)> paramsFn,
140  function_ref<LogicalResult(ValueRange)> valuesFn) {
141  if (llvm::isa<transform::TransformHandleTypeInterface>(handle.getType())) {
142  SmallVector<Operation *> operations;
143  operations.reserve(values.size());
144  for (transform::MappedValue value : values) {
145  if (auto *op = llvm::dyn_cast_if_present<Operation *>(value)) {
146  operations.push_back(op);
147  continue;
148  }
149  return emitSilenceableFailure(handle.getLoc())
150  << "wrong kind of value provided for top-level operation handle";
151  }
152  if (failed(operationsFn(operations)))
155  }
156 
157  if (llvm::isa<transform::TransformValueHandleTypeInterface>(
158  handle.getType())) {
159  SmallVector<Value> payloadValues;
160  payloadValues.reserve(values.size());
161  for (transform::MappedValue value : values) {
162  if (auto v = llvm::dyn_cast_if_present<Value>(value)) {
163  payloadValues.push_back(v);
164  continue;
165  }
166  return emitSilenceableFailure(handle.getLoc())
167  << "wrong kind of value provided for the top-level value handle";
168  }
169  if (failed(valuesFn(payloadValues)))
172  }
173 
174  assert(llvm::isa<transform::TransformParamTypeInterface>(handle.getType()) &&
175  "unsupported kind of block argument");
177  parameters.reserve(values.size());
178  for (transform::MappedValue value : values) {
179  if (auto attr = llvm::dyn_cast_if_present<Attribute>(value)) {
180  parameters.push_back(attr);
181  continue;
182  }
183  return emitSilenceableFailure(handle.getLoc())
184  << "wrong kind of value provided for top-level parameter";
185  }
186  if (failed(paramsFn(parameters)))
189 }
190 
191 LogicalResult
193  ArrayRef<MappedValue> values) {
194  return dispatchMappedValues(
195  argument, values,
196  [&](ArrayRef<Operation *> operations) {
197  return setPayloadOps(argument, operations);
198  },
199  [&](ArrayRef<Param> params) {
200  return setParams(argument, params);
201  },
202  [&](ValueRange payloadValues) {
203  return setPayloadValues(argument, payloadValues);
204  })
205  .checkAndReport();
206 }
207 
209  Block::BlockArgListType arguments,
211  for (auto &&[argument, values] : llvm::zip_equal(arguments, mapping))
212  if (failed(mapBlockArgument(argument, values)))
213  return failure();
214  return success();
215 }
216 
217 LogicalResult
218 transform::TransformState::setPayloadOps(Value value,
219  ArrayRef<Operation *> targets) {
220  assert(value != kTopLevelValue &&
221  "attempting to reset the transformation root");
222  assert(llvm::isa<TransformHandleTypeInterface>(value.getType()) &&
223  "wrong handle type");
224 
225  for (Operation *target : targets) {
226  if (target)
227  continue;
228  return emitError(value.getLoc())
229  << "attempting to assign a null payload op to this transform value";
230  }
231 
232  auto iface = llvm::cast<TransformHandleTypeInterface>(value.getType());
234  iface.checkPayload(value.getLoc(), targets);
235  if (failed(result.checkAndReport()))
236  return failure();
237 
238  // Setting new payload for the value without cleaning it first is a misuse of
239  // the API, assert here.
240  SmallVector<Operation *> storedTargets(targets);
241  Mappings &mappings = getMapping(value);
242  bool inserted =
243  mappings.direct.insert({value, std::move(storedTargets)}).second;
244  assert(inserted && "value is already associated with another list");
245  (void)inserted;
246 
247  for (Operation *op : targets)
248  mappings.reverse[op].push_back(value);
249 
250  return success();
251 }
252 
253 LogicalResult
254 transform::TransformState::setPayloadValues(Value handle,
255  ValueRange payloadValues) {
256  assert(handle != nullptr && "attempting to set params for a null value");
257  assert(llvm::isa<TransformValueHandleTypeInterface>(handle.getType()) &&
258  "wrong handle type");
259 
260  for (Value payload : payloadValues) {
261  if (payload)
262  continue;
263  return emitError(handle.getLoc()) << "attempting to assign a null payload "
264  "value to this transform handle";
265  }
266 
267  auto iface = llvm::cast<TransformValueHandleTypeInterface>(handle.getType());
268  SmallVector<Value> payloadValueVector = llvm::to_vector(payloadValues);
270  iface.checkPayload(handle.getLoc(), payloadValueVector);
271  if (failed(result.checkAndReport()))
272  return failure();
273 
274  Mappings &mappings = getMapping(handle);
275  bool inserted =
276  mappings.values.insert({handle, std::move(payloadValueVector)}).second;
277  assert(
278  inserted &&
279  "value handle is already associated with another list of payload values");
280  (void)inserted;
281 
282  for (Value payload : payloadValues)
283  mappings.reverseValues[payload].push_back(handle);
284 
285  return success();
286 }
287 
288 LogicalResult transform::TransformState::setParams(Value value,
289  ArrayRef<Param> params) {
290  assert(value != nullptr && "attempting to set params for a null value");
291 
292  for (Attribute attr : params) {
293  if (attr)
294  continue;
295  return emitError(value.getLoc())
296  << "attempting to assign a null parameter to this transform value";
297  }
298 
299  auto valueType = llvm::dyn_cast<TransformParamTypeInterface>(value.getType());
300  assert(value &&
301  "cannot associate parameter with a value of non-parameter type");
303  valueType.checkPayload(value.getLoc(), params);
304  if (failed(result.checkAndReport()))
305  return failure();
306 
307  Mappings &mappings = getMapping(value);
308  bool inserted =
309  mappings.params.insert({value, llvm::to_vector(params)}).second;
310  assert(inserted && "value is already associated with another list of params");
311  (void)inserted;
312  return success();
313 }
314 
315 template <typename Mapping, typename Key, typename Mapped>
316 void dropMappingEntry(Mapping &mapping, Key key, Mapped mapped) {
317  auto it = mapping.find(key);
318  if (it == mapping.end())
319  return;
320 
321  llvm::erase(it->getSecond(), mapped);
322  if (it->getSecond().empty())
323  mapping.erase(it);
324 }
325 
326 void transform::TransformState::forgetMapping(Value opHandle,
327  ValueRange origOpFlatResults,
328  bool allowOutOfScope) {
329  Mappings &mappings = getMapping(opHandle, allowOutOfScope);
330  for (Operation *op : mappings.direct[opHandle])
331  dropMappingEntry(mappings.reverse, op, opHandle);
332  mappings.direct.erase(opHandle);
333 #if LLVM_ENABLE_ABI_BREAKING_CHECKS
334  // Payload IR is removed from the mapping. This invalidates the respective
335  // iterators.
336  mappings.incrementTimestamp(opHandle);
337 #endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
338 
339  for (Value opResult : origOpFlatResults) {
340  SmallVector<Value> resultHandles;
341  (void)getHandlesForPayloadValue(opResult, resultHandles);
342  for (Value resultHandle : resultHandles) {
343  Mappings &localMappings = getMapping(resultHandle);
344  dropMappingEntry(localMappings.values, resultHandle, opResult);
345 #if LLVM_ENABLE_ABI_BREAKING_CHECKS
346  // Payload IR is removed from the mapping. This invalidates the respective
347  // iterators.
348  mappings.incrementTimestamp(resultHandle);
349 #endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
350  dropMappingEntry(localMappings.reverseValues, opResult, resultHandle);
351  }
352  }
353 }
354 
355 void transform::TransformState::forgetValueMapping(
356  Value valueHandle, ArrayRef<Operation *> payloadOperations) {
357  Mappings &mappings = getMapping(valueHandle);
358  for (Value payloadValue : mappings.reverseValues[valueHandle])
359  dropMappingEntry(mappings.reverseValues, payloadValue, valueHandle);
360  mappings.values.erase(valueHandle);
361 #if LLVM_ENABLE_ABI_BREAKING_CHECKS
362  // Payload IR is removed from the mapping. This invalidates the respective
363  // iterators.
364  mappings.incrementTimestamp(valueHandle);
365 #endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
366 
367  for (Operation *payloadOp : payloadOperations) {
368  SmallVector<Value> opHandles;
369  (void)getHandlesForPayloadOp(payloadOp, opHandles);
370  for (Value opHandle : opHandles) {
371  Mappings &localMappings = getMapping(opHandle);
372  dropMappingEntry(localMappings.direct, opHandle, payloadOp);
373  dropMappingEntry(localMappings.reverse, payloadOp, opHandle);
374 
375 #if LLVM_ENABLE_ABI_BREAKING_CHECKS
376  // Payload IR is removed from the mapping. This invalidates the respective
377  // iterators.
378  localMappings.incrementTimestamp(opHandle);
379 #endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
380  }
381  }
382 }
383 
384 LogicalResult
385 transform::TransformState::replacePayloadOp(Operation *op,
386  Operation *replacement) {
387  // TODO: consider invalidating the handles to nested objects here.
388 
389 #ifndef NDEBUG
390  for (Value opResult : op->getResults()) {
391  SmallVector<Value> valueHandles;
392  (void)getHandlesForPayloadValue(opResult, valueHandles,
393  /*includeOutOfScope=*/true);
394  assert(valueHandles.empty() && "expected no mapping to old results");
395  }
396 #endif // NDEBUG
397 
398  // Drop the mapping between the op and all handles that point to it. Fail if
399  // there are no handles.
400  SmallVector<Value> opHandles;
401  if (failed(getHandlesForPayloadOp(op, opHandles, /*includeOutOfScope=*/true)))
402  return failure();
403  for (Value handle : opHandles) {
404  Mappings &mappings = getMapping(handle, /*allowOutOfScope=*/true);
405  dropMappingEntry(mappings.reverse, op, handle);
406  }
407 
408  // Replace the pointed-to object of all handles with the replacement object.
409  // In case a payload op was erased (replacement object is nullptr), a nullptr
410  // is stored in the mapping. These nullptrs are removed after each transform.
411  // Furthermore, nullptrs are not enumerated by payload op iterators. The
412  // relative order of ops is preserved.
413  //
414  // Removing an op from the mapping would be problematic because removing an
415  // element from an array invalidates iterators; merely changing the value of
416  // elements does not.
417  for (Value handle : opHandles) {
418  Mappings &mappings = getMapping(handle, /*allowOutOfScope=*/true);
419  auto it = mappings.direct.find(handle);
420  if (it == mappings.direct.end())
421  continue;
422 
423  SmallVector<Operation *, 2> &association = it->getSecond();
424  // Note that an operation may be associated with the handle more than once.
425  for (Operation *&mapped : association) {
426  if (mapped == op)
427  mapped = replacement;
428  }
429 
430  if (replacement) {
431  mappings.reverse[replacement].push_back(handle);
432  } else {
433  opHandlesToCompact.insert(handle);
434  }
435  }
436 
437  return success();
438 }
439 
440 LogicalResult
441 transform::TransformState::replacePayloadValue(Value value, Value replacement) {
442  SmallVector<Value> valueHandles;
443  if (failed(getHandlesForPayloadValue(value, valueHandles,
444  /*includeOutOfScope=*/true)))
445  return failure();
446 
447  for (Value handle : valueHandles) {
448  Mappings &mappings = getMapping(handle, /*allowOutOfScope=*/true);
449  dropMappingEntry(mappings.reverseValues, value, handle);
450 
451  // If replacing with null, that is erasing the mapping, drop the mapping
452  // between the handles and the IR objects
453  if (!replacement) {
454  dropMappingEntry(mappings.values, handle, value);
455 #if LLVM_ENABLE_ABI_BREAKING_CHECKS
456  // Payload IR is removed from the mapping. This invalidates the respective
457  // iterators.
458  mappings.incrementTimestamp(handle);
459 #endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
460  } else {
461  auto it = mappings.values.find(handle);
462  if (it == mappings.values.end())
463  continue;
464 
465  SmallVector<Value> &association = it->getSecond();
466  for (Value &mapped : association) {
467  if (mapped == value)
468  mapped = replacement;
469  }
470  mappings.reverseValues[replacement].push_back(handle);
471  }
472  }
473 
474  return success();
475 }
476 
477 void transform::TransformState::recordOpHandleInvalidationOne(
478  OpOperand &consumingHandle, ArrayRef<Operation *> potentialAncestors,
479  Operation *payloadOp, Value otherHandle, Value throughValue,
480  transform::TransformState::InvalidatedHandleMap &newlyInvalidated) const {
481  // If the op is associated with invalidated handle, skip the check as it
482  // may be reading invalid IR. This also ensures we report the first
483  // invalidation and not the last one.
484  if (invalidatedHandles.count(otherHandle) ||
485  newlyInvalidated.count(otherHandle))
486  return;
487 
488  FULL_LDBG("--recordOpHandleInvalidationOne\n");
489  DEBUG_WITH_TYPE(
491  llvm::interleaveComma(potentialAncestors, DBGS() << "--ancestors: ",
492  [](Operation *op) { llvm::dbgs() << *op; });
493  llvm::dbgs() << "\n");
494 
495  Operation *owner = consumingHandle.getOwner();
496  unsigned operandNo = consumingHandle.getOperandNumber();
497  for (Operation *ancestor : potentialAncestors) {
498  // clang-format off
499  DEBUG_WITH_TYPE(DEBUG_TYPE_FULL,
500  { (DBGS() << "----handle one ancestor: " << *ancestor << "\n"); });
501  DEBUG_WITH_TYPE(DEBUG_TYPE_FULL,
502  { (DBGS() << "----of payload with name: "
503  << payloadOp->getName().getIdentifier() << "\n"); });
504  DEBUG_WITH_TYPE(DEBUG_TYPE_FULL,
505  { (DBGS() << "----of payload: " << *payloadOp << "\n"); });
506  // clang-format on
507  if (!ancestor->isAncestor(payloadOp))
508  continue;
509 
510  // Make sure the error-reporting lambda doesn't capture anything
511  // by-reference because it will go out of scope. Additionally, extract
512  // location from Payload IR ops because the ops themselves may be
513  // deleted before the lambda gets called.
514  Location ancestorLoc = ancestor->getLoc();
515  Location opLoc = payloadOp->getLoc();
516  std::optional<Location> throughValueLoc =
517  throughValue ? std::make_optional(throughValue.getLoc()) : std::nullopt;
518  newlyInvalidated[otherHandle] = [ancestorLoc, opLoc, owner, operandNo,
519  otherHandle,
520  throughValueLoc](Location currentLoc) {
521  InFlightDiagnostic diag = emitError(currentLoc)
522  << "op uses a handle invalidated by a "
523  "previously executed transform op";
524  diag.attachNote(otherHandle.getLoc()) << "handle to invalidated ops";
525  diag.attachNote(owner->getLoc())
526  << "invalidated by this transform op that consumes its operand #"
527  << operandNo
528  << " and invalidates all handles to payload IR entities associated "
529  "with this operand and entities nested in them";
530  diag.attachNote(ancestorLoc) << "ancestor payload op";
531  diag.attachNote(opLoc) << "nested payload op";
532  if (throughValueLoc) {
533  diag.attachNote(*throughValueLoc)
534  << "consumed handle points to this payload value";
535  }
536  };
537  }
538 }
539 
540 void transform::TransformState::recordValueHandleInvalidationByOpHandleOne(
541  OpOperand &opHandle, ArrayRef<Operation *> potentialAncestors,
542  Value payloadValue, Value valueHandle,
543  transform::TransformState::InvalidatedHandleMap &newlyInvalidated) const {
544  // If the op is associated with invalidated handle, skip the check as it
545  // may be reading invalid IR. This also ensures we report the first
546  // invalidation and not the last one.
547  if (invalidatedHandles.count(valueHandle) ||
548  newlyInvalidated.count(valueHandle))
549  return;
550 
551  for (Operation *ancestor : potentialAncestors) {
552  Operation *definingOp;
553  std::optional<unsigned> resultNo;
554  unsigned argumentNo = std::numeric_limits<unsigned>::max();
555  unsigned blockNo = std::numeric_limits<unsigned>::max();
556  unsigned regionNo = std::numeric_limits<unsigned>::max();
557  if (auto opResult = llvm::dyn_cast<OpResult>(payloadValue)) {
558  definingOp = opResult.getOwner();
559  resultNo = opResult.getResultNumber();
560  } else {
561  auto arg = llvm::cast<BlockArgument>(payloadValue);
562  definingOp = arg.getParentBlock()->getParentOp();
563  argumentNo = arg.getArgNumber();
564  blockNo = std::distance(arg.getOwner()->getParent()->begin(),
565  arg.getOwner()->getIterator());
566  regionNo = arg.getOwner()->getParent()->getRegionNumber();
567  }
568  assert(definingOp && "expected the value to be defined by an op as result "
569  "or block argument");
570  if (!ancestor->isAncestor(definingOp))
571  continue;
572 
573  Operation *owner = opHandle.getOwner();
574  unsigned operandNo = opHandle.getOperandNumber();
575  Location ancestorLoc = ancestor->getLoc();
576  Location opLoc = definingOp->getLoc();
577  Location valueLoc = payloadValue.getLoc();
578  newlyInvalidated[valueHandle] = [valueHandle, owner, operandNo, resultNo,
579  argumentNo, blockNo, regionNo, ancestorLoc,
580  opLoc, valueLoc](Location currentLoc) {
581  InFlightDiagnostic diag = emitError(currentLoc)
582  << "op uses a handle invalidated by a "
583  "previously executed transform op";
584  diag.attachNote(valueHandle.getLoc()) << "invalidated handle";
585  diag.attachNote(owner->getLoc())
586  << "invalidated by this transform op that consumes its operand #"
587  << operandNo
588  << " and invalidates all handles to payload IR entities "
589  "associated with this operand and entities nested in them";
590  diag.attachNote(ancestorLoc)
591  << "ancestor op associated with the consumed handle";
592  if (resultNo) {
593  diag.attachNote(opLoc)
594  << "op defining the value as result #" << *resultNo;
595  } else {
596  diag.attachNote(opLoc)
597  << "op defining the value as block argument #" << argumentNo
598  << " of block #" << blockNo << " in region #" << regionNo;
599  }
600  diag.attachNote(valueLoc) << "payload value";
601  };
602  }
603 }
604 
605 void transform::TransformState::recordOpHandleInvalidation(
606  OpOperand &handle, ArrayRef<Operation *> potentialAncestors,
607  Value throughValue,
608  transform::TransformState::InvalidatedHandleMap &newlyInvalidated) const {
609 
610  if (potentialAncestors.empty()) {
611  DEBUG_WITH_TYPE(DEBUG_TYPE_FULL, {
612  (DBGS() << "----recording invalidation for empty handle: " << handle.get()
613  << "\n");
614  });
615 
616  Operation *owner = handle.getOwner();
617  unsigned operandNo = handle.getOperandNumber();
618  newlyInvalidated[handle.get()] = [owner, operandNo](Location currentLoc) {
619  InFlightDiagnostic diag = emitError(currentLoc)
620  << "op uses a handle associated with empty "
621  "payload and invalidated by a "
622  "previously executed transform op";
623  diag.attachNote(owner->getLoc())
624  << "invalidated by this transform op that consumes its operand #"
625  << operandNo;
626  };
627  return;
628  }
629 
630  // Iterate over the mapping and invalidate aliasing handles. This is quite
631  // expensive and only necessary for error reporting in case of transform
632  // dialect misuse with dangling handles. Iteration over the handles is based
633  // on the assumption that the number of handles is significantly less than the
634  // number of IR objects (operations and values). Alternatively, we could walk
635  // the IR nested in each payload op associated with the given handle and look
636  // for handles associated with each operation and value.
637  for (const auto &[region, mapping] : llvm::reverse(mappings)) {
638  // Go over all op handle mappings and mark as invalidated any handle
639  // pointing to any of the payload ops associated with the given handle or
640  // any op nested in them.
641  for (const auto &[payloadOp, otherHandles] : mapping->reverse) {
642  for (Value otherHandle : otherHandles)
643  recordOpHandleInvalidationOne(handle, potentialAncestors, payloadOp,
644  otherHandle, throughValue,
645  newlyInvalidated);
646  }
647  // Go over all value handle mappings and mark as invalidated any handle
648  // pointing to any result of the payload op associated with the given handle
649  // or any op nested in them. Similarly invalidate handles to argument of
650  // blocks belonging to any region of any payload op associated with the
651  // given handle or any op nested in them.
652  for (const auto &[payloadValue, valueHandles] : mapping->reverseValues) {
653  for (Value valueHandle : valueHandles)
654  recordValueHandleInvalidationByOpHandleOne(handle, potentialAncestors,
655  payloadValue, valueHandle,
656  newlyInvalidated);
657  }
658 
659  // Stop lookup when reaching a region that is isolated from above.
661  break;
662  }
663 }
664 
665 void transform::TransformState::recordValueHandleInvalidation(
666  OpOperand &valueHandle,
667  transform::TransformState::InvalidatedHandleMap &newlyInvalidated) const {
668  // Invalidate other handles to the same value.
669  for (Value payloadValue : getPayloadValuesView(valueHandle.get())) {
670  SmallVector<Value> otherValueHandles;
671  (void)getHandlesForPayloadValue(payloadValue, otherValueHandles);
672  for (Value otherHandle : otherValueHandles) {
673  Operation *owner = valueHandle.getOwner();
674  unsigned operandNo = valueHandle.getOperandNumber();
675  Location valueLoc = payloadValue.getLoc();
676  newlyInvalidated[otherHandle] = [otherHandle, owner, operandNo,
677  valueLoc](Location currentLoc) {
678  InFlightDiagnostic diag = emitError(currentLoc)
679  << "op uses a handle invalidated by a "
680  "previously executed transform op";
681  diag.attachNote(otherHandle.getLoc()) << "invalidated handle";
682  diag.attachNote(owner->getLoc())
683  << "invalidated by this transform op that consumes its operand #"
684  << operandNo
685  << " and invalidates handles to the same values as associated with "
686  "it";
687  diag.attachNote(valueLoc) << "payload value";
688  };
689  }
690 
691  if (auto opResult = llvm::dyn_cast<OpResult>(payloadValue)) {
692  Operation *payloadOp = opResult.getOwner();
693  recordOpHandleInvalidation(valueHandle, payloadOp, payloadValue,
694  newlyInvalidated);
695  } else {
696  auto arg = llvm::dyn_cast<BlockArgument>(payloadValue);
697  for (Operation &payloadOp : *arg.getOwner())
698  recordOpHandleInvalidation(valueHandle, &payloadOp, payloadValue,
699  newlyInvalidated);
700  }
701  }
702 }
703 
704 /// Checks that the operation does not use invalidated handles as operands.
705 /// Reports errors and returns failure if it does. Otherwise, invalidates the
706 /// handles consumed by the operation as well as any handles pointing to payload
707 /// IR operations nested in the operations associated with the consumed handles.
708 LogicalResult transform::TransformState::checkAndRecordHandleInvalidationImpl(
709  transform::TransformOpInterface transform,
710  transform::TransformState::InvalidatedHandleMap &newlyInvalidated) const {
711  FULL_LDBG("--Start checkAndRecordHandleInvalidation\n");
712  auto memoryEffectsIface =
713  cast<MemoryEffectOpInterface>(transform.getOperation());
715  memoryEffectsIface.getEffectsOnResource(
717 
718  for (OpOperand &target : transform->getOpOperands()) {
719  DEBUG_WITH_TYPE(DEBUG_TYPE_FULL, {
720  (DBGS() << "----iterate on handle: " << target.get() << "\n");
721  });
722  // If the operand uses an invalidated handle, report it. If the operation
723  // allows handles to point to repeated payload operations, only report
724  // pre-existing invalidation errors. Otherwise, also report invalidations
725  // caused by the current transform operation affecting its other operands.
726  auto it = invalidatedHandles.find(target.get());
727  auto nit = newlyInvalidated.find(target.get());
728  if (it != invalidatedHandles.end()) {
729  FULL_LDBG("--End checkAndRecordHandleInvalidation, found already "
730  "invalidated -> FAILURE\n");
731  return it->getSecond()(transform->getLoc()), failure();
732  }
733  if (!transform.allowsRepeatedHandleOperands() &&
734  nit != newlyInvalidated.end()) {
735  FULL_LDBG("--End checkAndRecordHandleInvalidation, found newly "
736  "invalidated (by this op) -> FAILURE\n");
737  return nit->getSecond()(transform->getLoc()), failure();
738  }
739 
740  // Invalidate handles pointing to the operations nested in the operation
741  // associated with the handle consumed by this operation.
742  auto consumesTarget = [&](const MemoryEffects::EffectInstance &effect) {
743  return isa<MemoryEffects::Free>(effect.getEffect()) &&
744  effect.getValue() == target.get();
745  };
746  if (llvm::any_of(effects, consumesTarget)) {
747  FULL_LDBG("----found consume effect\n");
748  if (llvm::isa<transform::TransformHandleTypeInterface>(
749  target.get().getType())) {
750  FULL_LDBG("----recordOpHandleInvalidation\n");
751  SmallVector<Operation *> payloadOps =
752  llvm::to_vector(getPayloadOps(target.get()));
753  recordOpHandleInvalidation(target, payloadOps, nullptr,
754  newlyInvalidated);
755  } else if (llvm::isa<transform::TransformValueHandleTypeInterface>(
756  target.get().getType())) {
757  FULL_LDBG("----recordValueHandleInvalidation\n");
758  recordValueHandleInvalidation(target, newlyInvalidated);
759  } else {
760  FULL_LDBG("----not a TransformHandle -> SKIP AND DROP ON THE FLOOR\n");
761  }
762  } else {
763  FULL_LDBG("----no consume effect -> SKIP\n");
764  }
765  }
766 
767  FULL_LDBG("--End checkAndRecordHandleInvalidation -> SUCCESS\n");
768  return success();
769 }
770 
771 LogicalResult transform::TransformState::checkAndRecordHandleInvalidation(
772  transform::TransformOpInterface transform) {
773  InvalidatedHandleMap newlyInvalidated;
774  LogicalResult checkResult =
775  checkAndRecordHandleInvalidationImpl(transform, newlyInvalidated);
776  invalidatedHandles.insert(std::make_move_iterator(newlyInvalidated.begin()),
777  std::make_move_iterator(newlyInvalidated.end()));
778  return checkResult;
779 }
780 
781 template <typename T>
784  transform::TransformOpInterface transform,
785  unsigned operandNumber) {
786  DenseSet<T> seen;
787  for (T p : payload) {
788  if (!seen.insert(p).second) {
790  transform.emitSilenceableError()
791  << "a handle passed as operand #" << operandNumber
792  << " and consumed by this operation points to a payload "
793  "entity more than once";
794  if constexpr (std::is_pointer_v<T>)
795  diag.attachNote(p->getLoc()) << "repeated target op";
796  else
797  diag.attachNote(p.getLoc()) << "repeated target value";
798  return diag;
799  }
800  }
802 }
803 
804 void transform::TransformState::compactOpHandles() {
805  for (Value handle : opHandlesToCompact) {
806  Mappings &mappings = getMapping(handle, /*allowOutOfScope=*/true);
807 #if LLVM_ENABLE_ABI_BREAKING_CHECKS
808  if (llvm::find(mappings.direct[handle], nullptr) !=
809  mappings.direct[handle].end())
810  // Payload IR is removed from the mapping. This invalidates the respective
811  // iterators.
812  mappings.incrementTimestamp(handle);
813 #endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
814  llvm::erase(mappings.direct[handle], nullptr);
815  }
816  opHandlesToCompact.clear();
817 }
818 
820 transform::TransformState::applyTransform(TransformOpInterface transform) {
821  LLVM_DEBUG({
822  DBGS() << "applying: ";
823  transform->print(llvm::dbgs(), OpPrintingFlags().skipRegions());
824  llvm::dbgs() << "\n";
825  });
826  DEBUG_WITH_TYPE(DEBUG_TYPE_FULL,
827  DBGS() << "Top-level payload before application:\n"
828  << *getTopLevel() << "\n");
829  auto printOnFailureRAII = llvm::make_scope_exit([this] {
830  (void)this;
831  LLVM_DEBUG(DBGS() << "Failing Top-level payload:\n"; getTopLevel()->print(
832  llvm::dbgs(), mlir::OpPrintingFlags().printGenericOpForm()););
833  });
834 
835  // Set current transform op.
836  regionStack.back()->currentTransform = transform;
837 
838  // Expensive checks to detect invalid transform IR.
839  if (options.getExpensiveChecksEnabled()) {
840  FULL_LDBG("ExpensiveChecksEnabled\n");
841  if (failed(checkAndRecordHandleInvalidation(transform)))
843 
844  for (OpOperand &operand : transform->getOpOperands()) {
845  DEBUG_WITH_TYPE(DEBUG_TYPE_FULL, {
846  (DBGS() << "iterate on handle: " << operand.get() << "\n");
847  });
848  if (!isHandleConsumed(operand.get(), transform)) {
849  FULL_LDBG("--handle not consumed -> SKIP\n");
850  continue;
851  }
852  if (transform.allowsRepeatedHandleOperands()) {
853  FULL_LDBG("--op allows repeated handles -> SKIP\n");
854  continue;
855  }
856  FULL_LDBG("--handle is consumed\n");
857 
858  Type operandType = operand.get().getType();
859  if (llvm::isa<TransformHandleTypeInterface>(operandType)) {
860  FULL_LDBG("--checkRepeatedConsumptionInOperand for Operation*\n");
862  checkRepeatedConsumptionInOperand<Operation *>(
863  getPayloadOpsView(operand.get()), transform,
864  operand.getOperandNumber());
865  if (!check.succeeded()) {
866  FULL_LDBG("----FAILED\n");
867  return check;
868  }
869  } else if (llvm::isa<TransformValueHandleTypeInterface>(operandType)) {
870  FULL_LDBG("--checkRepeatedConsumptionInOperand For Value\n");
872  checkRepeatedConsumptionInOperand<Value>(
873  getPayloadValuesView(operand.get()), transform,
874  operand.getOperandNumber());
875  if (!check.succeeded()) {
876  FULL_LDBG("----FAILED\n");
877  return check;
878  }
879  } else {
880  FULL_LDBG("--not a TransformHandle -> SKIP AND DROP ON THE FLOOR\n");
881  }
882  }
883  }
884 
885  // Find which operands are consumed.
886  SmallVector<OpOperand *> consumedOperands =
887  transform.getConsumedHandleOpOperands();
888 
889  // Remember the results of the payload ops associated with the consumed
890  // op handles or the ops defining the value handles so we can drop the
891  // association with them later. This must happen here because the
892  // transformation may destroy or mutate them so we cannot traverse the payload
893  // IR after that.
894  SmallVector<Value> origOpFlatResults;
895  SmallVector<Operation *> origAssociatedOps;
896  for (OpOperand *opOperand : consumedOperands) {
897  Value operand = opOperand->get();
898  if (llvm::isa<TransformHandleTypeInterface>(operand.getType())) {
899  for (Operation *payloadOp : getPayloadOps(operand)) {
900  llvm::append_range(origOpFlatResults, payloadOp->getResults());
901  }
902  continue;
903  }
904  if (llvm::isa<TransformValueHandleTypeInterface>(operand.getType())) {
905  for (Value payloadValue : getPayloadValuesView(operand)) {
906  if (llvm::isa<OpResult>(payloadValue)) {
907  origAssociatedOps.push_back(payloadValue.getDefiningOp());
908  continue;
909  }
910  llvm::append_range(
911  origAssociatedOps,
912  llvm::map_range(*llvm::cast<BlockArgument>(payloadValue).getOwner(),
913  [](Operation &op) { return &op; }));
914  }
915  continue;
916  }
918  emitDefiniteFailure(transform->getLoc())
919  << "unexpectedly consumed a value that is not a handle as operand #"
920  << opOperand->getOperandNumber();
921  diag.attachNote(operand.getLoc())
922  << "value defined here with type " << operand.getType();
923  return diag;
924  }
925 
926  // Prepare rewriter and listener.
927  TrackingListenerConfig config;
928  config.skipHandleFn = [&](Value handle) {
929  // Skip handle if it is dead.
930  auto scopeIt =
931  llvm::find_if(llvm::reverse(regionStack), [&](RegionScope *scope) {
932  return handle.getParentRegion() == scope->region;
933  });
934  assert(scopeIt != regionStack.rend() &&
935  "could not find region scope for handle");
936  RegionScope *scope = *scopeIt;
937  return llvm::all_of(handle.getUsers(), [&](Operation *user) {
938  return user == scope->currentTransform ||
939  happensBefore(user, scope->currentTransform);
940  });
941  };
942  transform::ErrorCheckingTrackingListener trackingListener(*this, transform,
943  config);
944  transform::TransformRewriter rewriter(transform->getContext(),
945  &trackingListener);
946 
947  // Compute the result but do not short-circuit the silenceable failure case as
948  // we still want the handles to propagate properly so the "suppress" mode can
949  // proceed on a best effort basis.
950  transform::TransformResults results(transform->getNumResults());
951  DiagnosedSilenceableFailure result(transform.apply(rewriter, results, *this));
952  compactOpHandles();
953 
954  // Error handling: fail if transform or listener failed.
955  DiagnosedSilenceableFailure trackingFailure =
956  trackingListener.checkAndResetError();
957  if (!transform->hasTrait<ReportTrackingListenerFailuresOpTrait>() ||
958  transform->hasAttr(FindPayloadReplacementOpInterface::
959  kSilenceTrackingFailuresAttrName)) {
960  // Only report failures for ReportTrackingListenerFailuresOpTrait ops. Also
961  // do not report failures if the above mentioned attribute is set.
962  if (trackingFailure.isSilenceableFailure())
963  (void)trackingFailure.silence();
964  trackingFailure = DiagnosedSilenceableFailure::success();
965  }
966  if (!trackingFailure.succeeded()) {
967  if (result.succeeded()) {
968  result = std::move(trackingFailure);
969  } else {
970  // Transform op errors have precedence, report those first.
971  if (result.isSilenceableFailure())
972  result.attachNote() << "tracking listener also failed: "
973  << trackingFailure.getMessage();
974  (void)trackingFailure.silence();
975  }
976  }
977  if (result.isDefiniteFailure())
978  return result;
979 
980  // If a silenceable failure was produced, some results may be unset, set them
981  // to empty lists.
982  if (result.isSilenceableFailure())
983  results.setRemainingToEmpty(transform);
984 
985  // Remove the mapping for the operand if it is consumed by the operation. This
986  // allows us to catch use-after-free with assertions later on.
987  for (OpOperand *opOperand : consumedOperands) {
988  Value operand = opOperand->get();
989  if (llvm::isa<TransformHandleTypeInterface>(operand.getType())) {
990  forgetMapping(operand, origOpFlatResults);
991  } else if (llvm::isa<TransformValueHandleTypeInterface>(
992  operand.getType())) {
993  forgetValueMapping(operand, origAssociatedOps);
994  }
995  }
996 
997  if (failed(updateStateFromResults(results, transform->getResults())))
999 
1000  printOnFailureRAII.release();
1001  DEBUG_WITH_TYPE(DEBUG_PRINT_AFTER_ALL, {
1002  DBGS() << "Top-level payload:\n";
1003  getTopLevel()->print(llvm::dbgs());
1004  });
1005  return result;
1006 }
1007 
1008 LogicalResult transform::TransformState::updateStateFromResults(
1009  const TransformResults &results, ResultRange opResults) {
1010  for (OpResult result : opResults) {
1011  if (llvm::isa<TransformParamTypeInterface>(result.getType())) {
1012  assert(results.isParam(result.getResultNumber()) &&
1013  "expected parameters for the parameter-typed result");
1014  if (failed(
1015  setParams(result, results.getParams(result.getResultNumber())))) {
1016  return failure();
1017  }
1018  } else if (llvm::isa<TransformValueHandleTypeInterface>(result.getType())) {
1019  assert(results.isValue(result.getResultNumber()) &&
1020  "expected values for value-type-result");
1021  if (failed(setPayloadValues(
1022  result, results.getValues(result.getResultNumber())))) {
1023  return failure();
1024  }
1025  } else {
1026  assert(!results.isParam(result.getResultNumber()) &&
1027  "expected payload ops for the non-parameter typed result");
1028  if (failed(
1029  setPayloadOps(result, results.get(result.getResultNumber())))) {
1030  return failure();
1031  }
1032  }
1033  }
1034  return success();
1035 }
1036 
1037 //===----------------------------------------------------------------------===//
1038 // TransformState::Extension
1039 //===----------------------------------------------------------------------===//
1040 
1042 
1043 LogicalResult
1045  Operation *replacement) {
1046  // TODO: we may need to invalidate handles to operations and values nested in
1047  // the operation being replaced.
1048  return state.replacePayloadOp(op, replacement);
1049 }
1050 
1051 LogicalResult
1053  Value replacement) {
1054  return state.replacePayloadValue(value, replacement);
1055 }
1056 
1057 //===----------------------------------------------------------------------===//
1058 // TransformState::RegionScope
1059 //===----------------------------------------------------------------------===//
1060 
1062  // Remove handle invalidation notices as handles are going out of scope.
1063  // The same region may be re-entered leading to incorrect invalidation
1064  // errors.
1065  for (Block &block : *region) {
1066  for (Value handle : block.getArguments()) {
1067  state.invalidatedHandles.erase(handle);
1068  }
1069  for (Operation &op : block) {
1070  for (Value handle : op.getResults()) {
1071  state.invalidatedHandles.erase(handle);
1072  }
1073  }
1074  }
1075 
1076 #if LLVM_ENABLE_ABI_BREAKING_CHECKS
1077  // Remember pointers to payload ops referenced by the handles going out of
1078  // scope.
1079  SmallVector<Operation *> referencedOps =
1080  llvm::to_vector(llvm::make_first_range(state.mappings[region]->reverse));
1081 #endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
1082 
1083  state.mappings.erase(region);
1084  state.regionStack.pop_back();
1085 }
1086 
1087 //===----------------------------------------------------------------------===//
1088 // TransformResults
1089 //===----------------------------------------------------------------------===//
1090 
1091 transform::TransformResults::TransformResults(unsigned numSegments) {
1092  operations.appendEmptyRows(numSegments);
1093  params.appendEmptyRows(numSegments);
1094  values.appendEmptyRows(numSegments);
1095 }
1096 
1099  int64_t position = value.getResultNumber();
1100  assert(position < static_cast<int64_t>(this->params.size()) &&
1101  "setting params for a non-existent handle");
1102  assert(this->params[position].data() == nullptr && "params already set");
1103  assert(operations[position].data() == nullptr &&
1104  "another kind of results already set");
1105  assert(values[position].data() == nullptr &&
1106  "another kind of results already set");
1107  this->params.replace(position, params);
1108 }
1109 
1111  OpResult handle, ArrayRef<MappedValue> values) {
1113  handle, values,
1114  [&](ArrayRef<Operation *> operations) {
1115  return set(handle, operations), success();
1116  },
1117  [&](ArrayRef<Param> params) {
1118  return setParams(handle, params), success();
1119  },
1120  [&](ValueRange payloadValues) {
1121  return setValues(handle, payloadValues), success();
1122  });
1123 #ifndef NDEBUG
1124  if (!diag.succeeded())
1125  llvm::dbgs() << diag.getStatusString() << "\n";
1126  assert(diag.succeeded() && "incorrect mapping");
1127 #endif // NDEBUG
1128  (void)diag.silence();
1129 }
1130 
1132  transform::TransformOpInterface transform) {
1133  for (OpResult opResult : transform->getResults()) {
1134  if (!isSet(opResult.getResultNumber()))
1135  setMappedValues(opResult, {});
1136  }
1137 }
1138 
1140 transform::TransformResults::get(unsigned resultNumber) const {
1141  assert(resultNumber < operations.size() &&
1142  "querying results for a non-existent handle");
1143  assert(operations[resultNumber].data() != nullptr &&
1144  "querying unset results (values or params expected?)");
1145  return operations[resultNumber];
1146 }
1147 
1149 transform::TransformResults::getParams(unsigned resultNumber) const {
1150  assert(resultNumber < params.size() &&
1151  "querying params for a non-existent handle");
1152  assert(params[resultNumber].data() != nullptr &&
1153  "querying unset params (ops or values expected?)");
1154  return params[resultNumber];
1155 }
1156 
1158 transform::TransformResults::getValues(unsigned resultNumber) const {
1159  assert(resultNumber < values.size() &&
1160  "querying values for a non-existent handle");
1161  assert(values[resultNumber].data() != nullptr &&
1162  "querying unset values (ops or params expected?)");
1163  return values[resultNumber];
1164 }
1165 
1166 bool transform::TransformResults::isParam(unsigned resultNumber) const {
1167  assert(resultNumber < params.size() &&
1168  "querying association for a non-existent handle");
1169  return params[resultNumber].data() != nullptr;
1170 }
1171 
1172 bool transform::TransformResults::isValue(unsigned resultNumber) const {
1173  assert(resultNumber < values.size() &&
1174  "querying association for a non-existent handle");
1175  return values[resultNumber].data() != nullptr;
1176 }
1177 
1178 bool transform::TransformResults::isSet(unsigned resultNumber) const {
1179  assert(resultNumber < params.size() &&
1180  "querying association for a non-existent handle");
1181  return params[resultNumber].data() != nullptr ||
1182  operations[resultNumber].data() != nullptr ||
1183  values[resultNumber].data() != nullptr;
1184 }
1185 
1186 //===----------------------------------------------------------------------===//
1187 // TrackingListener
1188 //===----------------------------------------------------------------------===//
1189 
1191  TransformOpInterface op,
1192  TrackingListenerConfig config)
1193  : TransformState::Extension(state), transformOp(op), config(config) {
1194  if (op) {
1195  for (OpOperand *opOperand : transformOp.getConsumedHandleOpOperands()) {
1196  consumedHandles.insert(opOperand->get());
1197  }
1198  }
1199 }
1200 
1202  Operation *defOp = nullptr;
1203  for (Value v : values) {
1204  // Skip empty values.
1205  if (!v)
1206  continue;
1207  if (!defOp) {
1208  defOp = v.getDefiningOp();
1209  continue;
1210  }
1211  if (defOp != v.getDefiningOp())
1212  return nullptr;
1213  }
1214  return defOp;
1215 }
1216 
1218  Operation *&result, Operation *op, ValueRange newValues) const {
1219  assert(op->getNumResults() == newValues.size() &&
1220  "invalid number of replacement values");
1221  SmallVector<Value> values(newValues.begin(), newValues.end());
1222 
1224  getTransformOp(), "tracking listener failed to find replacement op "
1225  "during application of this transform op");
1226 
1227  do {
1228  // If the replacement values belong to different ops, drop the mapping.
1229  Operation *defOp = getCommonDefiningOp(values);
1230  if (!defOp) {
1231  diag.attachNote() << "replacement values belong to different ops";
1232  return diag;
1233  }
1234 
1235  // Skip through ops that implement CastOpInterface.
1236  if (config.skipCastOps && isa<CastOpInterface>(defOp)) {
1237  values.clear();
1238  values.assign(defOp->getOperands().begin(), defOp->getOperands().end());
1239  diag.attachNote(defOp->getLoc())
1240  << "using output of 'CastOpInterface' op";
1241  continue;
1242  }
1243 
1244  // If the defining op has the same name or we do not care about the name of
1245  // op replacements at all, we take it as a replacement.
1246  if (!config.requireMatchingReplacementOpName ||
1247  op->getName() == defOp->getName()) {
1248  result = defOp;
1250  }
1251 
1252  // Replacing an op with a constant-like equivalent is a common
1253  // canonicalization.
1254  if (defOp->hasTrait<OpTrait::ConstantLike>()) {
1255  result = defOp;
1257  }
1258 
1259  values.clear();
1260 
1261  // Skip through ops that implement FindPayloadReplacementOpInterface.
1262  if (auto findReplacementOpInterface =
1263  dyn_cast<FindPayloadReplacementOpInterface>(defOp)) {
1264  values.assign(findReplacementOpInterface.getNextOperands());
1265  diag.attachNote(defOp->getLoc()) << "using operands provided by "
1266  "'FindPayloadReplacementOpInterface'";
1267  continue;
1268  }
1269  } while (!values.empty());
1270 
1271  diag.attachNote() << "ran out of suitable replacement values";
1272  return diag;
1273 }
1274 
1276  Location loc, function_ref<void(Diagnostic &)> reasonCallback) {
1277  LLVM_DEBUG({
1279  reasonCallback(diag);
1280  DBGS() << "Match Failure : " << diag.str() << "\n";
1281  });
1282 }
1283 
1284 void transform::TrackingListener::notifyOperationErased(Operation *op) {
1285  // Remove mappings for result values.
1286  for (OpResult value : op->getResults())
1287  (void)replacePayloadValue(value, nullptr);
1288  // Remove mapping for op.
1289  (void)replacePayloadOp(op, nullptr);
1290 }
1291 
1292 void transform::TrackingListener::notifyOperationReplaced(
1293  Operation *op, ValueRange newValues) {
1294  assert(op->getNumResults() == newValues.size() &&
1295  "invalid number of replacement values");
1296 
1297  // Replace value handles.
1298  for (auto [oldValue, newValue] : llvm::zip(op->getResults(), newValues))
1299  (void)replacePayloadValue(oldValue, newValue);
1300 
1301  // Replace op handle.
1302  SmallVector<Value> opHandles;
1303  if (failed(getTransformState().getHandlesForPayloadOp(
1304  op, opHandles, /*includeOutOfScope=*/true))) {
1305  // Op is not tracked.
1306  return;
1307  }
1308 
1309  // Helper function to check if the current transform op consumes any handle
1310  // that is mapped to `op`.
1311  //
1312  // Note: If a handle was consumed, there shouldn't be any alive users, so it
1313  // is not really necessary to check for consumed handles. However, in case
1314  // there are indeed alive handles that were consumed (which is undefined
1315  // behavior) and a replacement op could not be found, we want to fail with a
1316  // nicer error message: "op uses a handle invalidated..." instead of "could
1317  // not find replacement op". This nicer error is produced later.
1318  auto handleWasConsumed = [&] {
1319  return llvm::any_of(opHandles,
1320  [&](Value h) { return consumedHandles.contains(h); });
1321  };
1322 
1323  // Check if there are any handles that must be updated.
1324  Value aliveHandle;
1325  if (config.skipHandleFn) {
1326  auto it = llvm::find_if(opHandles,
1327  [&](Value v) { return !config.skipHandleFn(v); });
1328  if (it != opHandles.end())
1329  aliveHandle = *it;
1330  } else if (!opHandles.empty()) {
1331  aliveHandle = opHandles.front();
1332  }
1333  if (!aliveHandle || handleWasConsumed()) {
1334  // The op is tracked but the corresponding handles are dead or were
1335  // consumed. Drop the op form the mapping.
1336  (void)replacePayloadOp(op, nullptr);
1337  return;
1338  }
1339 
1340  Operation *replacement;
1342  findReplacementOp(replacement, op, newValues);
1343  // If the op is tracked but no replacement op was found, send a
1344  // notification.
1345  if (!diag.succeeded()) {
1346  diag.attachNote(aliveHandle.getLoc())
1347  << "replacement is required because this handle must be updated";
1348  notifyPayloadReplacementNotFound(op, newValues, std::move(diag));
1349  (void)replacePayloadOp(op, nullptr);
1350  return;
1351  }
1352 
1353  (void)replacePayloadOp(op, replacement);
1354 }
1355 
1357  // The state of the ErrorCheckingTrackingListener must be checked and reset
1358  // if there was an error. This is to prevent errors from accidentally being
1359  // missed.
1360  assert(status.succeeded() && "listener state was not checked");
1361 }
1362 
1365  DiagnosedSilenceableFailure s = std::move(status);
1367  errorCounter = 0;
1368  return s;
1369 }
1370 
1372  return !status.succeeded();
1373 }
1374 
1377 
1378  // Merge potentially existing diags and store the result in the listener.
1380  diag.takeDiagnostics(diags);
1381  if (!status.succeeded())
1382  status.takeDiagnostics(diags);
1383  status = DiagnosedSilenceableFailure::silenceableFailure(std::move(diags));
1384 
1385  // Report more details.
1386  status.attachNote(op->getLoc()) << "[" << errorCounter << "] replaced op";
1387  for (auto &&[index, value] : llvm::enumerate(values))
1388  status.attachNote(value.getLoc())
1389  << "[" << errorCounter << "] replacement value " << index;
1390  ++errorCounter;
1391 }
1392 
1393 //===----------------------------------------------------------------------===//
1394 // TransformRewriter
1395 //===----------------------------------------------------------------------===//
1396 
1399  : RewriterBase(ctx), listener(listener) {
1400  setListener(listener);
1401 }
1402 
1404  return listener->failed();
1405 }
1406 
1407 /// Silence all tracking failures that have been encountered so far.
1409  if (hasTrackingFailures()) {
1410  DiagnosedSilenceableFailure status = listener->checkAndResetError();
1411  (void)status.silence();
1412  }
1413 }
1414 
1416  Operation *op, Operation *replacement) {
1417  return listener->replacePayloadOp(op, replacement);
1418 }
1419 
1420 //===----------------------------------------------------------------------===//
1421 // Utilities for TransformEachOpTrait.
1422 //===----------------------------------------------------------------------===//
1423 
1424 LogicalResult
1426  ArrayRef<Operation *> targets) {
1427  for (auto &&[position, parent] : llvm::enumerate(targets)) {
1428  for (Operation *child : targets.drop_front(position + 1)) {
1429  if (parent->isAncestor(child)) {
1431  emitError(loc)
1432  << "transform operation consumes a handle pointing to an ancestor "
1433  "payload operation before its descendant";
1434  diag.attachNote()
1435  << "the ancestor is likely erased or rewritten before the "
1436  "descendant is accessed, leading to undefined behavior";
1437  diag.attachNote(parent->getLoc()) << "ancestor payload op";
1438  diag.attachNote(child->getLoc()) << "descendant payload op";
1439  return diag;
1440  }
1441  }
1442  }
1443  return success();
1444 }
1445 
1446 LogicalResult
1448  Location payloadOpLoc,
1449  const ApplyToEachResultList &partialResult) {
1450  Location transformOpLoc = transformOp->getLoc();
1451  StringRef transformOpName = transformOp->getName().getStringRef();
1452  unsigned expectedNumResults = transformOp->getNumResults();
1453 
1454  // Reuse the emission of the diagnostic note.
1455  auto emitDiag = [&]() {
1456  auto diag = mlir::emitError(transformOpLoc);
1457  diag.attachNote(payloadOpLoc) << "when applied to this op";
1458  return diag;
1459  };
1460 
1461  if (partialResult.size() != expectedNumResults) {
1462  auto diag = emitDiag() << "application of " << transformOpName
1463  << " expected to produce " << expectedNumResults
1464  << " results (actually produced "
1465  << partialResult.size() << ").";
1466  diag.attachNote(transformOpLoc)
1467  << "if you need variadic results, consider a generic `apply` "
1468  << "instead of the specialized `applyToOne`.";
1469  return failure();
1470  }
1471 
1472  // Check that the right kind of value was produced.
1473  for (const auto &[ptr, res] :
1474  llvm::zip(partialResult, transformOp->getResults())) {
1475  if (ptr.isNull())
1476  continue;
1477  if (llvm::isa<TransformHandleTypeInterface>(res.getType()) &&
1478  !ptr.is<Operation *>()) {
1479  return emitDiag() << "application of " << transformOpName
1480  << " expected to produce an Operation * for result #"
1481  << res.getResultNumber();
1482  }
1483  if (llvm::isa<TransformParamTypeInterface>(res.getType()) &&
1484  !ptr.is<Attribute>()) {
1485  return emitDiag() << "application of " << transformOpName
1486  << " expected to produce an Attribute for result #"
1487  << res.getResultNumber();
1488  }
1489  if (llvm::isa<TransformValueHandleTypeInterface>(res.getType()) &&
1490  !ptr.is<Value>()) {
1491  return emitDiag() << "application of " << transformOpName
1492  << " expected to produce a Value for result #"
1493  << res.getResultNumber();
1494  }
1495  }
1496  return success();
1497 }
1498 
1499 template <typename T>
1501  return llvm::to_vector(llvm::map_range(
1502  range, [](transform::MappedValue value) { return value.get<T>(); }));
1503 }
1504 
1506  Operation *transformOp, TransformResults &transformResults,
1509  transposed.resize(transformOp->getNumResults());
1510  for (const ApplyToEachResultList &partialResults : results) {
1511  if (llvm::any_of(partialResults,
1512  [](MappedValue value) { return value.isNull(); }))
1513  continue;
1514  assert(transformOp->getNumResults() == partialResults.size() &&
1515  "expected as many partial results as op as results");
1516  for (auto [i, value] : llvm::enumerate(partialResults))
1517  transposed[i].push_back(value);
1518  }
1519 
1520  for (OpResult r : transformOp->getResults()) {
1521  unsigned position = r.getResultNumber();
1522  if (llvm::isa<TransformParamTypeInterface>(r.getType())) {
1523  transformResults.setParams(r,
1524  castVector<Attribute>(transposed[position]));
1525  } else if (llvm::isa<TransformValueHandleTypeInterface>(r.getType())) {
1526  transformResults.setValues(r, castVector<Value>(transposed[position]));
1527  } else {
1528  transformResults.set(r, castVector<Operation *>(transposed[position]));
1529  }
1530  }
1531 }
1532 
1533 //===----------------------------------------------------------------------===//
1534 // Utilities for implementing transform ops with regions.
1535 //===----------------------------------------------------------------------===//
1536 
1539  ValueRange values, const transform::TransformState &state, bool flatten) {
1540  assert(mappings.size() == values.size() && "mismatching number of mappings");
1541  for (auto &&[operand, mapped] : llvm::zip_equal(values, mappings)) {
1542  size_t mappedSize = mapped.size();
1543  if (llvm::isa<TransformHandleTypeInterface>(operand.getType())) {
1544  llvm::append_range(mapped, state.getPayloadOps(operand));
1545  } else if (llvm::isa<TransformValueHandleTypeInterface>(
1546  operand.getType())) {
1547  llvm::append_range(mapped, state.getPayloadValues(operand));
1548  } else {
1549  assert(llvm::isa<TransformParamTypeInterface>(operand.getType()) &&
1550  "unsupported kind of transform dialect value");
1551  llvm::append_range(mapped, state.getParams(operand));
1552  }
1553 
1554  if (mapped.size() - mappedSize != 1 && !flatten)
1555  return failure();
1556  }
1557  return success();
1558 }
1559 
1562  ValueRange values, const transform::TransformState &state) {
1563  mappings.resize(mappings.size() + values.size());
1564  (void)appendValueMappings(
1566  values.size()),
1567  values, state);
1568 }
1569 
1571  Block *block, transform::TransformState &state,
1572  transform::TransformResults &results) {
1573  for (auto &&[terminatorOperand, result] :
1574  llvm::zip(block->getTerminator()->getOperands(),
1575  block->getParentOp()->getOpResults())) {
1576  if (llvm::isa<transform::TransformHandleTypeInterface>(result.getType())) {
1577  results.set(result, state.getPayloadOps(terminatorOperand));
1578  } else if (llvm::isa<transform::TransformValueHandleTypeInterface>(
1579  result.getType())) {
1580  results.setValues(result, state.getPayloadValues(terminatorOperand));
1581  } else {
1582  assert(
1583  llvm::isa<transform::TransformParamTypeInterface>(result.getType()) &&
1584  "unhandled transform type interface");
1585  results.setParams(result, state.getParams(terminatorOperand));
1586  }
1587  }
1588 }
1589 
1592  Operation *payloadRoot) {
1593  return TransformState(region, payloadRoot);
1594 }
1595 
1596 //===----------------------------------------------------------------------===//
1597 // Utilities for PossibleTopLevelTransformOpTrait.
1598 //===----------------------------------------------------------------------===//
1599 
1600 /// Appends to `effects` the memory effect instances on `target` with the same
1601 /// resource and effect as the ones the operation `iface` having on `source`.
1602 static void
1603 remapEffects(MemoryEffectOpInterface iface, BlockArgument source,
1604  OpOperand *target,
1607  iface.getEffectsOnValue(source, nestedEffects);
1608  for (const auto &effect : nestedEffects)
1609  effects.emplace_back(effect.getEffect(), target, effect.getResource());
1610 }
1611 
1612 /// Appends to `effects` the same effects as the operations of `block` have on
1613 /// block arguments but associated with `operands.`
1614 static void
1617  for (Operation &op : block) {
1618  auto iface = dyn_cast<MemoryEffectOpInterface>(&op);
1619  if (!iface)
1620  continue;
1621 
1622  for (auto &&[source, target] : llvm::zip(block.getArguments(), operands)) {
1623  remapEffects(iface, source, &target, effects);
1624  }
1625 
1627  iface.getEffectsOnResource(transform::PayloadIRResource::get(),
1628  nestedEffects);
1629  llvm::append_range(effects, nestedEffects);
1630  }
1631 }
1632 
1634  Operation *operation, Value root, Block &body,
1636  transform::onlyReadsHandle(operation->getOpOperands(), effects);
1637  transform::producesHandle(operation->getOpResults(), effects);
1638 
1639  if (!root) {
1640  for (Operation &op : body) {
1641  auto iface = dyn_cast<MemoryEffectOpInterface>(&op);
1642  if (!iface)
1643  continue;
1644 
1646  iface.getEffects(effects);
1647  }
1648  return;
1649  }
1650 
1651  // Carry over all effects on arguments of the entry block as those on the
1652  // operands, this is the same value just remapped.
1653  remapArgumentEffects(body, operation->getOpOperands(), effects);
1654 }
1655 
1657  TransformState &state, Operation *op, Region &region) {
1658  SmallVector<Operation *> targets;
1659  SmallVector<SmallVector<MappedValue>> extraMappings;
1660  if (op->getNumOperands() != 0) {
1661  llvm::append_range(targets, state.getPayloadOps(op->getOperand(0)));
1662  prepareValueMappings(extraMappings, op->getOperands().drop_front(), state);
1663  } else {
1664  if (state.getNumTopLevelMappings() !=
1665  region.front().getNumArguments() - 1) {
1666  return emitError(op->getLoc())
1667  << "operation expects " << region.front().getNumArguments() - 1
1668  << " extra value bindings, but " << state.getNumTopLevelMappings()
1669  << " were provided to the interpreter";
1670  }
1671 
1672  targets.push_back(state.getTopLevel());
1673 
1674  for (unsigned i = 0, e = state.getNumTopLevelMappings(); i < e; ++i)
1675  extraMappings.push_back(llvm::to_vector(state.getTopLevelMapping(i)));
1676  }
1677 
1678  if (failed(state.mapBlockArguments(region.front().getArgument(0), targets)))
1679  return failure();
1680 
1681  for (BlockArgument argument : region.front().getArguments().drop_front()) {
1682  if (failed(state.mapBlockArgument(
1683  argument, extraMappings[argument.getArgNumber() - 1])))
1684  return failure();
1685  }
1686 
1687  return success();
1688 }
1689 
1690 LogicalResult
1692  // Attaching this trait without the interface is a misuse of the API, but it
1693  // cannot be caught via a static_assert because interface registration is
1694  // dynamic.
1695  assert(isa<TransformOpInterface>(op) &&
1696  "should implement TransformOpInterface to have "
1697  "PossibleTopLevelTransformOpTrait");
1698 
1699  if (op->getNumRegions() < 1)
1700  return op->emitOpError() << "expects at least one region";
1701 
1702  Region *bodyRegion = &op->getRegion(0);
1703  if (!llvm::hasNItems(*bodyRegion, 1))
1704  return op->emitOpError() << "expects a single-block region";
1705 
1706  Block *body = &bodyRegion->front();
1707  if (body->getNumArguments() == 0) {
1708  return op->emitOpError()
1709  << "expects the entry block to have at least one argument";
1710  }
1711  if (!llvm::isa<TransformHandleTypeInterface>(
1712  body->getArgument(0).getType())) {
1713  return op->emitOpError()
1714  << "expects the first entry block argument to be of type "
1715  "implementing TransformHandleTypeInterface";
1716  }
1717  BlockArgument arg = body->getArgument(0);
1718  if (op->getNumOperands() != 0) {
1719  if (arg.getType() != op->getOperand(0).getType()) {
1720  return op->emitOpError()
1721  << "expects the type of the block argument to match "
1722  "the type of the operand";
1723  }
1724  }
1725  for (BlockArgument arg : body->getArguments().drop_front()) {
1726  if (llvm::isa<TransformHandleTypeInterface, TransformParamTypeInterface,
1727  TransformValueHandleTypeInterface>(arg.getType()))
1728  continue;
1729 
1731  op->emitOpError()
1732  << "expects trailing entry block arguments to be of type implementing "
1733  "TransformHandleTypeInterface, TransformValueHandleTypeInterface or "
1734  "TransformParamTypeInterface";
1735  diag.attachNote() << "argument #" << arg.getArgNumber() << " does not";
1736  return diag;
1737  }
1738 
1739  if (auto *parent =
1741  if (op->getNumOperands() != body->getNumArguments()) {
1743  op->emitOpError()
1744  << "expects operands to be provided for a nested op";
1745  diag.attachNote(parent->getLoc())
1746  << "nested in another possible top-level op";
1747  return diag;
1748  }
1749  }
1750 
1751  return success();
1752 }
1753 
1754 //===----------------------------------------------------------------------===//
1755 // Utilities for ParamProducedTransformOpTrait.
1756 //===----------------------------------------------------------------------===//
1757 
1760  producesHandle(op->getResults(), effects);
1761  bool hasPayloadOperands = false;
1762  for (OpOperand &operand : op->getOpOperands()) {
1763  onlyReadsHandle(operand, effects);
1764  if (llvm::isa<TransformHandleTypeInterface,
1765  TransformValueHandleTypeInterface>(operand.get().getType()))
1766  hasPayloadOperands = true;
1767  }
1768  if (hasPayloadOperands)
1769  onlyReadsPayload(effects);
1770 }
1771 
1772 LogicalResult
1774  // Interfaces can be attached dynamically, so this cannot be a static
1775  // assert.
1776  if (!op->getName().getInterface<MemoryEffectOpInterface>()) {
1777  llvm::report_fatal_error(
1778  Twine("ParamProducerTransformOpTrait must be attached to an op that "
1779  "implements MemoryEffectsOpInterface, found on ") +
1780  op->getName().getStringRef());
1781  }
1782  for (Value result : op->getResults()) {
1783  if (llvm::isa<TransformParamTypeInterface>(result.getType()))
1784  continue;
1785  return op->emitOpError()
1786  << "ParamProducerTransformOpTrait attached to this op expects "
1787  "result types to implement TransformParamTypeInterface";
1788  }
1789  return success();
1790 }
1791 
1792 //===----------------------------------------------------------------------===//
1793 // Memory effects.
1794 //===----------------------------------------------------------------------===//
1795 
1799  for (OpOperand &handle : handles) {
1800  effects.emplace_back(MemoryEffects::Read::get(), &handle,
1802  effects.emplace_back(MemoryEffects::Free::get(), &handle,
1804  }
1805 }
1806 
1807 /// Returns `true` if the given list of effects instances contains an instance
1808 /// with the effect type specified as template parameter.
1809 template <typename EffectTy, typename ResourceTy, typename Range>
1810 static bool hasEffect(Range &&effects) {
1811  return llvm::any_of(effects, [](const MemoryEffects::EffectInstance &effect) {
1812  return isa<EffectTy>(effect.getEffect()) &&
1813  isa<ResourceTy>(effect.getResource());
1814  });
1815 }
1816 
1818  transform::TransformOpInterface transform) {
1819  auto iface = cast<MemoryEffectOpInterface>(transform.getOperation());
1821  iface.getEffectsOnValue(handle, effects);
1822  return ::hasEffect<MemoryEffects::Read, TransformMappingResource>(effects) &&
1823  ::hasEffect<MemoryEffects::Free, TransformMappingResource>(effects);
1824 }
1825 
1827  ResultRange handles,
1829  for (OpResult handle : handles) {
1830  effects.emplace_back(MemoryEffects::Allocate::get(), handle,
1832  effects.emplace_back(MemoryEffects::Write::get(), handle,
1834  }
1835 }
1836 
1840  for (BlockArgument handle : handles) {
1841  effects.emplace_back(MemoryEffects::Allocate::get(), handle,
1843  effects.emplace_back(MemoryEffects::Write::get(), handle,
1845  }
1846 }
1847 
1851  for (OpOperand &handle : handles) {
1852  effects.emplace_back(MemoryEffects::Read::get(), &handle,
1854  }
1855 }
1856 
1859  effects.emplace_back(MemoryEffects::Read::get(), PayloadIRResource::get());
1860  effects.emplace_back(MemoryEffects::Write::get(), PayloadIRResource::get());
1861 }
1862 
1865  effects.emplace_back(MemoryEffects::Read::get(), PayloadIRResource::get());
1866 }
1867 
1868 bool transform::doesModifyPayload(transform::TransformOpInterface transform) {
1869  auto iface = cast<MemoryEffectOpInterface>(transform.getOperation());
1871  iface.getEffects(effects);
1872  return ::hasEffect<MemoryEffects::Write, PayloadIRResource>(effects);
1873 }
1874 
1875 bool transform::doesReadPayload(transform::TransformOpInterface transform) {
1876  auto iface = cast<MemoryEffectOpInterface>(transform.getOperation());
1878  iface.getEffects(effects);
1879  return ::hasEffect<MemoryEffects::Read, PayloadIRResource>(effects);
1880 }
1881 
1883  Block &block, llvm::SmallDenseSet<unsigned int> &consumedArguments) {
1885  for (Operation &nested : block) {
1886  auto iface = dyn_cast<MemoryEffectOpInterface>(nested);
1887  if (!iface)
1888  continue;
1889 
1890  effects.clear();
1891  iface.getEffects(effects);
1892  for (const MemoryEffects::EffectInstance &effect : effects) {
1893  BlockArgument argument =
1894  dyn_cast_or_null<BlockArgument>(effect.getValue());
1895  if (!argument || argument.getOwner() != &block ||
1896  !isa<MemoryEffects::Free>(effect.getEffect()) ||
1897  effect.getResource() != transform::TransformMappingResource::get()) {
1898  continue;
1899  }
1900  consumedArguments.insert(argument.getArgNumber());
1901  }
1902  }
1903 }
1904 
1905 //===----------------------------------------------------------------------===//
1906 // Utilities for TransformOpInterface.
1907 //===----------------------------------------------------------------------===//
1908 
1910  TransformOpInterface transformOp) {
1911  SmallVector<OpOperand *> consumedOperands;
1912  consumedOperands.reserve(transformOp->getNumOperands());
1913  auto memEffectInterface =
1914  cast<MemoryEffectOpInterface>(transformOp.getOperation());
1916  for (OpOperand &target : transformOp->getOpOperands()) {
1917  effects.clear();
1918  memEffectInterface.getEffectsOnValue(target.get(), effects);
1919  if (llvm::any_of(effects, [](const MemoryEffects::EffectInstance &effect) {
1920  return isa<transform::TransformMappingResource>(
1921  effect.getResource()) &&
1922  isa<MemoryEffects::Free>(effect.getEffect());
1923  })) {
1924  consumedOperands.push_back(&target);
1925  }
1926  }
1927  return consumedOperands;
1928 }
1929 
1931  auto iface = cast<MemoryEffectOpInterface>(op);
1933  iface.getEffects(effects);
1934 
1935  auto effectsOn = [&](Value value) {
1936  return llvm::make_filter_range(
1937  effects, [value](const MemoryEffects::EffectInstance &instance) {
1938  return instance.getValue() == value;
1939  });
1940  };
1941 
1942  std::optional<unsigned> firstConsumedOperand;
1943  for (OpOperand &operand : op->getOpOperands()) {
1944  auto range = effectsOn(operand.get());
1945  if (range.empty()) {
1947  op->emitError() << "TransformOpInterface requires memory effects "
1948  "on operands to be specified";
1949  diag.attachNote() << "no effects specified for operand #"
1950  << operand.getOperandNumber();
1951  return diag;
1952  }
1953  if (::hasEffect<MemoryEffects::Allocate, TransformMappingResource>(range)) {
1955  << "TransformOpInterface did not expect "
1956  "'allocate' memory effect on an operand";
1957  diag.attachNote() << "specified for operand #"
1958  << operand.getOperandNumber();
1959  return diag;
1960  }
1961  if (!firstConsumedOperand &&
1962  ::hasEffect<MemoryEffects::Free, TransformMappingResource>(range)) {
1963  firstConsumedOperand = operand.getOperandNumber();
1964  }
1965  }
1966 
1967  if (firstConsumedOperand &&
1968  !::hasEffect<MemoryEffects::Write, PayloadIRResource>(effects)) {
1970  op->emitError()
1971  << "TransformOpInterface expects ops consuming operands to have a "
1972  "'write' effect on the payload resource";
1973  diag.attachNote() << "consumes operand #" << *firstConsumedOperand;
1974  return diag;
1975  }
1976 
1977  for (OpResult result : op->getResults()) {
1978  auto range = effectsOn(result);
1979  if (!::hasEffect<MemoryEffects::Allocate, TransformMappingResource>(
1980  range)) {
1982  op->emitError() << "TransformOpInterface requires 'allocate' memory "
1983  "effect to be specified for results";
1984  diag.attachNote() << "no 'allocate' effect specified for result #"
1985  << result.getResultNumber();
1986  return diag;
1987  }
1988  }
1989 
1990  return success();
1991 }
1992 
1993 //===----------------------------------------------------------------------===//
1994 // Entry point.
1995 //===----------------------------------------------------------------------===//
1996 
1998  Operation *payloadRoot, TransformOpInterface transform,
1999  const RaggedArray<MappedValue> &extraMapping,
2000  const TransformOptions &options, bool enforceToplevelTransformOp,
2001  function_ref<void(TransformState &)> stateInitializer,
2002  function_ref<LogicalResult(TransformState &)> stateExporter) {
2003  if (enforceToplevelTransformOp) {
2004  if (!transform->hasTrait<PossibleTopLevelTransformOpTrait>() ||
2005  transform->getNumOperands() != 0) {
2006  return transform->emitError()
2007  << "expected transform to start at the top-level transform op";
2008  }
2009  } else if (failed(
2011  return failure();
2012  }
2013 
2014  TransformState state(transform->getParentRegion(), payloadRoot, extraMapping,
2015  options);
2016  if (stateInitializer)
2017  stateInitializer(state);
2018  if (state.applyTransform(transform).checkAndReport().failed())
2019  return failure();
2020  if (stateExporter)
2021  return stateExporter(state);
2022  return success();
2023 }
2024 
2025 //===----------------------------------------------------------------------===//
2026 // Generated interface implementation.
2027 //===----------------------------------------------------------------------===//
2028 
2029 #include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.cpp.inc"
2030 #include "mlir/Dialect/Transform/Interfaces/TransformTypeInterfaces.cpp.inc"
static InFlightDiagnostic emitDiag(Location location, DiagnosticSeverity severity, const Twine &message)
Helper function used to emit a diagnostic with an optionally empty twine message.
static std::string diag(const llvm::Value &value)
static llvm::ManagedStatic< PassManagerOptions > options
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
#define DEBUG_TYPE_FULL
#define FULL_LDBG(X)
static void remapEffects(MemoryEffectOpInterface iface, BlockArgument source, OpOperand *target, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
Appends to effects the memory effect instances on target with the same resource and effect as the one...
static void remapArgumentEffects(Block &block, MutableArrayRef< OpOperand > operands, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
Appends to effects the same effects as the operations of block have on block arguments but associated...
static bool happensBefore(Operation *a, Operation *b)
Return true if a happens before b, i.e., a or one of its ancestors properly dominates b and b is not ...
void dropMappingEntry(Mapping &mapping, Key key, Mapped mapped)
#define DEBUG_PRINT_AFTER_ALL
DiagnosedSilenceableFailure checkRepeatedConsumptionInOperand(ArrayRef< T > payload, transform::TransformOpInterface transform, unsigned operandNumber)
#define DBGS()
static DiagnosedSilenceableFailure dispatchMappedValues(Value handle, ArrayRef< transform::MappedValue > values, function_ref< LogicalResult(ArrayRef< Operation * >)> operationsFn, function_ref< LogicalResult(ArrayRef< transform::Param >)> paramsFn, function_ref< LogicalResult(ValueRange)> valuesFn)
Given a list of MappedValues, cast them to the value kind implied by the interface of the handle type...
static SmallVector< T > castVector(ArrayRef< transform::MappedValue > range)
Attributes are known-constant values of operations.
Definition: Attributes.h:25
This class represents an argument of a Block.
Definition: Value.h:319
Block * getOwner() const
Returns the block that owns this argument.
Definition: Value.h:328
unsigned getArgNumber() const
Returns the number of this argument.
Definition: Value.h:331
Block represents an ordered list of Operations.
Definition: Block.h:31
Operation * findAncestorOpInBlock(Operation &op)
Returns 'op' if 'op' lies in this block, or otherwise finds the ancestor operation of 'op' that lies ...
Definition: Block.cpp:73
BlockArgument getArgument(unsigned i)
Definition: Block.h:127
unsigned getNumArguments()
Definition: Block.h:126
Operation * getTerminator()
Get the terminator operation of this block.
Definition: Block.cpp:243
BlockArgListType getArguments()
Definition: Block.h:85
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
Definition: Block.cpp:30
A compatibility class connecting InFlightDiagnostic to DiagnosedSilenceableFailure while providing an...
The result of a transform IR operation application.
LogicalResult silence()
Converts silenceable failure into LogicalResult success without reporting the diagnostic,...
static DiagnosedSilenceableFailure success()
Constructs a DiagnosedSilenceableFailure in the success state.
Diagnostic & attachNote(std::optional< Location > loc=std::nullopt)
Attaches a note to the last diagnostic.
std::string getMessage() const
Returns the diagnostic message without emitting it.
bool isDefiniteFailure() const
Returns true if this is a definite failure.
LogicalResult checkAndReport()
Converts all kinds of failure into a LogicalResult failure, emitting the diagnostic if necessary.
static DiagnosedSilenceableFailure silenceableFailure(Diagnostic &&diag)
Constructs a DiagnosedSilenceableFailure in the silenceable failure state, ready to emit the given di...
bool succeeded() const
Returns true if this is a success.
static DiagnosedSilenceableFailure definiteFailure()
Constructs a DiagnosedSilenceableFailure in the failure state.
bool isSilenceableFailure() const
Returns true if this is a silenceable failure.
This class contains all of the information necessary to report a diagnostic to the DiagnosticEngine.
Definition: Diagnostics.h:155
IRValueT get() const
Return the current value being used by this operand.
Definition: UseDefLists.h:160
This class represents a diagnostic that is inflight and set to be reported.
Definition: Diagnostics.h:313
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
void setListener(Listener *newListener)
Sets the listener of this builder to the one provided.
Definition: Builders.h:324
This class represents an operand of an operation.
Definition: Value.h:267
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Definition: Value.cpp:216
Set of flags used to control the behavior of the various IR print methods (e.g.
This is a value defined by a result of an operation.
Definition: Value.h:457
unsigned getResultNumber() const
Returns the number of this result.
Definition: Value.h:469
This class provides the API for a sub-set of ops that are known to be constant-like.
This class provides the API for ops that are known to be isolated from above.
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
StringAttr getIdentifier() const
Return the name of this operation as a StringAttr.
T::Concept * getInterface() const
Returns an instance of the concept object for the given interface if it was registered to this operat...
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Value getOperand(unsigned idx)
Definition: Operation.h:345
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition: Operation.h:745
bool isBeforeInBlock(Operation *other)
Given an operation 'other' that is within the same parent block, return whether the current operation...
Definition: Operation.cpp:386
unsigned getNumRegions()
Returns the number of regions held by this operation.
Definition: Operation.h:669
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
unsigned getNumOperands()
Definition: Operation.h:341
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition: Operation.h:234
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:268
Block * getBlock()
Returns the operation block that contains this operation.
Definition: Operation.h:213
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Definition: Operation.h:682
Operation * getParentWithTrait()
Returns the closest surrounding parent operation with trait Trait.
Definition: Operation.h:248
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:119
MutableArrayRef< OpOperand > getOpOperands()
Definition: Operation.h:378
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:373
result_range getOpResults()
Definition: Operation.h:415
result_range getResults()
Definition: Operation.h:410
bool isProperAncestor(Operation *other)
Return true if this operation is a proper ancestor of the other operation.
Definition: Operation.cpp:219
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:671
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:399
A 2D array where each row may have different length.
Definition: RaggedArray.h:18
size_t size() const
Returns the number of rows in the 2D array.
Definition: RaggedArray.h:21
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
Operation * getParentOp()
Return the parent operation this region is attached to.
Definition: Region.h:200
Block & front()
Definition: Region.h:65
This class implements the result iterators for the Operation class.
Definition: ValueRange.h:242
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:400
This class represents a specific instance of an effect.
Resource * getResource() const
Return the resource that the effect applies to.
EffectT * getEffect() const
Return the effect being applied.
Value getValue() const
Return the value the effect is applied on, or nullptr if there isn't a known value being affected.
static DerivedEffect * get()
Returns a unique instance for the derived effect class.
static TransformMappingResource * get()
Returns a unique instance for the given effect class.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:26
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
Operation * getOwner() const
Return the owner of this operand.
Definition: UseDefLists.h:38
A list of results of applying a transform op with ApplyEachOpTrait to a single payload operation,...
size_t size() const
Returns the number of elements in the list.
A specialized listener that keeps track of cases in which no replacement payload could be found.
bool failed() const
Return "true" if this tracking listener had a failure.
void notifyPayloadReplacementNotFound(Operation *op, ValueRange values, DiagnosedSilenceableFailure &&diag) override
This function is called when a tracked payload op is dropped because no replacement op was found.
DiagnosedSilenceableFailure checkAndResetError()
Check and return the current error state of this listener.
This trait is supposed to be attached to Transform dialect operations that can be standalone top-leve...
TrackingListener failures are reported only for ops that have this trait.
void notifyMatchFailure(Location loc, function_ref< void(Diagnostic &)> reasonCallback) override
Notify the listener that the pattern failed to match the given operation, and provide a callback to p...
TrackingListener(TransformState &state, TransformOpInterface op, TrackingListenerConfig config=TrackingListenerConfig())
Create a new TrackingListener for usage in the specified transform op.
static Operation * getCommonDefiningOp(ValueRange values)
Return the single op that defines all given values (if any).
virtual DiagnosedSilenceableFailure findReplacementOp(Operation *&result, Operation *op, ValueRange newValues) const
Return a replacement payload op for the given op, which is going to be replaced with the given values...
Options controlling the application of transform operations by the TransformState.
Local mapping between values defined by a specific op implementing the TransformOpInterface and the p...
void setValues(OpResult handle, Range &&values)
Indicates that the result of the transform IR op at the given position corresponds to the given range...
void setParams(OpResult value, ArrayRef< TransformState::Param > params)
Indicates that the result of the transform IR op at the given position corresponds to the given list ...
void set(OpResult value, Range &&ops)
Indicates that the result of the transform IR op at the given position corresponds to the given list ...
void setRemainingToEmpty(TransformOpInterface transform)
Sets the currently unset results to empty lists of the kind expected by the corresponding results of ...
void setMappedValues(OpResult handle, ArrayRef< MappedValue > values)
Indicates that the result of the transform IR op at the given position corresponds to the given range...
This is a special rewriter to be used in transform op implementations, providing additional helper fu...
TransformRewriter(MLIRContext *ctx, ErrorCheckingTrackingListener *listener)
Create a new TransformRewriter.
bool hasTrackingFailures() const
Return "true" if the tracking listener had failures.
LogicalResult notifyPayloadOperationReplaced(Operation *op, Operation *replacement)
Notify the transform dialect interpreter that the given op has been replaced with another op and that...
void silenceTrackingFailure()
Silence all tracking failures that have been encountered so far.
LogicalResult replacePayloadOp(Operation *op, Operation *replacement)
Replaces the given payload op with another op.
virtual ~Extension()
Base virtual destructor.
LogicalResult replacePayloadValue(Value value, Value replacement)
Replaces the given payload value with another value.
A RAII object maintaining a "stack frame" for a transform IR region.
~RegionScope()
Forgets the mapping from or to values defined in the associated transform IR region,...
The state maintained across applications of various ops implementing the TransformOpInterface.
LogicalResult getHandlesForPayloadValue(Value payloadValue, SmallVectorImpl< Value > &handles, bool includeOutOfScope=false) const
Populates handles with all handles pointing to the given payload IR value.
LogicalResult mapBlockArguments(BlockArgument argument, ArrayRef< Operation * > operations)
Records the mapping between a block argument in the transform IR and a list of operations in the payl...
DiagnosedSilenceableFailure applyTransform(TransformOpInterface transform)
Applies the transformation specified by the given transform op and updates the state accordingly.
ArrayRef< Attribute > getParams(Value value) const
Returns the list of parameters that the given transform IR value corresponds to.
LogicalResult mapBlockArgument(BlockArgument argument, ArrayRef< MappedValue > values)
LogicalResult getHandlesForPayloadOp(Operation *op, SmallVectorImpl< Value > &handles, bool includeOutOfScope=false) const
Populates handles with all handles pointing to the given Payload IR op.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
LogicalResult verifyTransformOpInterface(Operation *op)
Verification hook for TransformOpInterface.
LogicalResult appendValueMappings(MutableArrayRef< SmallVector< transform::MappedValue >> mappings, ValueRange values, const transform::TransformState &state, bool flatten=true)
Appends the entities associated with the given transform values in state to the pre-existing list of ...
void setApplyToOneResults(Operation *transformOp, TransformResults &transformResults, ArrayRef< ApplyToEachResultList > results)
"Transpose" the results produced by individual applications, arranging them per result value of the t...
void forwardTerminatorOperands(Block *block, transform::TransformState &state, transform::TransformResults &results)
Populates results with payload associations that match exactly those of the operands to block's termi...
void getParamProducerTransformOpTraitEffects(Operation *op, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
Non-template implementation of ParamProducerTransformOpTrait::getEffects().
LogicalResult checkNestedConsumption(Location loc, ArrayRef< Operation * > targets)
Reports an error and returns failure if targets contains an ancestor operation before its descendant ...
LogicalResult mapPossibleTopLevelTransformOpBlockArguments(TransformState &state, Operation *op, Region &region)
Maps the only block argument of the op with PossibleTopLevelTransformOpTrait to either the list of op...
TransformState makeTransformStateForTesting(Region *region, Operation *payloadRoot)
Make a dummy transform state for testing purposes.
LogicalResult checkApplyToOne(Operation *transformOp, Location payloadOpLoc, const ApplyToEachResultList &partialResult)
Check that the contents of partialResult matches the number, kind (payload op or parameter) and nulli...
SmallVector< OpOperand * > getConsumedHandleOpOperands(transform::TransformOpInterface transformOp)
Returns all operands that are handles and being consumed by the given op.
LogicalResult verifyParamProducerTransformOpTrait(Operation *op)
Non-template implementation of ParamProducerTransformOpTrait::verify().
void prepareValueMappings(SmallVectorImpl< SmallVector< transform::MappedValue >> &mappings, ValueRange values, const transform::TransformState &state)
Populates mappings with mapped values associated with the given transform IR values in the given stat...
LogicalResult verifyPossibleTopLevelTransformOpTrait(Operation *op)
Verification hook for PossibleTopLevelTransformOpTrait.
void getPotentialTopLevelEffects(Operation *operation, Value root, Block &body, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
Populates effects with side effects implied by PossibleTopLevelTransformOpTrait for the given operati...
void onlyReadsPayload(SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
LogicalResult applyTransforms(Operation *payloadRoot, TransformOpInterface transform, const RaggedArray< MappedValue > &extraMapping={}, const TransformOptions &options=TransformOptions(), bool enforceToplevelTransformOp=true, function_ref< void(TransformState &)> stateInitializer=nullptr, function_ref< LogicalResult(TransformState &)> stateExporter=nullptr)
Entry point to the Transform dialect infrastructure.
bool isHandleConsumed(Value handle, transform::TransformOpInterface transform)
Checks whether the transform op consumes the given handle.
void getConsumedBlockArguments(Block &block, llvm::SmallDenseSet< unsigned > &consumedArguments)
Populates consumedArguments with positions of block arguments that are consumed by the operations in ...
bool doesModifyPayload(transform::TransformOpInterface transform)
Checks whether the transform op modifies the payload.
void producesHandle(ResultRange handles, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
bool doesReadPayload(transform::TransformOpInterface transform)
Checks whether the transform op reads the payload.
void consumesHandle(MutableArrayRef< OpOperand > handles, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
Populates effects with the memory effects indicating the operation on the given handle value:
void onlyReadsHandle(MutableArrayRef< OpOperand > handles, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
void modifiesPayload(SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
Populates effects with the memory effects indicating the access to payload IR resource.
Include the generated interface declarations.
DiagnosedSilenceableFailure emitSilenceableFailure(Location loc, const Twine &message={})
Emits a silenceable failure with the given message.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
DiagnosedDefiniteFailure emitDefiniteFailure(Location loc, const Twine &message={})
Emits a definite failure with the given message.
bool hasEffect(Operation *op)
Returns true if op has an effect of type EffectTy.
Represents a range (offset, size, and stride) where each element of the triple may be dynamic or stat...
A configuration object for customizing a TrackingListener.
SkipHandleFn skipHandleFn
An optional function that returns "true" for handles that do not have to be updated.