MLIR  18.0.0git
TransformInterfaces.cpp
Go to the documentation of this file.
1 //===- TransformInterfaces.cpp - Transform Dialect Interfaces -------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
13 #include "mlir/IR/Diagnostics.h"
14 #include "mlir/IR/Operation.h"
15 #include "mlir/IR/PatternMatch.h"
19 #include "llvm/ADT/STLExtras.h"
20 #include "llvm/ADT/ScopeExit.h"
21 #include "llvm/Support/Debug.h"
22 #include "llvm/Support/ErrorHandling.h"
23 
24 #define DEBUG_TYPE "transform-dialect"
25 #define DEBUG_TYPE_FULL "transform-dialect-full"
26 #define DEBUG_PRINT_AFTER_ALL "transform-dialect-print-top-level-after-all"
27 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "] ")
28 #define LDBG(X) LLVM_DEBUG(DBGS() << (X))
29 #define FULL_LDBG(X) DEBUG_WITH_TYPE(DEBUG_TYPE_FULL, (DBGS() << (X)))
30 
31 using namespace mlir;
32 
33 //===----------------------------------------------------------------------===//
34 // TransformState
35 //===----------------------------------------------------------------------===//
36 
37 constexpr const Value transform::TransformState::kTopLevelValue;
38 
39 transform::TransformState::TransformState(
40  Region *region, Operation *payloadRoot,
41  const RaggedArray<MappedValue> &extraMappings,
42  const TransformOptions &options)
43  : topLevel(payloadRoot), options(options) {
44  topLevelMappedValues.reserve(extraMappings.size());
45  for (ArrayRef<MappedValue> mapping : extraMappings)
46  topLevelMappedValues.push_back(mapping);
47 
48  auto result =
49  mappings.insert(std::make_pair(region, std::make_unique<Mappings>()));
50  assert(result.second && "the region scope is already present");
51  (void)result;
52 #if LLVM_ENABLE_ABI_BREAKING_CHECKS
53  regionStack.push_back(region);
54 #endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
55 }
56 
57 Operation *transform::TransformState::getTopLevel() const { return topLevel; }
58 
60 transform::TransformState::getPayloadOpsView(Value value) const {
61  const TransformOpMapping &operationMapping = getMapping(value).direct;
62  auto iter = operationMapping.find(value);
63  assert(
64  iter != operationMapping.end() &&
65  "cannot find mapping for payload handle (param/value handle provided?)");
66  return iter->getSecond();
67 }
68 
70  const ParamMapping &mapping = getMapping(value).params;
71  auto iter = mapping.find(value);
72  assert(iter != mapping.end() && "cannot find mapping for param handle "
73  "(operation/value handle provided?)");
74  return iter->getSecond();
75 }
76 
78 transform::TransformState::getPayloadValuesView(Value handleValue) const {
79  const ValueMapping &mapping = getMapping(handleValue).values;
80  auto iter = mapping.find(handleValue);
81  assert(iter != mapping.end() && "cannot find mapping for value handle "
82  "(param/operation handle provided?)");
83  return iter->getSecond();
84 }
85 
87  Operation *op, SmallVectorImpl<Value> &handles,
88  bool includeOutOfScope) const {
89  bool found = false;
90  for (const auto &[region, mapping] : llvm::reverse(mappings)) {
91  auto iterator = mapping->reverse.find(op);
92  if (iterator != mapping->reverse.end()) {
93  llvm::append_range(handles, iterator->getSecond());
94  found = true;
95  }
96  // Stop looking when reaching a region that is isolated from above.
97  if (!includeOutOfScope &&
98  region->getParentOp()->hasTrait<OpTrait::IsIsolatedFromAbove>())
99  break;
100  }
101 
102  return success(found);
103 }
104 
106  Value payloadValue, SmallVectorImpl<Value> &handles,
107  bool includeOutOfScope) const {
108  bool found = false;
109  for (const auto &[region, mapping] : llvm::reverse(mappings)) {
110  auto iterator = mapping->reverseValues.find(payloadValue);
111  if (iterator != mapping->reverseValues.end()) {
112  llvm::append_range(handles, iterator->getSecond());
113  found = true;
114  }
115  // Stop looking when reaching a region that is isolated from above.
116  if (!includeOutOfScope &&
117  region->getParentOp()->hasTrait<OpTrait::IsIsolatedFromAbove>())
118  break;
119  }
120 
121  return success(found);
122 }
123 
124 /// Given a list of MappedValues, cast them to the value kind implied by the
125 /// interface of the handle type, and dispatch to one of the callbacks.
131  if (llvm::isa<transform::TransformHandleTypeInterface>(handle.getType())) {
132  SmallVector<Operation *> operations;
133  operations.reserve(values.size());
134  for (transform::MappedValue value : values) {
135  if (auto *op = llvm::dyn_cast_if_present<Operation *>(value)) {
136  operations.push_back(op);
137  continue;
138  }
139  return emitSilenceableFailure(handle.getLoc())
140  << "wrong kind of value provided for top-level operation handle";
141  }
142  if (failed(operationsFn(operations)))
145  }
146 
147  if (llvm::isa<transform::TransformValueHandleTypeInterface>(
148  handle.getType())) {
149  SmallVector<Value> payloadValues;
150  payloadValues.reserve(values.size());
151  for (transform::MappedValue value : values) {
152  if (auto v = llvm::dyn_cast_if_present<Value>(value)) {
153  payloadValues.push_back(v);
154  continue;
155  }
156  return emitSilenceableFailure(handle.getLoc())
157  << "wrong kind of value provided for the top-level value handle";
158  }
159  if (failed(valuesFn(payloadValues)))
162  }
163 
164  assert(llvm::isa<transform::TransformParamTypeInterface>(handle.getType()) &&
165  "unsupported kind of block argument");
167  parameters.reserve(values.size());
168  for (transform::MappedValue value : values) {
169  if (auto attr = llvm::dyn_cast_if_present<Attribute>(value)) {
170  parameters.push_back(attr);
171  continue;
172  }
173  return emitSilenceableFailure(handle.getLoc())
174  << "wrong kind of value provided for top-level parameter";
175  }
176  if (failed(paramsFn(parameters)))
179 }
180 
183  ArrayRef<MappedValue> values) {
184  return dispatchMappedValues(
185  argument, values,
186  [&](ArrayRef<Operation *> operations) {
187  return setPayloadOps(argument, operations);
188  },
189  [&](ArrayRef<Param> params) {
190  return setParams(argument, params);
191  },
192  [&](ValueRange payloadValues) {
193  return setPayloadValues(argument, payloadValues);
194  })
195  .checkAndReport();
196 }
197 
199 transform::TransformState::setPayloadOps(Value value,
200  ArrayRef<Operation *> targets) {
201  assert(value != kTopLevelValue &&
202  "attempting to reset the transformation root");
203  assert(llvm::isa<TransformHandleTypeInterface>(value.getType()) &&
204  "wrong handle type");
205 
206  for (Operation *target : targets) {
207  if (target)
208  continue;
209  return emitError(value.getLoc())
210  << "attempting to assign a null payload op to this transform value";
211  }
212 
213  auto iface = llvm::cast<TransformHandleTypeInterface>(value.getType());
215  iface.checkPayload(value.getLoc(), targets);
216  if (failed(result.checkAndReport()))
217  return failure();
218 
219  // Setting new payload for the value without cleaning it first is a misuse of
220  // the API, assert here.
221  SmallVector<Operation *> storedTargets(targets.begin(), targets.end());
222  Mappings &mappings = getMapping(value);
223  bool inserted =
224  mappings.direct.insert({value, std::move(storedTargets)}).second;
225  assert(inserted && "value is already associated with another list");
226  (void)inserted;
227 
228  for (Operation *op : targets)
229  mappings.reverse[op].push_back(value);
230 
231  return success();
232 }
233 
235 transform::TransformState::setPayloadValues(Value handle,
236  ValueRange payloadValues) {
237  assert(handle != nullptr && "attempting to set params for a null value");
238  assert(llvm::isa<TransformValueHandleTypeInterface>(handle.getType()) &&
239  "wrong handle type");
240 
241  for (Value payload : payloadValues) {
242  if (payload)
243  continue;
244  return emitError(handle.getLoc()) << "attempting to assign a null payload "
245  "value to this transform handle";
246  }
247 
248  auto iface = llvm::cast<TransformValueHandleTypeInterface>(handle.getType());
249  SmallVector<Value> payloadValueVector = llvm::to_vector(payloadValues);
251  iface.checkPayload(handle.getLoc(), payloadValueVector);
252  if (failed(result.checkAndReport()))
253  return failure();
254 
255  Mappings &mappings = getMapping(handle);
256  bool inserted =
257  mappings.values.insert({handle, std::move(payloadValueVector)}).second;
258  assert(
259  inserted &&
260  "value handle is already associated with another list of payload values");
261  (void)inserted;
262 
263  for (Value payload : payloadValues)
264  mappings.reverseValues[payload].push_back(handle);
265 
266  return success();
267 }
268 
269 LogicalResult transform::TransformState::setParams(Value value,
270  ArrayRef<Param> params) {
271  assert(value != nullptr && "attempting to set params for a null value");
272 
273  for (Attribute attr : params) {
274  if (attr)
275  continue;
276  return emitError(value.getLoc())
277  << "attempting to assign a null parameter to this transform value";
278  }
279 
280  auto valueType = llvm::dyn_cast<TransformParamTypeInterface>(value.getType());
281  assert(value &&
282  "cannot associate parameter with a value of non-parameter type");
284  valueType.checkPayload(value.getLoc(), params);
285  if (failed(result.checkAndReport()))
286  return failure();
287 
288  Mappings &mappings = getMapping(value);
289  bool inserted =
290  mappings.params.insert({value, llvm::to_vector(params)}).second;
291  assert(inserted && "value is already associated with another list of params");
292  (void)inserted;
293  return success();
294 }
295 
296 template <typename Mapping, typename Key, typename Mapped>
297 void dropMappingEntry(Mapping &mapping, Key key, Mapped mapped) {
298  auto it = mapping.find(key);
299  if (it == mapping.end())
300  return;
301 
302  llvm::erase_value(it->getSecond(), mapped);
303  if (it->getSecond().empty())
304  mapping.erase(it);
305 }
306 
307 void transform::TransformState::forgetMapping(Value opHandle,
308  ValueRange origOpFlatResults) {
309  Mappings &mappings = getMapping(opHandle);
310  for (Operation *op : mappings.direct[opHandle])
311  dropMappingEntry(mappings.reverse, op, opHandle);
312  mappings.direct.erase(opHandle);
313 #ifdef LLVM_ENABLE_ABI_BREAKING_CHECKS
314  // Payload IR is removed from the mapping. This invalidates the respective
315  // iterators.
316  mappings.incrementTimestamp(opHandle);
317 #endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
318 
319  for (Value opResult : origOpFlatResults) {
320  SmallVector<Value> resultHandles;
321  (void)getHandlesForPayloadValue(opResult, resultHandles);
322  for (Value resultHandle : resultHandles) {
323  Mappings &localMappings = getMapping(resultHandle);
324  dropMappingEntry(localMappings.values, resultHandle, opResult);
325 #ifdef LLVM_ENABLE_ABI_BREAKING_CHECKS
326  // Payload IR is removed from the mapping. This invalidates the respective
327  // iterators.
328  mappings.incrementTimestamp(resultHandle);
329 #endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
330  dropMappingEntry(localMappings.reverseValues, opResult, resultHandle);
331  }
332  }
333 }
334 
335 void transform::TransformState::forgetValueMapping(
336  Value valueHandle, ArrayRef<Operation *> payloadOperations) {
337  Mappings &mappings = getMapping(valueHandle);
338  for (Value payloadValue : mappings.reverseValues[valueHandle])
339  dropMappingEntry(mappings.reverseValues, payloadValue, valueHandle);
340  mappings.values.erase(valueHandle);
341 #ifdef LLVM_ENABLE_ABI_BREAKING_CHECKS
342  // Payload IR is removed from the mapping. This invalidates the respective
343  // iterators.
344  mappings.incrementTimestamp(valueHandle);
345 #endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
346 
347  for (Operation *payloadOp : payloadOperations) {
348  SmallVector<Value> opHandles;
349  (void)getHandlesForPayloadOp(payloadOp, opHandles);
350  for (Value opHandle : opHandles) {
351  Mappings &localMappings = getMapping(opHandle);
352  dropMappingEntry(localMappings.direct, opHandle, payloadOp);
353  dropMappingEntry(localMappings.reverse, payloadOp, opHandle);
354 
355 #ifdef LLVM_ENABLE_ABI_BREAKING_CHECKS
356  // Payload IR is removed from the mapping. This invalidates the respective
357  // iterators.
358  localMappings.incrementTimestamp(opHandle);
359 #endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
360  }
361  }
362 }
363 
365 transform::TransformState::replacePayloadOp(Operation *op,
366  Operation *replacement) {
367  // TODO: consider invalidating the handles to nested objects here.
368 
369 #ifndef NDEBUG
370  for (Value opResult : op->getResults()) {
371  SmallVector<Value> valueHandles;
372  (void)getHandlesForPayloadValue(opResult, valueHandles,
373  /*includeOutOfScope=*/true);
374  assert(valueHandles.empty() && "expected no mapping to old results");
375  }
376 #endif // NDEBUG
377 
378  // Drop the mapping between the op and all handles that point to it. Fail if
379  // there are no handles.
380  SmallVector<Value> opHandles;
381  if (failed(getHandlesForPayloadOp(op, opHandles, /*includeOutOfScope=*/true)))
382  return failure();
383  for (Value handle : opHandles) {
384  Mappings &mappings = getMapping(handle, /*allowOutOfScope=*/true);
385  dropMappingEntry(mappings.reverse, op, handle);
386  }
387 
388 #if LLVM_ENABLE_ABI_BREAKING_CHECKS
389  if (options.getExpensiveChecksEnabled()) {
390  auto it = cachedNames.find(op);
391  assert(it != cachedNames.end() && "entry not found");
392  assert(it->second == op->getName() && "operation name mismatch");
393  cachedNames.erase(it);
394  if (replacement) {
395  auto insertion =
396  cachedNames.insert({replacement, replacement->getName()});
397  if (!insertion.second) {
398  assert(insertion.first->second == replacement->getName() &&
399  "operation is already cached with a different name");
400  }
401  }
402  }
403 #endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
404 
405  // Replace the pointed-to object of all handles with the replacement object.
406  // In case a payload op was erased (replacement object is nullptr), a nullptr
407  // is stored in the mapping. These nullptrs are removed after each transform.
408  // Furthermore, nullptrs are not enumerated by payload op iterators. The
409  // relative order of ops is preserved.
410  //
411  // Removing an op from the mapping would be problematic because removing an
412  // element from an array invalidates iterators; merely changing the value of
413  // elements does not.
414  for (Value handle : opHandles) {
415  Mappings &mappings = getMapping(handle, /*allowOutOfScope=*/true);
416  auto it = mappings.direct.find(handle);
417  if (it == mappings.direct.end())
418  continue;
419 
420  SmallVector<Operation *, 2> &association = it->getSecond();
421  // Note that an operation may be associated with the handle more than once.
422  for (Operation *&mapped : association) {
423  if (mapped == op)
424  mapped = replacement;
425  }
426 
427  if (replacement) {
428  mappings.reverse[replacement].push_back(handle);
429  } else {
430  opHandlesToCompact.insert(handle);
431  }
432  }
433 
434  return success();
435 }
436 
438 transform::TransformState::replacePayloadValue(Value value, Value replacement) {
439  SmallVector<Value> valueHandles;
440  if (failed(getHandlesForPayloadValue(value, valueHandles,
441  /*includeOutOfScope=*/true)))
442  return failure();
443 
444  for (Value handle : valueHandles) {
445  Mappings &mappings = getMapping(handle, /*allowOutOfScope=*/true);
446  dropMappingEntry(mappings.reverseValues, value, handle);
447 
448  // If replacing with null, that is erasing the mapping, drop the mapping
449  // between the handles and the IR objects
450  if (!replacement) {
451  dropMappingEntry(mappings.values, handle, value);
452 #ifdef LLVM_ENABLE_ABI_BREAKING_CHECKS
453  // Payload IR is removed from the mapping. This invalidates the respective
454  // iterators.
455  mappings.incrementTimestamp(handle);
456 #endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
457  } else {
458  auto it = mappings.values.find(handle);
459  if (it == mappings.values.end())
460  continue;
461 
462  SmallVector<Value> &association = it->getSecond();
463  for (Value &mapped : association) {
464  if (mapped == value)
465  mapped = replacement;
466  }
467  mappings.reverseValues[replacement].push_back(handle);
468  }
469  }
470 
471  return success();
472 }
473 
474 void transform::TransformState::recordOpHandleInvalidationOne(
475  OpOperand &consumingHandle, ArrayRef<Operation *> potentialAncestors,
476  Operation *payloadOp, Value otherHandle, Value throughValue,
477  transform::TransformState::InvalidatedHandleMap &newlyInvalidated) const {
478  // If the op is associated with invalidated handle, skip the check as it
479  // may be reading invalid IR. This also ensures we report the first
480  // invalidation and not the last one.
481  if (invalidatedHandles.count(otherHandle) ||
482  newlyInvalidated.count(otherHandle))
483  return;
484 
485  FULL_LDBG("--recordOpHandleInvalidationOne\n");
486  DEBUG_WITH_TYPE(
488  llvm::interleaveComma(potentialAncestors, DBGS() << "--ancestors: ",
489  [](Operation *op) { llvm::dbgs() << *op; });
490  llvm::dbgs() << "\n");
491 
492  Operation *owner = consumingHandle.getOwner();
493  unsigned operandNo = consumingHandle.getOperandNumber();
494  for (Operation *ancestor : potentialAncestors) {
495  // clang-format off
496  DEBUG_WITH_TYPE(DEBUG_TYPE_FULL,
497  { (DBGS() << "----handle one ancestor: " << *ancestor << "\n"); });
498  DEBUG_WITH_TYPE(DEBUG_TYPE_FULL,
499  { (DBGS() << "----of payload with name: "
500  << payloadOp->getName().getIdentifier() << "\n"); });
501  DEBUG_WITH_TYPE(DEBUG_TYPE_FULL,
502  { (DBGS() << "----of payload: " << *payloadOp << "\n"); });
503  // clang-format on
504  if (!ancestor->isAncestor(payloadOp))
505  continue;
506 
507  // Make sure the error-reporting lambda doesn't capture anything
508  // by-reference because it will go out of scope. Additionally, extract
509  // location from Payload IR ops because the ops themselves may be
510  // deleted before the lambda gets called.
511  Location ancestorLoc = ancestor->getLoc();
512  Location opLoc = payloadOp->getLoc();
513  std::optional<Location> throughValueLoc =
514  throughValue ? std::make_optional(throughValue.getLoc()) : std::nullopt;
515  newlyInvalidated[otherHandle] = [ancestorLoc, opLoc, owner, operandNo,
516  otherHandle,
517  throughValueLoc](Location currentLoc) {
518  InFlightDiagnostic diag = emitError(currentLoc)
519  << "op uses a handle invalidated by a "
520  "previously executed transform op";
521  diag.attachNote(otherHandle.getLoc()) << "handle to invalidated ops";
522  diag.attachNote(owner->getLoc())
523  << "invalidated by this transform op that consumes its operand #"
524  << operandNo
525  << " and invalidates all handles to payload IR entities associated "
526  "with this operand and entities nested in them";
527  diag.attachNote(ancestorLoc) << "ancestor payload op";
528  diag.attachNote(opLoc) << "nested payload op";
529  if (throughValueLoc) {
530  diag.attachNote(*throughValueLoc)
531  << "consumed handle points to this payload value";
532  }
533  };
534  }
535 }
536 
537 void transform::TransformState::recordValueHandleInvalidationByOpHandleOne(
538  OpOperand &opHandle, ArrayRef<Operation *> potentialAncestors,
539  Value payloadValue, Value valueHandle,
540  transform::TransformState::InvalidatedHandleMap &newlyInvalidated) const {
541  // If the op is associated with invalidated handle, skip the check as it
542  // may be reading invalid IR. This also ensures we report the first
543  // invalidation and not the last one.
544  if (invalidatedHandles.count(valueHandle) ||
545  newlyInvalidated.count(valueHandle))
546  return;
547 
548  for (Operation *ancestor : potentialAncestors) {
549  Operation *definingOp;
550  std::optional<unsigned> resultNo;
551  unsigned argumentNo = std::numeric_limits<unsigned>::max();
552  unsigned blockNo = std::numeric_limits<unsigned>::max();
553  unsigned regionNo = std::numeric_limits<unsigned>::max();
554  if (auto opResult = llvm::dyn_cast<OpResult>(payloadValue)) {
555  definingOp = opResult.getOwner();
556  resultNo = opResult.getResultNumber();
557  } else {
558  auto arg = llvm::cast<BlockArgument>(payloadValue);
559  definingOp = arg.getParentBlock()->getParentOp();
560  argumentNo = arg.getArgNumber();
561  blockNo = std::distance(arg.getOwner()->getParent()->begin(),
562  arg.getOwner()->getIterator());
563  regionNo = arg.getOwner()->getParent()->getRegionNumber();
564  }
565  assert(definingOp && "expected the value to be defined by an op as result "
566  "or block argument");
567  if (!ancestor->isAncestor(definingOp))
568  continue;
569 
570  Operation *owner = opHandle.getOwner();
571  unsigned operandNo = opHandle.getOperandNumber();
572  Location ancestorLoc = ancestor->getLoc();
573  Location opLoc = definingOp->getLoc();
574  Location valueLoc = payloadValue.getLoc();
575  newlyInvalidated[valueHandle] = [valueHandle, owner, operandNo, resultNo,
576  argumentNo, blockNo, regionNo, ancestorLoc,
577  opLoc, valueLoc](Location currentLoc) {
578  InFlightDiagnostic diag = emitError(currentLoc)
579  << "op uses a handle invalidated by a "
580  "previously executed transform op";
581  diag.attachNote(valueHandle.getLoc()) << "invalidated handle";
582  diag.attachNote(owner->getLoc())
583  << "invalidated by this transform op that consumes its operand #"
584  << operandNo
585  << " and invalidates all handles to payload IR entities "
586  "associated with this operand and entities nested in them";
587  diag.attachNote(ancestorLoc)
588  << "ancestor op associated with the consumed handle";
589  if (resultNo) {
590  diag.attachNote(opLoc)
591  << "op defining the value as result #" << *resultNo;
592  } else {
593  diag.attachNote(opLoc)
594  << "op defining the value as block argument #" << argumentNo
595  << " of block #" << blockNo << " in region #" << regionNo;
596  }
597  diag.attachNote(valueLoc) << "payload value";
598  };
599  }
600 }
601 
602 void transform::TransformState::recordOpHandleInvalidation(
603  OpOperand &handle, ArrayRef<Operation *> potentialAncestors,
604  Value throughValue,
605  transform::TransformState::InvalidatedHandleMap &newlyInvalidated) const {
606 
607  if (potentialAncestors.empty()) {
608  DEBUG_WITH_TYPE(DEBUG_TYPE_FULL, {
609  (DBGS() << "----recording invalidation for empty handle: " << handle.get()
610  << "\n");
611  });
612 
613  Operation *owner = handle.getOwner();
614  unsigned operandNo = handle.getOperandNumber();
615  newlyInvalidated[handle.get()] = [owner, operandNo](Location currentLoc) {
616  InFlightDiagnostic diag = emitError(currentLoc)
617  << "op uses a handle associated with empty "
618  "payload and invalidated by a "
619  "previously executed transform op";
620  diag.attachNote(owner->getLoc())
621  << "invalidated by this transform op that consumes its operand #"
622  << operandNo;
623  };
624  return;
625  }
626 
627  // Iterate over the mapping and invalidate aliasing handles. This is quite
628  // expensive and only necessary for error reporting in case of transform
629  // dialect misuse with dangling handles. Iteration over the handles is based
630  // on the assumption that the number of handles is significantly less than the
631  // number of IR objects (operations and values). Alternatively, we could walk
632  // the IR nested in each payload op associated with the given handle and look
633  // for handles associated with each operation and value.
634  for (const auto &[region, mapping] : llvm::reverse(mappings)) {
635  // Stop lookup when reaching a region that is isolated from above.
636  if (region->getParentOp()->hasTrait<OpTrait::IsIsolatedFromAbove>())
637  break;
638  // Go over all op handle mappings and mark as invalidated any handle
639  // pointing to any of the payload ops associated with the given handle or
640  // any op nested in them.
641  for (const auto &[payloadOp, otherHandles] : mapping->reverse) {
642  for (Value otherHandle : otherHandles)
643  recordOpHandleInvalidationOne(handle, potentialAncestors, payloadOp,
644  otherHandle, throughValue,
645  newlyInvalidated);
646  }
647  // Go over all value handle mappings and mark as invalidated any handle
648  // pointing to any result of the payload op associated with the given handle
649  // or any op nested in them. Similarly invalidate handles to argument of
650  // blocks belonging to any region of any payload op associated with the
651  // given handle or any op nested in them.
652  for (const auto &[payloadValue, valueHandles] : mapping->reverseValues) {
653  for (Value valueHandle : valueHandles)
654  recordValueHandleInvalidationByOpHandleOne(handle, potentialAncestors,
655  payloadValue, valueHandle,
656  newlyInvalidated);
657  }
658  }
659 }
660 
661 void transform::TransformState::recordValueHandleInvalidation(
662  OpOperand &valueHandle,
663  transform::TransformState::InvalidatedHandleMap &newlyInvalidated) const {
664  // Invalidate other handles to the same value.
665  for (Value payloadValue : getPayloadValuesView(valueHandle.get())) {
666  SmallVector<Value> otherValueHandles;
667  (void)getHandlesForPayloadValue(payloadValue, otherValueHandles);
668  for (Value otherHandle : otherValueHandles) {
669  Operation *owner = valueHandle.getOwner();
670  unsigned operandNo = valueHandle.getOperandNumber();
671  Location valueLoc = payloadValue.getLoc();
672  newlyInvalidated[otherHandle] = [otherHandle, owner, operandNo,
673  valueLoc](Location currentLoc) {
674  InFlightDiagnostic diag = emitError(currentLoc)
675  << "op uses a handle invalidated by a "
676  "previously executed transform op";
677  diag.attachNote(otherHandle.getLoc()) << "invalidated handle";
678  diag.attachNote(owner->getLoc())
679  << "invalidated by this transform op that consumes its operand #"
680  << operandNo
681  << " and invalidates handles to the same values as associated with "
682  "it";
683  diag.attachNote(valueLoc) << "payload value";
684  };
685  }
686 
687  if (auto opResult = llvm::dyn_cast<OpResult>(payloadValue)) {
688  Operation *payloadOp = opResult.getOwner();
689  recordOpHandleInvalidation(valueHandle, payloadOp, payloadValue,
690  newlyInvalidated);
691  } else {
692  auto arg = llvm::dyn_cast<BlockArgument>(payloadValue);
693  for (Operation &payloadOp : *arg.getOwner())
694  recordOpHandleInvalidation(valueHandle, &payloadOp, payloadValue,
695  newlyInvalidated);
696  }
697  }
698 }
699 
700 /// Checks that the operation does not use invalidated handles as operands.
701 /// Reports errors and returns failure if it does. Otherwise, invalidates the
702 /// handles consumed by the operation as well as any handles pointing to payload
703 /// IR operations nested in the operations associated with the consumed handles.
704 LogicalResult transform::TransformState::checkAndRecordHandleInvalidationImpl(
705  transform::TransformOpInterface transform,
706  transform::TransformState::InvalidatedHandleMap &newlyInvalidated) const {
707  FULL_LDBG("--Start checkAndRecordHandleInvalidation\n");
708  auto memoryEffectsIface =
709  cast<MemoryEffectOpInterface>(transform.getOperation());
711  memoryEffectsIface.getEffectsOnResource(
713 
714  for (OpOperand &target : transform->getOpOperands()) {
715  DEBUG_WITH_TYPE(DEBUG_TYPE_FULL, {
716  (DBGS() << "----iterate on handle: " << target.get() << "\n");
717  });
718  // If the operand uses an invalidated handle, report it. If the operation
719  // allows handles to point to repeated payload operations, only report
720  // pre-existing invalidation errors. Otherwise, also report invalidations
721  // caused by the current transform operation affecting its other operands.
722  auto it = invalidatedHandles.find(target.get());
723  auto nit = newlyInvalidated.find(target.get());
724  if (it != invalidatedHandles.end()) {
725  FULL_LDBG("--End checkAndRecordHandleInvalidation, found already "
726  "invalidated -> FAILURE\n");
727  return it->getSecond()(transform->getLoc()), failure();
728  }
729  if (!transform.allowsRepeatedHandleOperands() &&
730  nit != newlyInvalidated.end()) {
731  FULL_LDBG("--End checkAndRecordHandleInvalidation, found newly "
732  "invalidated (by this op) -> FAILURE\n");
733  return nit->getSecond()(transform->getLoc()), failure();
734  }
735 
736  // Invalidate handles pointing to the operations nested in the operation
737  // associated with the handle consumed by this operation.
738  auto consumesTarget = [&](const MemoryEffects::EffectInstance &effect) {
739  return isa<MemoryEffects::Free>(effect.getEffect()) &&
740  effect.getValue() == target.get();
741  };
742  if (llvm::any_of(effects, consumesTarget)) {
743  FULL_LDBG("----found consume effect\n");
744  if (llvm::isa<transform::TransformHandleTypeInterface>(
745  target.get().getType())) {
746  FULL_LDBG("----recordOpHandleInvalidation\n");
747  SmallVector<Operation *> payloadOps =
748  llvm::to_vector(getPayloadOps(target.get()));
749  recordOpHandleInvalidation(target, payloadOps, nullptr,
750  newlyInvalidated);
751  } else if (llvm::isa<transform::TransformValueHandleTypeInterface>(
752  target.get().getType())) {
753  FULL_LDBG("----recordValueHandleInvalidation\n");
754  recordValueHandleInvalidation(target, newlyInvalidated);
755  } else {
756  FULL_LDBG("----not a TransformHandle -> SKIP AND DROP ON THE FLOOR\n");
757  }
758  } else {
759  FULL_LDBG("----no consume effect -> SKIP\n");
760  }
761  }
762 
763  FULL_LDBG("--End checkAndRecordHandleInvalidation -> SUCCESS\n");
764  return success();
765 }
766 
767 LogicalResult transform::TransformState::checkAndRecordHandleInvalidation(
768  transform::TransformOpInterface transform) {
769  InvalidatedHandleMap newlyInvalidated;
770  LogicalResult checkResult =
771  checkAndRecordHandleInvalidationImpl(transform, newlyInvalidated);
772  invalidatedHandles.insert(std::make_move_iterator(newlyInvalidated.begin()),
773  std::make_move_iterator(newlyInvalidated.end()));
774  return checkResult;
775 }
776 
777 template <typename T>
780  transform::TransformOpInterface transform,
781  unsigned operandNumber) {
782  DenseSet<T> seen;
783  for (T p : payload) {
784  if (!seen.insert(p).second) {
786  transform.emitSilenceableError()
787  << "a handle passed as operand #" << operandNumber
788  << " and consumed by this operation points to a payload "
789  "entity more than once";
790  if constexpr (std::is_pointer_v<T>)
791  diag.attachNote(p->getLoc()) << "repeated target op";
792  else
793  diag.attachNote(p.getLoc()) << "repeated target value";
794  return diag;
795  }
796  }
798 }
799 
800 void transform::TransformState::compactOpHandles() {
801  for (Value handle : opHandlesToCompact) {
802  Mappings &mappings = getMapping(handle, /*allowOutOfScope=*/true);
803 #ifdef LLVM_ENABLE_ABI_BREAKING_CHECKS
804  if (llvm::find(mappings.direct[handle], nullptr) !=
805  mappings.direct[handle].end())
806  // Payload IR is removed from the mapping. This invalidates the respective
807  // iterators.
808  mappings.incrementTimestamp(handle);
809 #endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
810  llvm::erase_value(mappings.direct[handle], nullptr);
811  }
812  opHandlesToCompact.clear();
813 }
814 
816 transform::TransformState::applyTransform(TransformOpInterface transform) {
817  LLVM_DEBUG({
818  DBGS() << "applying: ";
819  transform->print(llvm::dbgs(), OpPrintingFlags().skipRegions());
820  llvm::dbgs() << "\n";
821  });
822  DEBUG_WITH_TYPE(DEBUG_TYPE_FULL,
823  DBGS() << "Top-level payload before application:\n"
824  << *getTopLevel() << "\n");
825  auto printOnFailureRAII = llvm::make_scope_exit([this] {
826  (void)this;
827  LLVM_DEBUG(DBGS() << "Failing Top-level payload:\n"; getTopLevel()->print(
828  llvm::dbgs(), mlir::OpPrintingFlags().printGenericOpForm()););
829  });
830  if (options.getExpensiveChecksEnabled()) {
831  FULL_LDBG("ExpensiveChecksEnabled\n");
832  if (failed(checkAndRecordHandleInvalidation(transform)))
834 
835  for (OpOperand &operand : transform->getOpOperands()) {
836  DEBUG_WITH_TYPE(DEBUG_TYPE_FULL, {
837  (DBGS() << "iterate on handle: " << operand.get() << "\n");
838  });
839  if (!isHandleConsumed(operand.get(), transform)) {
840  FULL_LDBG("--handle not consumed -> SKIP\n");
841  continue;
842  }
843  if (transform.allowsRepeatedHandleOperands()) {
844  FULL_LDBG("--op allows repeated handles -> SKIP\n");
845  continue;
846  }
847  FULL_LDBG("--handle is consumed\n");
848 
849  Type operandType = operand.get().getType();
850  if (llvm::isa<TransformHandleTypeInterface>(operandType)) {
851  FULL_LDBG("--checkRepeatedConsumptionInOperand for Operation*\n");
853  checkRepeatedConsumptionInOperand<Operation *>(
854  getPayloadOpsView(operand.get()), transform,
855  operand.getOperandNumber());
856  if (!check.succeeded()) {
857  FULL_LDBG("----FAILED\n");
858  return check;
859  }
860  } else if (llvm::isa<TransformValueHandleTypeInterface>(operandType)) {
861  FULL_LDBG("--checkRepeatedConsumptionInOperand For Value\n");
863  checkRepeatedConsumptionInOperand<Value>(
864  getPayloadValuesView(operand.get()), transform,
865  operand.getOperandNumber());
866  if (!check.succeeded()) {
867  FULL_LDBG("----FAILED\n");
868  return check;
869  }
870  } else {
871  FULL_LDBG("--not a TransformHandle -> SKIP AND DROP ON THE FLOOR\n");
872  }
873  }
874 
875 #if LLVM_ENABLE_ABI_BREAKING_CHECKS
876  // Cache Operation* -> OperationName mappings. These will be checked after
877  // the transform has been applied to detect incorrect memory side effects
878  // and missing op tracking.
879  for (std::unique_ptr<Mappings> &mapping :
880  llvm::make_second_range(mappings)) {
881  for (Operation *op : llvm::make_first_range(mapping->reverse)) {
882  auto insertion = cachedNames.insert({op, op->getName()});
883  if (!insertion.second) {
884  if (insertion.first->second != op->getName()) {
885  // Operation is already in the cache, but with a different name.
887  emitDefiniteFailure(transform->getLoc())
888  << "expensive checks failure: operation mismatch, expected "
889  << insertion.first->second;
890  diag.attachNote(op->getLoc()) << "payload op: " << op->getName();
891  return diag;
892  }
893  }
894  }
895  }
896 #endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
897  }
898 
899  // Find which operands are consumed.
900  SmallVector<OpOperand *> consumedOperands =
901  transform.getConsumedHandleOpOperands();
902 
903  // Remember the results of the payload ops associated with the consumed
904  // op handles or the ops defining the value handles so we can drop the
905  // association with them later. This must happen here because the
906  // transformation may destroy or mutate them so we cannot traverse the payload
907  // IR after that.
908  SmallVector<Value> origOpFlatResults;
909  SmallVector<Operation *> origAssociatedOps;
910 #if LLVM_ENABLE_ABI_BREAKING_CHECKS
911  DenseSet<Operation *> consumedPayloadOps;
912 #endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
913  for (OpOperand *opOperand : consumedOperands) {
914  Value operand = opOperand->get();
915  if (llvm::isa<TransformHandleTypeInterface>(operand.getType())) {
916  for (Operation *payloadOp : getPayloadOps(operand)) {
917  llvm::append_range(origOpFlatResults, payloadOp->getResults());
918 #if LLVM_ENABLE_ABI_BREAKING_CHECKS
919  if (options.getExpensiveChecksEnabled()) {
920  // Store all consumed payload ops (and their nested ops) in a set for
921  // extra error checking.
922  payloadOp->walk(
923  [&](Operation *op) { consumedPayloadOps.insert(op); });
924  }
925 #endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
926  }
927  continue;
928  }
929  if (llvm::isa<TransformValueHandleTypeInterface>(operand.getType())) {
930  for (Value payloadValue : getPayloadValuesView(operand)) {
931  if (llvm::isa<OpResult>(payloadValue)) {
932  origAssociatedOps.push_back(payloadValue.getDefiningOp());
933  continue;
934  }
935  llvm::append_range(
936  origAssociatedOps,
937  llvm::map_range(*llvm::cast<BlockArgument>(payloadValue).getOwner(),
938  [](Operation &op) { return &op; }));
939  }
940  continue;
941  }
943  emitDefiniteFailure(transform->getLoc())
944  << "unexpectedly consumed a value that is not a handle as operand #"
945  << opOperand->getOperandNumber();
946  diag.attachNote(operand.getLoc())
947  << "value defined here with type " << operand.getType();
948  return diag;
949  }
950 
951  // Prepare rewriter and listener.
952  transform::ErrorCheckingTrackingListener trackingListener(*this, transform);
953  transform::TransformRewriter rewriter(transform->getContext(),
954  &trackingListener);
955 
956  // Compute the result but do not short-circuit the silenceable failure case as
957  // we still want the handles to propagate properly so the "suppress" mode can
958  // proceed on a best effort basis.
959  transform::TransformResults results(transform->getNumResults());
960  DiagnosedSilenceableFailure result(transform.apply(rewriter, results, *this));
961  compactOpHandles();
962 
963  // Error handling: fail if transform or listener failed.
964  DiagnosedSilenceableFailure trackingFailure =
965  trackingListener.checkAndResetError();
966  if (!transform->hasTrait<ReportTrackingListenerFailuresOpTrait>() ||
967  transform->hasAttr(
968  transform::TransformDialect::kSilenceTrackingFailuresAttrName)) {
969  // Only report failures for ReportTrackingListenerFailuresOpTrait ops. Also
970  // do not report failures if the above mentioned attribute is set.
971  if (trackingFailure.isSilenceableFailure())
972  (void)trackingFailure.silence();
973  trackingFailure = DiagnosedSilenceableFailure::success();
974  }
975  if (!trackingFailure.succeeded()) {
976  if (result.succeeded()) {
977  result = std::move(trackingFailure);
978  } else {
979  // Transform op errors have precedence, report those first.
980  if (result.isSilenceableFailure())
981  result.attachNote() << "tracking listener also failed: "
982  << trackingFailure.getMessage();
983  (void)trackingFailure.silence();
984  }
985  }
986  if (result.isDefiniteFailure())
987  return result;
988 
989  // If a silenceable failure was produced, some results may be unset, set them
990  // to empty lists.
991  if (result.isSilenceableFailure())
992  results.setRemainingToEmpty(transform);
993 
994  // Remove the mapping for the operand if it is consumed by the operation. This
995  // allows us to catch use-after-free with assertions later on.
996  for (OpOperand *opOperand : consumedOperands) {
997  Value operand = opOperand->get();
998  if (llvm::isa<TransformHandleTypeInterface>(operand.getType())) {
999  forgetMapping(operand, origOpFlatResults);
1000  } else if (llvm::isa<TransformValueHandleTypeInterface>(
1001  operand.getType())) {
1002  forgetValueMapping(operand, origAssociatedOps);
1003  }
1004  }
1005 
1006 #if LLVM_ENABLE_ABI_BREAKING_CHECKS
1007  if (options.getExpensiveChecksEnabled()) {
1008  // Remove erased ops from the transform state.
1009  for (Operation *op : consumedPayloadOps) {
1010  // This payload op was consumed but it may still be mapped to one or
1011  // multiple handles. Forget all handles that are mapped to the op, so that
1012  // there are no dangling pointers in the transform dialect state. This is
1013  // necessary so that the `cachedNames`-based checks work correctly.
1014  //
1015  // Note: Dangling pointers to erased payload ops are allowed if the
1016  // corresponding handles are not used anymore. There is another
1017  // "expensive-check" that looks for future uses of dangling payload op
1018  // pointers (through arbitrary handles). Removing handles to erased ops
1019  // does not interfere with the other expensive checks: handle invalidation
1020  // happens earlier and keeps track of invalidated handles with
1021  // pre-generated error messages, so we do not need the association to
1022  // still be there when the invalidated handle is accessed.
1023  SmallVector<Value> handles;
1024  (void)getHandlesForPayloadOp(op, handles);
1025  for (Value handle : handles)
1026  forgetMapping(handle, /*origOpFlatResults=*/ValueRange());
1027  cachedNames.erase(op);
1028  }
1029 
1030  // Check cached operation names.
1031  for (std::unique_ptr<Mappings> &mapping :
1032  llvm::make_second_range(mappings)) {
1033  for (Operation *op : llvm::make_first_range(mapping->reverse)) {
1034  // Make sure that the name of the op has not changed. If it has changed,
1035  // the op was removed and a new op was allocated at the same memory
1036  // location. This means that we are missing op tracking somewhere.
1037  auto cacheIt = cachedNames.find(op);
1038  if (cacheIt == cachedNames.end()) {
1040  emitDefiniteFailure(transform->getLoc())
1041  << "expensive checks failure: operation not found in cache";
1042  diag.attachNote(op->getLoc()) << "payload op";
1043  return diag;
1044  }
1045  // If the `getName` call (or the above `attachNote`) is crashing, we
1046  // have a dangling pointer. This usually means that an op was erased but
1047  // the transform dialect was not made aware of that; e.g., missing
1048  // "consumesHandle" or rewriter usage.
1049  if (cacheIt->second != op->getName()) {
1051  emitDefiniteFailure(transform->getLoc())
1052  << "expensive checks failure: operation mismatch, expected "
1053  << cacheIt->second;
1054  diag.attachNote(op->getLoc()) << "payload op: " << op->getName();
1055  return diag;
1056  }
1057  }
1058  }
1059  }
1060 #endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
1061 
1062  if (failed(updateStateFromResults(results, transform->getResults())))
1064 
1065  printOnFailureRAII.release();
1066  DEBUG_WITH_TYPE(DEBUG_PRINT_AFTER_ALL, {
1067  DBGS() << "Top-level payload:\n";
1068  getTopLevel()->print(llvm::dbgs());
1069  });
1070  return result;
1071 }
1072 
1073 LogicalResult transform::TransformState::updateStateFromResults(
1074  const TransformResults &results, ResultRange opResults) {
1075  for (OpResult result : opResults) {
1076  if (llvm::isa<TransformParamTypeInterface>(result.getType())) {
1077  assert(results.isParam(result.getResultNumber()) &&
1078  "expected parameters for the parameter-typed result");
1079  if (failed(
1080  setParams(result, results.getParams(result.getResultNumber())))) {
1081  return failure();
1082  }
1083  } else if (llvm::isa<TransformValueHandleTypeInterface>(result.getType())) {
1084  assert(results.isValue(result.getResultNumber()) &&
1085  "expected values for value-type-result");
1086  if (failed(setPayloadValues(
1087  result, results.getValues(result.getResultNumber())))) {
1088  return failure();
1089  }
1090  } else {
1091  assert(!results.isParam(result.getResultNumber()) &&
1092  "expected payload ops for the non-parameter typed result");
1093  if (failed(
1094  setPayloadOps(result, results.get(result.getResultNumber())))) {
1095  return failure();
1096  }
1097  }
1098  }
1099  return success();
1100 }
1101 
1102 //===----------------------------------------------------------------------===//
1103 // TransformState::Extension
1104 //===----------------------------------------------------------------------===//
1105 
1107 
1110  Operation *replacement) {
1111  // TODO: we may need to invalidate handles to operations and values nested in
1112  // the operation being replaced.
1113  return state.replacePayloadOp(op, replacement);
1114 }
1115 
1118  Value replacement) {
1119  return state.replacePayloadValue(value, replacement);
1120 }
1121 
1122 //===----------------------------------------------------------------------===//
1123 // TransformState::RegionScope
1124 //===----------------------------------------------------------------------===//
1125 
1127  // Remove handle invalidation notices as handles are going out of scope.
1128  // The same region may be re-entered leading to incorrect invalidation
1129  // errors.
1130  for (Block &block : *region) {
1131  for (Value handle : block.getArguments()) {
1132  state.invalidatedHandles.erase(handle);
1133  }
1134  for (Operation &op : block) {
1135  for (Value handle : op.getResults()) {
1136  state.invalidatedHandles.erase(handle);
1137  }
1138  }
1139  }
1140 
1141 #if LLVM_ENABLE_ABI_BREAKING_CHECKS
1142  // Remember pointers to payload ops referenced by the handles going out of
1143  // scope.
1144  SmallVector<Operation *> referencedOps =
1145  llvm::to_vector(llvm::make_first_range(state.mappings[region]->reverse));
1146 #endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
1147 
1148  state.mappings.erase(region);
1149 
1150 #if LLVM_ENABLE_ABI_BREAKING_CHECKS
1151  // If the last handle to a payload op has gone out of scope, we no longer
1152  // need to store the cached name. Pointers may get reused, leading to
1153  // incorrect associations in the cache.
1154  for (Operation *op : referencedOps) {
1155  SmallVector<Value> handles;
1156  if (succeeded(state.getHandlesForPayloadOp(op, handles)))
1157  continue;
1158  state.cachedNames.erase(op);
1159  }
1160 
1161  state.regionStack.pop_back();
1162 #endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
1163 }
1164 
1165 //===----------------------------------------------------------------------===//
1166 // TransformResults
1167 //===----------------------------------------------------------------------===//
1168 
1169 transform::TransformResults::TransformResults(unsigned numSegments) {
1170  operations.appendEmptyRows(numSegments);
1171  params.appendEmptyRows(numSegments);
1172  values.appendEmptyRows(numSegments);
1173 }
1174 
1177  int64_t position = value.getResultNumber();
1178  assert(position < static_cast<int64_t>(this->params.size()) &&
1179  "setting params for a non-existent handle");
1180  assert(this->params[position].data() == nullptr && "params already set");
1181  assert(operations[position].data() == nullptr &&
1182  "another kind of results already set");
1183  assert(values[position].data() == nullptr &&
1184  "another kind of results already set");
1185  this->params.replace(position, params);
1186 }
1187 
1189  OpResult handle, ArrayRef<MappedValue> values) {
1191  handle, values,
1192  [&](ArrayRef<Operation *> operations) {
1193  return set(handle, operations), success();
1194  },
1195  [&](ArrayRef<Param> params) {
1196  return setParams(handle, params), success();
1197  },
1198  [&](ValueRange payloadValues) {
1199  return setValues(handle, payloadValues), success();
1200  });
1201 #ifndef NDEBUG
1202  if (!diag.succeeded())
1203  llvm::dbgs() << diag.getStatusString() << "\n";
1204  assert(diag.succeeded() && "incorrect mapping");
1205 #endif // NDEBUG
1206  (void)diag.silence();
1207 }
1208 
1210  transform::TransformOpInterface transform) {
1211  for (OpResult opResult : transform->getResults()) {
1212  if (!isSet(opResult.getResultNumber()))
1213  setMappedValues(opResult, {});
1214  }
1215 }
1216 
1218 transform::TransformResults::get(unsigned resultNumber) const {
1219  assert(resultNumber < operations.size() &&
1220  "querying results for a non-existent handle");
1221  assert(operations[resultNumber].data() != nullptr &&
1222  "querying unset results (values or params expected?)");
1223  return operations[resultNumber];
1224 }
1225 
1227 transform::TransformResults::getParams(unsigned resultNumber) const {
1228  assert(resultNumber < params.size() &&
1229  "querying params for a non-existent handle");
1230  assert(params[resultNumber].data() != nullptr &&
1231  "querying unset params (ops or values expected?)");
1232  return params[resultNumber];
1233 }
1234 
1236 transform::TransformResults::getValues(unsigned resultNumber) const {
1237  assert(resultNumber < values.size() &&
1238  "querying values for a non-existent handle");
1239  assert(values[resultNumber].data() != nullptr &&
1240  "querying unset values (ops or params expected?)");
1241  return values[resultNumber];
1242 }
1243 
1244 bool transform::TransformResults::isParam(unsigned resultNumber) const {
1245  assert(resultNumber < params.size() &&
1246  "querying association for a non-existent handle");
1247  return params[resultNumber].data() != nullptr;
1248 }
1249 
1250 bool transform::TransformResults::isValue(unsigned resultNumber) const {
1251  assert(resultNumber < values.size() &&
1252  "querying association for a non-existent handle");
1253  return values[resultNumber].data() != nullptr;
1254 }
1255 
1256 bool transform::TransformResults::isSet(unsigned resultNumber) const {
1257  assert(resultNumber < params.size() &&
1258  "querying association for a non-existent handle");
1259  return params[resultNumber].data() != nullptr ||
1260  operations[resultNumber].data() != nullptr ||
1261  values[resultNumber].data() != nullptr;
1262 }
1263 
1264 //===----------------------------------------------------------------------===//
1265 // TrackingListener
1266 //===----------------------------------------------------------------------===//
1267 
1269  TransformOpInterface op)
1270  : TransformState::Extension(state), transformOp(op) {
1271  if (op) {
1272  for (OpOperand *opOperand : transformOp.getConsumedHandleOpOperands()) {
1273  consumedHandles.insert(opOperand->get());
1274  }
1275  }
1276 }
1277 
1279  Operation *defOp = nullptr;
1280  for (Value v : values) {
1281  // Skip empty values.
1282  if (!v)
1283  continue;
1284  if (!defOp) {
1285  defOp = v.getDefiningOp();
1286  continue;
1287  }
1288  if (defOp != v.getDefiningOp())
1289  return nullptr;
1290  }
1291  return defOp;
1292 }
1293 
1295  Operation *&result, Operation *op, ValueRange newValues) const {
1296  assert(op->getNumResults() == newValues.size() &&
1297  "invalid number of replacement values");
1298  SmallVector<Value> values(newValues.begin(), newValues.end());
1299 
1301  getTransformOp(), "tracking listener failed to find replacement op "
1302  "during application of this transform op");
1303 
1304  do {
1305  // If the replacement values belong to different ops, drop the mapping.
1306  Operation *defOp = getCommonDefiningOp(values);
1307  if (!defOp) {
1308  diag.attachNote() << "replacement values belong to different ops";
1309  return diag;
1310  }
1311 
1312  // If the defining op has the same type, we take it as a replacement.
1313  if (op->getName() == defOp->getName()) {
1314  result = defOp;
1316  }
1317 
1318  // Replacing an op with a constant-like equivalent is a common
1319  // canonicalization.
1320  if (defOp->hasTrait<OpTrait::ConstantLike>()) {
1321  result = defOp;
1323  }
1324 
1325  values.clear();
1326 
1327  // Skip through ops that implement FindPayloadReplacementOpInterface.
1328  if (auto findReplacementOpInterface =
1329  dyn_cast<FindPayloadReplacementOpInterface>(defOp)) {
1330  values.assign(findReplacementOpInterface.getNextOperands());
1331  diag.attachNote(defOp->getLoc()) << "using operands provided by "
1332  "'FindPayloadReplacementOpInterface'";
1333  continue;
1334  }
1335 
1336  // Skip through ops that implement CastOpInterface.
1337  if (isa<CastOpInterface>(defOp)) {
1338  values.assign(defOp->getOperands().begin(), defOp->getOperands().end());
1339  diag.attachNote(defOp->getLoc())
1340  << "using output of 'CastOpInterface' op";
1341  continue;
1342  }
1343  } while (!values.empty());
1344 
1345  diag.attachNote() << "ran out of suitable replacement values";
1346  return diag;
1347 }
1348 
1350  Location loc, function_ref<void(Diagnostic &)> reasonCallback) {
1351  LLVM_DEBUG({
1353  reasonCallback(diag);
1354  DBGS() << "Match Failure : " << diag.str() << "\n";
1355  });
1356  return failure();
1357 }
1358 
1359 void transform::TrackingListener::notifyOperationRemoved(Operation *op) {
1360  // TODO: Walk can be removed when D144193 has landed.
1361  op->walk([&](Operation *op) {
1362  // Remove mappings for result values.
1363  for (OpResult value : op->getResults())
1364  (void)replacePayloadValue(value, nullptr);
1365  // Remove mapping for op.
1366  (void)replacePayloadOp(op, nullptr);
1367  });
1368 }
1369 
1370 /// Return true if `a` happens before `b`, i.e., `a` or one of its ancestors
1371 /// properly dominates `b` and `b` is not inside `a`.
1372 static bool happensBefore(Operation *a, Operation *b) {
1373  do {
1374  if (a->isProperAncestor(b))
1375  return false;
1376  if (Operation *bAncestor = a->getBlock()->findAncestorOpInBlock(*b)) {
1377  return a->isBeforeInBlock(bAncestor);
1378  }
1379  } while ((a = a->getParentOp()));
1380  return false;
1381 }
1382 
1383 void transform::TrackingListener::notifyOperationReplaced(
1384  Operation *op, ValueRange newValues) {
1385  assert(op->getNumResults() == newValues.size() &&
1386  "invalid number of replacement values");
1387 
1388  // Replace value handles.
1389  for (auto [oldValue, newValue] : llvm::zip(op->getResults(), newValues))
1390  (void)replacePayloadValue(oldValue, newValue);
1391 
1392  // Replace op handle.
1393  SmallVector<Value> opHandles;
1394  if (failed(getTransformState().getHandlesForPayloadOp(
1395  op, opHandles, /*includeOutOfScope=*/true))) {
1396  // Op is not tracked.
1397  return;
1398  }
1399 
1400  // Helper function to check if the current transform op consumes any handle
1401  // that is mapped to `op`.
1402  //
1403  // Note: If a handle was consumed, there shouldn't be any alive users, so it
1404  // is not really necessary to check for consumed handles. However, in case
1405  // there are indeed alive handles that were consumed (which is undefined
1406  // behavior) and a replacement op could not be found, we want to fail with a
1407  // nicer error message: "op uses a handle invalidated..." instead of "could
1408  // not find replacement op". This nicer error is produced later.
1409  auto handleWasConsumed = [&] {
1410  return llvm::any_of(opHandles,
1411  [&](Value h) { return consumedHandles.contains(h); });
1412  };
1413 
1414  // Helper function to check if the handle is alive.
1415  auto firstAliveUser = [&]() -> std::optional<OpOperand *> {
1416  for (Value v : opHandles) {
1417  for (OpOperand &use : v.getUses())
1418  if (use.getOwner() != transformOp &&
1419  !happensBefore(use.getOwner(), transformOp))
1420  return &use;
1421  }
1422  return std::nullopt;
1423  }();
1424 
1425  if (!firstAliveUser.has_value() || handleWasConsumed()) {
1426  // The op is tracked but the corresponding handles are dead or were
1427  // consumed. Drop the op form the mapping.
1428  (void)replacePayloadOp(op, nullptr);
1429  return;
1430  }
1431 
1432  Operation *replacement;
1434  findReplacementOp(replacement, op, newValues);
1435  // If the op is tracked but no replacement op was found, send a
1436  // notification.
1437  if (!diag.succeeded()) {
1438  diag.attachNote((*firstAliveUser)->getOwner()->getLoc())
1439  << "replacement is required because alive handle(s) exist "
1440  << "(first use in this op as operand number "
1441  << (*firstAliveUser)->getOperandNumber() << ")";
1442  notifyPayloadReplacementNotFound(op, newValues, std::move(diag));
1443  (void)replacePayloadOp(op, nullptr);
1444  return;
1445  }
1446 
1447  (void)replacePayloadOp(op, replacement);
1448 }
1449 
1451  // The state of the ErrorCheckingTrackingListener must be checked and reset
1452  // if there was an error. This is to prevent errors from accidentally being
1453  // missed.
1454  assert(status.succeeded() && "listener state was not checked");
1455 }
1456 
1459  DiagnosedSilenceableFailure s = std::move(status);
1461  errorCounter = 0;
1462  return s;
1463 }
1464 
1466  return !status.succeeded();
1467 }
1468 
1471 
1472  // Merge potentially existing diags and store the result in the listener.
1474  diag.takeDiagnostics(diags);
1475  if (!status.succeeded())
1476  status.takeDiagnostics(diags);
1477  status = DiagnosedSilenceableFailure::silenceableFailure(std::move(diags));
1478 
1479  // Report more details.
1480  status.attachNote(op->getLoc()) << "[" << errorCounter << "] replaced op";
1481  for (auto &&[index, value] : llvm::enumerate(values))
1482  status.attachNote(value.getLoc())
1483  << "[" << errorCounter << "] replacement value " << index;
1484  ++errorCounter;
1485 }
1486 
1487 //===----------------------------------------------------------------------===//
1488 // TransformRewriter
1489 //===----------------------------------------------------------------------===//
1490 
1493  : RewriterBase(ctx), listener(listener) {
1494  setListener(listener);
1495 }
1496 
1498  return listener->failed();
1499 }
1500 
1501 /// Silence all tracking failures that have been encountered so far.
1503  if (hasTrackingFailures()) {
1504  DiagnosedSilenceableFailure status = listener->checkAndResetError();
1505  (void)status.silence();
1506  }
1507 }
1508 
1510  Operation *op, Operation *replacement) {
1511  return listener->replacePayloadOp(op, replacement);
1512 }
1513 
1514 //===----------------------------------------------------------------------===//
1515 // Utilities for TransformEachOpTrait.
1516 //===----------------------------------------------------------------------===//
1517 
1520  ArrayRef<Operation *> targets) {
1521  for (auto &&[position, parent] : llvm::enumerate(targets)) {
1522  for (Operation *child : targets.drop_front(position + 1)) {
1523  if (parent->isAncestor(child)) {
1525  emitError(loc)
1526  << "transform operation consumes a handle pointing to an ancestor "
1527  "payload operation before its descendant";
1528  diag.attachNote()
1529  << "the ancestor is likely erased or rewritten before the "
1530  "descendant is accessed, leading to undefined behavior";
1531  diag.attachNote(parent->getLoc()) << "ancestor payload op";
1532  diag.attachNote(child->getLoc()) << "descendant payload op";
1533  return diag;
1534  }
1535  }
1536  }
1537  return success();
1538 }
1539 
1542  Location payloadOpLoc,
1543  const ApplyToEachResultList &partialResult) {
1544  Location transformOpLoc = transformOp->getLoc();
1545  StringRef transformOpName = transformOp->getName().getStringRef();
1546  unsigned expectedNumResults = transformOp->getNumResults();
1547 
1548  // Reuse the emission of the diagnostic note.
1549  auto emitDiag = [&]() {
1550  auto diag = mlir::emitError(transformOpLoc);
1551  diag.attachNote(payloadOpLoc) << "when applied to this op";
1552  return diag;
1553  };
1554 
1555  if (partialResult.size() != expectedNumResults) {
1556  auto diag = emitDiag() << "application of " << transformOpName
1557  << " expected to produce " << expectedNumResults
1558  << " results (actually produced "
1559  << partialResult.size() << ").";
1560  diag.attachNote(transformOpLoc)
1561  << "if you need variadic results, consider a generic `apply` "
1562  << "instead of the specialized `applyToOne`.";
1563  return failure();
1564  }
1565 
1566  // Check that the right kind of value was produced.
1567  for (const auto &[ptr, res] :
1568  llvm::zip(partialResult, transformOp->getResults())) {
1569  if (ptr.isNull())
1570  continue;
1571  if (llvm::isa<TransformHandleTypeInterface>(res.getType()) &&
1572  !ptr.is<Operation *>()) {
1573  return emitDiag() << "application of " << transformOpName
1574  << " expected to produce an Operation * for result #"
1575  << res.getResultNumber();
1576  }
1577  if (llvm::isa<TransformParamTypeInterface>(res.getType()) &&
1578  !ptr.is<Attribute>()) {
1579  return emitDiag() << "application of " << transformOpName
1580  << " expected to produce an Attribute for result #"
1581  << res.getResultNumber();
1582  }
1583  if (llvm::isa<TransformValueHandleTypeInterface>(res.getType()) &&
1584  !ptr.is<Value>()) {
1585  return emitDiag() << "application of " << transformOpName
1586  << " expected to produce a Value for result #"
1587  << res.getResultNumber();
1588  }
1589  }
1590  return success();
1591 }
1592 
1593 template <typename T>
1595  return llvm::to_vector(llvm::map_range(
1596  range, [](transform::MappedValue value) { return value.get<T>(); }));
1597 }
1598 
1600  Operation *transformOp, TransformResults &transformResults,
1603  transposed.resize(transformOp->getNumResults());
1604  for (const ApplyToEachResultList &partialResults : results) {
1605  if (llvm::any_of(partialResults,
1606  [](MappedValue value) { return value.isNull(); }))
1607  continue;
1608  assert(transformOp->getNumResults() == partialResults.size() &&
1609  "expected as many partial results as op as results");
1610  for (auto [i, value] : llvm::enumerate(partialResults))
1611  transposed[i].push_back(value);
1612  }
1613 
1614  for (OpResult r : transformOp->getResults()) {
1615  unsigned position = r.getResultNumber();
1616  if (llvm::isa<TransformParamTypeInterface>(r.getType())) {
1617  transformResults.setParams(r,
1618  castVector<Attribute>(transposed[position]));
1619  } else if (llvm::isa<TransformValueHandleTypeInterface>(r.getType())) {
1620  transformResults.setValues(r, castVector<Value>(transposed[position]));
1621  } else {
1622  transformResults.set(r, castVector<Operation *>(transposed[position]));
1623  }
1624  }
1625 }
1626 
1627 //===----------------------------------------------------------------------===//
1628 // Utilities for implementing transform ops with regions.
1629 //===----------------------------------------------------------------------===//
1630 
1633  ValueRange values, const transform::TransformState &state) {
1634  for (Value operand : values) {
1635  SmallVector<MappedValue> &mapped = mappings.emplace_back();
1636  if (llvm::isa<TransformHandleTypeInterface>(operand.getType())) {
1637  llvm::append_range(mapped, state.getPayloadOps(operand));
1638  } else if (llvm::isa<TransformValueHandleTypeInterface>(
1639  operand.getType())) {
1640  llvm::append_range(mapped, state.getPayloadValues(operand));
1641  } else {
1642  assert(llvm::isa<TransformParamTypeInterface>(operand.getType()) &&
1643  "unsupported kind of transform dialect value");
1644  llvm::append_range(mapped, state.getParams(operand));
1645  }
1646  }
1647 }
1648 
1650  Block *block, transform::TransformState &state,
1651  transform::TransformResults &results) {
1652  for (auto &&[terminatorOperand, result] :
1653  llvm::zip(block->getTerminator()->getOperands(),
1654  block->getParentOp()->getOpResults())) {
1655  if (llvm::isa<transform::TransformHandleTypeInterface>(result.getType())) {
1656  results.set(result, state.getPayloadOps(terminatorOperand));
1657  } else if (llvm::isa<transform::TransformValueHandleTypeInterface>(
1658  result.getType())) {
1659  results.setValues(result, state.getPayloadValues(terminatorOperand));
1660  } else {
1661  assert(
1662  llvm::isa<transform::TransformParamTypeInterface>(result.getType()) &&
1663  "unhandled transform type interface");
1664  results.setParams(result, state.getParams(terminatorOperand));
1665  }
1666  }
1667 }
1668 
1671  Operation *payloadRoot) {
1672  return TransformState(region, payloadRoot);
1673 }
1674 
1675 //===----------------------------------------------------------------------===//
1676 // Utilities for PossibleTopLevelTransformOpTrait.
1677 //===----------------------------------------------------------------------===//
1678 
1679 /// Appends to `effects` the memory effect instances on `target` with the same
1680 /// resource and effect as the ones the operation `iface` having on `source`.
1681 static void
1682 remapEffects(MemoryEffectOpInterface iface, BlockArgument source, Value target,
1685  iface.getEffectsOnValue(source, nestedEffects);
1686  for (const auto &effect : nestedEffects)
1687  effects.emplace_back(effect.getEffect(), target, effect.getResource());
1688 }
1689 
1690 /// Appends to `effects` the same effects as the operations of `block` have on
1691 /// block arguments but associated with `operands.`
1692 static void
1695  for (Operation &op : block) {
1696  auto iface = dyn_cast<MemoryEffectOpInterface>(&op);
1697  if (!iface)
1698  continue;
1699 
1700  for (auto &&[source, target] : llvm::zip(block.getArguments(), operands)) {
1701  remapEffects(iface, source, target, effects);
1702  }
1703 
1705  iface.getEffectsOnResource(transform::PayloadIRResource::get(),
1706  nestedEffects);
1707  llvm::append_range(effects, nestedEffects);
1708  }
1709 }
1710 
1712  Operation *operation, Value root, Block &body,
1714  transform::onlyReadsHandle(operation->getOperands(), effects);
1715  transform::producesHandle(operation->getResults(), effects);
1716 
1717  if (!root) {
1718  for (Operation &op : body) {
1719  auto iface = dyn_cast<MemoryEffectOpInterface>(&op);
1720  if (!iface)
1721  continue;
1722 
1724  iface.getEffects(effects);
1725  }
1726  return;
1727  }
1728 
1729  // Carry over all effects on arguments of the entry block as those on the
1730  // operands, this is the same value just remapped.
1731  remapArgumentEffects(body, operation->getOperands(), effects);
1732 }
1733 
1735  TransformState &state, Operation *op, Region &region) {
1736  SmallVector<Operation *> targets;
1737  SmallVector<SmallVector<MappedValue>> extraMappings;
1738  if (op->getNumOperands() != 0) {
1739  llvm::append_range(targets, state.getPayloadOps(op->getOperand(0)));
1740  prepareValueMappings(extraMappings, op->getOperands().drop_front(), state);
1741  } else {
1742  if (state.getNumTopLevelMappings() !=
1743  region.front().getNumArguments() - 1) {
1744  return emitError(op->getLoc())
1745  << "operation expects " << region.front().getNumArguments() - 1
1746  << " extra value bindings, but " << state.getNumTopLevelMappings()
1747  << " were provided to the interpreter";
1748  }
1749 
1750  // Top-level transforms can be used for matching. If no concrete operation
1751  // type is specified, the block argument is mapped to the top-level op.
1752  // Otherwise, it is mapped to all ops of the specified type within the
1753  // top-level op (including the top-level op itself). Once an op is added as
1754  // a target, its descendants are not explored any further.
1755  BlockArgument bbArg = region.front().getArgument(0);
1756  if (auto bbArgType = dyn_cast<transform::OperationType>(bbArg.getType())) {
1757  state.getTopLevel()->walk<WalkOrder::PreOrder>([&](Operation *op) {
1758  if (op->getName().getStringRef() == bbArgType.getOperationName()) {
1759  targets.push_back(op);
1760  return WalkResult::skip();
1761  }
1762  return WalkResult::advance();
1763  });
1764  } else {
1765  targets.push_back(state.getTopLevel());
1766  }
1767 
1768  for (unsigned i = 0, e = state.getNumTopLevelMappings(); i < e; ++i)
1769  extraMappings.push_back(llvm::to_vector(state.getTopLevelMapping(i)));
1770  }
1771 
1772  if (failed(state.mapBlockArguments(region.front().getArgument(0), targets)))
1773  return failure();
1774 
1775  for (BlockArgument argument : region.front().getArguments().drop_front()) {
1776  if (failed(state.mapBlockArgument(
1777  argument, extraMappings[argument.getArgNumber() - 1])))
1778  return failure();
1779  }
1780 
1781  return success();
1782 }
1783 
1786  // Attaching this trait without the interface is a misuse of the API, but it
1787  // cannot be caught via a static_assert because interface registration is
1788  // dynamic.
1789  assert(isa<TransformOpInterface>(op) &&
1790  "should implement TransformOpInterface to have "
1791  "PossibleTopLevelTransformOpTrait");
1792 
1793  if (op->getNumRegions() < 1)
1794  return op->emitOpError() << "expects at least one region";
1795 
1796  Region *bodyRegion = &op->getRegion(0);
1797  if (!llvm::hasNItems(*bodyRegion, 1))
1798  return op->emitOpError() << "expects a single-block region";
1799 
1800  Block *body = &bodyRegion->front();
1801  if (body->getNumArguments() == 0) {
1802  return op->emitOpError()
1803  << "expects the entry block to have at least one argument";
1804  }
1805  if (!llvm::isa<TransformHandleTypeInterface>(
1806  body->getArgument(0).getType())) {
1807  return op->emitOpError()
1808  << "expects the first entry block argument to be of type "
1809  "implementing TransformHandleTypeInterface";
1810  }
1811  BlockArgument arg = body->getArgument(0);
1812  if (op->getNumOperands() != 0) {
1813  if (arg.getType() != op->getOperand(0).getType()) {
1814  return op->emitOpError()
1815  << "expects the type of the block argument to match "
1816  "the type of the operand";
1817  }
1818  }
1819  for (BlockArgument arg : body->getArguments().drop_front()) {
1820  if (llvm::isa<TransformHandleTypeInterface, TransformParamTypeInterface,
1821  TransformValueHandleTypeInterface>(arg.getType()))
1822  continue;
1823 
1825  op->emitOpError()
1826  << "expects trailing entry block arguments to be of type implementing "
1827  "TransformHandleTypeInterface, TransformValueHandleTypeInterface or "
1828  "TransformParamTypeInterface";
1829  diag.attachNote() << "argument #" << arg.getArgNumber() << " does not";
1830  return diag;
1831  }
1832 
1833  if (auto *parent =
1835  if (op->getNumOperands() != body->getNumArguments()) {
1837  op->emitOpError()
1838  << "expects operands to be provided for a nested op";
1839  diag.attachNote(parent->getLoc())
1840  << "nested in another possible top-level op";
1841  return diag;
1842  }
1843  }
1844 
1845  return success();
1846 }
1847 
1848 //===----------------------------------------------------------------------===//
1849 // Utilities for ParamProducedTransformOpTrait.
1850 //===----------------------------------------------------------------------===//
1851 
1854  producesHandle(op->getResults(), effects);
1855  bool hasPayloadOperands = false;
1856  for (Value operand : op->getOperands()) {
1857  onlyReadsHandle(operand, effects);
1858  if (llvm::isa<TransformHandleTypeInterface,
1859  TransformValueHandleTypeInterface>(operand.getType()))
1860  hasPayloadOperands = true;
1861  }
1862  if (hasPayloadOperands)
1863  onlyReadsPayload(effects);
1864 }
1865 
1868  // Interfaces can be attached dynamically, so this cannot be a static
1869  // assert.
1870  if (!op->getName().getInterface<MemoryEffectOpInterface>()) {
1871  llvm::report_fatal_error(
1872  Twine("ParamProducerTransformOpTrait must be attached to an op that "
1873  "implements MemoryEffectsOpInterface, found on ") +
1874  op->getName().getStringRef());
1875  }
1876  for (Value result : op->getResults()) {
1877  if (llvm::isa<TransformParamTypeInterface>(result.getType()))
1878  continue;
1879  return op->emitOpError()
1880  << "ParamProducerTransformOpTrait attached to this op expects "
1881  "result types to implement TransformParamTypeInterface";
1882  }
1883  return success();
1884 }
1885 
1886 //===----------------------------------------------------------------------===//
1887 // Memory effects.
1888 //===----------------------------------------------------------------------===//
1889 
1891  ValueRange handles,
1893  for (Value handle : handles) {
1894  effects.emplace_back(MemoryEffects::Read::get(), handle,
1896  effects.emplace_back(MemoryEffects::Free::get(), handle,
1898  }
1899 }
1900 
1901 /// Returns `true` if the given list of effects instances contains an instance
1902 /// with the effect type specified as template parameter.
1903 template <typename EffectTy, typename ResourceTy, typename Range>
1904 static bool hasEffect(Range &&effects) {
1905  return llvm::any_of(effects, [](const MemoryEffects::EffectInstance &effect) {
1906  return isa<EffectTy>(effect.getEffect()) &&
1907  isa<ResourceTy>(effect.getResource());
1908  });
1909 }
1910 
1912  transform::TransformOpInterface transform) {
1913  auto iface = cast<MemoryEffectOpInterface>(transform.getOperation());
1915  iface.getEffectsOnValue(handle, effects);
1916  return ::hasEffect<MemoryEffects::Read, TransformMappingResource>(effects) &&
1917  ::hasEffect<MemoryEffects::Free, TransformMappingResource>(effects);
1918 }
1919 
1921  ValueRange handles,
1923  for (Value handle : handles) {
1924  effects.emplace_back(MemoryEffects::Allocate::get(), handle,
1926  effects.emplace_back(MemoryEffects::Write::get(), handle,
1928  }
1929 }
1930 
1932  ValueRange handles,
1934  for (Value handle : handles) {
1935  effects.emplace_back(MemoryEffects::Read::get(), handle,
1937  }
1938 }
1939 
1942  effects.emplace_back(MemoryEffects::Read::get(), PayloadIRResource::get());
1943  effects.emplace_back(MemoryEffects::Write::get(), PayloadIRResource::get());
1944 }
1945 
1948  effects.emplace_back(MemoryEffects::Read::get(), PayloadIRResource::get());
1949 }
1950 
1951 bool transform::doesModifyPayload(transform::TransformOpInterface transform) {
1952  auto iface = cast<MemoryEffectOpInterface>(transform.getOperation());
1954  iface.getEffects(effects);
1955  return ::hasEffect<MemoryEffects::Write, PayloadIRResource>(effects);
1956 }
1957 
1958 bool transform::doesReadPayload(transform::TransformOpInterface transform) {
1959  auto iface = cast<MemoryEffectOpInterface>(transform.getOperation());
1961  iface.getEffects(effects);
1962  return ::hasEffect<MemoryEffects::Read, PayloadIRResource>(effects);
1963 }
1964 
1966  Block &block, llvm::SmallDenseSet<unsigned int> &consumedArguments) {
1968  for (Operation &nested : block) {
1969  auto iface = dyn_cast<MemoryEffectOpInterface>(nested);
1970  if (!iface)
1971  continue;
1972 
1973  effects.clear();
1974  iface.getEffects(effects);
1975  for (const MemoryEffects::EffectInstance &effect : effects) {
1976  BlockArgument argument =
1977  dyn_cast_or_null<BlockArgument>(effect.getValue());
1978  if (!argument || argument.getOwner() != &block ||
1979  !isa<MemoryEffects::Free>(effect.getEffect()) ||
1980  effect.getResource() != transform::TransformMappingResource::get()) {
1981  continue;
1982  }
1983  consumedArguments.insert(argument.getArgNumber());
1984  }
1985  }
1986 }
1987 
1988 //===----------------------------------------------------------------------===//
1989 // Utilities for TransformOpInterface.
1990 //===----------------------------------------------------------------------===//
1991 
1993  TransformOpInterface transformOp) {
1994  SmallVector<OpOperand *> consumedOperands;
1995  consumedOperands.reserve(transformOp->getNumOperands());
1996  auto memEffectInterface =
1997  cast<MemoryEffectOpInterface>(transformOp.getOperation());
1999  for (OpOperand &target : transformOp->getOpOperands()) {
2000  effects.clear();
2001  memEffectInterface.getEffectsOnValue(target.get(), effects);
2002  if (llvm::any_of(effects, [](const MemoryEffects::EffectInstance &effect) {
2003  return isa<transform::TransformMappingResource>(
2004  effect.getResource()) &&
2005  isa<MemoryEffects::Free>(effect.getEffect());
2006  })) {
2007  consumedOperands.push_back(&target);
2008  }
2009  }
2010  return consumedOperands;
2011 }
2012 
2014  auto iface = cast<MemoryEffectOpInterface>(op);
2016  iface.getEffects(effects);
2017 
2018  auto effectsOn = [&](Value value) {
2019  return llvm::make_filter_range(
2020  effects, [value](const MemoryEffects::EffectInstance &instance) {
2021  return instance.getValue() == value;
2022  });
2023  };
2024 
2025  std::optional<unsigned> firstConsumedOperand;
2026  for (OpOperand &operand : op->getOpOperands()) {
2027  auto range = effectsOn(operand.get());
2028  if (range.empty()) {
2030  op->emitError() << "TransformOpInterface requires memory effects "
2031  "on operands to be specified";
2032  diag.attachNote() << "no effects specified for operand #"
2033  << operand.getOperandNumber();
2034  return diag;
2035  }
2036  if (::hasEffect<MemoryEffects::Allocate, TransformMappingResource>(range)) {
2038  << "TransformOpInterface did not expect "
2039  "'allocate' memory effect on an operand";
2040  diag.attachNote() << "specified for operand #"
2041  << operand.getOperandNumber();
2042  return diag;
2043  }
2044  if (!firstConsumedOperand &&
2045  ::hasEffect<MemoryEffects::Free, TransformMappingResource>(range)) {
2046  firstConsumedOperand = operand.getOperandNumber();
2047  }
2048  }
2049 
2050  if (firstConsumedOperand &&
2051  !::hasEffect<MemoryEffects::Write, PayloadIRResource>(effects)) {
2053  op->emitError()
2054  << "TransformOpInterface expects ops consuming operands to have a "
2055  "'write' effect on the payload resource";
2056  diag.attachNote() << "consumes operand #" << *firstConsumedOperand;
2057  return diag;
2058  }
2059 
2060  for (OpResult result : op->getResults()) {
2061  auto range = effectsOn(result);
2062  if (!::hasEffect<MemoryEffects::Allocate, TransformMappingResource>(
2063  range)) {
2065  op->emitError() << "TransformOpInterface requires 'allocate' memory "
2066  "effect to be specified for results";
2067  diag.attachNote() << "no 'allocate' effect specified for result #"
2068  << result.getResultNumber();
2069  return diag;
2070  }
2071  }
2072 
2073  return success();
2074 }
2075 
2076 //===----------------------------------------------------------------------===//
2077 // Entry point.
2078 //===----------------------------------------------------------------------===//
2079 
2082  TransformOpInterface transform,
2083  const RaggedArray<MappedValue> &extraMapping,
2084  const TransformOptions &options) {
2085 #ifndef NDEBUG
2086  if (!transform->hasTrait<PossibleTopLevelTransformOpTrait>() ||
2087  transform->getNumOperands() != 0) {
2088  transform->emitError()
2089  << "expected transform to start at the top-level transform op";
2090  llvm::report_fatal_error("could not run transforms",
2091  /*gen_crash_diag=*/false);
2092  }
2093 #endif // NDEBUG
2094 
2095  TransformState state(transform->getParentRegion(), payloadRoot, extraMapping,
2096  options);
2097  return state.applyTransform(transform).checkAndReport();
2098 }
2099 
2100 //===----------------------------------------------------------------------===//
2101 // Generated interface implementation.
2102 //===----------------------------------------------------------------------===//
2103 
2104 #include "mlir/Dialect/Transform/IR/TransformInterfaces.cpp.inc"
static InFlightDiagnostic emitDiag(Location location, DiagnosticSeverity severity, const Twine &message)
Helper function used to emit a diagnostic with an optionally empty twine message.
static std::string diag(const llvm::Value &value)
static llvm::ManagedStatic< PassManagerOptions > options
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
#define DEBUG_TYPE_FULL
static void remapEffects(MemoryEffectOpInterface iface, BlockArgument source, Value target, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
Appends to effects the memory effect instances on target with the same resource and effect as the one...
#define FULL_LDBG(X)
static void remapArgumentEffects(Block &block, ValueRange operands, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
Appends to effects the same effects as the operations of block have on block arguments but associated...
static bool happensBefore(Operation *a, Operation *b)
Return true if a happens before b, i.e., a or one of its ancestors properly dominates b and b is not ...
void dropMappingEntry(Mapping &mapping, Key key, Mapped mapped)
#define DEBUG_PRINT_AFTER_ALL
DiagnosedSilenceableFailure checkRepeatedConsumptionInOperand(ArrayRef< T > payload, transform::TransformOpInterface transform, unsigned operandNumber)
#define DBGS()
static DiagnosedSilenceableFailure dispatchMappedValues(Value handle, ArrayRef< transform::MappedValue > values, function_ref< LogicalResult(ArrayRef< Operation * >)> operationsFn, function_ref< LogicalResult(ArrayRef< transform::Param >)> paramsFn, function_ref< LogicalResult(ValueRange)> valuesFn)
Given a list of MappedValues, cast them to the value kind implied by the interface of the handle type...
static SmallVector< T > castVector(ArrayRef< transform::MappedValue > range)
Attributes are known-constant values of operations.
Definition: Attributes.h:25
This class represents an argument of a Block.
Definition: Value.h:310
Block * getOwner() const
Returns the block that owns this argument.
Definition: Value.h:319
unsigned getArgNumber() const
Returns the number of this argument.
Definition: Value.h:322
Block represents an ordered list of Operations.
Definition: Block.h:30
Operation * findAncestorOpInBlock(Operation &op)
Returns 'op' if 'op' lies in this block, or otherwise finds the ancestor operation of 'op' that lies ...
Definition: Block.cpp:68
BlockArgument getArgument(unsigned i)
Definition: Block.h:122
unsigned getNumArguments()
Definition: Block.h:121
Operation * getTerminator()
Get the terminator operation of this block.
Definition: Block.cpp:238
BlockArgListType getArguments()
Definition: Block.h:80
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
Definition: Block.cpp:30
A compatibility class connecting InFlightDiagnostic to DiagnosedSilenceableFailure while providing an...
The result of a transform IR operation application.
LogicalResult silence()
Converts silenceable failure into LogicalResult success without reporting the diagnostic,...
static DiagnosedSilenceableFailure success()
Constructs a DiagnosedSilenceableFailure in the success state.
Diagnostic & attachNote(std::optional< Location > loc=std::nullopt)
Attaches a note to the last diagnostic.
std::string getMessage() const
Returns the diagnostic message without emitting it.
bool isDefiniteFailure() const
Returns true if this is a definite failure.
LogicalResult checkAndReport()
Converts all kinds of failure into a LogicalResult failure, emitting the diagnostic if necessary.
static DiagnosedSilenceableFailure silenceableFailure(Diagnostic &&diag)
Constructs a DiagnosedSilenceableFailure in the silenceable failure state, ready to emit the given di...
bool succeeded() const
Returns true if this is a success.
static DiagnosedSilenceableFailure definiteFailure()
Constructs a DiagnosedSilenceableFailure in the failure state.
bool isSilenceableFailure() const
Returns true if this is a silenceable failure.
This class contains all of the information necessary to report a diagnostic to the DiagnosticEngine.
Definition: Diagnostics.h:156
IRValueT get() const
Return the current value being used by this operand.
Definition: UseDefLists.h:150
This class represents a diagnostic that is inflight and set to be reported.
Definition: Diagnostics.h:308
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
void setListener(Listener *newListener)
Sets the listener of this builder to the one provided.
Definition: Builders.h:301
This class represents an operand of an operation.
Definition: Value.h:261
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Definition: Value.cpp:217
Set of flags used to control the behavior of the various IR print methods (e.g.
This is a value defined by a result of an operation.
Definition: Value.h:448
unsigned getResultNumber() const
Returns the number of this result.
Definition: Value.h:460
This class provides the API for a sub-set of ops that are known to be constant-like.
This class provides the API for ops that are known to be isolated from above.
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
StringAttr getIdentifier() const
Return the name of this operation as a StringAttr.
T::Concept * getInterface() const
Returns an instance of the concept object for the given interface if it was registered to this operat...
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Value getOperand(unsigned idx)
Definition: Operation.h:345
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition: Operation.h:728
bool isBeforeInBlock(Operation *other)
Given an operation 'other' that is within the same parent block, return whether the current operation...
Definition: Operation.cpp:385
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Definition: Operation.h:776
unsigned getNumRegions()
Returns the number of regions held by this operation.
Definition: Operation.h:652
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
unsigned getNumOperands()
Definition: Operation.h:341
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition: Operation.h:234
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:267
Block * getBlock()
Returns the operation block that contains this operation.
Definition: Operation.h:213
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Definition: Operation.h:665
Operation * getParentWithTrait()
Returns the closest surrounding parent operation with trait Trait.
Definition: Operation.h:248
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:119
MutableArrayRef< OpOperand > getOpOperands()
Definition: Operation.h:378
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:373
result_range getOpResults()
Definition: Operation.h:415
result_range getResults()
Definition: Operation.h:410
bool isProperAncestor(Operation *other)
Return true if this operation is a proper ancestor of the other operation.
Definition: Operation.cpp:218
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:640
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:399
A 2D array where each row may have different length.
Definition: RaggedArray.h:18
size_t size() const
Returns the number of rows in the 2D array.
Definition: RaggedArray.h:21
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
Block & front()
Definition: Region.h:65
This class implements the result iterators for the Operation class.
Definition: ValueRange.h:233
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:399
This class represents a specific instance of an effect.
Resource * getResource() const
Return the resource that the effect applies to.
EffectT * getEffect() const
Return the effect being applied.
Value getValue() const
Return the value the effect is applied on, or nullptr if there isn't a known value being affected.
static DerivedEffect * get()
Returns a unique instance for the derived effect class.
static TransformMappingResource * get()
Returns a unique instance for the given effect class.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:372
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:93
Type getType() const
Return the type of this value.
Definition: Value.h:122
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:26
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
static WalkResult skip()
Definition: Visitors.h:53
static WalkResult advance()
Definition: Visitors.h:52
Operation * getOwner() const
Return the owner of this operand.
Definition: UseDefLists.h:38
A list of results of applying a transform op with ApplyEachOpTrait to a single payload operation,...
size_t size() const
Returns the number of elements in the list.
A specialized listener that keeps track of cases in which no replacement payload could be found.
bool failed() const
Return "true" if this tracking listener had a failure.
void notifyPayloadReplacementNotFound(Operation *op, ValueRange values, DiagnosedSilenceableFailure &&diag) override
This function is called when a tracked payload op is dropped because no replacement op was found.
DiagnosedSilenceableFailure checkAndResetError()
Check and return the current error state of this listener.
This trait is supposed to be attached to Transform dialect operations that can be standalone top-leve...
TrackingListener failures are reported only for ops that have this trait.
LogicalResult notifyMatchFailure(Location loc, function_ref< void(Diagnostic &)> reasonCallback) override
Notify the listener that the pattern failed to match the given operation, and provide a callback to p...
static Operation * getCommonDefiningOp(ValueRange values)
Return the single op that defines all given values (if any).
virtual DiagnosedSilenceableFailure findReplacementOp(Operation *&result, Operation *op, ValueRange newValues) const
Return a replacement payload op for the given op, which is going to be replaced with the given values...
TrackingListener(TransformState &state, TransformOpInterface op)
Create a new TrackingListener for usage in the specified transform op.
Options controlling the application of transform operations by the TransformState.
Local mapping between values defined by a specific op implementing the TransformOpInterface and the p...
void setValues(OpResult handle, Range &&values)
Indicates that the result of the transform IR op at the given position corresponds to the given range...
void setParams(OpResult value, ArrayRef< TransformState::Param > params)
Indicates that the result of the transform IR op at the given position corresponds to the given list ...
void set(OpResult value, Range &&ops)
Indicates that the result of the transform IR op at the given position corresponds to the given list ...
void setRemainingToEmpty(TransformOpInterface transform)
Sets the currently unset results to empty lists of the kind expected by the corresponding results of ...
void setMappedValues(OpResult handle, ArrayRef< MappedValue > values)
Indicates that the result of the transform IR op at the given position corresponds to the given range...
This is a special rewriter to be used in transform op implementations, providing additional helper fu...
TransformRewriter(MLIRContext *ctx, ErrorCheckingTrackingListener *listener)
Create a new TransformRewriter.
bool hasTrackingFailures() const
Return "true" if the tracking listener had failures.
LogicalResult notifyPayloadOperationReplaced(Operation *op, Operation *replacement)
Notify the transform dialect interpreter that the given op has been replaced with another op and that...
void silenceTrackingFailure()
Silence all tracking failures that have been encountered so far.
LogicalResult replacePayloadOp(Operation *op, Operation *replacement)
Replaces the given payload op with another op.
virtual ~Extension()
Base virtual destructor.
LogicalResult replacePayloadValue(Value value, Value replacement)
Replaces the given payload value with another value.
~RegionScope()
Forgets the mapping from or to values defined in the associated transform IR region,...
The state maintained across applications of various ops implementing the TransformOpInterface.
LogicalResult getHandlesForPayloadValue(Value payloadValue, SmallVectorImpl< Value > &handles, bool includeOutOfScope=false) const
Populates handles with all handles pointing to the given payload IR value.
DiagnosedSilenceableFailure applyTransform(TransformOpInterface transform)
Applies the transformation specified by the given transform op and updates the state accordingly.
ArrayRef< Attribute > getParams(Value value) const
Returns the list of parameters that the given transform IR value corresponds to.
LogicalResult mapBlockArgument(BlockArgument argument, ArrayRef< MappedValue > values)
LogicalResult getHandlesForPayloadOp(Operation *op, SmallVectorImpl< Value > &handles, bool includeOutOfScope=false) const
Populates handles with all handles pointing to the given Payload IR op.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:285
LogicalResult verifyTransformOpInterface(Operation *op)
Verification hook for TransformOpInterface.
void setApplyToOneResults(Operation *transformOp, TransformResults &transformResults, ArrayRef< ApplyToEachResultList > results)
"Transpose" the results produced by individual applications, arranging them per result value of the t...
void forwardTerminatorOperands(Block *block, transform::TransformState &state, transform::TransformResults &results)
Populates results with payload associations that match exactly those of the operands to block's termi...
void getParamProducerTransformOpTraitEffects(Operation *op, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
Non-template implementation of ParamProducerTransformOpTrait::getEffects().
LogicalResult checkNestedConsumption(Location loc, ArrayRef< Operation * > targets)
Reports an error and returns failure if targets contains an ancestor operation before its descendant ...
LogicalResult mapPossibleTopLevelTransformOpBlockArguments(TransformState &state, Operation *op, Region &region)
Maps the only block argument of the op with PossibleTopLevelTransformOpTrait to either the list of op...
TransformState makeTransformStateForTesting(Region *region, Operation *payloadRoot)
Make a dummy transform state for testing purposes.
LogicalResult checkApplyToOne(Operation *transformOp, Location payloadOpLoc, const ApplyToEachResultList &partialResult)
Check that the contents of partialResult matches the number, kind (payload op or parameter) and nulli...
SmallVector< OpOperand * > getConsumedHandleOpOperands(transform::TransformOpInterface transformOp)
Returns all operands that are handles and being consumed by the given op.
LogicalResult verifyParamProducerTransformOpTrait(Operation *op)
Non-template implementation of ParamProducerTransformOpTrait::verify().
void prepareValueMappings(SmallVectorImpl< SmallVector< transform::MappedValue >> &mappings, ValueRange values, const transform::TransformState &state)
Populates mappings with mapped values associated with the given transform IR values in the given stat...
LogicalResult verifyPossibleTopLevelTransformOpTrait(Operation *op)
Verification hook for PossibleTopLevelTransformOpTrait.
void getPotentialTopLevelEffects(Operation *operation, Value root, Block &body, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
Populates effects with side effects implied by PossibleTopLevelTransformOpTrait for the given operati...
void onlyReadsPayload(SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
bool isHandleConsumed(Value handle, transform::TransformOpInterface transform)
Checks whether the transform op consumes the given handle.
void onlyReadsHandle(ValueRange handles, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
void getConsumedBlockArguments(Block &block, llvm::SmallDenseSet< unsigned > &consumedArguments)
Populates consumedArguments with positions of block arguments that are consumed by the operations in ...
bool doesModifyPayload(transform::TransformOpInterface transform)
Checks whether the transform op modifies the payload.
void consumesHandle(ValueRange handles, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
Populates effects with the memory effects indicating the operation on the given handle value:
bool doesReadPayload(transform::TransformOpInterface transform)
Checks whether the transform op reads the payload.
void producesHandle(ValueRange handles, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
void modifiesPayload(SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
Populates effects with the memory effects indicating the access to payload IR resource.
LogicalResult applyTransforms(Operation *payloadRoot, TransformOpInterface transform, const RaggedArray< MappedValue > &extraMapping={}, const TransformOptions &options=TransformOptions())
Entry point to the Transform dialect infrastructure.
This header declares functions that assist transformations in the MemRef dialect.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
bool hasEffect(Operation *op, Value value=nullptr)
Returns true if op has an effect of type EffectTy on value.
DiagnosedSilenceableFailure emitSilenceableFailure(Location loc, const Twine &message={})
Emits a silenceable failure with the given message.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value.
Definition: LogicalResult.h:68
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
DiagnosedDefiniteFailure emitDefiniteFailure(Location loc, const Twine &message={})
Emits a definite failure with the given message.
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
Represents a range (offset, size, and stride) where each element of the triple may be dynamic or stat...