MLIR  19.0.0git
TransformInterfaces.cpp
Go to the documentation of this file.
1 //===- TransformInterfaces.cpp - Transform Dialect Interfaces -------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
13 #include "mlir/IR/Diagnostics.h"
14 #include "mlir/IR/Operation.h"
15 #include "mlir/IR/PatternMatch.h"
19 #include "llvm/ADT/STLExtras.h"
20 #include "llvm/ADT/ScopeExit.h"
21 #include "llvm/Support/Debug.h"
22 #include "llvm/Support/ErrorHandling.h"
23 
24 #define DEBUG_TYPE "transform-dialect"
25 #define DEBUG_TYPE_FULL "transform-dialect-full"
26 #define DEBUG_PRINT_AFTER_ALL "transform-dialect-print-top-level-after-all"
27 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "] ")
28 #define LDBG(X) LLVM_DEBUG(DBGS() << (X))
29 #define FULL_LDBG(X) DEBUG_WITH_TYPE(DEBUG_TYPE_FULL, (DBGS() << (X)))
30 
31 using namespace mlir;
32 
33 //===----------------------------------------------------------------------===//
34 // Helper functions
35 //===----------------------------------------------------------------------===//
36 
37 /// Return true if `a` happens before `b`, i.e., `a` or one of its ancestors
38 /// properly dominates `b` and `b` is not inside `a`.
39 static bool happensBefore(Operation *a, Operation *b) {
40  do {
41  if (a->isProperAncestor(b))
42  return false;
43  if (Operation *bAncestor = a->getBlock()->findAncestorOpInBlock(*b)) {
44  return a->isBeforeInBlock(bAncestor);
45  }
46  } while ((a = a->getParentOp()));
47  return false;
48 }
49 
50 //===----------------------------------------------------------------------===//
51 // TransformState
52 //===----------------------------------------------------------------------===//
53 
54 constexpr const Value transform::TransformState::kTopLevelValue;
55 
56 transform::TransformState::TransformState(
57  Region *region, Operation *payloadRoot,
58  const RaggedArray<MappedValue> &extraMappings,
59  const TransformOptions &options)
60  : topLevel(payloadRoot), options(options) {
61  topLevelMappedValues.reserve(extraMappings.size());
62  for (ArrayRef<MappedValue> mapping : extraMappings)
63  topLevelMappedValues.push_back(mapping);
64  if (region) {
65  RegionScope *scope = new RegionScope(*this, *region);
66  topLevelRegionScope.reset(scope);
67  }
68 }
69 
70 Operation *transform::TransformState::getTopLevel() const { return topLevel; }
71 
73 transform::TransformState::getPayloadOpsView(Value value) const {
74  const TransformOpMapping &operationMapping = getMapping(value).direct;
75  auto iter = operationMapping.find(value);
76  assert(iter != operationMapping.end() &&
77  "cannot find mapping for payload handle (param/value handle "
78  "provided?)");
79  return iter->getSecond();
80 }
81 
83  const ParamMapping &mapping = getMapping(value).params;
84  auto iter = mapping.find(value);
85  assert(iter != mapping.end() && "cannot find mapping for param handle "
86  "(operation/value handle provided?)");
87  return iter->getSecond();
88 }
89 
91 transform::TransformState::getPayloadValuesView(Value handleValue) const {
92  const ValueMapping &mapping = getMapping(handleValue).values;
93  auto iter = mapping.find(handleValue);
94  assert(iter != mapping.end() && "cannot find mapping for value handle "
95  "(param/operation handle provided?)");
96  return iter->getSecond();
97 }
98 
100  Operation *op, SmallVectorImpl<Value> &handles,
101  bool includeOutOfScope) const {
102  bool found = false;
103  for (const auto &[region, mapping] : llvm::reverse(mappings)) {
104  auto iterator = mapping->reverse.find(op);
105  if (iterator != mapping->reverse.end()) {
106  llvm::append_range(handles, iterator->getSecond());
107  found = true;
108  }
109  // Stop looking when reaching a region that is isolated from above.
110  if (!includeOutOfScope &&
112  break;
113  }
114 
115  return success(found);
116 }
117 
119  Value payloadValue, SmallVectorImpl<Value> &handles,
120  bool includeOutOfScope) const {
121  bool found = false;
122  for (const auto &[region, mapping] : llvm::reverse(mappings)) {
123  auto iterator = mapping->reverseValues.find(payloadValue);
124  if (iterator != mapping->reverseValues.end()) {
125  llvm::append_range(handles, iterator->getSecond());
126  found = true;
127  }
128  // Stop looking when reaching a region that is isolated from above.
129  if (!includeOutOfScope &&
131  break;
132  }
133 
134  return success(found);
135 }
136 
137 /// Given a list of MappedValues, cast them to the value kind implied by the
138 /// interface of the handle type, and dispatch to one of the callbacks.
144  if (llvm::isa<transform::TransformHandleTypeInterface>(handle.getType())) {
145  SmallVector<Operation *> operations;
146  operations.reserve(values.size());
147  for (transform::MappedValue value : values) {
148  if (auto *op = llvm::dyn_cast_if_present<Operation *>(value)) {
149  operations.push_back(op);
150  continue;
151  }
152  return emitSilenceableFailure(handle.getLoc())
153  << "wrong kind of value provided for top-level operation handle";
154  }
155  if (failed(operationsFn(operations)))
158  }
159 
160  if (llvm::isa<transform::TransformValueHandleTypeInterface>(
161  handle.getType())) {
162  SmallVector<Value> payloadValues;
163  payloadValues.reserve(values.size());
164  for (transform::MappedValue value : values) {
165  if (auto v = llvm::dyn_cast_if_present<Value>(value)) {
166  payloadValues.push_back(v);
167  continue;
168  }
169  return emitSilenceableFailure(handle.getLoc())
170  << "wrong kind of value provided for the top-level value handle";
171  }
172  if (failed(valuesFn(payloadValues)))
175  }
176 
177  assert(llvm::isa<transform::TransformParamTypeInterface>(handle.getType()) &&
178  "unsupported kind of block argument");
180  parameters.reserve(values.size());
181  for (transform::MappedValue value : values) {
182  if (auto attr = llvm::dyn_cast_if_present<Attribute>(value)) {
183  parameters.push_back(attr);
184  continue;
185  }
186  return emitSilenceableFailure(handle.getLoc())
187  << "wrong kind of value provided for top-level parameter";
188  }
189  if (failed(paramsFn(parameters)))
192 }
193 
196  ArrayRef<MappedValue> values) {
197  return dispatchMappedValues(
198  argument, values,
199  [&](ArrayRef<Operation *> operations) {
200  return setPayloadOps(argument, operations);
201  },
202  [&](ArrayRef<Param> params) {
203  return setParams(argument, params);
204  },
205  [&](ValueRange payloadValues) {
206  return setPayloadValues(argument, payloadValues);
207  })
208  .checkAndReport();
209 }
210 
212 transform::TransformState::setPayloadOps(Value value,
213  ArrayRef<Operation *> targets) {
214  assert(value != kTopLevelValue &&
215  "attempting to reset the transformation root");
216  assert(llvm::isa<TransformHandleTypeInterface>(value.getType()) &&
217  "wrong handle type");
218 
219  for (Operation *target : targets) {
220  if (target)
221  continue;
222  return emitError(value.getLoc())
223  << "attempting to assign a null payload op to this transform value";
224  }
225 
226  auto iface = llvm::cast<TransformHandleTypeInterface>(value.getType());
228  iface.checkPayload(value.getLoc(), targets);
229  if (failed(result.checkAndReport()))
230  return failure();
231 
232  // Setting new payload for the value without cleaning it first is a misuse of
233  // the API, assert here.
234  SmallVector<Operation *> storedTargets(targets.begin(), targets.end());
235  Mappings &mappings = getMapping(value);
236  bool inserted =
237  mappings.direct.insert({value, std::move(storedTargets)}).second;
238  assert(inserted && "value is already associated with another list");
239  (void)inserted;
240 
241  for (Operation *op : targets)
242  mappings.reverse[op].push_back(value);
243 
244  return success();
245 }
246 
248 transform::TransformState::setPayloadValues(Value handle,
249  ValueRange payloadValues) {
250  assert(handle != nullptr && "attempting to set params for a null value");
251  assert(llvm::isa<TransformValueHandleTypeInterface>(handle.getType()) &&
252  "wrong handle type");
253 
254  for (Value payload : payloadValues) {
255  if (payload)
256  continue;
257  return emitError(handle.getLoc()) << "attempting to assign a null payload "
258  "value to this transform handle";
259  }
260 
261  auto iface = llvm::cast<TransformValueHandleTypeInterface>(handle.getType());
262  SmallVector<Value> payloadValueVector = llvm::to_vector(payloadValues);
264  iface.checkPayload(handle.getLoc(), payloadValueVector);
265  if (failed(result.checkAndReport()))
266  return failure();
267 
268  Mappings &mappings = getMapping(handle);
269  bool inserted =
270  mappings.values.insert({handle, std::move(payloadValueVector)}).second;
271  assert(
272  inserted &&
273  "value handle is already associated with another list of payload values");
274  (void)inserted;
275 
276  for (Value payload : payloadValues)
277  mappings.reverseValues[payload].push_back(handle);
278 
279  return success();
280 }
281 
282 LogicalResult transform::TransformState::setParams(Value value,
283  ArrayRef<Param> params) {
284  assert(value != nullptr && "attempting to set params for a null value");
285 
286  for (Attribute attr : params) {
287  if (attr)
288  continue;
289  return emitError(value.getLoc())
290  << "attempting to assign a null parameter to this transform value";
291  }
292 
293  auto valueType = llvm::dyn_cast<TransformParamTypeInterface>(value.getType());
294  assert(value &&
295  "cannot associate parameter with a value of non-parameter type");
297  valueType.checkPayload(value.getLoc(), params);
298  if (failed(result.checkAndReport()))
299  return failure();
300 
301  Mappings &mappings = getMapping(value);
302  bool inserted =
303  mappings.params.insert({value, llvm::to_vector(params)}).second;
304  assert(inserted && "value is already associated with another list of params");
305  (void)inserted;
306  return success();
307 }
308 
309 template <typename Mapping, typename Key, typename Mapped>
310 void dropMappingEntry(Mapping &mapping, Key key, Mapped mapped) {
311  auto it = mapping.find(key);
312  if (it == mapping.end())
313  return;
314 
315  llvm::erase(it->getSecond(), mapped);
316  if (it->getSecond().empty())
317  mapping.erase(it);
318 }
319 
320 void transform::TransformState::forgetMapping(Value opHandle,
321  ValueRange origOpFlatResults,
322  bool allowOutOfScope) {
323  Mappings &mappings = getMapping(opHandle, allowOutOfScope);
324  for (Operation *op : mappings.direct[opHandle])
325  dropMappingEntry(mappings.reverse, op, opHandle);
326  mappings.direct.erase(opHandle);
327 #ifdef LLVM_ENABLE_ABI_BREAKING_CHECKS
328  // Payload IR is removed from the mapping. This invalidates the respective
329  // iterators.
330  mappings.incrementTimestamp(opHandle);
331 #endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
332 
333  for (Value opResult : origOpFlatResults) {
334  SmallVector<Value> resultHandles;
335  (void)getHandlesForPayloadValue(opResult, resultHandles);
336  for (Value resultHandle : resultHandles) {
337  Mappings &localMappings = getMapping(resultHandle);
338  dropMappingEntry(localMappings.values, resultHandle, opResult);
339 #ifdef LLVM_ENABLE_ABI_BREAKING_CHECKS
340  // Payload IR is removed from the mapping. This invalidates the respective
341  // iterators.
342  mappings.incrementTimestamp(resultHandle);
343 #endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
344  dropMappingEntry(localMappings.reverseValues, opResult, resultHandle);
345  }
346  }
347 }
348 
349 void transform::TransformState::forgetValueMapping(
350  Value valueHandle, ArrayRef<Operation *> payloadOperations) {
351  Mappings &mappings = getMapping(valueHandle);
352  for (Value payloadValue : mappings.reverseValues[valueHandle])
353  dropMappingEntry(mappings.reverseValues, payloadValue, valueHandle);
354  mappings.values.erase(valueHandle);
355 #ifdef LLVM_ENABLE_ABI_BREAKING_CHECKS
356  // Payload IR is removed from the mapping. This invalidates the respective
357  // iterators.
358  mappings.incrementTimestamp(valueHandle);
359 #endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
360 
361  for (Operation *payloadOp : payloadOperations) {
362  SmallVector<Value> opHandles;
363  (void)getHandlesForPayloadOp(payloadOp, opHandles);
364  for (Value opHandle : opHandles) {
365  Mappings &localMappings = getMapping(opHandle);
366  dropMappingEntry(localMappings.direct, opHandle, payloadOp);
367  dropMappingEntry(localMappings.reverse, payloadOp, opHandle);
368 
369 #ifdef LLVM_ENABLE_ABI_BREAKING_CHECKS
370  // Payload IR is removed from the mapping. This invalidates the respective
371  // iterators.
372  localMappings.incrementTimestamp(opHandle);
373 #endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
374  }
375  }
376 }
377 
379 transform::TransformState::replacePayloadOp(Operation *op,
380  Operation *replacement) {
381  // TODO: consider invalidating the handles to nested objects here.
382 
383 #ifndef NDEBUG
384  for (Value opResult : op->getResults()) {
385  SmallVector<Value> valueHandles;
386  (void)getHandlesForPayloadValue(opResult, valueHandles,
387  /*includeOutOfScope=*/true);
388  assert(valueHandles.empty() && "expected no mapping to old results");
389  }
390 #endif // NDEBUG
391 
392  // Drop the mapping between the op and all handles that point to it. Fail if
393  // there are no handles.
394  SmallVector<Value> opHandles;
395  if (failed(getHandlesForPayloadOp(op, opHandles, /*includeOutOfScope=*/true)))
396  return failure();
397  for (Value handle : opHandles) {
398  Mappings &mappings = getMapping(handle, /*allowOutOfScope=*/true);
399  dropMappingEntry(mappings.reverse, op, handle);
400  }
401 
402  // Replace the pointed-to object of all handles with the replacement object.
403  // In case a payload op was erased (replacement object is nullptr), a nullptr
404  // is stored in the mapping. These nullptrs are removed after each transform.
405  // Furthermore, nullptrs are not enumerated by payload op iterators. The
406  // relative order of ops is preserved.
407  //
408  // Removing an op from the mapping would be problematic because removing an
409  // element from an array invalidates iterators; merely changing the value of
410  // elements does not.
411  for (Value handle : opHandles) {
412  Mappings &mappings = getMapping(handle, /*allowOutOfScope=*/true);
413  auto it = mappings.direct.find(handle);
414  if (it == mappings.direct.end())
415  continue;
416 
417  SmallVector<Operation *, 2> &association = it->getSecond();
418  // Note that an operation may be associated with the handle more than once.
419  for (Operation *&mapped : association) {
420  if (mapped == op)
421  mapped = replacement;
422  }
423 
424  if (replacement) {
425  mappings.reverse[replacement].push_back(handle);
426  } else {
427  opHandlesToCompact.insert(handle);
428  }
429  }
430 
431  return success();
432 }
433 
435 transform::TransformState::replacePayloadValue(Value value, Value replacement) {
436  SmallVector<Value> valueHandles;
437  if (failed(getHandlesForPayloadValue(value, valueHandles,
438  /*includeOutOfScope=*/true)))
439  return failure();
440 
441  for (Value handle : valueHandles) {
442  Mappings &mappings = getMapping(handle, /*allowOutOfScope=*/true);
443  dropMappingEntry(mappings.reverseValues, value, handle);
444 
445  // If replacing with null, that is erasing the mapping, drop the mapping
446  // between the handles and the IR objects
447  if (!replacement) {
448  dropMappingEntry(mappings.values, handle, value);
449 #ifdef LLVM_ENABLE_ABI_BREAKING_CHECKS
450  // Payload IR is removed from the mapping. This invalidates the respective
451  // iterators.
452  mappings.incrementTimestamp(handle);
453 #endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
454  } else {
455  auto it = mappings.values.find(handle);
456  if (it == mappings.values.end())
457  continue;
458 
459  SmallVector<Value> &association = it->getSecond();
460  for (Value &mapped : association) {
461  if (mapped == value)
462  mapped = replacement;
463  }
464  mappings.reverseValues[replacement].push_back(handle);
465  }
466  }
467 
468  return success();
469 }
470 
471 void transform::TransformState::recordOpHandleInvalidationOne(
472  OpOperand &consumingHandle, ArrayRef<Operation *> potentialAncestors,
473  Operation *payloadOp, Value otherHandle, Value throughValue,
474  transform::TransformState::InvalidatedHandleMap &newlyInvalidated) const {
475  // If the op is associated with invalidated handle, skip the check as it
476  // may be reading invalid IR. This also ensures we report the first
477  // invalidation and not the last one.
478  if (invalidatedHandles.count(otherHandle) ||
479  newlyInvalidated.count(otherHandle))
480  return;
481 
482  FULL_LDBG("--recordOpHandleInvalidationOne\n");
483  DEBUG_WITH_TYPE(
485  llvm::interleaveComma(potentialAncestors, DBGS() << "--ancestors: ",
486  [](Operation *op) { llvm::dbgs() << *op; });
487  llvm::dbgs() << "\n");
488 
489  Operation *owner = consumingHandle.getOwner();
490  unsigned operandNo = consumingHandle.getOperandNumber();
491  for (Operation *ancestor : potentialAncestors) {
492  // clang-format off
493  DEBUG_WITH_TYPE(DEBUG_TYPE_FULL,
494  { (DBGS() << "----handle one ancestor: " << *ancestor << "\n"); });
495  DEBUG_WITH_TYPE(DEBUG_TYPE_FULL,
496  { (DBGS() << "----of payload with name: "
497  << payloadOp->getName().getIdentifier() << "\n"); });
498  DEBUG_WITH_TYPE(DEBUG_TYPE_FULL,
499  { (DBGS() << "----of payload: " << *payloadOp << "\n"); });
500  // clang-format on
501  if (!ancestor->isAncestor(payloadOp))
502  continue;
503 
504  // Make sure the error-reporting lambda doesn't capture anything
505  // by-reference because it will go out of scope. Additionally, extract
506  // location from Payload IR ops because the ops themselves may be
507  // deleted before the lambda gets called.
508  Location ancestorLoc = ancestor->getLoc();
509  Location opLoc = payloadOp->getLoc();
510  std::optional<Location> throughValueLoc =
511  throughValue ? std::make_optional(throughValue.getLoc()) : std::nullopt;
512  newlyInvalidated[otherHandle] = [ancestorLoc, opLoc, owner, operandNo,
513  otherHandle,
514  throughValueLoc](Location currentLoc) {
515  InFlightDiagnostic diag = emitError(currentLoc)
516  << "op uses a handle invalidated by a "
517  "previously executed transform op";
518  diag.attachNote(otherHandle.getLoc()) << "handle to invalidated ops";
519  diag.attachNote(owner->getLoc())
520  << "invalidated by this transform op that consumes its operand #"
521  << operandNo
522  << " and invalidates all handles to payload IR entities associated "
523  "with this operand and entities nested in them";
524  diag.attachNote(ancestorLoc) << "ancestor payload op";
525  diag.attachNote(opLoc) << "nested payload op";
526  if (throughValueLoc) {
527  diag.attachNote(*throughValueLoc)
528  << "consumed handle points to this payload value";
529  }
530  };
531  }
532 }
533 
534 void transform::TransformState::recordValueHandleInvalidationByOpHandleOne(
535  OpOperand &opHandle, ArrayRef<Operation *> potentialAncestors,
536  Value payloadValue, Value valueHandle,
537  transform::TransformState::InvalidatedHandleMap &newlyInvalidated) const {
538  // If the op is associated with invalidated handle, skip the check as it
539  // may be reading invalid IR. This also ensures we report the first
540  // invalidation and not the last one.
541  if (invalidatedHandles.count(valueHandle) ||
542  newlyInvalidated.count(valueHandle))
543  return;
544 
545  for (Operation *ancestor : potentialAncestors) {
546  Operation *definingOp;
547  std::optional<unsigned> resultNo;
548  unsigned argumentNo = std::numeric_limits<unsigned>::max();
549  unsigned blockNo = std::numeric_limits<unsigned>::max();
550  unsigned regionNo = std::numeric_limits<unsigned>::max();
551  if (auto opResult = llvm::dyn_cast<OpResult>(payloadValue)) {
552  definingOp = opResult.getOwner();
553  resultNo = opResult.getResultNumber();
554  } else {
555  auto arg = llvm::cast<BlockArgument>(payloadValue);
556  definingOp = arg.getParentBlock()->getParentOp();
557  argumentNo = arg.getArgNumber();
558  blockNo = std::distance(arg.getOwner()->getParent()->begin(),
559  arg.getOwner()->getIterator());
560  regionNo = arg.getOwner()->getParent()->getRegionNumber();
561  }
562  assert(definingOp && "expected the value to be defined by an op as result "
563  "or block argument");
564  if (!ancestor->isAncestor(definingOp))
565  continue;
566 
567  Operation *owner = opHandle.getOwner();
568  unsigned operandNo = opHandle.getOperandNumber();
569  Location ancestorLoc = ancestor->getLoc();
570  Location opLoc = definingOp->getLoc();
571  Location valueLoc = payloadValue.getLoc();
572  newlyInvalidated[valueHandle] = [valueHandle, owner, operandNo, resultNo,
573  argumentNo, blockNo, regionNo, ancestorLoc,
574  opLoc, valueLoc](Location currentLoc) {
575  InFlightDiagnostic diag = emitError(currentLoc)
576  << "op uses a handle invalidated by a "
577  "previously executed transform op";
578  diag.attachNote(valueHandle.getLoc()) << "invalidated handle";
579  diag.attachNote(owner->getLoc())
580  << "invalidated by this transform op that consumes its operand #"
581  << operandNo
582  << " and invalidates all handles to payload IR entities "
583  "associated with this operand and entities nested in them";
584  diag.attachNote(ancestorLoc)
585  << "ancestor op associated with the consumed handle";
586  if (resultNo) {
587  diag.attachNote(opLoc)
588  << "op defining the value as result #" << *resultNo;
589  } else {
590  diag.attachNote(opLoc)
591  << "op defining the value as block argument #" << argumentNo
592  << " of block #" << blockNo << " in region #" << regionNo;
593  }
594  diag.attachNote(valueLoc) << "payload value";
595  };
596  }
597 }
598 
599 void transform::TransformState::recordOpHandleInvalidation(
600  OpOperand &handle, ArrayRef<Operation *> potentialAncestors,
601  Value throughValue,
602  transform::TransformState::InvalidatedHandleMap &newlyInvalidated) const {
603 
604  if (potentialAncestors.empty()) {
605  DEBUG_WITH_TYPE(DEBUG_TYPE_FULL, {
606  (DBGS() << "----recording invalidation for empty handle: " << handle.get()
607  << "\n");
608  });
609 
610  Operation *owner = handle.getOwner();
611  unsigned operandNo = handle.getOperandNumber();
612  newlyInvalidated[handle.get()] = [owner, operandNo](Location currentLoc) {
613  InFlightDiagnostic diag = emitError(currentLoc)
614  << "op uses a handle associated with empty "
615  "payload and invalidated by a "
616  "previously executed transform op";
617  diag.attachNote(owner->getLoc())
618  << "invalidated by this transform op that consumes its operand #"
619  << operandNo;
620  };
621  return;
622  }
623 
624  // Iterate over the mapping and invalidate aliasing handles. This is quite
625  // expensive and only necessary for error reporting in case of transform
626  // dialect misuse with dangling handles. Iteration over the handles is based
627  // on the assumption that the number of handles is significantly less than the
628  // number of IR objects (operations and values). Alternatively, we could walk
629  // the IR nested in each payload op associated with the given handle and look
630  // for handles associated with each operation and value.
631  for (const auto &[region, mapping] : llvm::reverse(mappings)) {
632  // Go over all op handle mappings and mark as invalidated any handle
633  // pointing to any of the payload ops associated with the given handle or
634  // any op nested in them.
635  for (const auto &[payloadOp, otherHandles] : mapping->reverse) {
636  for (Value otherHandle : otherHandles)
637  recordOpHandleInvalidationOne(handle, potentialAncestors, payloadOp,
638  otherHandle, throughValue,
639  newlyInvalidated);
640  }
641  // Go over all value handle mappings and mark as invalidated any handle
642  // pointing to any result of the payload op associated with the given handle
643  // or any op nested in them. Similarly invalidate handles to argument of
644  // blocks belonging to any region of any payload op associated with the
645  // given handle or any op nested in them.
646  for (const auto &[payloadValue, valueHandles] : mapping->reverseValues) {
647  for (Value valueHandle : valueHandles)
648  recordValueHandleInvalidationByOpHandleOne(handle, potentialAncestors,
649  payloadValue, valueHandle,
650  newlyInvalidated);
651  }
652 
653  // Stop lookup when reaching a region that is isolated from above.
655  break;
656  }
657 }
658 
659 void transform::TransformState::recordValueHandleInvalidation(
660  OpOperand &valueHandle,
661  transform::TransformState::InvalidatedHandleMap &newlyInvalidated) const {
662  // Invalidate other handles to the same value.
663  for (Value payloadValue : getPayloadValuesView(valueHandle.get())) {
664  SmallVector<Value> otherValueHandles;
665  (void)getHandlesForPayloadValue(payloadValue, otherValueHandles);
666  for (Value otherHandle : otherValueHandles) {
667  Operation *owner = valueHandle.getOwner();
668  unsigned operandNo = valueHandle.getOperandNumber();
669  Location valueLoc = payloadValue.getLoc();
670  newlyInvalidated[otherHandle] = [otherHandle, owner, operandNo,
671  valueLoc](Location currentLoc) {
672  InFlightDiagnostic diag = emitError(currentLoc)
673  << "op uses a handle invalidated by a "
674  "previously executed transform op";
675  diag.attachNote(otherHandle.getLoc()) << "invalidated handle";
676  diag.attachNote(owner->getLoc())
677  << "invalidated by this transform op that consumes its operand #"
678  << operandNo
679  << " and invalidates handles to the same values as associated with "
680  "it";
681  diag.attachNote(valueLoc) << "payload value";
682  };
683  }
684 
685  if (auto opResult = llvm::dyn_cast<OpResult>(payloadValue)) {
686  Operation *payloadOp = opResult.getOwner();
687  recordOpHandleInvalidation(valueHandle, payloadOp, payloadValue,
688  newlyInvalidated);
689  } else {
690  auto arg = llvm::dyn_cast<BlockArgument>(payloadValue);
691  for (Operation &payloadOp : *arg.getOwner())
692  recordOpHandleInvalidation(valueHandle, &payloadOp, payloadValue,
693  newlyInvalidated);
694  }
695  }
696 }
697 
698 /// Checks that the operation does not use invalidated handles as operands.
699 /// Reports errors and returns failure if it does. Otherwise, invalidates the
700 /// handles consumed by the operation as well as any handles pointing to payload
701 /// IR operations nested in the operations associated with the consumed handles.
702 LogicalResult transform::TransformState::checkAndRecordHandleInvalidationImpl(
703  transform::TransformOpInterface transform,
704  transform::TransformState::InvalidatedHandleMap &newlyInvalidated) const {
705  FULL_LDBG("--Start checkAndRecordHandleInvalidation\n");
706  auto memoryEffectsIface =
707  cast<MemoryEffectOpInterface>(transform.getOperation());
709  memoryEffectsIface.getEffectsOnResource(
711 
712  for (OpOperand &target : transform->getOpOperands()) {
713  DEBUG_WITH_TYPE(DEBUG_TYPE_FULL, {
714  (DBGS() << "----iterate on handle: " << target.get() << "\n");
715  });
716  // If the operand uses an invalidated handle, report it. If the operation
717  // allows handles to point to repeated payload operations, only report
718  // pre-existing invalidation errors. Otherwise, also report invalidations
719  // caused by the current transform operation affecting its other operands.
720  auto it = invalidatedHandles.find(target.get());
721  auto nit = newlyInvalidated.find(target.get());
722  if (it != invalidatedHandles.end()) {
723  FULL_LDBG("--End checkAndRecordHandleInvalidation, found already "
724  "invalidated -> FAILURE\n");
725  return it->getSecond()(transform->getLoc()), failure();
726  }
727  if (!transform.allowsRepeatedHandleOperands() &&
728  nit != newlyInvalidated.end()) {
729  FULL_LDBG("--End checkAndRecordHandleInvalidation, found newly "
730  "invalidated (by this op) -> FAILURE\n");
731  return nit->getSecond()(transform->getLoc()), failure();
732  }
733 
734  // Invalidate handles pointing to the operations nested in the operation
735  // associated with the handle consumed by this operation.
736  auto consumesTarget = [&](const MemoryEffects::EffectInstance &effect) {
737  return isa<MemoryEffects::Free>(effect.getEffect()) &&
738  effect.getValue() == target.get();
739  };
740  if (llvm::any_of(effects, consumesTarget)) {
741  FULL_LDBG("----found consume effect\n");
742  if (llvm::isa<transform::TransformHandleTypeInterface>(
743  target.get().getType())) {
744  FULL_LDBG("----recordOpHandleInvalidation\n");
745  SmallVector<Operation *> payloadOps =
746  llvm::to_vector(getPayloadOps(target.get()));
747  recordOpHandleInvalidation(target, payloadOps, nullptr,
748  newlyInvalidated);
749  } else if (llvm::isa<transform::TransformValueHandleTypeInterface>(
750  target.get().getType())) {
751  FULL_LDBG("----recordValueHandleInvalidation\n");
752  recordValueHandleInvalidation(target, newlyInvalidated);
753  } else {
754  FULL_LDBG("----not a TransformHandle -> SKIP AND DROP ON THE FLOOR\n");
755  }
756  } else {
757  FULL_LDBG("----no consume effect -> SKIP\n");
758  }
759  }
760 
761  FULL_LDBG("--End checkAndRecordHandleInvalidation -> SUCCESS\n");
762  return success();
763 }
764 
765 LogicalResult transform::TransformState::checkAndRecordHandleInvalidation(
766  transform::TransformOpInterface transform) {
767  InvalidatedHandleMap newlyInvalidated;
768  LogicalResult checkResult =
769  checkAndRecordHandleInvalidationImpl(transform, newlyInvalidated);
770  invalidatedHandles.insert(std::make_move_iterator(newlyInvalidated.begin()),
771  std::make_move_iterator(newlyInvalidated.end()));
772  return checkResult;
773 }
774 
775 template <typename T>
778  transform::TransformOpInterface transform,
779  unsigned operandNumber) {
780  DenseSet<T> seen;
781  for (T p : payload) {
782  if (!seen.insert(p).second) {
784  transform.emitSilenceableError()
785  << "a handle passed as operand #" << operandNumber
786  << " and consumed by this operation points to a payload "
787  "entity more than once";
788  if constexpr (std::is_pointer_v<T>)
789  diag.attachNote(p->getLoc()) << "repeated target op";
790  else
791  diag.attachNote(p.getLoc()) << "repeated target value";
792  return diag;
793  }
794  }
796 }
797 
798 void transform::TransformState::compactOpHandles() {
799  for (Value handle : opHandlesToCompact) {
800  Mappings &mappings = getMapping(handle, /*allowOutOfScope=*/true);
801 #ifdef LLVM_ENABLE_ABI_BREAKING_CHECKS
802  if (llvm::find(mappings.direct[handle], nullptr) !=
803  mappings.direct[handle].end())
804  // Payload IR is removed from the mapping. This invalidates the respective
805  // iterators.
806  mappings.incrementTimestamp(handle);
807 #endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
808  llvm::erase(mappings.direct[handle], nullptr);
809  }
810  opHandlesToCompact.clear();
811 }
812 
814 transform::TransformState::applyTransform(TransformOpInterface transform) {
815  LLVM_DEBUG({
816  DBGS() << "applying: ";
817  transform->print(llvm::dbgs(), OpPrintingFlags().skipRegions());
818  llvm::dbgs() << "\n";
819  });
820  DEBUG_WITH_TYPE(DEBUG_TYPE_FULL,
821  DBGS() << "Top-level payload before application:\n"
822  << *getTopLevel() << "\n");
823  auto printOnFailureRAII = llvm::make_scope_exit([this] {
824  (void)this;
825  LLVM_DEBUG(DBGS() << "Failing Top-level payload:\n"; getTopLevel()->print(
826  llvm::dbgs(), mlir::OpPrintingFlags().printGenericOpForm()););
827  });
828 
829  // Set current transform op.
830  regionStack.back()->currentTransform = transform;
831 
832  // Expensive checks to detect invalid transform IR.
833  if (options.getExpensiveChecksEnabled()) {
834  FULL_LDBG("ExpensiveChecksEnabled\n");
835  if (failed(checkAndRecordHandleInvalidation(transform)))
837 
838  for (OpOperand &operand : transform->getOpOperands()) {
839  DEBUG_WITH_TYPE(DEBUG_TYPE_FULL, {
840  (DBGS() << "iterate on handle: " << operand.get() << "\n");
841  });
842  if (!isHandleConsumed(operand.get(), transform)) {
843  FULL_LDBG("--handle not consumed -> SKIP\n");
844  continue;
845  }
846  if (transform.allowsRepeatedHandleOperands()) {
847  FULL_LDBG("--op allows repeated handles -> SKIP\n");
848  continue;
849  }
850  FULL_LDBG("--handle is consumed\n");
851 
852  Type operandType = operand.get().getType();
853  if (llvm::isa<TransformHandleTypeInterface>(operandType)) {
854  FULL_LDBG("--checkRepeatedConsumptionInOperand for Operation*\n");
856  checkRepeatedConsumptionInOperand<Operation *>(
857  getPayloadOpsView(operand.get()), transform,
858  operand.getOperandNumber());
859  if (!check.succeeded()) {
860  FULL_LDBG("----FAILED\n");
861  return check;
862  }
863  } else if (llvm::isa<TransformValueHandleTypeInterface>(operandType)) {
864  FULL_LDBG("--checkRepeatedConsumptionInOperand For Value\n");
866  checkRepeatedConsumptionInOperand<Value>(
867  getPayloadValuesView(operand.get()), transform,
868  operand.getOperandNumber());
869  if (!check.succeeded()) {
870  FULL_LDBG("----FAILED\n");
871  return check;
872  }
873  } else {
874  FULL_LDBG("--not a TransformHandle -> SKIP AND DROP ON THE FLOOR\n");
875  }
876  }
877  }
878 
879  // Find which operands are consumed.
880  SmallVector<OpOperand *> consumedOperands =
881  transform.getConsumedHandleOpOperands();
882 
883  // Remember the results of the payload ops associated with the consumed
884  // op handles or the ops defining the value handles so we can drop the
885  // association with them later. This must happen here because the
886  // transformation may destroy or mutate them so we cannot traverse the payload
887  // IR after that.
888  SmallVector<Value> origOpFlatResults;
889  SmallVector<Operation *> origAssociatedOps;
890  for (OpOperand *opOperand : consumedOperands) {
891  Value operand = opOperand->get();
892  if (llvm::isa<TransformHandleTypeInterface>(operand.getType())) {
893  for (Operation *payloadOp : getPayloadOps(operand)) {
894  llvm::append_range(origOpFlatResults, payloadOp->getResults());
895  }
896  continue;
897  }
898  if (llvm::isa<TransformValueHandleTypeInterface>(operand.getType())) {
899  for (Value payloadValue : getPayloadValuesView(operand)) {
900  if (llvm::isa<OpResult>(payloadValue)) {
901  origAssociatedOps.push_back(payloadValue.getDefiningOp());
902  continue;
903  }
904  llvm::append_range(
905  origAssociatedOps,
906  llvm::map_range(*llvm::cast<BlockArgument>(payloadValue).getOwner(),
907  [](Operation &op) { return &op; }));
908  }
909  continue;
910  }
912  emitDefiniteFailure(transform->getLoc())
913  << "unexpectedly consumed a value that is not a handle as operand #"
914  << opOperand->getOperandNumber();
915  diag.attachNote(operand.getLoc())
916  << "value defined here with type " << operand.getType();
917  return diag;
918  }
919 
920  // Prepare rewriter and listener.
921  TrackingListener::SkipHandleFn skipHandleFn = [&](Value handle) {
922  // Skip handle if it is dead.
923  auto scopeIt =
924  llvm::find_if(llvm::reverse(regionStack), [&](RegionScope *scope) {
925  return handle.getParentRegion() == scope->region;
926  });
927  assert(scopeIt != regionStack.rend() &&
928  "could not find region scope for handle");
929  RegionScope *scope = *scopeIt;
930  for (Operation *user : handle.getUsers()) {
931  if (user != scope->currentTransform &&
932  !happensBefore(user, scope->currentTransform))
933  return false;
934  }
935  return true;
936  };
937  transform::ErrorCheckingTrackingListener trackingListener(*this, transform,
938  skipHandleFn);
939  transform::TransformRewriter rewriter(transform->getContext(),
940  &trackingListener);
941 
942  // Compute the result but do not short-circuit the silenceable failure case as
943  // we still want the handles to propagate properly so the "suppress" mode can
944  // proceed on a best effort basis.
945  transform::TransformResults results(transform->getNumResults());
946  DiagnosedSilenceableFailure result(transform.apply(rewriter, results, *this));
947  compactOpHandles();
948 
949  // Error handling: fail if transform or listener failed.
950  DiagnosedSilenceableFailure trackingFailure =
951  trackingListener.checkAndResetError();
952  if (!transform->hasTrait<ReportTrackingListenerFailuresOpTrait>() ||
953  transform->hasAttr(
954  transform::TransformDialect::kSilenceTrackingFailuresAttrName)) {
955  // Only report failures for ReportTrackingListenerFailuresOpTrait ops. Also
956  // do not report failures if the above mentioned attribute is set.
957  if (trackingFailure.isSilenceableFailure())
958  (void)trackingFailure.silence();
959  trackingFailure = DiagnosedSilenceableFailure::success();
960  }
961  if (!trackingFailure.succeeded()) {
962  if (result.succeeded()) {
963  result = std::move(trackingFailure);
964  } else {
965  // Transform op errors have precedence, report those first.
966  if (result.isSilenceableFailure())
967  result.attachNote() << "tracking listener also failed: "
968  << trackingFailure.getMessage();
969  (void)trackingFailure.silence();
970  }
971  }
972  if (result.isDefiniteFailure())
973  return result;
974 
975  // If a silenceable failure was produced, some results may be unset, set them
976  // to empty lists.
977  if (result.isSilenceableFailure())
978  results.setRemainingToEmpty(transform);
979 
980  // Remove the mapping for the operand if it is consumed by the operation. This
981  // allows us to catch use-after-free with assertions later on.
982  for (OpOperand *opOperand : consumedOperands) {
983  Value operand = opOperand->get();
984  if (llvm::isa<TransformHandleTypeInterface>(operand.getType())) {
985  forgetMapping(operand, origOpFlatResults);
986  } else if (llvm::isa<TransformValueHandleTypeInterface>(
987  operand.getType())) {
988  forgetValueMapping(operand, origAssociatedOps);
989  }
990  }
991 
992  if (failed(updateStateFromResults(results, transform->getResults())))
994 
995  printOnFailureRAII.release();
996  DEBUG_WITH_TYPE(DEBUG_PRINT_AFTER_ALL, {
997  DBGS() << "Top-level payload:\n";
998  getTopLevel()->print(llvm::dbgs());
999  });
1000  return result;
1001 }
1002 
1003 LogicalResult transform::TransformState::updateStateFromResults(
1004  const TransformResults &results, ResultRange opResults) {
1005  for (OpResult result : opResults) {
1006  if (llvm::isa<TransformParamTypeInterface>(result.getType())) {
1007  assert(results.isParam(result.getResultNumber()) &&
1008  "expected parameters for the parameter-typed result");
1009  if (failed(
1010  setParams(result, results.getParams(result.getResultNumber())))) {
1011  return failure();
1012  }
1013  } else if (llvm::isa<TransformValueHandleTypeInterface>(result.getType())) {
1014  assert(results.isValue(result.getResultNumber()) &&
1015  "expected values for value-type-result");
1016  if (failed(setPayloadValues(
1017  result, results.getValues(result.getResultNumber())))) {
1018  return failure();
1019  }
1020  } else {
1021  assert(!results.isParam(result.getResultNumber()) &&
1022  "expected payload ops for the non-parameter typed result");
1023  if (failed(
1024  setPayloadOps(result, results.get(result.getResultNumber())))) {
1025  return failure();
1026  }
1027  }
1028  }
1029  return success();
1030 }
1031 
1032 //===----------------------------------------------------------------------===//
1033 // TransformState::Extension
1034 //===----------------------------------------------------------------------===//
1035 
1037 
1040  Operation *replacement) {
1041  // TODO: we may need to invalidate handles to operations and values nested in
1042  // the operation being replaced.
1043  return state.replacePayloadOp(op, replacement);
1044 }
1045 
1048  Value replacement) {
1049  return state.replacePayloadValue(value, replacement);
1050 }
1051 
1052 //===----------------------------------------------------------------------===//
1053 // TransformState::RegionScope
1054 //===----------------------------------------------------------------------===//
1055 
1057  // Remove handle invalidation notices as handles are going out of scope.
1058  // The same region may be re-entered leading to incorrect invalidation
1059  // errors.
1060  for (Block &block : *region) {
1061  for (Value handle : block.getArguments()) {
1062  state.invalidatedHandles.erase(handle);
1063  }
1064  for (Operation &op : block) {
1065  for (Value handle : op.getResults()) {
1066  state.invalidatedHandles.erase(handle);
1067  }
1068  }
1069  }
1070 
1071 #if LLVM_ENABLE_ABI_BREAKING_CHECKS
1072  // Remember pointers to payload ops referenced by the handles going out of
1073  // scope.
1074  SmallVector<Operation *> referencedOps =
1075  llvm::to_vector(llvm::make_first_range(state.mappings[region]->reverse));
1076 #endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
1077 
1078  state.mappings.erase(region);
1079  state.regionStack.pop_back();
1080 }
1081 
1082 //===----------------------------------------------------------------------===//
1083 // TransformResults
1084 //===----------------------------------------------------------------------===//
1085 
1086 transform::TransformResults::TransformResults(unsigned numSegments) {
1087  operations.appendEmptyRows(numSegments);
1088  params.appendEmptyRows(numSegments);
1089  values.appendEmptyRows(numSegments);
1090 }
1091 
1094  int64_t position = value.getResultNumber();
1095  assert(position < static_cast<int64_t>(this->params.size()) &&
1096  "setting params for a non-existent handle");
1097  assert(this->params[position].data() == nullptr && "params already set");
1098  assert(operations[position].data() == nullptr &&
1099  "another kind of results already set");
1100  assert(values[position].data() == nullptr &&
1101  "another kind of results already set");
1102  this->params.replace(position, params);
1103 }
1104 
1106  OpResult handle, ArrayRef<MappedValue> values) {
1108  handle, values,
1109  [&](ArrayRef<Operation *> operations) {
1110  return set(handle, operations), success();
1111  },
1112  [&](ArrayRef<Param> params) {
1113  return setParams(handle, params), success();
1114  },
1115  [&](ValueRange payloadValues) {
1116  return setValues(handle, payloadValues), success();
1117  });
1118 #ifndef NDEBUG
1119  if (!diag.succeeded())
1120  llvm::dbgs() << diag.getStatusString() << "\n";
1121  assert(diag.succeeded() && "incorrect mapping");
1122 #endif // NDEBUG
1123  (void)diag.silence();
1124 }
1125 
1127  transform::TransformOpInterface transform) {
1128  for (OpResult opResult : transform->getResults()) {
1129  if (!isSet(opResult.getResultNumber()))
1130  setMappedValues(opResult, {});
1131  }
1132 }
1133 
1135 transform::TransformResults::get(unsigned resultNumber) const {
1136  assert(resultNumber < operations.size() &&
1137  "querying results for a non-existent handle");
1138  assert(operations[resultNumber].data() != nullptr &&
1139  "querying unset results (values or params expected?)");
1140  return operations[resultNumber];
1141 }
1142 
1144 transform::TransformResults::getParams(unsigned resultNumber) const {
1145  assert(resultNumber < params.size() &&
1146  "querying params for a non-existent handle");
1147  assert(params[resultNumber].data() != nullptr &&
1148  "querying unset params (ops or values expected?)");
1149  return params[resultNumber];
1150 }
1151 
1153 transform::TransformResults::getValues(unsigned resultNumber) const {
1154  assert(resultNumber < values.size() &&
1155  "querying values for a non-existent handle");
1156  assert(values[resultNumber].data() != nullptr &&
1157  "querying unset values (ops or params expected?)");
1158  return values[resultNumber];
1159 }
1160 
1161 bool transform::TransformResults::isParam(unsigned resultNumber) const {
1162  assert(resultNumber < params.size() &&
1163  "querying association for a non-existent handle");
1164  return params[resultNumber].data() != nullptr;
1165 }
1166 
1167 bool transform::TransformResults::isValue(unsigned resultNumber) const {
1168  assert(resultNumber < values.size() &&
1169  "querying association for a non-existent handle");
1170  return values[resultNumber].data() != nullptr;
1171 }
1172 
1173 bool transform::TransformResults::isSet(unsigned resultNumber) const {
1174  assert(resultNumber < params.size() &&
1175  "querying association for a non-existent handle");
1176  return params[resultNumber].data() != nullptr ||
1177  operations[resultNumber].data() != nullptr ||
1178  values[resultNumber].data() != nullptr;
1179 }
1180 
1181 //===----------------------------------------------------------------------===//
1182 // TrackingListener
1183 //===----------------------------------------------------------------------===//
1184 
1186  TransformOpInterface op,
1187  SkipHandleFn skipHandleFn)
1188  : TransformState::Extension(state), transformOp(op),
1189  skipHandleFn(skipHandleFn) {
1190  if (op) {
1191  for (OpOperand *opOperand : transformOp.getConsumedHandleOpOperands()) {
1192  consumedHandles.insert(opOperand->get());
1193  }
1194  }
1195 }
1196 
1198  Operation *defOp = nullptr;
1199  for (Value v : values) {
1200  // Skip empty values.
1201  if (!v)
1202  continue;
1203  if (!defOp) {
1204  defOp = v.getDefiningOp();
1205  continue;
1206  }
1207  if (defOp != v.getDefiningOp())
1208  return nullptr;
1209  }
1210  return defOp;
1211 }
1212 
1214  Operation *&result, Operation *op, ValueRange newValues) const {
1215  assert(op->getNumResults() == newValues.size() &&
1216  "invalid number of replacement values");
1217  SmallVector<Value> values(newValues.begin(), newValues.end());
1218 
1220  getTransformOp(), "tracking listener failed to find replacement op "
1221  "during application of this transform op");
1222 
1223  do {
1224  // If the replacement values belong to different ops, drop the mapping.
1225  Operation *defOp = getCommonDefiningOp(values);
1226  if (!defOp) {
1227  diag.attachNote() << "replacement values belong to different ops";
1228  return diag;
1229  }
1230 
1231  // If the defining op has the same type, we take it as a replacement.
1232  if (op->getName() == defOp->getName()) {
1233  result = defOp;
1235  }
1236 
1237  // Replacing an op with a constant-like equivalent is a common
1238  // canonicalization.
1239  if (defOp->hasTrait<OpTrait::ConstantLike>()) {
1240  result = defOp;
1242  }
1243 
1244  values.clear();
1245 
1246  // Skip through ops that implement FindPayloadReplacementOpInterface.
1247  if (auto findReplacementOpInterface =
1248  dyn_cast<FindPayloadReplacementOpInterface>(defOp)) {
1249  values.assign(findReplacementOpInterface.getNextOperands());
1250  diag.attachNote(defOp->getLoc()) << "using operands provided by "
1251  "'FindPayloadReplacementOpInterface'";
1252  continue;
1253  }
1254 
1255  // Skip through ops that implement CastOpInterface.
1256  if (isa<CastOpInterface>(defOp)) {
1257  values.assign(defOp->getOperands().begin(), defOp->getOperands().end());
1258  diag.attachNote(defOp->getLoc())
1259  << "using output of 'CastOpInterface' op";
1260  continue;
1261  }
1262  } while (!values.empty());
1263 
1264  diag.attachNote() << "ran out of suitable replacement values";
1265  return diag;
1266 }
1267 
1269  Location loc, function_ref<void(Diagnostic &)> reasonCallback) {
1270  LLVM_DEBUG({
1272  reasonCallback(diag);
1273  DBGS() << "Match Failure : " << diag.str() << "\n";
1274  });
1275 }
1276 
1277 void transform::TrackingListener::notifyOperationErased(Operation *op) {
1278  // TODO: Walk can be removed when D144193 has landed.
1279  op->walk([&](Operation *op) {
1280  // Remove mappings for result values.
1281  for (OpResult value : op->getResults())
1282  (void)replacePayloadValue(value, nullptr);
1283  // Remove mapping for op.
1284  (void)replacePayloadOp(op, nullptr);
1285  });
1286 }
1287 
1288 void transform::TrackingListener::notifyOperationReplaced(
1289  Operation *op, ValueRange newValues) {
1290  assert(op->getNumResults() == newValues.size() &&
1291  "invalid number of replacement values");
1292 
1293  // Replace value handles.
1294  for (auto [oldValue, newValue] : llvm::zip(op->getResults(), newValues))
1295  (void)replacePayloadValue(oldValue, newValue);
1296 
1297  // Replace op handle.
1298  SmallVector<Value> opHandles;
1299  if (failed(getTransformState().getHandlesForPayloadOp(
1300  op, opHandles, /*includeOutOfScope=*/true))) {
1301  // Op is not tracked.
1302  return;
1303  }
1304 
1305  // Helper function to check if the current transform op consumes any handle
1306  // that is mapped to `op`.
1307  //
1308  // Note: If a handle was consumed, there shouldn't be any alive users, so it
1309  // is not really necessary to check for consumed handles. However, in case
1310  // there are indeed alive handles that were consumed (which is undefined
1311  // behavior) and a replacement op could not be found, we want to fail with a
1312  // nicer error message: "op uses a handle invalidated..." instead of "could
1313  // not find replacement op". This nicer error is produced later.
1314  auto handleWasConsumed = [&] {
1315  return llvm::any_of(opHandles,
1316  [&](Value h) { return consumedHandles.contains(h); });
1317  };
1318 
1319  // Check if there are any handles that must be updated.
1320  Value aliveHandle;
1321  if (skipHandleFn) {
1322  auto it =
1323  llvm::find_if(opHandles, [&](Value v) { return !skipHandleFn(v); });
1324  if (it != opHandles.end())
1325  aliveHandle = *it;
1326  } else if (!opHandles.empty()) {
1327  aliveHandle = opHandles.front();
1328  }
1329  if (!aliveHandle || handleWasConsumed()) {
1330  // The op is tracked but the corresponding handles are dead or were
1331  // consumed. Drop the op form the mapping.
1332  (void)replacePayloadOp(op, nullptr);
1333  return;
1334  }
1335 
1336  Operation *replacement;
1338  findReplacementOp(replacement, op, newValues);
1339  // If the op is tracked but no replacement op was found, send a
1340  // notification.
1341  if (!diag.succeeded()) {
1342  diag.attachNote(aliveHandle.getLoc())
1343  << "replacement is required because this handle must be updated";
1344  notifyPayloadReplacementNotFound(op, newValues, std::move(diag));
1345  (void)replacePayloadOp(op, nullptr);
1346  return;
1347  }
1348 
1349  (void)replacePayloadOp(op, replacement);
1350 }
1351 
1353  // The state of the ErrorCheckingTrackingListener must be checked and reset
1354  // if there was an error. This is to prevent errors from accidentally being
1355  // missed.
1356  assert(status.succeeded() && "listener state was not checked");
1357 }
1358 
1361  DiagnosedSilenceableFailure s = std::move(status);
1363  errorCounter = 0;
1364  return s;
1365 }
1366 
1368  return !status.succeeded();
1369 }
1370 
1373 
1374  // Merge potentially existing diags and store the result in the listener.
1376  diag.takeDiagnostics(diags);
1377  if (!status.succeeded())
1378  status.takeDiagnostics(diags);
1379  status = DiagnosedSilenceableFailure::silenceableFailure(std::move(diags));
1380 
1381  // Report more details.
1382  status.attachNote(op->getLoc()) << "[" << errorCounter << "] replaced op";
1383  for (auto &&[index, value] : llvm::enumerate(values))
1384  status.attachNote(value.getLoc())
1385  << "[" << errorCounter << "] replacement value " << index;
1386  ++errorCounter;
1387 }
1388 
1389 //===----------------------------------------------------------------------===//
1390 // TransformRewriter
1391 //===----------------------------------------------------------------------===//
1392 
1395  : RewriterBase(ctx), listener(listener) {
1396  setListener(listener);
1397 }
1398 
1400  return listener->failed();
1401 }
1402 
1403 /// Silence all tracking failures that have been encountered so far.
1405  if (hasTrackingFailures()) {
1406  DiagnosedSilenceableFailure status = listener->checkAndResetError();
1407  (void)status.silence();
1408  }
1409 }
1410 
1412  Operation *op, Operation *replacement) {
1413  return listener->replacePayloadOp(op, replacement);
1414 }
1415 
1416 //===----------------------------------------------------------------------===//
1417 // Utilities for TransformEachOpTrait.
1418 //===----------------------------------------------------------------------===//
1419 
1422  ArrayRef<Operation *> targets) {
1423  for (auto &&[position, parent] : llvm::enumerate(targets)) {
1424  for (Operation *child : targets.drop_front(position + 1)) {
1425  if (parent->isAncestor(child)) {
1427  emitError(loc)
1428  << "transform operation consumes a handle pointing to an ancestor "
1429  "payload operation before its descendant";
1430  diag.attachNote()
1431  << "the ancestor is likely erased or rewritten before the "
1432  "descendant is accessed, leading to undefined behavior";
1433  diag.attachNote(parent->getLoc()) << "ancestor payload op";
1434  diag.attachNote(child->getLoc()) << "descendant payload op";
1435  return diag;
1436  }
1437  }
1438  }
1439  return success();
1440 }
1441 
1444  Location payloadOpLoc,
1445  const ApplyToEachResultList &partialResult) {
1446  Location transformOpLoc = transformOp->getLoc();
1447  StringRef transformOpName = transformOp->getName().getStringRef();
1448  unsigned expectedNumResults = transformOp->getNumResults();
1449 
1450  // Reuse the emission of the diagnostic note.
1451  auto emitDiag = [&]() {
1452  auto diag = mlir::emitError(transformOpLoc);
1453  diag.attachNote(payloadOpLoc) << "when applied to this op";
1454  return diag;
1455  };
1456 
1457  if (partialResult.size() != expectedNumResults) {
1458  auto diag = emitDiag() << "application of " << transformOpName
1459  << " expected to produce " << expectedNumResults
1460  << " results (actually produced "
1461  << partialResult.size() << ").";
1462  diag.attachNote(transformOpLoc)
1463  << "if you need variadic results, consider a generic `apply` "
1464  << "instead of the specialized `applyToOne`.";
1465  return failure();
1466  }
1467 
1468  // Check that the right kind of value was produced.
1469  for (const auto &[ptr, res] :
1470  llvm::zip(partialResult, transformOp->getResults())) {
1471  if (ptr.isNull())
1472  continue;
1473  if (llvm::isa<TransformHandleTypeInterface>(res.getType()) &&
1474  !ptr.is<Operation *>()) {
1475  return emitDiag() << "application of " << transformOpName
1476  << " expected to produce an Operation * for result #"
1477  << res.getResultNumber();
1478  }
1479  if (llvm::isa<TransformParamTypeInterface>(res.getType()) &&
1480  !ptr.is<Attribute>()) {
1481  return emitDiag() << "application of " << transformOpName
1482  << " expected to produce an Attribute for result #"
1483  << res.getResultNumber();
1484  }
1485  if (llvm::isa<TransformValueHandleTypeInterface>(res.getType()) &&
1486  !ptr.is<Value>()) {
1487  return emitDiag() << "application of " << transformOpName
1488  << " expected to produce a Value for result #"
1489  << res.getResultNumber();
1490  }
1491  }
1492  return success();
1493 }
1494 
1495 template <typename T>
1497  return llvm::to_vector(llvm::map_range(
1498  range, [](transform::MappedValue value) { return value.get<T>(); }));
1499 }
1500 
1502  Operation *transformOp, TransformResults &transformResults,
1505  transposed.resize(transformOp->getNumResults());
1506  for (const ApplyToEachResultList &partialResults : results) {
1507  if (llvm::any_of(partialResults,
1508  [](MappedValue value) { return value.isNull(); }))
1509  continue;
1510  assert(transformOp->getNumResults() == partialResults.size() &&
1511  "expected as many partial results as op as results");
1512  for (auto [i, value] : llvm::enumerate(partialResults))
1513  transposed[i].push_back(value);
1514  }
1515 
1516  for (OpResult r : transformOp->getResults()) {
1517  unsigned position = r.getResultNumber();
1518  if (llvm::isa<TransformParamTypeInterface>(r.getType())) {
1519  transformResults.setParams(r,
1520  castVector<Attribute>(transposed[position]));
1521  } else if (llvm::isa<TransformValueHandleTypeInterface>(r.getType())) {
1522  transformResults.setValues(r, castVector<Value>(transposed[position]));
1523  } else {
1524  transformResults.set(r, castVector<Operation *>(transposed[position]));
1525  }
1526  }
1527 }
1528 
1529 //===----------------------------------------------------------------------===//
1530 // Utilities for implementing transform ops with regions.
1531 //===----------------------------------------------------------------------===//
1532 
1535  ValueRange values, const transform::TransformState &state) {
1536  for (Value operand : values) {
1537  SmallVector<MappedValue> &mapped = mappings.emplace_back();
1538  if (llvm::isa<TransformHandleTypeInterface>(operand.getType())) {
1539  llvm::append_range(mapped, state.getPayloadOps(operand));
1540  } else if (llvm::isa<TransformValueHandleTypeInterface>(
1541  operand.getType())) {
1542  llvm::append_range(mapped, state.getPayloadValues(operand));
1543  } else {
1544  assert(llvm::isa<TransformParamTypeInterface>(operand.getType()) &&
1545  "unsupported kind of transform dialect value");
1546  llvm::append_range(mapped, state.getParams(operand));
1547  }
1548  }
1549 }
1550 
1552  Block *block, transform::TransformState &state,
1553  transform::TransformResults &results) {
1554  for (auto &&[terminatorOperand, result] :
1555  llvm::zip(block->getTerminator()->getOperands(),
1556  block->getParentOp()->getOpResults())) {
1557  if (llvm::isa<transform::TransformHandleTypeInterface>(result.getType())) {
1558  results.set(result, state.getPayloadOps(terminatorOperand));
1559  } else if (llvm::isa<transform::TransformValueHandleTypeInterface>(
1560  result.getType())) {
1561  results.setValues(result, state.getPayloadValues(terminatorOperand));
1562  } else {
1563  assert(
1564  llvm::isa<transform::TransformParamTypeInterface>(result.getType()) &&
1565  "unhandled transform type interface");
1566  results.setParams(result, state.getParams(terminatorOperand));
1567  }
1568  }
1569 }
1570 
1573  Operation *payloadRoot) {
1574  return TransformState(region, payloadRoot);
1575 }
1576 
1577 //===----------------------------------------------------------------------===//
1578 // Utilities for PossibleTopLevelTransformOpTrait.
1579 //===----------------------------------------------------------------------===//
1580 
1581 /// Appends to `effects` the memory effect instances on `target` with the same
1582 /// resource and effect as the ones the operation `iface` having on `source`.
1583 static void
1584 remapEffects(MemoryEffectOpInterface iface, BlockArgument source, Value target,
1587  iface.getEffectsOnValue(source, nestedEffects);
1588  for (const auto &effect : nestedEffects)
1589  effects.emplace_back(effect.getEffect(), target, effect.getResource());
1590 }
1591 
1592 /// Appends to `effects` the same effects as the operations of `block` have on
1593 /// block arguments but associated with `operands.`
1594 static void
1597  for (Operation &op : block) {
1598  auto iface = dyn_cast<MemoryEffectOpInterface>(&op);
1599  if (!iface)
1600  continue;
1601 
1602  for (auto &&[source, target] : llvm::zip(block.getArguments(), operands)) {
1603  remapEffects(iface, source, target, effects);
1604  }
1605 
1607  iface.getEffectsOnResource(transform::PayloadIRResource::get(),
1608  nestedEffects);
1609  llvm::append_range(effects, nestedEffects);
1610  }
1611 }
1612 
1614  Operation *operation, Value root, Block &body,
1616  transform::onlyReadsHandle(operation->getOperands(), effects);
1617  transform::producesHandle(operation->getResults(), effects);
1618 
1619  if (!root) {
1620  for (Operation &op : body) {
1621  auto iface = dyn_cast<MemoryEffectOpInterface>(&op);
1622  if (!iface)
1623  continue;
1624 
1626  iface.getEffects(effects);
1627  }
1628  return;
1629  }
1630 
1631  // Carry over all effects on arguments of the entry block as those on the
1632  // operands, this is the same value just remapped.
1633  remapArgumentEffects(body, operation->getOperands(), effects);
1634 }
1635 
1637  TransformState &state, Operation *op, Region &region) {
1638  SmallVector<Operation *> targets;
1639  SmallVector<SmallVector<MappedValue>> extraMappings;
1640  if (op->getNumOperands() != 0) {
1641  llvm::append_range(targets, state.getPayloadOps(op->getOperand(0)));
1642  prepareValueMappings(extraMappings, op->getOperands().drop_front(), state);
1643  } else {
1644  if (state.getNumTopLevelMappings() !=
1645  region.front().getNumArguments() - 1) {
1646  return emitError(op->getLoc())
1647  << "operation expects " << region.front().getNumArguments() - 1
1648  << " extra value bindings, but " << state.getNumTopLevelMappings()
1649  << " were provided to the interpreter";
1650  }
1651 
1652  // Top-level transforms can be used for matching. If no concrete operation
1653  // type is specified, the block argument is mapped to the top-level op.
1654  // Otherwise, it is mapped to all ops of the specified type within the
1655  // top-level op (including the top-level op itself). Once an op is added as
1656  // a target, its descendants are not explored any further.
1657  BlockArgument bbArg = region.front().getArgument(0);
1658  if (auto bbArgType = dyn_cast<transform::OperationType>(bbArg.getType())) {
1659  state.getTopLevel()->walk<WalkOrder::PreOrder>([&](Operation *op) {
1660  if (op->getName().getStringRef() == bbArgType.getOperationName()) {
1661  targets.push_back(op);
1662  return WalkResult::skip();
1663  }
1664  return WalkResult::advance();
1665  });
1666  } else {
1667  targets.push_back(state.getTopLevel());
1668  }
1669 
1670  for (unsigned i = 0, e = state.getNumTopLevelMappings(); i < e; ++i)
1671  extraMappings.push_back(llvm::to_vector(state.getTopLevelMapping(i)));
1672  }
1673 
1674  if (failed(state.mapBlockArguments(region.front().getArgument(0), targets)))
1675  return failure();
1676 
1677  for (BlockArgument argument : region.front().getArguments().drop_front()) {
1678  if (failed(state.mapBlockArgument(
1679  argument, extraMappings[argument.getArgNumber() - 1])))
1680  return failure();
1681  }
1682 
1683  return success();
1684 }
1685 
1688  // Attaching this trait without the interface is a misuse of the API, but it
1689  // cannot be caught via a static_assert because interface registration is
1690  // dynamic.
1691  assert(isa<TransformOpInterface>(op) &&
1692  "should implement TransformOpInterface to have "
1693  "PossibleTopLevelTransformOpTrait");
1694 
1695  if (op->getNumRegions() < 1)
1696  return op->emitOpError() << "expects at least one region";
1697 
1698  Region *bodyRegion = &op->getRegion(0);
1699  if (!llvm::hasNItems(*bodyRegion, 1))
1700  return op->emitOpError() << "expects a single-block region";
1701 
1702  Block *body = &bodyRegion->front();
1703  if (body->getNumArguments() == 0) {
1704  return op->emitOpError()
1705  << "expects the entry block to have at least one argument";
1706  }
1707  if (!llvm::isa<TransformHandleTypeInterface>(
1708  body->getArgument(0).getType())) {
1709  return op->emitOpError()
1710  << "expects the first entry block argument to be of type "
1711  "implementing TransformHandleTypeInterface";
1712  }
1713  BlockArgument arg = body->getArgument(0);
1714  if (op->getNumOperands() != 0) {
1715  if (arg.getType() != op->getOperand(0).getType()) {
1716  return op->emitOpError()
1717  << "expects the type of the block argument to match "
1718  "the type of the operand";
1719  }
1720  }
1721  for (BlockArgument arg : body->getArguments().drop_front()) {
1722  if (llvm::isa<TransformHandleTypeInterface, TransformParamTypeInterface,
1723  TransformValueHandleTypeInterface>(arg.getType()))
1724  continue;
1725 
1727  op->emitOpError()
1728  << "expects trailing entry block arguments to be of type implementing "
1729  "TransformHandleTypeInterface, TransformValueHandleTypeInterface or "
1730  "TransformParamTypeInterface";
1731  diag.attachNote() << "argument #" << arg.getArgNumber() << " does not";
1732  return diag;
1733  }
1734 
1735  if (auto *parent =
1737  if (op->getNumOperands() != body->getNumArguments()) {
1739  op->emitOpError()
1740  << "expects operands to be provided for a nested op";
1741  diag.attachNote(parent->getLoc())
1742  << "nested in another possible top-level op";
1743  return diag;
1744  }
1745  }
1746 
1747  return success();
1748 }
1749 
1750 //===----------------------------------------------------------------------===//
1751 // Utilities for ParamProducedTransformOpTrait.
1752 //===----------------------------------------------------------------------===//
1753 
1756  producesHandle(op->getResults(), effects);
1757  bool hasPayloadOperands = false;
1758  for (Value operand : op->getOperands()) {
1759  onlyReadsHandle(operand, effects);
1760  if (llvm::isa<TransformHandleTypeInterface,
1761  TransformValueHandleTypeInterface>(operand.getType()))
1762  hasPayloadOperands = true;
1763  }
1764  if (hasPayloadOperands)
1765  onlyReadsPayload(effects);
1766 }
1767 
1770  // Interfaces can be attached dynamically, so this cannot be a static
1771  // assert.
1772  if (!op->getName().getInterface<MemoryEffectOpInterface>()) {
1773  llvm::report_fatal_error(
1774  Twine("ParamProducerTransformOpTrait must be attached to an op that "
1775  "implements MemoryEffectsOpInterface, found on ") +
1776  op->getName().getStringRef());
1777  }
1778  for (Value result : op->getResults()) {
1779  if (llvm::isa<TransformParamTypeInterface>(result.getType()))
1780  continue;
1781  return op->emitOpError()
1782  << "ParamProducerTransformOpTrait attached to this op expects "
1783  "result types to implement TransformParamTypeInterface";
1784  }
1785  return success();
1786 }
1787 
1788 //===----------------------------------------------------------------------===//
1789 // Memory effects.
1790 //===----------------------------------------------------------------------===//
1791 
1793  ValueRange handles,
1795  for (Value handle : handles) {
1796  effects.emplace_back(MemoryEffects::Read::get(), handle,
1798  effects.emplace_back(MemoryEffects::Free::get(), handle,
1800  }
1801 }
1802 
1803 /// Returns `true` if the given list of effects instances contains an instance
1804 /// with the effect type specified as template parameter.
1805 template <typename EffectTy, typename ResourceTy, typename Range>
1806 static bool hasEffect(Range &&effects) {
1807  return llvm::any_of(effects, [](const MemoryEffects::EffectInstance &effect) {
1808  return isa<EffectTy>(effect.getEffect()) &&
1809  isa<ResourceTy>(effect.getResource());
1810  });
1811 }
1812 
1814  transform::TransformOpInterface transform) {
1815  auto iface = cast<MemoryEffectOpInterface>(transform.getOperation());
1817  iface.getEffectsOnValue(handle, effects);
1818  return ::hasEffect<MemoryEffects::Read, TransformMappingResource>(effects) &&
1819  ::hasEffect<MemoryEffects::Free, TransformMappingResource>(effects);
1820 }
1821 
1823  ValueRange handles,
1825  for (Value handle : handles) {
1826  effects.emplace_back(MemoryEffects::Allocate::get(), handle,
1828  effects.emplace_back(MemoryEffects::Write::get(), handle,
1830  }
1831 }
1832 
1834  ValueRange handles,
1836  for (Value handle : handles) {
1837  effects.emplace_back(MemoryEffects::Read::get(), handle,
1839  }
1840 }
1841 
1844  effects.emplace_back(MemoryEffects::Read::get(), PayloadIRResource::get());
1845  effects.emplace_back(MemoryEffects::Write::get(), PayloadIRResource::get());
1846 }
1847 
1850  effects.emplace_back(MemoryEffects::Read::get(), PayloadIRResource::get());
1851 }
1852 
1853 bool transform::doesModifyPayload(transform::TransformOpInterface transform) {
1854  auto iface = cast<MemoryEffectOpInterface>(transform.getOperation());
1856  iface.getEffects(effects);
1857  return ::hasEffect<MemoryEffects::Write, PayloadIRResource>(effects);
1858 }
1859 
1860 bool transform::doesReadPayload(transform::TransformOpInterface transform) {
1861  auto iface = cast<MemoryEffectOpInterface>(transform.getOperation());
1863  iface.getEffects(effects);
1864  return ::hasEffect<MemoryEffects::Read, PayloadIRResource>(effects);
1865 }
1866 
1868  Block &block, llvm::SmallDenseSet<unsigned int> &consumedArguments) {
1870  for (Operation &nested : block) {
1871  auto iface = dyn_cast<MemoryEffectOpInterface>(nested);
1872  if (!iface)
1873  continue;
1874 
1875  effects.clear();
1876  iface.getEffects(effects);
1877  for (const MemoryEffects::EffectInstance &effect : effects) {
1878  BlockArgument argument =
1879  dyn_cast_or_null<BlockArgument>(effect.getValue());
1880  if (!argument || argument.getOwner() != &block ||
1881  !isa<MemoryEffects::Free>(effect.getEffect()) ||
1882  effect.getResource() != transform::TransformMappingResource::get()) {
1883  continue;
1884  }
1885  consumedArguments.insert(argument.getArgNumber());
1886  }
1887  }
1888 }
1889 
1890 //===----------------------------------------------------------------------===//
1891 // Utilities for TransformOpInterface.
1892 //===----------------------------------------------------------------------===//
1893 
1895  TransformOpInterface transformOp) {
1896  SmallVector<OpOperand *> consumedOperands;
1897  consumedOperands.reserve(transformOp->getNumOperands());
1898  auto memEffectInterface =
1899  cast<MemoryEffectOpInterface>(transformOp.getOperation());
1901  for (OpOperand &target : transformOp->getOpOperands()) {
1902  effects.clear();
1903  memEffectInterface.getEffectsOnValue(target.get(), effects);
1904  if (llvm::any_of(effects, [](const MemoryEffects::EffectInstance &effect) {
1905  return isa<transform::TransformMappingResource>(
1906  effect.getResource()) &&
1907  isa<MemoryEffects::Free>(effect.getEffect());
1908  })) {
1909  consumedOperands.push_back(&target);
1910  }
1911  }
1912  return consumedOperands;
1913 }
1914 
1916  auto iface = cast<MemoryEffectOpInterface>(op);
1918  iface.getEffects(effects);
1919 
1920  auto effectsOn = [&](Value value) {
1921  return llvm::make_filter_range(
1922  effects, [value](const MemoryEffects::EffectInstance &instance) {
1923  return instance.getValue() == value;
1924  });
1925  };
1926 
1927  std::optional<unsigned> firstConsumedOperand;
1928  for (OpOperand &operand : op->getOpOperands()) {
1929  auto range = effectsOn(operand.get());
1930  if (range.empty()) {
1932  op->emitError() << "TransformOpInterface requires memory effects "
1933  "on operands to be specified";
1934  diag.attachNote() << "no effects specified for operand #"
1935  << operand.getOperandNumber();
1936  return diag;
1937  }
1938  if (::hasEffect<MemoryEffects::Allocate, TransformMappingResource>(range)) {
1940  << "TransformOpInterface did not expect "
1941  "'allocate' memory effect on an operand";
1942  diag.attachNote() << "specified for operand #"
1943  << operand.getOperandNumber();
1944  return diag;
1945  }
1946  if (!firstConsumedOperand &&
1947  ::hasEffect<MemoryEffects::Free, TransformMappingResource>(range)) {
1948  firstConsumedOperand = operand.getOperandNumber();
1949  }
1950  }
1951 
1952  if (firstConsumedOperand &&
1953  !::hasEffect<MemoryEffects::Write, PayloadIRResource>(effects)) {
1955  op->emitError()
1956  << "TransformOpInterface expects ops consuming operands to have a "
1957  "'write' effect on the payload resource";
1958  diag.attachNote() << "consumes operand #" << *firstConsumedOperand;
1959  return diag;
1960  }
1961 
1962  for (OpResult result : op->getResults()) {
1963  auto range = effectsOn(result);
1964  if (!::hasEffect<MemoryEffects::Allocate, TransformMappingResource>(
1965  range)) {
1967  op->emitError() << "TransformOpInterface requires 'allocate' memory "
1968  "effect to be specified for results";
1969  diag.attachNote() << "no 'allocate' effect specified for result #"
1970  << result.getResultNumber();
1971  return diag;
1972  }
1973  }
1974 
1975  return success();
1976 }
1977 
1978 //===----------------------------------------------------------------------===//
1979 // Entry point.
1980 //===----------------------------------------------------------------------===//
1981 
1983  Operation *payloadRoot, TransformOpInterface transform,
1984  const RaggedArray<MappedValue> &extraMapping,
1985  const TransformOptions &options, bool enforceToplevelTransformOp) {
1986  if (enforceToplevelTransformOp) {
1987  if (!transform->hasTrait<PossibleTopLevelTransformOpTrait>() ||
1988  transform->getNumOperands() != 0) {
1989  return transform->emitError()
1990  << "expected transform to start at the top-level transform op";
1991  }
1992  } else if (failed(
1994  return failure();
1995  }
1996 
1997  TransformState state(transform->getParentRegion(), payloadRoot, extraMapping,
1998  options);
1999  return state.applyTransform(transform).checkAndReport();
2000 }
2001 
2002 //===----------------------------------------------------------------------===//
2003 // Generated interface implementation.
2004 //===----------------------------------------------------------------------===//
2005 
2006 #include "mlir/Dialect/Transform/IR/TransformInterfaces.cpp.inc"
static InFlightDiagnostic emitDiag(Location location, DiagnosticSeverity severity, const Twine &message)
Helper function used to emit a diagnostic with an optionally empty twine message.
static std::string diag(const llvm::Value &value)
static llvm::ManagedStatic< PassManagerOptions > options
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
#define DEBUG_TYPE_FULL
static void remapEffects(MemoryEffectOpInterface iface, BlockArgument source, Value target, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
Appends to effects the memory effect instances on target with the same resource and effect as the one...
#define FULL_LDBG(X)
static void remapArgumentEffects(Block &block, ValueRange operands, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
Appends to effects the same effects as the operations of block have on block arguments but associated...
static bool happensBefore(Operation *a, Operation *b)
Return true if a happens before b, i.e., a or one of its ancestors properly dominates b and b is not ...
void dropMappingEntry(Mapping &mapping, Key key, Mapped mapped)
#define DEBUG_PRINT_AFTER_ALL
DiagnosedSilenceableFailure checkRepeatedConsumptionInOperand(ArrayRef< T > payload, transform::TransformOpInterface transform, unsigned operandNumber)
#define DBGS()
static DiagnosedSilenceableFailure dispatchMappedValues(Value handle, ArrayRef< transform::MappedValue > values, function_ref< LogicalResult(ArrayRef< Operation * >)> operationsFn, function_ref< LogicalResult(ArrayRef< transform::Param >)> paramsFn, function_ref< LogicalResult(ValueRange)> valuesFn)
Given a list of MappedValues, cast them to the value kind implied by the interface of the handle type...
static SmallVector< T > castVector(ArrayRef< transform::MappedValue > range)
Attributes are known-constant values of operations.
Definition: Attributes.h:25
This class represents an argument of a Block.
Definition: Value.h:315
Block * getOwner() const
Returns the block that owns this argument.
Definition: Value.h:324
unsigned getArgNumber() const
Returns the number of this argument.
Definition: Value.h:327
Block represents an ordered list of Operations.
Definition: Block.h:30
Operation * findAncestorOpInBlock(Operation &op)
Returns 'op' if 'op' lies in this block, or otherwise finds the ancestor operation of 'op' that lies ...
Definition: Block.cpp:73
BlockArgument getArgument(unsigned i)
Definition: Block.h:126
unsigned getNumArguments()
Definition: Block.h:125
Operation * getTerminator()
Get the terminator operation of this block.
Definition: Block.cpp:243
BlockArgListType getArguments()
Definition: Block.h:84
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
Definition: Block.cpp:30
A compatibility class connecting InFlightDiagnostic to DiagnosedSilenceableFailure while providing an...
The result of a transform IR operation application.
LogicalResult silence()
Converts silenceable failure into LogicalResult success without reporting the diagnostic,...
static DiagnosedSilenceableFailure success()
Constructs a DiagnosedSilenceableFailure in the success state.
Diagnostic & attachNote(std::optional< Location > loc=std::nullopt)
Attaches a note to the last diagnostic.
std::string getMessage() const
Returns the diagnostic message without emitting it.
bool isDefiniteFailure() const
Returns true if this is a definite failure.
LogicalResult checkAndReport()
Converts all kinds of failure into a LogicalResult failure, emitting the diagnostic if necessary.
static DiagnosedSilenceableFailure silenceableFailure(Diagnostic &&diag)
Constructs a DiagnosedSilenceableFailure in the silenceable failure state, ready to emit the given di...
bool succeeded() const
Returns true if this is a success.
static DiagnosedSilenceableFailure definiteFailure()
Constructs a DiagnosedSilenceableFailure in the failure state.
bool isSilenceableFailure() const
Returns true if this is a silenceable failure.
This class contains all of the information necessary to report a diagnostic to the DiagnosticEngine.
Definition: Diagnostics.h:156
IRValueT get() const
Return the current value being used by this operand.
Definition: UseDefLists.h:160
This class represents a diagnostic that is inflight and set to be reported.
Definition: Diagnostics.h:308
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
void setListener(Listener *newListener)
Sets the listener of this builder to the one provided.
Definition: Builders.h:318
This class represents an operand of an operation.
Definition: Value.h:263
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Definition: Value.cpp:216
Set of flags used to control the behavior of the various IR print methods (e.g.
This is a value defined by a result of an operation.
Definition: Value.h:453
unsigned getResultNumber() const
Returns the number of this result.
Definition: Value.h:465
This class provides the API for a sub-set of ops that are known to be constant-like.
This class provides the API for ops that are known to be isolated from above.
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
StringAttr getIdentifier() const
Return the name of this operation as a StringAttr.
T::Concept * getInterface() const
Returns an instance of the concept object for the given interface if it was registered to this operat...
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Value getOperand(unsigned idx)
Definition: Operation.h:345
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition: Operation.h:745
bool isBeforeInBlock(Operation *other)
Given an operation 'other' that is within the same parent block, return whether the current operation...
Definition: Operation.cpp:386
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Definition: Operation.h:793
unsigned getNumRegions()
Returns the number of regions held by this operation.
Definition: Operation.h:669
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
unsigned getNumOperands()
Definition: Operation.h:341
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition: Operation.h:234
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:268
Block * getBlock()
Returns the operation block that contains this operation.
Definition: Operation.h:213
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Definition: Operation.h:682
Operation * getParentWithTrait()
Returns the closest surrounding parent operation with trait Trait.
Definition: Operation.h:248
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:119
MutableArrayRef< OpOperand > getOpOperands()
Definition: Operation.h:378
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:373
result_range getOpResults()
Definition: Operation.h:415
result_range getResults()
Definition: Operation.h:410
bool isProperAncestor(Operation *other)
Return true if this operation is a proper ancestor of the other operation.
Definition: Operation.cpp:219
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:671
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:399
A 2D array where each row may have different length.
Definition: RaggedArray.h:18
size_t size() const
Returns the number of rows in the 2D array.
Definition: RaggedArray.h:21
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
Operation * getParentOp()
Return the parent operation this region is attached to.
Definition: Region.h:200
Block & front()
Definition: Region.h:65
This class implements the result iterators for the Operation class.
Definition: ValueRange.h:239
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:399
This class represents a specific instance of an effect.
Resource * getResource() const
Return the resource that the effect applies to.
EffectT * getEffect() const
Return the effect being applied.
Value getValue() const
Return the value the effect is applied on, or nullptr if there isn't a known value being affected.
static DerivedEffect * get()
Returns a unique instance for the derived effect class.
static TransformMappingResource * get()
Returns a unique instance for the given effect class.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:378
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:125
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:26
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
static WalkResult skip()
Definition: Visitors.h:53
static WalkResult advance()
Definition: Visitors.h:52
Operation * getOwner() const
Return the owner of this operand.
Definition: UseDefLists.h:38
A list of results of applying a transform op with ApplyEachOpTrait to a single payload operation,...
size_t size() const
Returns the number of elements in the list.
A specialized listener that keeps track of cases in which no replacement payload could be found.
bool failed() const
Return "true" if this tracking listener had a failure.
void notifyPayloadReplacementNotFound(Operation *op, ValueRange values, DiagnosedSilenceableFailure &&diag) override
This function is called when a tracked payload op is dropped because no replacement op was found.
DiagnosedSilenceableFailure checkAndResetError()
Check and return the current error state of this listener.
This trait is supposed to be attached to Transform dialect operations that can be standalone top-leve...
TrackingListener failures are reported only for ops that have this trait.
void notifyMatchFailure(Location loc, function_ref< void(Diagnostic &)> reasonCallback) override
Notify the listener that the pattern failed to match the given operation, and provide a callback to p...
std::function< bool(Value)> SkipHandleFn
A function that returns "true" for handles that do not have to be updated.
TrackingListener(TransformState &state, TransformOpInterface op, SkipHandleFn skipHandleFn=nullptr)
Create a new TrackingListener for usage in the specified transform op.
static Operation * getCommonDefiningOp(ValueRange values)
Return the single op that defines all given values (if any).
virtual DiagnosedSilenceableFailure findReplacementOp(Operation *&result, Operation *op, ValueRange newValues) const
Return a replacement payload op for the given op, which is going to be replaced with the given values...
Options controlling the application of transform operations by the TransformState.
Local mapping between values defined by a specific op implementing the TransformOpInterface and the p...
void setValues(OpResult handle, Range &&values)
Indicates that the result of the transform IR op at the given position corresponds to the given range...
void setParams(OpResult value, ArrayRef< TransformState::Param > params)
Indicates that the result of the transform IR op at the given position corresponds to the given list ...
void set(OpResult value, Range &&ops)
Indicates that the result of the transform IR op at the given position corresponds to the given list ...
void setRemainingToEmpty(TransformOpInterface transform)
Sets the currently unset results to empty lists of the kind expected by the corresponding results of ...
void setMappedValues(OpResult handle, ArrayRef< MappedValue > values)
Indicates that the result of the transform IR op at the given position corresponds to the given range...
This is a special rewriter to be used in transform op implementations, providing additional helper fu...
TransformRewriter(MLIRContext *ctx, ErrorCheckingTrackingListener *listener)
Create a new TransformRewriter.
bool hasTrackingFailures() const
Return "true" if the tracking listener had failures.
LogicalResult notifyPayloadOperationReplaced(Operation *op, Operation *replacement)
Notify the transform dialect interpreter that the given op has been replaced with another op and that...
void silenceTrackingFailure()
Silence all tracking failures that have been encountered so far.
LogicalResult replacePayloadOp(Operation *op, Operation *replacement)
Replaces the given payload op with another op.
virtual ~Extension()
Base virtual destructor.
LogicalResult replacePayloadValue(Value value, Value replacement)
Replaces the given payload value with another value.
A RAII object maintaining a "stack frame" for a transform IR region.
~RegionScope()
Forgets the mapping from or to values defined in the associated transform IR region,...
The state maintained across applications of various ops implementing the TransformOpInterface.
LogicalResult getHandlesForPayloadValue(Value payloadValue, SmallVectorImpl< Value > &handles, bool includeOutOfScope=false) const
Populates handles with all handles pointing to the given payload IR value.
DiagnosedSilenceableFailure applyTransform(TransformOpInterface transform)
Applies the transformation specified by the given transform op and updates the state accordingly.
ArrayRef< Attribute > getParams(Value value) const
Returns the list of parameters that the given transform IR value corresponds to.
LogicalResult mapBlockArgument(BlockArgument argument, ArrayRef< MappedValue > values)
LogicalResult getHandlesForPayloadOp(Operation *op, SmallVectorImpl< Value > &handles, bool includeOutOfScope=false) const
Populates handles with all handles pointing to the given Payload IR op.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:285
LogicalResult verifyTransformOpInterface(Operation *op)
Verification hook for TransformOpInterface.
void setApplyToOneResults(Operation *transformOp, TransformResults &transformResults, ArrayRef< ApplyToEachResultList > results)
"Transpose" the results produced by individual applications, arranging them per result value of the t...
void forwardTerminatorOperands(Block *block, transform::TransformState &state, transform::TransformResults &results)
Populates results with payload associations that match exactly those of the operands to block's termi...
void getParamProducerTransformOpTraitEffects(Operation *op, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
Non-template implementation of ParamProducerTransformOpTrait::getEffects().
LogicalResult checkNestedConsumption(Location loc, ArrayRef< Operation * > targets)
Reports an error and returns failure if targets contains an ancestor operation before its descendant ...
LogicalResult mapPossibleTopLevelTransformOpBlockArguments(TransformState &state, Operation *op, Region &region)
Maps the only block argument of the op with PossibleTopLevelTransformOpTrait to either the list of op...
TransformState makeTransformStateForTesting(Region *region, Operation *payloadRoot)
Make a dummy transform state for testing purposes.
LogicalResult checkApplyToOne(Operation *transformOp, Location payloadOpLoc, const ApplyToEachResultList &partialResult)
Check that the contents of partialResult matches the number, kind (payload op or parameter) and nulli...
SmallVector< OpOperand * > getConsumedHandleOpOperands(transform::TransformOpInterface transformOp)
Returns all operands that are handles and being consumed by the given op.
LogicalResult verifyParamProducerTransformOpTrait(Operation *op)
Non-template implementation of ParamProducerTransformOpTrait::verify().
void prepareValueMappings(SmallVectorImpl< SmallVector< transform::MappedValue >> &mappings, ValueRange values, const transform::TransformState &state)
Populates mappings with mapped values associated with the given transform IR values in the given stat...
LogicalResult verifyPossibleTopLevelTransformOpTrait(Operation *op)
Verification hook for PossibleTopLevelTransformOpTrait.
void getPotentialTopLevelEffects(Operation *operation, Value root, Block &body, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
Populates effects with side effects implied by PossibleTopLevelTransformOpTrait for the given operati...
void onlyReadsPayload(SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
bool isHandleConsumed(Value handle, transform::TransformOpInterface transform)
Checks whether the transform op consumes the given handle.
void onlyReadsHandle(ValueRange handles, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
void getConsumedBlockArguments(Block &block, llvm::SmallDenseSet< unsigned > &consumedArguments)
Populates consumedArguments with positions of block arguments that are consumed by the operations in ...
LogicalResult applyTransforms(Operation *payloadRoot, TransformOpInterface transform, const RaggedArray< MappedValue > &extraMapping={}, const TransformOptions &options=TransformOptions(), bool enforceToplevelTransformOp=true)
Entry point to the Transform dialect infrastructure.
bool doesModifyPayload(transform::TransformOpInterface transform)
Checks whether the transform op modifies the payload.
void consumesHandle(ValueRange handles, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
Populates effects with the memory effects indicating the operation on the given handle value:
bool doesReadPayload(transform::TransformOpInterface transform)
Checks whether the transform op reads the payload.
void producesHandle(ValueRange handles, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
void modifiesPayload(SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
Populates effects with the memory effects indicating the access to payload IR resource.
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
bool hasEffect(Operation *op, Value value=nullptr)
Returns true if op has an effect of type EffectTy on value.
DiagnosedSilenceableFailure emitSilenceableFailure(Location loc, const Twine &message={})
Emits a silenceable failure with the given message.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
DiagnosedDefiniteFailure emitDefiniteFailure(Location loc, const Twine &message={})
Emits a definite failure with the given message.
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
Represents a range (offset, size, and stride) where each element of the triple may be dynamic or stat...