MLIR  21.0.0git
TransformInterfaces.cpp
Go to the documentation of this file.
1 //===- TransformInterfaces.cpp - Transform Dialect Interfaces -------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
11 #include "mlir/IR/Diagnostics.h"
12 #include "mlir/IR/Operation.h"
13 #include "mlir/IR/PatternMatch.h"
16 #include "llvm/ADT/STLExtras.h"
17 #include "llvm/ADT/ScopeExit.h"
18 #include "llvm/ADT/iterator.h"
19 #include "llvm/Support/Debug.h"
20 #include "llvm/Support/ErrorHandling.h"
21 #include "llvm/Support/InterleavedRange.h"
22 
23 #define DEBUG_TYPE "transform-dialect"
24 #define DEBUG_TYPE_FULL "transform-dialect-full"
25 #define DEBUG_PRINT_AFTER_ALL "transform-dialect-print-top-level-after-all"
26 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "] ")
27 #define LDBG(X) LLVM_DEBUG(DBGS() << (X))
28 #define FULL_LDBG(X) DEBUG_WITH_TYPE(DEBUG_TYPE_FULL, (DBGS() << (X)))
29 
30 using namespace mlir;
31 
32 //===----------------------------------------------------------------------===//
33 // Helper functions
34 //===----------------------------------------------------------------------===//
35 
36 /// Return true if `a` happens before `b`, i.e., `a` or one of its ancestors
37 /// properly dominates `b` and `b` is not inside `a`.
38 static bool happensBefore(Operation *a, Operation *b) {
39  do {
40  if (a->isProperAncestor(b))
41  return false;
42  if (Operation *bAncestor = a->getBlock()->findAncestorOpInBlock(*b)) {
43  return a->isBeforeInBlock(bAncestor);
44  }
45  } while ((a = a->getParentOp()));
46  return false;
47 }
48 
49 //===----------------------------------------------------------------------===//
50 // TransformState
51 //===----------------------------------------------------------------------===//
52 
53 constexpr const Value transform::TransformState::kTopLevelValue;
54 
55 transform::TransformState::TransformState(
56  Region *region, Operation *payloadRoot,
57  const RaggedArray<MappedValue> &extraMappings,
58  const TransformOptions &options)
59  : topLevel(payloadRoot), options(options) {
60  topLevelMappedValues.reserve(extraMappings.size());
61  for (ArrayRef<MappedValue> mapping : extraMappings)
62  topLevelMappedValues.push_back(mapping);
63  if (region) {
64  RegionScope *scope = new RegionScope(*this, *region);
65  topLevelRegionScope.reset(scope);
66  }
67 }
68 
69 Operation *transform::TransformState::getTopLevel() const { return topLevel; }
70 
72 transform::TransformState::getPayloadOpsView(Value value) const {
73  const TransformOpMapping &operationMapping = getMapping(value).direct;
74  auto iter = operationMapping.find(value);
75  assert(iter != operationMapping.end() &&
76  "cannot find mapping for payload handle (param/value handle "
77  "provided?)");
78  return iter->getSecond();
79 }
80 
82  const ParamMapping &mapping = getMapping(value).params;
83  auto iter = mapping.find(value);
84  assert(iter != mapping.end() && "cannot find mapping for param handle "
85  "(operation/value handle provided?)");
86  return iter->getSecond();
87 }
88 
90 transform::TransformState::getPayloadValuesView(Value handleValue) const {
91  const ValueMapping &mapping = getMapping(handleValue).values;
92  auto iter = mapping.find(handleValue);
93  assert(iter != mapping.end() && "cannot find mapping for value handle "
94  "(param/operation handle provided?)");
95  return iter->getSecond();
96 }
97 
99  Operation *op, SmallVectorImpl<Value> &handles,
100  bool includeOutOfScope) const {
101  bool found = false;
102  for (const auto &[region, mapping] : llvm::reverse(mappings)) {
103  auto iterator = mapping->reverse.find(op);
104  if (iterator != mapping->reverse.end()) {
105  llvm::append_range(handles, iterator->getSecond());
106  found = true;
107  }
108  // Stop looking when reaching a region that is isolated from above.
109  if (!includeOutOfScope &&
111  break;
112  }
113 
114  return success(found);
115 }
116 
118  Value payloadValue, SmallVectorImpl<Value> &handles,
119  bool includeOutOfScope) const {
120  bool found = false;
121  for (const auto &[region, mapping] : llvm::reverse(mappings)) {
122  auto iterator = mapping->reverseValues.find(payloadValue);
123  if (iterator != mapping->reverseValues.end()) {
124  llvm::append_range(handles, iterator->getSecond());
125  found = true;
126  }
127  // Stop looking when reaching a region that is isolated from above.
128  if (!includeOutOfScope &&
130  break;
131  }
132 
133  return success(found);
134 }
135 
136 /// Given a list of MappedValues, cast them to the value kind implied by the
137 /// interface of the handle type, and dispatch to one of the callbacks.
140  function_ref<LogicalResult(ArrayRef<Operation *>)> operationsFn,
141  function_ref<LogicalResult(ArrayRef<transform::Param>)> paramsFn,
142  function_ref<LogicalResult(ValueRange)> valuesFn) {
143  if (llvm::isa<transform::TransformHandleTypeInterface>(handle.getType())) {
144  SmallVector<Operation *> operations;
145  operations.reserve(values.size());
146  for (transform::MappedValue value : values) {
147  if (auto *op = llvm::dyn_cast_if_present<Operation *>(value)) {
148  operations.push_back(op);
149  continue;
150  }
151  return emitSilenceableFailure(handle.getLoc())
152  << "wrong kind of value provided for top-level operation handle";
153  }
154  if (failed(operationsFn(operations)))
157  }
158 
159  if (llvm::isa<transform::TransformValueHandleTypeInterface>(
160  handle.getType())) {
161  SmallVector<Value> payloadValues;
162  payloadValues.reserve(values.size());
163  for (transform::MappedValue value : values) {
164  if (auto v = llvm::dyn_cast_if_present<Value>(value)) {
165  payloadValues.push_back(v);
166  continue;
167  }
168  return emitSilenceableFailure(handle.getLoc())
169  << "wrong kind of value provided for the top-level value handle";
170  }
171  if (failed(valuesFn(payloadValues)))
174  }
175 
176  assert(llvm::isa<transform::TransformParamTypeInterface>(handle.getType()) &&
177  "unsupported kind of block argument");
179  parameters.reserve(values.size());
180  for (transform::MappedValue value : values) {
181  if (auto attr = llvm::dyn_cast_if_present<Attribute>(value)) {
182  parameters.push_back(attr);
183  continue;
184  }
185  return emitSilenceableFailure(handle.getLoc())
186  << "wrong kind of value provided for top-level parameter";
187  }
188  if (failed(paramsFn(parameters)))
191 }
192 
193 LogicalResult
195  ArrayRef<MappedValue> values) {
196  return dispatchMappedValues(
197  argument, values,
198  [&](ArrayRef<Operation *> operations) {
199  return setPayloadOps(argument, operations);
200  },
201  [&](ArrayRef<Param> params) {
202  return setParams(argument, params);
203  },
204  [&](ValueRange payloadValues) {
205  return setPayloadValues(argument, payloadValues);
206  })
207  .checkAndReport();
208 }
209 
211  Block::BlockArgListType arguments,
213  for (auto &&[argument, values] : llvm::zip_equal(arguments, mapping))
214  if (failed(mapBlockArgument(argument, values)))
215  return failure();
216  return success();
217 }
218 
219 LogicalResult
220 transform::TransformState::setPayloadOps(Value value,
221  ArrayRef<Operation *> targets) {
222  assert(value != kTopLevelValue &&
223  "attempting to reset the transformation root");
224  assert(llvm::isa<TransformHandleTypeInterface>(value.getType()) &&
225  "wrong handle type");
226 
227  for (Operation *target : targets) {
228  if (target)
229  continue;
230  return emitError(value.getLoc())
231  << "attempting to assign a null payload op to this transform value";
232  }
233 
234  auto iface = llvm::cast<TransformHandleTypeInterface>(value.getType());
236  iface.checkPayload(value.getLoc(), targets);
237  if (failed(result.checkAndReport()))
238  return failure();
239 
240  // Setting new payload for the value without cleaning it first is a misuse of
241  // the API, assert here.
242  SmallVector<Operation *> storedTargets(targets);
243  Mappings &mappings = getMapping(value);
244  bool inserted =
245  mappings.direct.insert({value, std::move(storedTargets)}).second;
246  assert(inserted && "value is already associated with another list");
247  (void)inserted;
248 
249  for (Operation *op : targets)
250  mappings.reverse[op].push_back(value);
251 
252  return success();
253 }
254 
255 LogicalResult
256 transform::TransformState::setPayloadValues(Value handle,
257  ValueRange payloadValues) {
258  assert(handle != nullptr && "attempting to set params for a null value");
259  assert(llvm::isa<TransformValueHandleTypeInterface>(handle.getType()) &&
260  "wrong handle type");
261 
262  for (Value payload : payloadValues) {
263  if (payload)
264  continue;
265  return emitError(handle.getLoc()) << "attempting to assign a null payload "
266  "value to this transform handle";
267  }
268 
269  auto iface = llvm::cast<TransformValueHandleTypeInterface>(handle.getType());
270  SmallVector<Value> payloadValueVector = llvm::to_vector(payloadValues);
272  iface.checkPayload(handle.getLoc(), payloadValueVector);
273  if (failed(result.checkAndReport()))
274  return failure();
275 
276  Mappings &mappings = getMapping(handle);
277  bool inserted =
278  mappings.values.insert({handle, std::move(payloadValueVector)}).second;
279  assert(
280  inserted &&
281  "value handle is already associated with another list of payload values");
282  (void)inserted;
283 
284  for (Value payload : payloadValues)
285  mappings.reverseValues[payload].push_back(handle);
286 
287  return success();
288 }
289 
290 LogicalResult transform::TransformState::setParams(Value value,
291  ArrayRef<Param> params) {
292  assert(value != nullptr && "attempting to set params for a null value");
293 
294  for (Attribute attr : params) {
295  if (attr)
296  continue;
297  return emitError(value.getLoc())
298  << "attempting to assign a null parameter to this transform value";
299  }
300 
301  auto valueType = llvm::dyn_cast<TransformParamTypeInterface>(value.getType());
302  assert(value &&
303  "cannot associate parameter with a value of non-parameter type");
305  valueType.checkPayload(value.getLoc(), params);
306  if (failed(result.checkAndReport()))
307  return failure();
308 
309  Mappings &mappings = getMapping(value);
310  bool inserted =
311  mappings.params.insert({value, llvm::to_vector(params)}).second;
312  assert(inserted && "value is already associated with another list of params");
313  (void)inserted;
314  return success();
315 }
316 
317 template <typename Mapping, typename Key, typename Mapped>
318 void dropMappingEntry(Mapping &mapping, Key key, Mapped mapped) {
319  auto it = mapping.find(key);
320  if (it == mapping.end())
321  return;
322 
323  llvm::erase(it->getSecond(), mapped);
324  if (it->getSecond().empty())
325  mapping.erase(it);
326 }
327 
328 void transform::TransformState::forgetMapping(Value opHandle,
329  ValueRange origOpFlatResults,
330  bool allowOutOfScope) {
331  Mappings &mappings = getMapping(opHandle, allowOutOfScope);
332  for (Operation *op : mappings.direct[opHandle])
333  dropMappingEntry(mappings.reverse, op, opHandle);
334  mappings.direct.erase(opHandle);
335 #if LLVM_ENABLE_ABI_BREAKING_CHECKS
336  // Payload IR is removed from the mapping. This invalidates the respective
337  // iterators.
338  mappings.incrementTimestamp(opHandle);
339 #endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
340 
341  for (Value opResult : origOpFlatResults) {
342  SmallVector<Value> resultHandles;
343  (void)getHandlesForPayloadValue(opResult, resultHandles);
344  for (Value resultHandle : resultHandles) {
345  Mappings &localMappings = getMapping(resultHandle);
346  dropMappingEntry(localMappings.values, resultHandle, opResult);
347 #if LLVM_ENABLE_ABI_BREAKING_CHECKS
348  // Payload IR is removed from the mapping. This invalidates the respective
349  // iterators.
350  mappings.incrementTimestamp(resultHandle);
351 #endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
352  dropMappingEntry(localMappings.reverseValues, opResult, resultHandle);
353  }
354  }
355 }
356 
357 void transform::TransformState::forgetValueMapping(
358  Value valueHandle, ArrayRef<Operation *> payloadOperations) {
359  Mappings &mappings = getMapping(valueHandle);
360  for (Value payloadValue : mappings.reverseValues[valueHandle])
361  dropMappingEntry(mappings.reverseValues, payloadValue, valueHandle);
362  mappings.values.erase(valueHandle);
363 #if LLVM_ENABLE_ABI_BREAKING_CHECKS
364  // Payload IR is removed from the mapping. This invalidates the respective
365  // iterators.
366  mappings.incrementTimestamp(valueHandle);
367 #endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
368 
369  for (Operation *payloadOp : payloadOperations) {
370  SmallVector<Value> opHandles;
371  (void)getHandlesForPayloadOp(payloadOp, opHandles);
372  for (Value opHandle : opHandles) {
373  Mappings &localMappings = getMapping(opHandle);
374  dropMappingEntry(localMappings.direct, opHandle, payloadOp);
375  dropMappingEntry(localMappings.reverse, payloadOp, opHandle);
376 
377 #if LLVM_ENABLE_ABI_BREAKING_CHECKS
378  // Payload IR is removed from the mapping. This invalidates the respective
379  // iterators.
380  localMappings.incrementTimestamp(opHandle);
381 #endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
382  }
383  }
384 }
385 
386 LogicalResult
387 transform::TransformState::replacePayloadOp(Operation *op,
388  Operation *replacement) {
389  // TODO: consider invalidating the handles to nested objects here.
390 
391 #ifndef NDEBUG
392  for (Value opResult : op->getResults()) {
393  SmallVector<Value> valueHandles;
394  (void)getHandlesForPayloadValue(opResult, valueHandles,
395  /*includeOutOfScope=*/true);
396  assert(valueHandles.empty() && "expected no mapping to old results");
397  }
398 #endif // NDEBUG
399 
400  // Drop the mapping between the op and all handles that point to it. Fail if
401  // there are no handles.
402  SmallVector<Value> opHandles;
403  if (failed(getHandlesForPayloadOp(op, opHandles, /*includeOutOfScope=*/true)))
404  return failure();
405  for (Value handle : opHandles) {
406  Mappings &mappings = getMapping(handle, /*allowOutOfScope=*/true);
407  dropMappingEntry(mappings.reverse, op, handle);
408  }
409 
410  // Replace the pointed-to object of all handles with the replacement object.
411  // In case a payload op was erased (replacement object is nullptr), a nullptr
412  // is stored in the mapping. These nullptrs are removed after each transform.
413  // Furthermore, nullptrs are not enumerated by payload op iterators. The
414  // relative order of ops is preserved.
415  //
416  // Removing an op from the mapping would be problematic because removing an
417  // element from an array invalidates iterators; merely changing the value of
418  // elements does not.
419  for (Value handle : opHandles) {
420  Mappings &mappings = getMapping(handle, /*allowOutOfScope=*/true);
421  auto it = mappings.direct.find(handle);
422  if (it == mappings.direct.end())
423  continue;
424 
425  SmallVector<Operation *, 2> &association = it->getSecond();
426  // Note that an operation may be associated with the handle more than once.
427  for (Operation *&mapped : association) {
428  if (mapped == op)
429  mapped = replacement;
430  }
431 
432  if (replacement) {
433  mappings.reverse[replacement].push_back(handle);
434  } else {
435  opHandlesToCompact.insert(handle);
436  }
437  }
438 
439  return success();
440 }
441 
442 LogicalResult
443 transform::TransformState::replacePayloadValue(Value value, Value replacement) {
444  SmallVector<Value> valueHandles;
445  if (failed(getHandlesForPayloadValue(value, valueHandles,
446  /*includeOutOfScope=*/true)))
447  return failure();
448 
449  for (Value handle : valueHandles) {
450  Mappings &mappings = getMapping(handle, /*allowOutOfScope=*/true);
451  dropMappingEntry(mappings.reverseValues, value, handle);
452 
453  // If replacing with null, that is erasing the mapping, drop the mapping
454  // between the handles and the IR objects
455  if (!replacement) {
456  dropMappingEntry(mappings.values, handle, value);
457 #if LLVM_ENABLE_ABI_BREAKING_CHECKS
458  // Payload IR is removed from the mapping. This invalidates the respective
459  // iterators.
460  mappings.incrementTimestamp(handle);
461 #endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
462  } else {
463  auto it = mappings.values.find(handle);
464  if (it == mappings.values.end())
465  continue;
466 
467  SmallVector<Value> &association = it->getSecond();
468  for (Value &mapped : association) {
469  if (mapped == value)
470  mapped = replacement;
471  }
472  mappings.reverseValues[replacement].push_back(handle);
473  }
474  }
475 
476  return success();
477 }
478 
479 void transform::TransformState::recordOpHandleInvalidationOne(
480  OpOperand &consumingHandle, ArrayRef<Operation *> potentialAncestors,
481  Operation *payloadOp, Value otherHandle, Value throughValue,
482  transform::TransformState::InvalidatedHandleMap &newlyInvalidated) const {
483  // If the op is associated with invalidated handle, skip the check as it
484  // may be reading invalid IR. This also ensures we report the first
485  // invalidation and not the last one.
486  if (invalidatedHandles.count(otherHandle) ||
487  newlyInvalidated.count(otherHandle))
488  return;
489 
490  FULL_LDBG("--recordOpHandleInvalidationOne\n");
491  DEBUG_WITH_TYPE(DEBUG_TYPE_FULL, {
492  (DBGS() << "--ancestors: "
493  << llvm::interleaved(llvm::make_pointee_range(potentialAncestors))
494  << "\n");
495  });
496 
497  Operation *owner = consumingHandle.getOwner();
498  unsigned operandNo = consumingHandle.getOperandNumber();
499  for (Operation *ancestor : potentialAncestors) {
500  // clang-format off
501  DEBUG_WITH_TYPE(DEBUG_TYPE_FULL,
502  { (DBGS() << "----handle one ancestor: " << *ancestor << "\n"); });
503  DEBUG_WITH_TYPE(DEBUG_TYPE_FULL,
504  { (DBGS() << "----of payload with name: "
505  << payloadOp->getName().getIdentifier() << "\n"); });
506  DEBUG_WITH_TYPE(DEBUG_TYPE_FULL,
507  { (DBGS() << "----of payload: " << *payloadOp << "\n"); });
508  // clang-format on
509  if (!ancestor->isAncestor(payloadOp))
510  continue;
511 
512  // Make sure the error-reporting lambda doesn't capture anything
513  // by-reference because it will go out of scope. Additionally, extract
514  // location from Payload IR ops because the ops themselves may be
515  // deleted before the lambda gets called.
516  Location ancestorLoc = ancestor->getLoc();
517  Location opLoc = payloadOp->getLoc();
518  std::optional<Location> throughValueLoc =
519  throughValue ? std::make_optional(throughValue.getLoc()) : std::nullopt;
520  newlyInvalidated[otherHandle] = [ancestorLoc, opLoc, owner, operandNo,
521  otherHandle,
522  throughValueLoc](Location currentLoc) {
523  InFlightDiagnostic diag = emitError(currentLoc)
524  << "op uses a handle invalidated by a "
525  "previously executed transform op";
526  diag.attachNote(otherHandle.getLoc()) << "handle to invalidated ops";
527  diag.attachNote(owner->getLoc())
528  << "invalidated by this transform op that consumes its operand #"
529  << operandNo
530  << " and invalidates all handles to payload IR entities associated "
531  "with this operand and entities nested in them";
532  diag.attachNote(ancestorLoc) << "ancestor payload op";
533  diag.attachNote(opLoc) << "nested payload op";
534  if (throughValueLoc) {
535  diag.attachNote(*throughValueLoc)
536  << "consumed handle points to this payload value";
537  }
538  };
539  }
540 }
541 
542 void transform::TransformState::recordValueHandleInvalidationByOpHandleOne(
543  OpOperand &opHandle, ArrayRef<Operation *> potentialAncestors,
544  Value payloadValue, Value valueHandle,
545  transform::TransformState::InvalidatedHandleMap &newlyInvalidated) const {
546  // If the op is associated with invalidated handle, skip the check as it
547  // may be reading invalid IR. This also ensures we report the first
548  // invalidation and not the last one.
549  if (invalidatedHandles.count(valueHandle) ||
550  newlyInvalidated.count(valueHandle))
551  return;
552 
553  for (Operation *ancestor : potentialAncestors) {
554  Operation *definingOp;
555  std::optional<unsigned> resultNo;
556  unsigned argumentNo = std::numeric_limits<unsigned>::max();
557  unsigned blockNo = std::numeric_limits<unsigned>::max();
558  unsigned regionNo = std::numeric_limits<unsigned>::max();
559  if (auto opResult = llvm::dyn_cast<OpResult>(payloadValue)) {
560  definingOp = opResult.getOwner();
561  resultNo = opResult.getResultNumber();
562  } else {
563  auto arg = llvm::cast<BlockArgument>(payloadValue);
564  definingOp = arg.getParentBlock()->getParentOp();
565  argumentNo = arg.getArgNumber();
566  blockNo = std::distance(arg.getOwner()->getParent()->begin(),
567  arg.getOwner()->getIterator());
568  regionNo = arg.getOwner()->getParent()->getRegionNumber();
569  }
570  assert(definingOp && "expected the value to be defined by an op as result "
571  "or block argument");
572  if (!ancestor->isAncestor(definingOp))
573  continue;
574 
575  Operation *owner = opHandle.getOwner();
576  unsigned operandNo = opHandle.getOperandNumber();
577  Location ancestorLoc = ancestor->getLoc();
578  Location opLoc = definingOp->getLoc();
579  Location valueLoc = payloadValue.getLoc();
580  newlyInvalidated[valueHandle] = [valueHandle, owner, operandNo, resultNo,
581  argumentNo, blockNo, regionNo, ancestorLoc,
582  opLoc, valueLoc](Location currentLoc) {
583  InFlightDiagnostic diag = emitError(currentLoc)
584  << "op uses a handle invalidated by a "
585  "previously executed transform op";
586  diag.attachNote(valueHandle.getLoc()) << "invalidated handle";
587  diag.attachNote(owner->getLoc())
588  << "invalidated by this transform op that consumes its operand #"
589  << operandNo
590  << " and invalidates all handles to payload IR entities "
591  "associated with this operand and entities nested in them";
592  diag.attachNote(ancestorLoc)
593  << "ancestor op associated with the consumed handle";
594  if (resultNo) {
595  diag.attachNote(opLoc)
596  << "op defining the value as result #" << *resultNo;
597  } else {
598  diag.attachNote(opLoc)
599  << "op defining the value as block argument #" << argumentNo
600  << " of block #" << blockNo << " in region #" << regionNo;
601  }
602  diag.attachNote(valueLoc) << "payload value";
603  };
604  }
605 }
606 
607 void transform::TransformState::recordOpHandleInvalidation(
608  OpOperand &handle, ArrayRef<Operation *> potentialAncestors,
609  Value throughValue,
610  transform::TransformState::InvalidatedHandleMap &newlyInvalidated) const {
611 
612  if (potentialAncestors.empty()) {
613  DEBUG_WITH_TYPE(DEBUG_TYPE_FULL, {
614  (DBGS() << "----recording invalidation for empty handle: " << handle.get()
615  << "\n");
616  });
617 
618  Operation *owner = handle.getOwner();
619  unsigned operandNo = handle.getOperandNumber();
620  newlyInvalidated[handle.get()] = [owner, operandNo](Location currentLoc) {
621  InFlightDiagnostic diag = emitError(currentLoc)
622  << "op uses a handle associated with empty "
623  "payload and invalidated by a "
624  "previously executed transform op";
625  diag.attachNote(owner->getLoc())
626  << "invalidated by this transform op that consumes its operand #"
627  << operandNo;
628  };
629  return;
630  }
631 
632  // Iterate over the mapping and invalidate aliasing handles. This is quite
633  // expensive and only necessary for error reporting in case of transform
634  // dialect misuse with dangling handles. Iteration over the handles is based
635  // on the assumption that the number of handles is significantly less than the
636  // number of IR objects (operations and values). Alternatively, we could walk
637  // the IR nested in each payload op associated with the given handle and look
638  // for handles associated with each operation and value.
639  for (const auto &[region, mapping] : llvm::reverse(mappings)) {
640  // Go over all op handle mappings and mark as invalidated any handle
641  // pointing to any of the payload ops associated with the given handle or
642  // any op nested in them.
643  for (const auto &[payloadOp, otherHandles] : mapping->reverse) {
644  for (Value otherHandle : otherHandles)
645  recordOpHandleInvalidationOne(handle, potentialAncestors, payloadOp,
646  otherHandle, throughValue,
647  newlyInvalidated);
648  }
649  // Go over all value handle mappings and mark as invalidated any handle
650  // pointing to any result of the payload op associated with the given handle
651  // or any op nested in them. Similarly invalidate handles to argument of
652  // blocks belonging to any region of any payload op associated with the
653  // given handle or any op nested in them.
654  for (const auto &[payloadValue, valueHandles] : mapping->reverseValues) {
655  for (Value valueHandle : valueHandles)
656  recordValueHandleInvalidationByOpHandleOne(handle, potentialAncestors,
657  payloadValue, valueHandle,
658  newlyInvalidated);
659  }
660 
661  // Stop lookup when reaching a region that is isolated from above.
663  break;
664  }
665 }
666 
667 void transform::TransformState::recordValueHandleInvalidation(
668  OpOperand &valueHandle,
669  transform::TransformState::InvalidatedHandleMap &newlyInvalidated) const {
670  // Invalidate other handles to the same value.
671  for (Value payloadValue : getPayloadValuesView(valueHandle.get())) {
672  SmallVector<Value> otherValueHandles;
673  (void)getHandlesForPayloadValue(payloadValue, otherValueHandles);
674  for (Value otherHandle : otherValueHandles) {
675  Operation *owner = valueHandle.getOwner();
676  unsigned operandNo = valueHandle.getOperandNumber();
677  Location valueLoc = payloadValue.getLoc();
678  newlyInvalidated[otherHandle] = [otherHandle, owner, operandNo,
679  valueLoc](Location currentLoc) {
680  InFlightDiagnostic diag = emitError(currentLoc)
681  << "op uses a handle invalidated by a "
682  "previously executed transform op";
683  diag.attachNote(otherHandle.getLoc()) << "invalidated handle";
684  diag.attachNote(owner->getLoc())
685  << "invalidated by this transform op that consumes its operand #"
686  << operandNo
687  << " and invalidates handles to the same values as associated with "
688  "it";
689  diag.attachNote(valueLoc) << "payload value";
690  };
691  }
692 
693  if (auto opResult = llvm::dyn_cast<OpResult>(payloadValue)) {
694  Operation *payloadOp = opResult.getOwner();
695  recordOpHandleInvalidation(valueHandle, payloadOp, payloadValue,
696  newlyInvalidated);
697  } else {
698  auto arg = llvm::dyn_cast<BlockArgument>(payloadValue);
699  for (Operation &payloadOp : *arg.getOwner())
700  recordOpHandleInvalidation(valueHandle, &payloadOp, payloadValue,
701  newlyInvalidated);
702  }
703  }
704 }
705 
706 /// Checks that the operation does not use invalidated handles as operands.
707 /// Reports errors and returns failure if it does. Otherwise, invalidates the
708 /// handles consumed by the operation as well as any handles pointing to payload
709 /// IR operations nested in the operations associated with the consumed handles.
710 LogicalResult transform::TransformState::checkAndRecordHandleInvalidationImpl(
711  transform::TransformOpInterface transform,
712  transform::TransformState::InvalidatedHandleMap &newlyInvalidated) const {
713  FULL_LDBG("--Start checkAndRecordHandleInvalidation\n");
714  auto memoryEffectsIface =
715  cast<MemoryEffectOpInterface>(transform.getOperation());
717  memoryEffectsIface.getEffectsOnResource(
719 
720  for (OpOperand &target : transform->getOpOperands()) {
721  DEBUG_WITH_TYPE(DEBUG_TYPE_FULL, {
722  (DBGS() << "----iterate on handle: " << target.get() << "\n");
723  });
724  // If the operand uses an invalidated handle, report it. If the operation
725  // allows handles to point to repeated payload operations, only report
726  // pre-existing invalidation errors. Otherwise, also report invalidations
727  // caused by the current transform operation affecting its other operands.
728  auto it = invalidatedHandles.find(target.get());
729  auto nit = newlyInvalidated.find(target.get());
730  if (it != invalidatedHandles.end()) {
731  FULL_LDBG("--End checkAndRecordHandleInvalidation, found already "
732  "invalidated -> FAILURE\n");
733  return it->getSecond()(transform->getLoc()), failure();
734  }
735  if (!transform.allowsRepeatedHandleOperands() &&
736  nit != newlyInvalidated.end()) {
737  FULL_LDBG("--End checkAndRecordHandleInvalidation, found newly "
738  "invalidated (by this op) -> FAILURE\n");
739  return nit->getSecond()(transform->getLoc()), failure();
740  }
741 
742  // Invalidate handles pointing to the operations nested in the operation
743  // associated with the handle consumed by this operation.
744  auto consumesTarget = [&](const MemoryEffects::EffectInstance &effect) {
745  return isa<MemoryEffects::Free>(effect.getEffect()) &&
746  effect.getValue() == target.get();
747  };
748  if (llvm::any_of(effects, consumesTarget)) {
749  FULL_LDBG("----found consume effect\n");
750  if (llvm::isa<transform::TransformHandleTypeInterface>(
751  target.get().getType())) {
752  FULL_LDBG("----recordOpHandleInvalidation\n");
753  SmallVector<Operation *> payloadOps =
754  llvm::to_vector(getPayloadOps(target.get()));
755  recordOpHandleInvalidation(target, payloadOps, nullptr,
756  newlyInvalidated);
757  } else if (llvm::isa<transform::TransformValueHandleTypeInterface>(
758  target.get().getType())) {
759  FULL_LDBG("----recordValueHandleInvalidation\n");
760  recordValueHandleInvalidation(target, newlyInvalidated);
761  } else {
762  FULL_LDBG("----not a TransformHandle -> SKIP AND DROP ON THE FLOOR\n");
763  }
764  } else {
765  FULL_LDBG("----no consume effect -> SKIP\n");
766  }
767  }
768 
769  FULL_LDBG("--End checkAndRecordHandleInvalidation -> SUCCESS\n");
770  return success();
771 }
772 
773 LogicalResult transform::TransformState::checkAndRecordHandleInvalidation(
774  transform::TransformOpInterface transform) {
775  InvalidatedHandleMap newlyInvalidated;
776  LogicalResult checkResult =
777  checkAndRecordHandleInvalidationImpl(transform, newlyInvalidated);
778  invalidatedHandles.insert(std::make_move_iterator(newlyInvalidated.begin()),
779  std::make_move_iterator(newlyInvalidated.end()));
780  return checkResult;
781 }
782 
783 template <typename T>
786  transform::TransformOpInterface transform,
787  unsigned operandNumber) {
788  DenseSet<T> seen;
789  for (T p : payload) {
790  if (!seen.insert(p).second) {
792  transform.emitSilenceableError()
793  << "a handle passed as operand #" << operandNumber
794  << " and consumed by this operation points to a payload "
795  "entity more than once";
796  if constexpr (std::is_pointer_v<T>)
797  diag.attachNote(p->getLoc()) << "repeated target op";
798  else
799  diag.attachNote(p.getLoc()) << "repeated target value";
800  return diag;
801  }
802  }
804 }
805 
806 void transform::TransformState::compactOpHandles() {
807  for (Value handle : opHandlesToCompact) {
808  Mappings &mappings = getMapping(handle, /*allowOutOfScope=*/true);
809 #if LLVM_ENABLE_ABI_BREAKING_CHECKS
810  if (llvm::is_contained(mappings.direct[handle], nullptr))
811  // Payload IR is removed from the mapping. This invalidates the respective
812  // iterators.
813  mappings.incrementTimestamp(handle);
814 #endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
815  llvm::erase(mappings.direct[handle], nullptr);
816  }
817  opHandlesToCompact.clear();
818 }
819 
821 transform::TransformState::applyTransform(TransformOpInterface transform) {
822  LLVM_DEBUG({
823  DBGS() << "applying: ";
824  transform->print(llvm::dbgs(), OpPrintingFlags().skipRegions());
825  llvm::dbgs() << "\n";
826  });
827  DEBUG_WITH_TYPE(DEBUG_TYPE_FULL,
828  DBGS() << "Top-level payload before application:\n"
829  << *getTopLevel() << "\n");
830  auto printOnFailureRAII = llvm::make_scope_exit([this] {
831  (void)this;
832  LLVM_DEBUG(DBGS() << "Failing Top-level payload:\n"; getTopLevel()->print(
833  llvm::dbgs(), mlir::OpPrintingFlags().printGenericOpForm()););
834  });
835 
836  // Set current transform op.
837  regionStack.back()->currentTransform = transform;
838 
839  // Expensive checks to detect invalid transform IR.
840  if (options.getExpensiveChecksEnabled()) {
841  FULL_LDBG("ExpensiveChecksEnabled\n");
842  if (failed(checkAndRecordHandleInvalidation(transform)))
844 
845  for (OpOperand &operand : transform->getOpOperands()) {
846  DEBUG_WITH_TYPE(DEBUG_TYPE_FULL, {
847  (DBGS() << "iterate on handle: " << operand.get() << "\n");
848  });
849  if (!isHandleConsumed(operand.get(), transform)) {
850  FULL_LDBG("--handle not consumed -> SKIP\n");
851  continue;
852  }
853  if (transform.allowsRepeatedHandleOperands()) {
854  FULL_LDBG("--op allows repeated handles -> SKIP\n");
855  continue;
856  }
857  FULL_LDBG("--handle is consumed\n");
858 
859  Type operandType = operand.get().getType();
860  if (llvm::isa<TransformHandleTypeInterface>(operandType)) {
861  FULL_LDBG("--checkRepeatedConsumptionInOperand for Operation*\n");
863  checkRepeatedConsumptionInOperand<Operation *>(
864  getPayloadOpsView(operand.get()), transform,
865  operand.getOperandNumber());
866  if (!check.succeeded()) {
867  FULL_LDBG("----FAILED\n");
868  return check;
869  }
870  } else if (llvm::isa<TransformValueHandleTypeInterface>(operandType)) {
871  FULL_LDBG("--checkRepeatedConsumptionInOperand For Value\n");
873  checkRepeatedConsumptionInOperand<Value>(
874  getPayloadValuesView(operand.get()), transform,
875  operand.getOperandNumber());
876  if (!check.succeeded()) {
877  FULL_LDBG("----FAILED\n");
878  return check;
879  }
880  } else {
881  FULL_LDBG("--not a TransformHandle -> SKIP AND DROP ON THE FLOOR\n");
882  }
883  }
884  }
885 
886  // Find which operands are consumed.
887  SmallVector<OpOperand *> consumedOperands =
888  transform.getConsumedHandleOpOperands();
889 
890  // Remember the results of the payload ops associated with the consumed
891  // op handles or the ops defining the value handles so we can drop the
892  // association with them later. This must happen here because the
893  // transformation may destroy or mutate them so we cannot traverse the payload
894  // IR after that.
895  SmallVector<Value> origOpFlatResults;
896  SmallVector<Operation *> origAssociatedOps;
897  for (OpOperand *opOperand : consumedOperands) {
898  Value operand = opOperand->get();
899  if (llvm::isa<TransformHandleTypeInterface>(operand.getType())) {
900  for (Operation *payloadOp : getPayloadOps(operand)) {
901  llvm::append_range(origOpFlatResults, payloadOp->getResults());
902  }
903  continue;
904  }
905  if (llvm::isa<TransformValueHandleTypeInterface>(operand.getType())) {
906  for (Value payloadValue : getPayloadValuesView(operand)) {
907  if (llvm::isa<OpResult>(payloadValue)) {
908  origAssociatedOps.push_back(payloadValue.getDefiningOp());
909  continue;
910  }
911  llvm::append_range(
912  origAssociatedOps,
913  llvm::map_range(*llvm::cast<BlockArgument>(payloadValue).getOwner(),
914  [](Operation &op) { return &op; }));
915  }
916  continue;
917  }
919  emitDefiniteFailure(transform->getLoc())
920  << "unexpectedly consumed a value that is not a handle as operand #"
921  << opOperand->getOperandNumber();
922  diag.attachNote(operand.getLoc())
923  << "value defined here with type " << operand.getType();
924  return diag;
925  }
926 
927  // Prepare rewriter and listener.
929  config.skipHandleFn = [&](Value handle) {
930  // Skip handle if it is dead.
931  auto scopeIt =
932  llvm::find_if(llvm::reverse(regionStack), [&](RegionScope *scope) {
933  return handle.getParentRegion() == scope->region;
934  });
935  assert(scopeIt != regionStack.rend() &&
936  "could not find region scope for handle");
937  RegionScope *scope = *scopeIt;
938  return llvm::all_of(handle.getUsers(), [&](Operation *user) {
939  return user == scope->currentTransform ||
940  happensBefore(user, scope->currentTransform);
941  });
942  };
943  transform::ErrorCheckingTrackingListener trackingListener(*this, transform,
944  config);
945  transform::TransformRewriter rewriter(transform->getContext(),
946  &trackingListener);
947 
948  // Compute the result but do not short-circuit the silenceable failure case as
949  // we still want the handles to propagate properly so the "suppress" mode can
950  // proceed on a best effort basis.
951  transform::TransformResults results(transform->getNumResults());
952  DiagnosedSilenceableFailure result(transform.apply(rewriter, results, *this));
953  compactOpHandles();
954 
955  // Error handling: fail if transform or listener failed.
956  DiagnosedSilenceableFailure trackingFailure =
957  trackingListener.checkAndResetError();
958  if (!transform->hasTrait<ReportTrackingListenerFailuresOpTrait>() ||
959  transform->hasAttr(FindPayloadReplacementOpInterface::
960  kSilenceTrackingFailuresAttrName)) {
961  // Only report failures for ReportTrackingListenerFailuresOpTrait ops. Also
962  // do not report failures if the above mentioned attribute is set.
963  if (trackingFailure.isSilenceableFailure())
964  (void)trackingFailure.silence();
965  trackingFailure = DiagnosedSilenceableFailure::success();
966  }
967  if (!trackingFailure.succeeded()) {
968  if (result.succeeded()) {
969  result = std::move(trackingFailure);
970  } else {
971  // Transform op errors have precedence, report those first.
972  if (result.isSilenceableFailure())
973  result.attachNote() << "tracking listener also failed: "
974  << trackingFailure.getMessage();
975  (void)trackingFailure.silence();
976  }
977  }
978  if (result.isDefiniteFailure())
979  return result;
980 
981  // If a silenceable failure was produced, some results may be unset, set them
982  // to empty lists.
983  if (result.isSilenceableFailure())
984  results.setRemainingToEmpty(transform);
985 
986  // Remove the mapping for the operand if it is consumed by the operation. This
987  // allows us to catch use-after-free with assertions later on.
988  for (OpOperand *opOperand : consumedOperands) {
989  Value operand = opOperand->get();
990  if (llvm::isa<TransformHandleTypeInterface>(operand.getType())) {
991  forgetMapping(operand, origOpFlatResults);
992  } else if (llvm::isa<TransformValueHandleTypeInterface>(
993  operand.getType())) {
994  forgetValueMapping(operand, origAssociatedOps);
995  }
996  }
997 
998  if (failed(updateStateFromResults(results, transform->getResults())))
1000 
1001  printOnFailureRAII.release();
1002  DEBUG_WITH_TYPE(DEBUG_PRINT_AFTER_ALL, {
1003  DBGS() << "Top-level payload:\n";
1004  getTopLevel()->print(llvm::dbgs());
1005  });
1006  return result;
1007 }
1008 
1009 LogicalResult transform::TransformState::updateStateFromResults(
1010  const TransformResults &results, ResultRange opResults) {
1011  for (OpResult result : opResults) {
1012  if (llvm::isa<TransformParamTypeInterface>(result.getType())) {
1013  assert(results.isParam(result.getResultNumber()) &&
1014  "expected parameters for the parameter-typed result");
1015  if (failed(
1016  setParams(result, results.getParams(result.getResultNumber())))) {
1017  return failure();
1018  }
1019  } else if (llvm::isa<TransformValueHandleTypeInterface>(result.getType())) {
1020  assert(results.isValue(result.getResultNumber()) &&
1021  "expected values for value-type-result");
1022  if (failed(setPayloadValues(
1023  result, results.getValues(result.getResultNumber())))) {
1024  return failure();
1025  }
1026  } else {
1027  assert(!results.isParam(result.getResultNumber()) &&
1028  "expected payload ops for the non-parameter typed result");
1029  if (failed(
1030  setPayloadOps(result, results.get(result.getResultNumber())))) {
1031  return failure();
1032  }
1033  }
1034  }
1035  return success();
1036 }
1037 
1038 //===----------------------------------------------------------------------===//
1039 // TransformState::Extension
1040 //===----------------------------------------------------------------------===//
1041 
1043 
1044 LogicalResult
1046  Operation *replacement) {
1047  // TODO: we may need to invalidate handles to operations and values nested in
1048  // the operation being replaced.
1049  return state.replacePayloadOp(op, replacement);
1050 }
1051 
1052 LogicalResult
1054  Value replacement) {
1055  return state.replacePayloadValue(value, replacement);
1056 }
1057 
1058 //===----------------------------------------------------------------------===//
1059 // TransformState::RegionScope
1060 //===----------------------------------------------------------------------===//
1061 
1063  // Remove handle invalidation notices as handles are going out of scope.
1064  // The same region may be re-entered leading to incorrect invalidation
1065  // errors.
1066  for (Block &block : *region) {
1067  for (Value handle : block.getArguments()) {
1068  state.invalidatedHandles.erase(handle);
1069  }
1070  for (Operation &op : block) {
1071  for (Value handle : op.getResults()) {
1072  state.invalidatedHandles.erase(handle);
1073  }
1074  }
1075  }
1076 
1077 #if LLVM_ENABLE_ABI_BREAKING_CHECKS
1078  // Remember pointers to payload ops referenced by the handles going out of
1079  // scope.
1080  SmallVector<Operation *> referencedOps =
1081  llvm::to_vector(llvm::make_first_range(state.mappings[region]->reverse));
1082 #endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
1083 
1084  state.mappings.erase(region);
1085  state.regionStack.pop_back();
1086 }
1087 
1088 //===----------------------------------------------------------------------===//
1089 // TransformResults
1090 //===----------------------------------------------------------------------===//
1091 
1092 transform::TransformResults::TransformResults(unsigned numSegments) {
1093  operations.appendEmptyRows(numSegments);
1094  params.appendEmptyRows(numSegments);
1095  values.appendEmptyRows(numSegments);
1096 }
1097 
1100  int64_t position = value.getResultNumber();
1101  assert(position < static_cast<int64_t>(this->params.size()) &&
1102  "setting params for a non-existent handle");
1103  assert(this->params[position].data() == nullptr && "params already set");
1104  assert(operations[position].data() == nullptr &&
1105  "another kind of results already set");
1106  assert(values[position].data() == nullptr &&
1107  "another kind of results already set");
1108  this->params.replace(position, params);
1109 }
1110 
1112  OpResult handle, ArrayRef<MappedValue> values) {
1114  handle, values,
1115  [&](ArrayRef<Operation *> operations) {
1116  return set(handle, operations), success();
1117  },
1118  [&](ArrayRef<Param> params) {
1119  return setParams(handle, params), success();
1120  },
1121  [&](ValueRange payloadValues) {
1122  return setValues(handle, payloadValues), success();
1123  });
1124 #ifndef NDEBUG
1125  if (!diag.succeeded())
1126  llvm::dbgs() << diag.getStatusString() << "\n";
1127  assert(diag.succeeded() && "incorrect mapping");
1128 #endif // NDEBUG
1129  (void)diag.silence();
1130 }
1131 
1133  transform::TransformOpInterface transform) {
1134  for (OpResult opResult : transform->getResults()) {
1135  if (!isSet(opResult.getResultNumber()))
1136  setMappedValues(opResult, {});
1137  }
1138 }
1139 
1141 transform::TransformResults::get(unsigned resultNumber) const {
1142  assert(resultNumber < operations.size() &&
1143  "querying results for a non-existent handle");
1144  assert(operations[resultNumber].data() != nullptr &&
1145  "querying unset results (values or params expected?)");
1146  return operations[resultNumber];
1147 }
1148 
1150 transform::TransformResults::getParams(unsigned resultNumber) const {
1151  assert(resultNumber < params.size() &&
1152  "querying params for a non-existent handle");
1153  assert(params[resultNumber].data() != nullptr &&
1154  "querying unset params (ops or values expected?)");
1155  return params[resultNumber];
1156 }
1157 
1159 transform::TransformResults::getValues(unsigned resultNumber) const {
1160  assert(resultNumber < values.size() &&
1161  "querying values for a non-existent handle");
1162  assert(values[resultNumber].data() != nullptr &&
1163  "querying unset values (ops or params expected?)");
1164  return values[resultNumber];
1165 }
1166 
1167 bool transform::TransformResults::isParam(unsigned resultNumber) const {
1168  assert(resultNumber < params.size() &&
1169  "querying association for a non-existent handle");
1170  return params[resultNumber].data() != nullptr;
1171 }
1172 
1173 bool transform::TransformResults::isValue(unsigned resultNumber) const {
1174  assert(resultNumber < values.size() &&
1175  "querying association for a non-existent handle");
1176  return values[resultNumber].data() != nullptr;
1177 }
1178 
1179 bool transform::TransformResults::isSet(unsigned resultNumber) const {
1180  assert(resultNumber < params.size() &&
1181  "querying association for a non-existent handle");
1182  return params[resultNumber].data() != nullptr ||
1183  operations[resultNumber].data() != nullptr ||
1184  values[resultNumber].data() != nullptr;
1185 }
1186 
1187 //===----------------------------------------------------------------------===//
1188 // TrackingListener
1189 //===----------------------------------------------------------------------===//
1190 
1192  TransformOpInterface op,
1194  : TransformState::Extension(state), transformOp(op), config(config) {
1195  if (op) {
1196  for (OpOperand *opOperand : transformOp.getConsumedHandleOpOperands()) {
1197  consumedHandles.insert(opOperand->get());
1198  }
1199  }
1200 }
1201 
1203  Operation *defOp = nullptr;
1204  for (Value v : values) {
1205  // Skip empty values.
1206  if (!v)
1207  continue;
1208  if (!defOp) {
1209  defOp = v.getDefiningOp();
1210  continue;
1211  }
1212  if (defOp != v.getDefiningOp())
1213  return nullptr;
1214  }
1215  return defOp;
1216 }
1217 
1219  Operation *&result, Operation *op, ValueRange newValues) const {
1220  assert(op->getNumResults() == newValues.size() &&
1221  "invalid number of replacement values");
1222  SmallVector<Value> values(newValues.begin(), newValues.end());
1223 
1225  getTransformOp(), "tracking listener failed to find replacement op "
1226  "during application of this transform op");
1227 
1228  do {
1229  // If the replacement values belong to different ops, drop the mapping.
1230  Operation *defOp = getCommonDefiningOp(values);
1231  if (!defOp) {
1232  diag.attachNote() << "replacement values belong to different ops";
1233  return diag;
1234  }
1235 
1236  // Skip through ops that implement CastOpInterface.
1237  if (config.skipCastOps && isa<CastOpInterface>(defOp)) {
1238  values.clear();
1239  values.assign(defOp->getOperands().begin(), defOp->getOperands().end());
1240  diag.attachNote(defOp->getLoc())
1241  << "using output of 'CastOpInterface' op";
1242  continue;
1243  }
1244 
1245  // If the defining op has the same name or we do not care about the name of
1246  // op replacements at all, we take it as a replacement.
1247  if (!config.requireMatchingReplacementOpName ||
1248  op->getName() == defOp->getName()) {
1249  result = defOp;
1251  }
1252 
1253  // Replacing an op with a constant-like equivalent is a common
1254  // canonicalization.
1255  if (defOp->hasTrait<OpTrait::ConstantLike>()) {
1256  result = defOp;
1258  }
1259 
1260  values.clear();
1261 
1262  // Skip through ops that implement FindPayloadReplacementOpInterface.
1263  if (auto findReplacementOpInterface =
1264  dyn_cast<FindPayloadReplacementOpInterface>(defOp)) {
1265  values.assign(findReplacementOpInterface.getNextOperands());
1266  diag.attachNote(defOp->getLoc()) << "using operands provided by "
1267  "'FindPayloadReplacementOpInterface'";
1268  continue;
1269  }
1270  } while (!values.empty());
1271 
1272  diag.attachNote() << "ran out of suitable replacement values";
1273  return diag;
1274 }
1275 
1277  Location loc, function_ref<void(Diagnostic &)> reasonCallback) {
1278  LLVM_DEBUG({
1280  reasonCallback(diag);
1281  DBGS() << "Match Failure : " << diag.str() << "\n";
1282  });
1283 }
1284 
1285 void transform::TrackingListener::notifyOperationErased(Operation *op) {
1286  // Remove mappings for result values.
1287  for (OpResult value : op->getResults())
1288  (void)replacePayloadValue(value, nullptr);
1289  // Remove mapping for op.
1290  (void)replacePayloadOp(op, nullptr);
1291 }
1292 
1293 void transform::TrackingListener::notifyOperationReplaced(
1294  Operation *op, ValueRange newValues) {
1295  assert(op->getNumResults() == newValues.size() &&
1296  "invalid number of replacement values");
1297 
1298  // Replace value handles.
1299  for (auto [oldValue, newValue] : llvm::zip(op->getResults(), newValues))
1300  (void)replacePayloadValue(oldValue, newValue);
1301 
1302  // Replace op handle.
1303  SmallVector<Value> opHandles;
1304  if (failed(getTransformState().getHandlesForPayloadOp(
1305  op, opHandles, /*includeOutOfScope=*/true))) {
1306  // Op is not tracked.
1307  return;
1308  }
1309 
1310  // Helper function to check if the current transform op consumes any handle
1311  // that is mapped to `op`.
1312  //
1313  // Note: If a handle was consumed, there shouldn't be any alive users, so it
1314  // is not really necessary to check for consumed handles. However, in case
1315  // there are indeed alive handles that were consumed (which is undefined
1316  // behavior) and a replacement op could not be found, we want to fail with a
1317  // nicer error message: "op uses a handle invalidated..." instead of "could
1318  // not find replacement op". This nicer error is produced later.
1319  auto handleWasConsumed = [&] {
1320  return llvm::any_of(opHandles,
1321  [&](Value h) { return consumedHandles.contains(h); });
1322  };
1323 
1324  // Check if there are any handles that must be updated.
1325  Value aliveHandle;
1326  if (config.skipHandleFn) {
1327  auto it = llvm::find_if(opHandles,
1328  [&](Value v) { return !config.skipHandleFn(v); });
1329  if (it != opHandles.end())
1330  aliveHandle = *it;
1331  } else if (!opHandles.empty()) {
1332  aliveHandle = opHandles.front();
1333  }
1334  if (!aliveHandle || handleWasConsumed()) {
1335  // The op is tracked but the corresponding handles are dead or were
1336  // consumed. Drop the op form the mapping.
1337  (void)replacePayloadOp(op, nullptr);
1338  return;
1339  }
1340 
1341  Operation *replacement;
1343  findReplacementOp(replacement, op, newValues);
1344  // If the op is tracked but no replacement op was found, send a
1345  // notification.
1346  if (!diag.succeeded()) {
1347  diag.attachNote(aliveHandle.getLoc())
1348  << "replacement is required because this handle must be updated";
1349  notifyPayloadReplacementNotFound(op, newValues, std::move(diag));
1350  (void)replacePayloadOp(op, nullptr);
1351  return;
1352  }
1353 
1354  (void)replacePayloadOp(op, replacement);
1355 }
1356 
1358  // The state of the ErrorCheckingTrackingListener must be checked and reset
1359  // if there was an error. This is to prevent errors from accidentally being
1360  // missed.
1361  assert(status.succeeded() && "listener state was not checked");
1362 }
1363 
1366  DiagnosedSilenceableFailure s = std::move(status);
1368  errorCounter = 0;
1369  return s;
1370 }
1371 
1373  return !status.succeeded();
1374 }
1375 
1378 
1379  // Merge potentially existing diags and store the result in the listener.
1381  diag.takeDiagnostics(diags);
1382  if (!status.succeeded())
1383  status.takeDiagnostics(diags);
1384  status = DiagnosedSilenceableFailure::silenceableFailure(std::move(diags));
1385 
1386  // Report more details.
1387  status.attachNote(op->getLoc()) << "[" << errorCounter << "] replaced op";
1388  for (auto &&[index, value] : llvm::enumerate(values))
1389  status.attachNote(value.getLoc())
1390  << "[" << errorCounter << "] replacement value " << index;
1391  ++errorCounter;
1392 }
1393 
1394 std::string
1396  if (!matchFailure) {
1397  return "";
1398  }
1399  return matchFailure->str();
1400 }
1401 
1403  Location loc, function_ref<void(Diagnostic &)> reasonCallback) {
1405  reasonCallback(diag);
1406  matchFailure = std::move(diag);
1407 }
1408 
1409 //===----------------------------------------------------------------------===//
1410 // TransformRewriter
1411 //===----------------------------------------------------------------------===//
1412 
1415  : RewriterBase(ctx), listener(listener) {
1416  setListener(listener);
1417 }
1418 
1420  return listener->failed();
1421 }
1422 
1423 /// Silence all tracking failures that have been encountered so far.
1425  if (hasTrackingFailures()) {
1426  DiagnosedSilenceableFailure status = listener->checkAndResetError();
1427  (void)status.silence();
1428  }
1429 }
1430 
1432  Operation *op, Operation *replacement) {
1433  return listener->replacePayloadOp(op, replacement);
1434 }
1435 
1436 //===----------------------------------------------------------------------===//
1437 // Utilities for TransformEachOpTrait.
1438 //===----------------------------------------------------------------------===//
1439 
1440 LogicalResult
1442  ArrayRef<Operation *> targets) {
1443  for (auto &&[position, parent] : llvm::enumerate(targets)) {
1444  for (Operation *child : targets.drop_front(position + 1)) {
1445  if (parent->isAncestor(child)) {
1447  emitError(loc)
1448  << "transform operation consumes a handle pointing to an ancestor "
1449  "payload operation before its descendant";
1450  diag.attachNote()
1451  << "the ancestor is likely erased or rewritten before the "
1452  "descendant is accessed, leading to undefined behavior";
1453  diag.attachNote(parent->getLoc()) << "ancestor payload op";
1454  diag.attachNote(child->getLoc()) << "descendant payload op";
1455  return diag;
1456  }
1457  }
1458  }
1459  return success();
1460 }
1461 
1462 LogicalResult
1464  Location payloadOpLoc,
1465  const ApplyToEachResultList &partialResult) {
1466  Location transformOpLoc = transformOp->getLoc();
1467  StringRef transformOpName = transformOp->getName().getStringRef();
1468  unsigned expectedNumResults = transformOp->getNumResults();
1469 
1470  // Reuse the emission of the diagnostic note.
1471  auto emitDiag = [&]() {
1472  auto diag = mlir::emitError(transformOpLoc);
1473  diag.attachNote(payloadOpLoc) << "when applied to this op";
1474  return diag;
1475  };
1476 
1477  if (partialResult.size() != expectedNumResults) {
1478  auto diag = emitDiag() << "application of " << transformOpName
1479  << " expected to produce " << expectedNumResults
1480  << " results (actually produced "
1481  << partialResult.size() << ").";
1482  diag.attachNote(transformOpLoc)
1483  << "if you need variadic results, consider a generic `apply` "
1484  << "instead of the specialized `applyToOne`.";
1485  return failure();
1486  }
1487 
1488  // Check that the right kind of value was produced.
1489  for (const auto &[ptr, res] :
1490  llvm::zip(partialResult, transformOp->getResults())) {
1491  if (ptr.isNull())
1492  continue;
1493  if (llvm::isa<TransformHandleTypeInterface>(res.getType()) &&
1494  !isa<Operation *>(ptr)) {
1495  return emitDiag() << "application of " << transformOpName
1496  << " expected to produce an Operation * for result #"
1497  << res.getResultNumber();
1498  }
1499  if (llvm::isa<TransformParamTypeInterface>(res.getType()) &&
1500  !isa<Attribute>(ptr)) {
1501  return emitDiag() << "application of " << transformOpName
1502  << " expected to produce an Attribute for result #"
1503  << res.getResultNumber();
1504  }
1505  if (llvm::isa<TransformValueHandleTypeInterface>(res.getType()) &&
1506  !isa<Value>(ptr)) {
1507  return emitDiag() << "application of " << transformOpName
1508  << " expected to produce a Value for result #"
1509  << res.getResultNumber();
1510  }
1511  }
1512  return success();
1513 }
1514 
1515 template <typename T>
1517  return llvm::to_vector(llvm::map_range(
1518  range, [](transform::MappedValue value) { return cast<T>(value); }));
1519 }
1520 
1522  Operation *transformOp, TransformResults &transformResults,
1525  transposed.resize(transformOp->getNumResults());
1526  for (const ApplyToEachResultList &partialResults : results) {
1527  if (llvm::any_of(partialResults,
1528  [](MappedValue value) { return value.isNull(); }))
1529  continue;
1530  assert(transformOp->getNumResults() == partialResults.size() &&
1531  "expected as many partial results as op as results");
1532  for (auto [i, value] : llvm::enumerate(partialResults))
1533  transposed[i].push_back(value);
1534  }
1535 
1536  for (OpResult r : transformOp->getResults()) {
1537  unsigned position = r.getResultNumber();
1538  if (llvm::isa<TransformParamTypeInterface>(r.getType())) {
1539  transformResults.setParams(r,
1540  castVector<Attribute>(transposed[position]));
1541  } else if (llvm::isa<TransformValueHandleTypeInterface>(r.getType())) {
1542  transformResults.setValues(r, castVector<Value>(transposed[position]));
1543  } else {
1544  transformResults.set(r, castVector<Operation *>(transposed[position]));
1545  }
1546  }
1547 }
1548 
1549 //===----------------------------------------------------------------------===//
1550 // Utilities for implementing transform ops with regions.
1551 //===----------------------------------------------------------------------===//
1552 
1555  ValueRange values, const transform::TransformState &state, bool flatten) {
1556  assert(mappings.size() == values.size() && "mismatching number of mappings");
1557  for (auto &&[operand, mapped] : llvm::zip_equal(values, mappings)) {
1558  size_t mappedSize = mapped.size();
1559  if (llvm::isa<TransformHandleTypeInterface>(operand.getType())) {
1560  llvm::append_range(mapped, state.getPayloadOps(operand));
1561  } else if (llvm::isa<TransformValueHandleTypeInterface>(
1562  operand.getType())) {
1563  llvm::append_range(mapped, state.getPayloadValues(operand));
1564  } else {
1565  assert(llvm::isa<TransformParamTypeInterface>(operand.getType()) &&
1566  "unsupported kind of transform dialect value");
1567  llvm::append_range(mapped, state.getParams(operand));
1568  }
1569 
1570  if (mapped.size() - mappedSize != 1 && !flatten)
1571  return failure();
1572  }
1573  return success();
1574 }
1575 
1578  ValueRange values, const transform::TransformState &state) {
1579  mappings.resize(mappings.size() + values.size());
1580  (void)appendValueMappings(
1582  values.size()),
1583  values, state);
1584 }
1585 
1587  Block *block, transform::TransformState &state,
1588  transform::TransformResults &results) {
1589  for (auto &&[terminatorOperand, result] :
1590  llvm::zip(block->getTerminator()->getOperands(),
1591  block->getParentOp()->getOpResults())) {
1592  if (llvm::isa<transform::TransformHandleTypeInterface>(result.getType())) {
1593  results.set(result, state.getPayloadOps(terminatorOperand));
1594  } else if (llvm::isa<transform::TransformValueHandleTypeInterface>(
1595  result.getType())) {
1596  results.setValues(result, state.getPayloadValues(terminatorOperand));
1597  } else {
1598  assert(
1599  llvm::isa<transform::TransformParamTypeInterface>(result.getType()) &&
1600  "unhandled transform type interface");
1601  results.setParams(result, state.getParams(terminatorOperand));
1602  }
1603  }
1604 }
1605 
1608  Operation *payloadRoot) {
1609  return TransformState(region, payloadRoot);
1610 }
1611 
1612 //===----------------------------------------------------------------------===//
1613 // Utilities for PossibleTopLevelTransformOpTrait.
1614 //===----------------------------------------------------------------------===//
1615 
1616 /// Appends to `effects` the memory effect instances on `target` with the same
1617 /// resource and effect as the ones the operation `iface` having on `source`.
1618 static void
1619 remapEffects(MemoryEffectOpInterface iface, BlockArgument source,
1620  OpOperand *target,
1623  iface.getEffectsOnValue(source, nestedEffects);
1624  for (const auto &effect : nestedEffects)
1625  effects.emplace_back(effect.getEffect(), target, effect.getResource());
1626 }
1627 
1628 /// Appends to `effects` the same effects as the operations of `block` have on
1629 /// block arguments but associated with `operands.`
1630 static void
1633  for (Operation &op : block) {
1634  auto iface = dyn_cast<MemoryEffectOpInterface>(&op);
1635  if (!iface)
1636  continue;
1637 
1638  for (auto &&[source, target] : llvm::zip(block.getArguments(), operands)) {
1639  remapEffects(iface, source, &target, effects);
1640  }
1641 
1643  iface.getEffectsOnResource(transform::PayloadIRResource::get(),
1644  nestedEffects);
1645  llvm::append_range(effects, nestedEffects);
1646  }
1647 }
1648 
1650  Operation *operation, Value root, Block &body,
1652  transform::onlyReadsHandle(operation->getOpOperands(), effects);
1653  transform::producesHandle(operation->getOpResults(), effects);
1654 
1655  if (!root) {
1656  for (Operation &op : body) {
1657  auto iface = dyn_cast<MemoryEffectOpInterface>(&op);
1658  if (!iface)
1659  continue;
1660 
1661  iface.getEffects(effects);
1662  }
1663  return;
1664  }
1665 
1666  // Carry over all effects on arguments of the entry block as those on the
1667  // operands, this is the same value just remapped.
1668  remapArgumentEffects(body, operation->getOpOperands(), effects);
1669 }
1670 
1672  TransformState &state, Operation *op, Region &region) {
1673  SmallVector<Operation *> targets;
1674  SmallVector<SmallVector<MappedValue>> extraMappings;
1675  if (op->getNumOperands() != 0) {
1676  llvm::append_range(targets, state.getPayloadOps(op->getOperand(0)));
1677  prepareValueMappings(extraMappings, op->getOperands().drop_front(), state);
1678  } else {
1679  if (state.getNumTopLevelMappings() !=
1680  region.front().getNumArguments() - 1) {
1681  return emitError(op->getLoc())
1682  << "operation expects " << region.front().getNumArguments() - 1
1683  << " extra value bindings, but " << state.getNumTopLevelMappings()
1684  << " were provided to the interpreter";
1685  }
1686 
1687  targets.push_back(state.getTopLevel());
1688 
1689  for (unsigned i = 0, e = state.getNumTopLevelMappings(); i < e; ++i)
1690  extraMappings.push_back(llvm::to_vector(state.getTopLevelMapping(i)));
1691  }
1692 
1693  if (failed(state.mapBlockArguments(region.front().getArgument(0), targets)))
1694  return failure();
1695 
1696  for (BlockArgument argument : region.front().getArguments().drop_front()) {
1697  if (failed(state.mapBlockArgument(
1698  argument, extraMappings[argument.getArgNumber() - 1])))
1699  return failure();
1700  }
1701 
1702  return success();
1703 }
1704 
1705 LogicalResult
1707  // Attaching this trait without the interface is a misuse of the API, but it
1708  // cannot be caught via a static_assert because interface registration is
1709  // dynamic.
1710  assert(isa<TransformOpInterface>(op) &&
1711  "should implement TransformOpInterface to have "
1712  "PossibleTopLevelTransformOpTrait");
1713 
1714  if (op->getNumRegions() < 1)
1715  return op->emitOpError() << "expects at least one region";
1716 
1717  Region *bodyRegion = &op->getRegion(0);
1718  if (!llvm::hasNItems(*bodyRegion, 1))
1719  return op->emitOpError() << "expects a single-block region";
1720 
1721  Block *body = &bodyRegion->front();
1722  if (body->getNumArguments() == 0) {
1723  return op->emitOpError()
1724  << "expects the entry block to have at least one argument";
1725  }
1726  if (!llvm::isa<TransformHandleTypeInterface>(
1727  body->getArgument(0).getType())) {
1728  return op->emitOpError()
1729  << "expects the first entry block argument to be of type "
1730  "implementing TransformHandleTypeInterface";
1731  }
1732  BlockArgument arg = body->getArgument(0);
1733  if (op->getNumOperands() != 0) {
1734  if (arg.getType() != op->getOperand(0).getType()) {
1735  return op->emitOpError()
1736  << "expects the type of the block argument to match "
1737  "the type of the operand";
1738  }
1739  }
1740  for (BlockArgument arg : body->getArguments().drop_front()) {
1741  if (llvm::isa<TransformHandleTypeInterface, TransformParamTypeInterface,
1742  TransformValueHandleTypeInterface>(arg.getType()))
1743  continue;
1744 
1746  op->emitOpError()
1747  << "expects trailing entry block arguments to be of type implementing "
1748  "TransformHandleTypeInterface, TransformValueHandleTypeInterface or "
1749  "TransformParamTypeInterface";
1750  diag.attachNote() << "argument #" << arg.getArgNumber() << " does not";
1751  return diag;
1752  }
1753 
1754  if (auto *parent =
1756  if (op->getNumOperands() != body->getNumArguments()) {
1758  op->emitOpError()
1759  << "expects operands to be provided for a nested op";
1760  diag.attachNote(parent->getLoc())
1761  << "nested in another possible top-level op";
1762  return diag;
1763  }
1764  }
1765 
1766  return success();
1767 }
1768 
1769 //===----------------------------------------------------------------------===//
1770 // Utilities for ParamProducedTransformOpTrait.
1771 //===----------------------------------------------------------------------===//
1772 
1775  producesHandle(op->getResults(), effects);
1776  bool hasPayloadOperands = false;
1777  for (OpOperand &operand : op->getOpOperands()) {
1778  onlyReadsHandle(operand, effects);
1779  if (llvm::isa<TransformHandleTypeInterface,
1780  TransformValueHandleTypeInterface>(operand.get().getType()))
1781  hasPayloadOperands = true;
1782  }
1783  if (hasPayloadOperands)
1784  onlyReadsPayload(effects);
1785 }
1786 
1787 LogicalResult
1789  // Interfaces can be attached dynamically, so this cannot be a static
1790  // assert.
1791  if (!op->getName().getInterface<MemoryEffectOpInterface>()) {
1792  llvm::report_fatal_error(
1793  Twine("ParamProducerTransformOpTrait must be attached to an op that "
1794  "implements MemoryEffectsOpInterface, found on ") +
1795  op->getName().getStringRef());
1796  }
1797  for (Value result : op->getResults()) {
1798  if (llvm::isa<TransformParamTypeInterface>(result.getType()))
1799  continue;
1800  return op->emitOpError()
1801  << "ParamProducerTransformOpTrait attached to this op expects "
1802  "result types to implement TransformParamTypeInterface";
1803  }
1804  return success();
1805 }
1806 
1807 //===----------------------------------------------------------------------===//
1808 // Memory effects.
1809 //===----------------------------------------------------------------------===//
1810 
1814  for (OpOperand &handle : handles) {
1815  effects.emplace_back(MemoryEffects::Read::get(), &handle,
1817  effects.emplace_back(MemoryEffects::Free::get(), &handle,
1819  }
1820 }
1821 
1822 /// Returns `true` if the given list of effects instances contains an instance
1823 /// with the effect type specified as template parameter.
1824 template <typename EffectTy, typename ResourceTy, typename Range>
1825 static bool hasEffect(Range &&effects) {
1826  return llvm::any_of(effects, [](const MemoryEffects::EffectInstance &effect) {
1827  return isa<EffectTy>(effect.getEffect()) &&
1828  isa<ResourceTy>(effect.getResource());
1829  });
1830 }
1831 
1833  transform::TransformOpInterface transform) {
1834  auto iface = cast<MemoryEffectOpInterface>(transform.getOperation());
1836  iface.getEffectsOnValue(handle, effects);
1837  return ::hasEffect<MemoryEffects::Read, TransformMappingResource>(effects) &&
1838  ::hasEffect<MemoryEffects::Free, TransformMappingResource>(effects);
1839 }
1840 
1842  ResultRange handles,
1844  for (OpResult handle : handles) {
1845  effects.emplace_back(MemoryEffects::Allocate::get(), handle,
1847  effects.emplace_back(MemoryEffects::Write::get(), handle,
1849  }
1850 }
1851 
1855  for (BlockArgument handle : handles) {
1856  effects.emplace_back(MemoryEffects::Allocate::get(), handle,
1858  effects.emplace_back(MemoryEffects::Write::get(), handle,
1860  }
1861 }
1862 
1866  for (OpOperand &handle : handles) {
1867  effects.emplace_back(MemoryEffects::Read::get(), &handle,
1869  }
1870 }
1871 
1874  effects.emplace_back(MemoryEffects::Read::get(), PayloadIRResource::get());
1875  effects.emplace_back(MemoryEffects::Write::get(), PayloadIRResource::get());
1876 }
1877 
1880  effects.emplace_back(MemoryEffects::Read::get(), PayloadIRResource::get());
1881 }
1882 
1883 bool transform::doesModifyPayload(transform::TransformOpInterface transform) {
1884  auto iface = cast<MemoryEffectOpInterface>(transform.getOperation());
1886  iface.getEffects(effects);
1887  return ::hasEffect<MemoryEffects::Write, PayloadIRResource>(effects);
1888 }
1889 
1890 bool transform::doesReadPayload(transform::TransformOpInterface transform) {
1891  auto iface = cast<MemoryEffectOpInterface>(transform.getOperation());
1893  iface.getEffects(effects);
1894  return ::hasEffect<MemoryEffects::Read, PayloadIRResource>(effects);
1895 }
1896 
1898  Block &block, llvm::SmallDenseSet<unsigned int> &consumedArguments) {
1900  for (Operation &nested : block) {
1901  auto iface = dyn_cast<MemoryEffectOpInterface>(nested);
1902  if (!iface)
1903  continue;
1904 
1905  effects.clear();
1906  iface.getEffects(effects);
1907  for (const MemoryEffects::EffectInstance &effect : effects) {
1908  BlockArgument argument =
1909  dyn_cast_or_null<BlockArgument>(effect.getValue());
1910  if (!argument || argument.getOwner() != &block ||
1911  !isa<MemoryEffects::Free>(effect.getEffect()) ||
1912  effect.getResource() != transform::TransformMappingResource::get()) {
1913  continue;
1914  }
1915  consumedArguments.insert(argument.getArgNumber());
1916  }
1917  }
1918 }
1919 
1920 //===----------------------------------------------------------------------===//
1921 // Utilities for TransformOpInterface.
1922 //===----------------------------------------------------------------------===//
1923 
1925  TransformOpInterface transformOp) {
1926  SmallVector<OpOperand *> consumedOperands;
1927  consumedOperands.reserve(transformOp->getNumOperands());
1928  auto memEffectInterface =
1929  cast<MemoryEffectOpInterface>(transformOp.getOperation());
1931  for (OpOperand &target : transformOp->getOpOperands()) {
1932  effects.clear();
1933  memEffectInterface.getEffectsOnValue(target.get(), effects);
1934  if (llvm::any_of(effects, [](const MemoryEffects::EffectInstance &effect) {
1935  return isa<transform::TransformMappingResource>(
1936  effect.getResource()) &&
1937  isa<MemoryEffects::Free>(effect.getEffect());
1938  })) {
1939  consumedOperands.push_back(&target);
1940  }
1941  }
1942  return consumedOperands;
1943 }
1944 
1946  auto iface = cast<MemoryEffectOpInterface>(op);
1948  iface.getEffects(effects);
1949 
1950  auto effectsOn = [&](Value value) {
1951  return llvm::make_filter_range(
1952  effects, [value](const MemoryEffects::EffectInstance &instance) {
1953  return instance.getValue() == value;
1954  });
1955  };
1956 
1957  std::optional<unsigned> firstConsumedOperand;
1958  for (OpOperand &operand : op->getOpOperands()) {
1959  auto range = effectsOn(operand.get());
1960  if (range.empty()) {
1962  op->emitError() << "TransformOpInterface requires memory effects "
1963  "on operands to be specified";
1964  diag.attachNote() << "no effects specified for operand #"
1965  << operand.getOperandNumber();
1966  return diag;
1967  }
1968  if (::hasEffect<MemoryEffects::Allocate, TransformMappingResource>(range)) {
1970  << "TransformOpInterface did not expect "
1971  "'allocate' memory effect on an operand";
1972  diag.attachNote() << "specified for operand #"
1973  << operand.getOperandNumber();
1974  return diag;
1975  }
1976  if (!firstConsumedOperand &&
1977  ::hasEffect<MemoryEffects::Free, TransformMappingResource>(range)) {
1978  firstConsumedOperand = operand.getOperandNumber();
1979  }
1980  }
1981 
1982  if (firstConsumedOperand &&
1983  !::hasEffect<MemoryEffects::Write, PayloadIRResource>(effects)) {
1985  op->emitError()
1986  << "TransformOpInterface expects ops consuming operands to have a "
1987  "'write' effect on the payload resource";
1988  diag.attachNote() << "consumes operand #" << *firstConsumedOperand;
1989  return diag;
1990  }
1991 
1992  for (OpResult result : op->getResults()) {
1993  auto range = effectsOn(result);
1994  if (!::hasEffect<MemoryEffects::Allocate, TransformMappingResource>(
1995  range)) {
1997  op->emitError() << "TransformOpInterface requires 'allocate' memory "
1998  "effect to be specified for results";
1999  diag.attachNote() << "no 'allocate' effect specified for result #"
2000  << result.getResultNumber();
2001  return diag;
2002  }
2003  }
2004 
2005  return success();
2006 }
2007 
2008 //===----------------------------------------------------------------------===//
2009 // Entry point.
2010 //===----------------------------------------------------------------------===//
2011 
2013  Operation *payloadRoot, TransformOpInterface transform,
2014  const RaggedArray<MappedValue> &extraMapping,
2015  const TransformOptions &options, bool enforceToplevelTransformOp,
2016  function_ref<void(TransformState &)> stateInitializer,
2017  function_ref<LogicalResult(TransformState &)> stateExporter) {
2018  if (enforceToplevelTransformOp) {
2019  if (!transform->hasTrait<PossibleTopLevelTransformOpTrait>() ||
2020  transform->getNumOperands() != 0) {
2021  return transform->emitError()
2022  << "expected transform to start at the top-level transform op";
2023  }
2024  } else if (failed(
2026  return failure();
2027  }
2028 
2029  TransformState state(transform->getParentRegion(), payloadRoot, extraMapping,
2030  options);
2031  if (stateInitializer)
2032  stateInitializer(state);
2033  if (state.applyTransform(transform).checkAndReport().failed())
2034  return failure();
2035  if (stateExporter)
2036  return stateExporter(state);
2037  return success();
2038 }
2039 
2040 //===----------------------------------------------------------------------===//
2041 // Generated interface implementation.
2042 //===----------------------------------------------------------------------===//
2043 
2044 #include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.cpp.inc"
2045 #include "mlir/Dialect/Transform/Interfaces/TransformTypeInterfaces.cpp.inc"
static InFlightDiagnostic emitDiag(Location location, DiagnosticSeverity severity, const Twine &message)
Helper function used to emit a diagnostic with an optionally empty twine message.
static std::string diag(const llvm::Value &value)
static llvm::ManagedStatic< PassManagerOptions > options
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
#define DEBUG_TYPE_FULL
#define FULL_LDBG(X)
static void remapEffects(MemoryEffectOpInterface iface, BlockArgument source, OpOperand *target, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
Appends to effects the memory effect instances on target with the same resource and effect as the one...
static void remapArgumentEffects(Block &block, MutableArrayRef< OpOperand > operands, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
Appends to effects the same effects as the operations of block have on block arguments but associated...
static bool happensBefore(Operation *a, Operation *b)
Return true if a happens before b, i.e., a or one of its ancestors properly dominates b and b is not ...
void dropMappingEntry(Mapping &mapping, Key key, Mapped mapped)
#define DEBUG_PRINT_AFTER_ALL
DiagnosedSilenceableFailure checkRepeatedConsumptionInOperand(ArrayRef< T > payload, transform::TransformOpInterface transform, unsigned operandNumber)
#define DBGS()
static DiagnosedSilenceableFailure dispatchMappedValues(Value handle, ArrayRef< transform::MappedValue > values, function_ref< LogicalResult(ArrayRef< Operation * >)> operationsFn, function_ref< LogicalResult(ArrayRef< transform::Param >)> paramsFn, function_ref< LogicalResult(ValueRange)> valuesFn)
Given a list of MappedValues, cast them to the value kind implied by the interface of the handle type...
static SmallVector< T > castVector(ArrayRef< transform::MappedValue > range)
Attributes are known-constant values of operations.
Definition: Attributes.h:25
This class represents an argument of a Block.
Definition: Value.h:309
Block * getOwner() const
Returns the block that owns this argument.
Definition: Value.h:318
unsigned getArgNumber() const
Returns the number of this argument.
Definition: Value.h:321
Block represents an ordered list of Operations.
Definition: Block.h:33
MutableArrayRef< BlockArgument > BlockArgListType
Definition: Block.h:85
Operation * findAncestorOpInBlock(Operation &op)
Returns 'op' if 'op' lies in this block, or otherwise finds the ancestor operation of 'op' that lies ...
Definition: Block.cpp:76
BlockArgument getArgument(unsigned i)
Definition: Block.h:129
unsigned getNumArguments()
Definition: Block.h:128
Operation * getTerminator()
Get the terminator operation of this block.
Definition: Block.cpp:246
BlockArgListType getArguments()
Definition: Block.h:87
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
Definition: Block.cpp:33
A compatibility class connecting InFlightDiagnostic to DiagnosedSilenceableFailure while providing an...
The result of a transform IR operation application.
LogicalResult silence()
Converts silenceable failure into LogicalResult success without reporting the diagnostic,...
static DiagnosedSilenceableFailure success()
Constructs a DiagnosedSilenceableFailure in the success state.
Diagnostic & attachNote(std::optional< Location > loc=std::nullopt)
Attaches a note to the last diagnostic.
std::string getMessage() const
Returns the diagnostic message without emitting it.
bool isDefiniteFailure() const
Returns true if this is a definite failure.
LogicalResult checkAndReport()
Converts all kinds of failure into a LogicalResult failure, emitting the diagnostic if necessary.
static DiagnosedSilenceableFailure silenceableFailure(Diagnostic &&diag)
Constructs a DiagnosedSilenceableFailure in the silenceable failure state, ready to emit the given di...
bool succeeded() const
Returns true if this is a success.
static DiagnosedSilenceableFailure definiteFailure()
Constructs a DiagnosedSilenceableFailure in the failure state.
bool isSilenceableFailure() const
Returns true if this is a silenceable failure.
This class contains all of the information necessary to report a diagnostic to the DiagnosticEngine.
Definition: Diagnostics.h:155
IRValueT get() const
Return the current value being used by this operand.
Definition: UseDefLists.h:160
This class represents a diagnostic that is inflight and set to be reported.
Definition: Diagnostics.h:314
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
void setListener(Listener *newListener)
Sets the listener of this builder to the one provided.
Definition: Builders.h:314
This class represents an operand of an operation.
Definition: Value.h:257
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Definition: Value.cpp:228
Set of flags used to control the behavior of the various IR print methods (e.g.
This is a value defined by a result of an operation.
Definition: Value.h:447
unsigned getResultNumber() const
Returns the number of this result.
Definition: Value.h:459
This class provides the API for a sub-set of ops that are known to be constant-like.
This class provides the API for ops that are known to be isolated from above.
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
StringAttr getIdentifier() const
Return the name of this operation as a StringAttr.
T::Concept * getInterface() const
Returns an instance of the concept object for the given interface if it was registered to this operat...
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Value getOperand(unsigned idx)
Definition: Operation.h:350
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition: Operation.h:749
bool isBeforeInBlock(Operation *other)
Given an operation 'other' that is within the same parent block, return whether the current operation...
Definition: Operation.cpp:386
unsigned getNumRegions()
Returns the number of regions held by this operation.
Definition: Operation.h:674
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
unsigned getNumOperands()
Definition: Operation.h:346
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition: Operation.h:234
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:268
Block * getBlock()
Returns the operation block that contains this operation.
Definition: Operation.h:213
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Definition: Operation.h:686
Operation * getParentWithTrait()
Returns the closest surrounding parent operation with trait Trait.
Definition: Operation.h:248
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:119
MutableArrayRef< OpOperand > getOpOperands()
Definition: Operation.h:383
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:378
result_range getOpResults()
Definition: Operation.h:420
result_range getResults()
Definition: Operation.h:415
bool isProperAncestor(Operation *other)
Return true if this operation is a proper ancestor of the other operation.
Definition: Operation.cpp:219
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:673
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:404
A 2D array where each row may have different length.
Definition: RaggedArray.h:18
size_t size() const
Returns the number of rows in the 2D array.
Definition: RaggedArray.h:21
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
Operation * getParentOp()
Return the parent operation this region is attached to.
Definition: Region.h:200
Block & front()
Definition: Region.h:65
This class implements the result iterators for the Operation class.
Definition: ValueRange.h:247
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:358
This class represents a specific instance of an effect.
Resource * getResource() const
Return the resource that the effect applies to.
EffectT * getEffect() const
Return the effect being applied.
Value getValue() const
Return the value the effect is applied on, or nullptr if there isn't a known value being affected.
static DerivedEffect * get()
Returns a unique instance for the derived effect class.
static TransformMappingResource * get()
Returns a unique instance for the given effect class.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:26
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
Operation * getOwner() const
Return the owner of this operand.
Definition: UseDefLists.h:38
A list of results of applying a transform op with ApplyEachOpTrait to a single payload operation,...
size_t size() const
Returns the number of elements in the list.
A specialized listener that keeps track of cases in which no replacement payload could be found.
bool failed() const
Return "true" if this tracking listener had a failure.
std::string getLatestMatchFailureMessage()
Return the latest match notification message.
void notifyPayloadReplacementNotFound(Operation *op, ValueRange values, DiagnosedSilenceableFailure &&diag) override
This function is called when a tracked payload op is dropped because no replacement op was found.
void notifyMatchFailure(Location loc, function_ref< void(Diagnostic &)> reasonCallback) override
Notify the listener that the pattern failed to match the given operation, and provide a callback to p...
DiagnosedSilenceableFailure checkAndResetError()
Check and return the current error state of this listener.
This trait is supposed to be attached to Transform dialect operations that can be standalone top-leve...
TrackingListener failures are reported only for ops that have this trait.
void notifyMatchFailure(Location loc, function_ref< void(Diagnostic &)> reasonCallback) override
Notify the listener that the pattern failed to match the given operation, and provide a callback to p...
TrackingListener(TransformState &state, TransformOpInterface op, TrackingListenerConfig config=TrackingListenerConfig())
Create a new TrackingListener for usage in the specified transform op.
static Operation * getCommonDefiningOp(ValueRange values)
Return the single op that defines all given values (if any).
virtual DiagnosedSilenceableFailure findReplacementOp(Operation *&result, Operation *op, ValueRange newValues) const
Return a replacement payload op for the given op, which is going to be replaced with the given values...
Options controlling the application of transform operations by the TransformState.
Local mapping between values defined by a specific op implementing the TransformOpInterface and the p...
void setValues(OpResult handle, Range &&values)
Indicates that the result of the transform IR op at the given position corresponds to the given range...
void setParams(OpResult value, ArrayRef< TransformState::Param > params)
Indicates that the result of the transform IR op at the given position corresponds to the given list ...
void set(OpResult value, Range &&ops)
Indicates that the result of the transform IR op at the given position corresponds to the given list ...
void setRemainingToEmpty(TransformOpInterface transform)
Sets the currently unset results to empty lists of the kind expected by the corresponding results of ...
void setMappedValues(OpResult handle, ArrayRef< MappedValue > values)
Indicates that the result of the transform IR op at the given position corresponds to the given range...
This is a special rewriter to be used in transform op implementations, providing additional helper fu...
TransformRewriter(MLIRContext *ctx, ErrorCheckingTrackingListener *listener)
Create a new TransformRewriter.
bool hasTrackingFailures() const
Return "true" if the tracking listener had failures.
LogicalResult notifyPayloadOperationReplaced(Operation *op, Operation *replacement)
Notify the transform dialect interpreter that the given op has been replaced with another op and that...
void silenceTrackingFailure()
Silence all tracking failures that have been encountered so far.
LogicalResult replacePayloadOp(Operation *op, Operation *replacement)
Replaces the given payload op with another op.
virtual ~Extension()
Base virtual destructor.
LogicalResult replacePayloadValue(Value value, Value replacement)
Replaces the given payload value with another value.
A RAII object maintaining a "stack frame" for a transform IR region.
~RegionScope()
Forgets the mapping from or to values defined in the associated transform IR region,...
The state maintained across applications of various ops implementing the TransformOpInterface.
LogicalResult getHandlesForPayloadValue(Value payloadValue, SmallVectorImpl< Value > &handles, bool includeOutOfScope=false) const
Populates handles with all handles pointing to the given payload IR value.
LogicalResult mapBlockArguments(BlockArgument argument, ArrayRef< Operation * > operations)
Records the mapping between a block argument in the transform IR and a list of operations in the payl...
DiagnosedSilenceableFailure applyTransform(TransformOpInterface transform)
Applies the transformation specified by the given transform op and updates the state accordingly.
ArrayRef< Attribute > getParams(Value value) const
Returns the list of parameters that the given transform IR value corresponds to.
LogicalResult mapBlockArgument(BlockArgument argument, ArrayRef< MappedValue > values)
LogicalResult getHandlesForPayloadOp(Operation *op, SmallVectorImpl< Value > &handles, bool includeOutOfScope=false) const
Populates handles with all handles pointing to the given Payload IR op.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
LogicalResult verifyTransformOpInterface(Operation *op)
Verification hook for TransformOpInterface.
LogicalResult appendValueMappings(MutableArrayRef< SmallVector< transform::MappedValue >> mappings, ValueRange values, const transform::TransformState &state, bool flatten=true)
Appends the entities associated with the given transform values in state to the pre-existing list of ...
void setApplyToOneResults(Operation *transformOp, TransformResults &transformResults, ArrayRef< ApplyToEachResultList > results)
"Transpose" the results produced by individual applications, arranging them per result value of the t...
void forwardTerminatorOperands(Block *block, transform::TransformState &state, transform::TransformResults &results)
Populates results with payload associations that match exactly those of the operands to block's termi...
void getParamProducerTransformOpTraitEffects(Operation *op, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
Non-template implementation of ParamProducerTransformOpTrait::getEffects().
LogicalResult checkNestedConsumption(Location loc, ArrayRef< Operation * > targets)
Reports an error and returns failure if targets contains an ancestor operation before its descendant ...
LogicalResult mapPossibleTopLevelTransformOpBlockArguments(TransformState &state, Operation *op, Region &region)
Maps the only block argument of the op with PossibleTopLevelTransformOpTrait to either the list of op...
TransformState makeTransformStateForTesting(Region *region, Operation *payloadRoot)
Make a dummy transform state for testing purposes.
LogicalResult checkApplyToOne(Operation *transformOp, Location payloadOpLoc, const ApplyToEachResultList &partialResult)
Check that the contents of partialResult matches the number, kind (payload op or parameter) and nulli...
SmallVector< OpOperand * > getConsumedHandleOpOperands(transform::TransformOpInterface transformOp)
Returns all operands that are handles and being consumed by the given op.
LogicalResult verifyParamProducerTransformOpTrait(Operation *op)
Non-template implementation of ParamProducerTransformOpTrait::verify().
void prepareValueMappings(SmallVectorImpl< SmallVector< transform::MappedValue >> &mappings, ValueRange values, const transform::TransformState &state)
Populates mappings with mapped values associated with the given transform IR values in the given stat...
LogicalResult verifyPossibleTopLevelTransformOpTrait(Operation *op)
Verification hook for PossibleTopLevelTransformOpTrait.
void getPotentialTopLevelEffects(Operation *operation, Value root, Block &body, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
Populates effects with side effects implied by PossibleTopLevelTransformOpTrait for the given operati...
void onlyReadsPayload(SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
LogicalResult applyTransforms(Operation *payloadRoot, TransformOpInterface transform, const RaggedArray< MappedValue > &extraMapping={}, const TransformOptions &options=TransformOptions(), bool enforceToplevelTransformOp=true, function_ref< void(TransformState &)> stateInitializer=nullptr, function_ref< LogicalResult(TransformState &)> stateExporter=nullptr)
Entry point to the Transform dialect infrastructure.
bool isHandleConsumed(Value handle, transform::TransformOpInterface transform)
Checks whether the transform op consumes the given handle.
void getConsumedBlockArguments(Block &block, llvm::SmallDenseSet< unsigned > &consumedArguments)
Populates consumedArguments with positions of block arguments that are consumed by the operations in ...
bool doesModifyPayload(transform::TransformOpInterface transform)
Checks whether the transform op modifies the payload.
void producesHandle(ResultRange handles, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
bool doesReadPayload(transform::TransformOpInterface transform)
Checks whether the transform op reads the payload.
void consumesHandle(MutableArrayRef< OpOperand > handles, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
Populates effects with the memory effects indicating the operation on the given handle value:
void onlyReadsHandle(MutableArrayRef< OpOperand > handles, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
void modifiesPayload(SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
Populates effects with the memory effects indicating the access to payload IR resource.
Include the generated interface declarations.
const FrozenRewritePatternSet GreedyRewriteConfig config
DiagnosedSilenceableFailure emitSilenceableFailure(Location loc, const Twine &message={})
Emits a silenceable failure with the given message.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
DiagnosedDefiniteFailure emitDefiniteFailure(Location loc, const Twine &message={})
Emits a definite failure with the given message.
bool hasEffect(Operation *op)
Returns true if op has an effect of type EffectTy.
Represents a range (offset, size, and stride) where each element of the triple may be dynamic or stat...
A configuration object for customizing a TrackingListener.