16 #include "llvm/ADT/STLExtras.h"
17 #include "llvm/ADT/ScopeExit.h"
18 #include "llvm/ADT/iterator.h"
19 #include "llvm/Support/Debug.h"
20 #include "llvm/Support/ErrorHandling.h"
21 #include "llvm/Support/InterleavedRange.h"
23 #define DEBUG_TYPE "transform-dialect"
24 #define DEBUG_TYPE_FULL "transform-dialect-full"
25 #define DEBUG_PRINT_AFTER_ALL "transform-dialect-print-top-level-after-all"
26 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "] ")
27 #define LDBG(X) LLVM_DEBUG(DBGS() << (X))
28 #define FULL_LDBG(X) DEBUG_WITH_TYPE(DEBUG_TYPE_FULL, (DBGS() << (X)))
53 constexpr
const Value transform::TransformState::kTopLevelValue;
55 transform::TransformState::TransformState(
58 const TransformOptions &
options)
60 topLevelMappedValues.reserve(extraMappings.
size());
62 topLevelMappedValues.push_back(mapping);
64 RegionScope *scope =
new RegionScope(*
this, *region);
65 topLevelRegionScope.reset(scope);
69 Operation *transform::TransformState::getTopLevel()
const {
return topLevel; }
72 transform::TransformState::getPayloadOpsView(
Value value)
const {
73 const TransformOpMapping &operationMapping = getMapping(value).direct;
74 auto iter = operationMapping.find(value);
75 assert(iter != operationMapping.end() &&
76 "cannot find mapping for payload handle (param/value handle "
78 return iter->getSecond();
83 auto iter = mapping.find(value);
84 assert(iter != mapping.end() &&
"cannot find mapping for param handle "
85 "(operation/value handle provided?)");
86 return iter->getSecond();
90 transform::TransformState::getPayloadValuesView(
Value handleValue)
const {
91 const ValueMapping &mapping = getMapping(handleValue).values;
92 auto iter = mapping.find(handleValue);
93 assert(iter != mapping.end() &&
"cannot find mapping for value handle "
94 "(param/operation handle provided?)");
95 return iter->getSecond();
100 bool includeOutOfScope)
const {
102 for (
const auto &[region, mapping] : llvm::reverse(mappings)) {
103 auto iterator = mapping->reverse.find(op);
104 if (iterator != mapping->reverse.end()) {
105 llvm::append_range(handles, iterator->getSecond());
109 if (!includeOutOfScope &&
114 return success(found);
119 bool includeOutOfScope)
const {
121 for (
const auto &[region, mapping] : llvm::reverse(mappings)) {
122 auto iterator = mapping->reverseValues.find(payloadValue);
123 if (iterator != mapping->reverseValues.end()) {
124 llvm::append_range(handles, iterator->getSecond());
128 if (!includeOutOfScope &&
133 return success(found);
143 if (llvm::isa<transform::TransformHandleTypeInterface>(handle.
getType())) {
145 operations.reserve(values.size());
147 if (
auto *op = llvm::dyn_cast_if_present<Operation *>(value)) {
148 operations.push_back(op);
152 <<
"wrong kind of value provided for top-level operation handle";
154 if (failed(operationsFn(operations)))
159 if (llvm::isa<transform::TransformValueHandleTypeInterface>(
162 payloadValues.reserve(values.size());
164 if (
auto v = llvm::dyn_cast_if_present<Value>(value)) {
165 payloadValues.push_back(v);
169 <<
"wrong kind of value provided for the top-level value handle";
171 if (failed(valuesFn(payloadValues)))
176 assert(llvm::isa<transform::TransformParamTypeInterface>(handle.
getType()) &&
177 "unsupported kind of block argument");
179 parameters.reserve(values.size());
181 if (
auto attr = llvm::dyn_cast_if_present<Attribute>(value)) {
182 parameters.push_back(attr);
186 <<
"wrong kind of value provided for top-level parameter";
188 if (failed(paramsFn(parameters)))
199 return setPayloadOps(argument, operations);
202 return setParams(argument, params);
205 return setPayloadValues(argument, payloadValues);
213 for (
auto &&[argument, values] : llvm::zip_equal(arguments, mapping))
214 if (failed(mapBlockArgument(argument, values)))
220 transform::TransformState::setPayloadOps(
Value value,
222 assert(value != kTopLevelValue &&
223 "attempting to reset the transformation root");
224 assert(llvm::isa<TransformHandleTypeInterface>(value.
getType()) &&
225 "wrong handle type");
231 <<
"attempting to assign a null payload op to this transform value";
234 auto iface = llvm::cast<TransformHandleTypeInterface>(value.
getType());
236 iface.checkPayload(value.
getLoc(), targets);
243 Mappings &mappings = getMapping(value);
245 mappings.direct.insert({value, std::move(storedTargets)}).second;
246 assert(inserted &&
"value is already associated with another list");
250 mappings.reverse[op].push_back(value);
256 transform::TransformState::setPayloadValues(
Value handle,
258 assert(handle !=
nullptr &&
"attempting to set params for a null value");
259 assert(llvm::isa<TransformValueHandleTypeInterface>(handle.
getType()) &&
260 "wrong handle type");
262 for (
Value payload : payloadValues) {
265 return emitError(handle.
getLoc()) <<
"attempting to assign a null payload "
266 "value to this transform handle";
269 auto iface = llvm::cast<TransformValueHandleTypeInterface>(handle.
getType());
272 iface.checkPayload(handle.
getLoc(), payloadValueVector);
276 Mappings &mappings = getMapping(handle);
278 mappings.values.insert({handle, std::move(payloadValueVector)}).second;
281 "value handle is already associated with another list of payload values");
284 for (
Value payload : payloadValues)
285 mappings.reverseValues[payload].push_back(handle);
290 LogicalResult transform::TransformState::setParams(
Value value,
292 assert(value !=
nullptr &&
"attempting to set params for a null value");
298 <<
"attempting to assign a null parameter to this transform value";
301 auto valueType = llvm::dyn_cast<TransformParamTypeInterface>(value.
getType());
303 "cannot associate parameter with a value of non-parameter type");
305 valueType.checkPayload(value.
getLoc(), params);
309 Mappings &mappings = getMapping(value);
311 mappings.params.insert({value, llvm::to_vector(params)}).second;
312 assert(inserted &&
"value is already associated with another list of params");
317 template <
typename Mapping,
typename Key,
typename Mapped>
319 auto it = mapping.find(key);
320 if (it == mapping.end())
323 llvm::erase(it->getSecond(), mapped);
324 if (it->getSecond().empty())
328 void transform::TransformState::forgetMapping(
Value opHandle,
330 bool allowOutOfScope) {
331 Mappings &mappings = getMapping(opHandle, allowOutOfScope);
332 for (
Operation *op : mappings.direct[opHandle])
334 mappings.direct.erase(opHandle);
335 #if LLVM_ENABLE_ABI_BREAKING_CHECKS
338 mappings.incrementTimestamp(opHandle);
341 for (
Value opResult : origOpFlatResults) {
343 (void)getHandlesForPayloadValue(opResult, resultHandles);
344 for (
Value resultHandle : resultHandles) {
345 Mappings &localMappings = getMapping(resultHandle);
347 #if LLVM_ENABLE_ABI_BREAKING_CHECKS
350 mappings.incrementTimestamp(resultHandle);
357 void transform::TransformState::forgetValueMapping(
359 Mappings &mappings = getMapping(valueHandle);
360 for (
Value payloadValue : mappings.reverseValues[valueHandle])
362 mappings.values.erase(valueHandle);
363 #if LLVM_ENABLE_ABI_BREAKING_CHECKS
366 mappings.incrementTimestamp(valueHandle);
369 for (
Operation *payloadOp : payloadOperations) {
371 (void)getHandlesForPayloadOp(payloadOp, opHandles);
372 for (
Value opHandle : opHandles) {
373 Mappings &localMappings = getMapping(opHandle);
377 #if LLVM_ENABLE_ABI_BREAKING_CHECKS
380 localMappings.incrementTimestamp(opHandle);
387 transform::TransformState::replacePayloadOp(
Operation *op,
394 (void)getHandlesForPayloadValue(opResult, valueHandles,
396 assert(valueHandles.empty() &&
"expected no mapping to old results");
403 if (failed(getHandlesForPayloadOp(op, opHandles,
true)))
405 for (
Value handle : opHandles) {
406 Mappings &mappings = getMapping(handle,
true);
419 for (
Value handle : opHandles) {
420 Mappings &mappings = getMapping(handle,
true);
421 auto it = mappings.direct.find(handle);
422 if (it == mappings.direct.end())
429 mapped = replacement;
433 mappings.reverse[replacement].push_back(handle);
435 opHandlesToCompact.insert(handle);
443 transform::TransformState::replacePayloadValue(
Value value,
Value replacement) {
445 if (failed(getHandlesForPayloadValue(value, valueHandles,
449 for (
Value handle : valueHandles) {
450 Mappings &mappings = getMapping(handle,
true);
457 #if LLVM_ENABLE_ABI_BREAKING_CHECKS
460 mappings.incrementTimestamp(handle);
463 auto it = mappings.values.find(handle);
464 if (it == mappings.values.end())
468 for (
Value &mapped : association) {
470 mapped = replacement;
472 mappings.reverseValues[replacement].push_back(handle);
479 void transform::TransformState::recordOpHandleInvalidationOne(
486 if (invalidatedHandles.count(otherHandle) ||
487 newlyInvalidated.count(otherHandle))
490 FULL_LDBG(
"--recordOpHandleInvalidationOne\n");
492 (
DBGS() <<
"--ancestors: "
493 << llvm::interleaved(llvm::make_pointee_range(potentialAncestors))
499 for (
Operation *ancestor : potentialAncestors) {
502 { (
DBGS() <<
"----handle one ancestor: " << *ancestor <<
"\n"); });
504 { (
DBGS() <<
"----of payload with name: "
507 { (
DBGS() <<
"----of payload: " << *payloadOp <<
"\n"); });
509 if (!ancestor->isAncestor(payloadOp))
516 Location ancestorLoc = ancestor->getLoc();
518 std::optional<Location> throughValueLoc =
519 throughValue ? std::make_optional(throughValue.
getLoc()) : std::nullopt;
520 newlyInvalidated[otherHandle] = [ancestorLoc, opLoc, owner, operandNo,
522 throughValueLoc](
Location currentLoc) {
524 <<
"op uses a handle invalidated by a "
525 "previously executed transform op";
526 diag.attachNote(otherHandle.getLoc()) <<
"handle to invalidated ops";
527 diag.attachNote(owner->getLoc())
528 <<
"invalidated by this transform op that consumes its operand #"
530 <<
" and invalidates all handles to payload IR entities associated "
531 "with this operand and entities nested in them";
532 diag.attachNote(ancestorLoc) <<
"ancestor payload op";
533 diag.attachNote(opLoc) <<
"nested payload op";
534 if (throughValueLoc) {
535 diag.attachNote(*throughValueLoc)
536 <<
"consumed handle points to this payload value";
542 void transform::TransformState::recordValueHandleInvalidationByOpHandleOne(
549 if (invalidatedHandles.count(valueHandle) ||
550 newlyInvalidated.count(valueHandle))
553 for (
Operation *ancestor : potentialAncestors) {
555 std::optional<unsigned> resultNo;
559 if (
auto opResult = llvm::dyn_cast<OpResult>(payloadValue)) {
560 definingOp = opResult.getOwner();
561 resultNo = opResult.getResultNumber();
563 auto arg = llvm::cast<BlockArgument>(payloadValue);
565 argumentNo = arg.getArgNumber();
566 blockNo = std::distance(arg.getOwner()->getParent()->begin(),
567 arg.getOwner()->getIterator());
568 regionNo = arg.getOwner()->getParent()->getRegionNumber();
570 assert(definingOp &&
"expected the value to be defined by an op as result "
571 "or block argument");
572 if (!ancestor->isAncestor(definingOp))
577 Location ancestorLoc = ancestor->getLoc();
580 newlyInvalidated[valueHandle] = [valueHandle, owner, operandNo, resultNo,
581 argumentNo, blockNo, regionNo, ancestorLoc,
582 opLoc, valueLoc](
Location currentLoc) {
584 <<
"op uses a handle invalidated by a "
585 "previously executed transform op";
586 diag.attachNote(valueHandle.
getLoc()) <<
"invalidated handle";
587 diag.attachNote(owner->getLoc())
588 <<
"invalidated by this transform op that consumes its operand #"
590 <<
" and invalidates all handles to payload IR entities "
591 "associated with this operand and entities nested in them";
592 diag.attachNote(ancestorLoc)
593 <<
"ancestor op associated with the consumed handle";
595 diag.attachNote(opLoc)
596 <<
"op defining the value as result #" << *resultNo;
598 diag.attachNote(opLoc)
599 <<
"op defining the value as block argument #" << argumentNo
600 <<
" of block #" << blockNo <<
" in region #" << regionNo;
602 diag.attachNote(valueLoc) <<
"payload value";
607 void transform::TransformState::recordOpHandleInvalidation(
612 if (potentialAncestors.empty()) {
614 (
DBGS() <<
"----recording invalidation for empty handle: " << handle.
get()
620 newlyInvalidated[handle.
get()] = [owner, operandNo](
Location currentLoc) {
622 <<
"op uses a handle associated with empty "
623 "payload and invalidated by a "
624 "previously executed transform op";
625 diag.attachNote(owner->getLoc())
626 <<
"invalidated by this transform op that consumes its operand #"
639 for (
const auto &[region, mapping] : llvm::reverse(mappings)) {
643 for (
const auto &[payloadOp, otherHandles] : mapping->reverse) {
644 for (
Value otherHandle : otherHandles)
645 recordOpHandleInvalidationOne(handle, potentialAncestors, payloadOp,
646 otherHandle, throughValue,
654 for (
const auto &[payloadValue, valueHandles] : mapping->reverseValues) {
655 for (
Value valueHandle : valueHandles)
656 recordValueHandleInvalidationByOpHandleOne(handle, potentialAncestors,
657 payloadValue, valueHandle,
667 void transform::TransformState::recordValueHandleInvalidation(
671 for (
Value payloadValue : getPayloadValuesView(valueHandle.
get())) {
673 (void)getHandlesForPayloadValue(payloadValue, otherValueHandles);
674 for (
Value otherHandle : otherValueHandles) {
678 newlyInvalidated[otherHandle] = [otherHandle, owner, operandNo,
681 <<
"op uses a handle invalidated by a "
682 "previously executed transform op";
683 diag.attachNote(otherHandle.getLoc()) <<
"invalidated handle";
684 diag.attachNote(owner->getLoc())
685 <<
"invalidated by this transform op that consumes its operand #"
687 <<
" and invalidates handles to the same values as associated with "
689 diag.attachNote(valueLoc) <<
"payload value";
693 if (
auto opResult = llvm::dyn_cast<OpResult>(payloadValue)) {
694 Operation *payloadOp = opResult.getOwner();
695 recordOpHandleInvalidation(valueHandle, payloadOp, payloadValue,
698 auto arg = llvm::dyn_cast<BlockArgument>(payloadValue);
699 for (
Operation &payloadOp : *arg.getOwner())
700 recordOpHandleInvalidation(valueHandle, &payloadOp, payloadValue,
710 LogicalResult transform::TransformState::checkAndRecordHandleInvalidationImpl(
711 transform::TransformOpInterface transform,
713 FULL_LDBG(
"--Start checkAndRecordHandleInvalidation\n");
714 auto memoryEffectsIface =
715 cast<MemoryEffectOpInterface>(transform.getOperation());
717 memoryEffectsIface.getEffectsOnResource(
720 for (
OpOperand &target : transform->getOpOperands()) {
722 (
DBGS() <<
"----iterate on handle: " << target.get() <<
"\n");
728 auto it = invalidatedHandles.find(target.get());
729 auto nit = newlyInvalidated.find(target.get());
730 if (it != invalidatedHandles.end()) {
731 FULL_LDBG(
"--End checkAndRecordHandleInvalidation, found already "
732 "invalidated -> FAILURE\n");
733 return it->getSecond()(transform->getLoc()), failure();
735 if (!transform.allowsRepeatedHandleOperands() &&
736 nit != newlyInvalidated.end()) {
737 FULL_LDBG(
"--End checkAndRecordHandleInvalidation, found newly "
738 "invalidated (by this op) -> FAILURE\n");
739 return nit->getSecond()(transform->getLoc()), failure();
745 return isa<MemoryEffects::Free>(effect.getEffect()) &&
746 effect.getValue() == target.get();
748 if (llvm::any_of(effects, consumesTarget)) {
750 if (llvm::isa<transform::TransformHandleTypeInterface>(
751 target.get().getType())) {
752 FULL_LDBG(
"----recordOpHandleInvalidation\n");
754 llvm::to_vector(getPayloadOps(target.get()));
755 recordOpHandleInvalidation(target, payloadOps,
nullptr,
757 }
else if (llvm::isa<transform::TransformValueHandleTypeInterface>(
758 target.get().getType())) {
759 FULL_LDBG(
"----recordValueHandleInvalidation\n");
760 recordValueHandleInvalidation(target, newlyInvalidated);
762 FULL_LDBG(
"----not a TransformHandle -> SKIP AND DROP ON THE FLOOR\n");
765 FULL_LDBG(
"----no consume effect -> SKIP\n");
769 FULL_LDBG(
"--End checkAndRecordHandleInvalidation -> SUCCESS\n");
773 LogicalResult transform::TransformState::checkAndRecordHandleInvalidation(
774 transform::TransformOpInterface transform) {
775 InvalidatedHandleMap newlyInvalidated;
776 LogicalResult checkResult =
777 checkAndRecordHandleInvalidationImpl(transform, newlyInvalidated);
778 invalidatedHandles.insert(std::make_move_iterator(newlyInvalidated.begin()),
779 std::make_move_iterator(newlyInvalidated.end()));
783 template <
typename T>
786 transform::TransformOpInterface transform,
787 unsigned operandNumber) {
789 for (T p : payload) {
790 if (!seen.insert(p).second) {
792 transform.emitSilenceableError()
793 <<
"a handle passed as operand #" << operandNumber
794 <<
" and consumed by this operation points to a payload "
795 "entity more than once";
796 if constexpr (std::is_pointer_v<T>)
797 diag.attachNote(p->getLoc()) <<
"repeated target op";
799 diag.attachNote(p.getLoc()) <<
"repeated target value";
806 void transform::TransformState::compactOpHandles() {
807 for (
Value handle : opHandlesToCompact) {
808 Mappings &mappings = getMapping(handle,
true);
809 #if LLVM_ENABLE_ABI_BREAKING_CHECKS
810 if (llvm::find(mappings.direct[handle],
nullptr) !=
811 mappings.direct[handle].end())
814 mappings.incrementTimestamp(handle);
816 llvm::erase(mappings.direct[handle],
nullptr);
818 opHandlesToCompact.clear();
824 DBGS() <<
"applying: ";
826 llvm::dbgs() <<
"\n";
829 DBGS() <<
"Top-level payload before application:\n"
830 << *getTopLevel() <<
"\n");
831 auto printOnFailureRAII = llvm::make_scope_exit([
this] {
833 LLVM_DEBUG(
DBGS() <<
"Failing Top-level payload:\n"; getTopLevel()->print(
838 regionStack.back()->currentTransform = transform;
841 if (
options.getExpensiveChecksEnabled()) {
843 if (failed(checkAndRecordHandleInvalidation(transform)))
846 for (
OpOperand &operand : transform->getOpOperands()) {
848 (
DBGS() <<
"iterate on handle: " << operand.get() <<
"\n");
851 FULL_LDBG(
"--handle not consumed -> SKIP\n");
854 if (transform.allowsRepeatedHandleOperands()) {
855 FULL_LDBG(
"--op allows repeated handles -> SKIP\n");
860 Type operandType = operand.get().getType();
861 if (llvm::isa<TransformHandleTypeInterface>(operandType)) {
862 FULL_LDBG(
"--checkRepeatedConsumptionInOperand for Operation*\n");
864 checkRepeatedConsumptionInOperand<Operation *>(
865 getPayloadOpsView(operand.get()), transform,
866 operand.getOperandNumber());
871 }
else if (llvm::isa<TransformValueHandleTypeInterface>(operandType)) {
872 FULL_LDBG(
"--checkRepeatedConsumptionInOperand For Value\n");
874 checkRepeatedConsumptionInOperand<Value>(
875 getPayloadValuesView(operand.get()), transform,
876 operand.getOperandNumber());
882 FULL_LDBG(
"--not a TransformHandle -> SKIP AND DROP ON THE FLOOR\n");
889 transform.getConsumedHandleOpOperands();
898 for (
OpOperand *opOperand : consumedOperands) {
899 Value operand = opOperand->get();
900 if (llvm::isa<TransformHandleTypeInterface>(operand.
getType())) {
901 for (
Operation *payloadOp : getPayloadOps(operand)) {
902 llvm::append_range(origOpFlatResults, payloadOp->
getResults());
906 if (llvm::isa<TransformValueHandleTypeInterface>(operand.
getType())) {
907 for (
Value payloadValue : getPayloadValuesView(operand)) {
908 if (llvm::isa<OpResult>(payloadValue)) {
914 llvm::map_range(*llvm::cast<BlockArgument>(payloadValue).getOwner(),
921 <<
"unexpectedly consumed a value that is not a handle as operand #"
922 << opOperand->getOperandNumber();
924 <<
"value defined here with type " << operand.
getType();
933 llvm::find_if(llvm::reverse(regionStack), [&](
RegionScope *scope) {
934 return handle.getParentRegion() == scope->region;
936 assert(scopeIt != regionStack.rend() &&
937 "could not find region scope for handle");
939 return llvm::all_of(handle.getUsers(), [&](
Operation *user) {
940 return user == scope->currentTransform ||
941 happensBefore(user, scope->currentTransform);
960 transform->hasAttr(FindPayloadReplacementOpInterface::
961 kSilenceTrackingFailuresAttrName)) {
965 (void)trackingFailure.
silence();
970 result = std::move(trackingFailure);
974 result.
attachNote() <<
"tracking listener also failed: "
976 (void)trackingFailure.
silence();
989 for (
OpOperand *opOperand : consumedOperands) {
990 Value operand = opOperand->get();
991 if (llvm::isa<TransformHandleTypeInterface>(operand.
getType())) {
992 forgetMapping(operand, origOpFlatResults);
993 }
else if (llvm::isa<TransformValueHandleTypeInterface>(
995 forgetValueMapping(operand, origAssociatedOps);
999 if (failed(updateStateFromResults(results, transform->getResults())))
1002 printOnFailureRAII.release();
1004 DBGS() <<
"Top-level payload:\n";
1005 getTopLevel()->print(llvm::dbgs());
1010 LogicalResult transform::TransformState::updateStateFromResults(
1012 for (
OpResult result : opResults) {
1013 if (llvm::isa<TransformParamTypeInterface>(result.getType())) {
1014 assert(results.isParam(result.getResultNumber()) &&
1015 "expected parameters for the parameter-typed result");
1017 setParams(result, results.getParams(result.getResultNumber())))) {
1020 }
else if (llvm::isa<TransformValueHandleTypeInterface>(result.getType())) {
1021 assert(results.isValue(result.getResultNumber()) &&
1022 "expected values for value-type-result");
1023 if (failed(setPayloadValues(
1024 result, results.getValues(result.getResultNumber())))) {
1028 assert(!results.isParam(result.getResultNumber()) &&
1029 "expected payload ops for the non-parameter typed result");
1031 setPayloadOps(result, results.get(result.getResultNumber())))) {
1050 return state.replacePayloadOp(op, replacement);
1055 Value replacement) {
1056 return state.replacePayloadValue(value, replacement);
1067 for (
Block &block : *region) {
1068 for (
Value handle : block.getArguments()) {
1069 state.invalidatedHandles.erase(handle);
1073 state.invalidatedHandles.erase(handle);
1078 #if LLVM_ENABLE_ABI_BREAKING_CHECKS
1082 llvm::to_vector(llvm::make_first_range(state.mappings[region]->reverse));
1085 state.mappings.erase(region);
1086 state.regionStack.pop_back();
1093 transform::TransformResults::TransformResults(
unsigned numSegments) {
1094 operations.appendEmptyRows(numSegments);
1095 params.appendEmptyRows(numSegments);
1096 values.appendEmptyRows(numSegments);
1102 assert(position <
static_cast<int64_t
>(this->params.size()) &&
1103 "setting params for a non-existent handle");
1104 assert(this->params[position].data() ==
nullptr &&
"params already set");
1105 assert(operations[position].data() ==
nullptr &&
1106 "another kind of results already set");
1107 assert(values[position].data() ==
nullptr &&
1108 "another kind of results already set");
1109 this->params.replace(position, params);
1117 return set(handle, operations), success();
1120 return setParams(handle, params), success();
1123 return setValues(handle, payloadValues), success();
1126 if (!
diag.succeeded())
1127 llvm::dbgs() <<
diag.getStatusString() <<
"\n";
1128 assert(
diag.succeeded() &&
"incorrect mapping");
1130 (void)
diag.silence();
1134 transform::TransformOpInterface transform) {
1135 for (
OpResult opResult : transform->getResults()) {
1136 if (!isSet(opResult.getResultNumber()))
1137 setMappedValues(opResult, {});
1142 transform::TransformResults::get(
unsigned resultNumber)
const {
1143 assert(resultNumber < operations.size() &&
1144 "querying results for a non-existent handle");
1145 assert(operations[resultNumber].data() !=
nullptr &&
1146 "querying unset results (values or params expected?)");
1147 return operations[resultNumber];
1151 transform::TransformResults::getParams(
unsigned resultNumber)
const {
1152 assert(resultNumber < params.size() &&
1153 "querying params for a non-existent handle");
1154 assert(params[resultNumber].data() !=
nullptr &&
1155 "querying unset params (ops or values expected?)");
1156 return params[resultNumber];
1160 transform::TransformResults::getValues(
unsigned resultNumber)
const {
1161 assert(resultNumber < values.size() &&
1162 "querying values for a non-existent handle");
1163 assert(values[resultNumber].data() !=
nullptr &&
1164 "querying unset values (ops or params expected?)");
1165 return values[resultNumber];
1168 bool transform::TransformResults::isParam(
unsigned resultNumber)
const {
1169 assert(resultNumber < params.size() &&
1170 "querying association for a non-existent handle");
1171 return params[resultNumber].data() !=
nullptr;
1174 bool transform::TransformResults::isValue(
unsigned resultNumber)
const {
1175 assert(resultNumber < values.size() &&
1176 "querying association for a non-existent handle");
1177 return values[resultNumber].data() !=
nullptr;
1180 bool transform::TransformResults::isSet(
unsigned resultNumber)
const {
1181 assert(resultNumber < params.size() &&
1182 "querying association for a non-existent handle");
1183 return params[resultNumber].data() !=
nullptr ||
1184 operations[resultNumber].data() !=
nullptr ||
1185 values[resultNumber].data() !=
nullptr;
1193 TransformOpInterface op,
1197 for (
OpOperand *opOperand : transformOp.getConsumedHandleOpOperands()) {
1198 consumedHandles.insert(opOperand->get());
1205 for (
Value v : values) {
1210 defOp = v.getDefiningOp();
1213 if (defOp != v.getDefiningOp())
1222 "invalid number of replacement values");
1226 getTransformOp(),
"tracking listener failed to find replacement op "
1227 "during application of this transform op");
1231 Operation *defOp = getCommonDefiningOp(values);
1233 diag.attachNote() <<
"replacement values belong to different ops";
1238 if (
config.skipCastOps && isa<CastOpInterface>(defOp)) {
1242 <<
"using output of 'CastOpInterface' op";
1248 if (!
config.requireMatchingReplacementOpName ||
1264 if (
auto findReplacementOpInterface =
1265 dyn_cast<FindPayloadReplacementOpInterface>(defOp)) {
1266 values.assign(findReplacementOpInterface.getNextOperands());
1267 diag.attachNote(defOp->
getLoc()) <<
"using operands provided by "
1268 "'FindPayloadReplacementOpInterface'";
1271 }
while (!values.empty());
1273 diag.attachNote() <<
"ran out of suitable replacement values";
1281 reasonCallback(
diag);
1282 DBGS() <<
"Match Failure : " <<
diag.str() <<
"\n";
1286 void transform::TrackingListener::notifyOperationErased(
Operation *op) {
1289 (void)replacePayloadValue(value,
nullptr);
1291 (void)replacePayloadOp(op,
nullptr);
1294 void transform::TrackingListener::notifyOperationReplaced(
1297 "invalid number of replacement values");
1300 for (
auto [oldValue, newValue] : llvm::zip(op->
getResults(), newValues))
1301 (void)replacePayloadValue(oldValue, newValue);
1305 if (failed(getTransformState().getHandlesForPayloadOp(
1306 op, opHandles,
true))) {
1320 auto handleWasConsumed = [&] {
1321 return llvm::any_of(opHandles,
1322 [&](
Value h) {
return consumedHandles.contains(h); });
1327 if (
config.skipHandleFn) {
1328 auto it = llvm::find_if(opHandles,
1329 [&](
Value v) {
return !
config.skipHandleFn(v); });
1330 if (it != opHandles.end())
1332 }
else if (!opHandles.empty()) {
1333 aliveHandle = opHandles.front();
1335 if (!aliveHandle || handleWasConsumed()) {
1338 (void)replacePayloadOp(op,
nullptr);
1344 findReplacementOp(replacement, op, newValues);
1347 if (!
diag.succeeded()) {
1349 <<
"replacement is required because this handle must be updated";
1350 notifyPayloadReplacementNotFound(op, newValues, std::move(
diag));
1351 (void)replacePayloadOp(op,
nullptr);
1355 (void)replacePayloadOp(op, replacement);
1362 assert(status.succeeded() &&
"listener state was not checked");
1374 return !status.succeeded();
1382 diag.takeDiagnostics(diags);
1383 if (!status.succeeded())
1384 status.takeDiagnostics(diags);
1388 status.attachNote(op->
getLoc()) <<
"[" << errorCounter <<
"] replaced op";
1390 status.attachNote(value.
getLoc())
1391 <<
"[" << errorCounter <<
"] replacement value " << index;
1397 if (!matchFailure) {
1400 return matchFailure->str();
1406 reasonCallback(
diag);
1407 matchFailure = std::move(
diag);
1421 return listener->failed();
1426 if (hasTrackingFailures()) {
1434 return listener->replacePayloadOp(op, replacement);
1445 for (
Operation *child : targets.drop_front(position + 1)) {
1446 if (parent->isAncestor(child)) {
1449 <<
"transform operation consumes a handle pointing to an ancestor "
1450 "payload operation before its descendant";
1452 <<
"the ancestor is likely erased or rewritten before the "
1453 "descendant is accessed, leading to undefined behavior";
1454 diag.attachNote(parent->getLoc()) <<
"ancestor payload op";
1455 diag.attachNote(child->getLoc()) <<
"descendant payload op";
1474 diag.attachNote(payloadOpLoc) <<
"when applied to this op";
1478 if (partialResult.
size() != expectedNumResults) {
1479 auto diag =
emitDiag() <<
"application of " << transformOpName
1480 <<
" expected to produce " << expectedNumResults
1481 <<
" results (actually produced "
1482 << partialResult.
size() <<
").";
1483 diag.attachNote(transformOpLoc)
1484 <<
"if you need variadic results, consider a generic `apply` "
1485 <<
"instead of the specialized `applyToOne`.";
1490 for (
const auto &[ptr, res] :
1491 llvm::zip(partialResult, transformOp->
getResults())) {
1494 if (llvm::isa<TransformHandleTypeInterface>(res.getType()) &&
1495 !isa<Operation *>(ptr)) {
1496 return emitDiag() <<
"application of " << transformOpName
1497 <<
" expected to produce an Operation * for result #"
1498 << res.getResultNumber();
1500 if (llvm::isa<TransformParamTypeInterface>(res.getType()) &&
1501 !isa<Attribute>(ptr)) {
1502 return emitDiag() <<
"application of " << transformOpName
1503 <<
" expected to produce an Attribute for result #"
1504 << res.getResultNumber();
1506 if (llvm::isa<TransformValueHandleTypeInterface>(res.getType()) &&
1508 return emitDiag() <<
"application of " << transformOpName
1509 <<
" expected to produce a Value for result #"
1510 << res.getResultNumber();
1516 template <
typename T>
1518 return llvm::to_vector(llvm::map_range(
1528 if (llvm::any_of(partialResults,
1529 [](
MappedValue value) {
return value.isNull(); }))
1531 assert(transformOp->
getNumResults() == partialResults.size() &&
1532 "expected as many partial results as op as results");
1534 transposed[i].push_back(value);
1538 unsigned position = r.getResultNumber();
1539 if (llvm::isa<TransformParamTypeInterface>(r.getType())) {
1541 castVector<Attribute>(transposed[position]));
1542 }
else if (llvm::isa<TransformValueHandleTypeInterface>(r.getType())) {
1543 transformResults.
setValues(r, castVector<Value>(transposed[position]));
1545 transformResults.
set(r, castVector<Operation *>(transposed[position]));
1557 assert(mappings.size() == values.size() &&
"mismatching number of mappings");
1558 for (
auto &&[operand, mapped] : llvm::zip_equal(values, mappings)) {
1559 size_t mappedSize = mapped.size();
1560 if (llvm::isa<TransformHandleTypeInterface>(operand.getType())) {
1561 llvm::append_range(mapped, state.getPayloadOps(operand));
1562 }
else if (llvm::isa<TransformValueHandleTypeInterface>(
1563 operand.getType())) {
1564 llvm::append_range(mapped, state.getPayloadValues(operand));
1566 assert(llvm::isa<TransformParamTypeInterface>(operand.getType()) &&
1567 "unsupported kind of transform dialect value");
1568 llvm::append_range(mapped, state.getParams(operand));
1571 if (mapped.size() - mappedSize != 1 && !flatten)
1580 mappings.resize(mappings.size() + values.size());
1590 for (
auto &&[terminatorOperand, result] :
1593 if (llvm::isa<transform::TransformHandleTypeInterface>(result.getType())) {
1594 results.
set(result, state.getPayloadOps(terminatorOperand));
1595 }
else if (llvm::isa<transform::TransformValueHandleTypeInterface>(
1596 result.getType())) {
1597 results.
setValues(result, state.getPayloadValues(terminatorOperand));
1600 llvm::isa<transform::TransformParamTypeInterface>(result.getType()) &&
1601 "unhandled transform type interface");
1602 results.
setParams(result, state.getParams(terminatorOperand));
1624 iface.getEffectsOnValue(source, nestedEffects);
1625 for (
const auto &effect : nestedEffects)
1626 effects.emplace_back(effect.getEffect(), target, effect.getResource());
1635 auto iface = dyn_cast<MemoryEffectOpInterface>(&op);
1639 for (
auto &&[source, target] : llvm::zip(block.
getArguments(), operands)) {
1646 llvm::append_range(effects, nestedEffects);
1658 auto iface = dyn_cast<MemoryEffectOpInterface>(&op);
1663 iface.getEffects(effects);
1678 llvm::append_range(targets, state.getPayloadOps(op->
getOperand(0)));
1681 if (state.getNumTopLevelMappings() !=
1685 <<
" extra value bindings, but " << state.getNumTopLevelMappings()
1686 <<
" were provided to the interpreter";
1689 targets.push_back(state.getTopLevel());
1691 for (
unsigned i = 0, e = state.getNumTopLevelMappings(); i < e; ++i)
1692 extraMappings.push_back(llvm::to_vector(state.getTopLevelMapping(i)));
1695 if (failed(state.mapBlockArguments(region.
front().
getArgument(0), targets)))
1699 if (failed(state.mapBlockArgument(
1700 argument, extraMappings[argument.getArgNumber() - 1])))
1712 assert(isa<TransformOpInterface>(op) &&
1713 "should implement TransformOpInterface to have "
1714 "PossibleTopLevelTransformOpTrait");
1717 return op->
emitOpError() <<
"expects at least one region";
1720 if (!llvm::hasNItems(*bodyRegion, 1))
1721 return op->
emitOpError() <<
"expects a single-block region";
1726 <<
"expects the entry block to have at least one argument";
1728 if (!llvm::isa<TransformHandleTypeInterface>(
1731 <<
"expects the first entry block argument to be of type "
1732 "implementing TransformHandleTypeInterface";
1738 <<
"expects the type of the block argument to match "
1739 "the type of the operand";
1743 if (llvm::isa<TransformHandleTypeInterface, TransformParamTypeInterface,
1744 TransformValueHandleTypeInterface>(arg.
getType()))
1749 <<
"expects trailing entry block arguments to be of type implementing "
1750 "TransformHandleTypeInterface, TransformValueHandleTypeInterface or "
1751 "TransformParamTypeInterface";
1761 <<
"expects operands to be provided for a nested op";
1762 diag.attachNote(parent->getLoc())
1763 <<
"nested in another possible top-level op";
1778 bool hasPayloadOperands =
false;
1781 if (llvm::isa<TransformHandleTypeInterface,
1782 TransformValueHandleTypeInterface>(operand.get().getType()))
1783 hasPayloadOperands =
true;
1785 if (hasPayloadOperands)
1794 llvm::report_fatal_error(
1795 Twine(
"ParamProducerTransformOpTrait must be attached to an op that "
1796 "implements MemoryEffectsOpInterface, found on ") +
1800 if (llvm::isa<TransformParamTypeInterface>(result.getType()))
1803 <<
"ParamProducerTransformOpTrait attached to this op expects "
1804 "result types to implement TransformParamTypeInterface";
1826 template <
typename EffectTy,
typename ResourceTy,
typename Range>
1829 return isa<EffectTy>(effect.
getEffect()) &&
1835 transform::TransformOpInterface transform) {
1836 auto iface = cast<MemoryEffectOpInterface>(transform.getOperation());
1838 iface.getEffectsOnValue(handle, effects);
1839 return ::hasEffect<MemoryEffects::Read, TransformMappingResource>(effects) &&
1840 ::hasEffect<MemoryEffects::Free, TransformMappingResource>(effects);
1886 auto iface = cast<MemoryEffectOpInterface>(transform.getOperation());
1888 iface.getEffects(effects);
1889 return ::hasEffect<MemoryEffects::Write, PayloadIRResource>(effects);
1893 auto iface = cast<MemoryEffectOpInterface>(transform.getOperation());
1895 iface.getEffects(effects);
1896 return ::hasEffect<MemoryEffects::Read, PayloadIRResource>(effects);
1900 Block &block, llvm::SmallDenseSet<unsigned int> &consumedArguments) {
1903 auto iface = dyn_cast<MemoryEffectOpInterface>(nested);
1908 iface.getEffects(effects);
1911 dyn_cast_or_null<BlockArgument>(effect.getValue());
1912 if (!argument || argument.
getOwner() != &block ||
1913 !isa<MemoryEffects::Free>(effect.getEffect()) ||
1927 TransformOpInterface transformOp) {
1929 consumedOperands.reserve(transformOp->getNumOperands());
1930 auto memEffectInterface =
1931 cast<MemoryEffectOpInterface>(transformOp.getOperation());
1933 for (
OpOperand &target : transformOp->getOpOperands()) {
1935 memEffectInterface.getEffectsOnValue(target.get(), effects);
1937 return isa<transform::TransformMappingResource>(
1939 isa<MemoryEffects::Free>(effect.
getEffect());
1941 consumedOperands.push_back(&target);
1944 return consumedOperands;
1948 auto iface = cast<MemoryEffectOpInterface>(op);
1950 iface.getEffects(effects);
1952 auto effectsOn = [&](
Value value) {
1953 return llvm::make_filter_range(
1955 return instance.
getValue() == value;
1959 std::optional<unsigned> firstConsumedOperand;
1961 auto range = effectsOn(operand.get());
1962 if (range.empty()) {
1964 op->
emitError() <<
"TransformOpInterface requires memory effects "
1965 "on operands to be specified";
1966 diag.attachNote() <<
"no effects specified for operand #"
1967 << operand.getOperandNumber();
1970 if (::hasEffect<MemoryEffects::Allocate, TransformMappingResource>(range)) {
1972 <<
"TransformOpInterface did not expect "
1973 "'allocate' memory effect on an operand";
1974 diag.attachNote() <<
"specified for operand #"
1975 << operand.getOperandNumber();
1978 if (!firstConsumedOperand &&
1979 ::hasEffect<MemoryEffects::Free, TransformMappingResource>(range)) {
1980 firstConsumedOperand = operand.getOperandNumber();
1984 if (firstConsumedOperand &&
1985 !::hasEffect<MemoryEffects::Write, PayloadIRResource>(effects)) {
1988 <<
"TransformOpInterface expects ops consuming operands to have a "
1989 "'write' effect on the payload resource";
1990 diag.attachNote() <<
"consumes operand #" << *firstConsumedOperand;
1995 auto range = effectsOn(result);
1996 if (!::hasEffect<MemoryEffects::Allocate, TransformMappingResource>(
1999 op->
emitError() <<
"TransformOpInterface requires 'allocate' memory "
2000 "effect to be specified for results";
2001 diag.attachNote() <<
"no 'allocate' effect specified for result #"
2002 << result.getResultNumber();
2015 Operation *payloadRoot, TransformOpInterface transform,
2020 if (enforceToplevelTransformOp) {
2022 transform->getNumOperands() != 0) {
2023 return transform->emitError()
2024 <<
"expected transform to start at the top-level transform op";
2031 TransformState state(transform->getParentRegion(), payloadRoot, extraMapping,
2033 if (stateInitializer)
2034 stateInitializer(state);
2035 if (state.applyTransform(transform).checkAndReport().failed())
2038 return stateExporter(state);
2046 #include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.cpp.inc"
2047 #include "mlir/Dialect/Transform/Interfaces/TransformTypeInterfaces.cpp.inc"
static InFlightDiagnostic emitDiag(Location location, DiagnosticSeverity severity, const Twine &message)
Helper function used to emit a diagnostic with an optionally empty twine message.
static std::string diag(const llvm::Value &value)
static llvm::ManagedStatic< PassManagerOptions > options
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
Attributes are known-constant values of operations.
This class represents an argument of a Block.
Block * getOwner() const
Returns the block that owns this argument.
unsigned getArgNumber() const
Returns the number of this argument.
Block represents an ordered list of Operations.
MutableArrayRef< BlockArgument > BlockArgListType
Operation * findAncestorOpInBlock(Operation &op)
Returns 'op' if 'op' lies in this block, or otherwise finds the ancestor operation of 'op' that lies ...
BlockArgument getArgument(unsigned i)
unsigned getNumArguments()
Operation * getTerminator()
Get the terminator operation of this block.
BlockArgListType getArguments()
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
A compatibility class connecting InFlightDiagnostic to DiagnosedSilenceableFailure while providing an...
The result of a transform IR operation application.
LogicalResult silence()
Converts silenceable failure into LogicalResult success without reporting the diagnostic,...
static DiagnosedSilenceableFailure success()
Constructs a DiagnosedSilenceableFailure in the success state.
Diagnostic & attachNote(std::optional< Location > loc=std::nullopt)
Attaches a note to the last diagnostic.
std::string getMessage() const
Returns the diagnostic message without emitting it.
bool isDefiniteFailure() const
Returns true if this is a definite failure.
LogicalResult checkAndReport()
Converts all kinds of failure into a LogicalResult failure, emitting the diagnostic if necessary.
static DiagnosedSilenceableFailure silenceableFailure(Diagnostic &&diag)
Constructs a DiagnosedSilenceableFailure in the silenceable failure state, ready to emit the given di...
bool succeeded() const
Returns true if this is a success.
static DiagnosedSilenceableFailure definiteFailure()
Constructs a DiagnosedSilenceableFailure in the failure state.
bool isSilenceableFailure() const
Returns true if this is a silenceable failure.
This class contains all of the information necessary to report a diagnostic to the DiagnosticEngine.
IRValueT get() const
Return the current value being used by this operand.
This class represents a diagnostic that is inflight and set to be reported.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
void setListener(Listener *newListener)
Sets the listener of this builder to the one provided.
This class represents an operand of an operation.
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Set of flags used to control the behavior of the various IR print methods (e.g.
This is a value defined by a result of an operation.
unsigned getResultNumber() const
Returns the number of this result.
This class provides the API for a sub-set of ops that are known to be constant-like.
This class provides the API for ops that are known to be isolated from above.
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
StringAttr getIdentifier() const
Return the name of this operation as a StringAttr.
T::Concept * getInterface() const
Returns an instance of the concept object for the given interface if it was registered to this operat...
Operation is the basic unit of execution within MLIR.
Value getOperand(unsigned idx)
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
bool isBeforeInBlock(Operation *other)
Given an operation 'other' that is within the same parent block, return whether the current operation...
unsigned getNumRegions()
Returns the number of regions held by this operation.
Location getLoc()
The source location the operation was defined or derived from.
unsigned getNumOperands()
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Block * getBlock()
Returns the operation block that contains this operation.
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Operation * getParentWithTrait()
Returns the closest surrounding parent operation with trait Trait.
OperationName getName()
The name of an operation is the key identifier for it.
MutableArrayRef< OpOperand > getOpOperands()
operand_range getOperands()
Returns an iterator on the underlying Value's.
result_range getOpResults()
result_range getResults()
bool isProperAncestor(Operation *other)
Return true if this operation is a proper ancestor of the other operation.
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
unsigned getNumResults()
Return the number of results held by this operation.
A 2D array where each row may have different length.
size_t size() const
Returns the number of rows in the 2D array.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Operation * getParentOp()
Return the parent operation this region is attached to.
This class implements the result iterators for the Operation class.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
This class represents a specific instance of an effect.
Resource * getResource() const
Return the resource that the effect applies to.
EffectT * getEffect() const
Return the effect being applied.
Value getValue() const
Return the value the effect is applied on, or nullptr if there isn't a known value being affected.
static DerivedEffect * get()
Returns a unique instance for the derived effect class.
static TransformMappingResource * get()
Returns a unique instance for the given effect class.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Operation * getOwner() const
Return the owner of this operand.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Include the generated interface declarations.
const FrozenRewritePatternSet GreedyRewriteConfig config
DiagnosedSilenceableFailure emitSilenceableFailure(Location loc, const Twine &message={})
Emits a silenceable failure with the given message.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
DiagnosedDefiniteFailure emitDefiniteFailure(Location loc, const Twine &message={})
Emits a definite failure with the given message.
bool hasEffect(Operation *op)
Returns true if op has an effect of type EffectTy.
Represents a range (offset, size, and stride) where each element of the triple may be dynamic or stat...