MLIR  19.0.0git
TransformInterfaces.cpp
Go to the documentation of this file.
1 //===- TransformInterfaces.cpp - Transform Dialect Interfaces -------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
11 #include "mlir/IR/Diagnostics.h"
12 #include "mlir/IR/Operation.h"
13 #include "mlir/IR/PatternMatch.h"
17 #include "llvm/ADT/STLExtras.h"
18 #include "llvm/ADT/ScopeExit.h"
19 #include "llvm/Support/Debug.h"
20 #include "llvm/Support/ErrorHandling.h"
21 
22 #define DEBUG_TYPE "transform-dialect"
23 #define DEBUG_TYPE_FULL "transform-dialect-full"
24 #define DEBUG_PRINT_AFTER_ALL "transform-dialect-print-top-level-after-all"
25 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "] ")
26 #define LDBG(X) LLVM_DEBUG(DBGS() << (X))
27 #define FULL_LDBG(X) DEBUG_WITH_TYPE(DEBUG_TYPE_FULL, (DBGS() << (X)))
28 
29 using namespace mlir;
30 
31 //===----------------------------------------------------------------------===//
32 // Helper functions
33 //===----------------------------------------------------------------------===//
34 
35 /// Return true if `a` happens before `b`, i.e., `a` or one of its ancestors
36 /// properly dominates `b` and `b` is not inside `a`.
37 static bool happensBefore(Operation *a, Operation *b) {
38  do {
39  if (a->isProperAncestor(b))
40  return false;
41  if (Operation *bAncestor = a->getBlock()->findAncestorOpInBlock(*b)) {
42  return a->isBeforeInBlock(bAncestor);
43  }
44  } while ((a = a->getParentOp()));
45  return false;
46 }
47 
48 //===----------------------------------------------------------------------===//
49 // TransformState
50 //===----------------------------------------------------------------------===//
51 
52 constexpr const Value transform::TransformState::kTopLevelValue;
53 
54 transform::TransformState::TransformState(
55  Region *region, Operation *payloadRoot,
56  const RaggedArray<MappedValue> &extraMappings,
57  const TransformOptions &options)
58  : topLevel(payloadRoot), options(options) {
59  topLevelMappedValues.reserve(extraMappings.size());
60  for (ArrayRef<MappedValue> mapping : extraMappings)
61  topLevelMappedValues.push_back(mapping);
62  if (region) {
63  RegionScope *scope = new RegionScope(*this, *region);
64  topLevelRegionScope.reset(scope);
65  }
66 }
67 
68 Operation *transform::TransformState::getTopLevel() const { return topLevel; }
69 
71 transform::TransformState::getPayloadOpsView(Value value) const {
72  const TransformOpMapping &operationMapping = getMapping(value).direct;
73  auto iter = operationMapping.find(value);
74  assert(iter != operationMapping.end() &&
75  "cannot find mapping for payload handle (param/value handle "
76  "provided?)");
77  return iter->getSecond();
78 }
79 
81  const ParamMapping &mapping = getMapping(value).params;
82  auto iter = mapping.find(value);
83  assert(iter != mapping.end() && "cannot find mapping for param handle "
84  "(operation/value handle provided?)");
85  return iter->getSecond();
86 }
87 
89 transform::TransformState::getPayloadValuesView(Value handleValue) const {
90  const ValueMapping &mapping = getMapping(handleValue).values;
91  auto iter = mapping.find(handleValue);
92  assert(iter != mapping.end() && "cannot find mapping for value handle "
93  "(param/operation handle provided?)");
94  return iter->getSecond();
95 }
96 
98  Operation *op, SmallVectorImpl<Value> &handles,
99  bool includeOutOfScope) const {
100  bool found = false;
101  for (const auto &[region, mapping] : llvm::reverse(mappings)) {
102  auto iterator = mapping->reverse.find(op);
103  if (iterator != mapping->reverse.end()) {
104  llvm::append_range(handles, iterator->getSecond());
105  found = true;
106  }
107  // Stop looking when reaching a region that is isolated from above.
108  if (!includeOutOfScope &&
110  break;
111  }
112 
113  return success(found);
114 }
115 
117  Value payloadValue, SmallVectorImpl<Value> &handles,
118  bool includeOutOfScope) const {
119  bool found = false;
120  for (const auto &[region, mapping] : llvm::reverse(mappings)) {
121  auto iterator = mapping->reverseValues.find(payloadValue);
122  if (iterator != mapping->reverseValues.end()) {
123  llvm::append_range(handles, iterator->getSecond());
124  found = true;
125  }
126  // Stop looking when reaching a region that is isolated from above.
127  if (!includeOutOfScope &&
129  break;
130  }
131 
132  return success(found);
133 }
134 
135 /// Given a list of MappedValues, cast them to the value kind implied by the
136 /// interface of the handle type, and dispatch to one of the callbacks.
142  if (llvm::isa<transform::TransformHandleTypeInterface>(handle.getType())) {
143  SmallVector<Operation *> operations;
144  operations.reserve(values.size());
145  for (transform::MappedValue value : values) {
146  if (auto *op = llvm::dyn_cast_if_present<Operation *>(value)) {
147  operations.push_back(op);
148  continue;
149  }
150  return emitSilenceableFailure(handle.getLoc())
151  << "wrong kind of value provided for top-level operation handle";
152  }
153  if (failed(operationsFn(operations)))
156  }
157 
158  if (llvm::isa<transform::TransformValueHandleTypeInterface>(
159  handle.getType())) {
160  SmallVector<Value> payloadValues;
161  payloadValues.reserve(values.size());
162  for (transform::MappedValue value : values) {
163  if (auto v = llvm::dyn_cast_if_present<Value>(value)) {
164  payloadValues.push_back(v);
165  continue;
166  }
167  return emitSilenceableFailure(handle.getLoc())
168  << "wrong kind of value provided for the top-level value handle";
169  }
170  if (failed(valuesFn(payloadValues)))
173  }
174 
175  assert(llvm::isa<transform::TransformParamTypeInterface>(handle.getType()) &&
176  "unsupported kind of block argument");
178  parameters.reserve(values.size());
179  for (transform::MappedValue value : values) {
180  if (auto attr = llvm::dyn_cast_if_present<Attribute>(value)) {
181  parameters.push_back(attr);
182  continue;
183  }
184  return emitSilenceableFailure(handle.getLoc())
185  << "wrong kind of value provided for top-level parameter";
186  }
187  if (failed(paramsFn(parameters)))
190 }
191 
194  ArrayRef<MappedValue> values) {
195  return dispatchMappedValues(
196  argument, values,
197  [&](ArrayRef<Operation *> operations) {
198  return setPayloadOps(argument, operations);
199  },
200  [&](ArrayRef<Param> params) {
201  return setParams(argument, params);
202  },
203  [&](ValueRange payloadValues) {
204  return setPayloadValues(argument, payloadValues);
205  })
206  .checkAndReport();
207 }
208 
210 transform::TransformState::setPayloadOps(Value value,
211  ArrayRef<Operation *> targets) {
212  assert(value != kTopLevelValue &&
213  "attempting to reset the transformation root");
214  assert(llvm::isa<TransformHandleTypeInterface>(value.getType()) &&
215  "wrong handle type");
216 
217  for (Operation *target : targets) {
218  if (target)
219  continue;
220  return emitError(value.getLoc())
221  << "attempting to assign a null payload op to this transform value";
222  }
223 
224  auto iface = llvm::cast<TransformHandleTypeInterface>(value.getType());
226  iface.checkPayload(value.getLoc(), targets);
227  if (failed(result.checkAndReport()))
228  return failure();
229 
230  // Setting new payload for the value without cleaning it first is a misuse of
231  // the API, assert here.
232  SmallVector<Operation *> storedTargets(targets.begin(), targets.end());
233  Mappings &mappings = getMapping(value);
234  bool inserted =
235  mappings.direct.insert({value, std::move(storedTargets)}).second;
236  assert(inserted && "value is already associated with another list");
237  (void)inserted;
238 
239  for (Operation *op : targets)
240  mappings.reverse[op].push_back(value);
241 
242  return success();
243 }
244 
246 transform::TransformState::setPayloadValues(Value handle,
247  ValueRange payloadValues) {
248  assert(handle != nullptr && "attempting to set params for a null value");
249  assert(llvm::isa<TransformValueHandleTypeInterface>(handle.getType()) &&
250  "wrong handle type");
251 
252  for (Value payload : payloadValues) {
253  if (payload)
254  continue;
255  return emitError(handle.getLoc()) << "attempting to assign a null payload "
256  "value to this transform handle";
257  }
258 
259  auto iface = llvm::cast<TransformValueHandleTypeInterface>(handle.getType());
260  SmallVector<Value> payloadValueVector = llvm::to_vector(payloadValues);
262  iface.checkPayload(handle.getLoc(), payloadValueVector);
263  if (failed(result.checkAndReport()))
264  return failure();
265 
266  Mappings &mappings = getMapping(handle);
267  bool inserted =
268  mappings.values.insert({handle, std::move(payloadValueVector)}).second;
269  assert(
270  inserted &&
271  "value handle is already associated with another list of payload values");
272  (void)inserted;
273 
274  for (Value payload : payloadValues)
275  mappings.reverseValues[payload].push_back(handle);
276 
277  return success();
278 }
279 
280 LogicalResult transform::TransformState::setParams(Value value,
281  ArrayRef<Param> params) {
282  assert(value != nullptr && "attempting to set params for a null value");
283 
284  for (Attribute attr : params) {
285  if (attr)
286  continue;
287  return emitError(value.getLoc())
288  << "attempting to assign a null parameter to this transform value";
289  }
290 
291  auto valueType = llvm::dyn_cast<TransformParamTypeInterface>(value.getType());
292  assert(value &&
293  "cannot associate parameter with a value of non-parameter type");
295  valueType.checkPayload(value.getLoc(), params);
296  if (failed(result.checkAndReport()))
297  return failure();
298 
299  Mappings &mappings = getMapping(value);
300  bool inserted =
301  mappings.params.insert({value, llvm::to_vector(params)}).second;
302  assert(inserted && "value is already associated with another list of params");
303  (void)inserted;
304  return success();
305 }
306 
307 template <typename Mapping, typename Key, typename Mapped>
308 void dropMappingEntry(Mapping &mapping, Key key, Mapped mapped) {
309  auto it = mapping.find(key);
310  if (it == mapping.end())
311  return;
312 
313  llvm::erase(it->getSecond(), mapped);
314  if (it->getSecond().empty())
315  mapping.erase(it);
316 }
317 
318 void transform::TransformState::forgetMapping(Value opHandle,
319  ValueRange origOpFlatResults,
320  bool allowOutOfScope) {
321  Mappings &mappings = getMapping(opHandle, allowOutOfScope);
322  for (Operation *op : mappings.direct[opHandle])
323  dropMappingEntry(mappings.reverse, op, opHandle);
324  mappings.direct.erase(opHandle);
325 #ifdef LLVM_ENABLE_ABI_BREAKING_CHECKS
326  // Payload IR is removed from the mapping. This invalidates the respective
327  // iterators.
328  mappings.incrementTimestamp(opHandle);
329 #endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
330 
331  for (Value opResult : origOpFlatResults) {
332  SmallVector<Value> resultHandles;
333  (void)getHandlesForPayloadValue(opResult, resultHandles);
334  for (Value resultHandle : resultHandles) {
335  Mappings &localMappings = getMapping(resultHandle);
336  dropMappingEntry(localMappings.values, resultHandle, opResult);
337 #ifdef LLVM_ENABLE_ABI_BREAKING_CHECKS
338  // Payload IR is removed from the mapping. This invalidates the respective
339  // iterators.
340  mappings.incrementTimestamp(resultHandle);
341 #endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
342  dropMappingEntry(localMappings.reverseValues, opResult, resultHandle);
343  }
344  }
345 }
346 
347 void transform::TransformState::forgetValueMapping(
348  Value valueHandle, ArrayRef<Operation *> payloadOperations) {
349  Mappings &mappings = getMapping(valueHandle);
350  for (Value payloadValue : mappings.reverseValues[valueHandle])
351  dropMappingEntry(mappings.reverseValues, payloadValue, valueHandle);
352  mappings.values.erase(valueHandle);
353 #ifdef LLVM_ENABLE_ABI_BREAKING_CHECKS
354  // Payload IR is removed from the mapping. This invalidates the respective
355  // iterators.
356  mappings.incrementTimestamp(valueHandle);
357 #endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
358 
359  for (Operation *payloadOp : payloadOperations) {
360  SmallVector<Value> opHandles;
361  (void)getHandlesForPayloadOp(payloadOp, opHandles);
362  for (Value opHandle : opHandles) {
363  Mappings &localMappings = getMapping(opHandle);
364  dropMappingEntry(localMappings.direct, opHandle, payloadOp);
365  dropMappingEntry(localMappings.reverse, payloadOp, opHandle);
366 
367 #ifdef LLVM_ENABLE_ABI_BREAKING_CHECKS
368  // Payload IR is removed from the mapping. This invalidates the respective
369  // iterators.
370  localMappings.incrementTimestamp(opHandle);
371 #endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
372  }
373  }
374 }
375 
377 transform::TransformState::replacePayloadOp(Operation *op,
378  Operation *replacement) {
379  // TODO: consider invalidating the handles to nested objects here.
380 
381 #ifndef NDEBUG
382  for (Value opResult : op->getResults()) {
383  SmallVector<Value> valueHandles;
384  (void)getHandlesForPayloadValue(opResult, valueHandles,
385  /*includeOutOfScope=*/true);
386  assert(valueHandles.empty() && "expected no mapping to old results");
387  }
388 #endif // NDEBUG
389 
390  // Drop the mapping between the op and all handles that point to it. Fail if
391  // there are no handles.
392  SmallVector<Value> opHandles;
393  if (failed(getHandlesForPayloadOp(op, opHandles, /*includeOutOfScope=*/true)))
394  return failure();
395  for (Value handle : opHandles) {
396  Mappings &mappings = getMapping(handle, /*allowOutOfScope=*/true);
397  dropMappingEntry(mappings.reverse, op, handle);
398  }
399 
400  // Replace the pointed-to object of all handles with the replacement object.
401  // In case a payload op was erased (replacement object is nullptr), a nullptr
402  // is stored in the mapping. These nullptrs are removed after each transform.
403  // Furthermore, nullptrs are not enumerated by payload op iterators. The
404  // relative order of ops is preserved.
405  //
406  // Removing an op from the mapping would be problematic because removing an
407  // element from an array invalidates iterators; merely changing the value of
408  // elements does not.
409  for (Value handle : opHandles) {
410  Mappings &mappings = getMapping(handle, /*allowOutOfScope=*/true);
411  auto it = mappings.direct.find(handle);
412  if (it == mappings.direct.end())
413  continue;
414 
415  SmallVector<Operation *, 2> &association = it->getSecond();
416  // Note that an operation may be associated with the handle more than once.
417  for (Operation *&mapped : association) {
418  if (mapped == op)
419  mapped = replacement;
420  }
421 
422  if (replacement) {
423  mappings.reverse[replacement].push_back(handle);
424  } else {
425  opHandlesToCompact.insert(handle);
426  }
427  }
428 
429  return success();
430 }
431 
433 transform::TransformState::replacePayloadValue(Value value, Value replacement) {
434  SmallVector<Value> valueHandles;
435  if (failed(getHandlesForPayloadValue(value, valueHandles,
436  /*includeOutOfScope=*/true)))
437  return failure();
438 
439  for (Value handle : valueHandles) {
440  Mappings &mappings = getMapping(handle, /*allowOutOfScope=*/true);
441  dropMappingEntry(mappings.reverseValues, value, handle);
442 
443  // If replacing with null, that is erasing the mapping, drop the mapping
444  // between the handles and the IR objects
445  if (!replacement) {
446  dropMappingEntry(mappings.values, handle, value);
447 #ifdef LLVM_ENABLE_ABI_BREAKING_CHECKS
448  // Payload IR is removed from the mapping. This invalidates the respective
449  // iterators.
450  mappings.incrementTimestamp(handle);
451 #endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
452  } else {
453  auto it = mappings.values.find(handle);
454  if (it == mappings.values.end())
455  continue;
456 
457  SmallVector<Value> &association = it->getSecond();
458  for (Value &mapped : association) {
459  if (mapped == value)
460  mapped = replacement;
461  }
462  mappings.reverseValues[replacement].push_back(handle);
463  }
464  }
465 
466  return success();
467 }
468 
469 void transform::TransformState::recordOpHandleInvalidationOne(
470  OpOperand &consumingHandle, ArrayRef<Operation *> potentialAncestors,
471  Operation *payloadOp, Value otherHandle, Value throughValue,
472  transform::TransformState::InvalidatedHandleMap &newlyInvalidated) const {
473  // If the op is associated with invalidated handle, skip the check as it
474  // may be reading invalid IR. This also ensures we report the first
475  // invalidation and not the last one.
476  if (invalidatedHandles.count(otherHandle) ||
477  newlyInvalidated.count(otherHandle))
478  return;
479 
480  FULL_LDBG("--recordOpHandleInvalidationOne\n");
481  DEBUG_WITH_TYPE(
483  llvm::interleaveComma(potentialAncestors, DBGS() << "--ancestors: ",
484  [](Operation *op) { llvm::dbgs() << *op; });
485  llvm::dbgs() << "\n");
486 
487  Operation *owner = consumingHandle.getOwner();
488  unsigned operandNo = consumingHandle.getOperandNumber();
489  for (Operation *ancestor : potentialAncestors) {
490  // clang-format off
491  DEBUG_WITH_TYPE(DEBUG_TYPE_FULL,
492  { (DBGS() << "----handle one ancestor: " << *ancestor << "\n"); });
493  DEBUG_WITH_TYPE(DEBUG_TYPE_FULL,
494  { (DBGS() << "----of payload with name: "
495  << payloadOp->getName().getIdentifier() << "\n"); });
496  DEBUG_WITH_TYPE(DEBUG_TYPE_FULL,
497  { (DBGS() << "----of payload: " << *payloadOp << "\n"); });
498  // clang-format on
499  if (!ancestor->isAncestor(payloadOp))
500  continue;
501 
502  // Make sure the error-reporting lambda doesn't capture anything
503  // by-reference because it will go out of scope. Additionally, extract
504  // location from Payload IR ops because the ops themselves may be
505  // deleted before the lambda gets called.
506  Location ancestorLoc = ancestor->getLoc();
507  Location opLoc = payloadOp->getLoc();
508  std::optional<Location> throughValueLoc =
509  throughValue ? std::make_optional(throughValue.getLoc()) : std::nullopt;
510  newlyInvalidated[otherHandle] = [ancestorLoc, opLoc, owner, operandNo,
511  otherHandle,
512  throughValueLoc](Location currentLoc) {
513  InFlightDiagnostic diag = emitError(currentLoc)
514  << "op uses a handle invalidated by a "
515  "previously executed transform op";
516  diag.attachNote(otherHandle.getLoc()) << "handle to invalidated ops";
517  diag.attachNote(owner->getLoc())
518  << "invalidated by this transform op that consumes its operand #"
519  << operandNo
520  << " and invalidates all handles to payload IR entities associated "
521  "with this operand and entities nested in them";
522  diag.attachNote(ancestorLoc) << "ancestor payload op";
523  diag.attachNote(opLoc) << "nested payload op";
524  if (throughValueLoc) {
525  diag.attachNote(*throughValueLoc)
526  << "consumed handle points to this payload value";
527  }
528  };
529  }
530 }
531 
532 void transform::TransformState::recordValueHandleInvalidationByOpHandleOne(
533  OpOperand &opHandle, ArrayRef<Operation *> potentialAncestors,
534  Value payloadValue, Value valueHandle,
535  transform::TransformState::InvalidatedHandleMap &newlyInvalidated) const {
536  // If the op is associated with invalidated handle, skip the check as it
537  // may be reading invalid IR. This also ensures we report the first
538  // invalidation and not the last one.
539  if (invalidatedHandles.count(valueHandle) ||
540  newlyInvalidated.count(valueHandle))
541  return;
542 
543  for (Operation *ancestor : potentialAncestors) {
544  Operation *definingOp;
545  std::optional<unsigned> resultNo;
546  unsigned argumentNo = std::numeric_limits<unsigned>::max();
547  unsigned blockNo = std::numeric_limits<unsigned>::max();
548  unsigned regionNo = std::numeric_limits<unsigned>::max();
549  if (auto opResult = llvm::dyn_cast<OpResult>(payloadValue)) {
550  definingOp = opResult.getOwner();
551  resultNo = opResult.getResultNumber();
552  } else {
553  auto arg = llvm::cast<BlockArgument>(payloadValue);
554  definingOp = arg.getParentBlock()->getParentOp();
555  argumentNo = arg.getArgNumber();
556  blockNo = std::distance(arg.getOwner()->getParent()->begin(),
557  arg.getOwner()->getIterator());
558  regionNo = arg.getOwner()->getParent()->getRegionNumber();
559  }
560  assert(definingOp && "expected the value to be defined by an op as result "
561  "or block argument");
562  if (!ancestor->isAncestor(definingOp))
563  continue;
564 
565  Operation *owner = opHandle.getOwner();
566  unsigned operandNo = opHandle.getOperandNumber();
567  Location ancestorLoc = ancestor->getLoc();
568  Location opLoc = definingOp->getLoc();
569  Location valueLoc = payloadValue.getLoc();
570  newlyInvalidated[valueHandle] = [valueHandle, owner, operandNo, resultNo,
571  argumentNo, blockNo, regionNo, ancestorLoc,
572  opLoc, valueLoc](Location currentLoc) {
573  InFlightDiagnostic diag = emitError(currentLoc)
574  << "op uses a handle invalidated by a "
575  "previously executed transform op";
576  diag.attachNote(valueHandle.getLoc()) << "invalidated handle";
577  diag.attachNote(owner->getLoc())
578  << "invalidated by this transform op that consumes its operand #"
579  << operandNo
580  << " and invalidates all handles to payload IR entities "
581  "associated with this operand and entities nested in them";
582  diag.attachNote(ancestorLoc)
583  << "ancestor op associated with the consumed handle";
584  if (resultNo) {
585  diag.attachNote(opLoc)
586  << "op defining the value as result #" << *resultNo;
587  } else {
588  diag.attachNote(opLoc)
589  << "op defining the value as block argument #" << argumentNo
590  << " of block #" << blockNo << " in region #" << regionNo;
591  }
592  diag.attachNote(valueLoc) << "payload value";
593  };
594  }
595 }
596 
597 void transform::TransformState::recordOpHandleInvalidation(
598  OpOperand &handle, ArrayRef<Operation *> potentialAncestors,
599  Value throughValue,
600  transform::TransformState::InvalidatedHandleMap &newlyInvalidated) const {
601 
602  if (potentialAncestors.empty()) {
603  DEBUG_WITH_TYPE(DEBUG_TYPE_FULL, {
604  (DBGS() << "----recording invalidation for empty handle: " << handle.get()
605  << "\n");
606  });
607 
608  Operation *owner = handle.getOwner();
609  unsigned operandNo = handle.getOperandNumber();
610  newlyInvalidated[handle.get()] = [owner, operandNo](Location currentLoc) {
611  InFlightDiagnostic diag = emitError(currentLoc)
612  << "op uses a handle associated with empty "
613  "payload and invalidated by a "
614  "previously executed transform op";
615  diag.attachNote(owner->getLoc())
616  << "invalidated by this transform op that consumes its operand #"
617  << operandNo;
618  };
619  return;
620  }
621 
622  // Iterate over the mapping and invalidate aliasing handles. This is quite
623  // expensive and only necessary for error reporting in case of transform
624  // dialect misuse with dangling handles. Iteration over the handles is based
625  // on the assumption that the number of handles is significantly less than the
626  // number of IR objects (operations and values). Alternatively, we could walk
627  // the IR nested in each payload op associated with the given handle and look
628  // for handles associated with each operation and value.
629  for (const auto &[region, mapping] : llvm::reverse(mappings)) {
630  // Go over all op handle mappings and mark as invalidated any handle
631  // pointing to any of the payload ops associated with the given handle or
632  // any op nested in them.
633  for (const auto &[payloadOp, otherHandles] : mapping->reverse) {
634  for (Value otherHandle : otherHandles)
635  recordOpHandleInvalidationOne(handle, potentialAncestors, payloadOp,
636  otherHandle, throughValue,
637  newlyInvalidated);
638  }
639  // Go over all value handle mappings and mark as invalidated any handle
640  // pointing to any result of the payload op associated with the given handle
641  // or any op nested in them. Similarly invalidate handles to argument of
642  // blocks belonging to any region of any payload op associated with the
643  // given handle or any op nested in them.
644  for (const auto &[payloadValue, valueHandles] : mapping->reverseValues) {
645  for (Value valueHandle : valueHandles)
646  recordValueHandleInvalidationByOpHandleOne(handle, potentialAncestors,
647  payloadValue, valueHandle,
648  newlyInvalidated);
649  }
650 
651  // Stop lookup when reaching a region that is isolated from above.
653  break;
654  }
655 }
656 
657 void transform::TransformState::recordValueHandleInvalidation(
658  OpOperand &valueHandle,
659  transform::TransformState::InvalidatedHandleMap &newlyInvalidated) const {
660  // Invalidate other handles to the same value.
661  for (Value payloadValue : getPayloadValuesView(valueHandle.get())) {
662  SmallVector<Value> otherValueHandles;
663  (void)getHandlesForPayloadValue(payloadValue, otherValueHandles);
664  for (Value otherHandle : otherValueHandles) {
665  Operation *owner = valueHandle.getOwner();
666  unsigned operandNo = valueHandle.getOperandNumber();
667  Location valueLoc = payloadValue.getLoc();
668  newlyInvalidated[otherHandle] = [otherHandle, owner, operandNo,
669  valueLoc](Location currentLoc) {
670  InFlightDiagnostic diag = emitError(currentLoc)
671  << "op uses a handle invalidated by a "
672  "previously executed transform op";
673  diag.attachNote(otherHandle.getLoc()) << "invalidated handle";
674  diag.attachNote(owner->getLoc())
675  << "invalidated by this transform op that consumes its operand #"
676  << operandNo
677  << " and invalidates handles to the same values as associated with "
678  "it";
679  diag.attachNote(valueLoc) << "payload value";
680  };
681  }
682 
683  if (auto opResult = llvm::dyn_cast<OpResult>(payloadValue)) {
684  Operation *payloadOp = opResult.getOwner();
685  recordOpHandleInvalidation(valueHandle, payloadOp, payloadValue,
686  newlyInvalidated);
687  } else {
688  auto arg = llvm::dyn_cast<BlockArgument>(payloadValue);
689  for (Operation &payloadOp : *arg.getOwner())
690  recordOpHandleInvalidation(valueHandle, &payloadOp, payloadValue,
691  newlyInvalidated);
692  }
693  }
694 }
695 
696 /// Checks that the operation does not use invalidated handles as operands.
697 /// Reports errors and returns failure if it does. Otherwise, invalidates the
698 /// handles consumed by the operation as well as any handles pointing to payload
699 /// IR operations nested in the operations associated with the consumed handles.
700 LogicalResult transform::TransformState::checkAndRecordHandleInvalidationImpl(
701  transform::TransformOpInterface transform,
702  transform::TransformState::InvalidatedHandleMap &newlyInvalidated) const {
703  FULL_LDBG("--Start checkAndRecordHandleInvalidation\n");
704  auto memoryEffectsIface =
705  cast<MemoryEffectOpInterface>(transform.getOperation());
707  memoryEffectsIface.getEffectsOnResource(
709 
710  for (OpOperand &target : transform->getOpOperands()) {
711  DEBUG_WITH_TYPE(DEBUG_TYPE_FULL, {
712  (DBGS() << "----iterate on handle: " << target.get() << "\n");
713  });
714  // If the operand uses an invalidated handle, report it. If the operation
715  // allows handles to point to repeated payload operations, only report
716  // pre-existing invalidation errors. Otherwise, also report invalidations
717  // caused by the current transform operation affecting its other operands.
718  auto it = invalidatedHandles.find(target.get());
719  auto nit = newlyInvalidated.find(target.get());
720  if (it != invalidatedHandles.end()) {
721  FULL_LDBG("--End checkAndRecordHandleInvalidation, found already "
722  "invalidated -> FAILURE\n");
723  return it->getSecond()(transform->getLoc()), failure();
724  }
725  if (!transform.allowsRepeatedHandleOperands() &&
726  nit != newlyInvalidated.end()) {
727  FULL_LDBG("--End checkAndRecordHandleInvalidation, found newly "
728  "invalidated (by this op) -> FAILURE\n");
729  return nit->getSecond()(transform->getLoc()), failure();
730  }
731 
732  // Invalidate handles pointing to the operations nested in the operation
733  // associated with the handle consumed by this operation.
734  auto consumesTarget = [&](const MemoryEffects::EffectInstance &effect) {
735  return isa<MemoryEffects::Free>(effect.getEffect()) &&
736  effect.getValue() == target.get();
737  };
738  if (llvm::any_of(effects, consumesTarget)) {
739  FULL_LDBG("----found consume effect\n");
740  if (llvm::isa<transform::TransformHandleTypeInterface>(
741  target.get().getType())) {
742  FULL_LDBG("----recordOpHandleInvalidation\n");
743  SmallVector<Operation *> payloadOps =
744  llvm::to_vector(getPayloadOps(target.get()));
745  recordOpHandleInvalidation(target, payloadOps, nullptr,
746  newlyInvalidated);
747  } else if (llvm::isa<transform::TransformValueHandleTypeInterface>(
748  target.get().getType())) {
749  FULL_LDBG("----recordValueHandleInvalidation\n");
750  recordValueHandleInvalidation(target, newlyInvalidated);
751  } else {
752  FULL_LDBG("----not a TransformHandle -> SKIP AND DROP ON THE FLOOR\n");
753  }
754  } else {
755  FULL_LDBG("----no consume effect -> SKIP\n");
756  }
757  }
758 
759  FULL_LDBG("--End checkAndRecordHandleInvalidation -> SUCCESS\n");
760  return success();
761 }
762 
763 LogicalResult transform::TransformState::checkAndRecordHandleInvalidation(
764  transform::TransformOpInterface transform) {
765  InvalidatedHandleMap newlyInvalidated;
766  LogicalResult checkResult =
767  checkAndRecordHandleInvalidationImpl(transform, newlyInvalidated);
768  invalidatedHandles.insert(std::make_move_iterator(newlyInvalidated.begin()),
769  std::make_move_iterator(newlyInvalidated.end()));
770  return checkResult;
771 }
772 
773 template <typename T>
776  transform::TransformOpInterface transform,
777  unsigned operandNumber) {
778  DenseSet<T> seen;
779  for (T p : payload) {
780  if (!seen.insert(p).second) {
782  transform.emitSilenceableError()
783  << "a handle passed as operand #" << operandNumber
784  << " and consumed by this operation points to a payload "
785  "entity more than once";
786  if constexpr (std::is_pointer_v<T>)
787  diag.attachNote(p->getLoc()) << "repeated target op";
788  else
789  diag.attachNote(p.getLoc()) << "repeated target value";
790  return diag;
791  }
792  }
794 }
795 
796 void transform::TransformState::compactOpHandles() {
797  for (Value handle : opHandlesToCompact) {
798  Mappings &mappings = getMapping(handle, /*allowOutOfScope=*/true);
799 #ifdef LLVM_ENABLE_ABI_BREAKING_CHECKS
800  if (llvm::find(mappings.direct[handle], nullptr) !=
801  mappings.direct[handle].end())
802  // Payload IR is removed from the mapping. This invalidates the respective
803  // iterators.
804  mappings.incrementTimestamp(handle);
805 #endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
806  llvm::erase(mappings.direct[handle], nullptr);
807  }
808  opHandlesToCompact.clear();
809 }
810 
812 transform::TransformState::applyTransform(TransformOpInterface transform) {
813  LLVM_DEBUG({
814  DBGS() << "applying: ";
815  transform->print(llvm::dbgs(), OpPrintingFlags().skipRegions());
816  llvm::dbgs() << "\n";
817  });
818  DEBUG_WITH_TYPE(DEBUG_TYPE_FULL,
819  DBGS() << "Top-level payload before application:\n"
820  << *getTopLevel() << "\n");
821  auto printOnFailureRAII = llvm::make_scope_exit([this] {
822  (void)this;
823  LLVM_DEBUG(DBGS() << "Failing Top-level payload:\n"; getTopLevel()->print(
824  llvm::dbgs(), mlir::OpPrintingFlags().printGenericOpForm()););
825  });
826 
827  // Set current transform op.
828  regionStack.back()->currentTransform = transform;
829 
830  // Expensive checks to detect invalid transform IR.
831  if (options.getExpensiveChecksEnabled()) {
832  FULL_LDBG("ExpensiveChecksEnabled\n");
833  if (failed(checkAndRecordHandleInvalidation(transform)))
835 
836  for (OpOperand &operand : transform->getOpOperands()) {
837  DEBUG_WITH_TYPE(DEBUG_TYPE_FULL, {
838  (DBGS() << "iterate on handle: " << operand.get() << "\n");
839  });
840  if (!isHandleConsumed(operand.get(), transform)) {
841  FULL_LDBG("--handle not consumed -> SKIP\n");
842  continue;
843  }
844  if (transform.allowsRepeatedHandleOperands()) {
845  FULL_LDBG("--op allows repeated handles -> SKIP\n");
846  continue;
847  }
848  FULL_LDBG("--handle is consumed\n");
849 
850  Type operandType = operand.get().getType();
851  if (llvm::isa<TransformHandleTypeInterface>(operandType)) {
852  FULL_LDBG("--checkRepeatedConsumptionInOperand for Operation*\n");
854  checkRepeatedConsumptionInOperand<Operation *>(
855  getPayloadOpsView(operand.get()), transform,
856  operand.getOperandNumber());
857  if (!check.succeeded()) {
858  FULL_LDBG("----FAILED\n");
859  return check;
860  }
861  } else if (llvm::isa<TransformValueHandleTypeInterface>(operandType)) {
862  FULL_LDBG("--checkRepeatedConsumptionInOperand For Value\n");
864  checkRepeatedConsumptionInOperand<Value>(
865  getPayloadValuesView(operand.get()), transform,
866  operand.getOperandNumber());
867  if (!check.succeeded()) {
868  FULL_LDBG("----FAILED\n");
869  return check;
870  }
871  } else {
872  FULL_LDBG("--not a TransformHandle -> SKIP AND DROP ON THE FLOOR\n");
873  }
874  }
875  }
876 
877  // Find which operands are consumed.
878  SmallVector<OpOperand *> consumedOperands =
879  transform.getConsumedHandleOpOperands();
880 
881  // Remember the results of the payload ops associated with the consumed
882  // op handles or the ops defining the value handles so we can drop the
883  // association with them later. This must happen here because the
884  // transformation may destroy or mutate them so we cannot traverse the payload
885  // IR after that.
886  SmallVector<Value> origOpFlatResults;
887  SmallVector<Operation *> origAssociatedOps;
888  for (OpOperand *opOperand : consumedOperands) {
889  Value operand = opOperand->get();
890  if (llvm::isa<TransformHandleTypeInterface>(operand.getType())) {
891  for (Operation *payloadOp : getPayloadOps(operand)) {
892  llvm::append_range(origOpFlatResults, payloadOp->getResults());
893  }
894  continue;
895  }
896  if (llvm::isa<TransformValueHandleTypeInterface>(operand.getType())) {
897  for (Value payloadValue : getPayloadValuesView(operand)) {
898  if (llvm::isa<OpResult>(payloadValue)) {
899  origAssociatedOps.push_back(payloadValue.getDefiningOp());
900  continue;
901  }
902  llvm::append_range(
903  origAssociatedOps,
904  llvm::map_range(*llvm::cast<BlockArgument>(payloadValue).getOwner(),
905  [](Operation &op) { return &op; }));
906  }
907  continue;
908  }
910  emitDefiniteFailure(transform->getLoc())
911  << "unexpectedly consumed a value that is not a handle as operand #"
912  << opOperand->getOperandNumber();
913  diag.attachNote(operand.getLoc())
914  << "value defined here with type " << operand.getType();
915  return diag;
916  }
917 
918  // Prepare rewriter and listener.
919  TrackingListenerConfig config;
920  config.skipHandleFn = [&](Value handle) {
921  // Skip handle if it is dead.
922  auto scopeIt =
923  llvm::find_if(llvm::reverse(regionStack), [&](RegionScope *scope) {
924  return handle.getParentRegion() == scope->region;
925  });
926  assert(scopeIt != regionStack.rend() &&
927  "could not find region scope for handle");
928  RegionScope *scope = *scopeIt;
929  for (Operation *user : handle.getUsers()) {
930  if (user != scope->currentTransform &&
931  !happensBefore(user, scope->currentTransform))
932  return false;
933  }
934  return true;
935  };
936  transform::ErrorCheckingTrackingListener trackingListener(*this, transform,
937  config);
938  transform::TransformRewriter rewriter(transform->getContext(),
939  &trackingListener);
940 
941  // Compute the result but do not short-circuit the silenceable failure case as
942  // we still want the handles to propagate properly so the "suppress" mode can
943  // proceed on a best effort basis.
944  transform::TransformResults results(transform->getNumResults());
945  DiagnosedSilenceableFailure result(transform.apply(rewriter, results, *this));
946  compactOpHandles();
947 
948  // Error handling: fail if transform or listener failed.
949  DiagnosedSilenceableFailure trackingFailure =
950  trackingListener.checkAndResetError();
951  if (!transform->hasTrait<ReportTrackingListenerFailuresOpTrait>() ||
952  transform->hasAttr(FindPayloadReplacementOpInterface::
953  kSilenceTrackingFailuresAttrName)) {
954  // Only report failures for ReportTrackingListenerFailuresOpTrait ops. Also
955  // do not report failures if the above mentioned attribute is set.
956  if (trackingFailure.isSilenceableFailure())
957  (void)trackingFailure.silence();
958  trackingFailure = DiagnosedSilenceableFailure::success();
959  }
960  if (!trackingFailure.succeeded()) {
961  if (result.succeeded()) {
962  result = std::move(trackingFailure);
963  } else {
964  // Transform op errors have precedence, report those first.
965  if (result.isSilenceableFailure())
966  result.attachNote() << "tracking listener also failed: "
967  << trackingFailure.getMessage();
968  (void)trackingFailure.silence();
969  }
970  }
971  if (result.isDefiniteFailure())
972  return result;
973 
974  // If a silenceable failure was produced, some results may be unset, set them
975  // to empty lists.
976  if (result.isSilenceableFailure())
977  results.setRemainingToEmpty(transform);
978 
979  // Remove the mapping for the operand if it is consumed by the operation. This
980  // allows us to catch use-after-free with assertions later on.
981  for (OpOperand *opOperand : consumedOperands) {
982  Value operand = opOperand->get();
983  if (llvm::isa<TransformHandleTypeInterface>(operand.getType())) {
984  forgetMapping(operand, origOpFlatResults);
985  } else if (llvm::isa<TransformValueHandleTypeInterface>(
986  operand.getType())) {
987  forgetValueMapping(operand, origAssociatedOps);
988  }
989  }
990 
991  if (failed(updateStateFromResults(results, transform->getResults())))
993 
994  printOnFailureRAII.release();
995  DEBUG_WITH_TYPE(DEBUG_PRINT_AFTER_ALL, {
996  DBGS() << "Top-level payload:\n";
997  getTopLevel()->print(llvm::dbgs());
998  });
999  return result;
1000 }
1001 
1002 LogicalResult transform::TransformState::updateStateFromResults(
1003  const TransformResults &results, ResultRange opResults) {
1004  for (OpResult result : opResults) {
1005  if (llvm::isa<TransformParamTypeInterface>(result.getType())) {
1006  assert(results.isParam(result.getResultNumber()) &&
1007  "expected parameters for the parameter-typed result");
1008  if (failed(
1009  setParams(result, results.getParams(result.getResultNumber())))) {
1010  return failure();
1011  }
1012  } else if (llvm::isa<TransformValueHandleTypeInterface>(result.getType())) {
1013  assert(results.isValue(result.getResultNumber()) &&
1014  "expected values for value-type-result");
1015  if (failed(setPayloadValues(
1016  result, results.getValues(result.getResultNumber())))) {
1017  return failure();
1018  }
1019  } else {
1020  assert(!results.isParam(result.getResultNumber()) &&
1021  "expected payload ops for the non-parameter typed result");
1022  if (failed(
1023  setPayloadOps(result, results.get(result.getResultNumber())))) {
1024  return failure();
1025  }
1026  }
1027  }
1028  return success();
1029 }
1030 
1031 //===----------------------------------------------------------------------===//
1032 // TransformState::Extension
1033 //===----------------------------------------------------------------------===//
1034 
1036 
1039  Operation *replacement) {
1040  // TODO: we may need to invalidate handles to operations and values nested in
1041  // the operation being replaced.
1042  return state.replacePayloadOp(op, replacement);
1043 }
1044 
1047  Value replacement) {
1048  return state.replacePayloadValue(value, replacement);
1049 }
1050 
1051 //===----------------------------------------------------------------------===//
1052 // TransformState::RegionScope
1053 //===----------------------------------------------------------------------===//
1054 
1056  // Remove handle invalidation notices as handles are going out of scope.
1057  // The same region may be re-entered leading to incorrect invalidation
1058  // errors.
1059  for (Block &block : *region) {
1060  for (Value handle : block.getArguments()) {
1061  state.invalidatedHandles.erase(handle);
1062  }
1063  for (Operation &op : block) {
1064  for (Value handle : op.getResults()) {
1065  state.invalidatedHandles.erase(handle);
1066  }
1067  }
1068  }
1069 
1070 #if LLVM_ENABLE_ABI_BREAKING_CHECKS
1071  // Remember pointers to payload ops referenced by the handles going out of
1072  // scope.
1073  SmallVector<Operation *> referencedOps =
1074  llvm::to_vector(llvm::make_first_range(state.mappings[region]->reverse));
1075 #endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
1076 
1077  state.mappings.erase(region);
1078  state.regionStack.pop_back();
1079 }
1080 
1081 //===----------------------------------------------------------------------===//
1082 // TransformResults
1083 //===----------------------------------------------------------------------===//
1084 
1085 transform::TransformResults::TransformResults(unsigned numSegments) {
1086  operations.appendEmptyRows(numSegments);
1087  params.appendEmptyRows(numSegments);
1088  values.appendEmptyRows(numSegments);
1089 }
1090 
1093  int64_t position = value.getResultNumber();
1094  assert(position < static_cast<int64_t>(this->params.size()) &&
1095  "setting params for a non-existent handle");
1096  assert(this->params[position].data() == nullptr && "params already set");
1097  assert(operations[position].data() == nullptr &&
1098  "another kind of results already set");
1099  assert(values[position].data() == nullptr &&
1100  "another kind of results already set");
1101  this->params.replace(position, params);
1102 }
1103 
1105  OpResult handle, ArrayRef<MappedValue> values) {
1107  handle, values,
1108  [&](ArrayRef<Operation *> operations) {
1109  return set(handle, operations), success();
1110  },
1111  [&](ArrayRef<Param> params) {
1112  return setParams(handle, params), success();
1113  },
1114  [&](ValueRange payloadValues) {
1115  return setValues(handle, payloadValues), success();
1116  });
1117 #ifndef NDEBUG
1118  if (!diag.succeeded())
1119  llvm::dbgs() << diag.getStatusString() << "\n";
1120  assert(diag.succeeded() && "incorrect mapping");
1121 #endif // NDEBUG
1122  (void)diag.silence();
1123 }
1124 
1126  transform::TransformOpInterface transform) {
1127  for (OpResult opResult : transform->getResults()) {
1128  if (!isSet(opResult.getResultNumber()))
1129  setMappedValues(opResult, {});
1130  }
1131 }
1132 
1134 transform::TransformResults::get(unsigned resultNumber) const {
1135  assert(resultNumber < operations.size() &&
1136  "querying results for a non-existent handle");
1137  assert(operations[resultNumber].data() != nullptr &&
1138  "querying unset results (values or params expected?)");
1139  return operations[resultNumber];
1140 }
1141 
1143 transform::TransformResults::getParams(unsigned resultNumber) const {
1144  assert(resultNumber < params.size() &&
1145  "querying params for a non-existent handle");
1146  assert(params[resultNumber].data() != nullptr &&
1147  "querying unset params (ops or values expected?)");
1148  return params[resultNumber];
1149 }
1150 
1152 transform::TransformResults::getValues(unsigned resultNumber) const {
1153  assert(resultNumber < values.size() &&
1154  "querying values for a non-existent handle");
1155  assert(values[resultNumber].data() != nullptr &&
1156  "querying unset values (ops or params expected?)");
1157  return values[resultNumber];
1158 }
1159 
1160 bool transform::TransformResults::isParam(unsigned resultNumber) const {
1161  assert(resultNumber < params.size() &&
1162  "querying association for a non-existent handle");
1163  return params[resultNumber].data() != nullptr;
1164 }
1165 
1166 bool transform::TransformResults::isValue(unsigned resultNumber) const {
1167  assert(resultNumber < values.size() &&
1168  "querying association for a non-existent handle");
1169  return values[resultNumber].data() != nullptr;
1170 }
1171 
1172 bool transform::TransformResults::isSet(unsigned resultNumber) const {
1173  assert(resultNumber < params.size() &&
1174  "querying association for a non-existent handle");
1175  return params[resultNumber].data() != nullptr ||
1176  operations[resultNumber].data() != nullptr ||
1177  values[resultNumber].data() != nullptr;
1178 }
1179 
1180 //===----------------------------------------------------------------------===//
1181 // TrackingListener
1182 //===----------------------------------------------------------------------===//
1183 
1185  TransformOpInterface op,
1186  TrackingListenerConfig config)
1187  : TransformState::Extension(state), transformOp(op), config(config) {
1188  if (op) {
1189  for (OpOperand *opOperand : transformOp.getConsumedHandleOpOperands()) {
1190  consumedHandles.insert(opOperand->get());
1191  }
1192  }
1193 }
1194 
1196  Operation *defOp = nullptr;
1197  for (Value v : values) {
1198  // Skip empty values.
1199  if (!v)
1200  continue;
1201  if (!defOp) {
1202  defOp = v.getDefiningOp();
1203  continue;
1204  }
1205  if (defOp != v.getDefiningOp())
1206  return nullptr;
1207  }
1208  return defOp;
1209 }
1210 
1212  Operation *&result, Operation *op, ValueRange newValues) const {
1213  assert(op->getNumResults() == newValues.size() &&
1214  "invalid number of replacement values");
1215  SmallVector<Value> values(newValues.begin(), newValues.end());
1216 
1218  getTransformOp(), "tracking listener failed to find replacement op "
1219  "during application of this transform op");
1220 
1221  do {
1222  // If the replacement values belong to different ops, drop the mapping.
1223  Operation *defOp = getCommonDefiningOp(values);
1224  if (!defOp) {
1225  diag.attachNote() << "replacement values belong to different ops";
1226  return diag;
1227  }
1228 
1229  // Skip through ops that implement CastOpInterface.
1230  if (config.skipCastOps && isa<CastOpInterface>(defOp)) {
1231  values.clear();
1232  values.assign(defOp->getOperands().begin(), defOp->getOperands().end());
1233  diag.attachNote(defOp->getLoc())
1234  << "using output of 'CastOpInterface' op";
1235  continue;
1236  }
1237 
1238  // If the defining op has the same name or we do not care about the name of
1239  // op replacements at all, we take it as a replacement.
1240  if (!config.requireMatchingReplacementOpName ||
1241  op->getName() == defOp->getName()) {
1242  result = defOp;
1244  }
1245 
1246  // Replacing an op with a constant-like equivalent is a common
1247  // canonicalization.
1248  if (defOp->hasTrait<OpTrait::ConstantLike>()) {
1249  result = defOp;
1251  }
1252 
1253  values.clear();
1254 
1255  // Skip through ops that implement FindPayloadReplacementOpInterface.
1256  if (auto findReplacementOpInterface =
1257  dyn_cast<FindPayloadReplacementOpInterface>(defOp)) {
1258  values.assign(findReplacementOpInterface.getNextOperands());
1259  diag.attachNote(defOp->getLoc()) << "using operands provided by "
1260  "'FindPayloadReplacementOpInterface'";
1261  continue;
1262  }
1263  } while (!values.empty());
1264 
1265  diag.attachNote() << "ran out of suitable replacement values";
1266  return diag;
1267 }
1268 
1270  Location loc, function_ref<void(Diagnostic &)> reasonCallback) {
1271  LLVM_DEBUG({
1273  reasonCallback(diag);
1274  DBGS() << "Match Failure : " << diag.str() << "\n";
1275  });
1276 }
1277 
1278 void transform::TrackingListener::notifyOperationErased(Operation *op) {
1279  // Remove mappings for result values.
1280  for (OpResult value : op->getResults())
1281  (void)replacePayloadValue(value, nullptr);
1282  // Remove mapping for op.
1283  (void)replacePayloadOp(op, nullptr);
1284 }
1285 
1286 void transform::TrackingListener::notifyOperationReplaced(
1287  Operation *op, ValueRange newValues) {
1288  assert(op->getNumResults() == newValues.size() &&
1289  "invalid number of replacement values");
1290 
1291  // Replace value handles.
1292  for (auto [oldValue, newValue] : llvm::zip(op->getResults(), newValues))
1293  (void)replacePayloadValue(oldValue, newValue);
1294 
1295  // Replace op handle.
1296  SmallVector<Value> opHandles;
1297  if (failed(getTransformState().getHandlesForPayloadOp(
1298  op, opHandles, /*includeOutOfScope=*/true))) {
1299  // Op is not tracked.
1300  return;
1301  }
1302 
1303  // Helper function to check if the current transform op consumes any handle
1304  // that is mapped to `op`.
1305  //
1306  // Note: If a handle was consumed, there shouldn't be any alive users, so it
1307  // is not really necessary to check for consumed handles. However, in case
1308  // there are indeed alive handles that were consumed (which is undefined
1309  // behavior) and a replacement op could not be found, we want to fail with a
1310  // nicer error message: "op uses a handle invalidated..." instead of "could
1311  // not find replacement op". This nicer error is produced later.
1312  auto handleWasConsumed = [&] {
1313  return llvm::any_of(opHandles,
1314  [&](Value h) { return consumedHandles.contains(h); });
1315  };
1316 
1317  // Check if there are any handles that must be updated.
1318  Value aliveHandle;
1319  if (config.skipHandleFn) {
1320  auto it = llvm::find_if(opHandles,
1321  [&](Value v) { return !config.skipHandleFn(v); });
1322  if (it != opHandles.end())
1323  aliveHandle = *it;
1324  } else if (!opHandles.empty()) {
1325  aliveHandle = opHandles.front();
1326  }
1327  if (!aliveHandle || handleWasConsumed()) {
1328  // The op is tracked but the corresponding handles are dead or were
1329  // consumed. Drop the op form the mapping.
1330  (void)replacePayloadOp(op, nullptr);
1331  return;
1332  }
1333 
1334  Operation *replacement;
1336  findReplacementOp(replacement, op, newValues);
1337  // If the op is tracked but no replacement op was found, send a
1338  // notification.
1339  if (!diag.succeeded()) {
1340  diag.attachNote(aliveHandle.getLoc())
1341  << "replacement is required because this handle must be updated";
1342  notifyPayloadReplacementNotFound(op, newValues, std::move(diag));
1343  (void)replacePayloadOp(op, nullptr);
1344  return;
1345  }
1346 
1347  (void)replacePayloadOp(op, replacement);
1348 }
1349 
1351  // The state of the ErrorCheckingTrackingListener must be checked and reset
1352  // if there was an error. This is to prevent errors from accidentally being
1353  // missed.
1354  assert(status.succeeded() && "listener state was not checked");
1355 }
1356 
1359  DiagnosedSilenceableFailure s = std::move(status);
1361  errorCounter = 0;
1362  return s;
1363 }
1364 
1366  return !status.succeeded();
1367 }
1368 
1371 
1372  // Merge potentially existing diags and store the result in the listener.
1374  diag.takeDiagnostics(diags);
1375  if (!status.succeeded())
1376  status.takeDiagnostics(diags);
1377  status = DiagnosedSilenceableFailure::silenceableFailure(std::move(diags));
1378 
1379  // Report more details.
1380  status.attachNote(op->getLoc()) << "[" << errorCounter << "] replaced op";
1381  for (auto &&[index, value] : llvm::enumerate(values))
1382  status.attachNote(value.getLoc())
1383  << "[" << errorCounter << "] replacement value " << index;
1384  ++errorCounter;
1385 }
1386 
1387 //===----------------------------------------------------------------------===//
1388 // TransformRewriter
1389 //===----------------------------------------------------------------------===//
1390 
1393  : RewriterBase(ctx), listener(listener) {
1394  setListener(listener);
1395 }
1396 
1398  return listener->failed();
1399 }
1400 
1401 /// Silence all tracking failures that have been encountered so far.
1403  if (hasTrackingFailures()) {
1404  DiagnosedSilenceableFailure status = listener->checkAndResetError();
1405  (void)status.silence();
1406  }
1407 }
1408 
1410  Operation *op, Operation *replacement) {
1411  return listener->replacePayloadOp(op, replacement);
1412 }
1413 
1414 //===----------------------------------------------------------------------===//
1415 // Utilities for TransformEachOpTrait.
1416 //===----------------------------------------------------------------------===//
1417 
1420  ArrayRef<Operation *> targets) {
1421  for (auto &&[position, parent] : llvm::enumerate(targets)) {
1422  for (Operation *child : targets.drop_front(position + 1)) {
1423  if (parent->isAncestor(child)) {
1425  emitError(loc)
1426  << "transform operation consumes a handle pointing to an ancestor "
1427  "payload operation before its descendant";
1428  diag.attachNote()
1429  << "the ancestor is likely erased or rewritten before the "
1430  "descendant is accessed, leading to undefined behavior";
1431  diag.attachNote(parent->getLoc()) << "ancestor payload op";
1432  diag.attachNote(child->getLoc()) << "descendant payload op";
1433  return diag;
1434  }
1435  }
1436  }
1437  return success();
1438 }
1439 
1442  Location payloadOpLoc,
1443  const ApplyToEachResultList &partialResult) {
1444  Location transformOpLoc = transformOp->getLoc();
1445  StringRef transformOpName = transformOp->getName().getStringRef();
1446  unsigned expectedNumResults = transformOp->getNumResults();
1447 
1448  // Reuse the emission of the diagnostic note.
1449  auto emitDiag = [&]() {
1450  auto diag = mlir::emitError(transformOpLoc);
1451  diag.attachNote(payloadOpLoc) << "when applied to this op";
1452  return diag;
1453  };
1454 
1455  if (partialResult.size() != expectedNumResults) {
1456  auto diag = emitDiag() << "application of " << transformOpName
1457  << " expected to produce " << expectedNumResults
1458  << " results (actually produced "
1459  << partialResult.size() << ").";
1460  diag.attachNote(transformOpLoc)
1461  << "if you need variadic results, consider a generic `apply` "
1462  << "instead of the specialized `applyToOne`.";
1463  return failure();
1464  }
1465 
1466  // Check that the right kind of value was produced.
1467  for (const auto &[ptr, res] :
1468  llvm::zip(partialResult, transformOp->getResults())) {
1469  if (ptr.isNull())
1470  continue;
1471  if (llvm::isa<TransformHandleTypeInterface>(res.getType()) &&
1472  !ptr.is<Operation *>()) {
1473  return emitDiag() << "application of " << transformOpName
1474  << " expected to produce an Operation * for result #"
1475  << res.getResultNumber();
1476  }
1477  if (llvm::isa<TransformParamTypeInterface>(res.getType()) &&
1478  !ptr.is<Attribute>()) {
1479  return emitDiag() << "application of " << transformOpName
1480  << " expected to produce an Attribute for result #"
1481  << res.getResultNumber();
1482  }
1483  if (llvm::isa<TransformValueHandleTypeInterface>(res.getType()) &&
1484  !ptr.is<Value>()) {
1485  return emitDiag() << "application of " << transformOpName
1486  << " expected to produce a Value for result #"
1487  << res.getResultNumber();
1488  }
1489  }
1490  return success();
1491 }
1492 
1493 template <typename T>
1495  return llvm::to_vector(llvm::map_range(
1496  range, [](transform::MappedValue value) { return value.get<T>(); }));
1497 }
1498 
1500  Operation *transformOp, TransformResults &transformResults,
1503  transposed.resize(transformOp->getNumResults());
1504  for (const ApplyToEachResultList &partialResults : results) {
1505  if (llvm::any_of(partialResults,
1506  [](MappedValue value) { return value.isNull(); }))
1507  continue;
1508  assert(transformOp->getNumResults() == partialResults.size() &&
1509  "expected as many partial results as op as results");
1510  for (auto [i, value] : llvm::enumerate(partialResults))
1511  transposed[i].push_back(value);
1512  }
1513 
1514  for (OpResult r : transformOp->getResults()) {
1515  unsigned position = r.getResultNumber();
1516  if (llvm::isa<TransformParamTypeInterface>(r.getType())) {
1517  transformResults.setParams(r,
1518  castVector<Attribute>(transposed[position]));
1519  } else if (llvm::isa<TransformValueHandleTypeInterface>(r.getType())) {
1520  transformResults.setValues(r, castVector<Value>(transposed[position]));
1521  } else {
1522  transformResults.set(r, castVector<Operation *>(transposed[position]));
1523  }
1524  }
1525 }
1526 
1527 //===----------------------------------------------------------------------===//
1528 // Utilities for implementing transform ops with regions.
1529 //===----------------------------------------------------------------------===//
1530 
1533  ValueRange values, const transform::TransformState &state) {
1534  for (Value operand : values) {
1535  SmallVector<MappedValue> &mapped = mappings.emplace_back();
1536  if (llvm::isa<TransformHandleTypeInterface>(operand.getType())) {
1537  llvm::append_range(mapped, state.getPayloadOps(operand));
1538  } else if (llvm::isa<TransformValueHandleTypeInterface>(
1539  operand.getType())) {
1540  llvm::append_range(mapped, state.getPayloadValues(operand));
1541  } else {
1542  assert(llvm::isa<TransformParamTypeInterface>(operand.getType()) &&
1543  "unsupported kind of transform dialect value");
1544  llvm::append_range(mapped, state.getParams(operand));
1545  }
1546  }
1547 }
1548 
1550  Block *block, transform::TransformState &state,
1551  transform::TransformResults &results) {
1552  for (auto &&[terminatorOperand, result] :
1553  llvm::zip(block->getTerminator()->getOperands(),
1554  block->getParentOp()->getOpResults())) {
1555  if (llvm::isa<transform::TransformHandleTypeInterface>(result.getType())) {
1556  results.set(result, state.getPayloadOps(terminatorOperand));
1557  } else if (llvm::isa<transform::TransformValueHandleTypeInterface>(
1558  result.getType())) {
1559  results.setValues(result, state.getPayloadValues(terminatorOperand));
1560  } else {
1561  assert(
1562  llvm::isa<transform::TransformParamTypeInterface>(result.getType()) &&
1563  "unhandled transform type interface");
1564  results.setParams(result, state.getParams(terminatorOperand));
1565  }
1566  }
1567 }
1568 
1571  Operation *payloadRoot) {
1572  return TransformState(region, payloadRoot);
1573 }
1574 
1575 //===----------------------------------------------------------------------===//
1576 // Utilities for PossibleTopLevelTransformOpTrait.
1577 //===----------------------------------------------------------------------===//
1578 
1579 /// Appends to `effects` the memory effect instances on `target` with the same
1580 /// resource and effect as the ones the operation `iface` having on `source`.
1581 static void
1582 remapEffects(MemoryEffectOpInterface iface, BlockArgument source, Value target,
1585  iface.getEffectsOnValue(source, nestedEffects);
1586  for (const auto &effect : nestedEffects)
1587  effects.emplace_back(effect.getEffect(), target, effect.getResource());
1588 }
1589 
1590 /// Appends to `effects` the same effects as the operations of `block` have on
1591 /// block arguments but associated with `operands.`
1592 static void
1595  for (Operation &op : block) {
1596  auto iface = dyn_cast<MemoryEffectOpInterface>(&op);
1597  if (!iface)
1598  continue;
1599 
1600  for (auto &&[source, target] : llvm::zip(block.getArguments(), operands)) {
1601  remapEffects(iface, source, target, effects);
1602  }
1603 
1605  iface.getEffectsOnResource(transform::PayloadIRResource::get(),
1606  nestedEffects);
1607  llvm::append_range(effects, nestedEffects);
1608  }
1609 }
1610 
1612  Operation *operation, Value root, Block &body,
1614  transform::onlyReadsHandle(operation->getOperands(), effects);
1615  transform::producesHandle(operation->getResults(), effects);
1616 
1617  if (!root) {
1618  for (Operation &op : body) {
1619  auto iface = dyn_cast<MemoryEffectOpInterface>(&op);
1620  if (!iface)
1621  continue;
1622 
1624  iface.getEffects(effects);
1625  }
1626  return;
1627  }
1628 
1629  // Carry over all effects on arguments of the entry block as those on the
1630  // operands, this is the same value just remapped.
1631  remapArgumentEffects(body, operation->getOperands(), effects);
1632 }
1633 
1635  TransformState &state, Operation *op, Region &region) {
1636  SmallVector<Operation *> targets;
1637  SmallVector<SmallVector<MappedValue>> extraMappings;
1638  if (op->getNumOperands() != 0) {
1639  llvm::append_range(targets, state.getPayloadOps(op->getOperand(0)));
1640  prepareValueMappings(extraMappings, op->getOperands().drop_front(), state);
1641  } else {
1642  if (state.getNumTopLevelMappings() !=
1643  region.front().getNumArguments() - 1) {
1644  return emitError(op->getLoc())
1645  << "operation expects " << region.front().getNumArguments() - 1
1646  << " extra value bindings, but " << state.getNumTopLevelMappings()
1647  << " were provided to the interpreter";
1648  }
1649 
1650  targets.push_back(state.getTopLevel());
1651 
1652  for (unsigned i = 0, e = state.getNumTopLevelMappings(); i < e; ++i)
1653  extraMappings.push_back(llvm::to_vector(state.getTopLevelMapping(i)));
1654  }
1655 
1656  if (failed(state.mapBlockArguments(region.front().getArgument(0), targets)))
1657  return failure();
1658 
1659  for (BlockArgument argument : region.front().getArguments().drop_front()) {
1660  if (failed(state.mapBlockArgument(
1661  argument, extraMappings[argument.getArgNumber() - 1])))
1662  return failure();
1663  }
1664 
1665  return success();
1666 }
1667 
1670  // Attaching this trait without the interface is a misuse of the API, but it
1671  // cannot be caught via a static_assert because interface registration is
1672  // dynamic.
1673  assert(isa<TransformOpInterface>(op) &&
1674  "should implement TransformOpInterface to have "
1675  "PossibleTopLevelTransformOpTrait");
1676 
1677  if (op->getNumRegions() < 1)
1678  return op->emitOpError() << "expects at least one region";
1679 
1680  Region *bodyRegion = &op->getRegion(0);
1681  if (!llvm::hasNItems(*bodyRegion, 1))
1682  return op->emitOpError() << "expects a single-block region";
1683 
1684  Block *body = &bodyRegion->front();
1685  if (body->getNumArguments() == 0) {
1686  return op->emitOpError()
1687  << "expects the entry block to have at least one argument";
1688  }
1689  if (!llvm::isa<TransformHandleTypeInterface>(
1690  body->getArgument(0).getType())) {
1691  return op->emitOpError()
1692  << "expects the first entry block argument to be of type "
1693  "implementing TransformHandleTypeInterface";
1694  }
1695  BlockArgument arg = body->getArgument(0);
1696  if (op->getNumOperands() != 0) {
1697  if (arg.getType() != op->getOperand(0).getType()) {
1698  return op->emitOpError()
1699  << "expects the type of the block argument to match "
1700  "the type of the operand";
1701  }
1702  }
1703  for (BlockArgument arg : body->getArguments().drop_front()) {
1704  if (llvm::isa<TransformHandleTypeInterface, TransformParamTypeInterface,
1705  TransformValueHandleTypeInterface>(arg.getType()))
1706  continue;
1707 
1709  op->emitOpError()
1710  << "expects trailing entry block arguments to be of type implementing "
1711  "TransformHandleTypeInterface, TransformValueHandleTypeInterface or "
1712  "TransformParamTypeInterface";
1713  diag.attachNote() << "argument #" << arg.getArgNumber() << " does not";
1714  return diag;
1715  }
1716 
1717  if (auto *parent =
1719  if (op->getNumOperands() != body->getNumArguments()) {
1721  op->emitOpError()
1722  << "expects operands to be provided for a nested op";
1723  diag.attachNote(parent->getLoc())
1724  << "nested in another possible top-level op";
1725  return diag;
1726  }
1727  }
1728 
1729  return success();
1730 }
1731 
1732 //===----------------------------------------------------------------------===//
1733 // Utilities for ParamProducedTransformOpTrait.
1734 //===----------------------------------------------------------------------===//
1735 
1738  producesHandle(op->getResults(), effects);
1739  bool hasPayloadOperands = false;
1740  for (Value operand : op->getOperands()) {
1741  onlyReadsHandle(operand, effects);
1742  if (llvm::isa<TransformHandleTypeInterface,
1743  TransformValueHandleTypeInterface>(operand.getType()))
1744  hasPayloadOperands = true;
1745  }
1746  if (hasPayloadOperands)
1747  onlyReadsPayload(effects);
1748 }
1749 
1752  // Interfaces can be attached dynamically, so this cannot be a static
1753  // assert.
1754  if (!op->getName().getInterface<MemoryEffectOpInterface>()) {
1755  llvm::report_fatal_error(
1756  Twine("ParamProducerTransformOpTrait must be attached to an op that "
1757  "implements MemoryEffectsOpInterface, found on ") +
1758  op->getName().getStringRef());
1759  }
1760  for (Value result : op->getResults()) {
1761  if (llvm::isa<TransformParamTypeInterface>(result.getType()))
1762  continue;
1763  return op->emitOpError()
1764  << "ParamProducerTransformOpTrait attached to this op expects "
1765  "result types to implement TransformParamTypeInterface";
1766  }
1767  return success();
1768 }
1769 
1770 //===----------------------------------------------------------------------===//
1771 // Memory effects.
1772 //===----------------------------------------------------------------------===//
1773 
1775  ValueRange handles,
1777  for (Value handle : handles) {
1778  effects.emplace_back(MemoryEffects::Read::get(), handle,
1780  effects.emplace_back(MemoryEffects::Free::get(), handle,
1782  }
1783 }
1784 
1785 /// Returns `true` if the given list of effects instances contains an instance
1786 /// with the effect type specified as template parameter.
1787 template <typename EffectTy, typename ResourceTy, typename Range>
1788 static bool hasEffect(Range &&effects) {
1789  return llvm::any_of(effects, [](const MemoryEffects::EffectInstance &effect) {
1790  return isa<EffectTy>(effect.getEffect()) &&
1791  isa<ResourceTy>(effect.getResource());
1792  });
1793 }
1794 
1796  transform::TransformOpInterface transform) {
1797  auto iface = cast<MemoryEffectOpInterface>(transform.getOperation());
1799  iface.getEffectsOnValue(handle, effects);
1800  return ::hasEffect<MemoryEffects::Read, TransformMappingResource>(effects) &&
1801  ::hasEffect<MemoryEffects::Free, TransformMappingResource>(effects);
1802 }
1803 
1805  ValueRange handles,
1807  for (Value handle : handles) {
1808  effects.emplace_back(MemoryEffects::Allocate::get(), handle,
1810  effects.emplace_back(MemoryEffects::Write::get(), handle,
1812  }
1813 }
1814 
1816  ValueRange handles,
1818  for (Value handle : handles) {
1819  effects.emplace_back(MemoryEffects::Read::get(), handle,
1821  }
1822 }
1823 
1826  effects.emplace_back(MemoryEffects::Read::get(), PayloadIRResource::get());
1827  effects.emplace_back(MemoryEffects::Write::get(), PayloadIRResource::get());
1828 }
1829 
1832  effects.emplace_back(MemoryEffects::Read::get(), PayloadIRResource::get());
1833 }
1834 
1835 bool transform::doesModifyPayload(transform::TransformOpInterface transform) {
1836  auto iface = cast<MemoryEffectOpInterface>(transform.getOperation());
1838  iface.getEffects(effects);
1839  return ::hasEffect<MemoryEffects::Write, PayloadIRResource>(effects);
1840 }
1841 
1842 bool transform::doesReadPayload(transform::TransformOpInterface transform) {
1843  auto iface = cast<MemoryEffectOpInterface>(transform.getOperation());
1845  iface.getEffects(effects);
1846  return ::hasEffect<MemoryEffects::Read, PayloadIRResource>(effects);
1847 }
1848 
1850  Block &block, llvm::SmallDenseSet<unsigned int> &consumedArguments) {
1852  for (Operation &nested : block) {
1853  auto iface = dyn_cast<MemoryEffectOpInterface>(nested);
1854  if (!iface)
1855  continue;
1856 
1857  effects.clear();
1858  iface.getEffects(effects);
1859  for (const MemoryEffects::EffectInstance &effect : effects) {
1860  BlockArgument argument =
1861  dyn_cast_or_null<BlockArgument>(effect.getValue());
1862  if (!argument || argument.getOwner() != &block ||
1863  !isa<MemoryEffects::Free>(effect.getEffect()) ||
1864  effect.getResource() != transform::TransformMappingResource::get()) {
1865  continue;
1866  }
1867  consumedArguments.insert(argument.getArgNumber());
1868  }
1869  }
1870 }
1871 
1872 //===----------------------------------------------------------------------===//
1873 // Utilities for TransformOpInterface.
1874 //===----------------------------------------------------------------------===//
1875 
1877  TransformOpInterface transformOp) {
1878  SmallVector<OpOperand *> consumedOperands;
1879  consumedOperands.reserve(transformOp->getNumOperands());
1880  auto memEffectInterface =
1881  cast<MemoryEffectOpInterface>(transformOp.getOperation());
1883  for (OpOperand &target : transformOp->getOpOperands()) {
1884  effects.clear();
1885  memEffectInterface.getEffectsOnValue(target.get(), effects);
1886  if (llvm::any_of(effects, [](const MemoryEffects::EffectInstance &effect) {
1887  return isa<transform::TransformMappingResource>(
1888  effect.getResource()) &&
1889  isa<MemoryEffects::Free>(effect.getEffect());
1890  })) {
1891  consumedOperands.push_back(&target);
1892  }
1893  }
1894  return consumedOperands;
1895 }
1896 
1898  auto iface = cast<MemoryEffectOpInterface>(op);
1900  iface.getEffects(effects);
1901 
1902  auto effectsOn = [&](Value value) {
1903  return llvm::make_filter_range(
1904  effects, [value](const MemoryEffects::EffectInstance &instance) {
1905  return instance.getValue() == value;
1906  });
1907  };
1908 
1909  std::optional<unsigned> firstConsumedOperand;
1910  for (OpOperand &operand : op->getOpOperands()) {
1911  auto range = effectsOn(operand.get());
1912  if (range.empty()) {
1914  op->emitError() << "TransformOpInterface requires memory effects "
1915  "on operands to be specified";
1916  diag.attachNote() << "no effects specified for operand #"
1917  << operand.getOperandNumber();
1918  return diag;
1919  }
1920  if (::hasEffect<MemoryEffects::Allocate, TransformMappingResource>(range)) {
1922  << "TransformOpInterface did not expect "
1923  "'allocate' memory effect on an operand";
1924  diag.attachNote() << "specified for operand #"
1925  << operand.getOperandNumber();
1926  return diag;
1927  }
1928  if (!firstConsumedOperand &&
1929  ::hasEffect<MemoryEffects::Free, TransformMappingResource>(range)) {
1930  firstConsumedOperand = operand.getOperandNumber();
1931  }
1932  }
1933 
1934  if (firstConsumedOperand &&
1935  !::hasEffect<MemoryEffects::Write, PayloadIRResource>(effects)) {
1937  op->emitError()
1938  << "TransformOpInterface expects ops consuming operands to have a "
1939  "'write' effect on the payload resource";
1940  diag.attachNote() << "consumes operand #" << *firstConsumedOperand;
1941  return diag;
1942  }
1943 
1944  for (OpResult result : op->getResults()) {
1945  auto range = effectsOn(result);
1946  if (!::hasEffect<MemoryEffects::Allocate, TransformMappingResource>(
1947  range)) {
1949  op->emitError() << "TransformOpInterface requires 'allocate' memory "
1950  "effect to be specified for results";
1951  diag.attachNote() << "no 'allocate' effect specified for result #"
1952  << result.getResultNumber();
1953  return diag;
1954  }
1955  }
1956 
1957  return success();
1958 }
1959 
1960 //===----------------------------------------------------------------------===//
1961 // Entry point.
1962 //===----------------------------------------------------------------------===//
1963 
1965  Operation *payloadRoot, TransformOpInterface transform,
1966  const RaggedArray<MappedValue> &extraMapping,
1967  const TransformOptions &options, bool enforceToplevelTransformOp) {
1968  if (enforceToplevelTransformOp) {
1969  if (!transform->hasTrait<PossibleTopLevelTransformOpTrait>() ||
1970  transform->getNumOperands() != 0) {
1971  return transform->emitError()
1972  << "expected transform to start at the top-level transform op";
1973  }
1974  } else if (failed(
1976  return failure();
1977  }
1978 
1979  TransformState state(transform->getParentRegion(), payloadRoot, extraMapping,
1980  options);
1981  return state.applyTransform(transform).checkAndReport();
1982 }
1983 
1984 //===----------------------------------------------------------------------===//
1985 // Generated interface implementation.
1986 //===----------------------------------------------------------------------===//
1987 
1988 #include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.cpp.inc"
1989 #include "mlir/Dialect/Transform/Interfaces/TransformTypeInterfaces.cpp.inc"
static InFlightDiagnostic emitDiag(Location location, DiagnosticSeverity severity, const Twine &message)
Helper function used to emit a diagnostic with an optionally empty twine message.
static std::string diag(const llvm::Value &value)
static llvm::ManagedStatic< PassManagerOptions > options
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
#define DEBUG_TYPE_FULL
static void remapEffects(MemoryEffectOpInterface iface, BlockArgument source, Value target, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
Appends to effects the memory effect instances on target with the same resource and effect as the one...
#define FULL_LDBG(X)
static void remapArgumentEffects(Block &block, ValueRange operands, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
Appends to effects the same effects as the operations of block have on block arguments but associated...
static bool happensBefore(Operation *a, Operation *b)
Return true if a happens before b, i.e., a or one of its ancestors properly dominates b and b is not ...
void dropMappingEntry(Mapping &mapping, Key key, Mapped mapped)
#define DEBUG_PRINT_AFTER_ALL
DiagnosedSilenceableFailure checkRepeatedConsumptionInOperand(ArrayRef< T > payload, transform::TransformOpInterface transform, unsigned operandNumber)
#define DBGS()
static DiagnosedSilenceableFailure dispatchMappedValues(Value handle, ArrayRef< transform::MappedValue > values, function_ref< LogicalResult(ArrayRef< Operation * >)> operationsFn, function_ref< LogicalResult(ArrayRef< transform::Param >)> paramsFn, function_ref< LogicalResult(ValueRange)> valuesFn)
Given a list of MappedValues, cast them to the value kind implied by the interface of the handle type...
static SmallVector< T > castVector(ArrayRef< transform::MappedValue > range)
Attributes are known-constant values of operations.
Definition: Attributes.h:25
This class represents an argument of a Block.
Definition: Value.h:319
Block * getOwner() const
Returns the block that owns this argument.
Definition: Value.h:328
unsigned getArgNumber() const
Returns the number of this argument.
Definition: Value.h:331
Block represents an ordered list of Operations.
Definition: Block.h:30
Operation * findAncestorOpInBlock(Operation &op)
Returns 'op' if 'op' lies in this block, or otherwise finds the ancestor operation of 'op' that lies ...
Definition: Block.cpp:73
BlockArgument getArgument(unsigned i)
Definition: Block.h:126
unsigned getNumArguments()
Definition: Block.h:125
Operation * getTerminator()
Get the terminator operation of this block.
Definition: Block.cpp:243
BlockArgListType getArguments()
Definition: Block.h:84
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
Definition: Block.cpp:30
A compatibility class connecting InFlightDiagnostic to DiagnosedSilenceableFailure while providing an...
The result of a transform IR operation application.
LogicalResult silence()
Converts silenceable failure into LogicalResult success without reporting the diagnostic,...
static DiagnosedSilenceableFailure success()
Constructs a DiagnosedSilenceableFailure in the success state.
Diagnostic & attachNote(std::optional< Location > loc=std::nullopt)
Attaches a note to the last diagnostic.
std::string getMessage() const
Returns the diagnostic message without emitting it.
bool isDefiniteFailure() const
Returns true if this is a definite failure.
LogicalResult checkAndReport()
Converts all kinds of failure into a LogicalResult failure, emitting the diagnostic if necessary.
static DiagnosedSilenceableFailure silenceableFailure(Diagnostic &&diag)
Constructs a DiagnosedSilenceableFailure in the silenceable failure state, ready to emit the given di...
bool succeeded() const
Returns true if this is a success.
static DiagnosedSilenceableFailure definiteFailure()
Constructs a DiagnosedSilenceableFailure in the failure state.
bool isSilenceableFailure() const
Returns true if this is a silenceable failure.
This class contains all of the information necessary to report a diagnostic to the DiagnosticEngine.
Definition: Diagnostics.h:156
IRValueT get() const
Return the current value being used by this operand.
Definition: UseDefLists.h:160
This class represents a diagnostic that is inflight and set to be reported.
Definition: Diagnostics.h:308
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
void setListener(Listener *newListener)
Sets the listener of this builder to the one provided.
Definition: Builders.h:318
This class represents an operand of an operation.
Definition: Value.h:267
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Definition: Value.cpp:216
Set of flags used to control the behavior of the various IR print methods (e.g.
This is a value defined by a result of an operation.
Definition: Value.h:457
unsigned getResultNumber() const
Returns the number of this result.
Definition: Value.h:469
This class provides the API for a sub-set of ops that are known to be constant-like.
This class provides the API for ops that are known to be isolated from above.
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
StringAttr getIdentifier() const
Return the name of this operation as a StringAttr.
T::Concept * getInterface() const
Returns an instance of the concept object for the given interface if it was registered to this operat...
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Value getOperand(unsigned idx)
Definition: Operation.h:345
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition: Operation.h:745
bool isBeforeInBlock(Operation *other)
Given an operation 'other' that is within the same parent block, return whether the current operation...
Definition: Operation.cpp:386
unsigned getNumRegions()
Returns the number of regions held by this operation.
Definition: Operation.h:669
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
unsigned getNumOperands()
Definition: Operation.h:341
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition: Operation.h:234
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:268
Block * getBlock()
Returns the operation block that contains this operation.
Definition: Operation.h:213
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Definition: Operation.h:682
Operation * getParentWithTrait()
Returns the closest surrounding parent operation with trait Trait.
Definition: Operation.h:248
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:119
MutableArrayRef< OpOperand > getOpOperands()
Definition: Operation.h:378
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:373
result_range getOpResults()
Definition: Operation.h:415
result_range getResults()
Definition: Operation.h:410
bool isProperAncestor(Operation *other)
Return true if this operation is a proper ancestor of the other operation.
Definition: Operation.cpp:219
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:671
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:399
A 2D array where each row may have different length.
Definition: RaggedArray.h:18
size_t size() const
Returns the number of rows in the 2D array.
Definition: RaggedArray.h:21
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
Operation * getParentOp()
Return the parent operation this region is attached to.
Definition: Region.h:200
Block & front()
Definition: Region.h:65
This class implements the result iterators for the Operation class.
Definition: ValueRange.h:242
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:400
This class represents a specific instance of an effect.
Resource * getResource() const
Return the resource that the effect applies to.
EffectT * getEffect() const
Return the effect being applied.
Value getValue() const
Return the value the effect is applied on, or nullptr if there isn't a known value being affected.
static DerivedEffect * get()
Returns a unique instance for the derived effect class.
static TransformMappingResource * get()
Returns a unique instance for the given effect class.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:26
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
Operation * getOwner() const
Return the owner of this operand.
Definition: UseDefLists.h:38
A list of results of applying a transform op with ApplyEachOpTrait to a single payload operation,...
size_t size() const
Returns the number of elements in the list.
A specialized listener that keeps track of cases in which no replacement payload could be found.
bool failed() const
Return "true" if this tracking listener had a failure.
void notifyPayloadReplacementNotFound(Operation *op, ValueRange values, DiagnosedSilenceableFailure &&diag) override
This function is called when a tracked payload op is dropped because no replacement op was found.
DiagnosedSilenceableFailure checkAndResetError()
Check and return the current error state of this listener.
This trait is supposed to be attached to Transform dialect operations that can be standalone top-leve...
TrackingListener failures are reported only for ops that have this trait.
void notifyMatchFailure(Location loc, function_ref< void(Diagnostic &)> reasonCallback) override
Notify the listener that the pattern failed to match the given operation, and provide a callback to p...
TrackingListener(TransformState &state, TransformOpInterface op, TrackingListenerConfig config=TrackingListenerConfig())
Create a new TrackingListener for usage in the specified transform op.
static Operation * getCommonDefiningOp(ValueRange values)
Return the single op that defines all given values (if any).
virtual DiagnosedSilenceableFailure findReplacementOp(Operation *&result, Operation *op, ValueRange newValues) const
Return a replacement payload op for the given op, which is going to be replaced with the given values...
Options controlling the application of transform operations by the TransformState.
Local mapping between values defined by a specific op implementing the TransformOpInterface and the p...
void setValues(OpResult handle, Range &&values)
Indicates that the result of the transform IR op at the given position corresponds to the given range...
void setParams(OpResult value, ArrayRef< TransformState::Param > params)
Indicates that the result of the transform IR op at the given position corresponds to the given list ...
void set(OpResult value, Range &&ops)
Indicates that the result of the transform IR op at the given position corresponds to the given list ...
void setRemainingToEmpty(TransformOpInterface transform)
Sets the currently unset results to empty lists of the kind expected by the corresponding results of ...
void setMappedValues(OpResult handle, ArrayRef< MappedValue > values)
Indicates that the result of the transform IR op at the given position corresponds to the given range...
This is a special rewriter to be used in transform op implementations, providing additional helper fu...
TransformRewriter(MLIRContext *ctx, ErrorCheckingTrackingListener *listener)
Create a new TransformRewriter.
bool hasTrackingFailures() const
Return "true" if the tracking listener had failures.
LogicalResult notifyPayloadOperationReplaced(Operation *op, Operation *replacement)
Notify the transform dialect interpreter that the given op has been replaced with another op and that...
void silenceTrackingFailure()
Silence all tracking failures that have been encountered so far.
LogicalResult replacePayloadOp(Operation *op, Operation *replacement)
Replaces the given payload op with another op.
virtual ~Extension()
Base virtual destructor.
LogicalResult replacePayloadValue(Value value, Value replacement)
Replaces the given payload value with another value.
A RAII object maintaining a "stack frame" for a transform IR region.
~RegionScope()
Forgets the mapping from or to values defined in the associated transform IR region,...
The state maintained across applications of various ops implementing the TransformOpInterface.
LogicalResult getHandlesForPayloadValue(Value payloadValue, SmallVectorImpl< Value > &handles, bool includeOutOfScope=false) const
Populates handles with all handles pointing to the given payload IR value.
DiagnosedSilenceableFailure applyTransform(TransformOpInterface transform)
Applies the transformation specified by the given transform op and updates the state accordingly.
ArrayRef< Attribute > getParams(Value value) const
Returns the list of parameters that the given transform IR value corresponds to.
LogicalResult mapBlockArgument(BlockArgument argument, ArrayRef< MappedValue > values)
LogicalResult getHandlesForPayloadOp(Operation *op, SmallVectorImpl< Value > &handles, bool includeOutOfScope=false) const
Populates handles with all handles pointing to the given Payload IR op.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:285
LogicalResult verifyTransformOpInterface(Operation *op)
Verification hook for TransformOpInterface.
void setApplyToOneResults(Operation *transformOp, TransformResults &transformResults, ArrayRef< ApplyToEachResultList > results)
"Transpose" the results produced by individual applications, arranging them per result value of the t...
void forwardTerminatorOperands(Block *block, transform::TransformState &state, transform::TransformResults &results)
Populates results with payload associations that match exactly those of the operands to block's termi...
void getParamProducerTransformOpTraitEffects(Operation *op, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
Non-template implementation of ParamProducerTransformOpTrait::getEffects().
LogicalResult checkNestedConsumption(Location loc, ArrayRef< Operation * > targets)
Reports an error and returns failure if targets contains an ancestor operation before its descendant ...
LogicalResult mapPossibleTopLevelTransformOpBlockArguments(TransformState &state, Operation *op, Region &region)
Maps the only block argument of the op with PossibleTopLevelTransformOpTrait to either the list of op...
TransformState makeTransformStateForTesting(Region *region, Operation *payloadRoot)
Make a dummy transform state for testing purposes.
LogicalResult checkApplyToOne(Operation *transformOp, Location payloadOpLoc, const ApplyToEachResultList &partialResult)
Check that the contents of partialResult matches the number, kind (payload op or parameter) and nulli...
SmallVector< OpOperand * > getConsumedHandleOpOperands(transform::TransformOpInterface transformOp)
Returns all operands that are handles and being consumed by the given op.
LogicalResult verifyParamProducerTransformOpTrait(Operation *op)
Non-template implementation of ParamProducerTransformOpTrait::verify().
void prepareValueMappings(SmallVectorImpl< SmallVector< transform::MappedValue >> &mappings, ValueRange values, const transform::TransformState &state)
Populates mappings with mapped values associated with the given transform IR values in the given stat...
LogicalResult verifyPossibleTopLevelTransformOpTrait(Operation *op)
Verification hook for PossibleTopLevelTransformOpTrait.
void getPotentialTopLevelEffects(Operation *operation, Value root, Block &body, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
Populates effects with side effects implied by PossibleTopLevelTransformOpTrait for the given operati...
void onlyReadsPayload(SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
bool isHandleConsumed(Value handle, transform::TransformOpInterface transform)
Checks whether the transform op consumes the given handle.
void onlyReadsHandle(ValueRange handles, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
void getConsumedBlockArguments(Block &block, llvm::SmallDenseSet< unsigned > &consumedArguments)
Populates consumedArguments with positions of block arguments that are consumed by the operations in ...
LogicalResult applyTransforms(Operation *payloadRoot, TransformOpInterface transform, const RaggedArray< MappedValue > &extraMapping={}, const TransformOptions &options=TransformOptions(), bool enforceToplevelTransformOp=true)
Entry point to the Transform dialect infrastructure.
bool doesModifyPayload(transform::TransformOpInterface transform)
Checks whether the transform op modifies the payload.
void consumesHandle(ValueRange handles, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
Populates effects with the memory effects indicating the operation on the given handle value:
bool doesReadPayload(transform::TransformOpInterface transform)
Checks whether the transform op reads the payload.
void producesHandle(ValueRange handles, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
void modifiesPayload(SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
Populates effects with the memory effects indicating the access to payload IR resource.
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
bool hasEffect(Operation *op, Value value=nullptr)
Returns true if op has an effect of type EffectTy on value.
DiagnosedSilenceableFailure emitSilenceableFailure(Location loc, const Twine &message={})
Emits a silenceable failure with the given message.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
DiagnosedDefiniteFailure emitDefiniteFailure(Location loc, const Twine &message={})
Emits a definite failure with the given message.
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
Represents a range (offset, size, and stride) where each element of the triple may be dynamic or stat...
A configuration object for customizing a TrackingListener.
SkipHandleFn skipHandleFn
An optional function that returns "true" for handles that do not have to be updated.