17 #include "llvm/ADT/STLExtras.h"
18 #include "llvm/ADT/ScopeExit.h"
19 #include "llvm/Support/Debug.h"
20 #include "llvm/Support/ErrorHandling.h"
22 #define DEBUG_TYPE "transform-dialect"
23 #define DEBUG_TYPE_FULL "transform-dialect-full"
24 #define DEBUG_PRINT_AFTER_ALL "transform-dialect-print-top-level-after-all"
25 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "] ")
26 #define LDBG(X) LLVM_DEBUG(DBGS() << (X))
27 #define FULL_LDBG(X) DEBUG_WITH_TYPE(DEBUG_TYPE_FULL, (DBGS() << (X)))
52 constexpr
const Value transform::TransformState::kTopLevelValue;
54 transform::TransformState::TransformState(
57 const TransformOptions &
options)
59 topLevelMappedValues.reserve(extraMappings.
size());
61 topLevelMappedValues.push_back(mapping);
63 RegionScope *scope =
new RegionScope(*
this, *region);
64 topLevelRegionScope.reset(scope);
68 Operation *transform::TransformState::getTopLevel()
const {
return topLevel; }
71 transform::TransformState::getPayloadOpsView(
Value value)
const {
72 const TransformOpMapping &operationMapping = getMapping(value).direct;
73 auto iter = operationMapping.find(value);
74 assert(iter != operationMapping.end() &&
75 "cannot find mapping for payload handle (param/value handle "
77 return iter->getSecond();
82 auto iter = mapping.find(value);
83 assert(iter != mapping.end() &&
"cannot find mapping for param handle "
84 "(operation/value handle provided?)");
85 return iter->getSecond();
89 transform::TransformState::getPayloadValuesView(
Value handleValue)
const {
90 const ValueMapping &mapping = getMapping(handleValue).values;
91 auto iter = mapping.find(handleValue);
92 assert(iter != mapping.end() &&
"cannot find mapping for value handle "
93 "(param/operation handle provided?)");
94 return iter->getSecond();
99 bool includeOutOfScope)
const {
101 for (
const auto &[region, mapping] : llvm::reverse(mappings)) {
102 auto iterator = mapping->reverse.find(op);
103 if (iterator != mapping->reverse.end()) {
104 llvm::append_range(handles, iterator->getSecond());
108 if (!includeOutOfScope &&
118 bool includeOutOfScope)
const {
120 for (
const auto &[region, mapping] : llvm::reverse(mappings)) {
121 auto iterator = mapping->reverseValues.find(payloadValue);
122 if (iterator != mapping->reverseValues.end()) {
123 llvm::append_range(handles, iterator->getSecond());
127 if (!includeOutOfScope &&
142 if (llvm::isa<transform::TransformHandleTypeInterface>(handle.
getType())) {
144 operations.reserve(values.size());
146 if (
auto *op = llvm::dyn_cast_if_present<Operation *>(value)) {
147 operations.push_back(op);
151 <<
"wrong kind of value provided for top-level operation handle";
153 if (
failed(operationsFn(operations)))
158 if (llvm::isa<transform::TransformValueHandleTypeInterface>(
161 payloadValues.reserve(values.size());
163 if (
auto v = llvm::dyn_cast_if_present<Value>(value)) {
164 payloadValues.push_back(v);
168 <<
"wrong kind of value provided for the top-level value handle";
170 if (
failed(valuesFn(payloadValues)))
175 assert(llvm::isa<transform::TransformParamTypeInterface>(handle.
getType()) &&
176 "unsupported kind of block argument");
178 parameters.reserve(values.size());
180 if (
auto attr = llvm::dyn_cast_if_present<Attribute>(value)) {
181 parameters.push_back(attr);
185 <<
"wrong kind of value provided for top-level parameter";
187 if (
failed(paramsFn(parameters)))
198 return setPayloadOps(argument, operations);
201 return setParams(argument, params);
204 return setPayloadValues(argument, payloadValues);
210 transform::TransformState::setPayloadOps(
Value value,
212 assert(value != kTopLevelValue &&
213 "attempting to reset the transformation root");
214 assert(llvm::isa<TransformHandleTypeInterface>(value.
getType()) &&
215 "wrong handle type");
221 <<
"attempting to assign a null payload op to this transform value";
224 auto iface = llvm::cast<TransformHandleTypeInterface>(value.
getType());
226 iface.checkPayload(value.
getLoc(), targets);
233 Mappings &mappings = getMapping(value);
235 mappings.direct.insert({value, std::move(storedTargets)}).second;
236 assert(inserted &&
"value is already associated with another list");
240 mappings.reverse[op].push_back(value);
246 transform::TransformState::setPayloadValues(
Value handle,
248 assert(handle !=
nullptr &&
"attempting to set params for a null value");
249 assert(llvm::isa<TransformValueHandleTypeInterface>(handle.
getType()) &&
250 "wrong handle type");
252 for (
Value payload : payloadValues) {
255 return emitError(handle.
getLoc()) <<
"attempting to assign a null payload "
256 "value to this transform handle";
259 auto iface = llvm::cast<TransformValueHandleTypeInterface>(handle.
getType());
262 iface.checkPayload(handle.
getLoc(), payloadValueVector);
266 Mappings &mappings = getMapping(handle);
268 mappings.values.insert({handle, std::move(payloadValueVector)}).second;
271 "value handle is already associated with another list of payload values");
274 for (
Value payload : payloadValues)
275 mappings.reverseValues[payload].push_back(handle);
282 assert(value !=
nullptr &&
"attempting to set params for a null value");
288 <<
"attempting to assign a null parameter to this transform value";
291 auto valueType = llvm::dyn_cast<TransformParamTypeInterface>(value.
getType());
293 "cannot associate parameter with a value of non-parameter type");
295 valueType.checkPayload(value.
getLoc(), params);
299 Mappings &mappings = getMapping(value);
301 mappings.params.insert({value, llvm::to_vector(params)}).second;
302 assert(inserted &&
"value is already associated with another list of params");
307 template <
typename Mapping,
typename Key,
typename Mapped>
309 auto it = mapping.find(key);
310 if (it == mapping.end())
313 llvm::erase(it->getSecond(), mapped);
314 if (it->getSecond().empty())
318 void transform::TransformState::forgetMapping(
Value opHandle,
320 bool allowOutOfScope) {
321 Mappings &mappings = getMapping(opHandle, allowOutOfScope);
322 for (
Operation *op : mappings.direct[opHandle])
324 mappings.direct.erase(opHandle);
325 #ifdef LLVM_ENABLE_ABI_BREAKING_CHECKS
328 mappings.incrementTimestamp(opHandle);
331 for (
Value opResult : origOpFlatResults) {
333 (void)getHandlesForPayloadValue(opResult, resultHandles);
334 for (
Value resultHandle : resultHandles) {
335 Mappings &localMappings = getMapping(resultHandle);
337 #ifdef LLVM_ENABLE_ABI_BREAKING_CHECKS
340 mappings.incrementTimestamp(resultHandle);
347 void transform::TransformState::forgetValueMapping(
349 Mappings &mappings = getMapping(valueHandle);
350 for (
Value payloadValue : mappings.reverseValues[valueHandle])
352 mappings.values.erase(valueHandle);
353 #ifdef LLVM_ENABLE_ABI_BREAKING_CHECKS
356 mappings.incrementTimestamp(valueHandle);
359 for (
Operation *payloadOp : payloadOperations) {
361 (void)getHandlesForPayloadOp(payloadOp, opHandles);
362 for (
Value opHandle : opHandles) {
363 Mappings &localMappings = getMapping(opHandle);
367 #ifdef LLVM_ENABLE_ABI_BREAKING_CHECKS
370 localMappings.incrementTimestamp(opHandle);
377 transform::TransformState::replacePayloadOp(
Operation *op,
384 (void)getHandlesForPayloadValue(opResult, valueHandles,
386 assert(valueHandles.empty() &&
"expected no mapping to old results");
393 if (
failed(getHandlesForPayloadOp(op, opHandles,
true)))
395 for (
Value handle : opHandles) {
396 Mappings &mappings = getMapping(handle,
true);
409 for (
Value handle : opHandles) {
410 Mappings &mappings = getMapping(handle,
true);
411 auto it = mappings.direct.find(handle);
412 if (it == mappings.direct.end())
419 mapped = replacement;
423 mappings.reverse[replacement].push_back(handle);
425 opHandlesToCompact.insert(handle);
433 transform::TransformState::replacePayloadValue(
Value value,
Value replacement) {
435 if (
failed(getHandlesForPayloadValue(value, valueHandles,
439 for (
Value handle : valueHandles) {
440 Mappings &mappings = getMapping(handle,
true);
447 #ifdef LLVM_ENABLE_ABI_BREAKING_CHECKS
450 mappings.incrementTimestamp(handle);
453 auto it = mappings.values.find(handle);
454 if (it == mappings.values.end())
458 for (
Value &mapped : association) {
460 mapped = replacement;
462 mappings.reverseValues[replacement].push_back(handle);
469 void transform::TransformState::recordOpHandleInvalidationOne(
476 if (invalidatedHandles.count(otherHandle) ||
477 newlyInvalidated.count(otherHandle))
480 FULL_LDBG(
"--recordOpHandleInvalidationOne\n");
483 llvm::interleaveComma(potentialAncestors,
DBGS() <<
"--ancestors: ",
484 [](
Operation *op) { llvm::dbgs() << *op; });
485 llvm::dbgs() <<
"\n");
489 for (
Operation *ancestor : potentialAncestors) {
492 { (
DBGS() <<
"----handle one ancestor: " << *ancestor <<
"\n"); });
494 { (
DBGS() <<
"----of payload with name: "
497 { (
DBGS() <<
"----of payload: " << *payloadOp <<
"\n"); });
499 if (!ancestor->isAncestor(payloadOp))
506 Location ancestorLoc = ancestor->getLoc();
508 std::optional<Location> throughValueLoc =
509 throughValue ? std::make_optional(throughValue.
getLoc()) : std::nullopt;
510 newlyInvalidated[otherHandle] = [ancestorLoc, opLoc, owner, operandNo,
512 throughValueLoc](
Location currentLoc) {
514 <<
"op uses a handle invalidated by a "
515 "previously executed transform op";
516 diag.attachNote(otherHandle.getLoc()) <<
"handle to invalidated ops";
517 diag.attachNote(owner->getLoc())
518 <<
"invalidated by this transform op that consumes its operand #"
520 <<
" and invalidates all handles to payload IR entities associated "
521 "with this operand and entities nested in them";
522 diag.attachNote(ancestorLoc) <<
"ancestor payload op";
523 diag.attachNote(opLoc) <<
"nested payload op";
524 if (throughValueLoc) {
525 diag.attachNote(*throughValueLoc)
526 <<
"consumed handle points to this payload value";
532 void transform::TransformState::recordValueHandleInvalidationByOpHandleOne(
539 if (invalidatedHandles.count(valueHandle) ||
540 newlyInvalidated.count(valueHandle))
543 for (
Operation *ancestor : potentialAncestors) {
545 std::optional<unsigned> resultNo;
549 if (
auto opResult = llvm::dyn_cast<OpResult>(payloadValue)) {
550 definingOp = opResult.getOwner();
551 resultNo = opResult.getResultNumber();
553 auto arg = llvm::cast<BlockArgument>(payloadValue);
555 argumentNo = arg.getArgNumber();
556 blockNo = std::distance(arg.getOwner()->getParent()->begin(),
557 arg.getOwner()->getIterator());
558 regionNo = arg.getOwner()->getParent()->getRegionNumber();
560 assert(definingOp &&
"expected the value to be defined by an op as result "
561 "or block argument");
562 if (!ancestor->isAncestor(definingOp))
567 Location ancestorLoc = ancestor->getLoc();
570 newlyInvalidated[valueHandle] = [valueHandle, owner, operandNo, resultNo,
571 argumentNo, blockNo, regionNo, ancestorLoc,
572 opLoc, valueLoc](
Location currentLoc) {
574 <<
"op uses a handle invalidated by a "
575 "previously executed transform op";
576 diag.attachNote(valueHandle.
getLoc()) <<
"invalidated handle";
577 diag.attachNote(owner->getLoc())
578 <<
"invalidated by this transform op that consumes its operand #"
580 <<
" and invalidates all handles to payload IR entities "
581 "associated with this operand and entities nested in them";
582 diag.attachNote(ancestorLoc)
583 <<
"ancestor op associated with the consumed handle";
585 diag.attachNote(opLoc)
586 <<
"op defining the value as result #" << *resultNo;
588 diag.attachNote(opLoc)
589 <<
"op defining the value as block argument #" << argumentNo
590 <<
" of block #" << blockNo <<
" in region #" << regionNo;
592 diag.attachNote(valueLoc) <<
"payload value";
597 void transform::TransformState::recordOpHandleInvalidation(
602 if (potentialAncestors.empty()) {
604 (
DBGS() <<
"----recording invalidation for empty handle: " << handle.
get()
610 newlyInvalidated[handle.
get()] = [owner, operandNo](
Location currentLoc) {
612 <<
"op uses a handle associated with empty "
613 "payload and invalidated by a "
614 "previously executed transform op";
615 diag.attachNote(owner->getLoc())
616 <<
"invalidated by this transform op that consumes its operand #"
629 for (
const auto &[region, mapping] : llvm::reverse(mappings)) {
633 for (
const auto &[payloadOp, otherHandles] : mapping->reverse) {
634 for (
Value otherHandle : otherHandles)
635 recordOpHandleInvalidationOne(handle, potentialAncestors, payloadOp,
636 otherHandle, throughValue,
644 for (
const auto &[payloadValue, valueHandles] : mapping->reverseValues) {
645 for (
Value valueHandle : valueHandles)
646 recordValueHandleInvalidationByOpHandleOne(handle, potentialAncestors,
647 payloadValue, valueHandle,
657 void transform::TransformState::recordValueHandleInvalidation(
661 for (
Value payloadValue : getPayloadValuesView(valueHandle.
get())) {
663 (void)getHandlesForPayloadValue(payloadValue, otherValueHandles);
664 for (
Value otherHandle : otherValueHandles) {
668 newlyInvalidated[otherHandle] = [otherHandle, owner, operandNo,
671 <<
"op uses a handle invalidated by a "
672 "previously executed transform op";
673 diag.attachNote(otherHandle.getLoc()) <<
"invalidated handle";
674 diag.attachNote(owner->getLoc())
675 <<
"invalidated by this transform op that consumes its operand #"
677 <<
" and invalidates handles to the same values as associated with "
679 diag.attachNote(valueLoc) <<
"payload value";
683 if (
auto opResult = llvm::dyn_cast<OpResult>(payloadValue)) {
684 Operation *payloadOp = opResult.getOwner();
685 recordOpHandleInvalidation(valueHandle, payloadOp, payloadValue,
688 auto arg = llvm::dyn_cast<BlockArgument>(payloadValue);
689 for (
Operation &payloadOp : *arg.getOwner())
690 recordOpHandleInvalidation(valueHandle, &payloadOp, payloadValue,
700 LogicalResult transform::TransformState::checkAndRecordHandleInvalidationImpl(
701 transform::TransformOpInterface transform,
703 FULL_LDBG(
"--Start checkAndRecordHandleInvalidation\n");
704 auto memoryEffectsIface =
705 cast<MemoryEffectOpInterface>(transform.getOperation());
707 memoryEffectsIface.getEffectsOnResource(
710 for (
OpOperand &target : transform->getOpOperands()) {
712 (
DBGS() <<
"----iterate on handle: " << target.get() <<
"\n");
718 auto it = invalidatedHandles.find(target.get());
719 auto nit = newlyInvalidated.find(target.get());
720 if (it != invalidatedHandles.end()) {
721 FULL_LDBG(
"--End checkAndRecordHandleInvalidation, found already "
722 "invalidated -> FAILURE\n");
723 return it->getSecond()(transform->getLoc()),
failure();
725 if (!transform.allowsRepeatedHandleOperands() &&
726 nit != newlyInvalidated.end()) {
727 FULL_LDBG(
"--End checkAndRecordHandleInvalidation, found newly "
728 "invalidated (by this op) -> FAILURE\n");
729 return nit->getSecond()(transform->getLoc()),
failure();
735 return isa<MemoryEffects::Free>(effect.getEffect()) &&
736 effect.getValue() == target.get();
738 if (llvm::any_of(effects, consumesTarget)) {
740 if (llvm::isa<transform::TransformHandleTypeInterface>(
741 target.get().getType())) {
742 FULL_LDBG(
"----recordOpHandleInvalidation\n");
744 llvm::to_vector(getPayloadOps(target.get()));
745 recordOpHandleInvalidation(target, payloadOps,
nullptr,
747 }
else if (llvm::isa<transform::TransformValueHandleTypeInterface>(
748 target.get().getType())) {
749 FULL_LDBG(
"----recordValueHandleInvalidation\n");
750 recordValueHandleInvalidation(target, newlyInvalidated);
752 FULL_LDBG(
"----not a TransformHandle -> SKIP AND DROP ON THE FLOOR\n");
755 FULL_LDBG(
"----no consume effect -> SKIP\n");
759 FULL_LDBG(
"--End checkAndRecordHandleInvalidation -> SUCCESS\n");
763 LogicalResult transform::TransformState::checkAndRecordHandleInvalidation(
764 transform::TransformOpInterface transform) {
765 InvalidatedHandleMap newlyInvalidated;
767 checkAndRecordHandleInvalidationImpl(transform, newlyInvalidated);
768 invalidatedHandles.insert(std::make_move_iterator(newlyInvalidated.begin()),
769 std::make_move_iterator(newlyInvalidated.end()));
773 template <
typename T>
776 transform::TransformOpInterface transform,
777 unsigned operandNumber) {
779 for (T p : payload) {
780 if (!seen.insert(p).second) {
782 transform.emitSilenceableError()
783 <<
"a handle passed as operand #" << operandNumber
784 <<
" and consumed by this operation points to a payload "
785 "entity more than once";
786 if constexpr (std::is_pointer_v<T>)
787 diag.attachNote(p->getLoc()) <<
"repeated target op";
789 diag.attachNote(p.getLoc()) <<
"repeated target value";
796 void transform::TransformState::compactOpHandles() {
797 for (
Value handle : opHandlesToCompact) {
798 Mappings &mappings = getMapping(handle,
true);
799 #ifdef LLVM_ENABLE_ABI_BREAKING_CHECKS
800 if (llvm::find(mappings.direct[handle],
nullptr) !=
801 mappings.direct[handle].end())
804 mappings.incrementTimestamp(handle);
806 llvm::erase(mappings.direct[handle],
nullptr);
808 opHandlesToCompact.clear();
814 DBGS() <<
"applying: ";
816 llvm::dbgs() <<
"\n";
819 DBGS() <<
"Top-level payload before application:\n"
820 << *getTopLevel() <<
"\n");
821 auto printOnFailureRAII = llvm::make_scope_exit([
this] {
823 LLVM_DEBUG(
DBGS() <<
"Failing Top-level payload:\n"; getTopLevel()->print(
828 regionStack.back()->currentTransform = transform;
831 if (
options.getExpensiveChecksEnabled()) {
833 if (
failed(checkAndRecordHandleInvalidation(transform)))
836 for (
OpOperand &operand : transform->getOpOperands()) {
838 (
DBGS() <<
"iterate on handle: " << operand.get() <<
"\n");
841 FULL_LDBG(
"--handle not consumed -> SKIP\n");
844 if (transform.allowsRepeatedHandleOperands()) {
845 FULL_LDBG(
"--op allows repeated handles -> SKIP\n");
850 Type operandType = operand.get().getType();
851 if (llvm::isa<TransformHandleTypeInterface>(operandType)) {
852 FULL_LDBG(
"--checkRepeatedConsumptionInOperand for Operation*\n");
854 checkRepeatedConsumptionInOperand<Operation *>(
855 getPayloadOpsView(operand.get()), transform,
856 operand.getOperandNumber());
861 }
else if (llvm::isa<TransformValueHandleTypeInterface>(operandType)) {
862 FULL_LDBG(
"--checkRepeatedConsumptionInOperand For Value\n");
864 checkRepeatedConsumptionInOperand<Value>(
865 getPayloadValuesView(operand.get()), transform,
866 operand.getOperandNumber());
872 FULL_LDBG(
"--not a TransformHandle -> SKIP AND DROP ON THE FLOOR\n");
879 transform.getConsumedHandleOpOperands();
888 for (
OpOperand *opOperand : consumedOperands) {
889 Value operand = opOperand->get();
890 if (llvm::isa<TransformHandleTypeInterface>(operand.
getType())) {
891 for (
Operation *payloadOp : getPayloadOps(operand)) {
892 llvm::append_range(origOpFlatResults, payloadOp->
getResults());
896 if (llvm::isa<TransformValueHandleTypeInterface>(operand.
getType())) {
897 for (
Value payloadValue : getPayloadValuesView(operand)) {
898 if (llvm::isa<OpResult>(payloadValue)) {
904 llvm::map_range(*llvm::cast<BlockArgument>(payloadValue).getOwner(),
911 <<
"unexpectedly consumed a value that is not a handle as operand #"
912 << opOperand->getOperandNumber();
914 <<
"value defined here with type " << operand.
getType();
923 llvm::find_if(llvm::reverse(regionStack), [&](
RegionScope *scope) {
924 return handle.getParentRegion() == scope->region;
926 assert(scopeIt != regionStack.rend() &&
927 "could not find region scope for handle");
929 for (
Operation *user : handle.getUsers()) {
930 if (user != scope->currentTransform &&
952 transform->hasAttr(FindPayloadReplacementOpInterface::
953 kSilenceTrackingFailuresAttrName)) {
957 (void)trackingFailure.
silence();
962 result = std::move(trackingFailure);
966 result.
attachNote() <<
"tracking listener also failed: "
968 (void)trackingFailure.
silence();
981 for (
OpOperand *opOperand : consumedOperands) {
982 Value operand = opOperand->get();
983 if (llvm::isa<TransformHandleTypeInterface>(operand.
getType())) {
984 forgetMapping(operand, origOpFlatResults);
985 }
else if (llvm::isa<TransformValueHandleTypeInterface>(
987 forgetValueMapping(operand, origAssociatedOps);
991 if (
failed(updateStateFromResults(results, transform->getResults())))
994 printOnFailureRAII.release();
996 DBGS() <<
"Top-level payload:\n";
997 getTopLevel()->print(llvm::dbgs());
1002 LogicalResult transform::TransformState::updateStateFromResults(
1004 for (
OpResult result : opResults) {
1005 if (llvm::isa<TransformParamTypeInterface>(result.getType())) {
1006 assert(results.isParam(result.getResultNumber()) &&
1007 "expected parameters for the parameter-typed result");
1009 setParams(result, results.getParams(result.getResultNumber())))) {
1012 }
else if (llvm::isa<TransformValueHandleTypeInterface>(result.getType())) {
1013 assert(results.isValue(result.getResultNumber()) &&
1014 "expected values for value-type-result");
1015 if (
failed(setPayloadValues(
1016 result, results.getValues(result.getResultNumber())))) {
1020 assert(!results.isParam(result.getResultNumber()) &&
1021 "expected payload ops for the non-parameter typed result");
1023 setPayloadOps(result, results.get(result.getResultNumber())))) {
1042 return state.replacePayloadOp(op, replacement);
1047 Value replacement) {
1048 return state.replacePayloadValue(value, replacement);
1059 for (
Block &block : *region) {
1060 for (
Value handle : block.getArguments()) {
1061 state.invalidatedHandles.erase(handle);
1065 state.invalidatedHandles.erase(handle);
1070 #if LLVM_ENABLE_ABI_BREAKING_CHECKS
1074 llvm::to_vector(llvm::make_first_range(state.mappings[region]->reverse));
1077 state.mappings.erase(region);
1078 state.regionStack.pop_back();
1085 transform::TransformResults::TransformResults(
unsigned numSegments) {
1086 operations.appendEmptyRows(numSegments);
1087 params.appendEmptyRows(numSegments);
1088 values.appendEmptyRows(numSegments);
1094 assert(position <
static_cast<int64_t
>(this->params.size()) &&
1095 "setting params for a non-existent handle");
1096 assert(this->params[position].data() ==
nullptr &&
"params already set");
1097 assert(operations[position].data() ==
nullptr &&
1098 "another kind of results already set");
1099 assert(values[position].data() ==
nullptr &&
1100 "another kind of results already set");
1101 this->params.replace(position, params);
1109 return set(handle, operations),
success();
1112 return setParams(handle, params),
success();
1115 return setValues(handle, payloadValues),
success();
1118 if (!
diag.succeeded())
1119 llvm::dbgs() <<
diag.getStatusString() <<
"\n";
1120 assert(
diag.succeeded() &&
"incorrect mapping");
1122 (void)
diag.silence();
1126 transform::TransformOpInterface transform) {
1127 for (
OpResult opResult : transform->getResults()) {
1128 if (!isSet(opResult.getResultNumber()))
1129 setMappedValues(opResult, {});
1134 transform::TransformResults::get(
unsigned resultNumber)
const {
1135 assert(resultNumber < operations.size() &&
1136 "querying results for a non-existent handle");
1137 assert(operations[resultNumber].data() !=
nullptr &&
1138 "querying unset results (values or params expected?)");
1139 return operations[resultNumber];
1143 transform::TransformResults::getParams(
unsigned resultNumber)
const {
1144 assert(resultNumber < params.size() &&
1145 "querying params for a non-existent handle");
1146 assert(params[resultNumber].data() !=
nullptr &&
1147 "querying unset params (ops or values expected?)");
1148 return params[resultNumber];
1152 transform::TransformResults::getValues(
unsigned resultNumber)
const {
1153 assert(resultNumber < values.size() &&
1154 "querying values for a non-existent handle");
1155 assert(values[resultNumber].data() !=
nullptr &&
1156 "querying unset values (ops or params expected?)");
1157 return values[resultNumber];
1160 bool transform::TransformResults::isParam(
unsigned resultNumber)
const {
1161 assert(resultNumber < params.size() &&
1162 "querying association for a non-existent handle");
1163 return params[resultNumber].data() !=
nullptr;
1166 bool transform::TransformResults::isValue(
unsigned resultNumber)
const {
1167 assert(resultNumber < values.size() &&
1168 "querying association for a non-existent handle");
1169 return values[resultNumber].data() !=
nullptr;
1172 bool transform::TransformResults::isSet(
unsigned resultNumber)
const {
1173 assert(resultNumber < params.size() &&
1174 "querying association for a non-existent handle");
1175 return params[resultNumber].data() !=
nullptr ||
1176 operations[resultNumber].data() !=
nullptr ||
1177 values[resultNumber].data() !=
nullptr;
1185 TransformOpInterface op,
1187 :
TransformState::Extension(state), transformOp(op), config(config) {
1189 for (
OpOperand *opOperand : transformOp.getConsumedHandleOpOperands()) {
1190 consumedHandles.insert(opOperand->get());
1197 for (
Value v : values) {
1202 defOp = v.getDefiningOp();
1205 if (defOp != v.getDefiningOp())
1214 "invalid number of replacement values");
1218 getTransformOp(),
"tracking listener failed to find replacement op "
1219 "during application of this transform op");
1223 Operation *defOp = getCommonDefiningOp(values);
1225 diag.attachNote() <<
"replacement values belong to different ops";
1230 if (config.skipCastOps && isa<CastOpInterface>(defOp)) {
1234 <<
"using output of 'CastOpInterface' op";
1240 if (!config.requireMatchingReplacementOpName ||
1256 if (
auto findReplacementOpInterface =
1257 dyn_cast<FindPayloadReplacementOpInterface>(defOp)) {
1258 values.assign(findReplacementOpInterface.getNextOperands());
1259 diag.attachNote(defOp->
getLoc()) <<
"using operands provided by "
1260 "'FindPayloadReplacementOpInterface'";
1263 }
while (!values.empty());
1265 diag.attachNote() <<
"ran out of suitable replacement values";
1273 reasonCallback(
diag);
1274 DBGS() <<
"Match Failure : " <<
diag.str() <<
"\n";
1278 void transform::TrackingListener::notifyOperationErased(
Operation *op) {
1281 (void)replacePayloadValue(value,
nullptr);
1283 (void)replacePayloadOp(op,
nullptr);
1286 void transform::TrackingListener::notifyOperationReplaced(
1289 "invalid number of replacement values");
1292 for (
auto [oldValue, newValue] : llvm::zip(op->
getResults(), newValues))
1293 (void)replacePayloadValue(oldValue, newValue);
1297 if (
failed(getTransformState().getHandlesForPayloadOp(
1298 op, opHandles,
true))) {
1312 auto handleWasConsumed = [&] {
1313 return llvm::any_of(opHandles,
1314 [&](
Value h) {
return consumedHandles.contains(h); });
1319 if (config.skipHandleFn) {
1320 auto it = llvm::find_if(opHandles,
1321 [&](
Value v) {
return !config.skipHandleFn(v); });
1322 if (it != opHandles.end())
1324 }
else if (!opHandles.empty()) {
1325 aliveHandle = opHandles.front();
1327 if (!aliveHandle || handleWasConsumed()) {
1330 (void)replacePayloadOp(op,
nullptr);
1336 findReplacementOp(replacement, op, newValues);
1339 if (!
diag.succeeded()) {
1341 <<
"replacement is required because this handle must be updated";
1342 notifyPayloadReplacementNotFound(op, newValues, std::move(
diag));
1343 (void)replacePayloadOp(op,
nullptr);
1347 (void)replacePayloadOp(op, replacement);
1354 assert(status.succeeded() &&
"listener state was not checked");
1366 return !status.succeeded();
1374 diag.takeDiagnostics(diags);
1375 if (!status.succeeded())
1376 status.takeDiagnostics(diags);
1380 status.attachNote(op->
getLoc()) <<
"[" << errorCounter <<
"] replaced op";
1382 status.attachNote(value.
getLoc())
1383 <<
"[" << errorCounter <<
"] replacement value " << index;
1398 return listener->failed();
1403 if (hasTrackingFailures()) {
1411 return listener->replacePayloadOp(op, replacement);
1422 for (
Operation *child : targets.drop_front(position + 1)) {
1423 if (parent->isAncestor(child)) {
1426 <<
"transform operation consumes a handle pointing to an ancestor "
1427 "payload operation before its descendant";
1429 <<
"the ancestor is likely erased or rewritten before the "
1430 "descendant is accessed, leading to undefined behavior";
1431 diag.attachNote(parent->getLoc()) <<
"ancestor payload op";
1432 diag.attachNote(child->getLoc()) <<
"descendant payload op";
1451 diag.attachNote(payloadOpLoc) <<
"when applied to this op";
1455 if (partialResult.
size() != expectedNumResults) {
1456 auto diag =
emitDiag() <<
"application of " << transformOpName
1457 <<
" expected to produce " << expectedNumResults
1458 <<
" results (actually produced "
1459 << partialResult.
size() <<
").";
1460 diag.attachNote(transformOpLoc)
1461 <<
"if you need variadic results, consider a generic `apply` "
1462 <<
"instead of the specialized `applyToOne`.";
1467 for (
const auto &[ptr, res] :
1468 llvm::zip(partialResult, transformOp->
getResults())) {
1471 if (llvm::isa<TransformHandleTypeInterface>(res.getType()) &&
1473 return emitDiag() <<
"application of " << transformOpName
1474 <<
" expected to produce an Operation * for result #"
1475 << res.getResultNumber();
1477 if (llvm::isa<TransformParamTypeInterface>(res.getType()) &&
1479 return emitDiag() <<
"application of " << transformOpName
1480 <<
" expected to produce an Attribute for result #"
1481 << res.getResultNumber();
1483 if (llvm::isa<TransformValueHandleTypeInterface>(res.getType()) &&
1485 return emitDiag() <<
"application of " << transformOpName
1486 <<
" expected to produce a Value for result #"
1487 << res.getResultNumber();
1493 template <
typename T>
1495 return llvm::to_vector(llvm::map_range(
1505 if (llvm::any_of(partialResults,
1506 [](
MappedValue value) {
return value.isNull(); }))
1508 assert(transformOp->
getNumResults() == partialResults.size() &&
1509 "expected as many partial results as op as results");
1511 transposed[i].push_back(value);
1515 unsigned position = r.getResultNumber();
1516 if (llvm::isa<TransformParamTypeInterface>(r.getType())) {
1518 castVector<Attribute>(transposed[position]));
1519 }
else if (llvm::isa<TransformValueHandleTypeInterface>(r.getType())) {
1520 transformResults.
setValues(r, castVector<Value>(transposed[position]));
1522 transformResults.
set(r, castVector<Operation *>(transposed[position]));
1534 for (
Value operand : values) {
1536 if (llvm::isa<TransformHandleTypeInterface>(operand.getType())) {
1537 llvm::append_range(mapped, state.getPayloadOps(operand));
1538 }
else if (llvm::isa<TransformValueHandleTypeInterface>(
1539 operand.getType())) {
1540 llvm::append_range(mapped, state.getPayloadValues(operand));
1542 assert(llvm::isa<TransformParamTypeInterface>(operand.getType()) &&
1543 "unsupported kind of transform dialect value");
1544 llvm::append_range(mapped, state.getParams(operand));
1552 for (
auto &&[terminatorOperand, result] :
1555 if (llvm::isa<transform::TransformHandleTypeInterface>(result.getType())) {
1556 results.
set(result, state.getPayloadOps(terminatorOperand));
1557 }
else if (llvm::isa<transform::TransformValueHandleTypeInterface>(
1558 result.getType())) {
1559 results.
setValues(result, state.getPayloadValues(terminatorOperand));
1562 llvm::isa<transform::TransformParamTypeInterface>(result.getType()) &&
1563 "unhandled transform type interface");
1564 results.
setParams(result, state.getParams(terminatorOperand));
1585 iface.getEffectsOnValue(source, nestedEffects);
1586 for (
const auto &effect : nestedEffects)
1587 effects.emplace_back(effect.getEffect(), target, effect.getResource());
1596 auto iface = dyn_cast<MemoryEffectOpInterface>(&op);
1600 for (
auto &&[source, target] : llvm::zip(block.
getArguments(), operands)) {
1607 llvm::append_range(effects, nestedEffects);
1619 auto iface = dyn_cast<MemoryEffectOpInterface>(&op);
1624 iface.getEffects(effects);
1639 llvm::append_range(targets, state.getPayloadOps(op->
getOperand(0)));
1642 if (state.getNumTopLevelMappings() !=
1646 <<
" extra value bindings, but " << state.getNumTopLevelMappings()
1647 <<
" were provided to the interpreter";
1650 targets.push_back(state.getTopLevel());
1652 for (
unsigned i = 0, e = state.getNumTopLevelMappings(); i < e; ++i)
1653 extraMappings.push_back(llvm::to_vector(state.getTopLevelMapping(i)));
1660 if (
failed(state.mapBlockArgument(
1661 argument, extraMappings[argument.getArgNumber() - 1])))
1673 assert(isa<TransformOpInterface>(op) &&
1674 "should implement TransformOpInterface to have "
1675 "PossibleTopLevelTransformOpTrait");
1678 return op->
emitOpError() <<
"expects at least one region";
1681 if (!llvm::hasNItems(*bodyRegion, 1))
1682 return op->
emitOpError() <<
"expects a single-block region";
1687 <<
"expects the entry block to have at least one argument";
1689 if (!llvm::isa<TransformHandleTypeInterface>(
1692 <<
"expects the first entry block argument to be of type "
1693 "implementing TransformHandleTypeInterface";
1699 <<
"expects the type of the block argument to match "
1700 "the type of the operand";
1704 if (llvm::isa<TransformHandleTypeInterface, TransformParamTypeInterface,
1705 TransformValueHandleTypeInterface>(arg.
getType()))
1710 <<
"expects trailing entry block arguments to be of type implementing "
1711 "TransformHandleTypeInterface, TransformValueHandleTypeInterface or "
1712 "TransformParamTypeInterface";
1722 <<
"expects operands to be provided for a nested op";
1723 diag.attachNote(parent->getLoc())
1724 <<
"nested in another possible top-level op";
1739 bool hasPayloadOperands =
false;
1742 if (llvm::isa<TransformHandleTypeInterface,
1743 TransformValueHandleTypeInterface>(operand.getType()))
1744 hasPayloadOperands =
true;
1746 if (hasPayloadOperands)
1755 llvm::report_fatal_error(
1756 Twine(
"ParamProducerTransformOpTrait must be attached to an op that "
1757 "implements MemoryEffectsOpInterface, found on ") +
1761 if (llvm::isa<TransformParamTypeInterface>(result.getType()))
1764 <<
"ParamProducerTransformOpTrait attached to this op expects "
1765 "result types to implement TransformParamTypeInterface";
1777 for (
Value handle : handles) {
1787 template <
typename EffectTy,
typename ResourceTy,
typename Range>
1790 return isa<EffectTy>(effect.
getEffect()) &&
1796 transform::TransformOpInterface transform) {
1797 auto iface = cast<MemoryEffectOpInterface>(transform.getOperation());
1799 iface.getEffectsOnValue(handle, effects);
1800 return ::hasEffect<MemoryEffects::Read, TransformMappingResource>(effects) &&
1801 ::hasEffect<MemoryEffects::Free, TransformMappingResource>(effects);
1807 for (
Value handle : handles) {
1818 for (
Value handle : handles) {
1836 auto iface = cast<MemoryEffectOpInterface>(transform.getOperation());
1838 iface.getEffects(effects);
1839 return ::hasEffect<MemoryEffects::Write, PayloadIRResource>(effects);
1843 auto iface = cast<MemoryEffectOpInterface>(transform.getOperation());
1845 iface.getEffects(effects);
1846 return ::hasEffect<MemoryEffects::Read, PayloadIRResource>(effects);
1850 Block &block, llvm::SmallDenseSet<unsigned int> &consumedArguments) {
1853 auto iface = dyn_cast<MemoryEffectOpInterface>(nested);
1858 iface.getEffects(effects);
1861 dyn_cast_or_null<BlockArgument>(effect.getValue());
1862 if (!argument || argument.
getOwner() != &block ||
1863 !isa<MemoryEffects::Free>(effect.getEffect()) ||
1877 TransformOpInterface transformOp) {
1879 consumedOperands.reserve(transformOp->getNumOperands());
1880 auto memEffectInterface =
1881 cast<MemoryEffectOpInterface>(transformOp.getOperation());
1883 for (
OpOperand &target : transformOp->getOpOperands()) {
1885 memEffectInterface.getEffectsOnValue(target.get(), effects);
1887 return isa<transform::TransformMappingResource>(
1889 isa<MemoryEffects::Free>(effect.
getEffect());
1891 consumedOperands.push_back(&target);
1894 return consumedOperands;
1898 auto iface = cast<MemoryEffectOpInterface>(op);
1900 iface.getEffects(effects);
1902 auto effectsOn = [&](
Value value) {
1903 return llvm::make_filter_range(
1905 return instance.
getValue() == value;
1909 std::optional<unsigned> firstConsumedOperand;
1911 auto range = effectsOn(operand.get());
1912 if (range.empty()) {
1914 op->
emitError() <<
"TransformOpInterface requires memory effects "
1915 "on operands to be specified";
1916 diag.attachNote() <<
"no effects specified for operand #"
1917 << operand.getOperandNumber();
1920 if (::hasEffect<MemoryEffects::Allocate, TransformMappingResource>(range)) {
1922 <<
"TransformOpInterface did not expect "
1923 "'allocate' memory effect on an operand";
1924 diag.attachNote() <<
"specified for operand #"
1925 << operand.getOperandNumber();
1928 if (!firstConsumedOperand &&
1929 ::hasEffect<MemoryEffects::Free, TransformMappingResource>(range)) {
1930 firstConsumedOperand = operand.getOperandNumber();
1934 if (firstConsumedOperand &&
1935 !::hasEffect<MemoryEffects::Write, PayloadIRResource>(effects)) {
1938 <<
"TransformOpInterface expects ops consuming operands to have a "
1939 "'write' effect on the payload resource";
1940 diag.attachNote() <<
"consumes operand #" << *firstConsumedOperand;
1945 auto range = effectsOn(result);
1946 if (!::hasEffect<MemoryEffects::Allocate, TransformMappingResource>(
1949 op->
emitError() <<
"TransformOpInterface requires 'allocate' memory "
1950 "effect to be specified for results";
1951 diag.attachNote() <<
"no 'allocate' effect specified for result #"
1952 << result.getResultNumber();
1965 Operation *payloadRoot, TransformOpInterface transform,
1968 if (enforceToplevelTransformOp) {
1970 transform->getNumOperands() != 0) {
1971 return transform->emitError()
1972 <<
"expected transform to start at the top-level transform op";
1979 TransformState state(transform->getParentRegion(), payloadRoot, extraMapping,
1981 return state.applyTransform(transform).checkAndReport();
1988 #include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.cpp.inc"
1989 #include "mlir/Dialect/Transform/Interfaces/TransformTypeInterfaces.cpp.inc"
static InFlightDiagnostic emitDiag(Location location, DiagnosticSeverity severity, const Twine &message)
Helper function used to emit a diagnostic with an optionally empty twine message.
static std::string diag(const llvm::Value &value)
static llvm::ManagedStatic< PassManagerOptions > options
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
Attributes are known-constant values of operations.
This class represents an argument of a Block.
Block * getOwner() const
Returns the block that owns this argument.
unsigned getArgNumber() const
Returns the number of this argument.
Block represents an ordered list of Operations.
Operation * findAncestorOpInBlock(Operation &op)
Returns 'op' if 'op' lies in this block, or otherwise finds the ancestor operation of 'op' that lies ...
BlockArgument getArgument(unsigned i)
unsigned getNumArguments()
Operation * getTerminator()
Get the terminator operation of this block.
BlockArgListType getArguments()
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
A compatibility class connecting InFlightDiagnostic to DiagnosedSilenceableFailure while providing an...
The result of a transform IR operation application.
LogicalResult silence()
Converts silenceable failure into LogicalResult success without reporting the diagnostic,...
static DiagnosedSilenceableFailure success()
Constructs a DiagnosedSilenceableFailure in the success state.
Diagnostic & attachNote(std::optional< Location > loc=std::nullopt)
Attaches a note to the last diagnostic.
std::string getMessage() const
Returns the diagnostic message without emitting it.
bool isDefiniteFailure() const
Returns true if this is a definite failure.
LogicalResult checkAndReport()
Converts all kinds of failure into a LogicalResult failure, emitting the diagnostic if necessary.
static DiagnosedSilenceableFailure silenceableFailure(Diagnostic &&diag)
Constructs a DiagnosedSilenceableFailure in the silenceable failure state, ready to emit the given di...
bool succeeded() const
Returns true if this is a success.
static DiagnosedSilenceableFailure definiteFailure()
Constructs a DiagnosedSilenceableFailure in the failure state.
bool isSilenceableFailure() const
Returns true if this is a silenceable failure.
This class contains all of the information necessary to report a diagnostic to the DiagnosticEngine.
IRValueT get() const
Return the current value being used by this operand.
This class represents a diagnostic that is inflight and set to be reported.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
void setListener(Listener *newListener)
Sets the listener of this builder to the one provided.
This class represents an operand of an operation.
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Set of flags used to control the behavior of the various IR print methods (e.g.
This is a value defined by a result of an operation.
unsigned getResultNumber() const
Returns the number of this result.
This class provides the API for a sub-set of ops that are known to be constant-like.
This class provides the API for ops that are known to be isolated from above.
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
StringAttr getIdentifier() const
Return the name of this operation as a StringAttr.
T::Concept * getInterface() const
Returns an instance of the concept object for the given interface if it was registered to this operat...
Operation is the basic unit of execution within MLIR.
Value getOperand(unsigned idx)
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
bool isBeforeInBlock(Operation *other)
Given an operation 'other' that is within the same parent block, return whether the current operation...
unsigned getNumRegions()
Returns the number of regions held by this operation.
Location getLoc()
The source location the operation was defined or derived from.
unsigned getNumOperands()
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Block * getBlock()
Returns the operation block that contains this operation.
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Operation * getParentWithTrait()
Returns the closest surrounding parent operation with trait Trait.
OperationName getName()
The name of an operation is the key identifier for it.
MutableArrayRef< OpOperand > getOpOperands()
operand_range getOperands()
Returns an iterator on the underlying Value's.
result_range getOpResults()
result_range getResults()
bool isProperAncestor(Operation *other)
Return true if this operation is a proper ancestor of the other operation.
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
unsigned getNumResults()
Return the number of results held by this operation.
A 2D array where each row may have different length.
size_t size() const
Returns the number of rows in the 2D array.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Operation * getParentOp()
Return the parent operation this region is attached to.
This class implements the result iterators for the Operation class.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
This class represents a specific instance of an effect.
Resource * getResource() const
Return the resource that the effect applies to.
EffectT * getEffect() const
Return the effect being applied.
Value getValue() const
Return the value the effect is applied on, or nullptr if there isn't a known value being affected.
static DerivedEffect * get()
Returns a unique instance for the derived effect class.
static TransformMappingResource * get()
Returns a unique instance for the given effect class.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Operation * getOwner() const
Return the owner of this operand.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
bool hasEffect(Operation *op, Value value=nullptr)
Returns true if op has an effect of type EffectTy on value.
DiagnosedSilenceableFailure emitSilenceableFailure(Location loc, const Twine &message={})
Emits a silenceable failure with the given message.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
DiagnosedDefiniteFailure emitDefiniteFailure(Location loc, const Twine &message={})
Emits a definite failure with the given message.
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
This class represents an efficient way to signal success or failure.
Represents a range (offset, size, and stride) where each element of the triple may be dynamic or stat...