17#include "llvm/ADT/STLExtras.h"
18#include "llvm/ADT/ScopeExit.h"
19#include "llvm/ADT/iterator.h"
20#include "llvm/Support/Debug.h"
21#include "llvm/Support/DebugLog.h"
22#include "llvm/Support/ErrorHandling.h"
23#include "llvm/Support/InterleavedRange.h"
25#define DEBUG_TYPE "transform-dialect"
26#define DEBUG_PRINT_AFTER_ALL "transform-dialect-print-top-level-after-all"
27#define FULL_LDBG() LDBG(4)
52transform::TransformState::TransformState(
55 const TransformOptions &
options)
57 topLevelMappedValues.reserve(extraMappings.
size());
59 topLevelMappedValues.push_back(mapping);
61 RegionScope *scope =
new RegionScope(*
this, *region);
62 topLevelRegionScope.reset(scope);
69transform::TransformState::getPayloadOpsView(
Value value)
const {
70 const TransformOpMapping &operationMapping = getMapping(value).direct;
71 auto iter = operationMapping.find(value);
72 assert(iter != operationMapping.end() &&
73 "cannot find mapping for payload handle (param/value handle "
75 return iter->getSecond();
79 const ParamMapping &mapping = getMapping(value).params;
80 auto iter = mapping.find(value);
81 assert(iter != mapping.end() &&
"cannot find mapping for param handle "
82 "(operation/value handle provided?)");
83 return iter->getSecond();
87transform::TransformState::getPayloadValuesView(
Value handleValue)
const {
88 const ValueMapping &mapping = getMapping(handleValue).values;
89 auto iter = mapping.find(handleValue);
90 assert(iter != mapping.end() &&
"cannot find mapping for value handle "
91 "(param/operation handle provided?)");
92 return iter->getSecond();
97 bool includeOutOfScope)
const {
99 for (
const auto &[region, mapping] : llvm::reverse(mappings)) {
100 auto iterator = mapping->reverse.find(op);
101 if (iterator != mapping->reverse.end()) {
102 llvm::append_range(handles, iterator->getSecond());
106 if (!includeOutOfScope &&
116 bool includeOutOfScope)
const {
118 for (
const auto &[region, mapping] : llvm::reverse(mappings)) {
119 auto iterator = mapping->reverseValues.find(payloadValue);
120 if (iterator != mapping->reverseValues.end()) {
121 llvm::append_range(handles, iterator->getSecond());
125 if (!includeOutOfScope &&
140 if (llvm::isa<transform::TransformHandleTypeInterface>(handle.getType())) {
142 operations.reserve(values.size());
144 if (
auto *op = llvm::dyn_cast_if_present<Operation *>(value)) {
145 operations.push_back(op);
149 <<
"wrong kind of value provided for top-level operation handle";
151 if (failed(operationsFn(operations)))
156 if (llvm::isa<transform::TransformValueHandleTypeInterface>(
159 payloadValues.reserve(values.size());
161 if (
auto v = llvm::dyn_cast_if_present<Value>(value)) {
162 payloadValues.push_back(v);
166 <<
"wrong kind of value provided for the top-level value handle";
168 if (failed(valuesFn(payloadValues)))
173 assert(llvm::isa<transform::TransformParamTypeInterface>(handle.getType()) &&
174 "unsupported kind of block argument");
176 parameters.reserve(values.size());
178 if (
auto attr = llvm::dyn_cast_if_present<Attribute>(value)) {
179 parameters.push_back(attr);
183 <<
"wrong kind of value provided for top-level parameter";
185 if (failed(paramsFn(parameters)))
196 return setPayloadOps(argument, operations);
199 return setParams(argument, params);
202 return setPayloadValues(argument, payloadValues);
210 for (
auto &&[argument, values] : llvm::zip_equal(arguments, mapping))
217transform::TransformState::setPayloadOps(
Value value,
219 assert(value != kTopLevelValue &&
220 "attempting to reset the transformation root");
221 assert(llvm::isa<TransformHandleTypeInterface>(value.
getType()) &&
222 "wrong handle type");
228 <<
"attempting to assign a null payload op to this transform value";
231 auto iface = llvm::cast<TransformHandleTypeInterface>(value.
getType());
233 iface.checkPayload(value.
getLoc(), targets);
234 if (failed(
result.checkAndReport()))
240 Mappings &mappings = getMapping(value);
242 mappings.direct.insert({value, std::move(storedTargets)}).second;
243 assert(
inserted &&
"value is already associated with another list");
247 mappings.reverse[op].push_back(value);
253transform::TransformState::setPayloadValues(Value handle,
255 assert(handle !=
nullptr &&
"attempting to set params for a null value");
256 assert(llvm::isa<TransformValueHandleTypeInterface>(handle.
getType()) &&
257 "wrong handle type");
259 for (Value payload : payloadValues) {
262 return emitError(handle.
getLoc()) <<
"attempting to assign a null payload "
263 "value to this transform handle";
266 auto iface = llvm::cast<TransformValueHandleTypeInterface>(handle.
getType());
267 SmallVector<Value> payloadValueVector = llvm::to_vector(payloadValues);
268 DiagnosedSilenceableFailure
result =
269 iface.checkPayload(handle.
getLoc(), payloadValueVector);
273 Mappings &mappings = getMapping(handle);
275 mappings.values.insert({handle, std::move(payloadValueVector)}).second;
278 "value handle is already associated with another list of payload values");
281 for (Value payload : payloadValues)
282 mappings.reverseValues[payload].push_back(handle);
287LogicalResult transform::TransformState::setParams(Value value,
288 ArrayRef<Param> params) {
289 assert(value !=
nullptr &&
"attempting to set params for a null value");
291 for (Attribute attr : params) {
295 <<
"attempting to assign a null parameter to this transform value";
298 auto valueType = llvm::dyn_cast<TransformParamTypeInterface>(value.
getType());
300 "cannot associate parameter with a value of non-parameter type");
301 DiagnosedSilenceableFailure
result =
302 valueType.checkPayload(value.
getLoc(), params);
306 Mappings &mappings = getMapping(value);
308 mappings.params.insert({value, llvm::to_vector(params)}).second;
309 assert(
inserted &&
"value is already associated with another list of params");
314template <
typename Mapping,
typename Key,
typename Mapped>
316 auto it = mapping.find(key);
317 if (it == mapping.end())
320 llvm::erase(it->getSecond(), mapped);
321 if (it->getSecond().empty())
325void transform::TransformState::forgetMapping(Value opHandle,
327 bool allowOutOfScope) {
328 Mappings &mappings = getMapping(opHandle, allowOutOfScope);
329 for (Operation *op : mappings.direct[opHandle])
331 mappings.direct.erase(opHandle);
332#if LLVM_ENABLE_ABI_BREAKING_CHECKS
335 mappings.incrementTimestamp(opHandle);
338 for (Value opResult : origOpFlatResults) {
339 SmallVector<Value> resultHandles;
340 (void)getHandlesForPayloadValue(opResult, resultHandles);
341 for (Value resultHandle : resultHandles) {
342 Mappings &localMappings = getMapping(resultHandle);
344#if LLVM_ENABLE_ABI_BREAKING_CHECKS
347 mappings.incrementTimestamp(resultHandle);
354void transform::TransformState::forgetValueMapping(
355 Value valueHandle, ArrayRef<Operation *> payloadOperations) {
356 Mappings &mappings = getMapping(valueHandle);
357 for (Value payloadValue : mappings.reverseValues[valueHandle])
359 mappings.values.erase(valueHandle);
360#if LLVM_ENABLE_ABI_BREAKING_CHECKS
363 mappings.incrementTimestamp(valueHandle);
366 for (Operation *payloadOp : payloadOperations) {
367 SmallVector<Value> opHandles;
368 (void)getHandlesForPayloadOp(payloadOp, opHandles);
369 for (Value opHandle : opHandles) {
370 Mappings &localMappings = getMapping(opHandle);
374#if LLVM_ENABLE_ABI_BREAKING_CHECKS
377 localMappings.incrementTimestamp(opHandle);
384transform::TransformState::replacePayloadOp(Operation *op,
390 SmallVector<Value> valueHandles;
391 (void)getHandlesForPayloadValue(opResult, valueHandles,
393 assert(valueHandles.empty() &&
"expected no mapping to old results");
399 SmallVector<Value> opHandles;
400 if (
failed(getHandlesForPayloadOp(op, opHandles,
true)))
402 for (Value handle : opHandles) {
403 Mappings &mappings = getMapping(handle,
true);
416 for (Value handle : opHandles) {
417 Mappings &mappings = getMapping(handle,
true);
418 auto it = mappings.direct.find(handle);
419 if (it == mappings.direct.end())
422 SmallVector<Operation *, 2> &association = it->getSecond();
424 for (Operation *&mapped : association) {
432 opHandlesToCompact.insert(handle);
440transform::TransformState::replacePayloadValue(Value value, Value
replacement) {
441 SmallVector<Value> valueHandles;
442 if (
failed(getHandlesForPayloadValue(value, valueHandles,
446 for (Value handle : valueHandles) {
447 Mappings &mappings = getMapping(handle,
true);
454#if LLVM_ENABLE_ABI_BREAKING_CHECKS
457 mappings.incrementTimestamp(handle);
460 auto it = mappings.values.find(handle);
461 if (it == mappings.values.end())
464 SmallVector<Value> &association = it->getSecond();
465 for (Value &mapped : association) {
469 mappings.reverseValues[
replacement].push_back(handle);
476void transform::TransformState::recordOpHandleInvalidationOne(
477 OpOperand &consumingHandle, ArrayRef<Operation *> potentialAncestors,
478 Operation *payloadOp, Value otherHandle, Value throughValue,
479 transform::TransformState::InvalidatedHandleMap &newlyInvalidated)
const {
483 if (invalidatedHandles.count(otherHandle) ||
484 newlyInvalidated.count(otherHandle))
487 FULL_LDBG() <<
"--recordOpHandleInvalidationOne";
489 << llvm::interleaved(
490 llvm::make_pointee_range(potentialAncestors));
492 Operation *owner = consumingHandle.
getOwner();
494 for (Operation *ancestor : potentialAncestors) {
496 FULL_LDBG() <<
"----handle one ancestor: " << *ancestor;;
498 FULL_LDBG() <<
"----of payload with name: "
500 FULL_LDBG() <<
"----of payload: " << *payloadOp;
502 if (!ancestor->isAncestor(payloadOp))
509 Location ancestorLoc = ancestor->
getLoc();
510 Location opLoc = payloadOp->
getLoc();
511 std::optional<Location> throughValueLoc =
512 throughValue ? std::make_optional(throughValue.
getLoc()) : std::nullopt;
513 newlyInvalidated[otherHandle] = [ancestorLoc, opLoc, owner, operandNo,
515 throughValueLoc](Location currentLoc) {
517 <<
"op uses a handle invalidated by a "
518 "previously executed transform op";
519 diag.
attachNote(otherHandle.getLoc()) <<
"handle to invalidated ops";
521 <<
"invalidated by this transform op that consumes its operand #"
523 <<
" and invalidates all handles to payload IR entities associated "
524 "with this operand and entities nested in them";
527 if (throughValueLoc) {
528 diag.attachNote(*throughValueLoc)
529 <<
"consumed handle points to this payload value";
535void transform::TransformState::recordValueHandleInvalidationByOpHandleOne(
536 OpOperand &opHandle, ArrayRef<Operation *> potentialAncestors,
537 Value payloadValue, Value valueHandle,
538 transform::TransformState::InvalidatedHandleMap &newlyInvalidated)
const {
542 if (invalidatedHandles.count(valueHandle) ||
543 newlyInvalidated.count(valueHandle))
546 for (Operation *ancestor : potentialAncestors) {
547 Operation *definingOp;
548 std::optional<unsigned> resultNo;
549 unsigned argumentNo = std::numeric_limits<unsigned>::max();
550 unsigned blockNo = std::numeric_limits<unsigned>::max();
551 unsigned regionNo = std::numeric_limits<unsigned>::max();
552 if (
auto opResult = llvm::dyn_cast<OpResult>(payloadValue)) {
553 definingOp = opResult.getOwner();
554 resultNo = opResult.getResultNumber();
556 auto arg = llvm::cast<BlockArgument>(payloadValue);
558 argumentNo = arg.getArgNumber();
559 blockNo = arg.getOwner()->computeBlockNumber();
560 regionNo = arg.getOwner()->getParent()->getRegionNumber();
562 assert(definingOp &&
"expected the value to be defined by an op as result "
563 "or block argument");
564 if (!ancestor->isAncestor(definingOp))
567 Operation *owner = opHandle.
getOwner();
569 Location ancestorLoc = ancestor->getLoc();
570 Location opLoc = definingOp->
getLoc();
571 Location valueLoc = payloadValue.
getLoc();
572 newlyInvalidated[valueHandle] = [valueHandle, owner, operandNo, resultNo,
573 argumentNo, blockNo, regionNo, ancestorLoc,
574 opLoc, valueLoc](Location currentLoc) {
576 <<
"op uses a handle invalidated by a "
577 "previously executed transform op";
580 <<
"invalidated by this transform op that consumes its operand #"
582 <<
" and invalidates all handles to payload IR entities "
583 "associated with this operand and entities nested in them";
585 <<
"ancestor op associated with the consumed handle";
587 diag.attachNote(opLoc)
588 <<
"op defining the value as result #" << *resultNo;
590 diag.attachNote(opLoc)
591 <<
"op defining the value as block argument #" << argumentNo
592 <<
" of block #" << blockNo <<
" in region #" << regionNo;
599void transform::TransformState::recordOpHandleInvalidation(
600 OpOperand &handle, ArrayRef<Operation *> potentialAncestors,
602 transform::TransformState::InvalidatedHandleMap &newlyInvalidated)
const {
604 if (potentialAncestors.empty()) {
605 FULL_LDBG() <<
"----recording invalidation for empty handle: "
608 Operation *owner = handle.
getOwner();
610 newlyInvalidated[handle.
get()] = [owner, operandNo](Location currentLoc) {
612 <<
"op uses a handle associated with empty "
613 "payload and invalidated by a "
614 "previously executed transform op";
616 <<
"invalidated by this transform op that consumes its operand #"
629 for (
const auto &[region, mapping] : llvm::reverse(mappings)) {
633 for (
const auto &[payloadOp, otherHandles] : mapping->reverse) {
634 for (Value otherHandle : otherHandles)
635 recordOpHandleInvalidationOne(handle, potentialAncestors, payloadOp,
636 otherHandle, throughValue,
644 for (
const auto &[payloadValue, valueHandles] : mapping->reverseValues) {
645 for (Value valueHandle : valueHandles)
646 recordValueHandleInvalidationByOpHandleOne(handle, potentialAncestors,
647 payloadValue, valueHandle,
657void transform::TransformState::recordValueHandleInvalidation(
658 OpOperand &valueHandle,
659 transform::TransformState::InvalidatedHandleMap &newlyInvalidated)
const {
661 for (Value payloadValue : getPayloadValuesView(valueHandle.
get())) {
662 SmallVector<Value> otherValueHandles;
663 (void)getHandlesForPayloadValue(payloadValue, otherValueHandles);
664 for (Value otherHandle : otherValueHandles) {
665 Operation *owner = valueHandle.
getOwner();
667 Location valueLoc = payloadValue.
getLoc();
668 newlyInvalidated[otherHandle] = [otherHandle, owner, operandNo,
669 valueLoc](Location currentLoc) {
671 <<
"op uses a handle invalidated by a "
672 "previously executed transform op";
675 <<
"invalidated by this transform op that consumes its operand #"
677 <<
" and invalidates handles to the same values as associated with "
683 if (
auto opResult = llvm::dyn_cast<OpResult>(payloadValue)) {
684 Operation *payloadOp = opResult.getOwner();
685 recordOpHandleInvalidation(valueHandle, payloadOp, payloadValue,
688 auto arg = llvm::dyn_cast<BlockArgument>(payloadValue);
689 for (Operation &payloadOp : *arg.getOwner())
690 recordOpHandleInvalidation(valueHandle, &payloadOp, payloadValue,
700LogicalResult transform::TransformState::checkAndRecordHandleInvalidationImpl(
701 transform::TransformOpInterface transform,
702 transform::TransformState::InvalidatedHandleMap &newlyInvalidated)
const {
703 FULL_LDBG() <<
"--Start checkAndRecordHandleInvalidation";
704 auto memoryEffectsIface =
705 cast<MemoryEffectOpInterface>(transform.getOperation());
706 SmallVector<MemoryEffects::EffectInstance> effects;
707 memoryEffectsIface.getEffectsOnResource(
710 for (OpOperand &
target : transform->getOpOperands()) {
716 auto it = invalidatedHandles.find(
target.get());
717 auto nit = newlyInvalidated.find(
target.get());
718 if (it != invalidatedHandles.end()) {
719 FULL_LDBG() <<
"--End checkAndRecordHandleInvalidation, found already "
720 "invalidated -> FAILURE";
721 return it->getSecond()(transform->getLoc()), failure();
723 if (!transform.allowsRepeatedHandleOperands() &&
724 nit != newlyInvalidated.end()) {
725 FULL_LDBG() <<
"--End checkAndRecordHandleInvalidation, found newly "
726 "invalidated (by this op) -> FAILURE";
727 return nit->getSecond()(transform->getLoc()), failure();
733 return isa<MemoryEffects::Free>(effect.getEffect()) &&
734 effect.getValue() ==
target.get();
736 if (llvm::any_of(effects, consumesTarget)) {
737 FULL_LDBG() <<
"----found consume effect";
738 if (llvm::isa<transform::TransformHandleTypeInterface>(
739 target.get().getType())) {
740 FULL_LDBG() <<
"----recordOpHandleInvalidation";
741 SmallVector<Operation *> payloadOps =
742 llvm::to_vector(getPayloadOps(
target.get()));
743 recordOpHandleInvalidation(
target, payloadOps,
nullptr,
745 }
else if (llvm::isa<transform::TransformValueHandleTypeInterface>(
746 target.get().getType())) {
747 FULL_LDBG() <<
"----recordValueHandleInvalidation";
748 recordValueHandleInvalidation(
target, newlyInvalidated);
751 <<
"----not a TransformHandle -> SKIP AND DROP ON THE FLOOR";
754 FULL_LDBG() <<
"----no consume effect -> SKIP";
758 FULL_LDBG() <<
"--End checkAndRecordHandleInvalidation -> SUCCESS";
762LogicalResult transform::TransformState::checkAndRecordHandleInvalidation(
763 transform::TransformOpInterface transform) {
764 InvalidatedHandleMap newlyInvalidated;
765 LogicalResult checkResult =
766 checkAndRecordHandleInvalidationImpl(transform, newlyInvalidated);
767 invalidatedHandles.insert(std::make_move_iterator(newlyInvalidated.begin()),
768 std::make_move_iterator(newlyInvalidated.end()));
773static DiagnosedSilenceableFailure
775 transform::TransformOpInterface
transform,
776 unsigned operandNumber) {
778 for (T p : payload) {
779 if (!seen.insert(p).second) {
782 <<
"a handle passed as operand #" << operandNumber
783 <<
" and consumed by this operation points to a payload "
784 "entity more than once";
785 if constexpr (std::is_pointer_v<T>)
786 diag.attachNote(p->getLoc()) <<
"repeated target op";
788 diag.attachNote(p.getLoc()) <<
"repeated target value";
795void transform::TransformState::compactOpHandles() {
796 for (Value handle : opHandlesToCompact) {
797 Mappings &mappings = getMapping(handle,
true);
798#if LLVM_ENABLE_ABI_BREAKING_CHECKS
799 if (llvm::is_contained(mappings.direct[handle],
nullptr))
802 mappings.incrementTimestamp(handle);
804 llvm::erase(mappings.direct[handle],
nullptr);
806 opHandlesToCompact.clear();
809DiagnosedSilenceableFailure
811 LDBG() <<
"applying: "
814 llvm::scope_exit printOnFailureRAII([
this] {
816 LDBG() <<
"Failing Top-level payload:\n"
822 regionStack.back()->currentTransform =
transform;
825 if (options.getExpensiveChecksEnabled()) {
827 if (failed(checkAndRecordHandleInvalidation(
transform)))
831 FULL_LDBG() <<
"iterate on handle: " << operand.get();
833 FULL_LDBG() <<
"--handle not consumed -> SKIP";
836 if (
transform.allowsRepeatedHandleOperands()) {
837 FULL_LDBG() <<
"--op allows repeated handles -> SKIP";
842 Type operandType = operand.get().getType();
843 if (llvm::isa<TransformHandleTypeInterface>(operandType)) {
844 FULL_LDBG() <<
"--checkRepeatedConsumptionInOperand for Operation*";
847 getPayloadOpsView(operand.get()),
transform,
848 operand.getOperandNumber());
853 }
else if (llvm::isa<TransformValueHandleTypeInterface>(operandType)) {
854 FULL_LDBG() <<
"--checkRepeatedConsumptionInOperand For Value";
857 getPayloadValuesView(operand.get()),
transform,
858 operand.getOperandNumber());
864 FULL_LDBG() <<
"--not a TransformHandle -> SKIP AND DROP ON THE FLOOR";
880 for (
OpOperand *opOperand : consumedOperands) {
881 Value operand = opOperand->get();
882 if (llvm::isa<TransformHandleTypeInterface>(operand.
getType())) {
884 llvm::append_range(origOpFlatResults, payloadOp->
getResults());
888 if (llvm::isa<TransformValueHandleTypeInterface>(operand.
getType())) {
889 for (
Value payloadValue : getPayloadValuesView(operand)) {
890 if (llvm::isa<OpResult>(payloadValue)) {
896 llvm::map_range(*llvm::cast<BlockArgument>(payloadValue).getOwner(),
903 <<
"unexpectedly consumed a value that is not a handle as operand #"
904 << opOperand->getOperandNumber();
906 <<
"value defined here with type " << operand.
getType();
915 llvm::find_if(llvm::reverse(regionStack), [&](
RegionScope *scope) {
916 return handle.getParentRegion() == scope->region;
918 assert(scopeIt != regionStack.rend() &&
919 "could not find region scope for handle");
921 return llvm::all_of(handle.getUsers(), [&](
Operation *user) {
922 return user == scope->currentTransform ||
923 happensBefore(user, scope->currentTransform);
942 transform->hasAttr(FindPayloadReplacementOpInterface::
943 kSilenceTrackingFailuresAttrName)) {
952 result = std::move(trackingFailure);
955 if (
result.isSilenceableFailure())
956 result.attachNote() <<
"tracking listener also failed: "
961 if (
result.isDefiniteFailure())
966 if (
result.isSilenceableFailure())
971 for (
OpOperand *opOperand : consumedOperands) {
972 Value operand = opOperand->get();
973 if (llvm::isa<TransformHandleTypeInterface>(operand.
getType())) {
974 forgetMapping(operand, origOpFlatResults);
975 }
else if (llvm::isa<TransformValueHandleTypeInterface>(
977 forgetValueMapping(operand, origAssociatedOps);
981 if (failed(updateStateFromResults(results,
transform->getResults())))
984 printOnFailureRAII.release();
986 LDBG() <<
"Top-level payload:\n" << *getTopLevel();
991LogicalResult transform::TransformState::updateStateFromResults(
992 const TransformResults &results,
ResultRange opResults) {
994 if (llvm::isa<TransformParamTypeInterface>(
result.getType())) {
995 assert(results.isParam(
result.getResultNumber()) &&
996 "expected parameters for the parameter-typed result");
998 setParams(
result, results.getParams(
result.getResultNumber())))) {
1001 }
else if (llvm::isa<TransformValueHandleTypeInterface>(
result.getType())) {
1002 assert(results.isValue(
result.getResultNumber()) &&
1003 "expected values for value-type-result");
1004 if (
failed(setPayloadValues(
1005 result, results.getValues(
result.getResultNumber())))) {
1009 assert(!results.isParam(
result.getResultNumber()) &&
1010 "expected payload ops for the non-parameter typed result");
1012 setPayloadOps(
result, results.get(
result.getResultNumber())))) {
1037 return state.replacePayloadValue(value,
replacement);
1048 for (
Block &block : *region) {
1049 for (
Value handle : block.getArguments()) {
1050 state.invalidatedHandles.erase(handle);
1054 state.invalidatedHandles.erase(handle);
1059#if LLVM_ENABLE_ABI_BREAKING_CHECKS
1063 llvm::to_vector(llvm::make_first_range(state.mappings[region]->reverse));
1066 state.mappings.erase(region);
1067 state.regionStack.pop_back();
1074transform::TransformResults::TransformResults(
unsigned numSegments) {
1075 operations.appendEmptyRows(numSegments);
1076 params.appendEmptyRows(numSegments);
1077 values.appendEmptyRows(numSegments);
1083 assert(position <
static_cast<int64_t>(this->params.size()) &&
1084 "setting params for a non-existent handle");
1085 assert(this->params[position].data() ==
nullptr &&
"params already set");
1086 assert(operations[position].data() ==
nullptr &&
1087 "another kind of results already set");
1088 assert(values[position].data() ==
nullptr &&
1089 "another kind of results already set");
1090 this->params.replace(position, params);
1107 if (!
diag.succeeded())
1108 llvm::dbgs() <<
diag.getStatusString() <<
"\n";
1109 assert(
diag.succeeded() &&
"incorrect mapping");
1115 transform::TransformOpInterface
transform) {
1117 if (!isSet(opResult.getResultNumber()))
1123transform::TransformResults::get(
unsigned resultNumber)
const {
1124 assert(resultNumber < operations.size() &&
1125 "querying results for a non-existent handle");
1126 assert(operations[resultNumber].data() !=
nullptr &&
1127 "querying unset results (values or params expected?)");
1128 return operations[resultNumber];
1132transform::TransformResults::getParams(
unsigned resultNumber)
const {
1133 assert(resultNumber < params.size() &&
1134 "querying params for a non-existent handle");
1135 assert(params[resultNumber].data() !=
nullptr &&
1136 "querying unset params (ops or values expected?)");
1137 return params[resultNumber];
1141transform::TransformResults::getValues(
unsigned resultNumber)
const {
1142 assert(resultNumber < values.size() &&
1143 "querying values for a non-existent handle");
1144 assert(values[resultNumber].data() !=
nullptr &&
1145 "querying unset values (ops or params expected?)");
1146 return values[resultNumber];
1149bool transform::TransformResults::isParam(
unsigned resultNumber)
const {
1150 assert(resultNumber < params.size() &&
1151 "querying association for a non-existent handle");
1152 return params[resultNumber].data() !=
nullptr;
1155bool transform::TransformResults::isValue(
unsigned resultNumber)
const {
1156 assert(resultNumber < values.size() &&
1157 "querying association for a non-existent handle");
1158 return values[resultNumber].data() !=
nullptr;
1161bool transform::TransformResults::isSet(
unsigned resultNumber)
const {
1162 assert(resultNumber < params.size() &&
1163 "querying association for a non-existent handle");
1164 return params[resultNumber].data() !=
nullptr ||
1165 operations[resultNumber].data() !=
nullptr ||
1166 values[resultNumber].data() !=
nullptr;
1174 TransformOpInterface op,
1177 config(std::move(config)) {
1179 for (
OpOperand *opOperand : transformOp.getConsumedHandleOpOperands()) {
1180 consumedHandles.insert(opOperand->get());
1187 for (
Value v : values) {
1192 defOp = v.getDefiningOp();
1195 if (defOp != v.getDefiningOp())
1204 "invalid number of replacement values");
1208 getTransformOp(),
"tracking listener failed to find replacement op "
1209 "during application of this transform op");
1215 diag.attachNote() <<
"replacement values belong to different ops";
1220 if (config.skipCastOps && isa<CastOpInterface>(defOp)) {
1224 <<
"using output of 'CastOpInterface' op";
1230 if (!config.requireMatchingReplacementOpName ||
1246 if (
auto findReplacementOpInterface =
1247 dyn_cast<FindPayloadReplacementOpInterface>(defOp)) {
1248 values.assign(findReplacementOpInterface.getNextOperands());
1249 diag.attachNote(defOp->
getLoc()) <<
"using operands provided by "
1250 "'FindPayloadReplacementOpInterface'";
1253 }
while (!values.empty());
1255 diag.attachNote() <<
"ran out of suitable replacement values";
1263 reasonCallback(
diag);
1264 LDBG() <<
"Match Failure : " <<
diag.str();
1268void transform::TrackingListener::notifyOperationErased(
Operation *op) {
1271 (
void)replacePayloadValue(value,
nullptr);
1273 (
void)replacePayloadOp(op,
nullptr);
1276void transform::TrackingListener::notifyOperationReplaced(
1279 "invalid number of replacement values");
1282 for (
auto [oldValue, newValue] : llvm::zip(op->
getResults(), newValues))
1283 (
void)replacePayloadValue(oldValue, newValue);
1287 if (failed(getTransformState().getHandlesForPayloadOp(
1288 op, opHandles,
true))) {
1302 auto handleWasConsumed = [&] {
1303 return llvm::any_of(opHandles,
1304 [&](
Value h) {
return consumedHandles.contains(h); });
1309 if (
config.skipHandleFn) {
1310 auto *it = llvm::find_if(opHandles,
1311 [&](Value v) {
return !
config.skipHandleFn(v); });
1312 if (it != opHandles.end())
1314 }
else if (!opHandles.empty()) {
1315 aliveHandle = opHandles.front();
1317 if (!aliveHandle || handleWasConsumed()) {
1320 (void)replacePayloadOp(op,
nullptr);
1325 DiagnosedSilenceableFailure
diag =
1329 if (!
diag.succeeded()) {
1331 <<
"replacement is required because this handle must be updated";
1332 notifyPayloadReplacementNotFound(op, newValues, std::move(
diag));
1333 (void)replacePayloadOp(op,
nullptr);
1344 assert(status.succeeded() &&
"listener state was not checked");
1356 return !status.succeeded();
1364 diag.takeDiagnostics(diags);
1365 if (!status.succeeded())
1366 status.takeDiagnostics(diags);
1370 status.attachNote(op->
getLoc()) <<
"[" << errorCounter <<
"] replaced op";
1371 for (
auto &&[
index, value] : llvm::enumerate(values))
1372 status.attachNote(value.
getLoc())
1373 <<
"[" << errorCounter <<
"] replacement value " <<
index;
1379 if (!matchFailure) {
1382 return matchFailure->str();
1388 reasonCallback(
diag);
1389 matchFailure = std::move(
diag);
1403 return listener->failed();
1416 return listener->replacePayloadOp(op,
replacement);
1426 for (
auto &&[position, parent] : llvm::enumerate(targets)) {
1427 for (
Operation *child : targets.drop_front(position + 1)) {
1428 if (parent->isAncestor(child)) {
1431 <<
"transform operation consumes a handle pointing to an ancestor "
1432 "payload operation before its descendant";
1434 <<
"the ancestor is likely erased or rewritten before the "
1435 "descendant is accessed, leading to undefined behavior";
1436 diag.attachNote(parent->getLoc()) <<
"ancestor payload op";
1437 diag.attachNote(child->getLoc()) <<
"descendant payload op";
1456 diag.attachNote(payloadOpLoc) <<
"when applied to this op";
1460 if (partialResult.
size() != expectedNumResults) {
1461 auto diag =
emitDiag() <<
"application of " << transformOpName
1462 <<
" expected to produce " << expectedNumResults
1463 <<
" results (actually produced "
1464 << partialResult.
size() <<
").";
1465 diag.attachNote(transformOpLoc)
1466 <<
"if you need variadic results, consider a generic `apply` "
1467 <<
"instead of the specialized `applyToOne`.";
1472 for (
const auto &[
ptr, res] :
1473 llvm::zip(partialResult, transformOp->
getResults())) {
1476 if (llvm::isa<TransformHandleTypeInterface>(res.getType()) &&
1477 !isa<Operation *>(
ptr)) {
1478 return emitDiag() <<
"application of " << transformOpName
1479 <<
" expected to produce an Operation * for result #"
1480 << res.getResultNumber();
1482 if (llvm::isa<TransformParamTypeInterface>(res.getType()) &&
1483 !isa<Attribute>(
ptr)) {
1484 return emitDiag() <<
"application of " << transformOpName
1485 <<
" expected to produce an Attribute for result #"
1486 << res.getResultNumber();
1488 if (llvm::isa<TransformValueHandleTypeInterface>(res.getType()) &&
1490 return emitDiag() <<
"application of " << transformOpName
1491 <<
" expected to produce a Value for result #"
1492 << res.getResultNumber();
1498template <
typename T>
1500 return llvm::map_to_vector(range, llvm::CastTo<T>);
1509 if (llvm::any_of(partialResults,
1510 [](
MappedValue value) {
return value.isNull(); }))
1512 assert(transformOp->
getNumResults() == partialResults.size() &&
1513 "expected as many partial results as op as results");
1514 for (
auto [i, value] : llvm::enumerate(partialResults))
1515 transposed[i].push_back(value);
1519 unsigned position = r.getResultNumber();
1520 if (llvm::isa<TransformParamTypeInterface>(r.getType())) {
1523 }
else if (llvm::isa<TransformValueHandleTypeInterface>(r.getType())) {
1538 assert(mappings.size() == values.size() &&
"mismatching number of mappings");
1539 for (
auto &&[operand, mapped] : llvm::zip_equal(values, mappings)) {
1540 size_t mappedSize = mapped.size();
1541 if (llvm::isa<TransformHandleTypeInterface>(operand.getType())) {
1543 }
else if (llvm::isa<TransformValueHandleTypeInterface>(
1544 operand.getType())) {
1547 assert(llvm::isa<TransformParamTypeInterface>(operand.getType()) &&
1548 "unsupported kind of transform dialect value");
1549 llvm::append_range(mapped, state.
getParams(operand));
1552 if (mapped.size() - mappedSize != 1 && !flatten)
1561 mappings.resize(mappings.size() + values.size());
1571 for (
auto &&[terminatorOperand,
result] :
1574 if (llvm::isa<transform::TransformHandleTypeInterface>(
result.getType())) {
1576 }
else if (llvm::isa<transform::TransformValueHandleTypeInterface>(
1581 llvm::isa<transform::TransformParamTypeInterface>(
result.getType()) &&
1582 "unhandled transform type interface");
1605 iface.getEffectsOnValue(source, nestedEffects);
1606 for (
const auto &effect : nestedEffects)
1607 effects.emplace_back(effect.getEffect(),
target, effect.getResource());
1616 auto iface = dyn_cast<MemoryEffectOpInterface>(&op);
1627 llvm::append_range(effects, nestedEffects);
1639 auto iface = dyn_cast<MemoryEffectOpInterface>(&op);
1643 iface.getEffects(effects);
1666 <<
" were provided to the interpreter";
1680 argument, extraMappings[argument.getArgNumber() - 1])))
1692 assert(isa<TransformOpInterface>(op) &&
1693 "should implement TransformOpInterface to have "
1694 "PossibleTopLevelTransformOpTrait");
1697 return op->
emitOpError() <<
"expects at least one region";
1700 if (!llvm::hasNItems(*bodyRegion, 1))
1701 return op->
emitOpError() <<
"expects a single-block region";
1706 <<
"expects the entry block to have at least one argument";
1708 if (!llvm::isa<TransformHandleTypeInterface>(
1711 <<
"expects the first entry block argument to be of type "
1712 "implementing TransformHandleTypeInterface";
1718 <<
"expects the type of the block argument to match "
1719 "the type of the operand";
1723 if (llvm::isa<TransformHandleTypeInterface, TransformParamTypeInterface,
1724 TransformValueHandleTypeInterface>(arg.
getType()))
1729 <<
"expects trailing entry block arguments to be of type implementing "
1730 "TransformHandleTypeInterface, TransformValueHandleTypeInterface or "
1731 "TransformParamTypeInterface";
1741 <<
"expects operands to be provided for a nested op";
1742 diag.attachNote(parent->getLoc())
1743 <<
"nested in another possible top-level op";
1758 bool hasPayloadOperands =
false;
1761 if (llvm::isa<TransformHandleTypeInterface,
1762 TransformValueHandleTypeInterface>(operand.get().getType()))
1763 hasPayloadOperands =
true;
1765 if (hasPayloadOperands)
1774 llvm::report_fatal_error(
1775 Twine(
"ParamProducerTransformOpTrait must be attached to an op that "
1776 "implements MemoryEffectsOpInterface, found on ") +
1780 if (llvm::isa<TransformParamTypeInterface>(
result.getType()))
1783 <<
"ParamProducerTransformOpTrait attached to this op expects "
1784 "result types to implement TransformParamTypeInterface";
1806template <
typename EffectTy,
typename ResourceTy,
typename Range>
1809 return isa<EffectTy>(effect.
getEffect()) &&
1815 transform::TransformOpInterface
transform) {
1816 auto iface = cast<MemoryEffectOpInterface>(
transform.getOperation());
1818 iface.getEffectsOnValue(handle, effects);
1819 return ::hasEffect<MemoryEffects::Read, TransformMappingResource>(effects) &&
1866 auto iface = cast<MemoryEffectOpInterface>(
transform.getOperation());
1868 iface.getEffects(effects);
1869 return ::hasEffect<MemoryEffects::Write, PayloadIRResource>(effects);
1873 auto iface = cast<MemoryEffectOpInterface>(
transform.getOperation());
1875 iface.getEffects(effects);
1876 return ::hasEffect<MemoryEffects::Read, PayloadIRResource>(effects);
1880 Block &block, llvm::SmallDenseSet<unsigned int> &consumedArguments) {
1883 auto iface = dyn_cast<MemoryEffectOpInterface>(nested);
1888 iface.getEffects(effects);
1891 dyn_cast_or_null<BlockArgument>(effect.getValue());
1892 if (!argument || argument.
getOwner() != &block ||
1893 !isa<MemoryEffects::Free>(effect.getEffect()) ||
1907 TransformOpInterface transformOp) {
1908 SmallVector<OpOperand *> consumedOperands;
1909 consumedOperands.reserve(transformOp->getNumOperands());
1910 auto memEffectInterface =
1911 cast<MemoryEffectOpInterface>(transformOp.getOperation());
1912 SmallVector<MemoryEffects::EffectInstance, 2> effects;
1913 for (OpOperand &
target : transformOp->getOpOperands()) {
1915 memEffectInterface.getEffectsOnValue(
target.get(), effects);
1917 return isa<transform::TransformMappingResource>(
1919 isa<MemoryEffects::Free>(effect.
getEffect());
1921 consumedOperands.push_back(&
target);
1924 return consumedOperands;
1928 auto iface = cast<MemoryEffectOpInterface>(op);
1930 iface.getEffects(effects);
1932 auto effectsOn = [&](
Value value) {
1933 return llvm::make_filter_range(
1935 return instance.
getValue() == value;
1939 std::optional<unsigned> firstConsumedOperand;
1941 auto range = effectsOn(operand.get());
1942 if (range.empty()) {
1944 op->
emitError() <<
"TransformOpInterface requires memory effects "
1945 "on operands to be specified";
1946 diag.attachNote() <<
"no effects specified for operand #"
1947 << operand.getOperandNumber();
1952 <<
"TransformOpInterface did not expect "
1953 "'allocate' memory effect on an operand";
1954 diag.attachNote() <<
"specified for operand #"
1955 << operand.getOperandNumber();
1958 if (!firstConsumedOperand &&
1960 firstConsumedOperand = operand.getOperandNumber();
1964 if (firstConsumedOperand &&
1968 <<
"TransformOpInterface expects ops consuming operands to have a "
1969 "'write' effect on the payload resource";
1970 diag.attachNote() <<
"consumes operand #" << *firstConsumedOperand;
1975 auto range = effectsOn(
result);
1979 op->
emitError() <<
"TransformOpInterface requires 'allocate' memory "
1980 "effect to be specified for results";
1981 diag.attachNote() <<
"no 'allocate' effect specified for result #"
1982 <<
result.getResultNumber();
2000 if (enforceToplevelTransformOp) {
2004 <<
"expected transform to start at the top-level transform op";
2013 if (stateInitializer)
2014 stateInitializer(state);
2018 return stateExporter(state);
2026#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.cpp.inc"
2027#include "mlir/Dialect/Transform/Interfaces/TransformTypeInterfaces.cpp.inc"
static Operation * getCommonDefiningOp(const ValueVector &values)
Return the operation that defines all values in the vector.
static InFlightDiagnostic emitDiag(Location location, DiagnosticSeverity severity, const Twine &message)
Helper function used to emit a diagnostic with an optionally empty twine message.
*if copies could not be generated due to yet unimplemented cases *copyInPlacementStart and copyOutPlacementStart in copyPlacementBlock *specify the insertion points where the incoming copies and outgoing should be inserted(the insertion happens right before the *insertion point). Since `begin` can itself be invalidated due to the memref *rewriting done from this method
*if copies could not be generated due to yet unimplemented cases *copyInPlacementStart and copyOutPlacementStart in copyPlacementBlock *specify the insertion points where the incoming copies and outgoing should be the output argument nBegin is set to its * replacement(set to `begin` if no invalidation happens). Since outgoing *copies could have been inserted at `end`
static std::string diag(const llvm::Value &value)
static llvm::ManagedStatic< PassManagerOptions > options
This class represents an argument of a Block.
unsigned getArgNumber() const
Returns the number of this argument.
Block * getOwner() const
Returns the block that owns this argument.
Block represents an ordered list of Operations.
MutableArrayRef< BlockArgument > BlockArgListType
Operation * findAncestorOpInBlock(Operation &op)
Returns 'op' if 'op' lies in this block, or otherwise finds the ancestor operation of 'op' that lies ...
BlockArgument getArgument(unsigned i)
unsigned getNumArguments()
Operation * getTerminator()
Get the terminator operation of this block.
BlockArgListType getArguments()
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
A compatibility class connecting InFlightDiagnostic to DiagnosedSilenceableFailure while providing an...
The result of a transform IR operation application.
LogicalResult silence()
Converts silenceable failure into LogicalResult success without reporting the diagnostic,...
static DiagnosedSilenceableFailure success()
Constructs a DiagnosedSilenceableFailure in the success state.
std::string getMessage() const
Returns the diagnostic message without emitting it.
LogicalResult checkAndReport()
Converts all kinds of failure into a LogicalResult failure, emitting the diagnostic if necessary.
static DiagnosedSilenceableFailure silenceableFailure(Diagnostic &&diag)
Constructs a DiagnosedSilenceableFailure in the silenceable failure state, ready to emit the given di...
bool succeeded() const
Returns true if this is a success.
static DiagnosedSilenceableFailure definiteFailure()
Constructs a DiagnosedSilenceableFailure in the failure state.
bool isSilenceableFailure() const
Returns true if this is a silenceable failure.
This class contains all of the information necessary to report a diagnostic to the DiagnosticEngine.
Diagnostic & attachNote(std::optional< Location > noteLoc=std::nullopt)
Attaches a note to this diagnostic.
IRValueT get() const
Return the current value being used by this operand.
This class represents a diagnostic that is inflight and set to be reported.
Diagnostic & attachNote(std::optional< Location > noteLoc=std::nullopt)
Attaches a note to this diagnostic.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
void setListener(Listener *newListener)
Sets the listener of this builder to the one provided.
This class represents an operand of an operation.
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Set of flags used to control the behavior of the various IR print methods (e.g.
This is a value defined by a result of an operation.
unsigned getResultNumber() const
Returns the number of this result.
This class provides the API for a sub-set of ops that are known to be constant-like.
This class provides the API for ops that are known to be isolated from above.
A wrapper class that allows for printing an operation with a set of flags, useful to act as a "stream...
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
StringAttr getIdentifier() const
Return the name of this operation as a StringAttr.
T::Concept * getInterface() const
Returns an instance of the concept object for the given interface if it was registered to this operat...
Operation is the basic unit of execution within MLIR.
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Value getOperand(unsigned idx)
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
bool isBeforeInBlock(Operation *other)
Given an operation 'other' that is within the same parent block, return whether the current operation...
Block * getBlock()
Returns the operation block that contains this operation.
Operation * getParentWithTrait()
Returns the closest surrounding parent operation with trait Trait.
unsigned getNumRegions()
Returns the number of regions held by this operation.
Location getLoc()
The source location the operation was defined or derived from.
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
MutableArrayRef< OpOperand > getOpOperands()
unsigned getNumOperands()
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
OperationName getName()
The name of an operation is the key identifier for it.
operand_range getOperands()
Returns an iterator on the underlying Value's.
result_range getOpResults()
result_range getResults()
bool isProperAncestor(Operation *other)
Return true if this operation is a proper ancestor of the other operation.
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
unsigned getNumResults()
Return the number of results held by this operation.
A 2D array where each row may have different length.
size_t size() const
Returns the number of rows in the 2D array.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Operation * getParentOp()
Return the parent operation this region is attached to.
This class implements the result iterators for the Operation class.
RewriterBase(MLIRContext *ctx, OpBuilder::Listener *listener=nullptr)
Initialize the builder.
Resource * getResource() const
Return the resource that the effect applies to.
EffectT * getEffect() const
Return the effect being applied.
Value getValue() const
Return the value the effect is applied on, or nullptr if there isn't a known value being affected.
static DerivedEffect * get()
static TransformMappingResource * get()
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Operation * getOwner() const
Return the owner of this operand.
SideEffects::EffectInstance< Effect > EffectInstance
Include the generated interface declarations.
const FrozenRewritePatternSet GreedyRewriteConfig config
DiagnosedSilenceableFailure emitSilenceableFailure(Location loc, const Twine &message={})
Emits a silenceable failure with the given message.
llvm::DenseSet< ValueT, ValueInfoT > DenseSet
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
DiagnosedDefiniteFailure emitDefiniteFailure(Location loc, const Twine &message={})
Emits a definite failure with the given message.
bool hasEffect(Operation *op)
Returns "true" if op has an effect of type EffectTy.
llvm::function_ref< Fn > function_ref
Represents a range (offset, size, and stride) where each element of the triple may be dynamic or stat...