16 #include "llvm/ADT/STLExtras.h"
17 #include "llvm/ADT/ScopeExit.h"
18 #include "llvm/Support/Debug.h"
19 #include "llvm/Support/ErrorHandling.h"
21 #define DEBUG_TYPE "transform-dialect"
22 #define DEBUG_TYPE_FULL "transform-dialect-full"
23 #define DEBUG_PRINT_AFTER_ALL "transform-dialect-print-top-level-after-all"
24 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "] ")
25 #define LDBG(X) LLVM_DEBUG(DBGS() << (X))
26 #define FULL_LDBG(X) DEBUG_WITH_TYPE(DEBUG_TYPE_FULL, (DBGS() << (X)))
51 constexpr
const Value transform::TransformState::kTopLevelValue;
53 transform::TransformState::TransformState(
56 const TransformOptions &
options)
58 topLevelMappedValues.reserve(extraMappings.
size());
60 topLevelMappedValues.push_back(mapping);
62 RegionScope *scope =
new RegionScope(*
this, *region);
63 topLevelRegionScope.reset(scope);
67 Operation *transform::TransformState::getTopLevel()
const {
return topLevel; }
70 transform::TransformState::getPayloadOpsView(
Value value)
const {
71 const TransformOpMapping &operationMapping = getMapping(value).direct;
72 auto iter = operationMapping.find(value);
73 assert(iter != operationMapping.end() &&
74 "cannot find mapping for payload handle (param/value handle "
76 return iter->getSecond();
81 auto iter = mapping.find(value);
82 assert(iter != mapping.end() &&
"cannot find mapping for param handle "
83 "(operation/value handle provided?)");
84 return iter->getSecond();
88 transform::TransformState::getPayloadValuesView(
Value handleValue)
const {
89 const ValueMapping &mapping = getMapping(handleValue).values;
90 auto iter = mapping.find(handleValue);
91 assert(iter != mapping.end() &&
"cannot find mapping for value handle "
92 "(param/operation handle provided?)");
93 return iter->getSecond();
98 bool includeOutOfScope)
const {
100 for (
const auto &[region, mapping] : llvm::reverse(mappings)) {
101 auto iterator = mapping->reverse.find(op);
102 if (iterator != mapping->reverse.end()) {
103 llvm::append_range(handles, iterator->getSecond());
107 if (!includeOutOfScope &&
112 return success(found);
117 bool includeOutOfScope)
const {
119 for (
const auto &[region, mapping] : llvm::reverse(mappings)) {
120 auto iterator = mapping->reverseValues.find(payloadValue);
121 if (iterator != mapping->reverseValues.end()) {
122 llvm::append_range(handles, iterator->getSecond());
126 if (!includeOutOfScope &&
131 return success(found);
141 if (llvm::isa<transform::TransformHandleTypeInterface>(handle.
getType())) {
143 operations.reserve(values.size());
145 if (
auto *op = llvm::dyn_cast_if_present<Operation *>(value)) {
146 operations.push_back(op);
150 <<
"wrong kind of value provided for top-level operation handle";
152 if (failed(operationsFn(operations)))
157 if (llvm::isa<transform::TransformValueHandleTypeInterface>(
160 payloadValues.reserve(values.size());
162 if (
auto v = llvm::dyn_cast_if_present<Value>(value)) {
163 payloadValues.push_back(v);
167 <<
"wrong kind of value provided for the top-level value handle";
169 if (failed(valuesFn(payloadValues)))
174 assert(llvm::isa<transform::TransformParamTypeInterface>(handle.
getType()) &&
175 "unsupported kind of block argument");
177 parameters.reserve(values.size());
179 if (
auto attr = llvm::dyn_cast_if_present<Attribute>(value)) {
180 parameters.push_back(attr);
184 <<
"wrong kind of value provided for top-level parameter";
186 if (failed(paramsFn(parameters)))
197 return setPayloadOps(argument, operations);
200 return setParams(argument, params);
203 return setPayloadValues(argument, payloadValues);
211 for (
auto &&[argument, values] : llvm::zip_equal(arguments, mapping))
212 if (failed(mapBlockArgument(argument, values)))
218 transform::TransformState::setPayloadOps(
Value value,
220 assert(value != kTopLevelValue &&
221 "attempting to reset the transformation root");
222 assert(llvm::isa<TransformHandleTypeInterface>(value.
getType()) &&
223 "wrong handle type");
229 <<
"attempting to assign a null payload op to this transform value";
232 auto iface = llvm::cast<TransformHandleTypeInterface>(value.
getType());
234 iface.checkPayload(value.
getLoc(), targets);
241 Mappings &mappings = getMapping(value);
243 mappings.direct.insert({value, std::move(storedTargets)}).second;
244 assert(inserted &&
"value is already associated with another list");
248 mappings.reverse[op].push_back(value);
254 transform::TransformState::setPayloadValues(
Value handle,
256 assert(handle !=
nullptr &&
"attempting to set params for a null value");
257 assert(llvm::isa<TransformValueHandleTypeInterface>(handle.
getType()) &&
258 "wrong handle type");
260 for (
Value payload : payloadValues) {
263 return emitError(handle.
getLoc()) <<
"attempting to assign a null payload "
264 "value to this transform handle";
267 auto iface = llvm::cast<TransformValueHandleTypeInterface>(handle.
getType());
270 iface.checkPayload(handle.
getLoc(), payloadValueVector);
274 Mappings &mappings = getMapping(handle);
276 mappings.values.insert({handle, std::move(payloadValueVector)}).second;
279 "value handle is already associated with another list of payload values");
282 for (
Value payload : payloadValues)
283 mappings.reverseValues[payload].push_back(handle);
288 LogicalResult transform::TransformState::setParams(
Value value,
290 assert(value !=
nullptr &&
"attempting to set params for a null value");
296 <<
"attempting to assign a null parameter to this transform value";
299 auto valueType = llvm::dyn_cast<TransformParamTypeInterface>(value.
getType());
301 "cannot associate parameter with a value of non-parameter type");
303 valueType.checkPayload(value.
getLoc(), params);
307 Mappings &mappings = getMapping(value);
309 mappings.params.insert({value, llvm::to_vector(params)}).second;
310 assert(inserted &&
"value is already associated with another list of params");
315 template <
typename Mapping,
typename Key,
typename Mapped>
317 auto it = mapping.find(key);
318 if (it == mapping.end())
321 llvm::erase(it->getSecond(), mapped);
322 if (it->getSecond().empty())
326 void transform::TransformState::forgetMapping(
Value opHandle,
328 bool allowOutOfScope) {
329 Mappings &mappings = getMapping(opHandle, allowOutOfScope);
330 for (
Operation *op : mappings.direct[opHandle])
332 mappings.direct.erase(opHandle);
333 #ifdef LLVM_ENABLE_ABI_BREAKING_CHECKS
336 mappings.incrementTimestamp(opHandle);
339 for (
Value opResult : origOpFlatResults) {
341 (void)getHandlesForPayloadValue(opResult, resultHandles);
342 for (
Value resultHandle : resultHandles) {
343 Mappings &localMappings = getMapping(resultHandle);
345 #ifdef LLVM_ENABLE_ABI_BREAKING_CHECKS
348 mappings.incrementTimestamp(resultHandle);
355 void transform::TransformState::forgetValueMapping(
357 Mappings &mappings = getMapping(valueHandle);
358 for (
Value payloadValue : mappings.reverseValues[valueHandle])
360 mappings.values.erase(valueHandle);
361 #ifdef LLVM_ENABLE_ABI_BREAKING_CHECKS
364 mappings.incrementTimestamp(valueHandle);
367 for (
Operation *payloadOp : payloadOperations) {
369 (void)getHandlesForPayloadOp(payloadOp, opHandles);
370 for (
Value opHandle : opHandles) {
371 Mappings &localMappings = getMapping(opHandle);
375 #ifdef LLVM_ENABLE_ABI_BREAKING_CHECKS
378 localMappings.incrementTimestamp(opHandle);
385 transform::TransformState::replacePayloadOp(
Operation *op,
392 (void)getHandlesForPayloadValue(opResult, valueHandles,
394 assert(valueHandles.empty() &&
"expected no mapping to old results");
401 if (failed(getHandlesForPayloadOp(op, opHandles,
true)))
403 for (
Value handle : opHandles) {
404 Mappings &mappings = getMapping(handle,
true);
417 for (
Value handle : opHandles) {
418 Mappings &mappings = getMapping(handle,
true);
419 auto it = mappings.direct.find(handle);
420 if (it == mappings.direct.end())
427 mapped = replacement;
431 mappings.reverse[replacement].push_back(handle);
433 opHandlesToCompact.insert(handle);
441 transform::TransformState::replacePayloadValue(
Value value,
Value replacement) {
443 if (failed(getHandlesForPayloadValue(value, valueHandles,
447 for (
Value handle : valueHandles) {
448 Mappings &mappings = getMapping(handle,
true);
455 #ifdef LLVM_ENABLE_ABI_BREAKING_CHECKS
458 mappings.incrementTimestamp(handle);
461 auto it = mappings.values.find(handle);
462 if (it == mappings.values.end())
466 for (
Value &mapped : association) {
468 mapped = replacement;
470 mappings.reverseValues[replacement].push_back(handle);
477 void transform::TransformState::recordOpHandleInvalidationOne(
484 if (invalidatedHandles.count(otherHandle) ||
485 newlyInvalidated.count(otherHandle))
488 FULL_LDBG(
"--recordOpHandleInvalidationOne\n");
491 llvm::interleaveComma(potentialAncestors,
DBGS() <<
"--ancestors: ",
492 [](
Operation *op) { llvm::dbgs() << *op; });
493 llvm::dbgs() <<
"\n");
497 for (
Operation *ancestor : potentialAncestors) {
500 { (
DBGS() <<
"----handle one ancestor: " << *ancestor <<
"\n"); });
502 { (
DBGS() <<
"----of payload with name: "
505 { (
DBGS() <<
"----of payload: " << *payloadOp <<
"\n"); });
507 if (!ancestor->isAncestor(payloadOp))
514 Location ancestorLoc = ancestor->getLoc();
516 std::optional<Location> throughValueLoc =
517 throughValue ? std::make_optional(throughValue.
getLoc()) : std::nullopt;
518 newlyInvalidated[otherHandle] = [ancestorLoc, opLoc, owner, operandNo,
520 throughValueLoc](
Location currentLoc) {
522 <<
"op uses a handle invalidated by a "
523 "previously executed transform op";
524 diag.attachNote(otherHandle.getLoc()) <<
"handle to invalidated ops";
525 diag.attachNote(owner->getLoc())
526 <<
"invalidated by this transform op that consumes its operand #"
528 <<
" and invalidates all handles to payload IR entities associated "
529 "with this operand and entities nested in them";
530 diag.attachNote(ancestorLoc) <<
"ancestor payload op";
531 diag.attachNote(opLoc) <<
"nested payload op";
532 if (throughValueLoc) {
533 diag.attachNote(*throughValueLoc)
534 <<
"consumed handle points to this payload value";
540 void transform::TransformState::recordValueHandleInvalidationByOpHandleOne(
547 if (invalidatedHandles.count(valueHandle) ||
548 newlyInvalidated.count(valueHandle))
551 for (
Operation *ancestor : potentialAncestors) {
553 std::optional<unsigned> resultNo;
557 if (
auto opResult = llvm::dyn_cast<OpResult>(payloadValue)) {
558 definingOp = opResult.getOwner();
559 resultNo = opResult.getResultNumber();
561 auto arg = llvm::cast<BlockArgument>(payloadValue);
563 argumentNo = arg.getArgNumber();
564 blockNo = std::distance(arg.getOwner()->getParent()->begin(),
565 arg.getOwner()->getIterator());
566 regionNo = arg.getOwner()->getParent()->getRegionNumber();
568 assert(definingOp &&
"expected the value to be defined by an op as result "
569 "or block argument");
570 if (!ancestor->isAncestor(definingOp))
575 Location ancestorLoc = ancestor->getLoc();
578 newlyInvalidated[valueHandle] = [valueHandle, owner, operandNo, resultNo,
579 argumentNo, blockNo, regionNo, ancestorLoc,
580 opLoc, valueLoc](
Location currentLoc) {
582 <<
"op uses a handle invalidated by a "
583 "previously executed transform op";
584 diag.attachNote(valueHandle.
getLoc()) <<
"invalidated handle";
585 diag.attachNote(owner->getLoc())
586 <<
"invalidated by this transform op that consumes its operand #"
588 <<
" and invalidates all handles to payload IR entities "
589 "associated with this operand and entities nested in them";
590 diag.attachNote(ancestorLoc)
591 <<
"ancestor op associated with the consumed handle";
593 diag.attachNote(opLoc)
594 <<
"op defining the value as result #" << *resultNo;
596 diag.attachNote(opLoc)
597 <<
"op defining the value as block argument #" << argumentNo
598 <<
" of block #" << blockNo <<
" in region #" << regionNo;
600 diag.attachNote(valueLoc) <<
"payload value";
605 void transform::TransformState::recordOpHandleInvalidation(
610 if (potentialAncestors.empty()) {
612 (
DBGS() <<
"----recording invalidation for empty handle: " << handle.
get()
618 newlyInvalidated[handle.
get()] = [owner, operandNo](
Location currentLoc) {
620 <<
"op uses a handle associated with empty "
621 "payload and invalidated by a "
622 "previously executed transform op";
623 diag.attachNote(owner->getLoc())
624 <<
"invalidated by this transform op that consumes its operand #"
637 for (
const auto &[region, mapping] : llvm::reverse(mappings)) {
641 for (
const auto &[payloadOp, otherHandles] : mapping->reverse) {
642 for (
Value otherHandle : otherHandles)
643 recordOpHandleInvalidationOne(handle, potentialAncestors, payloadOp,
644 otherHandle, throughValue,
652 for (
const auto &[payloadValue, valueHandles] : mapping->reverseValues) {
653 for (
Value valueHandle : valueHandles)
654 recordValueHandleInvalidationByOpHandleOne(handle, potentialAncestors,
655 payloadValue, valueHandle,
665 void transform::TransformState::recordValueHandleInvalidation(
669 for (
Value payloadValue : getPayloadValuesView(valueHandle.
get())) {
671 (void)getHandlesForPayloadValue(payloadValue, otherValueHandles);
672 for (
Value otherHandle : otherValueHandles) {
676 newlyInvalidated[otherHandle] = [otherHandle, owner, operandNo,
679 <<
"op uses a handle invalidated by a "
680 "previously executed transform op";
681 diag.attachNote(otherHandle.getLoc()) <<
"invalidated handle";
682 diag.attachNote(owner->getLoc())
683 <<
"invalidated by this transform op that consumes its operand #"
685 <<
" and invalidates handles to the same values as associated with "
687 diag.attachNote(valueLoc) <<
"payload value";
691 if (
auto opResult = llvm::dyn_cast<OpResult>(payloadValue)) {
692 Operation *payloadOp = opResult.getOwner();
693 recordOpHandleInvalidation(valueHandle, payloadOp, payloadValue,
696 auto arg = llvm::dyn_cast<BlockArgument>(payloadValue);
697 for (
Operation &payloadOp : *arg.getOwner())
698 recordOpHandleInvalidation(valueHandle, &payloadOp, payloadValue,
708 LogicalResult transform::TransformState::checkAndRecordHandleInvalidationImpl(
709 transform::TransformOpInterface transform,
711 FULL_LDBG(
"--Start checkAndRecordHandleInvalidation\n");
712 auto memoryEffectsIface =
713 cast<MemoryEffectOpInterface>(transform.getOperation());
715 memoryEffectsIface.getEffectsOnResource(
718 for (
OpOperand &target : transform->getOpOperands()) {
720 (
DBGS() <<
"----iterate on handle: " << target.get() <<
"\n");
726 auto it = invalidatedHandles.find(target.get());
727 auto nit = newlyInvalidated.find(target.get());
728 if (it != invalidatedHandles.end()) {
729 FULL_LDBG(
"--End checkAndRecordHandleInvalidation, found already "
730 "invalidated -> FAILURE\n");
731 return it->getSecond()(transform->getLoc()), failure();
733 if (!transform.allowsRepeatedHandleOperands() &&
734 nit != newlyInvalidated.end()) {
735 FULL_LDBG(
"--End checkAndRecordHandleInvalidation, found newly "
736 "invalidated (by this op) -> FAILURE\n");
737 return nit->getSecond()(transform->getLoc()), failure();
743 return isa<MemoryEffects::Free>(effect.getEffect()) &&
744 effect.getValue() == target.get();
746 if (llvm::any_of(effects, consumesTarget)) {
748 if (llvm::isa<transform::TransformHandleTypeInterface>(
749 target.get().getType())) {
750 FULL_LDBG(
"----recordOpHandleInvalidation\n");
752 llvm::to_vector(getPayloadOps(target.get()));
753 recordOpHandleInvalidation(target, payloadOps,
nullptr,
755 }
else if (llvm::isa<transform::TransformValueHandleTypeInterface>(
756 target.get().getType())) {
757 FULL_LDBG(
"----recordValueHandleInvalidation\n");
758 recordValueHandleInvalidation(target, newlyInvalidated);
760 FULL_LDBG(
"----not a TransformHandle -> SKIP AND DROP ON THE FLOOR\n");
763 FULL_LDBG(
"----no consume effect -> SKIP\n");
767 FULL_LDBG(
"--End checkAndRecordHandleInvalidation -> SUCCESS\n");
771 LogicalResult transform::TransformState::checkAndRecordHandleInvalidation(
772 transform::TransformOpInterface transform) {
773 InvalidatedHandleMap newlyInvalidated;
774 LogicalResult checkResult =
775 checkAndRecordHandleInvalidationImpl(transform, newlyInvalidated);
776 invalidatedHandles.insert(std::make_move_iterator(newlyInvalidated.begin()),
777 std::make_move_iterator(newlyInvalidated.end()));
781 template <
typename T>
784 transform::TransformOpInterface transform,
785 unsigned operandNumber) {
787 for (T p : payload) {
788 if (!seen.insert(p).second) {
790 transform.emitSilenceableError()
791 <<
"a handle passed as operand #" << operandNumber
792 <<
" and consumed by this operation points to a payload "
793 "entity more than once";
794 if constexpr (std::is_pointer_v<T>)
795 diag.attachNote(p->getLoc()) <<
"repeated target op";
797 diag.attachNote(p.getLoc()) <<
"repeated target value";
804 void transform::TransformState::compactOpHandles() {
805 for (
Value handle : opHandlesToCompact) {
806 Mappings &mappings = getMapping(handle,
true);
807 #ifdef LLVM_ENABLE_ABI_BREAKING_CHECKS
808 if (llvm::find(mappings.direct[handle],
nullptr) !=
809 mappings.direct[handle].end())
812 mappings.incrementTimestamp(handle);
814 llvm::erase(mappings.direct[handle],
nullptr);
816 opHandlesToCompact.clear();
822 DBGS() <<
"applying: ";
824 llvm::dbgs() <<
"\n";
827 DBGS() <<
"Top-level payload before application:\n"
828 << *getTopLevel() <<
"\n");
829 auto printOnFailureRAII = llvm::make_scope_exit([
this] {
831 LLVM_DEBUG(
DBGS() <<
"Failing Top-level payload:\n"; getTopLevel()->print(
836 regionStack.back()->currentTransform = transform;
839 if (
options.getExpensiveChecksEnabled()) {
841 if (failed(checkAndRecordHandleInvalidation(transform)))
844 for (
OpOperand &operand : transform->getOpOperands()) {
846 (
DBGS() <<
"iterate on handle: " << operand.get() <<
"\n");
849 FULL_LDBG(
"--handle not consumed -> SKIP\n");
852 if (transform.allowsRepeatedHandleOperands()) {
853 FULL_LDBG(
"--op allows repeated handles -> SKIP\n");
858 Type operandType = operand.get().getType();
859 if (llvm::isa<TransformHandleTypeInterface>(operandType)) {
860 FULL_LDBG(
"--checkRepeatedConsumptionInOperand for Operation*\n");
862 checkRepeatedConsumptionInOperand<Operation *>(
863 getPayloadOpsView(operand.get()), transform,
864 operand.getOperandNumber());
869 }
else if (llvm::isa<TransformValueHandleTypeInterface>(operandType)) {
870 FULL_LDBG(
"--checkRepeatedConsumptionInOperand For Value\n");
872 checkRepeatedConsumptionInOperand<Value>(
873 getPayloadValuesView(operand.get()), transform,
874 operand.getOperandNumber());
880 FULL_LDBG(
"--not a TransformHandle -> SKIP AND DROP ON THE FLOOR\n");
887 transform.getConsumedHandleOpOperands();
896 for (
OpOperand *opOperand : consumedOperands) {
897 Value operand = opOperand->get();
898 if (llvm::isa<TransformHandleTypeInterface>(operand.
getType())) {
899 for (
Operation *payloadOp : getPayloadOps(operand)) {
900 llvm::append_range(origOpFlatResults, payloadOp->
getResults());
904 if (llvm::isa<TransformValueHandleTypeInterface>(operand.
getType())) {
905 for (
Value payloadValue : getPayloadValuesView(operand)) {
906 if (llvm::isa<OpResult>(payloadValue)) {
912 llvm::map_range(*llvm::cast<BlockArgument>(payloadValue).getOwner(),
919 <<
"unexpectedly consumed a value that is not a handle as operand #"
920 << opOperand->getOperandNumber();
922 <<
"value defined here with type " << operand.
getType();
931 llvm::find_if(llvm::reverse(regionStack), [&](
RegionScope *scope) {
932 return handle.getParentRegion() == scope->region;
934 assert(scopeIt != regionStack.rend() &&
935 "could not find region scope for handle");
937 for (
Operation *user : handle.getUsers()) {
938 if (user != scope->currentTransform &&
960 transform->hasAttr(FindPayloadReplacementOpInterface::
961 kSilenceTrackingFailuresAttrName)) {
965 (void)trackingFailure.
silence();
970 result = std::move(trackingFailure);
974 result.
attachNote() <<
"tracking listener also failed: "
976 (void)trackingFailure.
silence();
989 for (
OpOperand *opOperand : consumedOperands) {
990 Value operand = opOperand->get();
991 if (llvm::isa<TransformHandleTypeInterface>(operand.
getType())) {
992 forgetMapping(operand, origOpFlatResults);
993 }
else if (llvm::isa<TransformValueHandleTypeInterface>(
995 forgetValueMapping(operand, origAssociatedOps);
999 if (failed(updateStateFromResults(results, transform->getResults())))
1002 printOnFailureRAII.release();
1004 DBGS() <<
"Top-level payload:\n";
1005 getTopLevel()->print(llvm::dbgs());
1010 LogicalResult transform::TransformState::updateStateFromResults(
1012 for (
OpResult result : opResults) {
1013 if (llvm::isa<TransformParamTypeInterface>(result.getType())) {
1014 assert(results.isParam(result.getResultNumber()) &&
1015 "expected parameters for the parameter-typed result");
1017 setParams(result, results.getParams(result.getResultNumber())))) {
1020 }
else if (llvm::isa<TransformValueHandleTypeInterface>(result.getType())) {
1021 assert(results.isValue(result.getResultNumber()) &&
1022 "expected values for value-type-result");
1023 if (failed(setPayloadValues(
1024 result, results.getValues(result.getResultNumber())))) {
1028 assert(!results.isParam(result.getResultNumber()) &&
1029 "expected payload ops for the non-parameter typed result");
1031 setPayloadOps(result, results.get(result.getResultNumber())))) {
1050 return state.replacePayloadOp(op, replacement);
1055 Value replacement) {
1056 return state.replacePayloadValue(value, replacement);
1067 for (
Block &block : *region) {
1068 for (
Value handle : block.getArguments()) {
1069 state.invalidatedHandles.erase(handle);
1073 state.invalidatedHandles.erase(handle);
1078 #if LLVM_ENABLE_ABI_BREAKING_CHECKS
1082 llvm::to_vector(llvm::make_first_range(state.mappings[region]->reverse));
1085 state.mappings.erase(region);
1086 state.regionStack.pop_back();
1093 transform::TransformResults::TransformResults(
unsigned numSegments) {
1094 operations.appendEmptyRows(numSegments);
1095 params.appendEmptyRows(numSegments);
1096 values.appendEmptyRows(numSegments);
1102 assert(position <
static_cast<int64_t
>(this->params.size()) &&
1103 "setting params for a non-existent handle");
1104 assert(this->params[position].data() ==
nullptr &&
"params already set");
1105 assert(operations[position].data() ==
nullptr &&
1106 "another kind of results already set");
1107 assert(values[position].data() ==
nullptr &&
1108 "another kind of results already set");
1109 this->params.replace(position, params);
1117 return set(handle, operations), success();
1120 return setParams(handle, params), success();
1123 return setValues(handle, payloadValues), success();
1126 if (!
diag.succeeded())
1127 llvm::dbgs() <<
diag.getStatusString() <<
"\n";
1128 assert(
diag.succeeded() &&
"incorrect mapping");
1130 (void)
diag.silence();
1134 transform::TransformOpInterface transform) {
1135 for (
OpResult opResult : transform->getResults()) {
1136 if (!isSet(opResult.getResultNumber()))
1137 setMappedValues(opResult, {});
1142 transform::TransformResults::get(
unsigned resultNumber)
const {
1143 assert(resultNumber < operations.size() &&
1144 "querying results for a non-existent handle");
1145 assert(operations[resultNumber].data() !=
nullptr &&
1146 "querying unset results (values or params expected?)");
1147 return operations[resultNumber];
1151 transform::TransformResults::getParams(
unsigned resultNumber)
const {
1152 assert(resultNumber < params.size() &&
1153 "querying params for a non-existent handle");
1154 assert(params[resultNumber].data() !=
nullptr &&
1155 "querying unset params (ops or values expected?)");
1156 return params[resultNumber];
1160 transform::TransformResults::getValues(
unsigned resultNumber)
const {
1161 assert(resultNumber < values.size() &&
1162 "querying values for a non-existent handle");
1163 assert(values[resultNumber].data() !=
nullptr &&
1164 "querying unset values (ops or params expected?)");
1165 return values[resultNumber];
1168 bool transform::TransformResults::isParam(
unsigned resultNumber)
const {
1169 assert(resultNumber < params.size() &&
1170 "querying association for a non-existent handle");
1171 return params[resultNumber].data() !=
nullptr;
1174 bool transform::TransformResults::isValue(
unsigned resultNumber)
const {
1175 assert(resultNumber < values.size() &&
1176 "querying association for a non-existent handle");
1177 return values[resultNumber].data() !=
nullptr;
1180 bool transform::TransformResults::isSet(
unsigned resultNumber)
const {
1181 assert(resultNumber < params.size() &&
1182 "querying association for a non-existent handle");
1183 return params[resultNumber].data() !=
nullptr ||
1184 operations[resultNumber].data() !=
nullptr ||
1185 values[resultNumber].data() !=
nullptr;
1193 TransformOpInterface op,
1195 :
TransformState::Extension(state), transformOp(op), config(config) {
1197 for (
OpOperand *opOperand : transformOp.getConsumedHandleOpOperands()) {
1198 consumedHandles.insert(opOperand->get());
1205 for (
Value v : values) {
1210 defOp = v.getDefiningOp();
1213 if (defOp != v.getDefiningOp())
1222 "invalid number of replacement values");
1226 getTransformOp(),
"tracking listener failed to find replacement op "
1227 "during application of this transform op");
1231 Operation *defOp = getCommonDefiningOp(values);
1233 diag.attachNote() <<
"replacement values belong to different ops";
1238 if (config.skipCastOps && isa<CastOpInterface>(defOp)) {
1242 <<
"using output of 'CastOpInterface' op";
1248 if (!config.requireMatchingReplacementOpName ||
1264 if (
auto findReplacementOpInterface =
1265 dyn_cast<FindPayloadReplacementOpInterface>(defOp)) {
1266 values.assign(findReplacementOpInterface.getNextOperands());
1267 diag.attachNote(defOp->
getLoc()) <<
"using operands provided by "
1268 "'FindPayloadReplacementOpInterface'";
1271 }
while (!values.empty());
1273 diag.attachNote() <<
"ran out of suitable replacement values";
1281 reasonCallback(
diag);
1282 DBGS() <<
"Match Failure : " <<
diag.str() <<
"\n";
1286 void transform::TrackingListener::notifyOperationErased(
Operation *op) {
1289 (void)replacePayloadValue(value,
nullptr);
1291 (void)replacePayloadOp(op,
nullptr);
1294 void transform::TrackingListener::notifyOperationReplaced(
1297 "invalid number of replacement values");
1300 for (
auto [oldValue, newValue] : llvm::zip(op->
getResults(), newValues))
1301 (void)replacePayloadValue(oldValue, newValue);
1305 if (failed(getTransformState().getHandlesForPayloadOp(
1306 op, opHandles,
true))) {
1320 auto handleWasConsumed = [&] {
1321 return llvm::any_of(opHandles,
1322 [&](
Value h) {
return consumedHandles.contains(h); });
1327 if (config.skipHandleFn) {
1328 auto it = llvm::find_if(opHandles,
1329 [&](
Value v) {
return !config.skipHandleFn(v); });
1330 if (it != opHandles.end())
1332 }
else if (!opHandles.empty()) {
1333 aliveHandle = opHandles.front();
1335 if (!aliveHandle || handleWasConsumed()) {
1338 (void)replacePayloadOp(op,
nullptr);
1344 findReplacementOp(replacement, op, newValues);
1347 if (!
diag.succeeded()) {
1349 <<
"replacement is required because this handle must be updated";
1350 notifyPayloadReplacementNotFound(op, newValues, std::move(
diag));
1351 (void)replacePayloadOp(op,
nullptr);
1355 (void)replacePayloadOp(op, replacement);
1362 assert(status.succeeded() &&
"listener state was not checked");
1374 return !status.succeeded();
1382 diag.takeDiagnostics(diags);
1383 if (!status.succeeded())
1384 status.takeDiagnostics(diags);
1388 status.attachNote(op->
getLoc()) <<
"[" << errorCounter <<
"] replaced op";
1390 status.attachNote(value.
getLoc())
1391 <<
"[" << errorCounter <<
"] replacement value " << index;
1406 return listener->failed();
1411 if (hasTrackingFailures()) {
1419 return listener->replacePayloadOp(op, replacement);
1430 for (
Operation *child : targets.drop_front(position + 1)) {
1431 if (parent->isAncestor(child)) {
1434 <<
"transform operation consumes a handle pointing to an ancestor "
1435 "payload operation before its descendant";
1437 <<
"the ancestor is likely erased or rewritten before the "
1438 "descendant is accessed, leading to undefined behavior";
1439 diag.attachNote(parent->getLoc()) <<
"ancestor payload op";
1440 diag.attachNote(child->getLoc()) <<
"descendant payload op";
1459 diag.attachNote(payloadOpLoc) <<
"when applied to this op";
1463 if (partialResult.
size() != expectedNumResults) {
1464 auto diag =
emitDiag() <<
"application of " << transformOpName
1465 <<
" expected to produce " << expectedNumResults
1466 <<
" results (actually produced "
1467 << partialResult.
size() <<
").";
1468 diag.attachNote(transformOpLoc)
1469 <<
"if you need variadic results, consider a generic `apply` "
1470 <<
"instead of the specialized `applyToOne`.";
1475 for (
const auto &[ptr, res] :
1476 llvm::zip(partialResult, transformOp->
getResults())) {
1479 if (llvm::isa<TransformHandleTypeInterface>(res.getType()) &&
1481 return emitDiag() <<
"application of " << transformOpName
1482 <<
" expected to produce an Operation * for result #"
1483 << res.getResultNumber();
1485 if (llvm::isa<TransformParamTypeInterface>(res.getType()) &&
1487 return emitDiag() <<
"application of " << transformOpName
1488 <<
" expected to produce an Attribute for result #"
1489 << res.getResultNumber();
1491 if (llvm::isa<TransformValueHandleTypeInterface>(res.getType()) &&
1493 return emitDiag() <<
"application of " << transformOpName
1494 <<
" expected to produce a Value for result #"
1495 << res.getResultNumber();
1501 template <
typename T>
1503 return llvm::to_vector(llvm::map_range(
1513 if (llvm::any_of(partialResults,
1514 [](
MappedValue value) {
return value.isNull(); }))
1516 assert(transformOp->
getNumResults() == partialResults.size() &&
1517 "expected as many partial results as op as results");
1519 transposed[i].push_back(value);
1523 unsigned position = r.getResultNumber();
1524 if (llvm::isa<TransformParamTypeInterface>(r.getType())) {
1526 castVector<Attribute>(transposed[position]));
1527 }
else if (llvm::isa<TransformValueHandleTypeInterface>(r.getType())) {
1528 transformResults.
setValues(r, castVector<Value>(transposed[position]));
1530 transformResults.
set(r, castVector<Operation *>(transposed[position]));
1542 assert(mappings.size() == values.size() &&
"mismatching number of mappings");
1543 for (
auto &&[operand, mapped] : llvm::zip_equal(values, mappings)) {
1544 size_t mappedSize = mapped.size();
1545 if (llvm::isa<TransformHandleTypeInterface>(operand.getType())) {
1546 llvm::append_range(mapped, state.getPayloadOps(operand));
1547 }
else if (llvm::isa<TransformValueHandleTypeInterface>(
1548 operand.getType())) {
1549 llvm::append_range(mapped, state.getPayloadValues(operand));
1551 assert(llvm::isa<TransformParamTypeInterface>(operand.getType()) &&
1552 "unsupported kind of transform dialect value");
1553 llvm::append_range(mapped, state.getParams(operand));
1556 if (mapped.size() - mappedSize != 1 && !flatten)
1565 mappings.resize(mappings.size() + values.size());
1575 for (
auto &&[terminatorOperand, result] :
1578 if (llvm::isa<transform::TransformHandleTypeInterface>(result.getType())) {
1579 results.
set(result, state.getPayloadOps(terminatorOperand));
1580 }
else if (llvm::isa<transform::TransformValueHandleTypeInterface>(
1581 result.getType())) {
1582 results.
setValues(result, state.getPayloadValues(terminatorOperand));
1585 llvm::isa<transform::TransformParamTypeInterface>(result.getType()) &&
1586 "unhandled transform type interface");
1587 results.
setParams(result, state.getParams(terminatorOperand));
1609 iface.getEffectsOnValue(source, nestedEffects);
1610 for (
const auto &effect : nestedEffects)
1611 effects.emplace_back(effect.getEffect(), target, effect.getResource());
1620 auto iface = dyn_cast<MemoryEffectOpInterface>(&op);
1624 for (
auto &&[source, target] : llvm::zip(block.
getArguments(), operands)) {
1631 llvm::append_range(effects, nestedEffects);
1643 auto iface = dyn_cast<MemoryEffectOpInterface>(&op);
1648 iface.getEffects(effects);
1663 llvm::append_range(targets, state.getPayloadOps(op->
getOperand(0)));
1666 if (state.getNumTopLevelMappings() !=
1670 <<
" extra value bindings, but " << state.getNumTopLevelMappings()
1671 <<
" were provided to the interpreter";
1674 targets.push_back(state.getTopLevel());
1676 for (
unsigned i = 0, e = state.getNumTopLevelMappings(); i < e; ++i)
1677 extraMappings.push_back(llvm::to_vector(state.getTopLevelMapping(i)));
1680 if (failed(state.mapBlockArguments(region.
front().
getArgument(0), targets)))
1684 if (failed(state.mapBlockArgument(
1685 argument, extraMappings[argument.getArgNumber() - 1])))
1697 assert(isa<TransformOpInterface>(op) &&
1698 "should implement TransformOpInterface to have "
1699 "PossibleTopLevelTransformOpTrait");
1702 return op->
emitOpError() <<
"expects at least one region";
1705 if (!llvm::hasNItems(*bodyRegion, 1))
1706 return op->
emitOpError() <<
"expects a single-block region";
1711 <<
"expects the entry block to have at least one argument";
1713 if (!llvm::isa<TransformHandleTypeInterface>(
1716 <<
"expects the first entry block argument to be of type "
1717 "implementing TransformHandleTypeInterface";
1723 <<
"expects the type of the block argument to match "
1724 "the type of the operand";
1728 if (llvm::isa<TransformHandleTypeInterface, TransformParamTypeInterface,
1729 TransformValueHandleTypeInterface>(arg.
getType()))
1734 <<
"expects trailing entry block arguments to be of type implementing "
1735 "TransformHandleTypeInterface, TransformValueHandleTypeInterface or "
1736 "TransformParamTypeInterface";
1746 <<
"expects operands to be provided for a nested op";
1747 diag.attachNote(parent->getLoc())
1748 <<
"nested in another possible top-level op";
1763 bool hasPayloadOperands =
false;
1766 if (llvm::isa<TransformHandleTypeInterface,
1767 TransformValueHandleTypeInterface>(operand.get().getType()))
1768 hasPayloadOperands =
true;
1770 if (hasPayloadOperands)
1779 llvm::report_fatal_error(
1780 Twine(
"ParamProducerTransformOpTrait must be attached to an op that "
1781 "implements MemoryEffectsOpInterface, found on ") +
1785 if (llvm::isa<TransformParamTypeInterface>(result.getType()))
1788 <<
"ParamProducerTransformOpTrait attached to this op expects "
1789 "result types to implement TransformParamTypeInterface";
1811 template <
typename EffectTy,
typename ResourceTy,
typename Range>
1814 return isa<EffectTy>(effect.
getEffect()) &&
1820 transform::TransformOpInterface transform) {
1821 auto iface = cast<MemoryEffectOpInterface>(transform.getOperation());
1823 iface.getEffectsOnValue(handle, effects);
1824 return ::hasEffect<MemoryEffects::Read, TransformMappingResource>(effects) &&
1825 ::hasEffect<MemoryEffects::Free, TransformMappingResource>(effects);
1871 auto iface = cast<MemoryEffectOpInterface>(transform.getOperation());
1873 iface.getEffects(effects);
1874 return ::hasEffect<MemoryEffects::Write, PayloadIRResource>(effects);
1878 auto iface = cast<MemoryEffectOpInterface>(transform.getOperation());
1880 iface.getEffects(effects);
1881 return ::hasEffect<MemoryEffects::Read, PayloadIRResource>(effects);
1885 Block &block, llvm::SmallDenseSet<unsigned int> &consumedArguments) {
1888 auto iface = dyn_cast<MemoryEffectOpInterface>(nested);
1893 iface.getEffects(effects);
1896 dyn_cast_or_null<BlockArgument>(effect.getValue());
1897 if (!argument || argument.
getOwner() != &block ||
1898 !isa<MemoryEffects::Free>(effect.getEffect()) ||
1912 TransformOpInterface transformOp) {
1914 consumedOperands.reserve(transformOp->getNumOperands());
1915 auto memEffectInterface =
1916 cast<MemoryEffectOpInterface>(transformOp.getOperation());
1918 for (
OpOperand &target : transformOp->getOpOperands()) {
1920 memEffectInterface.getEffectsOnValue(target.get(), effects);
1922 return isa<transform::TransformMappingResource>(
1924 isa<MemoryEffects::Free>(effect.
getEffect());
1926 consumedOperands.push_back(&target);
1929 return consumedOperands;
1933 auto iface = cast<MemoryEffectOpInterface>(op);
1935 iface.getEffects(effects);
1937 auto effectsOn = [&](
Value value) {
1938 return llvm::make_filter_range(
1940 return instance.
getValue() == value;
1944 std::optional<unsigned> firstConsumedOperand;
1946 auto range = effectsOn(operand.get());
1947 if (range.empty()) {
1949 op->
emitError() <<
"TransformOpInterface requires memory effects "
1950 "on operands to be specified";
1951 diag.attachNote() <<
"no effects specified for operand #"
1952 << operand.getOperandNumber();
1955 if (::hasEffect<MemoryEffects::Allocate, TransformMappingResource>(range)) {
1957 <<
"TransformOpInterface did not expect "
1958 "'allocate' memory effect on an operand";
1959 diag.attachNote() <<
"specified for operand #"
1960 << operand.getOperandNumber();
1963 if (!firstConsumedOperand &&
1964 ::hasEffect<MemoryEffects::Free, TransformMappingResource>(range)) {
1965 firstConsumedOperand = operand.getOperandNumber();
1969 if (firstConsumedOperand &&
1970 !::hasEffect<MemoryEffects::Write, PayloadIRResource>(effects)) {
1973 <<
"TransformOpInterface expects ops consuming operands to have a "
1974 "'write' effect on the payload resource";
1975 diag.attachNote() <<
"consumes operand #" << *firstConsumedOperand;
1980 auto range = effectsOn(result);
1981 if (!::hasEffect<MemoryEffects::Allocate, TransformMappingResource>(
1984 op->
emitError() <<
"TransformOpInterface requires 'allocate' memory "
1985 "effect to be specified for results";
1986 diag.attachNote() <<
"no 'allocate' effect specified for result #"
1987 << result.getResultNumber();
2000 Operation *payloadRoot, TransformOpInterface transform,
2003 if (enforceToplevelTransformOp) {
2005 transform->getNumOperands() != 0) {
2006 return transform->emitError()
2007 <<
"expected transform to start at the top-level transform op";
2014 TransformState state(transform->getParentRegion(), payloadRoot, extraMapping,
2016 return state.applyTransform(transform).checkAndReport();
2023 #include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.cpp.inc"
2024 #include "mlir/Dialect/Transform/Interfaces/TransformTypeInterfaces.cpp.inc"
static InFlightDiagnostic emitDiag(Location location, DiagnosticSeverity severity, const Twine &message)
Helper function used to emit a diagnostic with an optionally empty twine message.
static std::string diag(const llvm::Value &value)
static llvm::ManagedStatic< PassManagerOptions > options
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
Attributes are known-constant values of operations.
This class represents an argument of a Block.
Block * getOwner() const
Returns the block that owns this argument.
unsigned getArgNumber() const
Returns the number of this argument.
Block represents an ordered list of Operations.
Operation * findAncestorOpInBlock(Operation &op)
Returns 'op' if 'op' lies in this block, or otherwise finds the ancestor operation of 'op' that lies ...
BlockArgument getArgument(unsigned i)
unsigned getNumArguments()
Operation * getTerminator()
Get the terminator operation of this block.
BlockArgListType getArguments()
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
A compatibility class connecting InFlightDiagnostic to DiagnosedSilenceableFailure while providing an...
The result of a transform IR operation application.
LogicalResult silence()
Converts silenceable failure into LogicalResult success without reporting the diagnostic,...
static DiagnosedSilenceableFailure success()
Constructs a DiagnosedSilenceableFailure in the success state.
Diagnostic & attachNote(std::optional< Location > loc=std::nullopt)
Attaches a note to the last diagnostic.
std::string getMessage() const
Returns the diagnostic message without emitting it.
bool isDefiniteFailure() const
Returns true if this is a definite failure.
LogicalResult checkAndReport()
Converts all kinds of failure into a LogicalResult failure, emitting the diagnostic if necessary.
static DiagnosedSilenceableFailure silenceableFailure(Diagnostic &&diag)
Constructs a DiagnosedSilenceableFailure in the silenceable failure state, ready to emit the given di...
bool succeeded() const
Returns true if this is a success.
static DiagnosedSilenceableFailure definiteFailure()
Constructs a DiagnosedSilenceableFailure in the failure state.
bool isSilenceableFailure() const
Returns true if this is a silenceable failure.
This class contains all of the information necessary to report a diagnostic to the DiagnosticEngine.
IRValueT get() const
Return the current value being used by this operand.
This class represents a diagnostic that is inflight and set to be reported.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
void setListener(Listener *newListener)
Sets the listener of this builder to the one provided.
This class represents an operand of an operation.
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Set of flags used to control the behavior of the various IR print methods (e.g.
This is a value defined by a result of an operation.
unsigned getResultNumber() const
Returns the number of this result.
This class provides the API for a sub-set of ops that are known to be constant-like.
This class provides the API for ops that are known to be isolated from above.
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
StringAttr getIdentifier() const
Return the name of this operation as a StringAttr.
T::Concept * getInterface() const
Returns an instance of the concept object for the given interface if it was registered to this operat...
Operation is the basic unit of execution within MLIR.
Value getOperand(unsigned idx)
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
bool isBeforeInBlock(Operation *other)
Given an operation 'other' that is within the same parent block, return whether the current operation...
unsigned getNumRegions()
Returns the number of regions held by this operation.
Location getLoc()
The source location the operation was defined or derived from.
unsigned getNumOperands()
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Block * getBlock()
Returns the operation block that contains this operation.
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Operation * getParentWithTrait()
Returns the closest surrounding parent operation with trait Trait.
OperationName getName()
The name of an operation is the key identifier for it.
MutableArrayRef< OpOperand > getOpOperands()
operand_range getOperands()
Returns an iterator on the underlying Value's.
result_range getOpResults()
result_range getResults()
bool isProperAncestor(Operation *other)
Return true if this operation is a proper ancestor of the other operation.
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
unsigned getNumResults()
Return the number of results held by this operation.
A 2D array where each row may have different length.
size_t size() const
Returns the number of rows in the 2D array.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Operation * getParentOp()
Return the parent operation this region is attached to.
This class implements the result iterators for the Operation class.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
This class represents a specific instance of an effect.
Resource * getResource() const
Return the resource that the effect applies to.
EffectT * getEffect() const
Return the effect being applied.
Value getValue() const
Return the value the effect is applied on, or nullptr if there isn't a known value being affected.
static DerivedEffect * get()
Returns a unique instance for the derived effect class.
static TransformMappingResource * get()
Returns a unique instance for the given effect class.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Operation * getOwner() const
Return the owner of this operand.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Include the generated interface declarations.
DiagnosedSilenceableFailure emitSilenceableFailure(Location loc, const Twine &message={})
Emits a silenceable failure with the given message.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
DiagnosedDefiniteFailure emitDefiniteFailure(Location loc, const Twine &message={})
Emits a definite failure with the given message.
bool hasEffect(Operation *op)
Returns true if op has an effect of type EffectTy.
Represents a range (offset, size, and stride) where each element of the triple may be dynamic or stat...