MLIR  20.0.0git
TransformInterfaces.cpp
Go to the documentation of this file.
1 //===- TransformInterfaces.cpp - Transform Dialect Interfaces -------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
11 #include "mlir/IR/Diagnostics.h"
12 #include "mlir/IR/Operation.h"
13 #include "mlir/IR/PatternMatch.h"
16 #include "llvm/ADT/STLExtras.h"
17 #include "llvm/ADT/ScopeExit.h"
18 #include "llvm/Support/Debug.h"
19 #include "llvm/Support/ErrorHandling.h"
20 
21 #define DEBUG_TYPE "transform-dialect"
22 #define DEBUG_TYPE_FULL "transform-dialect-full"
23 #define DEBUG_PRINT_AFTER_ALL "transform-dialect-print-top-level-after-all"
24 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "] ")
25 #define LDBG(X) LLVM_DEBUG(DBGS() << (X))
26 #define FULL_LDBG(X) DEBUG_WITH_TYPE(DEBUG_TYPE_FULL, (DBGS() << (X)))
27 
28 using namespace mlir;
29 
30 //===----------------------------------------------------------------------===//
31 // Helper functions
32 //===----------------------------------------------------------------------===//
33 
34 /// Return true if `a` happens before `b`, i.e., `a` or one of its ancestors
35 /// properly dominates `b` and `b` is not inside `a`.
36 static bool happensBefore(Operation *a, Operation *b) {
37  do {
38  if (a->isProperAncestor(b))
39  return false;
40  if (Operation *bAncestor = a->getBlock()->findAncestorOpInBlock(*b)) {
41  return a->isBeforeInBlock(bAncestor);
42  }
43  } while ((a = a->getParentOp()));
44  return false;
45 }
46 
47 //===----------------------------------------------------------------------===//
48 // TransformState
49 //===----------------------------------------------------------------------===//
50 
51 constexpr const Value transform::TransformState::kTopLevelValue;
52 
53 transform::TransformState::TransformState(
54  Region *region, Operation *payloadRoot,
55  const RaggedArray<MappedValue> &extraMappings,
56  const TransformOptions &options)
57  : topLevel(payloadRoot), options(options) {
58  topLevelMappedValues.reserve(extraMappings.size());
59  for (ArrayRef<MappedValue> mapping : extraMappings)
60  topLevelMappedValues.push_back(mapping);
61  if (region) {
62  RegionScope *scope = new RegionScope(*this, *region);
63  topLevelRegionScope.reset(scope);
64  }
65 }
66 
67 Operation *transform::TransformState::getTopLevel() const { return topLevel; }
68 
70 transform::TransformState::getPayloadOpsView(Value value) const {
71  const TransformOpMapping &operationMapping = getMapping(value).direct;
72  auto iter = operationMapping.find(value);
73  assert(iter != operationMapping.end() &&
74  "cannot find mapping for payload handle (param/value handle "
75  "provided?)");
76  return iter->getSecond();
77 }
78 
80  const ParamMapping &mapping = getMapping(value).params;
81  auto iter = mapping.find(value);
82  assert(iter != mapping.end() && "cannot find mapping for param handle "
83  "(operation/value handle provided?)");
84  return iter->getSecond();
85 }
86 
88 transform::TransformState::getPayloadValuesView(Value handleValue) const {
89  const ValueMapping &mapping = getMapping(handleValue).values;
90  auto iter = mapping.find(handleValue);
91  assert(iter != mapping.end() && "cannot find mapping for value handle "
92  "(param/operation handle provided?)");
93  return iter->getSecond();
94 }
95 
97  Operation *op, SmallVectorImpl<Value> &handles,
98  bool includeOutOfScope) const {
99  bool found = false;
100  for (const auto &[region, mapping] : llvm::reverse(mappings)) {
101  auto iterator = mapping->reverse.find(op);
102  if (iterator != mapping->reverse.end()) {
103  llvm::append_range(handles, iterator->getSecond());
104  found = true;
105  }
106  // Stop looking when reaching a region that is isolated from above.
107  if (!includeOutOfScope &&
109  break;
110  }
111 
112  return success(found);
113 }
114 
116  Value payloadValue, SmallVectorImpl<Value> &handles,
117  bool includeOutOfScope) const {
118  bool found = false;
119  for (const auto &[region, mapping] : llvm::reverse(mappings)) {
120  auto iterator = mapping->reverseValues.find(payloadValue);
121  if (iterator != mapping->reverseValues.end()) {
122  llvm::append_range(handles, iterator->getSecond());
123  found = true;
124  }
125  // Stop looking when reaching a region that is isolated from above.
126  if (!includeOutOfScope &&
128  break;
129  }
130 
131  return success(found);
132 }
133 
134 /// Given a list of MappedValues, cast them to the value kind implied by the
135 /// interface of the handle type, and dispatch to one of the callbacks.
138  function_ref<LogicalResult(ArrayRef<Operation *>)> operationsFn,
139  function_ref<LogicalResult(ArrayRef<transform::Param>)> paramsFn,
140  function_ref<LogicalResult(ValueRange)> valuesFn) {
141  if (llvm::isa<transform::TransformHandleTypeInterface>(handle.getType())) {
142  SmallVector<Operation *> operations;
143  operations.reserve(values.size());
144  for (transform::MappedValue value : values) {
145  if (auto *op = llvm::dyn_cast_if_present<Operation *>(value)) {
146  operations.push_back(op);
147  continue;
148  }
149  return emitSilenceableFailure(handle.getLoc())
150  << "wrong kind of value provided for top-level operation handle";
151  }
152  if (failed(operationsFn(operations)))
155  }
156 
157  if (llvm::isa<transform::TransformValueHandleTypeInterface>(
158  handle.getType())) {
159  SmallVector<Value> payloadValues;
160  payloadValues.reserve(values.size());
161  for (transform::MappedValue value : values) {
162  if (auto v = llvm::dyn_cast_if_present<Value>(value)) {
163  payloadValues.push_back(v);
164  continue;
165  }
166  return emitSilenceableFailure(handle.getLoc())
167  << "wrong kind of value provided for the top-level value handle";
168  }
169  if (failed(valuesFn(payloadValues)))
172  }
173 
174  assert(llvm::isa<transform::TransformParamTypeInterface>(handle.getType()) &&
175  "unsupported kind of block argument");
177  parameters.reserve(values.size());
178  for (transform::MappedValue value : values) {
179  if (auto attr = llvm::dyn_cast_if_present<Attribute>(value)) {
180  parameters.push_back(attr);
181  continue;
182  }
183  return emitSilenceableFailure(handle.getLoc())
184  << "wrong kind of value provided for top-level parameter";
185  }
186  if (failed(paramsFn(parameters)))
189 }
190 
191 LogicalResult
193  ArrayRef<MappedValue> values) {
194  return dispatchMappedValues(
195  argument, values,
196  [&](ArrayRef<Operation *> operations) {
197  return setPayloadOps(argument, operations);
198  },
199  [&](ArrayRef<Param> params) {
200  return setParams(argument, params);
201  },
202  [&](ValueRange payloadValues) {
203  return setPayloadValues(argument, payloadValues);
204  })
205  .checkAndReport();
206 }
207 
209  Block::BlockArgListType arguments,
211  for (auto &&[argument, values] : llvm::zip_equal(arguments, mapping))
212  if (failed(mapBlockArgument(argument, values)))
213  return failure();
214  return success();
215 }
216 
217 LogicalResult
218 transform::TransformState::setPayloadOps(Value value,
219  ArrayRef<Operation *> targets) {
220  assert(value != kTopLevelValue &&
221  "attempting to reset the transformation root");
222  assert(llvm::isa<TransformHandleTypeInterface>(value.getType()) &&
223  "wrong handle type");
224 
225  for (Operation *target : targets) {
226  if (target)
227  continue;
228  return emitError(value.getLoc())
229  << "attempting to assign a null payload op to this transform value";
230  }
231 
232  auto iface = llvm::cast<TransformHandleTypeInterface>(value.getType());
234  iface.checkPayload(value.getLoc(), targets);
235  if (failed(result.checkAndReport()))
236  return failure();
237 
238  // Setting new payload for the value without cleaning it first is a misuse of
239  // the API, assert here.
240  SmallVector<Operation *> storedTargets(targets.begin(), targets.end());
241  Mappings &mappings = getMapping(value);
242  bool inserted =
243  mappings.direct.insert({value, std::move(storedTargets)}).second;
244  assert(inserted && "value is already associated with another list");
245  (void)inserted;
246 
247  for (Operation *op : targets)
248  mappings.reverse[op].push_back(value);
249 
250  return success();
251 }
252 
253 LogicalResult
254 transform::TransformState::setPayloadValues(Value handle,
255  ValueRange payloadValues) {
256  assert(handle != nullptr && "attempting to set params for a null value");
257  assert(llvm::isa<TransformValueHandleTypeInterface>(handle.getType()) &&
258  "wrong handle type");
259 
260  for (Value payload : payloadValues) {
261  if (payload)
262  continue;
263  return emitError(handle.getLoc()) << "attempting to assign a null payload "
264  "value to this transform handle";
265  }
266 
267  auto iface = llvm::cast<TransformValueHandleTypeInterface>(handle.getType());
268  SmallVector<Value> payloadValueVector = llvm::to_vector(payloadValues);
270  iface.checkPayload(handle.getLoc(), payloadValueVector);
271  if (failed(result.checkAndReport()))
272  return failure();
273 
274  Mappings &mappings = getMapping(handle);
275  bool inserted =
276  mappings.values.insert({handle, std::move(payloadValueVector)}).second;
277  assert(
278  inserted &&
279  "value handle is already associated with another list of payload values");
280  (void)inserted;
281 
282  for (Value payload : payloadValues)
283  mappings.reverseValues[payload].push_back(handle);
284 
285  return success();
286 }
287 
288 LogicalResult transform::TransformState::setParams(Value value,
289  ArrayRef<Param> params) {
290  assert(value != nullptr && "attempting to set params for a null value");
291 
292  for (Attribute attr : params) {
293  if (attr)
294  continue;
295  return emitError(value.getLoc())
296  << "attempting to assign a null parameter to this transform value";
297  }
298 
299  auto valueType = llvm::dyn_cast<TransformParamTypeInterface>(value.getType());
300  assert(value &&
301  "cannot associate parameter with a value of non-parameter type");
303  valueType.checkPayload(value.getLoc(), params);
304  if (failed(result.checkAndReport()))
305  return failure();
306 
307  Mappings &mappings = getMapping(value);
308  bool inserted =
309  mappings.params.insert({value, llvm::to_vector(params)}).second;
310  assert(inserted && "value is already associated with another list of params");
311  (void)inserted;
312  return success();
313 }
314 
315 template <typename Mapping, typename Key, typename Mapped>
316 void dropMappingEntry(Mapping &mapping, Key key, Mapped mapped) {
317  auto it = mapping.find(key);
318  if (it == mapping.end())
319  return;
320 
321  llvm::erase(it->getSecond(), mapped);
322  if (it->getSecond().empty())
323  mapping.erase(it);
324 }
325 
326 void transform::TransformState::forgetMapping(Value opHandle,
327  ValueRange origOpFlatResults,
328  bool allowOutOfScope) {
329  Mappings &mappings = getMapping(opHandle, allowOutOfScope);
330  for (Operation *op : mappings.direct[opHandle])
331  dropMappingEntry(mappings.reverse, op, opHandle);
332  mappings.direct.erase(opHandle);
333 #ifdef LLVM_ENABLE_ABI_BREAKING_CHECKS
334  // Payload IR is removed from the mapping. This invalidates the respective
335  // iterators.
336  mappings.incrementTimestamp(opHandle);
337 #endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
338 
339  for (Value opResult : origOpFlatResults) {
340  SmallVector<Value> resultHandles;
341  (void)getHandlesForPayloadValue(opResult, resultHandles);
342  for (Value resultHandle : resultHandles) {
343  Mappings &localMappings = getMapping(resultHandle);
344  dropMappingEntry(localMappings.values, resultHandle, opResult);
345 #ifdef LLVM_ENABLE_ABI_BREAKING_CHECKS
346  // Payload IR is removed from the mapping. This invalidates the respective
347  // iterators.
348  mappings.incrementTimestamp(resultHandle);
349 #endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
350  dropMappingEntry(localMappings.reverseValues, opResult, resultHandle);
351  }
352  }
353 }
354 
355 void transform::TransformState::forgetValueMapping(
356  Value valueHandle, ArrayRef<Operation *> payloadOperations) {
357  Mappings &mappings = getMapping(valueHandle);
358  for (Value payloadValue : mappings.reverseValues[valueHandle])
359  dropMappingEntry(mappings.reverseValues, payloadValue, valueHandle);
360  mappings.values.erase(valueHandle);
361 #ifdef LLVM_ENABLE_ABI_BREAKING_CHECKS
362  // Payload IR is removed from the mapping. This invalidates the respective
363  // iterators.
364  mappings.incrementTimestamp(valueHandle);
365 #endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
366 
367  for (Operation *payloadOp : payloadOperations) {
368  SmallVector<Value> opHandles;
369  (void)getHandlesForPayloadOp(payloadOp, opHandles);
370  for (Value opHandle : opHandles) {
371  Mappings &localMappings = getMapping(opHandle);
372  dropMappingEntry(localMappings.direct, opHandle, payloadOp);
373  dropMappingEntry(localMappings.reverse, payloadOp, opHandle);
374 
375 #ifdef LLVM_ENABLE_ABI_BREAKING_CHECKS
376  // Payload IR is removed from the mapping. This invalidates the respective
377  // iterators.
378  localMappings.incrementTimestamp(opHandle);
379 #endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
380  }
381  }
382 }
383 
384 LogicalResult
385 transform::TransformState::replacePayloadOp(Operation *op,
386  Operation *replacement) {
387  // TODO: consider invalidating the handles to nested objects here.
388 
389 #ifndef NDEBUG
390  for (Value opResult : op->getResults()) {
391  SmallVector<Value> valueHandles;
392  (void)getHandlesForPayloadValue(opResult, valueHandles,
393  /*includeOutOfScope=*/true);
394  assert(valueHandles.empty() && "expected no mapping to old results");
395  }
396 #endif // NDEBUG
397 
398  // Drop the mapping between the op and all handles that point to it. Fail if
399  // there are no handles.
400  SmallVector<Value> opHandles;
401  if (failed(getHandlesForPayloadOp(op, opHandles, /*includeOutOfScope=*/true)))
402  return failure();
403  for (Value handle : opHandles) {
404  Mappings &mappings = getMapping(handle, /*allowOutOfScope=*/true);
405  dropMappingEntry(mappings.reverse, op, handle);
406  }
407 
408  // Replace the pointed-to object of all handles with the replacement object.
409  // In case a payload op was erased (replacement object is nullptr), a nullptr
410  // is stored in the mapping. These nullptrs are removed after each transform.
411  // Furthermore, nullptrs are not enumerated by payload op iterators. The
412  // relative order of ops is preserved.
413  //
414  // Removing an op from the mapping would be problematic because removing an
415  // element from an array invalidates iterators; merely changing the value of
416  // elements does not.
417  for (Value handle : opHandles) {
418  Mappings &mappings = getMapping(handle, /*allowOutOfScope=*/true);
419  auto it = mappings.direct.find(handle);
420  if (it == mappings.direct.end())
421  continue;
422 
423  SmallVector<Operation *, 2> &association = it->getSecond();
424  // Note that an operation may be associated with the handle more than once.
425  for (Operation *&mapped : association) {
426  if (mapped == op)
427  mapped = replacement;
428  }
429 
430  if (replacement) {
431  mappings.reverse[replacement].push_back(handle);
432  } else {
433  opHandlesToCompact.insert(handle);
434  }
435  }
436 
437  return success();
438 }
439 
440 LogicalResult
441 transform::TransformState::replacePayloadValue(Value value, Value replacement) {
442  SmallVector<Value> valueHandles;
443  if (failed(getHandlesForPayloadValue(value, valueHandles,
444  /*includeOutOfScope=*/true)))
445  return failure();
446 
447  for (Value handle : valueHandles) {
448  Mappings &mappings = getMapping(handle, /*allowOutOfScope=*/true);
449  dropMappingEntry(mappings.reverseValues, value, handle);
450 
451  // If replacing with null, that is erasing the mapping, drop the mapping
452  // between the handles and the IR objects
453  if (!replacement) {
454  dropMappingEntry(mappings.values, handle, value);
455 #ifdef LLVM_ENABLE_ABI_BREAKING_CHECKS
456  // Payload IR is removed from the mapping. This invalidates the respective
457  // iterators.
458  mappings.incrementTimestamp(handle);
459 #endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
460  } else {
461  auto it = mappings.values.find(handle);
462  if (it == mappings.values.end())
463  continue;
464 
465  SmallVector<Value> &association = it->getSecond();
466  for (Value &mapped : association) {
467  if (mapped == value)
468  mapped = replacement;
469  }
470  mappings.reverseValues[replacement].push_back(handle);
471  }
472  }
473 
474  return success();
475 }
476 
477 void transform::TransformState::recordOpHandleInvalidationOne(
478  OpOperand &consumingHandle, ArrayRef<Operation *> potentialAncestors,
479  Operation *payloadOp, Value otherHandle, Value throughValue,
480  transform::TransformState::InvalidatedHandleMap &newlyInvalidated) const {
481  // If the op is associated with invalidated handle, skip the check as it
482  // may be reading invalid IR. This also ensures we report the first
483  // invalidation and not the last one.
484  if (invalidatedHandles.count(otherHandle) ||
485  newlyInvalidated.count(otherHandle))
486  return;
487 
488  FULL_LDBG("--recordOpHandleInvalidationOne\n");
489  DEBUG_WITH_TYPE(
491  llvm::interleaveComma(potentialAncestors, DBGS() << "--ancestors: ",
492  [](Operation *op) { llvm::dbgs() << *op; });
493  llvm::dbgs() << "\n");
494 
495  Operation *owner = consumingHandle.getOwner();
496  unsigned operandNo = consumingHandle.getOperandNumber();
497  for (Operation *ancestor : potentialAncestors) {
498  // clang-format off
499  DEBUG_WITH_TYPE(DEBUG_TYPE_FULL,
500  { (DBGS() << "----handle one ancestor: " << *ancestor << "\n"); });
501  DEBUG_WITH_TYPE(DEBUG_TYPE_FULL,
502  { (DBGS() << "----of payload with name: "
503  << payloadOp->getName().getIdentifier() << "\n"); });
504  DEBUG_WITH_TYPE(DEBUG_TYPE_FULL,
505  { (DBGS() << "----of payload: " << *payloadOp << "\n"); });
506  // clang-format on
507  if (!ancestor->isAncestor(payloadOp))
508  continue;
509 
510  // Make sure the error-reporting lambda doesn't capture anything
511  // by-reference because it will go out of scope. Additionally, extract
512  // location from Payload IR ops because the ops themselves may be
513  // deleted before the lambda gets called.
514  Location ancestorLoc = ancestor->getLoc();
515  Location opLoc = payloadOp->getLoc();
516  std::optional<Location> throughValueLoc =
517  throughValue ? std::make_optional(throughValue.getLoc()) : std::nullopt;
518  newlyInvalidated[otherHandle] = [ancestorLoc, opLoc, owner, operandNo,
519  otherHandle,
520  throughValueLoc](Location currentLoc) {
521  InFlightDiagnostic diag = emitError(currentLoc)
522  << "op uses a handle invalidated by a "
523  "previously executed transform op";
524  diag.attachNote(otherHandle.getLoc()) << "handle to invalidated ops";
525  diag.attachNote(owner->getLoc())
526  << "invalidated by this transform op that consumes its operand #"
527  << operandNo
528  << " and invalidates all handles to payload IR entities associated "
529  "with this operand and entities nested in them";
530  diag.attachNote(ancestorLoc) << "ancestor payload op";
531  diag.attachNote(opLoc) << "nested payload op";
532  if (throughValueLoc) {
533  diag.attachNote(*throughValueLoc)
534  << "consumed handle points to this payload value";
535  }
536  };
537  }
538 }
539 
540 void transform::TransformState::recordValueHandleInvalidationByOpHandleOne(
541  OpOperand &opHandle, ArrayRef<Operation *> potentialAncestors,
542  Value payloadValue, Value valueHandle,
543  transform::TransformState::InvalidatedHandleMap &newlyInvalidated) const {
544  // If the op is associated with invalidated handle, skip the check as it
545  // may be reading invalid IR. This also ensures we report the first
546  // invalidation and not the last one.
547  if (invalidatedHandles.count(valueHandle) ||
548  newlyInvalidated.count(valueHandle))
549  return;
550 
551  for (Operation *ancestor : potentialAncestors) {
552  Operation *definingOp;
553  std::optional<unsigned> resultNo;
554  unsigned argumentNo = std::numeric_limits<unsigned>::max();
555  unsigned blockNo = std::numeric_limits<unsigned>::max();
556  unsigned regionNo = std::numeric_limits<unsigned>::max();
557  if (auto opResult = llvm::dyn_cast<OpResult>(payloadValue)) {
558  definingOp = opResult.getOwner();
559  resultNo = opResult.getResultNumber();
560  } else {
561  auto arg = llvm::cast<BlockArgument>(payloadValue);
562  definingOp = arg.getParentBlock()->getParentOp();
563  argumentNo = arg.getArgNumber();
564  blockNo = std::distance(arg.getOwner()->getParent()->begin(),
565  arg.getOwner()->getIterator());
566  regionNo = arg.getOwner()->getParent()->getRegionNumber();
567  }
568  assert(definingOp && "expected the value to be defined by an op as result "
569  "or block argument");
570  if (!ancestor->isAncestor(definingOp))
571  continue;
572 
573  Operation *owner = opHandle.getOwner();
574  unsigned operandNo = opHandle.getOperandNumber();
575  Location ancestorLoc = ancestor->getLoc();
576  Location opLoc = definingOp->getLoc();
577  Location valueLoc = payloadValue.getLoc();
578  newlyInvalidated[valueHandle] = [valueHandle, owner, operandNo, resultNo,
579  argumentNo, blockNo, regionNo, ancestorLoc,
580  opLoc, valueLoc](Location currentLoc) {
581  InFlightDiagnostic diag = emitError(currentLoc)
582  << "op uses a handle invalidated by a "
583  "previously executed transform op";
584  diag.attachNote(valueHandle.getLoc()) << "invalidated handle";
585  diag.attachNote(owner->getLoc())
586  << "invalidated by this transform op that consumes its operand #"
587  << operandNo
588  << " and invalidates all handles to payload IR entities "
589  "associated with this operand and entities nested in them";
590  diag.attachNote(ancestorLoc)
591  << "ancestor op associated with the consumed handle";
592  if (resultNo) {
593  diag.attachNote(opLoc)
594  << "op defining the value as result #" << *resultNo;
595  } else {
596  diag.attachNote(opLoc)
597  << "op defining the value as block argument #" << argumentNo
598  << " of block #" << blockNo << " in region #" << regionNo;
599  }
600  diag.attachNote(valueLoc) << "payload value";
601  };
602  }
603 }
604 
605 void transform::TransformState::recordOpHandleInvalidation(
606  OpOperand &handle, ArrayRef<Operation *> potentialAncestors,
607  Value throughValue,
608  transform::TransformState::InvalidatedHandleMap &newlyInvalidated) const {
609 
610  if (potentialAncestors.empty()) {
611  DEBUG_WITH_TYPE(DEBUG_TYPE_FULL, {
612  (DBGS() << "----recording invalidation for empty handle: " << handle.get()
613  << "\n");
614  });
615 
616  Operation *owner = handle.getOwner();
617  unsigned operandNo = handle.getOperandNumber();
618  newlyInvalidated[handle.get()] = [owner, operandNo](Location currentLoc) {
619  InFlightDiagnostic diag = emitError(currentLoc)
620  << "op uses a handle associated with empty "
621  "payload and invalidated by a "
622  "previously executed transform op";
623  diag.attachNote(owner->getLoc())
624  << "invalidated by this transform op that consumes its operand #"
625  << operandNo;
626  };
627  return;
628  }
629 
630  // Iterate over the mapping and invalidate aliasing handles. This is quite
631  // expensive and only necessary for error reporting in case of transform
632  // dialect misuse with dangling handles. Iteration over the handles is based
633  // on the assumption that the number of handles is significantly less than the
634  // number of IR objects (operations and values). Alternatively, we could walk
635  // the IR nested in each payload op associated with the given handle and look
636  // for handles associated with each operation and value.
637  for (const auto &[region, mapping] : llvm::reverse(mappings)) {
638  // Go over all op handle mappings and mark as invalidated any handle
639  // pointing to any of the payload ops associated with the given handle or
640  // any op nested in them.
641  for (const auto &[payloadOp, otherHandles] : mapping->reverse) {
642  for (Value otherHandle : otherHandles)
643  recordOpHandleInvalidationOne(handle, potentialAncestors, payloadOp,
644  otherHandle, throughValue,
645  newlyInvalidated);
646  }
647  // Go over all value handle mappings and mark as invalidated any handle
648  // pointing to any result of the payload op associated with the given handle
649  // or any op nested in them. Similarly invalidate handles to argument of
650  // blocks belonging to any region of any payload op associated with the
651  // given handle or any op nested in them.
652  for (const auto &[payloadValue, valueHandles] : mapping->reverseValues) {
653  for (Value valueHandle : valueHandles)
654  recordValueHandleInvalidationByOpHandleOne(handle, potentialAncestors,
655  payloadValue, valueHandle,
656  newlyInvalidated);
657  }
658 
659  // Stop lookup when reaching a region that is isolated from above.
661  break;
662  }
663 }
664 
665 void transform::TransformState::recordValueHandleInvalidation(
666  OpOperand &valueHandle,
667  transform::TransformState::InvalidatedHandleMap &newlyInvalidated) const {
668  // Invalidate other handles to the same value.
669  for (Value payloadValue : getPayloadValuesView(valueHandle.get())) {
670  SmallVector<Value> otherValueHandles;
671  (void)getHandlesForPayloadValue(payloadValue, otherValueHandles);
672  for (Value otherHandle : otherValueHandles) {
673  Operation *owner = valueHandle.getOwner();
674  unsigned operandNo = valueHandle.getOperandNumber();
675  Location valueLoc = payloadValue.getLoc();
676  newlyInvalidated[otherHandle] = [otherHandle, owner, operandNo,
677  valueLoc](Location currentLoc) {
678  InFlightDiagnostic diag = emitError(currentLoc)
679  << "op uses a handle invalidated by a "
680  "previously executed transform op";
681  diag.attachNote(otherHandle.getLoc()) << "invalidated handle";
682  diag.attachNote(owner->getLoc())
683  << "invalidated by this transform op that consumes its operand #"
684  << operandNo
685  << " and invalidates handles to the same values as associated with "
686  "it";
687  diag.attachNote(valueLoc) << "payload value";
688  };
689  }
690 
691  if (auto opResult = llvm::dyn_cast<OpResult>(payloadValue)) {
692  Operation *payloadOp = opResult.getOwner();
693  recordOpHandleInvalidation(valueHandle, payloadOp, payloadValue,
694  newlyInvalidated);
695  } else {
696  auto arg = llvm::dyn_cast<BlockArgument>(payloadValue);
697  for (Operation &payloadOp : *arg.getOwner())
698  recordOpHandleInvalidation(valueHandle, &payloadOp, payloadValue,
699  newlyInvalidated);
700  }
701  }
702 }
703 
704 /// Checks that the operation does not use invalidated handles as operands.
705 /// Reports errors and returns failure if it does. Otherwise, invalidates the
706 /// handles consumed by the operation as well as any handles pointing to payload
707 /// IR operations nested in the operations associated with the consumed handles.
708 LogicalResult transform::TransformState::checkAndRecordHandleInvalidationImpl(
709  transform::TransformOpInterface transform,
710  transform::TransformState::InvalidatedHandleMap &newlyInvalidated) const {
711  FULL_LDBG("--Start checkAndRecordHandleInvalidation\n");
712  auto memoryEffectsIface =
713  cast<MemoryEffectOpInterface>(transform.getOperation());
715  memoryEffectsIface.getEffectsOnResource(
717 
718  for (OpOperand &target : transform->getOpOperands()) {
719  DEBUG_WITH_TYPE(DEBUG_TYPE_FULL, {
720  (DBGS() << "----iterate on handle: " << target.get() << "\n");
721  });
722  // If the operand uses an invalidated handle, report it. If the operation
723  // allows handles to point to repeated payload operations, only report
724  // pre-existing invalidation errors. Otherwise, also report invalidations
725  // caused by the current transform operation affecting its other operands.
726  auto it = invalidatedHandles.find(target.get());
727  auto nit = newlyInvalidated.find(target.get());
728  if (it != invalidatedHandles.end()) {
729  FULL_LDBG("--End checkAndRecordHandleInvalidation, found already "
730  "invalidated -> FAILURE\n");
731  return it->getSecond()(transform->getLoc()), failure();
732  }
733  if (!transform.allowsRepeatedHandleOperands() &&
734  nit != newlyInvalidated.end()) {
735  FULL_LDBG("--End checkAndRecordHandleInvalidation, found newly "
736  "invalidated (by this op) -> FAILURE\n");
737  return nit->getSecond()(transform->getLoc()), failure();
738  }
739 
740  // Invalidate handles pointing to the operations nested in the operation
741  // associated with the handle consumed by this operation.
742  auto consumesTarget = [&](const MemoryEffects::EffectInstance &effect) {
743  return isa<MemoryEffects::Free>(effect.getEffect()) &&
744  effect.getValue() == target.get();
745  };
746  if (llvm::any_of(effects, consumesTarget)) {
747  FULL_LDBG("----found consume effect\n");
748  if (llvm::isa<transform::TransformHandleTypeInterface>(
749  target.get().getType())) {
750  FULL_LDBG("----recordOpHandleInvalidation\n");
751  SmallVector<Operation *> payloadOps =
752  llvm::to_vector(getPayloadOps(target.get()));
753  recordOpHandleInvalidation(target, payloadOps, nullptr,
754  newlyInvalidated);
755  } else if (llvm::isa<transform::TransformValueHandleTypeInterface>(
756  target.get().getType())) {
757  FULL_LDBG("----recordValueHandleInvalidation\n");
758  recordValueHandleInvalidation(target, newlyInvalidated);
759  } else {
760  FULL_LDBG("----not a TransformHandle -> SKIP AND DROP ON THE FLOOR\n");
761  }
762  } else {
763  FULL_LDBG("----no consume effect -> SKIP\n");
764  }
765  }
766 
767  FULL_LDBG("--End checkAndRecordHandleInvalidation -> SUCCESS\n");
768  return success();
769 }
770 
771 LogicalResult transform::TransformState::checkAndRecordHandleInvalidation(
772  transform::TransformOpInterface transform) {
773  InvalidatedHandleMap newlyInvalidated;
774  LogicalResult checkResult =
775  checkAndRecordHandleInvalidationImpl(transform, newlyInvalidated);
776  invalidatedHandles.insert(std::make_move_iterator(newlyInvalidated.begin()),
777  std::make_move_iterator(newlyInvalidated.end()));
778  return checkResult;
779 }
780 
781 template <typename T>
784  transform::TransformOpInterface transform,
785  unsigned operandNumber) {
786  DenseSet<T> seen;
787  for (T p : payload) {
788  if (!seen.insert(p).second) {
790  transform.emitSilenceableError()
791  << "a handle passed as operand #" << operandNumber
792  << " and consumed by this operation points to a payload "
793  "entity more than once";
794  if constexpr (std::is_pointer_v<T>)
795  diag.attachNote(p->getLoc()) << "repeated target op";
796  else
797  diag.attachNote(p.getLoc()) << "repeated target value";
798  return diag;
799  }
800  }
802 }
803 
804 void transform::TransformState::compactOpHandles() {
805  for (Value handle : opHandlesToCompact) {
806  Mappings &mappings = getMapping(handle, /*allowOutOfScope=*/true);
807 #ifdef LLVM_ENABLE_ABI_BREAKING_CHECKS
808  if (llvm::find(mappings.direct[handle], nullptr) !=
809  mappings.direct[handle].end())
810  // Payload IR is removed from the mapping. This invalidates the respective
811  // iterators.
812  mappings.incrementTimestamp(handle);
813 #endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
814  llvm::erase(mappings.direct[handle], nullptr);
815  }
816  opHandlesToCompact.clear();
817 }
818 
820 transform::TransformState::applyTransform(TransformOpInterface transform) {
821  LLVM_DEBUG({
822  DBGS() << "applying: ";
823  transform->print(llvm::dbgs(), OpPrintingFlags().skipRegions());
824  llvm::dbgs() << "\n";
825  });
826  DEBUG_WITH_TYPE(DEBUG_TYPE_FULL,
827  DBGS() << "Top-level payload before application:\n"
828  << *getTopLevel() << "\n");
829  auto printOnFailureRAII = llvm::make_scope_exit([this] {
830  (void)this;
831  LLVM_DEBUG(DBGS() << "Failing Top-level payload:\n"; getTopLevel()->print(
832  llvm::dbgs(), mlir::OpPrintingFlags().printGenericOpForm()););
833  });
834 
835  // Set current transform op.
836  regionStack.back()->currentTransform = transform;
837 
838  // Expensive checks to detect invalid transform IR.
839  if (options.getExpensiveChecksEnabled()) {
840  FULL_LDBG("ExpensiveChecksEnabled\n");
841  if (failed(checkAndRecordHandleInvalidation(transform)))
843 
844  for (OpOperand &operand : transform->getOpOperands()) {
845  DEBUG_WITH_TYPE(DEBUG_TYPE_FULL, {
846  (DBGS() << "iterate on handle: " << operand.get() << "\n");
847  });
848  if (!isHandleConsumed(operand.get(), transform)) {
849  FULL_LDBG("--handle not consumed -> SKIP\n");
850  continue;
851  }
852  if (transform.allowsRepeatedHandleOperands()) {
853  FULL_LDBG("--op allows repeated handles -> SKIP\n");
854  continue;
855  }
856  FULL_LDBG("--handle is consumed\n");
857 
858  Type operandType = operand.get().getType();
859  if (llvm::isa<TransformHandleTypeInterface>(operandType)) {
860  FULL_LDBG("--checkRepeatedConsumptionInOperand for Operation*\n");
862  checkRepeatedConsumptionInOperand<Operation *>(
863  getPayloadOpsView(operand.get()), transform,
864  operand.getOperandNumber());
865  if (!check.succeeded()) {
866  FULL_LDBG("----FAILED\n");
867  return check;
868  }
869  } else if (llvm::isa<TransformValueHandleTypeInterface>(operandType)) {
870  FULL_LDBG("--checkRepeatedConsumptionInOperand For Value\n");
872  checkRepeatedConsumptionInOperand<Value>(
873  getPayloadValuesView(operand.get()), transform,
874  operand.getOperandNumber());
875  if (!check.succeeded()) {
876  FULL_LDBG("----FAILED\n");
877  return check;
878  }
879  } else {
880  FULL_LDBG("--not a TransformHandle -> SKIP AND DROP ON THE FLOOR\n");
881  }
882  }
883  }
884 
885  // Find which operands are consumed.
886  SmallVector<OpOperand *> consumedOperands =
887  transform.getConsumedHandleOpOperands();
888 
889  // Remember the results of the payload ops associated with the consumed
890  // op handles or the ops defining the value handles so we can drop the
891  // association with them later. This must happen here because the
892  // transformation may destroy or mutate them so we cannot traverse the payload
893  // IR after that.
894  SmallVector<Value> origOpFlatResults;
895  SmallVector<Operation *> origAssociatedOps;
896  for (OpOperand *opOperand : consumedOperands) {
897  Value operand = opOperand->get();
898  if (llvm::isa<TransformHandleTypeInterface>(operand.getType())) {
899  for (Operation *payloadOp : getPayloadOps(operand)) {
900  llvm::append_range(origOpFlatResults, payloadOp->getResults());
901  }
902  continue;
903  }
904  if (llvm::isa<TransformValueHandleTypeInterface>(operand.getType())) {
905  for (Value payloadValue : getPayloadValuesView(operand)) {
906  if (llvm::isa<OpResult>(payloadValue)) {
907  origAssociatedOps.push_back(payloadValue.getDefiningOp());
908  continue;
909  }
910  llvm::append_range(
911  origAssociatedOps,
912  llvm::map_range(*llvm::cast<BlockArgument>(payloadValue).getOwner(),
913  [](Operation &op) { return &op; }));
914  }
915  continue;
916  }
918  emitDefiniteFailure(transform->getLoc())
919  << "unexpectedly consumed a value that is not a handle as operand #"
920  << opOperand->getOperandNumber();
921  diag.attachNote(operand.getLoc())
922  << "value defined here with type " << operand.getType();
923  return diag;
924  }
925 
926  // Prepare rewriter and listener.
927  TrackingListenerConfig config;
928  config.skipHandleFn = [&](Value handle) {
929  // Skip handle if it is dead.
930  auto scopeIt =
931  llvm::find_if(llvm::reverse(regionStack), [&](RegionScope *scope) {
932  return handle.getParentRegion() == scope->region;
933  });
934  assert(scopeIt != regionStack.rend() &&
935  "could not find region scope for handle");
936  RegionScope *scope = *scopeIt;
937  for (Operation *user : handle.getUsers()) {
938  if (user != scope->currentTransform &&
939  !happensBefore(user, scope->currentTransform))
940  return false;
941  }
942  return true;
943  };
944  transform::ErrorCheckingTrackingListener trackingListener(*this, transform,
945  config);
946  transform::TransformRewriter rewriter(transform->getContext(),
947  &trackingListener);
948 
949  // Compute the result but do not short-circuit the silenceable failure case as
950  // we still want the handles to propagate properly so the "suppress" mode can
951  // proceed on a best effort basis.
952  transform::TransformResults results(transform->getNumResults());
953  DiagnosedSilenceableFailure result(transform.apply(rewriter, results, *this));
954  compactOpHandles();
955 
956  // Error handling: fail if transform or listener failed.
957  DiagnosedSilenceableFailure trackingFailure =
958  trackingListener.checkAndResetError();
959  if (!transform->hasTrait<ReportTrackingListenerFailuresOpTrait>() ||
960  transform->hasAttr(FindPayloadReplacementOpInterface::
961  kSilenceTrackingFailuresAttrName)) {
962  // Only report failures for ReportTrackingListenerFailuresOpTrait ops. Also
963  // do not report failures if the above mentioned attribute is set.
964  if (trackingFailure.isSilenceableFailure())
965  (void)trackingFailure.silence();
966  trackingFailure = DiagnosedSilenceableFailure::success();
967  }
968  if (!trackingFailure.succeeded()) {
969  if (result.succeeded()) {
970  result = std::move(trackingFailure);
971  } else {
972  // Transform op errors have precedence, report those first.
973  if (result.isSilenceableFailure())
974  result.attachNote() << "tracking listener also failed: "
975  << trackingFailure.getMessage();
976  (void)trackingFailure.silence();
977  }
978  }
979  if (result.isDefiniteFailure())
980  return result;
981 
982  // If a silenceable failure was produced, some results may be unset, set them
983  // to empty lists.
984  if (result.isSilenceableFailure())
985  results.setRemainingToEmpty(transform);
986 
987  // Remove the mapping for the operand if it is consumed by the operation. This
988  // allows us to catch use-after-free with assertions later on.
989  for (OpOperand *opOperand : consumedOperands) {
990  Value operand = opOperand->get();
991  if (llvm::isa<TransformHandleTypeInterface>(operand.getType())) {
992  forgetMapping(operand, origOpFlatResults);
993  } else if (llvm::isa<TransformValueHandleTypeInterface>(
994  operand.getType())) {
995  forgetValueMapping(operand, origAssociatedOps);
996  }
997  }
998 
999  if (failed(updateStateFromResults(results, transform->getResults())))
1001 
1002  printOnFailureRAII.release();
1003  DEBUG_WITH_TYPE(DEBUG_PRINT_AFTER_ALL, {
1004  DBGS() << "Top-level payload:\n";
1005  getTopLevel()->print(llvm::dbgs());
1006  });
1007  return result;
1008 }
1009 
1010 LogicalResult transform::TransformState::updateStateFromResults(
1011  const TransformResults &results, ResultRange opResults) {
1012  for (OpResult result : opResults) {
1013  if (llvm::isa<TransformParamTypeInterface>(result.getType())) {
1014  assert(results.isParam(result.getResultNumber()) &&
1015  "expected parameters for the parameter-typed result");
1016  if (failed(
1017  setParams(result, results.getParams(result.getResultNumber())))) {
1018  return failure();
1019  }
1020  } else if (llvm::isa<TransformValueHandleTypeInterface>(result.getType())) {
1021  assert(results.isValue(result.getResultNumber()) &&
1022  "expected values for value-type-result");
1023  if (failed(setPayloadValues(
1024  result, results.getValues(result.getResultNumber())))) {
1025  return failure();
1026  }
1027  } else {
1028  assert(!results.isParam(result.getResultNumber()) &&
1029  "expected payload ops for the non-parameter typed result");
1030  if (failed(
1031  setPayloadOps(result, results.get(result.getResultNumber())))) {
1032  return failure();
1033  }
1034  }
1035  }
1036  return success();
1037 }
1038 
1039 //===----------------------------------------------------------------------===//
1040 // TransformState::Extension
1041 //===----------------------------------------------------------------------===//
1042 
1044 
1045 LogicalResult
1047  Operation *replacement) {
1048  // TODO: we may need to invalidate handles to operations and values nested in
1049  // the operation being replaced.
1050  return state.replacePayloadOp(op, replacement);
1051 }
1052 
1053 LogicalResult
1055  Value replacement) {
1056  return state.replacePayloadValue(value, replacement);
1057 }
1058 
1059 //===----------------------------------------------------------------------===//
1060 // TransformState::RegionScope
1061 //===----------------------------------------------------------------------===//
1062 
1064  // Remove handle invalidation notices as handles are going out of scope.
1065  // The same region may be re-entered leading to incorrect invalidation
1066  // errors.
1067  for (Block &block : *region) {
1068  for (Value handle : block.getArguments()) {
1069  state.invalidatedHandles.erase(handle);
1070  }
1071  for (Operation &op : block) {
1072  for (Value handle : op.getResults()) {
1073  state.invalidatedHandles.erase(handle);
1074  }
1075  }
1076  }
1077 
1078 #if LLVM_ENABLE_ABI_BREAKING_CHECKS
1079  // Remember pointers to payload ops referenced by the handles going out of
1080  // scope.
1081  SmallVector<Operation *> referencedOps =
1082  llvm::to_vector(llvm::make_first_range(state.mappings[region]->reverse));
1083 #endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
1084 
1085  state.mappings.erase(region);
1086  state.regionStack.pop_back();
1087 }
1088 
1089 //===----------------------------------------------------------------------===//
1090 // TransformResults
1091 //===----------------------------------------------------------------------===//
1092 
1093 transform::TransformResults::TransformResults(unsigned numSegments) {
1094  operations.appendEmptyRows(numSegments);
1095  params.appendEmptyRows(numSegments);
1096  values.appendEmptyRows(numSegments);
1097 }
1098 
1101  int64_t position = value.getResultNumber();
1102  assert(position < static_cast<int64_t>(this->params.size()) &&
1103  "setting params for a non-existent handle");
1104  assert(this->params[position].data() == nullptr && "params already set");
1105  assert(operations[position].data() == nullptr &&
1106  "another kind of results already set");
1107  assert(values[position].data() == nullptr &&
1108  "another kind of results already set");
1109  this->params.replace(position, params);
1110 }
1111 
1113  OpResult handle, ArrayRef<MappedValue> values) {
1115  handle, values,
1116  [&](ArrayRef<Operation *> operations) {
1117  return set(handle, operations), success();
1118  },
1119  [&](ArrayRef<Param> params) {
1120  return setParams(handle, params), success();
1121  },
1122  [&](ValueRange payloadValues) {
1123  return setValues(handle, payloadValues), success();
1124  });
1125 #ifndef NDEBUG
1126  if (!diag.succeeded())
1127  llvm::dbgs() << diag.getStatusString() << "\n";
1128  assert(diag.succeeded() && "incorrect mapping");
1129 #endif // NDEBUG
1130  (void)diag.silence();
1131 }
1132 
1134  transform::TransformOpInterface transform) {
1135  for (OpResult opResult : transform->getResults()) {
1136  if (!isSet(opResult.getResultNumber()))
1137  setMappedValues(opResult, {});
1138  }
1139 }
1140 
1142 transform::TransformResults::get(unsigned resultNumber) const {
1143  assert(resultNumber < operations.size() &&
1144  "querying results for a non-existent handle");
1145  assert(operations[resultNumber].data() != nullptr &&
1146  "querying unset results (values or params expected?)");
1147  return operations[resultNumber];
1148 }
1149 
1151 transform::TransformResults::getParams(unsigned resultNumber) const {
1152  assert(resultNumber < params.size() &&
1153  "querying params for a non-existent handle");
1154  assert(params[resultNumber].data() != nullptr &&
1155  "querying unset params (ops or values expected?)");
1156  return params[resultNumber];
1157 }
1158 
1160 transform::TransformResults::getValues(unsigned resultNumber) const {
1161  assert(resultNumber < values.size() &&
1162  "querying values for a non-existent handle");
1163  assert(values[resultNumber].data() != nullptr &&
1164  "querying unset values (ops or params expected?)");
1165  return values[resultNumber];
1166 }
1167 
1168 bool transform::TransformResults::isParam(unsigned resultNumber) const {
1169  assert(resultNumber < params.size() &&
1170  "querying association for a non-existent handle");
1171  return params[resultNumber].data() != nullptr;
1172 }
1173 
1174 bool transform::TransformResults::isValue(unsigned resultNumber) const {
1175  assert(resultNumber < values.size() &&
1176  "querying association for a non-existent handle");
1177  return values[resultNumber].data() != nullptr;
1178 }
1179 
1180 bool transform::TransformResults::isSet(unsigned resultNumber) const {
1181  assert(resultNumber < params.size() &&
1182  "querying association for a non-existent handle");
1183  return params[resultNumber].data() != nullptr ||
1184  operations[resultNumber].data() != nullptr ||
1185  values[resultNumber].data() != nullptr;
1186 }
1187 
1188 //===----------------------------------------------------------------------===//
1189 // TrackingListener
1190 //===----------------------------------------------------------------------===//
1191 
1193  TransformOpInterface op,
1194  TrackingListenerConfig config)
1195  : TransformState::Extension(state), transformOp(op), config(config) {
1196  if (op) {
1197  for (OpOperand *opOperand : transformOp.getConsumedHandleOpOperands()) {
1198  consumedHandles.insert(opOperand->get());
1199  }
1200  }
1201 }
1202 
1204  Operation *defOp = nullptr;
1205  for (Value v : values) {
1206  // Skip empty values.
1207  if (!v)
1208  continue;
1209  if (!defOp) {
1210  defOp = v.getDefiningOp();
1211  continue;
1212  }
1213  if (defOp != v.getDefiningOp())
1214  return nullptr;
1215  }
1216  return defOp;
1217 }
1218 
1220  Operation *&result, Operation *op, ValueRange newValues) const {
1221  assert(op->getNumResults() == newValues.size() &&
1222  "invalid number of replacement values");
1223  SmallVector<Value> values(newValues.begin(), newValues.end());
1224 
1226  getTransformOp(), "tracking listener failed to find replacement op "
1227  "during application of this transform op");
1228 
1229  do {
1230  // If the replacement values belong to different ops, drop the mapping.
1231  Operation *defOp = getCommonDefiningOp(values);
1232  if (!defOp) {
1233  diag.attachNote() << "replacement values belong to different ops";
1234  return diag;
1235  }
1236 
1237  // Skip through ops that implement CastOpInterface.
1238  if (config.skipCastOps && isa<CastOpInterface>(defOp)) {
1239  values.clear();
1240  values.assign(defOp->getOperands().begin(), defOp->getOperands().end());
1241  diag.attachNote(defOp->getLoc())
1242  << "using output of 'CastOpInterface' op";
1243  continue;
1244  }
1245 
1246  // If the defining op has the same name or we do not care about the name of
1247  // op replacements at all, we take it as a replacement.
1248  if (!config.requireMatchingReplacementOpName ||
1249  op->getName() == defOp->getName()) {
1250  result = defOp;
1252  }
1253 
1254  // Replacing an op with a constant-like equivalent is a common
1255  // canonicalization.
1256  if (defOp->hasTrait<OpTrait::ConstantLike>()) {
1257  result = defOp;
1259  }
1260 
1261  values.clear();
1262 
1263  // Skip through ops that implement FindPayloadReplacementOpInterface.
1264  if (auto findReplacementOpInterface =
1265  dyn_cast<FindPayloadReplacementOpInterface>(defOp)) {
1266  values.assign(findReplacementOpInterface.getNextOperands());
1267  diag.attachNote(defOp->getLoc()) << "using operands provided by "
1268  "'FindPayloadReplacementOpInterface'";
1269  continue;
1270  }
1271  } while (!values.empty());
1272 
1273  diag.attachNote() << "ran out of suitable replacement values";
1274  return diag;
1275 }
1276 
1278  Location loc, function_ref<void(Diagnostic &)> reasonCallback) {
1279  LLVM_DEBUG({
1281  reasonCallback(diag);
1282  DBGS() << "Match Failure : " << diag.str() << "\n";
1283  });
1284 }
1285 
1286 void transform::TrackingListener::notifyOperationErased(Operation *op) {
1287  // Remove mappings for result values.
1288  for (OpResult value : op->getResults())
1289  (void)replacePayloadValue(value, nullptr);
1290  // Remove mapping for op.
1291  (void)replacePayloadOp(op, nullptr);
1292 }
1293 
1294 void transform::TrackingListener::notifyOperationReplaced(
1295  Operation *op, ValueRange newValues) {
1296  assert(op->getNumResults() == newValues.size() &&
1297  "invalid number of replacement values");
1298 
1299  // Replace value handles.
1300  for (auto [oldValue, newValue] : llvm::zip(op->getResults(), newValues))
1301  (void)replacePayloadValue(oldValue, newValue);
1302 
1303  // Replace op handle.
1304  SmallVector<Value> opHandles;
1305  if (failed(getTransformState().getHandlesForPayloadOp(
1306  op, opHandles, /*includeOutOfScope=*/true))) {
1307  // Op is not tracked.
1308  return;
1309  }
1310 
1311  // Helper function to check if the current transform op consumes any handle
1312  // that is mapped to `op`.
1313  //
1314  // Note: If a handle was consumed, there shouldn't be any alive users, so it
1315  // is not really necessary to check for consumed handles. However, in case
1316  // there are indeed alive handles that were consumed (which is undefined
1317  // behavior) and a replacement op could not be found, we want to fail with a
1318  // nicer error message: "op uses a handle invalidated..." instead of "could
1319  // not find replacement op". This nicer error is produced later.
1320  auto handleWasConsumed = [&] {
1321  return llvm::any_of(opHandles,
1322  [&](Value h) { return consumedHandles.contains(h); });
1323  };
1324 
1325  // Check if there are any handles that must be updated.
1326  Value aliveHandle;
1327  if (config.skipHandleFn) {
1328  auto it = llvm::find_if(opHandles,
1329  [&](Value v) { return !config.skipHandleFn(v); });
1330  if (it != opHandles.end())
1331  aliveHandle = *it;
1332  } else if (!opHandles.empty()) {
1333  aliveHandle = opHandles.front();
1334  }
1335  if (!aliveHandle || handleWasConsumed()) {
1336  // The op is tracked but the corresponding handles are dead or were
1337  // consumed. Drop the op form the mapping.
1338  (void)replacePayloadOp(op, nullptr);
1339  return;
1340  }
1341 
1342  Operation *replacement;
1344  findReplacementOp(replacement, op, newValues);
1345  // If the op is tracked but no replacement op was found, send a
1346  // notification.
1347  if (!diag.succeeded()) {
1348  diag.attachNote(aliveHandle.getLoc())
1349  << "replacement is required because this handle must be updated";
1350  notifyPayloadReplacementNotFound(op, newValues, std::move(diag));
1351  (void)replacePayloadOp(op, nullptr);
1352  return;
1353  }
1354 
1355  (void)replacePayloadOp(op, replacement);
1356 }
1357 
1359  // The state of the ErrorCheckingTrackingListener must be checked and reset
1360  // if there was an error. This is to prevent errors from accidentally being
1361  // missed.
1362  assert(status.succeeded() && "listener state was not checked");
1363 }
1364 
1367  DiagnosedSilenceableFailure s = std::move(status);
1369  errorCounter = 0;
1370  return s;
1371 }
1372 
1374  return !status.succeeded();
1375 }
1376 
1379 
1380  // Merge potentially existing diags and store the result in the listener.
1382  diag.takeDiagnostics(diags);
1383  if (!status.succeeded())
1384  status.takeDiagnostics(diags);
1385  status = DiagnosedSilenceableFailure::silenceableFailure(std::move(diags));
1386 
1387  // Report more details.
1388  status.attachNote(op->getLoc()) << "[" << errorCounter << "] replaced op";
1389  for (auto &&[index, value] : llvm::enumerate(values))
1390  status.attachNote(value.getLoc())
1391  << "[" << errorCounter << "] replacement value " << index;
1392  ++errorCounter;
1393 }
1394 
1395 //===----------------------------------------------------------------------===//
1396 // TransformRewriter
1397 //===----------------------------------------------------------------------===//
1398 
1401  : RewriterBase(ctx), listener(listener) {
1402  setListener(listener);
1403 }
1404 
1406  return listener->failed();
1407 }
1408 
1409 /// Silence all tracking failures that have been encountered so far.
1411  if (hasTrackingFailures()) {
1412  DiagnosedSilenceableFailure status = listener->checkAndResetError();
1413  (void)status.silence();
1414  }
1415 }
1416 
1418  Operation *op, Operation *replacement) {
1419  return listener->replacePayloadOp(op, replacement);
1420 }
1421 
1422 //===----------------------------------------------------------------------===//
1423 // Utilities for TransformEachOpTrait.
1424 //===----------------------------------------------------------------------===//
1425 
1426 LogicalResult
1428  ArrayRef<Operation *> targets) {
1429  for (auto &&[position, parent] : llvm::enumerate(targets)) {
1430  for (Operation *child : targets.drop_front(position + 1)) {
1431  if (parent->isAncestor(child)) {
1433  emitError(loc)
1434  << "transform operation consumes a handle pointing to an ancestor "
1435  "payload operation before its descendant";
1436  diag.attachNote()
1437  << "the ancestor is likely erased or rewritten before the "
1438  "descendant is accessed, leading to undefined behavior";
1439  diag.attachNote(parent->getLoc()) << "ancestor payload op";
1440  diag.attachNote(child->getLoc()) << "descendant payload op";
1441  return diag;
1442  }
1443  }
1444  }
1445  return success();
1446 }
1447 
1448 LogicalResult
1450  Location payloadOpLoc,
1451  const ApplyToEachResultList &partialResult) {
1452  Location transformOpLoc = transformOp->getLoc();
1453  StringRef transformOpName = transformOp->getName().getStringRef();
1454  unsigned expectedNumResults = transformOp->getNumResults();
1455 
1456  // Reuse the emission of the diagnostic note.
1457  auto emitDiag = [&]() {
1458  auto diag = mlir::emitError(transformOpLoc);
1459  diag.attachNote(payloadOpLoc) << "when applied to this op";
1460  return diag;
1461  };
1462 
1463  if (partialResult.size() != expectedNumResults) {
1464  auto diag = emitDiag() << "application of " << transformOpName
1465  << " expected to produce " << expectedNumResults
1466  << " results (actually produced "
1467  << partialResult.size() << ").";
1468  diag.attachNote(transformOpLoc)
1469  << "if you need variadic results, consider a generic `apply` "
1470  << "instead of the specialized `applyToOne`.";
1471  return failure();
1472  }
1473 
1474  // Check that the right kind of value was produced.
1475  for (const auto &[ptr, res] :
1476  llvm::zip(partialResult, transformOp->getResults())) {
1477  if (ptr.isNull())
1478  continue;
1479  if (llvm::isa<TransformHandleTypeInterface>(res.getType()) &&
1480  !ptr.is<Operation *>()) {
1481  return emitDiag() << "application of " << transformOpName
1482  << " expected to produce an Operation * for result #"
1483  << res.getResultNumber();
1484  }
1485  if (llvm::isa<TransformParamTypeInterface>(res.getType()) &&
1486  !ptr.is<Attribute>()) {
1487  return emitDiag() << "application of " << transformOpName
1488  << " expected to produce an Attribute for result #"
1489  << res.getResultNumber();
1490  }
1491  if (llvm::isa<TransformValueHandleTypeInterface>(res.getType()) &&
1492  !ptr.is<Value>()) {
1493  return emitDiag() << "application of " << transformOpName
1494  << " expected to produce a Value for result #"
1495  << res.getResultNumber();
1496  }
1497  }
1498  return success();
1499 }
1500 
1501 template <typename T>
1503  return llvm::to_vector(llvm::map_range(
1504  range, [](transform::MappedValue value) { return value.get<T>(); }));
1505 }
1506 
1508  Operation *transformOp, TransformResults &transformResults,
1511  transposed.resize(transformOp->getNumResults());
1512  for (const ApplyToEachResultList &partialResults : results) {
1513  if (llvm::any_of(partialResults,
1514  [](MappedValue value) { return value.isNull(); }))
1515  continue;
1516  assert(transformOp->getNumResults() == partialResults.size() &&
1517  "expected as many partial results as op as results");
1518  for (auto [i, value] : llvm::enumerate(partialResults))
1519  transposed[i].push_back(value);
1520  }
1521 
1522  for (OpResult r : transformOp->getResults()) {
1523  unsigned position = r.getResultNumber();
1524  if (llvm::isa<TransformParamTypeInterface>(r.getType())) {
1525  transformResults.setParams(r,
1526  castVector<Attribute>(transposed[position]));
1527  } else if (llvm::isa<TransformValueHandleTypeInterface>(r.getType())) {
1528  transformResults.setValues(r, castVector<Value>(transposed[position]));
1529  } else {
1530  transformResults.set(r, castVector<Operation *>(transposed[position]));
1531  }
1532  }
1533 }
1534 
1535 //===----------------------------------------------------------------------===//
1536 // Utilities for implementing transform ops with regions.
1537 //===----------------------------------------------------------------------===//
1538 
1541  ValueRange values, const transform::TransformState &state, bool flatten) {
1542  assert(mappings.size() == values.size() && "mismatching number of mappings");
1543  for (auto &&[operand, mapped] : llvm::zip_equal(values, mappings)) {
1544  size_t mappedSize = mapped.size();
1545  if (llvm::isa<TransformHandleTypeInterface>(operand.getType())) {
1546  llvm::append_range(mapped, state.getPayloadOps(operand));
1547  } else if (llvm::isa<TransformValueHandleTypeInterface>(
1548  operand.getType())) {
1549  llvm::append_range(mapped, state.getPayloadValues(operand));
1550  } else {
1551  assert(llvm::isa<TransformParamTypeInterface>(operand.getType()) &&
1552  "unsupported kind of transform dialect value");
1553  llvm::append_range(mapped, state.getParams(operand));
1554  }
1555 
1556  if (mapped.size() - mappedSize != 1 && !flatten)
1557  return failure();
1558  }
1559  return success();
1560 }
1561 
1564  ValueRange values, const transform::TransformState &state) {
1565  mappings.resize(mappings.size() + values.size());
1566  (void)appendValueMappings(
1568  values.size()),
1569  values, state);
1570 }
1571 
1573  Block *block, transform::TransformState &state,
1574  transform::TransformResults &results) {
1575  for (auto &&[terminatorOperand, result] :
1576  llvm::zip(block->getTerminator()->getOperands(),
1577  block->getParentOp()->getOpResults())) {
1578  if (llvm::isa<transform::TransformHandleTypeInterface>(result.getType())) {
1579  results.set(result, state.getPayloadOps(terminatorOperand));
1580  } else if (llvm::isa<transform::TransformValueHandleTypeInterface>(
1581  result.getType())) {
1582  results.setValues(result, state.getPayloadValues(terminatorOperand));
1583  } else {
1584  assert(
1585  llvm::isa<transform::TransformParamTypeInterface>(result.getType()) &&
1586  "unhandled transform type interface");
1587  results.setParams(result, state.getParams(terminatorOperand));
1588  }
1589  }
1590 }
1591 
1594  Operation *payloadRoot) {
1595  return TransformState(region, payloadRoot);
1596 }
1597 
1598 //===----------------------------------------------------------------------===//
1599 // Utilities for PossibleTopLevelTransformOpTrait.
1600 //===----------------------------------------------------------------------===//
1601 
1602 /// Appends to `effects` the memory effect instances on `target` with the same
1603 /// resource and effect as the ones the operation `iface` having on `source`.
1604 static void
1605 remapEffects(MemoryEffectOpInterface iface, BlockArgument source,
1606  OpOperand *target,
1609  iface.getEffectsOnValue(source, nestedEffects);
1610  for (const auto &effect : nestedEffects)
1611  effects.emplace_back(effect.getEffect(), target, effect.getResource());
1612 }
1613 
1614 /// Appends to `effects` the same effects as the operations of `block` have on
1615 /// block arguments but associated with `operands.`
1616 static void
1619  for (Operation &op : block) {
1620  auto iface = dyn_cast<MemoryEffectOpInterface>(&op);
1621  if (!iface)
1622  continue;
1623 
1624  for (auto &&[source, target] : llvm::zip(block.getArguments(), operands)) {
1625  remapEffects(iface, source, &target, effects);
1626  }
1627 
1629  iface.getEffectsOnResource(transform::PayloadIRResource::get(),
1630  nestedEffects);
1631  llvm::append_range(effects, nestedEffects);
1632  }
1633 }
1634 
1636  Operation *operation, Value root, Block &body,
1638  transform::onlyReadsHandle(operation->getOpOperands(), effects);
1639  transform::producesHandle(operation->getOpResults(), effects);
1640 
1641  if (!root) {
1642  for (Operation &op : body) {
1643  auto iface = dyn_cast<MemoryEffectOpInterface>(&op);
1644  if (!iface)
1645  continue;
1646 
1648  iface.getEffects(effects);
1649  }
1650  return;
1651  }
1652 
1653  // Carry over all effects on arguments of the entry block as those on the
1654  // operands, this is the same value just remapped.
1655  remapArgumentEffects(body, operation->getOpOperands(), effects);
1656 }
1657 
1659  TransformState &state, Operation *op, Region &region) {
1660  SmallVector<Operation *> targets;
1661  SmallVector<SmallVector<MappedValue>> extraMappings;
1662  if (op->getNumOperands() != 0) {
1663  llvm::append_range(targets, state.getPayloadOps(op->getOperand(0)));
1664  prepareValueMappings(extraMappings, op->getOperands().drop_front(), state);
1665  } else {
1666  if (state.getNumTopLevelMappings() !=
1667  region.front().getNumArguments() - 1) {
1668  return emitError(op->getLoc())
1669  << "operation expects " << region.front().getNumArguments() - 1
1670  << " extra value bindings, but " << state.getNumTopLevelMappings()
1671  << " were provided to the interpreter";
1672  }
1673 
1674  targets.push_back(state.getTopLevel());
1675 
1676  for (unsigned i = 0, e = state.getNumTopLevelMappings(); i < e; ++i)
1677  extraMappings.push_back(llvm::to_vector(state.getTopLevelMapping(i)));
1678  }
1679 
1680  if (failed(state.mapBlockArguments(region.front().getArgument(0), targets)))
1681  return failure();
1682 
1683  for (BlockArgument argument : region.front().getArguments().drop_front()) {
1684  if (failed(state.mapBlockArgument(
1685  argument, extraMappings[argument.getArgNumber() - 1])))
1686  return failure();
1687  }
1688 
1689  return success();
1690 }
1691 
1692 LogicalResult
1694  // Attaching this trait without the interface is a misuse of the API, but it
1695  // cannot be caught via a static_assert because interface registration is
1696  // dynamic.
1697  assert(isa<TransformOpInterface>(op) &&
1698  "should implement TransformOpInterface to have "
1699  "PossibleTopLevelTransformOpTrait");
1700 
1701  if (op->getNumRegions() < 1)
1702  return op->emitOpError() << "expects at least one region";
1703 
1704  Region *bodyRegion = &op->getRegion(0);
1705  if (!llvm::hasNItems(*bodyRegion, 1))
1706  return op->emitOpError() << "expects a single-block region";
1707 
1708  Block *body = &bodyRegion->front();
1709  if (body->getNumArguments() == 0) {
1710  return op->emitOpError()
1711  << "expects the entry block to have at least one argument";
1712  }
1713  if (!llvm::isa<TransformHandleTypeInterface>(
1714  body->getArgument(0).getType())) {
1715  return op->emitOpError()
1716  << "expects the first entry block argument to be of type "
1717  "implementing TransformHandleTypeInterface";
1718  }
1719  BlockArgument arg = body->getArgument(0);
1720  if (op->getNumOperands() != 0) {
1721  if (arg.getType() != op->getOperand(0).getType()) {
1722  return op->emitOpError()
1723  << "expects the type of the block argument to match "
1724  "the type of the operand";
1725  }
1726  }
1727  for (BlockArgument arg : body->getArguments().drop_front()) {
1728  if (llvm::isa<TransformHandleTypeInterface, TransformParamTypeInterface,
1729  TransformValueHandleTypeInterface>(arg.getType()))
1730  continue;
1731 
1733  op->emitOpError()
1734  << "expects trailing entry block arguments to be of type implementing "
1735  "TransformHandleTypeInterface, TransformValueHandleTypeInterface or "
1736  "TransformParamTypeInterface";
1737  diag.attachNote() << "argument #" << arg.getArgNumber() << " does not";
1738  return diag;
1739  }
1740 
1741  if (auto *parent =
1743  if (op->getNumOperands() != body->getNumArguments()) {
1745  op->emitOpError()
1746  << "expects operands to be provided for a nested op";
1747  diag.attachNote(parent->getLoc())
1748  << "nested in another possible top-level op";
1749  return diag;
1750  }
1751  }
1752 
1753  return success();
1754 }
1755 
1756 //===----------------------------------------------------------------------===//
1757 // Utilities for ParamProducedTransformOpTrait.
1758 //===----------------------------------------------------------------------===//
1759 
1762  producesHandle(op->getResults(), effects);
1763  bool hasPayloadOperands = false;
1764  for (OpOperand &operand : op->getOpOperands()) {
1765  onlyReadsHandle(operand, effects);
1766  if (llvm::isa<TransformHandleTypeInterface,
1767  TransformValueHandleTypeInterface>(operand.get().getType()))
1768  hasPayloadOperands = true;
1769  }
1770  if (hasPayloadOperands)
1771  onlyReadsPayload(effects);
1772 }
1773 
1774 LogicalResult
1776  // Interfaces can be attached dynamically, so this cannot be a static
1777  // assert.
1778  if (!op->getName().getInterface<MemoryEffectOpInterface>()) {
1779  llvm::report_fatal_error(
1780  Twine("ParamProducerTransformOpTrait must be attached to an op that "
1781  "implements MemoryEffectsOpInterface, found on ") +
1782  op->getName().getStringRef());
1783  }
1784  for (Value result : op->getResults()) {
1785  if (llvm::isa<TransformParamTypeInterface>(result.getType()))
1786  continue;
1787  return op->emitOpError()
1788  << "ParamProducerTransformOpTrait attached to this op expects "
1789  "result types to implement TransformParamTypeInterface";
1790  }
1791  return success();
1792 }
1793 
1794 //===----------------------------------------------------------------------===//
1795 // Memory effects.
1796 //===----------------------------------------------------------------------===//
1797 
1801  for (OpOperand &handle : handles) {
1802  effects.emplace_back(MemoryEffects::Read::get(), &handle,
1804  effects.emplace_back(MemoryEffects::Free::get(), &handle,
1806  }
1807 }
1808 
1809 /// Returns `true` if the given list of effects instances contains an instance
1810 /// with the effect type specified as template parameter.
1811 template <typename EffectTy, typename ResourceTy, typename Range>
1812 static bool hasEffect(Range &&effects) {
1813  return llvm::any_of(effects, [](const MemoryEffects::EffectInstance &effect) {
1814  return isa<EffectTy>(effect.getEffect()) &&
1815  isa<ResourceTy>(effect.getResource());
1816  });
1817 }
1818 
1820  transform::TransformOpInterface transform) {
1821  auto iface = cast<MemoryEffectOpInterface>(transform.getOperation());
1823  iface.getEffectsOnValue(handle, effects);
1824  return ::hasEffect<MemoryEffects::Read, TransformMappingResource>(effects) &&
1825  ::hasEffect<MemoryEffects::Free, TransformMappingResource>(effects);
1826 }
1827 
1829  ResultRange handles,
1831  for (OpResult handle : handles) {
1832  effects.emplace_back(MemoryEffects::Allocate::get(), handle,
1834  effects.emplace_back(MemoryEffects::Write::get(), handle,
1836  }
1837 }
1838 
1842  for (BlockArgument handle : handles) {
1843  effects.emplace_back(MemoryEffects::Allocate::get(), handle,
1845  effects.emplace_back(MemoryEffects::Write::get(), handle,
1847  }
1848 }
1849 
1853  for (OpOperand &handle : handles) {
1854  effects.emplace_back(MemoryEffects::Read::get(), &handle,
1856  }
1857 }
1858 
1861  effects.emplace_back(MemoryEffects::Read::get(), PayloadIRResource::get());
1862  effects.emplace_back(MemoryEffects::Write::get(), PayloadIRResource::get());
1863 }
1864 
1867  effects.emplace_back(MemoryEffects::Read::get(), PayloadIRResource::get());
1868 }
1869 
1870 bool transform::doesModifyPayload(transform::TransformOpInterface transform) {
1871  auto iface = cast<MemoryEffectOpInterface>(transform.getOperation());
1873  iface.getEffects(effects);
1874  return ::hasEffect<MemoryEffects::Write, PayloadIRResource>(effects);
1875 }
1876 
1877 bool transform::doesReadPayload(transform::TransformOpInterface transform) {
1878  auto iface = cast<MemoryEffectOpInterface>(transform.getOperation());
1880  iface.getEffects(effects);
1881  return ::hasEffect<MemoryEffects::Read, PayloadIRResource>(effects);
1882 }
1883 
1885  Block &block, llvm::SmallDenseSet<unsigned int> &consumedArguments) {
1887  for (Operation &nested : block) {
1888  auto iface = dyn_cast<MemoryEffectOpInterface>(nested);
1889  if (!iface)
1890  continue;
1891 
1892  effects.clear();
1893  iface.getEffects(effects);
1894  for (const MemoryEffects::EffectInstance &effect : effects) {
1895  BlockArgument argument =
1896  dyn_cast_or_null<BlockArgument>(effect.getValue());
1897  if (!argument || argument.getOwner() != &block ||
1898  !isa<MemoryEffects::Free>(effect.getEffect()) ||
1899  effect.getResource() != transform::TransformMappingResource::get()) {
1900  continue;
1901  }
1902  consumedArguments.insert(argument.getArgNumber());
1903  }
1904  }
1905 }
1906 
1907 //===----------------------------------------------------------------------===//
1908 // Utilities for TransformOpInterface.
1909 //===----------------------------------------------------------------------===//
1910 
1912  TransformOpInterface transformOp) {
1913  SmallVector<OpOperand *> consumedOperands;
1914  consumedOperands.reserve(transformOp->getNumOperands());
1915  auto memEffectInterface =
1916  cast<MemoryEffectOpInterface>(transformOp.getOperation());
1918  for (OpOperand &target : transformOp->getOpOperands()) {
1919  effects.clear();
1920  memEffectInterface.getEffectsOnValue(target.get(), effects);
1921  if (llvm::any_of(effects, [](const MemoryEffects::EffectInstance &effect) {
1922  return isa<transform::TransformMappingResource>(
1923  effect.getResource()) &&
1924  isa<MemoryEffects::Free>(effect.getEffect());
1925  })) {
1926  consumedOperands.push_back(&target);
1927  }
1928  }
1929  return consumedOperands;
1930 }
1931 
1933  auto iface = cast<MemoryEffectOpInterface>(op);
1935  iface.getEffects(effects);
1936 
1937  auto effectsOn = [&](Value value) {
1938  return llvm::make_filter_range(
1939  effects, [value](const MemoryEffects::EffectInstance &instance) {
1940  return instance.getValue() == value;
1941  });
1942  };
1943 
1944  std::optional<unsigned> firstConsumedOperand;
1945  for (OpOperand &operand : op->getOpOperands()) {
1946  auto range = effectsOn(operand.get());
1947  if (range.empty()) {
1949  op->emitError() << "TransformOpInterface requires memory effects "
1950  "on operands to be specified";
1951  diag.attachNote() << "no effects specified for operand #"
1952  << operand.getOperandNumber();
1953  return diag;
1954  }
1955  if (::hasEffect<MemoryEffects::Allocate, TransformMappingResource>(range)) {
1957  << "TransformOpInterface did not expect "
1958  "'allocate' memory effect on an operand";
1959  diag.attachNote() << "specified for operand #"
1960  << operand.getOperandNumber();
1961  return diag;
1962  }
1963  if (!firstConsumedOperand &&
1964  ::hasEffect<MemoryEffects::Free, TransformMappingResource>(range)) {
1965  firstConsumedOperand = operand.getOperandNumber();
1966  }
1967  }
1968 
1969  if (firstConsumedOperand &&
1970  !::hasEffect<MemoryEffects::Write, PayloadIRResource>(effects)) {
1972  op->emitError()
1973  << "TransformOpInterface expects ops consuming operands to have a "
1974  "'write' effect on the payload resource";
1975  diag.attachNote() << "consumes operand #" << *firstConsumedOperand;
1976  return diag;
1977  }
1978 
1979  for (OpResult result : op->getResults()) {
1980  auto range = effectsOn(result);
1981  if (!::hasEffect<MemoryEffects::Allocate, TransformMappingResource>(
1982  range)) {
1984  op->emitError() << "TransformOpInterface requires 'allocate' memory "
1985  "effect to be specified for results";
1986  diag.attachNote() << "no 'allocate' effect specified for result #"
1987  << result.getResultNumber();
1988  return diag;
1989  }
1990  }
1991 
1992  return success();
1993 }
1994 
1995 //===----------------------------------------------------------------------===//
1996 // Entry point.
1997 //===----------------------------------------------------------------------===//
1998 
2000  Operation *payloadRoot, TransformOpInterface transform,
2001  const RaggedArray<MappedValue> &extraMapping,
2002  const TransformOptions &options, bool enforceToplevelTransformOp) {
2003  if (enforceToplevelTransformOp) {
2004  if (!transform->hasTrait<PossibleTopLevelTransformOpTrait>() ||
2005  transform->getNumOperands() != 0) {
2006  return transform->emitError()
2007  << "expected transform to start at the top-level transform op";
2008  }
2009  } else if (failed(
2011  return failure();
2012  }
2013 
2014  TransformState state(transform->getParentRegion(), payloadRoot, extraMapping,
2015  options);
2016  return state.applyTransform(transform).checkAndReport();
2017 }
2018 
2019 //===----------------------------------------------------------------------===//
2020 // Generated interface implementation.
2021 //===----------------------------------------------------------------------===//
2022 
2023 #include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.cpp.inc"
2024 #include "mlir/Dialect/Transform/Interfaces/TransformTypeInterfaces.cpp.inc"
static InFlightDiagnostic emitDiag(Location location, DiagnosticSeverity severity, const Twine &message)
Helper function used to emit a diagnostic with an optionally empty twine message.
static std::string diag(const llvm::Value &value)
static llvm::ManagedStatic< PassManagerOptions > options
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
#define DEBUG_TYPE_FULL
#define FULL_LDBG(X)
static void remapEffects(MemoryEffectOpInterface iface, BlockArgument source, OpOperand *target, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
Appends to effects the memory effect instances on target with the same resource and effect as the one...
static void remapArgumentEffects(Block &block, MutableArrayRef< OpOperand > operands, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
Appends to effects the same effects as the operations of block have on block arguments but associated...
static bool happensBefore(Operation *a, Operation *b)
Return true if a happens before b, i.e., a or one of its ancestors properly dominates b and b is not ...
void dropMappingEntry(Mapping &mapping, Key key, Mapped mapped)
#define DEBUG_PRINT_AFTER_ALL
DiagnosedSilenceableFailure checkRepeatedConsumptionInOperand(ArrayRef< T > payload, transform::TransformOpInterface transform, unsigned operandNumber)
#define DBGS()
static DiagnosedSilenceableFailure dispatchMappedValues(Value handle, ArrayRef< transform::MappedValue > values, function_ref< LogicalResult(ArrayRef< Operation * >)> operationsFn, function_ref< LogicalResult(ArrayRef< transform::Param >)> paramsFn, function_ref< LogicalResult(ValueRange)> valuesFn)
Given a list of MappedValues, cast them to the value kind implied by the interface of the handle type...
static SmallVector< T > castVector(ArrayRef< transform::MappedValue > range)
Attributes are known-constant values of operations.
Definition: Attributes.h:25
This class represents an argument of a Block.
Definition: Value.h:319
Block * getOwner() const
Returns the block that owns this argument.
Definition: Value.h:328
unsigned getArgNumber() const
Returns the number of this argument.
Definition: Value.h:331
Block represents an ordered list of Operations.
Definition: Block.h:31
Operation * findAncestorOpInBlock(Operation &op)
Returns 'op' if 'op' lies in this block, or otherwise finds the ancestor operation of 'op' that lies ...
Definition: Block.cpp:73
BlockArgument getArgument(unsigned i)
Definition: Block.h:127
unsigned getNumArguments()
Definition: Block.h:126
Operation * getTerminator()
Get the terminator operation of this block.
Definition: Block.cpp:243
BlockArgListType getArguments()
Definition: Block.h:85
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
Definition: Block.cpp:30
A compatibility class connecting InFlightDiagnostic to DiagnosedSilenceableFailure while providing an...
The result of a transform IR operation application.
LogicalResult silence()
Converts silenceable failure into LogicalResult success without reporting the diagnostic,...
static DiagnosedSilenceableFailure success()
Constructs a DiagnosedSilenceableFailure in the success state.
Diagnostic & attachNote(std::optional< Location > loc=std::nullopt)
Attaches a note to the last diagnostic.
std::string getMessage() const
Returns the diagnostic message without emitting it.
bool isDefiniteFailure() const
Returns true if this is a definite failure.
LogicalResult checkAndReport()
Converts all kinds of failure into a LogicalResult failure, emitting the diagnostic if necessary.
static DiagnosedSilenceableFailure silenceableFailure(Diagnostic &&diag)
Constructs a DiagnosedSilenceableFailure in the silenceable failure state, ready to emit the given di...
bool succeeded() const
Returns true if this is a success.
static DiagnosedSilenceableFailure definiteFailure()
Constructs a DiagnosedSilenceableFailure in the failure state.
bool isSilenceableFailure() const
Returns true if this is a silenceable failure.
This class contains all of the information necessary to report a diagnostic to the DiagnosticEngine.
Definition: Diagnostics.h:155
IRValueT get() const
Return the current value being used by this operand.
Definition: UseDefLists.h:160
This class represents a diagnostic that is inflight and set to be reported.
Definition: Diagnostics.h:313
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
void setListener(Listener *newListener)
Sets the listener of this builder to the one provided.
Definition: Builders.h:319
This class represents an operand of an operation.
Definition: Value.h:267
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Definition: Value.cpp:216
Set of flags used to control the behavior of the various IR print methods (e.g.
This is a value defined by a result of an operation.
Definition: Value.h:457
unsigned getResultNumber() const
Returns the number of this result.
Definition: Value.h:469
This class provides the API for a sub-set of ops that are known to be constant-like.
This class provides the API for ops that are known to be isolated from above.
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
StringAttr getIdentifier() const
Return the name of this operation as a StringAttr.
T::Concept * getInterface() const
Returns an instance of the concept object for the given interface if it was registered to this operat...
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Value getOperand(unsigned idx)
Definition: Operation.h:345
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition: Operation.h:745
bool isBeforeInBlock(Operation *other)
Given an operation 'other' that is within the same parent block, return whether the current operation...
Definition: Operation.cpp:386
unsigned getNumRegions()
Returns the number of regions held by this operation.
Definition: Operation.h:669
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
unsigned getNumOperands()
Definition: Operation.h:341
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition: Operation.h:234
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:268
Block * getBlock()
Returns the operation block that contains this operation.
Definition: Operation.h:213
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Definition: Operation.h:682
Operation * getParentWithTrait()
Returns the closest surrounding parent operation with trait Trait.
Definition: Operation.h:248
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:119
MutableArrayRef< OpOperand > getOpOperands()
Definition: Operation.h:378
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:373
result_range getOpResults()
Definition: Operation.h:415
result_range getResults()
Definition: Operation.h:410
bool isProperAncestor(Operation *other)
Return true if this operation is a proper ancestor of the other operation.
Definition: Operation.cpp:219
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:671
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:399
A 2D array where each row may have different length.
Definition: RaggedArray.h:18
size_t size() const
Returns the number of rows in the 2D array.
Definition: RaggedArray.h:21
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
Operation * getParentOp()
Return the parent operation this region is attached to.
Definition: Region.h:200
Block & front()
Definition: Region.h:65
This class implements the result iterators for the Operation class.
Definition: ValueRange.h:242
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:400
This class represents a specific instance of an effect.
Resource * getResource() const
Return the resource that the effect applies to.
EffectT * getEffect() const
Return the effect being applied.
Value getValue() const
Return the value the effect is applied on, or nullptr if there isn't a known value being affected.
static DerivedEffect * get()
Returns a unique instance for the derived effect class.
static TransformMappingResource * get()
Returns a unique instance for the given effect class.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:26
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
Operation * getOwner() const
Return the owner of this operand.
Definition: UseDefLists.h:38
A list of results of applying a transform op with ApplyEachOpTrait to a single payload operation,...
size_t size() const
Returns the number of elements in the list.
A specialized listener that keeps track of cases in which no replacement payload could be found.
bool failed() const
Return "true" if this tracking listener had a failure.
void notifyPayloadReplacementNotFound(Operation *op, ValueRange values, DiagnosedSilenceableFailure &&diag) override
This function is called when a tracked payload op is dropped because no replacement op was found.
DiagnosedSilenceableFailure checkAndResetError()
Check and return the current error state of this listener.
This trait is supposed to be attached to Transform dialect operations that can be standalone top-leve...
TrackingListener failures are reported only for ops that have this trait.
void notifyMatchFailure(Location loc, function_ref< void(Diagnostic &)> reasonCallback) override
Notify the listener that the pattern failed to match the given operation, and provide a callback to p...
TrackingListener(TransformState &state, TransformOpInterface op, TrackingListenerConfig config=TrackingListenerConfig())
Create a new TrackingListener for usage in the specified transform op.
static Operation * getCommonDefiningOp(ValueRange values)
Return the single op that defines all given values (if any).
virtual DiagnosedSilenceableFailure findReplacementOp(Operation *&result, Operation *op, ValueRange newValues) const
Return a replacement payload op for the given op, which is going to be replaced with the given values...
Options controlling the application of transform operations by the TransformState.
Local mapping between values defined by a specific op implementing the TransformOpInterface and the p...
void setValues(OpResult handle, Range &&values)
Indicates that the result of the transform IR op at the given position corresponds to the given range...
void setParams(OpResult value, ArrayRef< TransformState::Param > params)
Indicates that the result of the transform IR op at the given position corresponds to the given list ...
void set(OpResult value, Range &&ops)
Indicates that the result of the transform IR op at the given position corresponds to the given list ...
void setRemainingToEmpty(TransformOpInterface transform)
Sets the currently unset results to empty lists of the kind expected by the corresponding results of ...
void setMappedValues(OpResult handle, ArrayRef< MappedValue > values)
Indicates that the result of the transform IR op at the given position corresponds to the given range...
This is a special rewriter to be used in transform op implementations, providing additional helper fu...
TransformRewriter(MLIRContext *ctx, ErrorCheckingTrackingListener *listener)
Create a new TransformRewriter.
bool hasTrackingFailures() const
Return "true" if the tracking listener had failures.
LogicalResult notifyPayloadOperationReplaced(Operation *op, Operation *replacement)
Notify the transform dialect interpreter that the given op has been replaced with another op and that...
void silenceTrackingFailure()
Silence all tracking failures that have been encountered so far.
LogicalResult replacePayloadOp(Operation *op, Operation *replacement)
Replaces the given payload op with another op.
virtual ~Extension()
Base virtual destructor.
LogicalResult replacePayloadValue(Value value, Value replacement)
Replaces the given payload value with another value.
A RAII object maintaining a "stack frame" for a transform IR region.
~RegionScope()
Forgets the mapping from or to values defined in the associated transform IR region,...
The state maintained across applications of various ops implementing the TransformOpInterface.
LogicalResult getHandlesForPayloadValue(Value payloadValue, SmallVectorImpl< Value > &handles, bool includeOutOfScope=false) const
Populates handles with all handles pointing to the given payload IR value.
LogicalResult mapBlockArguments(BlockArgument argument, ArrayRef< Operation * > operations)
Records the mapping between a block argument in the transform IR and a list of operations in the payl...
DiagnosedSilenceableFailure applyTransform(TransformOpInterface transform)
Applies the transformation specified by the given transform op and updates the state accordingly.
ArrayRef< Attribute > getParams(Value value) const
Returns the list of parameters that the given transform IR value corresponds to.
LogicalResult mapBlockArgument(BlockArgument argument, ArrayRef< MappedValue > values)
LogicalResult getHandlesForPayloadOp(Operation *op, SmallVectorImpl< Value > &handles, bool includeOutOfScope=false) const
Populates handles with all handles pointing to the given Payload IR op.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:285
LogicalResult verifyTransformOpInterface(Operation *op)
Verification hook for TransformOpInterface.
LogicalResult appendValueMappings(MutableArrayRef< SmallVector< transform::MappedValue >> mappings, ValueRange values, const transform::TransformState &state, bool flatten=true)
Appends the entities associated with the given transform values in state to the pre-existing list of ...
void setApplyToOneResults(Operation *transformOp, TransformResults &transformResults, ArrayRef< ApplyToEachResultList > results)
"Transpose" the results produced by individual applications, arranging them per result value of the t...
void forwardTerminatorOperands(Block *block, transform::TransformState &state, transform::TransformResults &results)
Populates results with payload associations that match exactly those of the operands to block's termi...
void getParamProducerTransformOpTraitEffects(Operation *op, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
Non-template implementation of ParamProducerTransformOpTrait::getEffects().
LogicalResult checkNestedConsumption(Location loc, ArrayRef< Operation * > targets)
Reports an error and returns failure if targets contains an ancestor operation before its descendant ...
LogicalResult mapPossibleTopLevelTransformOpBlockArguments(TransformState &state, Operation *op, Region &region)
Maps the only block argument of the op with PossibleTopLevelTransformOpTrait to either the list of op...
TransformState makeTransformStateForTesting(Region *region, Operation *payloadRoot)
Make a dummy transform state for testing purposes.
LogicalResult checkApplyToOne(Operation *transformOp, Location payloadOpLoc, const ApplyToEachResultList &partialResult)
Check that the contents of partialResult matches the number, kind (payload op or parameter) and nulli...
SmallVector< OpOperand * > getConsumedHandleOpOperands(transform::TransformOpInterface transformOp)
Returns all operands that are handles and being consumed by the given op.
LogicalResult verifyParamProducerTransformOpTrait(Operation *op)
Non-template implementation of ParamProducerTransformOpTrait::verify().
void prepareValueMappings(SmallVectorImpl< SmallVector< transform::MappedValue >> &mappings, ValueRange values, const transform::TransformState &state)
Populates mappings with mapped values associated with the given transform IR values in the given stat...
LogicalResult verifyPossibleTopLevelTransformOpTrait(Operation *op)
Verification hook for PossibleTopLevelTransformOpTrait.
void getPotentialTopLevelEffects(Operation *operation, Value root, Block &body, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
Populates effects with side effects implied by PossibleTopLevelTransformOpTrait for the given operati...
void onlyReadsPayload(SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
bool isHandleConsumed(Value handle, transform::TransformOpInterface transform)
Checks whether the transform op consumes the given handle.
void getConsumedBlockArguments(Block &block, llvm::SmallDenseSet< unsigned > &consumedArguments)
Populates consumedArguments with positions of block arguments that are consumed by the operations in ...
LogicalResult applyTransforms(Operation *payloadRoot, TransformOpInterface transform, const RaggedArray< MappedValue > &extraMapping={}, const TransformOptions &options=TransformOptions(), bool enforceToplevelTransformOp=true)
Entry point to the Transform dialect infrastructure.
bool doesModifyPayload(transform::TransformOpInterface transform)
Checks whether the transform op modifies the payload.
void producesHandle(ResultRange handles, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
bool doesReadPayload(transform::TransformOpInterface transform)
Checks whether the transform op reads the payload.
void consumesHandle(MutableArrayRef< OpOperand > handles, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
Populates effects with the memory effects indicating the operation on the given handle value:
void onlyReadsHandle(MutableArrayRef< OpOperand > handles, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
void modifiesPayload(SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
Populates effects with the memory effects indicating the access to payload IR resource.
Include the generated interface declarations.
DiagnosedSilenceableFailure emitSilenceableFailure(Location loc, const Twine &message={})
Emits a silenceable failure with the given message.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
DiagnosedDefiniteFailure emitDefiniteFailure(Location loc, const Twine &message={})
Emits a definite failure with the given message.
bool hasEffect(Operation *op)
Returns true if op has an effect of type EffectTy.
Represents a range (offset, size, and stride) where each element of the triple may be dynamic or stat...
A configuration object for customizing a TrackingListener.
SkipHandleFn skipHandleFn
An optional function that returns "true" for handles that do not have to be updated.