16 #include "llvm/ADT/STLExtras.h"
17 #include "llvm/ADT/ScopeExit.h"
18 #include "llvm/Support/Debug.h"
19 #include "llvm/Support/ErrorHandling.h"
21 #define DEBUG_TYPE "transform-dialect"
22 #define DEBUG_TYPE_FULL "transform-dialect-full"
23 #define DEBUG_PRINT_AFTER_ALL "transform-dialect-print-top-level-after-all"
24 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "] ")
25 #define LDBG(X) LLVM_DEBUG(DBGS() << (X))
26 #define FULL_LDBG(X) DEBUG_WITH_TYPE(DEBUG_TYPE_FULL, (DBGS() << (X)))
51 constexpr
const Value transform::TransformState::kTopLevelValue;
53 transform::TransformState::TransformState(
56 const TransformOptions &
options)
58 topLevelMappedValues.reserve(extraMappings.
size());
60 topLevelMappedValues.push_back(mapping);
62 RegionScope *scope =
new RegionScope(*
this, *region);
63 topLevelRegionScope.reset(scope);
67 Operation *transform::TransformState::getTopLevel()
const {
return topLevel; }
70 transform::TransformState::getPayloadOpsView(
Value value)
const {
71 const TransformOpMapping &operationMapping = getMapping(value).direct;
72 auto iter = operationMapping.find(value);
73 assert(iter != operationMapping.end() &&
74 "cannot find mapping for payload handle (param/value handle "
76 return iter->getSecond();
81 auto iter = mapping.find(value);
82 assert(iter != mapping.end() &&
"cannot find mapping for param handle "
83 "(operation/value handle provided?)");
84 return iter->getSecond();
88 transform::TransformState::getPayloadValuesView(
Value handleValue)
const {
89 const ValueMapping &mapping = getMapping(handleValue).values;
90 auto iter = mapping.find(handleValue);
91 assert(iter != mapping.end() &&
"cannot find mapping for value handle "
92 "(param/operation handle provided?)");
93 return iter->getSecond();
98 bool includeOutOfScope)
const {
100 for (
const auto &[region, mapping] : llvm::reverse(mappings)) {
101 auto iterator = mapping->reverse.find(op);
102 if (iterator != mapping->reverse.end()) {
103 llvm::append_range(handles, iterator->getSecond());
107 if (!includeOutOfScope &&
112 return success(found);
117 bool includeOutOfScope)
const {
119 for (
const auto &[region, mapping] : llvm::reverse(mappings)) {
120 auto iterator = mapping->reverseValues.find(payloadValue);
121 if (iterator != mapping->reverseValues.end()) {
122 llvm::append_range(handles, iterator->getSecond());
126 if (!includeOutOfScope &&
131 return success(found);
141 if (llvm::isa<transform::TransformHandleTypeInterface>(handle.
getType())) {
143 operations.reserve(values.size());
145 if (
auto *op = llvm::dyn_cast_if_present<Operation *>(value)) {
146 operations.push_back(op);
150 <<
"wrong kind of value provided for top-level operation handle";
152 if (failed(operationsFn(operations)))
157 if (llvm::isa<transform::TransformValueHandleTypeInterface>(
160 payloadValues.reserve(values.size());
162 if (
auto v = llvm::dyn_cast_if_present<Value>(value)) {
163 payloadValues.push_back(v);
167 <<
"wrong kind of value provided for the top-level value handle";
169 if (failed(valuesFn(payloadValues)))
174 assert(llvm::isa<transform::TransformParamTypeInterface>(handle.
getType()) &&
175 "unsupported kind of block argument");
177 parameters.reserve(values.size());
179 if (
auto attr = llvm::dyn_cast_if_present<Attribute>(value)) {
180 parameters.push_back(attr);
184 <<
"wrong kind of value provided for top-level parameter";
186 if (failed(paramsFn(parameters)))
197 return setPayloadOps(argument, operations);
200 return setParams(argument, params);
203 return setPayloadValues(argument, payloadValues);
211 for (
auto &&[argument, values] : llvm::zip_equal(arguments, mapping))
212 if (failed(mapBlockArgument(argument, values)))
218 transform::TransformState::setPayloadOps(
Value value,
220 assert(value != kTopLevelValue &&
221 "attempting to reset the transformation root");
222 assert(llvm::isa<TransformHandleTypeInterface>(value.
getType()) &&
223 "wrong handle type");
229 <<
"attempting to assign a null payload op to this transform value";
232 auto iface = llvm::cast<TransformHandleTypeInterface>(value.
getType());
234 iface.checkPayload(value.
getLoc(), targets);
241 Mappings &mappings = getMapping(value);
243 mappings.direct.insert({value, std::move(storedTargets)}).second;
244 assert(inserted &&
"value is already associated with another list");
248 mappings.reverse[op].push_back(value);
254 transform::TransformState::setPayloadValues(
Value handle,
256 assert(handle !=
nullptr &&
"attempting to set params for a null value");
257 assert(llvm::isa<TransformValueHandleTypeInterface>(handle.
getType()) &&
258 "wrong handle type");
260 for (
Value payload : payloadValues) {
263 return emitError(handle.
getLoc()) <<
"attempting to assign a null payload "
264 "value to this transform handle";
267 auto iface = llvm::cast<TransformValueHandleTypeInterface>(handle.
getType());
270 iface.checkPayload(handle.
getLoc(), payloadValueVector);
274 Mappings &mappings = getMapping(handle);
276 mappings.values.insert({handle, std::move(payloadValueVector)}).second;
279 "value handle is already associated with another list of payload values");
282 for (
Value payload : payloadValues)
283 mappings.reverseValues[payload].push_back(handle);
288 LogicalResult transform::TransformState::setParams(
Value value,
290 assert(value !=
nullptr &&
"attempting to set params for a null value");
296 <<
"attempting to assign a null parameter to this transform value";
299 auto valueType = llvm::dyn_cast<TransformParamTypeInterface>(value.
getType());
301 "cannot associate parameter with a value of non-parameter type");
303 valueType.checkPayload(value.
getLoc(), params);
307 Mappings &mappings = getMapping(value);
309 mappings.params.insert({value, llvm::to_vector(params)}).second;
310 assert(inserted &&
"value is already associated with another list of params");
315 template <
typename Mapping,
typename Key,
typename Mapped>
317 auto it = mapping.find(key);
318 if (it == mapping.end())
321 llvm::erase(it->getSecond(), mapped);
322 if (it->getSecond().empty())
326 void transform::TransformState::forgetMapping(
Value opHandle,
328 bool allowOutOfScope) {
329 Mappings &mappings = getMapping(opHandle, allowOutOfScope);
330 for (
Operation *op : mappings.direct[opHandle])
332 mappings.direct.erase(opHandle);
333 #if LLVM_ENABLE_ABI_BREAKING_CHECKS
336 mappings.incrementTimestamp(opHandle);
339 for (
Value opResult : origOpFlatResults) {
341 (void)getHandlesForPayloadValue(opResult, resultHandles);
342 for (
Value resultHandle : resultHandles) {
343 Mappings &localMappings = getMapping(resultHandle);
345 #if LLVM_ENABLE_ABI_BREAKING_CHECKS
348 mappings.incrementTimestamp(resultHandle);
355 void transform::TransformState::forgetValueMapping(
357 Mappings &mappings = getMapping(valueHandle);
358 for (
Value payloadValue : mappings.reverseValues[valueHandle])
360 mappings.values.erase(valueHandle);
361 #if LLVM_ENABLE_ABI_BREAKING_CHECKS
364 mappings.incrementTimestamp(valueHandle);
367 for (
Operation *payloadOp : payloadOperations) {
369 (void)getHandlesForPayloadOp(payloadOp, opHandles);
370 for (
Value opHandle : opHandles) {
371 Mappings &localMappings = getMapping(opHandle);
375 #if LLVM_ENABLE_ABI_BREAKING_CHECKS
378 localMappings.incrementTimestamp(opHandle);
385 transform::TransformState::replacePayloadOp(
Operation *op,
392 (void)getHandlesForPayloadValue(opResult, valueHandles,
394 assert(valueHandles.empty() &&
"expected no mapping to old results");
401 if (failed(getHandlesForPayloadOp(op, opHandles,
true)))
403 for (
Value handle : opHandles) {
404 Mappings &mappings = getMapping(handle,
true);
417 for (
Value handle : opHandles) {
418 Mappings &mappings = getMapping(handle,
true);
419 auto it = mappings.direct.find(handle);
420 if (it == mappings.direct.end())
427 mapped = replacement;
431 mappings.reverse[replacement].push_back(handle);
433 opHandlesToCompact.insert(handle);
441 transform::TransformState::replacePayloadValue(
Value value,
Value replacement) {
443 if (failed(getHandlesForPayloadValue(value, valueHandles,
447 for (
Value handle : valueHandles) {
448 Mappings &mappings = getMapping(handle,
true);
455 #if LLVM_ENABLE_ABI_BREAKING_CHECKS
458 mappings.incrementTimestamp(handle);
461 auto it = mappings.values.find(handle);
462 if (it == mappings.values.end())
466 for (
Value &mapped : association) {
468 mapped = replacement;
470 mappings.reverseValues[replacement].push_back(handle);
477 void transform::TransformState::recordOpHandleInvalidationOne(
484 if (invalidatedHandles.count(otherHandle) ||
485 newlyInvalidated.count(otherHandle))
488 FULL_LDBG(
"--recordOpHandleInvalidationOne\n");
491 llvm::interleaveComma(potentialAncestors,
DBGS() <<
"--ancestors: ",
492 [](
Operation *op) { llvm::dbgs() << *op; });
493 llvm::dbgs() <<
"\n");
497 for (
Operation *ancestor : potentialAncestors) {
500 { (
DBGS() <<
"----handle one ancestor: " << *ancestor <<
"\n"); });
502 { (
DBGS() <<
"----of payload with name: "
505 { (
DBGS() <<
"----of payload: " << *payloadOp <<
"\n"); });
507 if (!ancestor->isAncestor(payloadOp))
514 Location ancestorLoc = ancestor->getLoc();
516 std::optional<Location> throughValueLoc =
517 throughValue ? std::make_optional(throughValue.
getLoc()) : std::nullopt;
518 newlyInvalidated[otherHandle] = [ancestorLoc, opLoc, owner, operandNo,
520 throughValueLoc](
Location currentLoc) {
522 <<
"op uses a handle invalidated by a "
523 "previously executed transform op";
524 diag.attachNote(otherHandle.getLoc()) <<
"handle to invalidated ops";
525 diag.attachNote(owner->getLoc())
526 <<
"invalidated by this transform op that consumes its operand #"
528 <<
" and invalidates all handles to payload IR entities associated "
529 "with this operand and entities nested in them";
530 diag.attachNote(ancestorLoc) <<
"ancestor payload op";
531 diag.attachNote(opLoc) <<
"nested payload op";
532 if (throughValueLoc) {
533 diag.attachNote(*throughValueLoc)
534 <<
"consumed handle points to this payload value";
540 void transform::TransformState::recordValueHandleInvalidationByOpHandleOne(
547 if (invalidatedHandles.count(valueHandle) ||
548 newlyInvalidated.count(valueHandle))
551 for (
Operation *ancestor : potentialAncestors) {
553 std::optional<unsigned> resultNo;
557 if (
auto opResult = llvm::dyn_cast<OpResult>(payloadValue)) {
558 definingOp = opResult.getOwner();
559 resultNo = opResult.getResultNumber();
561 auto arg = llvm::cast<BlockArgument>(payloadValue);
563 argumentNo = arg.getArgNumber();
564 blockNo = std::distance(arg.getOwner()->getParent()->begin(),
565 arg.getOwner()->getIterator());
566 regionNo = arg.getOwner()->getParent()->getRegionNumber();
568 assert(definingOp &&
"expected the value to be defined by an op as result "
569 "or block argument");
570 if (!ancestor->isAncestor(definingOp))
575 Location ancestorLoc = ancestor->getLoc();
578 newlyInvalidated[valueHandle] = [valueHandle, owner, operandNo, resultNo,
579 argumentNo, blockNo, regionNo, ancestorLoc,
580 opLoc, valueLoc](
Location currentLoc) {
582 <<
"op uses a handle invalidated by a "
583 "previously executed transform op";
584 diag.attachNote(valueHandle.
getLoc()) <<
"invalidated handle";
585 diag.attachNote(owner->getLoc())
586 <<
"invalidated by this transform op that consumes its operand #"
588 <<
" and invalidates all handles to payload IR entities "
589 "associated with this operand and entities nested in them";
590 diag.attachNote(ancestorLoc)
591 <<
"ancestor op associated with the consumed handle";
593 diag.attachNote(opLoc)
594 <<
"op defining the value as result #" << *resultNo;
596 diag.attachNote(opLoc)
597 <<
"op defining the value as block argument #" << argumentNo
598 <<
" of block #" << blockNo <<
" in region #" << regionNo;
600 diag.attachNote(valueLoc) <<
"payload value";
605 void transform::TransformState::recordOpHandleInvalidation(
610 if (potentialAncestors.empty()) {
612 (
DBGS() <<
"----recording invalidation for empty handle: " << handle.
get()
618 newlyInvalidated[handle.
get()] = [owner, operandNo](
Location currentLoc) {
620 <<
"op uses a handle associated with empty "
621 "payload and invalidated by a "
622 "previously executed transform op";
623 diag.attachNote(owner->getLoc())
624 <<
"invalidated by this transform op that consumes its operand #"
637 for (
const auto &[region, mapping] : llvm::reverse(mappings)) {
641 for (
const auto &[payloadOp, otherHandles] : mapping->reverse) {
642 for (
Value otherHandle : otherHandles)
643 recordOpHandleInvalidationOne(handle, potentialAncestors, payloadOp,
644 otherHandle, throughValue,
652 for (
const auto &[payloadValue, valueHandles] : mapping->reverseValues) {
653 for (
Value valueHandle : valueHandles)
654 recordValueHandleInvalidationByOpHandleOne(handle, potentialAncestors,
655 payloadValue, valueHandle,
665 void transform::TransformState::recordValueHandleInvalidation(
669 for (
Value payloadValue : getPayloadValuesView(valueHandle.
get())) {
671 (void)getHandlesForPayloadValue(payloadValue, otherValueHandles);
672 for (
Value otherHandle : otherValueHandles) {
676 newlyInvalidated[otherHandle] = [otherHandle, owner, operandNo,
679 <<
"op uses a handle invalidated by a "
680 "previously executed transform op";
681 diag.attachNote(otherHandle.getLoc()) <<
"invalidated handle";
682 diag.attachNote(owner->getLoc())
683 <<
"invalidated by this transform op that consumes its operand #"
685 <<
" and invalidates handles to the same values as associated with "
687 diag.attachNote(valueLoc) <<
"payload value";
691 if (
auto opResult = llvm::dyn_cast<OpResult>(payloadValue)) {
692 Operation *payloadOp = opResult.getOwner();
693 recordOpHandleInvalidation(valueHandle, payloadOp, payloadValue,
696 auto arg = llvm::dyn_cast<BlockArgument>(payloadValue);
697 for (
Operation &payloadOp : *arg.getOwner())
698 recordOpHandleInvalidation(valueHandle, &payloadOp, payloadValue,
708 LogicalResult transform::TransformState::checkAndRecordHandleInvalidationImpl(
709 transform::TransformOpInterface transform,
711 FULL_LDBG(
"--Start checkAndRecordHandleInvalidation\n");
712 auto memoryEffectsIface =
713 cast<MemoryEffectOpInterface>(transform.getOperation());
715 memoryEffectsIface.getEffectsOnResource(
718 for (
OpOperand &target : transform->getOpOperands()) {
720 (
DBGS() <<
"----iterate on handle: " << target.get() <<
"\n");
726 auto it = invalidatedHandles.find(target.get());
727 auto nit = newlyInvalidated.find(target.get());
728 if (it != invalidatedHandles.end()) {
729 FULL_LDBG(
"--End checkAndRecordHandleInvalidation, found already "
730 "invalidated -> FAILURE\n");
731 return it->getSecond()(transform->getLoc()), failure();
733 if (!transform.allowsRepeatedHandleOperands() &&
734 nit != newlyInvalidated.end()) {
735 FULL_LDBG(
"--End checkAndRecordHandleInvalidation, found newly "
736 "invalidated (by this op) -> FAILURE\n");
737 return nit->getSecond()(transform->getLoc()), failure();
743 return isa<MemoryEffects::Free>(effect.getEffect()) &&
744 effect.getValue() == target.get();
746 if (llvm::any_of(effects, consumesTarget)) {
748 if (llvm::isa<transform::TransformHandleTypeInterface>(
749 target.get().getType())) {
750 FULL_LDBG(
"----recordOpHandleInvalidation\n");
752 llvm::to_vector(getPayloadOps(target.get()));
753 recordOpHandleInvalidation(target, payloadOps,
nullptr,
755 }
else if (llvm::isa<transform::TransformValueHandleTypeInterface>(
756 target.get().getType())) {
757 FULL_LDBG(
"----recordValueHandleInvalidation\n");
758 recordValueHandleInvalidation(target, newlyInvalidated);
760 FULL_LDBG(
"----not a TransformHandle -> SKIP AND DROP ON THE FLOOR\n");
763 FULL_LDBG(
"----no consume effect -> SKIP\n");
767 FULL_LDBG(
"--End checkAndRecordHandleInvalidation -> SUCCESS\n");
771 LogicalResult transform::TransformState::checkAndRecordHandleInvalidation(
772 transform::TransformOpInterface transform) {
773 InvalidatedHandleMap newlyInvalidated;
774 LogicalResult checkResult =
775 checkAndRecordHandleInvalidationImpl(transform, newlyInvalidated);
776 invalidatedHandles.insert(std::make_move_iterator(newlyInvalidated.begin()),
777 std::make_move_iterator(newlyInvalidated.end()));
781 template <
typename T>
784 transform::TransformOpInterface transform,
785 unsigned operandNumber) {
787 for (T p : payload) {
788 if (!seen.insert(p).second) {
790 transform.emitSilenceableError()
791 <<
"a handle passed as operand #" << operandNumber
792 <<
" and consumed by this operation points to a payload "
793 "entity more than once";
794 if constexpr (std::is_pointer_v<T>)
795 diag.attachNote(p->getLoc()) <<
"repeated target op";
797 diag.attachNote(p.getLoc()) <<
"repeated target value";
804 void transform::TransformState::compactOpHandles() {
805 for (
Value handle : opHandlesToCompact) {
806 Mappings &mappings = getMapping(handle,
true);
807 #if LLVM_ENABLE_ABI_BREAKING_CHECKS
808 if (llvm::find(mappings.direct[handle],
nullptr) !=
809 mappings.direct[handle].end())
812 mappings.incrementTimestamp(handle);
814 llvm::erase(mappings.direct[handle],
nullptr);
816 opHandlesToCompact.clear();
822 DBGS() <<
"applying: ";
824 llvm::dbgs() <<
"\n";
827 DBGS() <<
"Top-level payload before application:\n"
828 << *getTopLevel() <<
"\n");
829 auto printOnFailureRAII = llvm::make_scope_exit([
this] {
831 LLVM_DEBUG(
DBGS() <<
"Failing Top-level payload:\n"; getTopLevel()->print(
836 regionStack.back()->currentTransform = transform;
839 if (
options.getExpensiveChecksEnabled()) {
841 if (failed(checkAndRecordHandleInvalidation(transform)))
844 for (
OpOperand &operand : transform->getOpOperands()) {
846 (
DBGS() <<
"iterate on handle: " << operand.get() <<
"\n");
849 FULL_LDBG(
"--handle not consumed -> SKIP\n");
852 if (transform.allowsRepeatedHandleOperands()) {
853 FULL_LDBG(
"--op allows repeated handles -> SKIP\n");
858 Type operandType = operand.get().getType();
859 if (llvm::isa<TransformHandleTypeInterface>(operandType)) {
860 FULL_LDBG(
"--checkRepeatedConsumptionInOperand for Operation*\n");
862 checkRepeatedConsumptionInOperand<Operation *>(
863 getPayloadOpsView(operand.get()), transform,
864 operand.getOperandNumber());
869 }
else if (llvm::isa<TransformValueHandleTypeInterface>(operandType)) {
870 FULL_LDBG(
"--checkRepeatedConsumptionInOperand For Value\n");
872 checkRepeatedConsumptionInOperand<Value>(
873 getPayloadValuesView(operand.get()), transform,
874 operand.getOperandNumber());
880 FULL_LDBG(
"--not a TransformHandle -> SKIP AND DROP ON THE FLOOR\n");
887 transform.getConsumedHandleOpOperands();
896 for (
OpOperand *opOperand : consumedOperands) {
897 Value operand = opOperand->get();
898 if (llvm::isa<TransformHandleTypeInterface>(operand.
getType())) {
899 for (
Operation *payloadOp : getPayloadOps(operand)) {
900 llvm::append_range(origOpFlatResults, payloadOp->
getResults());
904 if (llvm::isa<TransformValueHandleTypeInterface>(operand.
getType())) {
905 for (
Value payloadValue : getPayloadValuesView(operand)) {
906 if (llvm::isa<OpResult>(payloadValue)) {
912 llvm::map_range(*llvm::cast<BlockArgument>(payloadValue).getOwner(),
919 <<
"unexpectedly consumed a value that is not a handle as operand #"
920 << opOperand->getOperandNumber();
922 <<
"value defined here with type " << operand.
getType();
931 llvm::find_if(llvm::reverse(regionStack), [&](
RegionScope *scope) {
932 return handle.getParentRegion() == scope->region;
934 assert(scopeIt != regionStack.rend() &&
935 "could not find region scope for handle");
937 return llvm::all_of(handle.getUsers(), [&](
Operation *user) {
938 return user == scope->currentTransform ||
939 happensBefore(user, scope->currentTransform);
958 transform->hasAttr(FindPayloadReplacementOpInterface::
959 kSilenceTrackingFailuresAttrName)) {
963 (void)trackingFailure.
silence();
968 result = std::move(trackingFailure);
972 result.
attachNote() <<
"tracking listener also failed: "
974 (void)trackingFailure.
silence();
987 for (
OpOperand *opOperand : consumedOperands) {
988 Value operand = opOperand->get();
989 if (llvm::isa<TransformHandleTypeInterface>(operand.
getType())) {
990 forgetMapping(operand, origOpFlatResults);
991 }
else if (llvm::isa<TransformValueHandleTypeInterface>(
993 forgetValueMapping(operand, origAssociatedOps);
997 if (failed(updateStateFromResults(results, transform->getResults())))
1000 printOnFailureRAII.release();
1002 DBGS() <<
"Top-level payload:\n";
1003 getTopLevel()->print(llvm::dbgs());
1008 LogicalResult transform::TransformState::updateStateFromResults(
1010 for (
OpResult result : opResults) {
1011 if (llvm::isa<TransformParamTypeInterface>(result.getType())) {
1012 assert(results.isParam(result.getResultNumber()) &&
1013 "expected parameters for the parameter-typed result");
1015 setParams(result, results.getParams(result.getResultNumber())))) {
1018 }
else if (llvm::isa<TransformValueHandleTypeInterface>(result.getType())) {
1019 assert(results.isValue(result.getResultNumber()) &&
1020 "expected values for value-type-result");
1021 if (failed(setPayloadValues(
1022 result, results.getValues(result.getResultNumber())))) {
1026 assert(!results.isParam(result.getResultNumber()) &&
1027 "expected payload ops for the non-parameter typed result");
1029 setPayloadOps(result, results.get(result.getResultNumber())))) {
1048 return state.replacePayloadOp(op, replacement);
1053 Value replacement) {
1054 return state.replacePayloadValue(value, replacement);
1065 for (
Block &block : *region) {
1066 for (
Value handle : block.getArguments()) {
1067 state.invalidatedHandles.erase(handle);
1071 state.invalidatedHandles.erase(handle);
1076 #if LLVM_ENABLE_ABI_BREAKING_CHECKS
1080 llvm::to_vector(llvm::make_first_range(state.mappings[region]->reverse));
1083 state.mappings.erase(region);
1084 state.regionStack.pop_back();
1091 transform::TransformResults::TransformResults(
unsigned numSegments) {
1092 operations.appendEmptyRows(numSegments);
1093 params.appendEmptyRows(numSegments);
1094 values.appendEmptyRows(numSegments);
1100 assert(position <
static_cast<int64_t
>(this->params.size()) &&
1101 "setting params for a non-existent handle");
1102 assert(this->params[position].data() ==
nullptr &&
"params already set");
1103 assert(operations[position].data() ==
nullptr &&
1104 "another kind of results already set");
1105 assert(values[position].data() ==
nullptr &&
1106 "another kind of results already set");
1107 this->params.replace(position, params);
1115 return set(handle, operations), success();
1118 return setParams(handle, params), success();
1121 return setValues(handle, payloadValues), success();
1124 if (!
diag.succeeded())
1125 llvm::dbgs() <<
diag.getStatusString() <<
"\n";
1126 assert(
diag.succeeded() &&
"incorrect mapping");
1128 (void)
diag.silence();
1132 transform::TransformOpInterface transform) {
1133 for (
OpResult opResult : transform->getResults()) {
1134 if (!isSet(opResult.getResultNumber()))
1135 setMappedValues(opResult, {});
1140 transform::TransformResults::get(
unsigned resultNumber)
const {
1141 assert(resultNumber < operations.size() &&
1142 "querying results for a non-existent handle");
1143 assert(operations[resultNumber].data() !=
nullptr &&
1144 "querying unset results (values or params expected?)");
1145 return operations[resultNumber];
1149 transform::TransformResults::getParams(
unsigned resultNumber)
const {
1150 assert(resultNumber < params.size() &&
1151 "querying params for a non-existent handle");
1152 assert(params[resultNumber].data() !=
nullptr &&
1153 "querying unset params (ops or values expected?)");
1154 return params[resultNumber];
1158 transform::TransformResults::getValues(
unsigned resultNumber)
const {
1159 assert(resultNumber < values.size() &&
1160 "querying values for a non-existent handle");
1161 assert(values[resultNumber].data() !=
nullptr &&
1162 "querying unset values (ops or params expected?)");
1163 return values[resultNumber];
1166 bool transform::TransformResults::isParam(
unsigned resultNumber)
const {
1167 assert(resultNumber < params.size() &&
1168 "querying association for a non-existent handle");
1169 return params[resultNumber].data() !=
nullptr;
1172 bool transform::TransformResults::isValue(
unsigned resultNumber)
const {
1173 assert(resultNumber < values.size() &&
1174 "querying association for a non-existent handle");
1175 return values[resultNumber].data() !=
nullptr;
1178 bool transform::TransformResults::isSet(
unsigned resultNumber)
const {
1179 assert(resultNumber < params.size() &&
1180 "querying association for a non-existent handle");
1181 return params[resultNumber].data() !=
nullptr ||
1182 operations[resultNumber].data() !=
nullptr ||
1183 values[resultNumber].data() !=
nullptr;
1191 TransformOpInterface op,
1193 :
TransformState::Extension(state), transformOp(op), config(config) {
1195 for (
OpOperand *opOperand : transformOp.getConsumedHandleOpOperands()) {
1196 consumedHandles.insert(opOperand->get());
1203 for (
Value v : values) {
1208 defOp = v.getDefiningOp();
1211 if (defOp != v.getDefiningOp())
1220 "invalid number of replacement values");
1224 getTransformOp(),
"tracking listener failed to find replacement op "
1225 "during application of this transform op");
1229 Operation *defOp = getCommonDefiningOp(values);
1231 diag.attachNote() <<
"replacement values belong to different ops";
1236 if (config.skipCastOps && isa<CastOpInterface>(defOp)) {
1240 <<
"using output of 'CastOpInterface' op";
1246 if (!config.requireMatchingReplacementOpName ||
1262 if (
auto findReplacementOpInterface =
1263 dyn_cast<FindPayloadReplacementOpInterface>(defOp)) {
1264 values.assign(findReplacementOpInterface.getNextOperands());
1265 diag.attachNote(defOp->
getLoc()) <<
"using operands provided by "
1266 "'FindPayloadReplacementOpInterface'";
1269 }
while (!values.empty());
1271 diag.attachNote() <<
"ran out of suitable replacement values";
1279 reasonCallback(
diag);
1280 DBGS() <<
"Match Failure : " <<
diag.str() <<
"\n";
1284 void transform::TrackingListener::notifyOperationErased(
Operation *op) {
1287 (void)replacePayloadValue(value,
nullptr);
1289 (void)replacePayloadOp(op,
nullptr);
1292 void transform::TrackingListener::notifyOperationReplaced(
1295 "invalid number of replacement values");
1298 for (
auto [oldValue, newValue] : llvm::zip(op->
getResults(), newValues))
1299 (void)replacePayloadValue(oldValue, newValue);
1303 if (failed(getTransformState().getHandlesForPayloadOp(
1304 op, opHandles,
true))) {
1318 auto handleWasConsumed = [&] {
1319 return llvm::any_of(opHandles,
1320 [&](
Value h) {
return consumedHandles.contains(h); });
1325 if (config.skipHandleFn) {
1326 auto it = llvm::find_if(opHandles,
1327 [&](
Value v) {
return !config.skipHandleFn(v); });
1328 if (it != opHandles.end())
1330 }
else if (!opHandles.empty()) {
1331 aliveHandle = opHandles.front();
1333 if (!aliveHandle || handleWasConsumed()) {
1336 (void)replacePayloadOp(op,
nullptr);
1342 findReplacementOp(replacement, op, newValues);
1345 if (!
diag.succeeded()) {
1347 <<
"replacement is required because this handle must be updated";
1348 notifyPayloadReplacementNotFound(op, newValues, std::move(
diag));
1349 (void)replacePayloadOp(op,
nullptr);
1353 (void)replacePayloadOp(op, replacement);
1360 assert(status.succeeded() &&
"listener state was not checked");
1372 return !status.succeeded();
1380 diag.takeDiagnostics(diags);
1381 if (!status.succeeded())
1382 status.takeDiagnostics(diags);
1386 status.attachNote(op->
getLoc()) <<
"[" << errorCounter <<
"] replaced op";
1388 status.attachNote(value.
getLoc())
1389 <<
"[" << errorCounter <<
"] replacement value " << index;
1404 return listener->failed();
1409 if (hasTrackingFailures()) {
1417 return listener->replacePayloadOp(op, replacement);
1428 for (
Operation *child : targets.drop_front(position + 1)) {
1429 if (parent->isAncestor(child)) {
1432 <<
"transform operation consumes a handle pointing to an ancestor "
1433 "payload operation before its descendant";
1435 <<
"the ancestor is likely erased or rewritten before the "
1436 "descendant is accessed, leading to undefined behavior";
1437 diag.attachNote(parent->getLoc()) <<
"ancestor payload op";
1438 diag.attachNote(child->getLoc()) <<
"descendant payload op";
1457 diag.attachNote(payloadOpLoc) <<
"when applied to this op";
1461 if (partialResult.
size() != expectedNumResults) {
1462 auto diag =
emitDiag() <<
"application of " << transformOpName
1463 <<
" expected to produce " << expectedNumResults
1464 <<
" results (actually produced "
1465 << partialResult.
size() <<
").";
1466 diag.attachNote(transformOpLoc)
1467 <<
"if you need variadic results, consider a generic `apply` "
1468 <<
"instead of the specialized `applyToOne`.";
1473 for (
const auto &[ptr, res] :
1474 llvm::zip(partialResult, transformOp->
getResults())) {
1477 if (llvm::isa<TransformHandleTypeInterface>(res.getType()) &&
1479 return emitDiag() <<
"application of " << transformOpName
1480 <<
" expected to produce an Operation * for result #"
1481 << res.getResultNumber();
1483 if (llvm::isa<TransformParamTypeInterface>(res.getType()) &&
1485 return emitDiag() <<
"application of " << transformOpName
1486 <<
" expected to produce an Attribute for result #"
1487 << res.getResultNumber();
1489 if (llvm::isa<TransformValueHandleTypeInterface>(res.getType()) &&
1491 return emitDiag() <<
"application of " << transformOpName
1492 <<
" expected to produce a Value for result #"
1493 << res.getResultNumber();
1499 template <
typename T>
1501 return llvm::to_vector(llvm::map_range(
1511 if (llvm::any_of(partialResults,
1512 [](
MappedValue value) {
return value.isNull(); }))
1514 assert(transformOp->
getNumResults() == partialResults.size() &&
1515 "expected as many partial results as op as results");
1517 transposed[i].push_back(value);
1521 unsigned position = r.getResultNumber();
1522 if (llvm::isa<TransformParamTypeInterface>(r.getType())) {
1524 castVector<Attribute>(transposed[position]));
1525 }
else if (llvm::isa<TransformValueHandleTypeInterface>(r.getType())) {
1526 transformResults.
setValues(r, castVector<Value>(transposed[position]));
1528 transformResults.
set(r, castVector<Operation *>(transposed[position]));
1540 assert(mappings.size() == values.size() &&
"mismatching number of mappings");
1541 for (
auto &&[operand, mapped] : llvm::zip_equal(values, mappings)) {
1542 size_t mappedSize = mapped.size();
1543 if (llvm::isa<TransformHandleTypeInterface>(operand.getType())) {
1544 llvm::append_range(mapped, state.getPayloadOps(operand));
1545 }
else if (llvm::isa<TransformValueHandleTypeInterface>(
1546 operand.getType())) {
1547 llvm::append_range(mapped, state.getPayloadValues(operand));
1549 assert(llvm::isa<TransformParamTypeInterface>(operand.getType()) &&
1550 "unsupported kind of transform dialect value");
1551 llvm::append_range(mapped, state.getParams(operand));
1554 if (mapped.size() - mappedSize != 1 && !flatten)
1563 mappings.resize(mappings.size() + values.size());
1573 for (
auto &&[terminatorOperand, result] :
1576 if (llvm::isa<transform::TransformHandleTypeInterface>(result.getType())) {
1577 results.
set(result, state.getPayloadOps(terminatorOperand));
1578 }
else if (llvm::isa<transform::TransformValueHandleTypeInterface>(
1579 result.getType())) {
1580 results.
setValues(result, state.getPayloadValues(terminatorOperand));
1583 llvm::isa<transform::TransformParamTypeInterface>(result.getType()) &&
1584 "unhandled transform type interface");
1585 results.
setParams(result, state.getParams(terminatorOperand));
1607 iface.getEffectsOnValue(source, nestedEffects);
1608 for (
const auto &effect : nestedEffects)
1609 effects.emplace_back(effect.getEffect(), target, effect.getResource());
1618 auto iface = dyn_cast<MemoryEffectOpInterface>(&op);
1622 for (
auto &&[source, target] : llvm::zip(block.
getArguments(), operands)) {
1629 llvm::append_range(effects, nestedEffects);
1641 auto iface = dyn_cast<MemoryEffectOpInterface>(&op);
1646 iface.getEffects(effects);
1661 llvm::append_range(targets, state.getPayloadOps(op->
getOperand(0)));
1664 if (state.getNumTopLevelMappings() !=
1668 <<
" extra value bindings, but " << state.getNumTopLevelMappings()
1669 <<
" were provided to the interpreter";
1672 targets.push_back(state.getTopLevel());
1674 for (
unsigned i = 0, e = state.getNumTopLevelMappings(); i < e; ++i)
1675 extraMappings.push_back(llvm::to_vector(state.getTopLevelMapping(i)));
1678 if (failed(state.mapBlockArguments(region.
front().
getArgument(0), targets)))
1682 if (failed(state.mapBlockArgument(
1683 argument, extraMappings[argument.getArgNumber() - 1])))
1695 assert(isa<TransformOpInterface>(op) &&
1696 "should implement TransformOpInterface to have "
1697 "PossibleTopLevelTransformOpTrait");
1700 return op->
emitOpError() <<
"expects at least one region";
1703 if (!llvm::hasNItems(*bodyRegion, 1))
1704 return op->
emitOpError() <<
"expects a single-block region";
1709 <<
"expects the entry block to have at least one argument";
1711 if (!llvm::isa<TransformHandleTypeInterface>(
1714 <<
"expects the first entry block argument to be of type "
1715 "implementing TransformHandleTypeInterface";
1721 <<
"expects the type of the block argument to match "
1722 "the type of the operand";
1726 if (llvm::isa<TransformHandleTypeInterface, TransformParamTypeInterface,
1727 TransformValueHandleTypeInterface>(arg.
getType()))
1732 <<
"expects trailing entry block arguments to be of type implementing "
1733 "TransformHandleTypeInterface, TransformValueHandleTypeInterface or "
1734 "TransformParamTypeInterface";
1744 <<
"expects operands to be provided for a nested op";
1745 diag.attachNote(parent->getLoc())
1746 <<
"nested in another possible top-level op";
1761 bool hasPayloadOperands =
false;
1764 if (llvm::isa<TransformHandleTypeInterface,
1765 TransformValueHandleTypeInterface>(operand.get().getType()))
1766 hasPayloadOperands =
true;
1768 if (hasPayloadOperands)
1777 llvm::report_fatal_error(
1778 Twine(
"ParamProducerTransformOpTrait must be attached to an op that "
1779 "implements MemoryEffectsOpInterface, found on ") +
1783 if (llvm::isa<TransformParamTypeInterface>(result.getType()))
1786 <<
"ParamProducerTransformOpTrait attached to this op expects "
1787 "result types to implement TransformParamTypeInterface";
1809 template <
typename EffectTy,
typename ResourceTy,
typename Range>
1812 return isa<EffectTy>(effect.
getEffect()) &&
1818 transform::TransformOpInterface transform) {
1819 auto iface = cast<MemoryEffectOpInterface>(transform.getOperation());
1821 iface.getEffectsOnValue(handle, effects);
1822 return ::hasEffect<MemoryEffects::Read, TransformMappingResource>(effects) &&
1823 ::hasEffect<MemoryEffects::Free, TransformMappingResource>(effects);
1869 auto iface = cast<MemoryEffectOpInterface>(transform.getOperation());
1871 iface.getEffects(effects);
1872 return ::hasEffect<MemoryEffects::Write, PayloadIRResource>(effects);
1876 auto iface = cast<MemoryEffectOpInterface>(transform.getOperation());
1878 iface.getEffects(effects);
1879 return ::hasEffect<MemoryEffects::Read, PayloadIRResource>(effects);
1883 Block &block, llvm::SmallDenseSet<unsigned int> &consumedArguments) {
1886 auto iface = dyn_cast<MemoryEffectOpInterface>(nested);
1891 iface.getEffects(effects);
1894 dyn_cast_or_null<BlockArgument>(effect.getValue());
1895 if (!argument || argument.
getOwner() != &block ||
1896 !isa<MemoryEffects::Free>(effect.getEffect()) ||
1910 TransformOpInterface transformOp) {
1912 consumedOperands.reserve(transformOp->getNumOperands());
1913 auto memEffectInterface =
1914 cast<MemoryEffectOpInterface>(transformOp.getOperation());
1916 for (
OpOperand &target : transformOp->getOpOperands()) {
1918 memEffectInterface.getEffectsOnValue(target.get(), effects);
1920 return isa<transform::TransformMappingResource>(
1922 isa<MemoryEffects::Free>(effect.
getEffect());
1924 consumedOperands.push_back(&target);
1927 return consumedOperands;
1931 auto iface = cast<MemoryEffectOpInterface>(op);
1933 iface.getEffects(effects);
1935 auto effectsOn = [&](
Value value) {
1936 return llvm::make_filter_range(
1938 return instance.
getValue() == value;
1942 std::optional<unsigned> firstConsumedOperand;
1944 auto range = effectsOn(operand.get());
1945 if (range.empty()) {
1947 op->
emitError() <<
"TransformOpInterface requires memory effects "
1948 "on operands to be specified";
1949 diag.attachNote() <<
"no effects specified for operand #"
1950 << operand.getOperandNumber();
1953 if (::hasEffect<MemoryEffects::Allocate, TransformMappingResource>(range)) {
1955 <<
"TransformOpInterface did not expect "
1956 "'allocate' memory effect on an operand";
1957 diag.attachNote() <<
"specified for operand #"
1958 << operand.getOperandNumber();
1961 if (!firstConsumedOperand &&
1962 ::hasEffect<MemoryEffects::Free, TransformMappingResource>(range)) {
1963 firstConsumedOperand = operand.getOperandNumber();
1967 if (firstConsumedOperand &&
1968 !::hasEffect<MemoryEffects::Write, PayloadIRResource>(effects)) {
1971 <<
"TransformOpInterface expects ops consuming operands to have a "
1972 "'write' effect on the payload resource";
1973 diag.attachNote() <<
"consumes operand #" << *firstConsumedOperand;
1978 auto range = effectsOn(result);
1979 if (!::hasEffect<MemoryEffects::Allocate, TransformMappingResource>(
1982 op->
emitError() <<
"TransformOpInterface requires 'allocate' memory "
1983 "effect to be specified for results";
1984 diag.attachNote() <<
"no 'allocate' effect specified for result #"
1985 << result.getResultNumber();
1998 Operation *payloadRoot, TransformOpInterface transform,
2003 if (enforceToplevelTransformOp) {
2005 transform->getNumOperands() != 0) {
2006 return transform->emitError()
2007 <<
"expected transform to start at the top-level transform op";
2014 TransformState state(transform->getParentRegion(), payloadRoot, extraMapping,
2016 if (stateInitializer)
2017 stateInitializer(state);
2018 if (state.applyTransform(transform).checkAndReport().failed())
2021 return stateExporter(state);
2029 #include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.cpp.inc"
2030 #include "mlir/Dialect/Transform/Interfaces/TransformTypeInterfaces.cpp.inc"
static InFlightDiagnostic emitDiag(Location location, DiagnosticSeverity severity, const Twine &message)
Helper function used to emit a diagnostic with an optionally empty twine message.
static std::string diag(const llvm::Value &value)
static llvm::ManagedStatic< PassManagerOptions > options
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
Attributes are known-constant values of operations.
This class represents an argument of a Block.
Block * getOwner() const
Returns the block that owns this argument.
unsigned getArgNumber() const
Returns the number of this argument.
Block represents an ordered list of Operations.
Operation * findAncestorOpInBlock(Operation &op)
Returns 'op' if 'op' lies in this block, or otherwise finds the ancestor operation of 'op' that lies ...
BlockArgument getArgument(unsigned i)
unsigned getNumArguments()
Operation * getTerminator()
Get the terminator operation of this block.
BlockArgListType getArguments()
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
A compatibility class connecting InFlightDiagnostic to DiagnosedSilenceableFailure while providing an...
The result of a transform IR operation application.
LogicalResult silence()
Converts silenceable failure into LogicalResult success without reporting the diagnostic,...
static DiagnosedSilenceableFailure success()
Constructs a DiagnosedSilenceableFailure in the success state.
Diagnostic & attachNote(std::optional< Location > loc=std::nullopt)
Attaches a note to the last diagnostic.
std::string getMessage() const
Returns the diagnostic message without emitting it.
bool isDefiniteFailure() const
Returns true if this is a definite failure.
LogicalResult checkAndReport()
Converts all kinds of failure into a LogicalResult failure, emitting the diagnostic if necessary.
static DiagnosedSilenceableFailure silenceableFailure(Diagnostic &&diag)
Constructs a DiagnosedSilenceableFailure in the silenceable failure state, ready to emit the given di...
bool succeeded() const
Returns true if this is a success.
static DiagnosedSilenceableFailure definiteFailure()
Constructs a DiagnosedSilenceableFailure in the failure state.
bool isSilenceableFailure() const
Returns true if this is a silenceable failure.
This class contains all of the information necessary to report a diagnostic to the DiagnosticEngine.
IRValueT get() const
Return the current value being used by this operand.
This class represents a diagnostic that is inflight and set to be reported.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
void setListener(Listener *newListener)
Sets the listener of this builder to the one provided.
This class represents an operand of an operation.
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Set of flags used to control the behavior of the various IR print methods (e.g.
This is a value defined by a result of an operation.
unsigned getResultNumber() const
Returns the number of this result.
This class provides the API for a sub-set of ops that are known to be constant-like.
This class provides the API for ops that are known to be isolated from above.
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
StringAttr getIdentifier() const
Return the name of this operation as a StringAttr.
T::Concept * getInterface() const
Returns an instance of the concept object for the given interface if it was registered to this operat...
Operation is the basic unit of execution within MLIR.
Value getOperand(unsigned idx)
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
bool isBeforeInBlock(Operation *other)
Given an operation 'other' that is within the same parent block, return whether the current operation...
unsigned getNumRegions()
Returns the number of regions held by this operation.
Location getLoc()
The source location the operation was defined or derived from.
unsigned getNumOperands()
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Block * getBlock()
Returns the operation block that contains this operation.
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Operation * getParentWithTrait()
Returns the closest surrounding parent operation with trait Trait.
OperationName getName()
The name of an operation is the key identifier for it.
MutableArrayRef< OpOperand > getOpOperands()
operand_range getOperands()
Returns an iterator on the underlying Value's.
result_range getOpResults()
result_range getResults()
bool isProperAncestor(Operation *other)
Return true if this operation is a proper ancestor of the other operation.
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
unsigned getNumResults()
Return the number of results held by this operation.
A 2D array where each row may have different length.
size_t size() const
Returns the number of rows in the 2D array.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Operation * getParentOp()
Return the parent operation this region is attached to.
This class implements the result iterators for the Operation class.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
This class represents a specific instance of an effect.
Resource * getResource() const
Return the resource that the effect applies to.
EffectT * getEffect() const
Return the effect being applied.
Value getValue() const
Return the value the effect is applied on, or nullptr if there isn't a known value being affected.
static DerivedEffect * get()
Returns a unique instance for the derived effect class.
static TransformMappingResource * get()
Returns a unique instance for the given effect class.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Operation * getOwner() const
Return the owner of this operand.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Include the generated interface declarations.
DiagnosedSilenceableFailure emitSilenceableFailure(Location loc, const Twine &message={})
Emits a silenceable failure with the given message.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
DiagnosedDefiniteFailure emitDefiniteFailure(Location loc, const Twine &message={})
Emits a definite failure with the given message.
bool hasEffect(Operation *op)
Returns true if op has an effect of type EffectTy.
Represents a range (offset, size, and stride) where each element of the triple may be dynamic or stat...