16 #include "llvm/ADT/STLExtras.h"
17 #include "llvm/ADT/ScopeExit.h"
18 #include "llvm/ADT/iterator.h"
19 #include "llvm/Support/Debug.h"
20 #include "llvm/Support/ErrorHandling.h"
21 #include "llvm/Support/InterleavedRange.h"
23 #define DEBUG_TYPE "transform-dialect"
24 #define DEBUG_TYPE_FULL "transform-dialect-full"
25 #define DEBUG_PRINT_AFTER_ALL "transform-dialect-print-top-level-after-all"
26 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "] ")
27 #define LDBG(X) LLVM_DEBUG(DBGS() << (X))
28 #define FULL_LDBG(X) DEBUG_WITH_TYPE(DEBUG_TYPE_FULL, (DBGS() << (X)))
53 constexpr
const Value transform::TransformState::kTopLevelValue;
55 transform::TransformState::TransformState(
58 const TransformOptions &
options)
60 topLevelMappedValues.reserve(extraMappings.
size());
62 topLevelMappedValues.push_back(mapping);
64 RegionScope *scope =
new RegionScope(*
this, *region);
65 topLevelRegionScope.reset(scope);
69 Operation *transform::TransformState::getTopLevel()
const {
return topLevel; }
72 transform::TransformState::getPayloadOpsView(
Value value)
const {
73 const TransformOpMapping &operationMapping = getMapping(value).direct;
74 auto iter = operationMapping.find(value);
75 assert(iter != operationMapping.end() &&
76 "cannot find mapping for payload handle (param/value handle "
78 return iter->getSecond();
83 auto iter = mapping.find(value);
84 assert(iter != mapping.end() &&
"cannot find mapping for param handle "
85 "(operation/value handle provided?)");
86 return iter->getSecond();
90 transform::TransformState::getPayloadValuesView(
Value handleValue)
const {
91 const ValueMapping &mapping = getMapping(handleValue).values;
92 auto iter = mapping.find(handleValue);
93 assert(iter != mapping.end() &&
"cannot find mapping for value handle "
94 "(param/operation handle provided?)");
95 return iter->getSecond();
100 bool includeOutOfScope)
const {
102 for (
const auto &[region, mapping] : llvm::reverse(mappings)) {
103 auto iterator = mapping->reverse.find(op);
104 if (iterator != mapping->reverse.end()) {
105 llvm::append_range(handles, iterator->getSecond());
109 if (!includeOutOfScope &&
114 return success(found);
119 bool includeOutOfScope)
const {
121 for (
const auto &[region, mapping] : llvm::reverse(mappings)) {
122 auto iterator = mapping->reverseValues.find(payloadValue);
123 if (iterator != mapping->reverseValues.end()) {
124 llvm::append_range(handles, iterator->getSecond());
128 if (!includeOutOfScope &&
133 return success(found);
143 if (llvm::isa<transform::TransformHandleTypeInterface>(handle.
getType())) {
145 operations.reserve(values.size());
147 if (
auto *op = llvm::dyn_cast_if_present<Operation *>(value)) {
148 operations.push_back(op);
152 <<
"wrong kind of value provided for top-level operation handle";
154 if (failed(operationsFn(operations)))
159 if (llvm::isa<transform::TransformValueHandleTypeInterface>(
162 payloadValues.reserve(values.size());
164 if (
auto v = llvm::dyn_cast_if_present<Value>(value)) {
165 payloadValues.push_back(v);
169 <<
"wrong kind of value provided for the top-level value handle";
171 if (failed(valuesFn(payloadValues)))
176 assert(llvm::isa<transform::TransformParamTypeInterface>(handle.
getType()) &&
177 "unsupported kind of block argument");
179 parameters.reserve(values.size());
181 if (
auto attr = llvm::dyn_cast_if_present<Attribute>(value)) {
182 parameters.push_back(attr);
186 <<
"wrong kind of value provided for top-level parameter";
188 if (failed(paramsFn(parameters)))
199 return setPayloadOps(argument, operations);
202 return setParams(argument, params);
205 return setPayloadValues(argument, payloadValues);
213 for (
auto &&[argument, values] : llvm::zip_equal(arguments, mapping))
214 if (failed(mapBlockArgument(argument, values)))
220 transform::TransformState::setPayloadOps(
Value value,
222 assert(value != kTopLevelValue &&
223 "attempting to reset the transformation root");
224 assert(llvm::isa<TransformHandleTypeInterface>(value.
getType()) &&
225 "wrong handle type");
231 <<
"attempting to assign a null payload op to this transform value";
234 auto iface = llvm::cast<TransformHandleTypeInterface>(value.
getType());
236 iface.checkPayload(value.
getLoc(), targets);
243 Mappings &mappings = getMapping(value);
245 mappings.direct.insert({value, std::move(storedTargets)}).second;
246 assert(inserted &&
"value is already associated with another list");
250 mappings.reverse[op].push_back(value);
256 transform::TransformState::setPayloadValues(
Value handle,
258 assert(handle !=
nullptr &&
"attempting to set params for a null value");
259 assert(llvm::isa<TransformValueHandleTypeInterface>(handle.
getType()) &&
260 "wrong handle type");
262 for (
Value payload : payloadValues) {
265 return emitError(handle.
getLoc()) <<
"attempting to assign a null payload "
266 "value to this transform handle";
269 auto iface = llvm::cast<TransformValueHandleTypeInterface>(handle.
getType());
272 iface.checkPayload(handle.
getLoc(), payloadValueVector);
276 Mappings &mappings = getMapping(handle);
278 mappings.values.insert({handle, std::move(payloadValueVector)}).second;
281 "value handle is already associated with another list of payload values");
284 for (
Value payload : payloadValues)
285 mappings.reverseValues[payload].push_back(handle);
290 LogicalResult transform::TransformState::setParams(
Value value,
292 assert(value !=
nullptr &&
"attempting to set params for a null value");
298 <<
"attempting to assign a null parameter to this transform value";
301 auto valueType = llvm::dyn_cast<TransformParamTypeInterface>(value.
getType());
303 "cannot associate parameter with a value of non-parameter type");
305 valueType.checkPayload(value.
getLoc(), params);
309 Mappings &mappings = getMapping(value);
311 mappings.params.insert({value, llvm::to_vector(params)}).second;
312 assert(inserted &&
"value is already associated with another list of params");
317 template <
typename Mapping,
typename Key,
typename Mapped>
319 auto it = mapping.find(key);
320 if (it == mapping.end())
323 llvm::erase(it->getSecond(), mapped);
324 if (it->getSecond().empty())
328 void transform::TransformState::forgetMapping(
Value opHandle,
330 bool allowOutOfScope) {
331 Mappings &mappings = getMapping(opHandle, allowOutOfScope);
332 for (
Operation *op : mappings.direct[opHandle])
334 mappings.direct.erase(opHandle);
335 #if LLVM_ENABLE_ABI_BREAKING_CHECKS
338 mappings.incrementTimestamp(opHandle);
341 for (
Value opResult : origOpFlatResults) {
343 (void)getHandlesForPayloadValue(opResult, resultHandles);
344 for (
Value resultHandle : resultHandles) {
345 Mappings &localMappings = getMapping(resultHandle);
347 #if LLVM_ENABLE_ABI_BREAKING_CHECKS
350 mappings.incrementTimestamp(resultHandle);
357 void transform::TransformState::forgetValueMapping(
359 Mappings &mappings = getMapping(valueHandle);
360 for (
Value payloadValue : mappings.reverseValues[valueHandle])
362 mappings.values.erase(valueHandle);
363 #if LLVM_ENABLE_ABI_BREAKING_CHECKS
366 mappings.incrementTimestamp(valueHandle);
369 for (
Operation *payloadOp : payloadOperations) {
371 (void)getHandlesForPayloadOp(payloadOp, opHandles);
372 for (
Value opHandle : opHandles) {
373 Mappings &localMappings = getMapping(opHandle);
377 #if LLVM_ENABLE_ABI_BREAKING_CHECKS
380 localMappings.incrementTimestamp(opHandle);
387 transform::TransformState::replacePayloadOp(
Operation *op,
394 (void)getHandlesForPayloadValue(opResult, valueHandles,
396 assert(valueHandles.empty() &&
"expected no mapping to old results");
403 if (failed(getHandlesForPayloadOp(op, opHandles,
true)))
405 for (
Value handle : opHandles) {
406 Mappings &mappings = getMapping(handle,
true);
419 for (
Value handle : opHandles) {
420 Mappings &mappings = getMapping(handle,
true);
421 auto it = mappings.direct.find(handle);
422 if (it == mappings.direct.end())
429 mapped = replacement;
433 mappings.reverse[replacement].push_back(handle);
435 opHandlesToCompact.insert(handle);
443 transform::TransformState::replacePayloadValue(
Value value,
Value replacement) {
445 if (failed(getHandlesForPayloadValue(value, valueHandles,
449 for (
Value handle : valueHandles) {
450 Mappings &mappings = getMapping(handle,
true);
457 #if LLVM_ENABLE_ABI_BREAKING_CHECKS
460 mappings.incrementTimestamp(handle);
463 auto it = mappings.values.find(handle);
464 if (it == mappings.values.end())
468 for (
Value &mapped : association) {
470 mapped = replacement;
472 mappings.reverseValues[replacement].push_back(handle);
479 void transform::TransformState::recordOpHandleInvalidationOne(
486 if (invalidatedHandles.count(otherHandle) ||
487 newlyInvalidated.count(otherHandle))
490 FULL_LDBG(
"--recordOpHandleInvalidationOne\n");
492 (
DBGS() <<
"--ancestors: "
493 << llvm::interleaved(llvm::make_pointee_range(potentialAncestors))
499 for (
Operation *ancestor : potentialAncestors) {
502 { (
DBGS() <<
"----handle one ancestor: " << *ancestor <<
"\n"); });
504 { (
DBGS() <<
"----of payload with name: "
507 { (
DBGS() <<
"----of payload: " << *payloadOp <<
"\n"); });
509 if (!ancestor->isAncestor(payloadOp))
516 Location ancestorLoc = ancestor->getLoc();
518 std::optional<Location> throughValueLoc =
519 throughValue ? std::make_optional(throughValue.
getLoc()) : std::nullopt;
520 newlyInvalidated[otherHandle] = [ancestorLoc, opLoc, owner, operandNo,
522 throughValueLoc](
Location currentLoc) {
524 <<
"op uses a handle invalidated by a "
525 "previously executed transform op";
526 diag.attachNote(otherHandle.getLoc()) <<
"handle to invalidated ops";
527 diag.attachNote(owner->getLoc())
528 <<
"invalidated by this transform op that consumes its operand #"
530 <<
" and invalidates all handles to payload IR entities associated "
531 "with this operand and entities nested in them";
532 diag.attachNote(ancestorLoc) <<
"ancestor payload op";
533 diag.attachNote(opLoc) <<
"nested payload op";
534 if (throughValueLoc) {
535 diag.attachNote(*throughValueLoc)
536 <<
"consumed handle points to this payload value";
542 void transform::TransformState::recordValueHandleInvalidationByOpHandleOne(
549 if (invalidatedHandles.count(valueHandle) ||
550 newlyInvalidated.count(valueHandle))
553 for (
Operation *ancestor : potentialAncestors) {
555 std::optional<unsigned> resultNo;
559 if (
auto opResult = llvm::dyn_cast<OpResult>(payloadValue)) {
560 definingOp = opResult.getOwner();
561 resultNo = opResult.getResultNumber();
563 auto arg = llvm::cast<BlockArgument>(payloadValue);
565 argumentNo = arg.getArgNumber();
566 blockNo = std::distance(arg.getOwner()->getParent()->begin(),
567 arg.getOwner()->getIterator());
568 regionNo = arg.getOwner()->getParent()->getRegionNumber();
570 assert(definingOp &&
"expected the value to be defined by an op as result "
571 "or block argument");
572 if (!ancestor->isAncestor(definingOp))
577 Location ancestorLoc = ancestor->getLoc();
580 newlyInvalidated[valueHandle] = [valueHandle, owner, operandNo, resultNo,
581 argumentNo, blockNo, regionNo, ancestorLoc,
582 opLoc, valueLoc](
Location currentLoc) {
584 <<
"op uses a handle invalidated by a "
585 "previously executed transform op";
586 diag.attachNote(valueHandle.
getLoc()) <<
"invalidated handle";
587 diag.attachNote(owner->getLoc())
588 <<
"invalidated by this transform op that consumes its operand #"
590 <<
" and invalidates all handles to payload IR entities "
591 "associated with this operand and entities nested in them";
592 diag.attachNote(ancestorLoc)
593 <<
"ancestor op associated with the consumed handle";
595 diag.attachNote(opLoc)
596 <<
"op defining the value as result #" << *resultNo;
598 diag.attachNote(opLoc)
599 <<
"op defining the value as block argument #" << argumentNo
600 <<
" of block #" << blockNo <<
" in region #" << regionNo;
602 diag.attachNote(valueLoc) <<
"payload value";
607 void transform::TransformState::recordOpHandleInvalidation(
612 if (potentialAncestors.empty()) {
614 (
DBGS() <<
"----recording invalidation for empty handle: " << handle.
get()
620 newlyInvalidated[handle.
get()] = [owner, operandNo](
Location currentLoc) {
622 <<
"op uses a handle associated with empty "
623 "payload and invalidated by a "
624 "previously executed transform op";
625 diag.attachNote(owner->getLoc())
626 <<
"invalidated by this transform op that consumes its operand #"
639 for (
const auto &[region, mapping] : llvm::reverse(mappings)) {
643 for (
const auto &[payloadOp, otherHandles] : mapping->reverse) {
644 for (
Value otherHandle : otherHandles)
645 recordOpHandleInvalidationOne(handle, potentialAncestors, payloadOp,
646 otherHandle, throughValue,
654 for (
const auto &[payloadValue, valueHandles] : mapping->reverseValues) {
655 for (
Value valueHandle : valueHandles)
656 recordValueHandleInvalidationByOpHandleOne(handle, potentialAncestors,
657 payloadValue, valueHandle,
667 void transform::TransformState::recordValueHandleInvalidation(
671 for (
Value payloadValue : getPayloadValuesView(valueHandle.
get())) {
673 (void)getHandlesForPayloadValue(payloadValue, otherValueHandles);
674 for (
Value otherHandle : otherValueHandles) {
678 newlyInvalidated[otherHandle] = [otherHandle, owner, operandNo,
681 <<
"op uses a handle invalidated by a "
682 "previously executed transform op";
683 diag.attachNote(otherHandle.getLoc()) <<
"invalidated handle";
684 diag.attachNote(owner->getLoc())
685 <<
"invalidated by this transform op that consumes its operand #"
687 <<
" and invalidates handles to the same values as associated with "
689 diag.attachNote(valueLoc) <<
"payload value";
693 if (
auto opResult = llvm::dyn_cast<OpResult>(payloadValue)) {
694 Operation *payloadOp = opResult.getOwner();
695 recordOpHandleInvalidation(valueHandle, payloadOp, payloadValue,
698 auto arg = llvm::dyn_cast<BlockArgument>(payloadValue);
699 for (
Operation &payloadOp : *arg.getOwner())
700 recordOpHandleInvalidation(valueHandle, &payloadOp, payloadValue,
710 LogicalResult transform::TransformState::checkAndRecordHandleInvalidationImpl(
711 transform::TransformOpInterface transform,
713 FULL_LDBG(
"--Start checkAndRecordHandleInvalidation\n");
714 auto memoryEffectsIface =
715 cast<MemoryEffectOpInterface>(transform.getOperation());
717 memoryEffectsIface.getEffectsOnResource(
720 for (
OpOperand &target : transform->getOpOperands()) {
722 (
DBGS() <<
"----iterate on handle: " << target.get() <<
"\n");
728 auto it = invalidatedHandles.find(target.get());
729 auto nit = newlyInvalidated.find(target.get());
730 if (it != invalidatedHandles.end()) {
731 FULL_LDBG(
"--End checkAndRecordHandleInvalidation, found already "
732 "invalidated -> FAILURE\n");
733 return it->getSecond()(transform->getLoc()), failure();
735 if (!transform.allowsRepeatedHandleOperands() &&
736 nit != newlyInvalidated.end()) {
737 FULL_LDBG(
"--End checkAndRecordHandleInvalidation, found newly "
738 "invalidated (by this op) -> FAILURE\n");
739 return nit->getSecond()(transform->getLoc()), failure();
745 return isa<MemoryEffects::Free>(effect.getEffect()) &&
746 effect.getValue() == target.get();
748 if (llvm::any_of(effects, consumesTarget)) {
750 if (llvm::isa<transform::TransformHandleTypeInterface>(
751 target.get().getType())) {
752 FULL_LDBG(
"----recordOpHandleInvalidation\n");
754 llvm::to_vector(getPayloadOps(target.get()));
755 recordOpHandleInvalidation(target, payloadOps,
nullptr,
757 }
else if (llvm::isa<transform::TransformValueHandleTypeInterface>(
758 target.get().getType())) {
759 FULL_LDBG(
"----recordValueHandleInvalidation\n");
760 recordValueHandleInvalidation(target, newlyInvalidated);
762 FULL_LDBG(
"----not a TransformHandle -> SKIP AND DROP ON THE FLOOR\n");
765 FULL_LDBG(
"----no consume effect -> SKIP\n");
769 FULL_LDBG(
"--End checkAndRecordHandleInvalidation -> SUCCESS\n");
773 LogicalResult transform::TransformState::checkAndRecordHandleInvalidation(
774 transform::TransformOpInterface transform) {
775 InvalidatedHandleMap newlyInvalidated;
776 LogicalResult checkResult =
777 checkAndRecordHandleInvalidationImpl(transform, newlyInvalidated);
778 invalidatedHandles.insert(std::make_move_iterator(newlyInvalidated.begin()),
779 std::make_move_iterator(newlyInvalidated.end()));
783 template <
typename T>
786 transform::TransformOpInterface transform,
787 unsigned operandNumber) {
789 for (T p : payload) {
790 if (!seen.insert(p).second) {
792 transform.emitSilenceableError()
793 <<
"a handle passed as operand #" << operandNumber
794 <<
" and consumed by this operation points to a payload "
795 "entity more than once";
796 if constexpr (std::is_pointer_v<T>)
797 diag.attachNote(p->getLoc()) <<
"repeated target op";
799 diag.attachNote(p.getLoc()) <<
"repeated target value";
806 void transform::TransformState::compactOpHandles() {
807 for (
Value handle : opHandlesToCompact) {
808 Mappings &mappings = getMapping(handle,
true);
809 #if LLVM_ENABLE_ABI_BREAKING_CHECKS
810 if (llvm::is_contained(mappings.direct[handle],
nullptr))
813 mappings.incrementTimestamp(handle);
815 llvm::erase(mappings.direct[handle],
nullptr);
817 opHandlesToCompact.clear();
823 DBGS() <<
"applying: ";
825 llvm::dbgs() <<
"\n";
828 DBGS() <<
"Top-level payload before application:\n"
829 << *getTopLevel() <<
"\n");
830 auto printOnFailureRAII = llvm::make_scope_exit([
this] {
832 LLVM_DEBUG(
DBGS() <<
"Failing Top-level payload:\n"; getTopLevel()->print(
837 regionStack.back()->currentTransform = transform;
840 if (
options.getExpensiveChecksEnabled()) {
842 if (failed(checkAndRecordHandleInvalidation(transform)))
845 for (
OpOperand &operand : transform->getOpOperands()) {
847 (
DBGS() <<
"iterate on handle: " << operand.get() <<
"\n");
850 FULL_LDBG(
"--handle not consumed -> SKIP\n");
853 if (transform.allowsRepeatedHandleOperands()) {
854 FULL_LDBG(
"--op allows repeated handles -> SKIP\n");
859 Type operandType = operand.get().getType();
860 if (llvm::isa<TransformHandleTypeInterface>(operandType)) {
861 FULL_LDBG(
"--checkRepeatedConsumptionInOperand for Operation*\n");
863 checkRepeatedConsumptionInOperand<Operation *>(
864 getPayloadOpsView(operand.get()), transform,
865 operand.getOperandNumber());
870 }
else if (llvm::isa<TransformValueHandleTypeInterface>(operandType)) {
871 FULL_LDBG(
"--checkRepeatedConsumptionInOperand For Value\n");
873 checkRepeatedConsumptionInOperand<Value>(
874 getPayloadValuesView(operand.get()), transform,
875 operand.getOperandNumber());
881 FULL_LDBG(
"--not a TransformHandle -> SKIP AND DROP ON THE FLOOR\n");
888 transform.getConsumedHandleOpOperands();
897 for (
OpOperand *opOperand : consumedOperands) {
898 Value operand = opOperand->get();
899 if (llvm::isa<TransformHandleTypeInterface>(operand.
getType())) {
900 for (
Operation *payloadOp : getPayloadOps(operand)) {
901 llvm::append_range(origOpFlatResults, payloadOp->
getResults());
905 if (llvm::isa<TransformValueHandleTypeInterface>(operand.
getType())) {
906 for (
Value payloadValue : getPayloadValuesView(operand)) {
907 if (llvm::isa<OpResult>(payloadValue)) {
913 llvm::map_range(*llvm::cast<BlockArgument>(payloadValue).getOwner(),
920 <<
"unexpectedly consumed a value that is not a handle as operand #"
921 << opOperand->getOperandNumber();
923 <<
"value defined here with type " << operand.
getType();
932 llvm::find_if(llvm::reverse(regionStack), [&](
RegionScope *scope) {
933 return handle.getParentRegion() == scope->region;
935 assert(scopeIt != regionStack.rend() &&
936 "could not find region scope for handle");
938 return llvm::all_of(handle.getUsers(), [&](
Operation *user) {
939 return user == scope->currentTransform ||
940 happensBefore(user, scope->currentTransform);
959 transform->hasAttr(FindPayloadReplacementOpInterface::
960 kSilenceTrackingFailuresAttrName)) {
964 (void)trackingFailure.
silence();
969 result = std::move(trackingFailure);
973 result.
attachNote() <<
"tracking listener also failed: "
975 (void)trackingFailure.
silence();
988 for (
OpOperand *opOperand : consumedOperands) {
989 Value operand = opOperand->get();
990 if (llvm::isa<TransformHandleTypeInterface>(operand.
getType())) {
991 forgetMapping(operand, origOpFlatResults);
992 }
else if (llvm::isa<TransformValueHandleTypeInterface>(
994 forgetValueMapping(operand, origAssociatedOps);
998 if (failed(updateStateFromResults(results, transform->getResults())))
1001 printOnFailureRAII.release();
1003 DBGS() <<
"Top-level payload:\n";
1004 getTopLevel()->print(llvm::dbgs());
1009 LogicalResult transform::TransformState::updateStateFromResults(
1011 for (
OpResult result : opResults) {
1012 if (llvm::isa<TransformParamTypeInterface>(result.getType())) {
1013 assert(results.isParam(result.getResultNumber()) &&
1014 "expected parameters for the parameter-typed result");
1016 setParams(result, results.getParams(result.getResultNumber())))) {
1019 }
else if (llvm::isa<TransformValueHandleTypeInterface>(result.getType())) {
1020 assert(results.isValue(result.getResultNumber()) &&
1021 "expected values for value-type-result");
1022 if (failed(setPayloadValues(
1023 result, results.getValues(result.getResultNumber())))) {
1027 assert(!results.isParam(result.getResultNumber()) &&
1028 "expected payload ops for the non-parameter typed result");
1030 setPayloadOps(result, results.get(result.getResultNumber())))) {
1049 return state.replacePayloadOp(op, replacement);
1054 Value replacement) {
1055 return state.replacePayloadValue(value, replacement);
1066 for (
Block &block : *region) {
1067 for (
Value handle : block.getArguments()) {
1068 state.invalidatedHandles.erase(handle);
1072 state.invalidatedHandles.erase(handle);
1077 #if LLVM_ENABLE_ABI_BREAKING_CHECKS
1081 llvm::to_vector(llvm::make_first_range(state.mappings[region]->reverse));
1084 state.mappings.erase(region);
1085 state.regionStack.pop_back();
1092 transform::TransformResults::TransformResults(
unsigned numSegments) {
1093 operations.appendEmptyRows(numSegments);
1094 params.appendEmptyRows(numSegments);
1095 values.appendEmptyRows(numSegments);
1101 assert(position <
static_cast<int64_t
>(this->params.size()) &&
1102 "setting params for a non-existent handle");
1103 assert(this->params[position].data() ==
nullptr &&
"params already set");
1104 assert(operations[position].data() ==
nullptr &&
1105 "another kind of results already set");
1106 assert(values[position].data() ==
nullptr &&
1107 "another kind of results already set");
1108 this->params.replace(position, params);
1116 return set(handle, operations), success();
1119 return setParams(handle, params), success();
1122 return setValues(handle, payloadValues), success();
1125 if (!
diag.succeeded())
1126 llvm::dbgs() <<
diag.getStatusString() <<
"\n";
1127 assert(
diag.succeeded() &&
"incorrect mapping");
1129 (void)
diag.silence();
1133 transform::TransformOpInterface transform) {
1134 for (
OpResult opResult : transform->getResults()) {
1135 if (!isSet(opResult.getResultNumber()))
1136 setMappedValues(opResult, {});
1141 transform::TransformResults::get(
unsigned resultNumber)
const {
1142 assert(resultNumber < operations.size() &&
1143 "querying results for a non-existent handle");
1144 assert(operations[resultNumber].data() !=
nullptr &&
1145 "querying unset results (values or params expected?)");
1146 return operations[resultNumber];
1150 transform::TransformResults::getParams(
unsigned resultNumber)
const {
1151 assert(resultNumber < params.size() &&
1152 "querying params for a non-existent handle");
1153 assert(params[resultNumber].data() !=
nullptr &&
1154 "querying unset params (ops or values expected?)");
1155 return params[resultNumber];
1159 transform::TransformResults::getValues(
unsigned resultNumber)
const {
1160 assert(resultNumber < values.size() &&
1161 "querying values for a non-existent handle");
1162 assert(values[resultNumber].data() !=
nullptr &&
1163 "querying unset values (ops or params expected?)");
1164 return values[resultNumber];
1167 bool transform::TransformResults::isParam(
unsigned resultNumber)
const {
1168 assert(resultNumber < params.size() &&
1169 "querying association for a non-existent handle");
1170 return params[resultNumber].data() !=
nullptr;
1173 bool transform::TransformResults::isValue(
unsigned resultNumber)
const {
1174 assert(resultNumber < values.size() &&
1175 "querying association for a non-existent handle");
1176 return values[resultNumber].data() !=
nullptr;
1179 bool transform::TransformResults::isSet(
unsigned resultNumber)
const {
1180 assert(resultNumber < params.size() &&
1181 "querying association for a non-existent handle");
1182 return params[resultNumber].data() !=
nullptr ||
1183 operations[resultNumber].data() !=
nullptr ||
1184 values[resultNumber].data() !=
nullptr;
1192 TransformOpInterface op,
1196 for (
OpOperand *opOperand : transformOp.getConsumedHandleOpOperands()) {
1197 consumedHandles.insert(opOperand->get());
1204 for (
Value v : values) {
1209 defOp = v.getDefiningOp();
1212 if (defOp != v.getDefiningOp())
1221 "invalid number of replacement values");
1225 getTransformOp(),
"tracking listener failed to find replacement op "
1226 "during application of this transform op");
1230 Operation *defOp = getCommonDefiningOp(values);
1232 diag.attachNote() <<
"replacement values belong to different ops";
1237 if (
config.skipCastOps && isa<CastOpInterface>(defOp)) {
1241 <<
"using output of 'CastOpInterface' op";
1247 if (!
config.requireMatchingReplacementOpName ||
1263 if (
auto findReplacementOpInterface =
1264 dyn_cast<FindPayloadReplacementOpInterface>(defOp)) {
1265 values.assign(findReplacementOpInterface.getNextOperands());
1266 diag.attachNote(defOp->
getLoc()) <<
"using operands provided by "
1267 "'FindPayloadReplacementOpInterface'";
1270 }
while (!values.empty());
1272 diag.attachNote() <<
"ran out of suitable replacement values";
1280 reasonCallback(
diag);
1281 DBGS() <<
"Match Failure : " <<
diag.str() <<
"\n";
1285 void transform::TrackingListener::notifyOperationErased(
Operation *op) {
1288 (void)replacePayloadValue(value,
nullptr);
1290 (void)replacePayloadOp(op,
nullptr);
1293 void transform::TrackingListener::notifyOperationReplaced(
1296 "invalid number of replacement values");
1299 for (
auto [oldValue, newValue] : llvm::zip(op->
getResults(), newValues))
1300 (void)replacePayloadValue(oldValue, newValue);
1304 if (failed(getTransformState().getHandlesForPayloadOp(
1305 op, opHandles,
true))) {
1319 auto handleWasConsumed = [&] {
1320 return llvm::any_of(opHandles,
1321 [&](
Value h) {
return consumedHandles.contains(h); });
1326 if (
config.skipHandleFn) {
1327 auto it = llvm::find_if(opHandles,
1328 [&](
Value v) {
return !
config.skipHandleFn(v); });
1329 if (it != opHandles.end())
1331 }
else if (!opHandles.empty()) {
1332 aliveHandle = opHandles.front();
1334 if (!aliveHandle || handleWasConsumed()) {
1337 (void)replacePayloadOp(op,
nullptr);
1343 findReplacementOp(replacement, op, newValues);
1346 if (!
diag.succeeded()) {
1348 <<
"replacement is required because this handle must be updated";
1349 notifyPayloadReplacementNotFound(op, newValues, std::move(
diag));
1350 (void)replacePayloadOp(op,
nullptr);
1354 (void)replacePayloadOp(op, replacement);
1361 assert(status.succeeded() &&
"listener state was not checked");
1373 return !status.succeeded();
1381 diag.takeDiagnostics(diags);
1382 if (!status.succeeded())
1383 status.takeDiagnostics(diags);
1387 status.attachNote(op->
getLoc()) <<
"[" << errorCounter <<
"] replaced op";
1389 status.attachNote(value.
getLoc())
1390 <<
"[" << errorCounter <<
"] replacement value " << index;
1396 if (!matchFailure) {
1399 return matchFailure->str();
1405 reasonCallback(
diag);
1406 matchFailure = std::move(
diag);
1420 return listener->failed();
1425 if (hasTrackingFailures()) {
1433 return listener->replacePayloadOp(op, replacement);
1444 for (
Operation *child : targets.drop_front(position + 1)) {
1445 if (parent->isAncestor(child)) {
1448 <<
"transform operation consumes a handle pointing to an ancestor "
1449 "payload operation before its descendant";
1451 <<
"the ancestor is likely erased or rewritten before the "
1452 "descendant is accessed, leading to undefined behavior";
1453 diag.attachNote(parent->getLoc()) <<
"ancestor payload op";
1454 diag.attachNote(child->getLoc()) <<
"descendant payload op";
1473 diag.attachNote(payloadOpLoc) <<
"when applied to this op";
1477 if (partialResult.
size() != expectedNumResults) {
1478 auto diag =
emitDiag() <<
"application of " << transformOpName
1479 <<
" expected to produce " << expectedNumResults
1480 <<
" results (actually produced "
1481 << partialResult.
size() <<
").";
1482 diag.attachNote(transformOpLoc)
1483 <<
"if you need variadic results, consider a generic `apply` "
1484 <<
"instead of the specialized `applyToOne`.";
1489 for (
const auto &[ptr, res] :
1490 llvm::zip(partialResult, transformOp->
getResults())) {
1493 if (llvm::isa<TransformHandleTypeInterface>(res.getType()) &&
1494 !isa<Operation *>(ptr)) {
1495 return emitDiag() <<
"application of " << transformOpName
1496 <<
" expected to produce an Operation * for result #"
1497 << res.getResultNumber();
1499 if (llvm::isa<TransformParamTypeInterface>(res.getType()) &&
1500 !isa<Attribute>(ptr)) {
1501 return emitDiag() <<
"application of " << transformOpName
1502 <<
" expected to produce an Attribute for result #"
1503 << res.getResultNumber();
1505 if (llvm::isa<TransformValueHandleTypeInterface>(res.getType()) &&
1507 return emitDiag() <<
"application of " << transformOpName
1508 <<
" expected to produce a Value for result #"
1509 << res.getResultNumber();
1515 template <
typename T>
1517 return llvm::to_vector(llvm::map_range(
1527 if (llvm::any_of(partialResults,
1528 [](
MappedValue value) {
return value.isNull(); }))
1530 assert(transformOp->
getNumResults() == partialResults.size() &&
1531 "expected as many partial results as op as results");
1533 transposed[i].push_back(value);
1537 unsigned position = r.getResultNumber();
1538 if (llvm::isa<TransformParamTypeInterface>(r.getType())) {
1540 castVector<Attribute>(transposed[position]));
1541 }
else if (llvm::isa<TransformValueHandleTypeInterface>(r.getType())) {
1542 transformResults.
setValues(r, castVector<Value>(transposed[position]));
1544 transformResults.
set(r, castVector<Operation *>(transposed[position]));
1556 assert(mappings.size() == values.size() &&
"mismatching number of mappings");
1557 for (
auto &&[operand, mapped] : llvm::zip_equal(values, mappings)) {
1558 size_t mappedSize = mapped.size();
1559 if (llvm::isa<TransformHandleTypeInterface>(operand.getType())) {
1560 llvm::append_range(mapped, state.getPayloadOps(operand));
1561 }
else if (llvm::isa<TransformValueHandleTypeInterface>(
1562 operand.getType())) {
1563 llvm::append_range(mapped, state.getPayloadValues(operand));
1565 assert(llvm::isa<TransformParamTypeInterface>(operand.getType()) &&
1566 "unsupported kind of transform dialect value");
1567 llvm::append_range(mapped, state.getParams(operand));
1570 if (mapped.size() - mappedSize != 1 && !flatten)
1579 mappings.resize(mappings.size() + values.size());
1589 for (
auto &&[terminatorOperand, result] :
1592 if (llvm::isa<transform::TransformHandleTypeInterface>(result.getType())) {
1593 results.
set(result, state.getPayloadOps(terminatorOperand));
1594 }
else if (llvm::isa<transform::TransformValueHandleTypeInterface>(
1595 result.getType())) {
1596 results.
setValues(result, state.getPayloadValues(terminatorOperand));
1599 llvm::isa<transform::TransformParamTypeInterface>(result.getType()) &&
1600 "unhandled transform type interface");
1601 results.
setParams(result, state.getParams(terminatorOperand));
1623 iface.getEffectsOnValue(source, nestedEffects);
1624 for (
const auto &effect : nestedEffects)
1625 effects.emplace_back(effect.getEffect(), target, effect.getResource());
1634 auto iface = dyn_cast<MemoryEffectOpInterface>(&op);
1638 for (
auto &&[source, target] : llvm::zip(block.
getArguments(), operands)) {
1645 llvm::append_range(effects, nestedEffects);
1657 auto iface = dyn_cast<MemoryEffectOpInterface>(&op);
1661 iface.getEffects(effects);
1676 llvm::append_range(targets, state.getPayloadOps(op->
getOperand(0)));
1679 if (state.getNumTopLevelMappings() !=
1683 <<
" extra value bindings, but " << state.getNumTopLevelMappings()
1684 <<
" were provided to the interpreter";
1687 targets.push_back(state.getTopLevel());
1689 for (
unsigned i = 0, e = state.getNumTopLevelMappings(); i < e; ++i)
1690 extraMappings.push_back(llvm::to_vector(state.getTopLevelMapping(i)));
1693 if (failed(state.mapBlockArguments(region.
front().
getArgument(0), targets)))
1697 if (failed(state.mapBlockArgument(
1698 argument, extraMappings[argument.getArgNumber() - 1])))
1710 assert(isa<TransformOpInterface>(op) &&
1711 "should implement TransformOpInterface to have "
1712 "PossibleTopLevelTransformOpTrait");
1715 return op->
emitOpError() <<
"expects at least one region";
1718 if (!llvm::hasNItems(*bodyRegion, 1))
1719 return op->
emitOpError() <<
"expects a single-block region";
1724 <<
"expects the entry block to have at least one argument";
1726 if (!llvm::isa<TransformHandleTypeInterface>(
1729 <<
"expects the first entry block argument to be of type "
1730 "implementing TransformHandleTypeInterface";
1736 <<
"expects the type of the block argument to match "
1737 "the type of the operand";
1741 if (llvm::isa<TransformHandleTypeInterface, TransformParamTypeInterface,
1742 TransformValueHandleTypeInterface>(arg.
getType()))
1747 <<
"expects trailing entry block arguments to be of type implementing "
1748 "TransformHandleTypeInterface, TransformValueHandleTypeInterface or "
1749 "TransformParamTypeInterface";
1759 <<
"expects operands to be provided for a nested op";
1760 diag.attachNote(parent->getLoc())
1761 <<
"nested in another possible top-level op";
1776 bool hasPayloadOperands =
false;
1779 if (llvm::isa<TransformHandleTypeInterface,
1780 TransformValueHandleTypeInterface>(operand.get().getType()))
1781 hasPayloadOperands =
true;
1783 if (hasPayloadOperands)
1792 llvm::report_fatal_error(
1793 Twine(
"ParamProducerTransformOpTrait must be attached to an op that "
1794 "implements MemoryEffectsOpInterface, found on ") +
1798 if (llvm::isa<TransformParamTypeInterface>(result.getType()))
1801 <<
"ParamProducerTransformOpTrait attached to this op expects "
1802 "result types to implement TransformParamTypeInterface";
1824 template <
typename EffectTy,
typename ResourceTy,
typename Range>
1827 return isa<EffectTy>(effect.
getEffect()) &&
1833 transform::TransformOpInterface transform) {
1834 auto iface = cast<MemoryEffectOpInterface>(transform.getOperation());
1836 iface.getEffectsOnValue(handle, effects);
1837 return ::hasEffect<MemoryEffects::Read, TransformMappingResource>(effects) &&
1838 ::hasEffect<MemoryEffects::Free, TransformMappingResource>(effects);
1884 auto iface = cast<MemoryEffectOpInterface>(transform.getOperation());
1886 iface.getEffects(effects);
1887 return ::hasEffect<MemoryEffects::Write, PayloadIRResource>(effects);
1891 auto iface = cast<MemoryEffectOpInterface>(transform.getOperation());
1893 iface.getEffects(effects);
1894 return ::hasEffect<MemoryEffects::Read, PayloadIRResource>(effects);
1898 Block &block, llvm::SmallDenseSet<unsigned int> &consumedArguments) {
1901 auto iface = dyn_cast<MemoryEffectOpInterface>(nested);
1906 iface.getEffects(effects);
1909 dyn_cast_or_null<BlockArgument>(effect.getValue());
1910 if (!argument || argument.
getOwner() != &block ||
1911 !isa<MemoryEffects::Free>(effect.getEffect()) ||
1925 TransformOpInterface transformOp) {
1927 consumedOperands.reserve(transformOp->getNumOperands());
1928 auto memEffectInterface =
1929 cast<MemoryEffectOpInterface>(transformOp.getOperation());
1931 for (
OpOperand &target : transformOp->getOpOperands()) {
1933 memEffectInterface.getEffectsOnValue(target.get(), effects);
1935 return isa<transform::TransformMappingResource>(
1937 isa<MemoryEffects::Free>(effect.
getEffect());
1939 consumedOperands.push_back(&target);
1942 return consumedOperands;
1946 auto iface = cast<MemoryEffectOpInterface>(op);
1948 iface.getEffects(effects);
1950 auto effectsOn = [&](
Value value) {
1951 return llvm::make_filter_range(
1953 return instance.
getValue() == value;
1957 std::optional<unsigned> firstConsumedOperand;
1959 auto range = effectsOn(operand.get());
1960 if (range.empty()) {
1962 op->
emitError() <<
"TransformOpInterface requires memory effects "
1963 "on operands to be specified";
1964 diag.attachNote() <<
"no effects specified for operand #"
1965 << operand.getOperandNumber();
1968 if (::hasEffect<MemoryEffects::Allocate, TransformMappingResource>(range)) {
1970 <<
"TransformOpInterface did not expect "
1971 "'allocate' memory effect on an operand";
1972 diag.attachNote() <<
"specified for operand #"
1973 << operand.getOperandNumber();
1976 if (!firstConsumedOperand &&
1977 ::hasEffect<MemoryEffects::Free, TransformMappingResource>(range)) {
1978 firstConsumedOperand = operand.getOperandNumber();
1982 if (firstConsumedOperand &&
1983 !::hasEffect<MemoryEffects::Write, PayloadIRResource>(effects)) {
1986 <<
"TransformOpInterface expects ops consuming operands to have a "
1987 "'write' effect on the payload resource";
1988 diag.attachNote() <<
"consumes operand #" << *firstConsumedOperand;
1993 auto range = effectsOn(result);
1994 if (!::hasEffect<MemoryEffects::Allocate, TransformMappingResource>(
1997 op->
emitError() <<
"TransformOpInterface requires 'allocate' memory "
1998 "effect to be specified for results";
1999 diag.attachNote() <<
"no 'allocate' effect specified for result #"
2000 << result.getResultNumber();
2013 Operation *payloadRoot, TransformOpInterface transform,
2018 if (enforceToplevelTransformOp) {
2020 transform->getNumOperands() != 0) {
2021 return transform->emitError()
2022 <<
"expected transform to start at the top-level transform op";
2029 TransformState state(transform->getParentRegion(), payloadRoot, extraMapping,
2031 if (stateInitializer)
2032 stateInitializer(state);
2033 if (state.applyTransform(transform).checkAndReport().failed())
2036 return stateExporter(state);
2044 #include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.cpp.inc"
2045 #include "mlir/Dialect/Transform/Interfaces/TransformTypeInterfaces.cpp.inc"
static InFlightDiagnostic emitDiag(Location location, DiagnosticSeverity severity, const Twine &message)
Helper function used to emit a diagnostic with an optionally empty twine message.
static std::string diag(const llvm::Value &value)
static llvm::ManagedStatic< PassManagerOptions > options
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
Attributes are known-constant values of operations.
This class represents an argument of a Block.
Block * getOwner() const
Returns the block that owns this argument.
unsigned getArgNumber() const
Returns the number of this argument.
Block represents an ordered list of Operations.
MutableArrayRef< BlockArgument > BlockArgListType
Operation * findAncestorOpInBlock(Operation &op)
Returns 'op' if 'op' lies in this block, or otherwise finds the ancestor operation of 'op' that lies ...
BlockArgument getArgument(unsigned i)
unsigned getNumArguments()
Operation * getTerminator()
Get the terminator operation of this block.
BlockArgListType getArguments()
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
A compatibility class connecting InFlightDiagnostic to DiagnosedSilenceableFailure while providing an...
The result of a transform IR operation application.
LogicalResult silence()
Converts silenceable failure into LogicalResult success without reporting the diagnostic,...
static DiagnosedSilenceableFailure success()
Constructs a DiagnosedSilenceableFailure in the success state.
Diagnostic & attachNote(std::optional< Location > loc=std::nullopt)
Attaches a note to the last diagnostic.
std::string getMessage() const
Returns the diagnostic message without emitting it.
bool isDefiniteFailure() const
Returns true if this is a definite failure.
LogicalResult checkAndReport()
Converts all kinds of failure into a LogicalResult failure, emitting the diagnostic if necessary.
static DiagnosedSilenceableFailure silenceableFailure(Diagnostic &&diag)
Constructs a DiagnosedSilenceableFailure in the silenceable failure state, ready to emit the given di...
bool succeeded() const
Returns true if this is a success.
static DiagnosedSilenceableFailure definiteFailure()
Constructs a DiagnosedSilenceableFailure in the failure state.
bool isSilenceableFailure() const
Returns true if this is a silenceable failure.
This class contains all of the information necessary to report a diagnostic to the DiagnosticEngine.
IRValueT get() const
Return the current value being used by this operand.
This class represents a diagnostic that is inflight and set to be reported.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
void setListener(Listener *newListener)
Sets the listener of this builder to the one provided.
This class represents an operand of an operation.
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Set of flags used to control the behavior of the various IR print methods (e.g.
This is a value defined by a result of an operation.
unsigned getResultNumber() const
Returns the number of this result.
This class provides the API for a sub-set of ops that are known to be constant-like.
This class provides the API for ops that are known to be isolated from above.
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
StringAttr getIdentifier() const
Return the name of this operation as a StringAttr.
T::Concept * getInterface() const
Returns an instance of the concept object for the given interface if it was registered to this operat...
Operation is the basic unit of execution within MLIR.
Value getOperand(unsigned idx)
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
bool isBeforeInBlock(Operation *other)
Given an operation 'other' that is within the same parent block, return whether the current operation...
unsigned getNumRegions()
Returns the number of regions held by this operation.
Location getLoc()
The source location the operation was defined or derived from.
unsigned getNumOperands()
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Block * getBlock()
Returns the operation block that contains this operation.
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Operation * getParentWithTrait()
Returns the closest surrounding parent operation with trait Trait.
OperationName getName()
The name of an operation is the key identifier for it.
MutableArrayRef< OpOperand > getOpOperands()
operand_range getOperands()
Returns an iterator on the underlying Value's.
result_range getOpResults()
result_range getResults()
bool isProperAncestor(Operation *other)
Return true if this operation is a proper ancestor of the other operation.
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
unsigned getNumResults()
Return the number of results held by this operation.
A 2D array where each row may have different length.
size_t size() const
Returns the number of rows in the 2D array.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Operation * getParentOp()
Return the parent operation this region is attached to.
This class implements the result iterators for the Operation class.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
This class represents a specific instance of an effect.
Resource * getResource() const
Return the resource that the effect applies to.
EffectT * getEffect() const
Return the effect being applied.
Value getValue() const
Return the value the effect is applied on, or nullptr if there isn't a known value being affected.
static DerivedEffect * get()
Returns a unique instance for the derived effect class.
static TransformMappingResource * get()
Returns a unique instance for the given effect class.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Operation * getOwner() const
Return the owner of this operand.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Include the generated interface declarations.
const FrozenRewritePatternSet GreedyRewriteConfig config
DiagnosedSilenceableFailure emitSilenceableFailure(Location loc, const Twine &message={})
Emits a silenceable failure with the given message.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
DiagnosedDefiniteFailure emitDefiniteFailure(Location loc, const Twine &message={})
Emits a definite failure with the given message.
bool hasEffect(Operation *op)
Returns true if op has an effect of type EffectTy.
Represents a range (offset, size, and stride) where each element of the triple may be dynamic or stat...