15 #include "llvm/ADT/STLExtras.h"
16 #include "llvm/ADT/ScopeExit.h"
17 #include "llvm/ADT/iterator.h"
18 #include "llvm/Support/Debug.h"
19 #include "llvm/Support/DebugLog.h"
20 #include "llvm/Support/ErrorHandling.h"
21 #include "llvm/Support/InterleavedRange.h"
23 #define DEBUG_TYPE "transform-dialect"
24 #define DEBUG_PRINT_AFTER_ALL "transform-dialect-print-top-level-after-all"
25 #define FULL_LDBG() LDBG(4)
50 constexpr
const Value transform::TransformState::kTopLevelValue;
52 transform::TransformState::TransformState(
55 const TransformOptions &
options)
57 topLevelMappedValues.reserve(extraMappings.
size());
59 topLevelMappedValues.push_back(mapping);
61 RegionScope *scope =
new RegionScope(*
this, *region);
62 topLevelRegionScope.reset(scope);
66 Operation *transform::TransformState::getTopLevel()
const {
return topLevel; }
69 transform::TransformState::getPayloadOpsView(
Value value)
const {
70 const TransformOpMapping &operationMapping = getMapping(value).direct;
71 auto iter = operationMapping.find(value);
72 assert(iter != operationMapping.end() &&
73 "cannot find mapping for payload handle (param/value handle "
75 return iter->getSecond();
80 auto iter = mapping.find(value);
81 assert(iter != mapping.end() &&
"cannot find mapping for param handle "
82 "(operation/value handle provided?)");
83 return iter->getSecond();
87 transform::TransformState::getPayloadValuesView(
Value handleValue)
const {
88 const ValueMapping &mapping = getMapping(handleValue).values;
89 auto iter = mapping.find(handleValue);
90 assert(iter != mapping.end() &&
"cannot find mapping for value handle "
91 "(param/operation handle provided?)");
92 return iter->getSecond();
97 bool includeOutOfScope)
const {
99 for (
const auto &[region, mapping] : llvm::reverse(mappings)) {
100 auto iterator = mapping->reverse.find(op);
101 if (iterator != mapping->reverse.end()) {
102 llvm::append_range(handles, iterator->getSecond());
106 if (!includeOutOfScope &&
111 return success(found);
116 bool includeOutOfScope)
const {
118 for (
const auto &[region, mapping] : llvm::reverse(mappings)) {
119 auto iterator = mapping->reverseValues.find(payloadValue);
120 if (iterator != mapping->reverseValues.end()) {
121 llvm::append_range(handles, iterator->getSecond());
125 if (!includeOutOfScope &&
130 return success(found);
140 if (llvm::isa<transform::TransformHandleTypeInterface>(handle.
getType())) {
142 operations.reserve(values.size());
144 if (
auto *op = llvm::dyn_cast_if_present<Operation *>(value)) {
145 operations.push_back(op);
149 <<
"wrong kind of value provided for top-level operation handle";
151 if (
failed(operationsFn(operations)))
156 if (llvm::isa<transform::TransformValueHandleTypeInterface>(
159 payloadValues.reserve(values.size());
161 if (
auto v = llvm::dyn_cast_if_present<Value>(value)) {
162 payloadValues.push_back(v);
166 <<
"wrong kind of value provided for the top-level value handle";
168 if (
failed(valuesFn(payloadValues)))
173 assert(llvm::isa<transform::TransformParamTypeInterface>(handle.
getType()) &&
174 "unsupported kind of block argument");
176 parameters.reserve(values.size());
178 if (
auto attr = llvm::dyn_cast_if_present<Attribute>(value)) {
179 parameters.push_back(attr);
183 <<
"wrong kind of value provided for top-level parameter";
185 if (
failed(paramsFn(parameters)))
196 return setPayloadOps(argument, operations);
199 return setParams(argument, params);
202 return setPayloadValues(argument, payloadValues);
210 for (
auto &&[argument, values] : llvm::zip_equal(arguments, mapping))
211 if (
failed(mapBlockArgument(argument, values)))
217 transform::TransformState::setPayloadOps(
Value value,
219 assert(value != kTopLevelValue &&
220 "attempting to reset the transformation root");
221 assert(llvm::isa<TransformHandleTypeInterface>(value.
getType()) &&
222 "wrong handle type");
228 <<
"attempting to assign a null payload op to this transform value";
231 auto iface = llvm::cast<TransformHandleTypeInterface>(value.
getType());
233 iface.checkPayload(value.
getLoc(), targets);
240 Mappings &mappings = getMapping(value);
242 mappings.direct.insert({value, std::move(storedTargets)}).second;
243 assert(inserted &&
"value is already associated with another list");
247 mappings.reverse[op].push_back(value);
253 transform::TransformState::setPayloadValues(
Value handle,
255 assert(handle !=
nullptr &&
"attempting to set params for a null value");
256 assert(llvm::isa<TransformValueHandleTypeInterface>(handle.
getType()) &&
257 "wrong handle type");
259 for (
Value payload : payloadValues) {
262 return emitError(handle.
getLoc()) <<
"attempting to assign a null payload "
263 "value to this transform handle";
266 auto iface = llvm::cast<TransformValueHandleTypeInterface>(handle.
getType());
269 iface.checkPayload(handle.
getLoc(), payloadValueVector);
273 Mappings &mappings = getMapping(handle);
275 mappings.values.insert({handle, std::move(payloadValueVector)}).second;
278 "value handle is already associated with another list of payload values");
281 for (
Value payload : payloadValues)
282 mappings.reverseValues[payload].push_back(handle);
287 LogicalResult transform::TransformState::setParams(
Value value,
289 assert(value !=
nullptr &&
"attempting to set params for a null value");
295 <<
"attempting to assign a null parameter to this transform value";
298 auto valueType = llvm::dyn_cast<TransformParamTypeInterface>(value.
getType());
300 "cannot associate parameter with a value of non-parameter type");
302 valueType.checkPayload(value.
getLoc(), params);
306 Mappings &mappings = getMapping(value);
308 mappings.params.insert({value, llvm::to_vector(params)}).second;
309 assert(inserted &&
"value is already associated with another list of params");
314 template <
typename Mapping,
typename Key,
typename Mapped>
316 auto it = mapping.find(key);
317 if (it == mapping.end())
320 llvm::erase(it->getSecond(), mapped);
321 if (it->getSecond().empty())
325 void transform::TransformState::forgetMapping(
Value opHandle,
327 bool allowOutOfScope) {
328 Mappings &mappings = getMapping(opHandle, allowOutOfScope);
329 for (
Operation *op : mappings.direct[opHandle])
331 mappings.direct.erase(opHandle);
332 #if LLVM_ENABLE_ABI_BREAKING_CHECKS
335 mappings.incrementTimestamp(opHandle);
338 for (
Value opResult : origOpFlatResults) {
340 (void)getHandlesForPayloadValue(opResult, resultHandles);
341 for (
Value resultHandle : resultHandles) {
342 Mappings &localMappings = getMapping(resultHandle);
344 #if LLVM_ENABLE_ABI_BREAKING_CHECKS
347 mappings.incrementTimestamp(resultHandle);
354 void transform::TransformState::forgetValueMapping(
356 Mappings &mappings = getMapping(valueHandle);
357 for (
Value payloadValue : mappings.reverseValues[valueHandle])
359 mappings.values.erase(valueHandle);
360 #if LLVM_ENABLE_ABI_BREAKING_CHECKS
363 mappings.incrementTimestamp(valueHandle);
366 for (
Operation *payloadOp : payloadOperations) {
368 (void)getHandlesForPayloadOp(payloadOp, opHandles);
369 for (
Value opHandle : opHandles) {
370 Mappings &localMappings = getMapping(opHandle);
374 #if LLVM_ENABLE_ABI_BREAKING_CHECKS
377 localMappings.incrementTimestamp(opHandle);
384 transform::TransformState::replacePayloadOp(
Operation *op,
391 (void)getHandlesForPayloadValue(opResult, valueHandles,
393 assert(valueHandles.empty() &&
"expected no mapping to old results");
400 if (
failed(getHandlesForPayloadOp(op, opHandles,
true)))
402 for (
Value handle : opHandles) {
403 Mappings &mappings = getMapping(handle,
true);
416 for (
Value handle : opHandles) {
417 Mappings &mappings = getMapping(handle,
true);
418 auto it = mappings.direct.find(handle);
419 if (it == mappings.direct.end())
426 mapped = replacement;
430 mappings.reverse[replacement].push_back(handle);
432 opHandlesToCompact.insert(handle);
440 transform::TransformState::replacePayloadValue(
Value value,
Value replacement) {
442 if (
failed(getHandlesForPayloadValue(value, valueHandles,
446 for (
Value handle : valueHandles) {
447 Mappings &mappings = getMapping(handle,
true);
454 #if LLVM_ENABLE_ABI_BREAKING_CHECKS
457 mappings.incrementTimestamp(handle);
460 auto it = mappings.values.find(handle);
461 if (it == mappings.values.end())
465 for (
Value &mapped : association) {
467 mapped = replacement;
469 mappings.reverseValues[replacement].push_back(handle);
476 void transform::TransformState::recordOpHandleInvalidationOne(
483 if (invalidatedHandles.count(otherHandle) ||
484 newlyInvalidated.count(otherHandle))
487 FULL_LDBG() <<
"--recordOpHandleInvalidationOne";
489 << llvm::interleaved(
490 llvm::make_pointee_range(potentialAncestors));
494 for (
Operation *ancestor : potentialAncestors) {
496 FULL_LDBG() <<
"----handle one ancestor: " << *ancestor;;
498 FULL_LDBG() <<
"----of payload with name: "
500 FULL_LDBG() <<
"----of payload: " << *payloadOp;
502 if (!ancestor->isAncestor(payloadOp))
509 Location ancestorLoc = ancestor->getLoc();
511 std::optional<Location> throughValueLoc =
512 throughValue ? std::make_optional(throughValue.
getLoc()) : std::nullopt;
513 newlyInvalidated[otherHandle] = [ancestorLoc, opLoc, owner, operandNo,
515 throughValueLoc](
Location currentLoc) {
517 <<
"op uses a handle invalidated by a "
518 "previously executed transform op";
519 diag.attachNote(otherHandle.getLoc()) <<
"handle to invalidated ops";
520 diag.attachNote(owner->getLoc())
521 <<
"invalidated by this transform op that consumes its operand #"
523 <<
" and invalidates all handles to payload IR entities associated "
524 "with this operand and entities nested in them";
525 diag.attachNote(ancestorLoc) <<
"ancestor payload op";
526 diag.attachNote(opLoc) <<
"nested payload op";
527 if (throughValueLoc) {
528 diag.attachNote(*throughValueLoc)
529 <<
"consumed handle points to this payload value";
535 void transform::TransformState::recordValueHandleInvalidationByOpHandleOne(
542 if (invalidatedHandles.count(valueHandle) ||
543 newlyInvalidated.count(valueHandle))
546 for (
Operation *ancestor : potentialAncestors) {
548 std::optional<unsigned> resultNo;
552 if (
auto opResult = llvm::dyn_cast<OpResult>(payloadValue)) {
553 definingOp = opResult.getOwner();
554 resultNo = opResult.getResultNumber();
556 auto arg = llvm::cast<BlockArgument>(payloadValue);
558 argumentNo = arg.getArgNumber();
559 blockNo = std::distance(arg.getOwner()->getParent()->begin(),
560 arg.getOwner()->getIterator());
561 regionNo = arg.getOwner()->getParent()->getRegionNumber();
563 assert(definingOp &&
"expected the value to be defined by an op as result "
564 "or block argument");
565 if (!ancestor->isAncestor(definingOp))
570 Location ancestorLoc = ancestor->getLoc();
573 newlyInvalidated[valueHandle] = [valueHandle, owner, operandNo, resultNo,
574 argumentNo, blockNo, regionNo, ancestorLoc,
575 opLoc, valueLoc](
Location currentLoc) {
577 <<
"op uses a handle invalidated by a "
578 "previously executed transform op";
579 diag.attachNote(valueHandle.
getLoc()) <<
"invalidated handle";
580 diag.attachNote(owner->getLoc())
581 <<
"invalidated by this transform op that consumes its operand #"
583 <<
" and invalidates all handles to payload IR entities "
584 "associated with this operand and entities nested in them";
585 diag.attachNote(ancestorLoc)
586 <<
"ancestor op associated with the consumed handle";
588 diag.attachNote(opLoc)
589 <<
"op defining the value as result #" << *resultNo;
591 diag.attachNote(opLoc)
592 <<
"op defining the value as block argument #" << argumentNo
593 <<
" of block #" << blockNo <<
" in region #" << regionNo;
595 diag.attachNote(valueLoc) <<
"payload value";
600 void transform::TransformState::recordOpHandleInvalidation(
605 if (potentialAncestors.empty()) {
606 FULL_LDBG() <<
"----recording invalidation for empty handle: "
611 newlyInvalidated[handle.
get()] = [owner, operandNo](
Location currentLoc) {
613 <<
"op uses a handle associated with empty "
614 "payload and invalidated by a "
615 "previously executed transform op";
617 <<
"invalidated by this transform op that consumes its operand #"
630 for (
const auto &[region, mapping] : llvm::reverse(mappings)) {
634 for (
const auto &[payloadOp, otherHandles] : mapping->reverse) {
635 for (
Value otherHandle : otherHandles)
636 recordOpHandleInvalidationOne(handle, potentialAncestors, payloadOp,
637 otherHandle, throughValue,
645 for (
const auto &[payloadValue, valueHandles] : mapping->reverseValues) {
646 for (
Value valueHandle : valueHandles)
647 recordValueHandleInvalidationByOpHandleOne(handle, potentialAncestors,
648 payloadValue, valueHandle,
658 void transform::TransformState::recordValueHandleInvalidation(
662 for (
Value payloadValue : getPayloadValuesView(valueHandle.
get())) {
664 (void)getHandlesForPayloadValue(payloadValue, otherValueHandles);
665 for (
Value otherHandle : otherValueHandles) {
669 newlyInvalidated[otherHandle] = [otherHandle, owner, operandNo,
672 <<
"op uses a handle invalidated by a "
673 "previously executed transform op";
674 diag.attachNote(otherHandle.getLoc()) <<
"invalidated handle";
675 diag.attachNote(owner->getLoc())
676 <<
"invalidated by this transform op that consumes its operand #"
678 <<
" and invalidates handles to the same values as associated with "
680 diag.attachNote(valueLoc) <<
"payload value";
684 if (
auto opResult = llvm::dyn_cast<OpResult>(payloadValue)) {
685 Operation *payloadOp = opResult.getOwner();
686 recordOpHandleInvalidation(valueHandle, payloadOp, payloadValue,
689 auto arg = llvm::dyn_cast<BlockArgument>(payloadValue);
690 for (
Operation &payloadOp : *arg.getOwner())
691 recordOpHandleInvalidation(valueHandle, &payloadOp, payloadValue,
701 LogicalResult transform::TransformState::checkAndRecordHandleInvalidationImpl(
702 transform::TransformOpInterface transform,
704 FULL_LDBG() <<
"--Start checkAndRecordHandleInvalidation";
705 auto memoryEffectsIface =
706 cast<MemoryEffectOpInterface>(transform.getOperation());
708 memoryEffectsIface.getEffectsOnResource(
711 for (
OpOperand &target : transform->getOpOperands()) {
712 FULL_LDBG() <<
"----iterate on handle: " << target.get();
717 auto it = invalidatedHandles.find(target.get());
718 auto nit = newlyInvalidated.find(target.get());
719 if (it != invalidatedHandles.end()) {
720 FULL_LDBG() <<
"--End checkAndRecordHandleInvalidation, found already "
721 "invalidated -> FAILURE";
722 return it->getSecond()(transform->getLoc()), failure();
724 if (!transform.allowsRepeatedHandleOperands() &&
725 nit != newlyInvalidated.end()) {
726 FULL_LDBG() <<
"--End checkAndRecordHandleInvalidation, found newly "
727 "invalidated (by this op) -> FAILURE";
728 return nit->getSecond()(transform->getLoc()), failure();
734 return isa<MemoryEffects::Free>(effect.getEffect()) &&
735 effect.getValue() == target.get();
737 if (llvm::any_of(effects, consumesTarget)) {
738 FULL_LDBG() <<
"----found consume effect";
739 if (llvm::isa<transform::TransformHandleTypeInterface>(
740 target.get().getType())) {
741 FULL_LDBG() <<
"----recordOpHandleInvalidation";
743 llvm::to_vector(getPayloadOps(target.get()));
744 recordOpHandleInvalidation(target, payloadOps,
nullptr,
746 }
else if (llvm::isa<transform::TransformValueHandleTypeInterface>(
747 target.get().getType())) {
748 FULL_LDBG() <<
"----recordValueHandleInvalidation";
749 recordValueHandleInvalidation(target, newlyInvalidated);
752 <<
"----not a TransformHandle -> SKIP AND DROP ON THE FLOOR";
755 FULL_LDBG() <<
"----no consume effect -> SKIP";
759 FULL_LDBG() <<
"--End checkAndRecordHandleInvalidation -> SUCCESS";
763 LogicalResult transform::TransformState::checkAndRecordHandleInvalidation(
764 transform::TransformOpInterface transform) {
765 InvalidatedHandleMap newlyInvalidated;
766 LogicalResult checkResult =
767 checkAndRecordHandleInvalidationImpl(transform, newlyInvalidated);
768 invalidatedHandles.insert(std::make_move_iterator(newlyInvalidated.begin()),
769 std::make_move_iterator(newlyInvalidated.end()));
773 template <
typename T>
776 transform::TransformOpInterface transform,
777 unsigned operandNumber) {
779 for (T p : payload) {
780 if (!seen.insert(p).second) {
782 transform.emitSilenceableError()
783 <<
"a handle passed as operand #" << operandNumber
784 <<
" and consumed by this operation points to a payload "
785 "entity more than once";
786 if constexpr (std::is_pointer_v<T>)
787 diag.attachNote(p->getLoc()) <<
"repeated target op";
789 diag.attachNote(p.getLoc()) <<
"repeated target value";
796 void transform::TransformState::compactOpHandles() {
797 for (
Value handle : opHandlesToCompact) {
798 Mappings &mappings = getMapping(handle,
true);
799 #if LLVM_ENABLE_ABI_BREAKING_CHECKS
800 if (llvm::is_contained(mappings.direct[handle],
nullptr))
803 mappings.incrementTimestamp(handle);
805 llvm::erase(mappings.direct[handle],
nullptr);
807 opHandlesToCompact.clear();
812 LDBG() <<
"applying: "
814 FULL_LDBG() <<
"Top-level payload before application:\n" << *getTopLevel();
815 auto printOnFailureRAII = llvm::make_scope_exit([
this] {
817 LDBG() <<
"Failing Top-level payload:\n"
823 regionStack.back()->currentTransform = transform;
826 if (
options.getExpensiveChecksEnabled()) {
828 if (
failed(checkAndRecordHandleInvalidation(transform)))
831 for (
OpOperand &operand : transform->getOpOperands()) {
832 FULL_LDBG() <<
"iterate on handle: " << operand.get();
834 FULL_LDBG() <<
"--handle not consumed -> SKIP";
837 if (transform.allowsRepeatedHandleOperands()) {
838 FULL_LDBG() <<
"--op allows repeated handles -> SKIP";
843 Type operandType = operand.get().getType();
844 if (llvm::isa<TransformHandleTypeInterface>(operandType)) {
845 FULL_LDBG() <<
"--checkRepeatedConsumptionInOperand for Operation*";
847 checkRepeatedConsumptionInOperand<Operation *>(
848 getPayloadOpsView(operand.get()), transform,
849 operand.getOperandNumber());
854 }
else if (llvm::isa<TransformValueHandleTypeInterface>(operandType)) {
855 FULL_LDBG() <<
"--checkRepeatedConsumptionInOperand For Value";
857 checkRepeatedConsumptionInOperand<Value>(
858 getPayloadValuesView(operand.get()), transform,
859 operand.getOperandNumber());
865 FULL_LDBG() <<
"--not a TransformHandle -> SKIP AND DROP ON THE FLOOR";
872 transform.getConsumedHandleOpOperands();
881 for (
OpOperand *opOperand : consumedOperands) {
882 Value operand = opOperand->get();
883 if (llvm::isa<TransformHandleTypeInterface>(operand.
getType())) {
884 for (
Operation *payloadOp : getPayloadOps(operand)) {
885 llvm::append_range(origOpFlatResults, payloadOp->
getResults());
889 if (llvm::isa<TransformValueHandleTypeInterface>(operand.
getType())) {
890 for (
Value payloadValue : getPayloadValuesView(operand)) {
891 if (llvm::isa<OpResult>(payloadValue)) {
897 llvm::map_range(*llvm::cast<BlockArgument>(payloadValue).getOwner(),
904 <<
"unexpectedly consumed a value that is not a handle as operand #"
905 << opOperand->getOperandNumber();
907 <<
"value defined here with type " << operand.
getType();
916 llvm::find_if(llvm::reverse(regionStack), [&](
RegionScope *scope) {
917 return handle.getParentRegion() == scope->region;
919 assert(scopeIt != regionStack.rend() &&
920 "could not find region scope for handle");
922 return llvm::all_of(handle.getUsers(), [&](
Operation *user) {
923 return user == scope->currentTransform ||
924 happensBefore(user, scope->currentTransform);
943 transform->hasAttr(FindPayloadReplacementOpInterface::
944 kSilenceTrackingFailuresAttrName)) {
948 (void)trackingFailure.
silence();
953 result = std::move(trackingFailure);
957 result.
attachNote() <<
"tracking listener also failed: "
959 (void)trackingFailure.
silence();
972 for (
OpOperand *opOperand : consumedOperands) {
973 Value operand = opOperand->get();
974 if (llvm::isa<TransformHandleTypeInterface>(operand.
getType())) {
975 forgetMapping(operand, origOpFlatResults);
976 }
else if (llvm::isa<TransformValueHandleTypeInterface>(
978 forgetValueMapping(operand, origAssociatedOps);
982 if (
failed(updateStateFromResults(results, transform->getResults())))
985 printOnFailureRAII.release();
987 LDBG() <<
"Top-level payload:\n" << *getTopLevel();
992 LogicalResult transform::TransformState::updateStateFromResults(
995 if (llvm::isa<TransformParamTypeInterface>(result.getType())) {
996 assert(results.isParam(result.getResultNumber()) &&
997 "expected parameters for the parameter-typed result");
999 setParams(result, results.getParams(result.getResultNumber())))) {
1002 }
else if (llvm::isa<TransformValueHandleTypeInterface>(result.getType())) {
1003 assert(results.isValue(result.getResultNumber()) &&
1004 "expected values for value-type-result");
1005 if (
failed(setPayloadValues(
1006 result, results.getValues(result.getResultNumber())))) {
1010 assert(!results.isParam(result.getResultNumber()) &&
1011 "expected payload ops for the non-parameter typed result");
1013 setPayloadOps(result, results.get(result.getResultNumber())))) {
1032 return state.replacePayloadOp(op, replacement);
1037 Value replacement) {
1038 return state.replacePayloadValue(value, replacement);
1049 for (
Block &block : *region) {
1050 for (
Value handle : block.getArguments()) {
1051 state.invalidatedHandles.erase(handle);
1055 state.invalidatedHandles.erase(handle);
1060 #if LLVM_ENABLE_ABI_BREAKING_CHECKS
1064 llvm::to_vector(llvm::make_first_range(state.mappings[region]->reverse));
1067 state.mappings.erase(region);
1068 state.regionStack.pop_back();
1075 transform::TransformResults::TransformResults(
unsigned numSegments) {
1076 operations.appendEmptyRows(numSegments);
1077 params.appendEmptyRows(numSegments);
1078 values.appendEmptyRows(numSegments);
1084 assert(position <
static_cast<int64_t
>(this->params.size()) &&
1085 "setting params for a non-existent handle");
1086 assert(this->params[position].data() ==
nullptr &&
"params already set");
1087 assert(operations[position].data() ==
nullptr &&
1088 "another kind of results already set");
1089 assert(values[position].data() ==
nullptr &&
1090 "another kind of results already set");
1091 this->params.replace(position, params);
1099 return set(handle, operations), success();
1102 return setParams(handle, params), success();
1105 return setValues(handle, payloadValues), success();
1108 if (!
diag.succeeded())
1109 llvm::dbgs() <<
diag.getStatusString() <<
"\n";
1110 assert(
diag.succeeded() &&
"incorrect mapping");
1112 (void)
diag.silence();
1116 transform::TransformOpInterface transform) {
1117 for (
OpResult opResult : transform->getResults()) {
1118 if (!isSet(opResult.getResultNumber()))
1119 setMappedValues(opResult, {});
1124 transform::TransformResults::get(
unsigned resultNumber)
const {
1125 assert(resultNumber < operations.size() &&
1126 "querying results for a non-existent handle");
1127 assert(operations[resultNumber].data() !=
nullptr &&
1128 "querying unset results (values or params expected?)");
1129 return operations[resultNumber];
1133 transform::TransformResults::getParams(
unsigned resultNumber)
const {
1134 assert(resultNumber < params.size() &&
1135 "querying params for a non-existent handle");
1136 assert(params[resultNumber].data() !=
nullptr &&
1137 "querying unset params (ops or values expected?)");
1138 return params[resultNumber];
1142 transform::TransformResults::getValues(
unsigned resultNumber)
const {
1143 assert(resultNumber < values.size() &&
1144 "querying values for a non-existent handle");
1145 assert(values[resultNumber].data() !=
nullptr &&
1146 "querying unset values (ops or params expected?)");
1147 return values[resultNumber];
1150 bool transform::TransformResults::isParam(
unsigned resultNumber)
const {
1151 assert(resultNumber < params.size() &&
1152 "querying association for a non-existent handle");
1153 return params[resultNumber].data() !=
nullptr;
1156 bool transform::TransformResults::isValue(
unsigned resultNumber)
const {
1157 assert(resultNumber < values.size() &&
1158 "querying association for a non-existent handle");
1159 return values[resultNumber].data() !=
nullptr;
1162 bool transform::TransformResults::isSet(
unsigned resultNumber)
const {
1163 assert(resultNumber < params.size() &&
1164 "querying association for a non-existent handle");
1165 return params[resultNumber].data() !=
nullptr ||
1166 operations[resultNumber].data() !=
nullptr ||
1167 values[resultNumber].data() !=
nullptr;
1175 TransformOpInterface op,
1179 for (
OpOperand *opOperand : transformOp.getConsumedHandleOpOperands()) {
1180 consumedHandles.insert(opOperand->get());
1187 for (
Value v : values) {
1192 defOp = v.getDefiningOp();
1195 if (defOp != v.getDefiningOp())
1204 "invalid number of replacement values");
1208 getTransformOp(),
"tracking listener failed to find replacement op "
1209 "during application of this transform op");
1215 diag.attachNote() <<
"replacement values belong to different ops";
1220 if (
config.skipCastOps && isa<CastOpInterface>(defOp)) {
1224 <<
"using output of 'CastOpInterface' op";
1230 if (!
config.requireMatchingReplacementOpName ||
1246 if (
auto findReplacementOpInterface =
1247 dyn_cast<FindPayloadReplacementOpInterface>(defOp)) {
1248 values.assign(findReplacementOpInterface.getNextOperands());
1249 diag.attachNote(defOp->
getLoc()) <<
"using operands provided by "
1250 "'FindPayloadReplacementOpInterface'";
1253 }
while (!values.empty());
1255 diag.attachNote() <<
"ran out of suitable replacement values";
1263 reasonCallback(
diag);
1264 LDBG() <<
"Match Failure : " <<
diag.str();
1268 void transform::TrackingListener::notifyOperationErased(
Operation *op) {
1271 (void)replacePayloadValue(value,
nullptr);
1273 (void)replacePayloadOp(op,
nullptr);
1276 void transform::TrackingListener::notifyOperationReplaced(
1279 "invalid number of replacement values");
1282 for (
auto [oldValue, newValue] : llvm::zip(op->
getResults(), newValues))
1283 (void)replacePayloadValue(oldValue, newValue);
1287 if (
failed(getTransformState().getHandlesForPayloadOp(
1288 op, opHandles,
true))) {
1302 auto handleWasConsumed = [&] {
1303 return llvm::any_of(opHandles,
1304 [&](
Value h) {
return consumedHandles.contains(h); });
1309 if (
config.skipHandleFn) {
1310 auto it = llvm::find_if(opHandles,
1311 [&](
Value v) {
return !
config.skipHandleFn(v); });
1312 if (it != opHandles.end())
1314 }
else if (!opHandles.empty()) {
1315 aliveHandle = opHandles.front();
1317 if (!aliveHandle || handleWasConsumed()) {
1320 (void)replacePayloadOp(op,
nullptr);
1326 findReplacementOp(replacement, op, newValues);
1329 if (!
diag.succeeded()) {
1331 <<
"replacement is required because this handle must be updated";
1332 notifyPayloadReplacementNotFound(op, newValues, std::move(
diag));
1333 (void)replacePayloadOp(op,
nullptr);
1337 (void)replacePayloadOp(op, replacement);
1344 assert(status.succeeded() &&
"listener state was not checked");
1356 return !status.succeeded();
1364 diag.takeDiagnostics(diags);
1365 if (!status.succeeded())
1366 status.takeDiagnostics(diags);
1370 status.attachNote(op->
getLoc()) <<
"[" << errorCounter <<
"] replaced op";
1372 status.attachNote(value.
getLoc())
1373 <<
"[" << errorCounter <<
"] replacement value " << index;
1379 if (!matchFailure) {
1382 return matchFailure->str();
1388 reasonCallback(
diag);
1389 matchFailure = std::move(
diag);
1403 return listener->failed();
1408 if (hasTrackingFailures()) {
1416 return listener->replacePayloadOp(op, replacement);
1427 for (
Operation *child : targets.drop_front(position + 1)) {
1428 if (parent->isAncestor(child)) {
1431 <<
"transform operation consumes a handle pointing to an ancestor "
1432 "payload operation before its descendant";
1434 <<
"the ancestor is likely erased or rewritten before the "
1435 "descendant is accessed, leading to undefined behavior";
1436 diag.attachNote(parent->getLoc()) <<
"ancestor payload op";
1437 diag.attachNote(child->getLoc()) <<
"descendant payload op";
1456 diag.attachNote(payloadOpLoc) <<
"when applied to this op";
1460 if (partialResult.
size() != expectedNumResults) {
1461 auto diag =
emitDiag() <<
"application of " << transformOpName
1462 <<
" expected to produce " << expectedNumResults
1463 <<
" results (actually produced "
1464 << partialResult.
size() <<
").";
1465 diag.attachNote(transformOpLoc)
1466 <<
"if you need variadic results, consider a generic `apply` "
1467 <<
"instead of the specialized `applyToOne`.";
1472 for (
const auto &[ptr, res] :
1473 llvm::zip(partialResult, transformOp->
getResults())) {
1476 if (llvm::isa<TransformHandleTypeInterface>(res.getType()) &&
1477 !isa<Operation *>(ptr)) {
1478 return emitDiag() <<
"application of " << transformOpName
1479 <<
" expected to produce an Operation * for result #"
1480 << res.getResultNumber();
1482 if (llvm::isa<TransformParamTypeInterface>(res.getType()) &&
1483 !isa<Attribute>(ptr)) {
1484 return emitDiag() <<
"application of " << transformOpName
1485 <<
" expected to produce an Attribute for result #"
1486 << res.getResultNumber();
1488 if (llvm::isa<TransformValueHandleTypeInterface>(res.getType()) &&
1490 return emitDiag() <<
"application of " << transformOpName
1491 <<
" expected to produce a Value for result #"
1492 << res.getResultNumber();
1498 template <
typename T>
1500 return llvm::to_vector(llvm::map_range(
1510 if (llvm::any_of(partialResults,
1511 [](
MappedValue value) {
return value.isNull(); }))
1513 assert(transformOp->
getNumResults() == partialResults.size() &&
1514 "expected as many partial results as op as results");
1516 transposed[i].push_back(value);
1520 unsigned position = r.getResultNumber();
1521 if (llvm::isa<TransformParamTypeInterface>(r.getType())) {
1523 castVector<Attribute>(transposed[position]));
1524 }
else if (llvm::isa<TransformValueHandleTypeInterface>(r.getType())) {
1525 transformResults.
setValues(r, castVector<Value>(transposed[position]));
1527 transformResults.
set(r, castVector<Operation *>(transposed[position]));
1539 assert(mappings.size() == values.size() &&
"mismatching number of mappings");
1540 for (
auto &&[operand, mapped] : llvm::zip_equal(values, mappings)) {
1541 size_t mappedSize = mapped.size();
1542 if (llvm::isa<TransformHandleTypeInterface>(operand.getType())) {
1543 llvm::append_range(mapped, state.getPayloadOps(operand));
1544 }
else if (llvm::isa<TransformValueHandleTypeInterface>(
1545 operand.getType())) {
1546 llvm::append_range(mapped, state.getPayloadValues(operand));
1548 assert(llvm::isa<TransformParamTypeInterface>(operand.getType()) &&
1549 "unsupported kind of transform dialect value");
1550 llvm::append_range(mapped, state.getParams(operand));
1553 if (mapped.size() - mappedSize != 1 && !flatten)
1562 mappings.resize(mappings.size() + values.size());
1572 for (
auto &&[terminatorOperand, result] :
1575 if (llvm::isa<transform::TransformHandleTypeInterface>(result.getType())) {
1576 results.
set(result, state.getPayloadOps(terminatorOperand));
1577 }
else if (llvm::isa<transform::TransformValueHandleTypeInterface>(
1578 result.getType())) {
1579 results.
setValues(result, state.getPayloadValues(terminatorOperand));
1582 llvm::isa<transform::TransformParamTypeInterface>(result.getType()) &&
1583 "unhandled transform type interface");
1584 results.
setParams(result, state.getParams(terminatorOperand));
1606 iface.getEffectsOnValue(source, nestedEffects);
1607 for (
const auto &effect : nestedEffects)
1608 effects.emplace_back(effect.getEffect(), target, effect.getResource());
1617 auto iface = dyn_cast<MemoryEffectOpInterface>(&op);
1621 for (
auto &&[source, target] : llvm::zip(block.
getArguments(), operands)) {
1628 llvm::append_range(effects, nestedEffects);
1640 auto iface = dyn_cast<MemoryEffectOpInterface>(&op);
1644 iface.getEffects(effects);
1659 llvm::append_range(targets, state.getPayloadOps(op->
getOperand(0)));
1662 if (state.getNumTopLevelMappings() !=
1666 <<
" extra value bindings, but " << state.getNumTopLevelMappings()
1667 <<
" were provided to the interpreter";
1670 targets.push_back(state.getTopLevel());
1672 for (
unsigned i = 0, e = state.getNumTopLevelMappings(); i < e; ++i)
1673 extraMappings.push_back(llvm::to_vector(state.getTopLevelMapping(i)));
1680 if (
failed(state.mapBlockArgument(
1681 argument, extraMappings[argument.getArgNumber() - 1])))
1693 assert(isa<TransformOpInterface>(op) &&
1694 "should implement TransformOpInterface to have "
1695 "PossibleTopLevelTransformOpTrait");
1698 return op->
emitOpError() <<
"expects at least one region";
1701 if (!llvm::hasNItems(*bodyRegion, 1))
1702 return op->
emitOpError() <<
"expects a single-block region";
1707 <<
"expects the entry block to have at least one argument";
1709 if (!llvm::isa<TransformHandleTypeInterface>(
1712 <<
"expects the first entry block argument to be of type "
1713 "implementing TransformHandleTypeInterface";
1719 <<
"expects the type of the block argument to match "
1720 "the type of the operand";
1724 if (llvm::isa<TransformHandleTypeInterface, TransformParamTypeInterface,
1725 TransformValueHandleTypeInterface>(arg.
getType()))
1730 <<
"expects trailing entry block arguments to be of type implementing "
1731 "TransformHandleTypeInterface, TransformValueHandleTypeInterface or "
1732 "TransformParamTypeInterface";
1742 <<
"expects operands to be provided for a nested op";
1743 diag.attachNote(parent->getLoc())
1744 <<
"nested in another possible top-level op";
1759 bool hasPayloadOperands =
false;
1762 if (llvm::isa<TransformHandleTypeInterface,
1763 TransformValueHandleTypeInterface>(operand.get().getType()))
1764 hasPayloadOperands =
true;
1766 if (hasPayloadOperands)
1775 llvm::report_fatal_error(
1776 Twine(
"ParamProducerTransformOpTrait must be attached to an op that "
1777 "implements MemoryEffectsOpInterface, found on ") +
1781 if (llvm::isa<TransformParamTypeInterface>(result.getType()))
1784 <<
"ParamProducerTransformOpTrait attached to this op expects "
1785 "result types to implement TransformParamTypeInterface";
1807 template <
typename EffectTy,
typename ResourceTy,
typename Range>
1810 return isa<EffectTy>(effect.
getEffect()) &&
1816 transform::TransformOpInterface transform) {
1817 auto iface = cast<MemoryEffectOpInterface>(transform.getOperation());
1819 iface.getEffectsOnValue(handle, effects);
1820 return ::hasEffect<MemoryEffects::Read, TransformMappingResource>(effects) &&
1821 ::hasEffect<MemoryEffects::Free, TransformMappingResource>(effects);
1867 auto iface = cast<MemoryEffectOpInterface>(transform.getOperation());
1869 iface.getEffects(effects);
1870 return ::hasEffect<MemoryEffects::Write, PayloadIRResource>(effects);
1874 auto iface = cast<MemoryEffectOpInterface>(transform.getOperation());
1876 iface.getEffects(effects);
1877 return ::hasEffect<MemoryEffects::Read, PayloadIRResource>(effects);
1881 Block &block, llvm::SmallDenseSet<unsigned int> &consumedArguments) {
1884 auto iface = dyn_cast<MemoryEffectOpInterface>(nested);
1889 iface.getEffects(effects);
1892 dyn_cast_or_null<BlockArgument>(effect.getValue());
1893 if (!argument || argument.
getOwner() != &block ||
1894 !isa<MemoryEffects::Free>(effect.getEffect()) ||
1908 TransformOpInterface transformOp) {
1910 consumedOperands.reserve(transformOp->getNumOperands());
1911 auto memEffectInterface =
1912 cast<MemoryEffectOpInterface>(transformOp.getOperation());
1914 for (
OpOperand &target : transformOp->getOpOperands()) {
1916 memEffectInterface.getEffectsOnValue(target.get(), effects);
1918 return isa<transform::TransformMappingResource>(
1920 isa<MemoryEffects::Free>(effect.
getEffect());
1922 consumedOperands.push_back(&target);
1925 return consumedOperands;
1929 auto iface = cast<MemoryEffectOpInterface>(op);
1931 iface.getEffects(effects);
1933 auto effectsOn = [&](
Value value) {
1934 return llvm::make_filter_range(
1936 return instance.
getValue() == value;
1940 std::optional<unsigned> firstConsumedOperand;
1942 auto range = effectsOn(operand.get());
1943 if (range.empty()) {
1945 op->
emitError() <<
"TransformOpInterface requires memory effects "
1946 "on operands to be specified";
1947 diag.attachNote() <<
"no effects specified for operand #"
1948 << operand.getOperandNumber();
1951 if (::hasEffect<MemoryEffects::Allocate, TransformMappingResource>(range)) {
1953 <<
"TransformOpInterface did not expect "
1954 "'allocate' memory effect on an operand";
1955 diag.attachNote() <<
"specified for operand #"
1956 << operand.getOperandNumber();
1959 if (!firstConsumedOperand &&
1960 ::hasEffect<MemoryEffects::Free, TransformMappingResource>(range)) {
1961 firstConsumedOperand = operand.getOperandNumber();
1965 if (firstConsumedOperand &&
1966 !::hasEffect<MemoryEffects::Write, PayloadIRResource>(effects)) {
1969 <<
"TransformOpInterface expects ops consuming operands to have a "
1970 "'write' effect on the payload resource";
1971 diag.attachNote() <<
"consumes operand #" << *firstConsumedOperand;
1976 auto range = effectsOn(result);
1977 if (!::hasEffect<MemoryEffects::Allocate, TransformMappingResource>(
1980 op->
emitError() <<
"TransformOpInterface requires 'allocate' memory "
1981 "effect to be specified for results";
1982 diag.attachNote() <<
"no 'allocate' effect specified for result #"
1983 << result.getResultNumber();
1996 Operation *payloadRoot, TransformOpInterface transform,
2001 if (enforceToplevelTransformOp) {
2003 transform->getNumOperands() != 0) {
2004 return transform->emitError()
2005 <<
"expected transform to start at the top-level transform op";
2012 TransformState state(transform->getParentRegion(), payloadRoot, extraMapping,
2014 if (stateInitializer)
2015 stateInitializer(state);
2016 if (state.applyTransform(transform).checkAndReport().failed())
2019 return stateExporter(state);
2027 #include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.cpp.inc"
2028 #include "mlir/Dialect/Transform/Interfaces/TransformTypeInterfaces.cpp.inc"
static Operation * getCommonDefiningOp(const ValueVector &values)
Return the operation that defines all values in the vector.
static InFlightDiagnostic emitDiag(Location location, DiagnosticSeverity severity, const Twine &message)
Helper function used to emit a diagnostic with an optionally empty twine message.
static std::string diag(const llvm::Value &value)
static llvm::ManagedStatic< PassManagerOptions > options
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
Attributes are known-constant values of operations.
This class represents an argument of a Block.
Block * getOwner() const
Returns the block that owns this argument.
unsigned getArgNumber() const
Returns the number of this argument.
Block represents an ordered list of Operations.
MutableArrayRef< BlockArgument > BlockArgListType
Operation * findAncestorOpInBlock(Operation &op)
Returns 'op' if 'op' lies in this block, or otherwise finds the ancestor operation of 'op' that lies ...
BlockArgument getArgument(unsigned i)
unsigned getNumArguments()
Operation * getTerminator()
Get the terminator operation of this block.
BlockArgListType getArguments()
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
A compatibility class connecting InFlightDiagnostic to DiagnosedSilenceableFailure while providing an...
The result of a transform IR operation application.
LogicalResult silence()
Converts silenceable failure into LogicalResult success without reporting the diagnostic,...
static DiagnosedSilenceableFailure success()
Constructs a DiagnosedSilenceableFailure in the success state.
Diagnostic & attachNote(std::optional< Location > loc=std::nullopt)
Attaches a note to the last diagnostic.
std::string getMessage() const
Returns the diagnostic message without emitting it.
bool isDefiniteFailure() const
Returns true if this is a definite failure.
LogicalResult checkAndReport()
Converts all kinds of failure into a LogicalResult failure, emitting the diagnostic if necessary.
static DiagnosedSilenceableFailure silenceableFailure(Diagnostic &&diag)
Constructs a DiagnosedSilenceableFailure in the silenceable failure state, ready to emit the given di...
bool succeeded() const
Returns true if this is a success.
static DiagnosedSilenceableFailure definiteFailure()
Constructs a DiagnosedSilenceableFailure in the failure state.
bool isSilenceableFailure() const
Returns true if this is a silenceable failure.
This class contains all of the information necessary to report a diagnostic to the DiagnosticEngine.
IRValueT get() const
Return the current value being used by this operand.
This class represents a diagnostic that is inflight and set to be reported.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
void setListener(Listener *newListener)
Sets the listener of this builder to the one provided.
This class represents an operand of an operation.
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Set of flags used to control the behavior of the various IR print methods (e.g.
This is a value defined by a result of an operation.
unsigned getResultNumber() const
Returns the number of this result.
This class provides the API for a sub-set of ops that are known to be constant-like.
This class provides the API for ops that are known to be isolated from above.
A wrapper class that allows for printing an operation with a set of flags, useful to act as a "stream...
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
StringAttr getIdentifier() const
Return the name of this operation as a StringAttr.
T::Concept * getInterface() const
Returns an instance of the concept object for the given interface if it was registered to this operat...
Operation is the basic unit of execution within MLIR.
Value getOperand(unsigned idx)
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
bool isBeforeInBlock(Operation *other)
Given an operation 'other' that is within the same parent block, return whether the current operation...
unsigned getNumRegions()
Returns the number of regions held by this operation.
Location getLoc()
The source location the operation was defined or derived from.
unsigned getNumOperands()
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Block * getBlock()
Returns the operation block that contains this operation.
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Operation * getParentWithTrait()
Returns the closest surrounding parent operation with trait Trait.
OperationName getName()
The name of an operation is the key identifier for it.
MutableArrayRef< OpOperand > getOpOperands()
operand_range getOperands()
Returns an iterator on the underlying Value's.
result_range getOpResults()
result_range getResults()
bool isProperAncestor(Operation *other)
Return true if this operation is a proper ancestor of the other operation.
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
unsigned getNumResults()
Return the number of results held by this operation.
A 2D array where each row may have different length.
size_t size() const
Returns the number of rows in the 2D array.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Operation * getParentOp()
Return the parent operation this region is attached to.
This class implements the result iterators for the Operation class.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
This class represents a specific instance of an effect.
Resource * getResource() const
Return the resource that the effect applies to.
EffectT * getEffect() const
Return the effect being applied.
Value getValue() const
Return the value the effect is applied on, or nullptr if there isn't a known value being affected.
static DerivedEffect * get()
Returns a unique instance for the derived effect class.
static TransformMappingResource * get()
Returns a unique instance for the given effect class.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Operation * getOwner() const
Return the owner of this operand.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Include the generated interface declarations.
const FrozenRewritePatternSet GreedyRewriteConfig config
DiagnosedSilenceableFailure emitSilenceableFailure(Location loc, const Twine &message={})
Emits a silenceable failure with the given message.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
DiagnosedDefiniteFailure emitDefiniteFailure(Location loc, const Twine &message={})
Emits a definite failure with the given message.
bool hasEffect(Operation *op)
Returns "true" if op has an effect of type EffectTy.
Represents a range (offset, size, and stride) where each element of the triple may be dynamic or stat...