MLIR  21.0.0git
TransformInterfaces.cpp
Go to the documentation of this file.
1 //===- TransformInterfaces.cpp - Transform Dialect Interfaces -------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
11 #include "mlir/IR/Diagnostics.h"
12 #include "mlir/IR/Operation.h"
13 #include "mlir/IR/PatternMatch.h"
16 #include "llvm/ADT/STLExtras.h"
17 #include "llvm/ADT/ScopeExit.h"
18 #include "llvm/ADT/iterator.h"
19 #include "llvm/Support/Debug.h"
20 #include "llvm/Support/ErrorHandling.h"
21 #include "llvm/Support/InterleavedRange.h"
22 
23 #define DEBUG_TYPE "transform-dialect"
24 #define DEBUG_TYPE_FULL "transform-dialect-full"
25 #define DEBUG_PRINT_AFTER_ALL "transform-dialect-print-top-level-after-all"
26 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "] ")
27 #define LDBG(X) LLVM_DEBUG(DBGS() << (X))
28 #define FULL_LDBG(X) DEBUG_WITH_TYPE(DEBUG_TYPE_FULL, (DBGS() << (X)))
29 
30 using namespace mlir;
31 
32 //===----------------------------------------------------------------------===//
33 // Helper functions
34 //===----------------------------------------------------------------------===//
35 
36 /// Return true if `a` happens before `b`, i.e., `a` or one of its ancestors
37 /// properly dominates `b` and `b` is not inside `a`.
38 static bool happensBefore(Operation *a, Operation *b) {
39  do {
40  if (a->isProperAncestor(b))
41  return false;
42  if (Operation *bAncestor = a->getBlock()->findAncestorOpInBlock(*b)) {
43  return a->isBeforeInBlock(bAncestor);
44  }
45  } while ((a = a->getParentOp()));
46  return false;
47 }
48 
49 //===----------------------------------------------------------------------===//
50 // TransformState
51 //===----------------------------------------------------------------------===//
52 
53 constexpr const Value transform::TransformState::kTopLevelValue;
54 
55 transform::TransformState::TransformState(
56  Region *region, Operation *payloadRoot,
57  const RaggedArray<MappedValue> &extraMappings,
58  const TransformOptions &options)
59  : topLevel(payloadRoot), options(options) {
60  topLevelMappedValues.reserve(extraMappings.size());
61  for (ArrayRef<MappedValue> mapping : extraMappings)
62  topLevelMappedValues.push_back(mapping);
63  if (region) {
64  RegionScope *scope = new RegionScope(*this, *region);
65  topLevelRegionScope.reset(scope);
66  }
67 }
68 
69 Operation *transform::TransformState::getTopLevel() const { return topLevel; }
70 
72 transform::TransformState::getPayloadOpsView(Value value) const {
73  const TransformOpMapping &operationMapping = getMapping(value).direct;
74  auto iter = operationMapping.find(value);
75  assert(iter != operationMapping.end() &&
76  "cannot find mapping for payload handle (param/value handle "
77  "provided?)");
78  return iter->getSecond();
79 }
80 
82  const ParamMapping &mapping = getMapping(value).params;
83  auto iter = mapping.find(value);
84  assert(iter != mapping.end() && "cannot find mapping for param handle "
85  "(operation/value handle provided?)");
86  return iter->getSecond();
87 }
88 
90 transform::TransformState::getPayloadValuesView(Value handleValue) const {
91  const ValueMapping &mapping = getMapping(handleValue).values;
92  auto iter = mapping.find(handleValue);
93  assert(iter != mapping.end() && "cannot find mapping for value handle "
94  "(param/operation handle provided?)");
95  return iter->getSecond();
96 }
97 
99  Operation *op, SmallVectorImpl<Value> &handles,
100  bool includeOutOfScope) const {
101  bool found = false;
102  for (const auto &[region, mapping] : llvm::reverse(mappings)) {
103  auto iterator = mapping->reverse.find(op);
104  if (iterator != mapping->reverse.end()) {
105  llvm::append_range(handles, iterator->getSecond());
106  found = true;
107  }
108  // Stop looking when reaching a region that is isolated from above.
109  if (!includeOutOfScope &&
111  break;
112  }
113 
114  return success(found);
115 }
116 
118  Value payloadValue, SmallVectorImpl<Value> &handles,
119  bool includeOutOfScope) const {
120  bool found = false;
121  for (const auto &[region, mapping] : llvm::reverse(mappings)) {
122  auto iterator = mapping->reverseValues.find(payloadValue);
123  if (iterator != mapping->reverseValues.end()) {
124  llvm::append_range(handles, iterator->getSecond());
125  found = true;
126  }
127  // Stop looking when reaching a region that is isolated from above.
128  if (!includeOutOfScope &&
130  break;
131  }
132 
133  return success(found);
134 }
135 
136 /// Given a list of MappedValues, cast them to the value kind implied by the
137 /// interface of the handle type, and dispatch to one of the callbacks.
140  function_ref<LogicalResult(ArrayRef<Operation *>)> operationsFn,
141  function_ref<LogicalResult(ArrayRef<transform::Param>)> paramsFn,
142  function_ref<LogicalResult(ValueRange)> valuesFn) {
143  if (llvm::isa<transform::TransformHandleTypeInterface>(handle.getType())) {
144  SmallVector<Operation *> operations;
145  operations.reserve(values.size());
146  for (transform::MappedValue value : values) {
147  if (auto *op = llvm::dyn_cast_if_present<Operation *>(value)) {
148  operations.push_back(op);
149  continue;
150  }
151  return emitSilenceableFailure(handle.getLoc())
152  << "wrong kind of value provided for top-level operation handle";
153  }
154  if (failed(operationsFn(operations)))
157  }
158 
159  if (llvm::isa<transform::TransformValueHandleTypeInterface>(
160  handle.getType())) {
161  SmallVector<Value> payloadValues;
162  payloadValues.reserve(values.size());
163  for (transform::MappedValue value : values) {
164  if (auto v = llvm::dyn_cast_if_present<Value>(value)) {
165  payloadValues.push_back(v);
166  continue;
167  }
168  return emitSilenceableFailure(handle.getLoc())
169  << "wrong kind of value provided for the top-level value handle";
170  }
171  if (failed(valuesFn(payloadValues)))
174  }
175 
176  assert(llvm::isa<transform::TransformParamTypeInterface>(handle.getType()) &&
177  "unsupported kind of block argument");
179  parameters.reserve(values.size());
180  for (transform::MappedValue value : values) {
181  if (auto attr = llvm::dyn_cast_if_present<Attribute>(value)) {
182  parameters.push_back(attr);
183  continue;
184  }
185  return emitSilenceableFailure(handle.getLoc())
186  << "wrong kind of value provided for top-level parameter";
187  }
188  if (failed(paramsFn(parameters)))
191 }
192 
193 LogicalResult
195  ArrayRef<MappedValue> values) {
196  return dispatchMappedValues(
197  argument, values,
198  [&](ArrayRef<Operation *> operations) {
199  return setPayloadOps(argument, operations);
200  },
201  [&](ArrayRef<Param> params) {
202  return setParams(argument, params);
203  },
204  [&](ValueRange payloadValues) {
205  return setPayloadValues(argument, payloadValues);
206  })
207  .checkAndReport();
208 }
209 
211  Block::BlockArgListType arguments,
213  for (auto &&[argument, values] : llvm::zip_equal(arguments, mapping))
214  if (failed(mapBlockArgument(argument, values)))
215  return failure();
216  return success();
217 }
218 
219 LogicalResult
220 transform::TransformState::setPayloadOps(Value value,
221  ArrayRef<Operation *> targets) {
222  assert(value != kTopLevelValue &&
223  "attempting to reset the transformation root");
224  assert(llvm::isa<TransformHandleTypeInterface>(value.getType()) &&
225  "wrong handle type");
226 
227  for (Operation *target : targets) {
228  if (target)
229  continue;
230  return emitError(value.getLoc())
231  << "attempting to assign a null payload op to this transform value";
232  }
233 
234  auto iface = llvm::cast<TransformHandleTypeInterface>(value.getType());
236  iface.checkPayload(value.getLoc(), targets);
237  if (failed(result.checkAndReport()))
238  return failure();
239 
240  // Setting new payload for the value without cleaning it first is a misuse of
241  // the API, assert here.
242  SmallVector<Operation *> storedTargets(targets);
243  Mappings &mappings = getMapping(value);
244  bool inserted =
245  mappings.direct.insert({value, std::move(storedTargets)}).second;
246  assert(inserted && "value is already associated with another list");
247  (void)inserted;
248 
249  for (Operation *op : targets)
250  mappings.reverse[op].push_back(value);
251 
252  return success();
253 }
254 
255 LogicalResult
256 transform::TransformState::setPayloadValues(Value handle,
257  ValueRange payloadValues) {
258  assert(handle != nullptr && "attempting to set params for a null value");
259  assert(llvm::isa<TransformValueHandleTypeInterface>(handle.getType()) &&
260  "wrong handle type");
261 
262  for (Value payload : payloadValues) {
263  if (payload)
264  continue;
265  return emitError(handle.getLoc()) << "attempting to assign a null payload "
266  "value to this transform handle";
267  }
268 
269  auto iface = llvm::cast<TransformValueHandleTypeInterface>(handle.getType());
270  SmallVector<Value> payloadValueVector = llvm::to_vector(payloadValues);
272  iface.checkPayload(handle.getLoc(), payloadValueVector);
273  if (failed(result.checkAndReport()))
274  return failure();
275 
276  Mappings &mappings = getMapping(handle);
277  bool inserted =
278  mappings.values.insert({handle, std::move(payloadValueVector)}).second;
279  assert(
280  inserted &&
281  "value handle is already associated with another list of payload values");
282  (void)inserted;
283 
284  for (Value payload : payloadValues)
285  mappings.reverseValues[payload].push_back(handle);
286 
287  return success();
288 }
289 
290 LogicalResult transform::TransformState::setParams(Value value,
291  ArrayRef<Param> params) {
292  assert(value != nullptr && "attempting to set params for a null value");
293 
294  for (Attribute attr : params) {
295  if (attr)
296  continue;
297  return emitError(value.getLoc())
298  << "attempting to assign a null parameter to this transform value";
299  }
300 
301  auto valueType = llvm::dyn_cast<TransformParamTypeInterface>(value.getType());
302  assert(value &&
303  "cannot associate parameter with a value of non-parameter type");
305  valueType.checkPayload(value.getLoc(), params);
306  if (failed(result.checkAndReport()))
307  return failure();
308 
309  Mappings &mappings = getMapping(value);
310  bool inserted =
311  mappings.params.insert({value, llvm::to_vector(params)}).second;
312  assert(inserted && "value is already associated with another list of params");
313  (void)inserted;
314  return success();
315 }
316 
317 template <typename Mapping, typename Key, typename Mapped>
318 void dropMappingEntry(Mapping &mapping, Key key, Mapped mapped) {
319  auto it = mapping.find(key);
320  if (it == mapping.end())
321  return;
322 
323  llvm::erase(it->getSecond(), mapped);
324  if (it->getSecond().empty())
325  mapping.erase(it);
326 }
327 
328 void transform::TransformState::forgetMapping(Value opHandle,
329  ValueRange origOpFlatResults,
330  bool allowOutOfScope) {
331  Mappings &mappings = getMapping(opHandle, allowOutOfScope);
332  for (Operation *op : mappings.direct[opHandle])
333  dropMappingEntry(mappings.reverse, op, opHandle);
334  mappings.direct.erase(opHandle);
335 #if LLVM_ENABLE_ABI_BREAKING_CHECKS
336  // Payload IR is removed from the mapping. This invalidates the respective
337  // iterators.
338  mappings.incrementTimestamp(opHandle);
339 #endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
340 
341  for (Value opResult : origOpFlatResults) {
342  SmallVector<Value> resultHandles;
343  (void)getHandlesForPayloadValue(opResult, resultHandles);
344  for (Value resultHandle : resultHandles) {
345  Mappings &localMappings = getMapping(resultHandle);
346  dropMappingEntry(localMappings.values, resultHandle, opResult);
347 #if LLVM_ENABLE_ABI_BREAKING_CHECKS
348  // Payload IR is removed from the mapping. This invalidates the respective
349  // iterators.
350  mappings.incrementTimestamp(resultHandle);
351 #endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
352  dropMappingEntry(localMappings.reverseValues, opResult, resultHandle);
353  }
354  }
355 }
356 
357 void transform::TransformState::forgetValueMapping(
358  Value valueHandle, ArrayRef<Operation *> payloadOperations) {
359  Mappings &mappings = getMapping(valueHandle);
360  for (Value payloadValue : mappings.reverseValues[valueHandle])
361  dropMappingEntry(mappings.reverseValues, payloadValue, valueHandle);
362  mappings.values.erase(valueHandle);
363 #if LLVM_ENABLE_ABI_BREAKING_CHECKS
364  // Payload IR is removed from the mapping. This invalidates the respective
365  // iterators.
366  mappings.incrementTimestamp(valueHandle);
367 #endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
368 
369  for (Operation *payloadOp : payloadOperations) {
370  SmallVector<Value> opHandles;
371  (void)getHandlesForPayloadOp(payloadOp, opHandles);
372  for (Value opHandle : opHandles) {
373  Mappings &localMappings = getMapping(opHandle);
374  dropMappingEntry(localMappings.direct, opHandle, payloadOp);
375  dropMappingEntry(localMappings.reverse, payloadOp, opHandle);
376 
377 #if LLVM_ENABLE_ABI_BREAKING_CHECKS
378  // Payload IR is removed from the mapping. This invalidates the respective
379  // iterators.
380  localMappings.incrementTimestamp(opHandle);
381 #endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
382  }
383  }
384 }
385 
386 LogicalResult
387 transform::TransformState::replacePayloadOp(Operation *op,
388  Operation *replacement) {
389  // TODO: consider invalidating the handles to nested objects here.
390 
391 #ifndef NDEBUG
392  for (Value opResult : op->getResults()) {
393  SmallVector<Value> valueHandles;
394  (void)getHandlesForPayloadValue(opResult, valueHandles,
395  /*includeOutOfScope=*/true);
396  assert(valueHandles.empty() && "expected no mapping to old results");
397  }
398 #endif // NDEBUG
399 
400  // Drop the mapping between the op and all handles that point to it. Fail if
401  // there are no handles.
402  SmallVector<Value> opHandles;
403  if (failed(getHandlesForPayloadOp(op, opHandles, /*includeOutOfScope=*/true)))
404  return failure();
405  for (Value handle : opHandles) {
406  Mappings &mappings = getMapping(handle, /*allowOutOfScope=*/true);
407  dropMappingEntry(mappings.reverse, op, handle);
408  }
409 
410  // Replace the pointed-to object of all handles with the replacement object.
411  // In case a payload op was erased (replacement object is nullptr), a nullptr
412  // is stored in the mapping. These nullptrs are removed after each transform.
413  // Furthermore, nullptrs are not enumerated by payload op iterators. The
414  // relative order of ops is preserved.
415  //
416  // Removing an op from the mapping would be problematic because removing an
417  // element from an array invalidates iterators; merely changing the value of
418  // elements does not.
419  for (Value handle : opHandles) {
420  Mappings &mappings = getMapping(handle, /*allowOutOfScope=*/true);
421  auto it = mappings.direct.find(handle);
422  if (it == mappings.direct.end())
423  continue;
424 
425  SmallVector<Operation *, 2> &association = it->getSecond();
426  // Note that an operation may be associated with the handle more than once.
427  for (Operation *&mapped : association) {
428  if (mapped == op)
429  mapped = replacement;
430  }
431 
432  if (replacement) {
433  mappings.reverse[replacement].push_back(handle);
434  } else {
435  opHandlesToCompact.insert(handle);
436  }
437  }
438 
439  return success();
440 }
441 
442 LogicalResult
443 transform::TransformState::replacePayloadValue(Value value, Value replacement) {
444  SmallVector<Value> valueHandles;
445  if (failed(getHandlesForPayloadValue(value, valueHandles,
446  /*includeOutOfScope=*/true)))
447  return failure();
448 
449  for (Value handle : valueHandles) {
450  Mappings &mappings = getMapping(handle, /*allowOutOfScope=*/true);
451  dropMappingEntry(mappings.reverseValues, value, handle);
452 
453  // If replacing with null, that is erasing the mapping, drop the mapping
454  // between the handles and the IR objects
455  if (!replacement) {
456  dropMappingEntry(mappings.values, handle, value);
457 #if LLVM_ENABLE_ABI_BREAKING_CHECKS
458  // Payload IR is removed from the mapping. This invalidates the respective
459  // iterators.
460  mappings.incrementTimestamp(handle);
461 #endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
462  } else {
463  auto it = mappings.values.find(handle);
464  if (it == mappings.values.end())
465  continue;
466 
467  SmallVector<Value> &association = it->getSecond();
468  for (Value &mapped : association) {
469  if (mapped == value)
470  mapped = replacement;
471  }
472  mappings.reverseValues[replacement].push_back(handle);
473  }
474  }
475 
476  return success();
477 }
478 
479 void transform::TransformState::recordOpHandleInvalidationOne(
480  OpOperand &consumingHandle, ArrayRef<Operation *> potentialAncestors,
481  Operation *payloadOp, Value otherHandle, Value throughValue,
482  transform::TransformState::InvalidatedHandleMap &newlyInvalidated) const {
483  // If the op is associated with invalidated handle, skip the check as it
484  // may be reading invalid IR. This also ensures we report the first
485  // invalidation and not the last one.
486  if (invalidatedHandles.count(otherHandle) ||
487  newlyInvalidated.count(otherHandle))
488  return;
489 
490  FULL_LDBG("--recordOpHandleInvalidationOne\n");
491  DEBUG_WITH_TYPE(DEBUG_TYPE_FULL, {
492  (DBGS() << "--ancestors: "
493  << llvm::interleaved(llvm::make_pointee_range(potentialAncestors))
494  << "\n");
495  });
496 
497  Operation *owner = consumingHandle.getOwner();
498  unsigned operandNo = consumingHandle.getOperandNumber();
499  for (Operation *ancestor : potentialAncestors) {
500  // clang-format off
501  DEBUG_WITH_TYPE(DEBUG_TYPE_FULL,
502  { (DBGS() << "----handle one ancestor: " << *ancestor << "\n"); });
503  DEBUG_WITH_TYPE(DEBUG_TYPE_FULL,
504  { (DBGS() << "----of payload with name: "
505  << payloadOp->getName().getIdentifier() << "\n"); });
506  DEBUG_WITH_TYPE(DEBUG_TYPE_FULL,
507  { (DBGS() << "----of payload: " << *payloadOp << "\n"); });
508  // clang-format on
509  if (!ancestor->isAncestor(payloadOp))
510  continue;
511 
512  // Make sure the error-reporting lambda doesn't capture anything
513  // by-reference because it will go out of scope. Additionally, extract
514  // location from Payload IR ops because the ops themselves may be
515  // deleted before the lambda gets called.
516  Location ancestorLoc = ancestor->getLoc();
517  Location opLoc = payloadOp->getLoc();
518  std::optional<Location> throughValueLoc =
519  throughValue ? std::make_optional(throughValue.getLoc()) : std::nullopt;
520  newlyInvalidated[otherHandle] = [ancestorLoc, opLoc, owner, operandNo,
521  otherHandle,
522  throughValueLoc](Location currentLoc) {
523  InFlightDiagnostic diag = emitError(currentLoc)
524  << "op uses a handle invalidated by a "
525  "previously executed transform op";
526  diag.attachNote(otherHandle.getLoc()) << "handle to invalidated ops";
527  diag.attachNote(owner->getLoc())
528  << "invalidated by this transform op that consumes its operand #"
529  << operandNo
530  << " and invalidates all handles to payload IR entities associated "
531  "with this operand and entities nested in them";
532  diag.attachNote(ancestorLoc) << "ancestor payload op";
533  diag.attachNote(opLoc) << "nested payload op";
534  if (throughValueLoc) {
535  diag.attachNote(*throughValueLoc)
536  << "consumed handle points to this payload value";
537  }
538  };
539  }
540 }
541 
542 void transform::TransformState::recordValueHandleInvalidationByOpHandleOne(
543  OpOperand &opHandle, ArrayRef<Operation *> potentialAncestors,
544  Value payloadValue, Value valueHandle,
545  transform::TransformState::InvalidatedHandleMap &newlyInvalidated) const {
546  // If the op is associated with invalidated handle, skip the check as it
547  // may be reading invalid IR. This also ensures we report the first
548  // invalidation and not the last one.
549  if (invalidatedHandles.count(valueHandle) ||
550  newlyInvalidated.count(valueHandle))
551  return;
552 
553  for (Operation *ancestor : potentialAncestors) {
554  Operation *definingOp;
555  std::optional<unsigned> resultNo;
556  unsigned argumentNo = std::numeric_limits<unsigned>::max();
557  unsigned blockNo = std::numeric_limits<unsigned>::max();
558  unsigned regionNo = std::numeric_limits<unsigned>::max();
559  if (auto opResult = llvm::dyn_cast<OpResult>(payloadValue)) {
560  definingOp = opResult.getOwner();
561  resultNo = opResult.getResultNumber();
562  } else {
563  auto arg = llvm::cast<BlockArgument>(payloadValue);
564  definingOp = arg.getParentBlock()->getParentOp();
565  argumentNo = arg.getArgNumber();
566  blockNo = std::distance(arg.getOwner()->getParent()->begin(),
567  arg.getOwner()->getIterator());
568  regionNo = arg.getOwner()->getParent()->getRegionNumber();
569  }
570  assert(definingOp && "expected the value to be defined by an op as result "
571  "or block argument");
572  if (!ancestor->isAncestor(definingOp))
573  continue;
574 
575  Operation *owner = opHandle.getOwner();
576  unsigned operandNo = opHandle.getOperandNumber();
577  Location ancestorLoc = ancestor->getLoc();
578  Location opLoc = definingOp->getLoc();
579  Location valueLoc = payloadValue.getLoc();
580  newlyInvalidated[valueHandle] = [valueHandle, owner, operandNo, resultNo,
581  argumentNo, blockNo, regionNo, ancestorLoc,
582  opLoc, valueLoc](Location currentLoc) {
583  InFlightDiagnostic diag = emitError(currentLoc)
584  << "op uses a handle invalidated by a "
585  "previously executed transform op";
586  diag.attachNote(valueHandle.getLoc()) << "invalidated handle";
587  diag.attachNote(owner->getLoc())
588  << "invalidated by this transform op that consumes its operand #"
589  << operandNo
590  << " and invalidates all handles to payload IR entities "
591  "associated with this operand and entities nested in them";
592  diag.attachNote(ancestorLoc)
593  << "ancestor op associated with the consumed handle";
594  if (resultNo) {
595  diag.attachNote(opLoc)
596  << "op defining the value as result #" << *resultNo;
597  } else {
598  diag.attachNote(opLoc)
599  << "op defining the value as block argument #" << argumentNo
600  << " of block #" << blockNo << " in region #" << regionNo;
601  }
602  diag.attachNote(valueLoc) << "payload value";
603  };
604  }
605 }
606 
607 void transform::TransformState::recordOpHandleInvalidation(
608  OpOperand &handle, ArrayRef<Operation *> potentialAncestors,
609  Value throughValue,
610  transform::TransformState::InvalidatedHandleMap &newlyInvalidated) const {
611 
612  if (potentialAncestors.empty()) {
613  DEBUG_WITH_TYPE(DEBUG_TYPE_FULL, {
614  (DBGS() << "----recording invalidation for empty handle: " << handle.get()
615  << "\n");
616  });
617 
618  Operation *owner = handle.getOwner();
619  unsigned operandNo = handle.getOperandNumber();
620  newlyInvalidated[handle.get()] = [owner, operandNo](Location currentLoc) {
621  InFlightDiagnostic diag = emitError(currentLoc)
622  << "op uses a handle associated with empty "
623  "payload and invalidated by a "
624  "previously executed transform op";
625  diag.attachNote(owner->getLoc())
626  << "invalidated by this transform op that consumes its operand #"
627  << operandNo;
628  };
629  return;
630  }
631 
632  // Iterate over the mapping and invalidate aliasing handles. This is quite
633  // expensive and only necessary for error reporting in case of transform
634  // dialect misuse with dangling handles. Iteration over the handles is based
635  // on the assumption that the number of handles is significantly less than the
636  // number of IR objects (operations and values). Alternatively, we could walk
637  // the IR nested in each payload op associated with the given handle and look
638  // for handles associated with each operation and value.
639  for (const auto &[region, mapping] : llvm::reverse(mappings)) {
640  // Go over all op handle mappings and mark as invalidated any handle
641  // pointing to any of the payload ops associated with the given handle or
642  // any op nested in them.
643  for (const auto &[payloadOp, otherHandles] : mapping->reverse) {
644  for (Value otherHandle : otherHandles)
645  recordOpHandleInvalidationOne(handle, potentialAncestors, payloadOp,
646  otherHandle, throughValue,
647  newlyInvalidated);
648  }
649  // Go over all value handle mappings and mark as invalidated any handle
650  // pointing to any result of the payload op associated with the given handle
651  // or any op nested in them. Similarly invalidate handles to argument of
652  // blocks belonging to any region of any payload op associated with the
653  // given handle or any op nested in them.
654  for (const auto &[payloadValue, valueHandles] : mapping->reverseValues) {
655  for (Value valueHandle : valueHandles)
656  recordValueHandleInvalidationByOpHandleOne(handle, potentialAncestors,
657  payloadValue, valueHandle,
658  newlyInvalidated);
659  }
660 
661  // Stop lookup when reaching a region that is isolated from above.
663  break;
664  }
665 }
666 
667 void transform::TransformState::recordValueHandleInvalidation(
668  OpOperand &valueHandle,
669  transform::TransformState::InvalidatedHandleMap &newlyInvalidated) const {
670  // Invalidate other handles to the same value.
671  for (Value payloadValue : getPayloadValuesView(valueHandle.get())) {
672  SmallVector<Value> otherValueHandles;
673  (void)getHandlesForPayloadValue(payloadValue, otherValueHandles);
674  for (Value otherHandle : otherValueHandles) {
675  Operation *owner = valueHandle.getOwner();
676  unsigned operandNo = valueHandle.getOperandNumber();
677  Location valueLoc = payloadValue.getLoc();
678  newlyInvalidated[otherHandle] = [otherHandle, owner, operandNo,
679  valueLoc](Location currentLoc) {
680  InFlightDiagnostic diag = emitError(currentLoc)
681  << "op uses a handle invalidated by a "
682  "previously executed transform op";
683  diag.attachNote(otherHandle.getLoc()) << "invalidated handle";
684  diag.attachNote(owner->getLoc())
685  << "invalidated by this transform op that consumes its operand #"
686  << operandNo
687  << " and invalidates handles to the same values as associated with "
688  "it";
689  diag.attachNote(valueLoc) << "payload value";
690  };
691  }
692 
693  if (auto opResult = llvm::dyn_cast<OpResult>(payloadValue)) {
694  Operation *payloadOp = opResult.getOwner();
695  recordOpHandleInvalidation(valueHandle, payloadOp, payloadValue,
696  newlyInvalidated);
697  } else {
698  auto arg = llvm::dyn_cast<BlockArgument>(payloadValue);
699  for (Operation &payloadOp : *arg.getOwner())
700  recordOpHandleInvalidation(valueHandle, &payloadOp, payloadValue,
701  newlyInvalidated);
702  }
703  }
704 }
705 
706 /// Checks that the operation does not use invalidated handles as operands.
707 /// Reports errors and returns failure if it does. Otherwise, invalidates the
708 /// handles consumed by the operation as well as any handles pointing to payload
709 /// IR operations nested in the operations associated with the consumed handles.
710 LogicalResult transform::TransformState::checkAndRecordHandleInvalidationImpl(
711  transform::TransformOpInterface transform,
712  transform::TransformState::InvalidatedHandleMap &newlyInvalidated) const {
713  FULL_LDBG("--Start checkAndRecordHandleInvalidation\n");
714  auto memoryEffectsIface =
715  cast<MemoryEffectOpInterface>(transform.getOperation());
717  memoryEffectsIface.getEffectsOnResource(
719 
720  for (OpOperand &target : transform->getOpOperands()) {
721  DEBUG_WITH_TYPE(DEBUG_TYPE_FULL, {
722  (DBGS() << "----iterate on handle: " << target.get() << "\n");
723  });
724  // If the operand uses an invalidated handle, report it. If the operation
725  // allows handles to point to repeated payload operations, only report
726  // pre-existing invalidation errors. Otherwise, also report invalidations
727  // caused by the current transform operation affecting its other operands.
728  auto it = invalidatedHandles.find(target.get());
729  auto nit = newlyInvalidated.find(target.get());
730  if (it != invalidatedHandles.end()) {
731  FULL_LDBG("--End checkAndRecordHandleInvalidation, found already "
732  "invalidated -> FAILURE\n");
733  return it->getSecond()(transform->getLoc()), failure();
734  }
735  if (!transform.allowsRepeatedHandleOperands() &&
736  nit != newlyInvalidated.end()) {
737  FULL_LDBG("--End checkAndRecordHandleInvalidation, found newly "
738  "invalidated (by this op) -> FAILURE\n");
739  return nit->getSecond()(transform->getLoc()), failure();
740  }
741 
742  // Invalidate handles pointing to the operations nested in the operation
743  // associated with the handle consumed by this operation.
744  auto consumesTarget = [&](const MemoryEffects::EffectInstance &effect) {
745  return isa<MemoryEffects::Free>(effect.getEffect()) &&
746  effect.getValue() == target.get();
747  };
748  if (llvm::any_of(effects, consumesTarget)) {
749  FULL_LDBG("----found consume effect\n");
750  if (llvm::isa<transform::TransformHandleTypeInterface>(
751  target.get().getType())) {
752  FULL_LDBG("----recordOpHandleInvalidation\n");
753  SmallVector<Operation *> payloadOps =
754  llvm::to_vector(getPayloadOps(target.get()));
755  recordOpHandleInvalidation(target, payloadOps, nullptr,
756  newlyInvalidated);
757  } else if (llvm::isa<transform::TransformValueHandleTypeInterface>(
758  target.get().getType())) {
759  FULL_LDBG("----recordValueHandleInvalidation\n");
760  recordValueHandleInvalidation(target, newlyInvalidated);
761  } else {
762  FULL_LDBG("----not a TransformHandle -> SKIP AND DROP ON THE FLOOR\n");
763  }
764  } else {
765  FULL_LDBG("----no consume effect -> SKIP\n");
766  }
767  }
768 
769  FULL_LDBG("--End checkAndRecordHandleInvalidation -> SUCCESS\n");
770  return success();
771 }
772 
773 LogicalResult transform::TransformState::checkAndRecordHandleInvalidation(
774  transform::TransformOpInterface transform) {
775  InvalidatedHandleMap newlyInvalidated;
776  LogicalResult checkResult =
777  checkAndRecordHandleInvalidationImpl(transform, newlyInvalidated);
778  invalidatedHandles.insert(std::make_move_iterator(newlyInvalidated.begin()),
779  std::make_move_iterator(newlyInvalidated.end()));
780  return checkResult;
781 }
782 
783 template <typename T>
786  transform::TransformOpInterface transform,
787  unsigned operandNumber) {
788  DenseSet<T> seen;
789  for (T p : payload) {
790  if (!seen.insert(p).second) {
792  transform.emitSilenceableError()
793  << "a handle passed as operand #" << operandNumber
794  << " and consumed by this operation points to a payload "
795  "entity more than once";
796  if constexpr (std::is_pointer_v<T>)
797  diag.attachNote(p->getLoc()) << "repeated target op";
798  else
799  diag.attachNote(p.getLoc()) << "repeated target value";
800  return diag;
801  }
802  }
804 }
805 
806 void transform::TransformState::compactOpHandles() {
807  for (Value handle : opHandlesToCompact) {
808  Mappings &mappings = getMapping(handle, /*allowOutOfScope=*/true);
809 #if LLVM_ENABLE_ABI_BREAKING_CHECKS
810  if (llvm::find(mappings.direct[handle], nullptr) !=
811  mappings.direct[handle].end())
812  // Payload IR is removed from the mapping. This invalidates the respective
813  // iterators.
814  mappings.incrementTimestamp(handle);
815 #endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
816  llvm::erase(mappings.direct[handle], nullptr);
817  }
818  opHandlesToCompact.clear();
819 }
820 
822 transform::TransformState::applyTransform(TransformOpInterface transform) {
823  LLVM_DEBUG({
824  DBGS() << "applying: ";
825  transform->print(llvm::dbgs(), OpPrintingFlags().skipRegions());
826  llvm::dbgs() << "\n";
827  });
828  DEBUG_WITH_TYPE(DEBUG_TYPE_FULL,
829  DBGS() << "Top-level payload before application:\n"
830  << *getTopLevel() << "\n");
831  auto printOnFailureRAII = llvm::make_scope_exit([this] {
832  (void)this;
833  LLVM_DEBUG(DBGS() << "Failing Top-level payload:\n"; getTopLevel()->print(
834  llvm::dbgs(), mlir::OpPrintingFlags().printGenericOpForm()););
835  });
836 
837  // Set current transform op.
838  regionStack.back()->currentTransform = transform;
839 
840  // Expensive checks to detect invalid transform IR.
841  if (options.getExpensiveChecksEnabled()) {
842  FULL_LDBG("ExpensiveChecksEnabled\n");
843  if (failed(checkAndRecordHandleInvalidation(transform)))
845 
846  for (OpOperand &operand : transform->getOpOperands()) {
847  DEBUG_WITH_TYPE(DEBUG_TYPE_FULL, {
848  (DBGS() << "iterate on handle: " << operand.get() << "\n");
849  });
850  if (!isHandleConsumed(operand.get(), transform)) {
851  FULL_LDBG("--handle not consumed -> SKIP\n");
852  continue;
853  }
854  if (transform.allowsRepeatedHandleOperands()) {
855  FULL_LDBG("--op allows repeated handles -> SKIP\n");
856  continue;
857  }
858  FULL_LDBG("--handle is consumed\n");
859 
860  Type operandType = operand.get().getType();
861  if (llvm::isa<TransformHandleTypeInterface>(operandType)) {
862  FULL_LDBG("--checkRepeatedConsumptionInOperand for Operation*\n");
864  checkRepeatedConsumptionInOperand<Operation *>(
865  getPayloadOpsView(operand.get()), transform,
866  operand.getOperandNumber());
867  if (!check.succeeded()) {
868  FULL_LDBG("----FAILED\n");
869  return check;
870  }
871  } else if (llvm::isa<TransformValueHandleTypeInterface>(operandType)) {
872  FULL_LDBG("--checkRepeatedConsumptionInOperand For Value\n");
874  checkRepeatedConsumptionInOperand<Value>(
875  getPayloadValuesView(operand.get()), transform,
876  operand.getOperandNumber());
877  if (!check.succeeded()) {
878  FULL_LDBG("----FAILED\n");
879  return check;
880  }
881  } else {
882  FULL_LDBG("--not a TransformHandle -> SKIP AND DROP ON THE FLOOR\n");
883  }
884  }
885  }
886 
887  // Find which operands are consumed.
888  SmallVector<OpOperand *> consumedOperands =
889  transform.getConsumedHandleOpOperands();
890 
891  // Remember the results of the payload ops associated with the consumed
892  // op handles or the ops defining the value handles so we can drop the
893  // association with them later. This must happen here because the
894  // transformation may destroy or mutate them so we cannot traverse the payload
895  // IR after that.
896  SmallVector<Value> origOpFlatResults;
897  SmallVector<Operation *> origAssociatedOps;
898  for (OpOperand *opOperand : consumedOperands) {
899  Value operand = opOperand->get();
900  if (llvm::isa<TransformHandleTypeInterface>(operand.getType())) {
901  for (Operation *payloadOp : getPayloadOps(operand)) {
902  llvm::append_range(origOpFlatResults, payloadOp->getResults());
903  }
904  continue;
905  }
906  if (llvm::isa<TransformValueHandleTypeInterface>(operand.getType())) {
907  for (Value payloadValue : getPayloadValuesView(operand)) {
908  if (llvm::isa<OpResult>(payloadValue)) {
909  origAssociatedOps.push_back(payloadValue.getDefiningOp());
910  continue;
911  }
912  llvm::append_range(
913  origAssociatedOps,
914  llvm::map_range(*llvm::cast<BlockArgument>(payloadValue).getOwner(),
915  [](Operation &op) { return &op; }));
916  }
917  continue;
918  }
920  emitDefiniteFailure(transform->getLoc())
921  << "unexpectedly consumed a value that is not a handle as operand #"
922  << opOperand->getOperandNumber();
923  diag.attachNote(operand.getLoc())
924  << "value defined here with type " << operand.getType();
925  return diag;
926  }
927 
928  // Prepare rewriter and listener.
930  config.skipHandleFn = [&](Value handle) {
931  // Skip handle if it is dead.
932  auto scopeIt =
933  llvm::find_if(llvm::reverse(regionStack), [&](RegionScope *scope) {
934  return handle.getParentRegion() == scope->region;
935  });
936  assert(scopeIt != regionStack.rend() &&
937  "could not find region scope for handle");
938  RegionScope *scope = *scopeIt;
939  return llvm::all_of(handle.getUsers(), [&](Operation *user) {
940  return user == scope->currentTransform ||
941  happensBefore(user, scope->currentTransform);
942  });
943  };
944  transform::ErrorCheckingTrackingListener trackingListener(*this, transform,
945  config);
946  transform::TransformRewriter rewriter(transform->getContext(),
947  &trackingListener);
948 
949  // Compute the result but do not short-circuit the silenceable failure case as
950  // we still want the handles to propagate properly so the "suppress" mode can
951  // proceed on a best effort basis.
952  transform::TransformResults results(transform->getNumResults());
953  DiagnosedSilenceableFailure result(transform.apply(rewriter, results, *this));
954  compactOpHandles();
955 
956  // Error handling: fail if transform or listener failed.
957  DiagnosedSilenceableFailure trackingFailure =
958  trackingListener.checkAndResetError();
959  if (!transform->hasTrait<ReportTrackingListenerFailuresOpTrait>() ||
960  transform->hasAttr(FindPayloadReplacementOpInterface::
961  kSilenceTrackingFailuresAttrName)) {
962  // Only report failures for ReportTrackingListenerFailuresOpTrait ops. Also
963  // do not report failures if the above mentioned attribute is set.
964  if (trackingFailure.isSilenceableFailure())
965  (void)trackingFailure.silence();
966  trackingFailure = DiagnosedSilenceableFailure::success();
967  }
968  if (!trackingFailure.succeeded()) {
969  if (result.succeeded()) {
970  result = std::move(trackingFailure);
971  } else {
972  // Transform op errors have precedence, report those first.
973  if (result.isSilenceableFailure())
974  result.attachNote() << "tracking listener also failed: "
975  << trackingFailure.getMessage();
976  (void)trackingFailure.silence();
977  }
978  }
979  if (result.isDefiniteFailure())
980  return result;
981 
982  // If a silenceable failure was produced, some results may be unset, set them
983  // to empty lists.
984  if (result.isSilenceableFailure())
985  results.setRemainingToEmpty(transform);
986 
987  // Remove the mapping for the operand if it is consumed by the operation. This
988  // allows us to catch use-after-free with assertions later on.
989  for (OpOperand *opOperand : consumedOperands) {
990  Value operand = opOperand->get();
991  if (llvm::isa<TransformHandleTypeInterface>(operand.getType())) {
992  forgetMapping(operand, origOpFlatResults);
993  } else if (llvm::isa<TransformValueHandleTypeInterface>(
994  operand.getType())) {
995  forgetValueMapping(operand, origAssociatedOps);
996  }
997  }
998 
999  if (failed(updateStateFromResults(results, transform->getResults())))
1001 
1002  printOnFailureRAII.release();
1003  DEBUG_WITH_TYPE(DEBUG_PRINT_AFTER_ALL, {
1004  DBGS() << "Top-level payload:\n";
1005  getTopLevel()->print(llvm::dbgs());
1006  });
1007  return result;
1008 }
1009 
1010 LogicalResult transform::TransformState::updateStateFromResults(
1011  const TransformResults &results, ResultRange opResults) {
1012  for (OpResult result : opResults) {
1013  if (llvm::isa<TransformParamTypeInterface>(result.getType())) {
1014  assert(results.isParam(result.getResultNumber()) &&
1015  "expected parameters for the parameter-typed result");
1016  if (failed(
1017  setParams(result, results.getParams(result.getResultNumber())))) {
1018  return failure();
1019  }
1020  } else if (llvm::isa<TransformValueHandleTypeInterface>(result.getType())) {
1021  assert(results.isValue(result.getResultNumber()) &&
1022  "expected values for value-type-result");
1023  if (failed(setPayloadValues(
1024  result, results.getValues(result.getResultNumber())))) {
1025  return failure();
1026  }
1027  } else {
1028  assert(!results.isParam(result.getResultNumber()) &&
1029  "expected payload ops for the non-parameter typed result");
1030  if (failed(
1031  setPayloadOps(result, results.get(result.getResultNumber())))) {
1032  return failure();
1033  }
1034  }
1035  }
1036  return success();
1037 }
1038 
1039 //===----------------------------------------------------------------------===//
1040 // TransformState::Extension
1041 //===----------------------------------------------------------------------===//
1042 
1044 
1045 LogicalResult
1047  Operation *replacement) {
1048  // TODO: we may need to invalidate handles to operations and values nested in
1049  // the operation being replaced.
1050  return state.replacePayloadOp(op, replacement);
1051 }
1052 
1053 LogicalResult
1055  Value replacement) {
1056  return state.replacePayloadValue(value, replacement);
1057 }
1058 
1059 //===----------------------------------------------------------------------===//
1060 // TransformState::RegionScope
1061 //===----------------------------------------------------------------------===//
1062 
1064  // Remove handle invalidation notices as handles are going out of scope.
1065  // The same region may be re-entered leading to incorrect invalidation
1066  // errors.
1067  for (Block &block : *region) {
1068  for (Value handle : block.getArguments()) {
1069  state.invalidatedHandles.erase(handle);
1070  }
1071  for (Operation &op : block) {
1072  for (Value handle : op.getResults()) {
1073  state.invalidatedHandles.erase(handle);
1074  }
1075  }
1076  }
1077 
1078 #if LLVM_ENABLE_ABI_BREAKING_CHECKS
1079  // Remember pointers to payload ops referenced by the handles going out of
1080  // scope.
1081  SmallVector<Operation *> referencedOps =
1082  llvm::to_vector(llvm::make_first_range(state.mappings[region]->reverse));
1083 #endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
1084 
1085  state.mappings.erase(region);
1086  state.regionStack.pop_back();
1087 }
1088 
1089 //===----------------------------------------------------------------------===//
1090 // TransformResults
1091 //===----------------------------------------------------------------------===//
1092 
1093 transform::TransformResults::TransformResults(unsigned numSegments) {
1094  operations.appendEmptyRows(numSegments);
1095  params.appendEmptyRows(numSegments);
1096  values.appendEmptyRows(numSegments);
1097 }
1098 
1101  int64_t position = value.getResultNumber();
1102  assert(position < static_cast<int64_t>(this->params.size()) &&
1103  "setting params for a non-existent handle");
1104  assert(this->params[position].data() == nullptr && "params already set");
1105  assert(operations[position].data() == nullptr &&
1106  "another kind of results already set");
1107  assert(values[position].data() == nullptr &&
1108  "another kind of results already set");
1109  this->params.replace(position, params);
1110 }
1111 
1113  OpResult handle, ArrayRef<MappedValue> values) {
1115  handle, values,
1116  [&](ArrayRef<Operation *> operations) {
1117  return set(handle, operations), success();
1118  },
1119  [&](ArrayRef<Param> params) {
1120  return setParams(handle, params), success();
1121  },
1122  [&](ValueRange payloadValues) {
1123  return setValues(handle, payloadValues), success();
1124  });
1125 #ifndef NDEBUG
1126  if (!diag.succeeded())
1127  llvm::dbgs() << diag.getStatusString() << "\n";
1128  assert(diag.succeeded() && "incorrect mapping");
1129 #endif // NDEBUG
1130  (void)diag.silence();
1131 }
1132 
1134  transform::TransformOpInterface transform) {
1135  for (OpResult opResult : transform->getResults()) {
1136  if (!isSet(opResult.getResultNumber()))
1137  setMappedValues(opResult, {});
1138  }
1139 }
1140 
1142 transform::TransformResults::get(unsigned resultNumber) const {
1143  assert(resultNumber < operations.size() &&
1144  "querying results for a non-existent handle");
1145  assert(operations[resultNumber].data() != nullptr &&
1146  "querying unset results (values or params expected?)");
1147  return operations[resultNumber];
1148 }
1149 
1151 transform::TransformResults::getParams(unsigned resultNumber) const {
1152  assert(resultNumber < params.size() &&
1153  "querying params for a non-existent handle");
1154  assert(params[resultNumber].data() != nullptr &&
1155  "querying unset params (ops or values expected?)");
1156  return params[resultNumber];
1157 }
1158 
1160 transform::TransformResults::getValues(unsigned resultNumber) const {
1161  assert(resultNumber < values.size() &&
1162  "querying values for a non-existent handle");
1163  assert(values[resultNumber].data() != nullptr &&
1164  "querying unset values (ops or params expected?)");
1165  return values[resultNumber];
1166 }
1167 
1168 bool transform::TransformResults::isParam(unsigned resultNumber) const {
1169  assert(resultNumber < params.size() &&
1170  "querying association for a non-existent handle");
1171  return params[resultNumber].data() != nullptr;
1172 }
1173 
1174 bool transform::TransformResults::isValue(unsigned resultNumber) const {
1175  assert(resultNumber < values.size() &&
1176  "querying association for a non-existent handle");
1177  return values[resultNumber].data() != nullptr;
1178 }
1179 
1180 bool transform::TransformResults::isSet(unsigned resultNumber) const {
1181  assert(resultNumber < params.size() &&
1182  "querying association for a non-existent handle");
1183  return params[resultNumber].data() != nullptr ||
1184  operations[resultNumber].data() != nullptr ||
1185  values[resultNumber].data() != nullptr;
1186 }
1187 
1188 //===----------------------------------------------------------------------===//
1189 // TrackingListener
1190 //===----------------------------------------------------------------------===//
1191 
1193  TransformOpInterface op,
1195  : TransformState::Extension(state), transformOp(op), config(config) {
1196  if (op) {
1197  for (OpOperand *opOperand : transformOp.getConsumedHandleOpOperands()) {
1198  consumedHandles.insert(opOperand->get());
1199  }
1200  }
1201 }
1202 
1204  Operation *defOp = nullptr;
1205  for (Value v : values) {
1206  // Skip empty values.
1207  if (!v)
1208  continue;
1209  if (!defOp) {
1210  defOp = v.getDefiningOp();
1211  continue;
1212  }
1213  if (defOp != v.getDefiningOp())
1214  return nullptr;
1215  }
1216  return defOp;
1217 }
1218 
1220  Operation *&result, Operation *op, ValueRange newValues) const {
1221  assert(op->getNumResults() == newValues.size() &&
1222  "invalid number of replacement values");
1223  SmallVector<Value> values(newValues.begin(), newValues.end());
1224 
1226  getTransformOp(), "tracking listener failed to find replacement op "
1227  "during application of this transform op");
1228 
1229  do {
1230  // If the replacement values belong to different ops, drop the mapping.
1231  Operation *defOp = getCommonDefiningOp(values);
1232  if (!defOp) {
1233  diag.attachNote() << "replacement values belong to different ops";
1234  return diag;
1235  }
1236 
1237  // Skip through ops that implement CastOpInterface.
1238  if (config.skipCastOps && isa<CastOpInterface>(defOp)) {
1239  values.clear();
1240  values.assign(defOp->getOperands().begin(), defOp->getOperands().end());
1241  diag.attachNote(defOp->getLoc())
1242  << "using output of 'CastOpInterface' op";
1243  continue;
1244  }
1245 
1246  // If the defining op has the same name or we do not care about the name of
1247  // op replacements at all, we take it as a replacement.
1248  if (!config.requireMatchingReplacementOpName ||
1249  op->getName() == defOp->getName()) {
1250  result = defOp;
1252  }
1253 
1254  // Replacing an op with a constant-like equivalent is a common
1255  // canonicalization.
1256  if (defOp->hasTrait<OpTrait::ConstantLike>()) {
1257  result = defOp;
1259  }
1260 
1261  values.clear();
1262 
1263  // Skip through ops that implement FindPayloadReplacementOpInterface.
1264  if (auto findReplacementOpInterface =
1265  dyn_cast<FindPayloadReplacementOpInterface>(defOp)) {
1266  values.assign(findReplacementOpInterface.getNextOperands());
1267  diag.attachNote(defOp->getLoc()) << "using operands provided by "
1268  "'FindPayloadReplacementOpInterface'";
1269  continue;
1270  }
1271  } while (!values.empty());
1272 
1273  diag.attachNote() << "ran out of suitable replacement values";
1274  return diag;
1275 }
1276 
1278  Location loc, function_ref<void(Diagnostic &)> reasonCallback) {
1279  LLVM_DEBUG({
1281  reasonCallback(diag);
1282  DBGS() << "Match Failure : " << diag.str() << "\n";
1283  });
1284 }
1285 
1286 void transform::TrackingListener::notifyOperationErased(Operation *op) {
1287  // Remove mappings for result values.
1288  for (OpResult value : op->getResults())
1289  (void)replacePayloadValue(value, nullptr);
1290  // Remove mapping for op.
1291  (void)replacePayloadOp(op, nullptr);
1292 }
1293 
1294 void transform::TrackingListener::notifyOperationReplaced(
1295  Operation *op, ValueRange newValues) {
1296  assert(op->getNumResults() == newValues.size() &&
1297  "invalid number of replacement values");
1298 
1299  // Replace value handles.
1300  for (auto [oldValue, newValue] : llvm::zip(op->getResults(), newValues))
1301  (void)replacePayloadValue(oldValue, newValue);
1302 
1303  // Replace op handle.
1304  SmallVector<Value> opHandles;
1305  if (failed(getTransformState().getHandlesForPayloadOp(
1306  op, opHandles, /*includeOutOfScope=*/true))) {
1307  // Op is not tracked.
1308  return;
1309  }
1310 
1311  // Helper function to check if the current transform op consumes any handle
1312  // that is mapped to `op`.
1313  //
1314  // Note: If a handle was consumed, there shouldn't be any alive users, so it
1315  // is not really necessary to check for consumed handles. However, in case
1316  // there are indeed alive handles that were consumed (which is undefined
1317  // behavior) and a replacement op could not be found, we want to fail with a
1318  // nicer error message: "op uses a handle invalidated..." instead of "could
1319  // not find replacement op". This nicer error is produced later.
1320  auto handleWasConsumed = [&] {
1321  return llvm::any_of(opHandles,
1322  [&](Value h) { return consumedHandles.contains(h); });
1323  };
1324 
1325  // Check if there are any handles that must be updated.
1326  Value aliveHandle;
1327  if (config.skipHandleFn) {
1328  auto it = llvm::find_if(opHandles,
1329  [&](Value v) { return !config.skipHandleFn(v); });
1330  if (it != opHandles.end())
1331  aliveHandle = *it;
1332  } else if (!opHandles.empty()) {
1333  aliveHandle = opHandles.front();
1334  }
1335  if (!aliveHandle || handleWasConsumed()) {
1336  // The op is tracked but the corresponding handles are dead or were
1337  // consumed. Drop the op form the mapping.
1338  (void)replacePayloadOp(op, nullptr);
1339  return;
1340  }
1341 
1342  Operation *replacement;
1344  findReplacementOp(replacement, op, newValues);
1345  // If the op is tracked but no replacement op was found, send a
1346  // notification.
1347  if (!diag.succeeded()) {
1348  diag.attachNote(aliveHandle.getLoc())
1349  << "replacement is required because this handle must be updated";
1350  notifyPayloadReplacementNotFound(op, newValues, std::move(diag));
1351  (void)replacePayloadOp(op, nullptr);
1352  return;
1353  }
1354 
1355  (void)replacePayloadOp(op, replacement);
1356 }
1357 
1359  // The state of the ErrorCheckingTrackingListener must be checked and reset
1360  // if there was an error. This is to prevent errors from accidentally being
1361  // missed.
1362  assert(status.succeeded() && "listener state was not checked");
1363 }
1364 
1367  DiagnosedSilenceableFailure s = std::move(status);
1369  errorCounter = 0;
1370  return s;
1371 }
1372 
1374  return !status.succeeded();
1375 }
1376 
1379 
1380  // Merge potentially existing diags and store the result in the listener.
1382  diag.takeDiagnostics(diags);
1383  if (!status.succeeded())
1384  status.takeDiagnostics(diags);
1385  status = DiagnosedSilenceableFailure::silenceableFailure(std::move(diags));
1386 
1387  // Report more details.
1388  status.attachNote(op->getLoc()) << "[" << errorCounter << "] replaced op";
1389  for (auto &&[index, value] : llvm::enumerate(values))
1390  status.attachNote(value.getLoc())
1391  << "[" << errorCounter << "] replacement value " << index;
1392  ++errorCounter;
1393 }
1394 
1395 std::string
1397  if (!matchFailure) {
1398  return "";
1399  }
1400  return matchFailure->str();
1401 }
1402 
1404  Location loc, function_ref<void(Diagnostic &)> reasonCallback) {
1406  reasonCallback(diag);
1407  matchFailure = std::move(diag);
1408 }
1409 
1410 //===----------------------------------------------------------------------===//
1411 // TransformRewriter
1412 //===----------------------------------------------------------------------===//
1413 
1416  : RewriterBase(ctx), listener(listener) {
1417  setListener(listener);
1418 }
1419 
1421  return listener->failed();
1422 }
1423 
1424 /// Silence all tracking failures that have been encountered so far.
1426  if (hasTrackingFailures()) {
1427  DiagnosedSilenceableFailure status = listener->checkAndResetError();
1428  (void)status.silence();
1429  }
1430 }
1431 
1433  Operation *op, Operation *replacement) {
1434  return listener->replacePayloadOp(op, replacement);
1435 }
1436 
1437 //===----------------------------------------------------------------------===//
1438 // Utilities for TransformEachOpTrait.
1439 //===----------------------------------------------------------------------===//
1440 
1441 LogicalResult
1443  ArrayRef<Operation *> targets) {
1444  for (auto &&[position, parent] : llvm::enumerate(targets)) {
1445  for (Operation *child : targets.drop_front(position + 1)) {
1446  if (parent->isAncestor(child)) {
1448  emitError(loc)
1449  << "transform operation consumes a handle pointing to an ancestor "
1450  "payload operation before its descendant";
1451  diag.attachNote()
1452  << "the ancestor is likely erased or rewritten before the "
1453  "descendant is accessed, leading to undefined behavior";
1454  diag.attachNote(parent->getLoc()) << "ancestor payload op";
1455  diag.attachNote(child->getLoc()) << "descendant payload op";
1456  return diag;
1457  }
1458  }
1459  }
1460  return success();
1461 }
1462 
1463 LogicalResult
1465  Location payloadOpLoc,
1466  const ApplyToEachResultList &partialResult) {
1467  Location transformOpLoc = transformOp->getLoc();
1468  StringRef transformOpName = transformOp->getName().getStringRef();
1469  unsigned expectedNumResults = transformOp->getNumResults();
1470 
1471  // Reuse the emission of the diagnostic note.
1472  auto emitDiag = [&]() {
1473  auto diag = mlir::emitError(transformOpLoc);
1474  diag.attachNote(payloadOpLoc) << "when applied to this op";
1475  return diag;
1476  };
1477 
1478  if (partialResult.size() != expectedNumResults) {
1479  auto diag = emitDiag() << "application of " << transformOpName
1480  << " expected to produce " << expectedNumResults
1481  << " results (actually produced "
1482  << partialResult.size() << ").";
1483  diag.attachNote(transformOpLoc)
1484  << "if you need variadic results, consider a generic `apply` "
1485  << "instead of the specialized `applyToOne`.";
1486  return failure();
1487  }
1488 
1489  // Check that the right kind of value was produced.
1490  for (const auto &[ptr, res] :
1491  llvm::zip(partialResult, transformOp->getResults())) {
1492  if (ptr.isNull())
1493  continue;
1494  if (llvm::isa<TransformHandleTypeInterface>(res.getType()) &&
1495  !isa<Operation *>(ptr)) {
1496  return emitDiag() << "application of " << transformOpName
1497  << " expected to produce an Operation * for result #"
1498  << res.getResultNumber();
1499  }
1500  if (llvm::isa<TransformParamTypeInterface>(res.getType()) &&
1501  !isa<Attribute>(ptr)) {
1502  return emitDiag() << "application of " << transformOpName
1503  << " expected to produce an Attribute for result #"
1504  << res.getResultNumber();
1505  }
1506  if (llvm::isa<TransformValueHandleTypeInterface>(res.getType()) &&
1507  !isa<Value>(ptr)) {
1508  return emitDiag() << "application of " << transformOpName
1509  << " expected to produce a Value for result #"
1510  << res.getResultNumber();
1511  }
1512  }
1513  return success();
1514 }
1515 
1516 template <typename T>
1518  return llvm::to_vector(llvm::map_range(
1519  range, [](transform::MappedValue value) { return cast<T>(value); }));
1520 }
1521 
1523  Operation *transformOp, TransformResults &transformResults,
1526  transposed.resize(transformOp->getNumResults());
1527  for (const ApplyToEachResultList &partialResults : results) {
1528  if (llvm::any_of(partialResults,
1529  [](MappedValue value) { return value.isNull(); }))
1530  continue;
1531  assert(transformOp->getNumResults() == partialResults.size() &&
1532  "expected as many partial results as op as results");
1533  for (auto [i, value] : llvm::enumerate(partialResults))
1534  transposed[i].push_back(value);
1535  }
1536 
1537  for (OpResult r : transformOp->getResults()) {
1538  unsigned position = r.getResultNumber();
1539  if (llvm::isa<TransformParamTypeInterface>(r.getType())) {
1540  transformResults.setParams(r,
1541  castVector<Attribute>(transposed[position]));
1542  } else if (llvm::isa<TransformValueHandleTypeInterface>(r.getType())) {
1543  transformResults.setValues(r, castVector<Value>(transposed[position]));
1544  } else {
1545  transformResults.set(r, castVector<Operation *>(transposed[position]));
1546  }
1547  }
1548 }
1549 
1550 //===----------------------------------------------------------------------===//
1551 // Utilities for implementing transform ops with regions.
1552 //===----------------------------------------------------------------------===//
1553 
1556  ValueRange values, const transform::TransformState &state, bool flatten) {
1557  assert(mappings.size() == values.size() && "mismatching number of mappings");
1558  for (auto &&[operand, mapped] : llvm::zip_equal(values, mappings)) {
1559  size_t mappedSize = mapped.size();
1560  if (llvm::isa<TransformHandleTypeInterface>(operand.getType())) {
1561  llvm::append_range(mapped, state.getPayloadOps(operand));
1562  } else if (llvm::isa<TransformValueHandleTypeInterface>(
1563  operand.getType())) {
1564  llvm::append_range(mapped, state.getPayloadValues(operand));
1565  } else {
1566  assert(llvm::isa<TransformParamTypeInterface>(operand.getType()) &&
1567  "unsupported kind of transform dialect value");
1568  llvm::append_range(mapped, state.getParams(operand));
1569  }
1570 
1571  if (mapped.size() - mappedSize != 1 && !flatten)
1572  return failure();
1573  }
1574  return success();
1575 }
1576 
1579  ValueRange values, const transform::TransformState &state) {
1580  mappings.resize(mappings.size() + values.size());
1581  (void)appendValueMappings(
1583  values.size()),
1584  values, state);
1585 }
1586 
1588  Block *block, transform::TransformState &state,
1589  transform::TransformResults &results) {
1590  for (auto &&[terminatorOperand, result] :
1591  llvm::zip(block->getTerminator()->getOperands(),
1592  block->getParentOp()->getOpResults())) {
1593  if (llvm::isa<transform::TransformHandleTypeInterface>(result.getType())) {
1594  results.set(result, state.getPayloadOps(terminatorOperand));
1595  } else if (llvm::isa<transform::TransformValueHandleTypeInterface>(
1596  result.getType())) {
1597  results.setValues(result, state.getPayloadValues(terminatorOperand));
1598  } else {
1599  assert(
1600  llvm::isa<transform::TransformParamTypeInterface>(result.getType()) &&
1601  "unhandled transform type interface");
1602  results.setParams(result, state.getParams(terminatorOperand));
1603  }
1604  }
1605 }
1606 
1609  Operation *payloadRoot) {
1610  return TransformState(region, payloadRoot);
1611 }
1612 
1613 //===----------------------------------------------------------------------===//
1614 // Utilities for PossibleTopLevelTransformOpTrait.
1615 //===----------------------------------------------------------------------===//
1616 
1617 /// Appends to `effects` the memory effect instances on `target` with the same
1618 /// resource and effect as the ones the operation `iface` having on `source`.
1619 static void
1620 remapEffects(MemoryEffectOpInterface iface, BlockArgument source,
1621  OpOperand *target,
1624  iface.getEffectsOnValue(source, nestedEffects);
1625  for (const auto &effect : nestedEffects)
1626  effects.emplace_back(effect.getEffect(), target, effect.getResource());
1627 }
1628 
1629 /// Appends to `effects` the same effects as the operations of `block` have on
1630 /// block arguments but associated with `operands.`
1631 static void
1634  for (Operation &op : block) {
1635  auto iface = dyn_cast<MemoryEffectOpInterface>(&op);
1636  if (!iface)
1637  continue;
1638 
1639  for (auto &&[source, target] : llvm::zip(block.getArguments(), operands)) {
1640  remapEffects(iface, source, &target, effects);
1641  }
1642 
1644  iface.getEffectsOnResource(transform::PayloadIRResource::get(),
1645  nestedEffects);
1646  llvm::append_range(effects, nestedEffects);
1647  }
1648 }
1649 
1651  Operation *operation, Value root, Block &body,
1653  transform::onlyReadsHandle(operation->getOpOperands(), effects);
1654  transform::producesHandle(operation->getOpResults(), effects);
1655 
1656  if (!root) {
1657  for (Operation &op : body) {
1658  auto iface = dyn_cast<MemoryEffectOpInterface>(&op);
1659  if (!iface)
1660  continue;
1661 
1663  iface.getEffects(effects);
1664  }
1665  return;
1666  }
1667 
1668  // Carry over all effects on arguments of the entry block as those on the
1669  // operands, this is the same value just remapped.
1670  remapArgumentEffects(body, operation->getOpOperands(), effects);
1671 }
1672 
1674  TransformState &state, Operation *op, Region &region) {
1675  SmallVector<Operation *> targets;
1676  SmallVector<SmallVector<MappedValue>> extraMappings;
1677  if (op->getNumOperands() != 0) {
1678  llvm::append_range(targets, state.getPayloadOps(op->getOperand(0)));
1679  prepareValueMappings(extraMappings, op->getOperands().drop_front(), state);
1680  } else {
1681  if (state.getNumTopLevelMappings() !=
1682  region.front().getNumArguments() - 1) {
1683  return emitError(op->getLoc())
1684  << "operation expects " << region.front().getNumArguments() - 1
1685  << " extra value bindings, but " << state.getNumTopLevelMappings()
1686  << " were provided to the interpreter";
1687  }
1688 
1689  targets.push_back(state.getTopLevel());
1690 
1691  for (unsigned i = 0, e = state.getNumTopLevelMappings(); i < e; ++i)
1692  extraMappings.push_back(llvm::to_vector(state.getTopLevelMapping(i)));
1693  }
1694 
1695  if (failed(state.mapBlockArguments(region.front().getArgument(0), targets)))
1696  return failure();
1697 
1698  for (BlockArgument argument : region.front().getArguments().drop_front()) {
1699  if (failed(state.mapBlockArgument(
1700  argument, extraMappings[argument.getArgNumber() - 1])))
1701  return failure();
1702  }
1703 
1704  return success();
1705 }
1706 
1707 LogicalResult
1709  // Attaching this trait without the interface is a misuse of the API, but it
1710  // cannot be caught via a static_assert because interface registration is
1711  // dynamic.
1712  assert(isa<TransformOpInterface>(op) &&
1713  "should implement TransformOpInterface to have "
1714  "PossibleTopLevelTransformOpTrait");
1715 
1716  if (op->getNumRegions() < 1)
1717  return op->emitOpError() << "expects at least one region";
1718 
1719  Region *bodyRegion = &op->getRegion(0);
1720  if (!llvm::hasNItems(*bodyRegion, 1))
1721  return op->emitOpError() << "expects a single-block region";
1722 
1723  Block *body = &bodyRegion->front();
1724  if (body->getNumArguments() == 0) {
1725  return op->emitOpError()
1726  << "expects the entry block to have at least one argument";
1727  }
1728  if (!llvm::isa<TransformHandleTypeInterface>(
1729  body->getArgument(0).getType())) {
1730  return op->emitOpError()
1731  << "expects the first entry block argument to be of type "
1732  "implementing TransformHandleTypeInterface";
1733  }
1734  BlockArgument arg = body->getArgument(0);
1735  if (op->getNumOperands() != 0) {
1736  if (arg.getType() != op->getOperand(0).getType()) {
1737  return op->emitOpError()
1738  << "expects the type of the block argument to match "
1739  "the type of the operand";
1740  }
1741  }
1742  for (BlockArgument arg : body->getArguments().drop_front()) {
1743  if (llvm::isa<TransformHandleTypeInterface, TransformParamTypeInterface,
1744  TransformValueHandleTypeInterface>(arg.getType()))
1745  continue;
1746 
1748  op->emitOpError()
1749  << "expects trailing entry block arguments to be of type implementing "
1750  "TransformHandleTypeInterface, TransformValueHandleTypeInterface or "
1751  "TransformParamTypeInterface";
1752  diag.attachNote() << "argument #" << arg.getArgNumber() << " does not";
1753  return diag;
1754  }
1755 
1756  if (auto *parent =
1758  if (op->getNumOperands() != body->getNumArguments()) {
1760  op->emitOpError()
1761  << "expects operands to be provided for a nested op";
1762  diag.attachNote(parent->getLoc())
1763  << "nested in another possible top-level op";
1764  return diag;
1765  }
1766  }
1767 
1768  return success();
1769 }
1770 
1771 //===----------------------------------------------------------------------===//
1772 // Utilities for ParamProducedTransformOpTrait.
1773 //===----------------------------------------------------------------------===//
1774 
1777  producesHandle(op->getResults(), effects);
1778  bool hasPayloadOperands = false;
1779  for (OpOperand &operand : op->getOpOperands()) {
1780  onlyReadsHandle(operand, effects);
1781  if (llvm::isa<TransformHandleTypeInterface,
1782  TransformValueHandleTypeInterface>(operand.get().getType()))
1783  hasPayloadOperands = true;
1784  }
1785  if (hasPayloadOperands)
1786  onlyReadsPayload(effects);
1787 }
1788 
1789 LogicalResult
1791  // Interfaces can be attached dynamically, so this cannot be a static
1792  // assert.
1793  if (!op->getName().getInterface<MemoryEffectOpInterface>()) {
1794  llvm::report_fatal_error(
1795  Twine("ParamProducerTransformOpTrait must be attached to an op that "
1796  "implements MemoryEffectsOpInterface, found on ") +
1797  op->getName().getStringRef());
1798  }
1799  for (Value result : op->getResults()) {
1800  if (llvm::isa<TransformParamTypeInterface>(result.getType()))
1801  continue;
1802  return op->emitOpError()
1803  << "ParamProducerTransformOpTrait attached to this op expects "
1804  "result types to implement TransformParamTypeInterface";
1805  }
1806  return success();
1807 }
1808 
1809 //===----------------------------------------------------------------------===//
1810 // Memory effects.
1811 //===----------------------------------------------------------------------===//
1812 
1816  for (OpOperand &handle : handles) {
1817  effects.emplace_back(MemoryEffects::Read::get(), &handle,
1819  effects.emplace_back(MemoryEffects::Free::get(), &handle,
1821  }
1822 }
1823 
1824 /// Returns `true` if the given list of effects instances contains an instance
1825 /// with the effect type specified as template parameter.
1826 template <typename EffectTy, typename ResourceTy, typename Range>
1827 static bool hasEffect(Range &&effects) {
1828  return llvm::any_of(effects, [](const MemoryEffects::EffectInstance &effect) {
1829  return isa<EffectTy>(effect.getEffect()) &&
1830  isa<ResourceTy>(effect.getResource());
1831  });
1832 }
1833 
1835  transform::TransformOpInterface transform) {
1836  auto iface = cast<MemoryEffectOpInterface>(transform.getOperation());
1838  iface.getEffectsOnValue(handle, effects);
1839  return ::hasEffect<MemoryEffects::Read, TransformMappingResource>(effects) &&
1840  ::hasEffect<MemoryEffects::Free, TransformMappingResource>(effects);
1841 }
1842 
1844  ResultRange handles,
1846  for (OpResult handle : handles) {
1847  effects.emplace_back(MemoryEffects::Allocate::get(), handle,
1849  effects.emplace_back(MemoryEffects::Write::get(), handle,
1851  }
1852 }
1853 
1857  for (BlockArgument handle : handles) {
1858  effects.emplace_back(MemoryEffects::Allocate::get(), handle,
1860  effects.emplace_back(MemoryEffects::Write::get(), handle,
1862  }
1863 }
1864 
1868  for (OpOperand &handle : handles) {
1869  effects.emplace_back(MemoryEffects::Read::get(), &handle,
1871  }
1872 }
1873 
1876  effects.emplace_back(MemoryEffects::Read::get(), PayloadIRResource::get());
1877  effects.emplace_back(MemoryEffects::Write::get(), PayloadIRResource::get());
1878 }
1879 
1882  effects.emplace_back(MemoryEffects::Read::get(), PayloadIRResource::get());
1883 }
1884 
1885 bool transform::doesModifyPayload(transform::TransformOpInterface transform) {
1886  auto iface = cast<MemoryEffectOpInterface>(transform.getOperation());
1888  iface.getEffects(effects);
1889  return ::hasEffect<MemoryEffects::Write, PayloadIRResource>(effects);
1890 }
1891 
1892 bool transform::doesReadPayload(transform::TransformOpInterface transform) {
1893  auto iface = cast<MemoryEffectOpInterface>(transform.getOperation());
1895  iface.getEffects(effects);
1896  return ::hasEffect<MemoryEffects::Read, PayloadIRResource>(effects);
1897 }
1898 
1900  Block &block, llvm::SmallDenseSet<unsigned int> &consumedArguments) {
1902  for (Operation &nested : block) {
1903  auto iface = dyn_cast<MemoryEffectOpInterface>(nested);
1904  if (!iface)
1905  continue;
1906 
1907  effects.clear();
1908  iface.getEffects(effects);
1909  for (const MemoryEffects::EffectInstance &effect : effects) {
1910  BlockArgument argument =
1911  dyn_cast_or_null<BlockArgument>(effect.getValue());
1912  if (!argument || argument.getOwner() != &block ||
1913  !isa<MemoryEffects::Free>(effect.getEffect()) ||
1914  effect.getResource() != transform::TransformMappingResource::get()) {
1915  continue;
1916  }
1917  consumedArguments.insert(argument.getArgNumber());
1918  }
1919  }
1920 }
1921 
1922 //===----------------------------------------------------------------------===//
1923 // Utilities for TransformOpInterface.
1924 //===----------------------------------------------------------------------===//
1925 
1927  TransformOpInterface transformOp) {
1928  SmallVector<OpOperand *> consumedOperands;
1929  consumedOperands.reserve(transformOp->getNumOperands());
1930  auto memEffectInterface =
1931  cast<MemoryEffectOpInterface>(transformOp.getOperation());
1933  for (OpOperand &target : transformOp->getOpOperands()) {
1934  effects.clear();
1935  memEffectInterface.getEffectsOnValue(target.get(), effects);
1936  if (llvm::any_of(effects, [](const MemoryEffects::EffectInstance &effect) {
1937  return isa<transform::TransformMappingResource>(
1938  effect.getResource()) &&
1939  isa<MemoryEffects::Free>(effect.getEffect());
1940  })) {
1941  consumedOperands.push_back(&target);
1942  }
1943  }
1944  return consumedOperands;
1945 }
1946 
1948  auto iface = cast<MemoryEffectOpInterface>(op);
1950  iface.getEffects(effects);
1951 
1952  auto effectsOn = [&](Value value) {
1953  return llvm::make_filter_range(
1954  effects, [value](const MemoryEffects::EffectInstance &instance) {
1955  return instance.getValue() == value;
1956  });
1957  };
1958 
1959  std::optional<unsigned> firstConsumedOperand;
1960  for (OpOperand &operand : op->getOpOperands()) {
1961  auto range = effectsOn(operand.get());
1962  if (range.empty()) {
1964  op->emitError() << "TransformOpInterface requires memory effects "
1965  "on operands to be specified";
1966  diag.attachNote() << "no effects specified for operand #"
1967  << operand.getOperandNumber();
1968  return diag;
1969  }
1970  if (::hasEffect<MemoryEffects::Allocate, TransformMappingResource>(range)) {
1972  << "TransformOpInterface did not expect "
1973  "'allocate' memory effect on an operand";
1974  diag.attachNote() << "specified for operand #"
1975  << operand.getOperandNumber();
1976  return diag;
1977  }
1978  if (!firstConsumedOperand &&
1979  ::hasEffect<MemoryEffects::Free, TransformMappingResource>(range)) {
1980  firstConsumedOperand = operand.getOperandNumber();
1981  }
1982  }
1983 
1984  if (firstConsumedOperand &&
1985  !::hasEffect<MemoryEffects::Write, PayloadIRResource>(effects)) {
1987  op->emitError()
1988  << "TransformOpInterface expects ops consuming operands to have a "
1989  "'write' effect on the payload resource";
1990  diag.attachNote() << "consumes operand #" << *firstConsumedOperand;
1991  return diag;
1992  }
1993 
1994  for (OpResult result : op->getResults()) {
1995  auto range = effectsOn(result);
1996  if (!::hasEffect<MemoryEffects::Allocate, TransformMappingResource>(
1997  range)) {
1999  op->emitError() << "TransformOpInterface requires 'allocate' memory "
2000  "effect to be specified for results";
2001  diag.attachNote() << "no 'allocate' effect specified for result #"
2002  << result.getResultNumber();
2003  return diag;
2004  }
2005  }
2006 
2007  return success();
2008 }
2009 
2010 //===----------------------------------------------------------------------===//
2011 // Entry point.
2012 //===----------------------------------------------------------------------===//
2013 
2015  Operation *payloadRoot, TransformOpInterface transform,
2016  const RaggedArray<MappedValue> &extraMapping,
2017  const TransformOptions &options, bool enforceToplevelTransformOp,
2018  function_ref<void(TransformState &)> stateInitializer,
2019  function_ref<LogicalResult(TransformState &)> stateExporter) {
2020  if (enforceToplevelTransformOp) {
2021  if (!transform->hasTrait<PossibleTopLevelTransformOpTrait>() ||
2022  transform->getNumOperands() != 0) {
2023  return transform->emitError()
2024  << "expected transform to start at the top-level transform op";
2025  }
2026  } else if (failed(
2028  return failure();
2029  }
2030 
2031  TransformState state(transform->getParentRegion(), payloadRoot, extraMapping,
2032  options);
2033  if (stateInitializer)
2034  stateInitializer(state);
2035  if (state.applyTransform(transform).checkAndReport().failed())
2036  return failure();
2037  if (stateExporter)
2038  return stateExporter(state);
2039  return success();
2040 }
2041 
2042 //===----------------------------------------------------------------------===//
2043 // Generated interface implementation.
2044 //===----------------------------------------------------------------------===//
2045 
2046 #include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.cpp.inc"
2047 #include "mlir/Dialect/Transform/Interfaces/TransformTypeInterfaces.cpp.inc"
static InFlightDiagnostic emitDiag(Location location, DiagnosticSeverity severity, const Twine &message)
Helper function used to emit a diagnostic with an optionally empty twine message.
static std::string diag(const llvm::Value &value)
static llvm::ManagedStatic< PassManagerOptions > options
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
#define DEBUG_TYPE_FULL
#define FULL_LDBG(X)
static void remapEffects(MemoryEffectOpInterface iface, BlockArgument source, OpOperand *target, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
Appends to effects the memory effect instances on target with the same resource and effect as the one...
static void remapArgumentEffects(Block &block, MutableArrayRef< OpOperand > operands, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
Appends to effects the same effects as the operations of block have on block arguments but associated...
static bool happensBefore(Operation *a, Operation *b)
Return true if a happens before b, i.e., a or one of its ancestors properly dominates b and b is not ...
void dropMappingEntry(Mapping &mapping, Key key, Mapped mapped)
#define DEBUG_PRINT_AFTER_ALL
DiagnosedSilenceableFailure checkRepeatedConsumptionInOperand(ArrayRef< T > payload, transform::TransformOpInterface transform, unsigned operandNumber)
#define DBGS()
static DiagnosedSilenceableFailure dispatchMappedValues(Value handle, ArrayRef< transform::MappedValue > values, function_ref< LogicalResult(ArrayRef< Operation * >)> operationsFn, function_ref< LogicalResult(ArrayRef< transform::Param >)> paramsFn, function_ref< LogicalResult(ValueRange)> valuesFn)
Given a list of MappedValues, cast them to the value kind implied by the interface of the handle type...
static SmallVector< T > castVector(ArrayRef< transform::MappedValue > range)
Attributes are known-constant values of operations.
Definition: Attributes.h:25
This class represents an argument of a Block.
Definition: Value.h:295
Block * getOwner() const
Returns the block that owns this argument.
Definition: Value.h:304
unsigned getArgNumber() const
Returns the number of this argument.
Definition: Value.h:307
Block represents an ordered list of Operations.
Definition: Block.h:33
MutableArrayRef< BlockArgument > BlockArgListType
Definition: Block.h:85
Operation * findAncestorOpInBlock(Operation &op)
Returns 'op' if 'op' lies in this block, or otherwise finds the ancestor operation of 'op' that lies ...
Definition: Block.cpp:76
BlockArgument getArgument(unsigned i)
Definition: Block.h:129
unsigned getNumArguments()
Definition: Block.h:128
Operation * getTerminator()
Get the terminator operation of this block.
Definition: Block.cpp:246
BlockArgListType getArguments()
Definition: Block.h:87
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
Definition: Block.cpp:33
A compatibility class connecting InFlightDiagnostic to DiagnosedSilenceableFailure while providing an...
The result of a transform IR operation application.
LogicalResult silence()
Converts silenceable failure into LogicalResult success without reporting the diagnostic,...
static DiagnosedSilenceableFailure success()
Constructs a DiagnosedSilenceableFailure in the success state.
Diagnostic & attachNote(std::optional< Location > loc=std::nullopt)
Attaches a note to the last diagnostic.
std::string getMessage() const
Returns the diagnostic message without emitting it.
bool isDefiniteFailure() const
Returns true if this is a definite failure.
LogicalResult checkAndReport()
Converts all kinds of failure into a LogicalResult failure, emitting the diagnostic if necessary.
static DiagnosedSilenceableFailure silenceableFailure(Diagnostic &&diag)
Constructs a DiagnosedSilenceableFailure in the silenceable failure state, ready to emit the given di...
bool succeeded() const
Returns true if this is a success.
static DiagnosedSilenceableFailure definiteFailure()
Constructs a DiagnosedSilenceableFailure in the failure state.
bool isSilenceableFailure() const
Returns true if this is a silenceable failure.
This class contains all of the information necessary to report a diagnostic to the DiagnosticEngine.
Definition: Diagnostics.h:155
IRValueT get() const
Return the current value being used by this operand.
Definition: UseDefLists.h:160
This class represents a diagnostic that is inflight and set to be reported.
Definition: Diagnostics.h:314
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
void setListener(Listener *newListener)
Sets the listener of this builder to the one provided.
Definition: Builders.h:314
This class represents an operand of an operation.
Definition: Value.h:243
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Definition: Value.cpp:216
Set of flags used to control the behavior of the various IR print methods (e.g.
This is a value defined by a result of an operation.
Definition: Value.h:433
unsigned getResultNumber() const
Returns the number of this result.
Definition: Value.h:445
This class provides the API for a sub-set of ops that are known to be constant-like.
This class provides the API for ops that are known to be isolated from above.
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
StringAttr getIdentifier() const
Return the name of this operation as a StringAttr.
T::Concept * getInterface() const
Returns an instance of the concept object for the given interface if it was registered to this operat...
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Value getOperand(unsigned idx)
Definition: Operation.h:350
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition: Operation.h:750
bool isBeforeInBlock(Operation *other)
Given an operation 'other' that is within the same parent block, return whether the current operation...
Definition: Operation.cpp:386
unsigned getNumRegions()
Returns the number of regions held by this operation.
Definition: Operation.h:674
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
unsigned getNumOperands()
Definition: Operation.h:346
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition: Operation.h:234
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:268
Block * getBlock()
Returns the operation block that contains this operation.
Definition: Operation.h:213
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Definition: Operation.h:687
Operation * getParentWithTrait()
Returns the closest surrounding parent operation with trait Trait.
Definition: Operation.h:248
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:119
MutableArrayRef< OpOperand > getOpOperands()
Definition: Operation.h:383
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:378
result_range getOpResults()
Definition: Operation.h:420
result_range getResults()
Definition: Operation.h:415
bool isProperAncestor(Operation *other)
Return true if this operation is a proper ancestor of the other operation.
Definition: Operation.cpp:219
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:671
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:404
A 2D array where each row may have different length.
Definition: RaggedArray.h:18
size_t size() const
Returns the number of rows in the 2D array.
Definition: RaggedArray.h:21
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
Operation * getParentOp()
Return the parent operation this region is attached to.
Definition: Region.h:200
Block & front()
Definition: Region.h:65
This class implements the result iterators for the Operation class.
Definition: ValueRange.h:247
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:358
This class represents a specific instance of an effect.
Resource * getResource() const
Return the resource that the effect applies to.
EffectT * getEffect() const
Return the effect being applied.
Value getValue() const
Return the value the effect is applied on, or nullptr if there isn't a known value being affected.
static DerivedEffect * get()
Returns a unique instance for the derived effect class.
static TransformMappingResource * get()
Returns a unique instance for the given effect class.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:26
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
Operation * getOwner() const
Return the owner of this operand.
Definition: UseDefLists.h:38
A list of results of applying a transform op with ApplyEachOpTrait to a single payload operation,...
size_t size() const
Returns the number of elements in the list.
A specialized listener that keeps track of cases in which no replacement payload could be found.
bool failed() const
Return "true" if this tracking listener had a failure.
std::string getLatestMatchFailureMessage()
Return the latest match notification message.
void notifyPayloadReplacementNotFound(Operation *op, ValueRange values, DiagnosedSilenceableFailure &&diag) override
This function is called when a tracked payload op is dropped because no replacement op was found.
void notifyMatchFailure(Location loc, function_ref< void(Diagnostic &)> reasonCallback) override
Notify the listener that the pattern failed to match the given operation, and provide a callback to p...
DiagnosedSilenceableFailure checkAndResetError()
Check and return the current error state of this listener.
This trait is supposed to be attached to Transform dialect operations that can be standalone top-leve...
TrackingListener failures are reported only for ops that have this trait.
void notifyMatchFailure(Location loc, function_ref< void(Diagnostic &)> reasonCallback) override
Notify the listener that the pattern failed to match the given operation, and provide a callback to p...
TrackingListener(TransformState &state, TransformOpInterface op, TrackingListenerConfig config=TrackingListenerConfig())
Create a new TrackingListener for usage in the specified transform op.
static Operation * getCommonDefiningOp(ValueRange values)
Return the single op that defines all given values (if any).
virtual DiagnosedSilenceableFailure findReplacementOp(Operation *&result, Operation *op, ValueRange newValues) const
Return a replacement payload op for the given op, which is going to be replaced with the given values...
Options controlling the application of transform operations by the TransformState.
Local mapping between values defined by a specific op implementing the TransformOpInterface and the p...
void setValues(OpResult handle, Range &&values)
Indicates that the result of the transform IR op at the given position corresponds to the given range...
void setParams(OpResult value, ArrayRef< TransformState::Param > params)
Indicates that the result of the transform IR op at the given position corresponds to the given list ...
void set(OpResult value, Range &&ops)
Indicates that the result of the transform IR op at the given position corresponds to the given list ...
void setRemainingToEmpty(TransformOpInterface transform)
Sets the currently unset results to empty lists of the kind expected by the corresponding results of ...
void setMappedValues(OpResult handle, ArrayRef< MappedValue > values)
Indicates that the result of the transform IR op at the given position corresponds to the given range...
This is a special rewriter to be used in transform op implementations, providing additional helper fu...
TransformRewriter(MLIRContext *ctx, ErrorCheckingTrackingListener *listener)
Create a new TransformRewriter.
bool hasTrackingFailures() const
Return "true" if the tracking listener had failures.
LogicalResult notifyPayloadOperationReplaced(Operation *op, Operation *replacement)
Notify the transform dialect interpreter that the given op has been replaced with another op and that...
void silenceTrackingFailure()
Silence all tracking failures that have been encountered so far.
LogicalResult replacePayloadOp(Operation *op, Operation *replacement)
Replaces the given payload op with another op.
virtual ~Extension()
Base virtual destructor.
LogicalResult replacePayloadValue(Value value, Value replacement)
Replaces the given payload value with another value.
A RAII object maintaining a "stack frame" for a transform IR region.
~RegionScope()
Forgets the mapping from or to values defined in the associated transform IR region,...
The state maintained across applications of various ops implementing the TransformOpInterface.
LogicalResult getHandlesForPayloadValue(Value payloadValue, SmallVectorImpl< Value > &handles, bool includeOutOfScope=false) const
Populates handles with all handles pointing to the given payload IR value.
LogicalResult mapBlockArguments(BlockArgument argument, ArrayRef< Operation * > operations)
Records the mapping between a block argument in the transform IR and a list of operations in the payl...
DiagnosedSilenceableFailure applyTransform(TransformOpInterface transform)
Applies the transformation specified by the given transform op and updates the state accordingly.
ArrayRef< Attribute > getParams(Value value) const
Returns the list of parameters that the given transform IR value corresponds to.
LogicalResult mapBlockArgument(BlockArgument argument, ArrayRef< MappedValue > values)
LogicalResult getHandlesForPayloadOp(Operation *op, SmallVectorImpl< Value > &handles, bool includeOutOfScope=false) const
Populates handles with all handles pointing to the given Payload IR op.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
LogicalResult verifyTransformOpInterface(Operation *op)
Verification hook for TransformOpInterface.
LogicalResult appendValueMappings(MutableArrayRef< SmallVector< transform::MappedValue >> mappings, ValueRange values, const transform::TransformState &state, bool flatten=true)
Appends the entities associated with the given transform values in state to the pre-existing list of ...
void setApplyToOneResults(Operation *transformOp, TransformResults &transformResults, ArrayRef< ApplyToEachResultList > results)
"Transpose" the results produced by individual applications, arranging them per result value of the t...
void forwardTerminatorOperands(Block *block, transform::TransformState &state, transform::TransformResults &results)
Populates results with payload associations that match exactly those of the operands to block's termi...
void getParamProducerTransformOpTraitEffects(Operation *op, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
Non-template implementation of ParamProducerTransformOpTrait::getEffects().
LogicalResult checkNestedConsumption(Location loc, ArrayRef< Operation * > targets)
Reports an error and returns failure if targets contains an ancestor operation before its descendant ...
LogicalResult mapPossibleTopLevelTransformOpBlockArguments(TransformState &state, Operation *op, Region &region)
Maps the only block argument of the op with PossibleTopLevelTransformOpTrait to either the list of op...
TransformState makeTransformStateForTesting(Region *region, Operation *payloadRoot)
Make a dummy transform state for testing purposes.
LogicalResult checkApplyToOne(Operation *transformOp, Location payloadOpLoc, const ApplyToEachResultList &partialResult)
Check that the contents of partialResult matches the number, kind (payload op or parameter) and nulli...
SmallVector< OpOperand * > getConsumedHandleOpOperands(transform::TransformOpInterface transformOp)
Returns all operands that are handles and being consumed by the given op.
LogicalResult verifyParamProducerTransformOpTrait(Operation *op)
Non-template implementation of ParamProducerTransformOpTrait::verify().
void prepareValueMappings(SmallVectorImpl< SmallVector< transform::MappedValue >> &mappings, ValueRange values, const transform::TransformState &state)
Populates mappings with mapped values associated with the given transform IR values in the given stat...
LogicalResult verifyPossibleTopLevelTransformOpTrait(Operation *op)
Verification hook for PossibleTopLevelTransformOpTrait.
void getPotentialTopLevelEffects(Operation *operation, Value root, Block &body, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
Populates effects with side effects implied by PossibleTopLevelTransformOpTrait for the given operati...
void onlyReadsPayload(SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
LogicalResult applyTransforms(Operation *payloadRoot, TransformOpInterface transform, const RaggedArray< MappedValue > &extraMapping={}, const TransformOptions &options=TransformOptions(), bool enforceToplevelTransformOp=true, function_ref< void(TransformState &)> stateInitializer=nullptr, function_ref< LogicalResult(TransformState &)> stateExporter=nullptr)
Entry point to the Transform dialect infrastructure.
bool isHandleConsumed(Value handle, transform::TransformOpInterface transform)
Checks whether the transform op consumes the given handle.
void getConsumedBlockArguments(Block &block, llvm::SmallDenseSet< unsigned > &consumedArguments)
Populates consumedArguments with positions of block arguments that are consumed by the operations in ...
bool doesModifyPayload(transform::TransformOpInterface transform)
Checks whether the transform op modifies the payload.
void producesHandle(ResultRange handles, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
bool doesReadPayload(transform::TransformOpInterface transform)
Checks whether the transform op reads the payload.
void consumesHandle(MutableArrayRef< OpOperand > handles, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
Populates effects with the memory effects indicating the operation on the given handle value:
void onlyReadsHandle(MutableArrayRef< OpOperand > handles, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
void modifiesPayload(SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
Populates effects with the memory effects indicating the access to payload IR resource.
Include the generated interface declarations.
const FrozenRewritePatternSet GreedyRewriteConfig config
DiagnosedSilenceableFailure emitSilenceableFailure(Location loc, const Twine &message={})
Emits a silenceable failure with the given message.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
DiagnosedDefiniteFailure emitDefiniteFailure(Location loc, const Twine &message={})
Emits a definite failure with the given message.
bool hasEffect(Operation *op)
Returns true if op has an effect of type EffectTy.
Represents a range (offset, size, and stride) where each element of the triple may be dynamic or stat...
A configuration object for customizing a TrackingListener.