MLIR  19.0.0git
TransformInterfaces.cpp
Go to the documentation of this file.
1 //===- TransformInterfaces.cpp - Transform Dialect Interfaces -------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
11 #include "mlir/IR/Diagnostics.h"
12 #include "mlir/IR/Operation.h"
13 #include "mlir/IR/PatternMatch.h"
17 #include "llvm/ADT/STLExtras.h"
18 #include "llvm/ADT/ScopeExit.h"
19 #include "llvm/Support/Debug.h"
20 #include "llvm/Support/ErrorHandling.h"
21 
22 #define DEBUG_TYPE "transform-dialect"
23 #define DEBUG_TYPE_FULL "transform-dialect-full"
24 #define DEBUG_PRINT_AFTER_ALL "transform-dialect-print-top-level-after-all"
25 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "] ")
26 #define LDBG(X) LLVM_DEBUG(DBGS() << (X))
27 #define FULL_LDBG(X) DEBUG_WITH_TYPE(DEBUG_TYPE_FULL, (DBGS() << (X)))
28 
29 using namespace mlir;
30 
31 //===----------------------------------------------------------------------===//
32 // Helper functions
33 //===----------------------------------------------------------------------===//
34 
35 /// Return true if `a` happens before `b`, i.e., `a` or one of its ancestors
36 /// properly dominates `b` and `b` is not inside `a`.
37 static bool happensBefore(Operation *a, Operation *b) {
38  do {
39  if (a->isProperAncestor(b))
40  return false;
41  if (Operation *bAncestor = a->getBlock()->findAncestorOpInBlock(*b)) {
42  return a->isBeforeInBlock(bAncestor);
43  }
44  } while ((a = a->getParentOp()));
45  return false;
46 }
47 
48 //===----------------------------------------------------------------------===//
49 // TransformState
50 //===----------------------------------------------------------------------===//
51 
52 constexpr const Value transform::TransformState::kTopLevelValue;
53 
54 transform::TransformState::TransformState(
55  Region *region, Operation *payloadRoot,
56  const RaggedArray<MappedValue> &extraMappings,
57  const TransformOptions &options)
58  : topLevel(payloadRoot), options(options) {
59  topLevelMappedValues.reserve(extraMappings.size());
60  for (ArrayRef<MappedValue> mapping : extraMappings)
61  topLevelMappedValues.push_back(mapping);
62  if (region) {
63  RegionScope *scope = new RegionScope(*this, *region);
64  topLevelRegionScope.reset(scope);
65  }
66 }
67 
68 Operation *transform::TransformState::getTopLevel() const { return topLevel; }
69 
71 transform::TransformState::getPayloadOpsView(Value value) const {
72  const TransformOpMapping &operationMapping = getMapping(value).direct;
73  auto iter = operationMapping.find(value);
74  assert(iter != operationMapping.end() &&
75  "cannot find mapping for payload handle (param/value handle "
76  "provided?)");
77  return iter->getSecond();
78 }
79 
81  const ParamMapping &mapping = getMapping(value).params;
82  auto iter = mapping.find(value);
83  assert(iter != mapping.end() && "cannot find mapping for param handle "
84  "(operation/value handle provided?)");
85  return iter->getSecond();
86 }
87 
89 transform::TransformState::getPayloadValuesView(Value handleValue) const {
90  const ValueMapping &mapping = getMapping(handleValue).values;
91  auto iter = mapping.find(handleValue);
92  assert(iter != mapping.end() && "cannot find mapping for value handle "
93  "(param/operation handle provided?)");
94  return iter->getSecond();
95 }
96 
98  Operation *op, SmallVectorImpl<Value> &handles,
99  bool includeOutOfScope) const {
100  bool found = false;
101  for (const auto &[region, mapping] : llvm::reverse(mappings)) {
102  auto iterator = mapping->reverse.find(op);
103  if (iterator != mapping->reverse.end()) {
104  llvm::append_range(handles, iterator->getSecond());
105  found = true;
106  }
107  // Stop looking when reaching a region that is isolated from above.
108  if (!includeOutOfScope &&
110  break;
111  }
112 
113  return success(found);
114 }
115 
117  Value payloadValue, SmallVectorImpl<Value> &handles,
118  bool includeOutOfScope) const {
119  bool found = false;
120  for (const auto &[region, mapping] : llvm::reverse(mappings)) {
121  auto iterator = mapping->reverseValues.find(payloadValue);
122  if (iterator != mapping->reverseValues.end()) {
123  llvm::append_range(handles, iterator->getSecond());
124  found = true;
125  }
126  // Stop looking when reaching a region that is isolated from above.
127  if (!includeOutOfScope &&
129  break;
130  }
131 
132  return success(found);
133 }
134 
135 /// Given a list of MappedValues, cast them to the value kind implied by the
136 /// interface of the handle type, and dispatch to one of the callbacks.
142  if (llvm::isa<transform::TransformHandleTypeInterface>(handle.getType())) {
143  SmallVector<Operation *> operations;
144  operations.reserve(values.size());
145  for (transform::MappedValue value : values) {
146  if (auto *op = llvm::dyn_cast_if_present<Operation *>(value)) {
147  operations.push_back(op);
148  continue;
149  }
150  return emitSilenceableFailure(handle.getLoc())
151  << "wrong kind of value provided for top-level operation handle";
152  }
153  if (failed(operationsFn(operations)))
156  }
157 
158  if (llvm::isa<transform::TransformValueHandleTypeInterface>(
159  handle.getType())) {
160  SmallVector<Value> payloadValues;
161  payloadValues.reserve(values.size());
162  for (transform::MappedValue value : values) {
163  if (auto v = llvm::dyn_cast_if_present<Value>(value)) {
164  payloadValues.push_back(v);
165  continue;
166  }
167  return emitSilenceableFailure(handle.getLoc())
168  << "wrong kind of value provided for the top-level value handle";
169  }
170  if (failed(valuesFn(payloadValues)))
173  }
174 
175  assert(llvm::isa<transform::TransformParamTypeInterface>(handle.getType()) &&
176  "unsupported kind of block argument");
178  parameters.reserve(values.size());
179  for (transform::MappedValue value : values) {
180  if (auto attr = llvm::dyn_cast_if_present<Attribute>(value)) {
181  parameters.push_back(attr);
182  continue;
183  }
184  return emitSilenceableFailure(handle.getLoc())
185  << "wrong kind of value provided for top-level parameter";
186  }
187  if (failed(paramsFn(parameters)))
190 }
191 
194  ArrayRef<MappedValue> values) {
195  return dispatchMappedValues(
196  argument, values,
197  [&](ArrayRef<Operation *> operations) {
198  return setPayloadOps(argument, operations);
199  },
200  [&](ArrayRef<Param> params) {
201  return setParams(argument, params);
202  },
203  [&](ValueRange payloadValues) {
204  return setPayloadValues(argument, payloadValues);
205  })
206  .checkAndReport();
207 }
208 
210  Block::BlockArgListType arguments,
212  for (auto &&[argument, values] : llvm::zip_equal(arguments, mapping))
213  if (failed(mapBlockArgument(argument, values)))
214  return failure();
215  return success();
216 }
217 
219 transform::TransformState::setPayloadOps(Value value,
220  ArrayRef<Operation *> targets) {
221  assert(value != kTopLevelValue &&
222  "attempting to reset the transformation root");
223  assert(llvm::isa<TransformHandleTypeInterface>(value.getType()) &&
224  "wrong handle type");
225 
226  for (Operation *target : targets) {
227  if (target)
228  continue;
229  return emitError(value.getLoc())
230  << "attempting to assign a null payload op to this transform value";
231  }
232 
233  auto iface = llvm::cast<TransformHandleTypeInterface>(value.getType());
235  iface.checkPayload(value.getLoc(), targets);
236  if (failed(result.checkAndReport()))
237  return failure();
238 
239  // Setting new payload for the value without cleaning it first is a misuse of
240  // the API, assert here.
241  SmallVector<Operation *> storedTargets(targets.begin(), targets.end());
242  Mappings &mappings = getMapping(value);
243  bool inserted =
244  mappings.direct.insert({value, std::move(storedTargets)}).second;
245  assert(inserted && "value is already associated with another list");
246  (void)inserted;
247 
248  for (Operation *op : targets)
249  mappings.reverse[op].push_back(value);
250 
251  return success();
252 }
253 
255 transform::TransformState::setPayloadValues(Value handle,
256  ValueRange payloadValues) {
257  assert(handle != nullptr && "attempting to set params for a null value");
258  assert(llvm::isa<TransformValueHandleTypeInterface>(handle.getType()) &&
259  "wrong handle type");
260 
261  for (Value payload : payloadValues) {
262  if (payload)
263  continue;
264  return emitError(handle.getLoc()) << "attempting to assign a null payload "
265  "value to this transform handle";
266  }
267 
268  auto iface = llvm::cast<TransformValueHandleTypeInterface>(handle.getType());
269  SmallVector<Value> payloadValueVector = llvm::to_vector(payloadValues);
271  iface.checkPayload(handle.getLoc(), payloadValueVector);
272  if (failed(result.checkAndReport()))
273  return failure();
274 
275  Mappings &mappings = getMapping(handle);
276  bool inserted =
277  mappings.values.insert({handle, std::move(payloadValueVector)}).second;
278  assert(
279  inserted &&
280  "value handle is already associated with another list of payload values");
281  (void)inserted;
282 
283  for (Value payload : payloadValues)
284  mappings.reverseValues[payload].push_back(handle);
285 
286  return success();
287 }
288 
289 LogicalResult transform::TransformState::setParams(Value value,
290  ArrayRef<Param> params) {
291  assert(value != nullptr && "attempting to set params for a null value");
292 
293  for (Attribute attr : params) {
294  if (attr)
295  continue;
296  return emitError(value.getLoc())
297  << "attempting to assign a null parameter to this transform value";
298  }
299 
300  auto valueType = llvm::dyn_cast<TransformParamTypeInterface>(value.getType());
301  assert(value &&
302  "cannot associate parameter with a value of non-parameter type");
304  valueType.checkPayload(value.getLoc(), params);
305  if (failed(result.checkAndReport()))
306  return failure();
307 
308  Mappings &mappings = getMapping(value);
309  bool inserted =
310  mappings.params.insert({value, llvm::to_vector(params)}).second;
311  assert(inserted && "value is already associated with another list of params");
312  (void)inserted;
313  return success();
314 }
315 
316 template <typename Mapping, typename Key, typename Mapped>
317 void dropMappingEntry(Mapping &mapping, Key key, Mapped mapped) {
318  auto it = mapping.find(key);
319  if (it == mapping.end())
320  return;
321 
322  llvm::erase(it->getSecond(), mapped);
323  if (it->getSecond().empty())
324  mapping.erase(it);
325 }
326 
327 void transform::TransformState::forgetMapping(Value opHandle,
328  ValueRange origOpFlatResults,
329  bool allowOutOfScope) {
330  Mappings &mappings = getMapping(opHandle, allowOutOfScope);
331  for (Operation *op : mappings.direct[opHandle])
332  dropMappingEntry(mappings.reverse, op, opHandle);
333  mappings.direct.erase(opHandle);
334 #ifdef LLVM_ENABLE_ABI_BREAKING_CHECKS
335  // Payload IR is removed from the mapping. This invalidates the respective
336  // iterators.
337  mappings.incrementTimestamp(opHandle);
338 #endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
339 
340  for (Value opResult : origOpFlatResults) {
341  SmallVector<Value> resultHandles;
342  (void)getHandlesForPayloadValue(opResult, resultHandles);
343  for (Value resultHandle : resultHandles) {
344  Mappings &localMappings = getMapping(resultHandle);
345  dropMappingEntry(localMappings.values, resultHandle, opResult);
346 #ifdef LLVM_ENABLE_ABI_BREAKING_CHECKS
347  // Payload IR is removed from the mapping. This invalidates the respective
348  // iterators.
349  mappings.incrementTimestamp(resultHandle);
350 #endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
351  dropMappingEntry(localMappings.reverseValues, opResult, resultHandle);
352  }
353  }
354 }
355 
356 void transform::TransformState::forgetValueMapping(
357  Value valueHandle, ArrayRef<Operation *> payloadOperations) {
358  Mappings &mappings = getMapping(valueHandle);
359  for (Value payloadValue : mappings.reverseValues[valueHandle])
360  dropMappingEntry(mappings.reverseValues, payloadValue, valueHandle);
361  mappings.values.erase(valueHandle);
362 #ifdef LLVM_ENABLE_ABI_BREAKING_CHECKS
363  // Payload IR is removed from the mapping. This invalidates the respective
364  // iterators.
365  mappings.incrementTimestamp(valueHandle);
366 #endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
367 
368  for (Operation *payloadOp : payloadOperations) {
369  SmallVector<Value> opHandles;
370  (void)getHandlesForPayloadOp(payloadOp, opHandles);
371  for (Value opHandle : opHandles) {
372  Mappings &localMappings = getMapping(opHandle);
373  dropMappingEntry(localMappings.direct, opHandle, payloadOp);
374  dropMappingEntry(localMappings.reverse, payloadOp, opHandle);
375 
376 #ifdef LLVM_ENABLE_ABI_BREAKING_CHECKS
377  // Payload IR is removed from the mapping. This invalidates the respective
378  // iterators.
379  localMappings.incrementTimestamp(opHandle);
380 #endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
381  }
382  }
383 }
384 
386 transform::TransformState::replacePayloadOp(Operation *op,
387  Operation *replacement) {
388  // TODO: consider invalidating the handles to nested objects here.
389 
390 #ifndef NDEBUG
391  for (Value opResult : op->getResults()) {
392  SmallVector<Value> valueHandles;
393  (void)getHandlesForPayloadValue(opResult, valueHandles,
394  /*includeOutOfScope=*/true);
395  assert(valueHandles.empty() && "expected no mapping to old results");
396  }
397 #endif // NDEBUG
398 
399  // Drop the mapping between the op and all handles that point to it. Fail if
400  // there are no handles.
401  SmallVector<Value> opHandles;
402  if (failed(getHandlesForPayloadOp(op, opHandles, /*includeOutOfScope=*/true)))
403  return failure();
404  for (Value handle : opHandles) {
405  Mappings &mappings = getMapping(handle, /*allowOutOfScope=*/true);
406  dropMappingEntry(mappings.reverse, op, handle);
407  }
408 
409  // Replace the pointed-to object of all handles with the replacement object.
410  // In case a payload op was erased (replacement object is nullptr), a nullptr
411  // is stored in the mapping. These nullptrs are removed after each transform.
412  // Furthermore, nullptrs are not enumerated by payload op iterators. The
413  // relative order of ops is preserved.
414  //
415  // Removing an op from the mapping would be problematic because removing an
416  // element from an array invalidates iterators; merely changing the value of
417  // elements does not.
418  for (Value handle : opHandles) {
419  Mappings &mappings = getMapping(handle, /*allowOutOfScope=*/true);
420  auto it = mappings.direct.find(handle);
421  if (it == mappings.direct.end())
422  continue;
423 
424  SmallVector<Operation *, 2> &association = it->getSecond();
425  // Note that an operation may be associated with the handle more than once.
426  for (Operation *&mapped : association) {
427  if (mapped == op)
428  mapped = replacement;
429  }
430 
431  if (replacement) {
432  mappings.reverse[replacement].push_back(handle);
433  } else {
434  opHandlesToCompact.insert(handle);
435  }
436  }
437 
438  return success();
439 }
440 
442 transform::TransformState::replacePayloadValue(Value value, Value replacement) {
443  SmallVector<Value> valueHandles;
444  if (failed(getHandlesForPayloadValue(value, valueHandles,
445  /*includeOutOfScope=*/true)))
446  return failure();
447 
448  for (Value handle : valueHandles) {
449  Mappings &mappings = getMapping(handle, /*allowOutOfScope=*/true);
450  dropMappingEntry(mappings.reverseValues, value, handle);
451 
452  // If replacing with null, that is erasing the mapping, drop the mapping
453  // between the handles and the IR objects
454  if (!replacement) {
455  dropMappingEntry(mappings.values, handle, value);
456 #ifdef LLVM_ENABLE_ABI_BREAKING_CHECKS
457  // Payload IR is removed from the mapping. This invalidates the respective
458  // iterators.
459  mappings.incrementTimestamp(handle);
460 #endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
461  } else {
462  auto it = mappings.values.find(handle);
463  if (it == mappings.values.end())
464  continue;
465 
466  SmallVector<Value> &association = it->getSecond();
467  for (Value &mapped : association) {
468  if (mapped == value)
469  mapped = replacement;
470  }
471  mappings.reverseValues[replacement].push_back(handle);
472  }
473  }
474 
475  return success();
476 }
477 
478 void transform::TransformState::recordOpHandleInvalidationOne(
479  OpOperand &consumingHandle, ArrayRef<Operation *> potentialAncestors,
480  Operation *payloadOp, Value otherHandle, Value throughValue,
481  transform::TransformState::InvalidatedHandleMap &newlyInvalidated) const {
482  // If the op is associated with invalidated handle, skip the check as it
483  // may be reading invalid IR. This also ensures we report the first
484  // invalidation and not the last one.
485  if (invalidatedHandles.count(otherHandle) ||
486  newlyInvalidated.count(otherHandle))
487  return;
488 
489  FULL_LDBG("--recordOpHandleInvalidationOne\n");
490  DEBUG_WITH_TYPE(
492  llvm::interleaveComma(potentialAncestors, DBGS() << "--ancestors: ",
493  [](Operation *op) { llvm::dbgs() << *op; });
494  llvm::dbgs() << "\n");
495 
496  Operation *owner = consumingHandle.getOwner();
497  unsigned operandNo = consumingHandle.getOperandNumber();
498  for (Operation *ancestor : potentialAncestors) {
499  // clang-format off
500  DEBUG_WITH_TYPE(DEBUG_TYPE_FULL,
501  { (DBGS() << "----handle one ancestor: " << *ancestor << "\n"); });
502  DEBUG_WITH_TYPE(DEBUG_TYPE_FULL,
503  { (DBGS() << "----of payload with name: "
504  << payloadOp->getName().getIdentifier() << "\n"); });
505  DEBUG_WITH_TYPE(DEBUG_TYPE_FULL,
506  { (DBGS() << "----of payload: " << *payloadOp << "\n"); });
507  // clang-format on
508  if (!ancestor->isAncestor(payloadOp))
509  continue;
510 
511  // Make sure the error-reporting lambda doesn't capture anything
512  // by-reference because it will go out of scope. Additionally, extract
513  // location from Payload IR ops because the ops themselves may be
514  // deleted before the lambda gets called.
515  Location ancestorLoc = ancestor->getLoc();
516  Location opLoc = payloadOp->getLoc();
517  std::optional<Location> throughValueLoc =
518  throughValue ? std::make_optional(throughValue.getLoc()) : std::nullopt;
519  newlyInvalidated[otherHandle] = [ancestorLoc, opLoc, owner, operandNo,
520  otherHandle,
521  throughValueLoc](Location currentLoc) {
522  InFlightDiagnostic diag = emitError(currentLoc)
523  << "op uses a handle invalidated by a "
524  "previously executed transform op";
525  diag.attachNote(otherHandle.getLoc()) << "handle to invalidated ops";
526  diag.attachNote(owner->getLoc())
527  << "invalidated by this transform op that consumes its operand #"
528  << operandNo
529  << " and invalidates all handles to payload IR entities associated "
530  "with this operand and entities nested in them";
531  diag.attachNote(ancestorLoc) << "ancestor payload op";
532  diag.attachNote(opLoc) << "nested payload op";
533  if (throughValueLoc) {
534  diag.attachNote(*throughValueLoc)
535  << "consumed handle points to this payload value";
536  }
537  };
538  }
539 }
540 
541 void transform::TransformState::recordValueHandleInvalidationByOpHandleOne(
542  OpOperand &opHandle, ArrayRef<Operation *> potentialAncestors,
543  Value payloadValue, Value valueHandle,
544  transform::TransformState::InvalidatedHandleMap &newlyInvalidated) const {
545  // If the op is associated with invalidated handle, skip the check as it
546  // may be reading invalid IR. This also ensures we report the first
547  // invalidation and not the last one.
548  if (invalidatedHandles.count(valueHandle) ||
549  newlyInvalidated.count(valueHandle))
550  return;
551 
552  for (Operation *ancestor : potentialAncestors) {
553  Operation *definingOp;
554  std::optional<unsigned> resultNo;
555  unsigned argumentNo = std::numeric_limits<unsigned>::max();
556  unsigned blockNo = std::numeric_limits<unsigned>::max();
557  unsigned regionNo = std::numeric_limits<unsigned>::max();
558  if (auto opResult = llvm::dyn_cast<OpResult>(payloadValue)) {
559  definingOp = opResult.getOwner();
560  resultNo = opResult.getResultNumber();
561  } else {
562  auto arg = llvm::cast<BlockArgument>(payloadValue);
563  definingOp = arg.getParentBlock()->getParentOp();
564  argumentNo = arg.getArgNumber();
565  blockNo = std::distance(arg.getOwner()->getParent()->begin(),
566  arg.getOwner()->getIterator());
567  regionNo = arg.getOwner()->getParent()->getRegionNumber();
568  }
569  assert(definingOp && "expected the value to be defined by an op as result "
570  "or block argument");
571  if (!ancestor->isAncestor(definingOp))
572  continue;
573 
574  Operation *owner = opHandle.getOwner();
575  unsigned operandNo = opHandle.getOperandNumber();
576  Location ancestorLoc = ancestor->getLoc();
577  Location opLoc = definingOp->getLoc();
578  Location valueLoc = payloadValue.getLoc();
579  newlyInvalidated[valueHandle] = [valueHandle, owner, operandNo, resultNo,
580  argumentNo, blockNo, regionNo, ancestorLoc,
581  opLoc, valueLoc](Location currentLoc) {
582  InFlightDiagnostic diag = emitError(currentLoc)
583  << "op uses a handle invalidated by a "
584  "previously executed transform op";
585  diag.attachNote(valueHandle.getLoc()) << "invalidated handle";
586  diag.attachNote(owner->getLoc())
587  << "invalidated by this transform op that consumes its operand #"
588  << operandNo
589  << " and invalidates all handles to payload IR entities "
590  "associated with this operand and entities nested in them";
591  diag.attachNote(ancestorLoc)
592  << "ancestor op associated with the consumed handle";
593  if (resultNo) {
594  diag.attachNote(opLoc)
595  << "op defining the value as result #" << *resultNo;
596  } else {
597  diag.attachNote(opLoc)
598  << "op defining the value as block argument #" << argumentNo
599  << " of block #" << blockNo << " in region #" << regionNo;
600  }
601  diag.attachNote(valueLoc) << "payload value";
602  };
603  }
604 }
605 
606 void transform::TransformState::recordOpHandleInvalidation(
607  OpOperand &handle, ArrayRef<Operation *> potentialAncestors,
608  Value throughValue,
609  transform::TransformState::InvalidatedHandleMap &newlyInvalidated) const {
610 
611  if (potentialAncestors.empty()) {
612  DEBUG_WITH_TYPE(DEBUG_TYPE_FULL, {
613  (DBGS() << "----recording invalidation for empty handle: " << handle.get()
614  << "\n");
615  });
616 
617  Operation *owner = handle.getOwner();
618  unsigned operandNo = handle.getOperandNumber();
619  newlyInvalidated[handle.get()] = [owner, operandNo](Location currentLoc) {
620  InFlightDiagnostic diag = emitError(currentLoc)
621  << "op uses a handle associated with empty "
622  "payload and invalidated by a "
623  "previously executed transform op";
624  diag.attachNote(owner->getLoc())
625  << "invalidated by this transform op that consumes its operand #"
626  << operandNo;
627  };
628  return;
629  }
630 
631  // Iterate over the mapping and invalidate aliasing handles. This is quite
632  // expensive and only necessary for error reporting in case of transform
633  // dialect misuse with dangling handles. Iteration over the handles is based
634  // on the assumption that the number of handles is significantly less than the
635  // number of IR objects (operations and values). Alternatively, we could walk
636  // the IR nested in each payload op associated with the given handle and look
637  // for handles associated with each operation and value.
638  for (const auto &[region, mapping] : llvm::reverse(mappings)) {
639  // Go over all op handle mappings and mark as invalidated any handle
640  // pointing to any of the payload ops associated with the given handle or
641  // any op nested in them.
642  for (const auto &[payloadOp, otherHandles] : mapping->reverse) {
643  for (Value otherHandle : otherHandles)
644  recordOpHandleInvalidationOne(handle, potentialAncestors, payloadOp,
645  otherHandle, throughValue,
646  newlyInvalidated);
647  }
648  // Go over all value handle mappings and mark as invalidated any handle
649  // pointing to any result of the payload op associated with the given handle
650  // or any op nested in them. Similarly invalidate handles to argument of
651  // blocks belonging to any region of any payload op associated with the
652  // given handle or any op nested in them.
653  for (const auto &[payloadValue, valueHandles] : mapping->reverseValues) {
654  for (Value valueHandle : valueHandles)
655  recordValueHandleInvalidationByOpHandleOne(handle, potentialAncestors,
656  payloadValue, valueHandle,
657  newlyInvalidated);
658  }
659 
660  // Stop lookup when reaching a region that is isolated from above.
662  break;
663  }
664 }
665 
666 void transform::TransformState::recordValueHandleInvalidation(
667  OpOperand &valueHandle,
668  transform::TransformState::InvalidatedHandleMap &newlyInvalidated) const {
669  // Invalidate other handles to the same value.
670  for (Value payloadValue : getPayloadValuesView(valueHandle.get())) {
671  SmallVector<Value> otherValueHandles;
672  (void)getHandlesForPayloadValue(payloadValue, otherValueHandles);
673  for (Value otherHandle : otherValueHandles) {
674  Operation *owner = valueHandle.getOwner();
675  unsigned operandNo = valueHandle.getOperandNumber();
676  Location valueLoc = payloadValue.getLoc();
677  newlyInvalidated[otherHandle] = [otherHandle, owner, operandNo,
678  valueLoc](Location currentLoc) {
679  InFlightDiagnostic diag = emitError(currentLoc)
680  << "op uses a handle invalidated by a "
681  "previously executed transform op";
682  diag.attachNote(otherHandle.getLoc()) << "invalidated handle";
683  diag.attachNote(owner->getLoc())
684  << "invalidated by this transform op that consumes its operand #"
685  << operandNo
686  << " and invalidates handles to the same values as associated with "
687  "it";
688  diag.attachNote(valueLoc) << "payload value";
689  };
690  }
691 
692  if (auto opResult = llvm::dyn_cast<OpResult>(payloadValue)) {
693  Operation *payloadOp = opResult.getOwner();
694  recordOpHandleInvalidation(valueHandle, payloadOp, payloadValue,
695  newlyInvalidated);
696  } else {
697  auto arg = llvm::dyn_cast<BlockArgument>(payloadValue);
698  for (Operation &payloadOp : *arg.getOwner())
699  recordOpHandleInvalidation(valueHandle, &payloadOp, payloadValue,
700  newlyInvalidated);
701  }
702  }
703 }
704 
705 /// Checks that the operation does not use invalidated handles as operands.
706 /// Reports errors and returns failure if it does. Otherwise, invalidates the
707 /// handles consumed by the operation as well as any handles pointing to payload
708 /// IR operations nested in the operations associated with the consumed handles.
709 LogicalResult transform::TransformState::checkAndRecordHandleInvalidationImpl(
710  transform::TransformOpInterface transform,
711  transform::TransformState::InvalidatedHandleMap &newlyInvalidated) const {
712  FULL_LDBG("--Start checkAndRecordHandleInvalidation\n");
713  auto memoryEffectsIface =
714  cast<MemoryEffectOpInterface>(transform.getOperation());
716  memoryEffectsIface.getEffectsOnResource(
718 
719  for (OpOperand &target : transform->getOpOperands()) {
720  DEBUG_WITH_TYPE(DEBUG_TYPE_FULL, {
721  (DBGS() << "----iterate on handle: " << target.get() << "\n");
722  });
723  // If the operand uses an invalidated handle, report it. If the operation
724  // allows handles to point to repeated payload operations, only report
725  // pre-existing invalidation errors. Otherwise, also report invalidations
726  // caused by the current transform operation affecting its other operands.
727  auto it = invalidatedHandles.find(target.get());
728  auto nit = newlyInvalidated.find(target.get());
729  if (it != invalidatedHandles.end()) {
730  FULL_LDBG("--End checkAndRecordHandleInvalidation, found already "
731  "invalidated -> FAILURE\n");
732  return it->getSecond()(transform->getLoc()), failure();
733  }
734  if (!transform.allowsRepeatedHandleOperands() &&
735  nit != newlyInvalidated.end()) {
736  FULL_LDBG("--End checkAndRecordHandleInvalidation, found newly "
737  "invalidated (by this op) -> FAILURE\n");
738  return nit->getSecond()(transform->getLoc()), failure();
739  }
740 
741  // Invalidate handles pointing to the operations nested in the operation
742  // associated with the handle consumed by this operation.
743  auto consumesTarget = [&](const MemoryEffects::EffectInstance &effect) {
744  return isa<MemoryEffects::Free>(effect.getEffect()) &&
745  effect.getValue() == target.get();
746  };
747  if (llvm::any_of(effects, consumesTarget)) {
748  FULL_LDBG("----found consume effect\n");
749  if (llvm::isa<transform::TransformHandleTypeInterface>(
750  target.get().getType())) {
751  FULL_LDBG("----recordOpHandleInvalidation\n");
752  SmallVector<Operation *> payloadOps =
753  llvm::to_vector(getPayloadOps(target.get()));
754  recordOpHandleInvalidation(target, payloadOps, nullptr,
755  newlyInvalidated);
756  } else if (llvm::isa<transform::TransformValueHandleTypeInterface>(
757  target.get().getType())) {
758  FULL_LDBG("----recordValueHandleInvalidation\n");
759  recordValueHandleInvalidation(target, newlyInvalidated);
760  } else {
761  FULL_LDBG("----not a TransformHandle -> SKIP AND DROP ON THE FLOOR\n");
762  }
763  } else {
764  FULL_LDBG("----no consume effect -> SKIP\n");
765  }
766  }
767 
768  FULL_LDBG("--End checkAndRecordHandleInvalidation -> SUCCESS\n");
769  return success();
770 }
771 
772 LogicalResult transform::TransformState::checkAndRecordHandleInvalidation(
773  transform::TransformOpInterface transform) {
774  InvalidatedHandleMap newlyInvalidated;
775  LogicalResult checkResult =
776  checkAndRecordHandleInvalidationImpl(transform, newlyInvalidated);
777  invalidatedHandles.insert(std::make_move_iterator(newlyInvalidated.begin()),
778  std::make_move_iterator(newlyInvalidated.end()));
779  return checkResult;
780 }
781 
782 template <typename T>
785  transform::TransformOpInterface transform,
786  unsigned operandNumber) {
787  DenseSet<T> seen;
788  for (T p : payload) {
789  if (!seen.insert(p).second) {
791  transform.emitSilenceableError()
792  << "a handle passed as operand #" << operandNumber
793  << " and consumed by this operation points to a payload "
794  "entity more than once";
795  if constexpr (std::is_pointer_v<T>)
796  diag.attachNote(p->getLoc()) << "repeated target op";
797  else
798  diag.attachNote(p.getLoc()) << "repeated target value";
799  return diag;
800  }
801  }
803 }
804 
805 void transform::TransformState::compactOpHandles() {
806  for (Value handle : opHandlesToCompact) {
807  Mappings &mappings = getMapping(handle, /*allowOutOfScope=*/true);
808 #ifdef LLVM_ENABLE_ABI_BREAKING_CHECKS
809  if (llvm::find(mappings.direct[handle], nullptr) !=
810  mappings.direct[handle].end())
811  // Payload IR is removed from the mapping. This invalidates the respective
812  // iterators.
813  mappings.incrementTimestamp(handle);
814 #endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
815  llvm::erase(mappings.direct[handle], nullptr);
816  }
817  opHandlesToCompact.clear();
818 }
819 
821 transform::TransformState::applyTransform(TransformOpInterface transform) {
822  LLVM_DEBUG({
823  DBGS() << "applying: ";
824  transform->print(llvm::dbgs(), OpPrintingFlags().skipRegions());
825  llvm::dbgs() << "\n";
826  });
827  DEBUG_WITH_TYPE(DEBUG_TYPE_FULL,
828  DBGS() << "Top-level payload before application:\n"
829  << *getTopLevel() << "\n");
830  auto printOnFailureRAII = llvm::make_scope_exit([this] {
831  (void)this;
832  LLVM_DEBUG(DBGS() << "Failing Top-level payload:\n"; getTopLevel()->print(
833  llvm::dbgs(), mlir::OpPrintingFlags().printGenericOpForm()););
834  });
835 
836  // Set current transform op.
837  regionStack.back()->currentTransform = transform;
838 
839  // Expensive checks to detect invalid transform IR.
840  if (options.getExpensiveChecksEnabled()) {
841  FULL_LDBG("ExpensiveChecksEnabled\n");
842  if (failed(checkAndRecordHandleInvalidation(transform)))
844 
845  for (OpOperand &operand : transform->getOpOperands()) {
846  DEBUG_WITH_TYPE(DEBUG_TYPE_FULL, {
847  (DBGS() << "iterate on handle: " << operand.get() << "\n");
848  });
849  if (!isHandleConsumed(operand.get(), transform)) {
850  FULL_LDBG("--handle not consumed -> SKIP\n");
851  continue;
852  }
853  if (transform.allowsRepeatedHandleOperands()) {
854  FULL_LDBG("--op allows repeated handles -> SKIP\n");
855  continue;
856  }
857  FULL_LDBG("--handle is consumed\n");
858 
859  Type operandType = operand.get().getType();
860  if (llvm::isa<TransformHandleTypeInterface>(operandType)) {
861  FULL_LDBG("--checkRepeatedConsumptionInOperand for Operation*\n");
863  checkRepeatedConsumptionInOperand<Operation *>(
864  getPayloadOpsView(operand.get()), transform,
865  operand.getOperandNumber());
866  if (!check.succeeded()) {
867  FULL_LDBG("----FAILED\n");
868  return check;
869  }
870  } else if (llvm::isa<TransformValueHandleTypeInterface>(operandType)) {
871  FULL_LDBG("--checkRepeatedConsumptionInOperand For Value\n");
873  checkRepeatedConsumptionInOperand<Value>(
874  getPayloadValuesView(operand.get()), transform,
875  operand.getOperandNumber());
876  if (!check.succeeded()) {
877  FULL_LDBG("----FAILED\n");
878  return check;
879  }
880  } else {
881  FULL_LDBG("--not a TransformHandle -> SKIP AND DROP ON THE FLOOR\n");
882  }
883  }
884  }
885 
886  // Find which operands are consumed.
887  SmallVector<OpOperand *> consumedOperands =
888  transform.getConsumedHandleOpOperands();
889 
890  // Remember the results of the payload ops associated with the consumed
891  // op handles or the ops defining the value handles so we can drop the
892  // association with them later. This must happen here because the
893  // transformation may destroy or mutate them so we cannot traverse the payload
894  // IR after that.
895  SmallVector<Value> origOpFlatResults;
896  SmallVector<Operation *> origAssociatedOps;
897  for (OpOperand *opOperand : consumedOperands) {
898  Value operand = opOperand->get();
899  if (llvm::isa<TransformHandleTypeInterface>(operand.getType())) {
900  for (Operation *payloadOp : getPayloadOps(operand)) {
901  llvm::append_range(origOpFlatResults, payloadOp->getResults());
902  }
903  continue;
904  }
905  if (llvm::isa<TransformValueHandleTypeInterface>(operand.getType())) {
906  for (Value payloadValue : getPayloadValuesView(operand)) {
907  if (llvm::isa<OpResult>(payloadValue)) {
908  origAssociatedOps.push_back(payloadValue.getDefiningOp());
909  continue;
910  }
911  llvm::append_range(
912  origAssociatedOps,
913  llvm::map_range(*llvm::cast<BlockArgument>(payloadValue).getOwner(),
914  [](Operation &op) { return &op; }));
915  }
916  continue;
917  }
919  emitDefiniteFailure(transform->getLoc())
920  << "unexpectedly consumed a value that is not a handle as operand #"
921  << opOperand->getOperandNumber();
922  diag.attachNote(operand.getLoc())
923  << "value defined here with type " << operand.getType();
924  return diag;
925  }
926 
927  // Prepare rewriter and listener.
928  TrackingListenerConfig config;
929  config.skipHandleFn = [&](Value handle) {
930  // Skip handle if it is dead.
931  auto scopeIt =
932  llvm::find_if(llvm::reverse(regionStack), [&](RegionScope *scope) {
933  return handle.getParentRegion() == scope->region;
934  });
935  assert(scopeIt != regionStack.rend() &&
936  "could not find region scope for handle");
937  RegionScope *scope = *scopeIt;
938  for (Operation *user : handle.getUsers()) {
939  if (user != scope->currentTransform &&
940  !happensBefore(user, scope->currentTransform))
941  return false;
942  }
943  return true;
944  };
945  transform::ErrorCheckingTrackingListener trackingListener(*this, transform,
946  config);
947  transform::TransformRewriter rewriter(transform->getContext(),
948  &trackingListener);
949 
950  // Compute the result but do not short-circuit the silenceable failure case as
951  // we still want the handles to propagate properly so the "suppress" mode can
952  // proceed on a best effort basis.
953  transform::TransformResults results(transform->getNumResults());
954  DiagnosedSilenceableFailure result(transform.apply(rewriter, results, *this));
955  compactOpHandles();
956 
957  // Error handling: fail if transform or listener failed.
958  DiagnosedSilenceableFailure trackingFailure =
959  trackingListener.checkAndResetError();
960  if (!transform->hasTrait<ReportTrackingListenerFailuresOpTrait>() ||
961  transform->hasAttr(FindPayloadReplacementOpInterface::
962  kSilenceTrackingFailuresAttrName)) {
963  // Only report failures for ReportTrackingListenerFailuresOpTrait ops. Also
964  // do not report failures if the above mentioned attribute is set.
965  if (trackingFailure.isSilenceableFailure())
966  (void)trackingFailure.silence();
967  trackingFailure = DiagnosedSilenceableFailure::success();
968  }
969  if (!trackingFailure.succeeded()) {
970  if (result.succeeded()) {
971  result = std::move(trackingFailure);
972  } else {
973  // Transform op errors have precedence, report those first.
974  if (result.isSilenceableFailure())
975  result.attachNote() << "tracking listener also failed: "
976  << trackingFailure.getMessage();
977  (void)trackingFailure.silence();
978  }
979  }
980  if (result.isDefiniteFailure())
981  return result;
982 
983  // If a silenceable failure was produced, some results may be unset, set them
984  // to empty lists.
985  if (result.isSilenceableFailure())
986  results.setRemainingToEmpty(transform);
987 
988  // Remove the mapping for the operand if it is consumed by the operation. This
989  // allows us to catch use-after-free with assertions later on.
990  for (OpOperand *opOperand : consumedOperands) {
991  Value operand = opOperand->get();
992  if (llvm::isa<TransformHandleTypeInterface>(operand.getType())) {
993  forgetMapping(operand, origOpFlatResults);
994  } else if (llvm::isa<TransformValueHandleTypeInterface>(
995  operand.getType())) {
996  forgetValueMapping(operand, origAssociatedOps);
997  }
998  }
999 
1000  if (failed(updateStateFromResults(results, transform->getResults())))
1002 
1003  printOnFailureRAII.release();
1004  DEBUG_WITH_TYPE(DEBUG_PRINT_AFTER_ALL, {
1005  DBGS() << "Top-level payload:\n";
1006  getTopLevel()->print(llvm::dbgs());
1007  });
1008  return result;
1009 }
1010 
1011 LogicalResult transform::TransformState::updateStateFromResults(
1012  const TransformResults &results, ResultRange opResults) {
1013  for (OpResult result : opResults) {
1014  if (llvm::isa<TransformParamTypeInterface>(result.getType())) {
1015  assert(results.isParam(result.getResultNumber()) &&
1016  "expected parameters for the parameter-typed result");
1017  if (failed(
1018  setParams(result, results.getParams(result.getResultNumber())))) {
1019  return failure();
1020  }
1021  } else if (llvm::isa<TransformValueHandleTypeInterface>(result.getType())) {
1022  assert(results.isValue(result.getResultNumber()) &&
1023  "expected values for value-type-result");
1024  if (failed(setPayloadValues(
1025  result, results.getValues(result.getResultNumber())))) {
1026  return failure();
1027  }
1028  } else {
1029  assert(!results.isParam(result.getResultNumber()) &&
1030  "expected payload ops for the non-parameter typed result");
1031  if (failed(
1032  setPayloadOps(result, results.get(result.getResultNumber())))) {
1033  return failure();
1034  }
1035  }
1036  }
1037  return success();
1038 }
1039 
1040 //===----------------------------------------------------------------------===//
1041 // TransformState::Extension
1042 //===----------------------------------------------------------------------===//
1043 
1045 
1048  Operation *replacement) {
1049  // TODO: we may need to invalidate handles to operations and values nested in
1050  // the operation being replaced.
1051  return state.replacePayloadOp(op, replacement);
1052 }
1053 
1056  Value replacement) {
1057  return state.replacePayloadValue(value, replacement);
1058 }
1059 
1060 //===----------------------------------------------------------------------===//
1061 // TransformState::RegionScope
1062 //===----------------------------------------------------------------------===//
1063 
1065  // Remove handle invalidation notices as handles are going out of scope.
1066  // The same region may be re-entered leading to incorrect invalidation
1067  // errors.
1068  for (Block &block : *region) {
1069  for (Value handle : block.getArguments()) {
1070  state.invalidatedHandles.erase(handle);
1071  }
1072  for (Operation &op : block) {
1073  for (Value handle : op.getResults()) {
1074  state.invalidatedHandles.erase(handle);
1075  }
1076  }
1077  }
1078 
1079 #if LLVM_ENABLE_ABI_BREAKING_CHECKS
1080  // Remember pointers to payload ops referenced by the handles going out of
1081  // scope.
1082  SmallVector<Operation *> referencedOps =
1083  llvm::to_vector(llvm::make_first_range(state.mappings[region]->reverse));
1084 #endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
1085 
1086  state.mappings.erase(region);
1087  state.regionStack.pop_back();
1088 }
1089 
1090 //===----------------------------------------------------------------------===//
1091 // TransformResults
1092 //===----------------------------------------------------------------------===//
1093 
1094 transform::TransformResults::TransformResults(unsigned numSegments) {
1095  operations.appendEmptyRows(numSegments);
1096  params.appendEmptyRows(numSegments);
1097  values.appendEmptyRows(numSegments);
1098 }
1099 
1102  int64_t position = value.getResultNumber();
1103  assert(position < static_cast<int64_t>(this->params.size()) &&
1104  "setting params for a non-existent handle");
1105  assert(this->params[position].data() == nullptr && "params already set");
1106  assert(operations[position].data() == nullptr &&
1107  "another kind of results already set");
1108  assert(values[position].data() == nullptr &&
1109  "another kind of results already set");
1110  this->params.replace(position, params);
1111 }
1112 
1114  OpResult handle, ArrayRef<MappedValue> values) {
1116  handle, values,
1117  [&](ArrayRef<Operation *> operations) {
1118  return set(handle, operations), success();
1119  },
1120  [&](ArrayRef<Param> params) {
1121  return setParams(handle, params), success();
1122  },
1123  [&](ValueRange payloadValues) {
1124  return setValues(handle, payloadValues), success();
1125  });
1126 #ifndef NDEBUG
1127  if (!diag.succeeded())
1128  llvm::dbgs() << diag.getStatusString() << "\n";
1129  assert(diag.succeeded() && "incorrect mapping");
1130 #endif // NDEBUG
1131  (void)diag.silence();
1132 }
1133 
1135  transform::TransformOpInterface transform) {
1136  for (OpResult opResult : transform->getResults()) {
1137  if (!isSet(opResult.getResultNumber()))
1138  setMappedValues(opResult, {});
1139  }
1140 }
1141 
1143 transform::TransformResults::get(unsigned resultNumber) const {
1144  assert(resultNumber < operations.size() &&
1145  "querying results for a non-existent handle");
1146  assert(operations[resultNumber].data() != nullptr &&
1147  "querying unset results (values or params expected?)");
1148  return operations[resultNumber];
1149 }
1150 
1152 transform::TransformResults::getParams(unsigned resultNumber) const {
1153  assert(resultNumber < params.size() &&
1154  "querying params for a non-existent handle");
1155  assert(params[resultNumber].data() != nullptr &&
1156  "querying unset params (ops or values expected?)");
1157  return params[resultNumber];
1158 }
1159 
1161 transform::TransformResults::getValues(unsigned resultNumber) const {
1162  assert(resultNumber < values.size() &&
1163  "querying values for a non-existent handle");
1164  assert(values[resultNumber].data() != nullptr &&
1165  "querying unset values (ops or params expected?)");
1166  return values[resultNumber];
1167 }
1168 
1169 bool transform::TransformResults::isParam(unsigned resultNumber) const {
1170  assert(resultNumber < params.size() &&
1171  "querying association for a non-existent handle");
1172  return params[resultNumber].data() != nullptr;
1173 }
1174 
1175 bool transform::TransformResults::isValue(unsigned resultNumber) const {
1176  assert(resultNumber < values.size() &&
1177  "querying association for a non-existent handle");
1178  return values[resultNumber].data() != nullptr;
1179 }
1180 
1181 bool transform::TransformResults::isSet(unsigned resultNumber) const {
1182  assert(resultNumber < params.size() &&
1183  "querying association for a non-existent handle");
1184  return params[resultNumber].data() != nullptr ||
1185  operations[resultNumber].data() != nullptr ||
1186  values[resultNumber].data() != nullptr;
1187 }
1188 
1189 //===----------------------------------------------------------------------===//
1190 // TrackingListener
1191 //===----------------------------------------------------------------------===//
1192 
1194  TransformOpInterface op,
1195  TrackingListenerConfig config)
1196  : TransformState::Extension(state), transformOp(op), config(config) {
1197  if (op) {
1198  for (OpOperand *opOperand : transformOp.getConsumedHandleOpOperands()) {
1199  consumedHandles.insert(opOperand->get());
1200  }
1201  }
1202 }
1203 
1205  Operation *defOp = nullptr;
1206  for (Value v : values) {
1207  // Skip empty values.
1208  if (!v)
1209  continue;
1210  if (!defOp) {
1211  defOp = v.getDefiningOp();
1212  continue;
1213  }
1214  if (defOp != v.getDefiningOp())
1215  return nullptr;
1216  }
1217  return defOp;
1218 }
1219 
1221  Operation *&result, Operation *op, ValueRange newValues) const {
1222  assert(op->getNumResults() == newValues.size() &&
1223  "invalid number of replacement values");
1224  SmallVector<Value> values(newValues.begin(), newValues.end());
1225 
1227  getTransformOp(), "tracking listener failed to find replacement op "
1228  "during application of this transform op");
1229 
1230  do {
1231  // If the replacement values belong to different ops, drop the mapping.
1232  Operation *defOp = getCommonDefiningOp(values);
1233  if (!defOp) {
1234  diag.attachNote() << "replacement values belong to different ops";
1235  return diag;
1236  }
1237 
1238  // Skip through ops that implement CastOpInterface.
1239  if (config.skipCastOps && isa<CastOpInterface>(defOp)) {
1240  values.clear();
1241  values.assign(defOp->getOperands().begin(), defOp->getOperands().end());
1242  diag.attachNote(defOp->getLoc())
1243  << "using output of 'CastOpInterface' op";
1244  continue;
1245  }
1246 
1247  // If the defining op has the same name or we do not care about the name of
1248  // op replacements at all, we take it as a replacement.
1249  if (!config.requireMatchingReplacementOpName ||
1250  op->getName() == defOp->getName()) {
1251  result = defOp;
1253  }
1254 
1255  // Replacing an op with a constant-like equivalent is a common
1256  // canonicalization.
1257  if (defOp->hasTrait<OpTrait::ConstantLike>()) {
1258  result = defOp;
1260  }
1261 
1262  values.clear();
1263 
1264  // Skip through ops that implement FindPayloadReplacementOpInterface.
1265  if (auto findReplacementOpInterface =
1266  dyn_cast<FindPayloadReplacementOpInterface>(defOp)) {
1267  values.assign(findReplacementOpInterface.getNextOperands());
1268  diag.attachNote(defOp->getLoc()) << "using operands provided by "
1269  "'FindPayloadReplacementOpInterface'";
1270  continue;
1271  }
1272  } while (!values.empty());
1273 
1274  diag.attachNote() << "ran out of suitable replacement values";
1275  return diag;
1276 }
1277 
1279  Location loc, function_ref<void(Diagnostic &)> reasonCallback) {
1280  LLVM_DEBUG({
1282  reasonCallback(diag);
1283  DBGS() << "Match Failure : " << diag.str() << "\n";
1284  });
1285 }
1286 
1287 void transform::TrackingListener::notifyOperationErased(Operation *op) {
1288  // Remove mappings for result values.
1289  for (OpResult value : op->getResults())
1290  (void)replacePayloadValue(value, nullptr);
1291  // Remove mapping for op.
1292  (void)replacePayloadOp(op, nullptr);
1293 }
1294 
1295 void transform::TrackingListener::notifyOperationReplaced(
1296  Operation *op, ValueRange newValues) {
1297  assert(op->getNumResults() == newValues.size() &&
1298  "invalid number of replacement values");
1299 
1300  // Replace value handles.
1301  for (auto [oldValue, newValue] : llvm::zip(op->getResults(), newValues))
1302  (void)replacePayloadValue(oldValue, newValue);
1303 
1304  // Replace op handle.
1305  SmallVector<Value> opHandles;
1306  if (failed(getTransformState().getHandlesForPayloadOp(
1307  op, opHandles, /*includeOutOfScope=*/true))) {
1308  // Op is not tracked.
1309  return;
1310  }
1311 
1312  // Helper function to check if the current transform op consumes any handle
1313  // that is mapped to `op`.
1314  //
1315  // Note: If a handle was consumed, there shouldn't be any alive users, so it
1316  // is not really necessary to check for consumed handles. However, in case
1317  // there are indeed alive handles that were consumed (which is undefined
1318  // behavior) and a replacement op could not be found, we want to fail with a
1319  // nicer error message: "op uses a handle invalidated..." instead of "could
1320  // not find replacement op". This nicer error is produced later.
1321  auto handleWasConsumed = [&] {
1322  return llvm::any_of(opHandles,
1323  [&](Value h) { return consumedHandles.contains(h); });
1324  };
1325 
1326  // Check if there are any handles that must be updated.
1327  Value aliveHandle;
1328  if (config.skipHandleFn) {
1329  auto it = llvm::find_if(opHandles,
1330  [&](Value v) { return !config.skipHandleFn(v); });
1331  if (it != opHandles.end())
1332  aliveHandle = *it;
1333  } else if (!opHandles.empty()) {
1334  aliveHandle = opHandles.front();
1335  }
1336  if (!aliveHandle || handleWasConsumed()) {
1337  // The op is tracked but the corresponding handles are dead or were
1338  // consumed. Drop the op form the mapping.
1339  (void)replacePayloadOp(op, nullptr);
1340  return;
1341  }
1342 
1343  Operation *replacement;
1345  findReplacementOp(replacement, op, newValues);
1346  // If the op is tracked but no replacement op was found, send a
1347  // notification.
1348  if (!diag.succeeded()) {
1349  diag.attachNote(aliveHandle.getLoc())
1350  << "replacement is required because this handle must be updated";
1351  notifyPayloadReplacementNotFound(op, newValues, std::move(diag));
1352  (void)replacePayloadOp(op, nullptr);
1353  return;
1354  }
1355 
1356  (void)replacePayloadOp(op, replacement);
1357 }
1358 
1360  // The state of the ErrorCheckingTrackingListener must be checked and reset
1361  // if there was an error. This is to prevent errors from accidentally being
1362  // missed.
1363  assert(status.succeeded() && "listener state was not checked");
1364 }
1365 
1368  DiagnosedSilenceableFailure s = std::move(status);
1370  errorCounter = 0;
1371  return s;
1372 }
1373 
1375  return !status.succeeded();
1376 }
1377 
1380 
1381  // Merge potentially existing diags and store the result in the listener.
1383  diag.takeDiagnostics(diags);
1384  if (!status.succeeded())
1385  status.takeDiagnostics(diags);
1386  status = DiagnosedSilenceableFailure::silenceableFailure(std::move(diags));
1387 
1388  // Report more details.
1389  status.attachNote(op->getLoc()) << "[" << errorCounter << "] replaced op";
1390  for (auto &&[index, value] : llvm::enumerate(values))
1391  status.attachNote(value.getLoc())
1392  << "[" << errorCounter << "] replacement value " << index;
1393  ++errorCounter;
1394 }
1395 
1396 //===----------------------------------------------------------------------===//
1397 // TransformRewriter
1398 //===----------------------------------------------------------------------===//
1399 
1402  : RewriterBase(ctx), listener(listener) {
1403  setListener(listener);
1404 }
1405 
1407  return listener->failed();
1408 }
1409 
1410 /// Silence all tracking failures that have been encountered so far.
1412  if (hasTrackingFailures()) {
1413  DiagnosedSilenceableFailure status = listener->checkAndResetError();
1414  (void)status.silence();
1415  }
1416 }
1417 
1419  Operation *op, Operation *replacement) {
1420  return listener->replacePayloadOp(op, replacement);
1421 }
1422 
1423 //===----------------------------------------------------------------------===//
1424 // Utilities for TransformEachOpTrait.
1425 //===----------------------------------------------------------------------===//
1426 
1429  ArrayRef<Operation *> targets) {
1430  for (auto &&[position, parent] : llvm::enumerate(targets)) {
1431  for (Operation *child : targets.drop_front(position + 1)) {
1432  if (parent->isAncestor(child)) {
1434  emitError(loc)
1435  << "transform operation consumes a handle pointing to an ancestor "
1436  "payload operation before its descendant";
1437  diag.attachNote()
1438  << "the ancestor is likely erased or rewritten before the "
1439  "descendant is accessed, leading to undefined behavior";
1440  diag.attachNote(parent->getLoc()) << "ancestor payload op";
1441  diag.attachNote(child->getLoc()) << "descendant payload op";
1442  return diag;
1443  }
1444  }
1445  }
1446  return success();
1447 }
1448 
1451  Location payloadOpLoc,
1452  const ApplyToEachResultList &partialResult) {
1453  Location transformOpLoc = transformOp->getLoc();
1454  StringRef transformOpName = transformOp->getName().getStringRef();
1455  unsigned expectedNumResults = transformOp->getNumResults();
1456 
1457  // Reuse the emission of the diagnostic note.
1458  auto emitDiag = [&]() {
1459  auto diag = mlir::emitError(transformOpLoc);
1460  diag.attachNote(payloadOpLoc) << "when applied to this op";
1461  return diag;
1462  };
1463 
1464  if (partialResult.size() != expectedNumResults) {
1465  auto diag = emitDiag() << "application of " << transformOpName
1466  << " expected to produce " << expectedNumResults
1467  << " results (actually produced "
1468  << partialResult.size() << ").";
1469  diag.attachNote(transformOpLoc)
1470  << "if you need variadic results, consider a generic `apply` "
1471  << "instead of the specialized `applyToOne`.";
1472  return failure();
1473  }
1474 
1475  // Check that the right kind of value was produced.
1476  for (const auto &[ptr, res] :
1477  llvm::zip(partialResult, transformOp->getResults())) {
1478  if (ptr.isNull())
1479  continue;
1480  if (llvm::isa<TransformHandleTypeInterface>(res.getType()) &&
1481  !ptr.is<Operation *>()) {
1482  return emitDiag() << "application of " << transformOpName
1483  << " expected to produce an Operation * for result #"
1484  << res.getResultNumber();
1485  }
1486  if (llvm::isa<TransformParamTypeInterface>(res.getType()) &&
1487  !ptr.is<Attribute>()) {
1488  return emitDiag() << "application of " << transformOpName
1489  << " expected to produce an Attribute for result #"
1490  << res.getResultNumber();
1491  }
1492  if (llvm::isa<TransformValueHandleTypeInterface>(res.getType()) &&
1493  !ptr.is<Value>()) {
1494  return emitDiag() << "application of " << transformOpName
1495  << " expected to produce a Value for result #"
1496  << res.getResultNumber();
1497  }
1498  }
1499  return success();
1500 }
1501 
1502 template <typename T>
1504  return llvm::to_vector(llvm::map_range(
1505  range, [](transform::MappedValue value) { return value.get<T>(); }));
1506 }
1507 
1509  Operation *transformOp, TransformResults &transformResults,
1512  transposed.resize(transformOp->getNumResults());
1513  for (const ApplyToEachResultList &partialResults : results) {
1514  if (llvm::any_of(partialResults,
1515  [](MappedValue value) { return value.isNull(); }))
1516  continue;
1517  assert(transformOp->getNumResults() == partialResults.size() &&
1518  "expected as many partial results as op as results");
1519  for (auto [i, value] : llvm::enumerate(partialResults))
1520  transposed[i].push_back(value);
1521  }
1522 
1523  for (OpResult r : transformOp->getResults()) {
1524  unsigned position = r.getResultNumber();
1525  if (llvm::isa<TransformParamTypeInterface>(r.getType())) {
1526  transformResults.setParams(r,
1527  castVector<Attribute>(transposed[position]));
1528  } else if (llvm::isa<TransformValueHandleTypeInterface>(r.getType())) {
1529  transformResults.setValues(r, castVector<Value>(transposed[position]));
1530  } else {
1531  transformResults.set(r, castVector<Operation *>(transposed[position]));
1532  }
1533  }
1534 }
1535 
1536 //===----------------------------------------------------------------------===//
1537 // Utilities for implementing transform ops with regions.
1538 //===----------------------------------------------------------------------===//
1539 
1542  ValueRange values, const transform::TransformState &state, bool flatten) {
1543  assert(mappings.size() == values.size() && "mismatching number of mappings");
1544  for (auto &&[operand, mapped] : llvm::zip_equal(values, mappings)) {
1545  size_t mappedSize = mapped.size();
1546  if (llvm::isa<TransformHandleTypeInterface>(operand.getType())) {
1547  llvm::append_range(mapped, state.getPayloadOps(operand));
1548  } else if (llvm::isa<TransformValueHandleTypeInterface>(
1549  operand.getType())) {
1550  llvm::append_range(mapped, state.getPayloadValues(operand));
1551  } else {
1552  assert(llvm::isa<TransformParamTypeInterface>(operand.getType()) &&
1553  "unsupported kind of transform dialect value");
1554  llvm::append_range(mapped, state.getParams(operand));
1555  }
1556 
1557  if (mapped.size() - mappedSize != 1 && !flatten)
1558  return failure();
1559  }
1560  return success();
1561 }
1562 
1565  ValueRange values, const transform::TransformState &state) {
1566  mappings.resize(mappings.size() + values.size());
1567  (void)appendValueMappings(
1569  values.size()),
1570  values, state);
1571 }
1572 
1574  Block *block, transform::TransformState &state,
1575  transform::TransformResults &results) {
1576  for (auto &&[terminatorOperand, result] :
1577  llvm::zip(block->getTerminator()->getOperands(),
1578  block->getParentOp()->getOpResults())) {
1579  if (llvm::isa<transform::TransformHandleTypeInterface>(result.getType())) {
1580  results.set(result, state.getPayloadOps(terminatorOperand));
1581  } else if (llvm::isa<transform::TransformValueHandleTypeInterface>(
1582  result.getType())) {
1583  results.setValues(result, state.getPayloadValues(terminatorOperand));
1584  } else {
1585  assert(
1586  llvm::isa<transform::TransformParamTypeInterface>(result.getType()) &&
1587  "unhandled transform type interface");
1588  results.setParams(result, state.getParams(terminatorOperand));
1589  }
1590  }
1591 }
1592 
1595  Operation *payloadRoot) {
1596  return TransformState(region, payloadRoot);
1597 }
1598 
1599 //===----------------------------------------------------------------------===//
1600 // Utilities for PossibleTopLevelTransformOpTrait.
1601 //===----------------------------------------------------------------------===//
1602 
1603 /// Appends to `effects` the memory effect instances on `target` with the same
1604 /// resource and effect as the ones the operation `iface` having on `source`.
1605 static void
1606 remapEffects(MemoryEffectOpInterface iface, BlockArgument source,
1607  OpOperand *target,
1610  iface.getEffectsOnValue(source, nestedEffects);
1611  for (const auto &effect : nestedEffects)
1612  effects.emplace_back(effect.getEffect(), target, effect.getResource());
1613 }
1614 
1615 /// Appends to `effects` the same effects as the operations of `block` have on
1616 /// block arguments but associated with `operands.`
1617 static void
1620  for (Operation &op : block) {
1621  auto iface = dyn_cast<MemoryEffectOpInterface>(&op);
1622  if (!iface)
1623  continue;
1624 
1625  for (auto &&[source, target] : llvm::zip(block.getArguments(), operands)) {
1626  remapEffects(iface, source, &target, effects);
1627  }
1628 
1630  iface.getEffectsOnResource(transform::PayloadIRResource::get(),
1631  nestedEffects);
1632  llvm::append_range(effects, nestedEffects);
1633  }
1634 }
1635 
1637  Operation *operation, Value root, Block &body,
1639  transform::onlyReadsHandle(operation->getOpOperands(), effects);
1640  transform::producesHandle(operation->getOpResults(), effects);
1641 
1642  if (!root) {
1643  for (Operation &op : body) {
1644  auto iface = dyn_cast<MemoryEffectOpInterface>(&op);
1645  if (!iface)
1646  continue;
1647 
1649  iface.getEffects(effects);
1650  }
1651  return;
1652  }
1653 
1654  // Carry over all effects on arguments of the entry block as those on the
1655  // operands, this is the same value just remapped.
1656  remapArgumentEffects(body, operation->getOpOperands(), effects);
1657 }
1658 
1660  TransformState &state, Operation *op, Region &region) {
1661  SmallVector<Operation *> targets;
1662  SmallVector<SmallVector<MappedValue>> extraMappings;
1663  if (op->getNumOperands() != 0) {
1664  llvm::append_range(targets, state.getPayloadOps(op->getOperand(0)));
1665  prepareValueMappings(extraMappings, op->getOperands().drop_front(), state);
1666  } else {
1667  if (state.getNumTopLevelMappings() !=
1668  region.front().getNumArguments() - 1) {
1669  return emitError(op->getLoc())
1670  << "operation expects " << region.front().getNumArguments() - 1
1671  << " extra value bindings, but " << state.getNumTopLevelMappings()
1672  << " were provided to the interpreter";
1673  }
1674 
1675  targets.push_back(state.getTopLevel());
1676 
1677  for (unsigned i = 0, e = state.getNumTopLevelMappings(); i < e; ++i)
1678  extraMappings.push_back(llvm::to_vector(state.getTopLevelMapping(i)));
1679  }
1680 
1681  if (failed(state.mapBlockArguments(region.front().getArgument(0), targets)))
1682  return failure();
1683 
1684  for (BlockArgument argument : region.front().getArguments().drop_front()) {
1685  if (failed(state.mapBlockArgument(
1686  argument, extraMappings[argument.getArgNumber() - 1])))
1687  return failure();
1688  }
1689 
1690  return success();
1691 }
1692 
1695  // Attaching this trait without the interface is a misuse of the API, but it
1696  // cannot be caught via a static_assert because interface registration is
1697  // dynamic.
1698  assert(isa<TransformOpInterface>(op) &&
1699  "should implement TransformOpInterface to have "
1700  "PossibleTopLevelTransformOpTrait");
1701 
1702  if (op->getNumRegions() < 1)
1703  return op->emitOpError() << "expects at least one region";
1704 
1705  Region *bodyRegion = &op->getRegion(0);
1706  if (!llvm::hasNItems(*bodyRegion, 1))
1707  return op->emitOpError() << "expects a single-block region";
1708 
1709  Block *body = &bodyRegion->front();
1710  if (body->getNumArguments() == 0) {
1711  return op->emitOpError()
1712  << "expects the entry block to have at least one argument";
1713  }
1714  if (!llvm::isa<TransformHandleTypeInterface>(
1715  body->getArgument(0).getType())) {
1716  return op->emitOpError()
1717  << "expects the first entry block argument to be of type "
1718  "implementing TransformHandleTypeInterface";
1719  }
1720  BlockArgument arg = body->getArgument(0);
1721  if (op->getNumOperands() != 0) {
1722  if (arg.getType() != op->getOperand(0).getType()) {
1723  return op->emitOpError()
1724  << "expects the type of the block argument to match "
1725  "the type of the operand";
1726  }
1727  }
1728  for (BlockArgument arg : body->getArguments().drop_front()) {
1729  if (llvm::isa<TransformHandleTypeInterface, TransformParamTypeInterface,
1730  TransformValueHandleTypeInterface>(arg.getType()))
1731  continue;
1732 
1734  op->emitOpError()
1735  << "expects trailing entry block arguments to be of type implementing "
1736  "TransformHandleTypeInterface, TransformValueHandleTypeInterface or "
1737  "TransformParamTypeInterface";
1738  diag.attachNote() << "argument #" << arg.getArgNumber() << " does not";
1739  return diag;
1740  }
1741 
1742  if (auto *parent =
1744  if (op->getNumOperands() != body->getNumArguments()) {
1746  op->emitOpError()
1747  << "expects operands to be provided for a nested op";
1748  diag.attachNote(parent->getLoc())
1749  << "nested in another possible top-level op";
1750  return diag;
1751  }
1752  }
1753 
1754  return success();
1755 }
1756 
1757 //===----------------------------------------------------------------------===//
1758 // Utilities for ParamProducedTransformOpTrait.
1759 //===----------------------------------------------------------------------===//
1760 
1763  producesHandle(op->getResults(), effects);
1764  bool hasPayloadOperands = false;
1765  for (OpOperand &operand : op->getOpOperands()) {
1766  onlyReadsHandle(operand, effects);
1767  if (llvm::isa<TransformHandleTypeInterface,
1768  TransformValueHandleTypeInterface>(operand.get().getType()))
1769  hasPayloadOperands = true;
1770  }
1771  if (hasPayloadOperands)
1772  onlyReadsPayload(effects);
1773 }
1774 
1777  // Interfaces can be attached dynamically, so this cannot be a static
1778  // assert.
1779  if (!op->getName().getInterface<MemoryEffectOpInterface>()) {
1780  llvm::report_fatal_error(
1781  Twine("ParamProducerTransformOpTrait must be attached to an op that "
1782  "implements MemoryEffectsOpInterface, found on ") +
1783  op->getName().getStringRef());
1784  }
1785  for (Value result : op->getResults()) {
1786  if (llvm::isa<TransformParamTypeInterface>(result.getType()))
1787  continue;
1788  return op->emitOpError()
1789  << "ParamProducerTransformOpTrait attached to this op expects "
1790  "result types to implement TransformParamTypeInterface";
1791  }
1792  return success();
1793 }
1794 
1795 //===----------------------------------------------------------------------===//
1796 // Memory effects.
1797 //===----------------------------------------------------------------------===//
1798 
1802  for (OpOperand &handle : handles) {
1803  effects.emplace_back(MemoryEffects::Read::get(), &handle,
1805  effects.emplace_back(MemoryEffects::Free::get(), &handle,
1807  }
1808 }
1809 
1810 /// Returns `true` if the given list of effects instances contains an instance
1811 /// with the effect type specified as template parameter.
1812 template <typename EffectTy, typename ResourceTy, typename Range>
1813 static bool hasEffect(Range &&effects) {
1814  return llvm::any_of(effects, [](const MemoryEffects::EffectInstance &effect) {
1815  return isa<EffectTy>(effect.getEffect()) &&
1816  isa<ResourceTy>(effect.getResource());
1817  });
1818 }
1819 
1821  transform::TransformOpInterface transform) {
1822  auto iface = cast<MemoryEffectOpInterface>(transform.getOperation());
1824  iface.getEffectsOnValue(handle, effects);
1825  return ::hasEffect<MemoryEffects::Read, TransformMappingResource>(effects) &&
1826  ::hasEffect<MemoryEffects::Free, TransformMappingResource>(effects);
1827 }
1828 
1830  ResultRange handles,
1832  for (OpResult handle : handles) {
1833  effects.emplace_back(MemoryEffects::Allocate::get(), handle,
1835  effects.emplace_back(MemoryEffects::Write::get(), handle,
1837  }
1838 }
1839 
1843  for (BlockArgument handle : handles) {
1844  effects.emplace_back(MemoryEffects::Allocate::get(), handle,
1846  effects.emplace_back(MemoryEffects::Write::get(), handle,
1848  }
1849 }
1850 
1854  for (OpOperand &handle : handles) {
1855  effects.emplace_back(MemoryEffects::Read::get(), &handle,
1857  }
1858 }
1859 
1862  effects.emplace_back(MemoryEffects::Read::get(), PayloadIRResource::get());
1863  effects.emplace_back(MemoryEffects::Write::get(), PayloadIRResource::get());
1864 }
1865 
1868  effects.emplace_back(MemoryEffects::Read::get(), PayloadIRResource::get());
1869 }
1870 
1871 bool transform::doesModifyPayload(transform::TransformOpInterface transform) {
1872  auto iface = cast<MemoryEffectOpInterface>(transform.getOperation());
1874  iface.getEffects(effects);
1875  return ::hasEffect<MemoryEffects::Write, PayloadIRResource>(effects);
1876 }
1877 
1878 bool transform::doesReadPayload(transform::TransformOpInterface transform) {
1879  auto iface = cast<MemoryEffectOpInterface>(transform.getOperation());
1881  iface.getEffects(effects);
1882  return ::hasEffect<MemoryEffects::Read, PayloadIRResource>(effects);
1883 }
1884 
1886  Block &block, llvm::SmallDenseSet<unsigned int> &consumedArguments) {
1888  for (Operation &nested : block) {
1889  auto iface = dyn_cast<MemoryEffectOpInterface>(nested);
1890  if (!iface)
1891  continue;
1892 
1893  effects.clear();
1894  iface.getEffects(effects);
1895  for (const MemoryEffects::EffectInstance &effect : effects) {
1896  BlockArgument argument =
1897  dyn_cast_or_null<BlockArgument>(effect.getValue());
1898  if (!argument || argument.getOwner() != &block ||
1899  !isa<MemoryEffects::Free>(effect.getEffect()) ||
1900  effect.getResource() != transform::TransformMappingResource::get()) {
1901  continue;
1902  }
1903  consumedArguments.insert(argument.getArgNumber());
1904  }
1905  }
1906 }
1907 
1908 //===----------------------------------------------------------------------===//
1909 // Utilities for TransformOpInterface.
1910 //===----------------------------------------------------------------------===//
1911 
1913  TransformOpInterface transformOp) {
1914  SmallVector<OpOperand *> consumedOperands;
1915  consumedOperands.reserve(transformOp->getNumOperands());
1916  auto memEffectInterface =
1917  cast<MemoryEffectOpInterface>(transformOp.getOperation());
1919  for (OpOperand &target : transformOp->getOpOperands()) {
1920  effects.clear();
1921  memEffectInterface.getEffectsOnValue(target.get(), effects);
1922  if (llvm::any_of(effects, [](const MemoryEffects::EffectInstance &effect) {
1923  return isa<transform::TransformMappingResource>(
1924  effect.getResource()) &&
1925  isa<MemoryEffects::Free>(effect.getEffect());
1926  })) {
1927  consumedOperands.push_back(&target);
1928  }
1929  }
1930  return consumedOperands;
1931 }
1932 
1934  auto iface = cast<MemoryEffectOpInterface>(op);
1936  iface.getEffects(effects);
1937 
1938  auto effectsOn = [&](Value value) {
1939  return llvm::make_filter_range(
1940  effects, [value](const MemoryEffects::EffectInstance &instance) {
1941  return instance.getValue() == value;
1942  });
1943  };
1944 
1945  std::optional<unsigned> firstConsumedOperand;
1946  for (OpOperand &operand : op->getOpOperands()) {
1947  auto range = effectsOn(operand.get());
1948  if (range.empty()) {
1950  op->emitError() << "TransformOpInterface requires memory effects "
1951  "on operands to be specified";
1952  diag.attachNote() << "no effects specified for operand #"
1953  << operand.getOperandNumber();
1954  return diag;
1955  }
1956  if (::hasEffect<MemoryEffects::Allocate, TransformMappingResource>(range)) {
1958  << "TransformOpInterface did not expect "
1959  "'allocate' memory effect on an operand";
1960  diag.attachNote() << "specified for operand #"
1961  << operand.getOperandNumber();
1962  return diag;
1963  }
1964  if (!firstConsumedOperand &&
1965  ::hasEffect<MemoryEffects::Free, TransformMappingResource>(range)) {
1966  firstConsumedOperand = operand.getOperandNumber();
1967  }
1968  }
1969 
1970  if (firstConsumedOperand &&
1971  !::hasEffect<MemoryEffects::Write, PayloadIRResource>(effects)) {
1973  op->emitError()
1974  << "TransformOpInterface expects ops consuming operands to have a "
1975  "'write' effect on the payload resource";
1976  diag.attachNote() << "consumes operand #" << *firstConsumedOperand;
1977  return diag;
1978  }
1979 
1980  for (OpResult result : op->getResults()) {
1981  auto range = effectsOn(result);
1982  if (!::hasEffect<MemoryEffects::Allocate, TransformMappingResource>(
1983  range)) {
1985  op->emitError() << "TransformOpInterface requires 'allocate' memory "
1986  "effect to be specified for results";
1987  diag.attachNote() << "no 'allocate' effect specified for result #"
1988  << result.getResultNumber();
1989  return diag;
1990  }
1991  }
1992 
1993  return success();
1994 }
1995 
1996 //===----------------------------------------------------------------------===//
1997 // Entry point.
1998 //===----------------------------------------------------------------------===//
1999 
2001  Operation *payloadRoot, TransformOpInterface transform,
2002  const RaggedArray<MappedValue> &extraMapping,
2003  const TransformOptions &options, bool enforceToplevelTransformOp) {
2004  if (enforceToplevelTransformOp) {
2005  if (!transform->hasTrait<PossibleTopLevelTransformOpTrait>() ||
2006  transform->getNumOperands() != 0) {
2007  return transform->emitError()
2008  << "expected transform to start at the top-level transform op";
2009  }
2010  } else if (failed(
2012  return failure();
2013  }
2014 
2015  TransformState state(transform->getParentRegion(), payloadRoot, extraMapping,
2016  options);
2017  return state.applyTransform(transform).checkAndReport();
2018 }
2019 
2020 //===----------------------------------------------------------------------===//
2021 // Generated interface implementation.
2022 //===----------------------------------------------------------------------===//
2023 
2024 #include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.cpp.inc"
2025 #include "mlir/Dialect/Transform/Interfaces/TransformTypeInterfaces.cpp.inc"
static InFlightDiagnostic emitDiag(Location location, DiagnosticSeverity severity, const Twine &message)
Helper function used to emit a diagnostic with an optionally empty twine message.
static std::string diag(const llvm::Value &value)
static llvm::ManagedStatic< PassManagerOptions > options
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
#define DEBUG_TYPE_FULL
#define FULL_LDBG(X)
static void remapEffects(MemoryEffectOpInterface iface, BlockArgument source, OpOperand *target, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
Appends to effects the memory effect instances on target with the same resource and effect as the one...
static void remapArgumentEffects(Block &block, MutableArrayRef< OpOperand > operands, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
Appends to effects the same effects as the operations of block have on block arguments but associated...
static bool happensBefore(Operation *a, Operation *b)
Return true if a happens before b, i.e., a or one of its ancestors properly dominates b and b is not ...
void dropMappingEntry(Mapping &mapping, Key key, Mapped mapped)
#define DEBUG_PRINT_AFTER_ALL
DiagnosedSilenceableFailure checkRepeatedConsumptionInOperand(ArrayRef< T > payload, transform::TransformOpInterface transform, unsigned operandNumber)
#define DBGS()
static DiagnosedSilenceableFailure dispatchMappedValues(Value handle, ArrayRef< transform::MappedValue > values, function_ref< LogicalResult(ArrayRef< Operation * >)> operationsFn, function_ref< LogicalResult(ArrayRef< transform::Param >)> paramsFn, function_ref< LogicalResult(ValueRange)> valuesFn)
Given a list of MappedValues, cast them to the value kind implied by the interface of the handle type...
static SmallVector< T > castVector(ArrayRef< transform::MappedValue > range)
Attributes are known-constant values of operations.
Definition: Attributes.h:25
This class represents an argument of a Block.
Definition: Value.h:319
Block * getOwner() const
Returns the block that owns this argument.
Definition: Value.h:328
unsigned getArgNumber() const
Returns the number of this argument.
Definition: Value.h:331
Block represents an ordered list of Operations.
Definition: Block.h:31
Operation * findAncestorOpInBlock(Operation &op)
Returns 'op' if 'op' lies in this block, or otherwise finds the ancestor operation of 'op' that lies ...
Definition: Block.cpp:73
BlockArgument getArgument(unsigned i)
Definition: Block.h:127
unsigned getNumArguments()
Definition: Block.h:126
Operation * getTerminator()
Get the terminator operation of this block.
Definition: Block.cpp:243
BlockArgListType getArguments()
Definition: Block.h:85
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
Definition: Block.cpp:30
A compatibility class connecting InFlightDiagnostic to DiagnosedSilenceableFailure while providing an...
The result of a transform IR operation application.
LogicalResult silence()
Converts silenceable failure into LogicalResult success without reporting the diagnostic,...
static DiagnosedSilenceableFailure success()
Constructs a DiagnosedSilenceableFailure in the success state.
Diagnostic & attachNote(std::optional< Location > loc=std::nullopt)
Attaches a note to the last diagnostic.
std::string getMessage() const
Returns the diagnostic message without emitting it.
bool isDefiniteFailure() const
Returns true if this is a definite failure.
LogicalResult checkAndReport()
Converts all kinds of failure into a LogicalResult failure, emitting the diagnostic if necessary.
static DiagnosedSilenceableFailure silenceableFailure(Diagnostic &&diag)
Constructs a DiagnosedSilenceableFailure in the silenceable failure state, ready to emit the given di...
bool succeeded() const
Returns true if this is a success.
static DiagnosedSilenceableFailure definiteFailure()
Constructs a DiagnosedSilenceableFailure in the failure state.
bool isSilenceableFailure() const
Returns true if this is a silenceable failure.
This class contains all of the information necessary to report a diagnostic to the DiagnosticEngine.
Definition: Diagnostics.h:156
IRValueT get() const
Return the current value being used by this operand.
Definition: UseDefLists.h:160
This class represents a diagnostic that is inflight and set to be reported.
Definition: Diagnostics.h:308
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
void setListener(Listener *newListener)
Sets the listener of this builder to the one provided.
Definition: Builders.h:318
This class represents an operand of an operation.
Definition: Value.h:267
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Definition: Value.cpp:216
Set of flags used to control the behavior of the various IR print methods (e.g.
This is a value defined by a result of an operation.
Definition: Value.h:457
unsigned getResultNumber() const
Returns the number of this result.
Definition: Value.h:469
This class provides the API for a sub-set of ops that are known to be constant-like.
This class provides the API for ops that are known to be isolated from above.
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
StringAttr getIdentifier() const
Return the name of this operation as a StringAttr.
T::Concept * getInterface() const
Returns an instance of the concept object for the given interface if it was registered to this operat...
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Value getOperand(unsigned idx)
Definition: Operation.h:345
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition: Operation.h:745
bool isBeforeInBlock(Operation *other)
Given an operation 'other' that is within the same parent block, return whether the current operation...
Definition: Operation.cpp:386
unsigned getNumRegions()
Returns the number of regions held by this operation.
Definition: Operation.h:669
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
unsigned getNumOperands()
Definition: Operation.h:341
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition: Operation.h:234
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:268
Block * getBlock()
Returns the operation block that contains this operation.
Definition: Operation.h:213
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Definition: Operation.h:682
Operation * getParentWithTrait()
Returns the closest surrounding parent operation with trait Trait.
Definition: Operation.h:248
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:119
MutableArrayRef< OpOperand > getOpOperands()
Definition: Operation.h:378
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:373
result_range getOpResults()
Definition: Operation.h:415
result_range getResults()
Definition: Operation.h:410
bool isProperAncestor(Operation *other)
Return true if this operation is a proper ancestor of the other operation.
Definition: Operation.cpp:219
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:671
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:399
A 2D array where each row may have different length.
Definition: RaggedArray.h:18
size_t size() const
Returns the number of rows in the 2D array.
Definition: RaggedArray.h:21
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
Operation * getParentOp()
Return the parent operation this region is attached to.
Definition: Region.h:200
Block & front()
Definition: Region.h:65
This class implements the result iterators for the Operation class.
Definition: ValueRange.h:242
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:400
This class represents a specific instance of an effect.
Resource * getResource() const
Return the resource that the effect applies to.
EffectT * getEffect() const
Return the effect being applied.
Value getValue() const
Return the value the effect is applied on, or nullptr if there isn't a known value being affected.
static DerivedEffect * get()
Returns a unique instance for the derived effect class.
static TransformMappingResource * get()
Returns a unique instance for the given effect class.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:26
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
Operation * getOwner() const
Return the owner of this operand.
Definition: UseDefLists.h:38
A list of results of applying a transform op with ApplyEachOpTrait to a single payload operation,...
size_t size() const
Returns the number of elements in the list.
A specialized listener that keeps track of cases in which no replacement payload could be found.
bool failed() const
Return "true" if this tracking listener had a failure.
void notifyPayloadReplacementNotFound(Operation *op, ValueRange values, DiagnosedSilenceableFailure &&diag) override
This function is called when a tracked payload op is dropped because no replacement op was found.
DiagnosedSilenceableFailure checkAndResetError()
Check and return the current error state of this listener.
This trait is supposed to be attached to Transform dialect operations that can be standalone top-leve...
TrackingListener failures are reported only for ops that have this trait.
void notifyMatchFailure(Location loc, function_ref< void(Diagnostic &)> reasonCallback) override
Notify the listener that the pattern failed to match the given operation, and provide a callback to p...
TrackingListener(TransformState &state, TransformOpInterface op, TrackingListenerConfig config=TrackingListenerConfig())
Create a new TrackingListener for usage in the specified transform op.
static Operation * getCommonDefiningOp(ValueRange values)
Return the single op that defines all given values (if any).
virtual DiagnosedSilenceableFailure findReplacementOp(Operation *&result, Operation *op, ValueRange newValues) const
Return a replacement payload op for the given op, which is going to be replaced with the given values...
Options controlling the application of transform operations by the TransformState.
Local mapping between values defined by a specific op implementing the TransformOpInterface and the p...
void setValues(OpResult handle, Range &&values)
Indicates that the result of the transform IR op at the given position corresponds to the given range...
void setParams(OpResult value, ArrayRef< TransformState::Param > params)
Indicates that the result of the transform IR op at the given position corresponds to the given list ...
void set(OpResult value, Range &&ops)
Indicates that the result of the transform IR op at the given position corresponds to the given list ...
void setRemainingToEmpty(TransformOpInterface transform)
Sets the currently unset results to empty lists of the kind expected by the corresponding results of ...
void setMappedValues(OpResult handle, ArrayRef< MappedValue > values)
Indicates that the result of the transform IR op at the given position corresponds to the given range...
This is a special rewriter to be used in transform op implementations, providing additional helper fu...
TransformRewriter(MLIRContext *ctx, ErrorCheckingTrackingListener *listener)
Create a new TransformRewriter.
bool hasTrackingFailures() const
Return "true" if the tracking listener had failures.
LogicalResult notifyPayloadOperationReplaced(Operation *op, Operation *replacement)
Notify the transform dialect interpreter that the given op has been replaced with another op and that...
void silenceTrackingFailure()
Silence all tracking failures that have been encountered so far.
LogicalResult replacePayloadOp(Operation *op, Operation *replacement)
Replaces the given payload op with another op.
virtual ~Extension()
Base virtual destructor.
LogicalResult replacePayloadValue(Value value, Value replacement)
Replaces the given payload value with another value.
A RAII object maintaining a "stack frame" for a transform IR region.
~RegionScope()
Forgets the mapping from or to values defined in the associated transform IR region,...
The state maintained across applications of various ops implementing the TransformOpInterface.
LogicalResult getHandlesForPayloadValue(Value payloadValue, SmallVectorImpl< Value > &handles, bool includeOutOfScope=false) const
Populates handles with all handles pointing to the given payload IR value.
LogicalResult mapBlockArguments(BlockArgument argument, ArrayRef< Operation * > operations)
Records the mapping between a block argument in the transform IR and a list of operations in the payl...
DiagnosedSilenceableFailure applyTransform(TransformOpInterface transform)
Applies the transformation specified by the given transform op and updates the state accordingly.
ArrayRef< Attribute > getParams(Value value) const
Returns the list of parameters that the given transform IR value corresponds to.
LogicalResult mapBlockArgument(BlockArgument argument, ArrayRef< MappedValue > values)
LogicalResult getHandlesForPayloadOp(Operation *op, SmallVectorImpl< Value > &handles, bool includeOutOfScope=false) const
Populates handles with all handles pointing to the given Payload IR op.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:285
LogicalResult verifyTransformOpInterface(Operation *op)
Verification hook for TransformOpInterface.
LogicalResult appendValueMappings(MutableArrayRef< SmallVector< transform::MappedValue >> mappings, ValueRange values, const transform::TransformState &state, bool flatten=true)
Appends the entities associated with the given transform values in state to the pre-existing list of ...
void setApplyToOneResults(Operation *transformOp, TransformResults &transformResults, ArrayRef< ApplyToEachResultList > results)
"Transpose" the results produced by individual applications, arranging them per result value of the t...
void forwardTerminatorOperands(Block *block, transform::TransformState &state, transform::TransformResults &results)
Populates results with payload associations that match exactly those of the operands to block's termi...
void getParamProducerTransformOpTraitEffects(Operation *op, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
Non-template implementation of ParamProducerTransformOpTrait::getEffects().
LogicalResult checkNestedConsumption(Location loc, ArrayRef< Operation * > targets)
Reports an error and returns failure if targets contains an ancestor operation before its descendant ...
LogicalResult mapPossibleTopLevelTransformOpBlockArguments(TransformState &state, Operation *op, Region &region)
Maps the only block argument of the op with PossibleTopLevelTransformOpTrait to either the list of op...
TransformState makeTransformStateForTesting(Region *region, Operation *payloadRoot)
Make a dummy transform state for testing purposes.
LogicalResult checkApplyToOne(Operation *transformOp, Location payloadOpLoc, const ApplyToEachResultList &partialResult)
Check that the contents of partialResult matches the number, kind (payload op or parameter) and nulli...
SmallVector< OpOperand * > getConsumedHandleOpOperands(transform::TransformOpInterface transformOp)
Returns all operands that are handles and being consumed by the given op.
LogicalResult verifyParamProducerTransformOpTrait(Operation *op)
Non-template implementation of ParamProducerTransformOpTrait::verify().
void prepareValueMappings(SmallVectorImpl< SmallVector< transform::MappedValue >> &mappings, ValueRange values, const transform::TransformState &state)
Populates mappings with mapped values associated with the given transform IR values in the given stat...
LogicalResult verifyPossibleTopLevelTransformOpTrait(Operation *op)
Verification hook for PossibleTopLevelTransformOpTrait.
void getPotentialTopLevelEffects(Operation *operation, Value root, Block &body, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
Populates effects with side effects implied by PossibleTopLevelTransformOpTrait for the given operati...
void onlyReadsPayload(SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
bool isHandleConsumed(Value handle, transform::TransformOpInterface transform)
Checks whether the transform op consumes the given handle.
void getConsumedBlockArguments(Block &block, llvm::SmallDenseSet< unsigned > &consumedArguments)
Populates consumedArguments with positions of block arguments that are consumed by the operations in ...
LogicalResult applyTransforms(Operation *payloadRoot, TransformOpInterface transform, const RaggedArray< MappedValue > &extraMapping={}, const TransformOptions &options=TransformOptions(), bool enforceToplevelTransformOp=true)
Entry point to the Transform dialect infrastructure.
bool doesModifyPayload(transform::TransformOpInterface transform)
Checks whether the transform op modifies the payload.
void producesHandle(ResultRange handles, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
bool doesReadPayload(transform::TransformOpInterface transform)
Checks whether the transform op reads the payload.
void consumesHandle(MutableArrayRef< OpOperand > handles, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
Populates effects with the memory effects indicating the operation on the given handle value:
void onlyReadsHandle(MutableArrayRef< OpOperand > handles, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
void modifiesPayload(SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
Populates effects with the memory effects indicating the access to payload IR resource.
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
DiagnosedSilenceableFailure emitSilenceableFailure(Location loc, const Twine &message={})
Emits a silenceable failure with the given message.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
DiagnosedDefiniteFailure emitDefiniteFailure(Location loc, const Twine &message={})
Emits a definite failure with the given message.
bool hasEffect(Operation *op)
Returns true if op has an effect of type EffectTy.
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
Represents a range (offset, size, and stride) where each element of the triple may be dynamic or stat...
A configuration object for customizing a TrackingListener.
SkipHandleFn skipHandleFn
An optional function that returns "true" for handles that do not have to be updated.