MLIR 22.0.0git
TransformInterfaces.cpp
Go to the documentation of this file.
1//===- TransformInterfaces.cpp - Transform Dialect Interfaces -------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
10
11#include "mlir/IR/Diagnostics.h"
12#include "mlir/IR/Operation.h"
15#include "llvm/ADT/STLExtras.h"
16#include "llvm/ADT/ScopeExit.h"
17#include "llvm/ADT/iterator.h"
18#include "llvm/Support/Debug.h"
19#include "llvm/Support/DebugLog.h"
20#include "llvm/Support/ErrorHandling.h"
21#include "llvm/Support/InterleavedRange.h"
22
23#define DEBUG_TYPE "transform-dialect"
24#define DEBUG_PRINT_AFTER_ALL "transform-dialect-print-top-level-after-all"
25#define FULL_LDBG() LDBG(4)
26
27using namespace mlir;
28
29//===----------------------------------------------------------------------===//
30// Helper functions
31//===----------------------------------------------------------------------===//
32
33/// Return true if `a` happens before `b`, i.e., `a` or one of its ancestors
34/// properly dominates `b` and `b` is not inside `a`.
36 do {
37 if (a->isProperAncestor(b))
38 return false;
39 if (Operation *bAncestor = a->getBlock()->findAncestorOpInBlock(*b)) {
40 return a->isBeforeInBlock(bAncestor);
41 }
42 } while ((a = a->getParentOp()));
43 return false;
44}
45
46//===----------------------------------------------------------------------===//
47// TransformState
48//===----------------------------------------------------------------------===//
49
50transform::TransformState::TransformState(
51 Region *region, Operation *payloadRoot,
52 const RaggedArray<MappedValue> &extraMappings,
53 const TransformOptions &options)
54 : topLevel(payloadRoot), options(options) {
55 topLevelMappedValues.reserve(extraMappings.size());
56 for (ArrayRef<MappedValue> mapping : extraMappings)
57 topLevelMappedValues.push_back(mapping);
58 if (region) {
59 RegionScope *scope = new RegionScope(*this, *region);
60 topLevelRegionScope.reset(scope);
61 }
62}
63
65
67transform::TransformState::getPayloadOpsView(Value value) const {
68 const TransformOpMapping &operationMapping = getMapping(value).direct;
69 auto iter = operationMapping.find(value);
70 assert(iter != operationMapping.end() &&
71 "cannot find mapping for payload handle (param/value handle "
72 "provided?)");
73 return iter->getSecond();
74}
75
77 const ParamMapping &mapping = getMapping(value).params;
78 auto iter = mapping.find(value);
79 assert(iter != mapping.end() && "cannot find mapping for param handle "
80 "(operation/value handle provided?)");
81 return iter->getSecond();
82}
83
85transform::TransformState::getPayloadValuesView(Value handleValue) const {
86 const ValueMapping &mapping = getMapping(handleValue).values;
87 auto iter = mapping.find(handleValue);
88 assert(iter != mapping.end() && "cannot find mapping for value handle "
89 "(param/operation handle provided?)");
90 return iter->getSecond();
91}
92
95 bool includeOutOfScope) const {
96 bool found = false;
97 for (const auto &[region, mapping] : llvm::reverse(mappings)) {
98 auto iterator = mapping->reverse.find(op);
99 if (iterator != mapping->reverse.end()) {
100 llvm::append_range(handles, iterator->getSecond());
101 found = true;
102 }
103 // Stop looking when reaching a region that is isolated from above.
104 if (!includeOutOfScope &&
106 break;
107 }
108
109 return success(found);
110}
111
113 Value payloadValue, SmallVectorImpl<Value> &handles,
114 bool includeOutOfScope) const {
115 bool found = false;
116 for (const auto &[region, mapping] : llvm::reverse(mappings)) {
117 auto iterator = mapping->reverseValues.find(payloadValue);
118 if (iterator != mapping->reverseValues.end()) {
119 llvm::append_range(handles, iterator->getSecond());
120 found = true;
121 }
122 // Stop looking when reaching a region that is isolated from above.
123 if (!includeOutOfScope &&
125 break;
126 }
127
128 return success(found);
129}
130
131/// Given a list of MappedValues, cast them to the value kind implied by the
132/// interface of the handle type, and dispatch to one of the callbacks.
135 function_ref<LogicalResult(ArrayRef<Operation *>)> operationsFn,
136 function_ref<LogicalResult(ArrayRef<transform::Param>)> paramsFn,
137 function_ref<LogicalResult(ValueRange)> valuesFn) {
138 if (llvm::isa<transform::TransformHandleTypeInterface>(handle.getType())) {
139 SmallVector<Operation *> operations;
140 operations.reserve(values.size());
141 for (transform::MappedValue value : values) {
142 if (auto *op = llvm::dyn_cast_if_present<Operation *>(value)) {
143 operations.push_back(op);
144 continue;
145 }
146 return emitSilenceableFailure(handle.getLoc())
147 << "wrong kind of value provided for top-level operation handle";
148 }
149 if (failed(operationsFn(operations)))
152 }
153
154 if (llvm::isa<transform::TransformValueHandleTypeInterface>(
155 handle.getType())) {
156 SmallVector<Value> payloadValues;
157 payloadValues.reserve(values.size());
158 for (transform::MappedValue value : values) {
159 if (auto v = llvm::dyn_cast_if_present<Value>(value)) {
160 payloadValues.push_back(v);
161 continue;
162 }
163 return emitSilenceableFailure(handle.getLoc())
164 << "wrong kind of value provided for the top-level value handle";
165 }
166 if (failed(valuesFn(payloadValues)))
169 }
170
171 assert(llvm::isa<transform::TransformParamTypeInterface>(handle.getType()) &&
172 "unsupported kind of block argument");
174 parameters.reserve(values.size());
175 for (transform::MappedValue value : values) {
176 if (auto attr = llvm::dyn_cast_if_present<Attribute>(value)) {
177 parameters.push_back(attr);
178 continue;
179 }
180 return emitSilenceableFailure(handle.getLoc())
181 << "wrong kind of value provided for top-level parameter";
182 }
183 if (failed(paramsFn(parameters)))
186}
187
188LogicalResult
190 ArrayRef<MappedValue> values) {
192 argument, values,
193 [&](ArrayRef<Operation *> operations) {
194 return setPayloadOps(argument, operations);
195 },
196 [&](ArrayRef<Param> params) {
197 return setParams(argument, params);
198 },
199 [&](ValueRange payloadValues) {
200 return setPayloadValues(argument, payloadValues);
201 })
202 .checkAndReport();
203}
204
206 Block::BlockArgListType arguments,
208 for (auto &&[argument, values] : llvm::zip_equal(arguments, mapping))
209 if (failed(mapBlockArgument(argument, values)))
210 return failure();
211 return success();
212}
213
214LogicalResult
215transform::TransformState::setPayloadOps(Value value,
216 ArrayRef<Operation *> targets) {
217 assert(value != kTopLevelValue &&
218 "attempting to reset the transformation root");
219 assert(llvm::isa<TransformHandleTypeInterface>(value.getType()) &&
220 "wrong handle type");
221
222 for (Operation *target : targets) {
223 if (target)
224 continue;
225 return emitError(value.getLoc())
226 << "attempting to assign a null payload op to this transform value";
227 }
228
229 auto iface = llvm::cast<TransformHandleTypeInterface>(value.getType());
231 iface.checkPayload(value.getLoc(), targets);
232 if (failed(result.checkAndReport()))
233 return failure();
234
235 // Setting new payload for the value without cleaning it first is a misuse of
236 // the API, assert here.
237 SmallVector<Operation *> storedTargets(targets);
238 Mappings &mappings = getMapping(value);
239 bool inserted =
240 mappings.direct.insert({value, std::move(storedTargets)}).second;
241 assert(inserted && "value is already associated with another list");
242 (void)inserted;
243
244 for (Operation *op : targets)
245 mappings.reverse[op].push_back(value);
246
247 return success();
248}
249
250LogicalResult
251transform::TransformState::setPayloadValues(Value handle,
252 ValueRange payloadValues) {
253 assert(handle != nullptr && "attempting to set params for a null value");
254 assert(llvm::isa<TransformValueHandleTypeInterface>(handle.getType()) &&
255 "wrong handle type");
256
257 for (Value payload : payloadValues) {
258 if (payload)
259 continue;
260 return emitError(handle.getLoc()) << "attempting to assign a null payload "
261 "value to this transform handle";
262 }
263
264 auto iface = llvm::cast<TransformValueHandleTypeInterface>(handle.getType());
265 SmallVector<Value> payloadValueVector = llvm::to_vector(payloadValues);
266 DiagnosedSilenceableFailure result =
267 iface.checkPayload(handle.getLoc(), payloadValueVector);
268 if (failed(result.checkAndReport()))
269 return failure();
270
271 Mappings &mappings = getMapping(handle);
272 bool inserted =
273 mappings.values.insert({handle, std::move(payloadValueVector)}).second;
274 assert(
275 inserted &&
276 "value handle is already associated with another list of payload values");
277 (void)inserted;
278
279 for (Value payload : payloadValues)
280 mappings.reverseValues[payload].push_back(handle);
281
282 return success();
283}
284
285LogicalResult transform::TransformState::setParams(Value value,
286 ArrayRef<Param> params) {
287 assert(value != nullptr && "attempting to set params for a null value");
288
289 for (Attribute attr : params) {
290 if (attr)
291 continue;
292 return emitError(value.getLoc())
293 << "attempting to assign a null parameter to this transform value";
294 }
295
296 auto valueType = llvm::dyn_cast<TransformParamTypeInterface>(value.getType());
297 assert(value &&
298 "cannot associate parameter with a value of non-parameter type");
299 DiagnosedSilenceableFailure result =
300 valueType.checkPayload(value.getLoc(), params);
301 if (failed(result.checkAndReport()))
302 return failure();
303
304 Mappings &mappings = getMapping(value);
305 bool inserted =
306 mappings.params.insert({value, llvm::to_vector(params)}).second;
307 assert(inserted && "value is already associated with another list of params");
308 (void)inserted;
309 return success();
310}
311
312template <typename Mapping, typename Key, typename Mapped>
313static void dropMappingEntry(Mapping &mapping, Key key, Mapped mapped) {
314 auto it = mapping.find(key);
315 if (it == mapping.end())
316 return;
317
318 llvm::erase(it->getSecond(), mapped);
319 if (it->getSecond().empty())
320 mapping.erase(it);
321}
322
323void transform::TransformState::forgetMapping(Value opHandle,
324 ValueRange origOpFlatResults,
325 bool allowOutOfScope) {
326 Mappings &mappings = getMapping(opHandle, allowOutOfScope);
327 for (Operation *op : mappings.direct[opHandle])
328 dropMappingEntry(mappings.reverse, op, opHandle);
329 mappings.direct.erase(opHandle);
330#if LLVM_ENABLE_ABI_BREAKING_CHECKS
331 // Payload IR is removed from the mapping. This invalidates the respective
332 // iterators.
333 mappings.incrementTimestamp(opHandle);
334#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
335
336 for (Value opResult : origOpFlatResults) {
337 SmallVector<Value> resultHandles;
338 (void)getHandlesForPayloadValue(opResult, resultHandles);
339 for (Value resultHandle : resultHandles) {
340 Mappings &localMappings = getMapping(resultHandle);
341 dropMappingEntry(localMappings.values, resultHandle, opResult);
342#if LLVM_ENABLE_ABI_BREAKING_CHECKS
343 // Payload IR is removed from the mapping. This invalidates the respective
344 // iterators.
345 mappings.incrementTimestamp(resultHandle);
346#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
347 dropMappingEntry(localMappings.reverseValues, opResult, resultHandle);
348 }
349 }
350}
351
352void transform::TransformState::forgetValueMapping(
353 Value valueHandle, ArrayRef<Operation *> payloadOperations) {
354 Mappings &mappings = getMapping(valueHandle);
355 for (Value payloadValue : mappings.reverseValues[valueHandle])
356 dropMappingEntry(mappings.reverseValues, payloadValue, valueHandle);
357 mappings.values.erase(valueHandle);
358#if LLVM_ENABLE_ABI_BREAKING_CHECKS
359 // Payload IR is removed from the mapping. This invalidates the respective
360 // iterators.
361 mappings.incrementTimestamp(valueHandle);
362#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
363
364 for (Operation *payloadOp : payloadOperations) {
365 SmallVector<Value> opHandles;
366 (void)getHandlesForPayloadOp(payloadOp, opHandles);
367 for (Value opHandle : opHandles) {
368 Mappings &localMappings = getMapping(opHandle);
369 dropMappingEntry(localMappings.direct, opHandle, payloadOp);
370 dropMappingEntry(localMappings.reverse, payloadOp, opHandle);
371
372#if LLVM_ENABLE_ABI_BREAKING_CHECKS
373 // Payload IR is removed from the mapping. This invalidates the respective
374 // iterators.
375 localMappings.incrementTimestamp(opHandle);
376#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
377 }
378 }
379}
380
381LogicalResult
382transform::TransformState::replacePayloadOp(Operation *op,
383 Operation *replacement) {
384 // TODO: consider invalidating the handles to nested objects here.
385
386#ifndef NDEBUG
387 for (Value opResult : op->getResults()) {
388 SmallVector<Value> valueHandles;
389 (void)getHandlesForPayloadValue(opResult, valueHandles,
390 /*includeOutOfScope=*/true);
391 assert(valueHandles.empty() && "expected no mapping to old results");
392 }
393#endif // NDEBUG
394
395 // Drop the mapping between the op and all handles that point to it. Fail if
396 // there are no handles.
397 SmallVector<Value> opHandles;
398 if (failed(getHandlesForPayloadOp(op, opHandles, /*includeOutOfScope=*/true)))
399 return failure();
400 for (Value handle : opHandles) {
401 Mappings &mappings = getMapping(handle, /*allowOutOfScope=*/true);
402 dropMappingEntry(mappings.reverse, op, handle);
403 }
404
405 // Replace the pointed-to object of all handles with the replacement object.
406 // In case a payload op was erased (replacement object is nullptr), a nullptr
407 // is stored in the mapping. These nullptrs are removed after each transform.
408 // Furthermore, nullptrs are not enumerated by payload op iterators. The
409 // relative order of ops is preserved.
410 //
411 // Removing an op from the mapping would be problematic because removing an
412 // element from an array invalidates iterators; merely changing the value of
413 // elements does not.
414 for (Value handle : opHandles) {
415 Mappings &mappings = getMapping(handle, /*allowOutOfScope=*/true);
416 auto it = mappings.direct.find(handle);
417 if (it == mappings.direct.end())
418 continue;
419
420 SmallVector<Operation *, 2> &association = it->getSecond();
421 // Note that an operation may be associated with the handle more than once.
422 for (Operation *&mapped : association) {
423 if (mapped == op)
424 mapped = replacement;
425 }
426
427 if (replacement) {
428 mappings.reverse[replacement].push_back(handle);
429 } else {
430 opHandlesToCompact.insert(handle);
431 }
432 }
433
434 return success();
435}
436
437LogicalResult
438transform::TransformState::replacePayloadValue(Value value, Value replacement) {
439 SmallVector<Value> valueHandles;
440 if (failed(getHandlesForPayloadValue(value, valueHandles,
441 /*includeOutOfScope=*/true)))
442 return failure();
443
444 for (Value handle : valueHandles) {
445 Mappings &mappings = getMapping(handle, /*allowOutOfScope=*/true);
446 dropMappingEntry(mappings.reverseValues, value, handle);
447
448 // If replacing with null, that is erasing the mapping, drop the mapping
449 // between the handles and the IR objects
450 if (!replacement) {
451 dropMappingEntry(mappings.values, handle, value);
452#if LLVM_ENABLE_ABI_BREAKING_CHECKS
453 // Payload IR is removed from the mapping. This invalidates the respective
454 // iterators.
455 mappings.incrementTimestamp(handle);
456#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
457 } else {
458 auto it = mappings.values.find(handle);
459 if (it == mappings.values.end())
460 continue;
461
462 SmallVector<Value> &association = it->getSecond();
463 for (Value &mapped : association) {
464 if (mapped == value)
465 mapped = replacement;
466 }
467 mappings.reverseValues[replacement].push_back(handle);
468 }
469 }
470
471 return success();
472}
473
474void transform::TransformState::recordOpHandleInvalidationOne(
475 OpOperand &consumingHandle, ArrayRef<Operation *> potentialAncestors,
476 Operation *payloadOp, Value otherHandle, Value throughValue,
477 transform::TransformState::InvalidatedHandleMap &newlyInvalidated) const {
478 // If the op is associated with invalidated handle, skip the check as it
479 // may be reading invalid IR. This also ensures we report the first
480 // invalidation and not the last one.
481 if (invalidatedHandles.count(otherHandle) ||
482 newlyInvalidated.count(otherHandle))
483 return;
484
485 FULL_LDBG() << "--recordOpHandleInvalidationOne";
486 FULL_LDBG() << "--ancestors: "
487 << llvm::interleaved(
488 llvm::make_pointee_range(potentialAncestors));
489
490 Operation *owner = consumingHandle.getOwner();
491 unsigned operandNo = consumingHandle.getOperandNumber();
492 for (Operation *ancestor : potentialAncestors) {
493 // clang-format off
494 FULL_LDBG() << "----handle one ancestor: " << *ancestor;;
495
496 FULL_LDBG() << "----of payload with name: "
497 << payloadOp->getName().getIdentifier();
498 FULL_LDBG() << "----of payload: " << *payloadOp;
499 // clang-format on
500 if (!ancestor->isAncestor(payloadOp))
501 continue;
502
503 // Make sure the error-reporting lambda doesn't capture anything
504 // by-reference because it will go out of scope. Additionally, extract
505 // location from Payload IR ops because the ops themselves may be
506 // deleted before the lambda gets called.
507 Location ancestorLoc = ancestor->getLoc();
508 Location opLoc = payloadOp->getLoc();
509 std::optional<Location> throughValueLoc =
510 throughValue ? std::make_optional(throughValue.getLoc()) : std::nullopt;
511 newlyInvalidated[otherHandle] = [ancestorLoc, opLoc, owner, operandNo,
512 otherHandle,
513 throughValueLoc](Location currentLoc) {
514 InFlightDiagnostic diag = emitError(currentLoc)
515 << "op uses a handle invalidated by a "
516 "previously executed transform op";
517 diag.attachNote(otherHandle.getLoc()) << "handle to invalidated ops";
518 diag.attachNote(owner->getLoc())
519 << "invalidated by this transform op that consumes its operand #"
520 << operandNo
521 << " and invalidates all handles to payload IR entities associated "
522 "with this operand and entities nested in them";
523 diag.attachNote(ancestorLoc) << "ancestor payload op";
524 diag.attachNote(opLoc) << "nested payload op";
525 if (throughValueLoc) {
526 diag.attachNote(*throughValueLoc)
527 << "consumed handle points to this payload value";
528 }
529 };
530 }
531}
532
533void transform::TransformState::recordValueHandleInvalidationByOpHandleOne(
534 OpOperand &opHandle, ArrayRef<Operation *> potentialAncestors,
535 Value payloadValue, Value valueHandle,
536 transform::TransformState::InvalidatedHandleMap &newlyInvalidated) const {
537 // If the op is associated with invalidated handle, skip the check as it
538 // may be reading invalid IR. This also ensures we report the first
539 // invalidation and not the last one.
540 if (invalidatedHandles.count(valueHandle) ||
541 newlyInvalidated.count(valueHandle))
542 return;
543
544 for (Operation *ancestor : potentialAncestors) {
545 Operation *definingOp;
546 std::optional<unsigned> resultNo;
547 unsigned argumentNo = std::numeric_limits<unsigned>::max();
548 unsigned blockNo = std::numeric_limits<unsigned>::max();
549 unsigned regionNo = std::numeric_limits<unsigned>::max();
550 if (auto opResult = llvm::dyn_cast<OpResult>(payloadValue)) {
551 definingOp = opResult.getOwner();
552 resultNo = opResult.getResultNumber();
553 } else {
554 auto arg = llvm::cast<BlockArgument>(payloadValue);
555 definingOp = arg.getParentBlock()->getParentOp();
556 argumentNo = arg.getArgNumber();
557 blockNo = std::distance(arg.getOwner()->getParent()->begin(),
558 arg.getOwner()->getIterator());
559 regionNo = arg.getOwner()->getParent()->getRegionNumber();
560 }
561 assert(definingOp && "expected the value to be defined by an op as result "
562 "or block argument");
563 if (!ancestor->isAncestor(definingOp))
564 continue;
565
566 Operation *owner = opHandle.getOwner();
567 unsigned operandNo = opHandle.getOperandNumber();
568 Location ancestorLoc = ancestor->getLoc();
569 Location opLoc = definingOp->getLoc();
570 Location valueLoc = payloadValue.getLoc();
571 newlyInvalidated[valueHandle] = [valueHandle, owner, operandNo, resultNo,
572 argumentNo, blockNo, regionNo, ancestorLoc,
573 opLoc, valueLoc](Location currentLoc) {
574 InFlightDiagnostic diag = emitError(currentLoc)
575 << "op uses a handle invalidated by a "
576 "previously executed transform op";
577 diag.attachNote(valueHandle.getLoc()) << "invalidated handle";
578 diag.attachNote(owner->getLoc())
579 << "invalidated by this transform op that consumes its operand #"
580 << operandNo
581 << " and invalidates all handles to payload IR entities "
582 "associated with this operand and entities nested in them";
583 diag.attachNote(ancestorLoc)
584 << "ancestor op associated with the consumed handle";
585 if (resultNo) {
586 diag.attachNote(opLoc)
587 << "op defining the value as result #" << *resultNo;
588 } else {
589 diag.attachNote(opLoc)
590 << "op defining the value as block argument #" << argumentNo
591 << " of block #" << blockNo << " in region #" << regionNo;
592 }
593 diag.attachNote(valueLoc) << "payload value";
594 };
595 }
596}
597
598void transform::TransformState::recordOpHandleInvalidation(
599 OpOperand &handle, ArrayRef<Operation *> potentialAncestors,
600 Value throughValue,
601 transform::TransformState::InvalidatedHandleMap &newlyInvalidated) const {
602
603 if (potentialAncestors.empty()) {
604 FULL_LDBG() << "----recording invalidation for empty handle: "
605 << handle.get();
606
607 Operation *owner = handle.getOwner();
608 unsigned operandNo = handle.getOperandNumber();
609 newlyInvalidated[handle.get()] = [owner, operandNo](Location currentLoc) {
610 InFlightDiagnostic diag = emitError(currentLoc)
611 << "op uses a handle associated with empty "
612 "payload and invalidated by a "
613 "previously executed transform op";
614 diag.attachNote(owner->getLoc())
615 << "invalidated by this transform op that consumes its operand #"
616 << operandNo;
617 };
618 return;
619 }
620
621 // Iterate over the mapping and invalidate aliasing handles. This is quite
622 // expensive and only necessary for error reporting in case of transform
623 // dialect misuse with dangling handles. Iteration over the handles is based
624 // on the assumption that the number of handles is significantly less than the
625 // number of IR objects (operations and values). Alternatively, we could walk
626 // the IR nested in each payload op associated with the given handle and look
627 // for handles associated with each operation and value.
628 for (const auto &[region, mapping] : llvm::reverse(mappings)) {
629 // Go over all op handle mappings and mark as invalidated any handle
630 // pointing to any of the payload ops associated with the given handle or
631 // any op nested in them.
632 for (const auto &[payloadOp, otherHandles] : mapping->reverse) {
633 for (Value otherHandle : otherHandles)
634 recordOpHandleInvalidationOne(handle, potentialAncestors, payloadOp,
635 otherHandle, throughValue,
636 newlyInvalidated);
637 }
638 // Go over all value handle mappings and mark as invalidated any handle
639 // pointing to any result of the payload op associated with the given handle
640 // or any op nested in them. Similarly invalidate handles to argument of
641 // blocks belonging to any region of any payload op associated with the
642 // given handle or any op nested in them.
643 for (const auto &[payloadValue, valueHandles] : mapping->reverseValues) {
644 for (Value valueHandle : valueHandles)
645 recordValueHandleInvalidationByOpHandleOne(handle, potentialAncestors,
646 payloadValue, valueHandle,
647 newlyInvalidated);
648 }
649
650 // Stop lookup when reaching a region that is isolated from above.
651 if (region->getParentOp()->hasTrait<OpTrait::IsIsolatedFromAbove>())
652 break;
653 }
654}
655
656void transform::TransformState::recordValueHandleInvalidation(
657 OpOperand &valueHandle,
658 transform::TransformState::InvalidatedHandleMap &newlyInvalidated) const {
659 // Invalidate other handles to the same value.
660 for (Value payloadValue : getPayloadValuesView(valueHandle.get())) {
661 SmallVector<Value> otherValueHandles;
662 (void)getHandlesForPayloadValue(payloadValue, otherValueHandles);
663 for (Value otherHandle : otherValueHandles) {
664 Operation *owner = valueHandle.getOwner();
665 unsigned operandNo = valueHandle.getOperandNumber();
666 Location valueLoc = payloadValue.getLoc();
667 newlyInvalidated[otherHandle] = [otherHandle, owner, operandNo,
668 valueLoc](Location currentLoc) {
669 InFlightDiagnostic diag = emitError(currentLoc)
670 << "op uses a handle invalidated by a "
671 "previously executed transform op";
672 diag.attachNote(otherHandle.getLoc()) << "invalidated handle";
673 diag.attachNote(owner->getLoc())
674 << "invalidated by this transform op that consumes its operand #"
675 << operandNo
676 << " and invalidates handles to the same values as associated with "
677 "it";
678 diag.attachNote(valueLoc) << "payload value";
679 };
680 }
681
682 if (auto opResult = llvm::dyn_cast<OpResult>(payloadValue)) {
683 Operation *payloadOp = opResult.getOwner();
684 recordOpHandleInvalidation(valueHandle, payloadOp, payloadValue,
685 newlyInvalidated);
686 } else {
687 auto arg = llvm::dyn_cast<BlockArgument>(payloadValue);
688 for (Operation &payloadOp : *arg.getOwner())
689 recordOpHandleInvalidation(valueHandle, &payloadOp, payloadValue,
690 newlyInvalidated);
691 }
692 }
693}
694
695/// Checks that the operation does not use invalidated handles as operands.
696/// Reports errors and returns failure if it does. Otherwise, invalidates the
697/// handles consumed by the operation as well as any handles pointing to payload
698/// IR operations nested in the operations associated with the consumed handles.
699LogicalResult transform::TransformState::checkAndRecordHandleInvalidationImpl(
700 transform::TransformOpInterface transform,
701 transform::TransformState::InvalidatedHandleMap &newlyInvalidated) const {
702 FULL_LDBG() << "--Start checkAndRecordHandleInvalidation";
703 auto memoryEffectsIface =
704 cast<MemoryEffectOpInterface>(transform.getOperation());
705 SmallVector<MemoryEffects::EffectInstance> effects;
706 memoryEffectsIface.getEffectsOnResource(
708
709 for (OpOperand &target : transform->getOpOperands()) {
710 FULL_LDBG() << "----iterate on handle: " << target.get();
711 // If the operand uses an invalidated handle, report it. If the operation
712 // allows handles to point to repeated payload operations, only report
713 // pre-existing invalidation errors. Otherwise, also report invalidations
714 // caused by the current transform operation affecting its other operands.
715 auto it = invalidatedHandles.find(target.get());
716 auto nit = newlyInvalidated.find(target.get());
717 if (it != invalidatedHandles.end()) {
718 FULL_LDBG() << "--End checkAndRecordHandleInvalidation, found already "
719 "invalidated -> FAILURE";
720 return it->getSecond()(transform->getLoc()), failure();
721 }
722 if (!transform.allowsRepeatedHandleOperands() &&
723 nit != newlyInvalidated.end()) {
724 FULL_LDBG() << "--End checkAndRecordHandleInvalidation, found newly "
725 "invalidated (by this op) -> FAILURE";
726 return nit->getSecond()(transform->getLoc()), failure();
727 }
728
729 // Invalidate handles pointing to the operations nested in the operation
730 // associated with the handle consumed by this operation.
731 auto consumesTarget = [&](const MemoryEffects::EffectInstance &effect) {
732 return isa<MemoryEffects::Free>(effect.getEffect()) &&
733 effect.getValue() == target.get();
734 };
735 if (llvm::any_of(effects, consumesTarget)) {
736 FULL_LDBG() << "----found consume effect";
737 if (llvm::isa<transform::TransformHandleTypeInterface>(
738 target.get().getType())) {
739 FULL_LDBG() << "----recordOpHandleInvalidation";
740 SmallVector<Operation *> payloadOps =
741 llvm::to_vector(getPayloadOps(target.get()));
742 recordOpHandleInvalidation(target, payloadOps, nullptr,
743 newlyInvalidated);
744 } else if (llvm::isa<transform::TransformValueHandleTypeInterface>(
745 target.get().getType())) {
746 FULL_LDBG() << "----recordValueHandleInvalidation";
747 recordValueHandleInvalidation(target, newlyInvalidated);
748 } else {
749 FULL_LDBG()
750 << "----not a TransformHandle -> SKIP AND DROP ON THE FLOOR";
751 }
752 } else {
753 FULL_LDBG() << "----no consume effect -> SKIP";
754 }
755 }
756
757 FULL_LDBG() << "--End checkAndRecordHandleInvalidation -> SUCCESS";
758 return success();
759}
760
761LogicalResult transform::TransformState::checkAndRecordHandleInvalidation(
762 transform::TransformOpInterface transform) {
763 InvalidatedHandleMap newlyInvalidated;
764 LogicalResult checkResult =
765 checkAndRecordHandleInvalidationImpl(transform, newlyInvalidated);
766 invalidatedHandles.insert(std::make_move_iterator(newlyInvalidated.begin()),
767 std::make_move_iterator(newlyInvalidated.end()));
768 return checkResult;
769}
770
771template <typename T>
772static DiagnosedSilenceableFailure
774 transform::TransformOpInterface transform,
775 unsigned operandNumber) {
776 DenseSet<T> seen;
777 for (T p : payload) {
778 if (!seen.insert(p).second) {
780 transform.emitSilenceableError()
781 << "a handle passed as operand #" << operandNumber
782 << " and consumed by this operation points to a payload "
783 "entity more than once";
784 if constexpr (std::is_pointer_v<T>)
785 diag.attachNote(p->getLoc()) << "repeated target op";
786 else
787 diag.attachNote(p.getLoc()) << "repeated target value";
788 return diag;
789 }
790 }
792}
793
794void transform::TransformState::compactOpHandles() {
795 for (Value handle : opHandlesToCompact) {
796 Mappings &mappings = getMapping(handle, /*allowOutOfScope=*/true);
797#if LLVM_ENABLE_ABI_BREAKING_CHECKS
798 if (llvm::is_contained(mappings.direct[handle], nullptr))
799 // Payload IR is removed from the mapping. This invalidates the respective
800 // iterators.
801 mappings.incrementTimestamp(handle);
802#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
803 llvm::erase(mappings.direct[handle], nullptr);
804 }
805 opHandlesToCompact.clear();
806}
807
808DiagnosedSilenceableFailure
810 LDBG() << "applying: "
811 << OpWithFlags(transform, OpPrintingFlags().skipRegions());
812 FULL_LDBG() << "Top-level payload before application:\n" << *getTopLevel();
813 auto printOnFailureRAII = llvm::make_scope_exit([this] {
814 (void)this;
815 LDBG() << "Failing Top-level payload:\n"
817 OpPrintingFlags().printGenericOpForm());
818 });
819
820 // Set current transform op.
821 regionStack.back()->currentTransform = transform;
822
823 // Expensive checks to detect invalid transform IR.
824 if (options.getExpensiveChecksEnabled()) {
825 FULL_LDBG() << "ExpensiveChecksEnabled";
826 if (failed(checkAndRecordHandleInvalidation(transform)))
828
829 for (OpOperand &operand : transform->getOpOperands()) {
830 FULL_LDBG() << "iterate on handle: " << operand.get();
831 if (!isHandleConsumed(operand.get(), transform)) {
832 FULL_LDBG() << "--handle not consumed -> SKIP";
833 continue;
834 }
835 if (transform.allowsRepeatedHandleOperands()) {
836 FULL_LDBG() << "--op allows repeated handles -> SKIP";
837 continue;
838 }
839 FULL_LDBG() << "--handle is consumed";
840
841 Type operandType = operand.get().getType();
842 if (llvm::isa<TransformHandleTypeInterface>(operandType)) {
843 FULL_LDBG() << "--checkRepeatedConsumptionInOperand for Operation*";
846 getPayloadOpsView(operand.get()), transform,
847 operand.getOperandNumber());
848 if (!check.succeeded()) {
849 FULL_LDBG() << "----FAILED";
850 return check;
851 }
852 } else if (llvm::isa<TransformValueHandleTypeInterface>(operandType)) {
853 FULL_LDBG() << "--checkRepeatedConsumptionInOperand For Value";
856 getPayloadValuesView(operand.get()), transform,
857 operand.getOperandNumber());
858 if (!check.succeeded()) {
859 FULL_LDBG() << "----FAILED";
860 return check;
861 }
862 } else {
863 FULL_LDBG() << "--not a TransformHandle -> SKIP AND DROP ON THE FLOOR";
864 }
865 }
866 }
867
868 // Find which operands are consumed.
869 SmallVector<OpOperand *> consumedOperands =
870 transform.getConsumedHandleOpOperands();
871
872 // Remember the results of the payload ops associated with the consumed
873 // op handles or the ops defining the value handles so we can drop the
874 // association with them later. This must happen here because the
875 // transformation may destroy or mutate them so we cannot traverse the payload
876 // IR after that.
877 SmallVector<Value> origOpFlatResults;
878 SmallVector<Operation *> origAssociatedOps;
879 for (OpOperand *opOperand : consumedOperands) {
880 Value operand = opOperand->get();
881 if (llvm::isa<TransformHandleTypeInterface>(operand.getType())) {
882 for (Operation *payloadOp : getPayloadOps(operand)) {
883 llvm::append_range(origOpFlatResults, payloadOp->getResults());
884 }
885 continue;
886 }
887 if (llvm::isa<TransformValueHandleTypeInterface>(operand.getType())) {
888 for (Value payloadValue : getPayloadValuesView(operand)) {
889 if (llvm::isa<OpResult>(payloadValue)) {
890 origAssociatedOps.push_back(payloadValue.getDefiningOp());
891 continue;
892 }
893 llvm::append_range(
894 origAssociatedOps,
895 llvm::map_range(*llvm::cast<BlockArgument>(payloadValue).getOwner(),
896 [](Operation &op) { return &op; }));
897 }
898 continue;
899 }
902 << "unexpectedly consumed a value that is not a handle as operand #"
903 << opOperand->getOperandNumber();
904 diag.attachNote(operand.getLoc())
905 << "value defined here with type " << operand.getType();
906 return diag;
907 }
908
909 // Prepare rewriter and listener.
911 config.skipHandleFn = [&](Value handle) {
912 // Skip handle if it is dead.
913 auto scopeIt =
914 llvm::find_if(llvm::reverse(regionStack), [&](RegionScope *scope) {
915 return handle.getParentRegion() == scope->region;
916 });
917 assert(scopeIt != regionStack.rend() &&
918 "could not find region scope for handle");
919 RegionScope *scope = *scopeIt;
920 return llvm::all_of(handle.getUsers(), [&](Operation *user) {
921 return user == scope->currentTransform ||
922 happensBefore(user, scope->currentTransform);
923 });
924 };
926 config);
927 transform::TransformRewriter rewriter(transform->getContext(),
928 &trackingListener);
929
930 // Compute the result but do not short-circuit the silenceable failure case as
931 // we still want the handles to propagate properly so the "suppress" mode can
932 // proceed on a best effort basis.
933 transform::TransformResults results(transform->getNumResults());
934 DiagnosedSilenceableFailure result(transform.apply(rewriter, results, *this));
935 compactOpHandles();
936
937 // Error handling: fail if transform or listener failed.
938 DiagnosedSilenceableFailure trackingFailure =
939 trackingListener.checkAndResetError();
941 transform->hasAttr(FindPayloadReplacementOpInterface::
942 kSilenceTrackingFailuresAttrName)) {
943 // Only report failures for ReportTrackingListenerFailuresOpTrait ops. Also
944 // do not report failures if the above mentioned attribute is set.
945 if (trackingFailure.isSilenceableFailure())
946 (void)trackingFailure.silence();
947 trackingFailure = DiagnosedSilenceableFailure::success();
948 }
949 if (!trackingFailure.succeeded()) {
950 if (result.succeeded()) {
951 result = std::move(trackingFailure);
952 } else {
953 // Transform op errors have precedence, report those first.
954 if (result.isSilenceableFailure())
955 result.attachNote() << "tracking listener also failed: "
956 << trackingFailure.getMessage();
957 (void)trackingFailure.silence();
958 }
959 }
960 if (result.isDefiniteFailure())
961 return result;
962
963 // If a silenceable failure was produced, some results may be unset, set them
964 // to empty lists.
965 if (result.isSilenceableFailure())
967
968 // Remove the mapping for the operand if it is consumed by the operation. This
969 // allows us to catch use-after-free with assertions later on.
970 for (OpOperand *opOperand : consumedOperands) {
971 Value operand = opOperand->get();
972 if (llvm::isa<TransformHandleTypeInterface>(operand.getType())) {
973 forgetMapping(operand, origOpFlatResults);
974 } else if (llvm::isa<TransformValueHandleTypeInterface>(
975 operand.getType())) {
976 forgetValueMapping(operand, origAssociatedOps);
977 }
978 }
979
980 if (failed(updateStateFromResults(results, transform->getResults())))
982
983 printOnFailureRAII.release();
984 DEBUG_WITH_TYPE(DEBUG_PRINT_AFTER_ALL, {
985 LDBG() << "Top-level payload:\n" << *getTopLevel();
986 });
987 return result;
988}
989
990LogicalResult transform::TransformState::updateStateFromResults(
991 const TransformResults &results, ResultRange opResults) {
992 for (OpResult result : opResults) {
993 if (llvm::isa<TransformParamTypeInterface>(result.getType())) {
994 assert(results.isParam(result.getResultNumber()) &&
995 "expected parameters for the parameter-typed result");
996 if (failed(
997 setParams(result, results.getParams(result.getResultNumber())))) {
998 return failure();
999 }
1000 } else if (llvm::isa<TransformValueHandleTypeInterface>(result.getType())) {
1001 assert(results.isValue(result.getResultNumber()) &&
1002 "expected values for value-type-result");
1003 if (failed(setPayloadValues(
1004 result, results.getValues(result.getResultNumber())))) {
1005 return failure();
1006 }
1007 } else {
1008 assert(!results.isParam(result.getResultNumber()) &&
1009 "expected payload ops for the non-parameter typed result");
1010 if (failed(
1011 setPayloadOps(result, results.get(result.getResultNumber())))) {
1012 return failure();
1013 }
1014 }
1015 }
1016 return success();
1017}
1018
1019//===----------------------------------------------------------------------===//
1020// TransformState::Extension
1021//===----------------------------------------------------------------------===//
1022
1024
1025LogicalResult
1028 // TODO: we may need to invalidate handles to operations and values nested in
1029 // the operation being replaced.
1030 return state.replacePayloadOp(op, replacement);
1031}
1032
1033LogicalResult
1036 return state.replacePayloadValue(value, replacement);
1037}
1038
1039//===----------------------------------------------------------------------===//
1040// TransformState::RegionScope
1041//===----------------------------------------------------------------------===//
1042
1044 // Remove handle invalidation notices as handles are going out of scope.
1045 // The same region may be re-entered leading to incorrect invalidation
1046 // errors.
1047 for (Block &block : *region) {
1048 for (Value handle : block.getArguments()) {
1049 state.invalidatedHandles.erase(handle);
1050 }
1051 for (Operation &op : block) {
1052 for (Value handle : op.getResults()) {
1053 state.invalidatedHandles.erase(handle);
1054 }
1055 }
1056 }
1057
1058#if LLVM_ENABLE_ABI_BREAKING_CHECKS
1059 // Remember pointers to payload ops referenced by the handles going out of
1060 // scope.
1061 SmallVector<Operation *> referencedOps =
1062 llvm::to_vector(llvm::make_first_range(state.mappings[region]->reverse));
1063#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
1064
1065 state.mappings.erase(region);
1066 state.regionStack.pop_back();
1067}
1068
1069//===----------------------------------------------------------------------===//
1070// TransformResults
1071//===----------------------------------------------------------------------===//
1072
1073transform::TransformResults::TransformResults(unsigned numSegments) {
1074 operations.appendEmptyRows(numSegments);
1075 params.appendEmptyRows(numSegments);
1076 values.appendEmptyRows(numSegments);
1077}
1078
1081 int64_t position = value.getResultNumber();
1082 assert(position < static_cast<int64_t>(this->params.size()) &&
1083 "setting params for a non-existent handle");
1084 assert(this->params[position].data() == nullptr && "params already set");
1085 assert(operations[position].data() == nullptr &&
1086 "another kind of results already set");
1087 assert(values[position].data() == nullptr &&
1088 "another kind of results already set");
1089 this->params.replace(position, params);
1090}
1091
1093 OpResult handle, ArrayRef<MappedValue> values) {
1095 handle, values,
1096 [&](ArrayRef<Operation *> operations) {
1097 return set(handle, operations), success();
1098 },
1099 [&](ArrayRef<Param> params) {
1100 return setParams(handle, params), success();
1101 },
1102 [&](ValueRange payloadValues) {
1103 return setValues(handle, payloadValues), success();
1104 });
1105#ifndef NDEBUG
1106 if (!diag.succeeded())
1107 llvm::dbgs() << diag.getStatusString() << "\n";
1108 assert(diag.succeeded() && "incorrect mapping");
1109#endif // NDEBUG
1110 (void)diag.silence();
1111}
1112
1114 transform::TransformOpInterface transform) {
1115 for (OpResult opResult : transform->getResults()) {
1116 if (!isSet(opResult.getResultNumber()))
1117 setMappedValues(opResult, {});
1118 }
1119}
1120
1122transform::TransformResults::get(unsigned resultNumber) const {
1123 assert(resultNumber < operations.size() &&
1124 "querying results for a non-existent handle");
1125 assert(operations[resultNumber].data() != nullptr &&
1126 "querying unset results (values or params expected?)");
1127 return operations[resultNumber];
1128}
1129
1131transform::TransformResults::getParams(unsigned resultNumber) const {
1132 assert(resultNumber < params.size() &&
1133 "querying params for a non-existent handle");
1134 assert(params[resultNumber].data() != nullptr &&
1135 "querying unset params (ops or values expected?)");
1136 return params[resultNumber];
1137}
1138
1140transform::TransformResults::getValues(unsigned resultNumber) const {
1141 assert(resultNumber < values.size() &&
1142 "querying values for a non-existent handle");
1143 assert(values[resultNumber].data() != nullptr &&
1144 "querying unset values (ops or params expected?)");
1145 return values[resultNumber];
1146}
1147
1148bool transform::TransformResults::isParam(unsigned resultNumber) const {
1149 assert(resultNumber < params.size() &&
1150 "querying association for a non-existent handle");
1151 return params[resultNumber].data() != nullptr;
1152}
1153
1154bool transform::TransformResults::isValue(unsigned resultNumber) const {
1155 assert(resultNumber < values.size() &&
1156 "querying association for a non-existent handle");
1157 return values[resultNumber].data() != nullptr;
1158}
1159
1160bool transform::TransformResults::isSet(unsigned resultNumber) const {
1161 assert(resultNumber < params.size() &&
1162 "querying association for a non-existent handle");
1163 return params[resultNumber].data() != nullptr ||
1164 operations[resultNumber].data() != nullptr ||
1165 values[resultNumber].data() != nullptr;
1166}
1167
1168//===----------------------------------------------------------------------===//
1169// TrackingListener
1170//===----------------------------------------------------------------------===//
1171
1173 TransformOpInterface op,
1175 : TransformState::Extension(state), transformOp(op), config(config) {
1176 if (op) {
1177 for (OpOperand *opOperand : transformOp.getConsumedHandleOpOperands()) {
1178 consumedHandles.insert(opOperand->get());
1179 }
1180 }
1181}
1182
1184 Operation *defOp = nullptr;
1185 for (Value v : values) {
1186 // Skip empty values.
1187 if (!v)
1188 continue;
1189 if (!defOp) {
1190 defOp = v.getDefiningOp();
1191 continue;
1192 }
1193 if (defOp != v.getDefiningOp())
1194 return nullptr;
1195 }
1196 return defOp;
1197}
1198
1200 Operation *&result, Operation *op, ValueRange newValues) const {
1201 assert(op->getNumResults() == newValues.size() &&
1202 "invalid number of replacement values");
1203 SmallVector<Value> values(newValues.begin(), newValues.end());
1204
1206 getTransformOp(), "tracking listener failed to find replacement op "
1207 "during application of this transform op");
1208
1209 do {
1210 // If the replacement values belong to different ops, drop the mapping.
1211 Operation *defOp = getCommonDefiningOp(values);
1212 if (!defOp) {
1213 diag.attachNote() << "replacement values belong to different ops";
1214 return diag;
1215 }
1216
1217 // Skip through ops that implement CastOpInterface.
1218 if (config.skipCastOps && isa<CastOpInterface>(defOp)) {
1219 values.clear();
1220 values.assign(defOp->getOperands().begin(), defOp->getOperands().end());
1221 diag.attachNote(defOp->getLoc())
1222 << "using output of 'CastOpInterface' op";
1223 continue;
1224 }
1225
1226 // If the defining op has the same name or we do not care about the name of
1227 // op replacements at all, we take it as a replacement.
1228 if (!config.requireMatchingReplacementOpName ||
1229 op->getName() == defOp->getName()) {
1230 result = defOp;
1232 }
1233
1234 // Replacing an op with a constant-like equivalent is a common
1235 // canonicalization.
1236 if (defOp->hasTrait<OpTrait::ConstantLike>()) {
1237 result = defOp;
1239 }
1240
1241 values.clear();
1242
1243 // Skip through ops that implement FindPayloadReplacementOpInterface.
1244 if (auto findReplacementOpInterface =
1245 dyn_cast<FindPayloadReplacementOpInterface>(defOp)) {
1246 values.assign(findReplacementOpInterface.getNextOperands());
1247 diag.attachNote(defOp->getLoc()) << "using operands provided by "
1248 "'FindPayloadReplacementOpInterface'";
1249 continue;
1250 }
1251 } while (!values.empty());
1252
1253 diag.attachNote() << "ran out of suitable replacement values";
1254 return diag;
1255}
1256
1258 Location loc, function_ref<void(Diagnostic &)> reasonCallback) {
1259 LLVM_DEBUG({
1261 reasonCallback(diag);
1262 LDBG() << "Match Failure : " << diag.str();
1263 });
1264}
1265
1266void transform::TrackingListener::notifyOperationErased(Operation *op) {
1267 // Remove mappings for result values.
1268 for (OpResult value : op->getResults())
1269 (void)replacePayloadValue(value, nullptr);
1270 // Remove mapping for op.
1271 (void)replacePayloadOp(op, nullptr);
1272}
1273
1274void transform::TrackingListener::notifyOperationReplaced(
1275 Operation *op, ValueRange newValues) {
1276 assert(op->getNumResults() == newValues.size() &&
1277 "invalid number of replacement values");
1278
1279 // Replace value handles.
1280 for (auto [oldValue, newValue] : llvm::zip(op->getResults(), newValues))
1281 (void)replacePayloadValue(oldValue, newValue);
1282
1283 // Replace op handle.
1284 SmallVector<Value> opHandles;
1285 if (failed(getTransformState().getHandlesForPayloadOp(
1286 op, opHandles, /*includeOutOfScope=*/true))) {
1287 // Op is not tracked.
1288 return;
1289 }
1290
1291 // Helper function to check if the current transform op consumes any handle
1292 // that is mapped to `op`.
1293 //
1294 // Note: If a handle was consumed, there shouldn't be any alive users, so it
1295 // is not really necessary to check for consumed handles. However, in case
1296 // there are indeed alive handles that were consumed (which is undefined
1297 // behavior) and a replacement op could not be found, we want to fail with a
1298 // nicer error message: "op uses a handle invalidated..." instead of "could
1299 // not find replacement op". This nicer error is produced later.
1300 auto handleWasConsumed = [&] {
1301 return llvm::any_of(opHandles,
1302 [&](Value h) { return consumedHandles.contains(h); });
1303 };
1304
1305 // Check if there are any handles that must be updated.
1306 Value aliveHandle;
1307 if (config.skipHandleFn) {
1308 auto it = llvm::find_if(opHandles,
1309 [&](Value v) { return !config.skipHandleFn(v); });
1310 if (it != opHandles.end())
1311 aliveHandle = *it;
1312 } else if (!opHandles.empty()) {
1313 aliveHandle = opHandles.front();
1314 }
1315 if (!aliveHandle || handleWasConsumed()) {
1316 // The op is tracked but the corresponding handles are dead or were
1317 // consumed. Drop the op form the mapping.
1318 (void)replacePayloadOp(op, nullptr);
1319 return;
1320 }
1321
1322 Operation *replacement;
1323 DiagnosedSilenceableFailure diag =
1324 findReplacementOp(replacement, op, newValues);
1325 // If the op is tracked but no replacement op was found, send a
1326 // notification.
1327 if (!diag.succeeded()) {
1328 diag.attachNote(aliveHandle.getLoc())
1329 << "replacement is required because this handle must be updated";
1330 notifyPayloadReplacementNotFound(op, newValues, std::move(diag));
1331 (void)replacePayloadOp(op, nullptr);
1332 return;
1333 }
1334
1335 (void)replacePayloadOp(op, replacement);
1336}
1337
1339 // The state of the ErrorCheckingTrackingListener must be checked and reset
1340 // if there was an error. This is to prevent errors from accidentally being
1341 // missed.
1342 assert(status.succeeded() && "listener state was not checked");
1343}
1344
1347 DiagnosedSilenceableFailure s = std::move(status);
1349 errorCounter = 0;
1350 return s;
1351}
1352
1354 return !status.succeeded();
1355}
1356
1359
1360 // Merge potentially existing diags and store the result in the listener.
1362 diag.takeDiagnostics(diags);
1363 if (!status.succeeded())
1364 status.takeDiagnostics(diags);
1365 status = DiagnosedSilenceableFailure::silenceableFailure(std::move(diags));
1366
1367 // Report more details.
1368 status.attachNote(op->getLoc()) << "[" << errorCounter << "] replaced op";
1369 for (auto &&[index, value] : llvm::enumerate(values))
1370 status.attachNote(value.getLoc())
1371 << "[" << errorCounter << "] replacement value " << index;
1372 ++errorCounter;
1373}
1374
1375std::string
1377 if (!matchFailure) {
1378 return "";
1379 }
1380 return matchFailure->str();
1381}
1382
1384 Location loc, function_ref<void(Diagnostic &)> reasonCallback) {
1386 reasonCallback(diag);
1387 matchFailure = std::move(diag);
1388}
1389
1390//===----------------------------------------------------------------------===//
1391// TransformRewriter
1392//===----------------------------------------------------------------------===//
1393
1396 : RewriterBase(ctx), listener(listener) {
1397 setListener(listener);
1398}
1399
1401 return listener->failed();
1402}
1403
1404/// Silence all tracking failures that have been encountered so far.
1406 if (hasTrackingFailures()) {
1407 DiagnosedSilenceableFailure status = listener->checkAndResetError();
1408 (void)status.silence();
1409 }
1410}
1411
1414 return listener->replacePayloadOp(op, replacement);
1415}
1416
1417//===----------------------------------------------------------------------===//
1418// Utilities for TransformEachOpTrait.
1419//===----------------------------------------------------------------------===//
1420
1421LogicalResult
1423 ArrayRef<Operation *> targets) {
1424 for (auto &&[position, parent] : llvm::enumerate(targets)) {
1425 for (Operation *child : targets.drop_front(position + 1)) {
1426 if (parent->isAncestor(child)) {
1428 emitError(loc)
1429 << "transform operation consumes a handle pointing to an ancestor "
1430 "payload operation before its descendant";
1431 diag.attachNote()
1432 << "the ancestor is likely erased or rewritten before the "
1433 "descendant is accessed, leading to undefined behavior";
1434 diag.attachNote(parent->getLoc()) << "ancestor payload op";
1435 diag.attachNote(child->getLoc()) << "descendant payload op";
1436 return diag;
1437 }
1438 }
1439 }
1440 return success();
1441}
1442
1443LogicalResult
1445 Location payloadOpLoc,
1446 const ApplyToEachResultList &partialResult) {
1447 Location transformOpLoc = transformOp->getLoc();
1448 StringRef transformOpName = transformOp->getName().getStringRef();
1449 unsigned expectedNumResults = transformOp->getNumResults();
1450
1451 // Reuse the emission of the diagnostic note.
1452 auto emitDiag = [&]() {
1453 auto diag = mlir::emitError(transformOpLoc);
1454 diag.attachNote(payloadOpLoc) << "when applied to this op";
1455 return diag;
1456 };
1457
1458 if (partialResult.size() != expectedNumResults) {
1459 auto diag = emitDiag() << "application of " << transformOpName
1460 << " expected to produce " << expectedNumResults
1461 << " results (actually produced "
1462 << partialResult.size() << ").";
1463 diag.attachNote(transformOpLoc)
1464 << "if you need variadic results, consider a generic `apply` "
1465 << "instead of the specialized `applyToOne`.";
1466 return failure();
1467 }
1468
1469 // Check that the right kind of value was produced.
1470 for (const auto &[ptr, res] :
1471 llvm::zip(partialResult, transformOp->getResults())) {
1472 if (ptr.isNull())
1473 continue;
1474 if (llvm::isa<TransformHandleTypeInterface>(res.getType()) &&
1475 !isa<Operation *>(ptr)) {
1476 return emitDiag() << "application of " << transformOpName
1477 << " expected to produce an Operation * for result #"
1478 << res.getResultNumber();
1479 }
1480 if (llvm::isa<TransformParamTypeInterface>(res.getType()) &&
1481 !isa<Attribute>(ptr)) {
1482 return emitDiag() << "application of " << transformOpName
1483 << " expected to produce an Attribute for result #"
1484 << res.getResultNumber();
1485 }
1486 if (llvm::isa<TransformValueHandleTypeInterface>(res.getType()) &&
1487 !isa<Value>(ptr)) {
1488 return emitDiag() << "application of " << transformOpName
1489 << " expected to produce a Value for result #"
1490 << res.getResultNumber();
1491 }
1492 }
1493 return success();
1494}
1495
1496template <typename T>
1498 return llvm::map_to_vector(range, llvm::CastTo<T>);
1499}
1500
1502 Operation *transformOp, TransformResults &transformResults,
1505 transposed.resize(transformOp->getNumResults());
1506 for (const ApplyToEachResultList &partialResults : results) {
1507 if (llvm::any_of(partialResults,
1508 [](MappedValue value) { return value.isNull(); }))
1509 continue;
1510 assert(transformOp->getNumResults() == partialResults.size() &&
1511 "expected as many partial results as op as results");
1512 for (auto [i, value] : llvm::enumerate(partialResults))
1513 transposed[i].push_back(value);
1514 }
1515
1516 for (OpResult r : transformOp->getResults()) {
1517 unsigned position = r.getResultNumber();
1518 if (llvm::isa<TransformParamTypeInterface>(r.getType())) {
1519 transformResults.setParams(r,
1520 castVector<Attribute>(transposed[position]));
1521 } else if (llvm::isa<TransformValueHandleTypeInterface>(r.getType())) {
1522 transformResults.setValues(r, castVector<Value>(transposed[position]));
1523 } else {
1524 transformResults.set(r, castVector<Operation *>(transposed[position]));
1525 }
1526 }
1527}
1528
1529//===----------------------------------------------------------------------===//
1530// Utilities for implementing transform ops with regions.
1531//===----------------------------------------------------------------------===//
1532
1535 ValueRange values, const transform::TransformState &state, bool flatten) {
1536 assert(mappings.size() == values.size() && "mismatching number of mappings");
1537 for (auto &&[operand, mapped] : llvm::zip_equal(values, mappings)) {
1538 size_t mappedSize = mapped.size();
1539 if (llvm::isa<TransformHandleTypeInterface>(operand.getType())) {
1540 llvm::append_range(mapped, state.getPayloadOps(operand));
1541 } else if (llvm::isa<TransformValueHandleTypeInterface>(
1542 operand.getType())) {
1543 llvm::append_range(mapped, state.getPayloadValues(operand));
1544 } else {
1545 assert(llvm::isa<TransformParamTypeInterface>(operand.getType()) &&
1546 "unsupported kind of transform dialect value");
1547 llvm::append_range(mapped, state.getParams(operand));
1548 }
1549
1550 if (mapped.size() - mappedSize != 1 && !flatten)
1551 return failure();
1552 }
1553 return success();
1554}
1555
1558 ValueRange values, const transform::TransformState &state) {
1559 mappings.resize(mappings.size() + values.size());
1562 values.size()),
1563 values, state);
1564}
1565
1567 Block *block, transform::TransformState &state,
1568 transform::TransformResults &results) {
1569 for (auto &&[terminatorOperand, result] :
1570 llvm::zip(block->getTerminator()->getOperands(),
1571 block->getParentOp()->getOpResults())) {
1572 if (llvm::isa<transform::TransformHandleTypeInterface>(result.getType())) {
1573 results.set(result, state.getPayloadOps(terminatorOperand));
1574 } else if (llvm::isa<transform::TransformValueHandleTypeInterface>(
1575 result.getType())) {
1576 results.setValues(result, state.getPayloadValues(terminatorOperand));
1577 } else {
1578 assert(
1579 llvm::isa<transform::TransformParamTypeInterface>(result.getType()) &&
1580 "unhandled transform type interface");
1581 results.setParams(result, state.getParams(terminatorOperand));
1582 }
1583 }
1584}
1585
1588 Operation *payloadRoot) {
1589 return TransformState(region, payloadRoot);
1590}
1591
1592//===----------------------------------------------------------------------===//
1593// Utilities for PossibleTopLevelTransformOpTrait.
1594//===----------------------------------------------------------------------===//
1595
1596/// Appends to `effects` the memory effect instances on `target` with the same
1597/// resource and effect as the ones the operation `iface` having on `source`.
1598static void
1599remapEffects(MemoryEffectOpInterface iface, BlockArgument source,
1603 iface.getEffectsOnValue(source, nestedEffects);
1604 for (const auto &effect : nestedEffects)
1605 effects.emplace_back(effect.getEffect(), target, effect.getResource());
1606}
1607
1608/// Appends to `effects` the same effects as the operations of `block` have on
1609/// block arguments but associated with `operands.`
1610static void
1613 for (Operation &op : block) {
1614 auto iface = dyn_cast<MemoryEffectOpInterface>(&op);
1615 if (!iface)
1616 continue;
1617
1618 for (auto &&[source, target] : llvm::zip(block.getArguments(), operands)) {
1619 remapEffects(iface, source, &target, effects);
1620 }
1621
1623 iface.getEffectsOnResource(transform::PayloadIRResource::get(),
1624 nestedEffects);
1625 llvm::append_range(effects, nestedEffects);
1626 }
1627}
1628
1630 Operation *operation, Value root, Block &body,
1632 transform::onlyReadsHandle(operation->getOpOperands(), effects);
1633 transform::producesHandle(operation->getOpResults(), effects);
1634
1635 if (!root) {
1636 for (Operation &op : body) {
1637 auto iface = dyn_cast<MemoryEffectOpInterface>(&op);
1638 if (!iface)
1639 continue;
1640
1641 iface.getEffects(effects);
1642 }
1643 return;
1644 }
1645
1646 // Carry over all effects on arguments of the entry block as those on the
1647 // operands, this is the same value just remapped.
1648 remapArgumentEffects(body, operation->getOpOperands(), effects);
1649}
1650
1652 TransformState &state, Operation *op, Region &region) {
1655 if (op->getNumOperands() != 0) {
1656 llvm::append_range(targets, state.getPayloadOps(op->getOperand(0)));
1657 prepareValueMappings(extraMappings, op->getOperands().drop_front(), state);
1658 } else {
1659 if (state.getNumTopLevelMappings() !=
1660 region.front().getNumArguments() - 1) {
1661 return emitError(op->getLoc())
1662 << "operation expects " << region.front().getNumArguments() - 1
1663 << " extra value bindings, but " << state.getNumTopLevelMappings()
1664 << " were provided to the interpreter";
1665 }
1666
1667 targets.push_back(state.getTopLevel());
1668
1669 for (unsigned i = 0, e = state.getNumTopLevelMappings(); i < e; ++i)
1670 extraMappings.push_back(llvm::to_vector(state.getTopLevelMapping(i)));
1671 }
1672
1673 if (failed(state.mapBlockArguments(region.front().getArgument(0), targets)))
1674 return failure();
1675
1676 for (BlockArgument argument : region.front().getArguments().drop_front()) {
1677 if (failed(state.mapBlockArgument(
1678 argument, extraMappings[argument.getArgNumber() - 1])))
1679 return failure();
1680 }
1681
1682 return success();
1683}
1684
1685LogicalResult
1687 // Attaching this trait without the interface is a misuse of the API, but it
1688 // cannot be caught via a static_assert because interface registration is
1689 // dynamic.
1690 assert(isa<TransformOpInterface>(op) &&
1691 "should implement TransformOpInterface to have "
1692 "PossibleTopLevelTransformOpTrait");
1693
1694 if (op->getNumRegions() < 1)
1695 return op->emitOpError() << "expects at least one region";
1696
1697 Region *bodyRegion = &op->getRegion(0);
1698 if (!llvm::hasNItems(*bodyRegion, 1))
1699 return op->emitOpError() << "expects a single-block region";
1700
1701 Block *body = &bodyRegion->front();
1702 if (body->getNumArguments() == 0) {
1703 return op->emitOpError()
1704 << "expects the entry block to have at least one argument";
1705 }
1706 if (!llvm::isa<TransformHandleTypeInterface>(
1707 body->getArgument(0).getType())) {
1708 return op->emitOpError()
1709 << "expects the first entry block argument to be of type "
1710 "implementing TransformHandleTypeInterface";
1711 }
1712 BlockArgument arg = body->getArgument(0);
1713 if (op->getNumOperands() != 0) {
1714 if (arg.getType() != op->getOperand(0).getType()) {
1715 return op->emitOpError()
1716 << "expects the type of the block argument to match "
1717 "the type of the operand";
1718 }
1719 }
1720 for (BlockArgument arg : body->getArguments().drop_front()) {
1721 if (llvm::isa<TransformHandleTypeInterface, TransformParamTypeInterface,
1722 TransformValueHandleTypeInterface>(arg.getType()))
1723 continue;
1724
1726 op->emitOpError()
1727 << "expects trailing entry block arguments to be of type implementing "
1728 "TransformHandleTypeInterface, TransformValueHandleTypeInterface or "
1729 "TransformParamTypeInterface";
1730 diag.attachNote() << "argument #" << arg.getArgNumber() << " does not";
1731 return diag;
1732 }
1733
1734 if (auto *parent =
1736 if (op->getNumOperands() != body->getNumArguments()) {
1738 op->emitOpError()
1739 << "expects operands to be provided for a nested op";
1740 diag.attachNote(parent->getLoc())
1741 << "nested in another possible top-level op";
1742 return diag;
1743 }
1744 }
1745
1746 return success();
1747}
1748
1749//===----------------------------------------------------------------------===//
1750// Utilities for ParamProducedTransformOpTrait.
1751//===----------------------------------------------------------------------===//
1752
1755 producesHandle(op->getResults(), effects);
1756 bool hasPayloadOperands = false;
1757 for (OpOperand &operand : op->getOpOperands()) {
1758 onlyReadsHandle(operand, effects);
1759 if (llvm::isa<TransformHandleTypeInterface,
1760 TransformValueHandleTypeInterface>(operand.get().getType()))
1761 hasPayloadOperands = true;
1762 }
1763 if (hasPayloadOperands)
1764 onlyReadsPayload(effects);
1765}
1766
1767LogicalResult
1769 // Interfaces can be attached dynamically, so this cannot be a static
1770 // assert.
1771 if (!op->getName().getInterface<MemoryEffectOpInterface>()) {
1772 llvm::report_fatal_error(
1773 Twine("ParamProducerTransformOpTrait must be attached to an op that "
1774 "implements MemoryEffectsOpInterface, found on ") +
1775 op->getName().getStringRef());
1776 }
1777 for (Value result : op->getResults()) {
1778 if (llvm::isa<TransformParamTypeInterface>(result.getType()))
1779 continue;
1780 return op->emitOpError()
1781 << "ParamProducerTransformOpTrait attached to this op expects "
1782 "result types to implement TransformParamTypeInterface";
1783 }
1784 return success();
1785}
1786
1787//===----------------------------------------------------------------------===//
1788// Memory effects.
1789//===----------------------------------------------------------------------===//
1790
1794 for (OpOperand &handle : handles) {
1795 effects.emplace_back(MemoryEffects::Read::get(), &handle,
1797 effects.emplace_back(MemoryEffects::Free::get(), &handle,
1799 }
1800}
1801
1802/// Returns `true` if the given list of effects instances contains an instance
1803/// with the effect type specified as template parameter.
1804template <typename EffectTy, typename ResourceTy, typename Range>
1805static bool hasEffect(Range &&effects) {
1806 return llvm::any_of(effects, [](const MemoryEffects::EffectInstance &effect) {
1807 return isa<EffectTy>(effect.getEffect()) &&
1808 isa<ResourceTy>(effect.getResource());
1809 });
1810}
1811
1813 transform::TransformOpInterface transform) {
1814 auto iface = cast<MemoryEffectOpInterface>(transform.getOperation());
1816 iface.getEffectsOnValue(handle, effects);
1817 return ::hasEffect<MemoryEffects::Read, TransformMappingResource>(effects) &&
1819}
1820
1822 ResultRange handles,
1824 for (OpResult handle : handles) {
1825 effects.emplace_back(MemoryEffects::Allocate::get(), handle,
1827 effects.emplace_back(MemoryEffects::Write::get(), handle,
1829 }
1830}
1831
1835 for (BlockArgument handle : handles) {
1836 effects.emplace_back(MemoryEffects::Allocate::get(), handle,
1838 effects.emplace_back(MemoryEffects::Write::get(), handle,
1840 }
1841}
1842
1846 for (OpOperand &handle : handles) {
1847 effects.emplace_back(MemoryEffects::Read::get(), &handle,
1849 }
1850}
1851
1857
1862
1863bool transform::doesModifyPayload(transform::TransformOpInterface transform) {
1864 auto iface = cast<MemoryEffectOpInterface>(transform.getOperation());
1866 iface.getEffects(effects);
1867 return ::hasEffect<MemoryEffects::Write, PayloadIRResource>(effects);
1868}
1869
1870bool transform::doesReadPayload(transform::TransformOpInterface transform) {
1871 auto iface = cast<MemoryEffectOpInterface>(transform.getOperation());
1873 iface.getEffects(effects);
1874 return ::hasEffect<MemoryEffects::Read, PayloadIRResource>(effects);
1875}
1876
1878 Block &block, llvm::SmallDenseSet<unsigned int> &consumedArguments) {
1880 for (Operation &nested : block) {
1881 auto iface = dyn_cast<MemoryEffectOpInterface>(nested);
1882 if (!iface)
1883 continue;
1884
1885 effects.clear();
1886 iface.getEffects(effects);
1887 for (const MemoryEffects::EffectInstance &effect : effects) {
1888 BlockArgument argument =
1889 dyn_cast_or_null<BlockArgument>(effect.getValue());
1890 if (!argument || argument.getOwner() != &block ||
1891 !isa<MemoryEffects::Free>(effect.getEffect()) ||
1892 effect.getResource() != transform::TransformMappingResource::get()) {
1893 continue;
1894 }
1895 consumedArguments.insert(argument.getArgNumber());
1896 }
1897 }
1898}
1899
1900//===----------------------------------------------------------------------===//
1901// Utilities for TransformOpInterface.
1902//===----------------------------------------------------------------------===//
1903
1904SmallVector<OpOperand *> transform::detail::getConsumedHandleOpOperands(
1905 TransformOpInterface transformOp) {
1906 SmallVector<OpOperand *> consumedOperands;
1907 consumedOperands.reserve(transformOp->getNumOperands());
1908 auto memEffectInterface =
1909 cast<MemoryEffectOpInterface>(transformOp.getOperation());
1910 SmallVector<MemoryEffects::EffectInstance, 2> effects;
1911 for (OpOperand &target : transformOp->getOpOperands()) {
1912 effects.clear();
1913 memEffectInterface.getEffectsOnValue(target.get(), effects);
1914 if (llvm::any_of(effects, [](const MemoryEffects::EffectInstance &effect) {
1915 return isa<transform::TransformMappingResource>(
1916 effect.getResource()) &&
1917 isa<MemoryEffects::Free>(effect.getEffect());
1918 })) {
1919 consumedOperands.push_back(&target);
1920 }
1921 }
1922 return consumedOperands;
1923}
1924
1926 auto iface = cast<MemoryEffectOpInterface>(op);
1928 iface.getEffects(effects);
1929
1930 auto effectsOn = [&](Value value) {
1931 return llvm::make_filter_range(
1932 effects, [value](const MemoryEffects::EffectInstance &instance) {
1933 return instance.getValue() == value;
1934 });
1935 };
1936
1937 std::optional<unsigned> firstConsumedOperand;
1938 for (OpOperand &operand : op->getOpOperands()) {
1939 auto range = effectsOn(operand.get());
1940 if (range.empty()) {
1942 op->emitError() << "TransformOpInterface requires memory effects "
1943 "on operands to be specified";
1944 diag.attachNote() << "no effects specified for operand #"
1945 << operand.getOperandNumber();
1946 return diag;
1947 }
1950 << "TransformOpInterface did not expect "
1951 "'allocate' memory effect on an operand";
1952 diag.attachNote() << "specified for operand #"
1953 << operand.getOperandNumber();
1954 return diag;
1955 }
1956 if (!firstConsumedOperand &&
1958 firstConsumedOperand = operand.getOperandNumber();
1959 }
1960 }
1961
1962 if (firstConsumedOperand &&
1965 op->emitError()
1966 << "TransformOpInterface expects ops consuming operands to have a "
1967 "'write' effect on the payload resource";
1968 diag.attachNote() << "consumes operand #" << *firstConsumedOperand;
1969 return diag;
1970 }
1971
1972 for (OpResult result : op->getResults()) {
1973 auto range = effectsOn(result);
1975 range)) {
1977 op->emitError() << "TransformOpInterface requires 'allocate' memory "
1978 "effect to be specified for results";
1979 diag.attachNote() << "no 'allocate' effect specified for result #"
1980 << result.getResultNumber();
1981 return diag;
1982 }
1983 }
1984
1985 return success();
1986}
1987
1988//===----------------------------------------------------------------------===//
1989// Entry point.
1990//===----------------------------------------------------------------------===//
1991
1993 Operation *payloadRoot, TransformOpInterface transform,
1994 const RaggedArray<MappedValue> &extraMapping,
1995 const TransformOptions &options, bool enforceToplevelTransformOp,
1996 function_ref<void(TransformState &)> stateInitializer,
1997 function_ref<LogicalResult(TransformState &)> stateExporter) {
1998 if (enforceToplevelTransformOp) {
1999 if (!transform->hasTrait<PossibleTopLevelTransformOpTrait>() ||
2000 transform->getNumOperands() != 0) {
2001 return transform->emitError()
2002 << "expected transform to start at the top-level transform op";
2003 }
2004 } else if (failed(
2006 return failure();
2007 }
2008
2009 TransformState state(transform->getParentRegion(), payloadRoot, extraMapping,
2010 options);
2011 if (stateInitializer)
2012 stateInitializer(state);
2013 if (state.applyTransform(transform).checkAndReport().failed())
2014 return failure();
2015 if (stateExporter)
2016 return stateExporter(state);
2017 return success();
2018}
2019
2020//===----------------------------------------------------------------------===//
2021// Generated interface implementation.
2022//===----------------------------------------------------------------------===//
2023
2024#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.cpp.inc"
2025#include "mlir/Dialect/Transform/Interfaces/TransformTypeInterfaces.cpp.inc"
for(Operation *op :ops)
return success()
static Operation * getCommonDefiningOp(const ValueVector &values)
Return the operation that defines all values in the vector.
static InFlightDiagnostic emitDiag(Location location, DiagnosticSeverity severity, const Twine &message)
Helper function used to emit a diagnostic with an optionally empty twine message.
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
if(!isCopyOut)
*if copies could not be generated due to yet unimplemented cases *copyInPlacementStart and copyOutPlacementStart in copyPlacementBlock *specify the insertion points where the incoming copies and outgoing should be inserted(the insertion happens right before the *insertion point). Since `begin` can itself be invalidated due to the memref *rewriting done from this method
*if copies could not be generated due to yet unimplemented cases *copyInPlacementStart and copyOutPlacementStart in copyPlacementBlock *specify the insertion points where the incoming copies and outgoing should be the output argument nBegin is set to its * replacement(set to `begin` if no invalidation happens). Since outgoing *copies could have been inserted at `end`
static std::string diag(const llvm::Value &value)
static llvm::ManagedStatic< PassManagerOptions > options
#define FULL_LDBG()
static void remapEffects(MemoryEffectOpInterface iface, BlockArgument source, OpOperand *target, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
Appends to effects the memory effect instances on target with the same resource and effect as the one...
static void remapArgumentEffects(Block &block, MutableArrayRef< OpOperand > operands, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
Appends to effects the same effects as the operations of block have on block arguments but associated...
static bool happensBefore(Operation *a, Operation *b)
Return true if a happens before b, i.e., a or one of its ancestors properly dominates b and b is not ...
static void dropMappingEntry(Mapping &mapping, Key key, Mapped mapped)
#define DEBUG_PRINT_AFTER_ALL
static DiagnosedSilenceableFailure checkRepeatedConsumptionInOperand(ArrayRef< T > payload, transform::TransformOpInterface transform, unsigned operandNumber)
static DiagnosedSilenceableFailure dispatchMappedValues(Value handle, ArrayRef< transform::MappedValue > values, function_ref< LogicalResult(ArrayRef< Operation * >)> operationsFn, function_ref< LogicalResult(ArrayRef< transform::Param >)> paramsFn, function_ref< LogicalResult(ValueRange)> valuesFn)
Given a list of MappedValues, cast them to the value kind implied by the interface of the handle type...
static SmallVector< T > castVector(ArrayRef< transform::MappedValue > range)
This class represents an argument of a Block.
Definition Value.h:309
unsigned getArgNumber() const
Returns the number of this argument.
Definition Value.h:321
Block * getOwner() const
Returns the block that owns this argument.
Definition Value.h:318
Block represents an ordered list of Operations.
Definition Block.h:33
MutableArrayRef< BlockArgument > BlockArgListType
Definition Block.h:85
Operation * findAncestorOpInBlock(Operation &op)
Returns 'op' if 'op' lies in this block, or otherwise finds the ancestor operation of 'op' that lies ...
Definition Block.cpp:74
BlockArgument getArgument(unsigned i)
Definition Block.h:129
unsigned getNumArguments()
Definition Block.h:128
Operation * getTerminator()
Get the terminator operation of this block.
Definition Block.cpp:244
BlockArgListType getArguments()
Definition Block.h:87
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
Definition Block.cpp:31
A compatibility class connecting InFlightDiagnostic to DiagnosedSilenceableFailure while providing an...
The result of a transform IR operation application.
LogicalResult silence()
Converts silenceable failure into LogicalResult success without reporting the diagnostic,...
static DiagnosedSilenceableFailure success()
Constructs a DiagnosedSilenceableFailure in the success state.
std::string getMessage() const
Returns the diagnostic message without emitting it.
LogicalResult checkAndReport()
Converts all kinds of failure into a LogicalResult failure, emitting the diagnostic if necessary.
static DiagnosedSilenceableFailure silenceableFailure(Diagnostic &&diag)
Constructs a DiagnosedSilenceableFailure in the silenceable failure state, ready to emit the given di...
bool succeeded() const
Returns true if this is a success.
static DiagnosedSilenceableFailure definiteFailure()
Constructs a DiagnosedSilenceableFailure in the failure state.
bool isSilenceableFailure() const
Returns true if this is a silenceable failure.
This class contains all of the information necessary to report a diagnostic to the DiagnosticEngine.
Diagnostic & attachNote(std::optional< Location > noteLoc=std::nullopt)
Attaches a note to this diagnostic.
IRValueT get() const
Return the current value being used by this operand.
This class represents a diagnostic that is inflight and set to be reported.
Diagnostic & attachNote(std::optional< Location > noteLoc=std::nullopt)
Attaches a note to this diagnostic.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
void setListener(Listener *newListener)
Sets the listener of this builder to the one provided.
Definition Builders.h:316
This class represents an operand of an operation.
Definition Value.h:257
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Definition Value.cpp:226
Set of flags used to control the behavior of the various IR print methods (e.g.
This is a value defined by a result of an operation.
Definition Value.h:457
unsigned getResultNumber() const
Returns the number of this result.
Definition Value.h:469
This class provides the API for a sub-set of ops that are known to be constant-like.
This class provides the API for ops that are known to be isolated from above.
A wrapper class that allows for printing an operation with a set of flags, useful to act as a "stream...
Definition Operation.h:1111
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
StringAttr getIdentifier() const
Return the name of this operation as a StringAttr.
T::Concept * getInterface() const
Returns an instance of the concept object for the given interface if it was registered to this operat...
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Definition Operation.h:686
Value getOperand(unsigned idx)
Definition Operation.h:350
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition Operation.h:749
bool isBeforeInBlock(Operation *other)
Given an operation 'other' that is within the same parent block, return whether the current operation...
Block * getBlock()
Returns the operation block that contains this operation.
Definition Operation.h:213
Operation * getParentWithTrait()
Returns the closest surrounding parent operation with trait Trait.
Definition Operation.h:248
unsigned getNumRegions()
Returns the number of regions held by this operation.
Definition Operation.h:674
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:223
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition Operation.h:234
MutableArrayRef< OpOperand > getOpOperands()
Definition Operation.h:383
unsigned getNumOperands()
Definition Operation.h:346
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
OperationName getName()
The name of an operation is the key identifier for it.
Definition Operation.h:119
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition Operation.h:378
result_range getOpResults()
Definition Operation.h:420
result_range getResults()
Definition Operation.h:415
bool isProperAncestor(Operation *other)
Return true if this operation is a proper ancestor of the other operation.
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
unsigned getNumResults()
Return the number of results held by this operation.
Definition Operation.h:404
A 2D array where each row may have different length.
Definition RaggedArray.h:18
size_t size() const
Returns the number of rows in the 2D array.
Definition RaggedArray.h:21
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition Region.h:26
Block & front()
Definition Region.h:65
Operation * getParentOp()
Return the parent operation this region is attached to.
Definition Region.h:200
This class implements the result iterators for the Operation class.
Definition ValueRange.h:247
RewriterBase(MLIRContext *ctx, OpBuilder::Listener *listener=nullptr)
Initialize the builder.
Resource * getResource() const
Return the resource that the effect applies to.
EffectT * getEffect() const
Return the effect being applied.
Value getValue() const
Return the value the effect is applied on, or nullptr if there isn't a known value being affected.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
Location getLoc() const
Return the location of this value.
Definition Value.cpp:24
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
Operation * getOwner() const
Return the owner of this operand.
Definition UseDefLists.h:38
A list of results of applying a transform op with ApplyEachOpTrait to a single payload operation,...
size_t size() const
Returns the number of elements in the list.
A specialized listener that keeps track of cases in which no replacement payload could be found.
bool failed() const
Return "true" if this tracking listener had a failure.
std::string getLatestMatchFailureMessage()
Return the latest match notification message.
void notifyPayloadReplacementNotFound(Operation *op, ValueRange values, DiagnosedSilenceableFailure &&diag) override
This function is called when a tracked payload op is dropped because no replacement op was found.
void notifyMatchFailure(Location loc, function_ref< void(Diagnostic &)> reasonCallback) override
Notify the listener that the pattern failed to match, and provide a callback to populate a diagnostic...
DiagnosedSilenceableFailure checkAndResetError()
Check and return the current error state of this listener.
This trait is supposed to be attached to Transform dialect operations that can be standalone top-leve...
TrackingListener failures are reported only for ops that have this trait.
void notifyMatchFailure(Location loc, function_ref< void(Diagnostic &)> reasonCallback) override
Notify the listener that the pattern failed to match the given operation, and provide a callback to p...
TransformOpInterface getTransformOp() const
Return the transform op in which this TrackingListener is used.
TrackingListener(TransformState &state, TransformOpInterface op, TrackingListenerConfig config=TrackingListenerConfig())
Create a new TrackingListener for usage in the specified transform op.
static Operation * getCommonDefiningOp(ValueRange values)
Return the single op that defines all given values (if any).
virtual DiagnosedSilenceableFailure findReplacementOp(Operation *&result, Operation *op, ValueRange newValues) const
Return a replacement payload op for the given op, which is going to be replaced with the given values...
Options controlling the application of transform operations by the TransformState.
Local mapping between values defined by a specific op implementing the TransformOpInterface and the p...
void setValues(OpResult handle, Range &&values)
Indicates that the result of the transform IR op at the given position corresponds to the given range...
void setParams(OpResult value, ArrayRef< TransformState::Param > params)
Indicates that the result of the transform IR op at the given position corresponds to the given list ...
void set(OpResult value, Range &&ops)
Indicates that the result of the transform IR op at the given position corresponds to the given list ...
void setRemainingToEmpty(TransformOpInterface transform)
Sets the currently unset results to empty lists of the kind expected by the corresponding results of ...
void setMappedValues(OpResult handle, ArrayRef< MappedValue > values)
Indicates that the result of the transform IR op at the given position corresponds to the given range...
This is a special rewriter to be used in transform op implementations, providing additional helper fu...
TransformRewriter(MLIRContext *ctx, ErrorCheckingTrackingListener *listener)
Create a new TransformRewriter.
bool hasTrackingFailures() const
Return "true" if the tracking listener had failures.
LogicalResult notifyPayloadOperationReplaced(Operation *op, Operation *replacement)
Notify the transform dialect interpreter that the given op has been replaced with another op and that...
void silenceTrackingFailure()
Silence all tracking failures that have been encountered so far.
Extension(TransformState &state)
Constructs an extension of the given TransformState object.
LogicalResult replacePayloadOp(Operation *op, Operation *replacement)
Replaces the given payload op with another op.
virtual ~Extension()
Base virtual destructor.
LogicalResult replacePayloadValue(Value value, Value replacement)
Replaces the given payload value with another value.
A RAII object maintaining a "stack frame" for a transform IR region.
~RegionScope()
Forgets the mapping from or to values defined in the associated transform IR region,...
The state maintained across applications of various ops implementing the TransformOpInterface.
LogicalResult getHandlesForPayloadValue(Value payloadValue, SmallVectorImpl< Value > &handles, bool includeOutOfScope=false) const
Populates handles with all handles pointing to the given payload IR value.
auto getPayloadOps(Value value) const
Returns an iterator that enumerates all ops that the given transform IR value corresponds to.
auto getPayloadValues(Value handleValue) const
Returns an iterator that enumerates all payload IR values that the given transform IR value correspon...
LogicalResult mapBlockArguments(BlockArgument argument, ArrayRef< Operation * > operations)
Records the mapping between a block argument in the transform IR and a list of operations in the payl...
DiagnosedSilenceableFailure applyTransform(TransformOpInterface transform)
Applies the transformation specified by the given transform op and updates the state accordingly.
ArrayRef< Attribute > getParams(Value value) const
Returns the list of parameters that the given transform IR value corresponds to.
LogicalResult mapBlockArgument(BlockArgument argument, ArrayRef< MappedValue > values)
size_t getNumTopLevelMappings() const
Returns the number of extra mappings for the top-level operation.
LogicalResult getHandlesForPayloadOp(Operation *op, SmallVectorImpl< Value > &handles, bool includeOutOfScope=false) const
Populates handles with all handles pointing to the given Payload IR op.
Operation * getTopLevel() const
Returns the op at which the transformation state is rooted.
ArrayRef< MappedValue > getTopLevelMapping(size_t position) const
Returns the position-th extra mapping for the top-level operation.
SideEffects::EffectInstance< Effect > EffectInstance
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:561
LogicalResult verifyTransformOpInterface(Operation *op)
Verification hook for TransformOpInterface.
SmallVector< OpOperand * > getConsumedHandleOpOperands(transform::TransformOpInterface transformOp)
Returns all operands that are handles and being consumed by the given op.
LogicalResult appendValueMappings(MutableArrayRef< SmallVector< transform::MappedValue > > mappings, ValueRange values, const transform::TransformState &state, bool flatten=true)
Appends the entities associated with the given transform values in state to the pre-existing list of ...
void setApplyToOneResults(Operation *transformOp, TransformResults &transformResults, ArrayRef< ApplyToEachResultList > results)
"Transpose" the results produced by individual applications, arranging them per result value of the t...
void forwardTerminatorOperands(Block *block, transform::TransformState &state, transform::TransformResults &results)
Populates results with payload associations that match exactly those of the operands to block's termi...
void getParamProducerTransformOpTraitEffects(Operation *op, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
Non-template implementation of ParamProducerTransformOpTrait::getEffects().
LogicalResult checkNestedConsumption(Location loc, ArrayRef< Operation * > targets)
Reports an error and returns failure if targets contains an ancestor operation before its descendant ...
LogicalResult mapPossibleTopLevelTransformOpBlockArguments(TransformState &state, Operation *op, Region &region)
Maps the only block argument of the op with PossibleTopLevelTransformOpTrait to either the list of op...
TransformState makeTransformStateForTesting(Region *region, Operation *payloadRoot)
Make a dummy transform state for testing purposes.
LogicalResult checkApplyToOne(Operation *transformOp, Location payloadOpLoc, const ApplyToEachResultList &partialResult)
Check that the contents of partialResult matches the number, kind (payload op or parameter) and nulli...
LogicalResult verifyParamProducerTransformOpTrait(Operation *op)
Non-template implementation of ParamProducerTransformOpTrait::verify().
void prepareValueMappings(SmallVectorImpl< SmallVector< transform::MappedValue > > &mappings, ValueRange values, const transform::TransformState &state)
Populates mappings with mapped values associated with the given transform IR values in the given stat...
LogicalResult verifyPossibleTopLevelTransformOpTrait(Operation *op)
Verification hook for PossibleTopLevelTransformOpTrait.
void getPotentialTopLevelEffects(Operation *operation, Value root, Block &body, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
Populates effects with side effects implied by PossibleTopLevelTransformOpTrait for the given operati...
void onlyReadsPayload(SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
LogicalResult applyTransforms(Operation *payloadRoot, TransformOpInterface transform, const RaggedArray< MappedValue > &extraMapping={}, const TransformOptions &options=TransformOptions(), bool enforceToplevelTransformOp=true, function_ref< void(TransformState &)> stateInitializer=nullptr, function_ref< LogicalResult(TransformState &)> stateExporter=nullptr)
Entry point to the Transform dialect infrastructure.
bool isHandleConsumed(Value handle, transform::TransformOpInterface transform)
Checks whether the transform op consumes the given handle.
llvm::PointerUnion< Operation *, Param, Value > MappedValue
void getConsumedBlockArguments(Block &block, llvm::SmallDenseSet< unsigned > &consumedArguments)
Populates consumedArguments with positions of block arguments that are consumed by the operations in ...
bool doesModifyPayload(transform::TransformOpInterface transform)
Checks whether the transform op modifies the payload.
void producesHandle(ResultRange handles, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
bool doesReadPayload(transform::TransformOpInterface transform)
Checks whether the transform op reads the payload.
void consumesHandle(MutableArrayRef< OpOperand > handles, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
Populates effects with the memory effects indicating the operation on the given handle value:
void onlyReadsHandle(MutableArrayRef< OpOperand > handles, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
void modifiesPayload(SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
Populates effects with the memory effects indicating the access to payload IR resource.
Include the generated interface declarations.
const FrozenRewritePatternSet GreedyRewriteConfig config
DiagnosedSilenceableFailure emitSilenceableFailure(Location loc, const Twine &message={})
Emits a silenceable failure with the given message.
llvm::DenseSet< ValueT, ValueInfoT > DenseSet
Definition LLVM.h:128
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
DiagnosedDefiniteFailure emitDefiniteFailure(Location loc, const Twine &message={})
Emits a definite failure with the given message.
bool hasEffect(Operation *op)
Returns "true" if op has an effect of type EffectTy.
llvm::function_ref< Fn > function_ref
Definition LLVM.h:152
Represents a range (offset, size, and stride) where each element of the triple may be dynamic or stat...
A configuration object for customizing a TrackingListener.