MLIR 22.0.0git
TransformInterfaces.cpp
Go to the documentation of this file.
1//===- TransformInterfaces.cpp - Transform Dialect Interfaces -------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
10
11#include "mlir/IR/Diagnostics.h"
12#include "mlir/IR/Operation.h"
15#include "llvm/ADT/STLExtras.h"
16#include "llvm/ADT/ScopeExit.h"
17#include "llvm/ADT/iterator.h"
18#include "llvm/Support/Debug.h"
19#include "llvm/Support/DebugLog.h"
20#include "llvm/Support/ErrorHandling.h"
21#include "llvm/Support/InterleavedRange.h"
22
23#define DEBUG_TYPE "transform-dialect"
24#define DEBUG_PRINT_AFTER_ALL "transform-dialect-print-top-level-after-all"
25#define FULL_LDBG() LDBG(4)
26
27using namespace mlir;
28
29//===----------------------------------------------------------------------===//
30// Helper functions
31//===----------------------------------------------------------------------===//
32
33/// Return true if `a` happens before `b`, i.e., `a` or one of its ancestors
34/// properly dominates `b` and `b` is not inside `a`.
36 do {
37 if (a->isProperAncestor(b))
38 return false;
39 if (Operation *bAncestor = a->getBlock()->findAncestorOpInBlock(*b)) {
40 return a->isBeforeInBlock(bAncestor);
41 }
42 } while ((a = a->getParentOp()));
43 return false;
44}
45
46//===----------------------------------------------------------------------===//
47// TransformState
48//===----------------------------------------------------------------------===//
49
50transform::TransformState::TransformState(
51 Region *region, Operation *payloadRoot,
52 const RaggedArray<MappedValue> &extraMappings,
53 const TransformOptions &options)
54 : topLevel(payloadRoot), options(options) {
55 topLevelMappedValues.reserve(extraMappings.size());
56 for (ArrayRef<MappedValue> mapping : extraMappings)
57 topLevelMappedValues.push_back(mapping);
58 if (region) {
59 RegionScope *scope = new RegionScope(*this, *region);
60 topLevelRegionScope.reset(scope);
61 }
62}
63
65
67transform::TransformState::getPayloadOpsView(Value value) const {
68 const TransformOpMapping &operationMapping = getMapping(value).direct;
69 auto iter = operationMapping.find(value);
70 assert(iter != operationMapping.end() &&
71 "cannot find mapping for payload handle (param/value handle "
72 "provided?)");
73 return iter->getSecond();
74}
75
77 const ParamMapping &mapping = getMapping(value).params;
78 auto iter = mapping.find(value);
79 assert(iter != mapping.end() && "cannot find mapping for param handle "
80 "(operation/value handle provided?)");
81 return iter->getSecond();
82}
83
85transform::TransformState::getPayloadValuesView(Value handleValue) const {
86 const ValueMapping &mapping = getMapping(handleValue).values;
87 auto iter = mapping.find(handleValue);
88 assert(iter != mapping.end() && "cannot find mapping for value handle "
89 "(param/operation handle provided?)");
90 return iter->getSecond();
91}
92
95 bool includeOutOfScope) const {
96 bool found = false;
97 for (const auto &[region, mapping] : llvm::reverse(mappings)) {
98 auto iterator = mapping->reverse.find(op);
99 if (iterator != mapping->reverse.end()) {
100 llvm::append_range(handles, iterator->getSecond());
101 found = true;
102 }
103 // Stop looking when reaching a region that is isolated from above.
104 if (!includeOutOfScope &&
106 break;
107 }
108
109 return success(found);
110}
111
113 Value payloadValue, SmallVectorImpl<Value> &handles,
114 bool includeOutOfScope) const {
115 bool found = false;
116 for (const auto &[region, mapping] : llvm::reverse(mappings)) {
117 auto iterator = mapping->reverseValues.find(payloadValue);
118 if (iterator != mapping->reverseValues.end()) {
119 llvm::append_range(handles, iterator->getSecond());
120 found = true;
121 }
122 // Stop looking when reaching a region that is isolated from above.
123 if (!includeOutOfScope &&
125 break;
126 }
127
128 return success(found);
129}
130
131/// Given a list of MappedValues, cast them to the value kind implied by the
132/// interface of the handle type, and dispatch to one of the callbacks.
135 function_ref<LogicalResult(ArrayRef<Operation *>)> operationsFn,
136 function_ref<LogicalResult(ArrayRef<transform::Param>)> paramsFn,
137 function_ref<LogicalResult(ValueRange)> valuesFn) {
138 if (llvm::isa<transform::TransformHandleTypeInterface>(handle.getType())) {
139 SmallVector<Operation *> operations;
140 operations.reserve(values.size());
141 for (transform::MappedValue value : values) {
142 if (auto *op = llvm::dyn_cast_if_present<Operation *>(value)) {
143 operations.push_back(op);
144 continue;
145 }
146 return emitSilenceableFailure(handle.getLoc())
147 << "wrong kind of value provided for top-level operation handle";
148 }
149 if (failed(operationsFn(operations)))
152 }
153
154 if (llvm::isa<transform::TransformValueHandleTypeInterface>(
155 handle.getType())) {
156 SmallVector<Value> payloadValues;
157 payloadValues.reserve(values.size());
158 for (transform::MappedValue value : values) {
159 if (auto v = llvm::dyn_cast_if_present<Value>(value)) {
160 payloadValues.push_back(v);
161 continue;
162 }
163 return emitSilenceableFailure(handle.getLoc())
164 << "wrong kind of value provided for the top-level value handle";
165 }
166 if (failed(valuesFn(payloadValues)))
169 }
170
171 assert(llvm::isa<transform::TransformParamTypeInterface>(handle.getType()) &&
172 "unsupported kind of block argument");
174 parameters.reserve(values.size());
175 for (transform::MappedValue value : values) {
176 if (auto attr = llvm::dyn_cast_if_present<Attribute>(value)) {
177 parameters.push_back(attr);
178 continue;
179 }
180 return emitSilenceableFailure(handle.getLoc())
181 << "wrong kind of value provided for top-level parameter";
182 }
183 if (failed(paramsFn(parameters)))
186}
187
188LogicalResult
190 ArrayRef<MappedValue> values) {
192 argument, values,
193 [&](ArrayRef<Operation *> operations) {
194 return setPayloadOps(argument, operations);
195 },
196 [&](ArrayRef<Param> params) {
197 return setParams(argument, params);
198 },
199 [&](ValueRange payloadValues) {
200 return setPayloadValues(argument, payloadValues);
201 })
202 .checkAndReport();
203}
204
206 Block::BlockArgListType arguments,
208 for (auto &&[argument, values] : llvm::zip_equal(arguments, mapping))
209 if (failed(mapBlockArgument(argument, values)))
210 return failure();
211 return success();
212}
213
214LogicalResult
215transform::TransformState::setPayloadOps(Value value,
216 ArrayRef<Operation *> targets) {
217 assert(value != kTopLevelValue &&
218 "attempting to reset the transformation root");
219 assert(llvm::isa<TransformHandleTypeInterface>(value.getType()) &&
220 "wrong handle type");
221
222 for (Operation *target : targets) {
223 if (target)
224 continue;
225 return emitError(value.getLoc())
226 << "attempting to assign a null payload op to this transform value";
227 }
228
229 auto iface = llvm::cast<TransformHandleTypeInterface>(value.getType());
231 iface.checkPayload(value.getLoc(), targets);
232 if (failed(result.checkAndReport()))
233 return failure();
234
235 // Setting new payload for the value without cleaning it first is a misuse of
236 // the API, assert here.
237 SmallVector<Operation *> storedTargets(targets);
238 Mappings &mappings = getMapping(value);
239 bool inserted =
240 mappings.direct.insert({value, std::move(storedTargets)}).second;
241 assert(inserted && "value is already associated with another list");
242 (void)inserted;
243
244 for (Operation *op : targets)
245 mappings.reverse[op].push_back(value);
246
247 return success();
248}
249
250LogicalResult
251transform::TransformState::setPayloadValues(Value handle,
252 ValueRange payloadValues) {
253 assert(handle != nullptr && "attempting to set params for a null value");
254 assert(llvm::isa<TransformValueHandleTypeInterface>(handle.getType()) &&
255 "wrong handle type");
256
257 for (Value payload : payloadValues) {
258 if (payload)
259 continue;
260 return emitError(handle.getLoc()) << "attempting to assign a null payload "
261 "value to this transform handle";
262 }
263
264 auto iface = llvm::cast<TransformValueHandleTypeInterface>(handle.getType());
265 SmallVector<Value> payloadValueVector = llvm::to_vector(payloadValues);
266 DiagnosedSilenceableFailure result =
267 iface.checkPayload(handle.getLoc(), payloadValueVector);
268 if (failed(result.checkAndReport()))
269 return failure();
270
271 Mappings &mappings = getMapping(handle);
272 bool inserted =
273 mappings.values.insert({handle, std::move(payloadValueVector)}).second;
274 assert(
275 inserted &&
276 "value handle is already associated with another list of payload values");
277 (void)inserted;
278
279 for (Value payload : payloadValues)
280 mappings.reverseValues[payload].push_back(handle);
281
282 return success();
283}
284
285LogicalResult transform::TransformState::setParams(Value value,
286 ArrayRef<Param> params) {
287 assert(value != nullptr && "attempting to set params for a null value");
288
289 for (Attribute attr : params) {
290 if (attr)
291 continue;
292 return emitError(value.getLoc())
293 << "attempting to assign a null parameter to this transform value";
294 }
295
296 auto valueType = llvm::dyn_cast<TransformParamTypeInterface>(value.getType());
297 assert(value &&
298 "cannot associate parameter with a value of non-parameter type");
299 DiagnosedSilenceableFailure result =
300 valueType.checkPayload(value.getLoc(), params);
301 if (failed(result.checkAndReport()))
302 return failure();
303
304 Mappings &mappings = getMapping(value);
305 bool inserted =
306 mappings.params.insert({value, llvm::to_vector(params)}).second;
307 assert(inserted && "value is already associated with another list of params");
308 (void)inserted;
309 return success();
310}
311
312template <typename Mapping, typename Key, typename Mapped>
313static void dropMappingEntry(Mapping &mapping, Key key, Mapped mapped) {
314 auto it = mapping.find(key);
315 if (it == mapping.end())
316 return;
317
318 llvm::erase(it->getSecond(), mapped);
319 if (it->getSecond().empty())
320 mapping.erase(it);
321}
322
323void transform::TransformState::forgetMapping(Value opHandle,
324 ValueRange origOpFlatResults,
325 bool allowOutOfScope) {
326 Mappings &mappings = getMapping(opHandle, allowOutOfScope);
327 for (Operation *op : mappings.direct[opHandle])
328 dropMappingEntry(mappings.reverse, op, opHandle);
329 mappings.direct.erase(opHandle);
330#if LLVM_ENABLE_ABI_BREAKING_CHECKS
331 // Payload IR is removed from the mapping. This invalidates the respective
332 // iterators.
333 mappings.incrementTimestamp(opHandle);
334#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
335
336 for (Value opResult : origOpFlatResults) {
337 SmallVector<Value> resultHandles;
338 (void)getHandlesForPayloadValue(opResult, resultHandles);
339 for (Value resultHandle : resultHandles) {
340 Mappings &localMappings = getMapping(resultHandle);
341 dropMappingEntry(localMappings.values, resultHandle, opResult);
342#if LLVM_ENABLE_ABI_BREAKING_CHECKS
343 // Payload IR is removed from the mapping. This invalidates the respective
344 // iterators.
345 mappings.incrementTimestamp(resultHandle);
346#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
347 dropMappingEntry(localMappings.reverseValues, opResult, resultHandle);
348 }
349 }
350}
351
352void transform::TransformState::forgetValueMapping(
353 Value valueHandle, ArrayRef<Operation *> payloadOperations) {
354 Mappings &mappings = getMapping(valueHandle);
355 for (Value payloadValue : mappings.reverseValues[valueHandle])
356 dropMappingEntry(mappings.reverseValues, payloadValue, valueHandle);
357 mappings.values.erase(valueHandle);
358#if LLVM_ENABLE_ABI_BREAKING_CHECKS
359 // Payload IR is removed from the mapping. This invalidates the respective
360 // iterators.
361 mappings.incrementTimestamp(valueHandle);
362#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
363
364 for (Operation *payloadOp : payloadOperations) {
365 SmallVector<Value> opHandles;
366 (void)getHandlesForPayloadOp(payloadOp, opHandles);
367 for (Value opHandle : opHandles) {
368 Mappings &localMappings = getMapping(opHandle);
369 dropMappingEntry(localMappings.direct, opHandle, payloadOp);
370 dropMappingEntry(localMappings.reverse, payloadOp, opHandle);
371
372#if LLVM_ENABLE_ABI_BREAKING_CHECKS
373 // Payload IR is removed from the mapping. This invalidates the respective
374 // iterators.
375 localMappings.incrementTimestamp(opHandle);
376#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
377 }
378 }
379}
380
381LogicalResult
382transform::TransformState::replacePayloadOp(Operation *op,
383 Operation *replacement) {
384 // TODO: consider invalidating the handles to nested objects here.
385
386#ifndef NDEBUG
387 for (Value opResult : op->getResults()) {
388 SmallVector<Value> valueHandles;
389 (void)getHandlesForPayloadValue(opResult, valueHandles,
390 /*includeOutOfScope=*/true);
391 assert(valueHandles.empty() && "expected no mapping to old results");
392 }
393#endif // NDEBUG
394
395 // Drop the mapping between the op and all handles that point to it. Fail if
396 // there are no handles.
397 SmallVector<Value> opHandles;
398 if (failed(getHandlesForPayloadOp(op, opHandles, /*includeOutOfScope=*/true)))
399 return failure();
400 for (Value handle : opHandles) {
401 Mappings &mappings = getMapping(handle, /*allowOutOfScope=*/true);
402 dropMappingEntry(mappings.reverse, op, handle);
403 }
404
405 // Replace the pointed-to object of all handles with the replacement object.
406 // In case a payload op was erased (replacement object is nullptr), a nullptr
407 // is stored in the mapping. These nullptrs are removed after each transform.
408 // Furthermore, nullptrs are not enumerated by payload op iterators. The
409 // relative order of ops is preserved.
410 //
411 // Removing an op from the mapping would be problematic because removing an
412 // element from an array invalidates iterators; merely changing the value of
413 // elements does not.
414 for (Value handle : opHandles) {
415 Mappings &mappings = getMapping(handle, /*allowOutOfScope=*/true);
416 auto it = mappings.direct.find(handle);
417 if (it == mappings.direct.end())
418 continue;
419
420 SmallVector<Operation *, 2> &association = it->getSecond();
421 // Note that an operation may be associated with the handle more than once.
422 for (Operation *&mapped : association) {
423 if (mapped == op)
424 mapped = replacement;
425 }
426
427 if (replacement) {
428 mappings.reverse[replacement].push_back(handle);
429 } else {
430 opHandlesToCompact.insert(handle);
431 }
432 }
433
434 return success();
435}
436
437LogicalResult
438transform::TransformState::replacePayloadValue(Value value, Value replacement) {
439 SmallVector<Value> valueHandles;
440 if (failed(getHandlesForPayloadValue(value, valueHandles,
441 /*includeOutOfScope=*/true)))
442 return failure();
443
444 for (Value handle : valueHandles) {
445 Mappings &mappings = getMapping(handle, /*allowOutOfScope=*/true);
446 dropMappingEntry(mappings.reverseValues, value, handle);
447
448 // If replacing with null, that is erasing the mapping, drop the mapping
449 // between the handles and the IR objects
450 if (!replacement) {
451 dropMappingEntry(mappings.values, handle, value);
452#if LLVM_ENABLE_ABI_BREAKING_CHECKS
453 // Payload IR is removed from the mapping. This invalidates the respective
454 // iterators.
455 mappings.incrementTimestamp(handle);
456#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
457 } else {
458 auto it = mappings.values.find(handle);
459 if (it == mappings.values.end())
460 continue;
461
462 SmallVector<Value> &association = it->getSecond();
463 for (Value &mapped : association) {
464 if (mapped == value)
465 mapped = replacement;
466 }
467 mappings.reverseValues[replacement].push_back(handle);
468 }
469 }
470
471 return success();
472}
473
474void transform::TransformState::recordOpHandleInvalidationOne(
475 OpOperand &consumingHandle, ArrayRef<Operation *> potentialAncestors,
476 Operation *payloadOp, Value otherHandle, Value throughValue,
477 transform::TransformState::InvalidatedHandleMap &newlyInvalidated) const {
478 // If the op is associated with invalidated handle, skip the check as it
479 // may be reading invalid IR. This also ensures we report the first
480 // invalidation and not the last one.
481 if (invalidatedHandles.count(otherHandle) ||
482 newlyInvalidated.count(otherHandle))
483 return;
484
485 FULL_LDBG() << "--recordOpHandleInvalidationOne";
486 FULL_LDBG() << "--ancestors: "
487 << llvm::interleaved(
488 llvm::make_pointee_range(potentialAncestors));
489
490 Operation *owner = consumingHandle.getOwner();
491 unsigned operandNo = consumingHandle.getOperandNumber();
492 for (Operation *ancestor : potentialAncestors) {
493 // clang-format off
494 FULL_LDBG() << "----handle one ancestor: " << *ancestor;;
495
496 FULL_LDBG() << "----of payload with name: "
497 << payloadOp->getName().getIdentifier();
498 FULL_LDBG() << "----of payload: " << *payloadOp;
499 // clang-format on
500 if (!ancestor->isAncestor(payloadOp))
501 continue;
502
503 // Make sure the error-reporting lambda doesn't capture anything
504 // by-reference because it will go out of scope. Additionally, extract
505 // location from Payload IR ops because the ops themselves may be
506 // deleted before the lambda gets called.
507 Location ancestorLoc = ancestor->getLoc();
508 Location opLoc = payloadOp->getLoc();
509 std::optional<Location> throughValueLoc =
510 throughValue ? std::make_optional(throughValue.getLoc()) : std::nullopt;
511 newlyInvalidated[otherHandle] = [ancestorLoc, opLoc, owner, operandNo,
512 otherHandle,
513 throughValueLoc](Location currentLoc) {
514 InFlightDiagnostic diag = emitError(currentLoc)
515 << "op uses a handle invalidated by a "
516 "previously executed transform op";
517 diag.attachNote(otherHandle.getLoc()) << "handle to invalidated ops";
518 diag.attachNote(owner->getLoc())
519 << "invalidated by this transform op that consumes its operand #"
520 << operandNo
521 << " and invalidates all handles to payload IR entities associated "
522 "with this operand and entities nested in them";
523 diag.attachNote(ancestorLoc) << "ancestor payload op";
524 diag.attachNote(opLoc) << "nested payload op";
525 if (throughValueLoc) {
526 diag.attachNote(*throughValueLoc)
527 << "consumed handle points to this payload value";
528 }
529 };
530 }
531}
532
533void transform::TransformState::recordValueHandleInvalidationByOpHandleOne(
534 OpOperand &opHandle, ArrayRef<Operation *> potentialAncestors,
535 Value payloadValue, Value valueHandle,
536 transform::TransformState::InvalidatedHandleMap &newlyInvalidated) const {
537 // If the op is associated with invalidated handle, skip the check as it
538 // may be reading invalid IR. This also ensures we report the first
539 // invalidation and not the last one.
540 if (invalidatedHandles.count(valueHandle) ||
541 newlyInvalidated.count(valueHandle))
542 return;
543
544 for (Operation *ancestor : potentialAncestors) {
545 Operation *definingOp;
546 std::optional<unsigned> resultNo;
547 unsigned argumentNo = std::numeric_limits<unsigned>::max();
548 unsigned blockNo = std::numeric_limits<unsigned>::max();
549 unsigned regionNo = std::numeric_limits<unsigned>::max();
550 if (auto opResult = llvm::dyn_cast<OpResult>(payloadValue)) {
551 definingOp = opResult.getOwner();
552 resultNo = opResult.getResultNumber();
553 } else {
554 auto arg = llvm::cast<BlockArgument>(payloadValue);
555 definingOp = arg.getParentBlock()->getParentOp();
556 argumentNo = arg.getArgNumber();
557 blockNo = arg.getOwner()->computeBlockNumber();
558 regionNo = arg.getOwner()->getParent()->getRegionNumber();
559 }
560 assert(definingOp && "expected the value to be defined by an op as result "
561 "or block argument");
562 if (!ancestor->isAncestor(definingOp))
563 continue;
564
565 Operation *owner = opHandle.getOwner();
566 unsigned operandNo = opHandle.getOperandNumber();
567 Location ancestorLoc = ancestor->getLoc();
568 Location opLoc = definingOp->getLoc();
569 Location valueLoc = payloadValue.getLoc();
570 newlyInvalidated[valueHandle] = [valueHandle, owner, operandNo, resultNo,
571 argumentNo, blockNo, regionNo, ancestorLoc,
572 opLoc, valueLoc](Location currentLoc) {
573 InFlightDiagnostic diag = emitError(currentLoc)
574 << "op uses a handle invalidated by a "
575 "previously executed transform op";
576 diag.attachNote(valueHandle.getLoc()) << "invalidated handle";
577 diag.attachNote(owner->getLoc())
578 << "invalidated by this transform op that consumes its operand #"
579 << operandNo
580 << " and invalidates all handles to payload IR entities "
581 "associated with this operand and entities nested in them";
582 diag.attachNote(ancestorLoc)
583 << "ancestor op associated with the consumed handle";
584 if (resultNo) {
585 diag.attachNote(opLoc)
586 << "op defining the value as result #" << *resultNo;
587 } else {
588 diag.attachNote(opLoc)
589 << "op defining the value as block argument #" << argumentNo
590 << " of block #" << blockNo << " in region #" << regionNo;
591 }
592 diag.attachNote(valueLoc) << "payload value";
593 };
594 }
595}
596
597void transform::TransformState::recordOpHandleInvalidation(
598 OpOperand &handle, ArrayRef<Operation *> potentialAncestors,
599 Value throughValue,
600 transform::TransformState::InvalidatedHandleMap &newlyInvalidated) const {
601
602 if (potentialAncestors.empty()) {
603 FULL_LDBG() << "----recording invalidation for empty handle: "
604 << handle.get();
605
606 Operation *owner = handle.getOwner();
607 unsigned operandNo = handle.getOperandNumber();
608 newlyInvalidated[handle.get()] = [owner, operandNo](Location currentLoc) {
609 InFlightDiagnostic diag = emitError(currentLoc)
610 << "op uses a handle associated with empty "
611 "payload and invalidated by a "
612 "previously executed transform op";
613 diag.attachNote(owner->getLoc())
614 << "invalidated by this transform op that consumes its operand #"
615 << operandNo;
616 };
617 return;
618 }
619
620 // Iterate over the mapping and invalidate aliasing handles. This is quite
621 // expensive and only necessary for error reporting in case of transform
622 // dialect misuse with dangling handles. Iteration over the handles is based
623 // on the assumption that the number of handles is significantly less than the
624 // number of IR objects (operations and values). Alternatively, we could walk
625 // the IR nested in each payload op associated with the given handle and look
626 // for handles associated with each operation and value.
627 for (const auto &[region, mapping] : llvm::reverse(mappings)) {
628 // Go over all op handle mappings and mark as invalidated any handle
629 // pointing to any of the payload ops associated with the given handle or
630 // any op nested in them.
631 for (const auto &[payloadOp, otherHandles] : mapping->reverse) {
632 for (Value otherHandle : otherHandles)
633 recordOpHandleInvalidationOne(handle, potentialAncestors, payloadOp,
634 otherHandle, throughValue,
635 newlyInvalidated);
636 }
637 // Go over all value handle mappings and mark as invalidated any handle
638 // pointing to any result of the payload op associated with the given handle
639 // or any op nested in them. Similarly invalidate handles to argument of
640 // blocks belonging to any region of any payload op associated with the
641 // given handle or any op nested in them.
642 for (const auto &[payloadValue, valueHandles] : mapping->reverseValues) {
643 for (Value valueHandle : valueHandles)
644 recordValueHandleInvalidationByOpHandleOne(handle, potentialAncestors,
645 payloadValue, valueHandle,
646 newlyInvalidated);
647 }
648
649 // Stop lookup when reaching a region that is isolated from above.
650 if (region->getParentOp()->hasTrait<OpTrait::IsIsolatedFromAbove>())
651 break;
652 }
653}
654
655void transform::TransformState::recordValueHandleInvalidation(
656 OpOperand &valueHandle,
657 transform::TransformState::InvalidatedHandleMap &newlyInvalidated) const {
658 // Invalidate other handles to the same value.
659 for (Value payloadValue : getPayloadValuesView(valueHandle.get())) {
660 SmallVector<Value> otherValueHandles;
661 (void)getHandlesForPayloadValue(payloadValue, otherValueHandles);
662 for (Value otherHandle : otherValueHandles) {
663 Operation *owner = valueHandle.getOwner();
664 unsigned operandNo = valueHandle.getOperandNumber();
665 Location valueLoc = payloadValue.getLoc();
666 newlyInvalidated[otherHandle] = [otherHandle, owner, operandNo,
667 valueLoc](Location currentLoc) {
668 InFlightDiagnostic diag = emitError(currentLoc)
669 << "op uses a handle invalidated by a "
670 "previously executed transform op";
671 diag.attachNote(otherHandle.getLoc()) << "invalidated handle";
672 diag.attachNote(owner->getLoc())
673 << "invalidated by this transform op that consumes its operand #"
674 << operandNo
675 << " and invalidates handles to the same values as associated with "
676 "it";
677 diag.attachNote(valueLoc) << "payload value";
678 };
679 }
680
681 if (auto opResult = llvm::dyn_cast<OpResult>(payloadValue)) {
682 Operation *payloadOp = opResult.getOwner();
683 recordOpHandleInvalidation(valueHandle, payloadOp, payloadValue,
684 newlyInvalidated);
685 } else {
686 auto arg = llvm::dyn_cast<BlockArgument>(payloadValue);
687 for (Operation &payloadOp : *arg.getOwner())
688 recordOpHandleInvalidation(valueHandle, &payloadOp, payloadValue,
689 newlyInvalidated);
690 }
691 }
692}
693
694/// Checks that the operation does not use invalidated handles as operands.
695/// Reports errors and returns failure if it does. Otherwise, invalidates the
696/// handles consumed by the operation as well as any handles pointing to payload
697/// IR operations nested in the operations associated with the consumed handles.
698LogicalResult transform::TransformState::checkAndRecordHandleInvalidationImpl(
699 transform::TransformOpInterface transform,
700 transform::TransformState::InvalidatedHandleMap &newlyInvalidated) const {
701 FULL_LDBG() << "--Start checkAndRecordHandleInvalidation";
702 auto memoryEffectsIface =
703 cast<MemoryEffectOpInterface>(transform.getOperation());
704 SmallVector<MemoryEffects::EffectInstance> effects;
705 memoryEffectsIface.getEffectsOnResource(
707
708 for (OpOperand &target : transform->getOpOperands()) {
709 FULL_LDBG() << "----iterate on handle: " << target.get();
710 // If the operand uses an invalidated handle, report it. If the operation
711 // allows handles to point to repeated payload operations, only report
712 // pre-existing invalidation errors. Otherwise, also report invalidations
713 // caused by the current transform operation affecting its other operands.
714 auto it = invalidatedHandles.find(target.get());
715 auto nit = newlyInvalidated.find(target.get());
716 if (it != invalidatedHandles.end()) {
717 FULL_LDBG() << "--End checkAndRecordHandleInvalidation, found already "
718 "invalidated -> FAILURE";
719 return it->getSecond()(transform->getLoc()), failure();
720 }
721 if (!transform.allowsRepeatedHandleOperands() &&
722 nit != newlyInvalidated.end()) {
723 FULL_LDBG() << "--End checkAndRecordHandleInvalidation, found newly "
724 "invalidated (by this op) -> FAILURE";
725 return nit->getSecond()(transform->getLoc()), failure();
726 }
727
728 // Invalidate handles pointing to the operations nested in the operation
729 // associated with the handle consumed by this operation.
730 auto consumesTarget = [&](const MemoryEffects::EffectInstance &effect) {
731 return isa<MemoryEffects::Free>(effect.getEffect()) &&
732 effect.getValue() == target.get();
733 };
734 if (llvm::any_of(effects, consumesTarget)) {
735 FULL_LDBG() << "----found consume effect";
736 if (llvm::isa<transform::TransformHandleTypeInterface>(
737 target.get().getType())) {
738 FULL_LDBG() << "----recordOpHandleInvalidation";
739 SmallVector<Operation *> payloadOps =
740 llvm::to_vector(getPayloadOps(target.get()));
741 recordOpHandleInvalidation(target, payloadOps, nullptr,
742 newlyInvalidated);
743 } else if (llvm::isa<transform::TransformValueHandleTypeInterface>(
744 target.get().getType())) {
745 FULL_LDBG() << "----recordValueHandleInvalidation";
746 recordValueHandleInvalidation(target, newlyInvalidated);
747 } else {
748 FULL_LDBG()
749 << "----not a TransformHandle -> SKIP AND DROP ON THE FLOOR";
750 }
751 } else {
752 FULL_LDBG() << "----no consume effect -> SKIP";
753 }
754 }
755
756 FULL_LDBG() << "--End checkAndRecordHandleInvalidation -> SUCCESS";
757 return success();
758}
759
760LogicalResult transform::TransformState::checkAndRecordHandleInvalidation(
761 transform::TransformOpInterface transform) {
762 InvalidatedHandleMap newlyInvalidated;
763 LogicalResult checkResult =
764 checkAndRecordHandleInvalidationImpl(transform, newlyInvalidated);
765 invalidatedHandles.insert(std::make_move_iterator(newlyInvalidated.begin()),
766 std::make_move_iterator(newlyInvalidated.end()));
767 return checkResult;
768}
769
770template <typename T>
771static DiagnosedSilenceableFailure
773 transform::TransformOpInterface transform,
774 unsigned operandNumber) {
775 DenseSet<T> seen;
776 for (T p : payload) {
777 if (!seen.insert(p).second) {
779 transform.emitSilenceableError()
780 << "a handle passed as operand #" << operandNumber
781 << " and consumed by this operation points to a payload "
782 "entity more than once";
783 if constexpr (std::is_pointer_v<T>)
784 diag.attachNote(p->getLoc()) << "repeated target op";
785 else
786 diag.attachNote(p.getLoc()) << "repeated target value";
787 return diag;
788 }
789 }
791}
792
793void transform::TransformState::compactOpHandles() {
794 for (Value handle : opHandlesToCompact) {
795 Mappings &mappings = getMapping(handle, /*allowOutOfScope=*/true);
796#if LLVM_ENABLE_ABI_BREAKING_CHECKS
797 if (llvm::is_contained(mappings.direct[handle], nullptr))
798 // Payload IR is removed from the mapping. This invalidates the respective
799 // iterators.
800 mappings.incrementTimestamp(handle);
801#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
802 llvm::erase(mappings.direct[handle], nullptr);
803 }
804 opHandlesToCompact.clear();
805}
806
807DiagnosedSilenceableFailure
809 LDBG() << "applying: "
810 << OpWithFlags(transform, OpPrintingFlags().skipRegions());
811 FULL_LDBG() << "Top-level payload before application:\n" << *getTopLevel();
812 llvm::scope_exit printOnFailureRAII([this] {
813 (void)this;
814 LDBG() << "Failing Top-level payload:\n"
816 OpPrintingFlags().printGenericOpForm());
817 });
818
819 // Set current transform op.
820 regionStack.back()->currentTransform = transform;
821
822 // Expensive checks to detect invalid transform IR.
823 if (options.getExpensiveChecksEnabled()) {
824 FULL_LDBG() << "ExpensiveChecksEnabled";
825 if (failed(checkAndRecordHandleInvalidation(transform)))
827
828 for (OpOperand &operand : transform->getOpOperands()) {
829 FULL_LDBG() << "iterate on handle: " << operand.get();
830 if (!isHandleConsumed(operand.get(), transform)) {
831 FULL_LDBG() << "--handle not consumed -> SKIP";
832 continue;
833 }
834 if (transform.allowsRepeatedHandleOperands()) {
835 FULL_LDBG() << "--op allows repeated handles -> SKIP";
836 continue;
837 }
838 FULL_LDBG() << "--handle is consumed";
839
840 Type operandType = operand.get().getType();
841 if (llvm::isa<TransformHandleTypeInterface>(operandType)) {
842 FULL_LDBG() << "--checkRepeatedConsumptionInOperand for Operation*";
845 getPayloadOpsView(operand.get()), transform,
846 operand.getOperandNumber());
847 if (!check.succeeded()) {
848 FULL_LDBG() << "----FAILED";
849 return check;
850 }
851 } else if (llvm::isa<TransformValueHandleTypeInterface>(operandType)) {
852 FULL_LDBG() << "--checkRepeatedConsumptionInOperand For Value";
855 getPayloadValuesView(operand.get()), transform,
856 operand.getOperandNumber());
857 if (!check.succeeded()) {
858 FULL_LDBG() << "----FAILED";
859 return check;
860 }
861 } else {
862 FULL_LDBG() << "--not a TransformHandle -> SKIP AND DROP ON THE FLOOR";
863 }
864 }
865 }
866
867 // Find which operands are consumed.
868 SmallVector<OpOperand *> consumedOperands =
869 transform.getConsumedHandleOpOperands();
870
871 // Remember the results of the payload ops associated with the consumed
872 // op handles or the ops defining the value handles so we can drop the
873 // association with them later. This must happen here because the
874 // transformation may destroy or mutate them so we cannot traverse the payload
875 // IR after that.
876 SmallVector<Value> origOpFlatResults;
877 SmallVector<Operation *> origAssociatedOps;
878 for (OpOperand *opOperand : consumedOperands) {
879 Value operand = opOperand->get();
880 if (llvm::isa<TransformHandleTypeInterface>(operand.getType())) {
881 for (Operation *payloadOp : getPayloadOps(operand)) {
882 llvm::append_range(origOpFlatResults, payloadOp->getResults());
883 }
884 continue;
885 }
886 if (llvm::isa<TransformValueHandleTypeInterface>(operand.getType())) {
887 for (Value payloadValue : getPayloadValuesView(operand)) {
888 if (llvm::isa<OpResult>(payloadValue)) {
889 origAssociatedOps.push_back(payloadValue.getDefiningOp());
890 continue;
891 }
892 llvm::append_range(
893 origAssociatedOps,
894 llvm::map_range(*llvm::cast<BlockArgument>(payloadValue).getOwner(),
895 [](Operation &op) { return &op; }));
896 }
897 continue;
898 }
901 << "unexpectedly consumed a value that is not a handle as operand #"
902 << opOperand->getOperandNumber();
903 diag.attachNote(operand.getLoc())
904 << "value defined here with type " << operand.getType();
905 return diag;
906 }
907
908 // Prepare rewriter and listener.
910 config.skipHandleFn = [&](Value handle) {
911 // Skip handle if it is dead.
912 auto scopeIt =
913 llvm::find_if(llvm::reverse(regionStack), [&](RegionScope *scope) {
914 return handle.getParentRegion() == scope->region;
915 });
916 assert(scopeIt != regionStack.rend() &&
917 "could not find region scope for handle");
918 RegionScope *scope = *scopeIt;
919 return llvm::all_of(handle.getUsers(), [&](Operation *user) {
920 return user == scope->currentTransform ||
921 happensBefore(user, scope->currentTransform);
922 });
923 };
925 config);
926 transform::TransformRewriter rewriter(transform->getContext(),
927 &trackingListener);
928
929 // Compute the result but do not short-circuit the silenceable failure case as
930 // we still want the handles to propagate properly so the "suppress" mode can
931 // proceed on a best effort basis.
932 transform::TransformResults results(transform->getNumResults());
933 DiagnosedSilenceableFailure result(transform.apply(rewriter, results, *this));
934 compactOpHandles();
935
936 // Error handling: fail if transform or listener failed.
937 DiagnosedSilenceableFailure trackingFailure =
938 trackingListener.checkAndResetError();
940 transform->hasAttr(FindPayloadReplacementOpInterface::
941 kSilenceTrackingFailuresAttrName)) {
942 // Only report failures for ReportTrackingListenerFailuresOpTrait ops. Also
943 // do not report failures if the above mentioned attribute is set.
944 if (trackingFailure.isSilenceableFailure())
945 (void)trackingFailure.silence();
946 trackingFailure = DiagnosedSilenceableFailure::success();
947 }
948 if (!trackingFailure.succeeded()) {
949 if (result.succeeded()) {
950 result = std::move(trackingFailure);
951 } else {
952 // Transform op errors have precedence, report those first.
953 if (result.isSilenceableFailure())
954 result.attachNote() << "tracking listener also failed: "
955 << trackingFailure.getMessage();
956 (void)trackingFailure.silence();
957 }
958 }
959 if (result.isDefiniteFailure())
960 return result;
961
962 // If a silenceable failure was produced, some results may be unset, set them
963 // to empty lists.
964 if (result.isSilenceableFailure())
966
967 // Remove the mapping for the operand if it is consumed by the operation. This
968 // allows us to catch use-after-free with assertions later on.
969 for (OpOperand *opOperand : consumedOperands) {
970 Value operand = opOperand->get();
971 if (llvm::isa<TransformHandleTypeInterface>(operand.getType())) {
972 forgetMapping(operand, origOpFlatResults);
973 } else if (llvm::isa<TransformValueHandleTypeInterface>(
974 operand.getType())) {
975 forgetValueMapping(operand, origAssociatedOps);
976 }
977 }
978
979 if (failed(updateStateFromResults(results, transform->getResults())))
981
982 printOnFailureRAII.release();
983 DEBUG_WITH_TYPE(DEBUG_PRINT_AFTER_ALL, {
984 LDBG() << "Top-level payload:\n" << *getTopLevel();
985 });
986 return result;
987}
988
989LogicalResult transform::TransformState::updateStateFromResults(
990 const TransformResults &results, ResultRange opResults) {
991 for (OpResult result : opResults) {
992 if (llvm::isa<TransformParamTypeInterface>(result.getType())) {
993 assert(results.isParam(result.getResultNumber()) &&
994 "expected parameters for the parameter-typed result");
995 if (failed(
996 setParams(result, results.getParams(result.getResultNumber())))) {
997 return failure();
998 }
999 } else if (llvm::isa<TransformValueHandleTypeInterface>(result.getType())) {
1000 assert(results.isValue(result.getResultNumber()) &&
1001 "expected values for value-type-result");
1002 if (failed(setPayloadValues(
1003 result, results.getValues(result.getResultNumber())))) {
1004 return failure();
1005 }
1006 } else {
1007 assert(!results.isParam(result.getResultNumber()) &&
1008 "expected payload ops for the non-parameter typed result");
1009 if (failed(
1010 setPayloadOps(result, results.get(result.getResultNumber())))) {
1011 return failure();
1012 }
1013 }
1014 }
1015 return success();
1016}
1017
1018//===----------------------------------------------------------------------===//
1019// TransformState::Extension
1020//===----------------------------------------------------------------------===//
1021
1023
1024LogicalResult
1027 // TODO: we may need to invalidate handles to operations and values nested in
1028 // the operation being replaced.
1029 return state.replacePayloadOp(op, replacement);
1030}
1031
1032LogicalResult
1035 return state.replacePayloadValue(value, replacement);
1036}
1037
1038//===----------------------------------------------------------------------===//
1039// TransformState::RegionScope
1040//===----------------------------------------------------------------------===//
1041
1043 // Remove handle invalidation notices as handles are going out of scope.
1044 // The same region may be re-entered leading to incorrect invalidation
1045 // errors.
1046 for (Block &block : *region) {
1047 for (Value handle : block.getArguments()) {
1048 state.invalidatedHandles.erase(handle);
1049 }
1050 for (Operation &op : block) {
1051 for (Value handle : op.getResults()) {
1052 state.invalidatedHandles.erase(handle);
1053 }
1054 }
1055 }
1056
1057#if LLVM_ENABLE_ABI_BREAKING_CHECKS
1058 // Remember pointers to payload ops referenced by the handles going out of
1059 // scope.
1060 SmallVector<Operation *> referencedOps =
1061 llvm::to_vector(llvm::make_first_range(state.mappings[region]->reverse));
1062#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
1063
1064 state.mappings.erase(region);
1065 state.regionStack.pop_back();
1066}
1067
1068//===----------------------------------------------------------------------===//
1069// TransformResults
1070//===----------------------------------------------------------------------===//
1071
1072transform::TransformResults::TransformResults(unsigned numSegments) {
1073 operations.appendEmptyRows(numSegments);
1074 params.appendEmptyRows(numSegments);
1075 values.appendEmptyRows(numSegments);
1076}
1077
1080 int64_t position = value.getResultNumber();
1081 assert(position < static_cast<int64_t>(this->params.size()) &&
1082 "setting params for a non-existent handle");
1083 assert(this->params[position].data() == nullptr && "params already set");
1084 assert(operations[position].data() == nullptr &&
1085 "another kind of results already set");
1086 assert(values[position].data() == nullptr &&
1087 "another kind of results already set");
1088 this->params.replace(position, params);
1089}
1090
1092 OpResult handle, ArrayRef<MappedValue> values) {
1094 handle, values,
1095 [&](ArrayRef<Operation *> operations) {
1096 return set(handle, operations), success();
1097 },
1098 [&](ArrayRef<Param> params) {
1099 return setParams(handle, params), success();
1100 },
1101 [&](ValueRange payloadValues) {
1102 return setValues(handle, payloadValues), success();
1103 });
1104#ifndef NDEBUG
1105 if (!diag.succeeded())
1106 llvm::dbgs() << diag.getStatusString() << "\n";
1107 assert(diag.succeeded() && "incorrect mapping");
1108#endif // NDEBUG
1109 (void)diag.silence();
1110}
1111
1113 transform::TransformOpInterface transform) {
1114 for (OpResult opResult : transform->getResults()) {
1115 if (!isSet(opResult.getResultNumber()))
1116 setMappedValues(opResult, {});
1117 }
1118}
1119
1121transform::TransformResults::get(unsigned resultNumber) const {
1122 assert(resultNumber < operations.size() &&
1123 "querying results for a non-existent handle");
1124 assert(operations[resultNumber].data() != nullptr &&
1125 "querying unset results (values or params expected?)");
1126 return operations[resultNumber];
1127}
1128
1130transform::TransformResults::getParams(unsigned resultNumber) const {
1131 assert(resultNumber < params.size() &&
1132 "querying params for a non-existent handle");
1133 assert(params[resultNumber].data() != nullptr &&
1134 "querying unset params (ops or values expected?)");
1135 return params[resultNumber];
1136}
1137
1139transform::TransformResults::getValues(unsigned resultNumber) const {
1140 assert(resultNumber < values.size() &&
1141 "querying values for a non-existent handle");
1142 assert(values[resultNumber].data() != nullptr &&
1143 "querying unset values (ops or params expected?)");
1144 return values[resultNumber];
1145}
1146
1147bool transform::TransformResults::isParam(unsigned resultNumber) const {
1148 assert(resultNumber < params.size() &&
1149 "querying association for a non-existent handle");
1150 return params[resultNumber].data() != nullptr;
1151}
1152
1153bool transform::TransformResults::isValue(unsigned resultNumber) const {
1154 assert(resultNumber < values.size() &&
1155 "querying association for a non-existent handle");
1156 return values[resultNumber].data() != nullptr;
1157}
1158
1159bool transform::TransformResults::isSet(unsigned resultNumber) const {
1160 assert(resultNumber < params.size() &&
1161 "querying association for a non-existent handle");
1162 return params[resultNumber].data() != nullptr ||
1163 operations[resultNumber].data() != nullptr ||
1164 values[resultNumber].data() != nullptr;
1165}
1166
1167//===----------------------------------------------------------------------===//
1168// TrackingListener
1169//===----------------------------------------------------------------------===//
1170
1172 TransformOpInterface op,
1174 : TransformState::Extension(state), transformOp(op), config(config) {
1175 if (op) {
1176 for (OpOperand *opOperand : transformOp.getConsumedHandleOpOperands()) {
1177 consumedHandles.insert(opOperand->get());
1178 }
1179 }
1180}
1181
1183 Operation *defOp = nullptr;
1184 for (Value v : values) {
1185 // Skip empty values.
1186 if (!v)
1187 continue;
1188 if (!defOp) {
1189 defOp = v.getDefiningOp();
1190 continue;
1191 }
1192 if (defOp != v.getDefiningOp())
1193 return nullptr;
1194 }
1195 return defOp;
1196}
1197
1199 Operation *&result, Operation *op, ValueRange newValues) const {
1200 assert(op->getNumResults() == newValues.size() &&
1201 "invalid number of replacement values");
1202 SmallVector<Value> values(newValues.begin(), newValues.end());
1203
1205 getTransformOp(), "tracking listener failed to find replacement op "
1206 "during application of this transform op");
1207
1208 do {
1209 // If the replacement values belong to different ops, drop the mapping.
1210 Operation *defOp = getCommonDefiningOp(values);
1211 if (!defOp) {
1212 diag.attachNote() << "replacement values belong to different ops";
1213 return diag;
1214 }
1215
1216 // Skip through ops that implement CastOpInterface.
1217 if (config.skipCastOps && isa<CastOpInterface>(defOp)) {
1218 values.clear();
1219 values.assign(defOp->getOperands().begin(), defOp->getOperands().end());
1220 diag.attachNote(defOp->getLoc())
1221 << "using output of 'CastOpInterface' op";
1222 continue;
1223 }
1224
1225 // If the defining op has the same name or we do not care about the name of
1226 // op replacements at all, we take it as a replacement.
1227 if (!config.requireMatchingReplacementOpName ||
1228 op->getName() == defOp->getName()) {
1229 result = defOp;
1231 }
1232
1233 // Replacing an op with a constant-like equivalent is a common
1234 // canonicalization.
1235 if (defOp->hasTrait<OpTrait::ConstantLike>()) {
1236 result = defOp;
1238 }
1239
1240 values.clear();
1241
1242 // Skip through ops that implement FindPayloadReplacementOpInterface.
1243 if (auto findReplacementOpInterface =
1244 dyn_cast<FindPayloadReplacementOpInterface>(defOp)) {
1245 values.assign(findReplacementOpInterface.getNextOperands());
1246 diag.attachNote(defOp->getLoc()) << "using operands provided by "
1247 "'FindPayloadReplacementOpInterface'";
1248 continue;
1249 }
1250 } while (!values.empty());
1251
1252 diag.attachNote() << "ran out of suitable replacement values";
1253 return diag;
1254}
1255
1257 Location loc, function_ref<void(Diagnostic &)> reasonCallback) {
1258 LLVM_DEBUG({
1260 reasonCallback(diag);
1261 LDBG() << "Match Failure : " << diag.str();
1262 });
1263}
1264
1265void transform::TrackingListener::notifyOperationErased(Operation *op) {
1266 // Remove mappings for result values.
1267 for (OpResult value : op->getResults())
1268 (void)replacePayloadValue(value, nullptr);
1269 // Remove mapping for op.
1270 (void)replacePayloadOp(op, nullptr);
1271}
1272
1273void transform::TrackingListener::notifyOperationReplaced(
1274 Operation *op, ValueRange newValues) {
1275 assert(op->getNumResults() == newValues.size() &&
1276 "invalid number of replacement values");
1277
1278 // Replace value handles.
1279 for (auto [oldValue, newValue] : llvm::zip(op->getResults(), newValues))
1280 (void)replacePayloadValue(oldValue, newValue);
1281
1282 // Replace op handle.
1283 SmallVector<Value> opHandles;
1284 if (failed(getTransformState().getHandlesForPayloadOp(
1285 op, opHandles, /*includeOutOfScope=*/true))) {
1286 // Op is not tracked.
1287 return;
1288 }
1289
1290 // Helper function to check if the current transform op consumes any handle
1291 // that is mapped to `op`.
1292 //
1293 // Note: If a handle was consumed, there shouldn't be any alive users, so it
1294 // is not really necessary to check for consumed handles. However, in case
1295 // there are indeed alive handles that were consumed (which is undefined
1296 // behavior) and a replacement op could not be found, we want to fail with a
1297 // nicer error message: "op uses a handle invalidated..." instead of "could
1298 // not find replacement op". This nicer error is produced later.
1299 auto handleWasConsumed = [&] {
1300 return llvm::any_of(opHandles,
1301 [&](Value h) { return consumedHandles.contains(h); });
1302 };
1303
1304 // Check if there are any handles that must be updated.
1305 Value aliveHandle;
1306 if (config.skipHandleFn) {
1307 auto *it = llvm::find_if(opHandles,
1308 [&](Value v) { return !config.skipHandleFn(v); });
1309 if (it != opHandles.end())
1310 aliveHandle = *it;
1311 } else if (!opHandles.empty()) {
1312 aliveHandle = opHandles.front();
1313 }
1314 if (!aliveHandle || handleWasConsumed()) {
1315 // The op is tracked but the corresponding handles are dead or were
1316 // consumed. Drop the op form the mapping.
1317 (void)replacePayloadOp(op, nullptr);
1318 return;
1319 }
1320
1321 Operation *replacement;
1322 DiagnosedSilenceableFailure diag =
1323 findReplacementOp(replacement, op, newValues);
1324 // If the op is tracked but no replacement op was found, send a
1325 // notification.
1326 if (!diag.succeeded()) {
1327 diag.attachNote(aliveHandle.getLoc())
1328 << "replacement is required because this handle must be updated";
1329 notifyPayloadReplacementNotFound(op, newValues, std::move(diag));
1330 (void)replacePayloadOp(op, nullptr);
1331 return;
1332 }
1333
1334 (void)replacePayloadOp(op, replacement);
1335}
1336
1338 // The state of the ErrorCheckingTrackingListener must be checked and reset
1339 // if there was an error. This is to prevent errors from accidentally being
1340 // missed.
1341 assert(status.succeeded() && "listener state was not checked");
1342}
1343
1346 DiagnosedSilenceableFailure s = std::move(status);
1348 errorCounter = 0;
1349 return s;
1350}
1351
1353 return !status.succeeded();
1354}
1355
1358
1359 // Merge potentially existing diags and store the result in the listener.
1361 diag.takeDiagnostics(diags);
1362 if (!status.succeeded())
1363 status.takeDiagnostics(diags);
1364 status = DiagnosedSilenceableFailure::silenceableFailure(std::move(diags));
1365
1366 // Report more details.
1367 status.attachNote(op->getLoc()) << "[" << errorCounter << "] replaced op";
1368 for (auto &&[index, value] : llvm::enumerate(values))
1369 status.attachNote(value.getLoc())
1370 << "[" << errorCounter << "] replacement value " << index;
1371 ++errorCounter;
1372}
1373
1374std::string
1376 if (!matchFailure) {
1377 return "";
1378 }
1379 return matchFailure->str();
1380}
1381
1383 Location loc, function_ref<void(Diagnostic &)> reasonCallback) {
1385 reasonCallback(diag);
1386 matchFailure = std::move(diag);
1387}
1388
1389//===----------------------------------------------------------------------===//
1390// TransformRewriter
1391//===----------------------------------------------------------------------===//
1392
1395 : RewriterBase(ctx), listener(listener) {
1396 setListener(listener);
1397}
1398
1400 return listener->failed();
1401}
1402
1403/// Silence all tracking failures that have been encountered so far.
1405 if (hasTrackingFailures()) {
1406 DiagnosedSilenceableFailure status = listener->checkAndResetError();
1407 (void)status.silence();
1408 }
1409}
1410
1413 return listener->replacePayloadOp(op, replacement);
1414}
1415
1416//===----------------------------------------------------------------------===//
1417// Utilities for TransformEachOpTrait.
1418//===----------------------------------------------------------------------===//
1419
1420LogicalResult
1422 ArrayRef<Operation *> targets) {
1423 for (auto &&[position, parent] : llvm::enumerate(targets)) {
1424 for (Operation *child : targets.drop_front(position + 1)) {
1425 if (parent->isAncestor(child)) {
1427 emitError(loc)
1428 << "transform operation consumes a handle pointing to an ancestor "
1429 "payload operation before its descendant";
1430 diag.attachNote()
1431 << "the ancestor is likely erased or rewritten before the "
1432 "descendant is accessed, leading to undefined behavior";
1433 diag.attachNote(parent->getLoc()) << "ancestor payload op";
1434 diag.attachNote(child->getLoc()) << "descendant payload op";
1435 return diag;
1436 }
1437 }
1438 }
1439 return success();
1440}
1441
1442LogicalResult
1444 Location payloadOpLoc,
1445 const ApplyToEachResultList &partialResult) {
1446 Location transformOpLoc = transformOp->getLoc();
1447 StringRef transformOpName = transformOp->getName().getStringRef();
1448 unsigned expectedNumResults = transformOp->getNumResults();
1449
1450 // Reuse the emission of the diagnostic note.
1451 auto emitDiag = [&]() {
1452 auto diag = mlir::emitError(transformOpLoc);
1453 diag.attachNote(payloadOpLoc) << "when applied to this op";
1454 return diag;
1455 };
1456
1457 if (partialResult.size() != expectedNumResults) {
1458 auto diag = emitDiag() << "application of " << transformOpName
1459 << " expected to produce " << expectedNumResults
1460 << " results (actually produced "
1461 << partialResult.size() << ").";
1462 diag.attachNote(transformOpLoc)
1463 << "if you need variadic results, consider a generic `apply` "
1464 << "instead of the specialized `applyToOne`.";
1465 return failure();
1466 }
1467
1468 // Check that the right kind of value was produced.
1469 for (const auto &[ptr, res] :
1470 llvm::zip(partialResult, transformOp->getResults())) {
1471 if (ptr.isNull())
1472 continue;
1473 if (llvm::isa<TransformHandleTypeInterface>(res.getType()) &&
1474 !isa<Operation *>(ptr)) {
1475 return emitDiag() << "application of " << transformOpName
1476 << " expected to produce an Operation * for result #"
1477 << res.getResultNumber();
1478 }
1479 if (llvm::isa<TransformParamTypeInterface>(res.getType()) &&
1480 !isa<Attribute>(ptr)) {
1481 return emitDiag() << "application of " << transformOpName
1482 << " expected to produce an Attribute for result #"
1483 << res.getResultNumber();
1484 }
1485 if (llvm::isa<TransformValueHandleTypeInterface>(res.getType()) &&
1486 !isa<Value>(ptr)) {
1487 return emitDiag() << "application of " << transformOpName
1488 << " expected to produce a Value for result #"
1489 << res.getResultNumber();
1490 }
1491 }
1492 return success();
1493}
1494
1495template <typename T>
1497 return llvm::map_to_vector(range, llvm::CastTo<T>);
1498}
1499
1501 Operation *transformOp, TransformResults &transformResults,
1504 transposed.resize(transformOp->getNumResults());
1505 for (const ApplyToEachResultList &partialResults : results) {
1506 if (llvm::any_of(partialResults,
1507 [](MappedValue value) { return value.isNull(); }))
1508 continue;
1509 assert(transformOp->getNumResults() == partialResults.size() &&
1510 "expected as many partial results as op as results");
1511 for (auto [i, value] : llvm::enumerate(partialResults))
1512 transposed[i].push_back(value);
1513 }
1514
1515 for (OpResult r : transformOp->getResults()) {
1516 unsigned position = r.getResultNumber();
1517 if (llvm::isa<TransformParamTypeInterface>(r.getType())) {
1518 transformResults.setParams(r,
1519 castVector<Attribute>(transposed[position]));
1520 } else if (llvm::isa<TransformValueHandleTypeInterface>(r.getType())) {
1521 transformResults.setValues(r, castVector<Value>(transposed[position]));
1522 } else {
1523 transformResults.set(r, castVector<Operation *>(transposed[position]));
1524 }
1525 }
1526}
1527
1528//===----------------------------------------------------------------------===//
1529// Utilities for implementing transform ops with regions.
1530//===----------------------------------------------------------------------===//
1531
1534 ValueRange values, const transform::TransformState &state, bool flatten) {
1535 assert(mappings.size() == values.size() && "mismatching number of mappings");
1536 for (auto &&[operand, mapped] : llvm::zip_equal(values, mappings)) {
1537 size_t mappedSize = mapped.size();
1538 if (llvm::isa<TransformHandleTypeInterface>(operand.getType())) {
1539 llvm::append_range(mapped, state.getPayloadOps(operand));
1540 } else if (llvm::isa<TransformValueHandleTypeInterface>(
1541 operand.getType())) {
1542 llvm::append_range(mapped, state.getPayloadValues(operand));
1543 } else {
1544 assert(llvm::isa<TransformParamTypeInterface>(operand.getType()) &&
1545 "unsupported kind of transform dialect value");
1546 llvm::append_range(mapped, state.getParams(operand));
1547 }
1548
1549 if (mapped.size() - mappedSize != 1 && !flatten)
1550 return failure();
1551 }
1552 return success();
1553}
1554
1557 ValueRange values, const transform::TransformState &state) {
1558 mappings.resize(mappings.size() + values.size());
1561 values.size()),
1562 values, state);
1563}
1564
1566 Block *block, transform::TransformState &state,
1567 transform::TransformResults &results) {
1568 for (auto &&[terminatorOperand, result] :
1569 llvm::zip(block->getTerminator()->getOperands(),
1570 block->getParentOp()->getOpResults())) {
1571 if (llvm::isa<transform::TransformHandleTypeInterface>(result.getType())) {
1572 results.set(result, state.getPayloadOps(terminatorOperand));
1573 } else if (llvm::isa<transform::TransformValueHandleTypeInterface>(
1574 result.getType())) {
1575 results.setValues(result, state.getPayloadValues(terminatorOperand));
1576 } else {
1577 assert(
1578 llvm::isa<transform::TransformParamTypeInterface>(result.getType()) &&
1579 "unhandled transform type interface");
1580 results.setParams(result, state.getParams(terminatorOperand));
1581 }
1582 }
1583}
1584
1587 Operation *payloadRoot) {
1588 return TransformState(region, payloadRoot);
1589}
1590
1591//===----------------------------------------------------------------------===//
1592// Utilities for PossibleTopLevelTransformOpTrait.
1593//===----------------------------------------------------------------------===//
1594
1595/// Appends to `effects` the memory effect instances on `target` with the same
1596/// resource and effect as the ones the operation `iface` having on `source`.
1597static void
1598remapEffects(MemoryEffectOpInterface iface, BlockArgument source,
1602 iface.getEffectsOnValue(source, nestedEffects);
1603 for (const auto &effect : nestedEffects)
1604 effects.emplace_back(effect.getEffect(), target, effect.getResource());
1605}
1606
1607/// Appends to `effects` the same effects as the operations of `block` have on
1608/// block arguments but associated with `operands.`
1609static void
1612 for (Operation &op : block) {
1613 auto iface = dyn_cast<MemoryEffectOpInterface>(&op);
1614 if (!iface)
1615 continue;
1616
1617 for (auto &&[source, target] : llvm::zip(block.getArguments(), operands)) {
1618 remapEffects(iface, source, &target, effects);
1619 }
1620
1622 iface.getEffectsOnResource(transform::PayloadIRResource::get(),
1623 nestedEffects);
1624 llvm::append_range(effects, nestedEffects);
1625 }
1626}
1627
1629 Operation *operation, Value root, Block &body,
1631 transform::onlyReadsHandle(operation->getOpOperands(), effects);
1632 transform::producesHandle(operation->getOpResults(), effects);
1633
1634 if (!root) {
1635 for (Operation &op : body) {
1636 auto iface = dyn_cast<MemoryEffectOpInterface>(&op);
1637 if (!iface)
1638 continue;
1639
1640 iface.getEffects(effects);
1641 }
1642 return;
1643 }
1644
1645 // Carry over all effects on arguments of the entry block as those on the
1646 // operands, this is the same value just remapped.
1647 remapArgumentEffects(body, operation->getOpOperands(), effects);
1648}
1649
1651 TransformState &state, Operation *op, Region &region) {
1654 if (op->getNumOperands() != 0) {
1655 llvm::append_range(targets, state.getPayloadOps(op->getOperand(0)));
1656 prepareValueMappings(extraMappings, op->getOperands().drop_front(), state);
1657 } else {
1658 if (state.getNumTopLevelMappings() !=
1659 region.front().getNumArguments() - 1) {
1660 return emitError(op->getLoc())
1661 << "operation expects " << region.front().getNumArguments() - 1
1662 << " extra value bindings, but " << state.getNumTopLevelMappings()
1663 << " were provided to the interpreter";
1664 }
1665
1666 targets.push_back(state.getTopLevel());
1667
1668 for (unsigned i = 0, e = state.getNumTopLevelMappings(); i < e; ++i)
1669 extraMappings.push_back(llvm::to_vector(state.getTopLevelMapping(i)));
1670 }
1671
1672 if (failed(state.mapBlockArguments(region.front().getArgument(0), targets)))
1673 return failure();
1674
1675 for (BlockArgument argument : region.front().getArguments().drop_front()) {
1676 if (failed(state.mapBlockArgument(
1677 argument, extraMappings[argument.getArgNumber() - 1])))
1678 return failure();
1679 }
1680
1681 return success();
1682}
1683
1684LogicalResult
1686 // Attaching this trait without the interface is a misuse of the API, but it
1687 // cannot be caught via a static_assert because interface registration is
1688 // dynamic.
1689 assert(isa<TransformOpInterface>(op) &&
1690 "should implement TransformOpInterface to have "
1691 "PossibleTopLevelTransformOpTrait");
1692
1693 if (op->getNumRegions() < 1)
1694 return op->emitOpError() << "expects at least one region";
1695
1696 Region *bodyRegion = &op->getRegion(0);
1697 if (!llvm::hasNItems(*bodyRegion, 1))
1698 return op->emitOpError() << "expects a single-block region";
1699
1700 Block *body = &bodyRegion->front();
1701 if (body->getNumArguments() == 0) {
1702 return op->emitOpError()
1703 << "expects the entry block to have at least one argument";
1704 }
1705 if (!llvm::isa<TransformHandleTypeInterface>(
1706 body->getArgument(0).getType())) {
1707 return op->emitOpError()
1708 << "expects the first entry block argument to be of type "
1709 "implementing TransformHandleTypeInterface";
1710 }
1711 BlockArgument arg = body->getArgument(0);
1712 if (op->getNumOperands() != 0) {
1713 if (arg.getType() != op->getOperand(0).getType()) {
1714 return op->emitOpError()
1715 << "expects the type of the block argument to match "
1716 "the type of the operand";
1717 }
1718 }
1719 for (BlockArgument arg : body->getArguments().drop_front()) {
1720 if (llvm::isa<TransformHandleTypeInterface, TransformParamTypeInterface,
1721 TransformValueHandleTypeInterface>(arg.getType()))
1722 continue;
1723
1725 op->emitOpError()
1726 << "expects trailing entry block arguments to be of type implementing "
1727 "TransformHandleTypeInterface, TransformValueHandleTypeInterface or "
1728 "TransformParamTypeInterface";
1729 diag.attachNote() << "argument #" << arg.getArgNumber() << " does not";
1730 return diag;
1731 }
1732
1733 if (auto *parent =
1735 if (op->getNumOperands() != body->getNumArguments()) {
1737 op->emitOpError()
1738 << "expects operands to be provided for a nested op";
1739 diag.attachNote(parent->getLoc())
1740 << "nested in another possible top-level op";
1741 return diag;
1742 }
1743 }
1744
1745 return success();
1746}
1747
1748//===----------------------------------------------------------------------===//
1749// Utilities for ParamProducedTransformOpTrait.
1750//===----------------------------------------------------------------------===//
1751
1754 producesHandle(op->getResults(), effects);
1755 bool hasPayloadOperands = false;
1756 for (OpOperand &operand : op->getOpOperands()) {
1757 onlyReadsHandle(operand, effects);
1758 if (llvm::isa<TransformHandleTypeInterface,
1759 TransformValueHandleTypeInterface>(operand.get().getType()))
1760 hasPayloadOperands = true;
1761 }
1762 if (hasPayloadOperands)
1763 onlyReadsPayload(effects);
1764}
1765
1766LogicalResult
1768 // Interfaces can be attached dynamically, so this cannot be a static
1769 // assert.
1770 if (!op->getName().getInterface<MemoryEffectOpInterface>()) {
1771 llvm::report_fatal_error(
1772 Twine("ParamProducerTransformOpTrait must be attached to an op that "
1773 "implements MemoryEffectsOpInterface, found on ") +
1774 op->getName().getStringRef());
1775 }
1776 for (Value result : op->getResults()) {
1777 if (llvm::isa<TransformParamTypeInterface>(result.getType()))
1778 continue;
1779 return op->emitOpError()
1780 << "ParamProducerTransformOpTrait attached to this op expects "
1781 "result types to implement TransformParamTypeInterface";
1782 }
1783 return success();
1784}
1785
1786//===----------------------------------------------------------------------===//
1787// Memory effects.
1788//===----------------------------------------------------------------------===//
1789
1793 for (OpOperand &handle : handles) {
1794 effects.emplace_back(MemoryEffects::Read::get(), &handle,
1796 effects.emplace_back(MemoryEffects::Free::get(), &handle,
1798 }
1799}
1800
1801/// Returns `true` if the given list of effects instances contains an instance
1802/// with the effect type specified as template parameter.
1803template <typename EffectTy, typename ResourceTy, typename Range>
1804static bool hasEffect(Range &&effects) {
1805 return llvm::any_of(effects, [](const MemoryEffects::EffectInstance &effect) {
1806 return isa<EffectTy>(effect.getEffect()) &&
1807 isa<ResourceTy>(effect.getResource());
1808 });
1809}
1810
1812 transform::TransformOpInterface transform) {
1813 auto iface = cast<MemoryEffectOpInterface>(transform.getOperation());
1815 iface.getEffectsOnValue(handle, effects);
1816 return ::hasEffect<MemoryEffects::Read, TransformMappingResource>(effects) &&
1818}
1819
1821 ResultRange handles,
1823 for (OpResult handle : handles) {
1824 effects.emplace_back(MemoryEffects::Allocate::get(), handle,
1826 effects.emplace_back(MemoryEffects::Write::get(), handle,
1828 }
1829}
1830
1834 for (BlockArgument handle : handles) {
1835 effects.emplace_back(MemoryEffects::Allocate::get(), handle,
1837 effects.emplace_back(MemoryEffects::Write::get(), handle,
1839 }
1840}
1841
1845 for (OpOperand &handle : handles) {
1846 effects.emplace_back(MemoryEffects::Read::get(), &handle,
1848 }
1849}
1850
1856
1861
1862bool transform::doesModifyPayload(transform::TransformOpInterface transform) {
1863 auto iface = cast<MemoryEffectOpInterface>(transform.getOperation());
1865 iface.getEffects(effects);
1866 return ::hasEffect<MemoryEffects::Write, PayloadIRResource>(effects);
1867}
1868
1869bool transform::doesReadPayload(transform::TransformOpInterface transform) {
1870 auto iface = cast<MemoryEffectOpInterface>(transform.getOperation());
1872 iface.getEffects(effects);
1873 return ::hasEffect<MemoryEffects::Read, PayloadIRResource>(effects);
1874}
1875
1877 Block &block, llvm::SmallDenseSet<unsigned int> &consumedArguments) {
1879 for (Operation &nested : block) {
1880 auto iface = dyn_cast<MemoryEffectOpInterface>(nested);
1881 if (!iface)
1882 continue;
1883
1884 effects.clear();
1885 iface.getEffects(effects);
1886 for (const MemoryEffects::EffectInstance &effect : effects) {
1887 BlockArgument argument =
1888 dyn_cast_or_null<BlockArgument>(effect.getValue());
1889 if (!argument || argument.getOwner() != &block ||
1890 !isa<MemoryEffects::Free>(effect.getEffect()) ||
1891 effect.getResource() != transform::TransformMappingResource::get()) {
1892 continue;
1893 }
1894 consumedArguments.insert(argument.getArgNumber());
1895 }
1896 }
1897}
1898
1899//===----------------------------------------------------------------------===//
1900// Utilities for TransformOpInterface.
1901//===----------------------------------------------------------------------===//
1902
1903SmallVector<OpOperand *> transform::detail::getConsumedHandleOpOperands(
1904 TransformOpInterface transformOp) {
1905 SmallVector<OpOperand *> consumedOperands;
1906 consumedOperands.reserve(transformOp->getNumOperands());
1907 auto memEffectInterface =
1908 cast<MemoryEffectOpInterface>(transformOp.getOperation());
1909 SmallVector<MemoryEffects::EffectInstance, 2> effects;
1910 for (OpOperand &target : transformOp->getOpOperands()) {
1911 effects.clear();
1912 memEffectInterface.getEffectsOnValue(target.get(), effects);
1913 if (llvm::any_of(effects, [](const MemoryEffects::EffectInstance &effect) {
1914 return isa<transform::TransformMappingResource>(
1915 effect.getResource()) &&
1916 isa<MemoryEffects::Free>(effect.getEffect());
1917 })) {
1918 consumedOperands.push_back(&target);
1919 }
1920 }
1921 return consumedOperands;
1922}
1923
1925 auto iface = cast<MemoryEffectOpInterface>(op);
1927 iface.getEffects(effects);
1928
1929 auto effectsOn = [&](Value value) {
1930 return llvm::make_filter_range(
1931 effects, [value](const MemoryEffects::EffectInstance &instance) {
1932 return instance.getValue() == value;
1933 });
1934 };
1935
1936 std::optional<unsigned> firstConsumedOperand;
1937 for (OpOperand &operand : op->getOpOperands()) {
1938 auto range = effectsOn(operand.get());
1939 if (range.empty()) {
1941 op->emitError() << "TransformOpInterface requires memory effects "
1942 "on operands to be specified";
1943 diag.attachNote() << "no effects specified for operand #"
1944 << operand.getOperandNumber();
1945 return diag;
1946 }
1949 << "TransformOpInterface did not expect "
1950 "'allocate' memory effect on an operand";
1951 diag.attachNote() << "specified for operand #"
1952 << operand.getOperandNumber();
1953 return diag;
1954 }
1955 if (!firstConsumedOperand &&
1957 firstConsumedOperand = operand.getOperandNumber();
1958 }
1959 }
1960
1961 if (firstConsumedOperand &&
1964 op->emitError()
1965 << "TransformOpInterface expects ops consuming operands to have a "
1966 "'write' effect on the payload resource";
1967 diag.attachNote() << "consumes operand #" << *firstConsumedOperand;
1968 return diag;
1969 }
1970
1971 for (OpResult result : op->getResults()) {
1972 auto range = effectsOn(result);
1974 range)) {
1976 op->emitError() << "TransformOpInterface requires 'allocate' memory "
1977 "effect to be specified for results";
1978 diag.attachNote() << "no 'allocate' effect specified for result #"
1979 << result.getResultNumber();
1980 return diag;
1981 }
1982 }
1983
1984 return success();
1985}
1986
1987//===----------------------------------------------------------------------===//
1988// Entry point.
1989//===----------------------------------------------------------------------===//
1990
1992 Operation *payloadRoot, TransformOpInterface transform,
1993 const RaggedArray<MappedValue> &extraMapping,
1994 const TransformOptions &options, bool enforceToplevelTransformOp,
1995 function_ref<void(TransformState &)> stateInitializer,
1996 function_ref<LogicalResult(TransformState &)> stateExporter) {
1997 if (enforceToplevelTransformOp) {
1998 if (!transform->hasTrait<PossibleTopLevelTransformOpTrait>() ||
1999 transform->getNumOperands() != 0) {
2000 return transform->emitError()
2001 << "expected transform to start at the top-level transform op";
2002 }
2003 } else if (failed(
2005 return failure();
2006 }
2007
2008 TransformState state(transform->getParentRegion(), payloadRoot, extraMapping,
2009 options);
2010 if (stateInitializer)
2011 stateInitializer(state);
2012 if (state.applyTransform(transform).checkAndReport().failed())
2013 return failure();
2014 if (stateExporter)
2015 return stateExporter(state);
2016 return success();
2017}
2018
2019//===----------------------------------------------------------------------===//
2020// Generated interface implementation.
2021//===----------------------------------------------------------------------===//
2022
2023#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.cpp.inc"
2024#include "mlir/Dialect/Transform/Interfaces/TransformTypeInterfaces.cpp.inc"
for(Operation *op :ops)
return success()
static Operation * getCommonDefiningOp(const ValueVector &values)
Return the operation that defines all values in the vector.
static InFlightDiagnostic emitDiag(Location location, DiagnosticSeverity severity, const Twine &message)
Helper function used to emit a diagnostic with an optionally empty twine message.
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
if(!isCopyOut)
*if copies could not be generated due to yet unimplemented cases *copyInPlacementStart and copyOutPlacementStart in copyPlacementBlock *specify the insertion points where the incoming copies and outgoing should be inserted(the insertion happens right before the *insertion point). Since `begin` can itself be invalidated due to the memref *rewriting done from this method
*if copies could not be generated due to yet unimplemented cases *copyInPlacementStart and copyOutPlacementStart in copyPlacementBlock *specify the insertion points where the incoming copies and outgoing should be the output argument nBegin is set to its * replacement(set to `begin` if no invalidation happens). Since outgoing *copies could have been inserted at `end`
static std::string diag(const llvm::Value &value)
static llvm::ManagedStatic< PassManagerOptions > options
#define FULL_LDBG()
static void remapEffects(MemoryEffectOpInterface iface, BlockArgument source, OpOperand *target, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
Appends to effects the memory effect instances on target with the same resource and effect as the one...
static void remapArgumentEffects(Block &block, MutableArrayRef< OpOperand > operands, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
Appends to effects the same effects as the operations of block have on block arguments but associated...
static bool happensBefore(Operation *a, Operation *b)
Return true if a happens before b, i.e., a or one of its ancestors properly dominates b and b is not ...
static void dropMappingEntry(Mapping &mapping, Key key, Mapped mapped)
#define DEBUG_PRINT_AFTER_ALL
static DiagnosedSilenceableFailure checkRepeatedConsumptionInOperand(ArrayRef< T > payload, transform::TransformOpInterface transform, unsigned operandNumber)
static DiagnosedSilenceableFailure dispatchMappedValues(Value handle, ArrayRef< transform::MappedValue > values, function_ref< LogicalResult(ArrayRef< Operation * >)> operationsFn, function_ref< LogicalResult(ArrayRef< transform::Param >)> paramsFn, function_ref< LogicalResult(ValueRange)> valuesFn)
Given a list of MappedValues, cast them to the value kind implied by the interface of the handle type...
static SmallVector< T > castVector(ArrayRef< transform::MappedValue > range)
This class represents an argument of a Block.
Definition Value.h:309
unsigned getArgNumber() const
Returns the number of this argument.
Definition Value.h:321
Block * getOwner() const
Returns the block that owns this argument.
Definition Value.h:318
Block represents an ordered list of Operations.
Definition Block.h:33
MutableArrayRef< BlockArgument > BlockArgListType
Definition Block.h:95
Operation * findAncestorOpInBlock(Operation &op)
Returns 'op' if 'op' lies in this block, or otherwise finds the ancestor operation of 'op' that lies ...
Definition Block.cpp:74
BlockArgument getArgument(unsigned i)
Definition Block.h:139
unsigned getNumArguments()
Definition Block.h:138
Operation * getTerminator()
Get the terminator operation of this block.
Definition Block.cpp:249
BlockArgListType getArguments()
Definition Block.h:97
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
Definition Block.cpp:31
A compatibility class connecting InFlightDiagnostic to DiagnosedSilenceableFailure while providing an...
The result of a transform IR operation application.
LogicalResult silence()
Converts silenceable failure into LogicalResult success without reporting the diagnostic,...
static DiagnosedSilenceableFailure success()
Constructs a DiagnosedSilenceableFailure in the success state.
std::string getMessage() const
Returns the diagnostic message without emitting it.
LogicalResult checkAndReport()
Converts all kinds of failure into a LogicalResult failure, emitting the diagnostic if necessary.
static DiagnosedSilenceableFailure silenceableFailure(Diagnostic &&diag)
Constructs a DiagnosedSilenceableFailure in the silenceable failure state, ready to emit the given di...
bool succeeded() const
Returns true if this is a success.
static DiagnosedSilenceableFailure definiteFailure()
Constructs a DiagnosedSilenceableFailure in the failure state.
bool isSilenceableFailure() const
Returns true if this is a silenceable failure.
This class contains all of the information necessary to report a diagnostic to the DiagnosticEngine.
Diagnostic & attachNote(std::optional< Location > noteLoc=std::nullopt)
Attaches a note to this diagnostic.
IRValueT get() const
Return the current value being used by this operand.
This class represents a diagnostic that is inflight and set to be reported.
Diagnostic & attachNote(std::optional< Location > noteLoc=std::nullopt)
Attaches a note to this diagnostic.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
void setListener(Listener *newListener)
Sets the listener of this builder to the one provided.
Definition Builders.h:316
This class represents an operand of an operation.
Definition Value.h:257
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Definition Value.cpp:226
Set of flags used to control the behavior of the various IR print methods (e.g.
This is a value defined by a result of an operation.
Definition Value.h:457
unsigned getResultNumber() const
Returns the number of this result.
Definition Value.h:469
This class provides the API for a sub-set of ops that are known to be constant-like.
This class provides the API for ops that are known to be isolated from above.
A wrapper class that allows for printing an operation with a set of flags, useful to act as a "stream...
Definition Operation.h:1111
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
StringAttr getIdentifier() const
Return the name of this operation as a StringAttr.
T::Concept * getInterface() const
Returns an instance of the concept object for the given interface if it was registered to this operat...
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Definition Operation.h:686
Value getOperand(unsigned idx)
Definition Operation.h:350
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition Operation.h:749
bool isBeforeInBlock(Operation *other)
Given an operation 'other' that is within the same parent block, return whether the current operation...
Block * getBlock()
Returns the operation block that contains this operation.
Definition Operation.h:213
Operation * getParentWithTrait()
Returns the closest surrounding parent operation with trait Trait.
Definition Operation.h:248
unsigned getNumRegions()
Returns the number of regions held by this operation.
Definition Operation.h:674
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:223
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition Operation.h:234
MutableArrayRef< OpOperand > getOpOperands()
Definition Operation.h:383
unsigned getNumOperands()
Definition Operation.h:346
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
OperationName getName()
The name of an operation is the key identifier for it.
Definition Operation.h:119
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition Operation.h:378
result_range getOpResults()
Definition Operation.h:420
result_range getResults()
Definition Operation.h:415
bool isProperAncestor(Operation *other)
Return true if this operation is a proper ancestor of the other operation.
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
unsigned getNumResults()
Return the number of results held by this operation.
Definition Operation.h:404
A 2D array where each row may have different length.
Definition RaggedArray.h:18
size_t size() const
Returns the number of rows in the 2D array.
Definition RaggedArray.h:21
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition Region.h:26
Block & front()
Definition Region.h:65
Operation * getParentOp()
Return the parent operation this region is attached to.
Definition Region.h:200
This class implements the result iterators for the Operation class.
Definition ValueRange.h:247
RewriterBase(MLIRContext *ctx, OpBuilder::Listener *listener=nullptr)
Initialize the builder.
Resource * getResource() const
Return the resource that the effect applies to.
EffectT * getEffect() const
Return the effect being applied.
Value getValue() const
Return the value the effect is applied on, or nullptr if there isn't a known value being affected.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
Location getLoc() const
Return the location of this value.
Definition Value.cpp:24
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
Operation * getOwner() const
Return the owner of this operand.
Definition UseDefLists.h:38
A list of results of applying a transform op with ApplyEachOpTrait to a single payload operation,...
size_t size() const
Returns the number of elements in the list.
A specialized listener that keeps track of cases in which no replacement payload could be found.
bool failed() const
Return "true" if this tracking listener had a failure.
std::string getLatestMatchFailureMessage()
Return the latest match notification message.
void notifyPayloadReplacementNotFound(Operation *op, ValueRange values, DiagnosedSilenceableFailure &&diag) override
This function is called when a tracked payload op is dropped because no replacement op was found.
void notifyMatchFailure(Location loc, function_ref< void(Diagnostic &)> reasonCallback) override
Notify the listener that the pattern failed to match, and provide a callback to populate a diagnostic...
DiagnosedSilenceableFailure checkAndResetError()
Check and return the current error state of this listener.
This trait is supposed to be attached to Transform dialect operations that can be standalone top-leve...
TrackingListener failures are reported only for ops that have this trait.
void notifyMatchFailure(Location loc, function_ref< void(Diagnostic &)> reasonCallback) override
Notify the listener that the pattern failed to match the given operation, and provide a callback to p...
TransformOpInterface getTransformOp() const
Return the transform op in which this TrackingListener is used.
TrackingListener(TransformState &state, TransformOpInterface op, TrackingListenerConfig config=TrackingListenerConfig())
Create a new TrackingListener for usage in the specified transform op.
static Operation * getCommonDefiningOp(ValueRange values)
Return the single op that defines all given values (if any).
virtual DiagnosedSilenceableFailure findReplacementOp(Operation *&result, Operation *op, ValueRange newValues) const
Return a replacement payload op for the given op, which is going to be replaced with the given values...
Options controlling the application of transform operations by the TransformState.
Local mapping between values defined by a specific op implementing the TransformOpInterface and the p...
void setValues(OpResult handle, Range &&values)
Indicates that the result of the transform IR op at the given position corresponds to the given range...
void setParams(OpResult value, ArrayRef< TransformState::Param > params)
Indicates that the result of the transform IR op at the given position corresponds to the given list ...
void set(OpResult value, Range &&ops)
Indicates that the result of the transform IR op at the given position corresponds to the given list ...
void setRemainingToEmpty(TransformOpInterface transform)
Sets the currently unset results to empty lists of the kind expected by the corresponding results of ...
void setMappedValues(OpResult handle, ArrayRef< MappedValue > values)
Indicates that the result of the transform IR op at the given position corresponds to the given range...
This is a special rewriter to be used in transform op implementations, providing additional helper fu...
TransformRewriter(MLIRContext *ctx, ErrorCheckingTrackingListener *listener)
Create a new TransformRewriter.
bool hasTrackingFailures() const
Return "true" if the tracking listener had failures.
LogicalResult notifyPayloadOperationReplaced(Operation *op, Operation *replacement)
Notify the transform dialect interpreter that the given op has been replaced with another op and that...
void silenceTrackingFailure()
Silence all tracking failures that have been encountered so far.
Extension(TransformState &state)
Constructs an extension of the given TransformState object.
LogicalResult replacePayloadOp(Operation *op, Operation *replacement)
Replaces the given payload op with another op.
virtual ~Extension()
Base virtual destructor.
LogicalResult replacePayloadValue(Value value, Value replacement)
Replaces the given payload value with another value.
A RAII object maintaining a "stack frame" for a transform IR region.
~RegionScope()
Forgets the mapping from or to values defined in the associated transform IR region,...
The state maintained across applications of various ops implementing the TransformOpInterface.
LogicalResult getHandlesForPayloadValue(Value payloadValue, SmallVectorImpl< Value > &handles, bool includeOutOfScope=false) const
Populates handles with all handles pointing to the given payload IR value.
auto getPayloadOps(Value value) const
Returns an iterator that enumerates all ops that the given transform IR value corresponds to.
auto getPayloadValues(Value handleValue) const
Returns an iterator that enumerates all payload IR values that the given transform IR value correspon...
LogicalResult mapBlockArguments(BlockArgument argument, ArrayRef< Operation * > operations)
Records the mapping between a block argument in the transform IR and a list of operations in the payl...
DiagnosedSilenceableFailure applyTransform(TransformOpInterface transform)
Applies the transformation specified by the given transform op and updates the state accordingly.
ArrayRef< Attribute > getParams(Value value) const
Returns the list of parameters that the given transform IR value corresponds to.
LogicalResult mapBlockArgument(BlockArgument argument, ArrayRef< MappedValue > values)
size_t getNumTopLevelMappings() const
Returns the number of extra mappings for the top-level operation.
LogicalResult getHandlesForPayloadOp(Operation *op, SmallVectorImpl< Value > &handles, bool includeOutOfScope=false) const
Populates handles with all handles pointing to the given Payload IR op.
Operation * getTopLevel() const
Returns the op at which the transformation state is rooted.
ArrayRef< MappedValue > getTopLevelMapping(size_t position) const
Returns the position-th extra mapping for the top-level operation.
SideEffects::EffectInstance< Effect > EffectInstance
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:573
LogicalResult verifyTransformOpInterface(Operation *op)
Verification hook for TransformOpInterface.
SmallVector< OpOperand * > getConsumedHandleOpOperands(transform::TransformOpInterface transformOp)
Returns all operands that are handles and being consumed by the given op.
LogicalResult appendValueMappings(MutableArrayRef< SmallVector< transform::MappedValue > > mappings, ValueRange values, const transform::TransformState &state, bool flatten=true)
Appends the entities associated with the given transform values in state to the pre-existing list of ...
void setApplyToOneResults(Operation *transformOp, TransformResults &transformResults, ArrayRef< ApplyToEachResultList > results)
"Transpose" the results produced by individual applications, arranging them per result value of the t...
void forwardTerminatorOperands(Block *block, transform::TransformState &state, transform::TransformResults &results)
Populates results with payload associations that match exactly those of the operands to block's termi...
void getParamProducerTransformOpTraitEffects(Operation *op, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
Non-template implementation of ParamProducerTransformOpTrait::getEffects().
LogicalResult checkNestedConsumption(Location loc, ArrayRef< Operation * > targets)
Reports an error and returns failure if targets contains an ancestor operation before its descendant ...
LogicalResult mapPossibleTopLevelTransformOpBlockArguments(TransformState &state, Operation *op, Region &region)
Maps the only block argument of the op with PossibleTopLevelTransformOpTrait to either the list of op...
TransformState makeTransformStateForTesting(Region *region, Operation *payloadRoot)
Make a dummy transform state for testing purposes.
LogicalResult checkApplyToOne(Operation *transformOp, Location payloadOpLoc, const ApplyToEachResultList &partialResult)
Check that the contents of partialResult matches the number, kind (payload op or parameter) and nulli...
LogicalResult verifyParamProducerTransformOpTrait(Operation *op)
Non-template implementation of ParamProducerTransformOpTrait::verify().
void prepareValueMappings(SmallVectorImpl< SmallVector< transform::MappedValue > > &mappings, ValueRange values, const transform::TransformState &state)
Populates mappings with mapped values associated with the given transform IR values in the given stat...
LogicalResult verifyPossibleTopLevelTransformOpTrait(Operation *op)
Verification hook for PossibleTopLevelTransformOpTrait.
void getPotentialTopLevelEffects(Operation *operation, Value root, Block &body, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
Populates effects with side effects implied by PossibleTopLevelTransformOpTrait for the given operati...
void onlyReadsPayload(SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
LogicalResult applyTransforms(Operation *payloadRoot, TransformOpInterface transform, const RaggedArray< MappedValue > &extraMapping={}, const TransformOptions &options=TransformOptions(), bool enforceToplevelTransformOp=true, function_ref< void(TransformState &)> stateInitializer=nullptr, function_ref< LogicalResult(TransformState &)> stateExporter=nullptr)
Entry point to the Transform dialect infrastructure.
bool isHandleConsumed(Value handle, transform::TransformOpInterface transform)
Checks whether the transform op consumes the given handle.
llvm::PointerUnion< Operation *, Param, Value > MappedValue
void getConsumedBlockArguments(Block &block, llvm::SmallDenseSet< unsigned > &consumedArguments)
Populates consumedArguments with positions of block arguments that are consumed by the operations in ...
bool doesModifyPayload(transform::TransformOpInterface transform)
Checks whether the transform op modifies the payload.
void producesHandle(ResultRange handles, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
bool doesReadPayload(transform::TransformOpInterface transform)
Checks whether the transform op reads the payload.
void consumesHandle(MutableArrayRef< OpOperand > handles, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
Populates effects with the memory effects indicating the operation on the given handle value:
void onlyReadsHandle(MutableArrayRef< OpOperand > handles, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
void modifiesPayload(SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
Populates effects with the memory effects indicating the access to payload IR resource.
Include the generated interface declarations.
const FrozenRewritePatternSet GreedyRewriteConfig config
DiagnosedSilenceableFailure emitSilenceableFailure(Location loc, const Twine &message={})
Emits a silenceable failure with the given message.
llvm::DenseSet< ValueT, ValueInfoT > DenseSet
Definition LLVM.h:128
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
DiagnosedDefiniteFailure emitDefiniteFailure(Location loc, const Twine &message={})
Emits a definite failure with the given message.
bool hasEffect(Operation *op)
Returns "true" if op has an effect of type EffectTy.
llvm::function_ref< Fn > function_ref
Definition LLVM.h:152
Represents a range (offset, size, and stride) where each element of the triple may be dynamic or stat...
A configuration object for customizing a TrackingListener.