MLIR 23.0.0git
TransformInterfaces.cpp
Go to the documentation of this file.
1//===- TransformInterfaces.cpp - Transform Dialect Interfaces -------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include <utility>
10
12
13#include "mlir/IR/Diagnostics.h"
14#include "mlir/IR/Operation.h"
17#include "llvm/ADT/STLExtras.h"
18#include "llvm/ADT/ScopeExit.h"
19#include "llvm/ADT/iterator.h"
20#include "llvm/Support/Debug.h"
21#include "llvm/Support/DebugLog.h"
22#include "llvm/Support/ErrorHandling.h"
23#include "llvm/Support/InterleavedRange.h"
24
25#define DEBUG_TYPE "transform-dialect"
26#define DEBUG_PRINT_AFTER_ALL "transform-dialect-print-top-level-after-all"
27#define FULL_LDBG() LDBG(4)
28
29using namespace mlir;
30
31//===----------------------------------------------------------------------===//
32// Helper functions
33//===----------------------------------------------------------------------===//
34
35/// Return true if `a` happens before `b`, i.e., `a` or one of its ancestors
36/// properly dominates `b` and `b` is not inside `a`.
38 do {
39 if (a->isProperAncestor(b))
40 return false;
41 if (Operation *bAncestor = a->getBlock()->findAncestorOpInBlock(*b)) {
42 return a->isBeforeInBlock(bAncestor);
43 }
44 } while ((a = a->getParentOp()));
45 return false;
46}
47
48//===----------------------------------------------------------------------===//
49// TransformState
50//===----------------------------------------------------------------------===//
51
52transform::TransformState::TransformState(
53 Region *region, Operation *payloadRoot,
54 const RaggedArray<MappedValue> &extraMappings,
55 const TransformOptions &options)
56 : topLevel(payloadRoot), options(options) {
57 topLevelMappedValues.reserve(extraMappings.size());
58 for (ArrayRef<MappedValue> mapping : extraMappings)
59 topLevelMappedValues.push_back(mapping);
60 if (region) {
61 RegionScope *scope = new RegionScope(*this, *region);
62 topLevelRegionScope.reset(scope);
63 }
64}
65
67
69transform::TransformState::getPayloadOpsView(Value value) const {
70 const TransformOpMapping &operationMapping = getMapping(value).direct;
71 auto iter = operationMapping.find(value);
72 assert(iter != operationMapping.end() &&
73 "cannot find mapping for payload handle (param/value handle "
74 "provided?)");
75 return iter->getSecond();
76}
77
79 const ParamMapping &mapping = getMapping(value).params;
80 auto iter = mapping.find(value);
81 assert(iter != mapping.end() && "cannot find mapping for param handle "
82 "(operation/value handle provided?)");
83 return iter->getSecond();
84}
85
87transform::TransformState::getPayloadValuesView(Value handleValue) const {
88 const ValueMapping &mapping = getMapping(handleValue).values;
89 auto iter = mapping.find(handleValue);
90 assert(iter != mapping.end() && "cannot find mapping for value handle "
91 "(param/operation handle provided?)");
92 return iter->getSecond();
93}
94
97 bool includeOutOfScope) const {
98 bool found = false;
99 for (const auto &[region, mapping] : llvm::reverse(mappings)) {
100 auto iterator = mapping->reverse.find(op);
101 if (iterator != mapping->reverse.end()) {
102 llvm::append_range(handles, iterator->getSecond());
103 found = true;
104 }
105 // Stop looking when reaching a region that is isolated from above.
106 if (!includeOutOfScope &&
108 break;
109 }
110
111 return success(found);
112}
113
115 Value payloadValue, SmallVectorImpl<Value> &handles,
116 bool includeOutOfScope) const {
117 bool found = false;
118 for (const auto &[region, mapping] : llvm::reverse(mappings)) {
119 auto iterator = mapping->reverseValues.find(payloadValue);
120 if (iterator != mapping->reverseValues.end()) {
121 llvm::append_range(handles, iterator->getSecond());
122 found = true;
123 }
124 // Stop looking when reaching a region that is isolated from above.
125 if (!includeOutOfScope &&
127 break;
128 }
129
130 return success(found);
131}
132
133/// Given a list of MappedValues, cast them to the value kind implied by the
134/// interface of the handle type, and dispatch to one of the callbacks.
137 function_ref<LogicalResult(ArrayRef<Operation *>)> operationsFn,
138 function_ref<LogicalResult(ArrayRef<transform::Param>)> paramsFn,
139 function_ref<LogicalResult(ValueRange)> valuesFn) {
140 if (llvm::isa<transform::TransformHandleTypeInterface>(handle.getType())) {
141 SmallVector<Operation *> operations;
142 operations.reserve(values.size());
143 for (transform::MappedValue value : values) {
144 if (auto *op = llvm::dyn_cast_if_present<Operation *>(value)) {
145 operations.push_back(op);
146 continue;
147 }
148 return emitSilenceableFailure(handle.getLoc())
149 << "wrong kind of value provided for top-level operation handle";
150 }
151 if (failed(operationsFn(operations)))
154 }
155
156 if (llvm::isa<transform::TransformValueHandleTypeInterface>(
157 handle.getType())) {
158 SmallVector<Value> payloadValues;
159 payloadValues.reserve(values.size());
160 for (transform::MappedValue value : values) {
161 if (auto v = llvm::dyn_cast_if_present<Value>(value)) {
162 payloadValues.push_back(v);
163 continue;
164 }
165 return emitSilenceableFailure(handle.getLoc())
166 << "wrong kind of value provided for the top-level value handle";
167 }
168 if (failed(valuesFn(payloadValues)))
171 }
172
173 assert(llvm::isa<transform::TransformParamTypeInterface>(handle.getType()) &&
174 "unsupported kind of block argument");
176 parameters.reserve(values.size());
177 for (transform::MappedValue value : values) {
178 if (auto attr = llvm::dyn_cast_if_present<Attribute>(value)) {
179 parameters.push_back(attr);
180 continue;
181 }
182 return emitSilenceableFailure(handle.getLoc())
183 << "wrong kind of value provided for top-level parameter";
184 }
185 if (failed(paramsFn(parameters)))
188}
189
190LogicalResult
192 ArrayRef<MappedValue> values) {
194 argument, values,
195 [&](ArrayRef<Operation *> operations) {
196 return setPayloadOps(argument, operations);
197 },
198 [&](ArrayRef<Param> params) {
199 return setParams(argument, params);
200 },
201 [&](ValueRange payloadValues) {
202 return setPayloadValues(argument, payloadValues);
203 })
204 .checkAndReport();
205}
206
208 Block::BlockArgListType arguments,
210 for (auto &&[argument, values] : llvm::zip_equal(arguments, mapping))
211 if (failed(mapBlockArgument(argument, values)))
212 return failure();
213 return success();
214}
215
216LogicalResult
217transform::TransformState::setPayloadOps(Value value,
218 ArrayRef<Operation *> targets) {
219 assert(value != kTopLevelValue &&
220 "attempting to reset the transformation root");
221 assert(llvm::isa<TransformHandleTypeInterface>(value.getType()) &&
222 "wrong handle type");
223
224 for (Operation *target : targets) {
225 if (target)
226 continue;
227 return emitError(value.getLoc())
228 << "attempting to assign a null payload op to this transform value";
229 }
230
231 auto iface = llvm::cast<TransformHandleTypeInterface>(value.getType());
233 iface.checkPayload(value.getLoc(), targets);
234 if (failed(result.checkAndReport()))
235 return failure();
236
237 // Setting new payload for the value without cleaning it first is a misuse of
238 // the API, assert here.
239 SmallVector<Operation *> storedTargets(targets);
240 Mappings &mappings = getMapping(value);
241 bool inserted =
242 mappings.direct.insert({value, std::move(storedTargets)}).second;
243 assert(inserted && "value is already associated with another list");
244 (void)inserted;
245
246 for (Operation *op : targets)
247 mappings.reverse[op].push_back(value);
248
249 return success();
250}
251
252LogicalResult
253transform::TransformState::setPayloadValues(Value handle,
254 ValueRange payloadValues) {
255 assert(handle != nullptr && "attempting to set params for a null value");
256 assert(llvm::isa<TransformValueHandleTypeInterface>(handle.getType()) &&
257 "wrong handle type");
258
259 for (Value payload : payloadValues) {
260 if (payload)
261 continue;
262 return emitError(handle.getLoc()) << "attempting to assign a null payload "
263 "value to this transform handle";
264 }
265
266 auto iface = llvm::cast<TransformValueHandleTypeInterface>(handle.getType());
267 SmallVector<Value> payloadValueVector = llvm::to_vector(payloadValues);
268 DiagnosedSilenceableFailure result =
269 iface.checkPayload(handle.getLoc(), payloadValueVector);
270 if (failed(result.checkAndReport()))
271 return failure();
272
273 Mappings &mappings = getMapping(handle);
274 bool inserted =
275 mappings.values.insert({handle, std::move(payloadValueVector)}).second;
276 assert(
277 inserted &&
278 "value handle is already associated with another list of payload values");
279 (void)inserted;
280
281 for (Value payload : payloadValues)
282 mappings.reverseValues[payload].push_back(handle);
283
284 return success();
285}
286
287LogicalResult transform::TransformState::setParams(Value value,
288 ArrayRef<Param> params) {
289 assert(value != nullptr && "attempting to set params for a null value");
290
291 for (Attribute attr : params) {
292 if (attr)
293 continue;
294 return emitError(value.getLoc())
295 << "attempting to assign a null parameter to this transform value";
296 }
297
298 auto valueType = llvm::dyn_cast<TransformParamTypeInterface>(value.getType());
299 assert(value &&
300 "cannot associate parameter with a value of non-parameter type");
301 DiagnosedSilenceableFailure result =
302 valueType.checkPayload(value.getLoc(), params);
303 if (failed(result.checkAndReport()))
304 return failure();
305
306 Mappings &mappings = getMapping(value);
307 bool inserted =
308 mappings.params.insert({value, llvm::to_vector(params)}).second;
309 assert(inserted && "value is already associated with another list of params");
310 (void)inserted;
311 return success();
312}
313
314template <typename Mapping, typename Key, typename Mapped>
315static void dropMappingEntry(Mapping &mapping, Key key, Mapped mapped) {
316 auto it = mapping.find(key);
317 if (it == mapping.end())
318 return;
319
320 llvm::erase(it->getSecond(), mapped);
321 if (it->getSecond().empty())
322 mapping.erase(it);
323}
324
325void transform::TransformState::forgetMapping(Value opHandle,
326 ValueRange origOpFlatResults,
327 bool allowOutOfScope) {
328 Mappings &mappings = getMapping(opHandle, allowOutOfScope);
329 for (Operation *op : mappings.direct[opHandle])
330 dropMappingEntry(mappings.reverse, op, opHandle);
331 mappings.direct.erase(opHandle);
332#if LLVM_ENABLE_ABI_BREAKING_CHECKS
333 // Payload IR is removed from the mapping. This invalidates the respective
334 // iterators.
335 mappings.incrementTimestamp(opHandle);
336#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
337
338 for (Value opResult : origOpFlatResults) {
339 SmallVector<Value> resultHandles;
340 (void)getHandlesForPayloadValue(opResult, resultHandles);
341 for (Value resultHandle : resultHandles) {
342 Mappings &localMappings = getMapping(resultHandle);
343 dropMappingEntry(localMappings.values, resultHandle, opResult);
344#if LLVM_ENABLE_ABI_BREAKING_CHECKS
345 // Payload IR is removed from the mapping. This invalidates the respective
346 // iterators.
347 mappings.incrementTimestamp(resultHandle);
348#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
349 dropMappingEntry(localMappings.reverseValues, opResult, resultHandle);
350 }
351 }
352}
353
354void transform::TransformState::forgetValueMapping(
355 Value valueHandle, ArrayRef<Operation *> payloadOperations) {
356 Mappings &mappings = getMapping(valueHandle);
357 for (Value payloadValue : mappings.reverseValues[valueHandle])
358 dropMappingEntry(mappings.reverseValues, payloadValue, valueHandle);
359 mappings.values.erase(valueHandle);
360#if LLVM_ENABLE_ABI_BREAKING_CHECKS
361 // Payload IR is removed from the mapping. This invalidates the respective
362 // iterators.
363 mappings.incrementTimestamp(valueHandle);
364#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
365
366 for (Operation *payloadOp : payloadOperations) {
367 SmallVector<Value> opHandles;
368 (void)getHandlesForPayloadOp(payloadOp, opHandles);
369 for (Value opHandle : opHandles) {
370 Mappings &localMappings = getMapping(opHandle);
371 dropMappingEntry(localMappings.direct, opHandle, payloadOp);
372 dropMappingEntry(localMappings.reverse, payloadOp, opHandle);
373
374#if LLVM_ENABLE_ABI_BREAKING_CHECKS
375 // Payload IR is removed from the mapping. This invalidates the respective
376 // iterators.
377 localMappings.incrementTimestamp(opHandle);
378#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
379 }
380 }
381}
382
383LogicalResult
384transform::TransformState::replacePayloadOp(Operation *op,
385 Operation *replacement) {
386 // TODO: consider invalidating the handles to nested objects here.
387
388#ifndef NDEBUG
389 for (Value opResult : op->getResults()) {
390 SmallVector<Value> valueHandles;
391 (void)getHandlesForPayloadValue(opResult, valueHandles,
392 /*includeOutOfScope=*/true);
393 assert(valueHandles.empty() && "expected no mapping to old results");
394 }
395#endif // NDEBUG
396
397 // Drop the mapping between the op and all handles that point to it. Fail if
398 // there are no handles.
399 SmallVector<Value> opHandles;
400 if (failed(getHandlesForPayloadOp(op, opHandles, /*includeOutOfScope=*/true)))
401 return failure();
402 for (Value handle : opHandles) {
403 Mappings &mappings = getMapping(handle, /*allowOutOfScope=*/true);
404 dropMappingEntry(mappings.reverse, op, handle);
405 }
406
407 // Replace the pointed-to object of all handles with the replacement object.
408 // In case a payload op was erased (replacement object is nullptr), a nullptr
409 // is stored in the mapping. These nullptrs are removed after each transform.
410 // Furthermore, nullptrs are not enumerated by payload op iterators. The
411 // relative order of ops is preserved.
412 //
413 // Removing an op from the mapping would be problematic because removing an
414 // element from an array invalidates iterators; merely changing the value of
415 // elements does not.
416 for (Value handle : opHandles) {
417 Mappings &mappings = getMapping(handle, /*allowOutOfScope=*/true);
418 auto it = mappings.direct.find(handle);
419 if (it == mappings.direct.end())
420 continue;
421
422 SmallVector<Operation *, 2> &association = it->getSecond();
423 // Note that an operation may be associated with the handle more than once.
424 for (Operation *&mapped : association) {
425 if (mapped == op)
426 mapped = replacement;
427 }
428
429 if (replacement) {
430 mappings.reverse[replacement].push_back(handle);
431 } else {
432 opHandlesToCompact.insert(handle);
433 }
434 }
435
436 return success();
437}
438
439LogicalResult
440transform::TransformState::replacePayloadValue(Value value, Value replacement) {
441 SmallVector<Value> valueHandles;
442 if (failed(getHandlesForPayloadValue(value, valueHandles,
443 /*includeOutOfScope=*/true)))
444 return failure();
445
446 for (Value handle : valueHandles) {
447 Mappings &mappings = getMapping(handle, /*allowOutOfScope=*/true);
448 dropMappingEntry(mappings.reverseValues, value, handle);
449
450 // If replacing with null, that is erasing the mapping, drop the mapping
451 // between the handles and the IR objects
452 if (!replacement) {
453 dropMappingEntry(mappings.values, handle, value);
454#if LLVM_ENABLE_ABI_BREAKING_CHECKS
455 // Payload IR is removed from the mapping. This invalidates the respective
456 // iterators.
457 mappings.incrementTimestamp(handle);
458#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
459 } else {
460 auto it = mappings.values.find(handle);
461 if (it == mappings.values.end())
462 continue;
463
464 SmallVector<Value> &association = it->getSecond();
465 for (Value &mapped : association) {
466 if (mapped == value)
467 mapped = replacement;
468 }
469 mappings.reverseValues[replacement].push_back(handle);
470 }
471 }
472
473 return success();
474}
475
476void transform::TransformState::recordOpHandleInvalidationOne(
477 OpOperand &consumingHandle, ArrayRef<Operation *> potentialAncestors,
478 Operation *payloadOp, Value otherHandle, Value throughValue,
479 transform::TransformState::InvalidatedHandleMap &newlyInvalidated) const {
480 // If the op is associated with invalidated handle, skip the check as it
481 // may be reading invalid IR. This also ensures we report the first
482 // invalidation and not the last one.
483 if (invalidatedHandles.count(otherHandle) ||
484 newlyInvalidated.count(otherHandle))
485 return;
486
487 FULL_LDBG() << "--recordOpHandleInvalidationOne";
488 FULL_LDBG() << "--ancestors: "
489 << llvm::interleaved(
490 llvm::make_pointee_range(potentialAncestors));
491
492 Operation *owner = consumingHandle.getOwner();
493 unsigned operandNo = consumingHandle.getOperandNumber();
494 for (Operation *ancestor : potentialAncestors) {
495 // clang-format off
496 FULL_LDBG() << "----handle one ancestor: " << *ancestor;;
497
498 FULL_LDBG() << "----of payload with name: "
499 << payloadOp->getName().getIdentifier();
500 FULL_LDBG() << "----of payload: " << *payloadOp;
501 // clang-format on
502 if (!ancestor->isAncestor(payloadOp))
503 continue;
504
505 // Make sure the error-reporting lambda doesn't capture anything
506 // by-reference because it will go out of scope. Additionally, extract
507 // location from Payload IR ops because the ops themselves may be
508 // deleted before the lambda gets called.
509 Location ancestorLoc = ancestor->getLoc();
510 Location opLoc = payloadOp->getLoc();
511 std::optional<Location> throughValueLoc =
512 throughValue ? std::make_optional(throughValue.getLoc()) : std::nullopt;
513 newlyInvalidated[otherHandle] = [ancestorLoc, opLoc, owner, operandNo,
514 otherHandle,
515 throughValueLoc](Location currentLoc) {
516 InFlightDiagnostic diag = emitError(currentLoc)
517 << "op uses a handle invalidated by a "
518 "previously executed transform op";
519 diag.attachNote(otherHandle.getLoc()) << "handle to invalidated ops";
520 diag.attachNote(owner->getLoc())
521 << "invalidated by this transform op that consumes its operand #"
522 << operandNo
523 << " and invalidates all handles to payload IR entities associated "
524 "with this operand and entities nested in them";
525 diag.attachNote(ancestorLoc) << "ancestor payload op";
526 diag.attachNote(opLoc) << "nested payload op";
527 if (throughValueLoc) {
528 diag.attachNote(*throughValueLoc)
529 << "consumed handle points to this payload value";
530 }
531 };
532 }
533}
534
535void transform::TransformState::recordValueHandleInvalidationByOpHandleOne(
536 OpOperand &opHandle, ArrayRef<Operation *> potentialAncestors,
537 Value payloadValue, Value valueHandle,
538 transform::TransformState::InvalidatedHandleMap &newlyInvalidated) const {
539 // If the op is associated with invalidated handle, skip the check as it
540 // may be reading invalid IR. This also ensures we report the first
541 // invalidation and not the last one.
542 if (invalidatedHandles.count(valueHandle) ||
543 newlyInvalidated.count(valueHandle))
544 return;
545
546 for (Operation *ancestor : potentialAncestors) {
547 Operation *definingOp;
548 std::optional<unsigned> resultNo;
549 unsigned argumentNo = std::numeric_limits<unsigned>::max();
550 unsigned blockNo = std::numeric_limits<unsigned>::max();
551 unsigned regionNo = std::numeric_limits<unsigned>::max();
552 if (auto opResult = llvm::dyn_cast<OpResult>(payloadValue)) {
553 definingOp = opResult.getOwner();
554 resultNo = opResult.getResultNumber();
555 } else {
556 auto arg = llvm::cast<BlockArgument>(payloadValue);
557 definingOp = arg.getParentBlock()->getParentOp();
558 argumentNo = arg.getArgNumber();
559 blockNo = arg.getOwner()->computeBlockNumber();
560 regionNo = arg.getOwner()->getParent()->getRegionNumber();
561 }
562 assert(definingOp && "expected the value to be defined by an op as result "
563 "or block argument");
564 if (!ancestor->isAncestor(definingOp))
565 continue;
566
567 Operation *owner = opHandle.getOwner();
568 unsigned operandNo = opHandle.getOperandNumber();
569 Location ancestorLoc = ancestor->getLoc();
570 Location opLoc = definingOp->getLoc();
571 Location valueLoc = payloadValue.getLoc();
572 newlyInvalidated[valueHandle] = [valueHandle, owner, operandNo, resultNo,
573 argumentNo, blockNo, regionNo, ancestorLoc,
574 opLoc, valueLoc](Location currentLoc) {
575 InFlightDiagnostic diag = emitError(currentLoc)
576 << "op uses a handle invalidated by a "
577 "previously executed transform op";
578 diag.attachNote(valueHandle.getLoc()) << "invalidated handle";
579 diag.attachNote(owner->getLoc())
580 << "invalidated by this transform op that consumes its operand #"
581 << operandNo
582 << " and invalidates all handles to payload IR entities "
583 "associated with this operand and entities nested in them";
584 diag.attachNote(ancestorLoc)
585 << "ancestor op associated with the consumed handle";
586 if (resultNo) {
587 diag.attachNote(opLoc)
588 << "op defining the value as result #" << *resultNo;
589 } else {
590 diag.attachNote(opLoc)
591 << "op defining the value as block argument #" << argumentNo
592 << " of block #" << blockNo << " in region #" << regionNo;
593 }
594 diag.attachNote(valueLoc) << "payload value";
595 };
596 }
597}
598
599void transform::TransformState::recordOpHandleInvalidation(
600 OpOperand &handle, ArrayRef<Operation *> potentialAncestors,
601 Value throughValue,
602 transform::TransformState::InvalidatedHandleMap &newlyInvalidated) const {
603
604 if (potentialAncestors.empty()) {
605 FULL_LDBG() << "----recording invalidation for empty handle: "
606 << handle.get();
607
608 Operation *owner = handle.getOwner();
609 unsigned operandNo = handle.getOperandNumber();
610 newlyInvalidated[handle.get()] = [owner, operandNo](Location currentLoc) {
611 InFlightDiagnostic diag = emitError(currentLoc)
612 << "op uses a handle associated with empty "
613 "payload and invalidated by a "
614 "previously executed transform op";
615 diag.attachNote(owner->getLoc())
616 << "invalidated by this transform op that consumes its operand #"
617 << operandNo;
618 };
619 return;
620 }
621
622 // Iterate over the mapping and invalidate aliasing handles. This is quite
623 // expensive and only necessary for error reporting in case of transform
624 // dialect misuse with dangling handles. Iteration over the handles is based
625 // on the assumption that the number of handles is significantly less than the
626 // number of IR objects (operations and values). Alternatively, we could walk
627 // the IR nested in each payload op associated with the given handle and look
628 // for handles associated with each operation and value.
629 for (const auto &[region, mapping] : llvm::reverse(mappings)) {
630 // Go over all op handle mappings and mark as invalidated any handle
631 // pointing to any of the payload ops associated with the given handle or
632 // any op nested in them.
633 for (const auto &[payloadOp, otherHandles] : mapping->reverse) {
634 for (Value otherHandle : otherHandles)
635 recordOpHandleInvalidationOne(handle, potentialAncestors, payloadOp,
636 otherHandle, throughValue,
637 newlyInvalidated);
638 }
639 // Go over all value handle mappings and mark as invalidated any handle
640 // pointing to any result of the payload op associated with the given handle
641 // or any op nested in them. Similarly invalidate handles to argument of
642 // blocks belonging to any region of any payload op associated with the
643 // given handle or any op nested in them.
644 for (const auto &[payloadValue, valueHandles] : mapping->reverseValues) {
645 for (Value valueHandle : valueHandles)
646 recordValueHandleInvalidationByOpHandleOne(handle, potentialAncestors,
647 payloadValue, valueHandle,
648 newlyInvalidated);
649 }
650
651 // Stop lookup when reaching a region that is isolated from above.
652 if (region->getParentOp()->hasTrait<OpTrait::IsIsolatedFromAbove>())
653 break;
654 }
655}
656
657void transform::TransformState::recordValueHandleInvalidation(
658 OpOperand &valueHandle,
659 transform::TransformState::InvalidatedHandleMap &newlyInvalidated) const {
660 // Invalidate other handles to the same value.
661 for (Value payloadValue : getPayloadValuesView(valueHandle.get())) {
662 SmallVector<Value> otherValueHandles;
663 (void)getHandlesForPayloadValue(payloadValue, otherValueHandles);
664 for (Value otherHandle : otherValueHandles) {
665 Operation *owner = valueHandle.getOwner();
666 unsigned operandNo = valueHandle.getOperandNumber();
667 Location valueLoc = payloadValue.getLoc();
668 newlyInvalidated[otherHandle] = [otherHandle, owner, operandNo,
669 valueLoc](Location currentLoc) {
670 InFlightDiagnostic diag = emitError(currentLoc)
671 << "op uses a handle invalidated by a "
672 "previously executed transform op";
673 diag.attachNote(otherHandle.getLoc()) << "invalidated handle";
674 diag.attachNote(owner->getLoc())
675 << "invalidated by this transform op that consumes its operand #"
676 << operandNo
677 << " and invalidates handles to the same values as associated with "
678 "it";
679 diag.attachNote(valueLoc) << "payload value";
680 };
681 }
682
683 if (auto opResult = llvm::dyn_cast<OpResult>(payloadValue)) {
684 Operation *payloadOp = opResult.getOwner();
685 recordOpHandleInvalidation(valueHandle, payloadOp, payloadValue,
686 newlyInvalidated);
687 } else {
688 auto arg = llvm::dyn_cast<BlockArgument>(payloadValue);
689 for (Operation &payloadOp : *arg.getOwner())
690 recordOpHandleInvalidation(valueHandle, &payloadOp, payloadValue,
691 newlyInvalidated);
692 }
693 }
694}
695
696/// Checks that the operation does not use invalidated handles as operands.
697/// Reports errors and returns failure if it does. Otherwise, invalidates the
698/// handles consumed by the operation as well as any handles pointing to payload
699/// IR operations nested in the operations associated with the consumed handles.
700LogicalResult transform::TransformState::checkAndRecordHandleInvalidationImpl(
701 transform::TransformOpInterface transform,
702 transform::TransformState::InvalidatedHandleMap &newlyInvalidated) const {
703 FULL_LDBG() << "--Start checkAndRecordHandleInvalidation";
704 auto memoryEffectsIface =
705 cast<MemoryEffectOpInterface>(transform.getOperation());
706 SmallVector<MemoryEffects::EffectInstance> effects;
707 memoryEffectsIface.getEffectsOnResource(
709
710 for (OpOperand &target : transform->getOpOperands()) {
711 FULL_LDBG() << "----iterate on handle: " << target.get();
712 // If the operand uses an invalidated handle, report it. If the operation
713 // allows handles to point to repeated payload operations, only report
714 // pre-existing invalidation errors. Otherwise, also report invalidations
715 // caused by the current transform operation affecting its other operands.
716 auto it = invalidatedHandles.find(target.get());
717 auto nit = newlyInvalidated.find(target.get());
718 if (it != invalidatedHandles.end()) {
719 FULL_LDBG() << "--End checkAndRecordHandleInvalidation, found already "
720 "invalidated -> FAILURE";
721 return it->getSecond()(transform->getLoc()), failure();
722 }
723 if (!transform.allowsRepeatedHandleOperands() &&
724 nit != newlyInvalidated.end()) {
725 FULL_LDBG() << "--End checkAndRecordHandleInvalidation, found newly "
726 "invalidated (by this op) -> FAILURE";
727 return nit->getSecond()(transform->getLoc()), failure();
728 }
729
730 // Invalidate handles pointing to the operations nested in the operation
731 // associated with the handle consumed by this operation.
732 auto consumesTarget = [&](const MemoryEffects::EffectInstance &effect) {
733 return isa<MemoryEffects::Free>(effect.getEffect()) &&
734 effect.getValue() == target.get();
735 };
736 if (llvm::any_of(effects, consumesTarget)) {
737 FULL_LDBG() << "----found consume effect";
738 if (llvm::isa<transform::TransformHandleTypeInterface>(
739 target.get().getType())) {
740 FULL_LDBG() << "----recordOpHandleInvalidation";
741 SmallVector<Operation *> payloadOps =
742 llvm::to_vector(getPayloadOps(target.get()));
743 recordOpHandleInvalidation(target, payloadOps, nullptr,
744 newlyInvalidated);
745 } else if (llvm::isa<transform::TransformValueHandleTypeInterface>(
746 target.get().getType())) {
747 FULL_LDBG() << "----recordValueHandleInvalidation";
748 recordValueHandleInvalidation(target, newlyInvalidated);
749 } else {
750 FULL_LDBG()
751 << "----not a TransformHandle -> SKIP AND DROP ON THE FLOOR";
752 }
753 } else {
754 FULL_LDBG() << "----no consume effect -> SKIP";
755 }
756 }
757
758 FULL_LDBG() << "--End checkAndRecordHandleInvalidation -> SUCCESS";
759 return success();
760}
761
762LogicalResult transform::TransformState::checkAndRecordHandleInvalidation(
763 transform::TransformOpInterface transform) {
764 InvalidatedHandleMap newlyInvalidated;
765 LogicalResult checkResult =
766 checkAndRecordHandleInvalidationImpl(transform, newlyInvalidated);
767 invalidatedHandles.insert(std::make_move_iterator(newlyInvalidated.begin()),
768 std::make_move_iterator(newlyInvalidated.end()));
769 return checkResult;
770}
771
772template <typename T>
773static DiagnosedSilenceableFailure
775 transform::TransformOpInterface transform,
776 unsigned operandNumber) {
777 DenseSet<T> seen;
778 for (T p : payload) {
779 if (!seen.insert(p).second) {
781 transform.emitSilenceableError()
782 << "a handle passed as operand #" << operandNumber
783 << " and consumed by this operation points to a payload "
784 "entity more than once";
785 if constexpr (std::is_pointer_v<T>)
786 diag.attachNote(p->getLoc()) << "repeated target op";
787 else
788 diag.attachNote(p.getLoc()) << "repeated target value";
789 return diag;
790 }
791 }
793}
794
795void transform::TransformState::compactOpHandles() {
796 for (Value handle : opHandlesToCompact) {
797 Mappings &mappings = getMapping(handle, /*allowOutOfScope=*/true);
798#if LLVM_ENABLE_ABI_BREAKING_CHECKS
799 if (llvm::is_contained(mappings.direct[handle], nullptr))
800 // Payload IR is removed from the mapping. This invalidates the respective
801 // iterators.
802 mappings.incrementTimestamp(handle);
803#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
804 llvm::erase(mappings.direct[handle], nullptr);
805 }
806 opHandlesToCompact.clear();
807}
808
809DiagnosedSilenceableFailure
811 LDBG() << "applying: "
812 << OpWithFlags(transform, OpPrintingFlags().skipRegions());
813 FULL_LDBG() << "Top-level payload before application:\n" << *getTopLevel();
814 llvm::scope_exit printOnFailureRAII([this] {
815 (void)this;
816 LDBG() << "Failing Top-level payload:\n"
818 OpPrintingFlags().printGenericOpForm());
819 });
820
821 // Set current transform op.
822 regionStack.back()->currentTransform = transform;
823
824 // Expensive checks to detect invalid transform IR.
825 if (options.getExpensiveChecksEnabled()) {
826 FULL_LDBG() << "ExpensiveChecksEnabled";
827 if (failed(checkAndRecordHandleInvalidation(transform)))
829
830 for (OpOperand &operand : transform->getOpOperands()) {
831 FULL_LDBG() << "iterate on handle: " << operand.get();
832 if (!isHandleConsumed(operand.get(), transform)) {
833 FULL_LDBG() << "--handle not consumed -> SKIP";
834 continue;
835 }
836 if (transform.allowsRepeatedHandleOperands()) {
837 FULL_LDBG() << "--op allows repeated handles -> SKIP";
838 continue;
839 }
840 FULL_LDBG() << "--handle is consumed";
841
842 Type operandType = operand.get().getType();
843 if (llvm::isa<TransformHandleTypeInterface>(operandType)) {
844 FULL_LDBG() << "--checkRepeatedConsumptionInOperand for Operation*";
847 getPayloadOpsView(operand.get()), transform,
848 operand.getOperandNumber());
849 if (!check.succeeded()) {
850 FULL_LDBG() << "----FAILED";
851 return check;
852 }
853 } else if (llvm::isa<TransformValueHandleTypeInterface>(operandType)) {
854 FULL_LDBG() << "--checkRepeatedConsumptionInOperand For Value";
857 getPayloadValuesView(operand.get()), transform,
858 operand.getOperandNumber());
859 if (!check.succeeded()) {
860 FULL_LDBG() << "----FAILED";
861 return check;
862 }
863 } else {
864 FULL_LDBG() << "--not a TransformHandle -> SKIP AND DROP ON THE FLOOR";
865 }
866 }
867 }
868
869 // Find which operands are consumed.
870 SmallVector<OpOperand *> consumedOperands =
871 transform.getConsumedHandleOpOperands();
872
873 // Remember the results of the payload ops associated with the consumed
874 // op handles or the ops defining the value handles so we can drop the
875 // association with them later. This must happen here because the
876 // transformation may destroy or mutate them so we cannot traverse the payload
877 // IR after that.
878 SmallVector<Value> origOpFlatResults;
879 SmallVector<Operation *> origAssociatedOps;
880 for (OpOperand *opOperand : consumedOperands) {
881 Value operand = opOperand->get();
882 if (llvm::isa<TransformHandleTypeInterface>(operand.getType())) {
883 for (Operation *payloadOp : getPayloadOps(operand)) {
884 llvm::append_range(origOpFlatResults, payloadOp->getResults());
885 }
886 continue;
887 }
888 if (llvm::isa<TransformValueHandleTypeInterface>(operand.getType())) {
889 for (Value payloadValue : getPayloadValuesView(operand)) {
890 if (llvm::isa<OpResult>(payloadValue)) {
891 origAssociatedOps.push_back(payloadValue.getDefiningOp());
892 continue;
893 }
894 llvm::append_range(
895 origAssociatedOps,
896 llvm::map_range(*llvm::cast<BlockArgument>(payloadValue).getOwner(),
897 [](Operation &op) { return &op; }));
898 }
899 continue;
900 }
903 << "unexpectedly consumed a value that is not a handle as operand #"
904 << opOperand->getOperandNumber();
905 diag.attachNote(operand.getLoc())
906 << "value defined here with type " << operand.getType();
907 return diag;
908 }
909
910 // Prepare rewriter and listener.
912 config.skipHandleFn = [&](Value handle) {
913 // Skip handle if it is dead.
914 auto scopeIt =
915 llvm::find_if(llvm::reverse(regionStack), [&](RegionScope *scope) {
916 return handle.getParentRegion() == scope->region;
917 });
918 assert(scopeIt != regionStack.rend() &&
919 "could not find region scope for handle");
920 RegionScope *scope = *scopeIt;
921 return llvm::all_of(handle.getUsers(), [&](Operation *user) {
922 return user == scope->currentTransform ||
923 happensBefore(user, scope->currentTransform);
924 });
925 };
927 config);
928 transform::TransformRewriter rewriter(transform->getContext(),
929 &trackingListener);
930
931 // Compute the result but do not short-circuit the silenceable failure case as
932 // we still want the handles to propagate properly so the "suppress" mode can
933 // proceed on a best effort basis.
934 transform::TransformResults results(transform->getNumResults());
935 DiagnosedSilenceableFailure result(transform.apply(rewriter, results, *this));
936 compactOpHandles();
937
938 // Error handling: fail if transform or listener failed.
939 DiagnosedSilenceableFailure trackingFailure =
940 trackingListener.checkAndResetError();
942 transform->hasAttr(FindPayloadReplacementOpInterface::
943 kSilenceTrackingFailuresAttrName)) {
944 // Only report failures for ReportTrackingListenerFailuresOpTrait ops. Also
945 // do not report failures if the above mentioned attribute is set.
946 if (trackingFailure.isSilenceableFailure())
947 (void)trackingFailure.silence();
948 trackingFailure = DiagnosedSilenceableFailure::success();
949 }
950 if (!trackingFailure.succeeded()) {
951 if (result.succeeded()) {
952 result = std::move(trackingFailure);
953 } else {
954 // Transform op errors have precedence, report those first.
955 if (result.isSilenceableFailure())
956 result.attachNote() << "tracking listener also failed: "
957 << trackingFailure.getMessage();
958 (void)trackingFailure.silence();
959 }
960 }
961 if (result.isDefiniteFailure())
962 return result;
963
964 // If a silenceable failure was produced, some results may be unset, set them
965 // to empty lists.
966 if (result.isSilenceableFailure())
968
969 // Remove the mapping for the operand if it is consumed by the operation. This
970 // allows us to catch use-after-free with assertions later on.
971 for (OpOperand *opOperand : consumedOperands) {
972 Value operand = opOperand->get();
973 if (llvm::isa<TransformHandleTypeInterface>(operand.getType())) {
974 forgetMapping(operand, origOpFlatResults);
975 } else if (llvm::isa<TransformValueHandleTypeInterface>(
976 operand.getType())) {
977 forgetValueMapping(operand, origAssociatedOps);
978 }
979 }
980
981 if (failed(updateStateFromResults(results, transform->getResults())))
983
984 printOnFailureRAII.release();
985 DEBUG_WITH_TYPE(DEBUG_PRINT_AFTER_ALL, {
986 LDBG() << "Top-level payload:\n" << *getTopLevel();
987 });
988 return result;
989}
990
991LogicalResult transform::TransformState::updateStateFromResults(
992 const TransformResults &results, ResultRange opResults) {
993 for (OpResult result : opResults) {
994 if (llvm::isa<TransformParamTypeInterface>(result.getType())) {
995 assert(results.isParam(result.getResultNumber()) &&
996 "expected parameters for the parameter-typed result");
997 if (failed(
998 setParams(result, results.getParams(result.getResultNumber())))) {
999 return failure();
1000 }
1001 } else if (llvm::isa<TransformValueHandleTypeInterface>(result.getType())) {
1002 assert(results.isValue(result.getResultNumber()) &&
1003 "expected values for value-type-result");
1004 if (failed(setPayloadValues(
1005 result, results.getValues(result.getResultNumber())))) {
1006 return failure();
1007 }
1008 } else {
1009 assert(!results.isParam(result.getResultNumber()) &&
1010 "expected payload ops for the non-parameter typed result");
1011 if (failed(
1012 setPayloadOps(result, results.get(result.getResultNumber())))) {
1013 return failure();
1014 }
1015 }
1016 }
1017 return success();
1018}
1019
1020//===----------------------------------------------------------------------===//
1021// TransformState::Extension
1022//===----------------------------------------------------------------------===//
1023
1025
1026LogicalResult
1029 // TODO: we may need to invalidate handles to operations and values nested in
1030 // the operation being replaced.
1031 return state.replacePayloadOp(op, replacement);
1032}
1033
1034LogicalResult
1037 return state.replacePayloadValue(value, replacement);
1038}
1039
1040//===----------------------------------------------------------------------===//
1041// TransformState::RegionScope
1042//===----------------------------------------------------------------------===//
1043
1045 // Remove handle invalidation notices as handles are going out of scope.
1046 // The same region may be re-entered leading to incorrect invalidation
1047 // errors.
1048 for (Block &block : *region) {
1049 for (Value handle : block.getArguments()) {
1050 state.invalidatedHandles.erase(handle);
1051 }
1052 for (Operation &op : block) {
1053 for (Value handle : op.getResults()) {
1054 state.invalidatedHandles.erase(handle);
1055 }
1056 }
1057 }
1058
1059#if LLVM_ENABLE_ABI_BREAKING_CHECKS
1060 // Remember pointers to payload ops referenced by the handles going out of
1061 // scope.
1062 SmallVector<Operation *> referencedOps =
1063 llvm::to_vector(llvm::make_first_range(state.mappings[region]->reverse));
1064#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
1065
1066 state.mappings.erase(region);
1067 state.regionStack.pop_back();
1068}
1069
1070//===----------------------------------------------------------------------===//
1071// TransformResults
1072//===----------------------------------------------------------------------===//
1073
1074transform::TransformResults::TransformResults(unsigned numSegments) {
1075 operations.appendEmptyRows(numSegments);
1076 params.appendEmptyRows(numSegments);
1077 values.appendEmptyRows(numSegments);
1078}
1079
1082 int64_t position = value.getResultNumber();
1083 assert(position < static_cast<int64_t>(this->params.size()) &&
1084 "setting params for a non-existent handle");
1085 assert(this->params[position].data() == nullptr && "params already set");
1086 assert(operations[position].data() == nullptr &&
1087 "another kind of results already set");
1088 assert(values[position].data() == nullptr &&
1089 "another kind of results already set");
1090 this->params.replace(position, params);
1091}
1092
1094 OpResult handle, ArrayRef<MappedValue> values) {
1096 handle, values,
1097 [&](ArrayRef<Operation *> operations) {
1098 return set(handle, operations), success();
1099 },
1100 [&](ArrayRef<Param> params) {
1101 return setParams(handle, params), success();
1102 },
1103 [&](ValueRange payloadValues) {
1104 return setValues(handle, payloadValues), success();
1105 });
1106#ifndef NDEBUG
1107 if (!diag.succeeded())
1108 llvm::dbgs() << diag.getStatusString() << "\n";
1109 assert(diag.succeeded() && "incorrect mapping");
1110#endif // NDEBUG
1111 (void)diag.silence();
1112}
1113
1115 transform::TransformOpInterface transform) {
1116 for (OpResult opResult : transform->getResults()) {
1117 if (!isSet(opResult.getResultNumber()))
1118 setMappedValues(opResult, {});
1119 }
1120}
1121
1123transform::TransformResults::get(unsigned resultNumber) const {
1124 assert(resultNumber < operations.size() &&
1125 "querying results for a non-existent handle");
1126 assert(operations[resultNumber].data() != nullptr &&
1127 "querying unset results (values or params expected?)");
1128 return operations[resultNumber];
1129}
1130
1132transform::TransformResults::getParams(unsigned resultNumber) const {
1133 assert(resultNumber < params.size() &&
1134 "querying params for a non-existent handle");
1135 assert(params[resultNumber].data() != nullptr &&
1136 "querying unset params (ops or values expected?)");
1137 return params[resultNumber];
1138}
1139
1141transform::TransformResults::getValues(unsigned resultNumber) const {
1142 assert(resultNumber < values.size() &&
1143 "querying values for a non-existent handle");
1144 assert(values[resultNumber].data() != nullptr &&
1145 "querying unset values (ops or params expected?)");
1146 return values[resultNumber];
1147}
1148
1149bool transform::TransformResults::isParam(unsigned resultNumber) const {
1150 assert(resultNumber < params.size() &&
1151 "querying association for a non-existent handle");
1152 return params[resultNumber].data() != nullptr;
1153}
1154
1155bool transform::TransformResults::isValue(unsigned resultNumber) const {
1156 assert(resultNumber < values.size() &&
1157 "querying association for a non-existent handle");
1158 return values[resultNumber].data() != nullptr;
1159}
1160
1161bool transform::TransformResults::isSet(unsigned resultNumber) const {
1162 assert(resultNumber < params.size() &&
1163 "querying association for a non-existent handle");
1164 return params[resultNumber].data() != nullptr ||
1165 operations[resultNumber].data() != nullptr ||
1166 values[resultNumber].data() != nullptr;
1167}
1168
1169//===----------------------------------------------------------------------===//
1170// TrackingListener
1171//===----------------------------------------------------------------------===//
1172
1174 TransformOpInterface op,
1176 : TransformState::Extension(state), transformOp(op),
1177 config(std::move(config)) {
1178 if (op) {
1179 for (OpOperand *opOperand : transformOp.getConsumedHandleOpOperands()) {
1180 consumedHandles.insert(opOperand->get());
1181 }
1182 }
1183}
1184
1186 Operation *defOp = nullptr;
1187 for (Value v : values) {
1188 // Skip empty values.
1189 if (!v)
1190 continue;
1191 if (!defOp) {
1192 defOp = v.getDefiningOp();
1193 continue;
1194 }
1195 if (defOp != v.getDefiningOp())
1196 return nullptr;
1197 }
1198 return defOp;
1199}
1200
1202 Operation *&result, Operation *op, ValueRange newValues) const {
1203 assert(op->getNumResults() == newValues.size() &&
1204 "invalid number of replacement values");
1205 SmallVector<Value> values(newValues.begin(), newValues.end());
1206
1208 getTransformOp(), "tracking listener failed to find replacement op "
1209 "during application of this transform op");
1210
1211 do {
1212 // If the replacement values belong to different ops, drop the mapping.
1213 Operation *defOp = getCommonDefiningOp(values);
1214 if (!defOp) {
1215 diag.attachNote() << "replacement values belong to different ops";
1216 return diag;
1217 }
1218
1219 // Skip through ops that implement CastOpInterface.
1220 if (config.skipCastOps && isa<CastOpInterface>(defOp)) {
1221 values.clear();
1222 values.assign(defOp->getOperands().begin(), defOp->getOperands().end());
1223 diag.attachNote(defOp->getLoc())
1224 << "using output of 'CastOpInterface' op";
1225 continue;
1226 }
1227
1228 // If the defining op has the same name or we do not care about the name of
1229 // op replacements at all, we take it as a replacement.
1230 if (!config.requireMatchingReplacementOpName ||
1231 op->getName() == defOp->getName()) {
1232 result = defOp;
1234 }
1235
1236 // Replacing an op with a constant-like equivalent is a common
1237 // canonicalization.
1238 if (defOp->hasTrait<OpTrait::ConstantLike>()) {
1239 result = defOp;
1241 }
1242
1243 values.clear();
1244
1245 // Skip through ops that implement FindPayloadReplacementOpInterface.
1246 if (auto findReplacementOpInterface =
1247 dyn_cast<FindPayloadReplacementOpInterface>(defOp)) {
1248 values.assign(findReplacementOpInterface.getNextOperands());
1249 diag.attachNote(defOp->getLoc()) << "using operands provided by "
1250 "'FindPayloadReplacementOpInterface'";
1251 continue;
1252 }
1253 } while (!values.empty());
1254
1255 diag.attachNote() << "ran out of suitable replacement values";
1256 return diag;
1257}
1258
1260 Location loc, function_ref<void(Diagnostic &)> reasonCallback) {
1261 LLVM_DEBUG({
1263 reasonCallback(diag);
1264 LDBG() << "Match Failure : " << diag.str();
1265 });
1266}
1267
1268void transform::TrackingListener::notifyOperationErased(Operation *op) {
1269 // Remove mappings for result values.
1270 for (OpResult value : op->getResults())
1271 (void)replacePayloadValue(value, nullptr);
1272 // Remove mapping for op.
1273 (void)replacePayloadOp(op, nullptr);
1274}
1275
1276void transform::TrackingListener::notifyOperationReplaced(
1277 Operation *op, ValueRange newValues) {
1278 assert(op->getNumResults() == newValues.size() &&
1279 "invalid number of replacement values");
1280
1281 // Replace value handles.
1282 for (auto [oldValue, newValue] : llvm::zip(op->getResults(), newValues))
1283 (void)replacePayloadValue(oldValue, newValue);
1284
1285 // Replace op handle.
1286 SmallVector<Value> opHandles;
1287 if (failed(getTransformState().getHandlesForPayloadOp(
1288 op, opHandles, /*includeOutOfScope=*/true))) {
1289 // Op is not tracked.
1290 return;
1291 }
1292
1293 // Helper function to check if the current transform op consumes any handle
1294 // that is mapped to `op`.
1295 //
1296 // Note: If a handle was consumed, there shouldn't be any alive users, so it
1297 // is not really necessary to check for consumed handles. However, in case
1298 // there are indeed alive handles that were consumed (which is undefined
1299 // behavior) and a replacement op could not be found, we want to fail with a
1300 // nicer error message: "op uses a handle invalidated..." instead of "could
1301 // not find replacement op". This nicer error is produced later.
1302 auto handleWasConsumed = [&] {
1303 return llvm::any_of(opHandles,
1304 [&](Value h) { return consumedHandles.contains(h); });
1305 };
1306
1307 // Check if there are any handles that must be updated.
1308 Value aliveHandle;
1309 if (config.skipHandleFn) {
1310 auto *it = llvm::find_if(opHandles,
1311 [&](Value v) { return !config.skipHandleFn(v); });
1312 if (it != opHandles.end())
1313 aliveHandle = *it;
1314 } else if (!opHandles.empty()) {
1315 aliveHandle = opHandles.front();
1316 }
1317 if (!aliveHandle || handleWasConsumed()) {
1318 // The op is tracked but the corresponding handles are dead or were
1319 // consumed. Drop the op form the mapping.
1320 (void)replacePayloadOp(op, nullptr);
1321 return;
1322 }
1323
1324 Operation *replacement;
1325 DiagnosedSilenceableFailure diag =
1326 findReplacementOp(replacement, op, newValues);
1327 // If the op is tracked but no replacement op was found, send a
1328 // notification.
1329 if (!diag.succeeded()) {
1330 diag.attachNote(aliveHandle.getLoc())
1331 << "replacement is required because this handle must be updated";
1332 notifyPayloadReplacementNotFound(op, newValues, std::move(diag));
1333 (void)replacePayloadOp(op, nullptr);
1334 return;
1335 }
1336
1337 (void)replacePayloadOp(op, replacement);
1338}
1339
1341 // The state of the ErrorCheckingTrackingListener must be checked and reset
1342 // if there was an error. This is to prevent errors from accidentally being
1343 // missed.
1344 assert(status.succeeded() && "listener state was not checked");
1345}
1346
1349 DiagnosedSilenceableFailure s = std::move(status);
1351 errorCounter = 0;
1352 return s;
1353}
1354
1356 return !status.succeeded();
1357}
1358
1361
1362 // Merge potentially existing diags and store the result in the listener.
1364 diag.takeDiagnostics(diags);
1365 if (!status.succeeded())
1366 status.takeDiagnostics(diags);
1367 status = DiagnosedSilenceableFailure::silenceableFailure(std::move(diags));
1368
1369 // Report more details.
1370 status.attachNote(op->getLoc()) << "[" << errorCounter << "] replaced op";
1371 for (auto &&[index, value] : llvm::enumerate(values))
1372 status.attachNote(value.getLoc())
1373 << "[" << errorCounter << "] replacement value " << index;
1374 ++errorCounter;
1375}
1376
1377std::string
1379 if (!matchFailure) {
1380 return "";
1381 }
1382 return matchFailure->str();
1383}
1384
1386 Location loc, function_ref<void(Diagnostic &)> reasonCallback) {
1388 reasonCallback(diag);
1389 matchFailure = std::move(diag);
1390}
1391
1392//===----------------------------------------------------------------------===//
1393// TransformRewriter
1394//===----------------------------------------------------------------------===//
1395
1398 : RewriterBase(ctx), listener(listener) {
1399 setListener(listener);
1400}
1401
1403 return listener->failed();
1404}
1405
1406/// Silence all tracking failures that have been encountered so far.
1408 if (hasTrackingFailures()) {
1409 DiagnosedSilenceableFailure status = listener->checkAndResetError();
1410 (void)status.silence();
1411 }
1412}
1413
1416 return listener->replacePayloadOp(op, replacement);
1417}
1418
1419//===----------------------------------------------------------------------===//
1420// Utilities for TransformEachOpTrait.
1421//===----------------------------------------------------------------------===//
1422
1423LogicalResult
1425 ArrayRef<Operation *> targets) {
1426 for (auto &&[position, parent] : llvm::enumerate(targets)) {
1427 for (Operation *child : targets.drop_front(position + 1)) {
1428 if (parent->isAncestor(child)) {
1430 emitError(loc)
1431 << "transform operation consumes a handle pointing to an ancestor "
1432 "payload operation before its descendant";
1433 diag.attachNote()
1434 << "the ancestor is likely erased or rewritten before the "
1435 "descendant is accessed, leading to undefined behavior";
1436 diag.attachNote(parent->getLoc()) << "ancestor payload op";
1437 diag.attachNote(child->getLoc()) << "descendant payload op";
1438 return diag;
1439 }
1440 }
1441 }
1442 return success();
1443}
1444
1445LogicalResult
1447 Location payloadOpLoc,
1448 const ApplyToEachResultList &partialResult) {
1449 Location transformOpLoc = transformOp->getLoc();
1450 StringRef transformOpName = transformOp->getName().getStringRef();
1451 unsigned expectedNumResults = transformOp->getNumResults();
1452
1453 // Reuse the emission of the diagnostic note.
1454 auto emitDiag = [&]() {
1455 auto diag = mlir::emitError(transformOpLoc);
1456 diag.attachNote(payloadOpLoc) << "when applied to this op";
1457 return diag;
1458 };
1459
1460 if (partialResult.size() != expectedNumResults) {
1461 auto diag = emitDiag() << "application of " << transformOpName
1462 << " expected to produce " << expectedNumResults
1463 << " results (actually produced "
1464 << partialResult.size() << ").";
1465 diag.attachNote(transformOpLoc)
1466 << "if you need variadic results, consider a generic `apply` "
1467 << "instead of the specialized `applyToOne`.";
1468 return failure();
1469 }
1470
1471 // Check that the right kind of value was produced.
1472 for (const auto &[ptr, res] :
1473 llvm::zip(partialResult, transformOp->getResults())) {
1474 if (ptr.isNull())
1475 continue;
1476 if (llvm::isa<TransformHandleTypeInterface>(res.getType()) &&
1477 !isa<Operation *>(ptr)) {
1478 return emitDiag() << "application of " << transformOpName
1479 << " expected to produce an Operation * for result #"
1480 << res.getResultNumber();
1481 }
1482 if (llvm::isa<TransformParamTypeInterface>(res.getType()) &&
1483 !isa<Attribute>(ptr)) {
1484 return emitDiag() << "application of " << transformOpName
1485 << " expected to produce an Attribute for result #"
1486 << res.getResultNumber();
1487 }
1488 if (llvm::isa<TransformValueHandleTypeInterface>(res.getType()) &&
1489 !isa<Value>(ptr)) {
1490 return emitDiag() << "application of " << transformOpName
1491 << " expected to produce a Value for result #"
1492 << res.getResultNumber();
1493 }
1494 }
1495 return success();
1496}
1497
1498template <typename T>
1500 return llvm::map_to_vector(range, llvm::CastTo<T>);
1501}
1502
1504 Operation *transformOp, TransformResults &transformResults,
1507 transposed.resize(transformOp->getNumResults());
1508 for (const ApplyToEachResultList &partialResults : results) {
1509 if (llvm::any_of(partialResults,
1510 [](MappedValue value) { return value.isNull(); }))
1511 continue;
1512 assert(transformOp->getNumResults() == partialResults.size() &&
1513 "expected as many partial results as op as results");
1514 for (auto [i, value] : llvm::enumerate(partialResults))
1515 transposed[i].push_back(value);
1516 }
1517
1518 for (OpResult r : transformOp->getResults()) {
1519 unsigned position = r.getResultNumber();
1520 if (llvm::isa<TransformParamTypeInterface>(r.getType())) {
1521 transformResults.setParams(r,
1522 castVector<Attribute>(transposed[position]));
1523 } else if (llvm::isa<TransformValueHandleTypeInterface>(r.getType())) {
1524 transformResults.setValues(r, castVector<Value>(transposed[position]));
1525 } else {
1526 transformResults.set(r, castVector<Operation *>(transposed[position]));
1527 }
1528 }
1529}
1530
1531//===----------------------------------------------------------------------===//
1532// Utilities for implementing transform ops with regions.
1533//===----------------------------------------------------------------------===//
1534
1537 ValueRange values, const transform::TransformState &state, bool flatten) {
1538 assert(mappings.size() == values.size() && "mismatching number of mappings");
1539 for (auto &&[operand, mapped] : llvm::zip_equal(values, mappings)) {
1540 size_t mappedSize = mapped.size();
1541 if (llvm::isa<TransformHandleTypeInterface>(operand.getType())) {
1542 llvm::append_range(mapped, state.getPayloadOps(operand));
1543 } else if (llvm::isa<TransformValueHandleTypeInterface>(
1544 operand.getType())) {
1545 llvm::append_range(mapped, state.getPayloadValues(operand));
1546 } else {
1547 assert(llvm::isa<TransformParamTypeInterface>(operand.getType()) &&
1548 "unsupported kind of transform dialect value");
1549 llvm::append_range(mapped, state.getParams(operand));
1550 }
1551
1552 if (mapped.size() - mappedSize != 1 && !flatten)
1553 return failure();
1554 }
1555 return success();
1556}
1557
1560 ValueRange values, const transform::TransformState &state) {
1561 mappings.resize(mappings.size() + values.size());
1564 values.size()),
1565 values, state);
1566}
1567
1569 Block *block, transform::TransformState &state,
1570 transform::TransformResults &results) {
1571 for (auto &&[terminatorOperand, result] :
1572 llvm::zip(block->getTerminator()->getOperands(),
1573 block->getParentOp()->getOpResults())) {
1574 if (llvm::isa<transform::TransformHandleTypeInterface>(result.getType())) {
1575 results.set(result, state.getPayloadOps(terminatorOperand));
1576 } else if (llvm::isa<transform::TransformValueHandleTypeInterface>(
1577 result.getType())) {
1578 results.setValues(result, state.getPayloadValues(terminatorOperand));
1579 } else {
1580 assert(
1581 llvm::isa<transform::TransformParamTypeInterface>(result.getType()) &&
1582 "unhandled transform type interface");
1583 results.setParams(result, state.getParams(terminatorOperand));
1584 }
1585 }
1586}
1587
1590 Operation *payloadRoot) {
1591 return TransformState(region, payloadRoot);
1592}
1593
1594//===----------------------------------------------------------------------===//
1595// Utilities for PossibleTopLevelTransformOpTrait.
1596//===----------------------------------------------------------------------===//
1597
1598/// Appends to `effects` the memory effect instances on `target` with the same
1599/// resource and effect as the ones the operation `iface` having on `source`.
1600static void
1601remapEffects(MemoryEffectOpInterface iface, BlockArgument source,
1605 iface.getEffectsOnValue(source, nestedEffects);
1606 for (const auto &effect : nestedEffects)
1607 effects.emplace_back(effect.getEffect(), target, effect.getResource());
1608}
1609
1610/// Appends to `effects` the same effects as the operations of `block` have on
1611/// block arguments but associated with `operands.`
1612static void
1615 for (Operation &op : block) {
1616 auto iface = dyn_cast<MemoryEffectOpInterface>(&op);
1617 if (!iface)
1618 continue;
1619
1620 for (auto &&[source, target] : llvm::zip(block.getArguments(), operands)) {
1621 remapEffects(iface, source, &target, effects);
1622 }
1623
1625 iface.getEffectsOnResource(transform::PayloadIRResource::get(),
1626 nestedEffects);
1627 llvm::append_range(effects, nestedEffects);
1628 }
1629}
1630
1632 Operation *operation, Value root, Block &body,
1634 transform::onlyReadsHandle(operation->getOpOperands(), effects);
1635 transform::producesHandle(operation->getOpResults(), effects);
1636
1637 if (!root) {
1638 for (Operation &op : body) {
1639 auto iface = dyn_cast<MemoryEffectOpInterface>(&op);
1640 if (!iface)
1641 continue;
1642
1643 iface.getEffects(effects);
1644 }
1645 return;
1646 }
1647
1648 // Carry over all effects on arguments of the entry block as those on the
1649 // operands, this is the same value just remapped.
1650 remapArgumentEffects(body, operation->getOpOperands(), effects);
1651}
1652
1654 TransformState &state, Operation *op, Region &region) {
1657 if (op->getNumOperands() != 0) {
1658 llvm::append_range(targets, state.getPayloadOps(op->getOperand(0)));
1659 prepareValueMappings(extraMappings, op->getOperands().drop_front(), state);
1660 } else {
1661 if (state.getNumTopLevelMappings() !=
1662 region.front().getNumArguments() - 1) {
1663 return emitError(op->getLoc())
1664 << "operation expects " << region.front().getNumArguments() - 1
1665 << " extra value bindings, but " << state.getNumTopLevelMappings()
1666 << " were provided to the interpreter";
1667 }
1668
1669 targets.push_back(state.getTopLevel());
1670
1671 for (unsigned i = 0, e = state.getNumTopLevelMappings(); i < e; ++i)
1672 extraMappings.push_back(llvm::to_vector(state.getTopLevelMapping(i)));
1673 }
1674
1675 if (failed(state.mapBlockArguments(region.front().getArgument(0), targets)))
1676 return failure();
1677
1678 for (BlockArgument argument : region.front().getArguments().drop_front()) {
1679 if (failed(state.mapBlockArgument(
1680 argument, extraMappings[argument.getArgNumber() - 1])))
1681 return failure();
1682 }
1683
1684 return success();
1685}
1686
1687LogicalResult
1689 // Attaching this trait without the interface is a misuse of the API, but it
1690 // cannot be caught via a static_assert because interface registration is
1691 // dynamic.
1692 assert(isa<TransformOpInterface>(op) &&
1693 "should implement TransformOpInterface to have "
1694 "PossibleTopLevelTransformOpTrait");
1695
1696 if (op->getNumRegions() < 1)
1697 return op->emitOpError() << "expects at least one region";
1698
1699 Region *bodyRegion = &op->getRegion(0);
1700 if (!llvm::hasNItems(*bodyRegion, 1))
1701 return op->emitOpError() << "expects a single-block region";
1702
1703 Block *body = &bodyRegion->front();
1704 if (body->getNumArguments() == 0) {
1705 return op->emitOpError()
1706 << "expects the entry block to have at least one argument";
1707 }
1708 if (!llvm::isa<TransformHandleTypeInterface>(
1709 body->getArgument(0).getType())) {
1710 return op->emitOpError()
1711 << "expects the first entry block argument to be of type "
1712 "implementing TransformHandleTypeInterface";
1713 }
1714 BlockArgument arg = body->getArgument(0);
1715 if (op->getNumOperands() != 0) {
1716 if (arg.getType() != op->getOperand(0).getType()) {
1717 return op->emitOpError()
1718 << "expects the type of the block argument to match "
1719 "the type of the operand";
1720 }
1721 }
1722 for (BlockArgument arg : body->getArguments().drop_front()) {
1723 if (llvm::isa<TransformHandleTypeInterface, TransformParamTypeInterface,
1724 TransformValueHandleTypeInterface>(arg.getType()))
1725 continue;
1726
1728 op->emitOpError()
1729 << "expects trailing entry block arguments to be of type implementing "
1730 "TransformHandleTypeInterface, TransformValueHandleTypeInterface or "
1731 "TransformParamTypeInterface";
1732 diag.attachNote() << "argument #" << arg.getArgNumber() << " does not";
1733 return diag;
1734 }
1735
1736 if (auto *parent =
1738 if (op->getNumOperands() != body->getNumArguments()) {
1740 op->emitOpError()
1741 << "expects operands to be provided for a nested op";
1742 diag.attachNote(parent->getLoc())
1743 << "nested in another possible top-level op";
1744 return diag;
1745 }
1746 }
1747
1748 return success();
1749}
1750
1751//===----------------------------------------------------------------------===//
1752// Utilities for ParamProducedTransformOpTrait.
1753//===----------------------------------------------------------------------===//
1754
1757 producesHandle(op->getResults(), effects);
1758 bool hasPayloadOperands = false;
1759 for (OpOperand &operand : op->getOpOperands()) {
1760 onlyReadsHandle(operand, effects);
1761 if (llvm::isa<TransformHandleTypeInterface,
1762 TransformValueHandleTypeInterface>(operand.get().getType()))
1763 hasPayloadOperands = true;
1764 }
1765 if (hasPayloadOperands)
1766 onlyReadsPayload(effects);
1767}
1768
1769LogicalResult
1771 // Interfaces can be attached dynamically, so this cannot be a static
1772 // assert.
1773 if (!op->getName().getInterface<MemoryEffectOpInterface>()) {
1774 llvm::report_fatal_error(
1775 Twine("ParamProducerTransformOpTrait must be attached to an op that "
1776 "implements MemoryEffectsOpInterface, found on ") +
1777 op->getName().getStringRef());
1778 }
1779 for (Value result : op->getResults()) {
1780 if (llvm::isa<TransformParamTypeInterface>(result.getType()))
1781 continue;
1782 return op->emitOpError()
1783 << "ParamProducerTransformOpTrait attached to this op expects "
1784 "result types to implement TransformParamTypeInterface";
1785 }
1786 return success();
1787}
1788
1789//===----------------------------------------------------------------------===//
1790// Memory effects.
1791//===----------------------------------------------------------------------===//
1792
1796 for (OpOperand &handle : handles) {
1797 effects.emplace_back(MemoryEffects::Read::get(), &handle,
1799 effects.emplace_back(MemoryEffects::Free::get(), &handle,
1801 }
1802}
1803
1804/// Returns `true` if the given list of effects instances contains an instance
1805/// with the effect type specified as template parameter.
1806template <typename EffectTy, typename ResourceTy, typename Range>
1807static bool hasEffect(Range &&effects) {
1808 return llvm::any_of(effects, [](const MemoryEffects::EffectInstance &effect) {
1809 return isa<EffectTy>(effect.getEffect()) &&
1810 isa<ResourceTy>(effect.getResource());
1811 });
1812}
1813
1815 transform::TransformOpInterface transform) {
1816 auto iface = cast<MemoryEffectOpInterface>(transform.getOperation());
1818 iface.getEffectsOnValue(handle, effects);
1819 return ::hasEffect<MemoryEffects::Read, TransformMappingResource>(effects) &&
1821}
1822
1824 ResultRange handles,
1826 for (OpResult handle : handles) {
1827 effects.emplace_back(MemoryEffects::Allocate::get(), handle,
1829 effects.emplace_back(MemoryEffects::Write::get(), handle,
1831 }
1832}
1833
1837 for (BlockArgument handle : handles) {
1838 effects.emplace_back(MemoryEffects::Allocate::get(), handle,
1840 effects.emplace_back(MemoryEffects::Write::get(), handle,
1842 }
1843}
1844
1848 for (OpOperand &handle : handles) {
1849 effects.emplace_back(MemoryEffects::Read::get(), &handle,
1851 }
1852}
1853
1859
1864
1865bool transform::doesModifyPayload(transform::TransformOpInterface transform) {
1866 auto iface = cast<MemoryEffectOpInterface>(transform.getOperation());
1868 iface.getEffects(effects);
1869 return ::hasEffect<MemoryEffects::Write, PayloadIRResource>(effects);
1870}
1871
1872bool transform::doesReadPayload(transform::TransformOpInterface transform) {
1873 auto iface = cast<MemoryEffectOpInterface>(transform.getOperation());
1875 iface.getEffects(effects);
1876 return ::hasEffect<MemoryEffects::Read, PayloadIRResource>(effects);
1877}
1878
1880 Block &block, llvm::SmallDenseSet<unsigned int> &consumedArguments) {
1882 for (Operation &nested : block) {
1883 auto iface = dyn_cast<MemoryEffectOpInterface>(nested);
1884 if (!iface)
1885 continue;
1886
1887 effects.clear();
1888 iface.getEffects(effects);
1889 for (const MemoryEffects::EffectInstance &effect : effects) {
1890 BlockArgument argument =
1891 dyn_cast_or_null<BlockArgument>(effect.getValue());
1892 if (!argument || argument.getOwner() != &block ||
1893 !isa<MemoryEffects::Free>(effect.getEffect()) ||
1894 effect.getResource() != transform::TransformMappingResource::get()) {
1895 continue;
1896 }
1897 consumedArguments.insert(argument.getArgNumber());
1898 }
1899 }
1900}
1901
1902//===----------------------------------------------------------------------===//
1903// Utilities for TransformOpInterface.
1904//===----------------------------------------------------------------------===//
1905
1906SmallVector<OpOperand *> transform::detail::getConsumedHandleOpOperands(
1907 TransformOpInterface transformOp) {
1908 SmallVector<OpOperand *> consumedOperands;
1909 consumedOperands.reserve(transformOp->getNumOperands());
1910 auto memEffectInterface =
1911 cast<MemoryEffectOpInterface>(transformOp.getOperation());
1912 SmallVector<MemoryEffects::EffectInstance, 2> effects;
1913 for (OpOperand &target : transformOp->getOpOperands()) {
1914 effects.clear();
1915 memEffectInterface.getEffectsOnValue(target.get(), effects);
1916 if (llvm::any_of(effects, [](const MemoryEffects::EffectInstance &effect) {
1917 return isa<transform::TransformMappingResource>(
1918 effect.getResource()) &&
1919 isa<MemoryEffects::Free>(effect.getEffect());
1920 })) {
1921 consumedOperands.push_back(&target);
1922 }
1923 }
1924 return consumedOperands;
1925}
1926
1928 auto iface = cast<MemoryEffectOpInterface>(op);
1930 iface.getEffects(effects);
1931
1932 auto effectsOn = [&](Value value) {
1933 return llvm::make_filter_range(
1934 effects, [value](const MemoryEffects::EffectInstance &instance) {
1935 return instance.getValue() == value;
1936 });
1937 };
1938
1939 std::optional<unsigned> firstConsumedOperand;
1940 for (OpOperand &operand : op->getOpOperands()) {
1941 auto range = effectsOn(operand.get());
1942 if (range.empty()) {
1944 op->emitError() << "TransformOpInterface requires memory effects "
1945 "on operands to be specified";
1946 diag.attachNote() << "no effects specified for operand #"
1947 << operand.getOperandNumber();
1948 return diag;
1949 }
1952 << "TransformOpInterface did not expect "
1953 "'allocate' memory effect on an operand";
1954 diag.attachNote() << "specified for operand #"
1955 << operand.getOperandNumber();
1956 return diag;
1957 }
1958 if (!firstConsumedOperand &&
1960 firstConsumedOperand = operand.getOperandNumber();
1961 }
1962 }
1963
1964 if (firstConsumedOperand &&
1967 op->emitError()
1968 << "TransformOpInterface expects ops consuming operands to have a "
1969 "'write' effect on the payload resource";
1970 diag.attachNote() << "consumes operand #" << *firstConsumedOperand;
1971 return diag;
1972 }
1973
1974 for (OpResult result : op->getResults()) {
1975 auto range = effectsOn(result);
1977 range)) {
1979 op->emitError() << "TransformOpInterface requires 'allocate' memory "
1980 "effect to be specified for results";
1981 diag.attachNote() << "no 'allocate' effect specified for result #"
1982 << result.getResultNumber();
1983 return diag;
1984 }
1985 }
1986
1987 return success();
1988}
1989
1990//===----------------------------------------------------------------------===//
1991// Entry point.
1992//===----------------------------------------------------------------------===//
1993
1995 Operation *payloadRoot, TransformOpInterface transform,
1996 const RaggedArray<MappedValue> &extraMapping,
1997 const TransformOptions &options, bool enforceToplevelTransformOp,
1998 function_ref<void(TransformState &)> stateInitializer,
1999 function_ref<LogicalResult(TransformState &)> stateExporter) {
2000 if (enforceToplevelTransformOp) {
2001 if (!transform->hasTrait<PossibleTopLevelTransformOpTrait>() ||
2002 transform->getNumOperands() != 0) {
2003 return transform->emitError()
2004 << "expected transform to start at the top-level transform op";
2005 }
2006 } else if (failed(
2008 return failure();
2009 }
2010
2011 TransformState state(transform->getParentRegion(), payloadRoot, extraMapping,
2012 options);
2013 if (stateInitializer)
2014 stateInitializer(state);
2015 if (state.applyTransform(transform).checkAndReport().failed())
2016 return failure();
2017 if (stateExporter)
2018 return stateExporter(state);
2019 return success();
2020}
2021
2022//===----------------------------------------------------------------------===//
2023// Generated interface implementation.
2024//===----------------------------------------------------------------------===//
2025
2026#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.cpp.inc"
2027#include "mlir/Dialect/Transform/Interfaces/TransformTypeInterfaces.cpp.inc"
for(Operation *op :ops)
return success()
static Operation * getCommonDefiningOp(const ValueVector &values)
Return the operation that defines all values in the vector.
static InFlightDiagnostic emitDiag(Location location, DiagnosticSeverity severity, const Twine &message)
Helper function used to emit a diagnostic with an optionally empty twine message.
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
if(!isCopyOut)
*if copies could not be generated due to yet unimplemented cases *copyInPlacementStart and copyOutPlacementStart in copyPlacementBlock *specify the insertion points where the incoming copies and outgoing should be inserted(the insertion happens right before the *insertion point). Since `begin` can itself be invalidated due to the memref *rewriting done from this method
*if copies could not be generated due to yet unimplemented cases *copyInPlacementStart and copyOutPlacementStart in copyPlacementBlock *specify the insertion points where the incoming copies and outgoing should be the output argument nBegin is set to its * replacement(set to `begin` if no invalidation happens). Since outgoing *copies could have been inserted at `end`
static std::string diag(const llvm::Value &value)
static llvm::ManagedStatic< PassManagerOptions > options
#define FULL_LDBG()
static void remapEffects(MemoryEffectOpInterface iface, BlockArgument source, OpOperand *target, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
Appends to effects the memory effect instances on target with the same resource and effect as the one...
static void remapArgumentEffects(Block &block, MutableArrayRef< OpOperand > operands, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
Appends to effects the same effects as the operations of block have on block arguments but associated...
static bool happensBefore(Operation *a, Operation *b)
Return true if a happens before b, i.e., a or one of its ancestors properly dominates b and b is not ...
static void dropMappingEntry(Mapping &mapping, Key key, Mapped mapped)
#define DEBUG_PRINT_AFTER_ALL
static DiagnosedSilenceableFailure checkRepeatedConsumptionInOperand(ArrayRef< T > payload, transform::TransformOpInterface transform, unsigned operandNumber)
static DiagnosedSilenceableFailure dispatchMappedValues(Value handle, ArrayRef< transform::MappedValue > values, function_ref< LogicalResult(ArrayRef< Operation * >)> operationsFn, function_ref< LogicalResult(ArrayRef< transform::Param >)> paramsFn, function_ref< LogicalResult(ValueRange)> valuesFn)
Given a list of MappedValues, cast them to the value kind implied by the interface of the handle type...
static SmallVector< T > castVector(ArrayRef< transform::MappedValue > range)
This class represents an argument of a Block.
Definition Value.h:309
unsigned getArgNumber() const
Returns the number of this argument.
Definition Value.h:321
Block * getOwner() const
Returns the block that owns this argument.
Definition Value.h:318
Block represents an ordered list of Operations.
Definition Block.h:33
MutableArrayRef< BlockArgument > BlockArgListType
Definition Block.h:95
Operation * findAncestorOpInBlock(Operation &op)
Returns 'op' if 'op' lies in this block, or otherwise finds the ancestor operation of 'op' that lies ...
Definition Block.cpp:74
BlockArgument getArgument(unsigned i)
Definition Block.h:139
unsigned getNumArguments()
Definition Block.h:138
Operation * getTerminator()
Get the terminator operation of this block.
Definition Block.cpp:249
BlockArgListType getArguments()
Definition Block.h:97
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
Definition Block.cpp:31
A compatibility class connecting InFlightDiagnostic to DiagnosedSilenceableFailure while providing an...
The result of a transform IR operation application.
LogicalResult silence()
Converts silenceable failure into LogicalResult success without reporting the diagnostic,...
static DiagnosedSilenceableFailure success()
Constructs a DiagnosedSilenceableFailure in the success state.
std::string getMessage() const
Returns the diagnostic message without emitting it.
LogicalResult checkAndReport()
Converts all kinds of failure into a LogicalResult failure, emitting the diagnostic if necessary.
static DiagnosedSilenceableFailure silenceableFailure(Diagnostic &&diag)
Constructs a DiagnosedSilenceableFailure in the silenceable failure state, ready to emit the given di...
bool succeeded() const
Returns true if this is a success.
static DiagnosedSilenceableFailure definiteFailure()
Constructs a DiagnosedSilenceableFailure in the failure state.
bool isSilenceableFailure() const
Returns true if this is a silenceable failure.
This class contains all of the information necessary to report a diagnostic to the DiagnosticEngine.
Diagnostic & attachNote(std::optional< Location > noteLoc=std::nullopt)
Attaches a note to this diagnostic.
IRValueT get() const
Return the current value being used by this operand.
This class represents a diagnostic that is inflight and set to be reported.
Diagnostic & attachNote(std::optional< Location > noteLoc=std::nullopt)
Attaches a note to this diagnostic.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
void setListener(Listener *newListener)
Sets the listener of this builder to the one provided.
Definition Builders.h:318
This class represents an operand of an operation.
Definition Value.h:257
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Definition Value.cpp:226
Set of flags used to control the behavior of the various IR print methods (e.g.
This is a value defined by a result of an operation.
Definition Value.h:457
unsigned getResultNumber() const
Returns the number of this result.
Definition Value.h:469
This class provides the API for a sub-set of ops that are known to be constant-like.
This class provides the API for ops that are known to be isolated from above.
A wrapper class that allows for printing an operation with a set of flags, useful to act as a "stream...
Definition Operation.h:1111
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
StringAttr getIdentifier() const
Return the name of this operation as a StringAttr.
T::Concept * getInterface() const
Returns an instance of the concept object for the given interface if it was registered to this operat...
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Definition Operation.h:686
Value getOperand(unsigned idx)
Definition Operation.h:350
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition Operation.h:749
bool isBeforeInBlock(Operation *other)
Given an operation 'other' that is within the same parent block, return whether the current operation...
Block * getBlock()
Returns the operation block that contains this operation.
Definition Operation.h:213
Operation * getParentWithTrait()
Returns the closest surrounding parent operation with trait Trait.
Definition Operation.h:248
unsigned getNumRegions()
Returns the number of regions held by this operation.
Definition Operation.h:674
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:223
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition Operation.h:234
MutableArrayRef< OpOperand > getOpOperands()
Definition Operation.h:383
unsigned getNumOperands()
Definition Operation.h:346
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
OperationName getName()
The name of an operation is the key identifier for it.
Definition Operation.h:119
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition Operation.h:378
result_range getOpResults()
Definition Operation.h:420
result_range getResults()
Definition Operation.h:415
bool isProperAncestor(Operation *other)
Return true if this operation is a proper ancestor of the other operation.
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
unsigned getNumResults()
Return the number of results held by this operation.
Definition Operation.h:404
A 2D array where each row may have different length.
Definition RaggedArray.h:18
size_t size() const
Returns the number of rows in the 2D array.
Definition RaggedArray.h:21
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition Region.h:26
Block & front()
Definition Region.h:65
Operation * getParentOp()
Return the parent operation this region is attached to.
Definition Region.h:200
This class implements the result iterators for the Operation class.
Definition ValueRange.h:247
RewriterBase(MLIRContext *ctx, OpBuilder::Listener *listener=nullptr)
Initialize the builder.
Resource * getResource() const
Return the resource that the effect applies to.
EffectT * getEffect() const
Return the effect being applied.
Value getValue() const
Return the value the effect is applied on, or nullptr if there isn't a known value being affected.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
Location getLoc() const
Return the location of this value.
Definition Value.cpp:24
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
Operation * getOwner() const
Return the owner of this operand.
Definition UseDefLists.h:38
A list of results of applying a transform op with ApplyEachOpTrait to a single payload operation,...
size_t size() const
Returns the number of elements in the list.
A specialized listener that keeps track of cases in which no replacement payload could be found.
bool failed() const
Return "true" if this tracking listener had a failure.
std::string getLatestMatchFailureMessage()
Return the latest match notification message.
void notifyPayloadReplacementNotFound(Operation *op, ValueRange values, DiagnosedSilenceableFailure &&diag) override
This function is called when a tracked payload op is dropped because no replacement op was found.
void notifyMatchFailure(Location loc, function_ref< void(Diagnostic &)> reasonCallback) override
Notify the listener that the pattern failed to match, and provide a callback to populate a diagnostic...
DiagnosedSilenceableFailure checkAndResetError()
Check and return the current error state of this listener.
This trait is supposed to be attached to Transform dialect operations that can be standalone top-leve...
TrackingListener failures are reported only for ops that have this trait.
void notifyMatchFailure(Location loc, function_ref< void(Diagnostic &)> reasonCallback) override
Notify the listener that the pattern failed to match the given operation, and provide a callback to p...
TransformOpInterface getTransformOp() const
Return the transform op in which this TrackingListener is used.
TrackingListener(TransformState &state, TransformOpInterface op, TrackingListenerConfig config=TrackingListenerConfig())
Create a new TrackingListener for usage in the specified transform op.
static Operation * getCommonDefiningOp(ValueRange values)
Return the single op that defines all given values (if any).
virtual DiagnosedSilenceableFailure findReplacementOp(Operation *&result, Operation *op, ValueRange newValues) const
Return a replacement payload op for the given op, which is going to be replaced with the given values...
Options controlling the application of transform operations by the TransformState.
Local mapping between values defined by a specific op implementing the TransformOpInterface and the p...
void setValues(OpResult handle, Range &&values)
Indicates that the result of the transform IR op at the given position corresponds to the given range...
void setParams(OpResult value, ArrayRef< TransformState::Param > params)
Indicates that the result of the transform IR op at the given position corresponds to the given list ...
void set(OpResult value, Range &&ops)
Indicates that the result of the transform IR op at the given position corresponds to the given list ...
void setRemainingToEmpty(TransformOpInterface transform)
Sets the currently unset results to empty lists of the kind expected by the corresponding results of ...
void setMappedValues(OpResult handle, ArrayRef< MappedValue > values)
Indicates that the result of the transform IR op at the given position corresponds to the given range...
This is a special rewriter to be used in transform op implementations, providing additional helper fu...
TransformRewriter(MLIRContext *ctx, ErrorCheckingTrackingListener *listener)
Create a new TransformRewriter.
bool hasTrackingFailures() const
Return "true" if the tracking listener had failures.
LogicalResult notifyPayloadOperationReplaced(Operation *op, Operation *replacement)
Notify the transform dialect interpreter that the given op has been replaced with another op and that...
void silenceTrackingFailure()
Silence all tracking failures that have been encountered so far.
Extension(TransformState &state)
Constructs an extension of the given TransformState object.
LogicalResult replacePayloadOp(Operation *op, Operation *replacement)
Replaces the given payload op with another op.
virtual ~Extension()
Base virtual destructor.
LogicalResult replacePayloadValue(Value value, Value replacement)
Replaces the given payload value with another value.
A RAII object maintaining a "stack frame" for a transform IR region.
~RegionScope()
Forgets the mapping from or to values defined in the associated transform IR region,...
The state maintained across applications of various ops implementing the TransformOpInterface.
LogicalResult getHandlesForPayloadValue(Value payloadValue, SmallVectorImpl< Value > &handles, bool includeOutOfScope=false) const
Populates handles with all handles pointing to the given payload IR value.
auto getPayloadOps(Value value) const
Returns an iterator that enumerates all ops that the given transform IR value corresponds to.
auto getPayloadValues(Value handleValue) const
Returns an iterator that enumerates all payload IR values that the given transform IR value correspon...
LogicalResult mapBlockArguments(BlockArgument argument, ArrayRef< Operation * > operations)
Records the mapping between a block argument in the transform IR and a list of operations in the payl...
DiagnosedSilenceableFailure applyTransform(TransformOpInterface transform)
Applies the transformation specified by the given transform op and updates the state accordingly.
ArrayRef< Attribute > getParams(Value value) const
Returns the list of parameters that the given transform IR value corresponds to.
LogicalResult mapBlockArgument(BlockArgument argument, ArrayRef< MappedValue > values)
size_t getNumTopLevelMappings() const
Returns the number of extra mappings for the top-level operation.
LogicalResult getHandlesForPayloadOp(Operation *op, SmallVectorImpl< Value > &handles, bool includeOutOfScope=false) const
Populates handles with all handles pointing to the given Payload IR op.
Operation * getTopLevel() const
Returns the op at which the transformation state is rooted.
ArrayRef< MappedValue > getTopLevelMapping(size_t position) const
Returns the position-th extra mapping for the top-level operation.
SideEffects::EffectInstance< Effect > EffectInstance
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
LogicalResult verifyTransformOpInterface(Operation *op)
Verification hook for TransformOpInterface.
SmallVector< OpOperand * > getConsumedHandleOpOperands(transform::TransformOpInterface transformOp)
Returns all operands that are handles and being consumed by the given op.
LogicalResult appendValueMappings(MutableArrayRef< SmallVector< transform::MappedValue > > mappings, ValueRange values, const transform::TransformState &state, bool flatten=true)
Appends the entities associated with the given transform values in state to the pre-existing list of ...
void setApplyToOneResults(Operation *transformOp, TransformResults &transformResults, ArrayRef< ApplyToEachResultList > results)
"Transpose" the results produced by individual applications, arranging them per result value of the t...
void forwardTerminatorOperands(Block *block, transform::TransformState &state, transform::TransformResults &results)
Populates results with payload associations that match exactly those of the operands to block's termi...
void getParamProducerTransformOpTraitEffects(Operation *op, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
Non-template implementation of ParamProducerTransformOpTrait::getEffects().
LogicalResult checkNestedConsumption(Location loc, ArrayRef< Operation * > targets)
Reports an error and returns failure if targets contains an ancestor operation before its descendant ...
LogicalResult mapPossibleTopLevelTransformOpBlockArguments(TransformState &state, Operation *op, Region &region)
Maps the only block argument of the op with PossibleTopLevelTransformOpTrait to either the list of op...
TransformState makeTransformStateForTesting(Region *region, Operation *payloadRoot)
Make a dummy transform state for testing purposes.
LogicalResult checkApplyToOne(Operation *transformOp, Location payloadOpLoc, const ApplyToEachResultList &partialResult)
Check that the contents of partialResult matches the number, kind (payload op or parameter) and nulli...
LogicalResult verifyParamProducerTransformOpTrait(Operation *op)
Non-template implementation of ParamProducerTransformOpTrait::verify().
void prepareValueMappings(SmallVectorImpl< SmallVector< transform::MappedValue > > &mappings, ValueRange values, const transform::TransformState &state)
Populates mappings with mapped values associated with the given transform IR values in the given stat...
LogicalResult verifyPossibleTopLevelTransformOpTrait(Operation *op)
Verification hook for PossibleTopLevelTransformOpTrait.
void getPotentialTopLevelEffects(Operation *operation, Value root, Block &body, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
Populates effects with side effects implied by PossibleTopLevelTransformOpTrait for the given operati...
void onlyReadsPayload(SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
LogicalResult applyTransforms(Operation *payloadRoot, TransformOpInterface transform, const RaggedArray< MappedValue > &extraMapping={}, const TransformOptions &options=TransformOptions(), bool enforceToplevelTransformOp=true, function_ref< void(TransformState &)> stateInitializer=nullptr, function_ref< LogicalResult(TransformState &)> stateExporter=nullptr)
Entry point to the Transform dialect infrastructure.
bool isHandleConsumed(Value handle, transform::TransformOpInterface transform)
Checks whether the transform op consumes the given handle.
llvm::PointerUnion< Operation *, Param, Value > MappedValue
void getConsumedBlockArguments(Block &block, llvm::SmallDenseSet< unsigned > &consumedArguments)
Populates consumedArguments with positions of block arguments that are consumed by the operations in ...
bool doesModifyPayload(transform::TransformOpInterface transform)
Checks whether the transform op modifies the payload.
void producesHandle(ResultRange handles, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
bool doesReadPayload(transform::TransformOpInterface transform)
Checks whether the transform op reads the payload.
void consumesHandle(MutableArrayRef< OpOperand > handles, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
Populates effects with the memory effects indicating the operation on the given handle value:
void onlyReadsHandle(MutableArrayRef< OpOperand > handles, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
void modifiesPayload(SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
Populates effects with the memory effects indicating the access to payload IR resource.
Include the generated interface declarations.
const FrozenRewritePatternSet GreedyRewriteConfig config
DiagnosedSilenceableFailure emitSilenceableFailure(Location loc, const Twine &message={})
Emits a silenceable failure with the given message.
llvm::DenseSet< ValueT, ValueInfoT > DenseSet
Definition LLVM.h:120
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
DiagnosedDefiniteFailure emitDefiniteFailure(Location loc, const Twine &message={})
Emits a definite failure with the given message.
bool hasEffect(Operation *op)
Returns "true" if op has an effect of type EffectTy.
llvm::function_ref< Fn > function_ref
Definition LLVM.h:144
Represents a range (offset, size, and stride) where each element of the triple may be dynamic or stat...
A configuration object for customizing a TrackingListener.