MLIR  22.0.0git
TransformInterfaces.cpp
Go to the documentation of this file.
1 //===- TransformInterfaces.cpp - Transform Dialect Interfaces -------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
11 #include "mlir/IR/Diagnostics.h"
12 #include "mlir/IR/Operation.h"
13 #include "mlir/IR/PatternMatch.h"
15 #include "llvm/ADT/STLExtras.h"
16 #include "llvm/ADT/ScopeExit.h"
17 #include "llvm/ADT/iterator.h"
18 #include "llvm/Support/Debug.h"
19 #include "llvm/Support/DebugLog.h"
20 #include "llvm/Support/ErrorHandling.h"
21 #include "llvm/Support/InterleavedRange.h"
22 
23 #define DEBUG_TYPE "transform-dialect"
24 #define DEBUG_PRINT_AFTER_ALL "transform-dialect-print-top-level-after-all"
25 #define FULL_LDBG() LDBG(4)
26 
27 using namespace mlir;
28 
29 //===----------------------------------------------------------------------===//
30 // Helper functions
31 //===----------------------------------------------------------------------===//
32 
33 /// Return true if `a` happens before `b`, i.e., `a` or one of its ancestors
34 /// properly dominates `b` and `b` is not inside `a`.
35 static bool happensBefore(Operation *a, Operation *b) {
36  do {
37  if (a->isProperAncestor(b))
38  return false;
39  if (Operation *bAncestor = a->getBlock()->findAncestorOpInBlock(*b)) {
40  return a->isBeforeInBlock(bAncestor);
41  }
42  } while ((a = a->getParentOp()));
43  return false;
44 }
45 
46 //===----------------------------------------------------------------------===//
47 // TransformState
48 //===----------------------------------------------------------------------===//
49 
50 constexpr const Value transform::TransformState::kTopLevelValue;
51 
52 transform::TransformState::TransformState(
53  Region *region, Operation *payloadRoot,
54  const RaggedArray<MappedValue> &extraMappings,
55  const TransformOptions &options)
56  : topLevel(payloadRoot), options(options) {
57  topLevelMappedValues.reserve(extraMappings.size());
58  for (ArrayRef<MappedValue> mapping : extraMappings)
59  topLevelMappedValues.push_back(mapping);
60  if (region) {
61  RegionScope *scope = new RegionScope(*this, *region);
62  topLevelRegionScope.reset(scope);
63  }
64 }
65 
66 Operation *transform::TransformState::getTopLevel() const { return topLevel; }
67 
69 transform::TransformState::getPayloadOpsView(Value value) const {
70  const TransformOpMapping &operationMapping = getMapping(value).direct;
71  auto iter = operationMapping.find(value);
72  assert(iter != operationMapping.end() &&
73  "cannot find mapping for payload handle (param/value handle "
74  "provided?)");
75  return iter->getSecond();
76 }
77 
79  const ParamMapping &mapping = getMapping(value).params;
80  auto iter = mapping.find(value);
81  assert(iter != mapping.end() && "cannot find mapping for param handle "
82  "(operation/value handle provided?)");
83  return iter->getSecond();
84 }
85 
87 transform::TransformState::getPayloadValuesView(Value handleValue) const {
88  const ValueMapping &mapping = getMapping(handleValue).values;
89  auto iter = mapping.find(handleValue);
90  assert(iter != mapping.end() && "cannot find mapping for value handle "
91  "(param/operation handle provided?)");
92  return iter->getSecond();
93 }
94 
96  Operation *op, SmallVectorImpl<Value> &handles,
97  bool includeOutOfScope) const {
98  bool found = false;
99  for (const auto &[region, mapping] : llvm::reverse(mappings)) {
100  auto iterator = mapping->reverse.find(op);
101  if (iterator != mapping->reverse.end()) {
102  llvm::append_range(handles, iterator->getSecond());
103  found = true;
104  }
105  // Stop looking when reaching a region that is isolated from above.
106  if (!includeOutOfScope &&
108  break;
109  }
110 
111  return success(found);
112 }
113 
115  Value payloadValue, SmallVectorImpl<Value> &handles,
116  bool includeOutOfScope) const {
117  bool found = false;
118  for (const auto &[region, mapping] : llvm::reverse(mappings)) {
119  auto iterator = mapping->reverseValues.find(payloadValue);
120  if (iterator != mapping->reverseValues.end()) {
121  llvm::append_range(handles, iterator->getSecond());
122  found = true;
123  }
124  // Stop looking when reaching a region that is isolated from above.
125  if (!includeOutOfScope &&
127  break;
128  }
129 
130  return success(found);
131 }
132 
133 /// Given a list of MappedValues, cast them to the value kind implied by the
134 /// interface of the handle type, and dispatch to one of the callbacks.
137  function_ref<LogicalResult(ArrayRef<Operation *>)> operationsFn,
138  function_ref<LogicalResult(ArrayRef<transform::Param>)> paramsFn,
139  function_ref<LogicalResult(ValueRange)> valuesFn) {
140  if (llvm::isa<transform::TransformHandleTypeInterface>(handle.getType())) {
141  SmallVector<Operation *> operations;
142  operations.reserve(values.size());
143  for (transform::MappedValue value : values) {
144  if (auto *op = llvm::dyn_cast_if_present<Operation *>(value)) {
145  operations.push_back(op);
146  continue;
147  }
148  return emitSilenceableFailure(handle.getLoc())
149  << "wrong kind of value provided for top-level operation handle";
150  }
151  if (failed(operationsFn(operations)))
154  }
155 
156  if (llvm::isa<transform::TransformValueHandleTypeInterface>(
157  handle.getType())) {
158  SmallVector<Value> payloadValues;
159  payloadValues.reserve(values.size());
160  for (transform::MappedValue value : values) {
161  if (auto v = llvm::dyn_cast_if_present<Value>(value)) {
162  payloadValues.push_back(v);
163  continue;
164  }
165  return emitSilenceableFailure(handle.getLoc())
166  << "wrong kind of value provided for the top-level value handle";
167  }
168  if (failed(valuesFn(payloadValues)))
171  }
172 
173  assert(llvm::isa<transform::TransformParamTypeInterface>(handle.getType()) &&
174  "unsupported kind of block argument");
176  parameters.reserve(values.size());
177  for (transform::MappedValue value : values) {
178  if (auto attr = llvm::dyn_cast_if_present<Attribute>(value)) {
179  parameters.push_back(attr);
180  continue;
181  }
182  return emitSilenceableFailure(handle.getLoc())
183  << "wrong kind of value provided for top-level parameter";
184  }
185  if (failed(paramsFn(parameters)))
188 }
189 
190 LogicalResult
192  ArrayRef<MappedValue> values) {
193  return dispatchMappedValues(
194  argument, values,
195  [&](ArrayRef<Operation *> operations) {
196  return setPayloadOps(argument, operations);
197  },
198  [&](ArrayRef<Param> params) {
199  return setParams(argument, params);
200  },
201  [&](ValueRange payloadValues) {
202  return setPayloadValues(argument, payloadValues);
203  })
204  .checkAndReport();
205 }
206 
208  Block::BlockArgListType arguments,
210  for (auto &&[argument, values] : llvm::zip_equal(arguments, mapping))
211  if (failed(mapBlockArgument(argument, values)))
212  return failure();
213  return success();
214 }
215 
216 LogicalResult
217 transform::TransformState::setPayloadOps(Value value,
218  ArrayRef<Operation *> targets) {
219  assert(value != kTopLevelValue &&
220  "attempting to reset the transformation root");
221  assert(llvm::isa<TransformHandleTypeInterface>(value.getType()) &&
222  "wrong handle type");
223 
224  for (Operation *target : targets) {
225  if (target)
226  continue;
227  return emitError(value.getLoc())
228  << "attempting to assign a null payload op to this transform value";
229  }
230 
231  auto iface = llvm::cast<TransformHandleTypeInterface>(value.getType());
233  iface.checkPayload(value.getLoc(), targets);
234  if (failed(result.checkAndReport()))
235  return failure();
236 
237  // Setting new payload for the value without cleaning it first is a misuse of
238  // the API, assert here.
239  SmallVector<Operation *> storedTargets(targets);
240  Mappings &mappings = getMapping(value);
241  bool inserted =
242  mappings.direct.insert({value, std::move(storedTargets)}).second;
243  assert(inserted && "value is already associated with another list");
244  (void)inserted;
245 
246  for (Operation *op : targets)
247  mappings.reverse[op].push_back(value);
248 
249  return success();
250 }
251 
252 LogicalResult
253 transform::TransformState::setPayloadValues(Value handle,
254  ValueRange payloadValues) {
255  assert(handle != nullptr && "attempting to set params for a null value");
256  assert(llvm::isa<TransformValueHandleTypeInterface>(handle.getType()) &&
257  "wrong handle type");
258 
259  for (Value payload : payloadValues) {
260  if (payload)
261  continue;
262  return emitError(handle.getLoc()) << "attempting to assign a null payload "
263  "value to this transform handle";
264  }
265 
266  auto iface = llvm::cast<TransformValueHandleTypeInterface>(handle.getType());
267  SmallVector<Value> payloadValueVector = llvm::to_vector(payloadValues);
269  iface.checkPayload(handle.getLoc(), payloadValueVector);
270  if (failed(result.checkAndReport()))
271  return failure();
272 
273  Mappings &mappings = getMapping(handle);
274  bool inserted =
275  mappings.values.insert({handle, std::move(payloadValueVector)}).second;
276  assert(
277  inserted &&
278  "value handle is already associated with another list of payload values");
279  (void)inserted;
280 
281  for (Value payload : payloadValues)
282  mappings.reverseValues[payload].push_back(handle);
283 
284  return success();
285 }
286 
287 LogicalResult transform::TransformState::setParams(Value value,
288  ArrayRef<Param> params) {
289  assert(value != nullptr && "attempting to set params for a null value");
290 
291  for (Attribute attr : params) {
292  if (attr)
293  continue;
294  return emitError(value.getLoc())
295  << "attempting to assign a null parameter to this transform value";
296  }
297 
298  auto valueType = llvm::dyn_cast<TransformParamTypeInterface>(value.getType());
299  assert(value &&
300  "cannot associate parameter with a value of non-parameter type");
302  valueType.checkPayload(value.getLoc(), params);
303  if (failed(result.checkAndReport()))
304  return failure();
305 
306  Mappings &mappings = getMapping(value);
307  bool inserted =
308  mappings.params.insert({value, llvm::to_vector(params)}).second;
309  assert(inserted && "value is already associated with another list of params");
310  (void)inserted;
311  return success();
312 }
313 
314 template <typename Mapping, typename Key, typename Mapped>
315 static void dropMappingEntry(Mapping &mapping, Key key, Mapped mapped) {
316  auto it = mapping.find(key);
317  if (it == mapping.end())
318  return;
319 
320  llvm::erase(it->getSecond(), mapped);
321  if (it->getSecond().empty())
322  mapping.erase(it);
323 }
324 
325 void transform::TransformState::forgetMapping(Value opHandle,
326  ValueRange origOpFlatResults,
327  bool allowOutOfScope) {
328  Mappings &mappings = getMapping(opHandle, allowOutOfScope);
329  for (Operation *op : mappings.direct[opHandle])
330  dropMappingEntry(mappings.reverse, op, opHandle);
331  mappings.direct.erase(opHandle);
332 #if LLVM_ENABLE_ABI_BREAKING_CHECKS
333  // Payload IR is removed from the mapping. This invalidates the respective
334  // iterators.
335  mappings.incrementTimestamp(opHandle);
336 #endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
337 
338  for (Value opResult : origOpFlatResults) {
339  SmallVector<Value> resultHandles;
340  (void)getHandlesForPayloadValue(opResult, resultHandles);
341  for (Value resultHandle : resultHandles) {
342  Mappings &localMappings = getMapping(resultHandle);
343  dropMappingEntry(localMappings.values, resultHandle, opResult);
344 #if LLVM_ENABLE_ABI_BREAKING_CHECKS
345  // Payload IR is removed from the mapping. This invalidates the respective
346  // iterators.
347  mappings.incrementTimestamp(resultHandle);
348 #endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
349  dropMappingEntry(localMappings.reverseValues, opResult, resultHandle);
350  }
351  }
352 }
353 
354 void transform::TransformState::forgetValueMapping(
355  Value valueHandle, ArrayRef<Operation *> payloadOperations) {
356  Mappings &mappings = getMapping(valueHandle);
357  for (Value payloadValue : mappings.reverseValues[valueHandle])
358  dropMappingEntry(mappings.reverseValues, payloadValue, valueHandle);
359  mappings.values.erase(valueHandle);
360 #if LLVM_ENABLE_ABI_BREAKING_CHECKS
361  // Payload IR is removed from the mapping. This invalidates the respective
362  // iterators.
363  mappings.incrementTimestamp(valueHandle);
364 #endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
365 
366  for (Operation *payloadOp : payloadOperations) {
367  SmallVector<Value> opHandles;
368  (void)getHandlesForPayloadOp(payloadOp, opHandles);
369  for (Value opHandle : opHandles) {
370  Mappings &localMappings = getMapping(opHandle);
371  dropMappingEntry(localMappings.direct, opHandle, payloadOp);
372  dropMappingEntry(localMappings.reverse, payloadOp, opHandle);
373 
374 #if LLVM_ENABLE_ABI_BREAKING_CHECKS
375  // Payload IR is removed from the mapping. This invalidates the respective
376  // iterators.
377  localMappings.incrementTimestamp(opHandle);
378 #endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
379  }
380  }
381 }
382 
383 LogicalResult
384 transform::TransformState::replacePayloadOp(Operation *op,
385  Operation *replacement) {
386  // TODO: consider invalidating the handles to nested objects here.
387 
388 #ifndef NDEBUG
389  for (Value opResult : op->getResults()) {
390  SmallVector<Value> valueHandles;
391  (void)getHandlesForPayloadValue(opResult, valueHandles,
392  /*includeOutOfScope=*/true);
393  assert(valueHandles.empty() && "expected no mapping to old results");
394  }
395 #endif // NDEBUG
396 
397  // Drop the mapping between the op and all handles that point to it. Fail if
398  // there are no handles.
399  SmallVector<Value> opHandles;
400  if (failed(getHandlesForPayloadOp(op, opHandles, /*includeOutOfScope=*/true)))
401  return failure();
402  for (Value handle : opHandles) {
403  Mappings &mappings = getMapping(handle, /*allowOutOfScope=*/true);
404  dropMappingEntry(mappings.reverse, op, handle);
405  }
406 
407  // Replace the pointed-to object of all handles with the replacement object.
408  // In case a payload op was erased (replacement object is nullptr), a nullptr
409  // is stored in the mapping. These nullptrs are removed after each transform.
410  // Furthermore, nullptrs are not enumerated by payload op iterators. The
411  // relative order of ops is preserved.
412  //
413  // Removing an op from the mapping would be problematic because removing an
414  // element from an array invalidates iterators; merely changing the value of
415  // elements does not.
416  for (Value handle : opHandles) {
417  Mappings &mappings = getMapping(handle, /*allowOutOfScope=*/true);
418  auto it = mappings.direct.find(handle);
419  if (it == mappings.direct.end())
420  continue;
421 
422  SmallVector<Operation *, 2> &association = it->getSecond();
423  // Note that an operation may be associated with the handle more than once.
424  for (Operation *&mapped : association) {
425  if (mapped == op)
426  mapped = replacement;
427  }
428 
429  if (replacement) {
430  mappings.reverse[replacement].push_back(handle);
431  } else {
432  opHandlesToCompact.insert(handle);
433  }
434  }
435 
436  return success();
437 }
438 
439 LogicalResult
440 transform::TransformState::replacePayloadValue(Value value, Value replacement) {
441  SmallVector<Value> valueHandles;
442  if (failed(getHandlesForPayloadValue(value, valueHandles,
443  /*includeOutOfScope=*/true)))
444  return failure();
445 
446  for (Value handle : valueHandles) {
447  Mappings &mappings = getMapping(handle, /*allowOutOfScope=*/true);
448  dropMappingEntry(mappings.reverseValues, value, handle);
449 
450  // If replacing with null, that is erasing the mapping, drop the mapping
451  // between the handles and the IR objects
452  if (!replacement) {
453  dropMappingEntry(mappings.values, handle, value);
454 #if LLVM_ENABLE_ABI_BREAKING_CHECKS
455  // Payload IR is removed from the mapping. This invalidates the respective
456  // iterators.
457  mappings.incrementTimestamp(handle);
458 #endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
459  } else {
460  auto it = mappings.values.find(handle);
461  if (it == mappings.values.end())
462  continue;
463 
464  SmallVector<Value> &association = it->getSecond();
465  for (Value &mapped : association) {
466  if (mapped == value)
467  mapped = replacement;
468  }
469  mappings.reverseValues[replacement].push_back(handle);
470  }
471  }
472 
473  return success();
474 }
475 
476 void transform::TransformState::recordOpHandleInvalidationOne(
477  OpOperand &consumingHandle, ArrayRef<Operation *> potentialAncestors,
478  Operation *payloadOp, Value otherHandle, Value throughValue,
479  transform::TransformState::InvalidatedHandleMap &newlyInvalidated) const {
480  // If the op is associated with invalidated handle, skip the check as it
481  // may be reading invalid IR. This also ensures we report the first
482  // invalidation and not the last one.
483  if (invalidatedHandles.count(otherHandle) ||
484  newlyInvalidated.count(otherHandle))
485  return;
486 
487  FULL_LDBG() << "--recordOpHandleInvalidationOne";
488  FULL_LDBG() << "--ancestors: "
489  << llvm::interleaved(
490  llvm::make_pointee_range(potentialAncestors));
491 
492  Operation *owner = consumingHandle.getOwner();
493  unsigned operandNo = consumingHandle.getOperandNumber();
494  for (Operation *ancestor : potentialAncestors) {
495  // clang-format off
496  FULL_LDBG() << "----handle one ancestor: " << *ancestor;;
497 
498  FULL_LDBG() << "----of payload with name: "
499  << payloadOp->getName().getIdentifier();
500  FULL_LDBG() << "----of payload: " << *payloadOp;
501  // clang-format on
502  if (!ancestor->isAncestor(payloadOp))
503  continue;
504 
505  // Make sure the error-reporting lambda doesn't capture anything
506  // by-reference because it will go out of scope. Additionally, extract
507  // location from Payload IR ops because the ops themselves may be
508  // deleted before the lambda gets called.
509  Location ancestorLoc = ancestor->getLoc();
510  Location opLoc = payloadOp->getLoc();
511  std::optional<Location> throughValueLoc =
512  throughValue ? std::make_optional(throughValue.getLoc()) : std::nullopt;
513  newlyInvalidated[otherHandle] = [ancestorLoc, opLoc, owner, operandNo,
514  otherHandle,
515  throughValueLoc](Location currentLoc) {
516  InFlightDiagnostic diag = emitError(currentLoc)
517  << "op uses a handle invalidated by a "
518  "previously executed transform op";
519  diag.attachNote(otherHandle.getLoc()) << "handle to invalidated ops";
520  diag.attachNote(owner->getLoc())
521  << "invalidated by this transform op that consumes its operand #"
522  << operandNo
523  << " and invalidates all handles to payload IR entities associated "
524  "with this operand and entities nested in them";
525  diag.attachNote(ancestorLoc) << "ancestor payload op";
526  diag.attachNote(opLoc) << "nested payload op";
527  if (throughValueLoc) {
528  diag.attachNote(*throughValueLoc)
529  << "consumed handle points to this payload value";
530  }
531  };
532  }
533 }
534 
535 void transform::TransformState::recordValueHandleInvalidationByOpHandleOne(
536  OpOperand &opHandle, ArrayRef<Operation *> potentialAncestors,
537  Value payloadValue, Value valueHandle,
538  transform::TransformState::InvalidatedHandleMap &newlyInvalidated) const {
539  // If the op is associated with invalidated handle, skip the check as it
540  // may be reading invalid IR. This also ensures we report the first
541  // invalidation and not the last one.
542  if (invalidatedHandles.count(valueHandle) ||
543  newlyInvalidated.count(valueHandle))
544  return;
545 
546  for (Operation *ancestor : potentialAncestors) {
547  Operation *definingOp;
548  std::optional<unsigned> resultNo;
549  unsigned argumentNo = std::numeric_limits<unsigned>::max();
550  unsigned blockNo = std::numeric_limits<unsigned>::max();
551  unsigned regionNo = std::numeric_limits<unsigned>::max();
552  if (auto opResult = llvm::dyn_cast<OpResult>(payloadValue)) {
553  definingOp = opResult.getOwner();
554  resultNo = opResult.getResultNumber();
555  } else {
556  auto arg = llvm::cast<BlockArgument>(payloadValue);
557  definingOp = arg.getParentBlock()->getParentOp();
558  argumentNo = arg.getArgNumber();
559  blockNo = std::distance(arg.getOwner()->getParent()->begin(),
560  arg.getOwner()->getIterator());
561  regionNo = arg.getOwner()->getParent()->getRegionNumber();
562  }
563  assert(definingOp && "expected the value to be defined by an op as result "
564  "or block argument");
565  if (!ancestor->isAncestor(definingOp))
566  continue;
567 
568  Operation *owner = opHandle.getOwner();
569  unsigned operandNo = opHandle.getOperandNumber();
570  Location ancestorLoc = ancestor->getLoc();
571  Location opLoc = definingOp->getLoc();
572  Location valueLoc = payloadValue.getLoc();
573  newlyInvalidated[valueHandle] = [valueHandle, owner, operandNo, resultNo,
574  argumentNo, blockNo, regionNo, ancestorLoc,
575  opLoc, valueLoc](Location currentLoc) {
576  InFlightDiagnostic diag = emitError(currentLoc)
577  << "op uses a handle invalidated by a "
578  "previously executed transform op";
579  diag.attachNote(valueHandle.getLoc()) << "invalidated handle";
580  diag.attachNote(owner->getLoc())
581  << "invalidated by this transform op that consumes its operand #"
582  << operandNo
583  << " and invalidates all handles to payload IR entities "
584  "associated with this operand and entities nested in them";
585  diag.attachNote(ancestorLoc)
586  << "ancestor op associated with the consumed handle";
587  if (resultNo) {
588  diag.attachNote(opLoc)
589  << "op defining the value as result #" << *resultNo;
590  } else {
591  diag.attachNote(opLoc)
592  << "op defining the value as block argument #" << argumentNo
593  << " of block #" << blockNo << " in region #" << regionNo;
594  }
595  diag.attachNote(valueLoc) << "payload value";
596  };
597  }
598 }
599 
600 void transform::TransformState::recordOpHandleInvalidation(
601  OpOperand &handle, ArrayRef<Operation *> potentialAncestors,
602  Value throughValue,
603  transform::TransformState::InvalidatedHandleMap &newlyInvalidated) const {
604 
605  if (potentialAncestors.empty()) {
606  FULL_LDBG() << "----recording invalidation for empty handle: "
607  << handle.get();
608 
609  Operation *owner = handle.getOwner();
610  unsigned operandNo = handle.getOperandNumber();
611  newlyInvalidated[handle.get()] = [owner, operandNo](Location currentLoc) {
612  InFlightDiagnostic diag = emitError(currentLoc)
613  << "op uses a handle associated with empty "
614  "payload and invalidated by a "
615  "previously executed transform op";
616  diag.attachNote(owner->getLoc())
617  << "invalidated by this transform op that consumes its operand #"
618  << operandNo;
619  };
620  return;
621  }
622 
623  // Iterate over the mapping and invalidate aliasing handles. This is quite
624  // expensive and only necessary for error reporting in case of transform
625  // dialect misuse with dangling handles. Iteration over the handles is based
626  // on the assumption that the number of handles is significantly less than the
627  // number of IR objects (operations and values). Alternatively, we could walk
628  // the IR nested in each payload op associated with the given handle and look
629  // for handles associated with each operation and value.
630  for (const auto &[region, mapping] : llvm::reverse(mappings)) {
631  // Go over all op handle mappings and mark as invalidated any handle
632  // pointing to any of the payload ops associated with the given handle or
633  // any op nested in them.
634  for (const auto &[payloadOp, otherHandles] : mapping->reverse) {
635  for (Value otherHandle : otherHandles)
636  recordOpHandleInvalidationOne(handle, potentialAncestors, payloadOp,
637  otherHandle, throughValue,
638  newlyInvalidated);
639  }
640  // Go over all value handle mappings and mark as invalidated any handle
641  // pointing to any result of the payload op associated with the given handle
642  // or any op nested in them. Similarly invalidate handles to argument of
643  // blocks belonging to any region of any payload op associated with the
644  // given handle or any op nested in them.
645  for (const auto &[payloadValue, valueHandles] : mapping->reverseValues) {
646  for (Value valueHandle : valueHandles)
647  recordValueHandleInvalidationByOpHandleOne(handle, potentialAncestors,
648  payloadValue, valueHandle,
649  newlyInvalidated);
650  }
651 
652  // Stop lookup when reaching a region that is isolated from above.
654  break;
655  }
656 }
657 
658 void transform::TransformState::recordValueHandleInvalidation(
659  OpOperand &valueHandle,
660  transform::TransformState::InvalidatedHandleMap &newlyInvalidated) const {
661  // Invalidate other handles to the same value.
662  for (Value payloadValue : getPayloadValuesView(valueHandle.get())) {
663  SmallVector<Value> otherValueHandles;
664  (void)getHandlesForPayloadValue(payloadValue, otherValueHandles);
665  for (Value otherHandle : otherValueHandles) {
666  Operation *owner = valueHandle.getOwner();
667  unsigned operandNo = valueHandle.getOperandNumber();
668  Location valueLoc = payloadValue.getLoc();
669  newlyInvalidated[otherHandle] = [otherHandle, owner, operandNo,
670  valueLoc](Location currentLoc) {
671  InFlightDiagnostic diag = emitError(currentLoc)
672  << "op uses a handle invalidated by a "
673  "previously executed transform op";
674  diag.attachNote(otherHandle.getLoc()) << "invalidated handle";
675  diag.attachNote(owner->getLoc())
676  << "invalidated by this transform op that consumes its operand #"
677  << operandNo
678  << " and invalidates handles to the same values as associated with "
679  "it";
680  diag.attachNote(valueLoc) << "payload value";
681  };
682  }
683 
684  if (auto opResult = llvm::dyn_cast<OpResult>(payloadValue)) {
685  Operation *payloadOp = opResult.getOwner();
686  recordOpHandleInvalidation(valueHandle, payloadOp, payloadValue,
687  newlyInvalidated);
688  } else {
689  auto arg = llvm::dyn_cast<BlockArgument>(payloadValue);
690  for (Operation &payloadOp : *arg.getOwner())
691  recordOpHandleInvalidation(valueHandle, &payloadOp, payloadValue,
692  newlyInvalidated);
693  }
694  }
695 }
696 
697 /// Checks that the operation does not use invalidated handles as operands.
698 /// Reports errors and returns failure if it does. Otherwise, invalidates the
699 /// handles consumed by the operation as well as any handles pointing to payload
700 /// IR operations nested in the operations associated with the consumed handles.
701 LogicalResult transform::TransformState::checkAndRecordHandleInvalidationImpl(
702  transform::TransformOpInterface transform,
703  transform::TransformState::InvalidatedHandleMap &newlyInvalidated) const {
704  FULL_LDBG() << "--Start checkAndRecordHandleInvalidation";
705  auto memoryEffectsIface =
706  cast<MemoryEffectOpInterface>(transform.getOperation());
708  memoryEffectsIface.getEffectsOnResource(
710 
711  for (OpOperand &target : transform->getOpOperands()) {
712  FULL_LDBG() << "----iterate on handle: " << target.get();
713  // If the operand uses an invalidated handle, report it. If the operation
714  // allows handles to point to repeated payload operations, only report
715  // pre-existing invalidation errors. Otherwise, also report invalidations
716  // caused by the current transform operation affecting its other operands.
717  auto it = invalidatedHandles.find(target.get());
718  auto nit = newlyInvalidated.find(target.get());
719  if (it != invalidatedHandles.end()) {
720  FULL_LDBG() << "--End checkAndRecordHandleInvalidation, found already "
721  "invalidated -> FAILURE";
722  return it->getSecond()(transform->getLoc()), failure();
723  }
724  if (!transform.allowsRepeatedHandleOperands() &&
725  nit != newlyInvalidated.end()) {
726  FULL_LDBG() << "--End checkAndRecordHandleInvalidation, found newly "
727  "invalidated (by this op) -> FAILURE";
728  return nit->getSecond()(transform->getLoc()), failure();
729  }
730 
731  // Invalidate handles pointing to the operations nested in the operation
732  // associated with the handle consumed by this operation.
733  auto consumesTarget = [&](const MemoryEffects::EffectInstance &effect) {
734  return isa<MemoryEffects::Free>(effect.getEffect()) &&
735  effect.getValue() == target.get();
736  };
737  if (llvm::any_of(effects, consumesTarget)) {
738  FULL_LDBG() << "----found consume effect";
739  if (llvm::isa<transform::TransformHandleTypeInterface>(
740  target.get().getType())) {
741  FULL_LDBG() << "----recordOpHandleInvalidation";
742  SmallVector<Operation *> payloadOps =
743  llvm::to_vector(getPayloadOps(target.get()));
744  recordOpHandleInvalidation(target, payloadOps, nullptr,
745  newlyInvalidated);
746  } else if (llvm::isa<transform::TransformValueHandleTypeInterface>(
747  target.get().getType())) {
748  FULL_LDBG() << "----recordValueHandleInvalidation";
749  recordValueHandleInvalidation(target, newlyInvalidated);
750  } else {
751  FULL_LDBG()
752  << "----not a TransformHandle -> SKIP AND DROP ON THE FLOOR";
753  }
754  } else {
755  FULL_LDBG() << "----no consume effect -> SKIP";
756  }
757  }
758 
759  FULL_LDBG() << "--End checkAndRecordHandleInvalidation -> SUCCESS";
760  return success();
761 }
762 
763 LogicalResult transform::TransformState::checkAndRecordHandleInvalidation(
764  transform::TransformOpInterface transform) {
765  InvalidatedHandleMap newlyInvalidated;
766  LogicalResult checkResult =
767  checkAndRecordHandleInvalidationImpl(transform, newlyInvalidated);
768  invalidatedHandles.insert(std::make_move_iterator(newlyInvalidated.begin()),
769  std::make_move_iterator(newlyInvalidated.end()));
770  return checkResult;
771 }
772 
773 template <typename T>
776  transform::TransformOpInterface transform,
777  unsigned operandNumber) {
778  DenseSet<T> seen;
779  for (T p : payload) {
780  if (!seen.insert(p).second) {
782  transform.emitSilenceableError()
783  << "a handle passed as operand #" << operandNumber
784  << " and consumed by this operation points to a payload "
785  "entity more than once";
786  if constexpr (std::is_pointer_v<T>)
787  diag.attachNote(p->getLoc()) << "repeated target op";
788  else
789  diag.attachNote(p.getLoc()) << "repeated target value";
790  return diag;
791  }
792  }
794 }
795 
796 void transform::TransformState::compactOpHandles() {
797  for (Value handle : opHandlesToCompact) {
798  Mappings &mappings = getMapping(handle, /*allowOutOfScope=*/true);
799 #if LLVM_ENABLE_ABI_BREAKING_CHECKS
800  if (llvm::is_contained(mappings.direct[handle], nullptr))
801  // Payload IR is removed from the mapping. This invalidates the respective
802  // iterators.
803  mappings.incrementTimestamp(handle);
804 #endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
805  llvm::erase(mappings.direct[handle], nullptr);
806  }
807  opHandlesToCompact.clear();
808 }
809 
811 transform::TransformState::applyTransform(TransformOpInterface transform) {
812  LDBG() << "applying: "
813  << OpWithFlags(transform, OpPrintingFlags().skipRegions());
814  FULL_LDBG() << "Top-level payload before application:\n" << *getTopLevel();
815  auto printOnFailureRAII = llvm::make_scope_exit([this] {
816  (void)this;
817  LDBG() << "Failing Top-level payload:\n"
818  << OpWithFlags(getTopLevel(),
819  OpPrintingFlags().printGenericOpForm());
820  });
821 
822  // Set current transform op.
823  regionStack.back()->currentTransform = transform;
824 
825  // Expensive checks to detect invalid transform IR.
826  if (options.getExpensiveChecksEnabled()) {
827  FULL_LDBG() << "ExpensiveChecksEnabled";
828  if (failed(checkAndRecordHandleInvalidation(transform)))
830 
831  for (OpOperand &operand : transform->getOpOperands()) {
832  FULL_LDBG() << "iterate on handle: " << operand.get();
833  if (!isHandleConsumed(operand.get(), transform)) {
834  FULL_LDBG() << "--handle not consumed -> SKIP";
835  continue;
836  }
837  if (transform.allowsRepeatedHandleOperands()) {
838  FULL_LDBG() << "--op allows repeated handles -> SKIP";
839  continue;
840  }
841  FULL_LDBG() << "--handle is consumed";
842 
843  Type operandType = operand.get().getType();
844  if (llvm::isa<TransformHandleTypeInterface>(operandType)) {
845  FULL_LDBG() << "--checkRepeatedConsumptionInOperand for Operation*";
847  checkRepeatedConsumptionInOperand<Operation *>(
848  getPayloadOpsView(operand.get()), transform,
849  operand.getOperandNumber());
850  if (!check.succeeded()) {
851  FULL_LDBG() << "----FAILED";
852  return check;
853  }
854  } else if (llvm::isa<TransformValueHandleTypeInterface>(operandType)) {
855  FULL_LDBG() << "--checkRepeatedConsumptionInOperand For Value";
857  checkRepeatedConsumptionInOperand<Value>(
858  getPayloadValuesView(operand.get()), transform,
859  operand.getOperandNumber());
860  if (!check.succeeded()) {
861  FULL_LDBG() << "----FAILED";
862  return check;
863  }
864  } else {
865  FULL_LDBG() << "--not a TransformHandle -> SKIP AND DROP ON THE FLOOR";
866  }
867  }
868  }
869 
870  // Find which operands are consumed.
871  SmallVector<OpOperand *> consumedOperands =
872  transform.getConsumedHandleOpOperands();
873 
874  // Remember the results of the payload ops associated with the consumed
875  // op handles or the ops defining the value handles so we can drop the
876  // association with them later. This must happen here because the
877  // transformation may destroy or mutate them so we cannot traverse the payload
878  // IR after that.
879  SmallVector<Value> origOpFlatResults;
880  SmallVector<Operation *> origAssociatedOps;
881  for (OpOperand *opOperand : consumedOperands) {
882  Value operand = opOperand->get();
883  if (llvm::isa<TransformHandleTypeInterface>(operand.getType())) {
884  for (Operation *payloadOp : getPayloadOps(operand)) {
885  llvm::append_range(origOpFlatResults, payloadOp->getResults());
886  }
887  continue;
888  }
889  if (llvm::isa<TransformValueHandleTypeInterface>(operand.getType())) {
890  for (Value payloadValue : getPayloadValuesView(operand)) {
891  if (llvm::isa<OpResult>(payloadValue)) {
892  origAssociatedOps.push_back(payloadValue.getDefiningOp());
893  continue;
894  }
895  llvm::append_range(
896  origAssociatedOps,
897  llvm::map_range(*llvm::cast<BlockArgument>(payloadValue).getOwner(),
898  [](Operation &op) { return &op; }));
899  }
900  continue;
901  }
903  emitDefiniteFailure(transform->getLoc())
904  << "unexpectedly consumed a value that is not a handle as operand #"
905  << opOperand->getOperandNumber();
906  diag.attachNote(operand.getLoc())
907  << "value defined here with type " << operand.getType();
908  return diag;
909  }
910 
911  // Prepare rewriter and listener.
913  config.skipHandleFn = [&](Value handle) {
914  // Skip handle if it is dead.
915  auto scopeIt =
916  llvm::find_if(llvm::reverse(regionStack), [&](RegionScope *scope) {
917  return handle.getParentRegion() == scope->region;
918  });
919  assert(scopeIt != regionStack.rend() &&
920  "could not find region scope for handle");
921  RegionScope *scope = *scopeIt;
922  return llvm::all_of(handle.getUsers(), [&](Operation *user) {
923  return user == scope->currentTransform ||
924  happensBefore(user, scope->currentTransform);
925  });
926  };
927  transform::ErrorCheckingTrackingListener trackingListener(*this, transform,
928  config);
929  transform::TransformRewriter rewriter(transform->getContext(),
930  &trackingListener);
931 
932  // Compute the result but do not short-circuit the silenceable failure case as
933  // we still want the handles to propagate properly so the "suppress" mode can
934  // proceed on a best effort basis.
935  transform::TransformResults results(transform->getNumResults());
936  DiagnosedSilenceableFailure result(transform.apply(rewriter, results, *this));
937  compactOpHandles();
938 
939  // Error handling: fail if transform or listener failed.
940  DiagnosedSilenceableFailure trackingFailure =
941  trackingListener.checkAndResetError();
942  if (!transform->hasTrait<ReportTrackingListenerFailuresOpTrait>() ||
943  transform->hasAttr(FindPayloadReplacementOpInterface::
944  kSilenceTrackingFailuresAttrName)) {
945  // Only report failures for ReportTrackingListenerFailuresOpTrait ops. Also
946  // do not report failures if the above mentioned attribute is set.
947  if (trackingFailure.isSilenceableFailure())
948  (void)trackingFailure.silence();
949  trackingFailure = DiagnosedSilenceableFailure::success();
950  }
951  if (!trackingFailure.succeeded()) {
952  if (result.succeeded()) {
953  result = std::move(trackingFailure);
954  } else {
955  // Transform op errors have precedence, report those first.
956  if (result.isSilenceableFailure())
957  result.attachNote() << "tracking listener also failed: "
958  << trackingFailure.getMessage();
959  (void)trackingFailure.silence();
960  }
961  }
962  if (result.isDefiniteFailure())
963  return result;
964 
965  // If a silenceable failure was produced, some results may be unset, set them
966  // to empty lists.
967  if (result.isSilenceableFailure())
968  results.setRemainingToEmpty(transform);
969 
970  // Remove the mapping for the operand if it is consumed by the operation. This
971  // allows us to catch use-after-free with assertions later on.
972  for (OpOperand *opOperand : consumedOperands) {
973  Value operand = opOperand->get();
974  if (llvm::isa<TransformHandleTypeInterface>(operand.getType())) {
975  forgetMapping(operand, origOpFlatResults);
976  } else if (llvm::isa<TransformValueHandleTypeInterface>(
977  operand.getType())) {
978  forgetValueMapping(operand, origAssociatedOps);
979  }
980  }
981 
982  if (failed(updateStateFromResults(results, transform->getResults())))
984 
985  printOnFailureRAII.release();
986  DEBUG_WITH_TYPE(DEBUG_PRINT_AFTER_ALL, {
987  LDBG() << "Top-level payload:\n" << *getTopLevel();
988  });
989  return result;
990 }
991 
992 LogicalResult transform::TransformState::updateStateFromResults(
993  const TransformResults &results, ResultRange opResults) {
994  for (OpResult result : opResults) {
995  if (llvm::isa<TransformParamTypeInterface>(result.getType())) {
996  assert(results.isParam(result.getResultNumber()) &&
997  "expected parameters for the parameter-typed result");
998  if (failed(
999  setParams(result, results.getParams(result.getResultNumber())))) {
1000  return failure();
1001  }
1002  } else if (llvm::isa<TransformValueHandleTypeInterface>(result.getType())) {
1003  assert(results.isValue(result.getResultNumber()) &&
1004  "expected values for value-type-result");
1005  if (failed(setPayloadValues(
1006  result, results.getValues(result.getResultNumber())))) {
1007  return failure();
1008  }
1009  } else {
1010  assert(!results.isParam(result.getResultNumber()) &&
1011  "expected payload ops for the non-parameter typed result");
1012  if (failed(
1013  setPayloadOps(result, results.get(result.getResultNumber())))) {
1014  return failure();
1015  }
1016  }
1017  }
1018  return success();
1019 }
1020 
1021 //===----------------------------------------------------------------------===//
1022 // TransformState::Extension
1023 //===----------------------------------------------------------------------===//
1024 
1026 
1027 LogicalResult
1029  Operation *replacement) {
1030  // TODO: we may need to invalidate handles to operations and values nested in
1031  // the operation being replaced.
1032  return state.replacePayloadOp(op, replacement);
1033 }
1034 
1035 LogicalResult
1037  Value replacement) {
1038  return state.replacePayloadValue(value, replacement);
1039 }
1040 
1041 //===----------------------------------------------------------------------===//
1042 // TransformState::RegionScope
1043 //===----------------------------------------------------------------------===//
1044 
1046  // Remove handle invalidation notices as handles are going out of scope.
1047  // The same region may be re-entered leading to incorrect invalidation
1048  // errors.
1049  for (Block &block : *region) {
1050  for (Value handle : block.getArguments()) {
1051  state.invalidatedHandles.erase(handle);
1052  }
1053  for (Operation &op : block) {
1054  for (Value handle : op.getResults()) {
1055  state.invalidatedHandles.erase(handle);
1056  }
1057  }
1058  }
1059 
1060 #if LLVM_ENABLE_ABI_BREAKING_CHECKS
1061  // Remember pointers to payload ops referenced by the handles going out of
1062  // scope.
1063  SmallVector<Operation *> referencedOps =
1064  llvm::to_vector(llvm::make_first_range(state.mappings[region]->reverse));
1065 #endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
1066 
1067  state.mappings.erase(region);
1068  state.regionStack.pop_back();
1069 }
1070 
1071 //===----------------------------------------------------------------------===//
1072 // TransformResults
1073 //===----------------------------------------------------------------------===//
1074 
1075 transform::TransformResults::TransformResults(unsigned numSegments) {
1076  operations.appendEmptyRows(numSegments);
1077  params.appendEmptyRows(numSegments);
1078  values.appendEmptyRows(numSegments);
1079 }
1080 
1083  int64_t position = value.getResultNumber();
1084  assert(position < static_cast<int64_t>(this->params.size()) &&
1085  "setting params for a non-existent handle");
1086  assert(this->params[position].data() == nullptr && "params already set");
1087  assert(operations[position].data() == nullptr &&
1088  "another kind of results already set");
1089  assert(values[position].data() == nullptr &&
1090  "another kind of results already set");
1091  this->params.replace(position, params);
1092 }
1093 
1095  OpResult handle, ArrayRef<MappedValue> values) {
1097  handle, values,
1098  [&](ArrayRef<Operation *> operations) {
1099  return set(handle, operations), success();
1100  },
1101  [&](ArrayRef<Param> params) {
1102  return setParams(handle, params), success();
1103  },
1104  [&](ValueRange payloadValues) {
1105  return setValues(handle, payloadValues), success();
1106  });
1107 #ifndef NDEBUG
1108  if (!diag.succeeded())
1109  llvm::dbgs() << diag.getStatusString() << "\n";
1110  assert(diag.succeeded() && "incorrect mapping");
1111 #endif // NDEBUG
1112  (void)diag.silence();
1113 }
1114 
1116  transform::TransformOpInterface transform) {
1117  for (OpResult opResult : transform->getResults()) {
1118  if (!isSet(opResult.getResultNumber()))
1119  setMappedValues(opResult, {});
1120  }
1121 }
1122 
1124 transform::TransformResults::get(unsigned resultNumber) const {
1125  assert(resultNumber < operations.size() &&
1126  "querying results for a non-existent handle");
1127  assert(operations[resultNumber].data() != nullptr &&
1128  "querying unset results (values or params expected?)");
1129  return operations[resultNumber];
1130 }
1131 
1133 transform::TransformResults::getParams(unsigned resultNumber) const {
1134  assert(resultNumber < params.size() &&
1135  "querying params for a non-existent handle");
1136  assert(params[resultNumber].data() != nullptr &&
1137  "querying unset params (ops or values expected?)");
1138  return params[resultNumber];
1139 }
1140 
1142 transform::TransformResults::getValues(unsigned resultNumber) const {
1143  assert(resultNumber < values.size() &&
1144  "querying values for a non-existent handle");
1145  assert(values[resultNumber].data() != nullptr &&
1146  "querying unset values (ops or params expected?)");
1147  return values[resultNumber];
1148 }
1149 
1150 bool transform::TransformResults::isParam(unsigned resultNumber) const {
1151  assert(resultNumber < params.size() &&
1152  "querying association for a non-existent handle");
1153  return params[resultNumber].data() != nullptr;
1154 }
1155 
1156 bool transform::TransformResults::isValue(unsigned resultNumber) const {
1157  assert(resultNumber < values.size() &&
1158  "querying association for a non-existent handle");
1159  return values[resultNumber].data() != nullptr;
1160 }
1161 
1162 bool transform::TransformResults::isSet(unsigned resultNumber) const {
1163  assert(resultNumber < params.size() &&
1164  "querying association for a non-existent handle");
1165  return params[resultNumber].data() != nullptr ||
1166  operations[resultNumber].data() != nullptr ||
1167  values[resultNumber].data() != nullptr;
1168 }
1169 
1170 //===----------------------------------------------------------------------===//
1171 // TrackingListener
1172 //===----------------------------------------------------------------------===//
1173 
1175  TransformOpInterface op,
1177  : TransformState::Extension(state), transformOp(op), config(config) {
1178  if (op) {
1179  for (OpOperand *opOperand : transformOp.getConsumedHandleOpOperands()) {
1180  consumedHandles.insert(opOperand->get());
1181  }
1182  }
1183 }
1184 
1186  Operation *defOp = nullptr;
1187  for (Value v : values) {
1188  // Skip empty values.
1189  if (!v)
1190  continue;
1191  if (!defOp) {
1192  defOp = v.getDefiningOp();
1193  continue;
1194  }
1195  if (defOp != v.getDefiningOp())
1196  return nullptr;
1197  }
1198  return defOp;
1199 }
1200 
1202  Operation *&result, Operation *op, ValueRange newValues) const {
1203  assert(op->getNumResults() == newValues.size() &&
1204  "invalid number of replacement values");
1205  SmallVector<Value> values(newValues.begin(), newValues.end());
1206 
1208  getTransformOp(), "tracking listener failed to find replacement op "
1209  "during application of this transform op");
1210 
1211  do {
1212  // If the replacement values belong to different ops, drop the mapping.
1213  Operation *defOp = getCommonDefiningOp(values);
1214  if (!defOp) {
1215  diag.attachNote() << "replacement values belong to different ops";
1216  return diag;
1217  }
1218 
1219  // Skip through ops that implement CastOpInterface.
1220  if (config.skipCastOps && isa<CastOpInterface>(defOp)) {
1221  values.clear();
1222  values.assign(defOp->getOperands().begin(), defOp->getOperands().end());
1223  diag.attachNote(defOp->getLoc())
1224  << "using output of 'CastOpInterface' op";
1225  continue;
1226  }
1227 
1228  // If the defining op has the same name or we do not care about the name of
1229  // op replacements at all, we take it as a replacement.
1230  if (!config.requireMatchingReplacementOpName ||
1231  op->getName() == defOp->getName()) {
1232  result = defOp;
1234  }
1235 
1236  // Replacing an op with a constant-like equivalent is a common
1237  // canonicalization.
1238  if (defOp->hasTrait<OpTrait::ConstantLike>()) {
1239  result = defOp;
1241  }
1242 
1243  values.clear();
1244 
1245  // Skip through ops that implement FindPayloadReplacementOpInterface.
1246  if (auto findReplacementOpInterface =
1247  dyn_cast<FindPayloadReplacementOpInterface>(defOp)) {
1248  values.assign(findReplacementOpInterface.getNextOperands());
1249  diag.attachNote(defOp->getLoc()) << "using operands provided by "
1250  "'FindPayloadReplacementOpInterface'";
1251  continue;
1252  }
1253  } while (!values.empty());
1254 
1255  diag.attachNote() << "ran out of suitable replacement values";
1256  return diag;
1257 }
1258 
1260  Location loc, function_ref<void(Diagnostic &)> reasonCallback) {
1261  LLVM_DEBUG({
1263  reasonCallback(diag);
1264  LDBG() << "Match Failure : " << diag.str();
1265  });
1266 }
1267 
1268 void transform::TrackingListener::notifyOperationErased(Operation *op) {
1269  // Remove mappings for result values.
1270  for (OpResult value : op->getResults())
1271  (void)replacePayloadValue(value, nullptr);
1272  // Remove mapping for op.
1273  (void)replacePayloadOp(op, nullptr);
1274 }
1275 
1276 void transform::TrackingListener::notifyOperationReplaced(
1277  Operation *op, ValueRange newValues) {
1278  assert(op->getNumResults() == newValues.size() &&
1279  "invalid number of replacement values");
1280 
1281  // Replace value handles.
1282  for (auto [oldValue, newValue] : llvm::zip(op->getResults(), newValues))
1283  (void)replacePayloadValue(oldValue, newValue);
1284 
1285  // Replace op handle.
1286  SmallVector<Value> opHandles;
1287  if (failed(getTransformState().getHandlesForPayloadOp(
1288  op, opHandles, /*includeOutOfScope=*/true))) {
1289  // Op is not tracked.
1290  return;
1291  }
1292 
1293  // Helper function to check if the current transform op consumes any handle
1294  // that is mapped to `op`.
1295  //
1296  // Note: If a handle was consumed, there shouldn't be any alive users, so it
1297  // is not really necessary to check for consumed handles. However, in case
1298  // there are indeed alive handles that were consumed (which is undefined
1299  // behavior) and a replacement op could not be found, we want to fail with a
1300  // nicer error message: "op uses a handle invalidated..." instead of "could
1301  // not find replacement op". This nicer error is produced later.
1302  auto handleWasConsumed = [&] {
1303  return llvm::any_of(opHandles,
1304  [&](Value h) { return consumedHandles.contains(h); });
1305  };
1306 
1307  // Check if there are any handles that must be updated.
1308  Value aliveHandle;
1309  if (config.skipHandleFn) {
1310  auto it = llvm::find_if(opHandles,
1311  [&](Value v) { return !config.skipHandleFn(v); });
1312  if (it != opHandles.end())
1313  aliveHandle = *it;
1314  } else if (!opHandles.empty()) {
1315  aliveHandle = opHandles.front();
1316  }
1317  if (!aliveHandle || handleWasConsumed()) {
1318  // The op is tracked but the corresponding handles are dead or were
1319  // consumed. Drop the op form the mapping.
1320  (void)replacePayloadOp(op, nullptr);
1321  return;
1322  }
1323 
1324  Operation *replacement;
1326  findReplacementOp(replacement, op, newValues);
1327  // If the op is tracked but no replacement op was found, send a
1328  // notification.
1329  if (!diag.succeeded()) {
1330  diag.attachNote(aliveHandle.getLoc())
1331  << "replacement is required because this handle must be updated";
1332  notifyPayloadReplacementNotFound(op, newValues, std::move(diag));
1333  (void)replacePayloadOp(op, nullptr);
1334  return;
1335  }
1336 
1337  (void)replacePayloadOp(op, replacement);
1338 }
1339 
1341  // The state of the ErrorCheckingTrackingListener must be checked and reset
1342  // if there was an error. This is to prevent errors from accidentally being
1343  // missed.
1344  assert(status.succeeded() && "listener state was not checked");
1345 }
1346 
1349  DiagnosedSilenceableFailure s = std::move(status);
1351  errorCounter = 0;
1352  return s;
1353 }
1354 
1356  return !status.succeeded();
1357 }
1358 
1361 
1362  // Merge potentially existing diags and store the result in the listener.
1364  diag.takeDiagnostics(diags);
1365  if (!status.succeeded())
1366  status.takeDiagnostics(diags);
1367  status = DiagnosedSilenceableFailure::silenceableFailure(std::move(diags));
1368 
1369  // Report more details.
1370  status.attachNote(op->getLoc()) << "[" << errorCounter << "] replaced op";
1371  for (auto &&[index, value] : llvm::enumerate(values))
1372  status.attachNote(value.getLoc())
1373  << "[" << errorCounter << "] replacement value " << index;
1374  ++errorCounter;
1375 }
1376 
1377 std::string
1379  if (!matchFailure) {
1380  return "";
1381  }
1382  return matchFailure->str();
1383 }
1384 
1386  Location loc, function_ref<void(Diagnostic &)> reasonCallback) {
1388  reasonCallback(diag);
1389  matchFailure = std::move(diag);
1390 }
1391 
1392 //===----------------------------------------------------------------------===//
1393 // TransformRewriter
1394 //===----------------------------------------------------------------------===//
1395 
1398  : RewriterBase(ctx), listener(listener) {
1399  setListener(listener);
1400 }
1401 
1403  return listener->failed();
1404 }
1405 
1406 /// Silence all tracking failures that have been encountered so far.
1408  if (hasTrackingFailures()) {
1409  DiagnosedSilenceableFailure status = listener->checkAndResetError();
1410  (void)status.silence();
1411  }
1412 }
1413 
1415  Operation *op, Operation *replacement) {
1416  return listener->replacePayloadOp(op, replacement);
1417 }
1418 
1419 //===----------------------------------------------------------------------===//
1420 // Utilities for TransformEachOpTrait.
1421 //===----------------------------------------------------------------------===//
1422 
1423 LogicalResult
1425  ArrayRef<Operation *> targets) {
1426  for (auto &&[position, parent] : llvm::enumerate(targets)) {
1427  for (Operation *child : targets.drop_front(position + 1)) {
1428  if (parent->isAncestor(child)) {
1430  emitError(loc)
1431  << "transform operation consumes a handle pointing to an ancestor "
1432  "payload operation before its descendant";
1433  diag.attachNote()
1434  << "the ancestor is likely erased or rewritten before the "
1435  "descendant is accessed, leading to undefined behavior";
1436  diag.attachNote(parent->getLoc()) << "ancestor payload op";
1437  diag.attachNote(child->getLoc()) << "descendant payload op";
1438  return diag;
1439  }
1440  }
1441  }
1442  return success();
1443 }
1444 
1445 LogicalResult
1447  Location payloadOpLoc,
1448  const ApplyToEachResultList &partialResult) {
1449  Location transformOpLoc = transformOp->getLoc();
1450  StringRef transformOpName = transformOp->getName().getStringRef();
1451  unsigned expectedNumResults = transformOp->getNumResults();
1452 
1453  // Reuse the emission of the diagnostic note.
1454  auto emitDiag = [&]() {
1455  auto diag = mlir::emitError(transformOpLoc);
1456  diag.attachNote(payloadOpLoc) << "when applied to this op";
1457  return diag;
1458  };
1459 
1460  if (partialResult.size() != expectedNumResults) {
1461  auto diag = emitDiag() << "application of " << transformOpName
1462  << " expected to produce " << expectedNumResults
1463  << " results (actually produced "
1464  << partialResult.size() << ").";
1465  diag.attachNote(transformOpLoc)
1466  << "if you need variadic results, consider a generic `apply` "
1467  << "instead of the specialized `applyToOne`.";
1468  return failure();
1469  }
1470 
1471  // Check that the right kind of value was produced.
1472  for (const auto &[ptr, res] :
1473  llvm::zip(partialResult, transformOp->getResults())) {
1474  if (ptr.isNull())
1475  continue;
1476  if (llvm::isa<TransformHandleTypeInterface>(res.getType()) &&
1477  !isa<Operation *>(ptr)) {
1478  return emitDiag() << "application of " << transformOpName
1479  << " expected to produce an Operation * for result #"
1480  << res.getResultNumber();
1481  }
1482  if (llvm::isa<TransformParamTypeInterface>(res.getType()) &&
1483  !isa<Attribute>(ptr)) {
1484  return emitDiag() << "application of " << transformOpName
1485  << " expected to produce an Attribute for result #"
1486  << res.getResultNumber();
1487  }
1488  if (llvm::isa<TransformValueHandleTypeInterface>(res.getType()) &&
1489  !isa<Value>(ptr)) {
1490  return emitDiag() << "application of " << transformOpName
1491  << " expected to produce a Value for result #"
1492  << res.getResultNumber();
1493  }
1494  }
1495  return success();
1496 }
1497 
1498 template <typename T>
1500  return llvm::to_vector(llvm::map_range(
1501  range, [](transform::MappedValue value) { return cast<T>(value); }));
1502 }
1503 
1505  Operation *transformOp, TransformResults &transformResults,
1508  transposed.resize(transformOp->getNumResults());
1509  for (const ApplyToEachResultList &partialResults : results) {
1510  if (llvm::any_of(partialResults,
1511  [](MappedValue value) { return value.isNull(); }))
1512  continue;
1513  assert(transformOp->getNumResults() == partialResults.size() &&
1514  "expected as many partial results as op as results");
1515  for (auto [i, value] : llvm::enumerate(partialResults))
1516  transposed[i].push_back(value);
1517  }
1518 
1519  for (OpResult r : transformOp->getResults()) {
1520  unsigned position = r.getResultNumber();
1521  if (llvm::isa<TransformParamTypeInterface>(r.getType())) {
1522  transformResults.setParams(r,
1523  castVector<Attribute>(transposed[position]));
1524  } else if (llvm::isa<TransformValueHandleTypeInterface>(r.getType())) {
1525  transformResults.setValues(r, castVector<Value>(transposed[position]));
1526  } else {
1527  transformResults.set(r, castVector<Operation *>(transposed[position]));
1528  }
1529  }
1530 }
1531 
1532 //===----------------------------------------------------------------------===//
1533 // Utilities for implementing transform ops with regions.
1534 //===----------------------------------------------------------------------===//
1535 
1538  ValueRange values, const transform::TransformState &state, bool flatten) {
1539  assert(mappings.size() == values.size() && "mismatching number of mappings");
1540  for (auto &&[operand, mapped] : llvm::zip_equal(values, mappings)) {
1541  size_t mappedSize = mapped.size();
1542  if (llvm::isa<TransformHandleTypeInterface>(operand.getType())) {
1543  llvm::append_range(mapped, state.getPayloadOps(operand));
1544  } else if (llvm::isa<TransformValueHandleTypeInterface>(
1545  operand.getType())) {
1546  llvm::append_range(mapped, state.getPayloadValues(operand));
1547  } else {
1548  assert(llvm::isa<TransformParamTypeInterface>(operand.getType()) &&
1549  "unsupported kind of transform dialect value");
1550  llvm::append_range(mapped, state.getParams(operand));
1551  }
1552 
1553  if (mapped.size() - mappedSize != 1 && !flatten)
1554  return failure();
1555  }
1556  return success();
1557 }
1558 
1561  ValueRange values, const transform::TransformState &state) {
1562  mappings.resize(mappings.size() + values.size());
1563  (void)appendValueMappings(
1565  values.size()),
1566  values, state);
1567 }
1568 
1570  Block *block, transform::TransformState &state,
1571  transform::TransformResults &results) {
1572  for (auto &&[terminatorOperand, result] :
1573  llvm::zip(block->getTerminator()->getOperands(),
1574  block->getParentOp()->getOpResults())) {
1575  if (llvm::isa<transform::TransformHandleTypeInterface>(result.getType())) {
1576  results.set(result, state.getPayloadOps(terminatorOperand));
1577  } else if (llvm::isa<transform::TransformValueHandleTypeInterface>(
1578  result.getType())) {
1579  results.setValues(result, state.getPayloadValues(terminatorOperand));
1580  } else {
1581  assert(
1582  llvm::isa<transform::TransformParamTypeInterface>(result.getType()) &&
1583  "unhandled transform type interface");
1584  results.setParams(result, state.getParams(terminatorOperand));
1585  }
1586  }
1587 }
1588 
1591  Operation *payloadRoot) {
1592  return TransformState(region, payloadRoot);
1593 }
1594 
1595 //===----------------------------------------------------------------------===//
1596 // Utilities for PossibleTopLevelTransformOpTrait.
1597 //===----------------------------------------------------------------------===//
1598 
1599 /// Appends to `effects` the memory effect instances on `target` with the same
1600 /// resource and effect as the ones the operation `iface` having on `source`.
1601 static void
1602 remapEffects(MemoryEffectOpInterface iface, BlockArgument source,
1603  OpOperand *target,
1606  iface.getEffectsOnValue(source, nestedEffects);
1607  for (const auto &effect : nestedEffects)
1608  effects.emplace_back(effect.getEffect(), target, effect.getResource());
1609 }
1610 
1611 /// Appends to `effects` the same effects as the operations of `block` have on
1612 /// block arguments but associated with `operands.`
1613 static void
1616  for (Operation &op : block) {
1617  auto iface = dyn_cast<MemoryEffectOpInterface>(&op);
1618  if (!iface)
1619  continue;
1620 
1621  for (auto &&[source, target] : llvm::zip(block.getArguments(), operands)) {
1622  remapEffects(iface, source, &target, effects);
1623  }
1624 
1626  iface.getEffectsOnResource(transform::PayloadIRResource::get(),
1627  nestedEffects);
1628  llvm::append_range(effects, nestedEffects);
1629  }
1630 }
1631 
1633  Operation *operation, Value root, Block &body,
1635  transform::onlyReadsHandle(operation->getOpOperands(), effects);
1636  transform::producesHandle(operation->getOpResults(), effects);
1637 
1638  if (!root) {
1639  for (Operation &op : body) {
1640  auto iface = dyn_cast<MemoryEffectOpInterface>(&op);
1641  if (!iface)
1642  continue;
1643 
1644  iface.getEffects(effects);
1645  }
1646  return;
1647  }
1648 
1649  // Carry over all effects on arguments of the entry block as those on the
1650  // operands, this is the same value just remapped.
1651  remapArgumentEffects(body, operation->getOpOperands(), effects);
1652 }
1653 
1655  TransformState &state, Operation *op, Region &region) {
1656  SmallVector<Operation *> targets;
1657  SmallVector<SmallVector<MappedValue>> extraMappings;
1658  if (op->getNumOperands() != 0) {
1659  llvm::append_range(targets, state.getPayloadOps(op->getOperand(0)));
1660  prepareValueMappings(extraMappings, op->getOperands().drop_front(), state);
1661  } else {
1662  if (state.getNumTopLevelMappings() !=
1663  region.front().getNumArguments() - 1) {
1664  return emitError(op->getLoc())
1665  << "operation expects " << region.front().getNumArguments() - 1
1666  << " extra value bindings, but " << state.getNumTopLevelMappings()
1667  << " were provided to the interpreter";
1668  }
1669 
1670  targets.push_back(state.getTopLevel());
1671 
1672  for (unsigned i = 0, e = state.getNumTopLevelMappings(); i < e; ++i)
1673  extraMappings.push_back(llvm::to_vector(state.getTopLevelMapping(i)));
1674  }
1675 
1676  if (failed(state.mapBlockArguments(region.front().getArgument(0), targets)))
1677  return failure();
1678 
1679  for (BlockArgument argument : region.front().getArguments().drop_front()) {
1680  if (failed(state.mapBlockArgument(
1681  argument, extraMappings[argument.getArgNumber() - 1])))
1682  return failure();
1683  }
1684 
1685  return success();
1686 }
1687 
1688 LogicalResult
1690  // Attaching this trait without the interface is a misuse of the API, but it
1691  // cannot be caught via a static_assert because interface registration is
1692  // dynamic.
1693  assert(isa<TransformOpInterface>(op) &&
1694  "should implement TransformOpInterface to have "
1695  "PossibleTopLevelTransformOpTrait");
1696 
1697  if (op->getNumRegions() < 1)
1698  return op->emitOpError() << "expects at least one region";
1699 
1700  Region *bodyRegion = &op->getRegion(0);
1701  if (!llvm::hasNItems(*bodyRegion, 1))
1702  return op->emitOpError() << "expects a single-block region";
1703 
1704  Block *body = &bodyRegion->front();
1705  if (body->getNumArguments() == 0) {
1706  return op->emitOpError()
1707  << "expects the entry block to have at least one argument";
1708  }
1709  if (!llvm::isa<TransformHandleTypeInterface>(
1710  body->getArgument(0).getType())) {
1711  return op->emitOpError()
1712  << "expects the first entry block argument to be of type "
1713  "implementing TransformHandleTypeInterface";
1714  }
1715  BlockArgument arg = body->getArgument(0);
1716  if (op->getNumOperands() != 0) {
1717  if (arg.getType() != op->getOperand(0).getType()) {
1718  return op->emitOpError()
1719  << "expects the type of the block argument to match "
1720  "the type of the operand";
1721  }
1722  }
1723  for (BlockArgument arg : body->getArguments().drop_front()) {
1724  if (llvm::isa<TransformHandleTypeInterface, TransformParamTypeInterface,
1725  TransformValueHandleTypeInterface>(arg.getType()))
1726  continue;
1727 
1729  op->emitOpError()
1730  << "expects trailing entry block arguments to be of type implementing "
1731  "TransformHandleTypeInterface, TransformValueHandleTypeInterface or "
1732  "TransformParamTypeInterface";
1733  diag.attachNote() << "argument #" << arg.getArgNumber() << " does not";
1734  return diag;
1735  }
1736 
1737  if (auto *parent =
1739  if (op->getNumOperands() != body->getNumArguments()) {
1741  op->emitOpError()
1742  << "expects operands to be provided for a nested op";
1743  diag.attachNote(parent->getLoc())
1744  << "nested in another possible top-level op";
1745  return diag;
1746  }
1747  }
1748 
1749  return success();
1750 }
1751 
1752 //===----------------------------------------------------------------------===//
1753 // Utilities for ParamProducedTransformOpTrait.
1754 //===----------------------------------------------------------------------===//
1755 
1758  producesHandle(op->getResults(), effects);
1759  bool hasPayloadOperands = false;
1760  for (OpOperand &operand : op->getOpOperands()) {
1761  onlyReadsHandle(operand, effects);
1762  if (llvm::isa<TransformHandleTypeInterface,
1763  TransformValueHandleTypeInterface>(operand.get().getType()))
1764  hasPayloadOperands = true;
1765  }
1766  if (hasPayloadOperands)
1767  onlyReadsPayload(effects);
1768 }
1769 
1770 LogicalResult
1772  // Interfaces can be attached dynamically, so this cannot be a static
1773  // assert.
1774  if (!op->getName().getInterface<MemoryEffectOpInterface>()) {
1775  llvm::report_fatal_error(
1776  Twine("ParamProducerTransformOpTrait must be attached to an op that "
1777  "implements MemoryEffectsOpInterface, found on ") +
1778  op->getName().getStringRef());
1779  }
1780  for (Value result : op->getResults()) {
1781  if (llvm::isa<TransformParamTypeInterface>(result.getType()))
1782  continue;
1783  return op->emitOpError()
1784  << "ParamProducerTransformOpTrait attached to this op expects "
1785  "result types to implement TransformParamTypeInterface";
1786  }
1787  return success();
1788 }
1789 
1790 //===----------------------------------------------------------------------===//
1791 // Memory effects.
1792 //===----------------------------------------------------------------------===//
1793 
1797  for (OpOperand &handle : handles) {
1798  effects.emplace_back(MemoryEffects::Read::get(), &handle,
1800  effects.emplace_back(MemoryEffects::Free::get(), &handle,
1802  }
1803 }
1804 
1805 /// Returns `true` if the given list of effects instances contains an instance
1806 /// with the effect type specified as template parameter.
1807 template <typename EffectTy, typename ResourceTy, typename Range>
1808 static bool hasEffect(Range &&effects) {
1809  return llvm::any_of(effects, [](const MemoryEffects::EffectInstance &effect) {
1810  return isa<EffectTy>(effect.getEffect()) &&
1811  isa<ResourceTy>(effect.getResource());
1812  });
1813 }
1814 
1816  transform::TransformOpInterface transform) {
1817  auto iface = cast<MemoryEffectOpInterface>(transform.getOperation());
1819  iface.getEffectsOnValue(handle, effects);
1820  return ::hasEffect<MemoryEffects::Read, TransformMappingResource>(effects) &&
1821  ::hasEffect<MemoryEffects::Free, TransformMappingResource>(effects);
1822 }
1823 
1825  ResultRange handles,
1827  for (OpResult handle : handles) {
1828  effects.emplace_back(MemoryEffects::Allocate::get(), handle,
1830  effects.emplace_back(MemoryEffects::Write::get(), handle,
1832  }
1833 }
1834 
1838  for (BlockArgument handle : handles) {
1839  effects.emplace_back(MemoryEffects::Allocate::get(), handle,
1841  effects.emplace_back(MemoryEffects::Write::get(), handle,
1843  }
1844 }
1845 
1849  for (OpOperand &handle : handles) {
1850  effects.emplace_back(MemoryEffects::Read::get(), &handle,
1852  }
1853 }
1854 
1857  effects.emplace_back(MemoryEffects::Read::get(), PayloadIRResource::get());
1858  effects.emplace_back(MemoryEffects::Write::get(), PayloadIRResource::get());
1859 }
1860 
1863  effects.emplace_back(MemoryEffects::Read::get(), PayloadIRResource::get());
1864 }
1865 
1866 bool transform::doesModifyPayload(transform::TransformOpInterface transform) {
1867  auto iface = cast<MemoryEffectOpInterface>(transform.getOperation());
1869  iface.getEffects(effects);
1870  return ::hasEffect<MemoryEffects::Write, PayloadIRResource>(effects);
1871 }
1872 
1873 bool transform::doesReadPayload(transform::TransformOpInterface transform) {
1874  auto iface = cast<MemoryEffectOpInterface>(transform.getOperation());
1876  iface.getEffects(effects);
1877  return ::hasEffect<MemoryEffects::Read, PayloadIRResource>(effects);
1878 }
1879 
1881  Block &block, llvm::SmallDenseSet<unsigned int> &consumedArguments) {
1883  for (Operation &nested : block) {
1884  auto iface = dyn_cast<MemoryEffectOpInterface>(nested);
1885  if (!iface)
1886  continue;
1887 
1888  effects.clear();
1889  iface.getEffects(effects);
1890  for (const MemoryEffects::EffectInstance &effect : effects) {
1891  BlockArgument argument =
1892  dyn_cast_or_null<BlockArgument>(effect.getValue());
1893  if (!argument || argument.getOwner() != &block ||
1894  !isa<MemoryEffects::Free>(effect.getEffect()) ||
1895  effect.getResource() != transform::TransformMappingResource::get()) {
1896  continue;
1897  }
1898  consumedArguments.insert(argument.getArgNumber());
1899  }
1900  }
1901 }
1902 
1903 //===----------------------------------------------------------------------===//
1904 // Utilities for TransformOpInterface.
1905 //===----------------------------------------------------------------------===//
1906 
1908  TransformOpInterface transformOp) {
1909  SmallVector<OpOperand *> consumedOperands;
1910  consumedOperands.reserve(transformOp->getNumOperands());
1911  auto memEffectInterface =
1912  cast<MemoryEffectOpInterface>(transformOp.getOperation());
1914  for (OpOperand &target : transformOp->getOpOperands()) {
1915  effects.clear();
1916  memEffectInterface.getEffectsOnValue(target.get(), effects);
1917  if (llvm::any_of(effects, [](const MemoryEffects::EffectInstance &effect) {
1918  return isa<transform::TransformMappingResource>(
1919  effect.getResource()) &&
1920  isa<MemoryEffects::Free>(effect.getEffect());
1921  })) {
1922  consumedOperands.push_back(&target);
1923  }
1924  }
1925  return consumedOperands;
1926 }
1927 
1929  auto iface = cast<MemoryEffectOpInterface>(op);
1931  iface.getEffects(effects);
1932 
1933  auto effectsOn = [&](Value value) {
1934  return llvm::make_filter_range(
1935  effects, [value](const MemoryEffects::EffectInstance &instance) {
1936  return instance.getValue() == value;
1937  });
1938  };
1939 
1940  std::optional<unsigned> firstConsumedOperand;
1941  for (OpOperand &operand : op->getOpOperands()) {
1942  auto range = effectsOn(operand.get());
1943  if (range.empty()) {
1945  op->emitError() << "TransformOpInterface requires memory effects "
1946  "on operands to be specified";
1947  diag.attachNote() << "no effects specified for operand #"
1948  << operand.getOperandNumber();
1949  return diag;
1950  }
1951  if (::hasEffect<MemoryEffects::Allocate, TransformMappingResource>(range)) {
1953  << "TransformOpInterface did not expect "
1954  "'allocate' memory effect on an operand";
1955  diag.attachNote() << "specified for operand #"
1956  << operand.getOperandNumber();
1957  return diag;
1958  }
1959  if (!firstConsumedOperand &&
1960  ::hasEffect<MemoryEffects::Free, TransformMappingResource>(range)) {
1961  firstConsumedOperand = operand.getOperandNumber();
1962  }
1963  }
1964 
1965  if (firstConsumedOperand &&
1966  !::hasEffect<MemoryEffects::Write, PayloadIRResource>(effects)) {
1968  op->emitError()
1969  << "TransformOpInterface expects ops consuming operands to have a "
1970  "'write' effect on the payload resource";
1971  diag.attachNote() << "consumes operand #" << *firstConsumedOperand;
1972  return diag;
1973  }
1974 
1975  for (OpResult result : op->getResults()) {
1976  auto range = effectsOn(result);
1977  if (!::hasEffect<MemoryEffects::Allocate, TransformMappingResource>(
1978  range)) {
1980  op->emitError() << "TransformOpInterface requires 'allocate' memory "
1981  "effect to be specified for results";
1982  diag.attachNote() << "no 'allocate' effect specified for result #"
1983  << result.getResultNumber();
1984  return diag;
1985  }
1986  }
1987 
1988  return success();
1989 }
1990 
1991 //===----------------------------------------------------------------------===//
1992 // Entry point.
1993 //===----------------------------------------------------------------------===//
1994 
1996  Operation *payloadRoot, TransformOpInterface transform,
1997  const RaggedArray<MappedValue> &extraMapping,
1998  const TransformOptions &options, bool enforceToplevelTransformOp,
1999  function_ref<void(TransformState &)> stateInitializer,
2000  function_ref<LogicalResult(TransformState &)> stateExporter) {
2001  if (enforceToplevelTransformOp) {
2002  if (!transform->hasTrait<PossibleTopLevelTransformOpTrait>() ||
2003  transform->getNumOperands() != 0) {
2004  return transform->emitError()
2005  << "expected transform to start at the top-level transform op";
2006  }
2007  } else if (failed(
2009  return failure();
2010  }
2011 
2012  TransformState state(transform->getParentRegion(), payloadRoot, extraMapping,
2013  options);
2014  if (stateInitializer)
2015  stateInitializer(state);
2016  if (state.applyTransform(transform).checkAndReport().failed())
2017  return failure();
2018  if (stateExporter)
2019  return stateExporter(state);
2020  return success();
2021 }
2022 
2023 //===----------------------------------------------------------------------===//
2024 // Generated interface implementation.
2025 //===----------------------------------------------------------------------===//
2026 
2027 #include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.cpp.inc"
2028 #include "mlir/Dialect/Transform/Interfaces/TransformTypeInterfaces.cpp.inc"
static Operation * getCommonDefiningOp(const ValueVector &values)
Return the operation that defines all values in the vector.
static InFlightDiagnostic emitDiag(Location location, DiagnosticSeverity severity, const Twine &message)
Helper function used to emit a diagnostic with an optionally empty twine message.
static std::string diag(const llvm::Value &value)
static llvm::ManagedStatic< PassManagerOptions > options
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
#define FULL_LDBG()
static void remapEffects(MemoryEffectOpInterface iface, BlockArgument source, OpOperand *target, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
Appends to effects the memory effect instances on target with the same resource and effect as the one...
static void remapArgumentEffects(Block &block, MutableArrayRef< OpOperand > operands, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
Appends to effects the same effects as the operations of block have on block arguments but associated...
static bool happensBefore(Operation *a, Operation *b)
Return true if a happens before b, i.e., a or one of its ancestors properly dominates b and b is not ...
static void dropMappingEntry(Mapping &mapping, Key key, Mapped mapped)
#define DEBUG_PRINT_AFTER_ALL
static DiagnosedSilenceableFailure checkRepeatedConsumptionInOperand(ArrayRef< T > payload, transform::TransformOpInterface transform, unsigned operandNumber)
static DiagnosedSilenceableFailure dispatchMappedValues(Value handle, ArrayRef< transform::MappedValue > values, function_ref< LogicalResult(ArrayRef< Operation * >)> operationsFn, function_ref< LogicalResult(ArrayRef< transform::Param >)> paramsFn, function_ref< LogicalResult(ValueRange)> valuesFn)
Given a list of MappedValues, cast them to the value kind implied by the interface of the handle type...
static SmallVector< T > castVector(ArrayRef< transform::MappedValue > range)
Attributes are known-constant values of operations.
Definition: Attributes.h:25
This class represents an argument of a Block.
Definition: Value.h:309
Block * getOwner() const
Returns the block that owns this argument.
Definition: Value.h:318
unsigned getArgNumber() const
Returns the number of this argument.
Definition: Value.h:321
Block represents an ordered list of Operations.
Definition: Block.h:33
MutableArrayRef< BlockArgument > BlockArgListType
Definition: Block.h:85
Operation * findAncestorOpInBlock(Operation &op)
Returns 'op' if 'op' lies in this block, or otherwise finds the ancestor operation of 'op' that lies ...
Definition: Block.cpp:74
BlockArgument getArgument(unsigned i)
Definition: Block.h:129
unsigned getNumArguments()
Definition: Block.h:128
Operation * getTerminator()
Get the terminator operation of this block.
Definition: Block.cpp:244
BlockArgListType getArguments()
Definition: Block.h:87
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
Definition: Block.cpp:31
A compatibility class connecting InFlightDiagnostic to DiagnosedSilenceableFailure while providing an...
The result of a transform IR operation application.
LogicalResult silence()
Converts silenceable failure into LogicalResult success without reporting the diagnostic,...
static DiagnosedSilenceableFailure success()
Constructs a DiagnosedSilenceableFailure in the success state.
Diagnostic & attachNote(std::optional< Location > loc=std::nullopt)
Attaches a note to the last diagnostic.
std::string getMessage() const
Returns the diagnostic message without emitting it.
bool isDefiniteFailure() const
Returns true if this is a definite failure.
LogicalResult checkAndReport()
Converts all kinds of failure into a LogicalResult failure, emitting the diagnostic if necessary.
static DiagnosedSilenceableFailure silenceableFailure(Diagnostic &&diag)
Constructs a DiagnosedSilenceableFailure in the silenceable failure state, ready to emit the given di...
bool succeeded() const
Returns true if this is a success.
static DiagnosedSilenceableFailure definiteFailure()
Constructs a DiagnosedSilenceableFailure in the failure state.
bool isSilenceableFailure() const
Returns true if this is a silenceable failure.
This class contains all of the information necessary to report a diagnostic to the DiagnosticEngine.
Definition: Diagnostics.h:155
IRValueT get() const
Return the current value being used by this operand.
Definition: UseDefLists.h:160
This class represents a diagnostic that is inflight and set to be reported.
Definition: Diagnostics.h:314
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:63
void setListener(Listener *newListener)
Sets the listener of this builder to the one provided.
Definition: Builders.h:314
This class represents an operand of an operation.
Definition: Value.h:257
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Definition: Value.cpp:226
Set of flags used to control the behavior of the various IR print methods (e.g.
This is a value defined by a result of an operation.
Definition: Value.h:447
unsigned getResultNumber() const
Returns the number of this result.
Definition: Value.h:459
This class provides the API for a sub-set of ops that are known to be constant-like.
This class provides the API for ops that are known to be isolated from above.
A wrapper class that allows for printing an operation with a set of flags, useful to act as a "stream...
Definition: Operation.h:1111
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
StringAttr getIdentifier() const
Return the name of this operation as a StringAttr.
T::Concept * getInterface() const
Returns an instance of the concept object for the given interface if it was registered to this operat...
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Value getOperand(unsigned idx)
Definition: Operation.h:350
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition: Operation.h:749
bool isBeforeInBlock(Operation *other)
Given an operation 'other' that is within the same parent block, return whether the current operation...
Definition: Operation.cpp:385
unsigned getNumRegions()
Returns the number of regions held by this operation.
Definition: Operation.h:674
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
unsigned getNumOperands()
Definition: Operation.h:346
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition: Operation.h:234
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:267
Block * getBlock()
Returns the operation block that contains this operation.
Definition: Operation.h:213
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Definition: Operation.h:686
Operation * getParentWithTrait()
Returns the closest surrounding parent operation with trait Trait.
Definition: Operation.h:248
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:119
MutableArrayRef< OpOperand > getOpOperands()
Definition: Operation.h:383
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:378
result_range getOpResults()
Definition: Operation.h:420
result_range getResults()
Definition: Operation.h:415
bool isProperAncestor(Operation *other)
Return true if this operation is a proper ancestor of the other operation.
Definition: Operation.cpp:218
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:672
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:404
A 2D array where each row may have different length.
Definition: RaggedArray.h:18
size_t size() const
Returns the number of rows in the 2D array.
Definition: RaggedArray.h:21
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
Operation * getParentOp()
Return the parent operation this region is attached to.
Definition: Region.h:200
Block & front()
Definition: Region.h:65
This class implements the result iterators for the Operation class.
Definition: ValueRange.h:247
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:358
This class represents a specific instance of an effect.
Resource * getResource() const
Return the resource that the effect applies to.
EffectT * getEffect() const
Return the effect being applied.
Value getValue() const
Return the value the effect is applied on, or nullptr if there isn't a known value being affected.
static DerivedEffect * get()
Returns a unique instance for the derived effect class.
static TransformMappingResource * get()
Returns a unique instance for the given effect class.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:24
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:18
Operation * getOwner() const
Return the owner of this operand.
Definition: UseDefLists.h:38
A list of results of applying a transform op with ApplyEachOpTrait to a single payload operation,...
size_t size() const
Returns the number of elements in the list.
A specialized listener that keeps track of cases in which no replacement payload could be found.
bool failed() const
Return "true" if this tracking listener had a failure.
std::string getLatestMatchFailureMessage()
Return the latest match notification message.
void notifyPayloadReplacementNotFound(Operation *op, ValueRange values, DiagnosedSilenceableFailure &&diag) override
This function is called when a tracked payload op is dropped because no replacement op was found.
void notifyMatchFailure(Location loc, function_ref< void(Diagnostic &)> reasonCallback) override
Notify the listener that the pattern failed to match the given operation, and provide a callback to p...
DiagnosedSilenceableFailure checkAndResetError()
Check and return the current error state of this listener.
This trait is supposed to be attached to Transform dialect operations that can be standalone top-leve...
TrackingListener failures are reported only for ops that have this trait.
void notifyMatchFailure(Location loc, function_ref< void(Diagnostic &)> reasonCallback) override
Notify the listener that the pattern failed to match the given operation, and provide a callback to p...
TrackingListener(TransformState &state, TransformOpInterface op, TrackingListenerConfig config=TrackingListenerConfig())
Create a new TrackingListener for usage in the specified transform op.
static Operation * getCommonDefiningOp(ValueRange values)
Return the single op that defines all given values (if any).
virtual DiagnosedSilenceableFailure findReplacementOp(Operation *&result, Operation *op, ValueRange newValues) const
Return a replacement payload op for the given op, which is going to be replaced with the given values...
Options controlling the application of transform operations by the TransformState.
Local mapping between values defined by a specific op implementing the TransformOpInterface and the p...
void setValues(OpResult handle, Range &&values)
Indicates that the result of the transform IR op at the given position corresponds to the given range...
void setParams(OpResult value, ArrayRef< TransformState::Param > params)
Indicates that the result of the transform IR op at the given position corresponds to the given list ...
void set(OpResult value, Range &&ops)
Indicates that the result of the transform IR op at the given position corresponds to the given list ...
void setRemainingToEmpty(TransformOpInterface transform)
Sets the currently unset results to empty lists of the kind expected by the corresponding results of ...
void setMappedValues(OpResult handle, ArrayRef< MappedValue > values)
Indicates that the result of the transform IR op at the given position corresponds to the given range...
This is a special rewriter to be used in transform op implementations, providing additional helper fu...
TransformRewriter(MLIRContext *ctx, ErrorCheckingTrackingListener *listener)
Create a new TransformRewriter.
bool hasTrackingFailures() const
Return "true" if the tracking listener had failures.
LogicalResult notifyPayloadOperationReplaced(Operation *op, Operation *replacement)
Notify the transform dialect interpreter that the given op has been replaced with another op and that...
void silenceTrackingFailure()
Silence all tracking failures that have been encountered so far.
LogicalResult replacePayloadOp(Operation *op, Operation *replacement)
Replaces the given payload op with another op.
virtual ~Extension()
Base virtual destructor.
LogicalResult replacePayloadValue(Value value, Value replacement)
Replaces the given payload value with another value.
A RAII object maintaining a "stack frame" for a transform IR region.
~RegionScope()
Forgets the mapping from or to values defined in the associated transform IR region,...
The state maintained across applications of various ops implementing the TransformOpInterface.
LogicalResult getHandlesForPayloadValue(Value payloadValue, SmallVectorImpl< Value > &handles, bool includeOutOfScope=false) const
Populates handles with all handles pointing to the given payload IR value.
LogicalResult mapBlockArguments(BlockArgument argument, ArrayRef< Operation * > operations)
Records the mapping between a block argument in the transform IR and a list of operations in the payl...
DiagnosedSilenceableFailure applyTransform(TransformOpInterface transform)
Applies the transformation specified by the given transform op and updates the state accordingly.
ArrayRef< Attribute > getParams(Value value) const
Returns the list of parameters that the given transform IR value corresponds to.
LogicalResult mapBlockArgument(BlockArgument argument, ArrayRef< MappedValue > values)
LogicalResult getHandlesForPayloadOp(Operation *op, SmallVectorImpl< Value > &handles, bool includeOutOfScope=false) const
Populates handles with all handles pointing to the given Payload IR op.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition: Remarks.h:491
LogicalResult verifyTransformOpInterface(Operation *op)
Verification hook for TransformOpInterface.
LogicalResult appendValueMappings(MutableArrayRef< SmallVector< transform::MappedValue >> mappings, ValueRange values, const transform::TransformState &state, bool flatten=true)
Appends the entities associated with the given transform values in state to the pre-existing list of ...
void setApplyToOneResults(Operation *transformOp, TransformResults &transformResults, ArrayRef< ApplyToEachResultList > results)
"Transpose" the results produced by individual applications, arranging them per result value of the t...
void forwardTerminatorOperands(Block *block, transform::TransformState &state, transform::TransformResults &results)
Populates results with payload associations that match exactly those of the operands to block's termi...
void getParamProducerTransformOpTraitEffects(Operation *op, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
Non-template implementation of ParamProducerTransformOpTrait::getEffects().
LogicalResult checkNestedConsumption(Location loc, ArrayRef< Operation * > targets)
Reports an error and returns failure if targets contains an ancestor operation before its descendant ...
LogicalResult mapPossibleTopLevelTransformOpBlockArguments(TransformState &state, Operation *op, Region &region)
Maps the only block argument of the op with PossibleTopLevelTransformOpTrait to either the list of op...
TransformState makeTransformStateForTesting(Region *region, Operation *payloadRoot)
Make a dummy transform state for testing purposes.
LogicalResult checkApplyToOne(Operation *transformOp, Location payloadOpLoc, const ApplyToEachResultList &partialResult)
Check that the contents of partialResult matches the number, kind (payload op or parameter) and nulli...
SmallVector< OpOperand * > getConsumedHandleOpOperands(transform::TransformOpInterface transformOp)
Returns all operands that are handles and being consumed by the given op.
LogicalResult verifyParamProducerTransformOpTrait(Operation *op)
Non-template implementation of ParamProducerTransformOpTrait::verify().
void prepareValueMappings(SmallVectorImpl< SmallVector< transform::MappedValue >> &mappings, ValueRange values, const transform::TransformState &state)
Populates mappings with mapped values associated with the given transform IR values in the given stat...
LogicalResult verifyPossibleTopLevelTransformOpTrait(Operation *op)
Verification hook for PossibleTopLevelTransformOpTrait.
void getPotentialTopLevelEffects(Operation *operation, Value root, Block &body, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
Populates effects with side effects implied by PossibleTopLevelTransformOpTrait for the given operati...
void onlyReadsPayload(SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
LogicalResult applyTransforms(Operation *payloadRoot, TransformOpInterface transform, const RaggedArray< MappedValue > &extraMapping={}, const TransformOptions &options=TransformOptions(), bool enforceToplevelTransformOp=true, function_ref< void(TransformState &)> stateInitializer=nullptr, function_ref< LogicalResult(TransformState &)> stateExporter=nullptr)
Entry point to the Transform dialect infrastructure.
bool isHandleConsumed(Value handle, transform::TransformOpInterface transform)
Checks whether the transform op consumes the given handle.
void getConsumedBlockArguments(Block &block, llvm::SmallDenseSet< unsigned > &consumedArguments)
Populates consumedArguments with positions of block arguments that are consumed by the operations in ...
bool doesModifyPayload(transform::TransformOpInterface transform)
Checks whether the transform op modifies the payload.
void producesHandle(ResultRange handles, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
bool doesReadPayload(transform::TransformOpInterface transform)
Checks whether the transform op reads the payload.
void consumesHandle(MutableArrayRef< OpOperand > handles, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
Populates effects with the memory effects indicating the operation on the given handle value:
void onlyReadsHandle(MutableArrayRef< OpOperand > handles, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
void modifiesPayload(SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
Populates effects with the memory effects indicating the access to payload IR resource.
Include the generated interface declarations.
const FrozenRewritePatternSet GreedyRewriteConfig config
DiagnosedSilenceableFailure emitSilenceableFailure(Location loc, const Twine &message={})
Emits a silenceable failure with the given message.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
DiagnosedDefiniteFailure emitDefiniteFailure(Location loc, const Twine &message={})
Emits a definite failure with the given message.
bool hasEffect(Operation *op)
Returns "true" if op has an effect of type EffectTy.
Represents a range (offset, size, and stride) where each element of the triple may be dynamic or stat...
A configuration object for customizing a TrackingListener.