15#include "llvm/ADT/STLExtras.h"
16#include "llvm/ADT/ScopeExit.h"
17#include "llvm/ADT/iterator.h"
18#include "llvm/Support/Debug.h"
19#include "llvm/Support/DebugLog.h"
20#include "llvm/Support/ErrorHandling.h"
21#include "llvm/Support/InterleavedRange.h"
23#define DEBUG_TYPE "transform-dialect"
24#define DEBUG_PRINT_AFTER_ALL "transform-dialect-print-top-level-after-all"
25#define FULL_LDBG() LDBG(4)
50transform::TransformState::TransformState(
53 const TransformOptions &
options)
55 topLevelMappedValues.reserve(extraMappings.
size());
57 topLevelMappedValues.push_back(mapping);
59 RegionScope *scope =
new RegionScope(*
this, *region);
60 topLevelRegionScope.reset(scope);
67transform::TransformState::getPayloadOpsView(
Value value)
const {
68 const TransformOpMapping &operationMapping = getMapping(value).direct;
69 auto iter = operationMapping.find(value);
70 assert(iter != operationMapping.end() &&
71 "cannot find mapping for payload handle (param/value handle "
73 return iter->getSecond();
77 const ParamMapping &mapping = getMapping(value).params;
78 auto iter = mapping.find(value);
79 assert(iter != mapping.end() &&
"cannot find mapping for param handle "
80 "(operation/value handle provided?)");
81 return iter->getSecond();
85transform::TransformState::getPayloadValuesView(
Value handleValue)
const {
86 const ValueMapping &mapping = getMapping(handleValue).values;
87 auto iter = mapping.find(handleValue);
88 assert(iter != mapping.end() &&
"cannot find mapping for value handle "
89 "(param/operation handle provided?)");
90 return iter->getSecond();
95 bool includeOutOfScope)
const {
97 for (
const auto &[region, mapping] : llvm::reverse(mappings)) {
98 auto iterator = mapping->reverse.find(op);
99 if (iterator != mapping->reverse.end()) {
100 llvm::append_range(handles, iterator->getSecond());
104 if (!includeOutOfScope &&
114 bool includeOutOfScope)
const {
116 for (
const auto &[region, mapping] : llvm::reverse(mappings)) {
117 auto iterator = mapping->reverseValues.find(payloadValue);
118 if (iterator != mapping->reverseValues.end()) {
119 llvm::append_range(handles, iterator->getSecond());
123 if (!includeOutOfScope &&
138 if (llvm::isa<transform::TransformHandleTypeInterface>(handle.getType())) {
140 operations.reserve(values.size());
142 if (
auto *op = llvm::dyn_cast_if_present<Operation *>(value)) {
143 operations.push_back(op);
147 <<
"wrong kind of value provided for top-level operation handle";
149 if (failed(operationsFn(operations)))
154 if (llvm::isa<transform::TransformValueHandleTypeInterface>(
157 payloadValues.reserve(values.size());
159 if (
auto v = llvm::dyn_cast_if_present<Value>(value)) {
160 payloadValues.push_back(v);
164 <<
"wrong kind of value provided for the top-level value handle";
166 if (failed(valuesFn(payloadValues)))
171 assert(llvm::isa<transform::TransformParamTypeInterface>(handle.getType()) &&
172 "unsupported kind of block argument");
174 parameters.reserve(values.size());
176 if (
auto attr = llvm::dyn_cast_if_present<Attribute>(value)) {
177 parameters.push_back(attr);
181 <<
"wrong kind of value provided for top-level parameter";
183 if (failed(paramsFn(parameters)))
194 return setPayloadOps(argument, operations);
197 return setParams(argument, params);
200 return setPayloadValues(argument, payloadValues);
208 for (
auto &&[argument, values] : llvm::zip_equal(arguments, mapping))
215transform::TransformState::setPayloadOps(
Value value,
217 assert(value != kTopLevelValue &&
218 "attempting to reset the transformation root");
219 assert(llvm::isa<TransformHandleTypeInterface>(value.
getType()) &&
220 "wrong handle type");
226 <<
"attempting to assign a null payload op to this transform value";
229 auto iface = llvm::cast<TransformHandleTypeInterface>(value.
getType());
231 iface.checkPayload(value.
getLoc(), targets);
232 if (failed(
result.checkAndReport()))
238 Mappings &mappings = getMapping(value);
240 mappings.direct.insert({value, std::move(storedTargets)}).second;
241 assert(
inserted &&
"value is already associated with another list");
245 mappings.reverse[op].push_back(value);
251transform::TransformState::setPayloadValues(Value handle,
253 assert(handle !=
nullptr &&
"attempting to set params for a null value");
254 assert(llvm::isa<TransformValueHandleTypeInterface>(handle.
getType()) &&
255 "wrong handle type");
257 for (Value payload : payloadValues) {
260 return emitError(handle.
getLoc()) <<
"attempting to assign a null payload "
261 "value to this transform handle";
264 auto iface = llvm::cast<TransformValueHandleTypeInterface>(handle.
getType());
265 SmallVector<Value> payloadValueVector = llvm::to_vector(payloadValues);
266 DiagnosedSilenceableFailure
result =
267 iface.checkPayload(handle.
getLoc(), payloadValueVector);
271 Mappings &mappings = getMapping(handle);
273 mappings.values.insert({handle, std::move(payloadValueVector)}).second;
276 "value handle is already associated with another list of payload values");
279 for (Value payload : payloadValues)
280 mappings.reverseValues[payload].push_back(handle);
285LogicalResult transform::TransformState::setParams(Value value,
286 ArrayRef<Param> params) {
287 assert(value !=
nullptr &&
"attempting to set params for a null value");
289 for (Attribute attr : params) {
293 <<
"attempting to assign a null parameter to this transform value";
296 auto valueType = llvm::dyn_cast<TransformParamTypeInterface>(value.
getType());
298 "cannot associate parameter with a value of non-parameter type");
299 DiagnosedSilenceableFailure
result =
300 valueType.checkPayload(value.
getLoc(), params);
304 Mappings &mappings = getMapping(value);
306 mappings.params.insert({value, llvm::to_vector(params)}).second;
307 assert(
inserted &&
"value is already associated with another list of params");
312template <
typename Mapping,
typename Key,
typename Mapped>
314 auto it = mapping.find(key);
315 if (it == mapping.end())
318 llvm::erase(it->getSecond(), mapped);
319 if (it->getSecond().empty())
323void transform::TransformState::forgetMapping(Value opHandle,
325 bool allowOutOfScope) {
326 Mappings &mappings = getMapping(opHandle, allowOutOfScope);
327 for (Operation *op : mappings.direct[opHandle])
329 mappings.direct.erase(opHandle);
330#if LLVM_ENABLE_ABI_BREAKING_CHECKS
333 mappings.incrementTimestamp(opHandle);
336 for (Value opResult : origOpFlatResults) {
337 SmallVector<Value> resultHandles;
338 (void)getHandlesForPayloadValue(opResult, resultHandles);
339 for (Value resultHandle : resultHandles) {
340 Mappings &localMappings = getMapping(resultHandle);
342#if LLVM_ENABLE_ABI_BREAKING_CHECKS
345 mappings.incrementTimestamp(resultHandle);
352void transform::TransformState::forgetValueMapping(
353 Value valueHandle, ArrayRef<Operation *> payloadOperations) {
354 Mappings &mappings = getMapping(valueHandle);
355 for (Value payloadValue : mappings.reverseValues[valueHandle])
357 mappings.values.erase(valueHandle);
358#if LLVM_ENABLE_ABI_BREAKING_CHECKS
361 mappings.incrementTimestamp(valueHandle);
364 for (Operation *payloadOp : payloadOperations) {
365 SmallVector<Value> opHandles;
366 (void)getHandlesForPayloadOp(payloadOp, opHandles);
367 for (Value opHandle : opHandles) {
368 Mappings &localMappings = getMapping(opHandle);
372#if LLVM_ENABLE_ABI_BREAKING_CHECKS
375 localMappings.incrementTimestamp(opHandle);
382transform::TransformState::replacePayloadOp(Operation *op,
388 SmallVector<Value> valueHandles;
389 (void)getHandlesForPayloadValue(opResult, valueHandles,
391 assert(valueHandles.empty() &&
"expected no mapping to old results");
397 SmallVector<Value> opHandles;
398 if (
failed(getHandlesForPayloadOp(op, opHandles,
true)))
400 for (Value handle : opHandles) {
401 Mappings &mappings = getMapping(handle,
true);
414 for (Value handle : opHandles) {
415 Mappings &mappings = getMapping(handle,
true);
416 auto it = mappings.direct.find(handle);
417 if (it == mappings.direct.end())
420 SmallVector<Operation *, 2> &association = it->getSecond();
422 for (Operation *&mapped : association) {
430 opHandlesToCompact.insert(handle);
438transform::TransformState::replacePayloadValue(Value value, Value
replacement) {
439 SmallVector<Value> valueHandles;
440 if (
failed(getHandlesForPayloadValue(value, valueHandles,
444 for (Value handle : valueHandles) {
445 Mappings &mappings = getMapping(handle,
true);
452#if LLVM_ENABLE_ABI_BREAKING_CHECKS
455 mappings.incrementTimestamp(handle);
458 auto it = mappings.values.find(handle);
459 if (it == mappings.values.end())
462 SmallVector<Value> &association = it->getSecond();
463 for (Value &mapped : association) {
467 mappings.reverseValues[
replacement].push_back(handle);
474void transform::TransformState::recordOpHandleInvalidationOne(
475 OpOperand &consumingHandle, ArrayRef<Operation *> potentialAncestors,
476 Operation *payloadOp, Value otherHandle, Value throughValue,
477 transform::TransformState::InvalidatedHandleMap &newlyInvalidated)
const {
481 if (invalidatedHandles.count(otherHandle) ||
482 newlyInvalidated.count(otherHandle))
485 FULL_LDBG() <<
"--recordOpHandleInvalidationOne";
487 << llvm::interleaved(
488 llvm::make_pointee_range(potentialAncestors));
490 Operation *owner = consumingHandle.
getOwner();
492 for (Operation *ancestor : potentialAncestors) {
494 FULL_LDBG() <<
"----handle one ancestor: " << *ancestor;;
496 FULL_LDBG() <<
"----of payload with name: "
498 FULL_LDBG() <<
"----of payload: " << *payloadOp;
500 if (!ancestor->isAncestor(payloadOp))
507 Location ancestorLoc = ancestor->
getLoc();
508 Location opLoc = payloadOp->
getLoc();
509 std::optional<Location> throughValueLoc =
510 throughValue ? std::make_optional(throughValue.
getLoc()) : std::nullopt;
511 newlyInvalidated[otherHandle] = [ancestorLoc, opLoc, owner, operandNo,
513 throughValueLoc](Location currentLoc) {
515 <<
"op uses a handle invalidated by a "
516 "previously executed transform op";
517 diag.
attachNote(otherHandle.getLoc()) <<
"handle to invalidated ops";
519 <<
"invalidated by this transform op that consumes its operand #"
521 <<
" and invalidates all handles to payload IR entities associated "
522 "with this operand and entities nested in them";
525 if (throughValueLoc) {
526 diag.attachNote(*throughValueLoc)
527 <<
"consumed handle points to this payload value";
533void transform::TransformState::recordValueHandleInvalidationByOpHandleOne(
534 OpOperand &opHandle, ArrayRef<Operation *> potentialAncestors,
535 Value payloadValue, Value valueHandle,
536 transform::TransformState::InvalidatedHandleMap &newlyInvalidated)
const {
540 if (invalidatedHandles.count(valueHandle) ||
541 newlyInvalidated.count(valueHandle))
544 for (Operation *ancestor : potentialAncestors) {
545 Operation *definingOp;
546 std::optional<unsigned> resultNo;
547 unsigned argumentNo = std::numeric_limits<unsigned>::max();
548 unsigned blockNo = std::numeric_limits<unsigned>::max();
549 unsigned regionNo = std::numeric_limits<unsigned>::max();
550 if (
auto opResult = llvm::dyn_cast<OpResult>(payloadValue)) {
551 definingOp = opResult.getOwner();
552 resultNo = opResult.getResultNumber();
554 auto arg = llvm::cast<BlockArgument>(payloadValue);
556 argumentNo = arg.getArgNumber();
557 blockNo = arg.getOwner()->computeBlockNumber();
558 regionNo = arg.getOwner()->getParent()->getRegionNumber();
560 assert(definingOp &&
"expected the value to be defined by an op as result "
561 "or block argument");
562 if (!ancestor->isAncestor(definingOp))
565 Operation *owner = opHandle.
getOwner();
567 Location ancestorLoc = ancestor->getLoc();
568 Location opLoc = definingOp->
getLoc();
569 Location valueLoc = payloadValue.
getLoc();
570 newlyInvalidated[valueHandle] = [valueHandle, owner, operandNo, resultNo,
571 argumentNo, blockNo, regionNo, ancestorLoc,
572 opLoc, valueLoc](Location currentLoc) {
574 <<
"op uses a handle invalidated by a "
575 "previously executed transform op";
578 <<
"invalidated by this transform op that consumes its operand #"
580 <<
" and invalidates all handles to payload IR entities "
581 "associated with this operand and entities nested in them";
583 <<
"ancestor op associated with the consumed handle";
585 diag.attachNote(opLoc)
586 <<
"op defining the value as result #" << *resultNo;
588 diag.attachNote(opLoc)
589 <<
"op defining the value as block argument #" << argumentNo
590 <<
" of block #" << blockNo <<
" in region #" << regionNo;
597void transform::TransformState::recordOpHandleInvalidation(
598 OpOperand &handle, ArrayRef<Operation *> potentialAncestors,
600 transform::TransformState::InvalidatedHandleMap &newlyInvalidated)
const {
602 if (potentialAncestors.empty()) {
603 FULL_LDBG() <<
"----recording invalidation for empty handle: "
606 Operation *owner = handle.
getOwner();
608 newlyInvalidated[handle.
get()] = [owner, operandNo](Location currentLoc) {
610 <<
"op uses a handle associated with empty "
611 "payload and invalidated by a "
612 "previously executed transform op";
614 <<
"invalidated by this transform op that consumes its operand #"
627 for (
const auto &[region, mapping] : llvm::reverse(mappings)) {
631 for (
const auto &[payloadOp, otherHandles] : mapping->reverse) {
632 for (Value otherHandle : otherHandles)
633 recordOpHandleInvalidationOne(handle, potentialAncestors, payloadOp,
634 otherHandle, throughValue,
642 for (
const auto &[payloadValue, valueHandles] : mapping->reverseValues) {
643 for (Value valueHandle : valueHandles)
644 recordValueHandleInvalidationByOpHandleOne(handle, potentialAncestors,
645 payloadValue, valueHandle,
655void transform::TransformState::recordValueHandleInvalidation(
656 OpOperand &valueHandle,
657 transform::TransformState::InvalidatedHandleMap &newlyInvalidated)
const {
659 for (Value payloadValue : getPayloadValuesView(valueHandle.
get())) {
660 SmallVector<Value> otherValueHandles;
661 (void)getHandlesForPayloadValue(payloadValue, otherValueHandles);
662 for (Value otherHandle : otherValueHandles) {
663 Operation *owner = valueHandle.
getOwner();
665 Location valueLoc = payloadValue.
getLoc();
666 newlyInvalidated[otherHandle] = [otherHandle, owner, operandNo,
667 valueLoc](Location currentLoc) {
669 <<
"op uses a handle invalidated by a "
670 "previously executed transform op";
673 <<
"invalidated by this transform op that consumes its operand #"
675 <<
" and invalidates handles to the same values as associated with "
681 if (
auto opResult = llvm::dyn_cast<OpResult>(payloadValue)) {
682 Operation *payloadOp = opResult.getOwner();
683 recordOpHandleInvalidation(valueHandle, payloadOp, payloadValue,
686 auto arg = llvm::dyn_cast<BlockArgument>(payloadValue);
687 for (Operation &payloadOp : *arg.getOwner())
688 recordOpHandleInvalidation(valueHandle, &payloadOp, payloadValue,
698LogicalResult transform::TransformState::checkAndRecordHandleInvalidationImpl(
699 transform::TransformOpInterface transform,
700 transform::TransformState::InvalidatedHandleMap &newlyInvalidated)
const {
701 FULL_LDBG() <<
"--Start checkAndRecordHandleInvalidation";
702 auto memoryEffectsIface =
703 cast<MemoryEffectOpInterface>(transform.getOperation());
704 SmallVector<MemoryEffects::EffectInstance> effects;
705 memoryEffectsIface.getEffectsOnResource(
708 for (OpOperand &
target : transform->getOpOperands()) {
714 auto it = invalidatedHandles.find(
target.get());
715 auto nit = newlyInvalidated.find(
target.get());
716 if (it != invalidatedHandles.end()) {
717 FULL_LDBG() <<
"--End checkAndRecordHandleInvalidation, found already "
718 "invalidated -> FAILURE";
719 return it->getSecond()(transform->getLoc()), failure();
721 if (!transform.allowsRepeatedHandleOperands() &&
722 nit != newlyInvalidated.end()) {
723 FULL_LDBG() <<
"--End checkAndRecordHandleInvalidation, found newly "
724 "invalidated (by this op) -> FAILURE";
725 return nit->getSecond()(transform->getLoc()), failure();
731 return isa<MemoryEffects::Free>(effect.getEffect()) &&
732 effect.getValue() ==
target.get();
734 if (llvm::any_of(effects, consumesTarget)) {
735 FULL_LDBG() <<
"----found consume effect";
736 if (llvm::isa<transform::TransformHandleTypeInterface>(
737 target.get().getType())) {
738 FULL_LDBG() <<
"----recordOpHandleInvalidation";
739 SmallVector<Operation *> payloadOps =
740 llvm::to_vector(getPayloadOps(
target.get()));
741 recordOpHandleInvalidation(
target, payloadOps,
nullptr,
743 }
else if (llvm::isa<transform::TransformValueHandleTypeInterface>(
744 target.get().getType())) {
745 FULL_LDBG() <<
"----recordValueHandleInvalidation";
746 recordValueHandleInvalidation(
target, newlyInvalidated);
749 <<
"----not a TransformHandle -> SKIP AND DROP ON THE FLOOR";
752 FULL_LDBG() <<
"----no consume effect -> SKIP";
756 FULL_LDBG() <<
"--End checkAndRecordHandleInvalidation -> SUCCESS";
760LogicalResult transform::TransformState::checkAndRecordHandleInvalidation(
761 transform::TransformOpInterface transform) {
762 InvalidatedHandleMap newlyInvalidated;
763 LogicalResult checkResult =
764 checkAndRecordHandleInvalidationImpl(transform, newlyInvalidated);
765 invalidatedHandles.insert(std::make_move_iterator(newlyInvalidated.begin()),
766 std::make_move_iterator(newlyInvalidated.end()));
771static DiagnosedSilenceableFailure
773 transform::TransformOpInterface
transform,
774 unsigned operandNumber) {
776 for (T p : payload) {
777 if (!seen.insert(p).second) {
780 <<
"a handle passed as operand #" << operandNumber
781 <<
" and consumed by this operation points to a payload "
782 "entity more than once";
783 if constexpr (std::is_pointer_v<T>)
784 diag.attachNote(p->getLoc()) <<
"repeated target op";
786 diag.attachNote(p.getLoc()) <<
"repeated target value";
793void transform::TransformState::compactOpHandles() {
794 for (Value handle : opHandlesToCompact) {
795 Mappings &mappings = getMapping(handle,
true);
796#if LLVM_ENABLE_ABI_BREAKING_CHECKS
797 if (llvm::is_contained(mappings.direct[handle],
nullptr))
800 mappings.incrementTimestamp(handle);
802 llvm::erase(mappings.direct[handle],
nullptr);
804 opHandlesToCompact.clear();
807DiagnosedSilenceableFailure
809 LDBG() <<
"applying: "
812 llvm::scope_exit printOnFailureRAII([
this] {
814 LDBG() <<
"Failing Top-level payload:\n"
820 regionStack.back()->currentTransform =
transform;
823 if (options.getExpensiveChecksEnabled()) {
825 if (failed(checkAndRecordHandleInvalidation(
transform)))
829 FULL_LDBG() <<
"iterate on handle: " << operand.get();
831 FULL_LDBG() <<
"--handle not consumed -> SKIP";
834 if (
transform.allowsRepeatedHandleOperands()) {
835 FULL_LDBG() <<
"--op allows repeated handles -> SKIP";
840 Type operandType = operand.get().getType();
841 if (llvm::isa<TransformHandleTypeInterface>(operandType)) {
842 FULL_LDBG() <<
"--checkRepeatedConsumptionInOperand for Operation*";
845 getPayloadOpsView(operand.get()),
transform,
846 operand.getOperandNumber());
851 }
else if (llvm::isa<TransformValueHandleTypeInterface>(operandType)) {
852 FULL_LDBG() <<
"--checkRepeatedConsumptionInOperand For Value";
855 getPayloadValuesView(operand.get()),
transform,
856 operand.getOperandNumber());
862 FULL_LDBG() <<
"--not a TransformHandle -> SKIP AND DROP ON THE FLOOR";
878 for (
OpOperand *opOperand : consumedOperands) {
879 Value operand = opOperand->get();
880 if (llvm::isa<TransformHandleTypeInterface>(operand.
getType())) {
882 llvm::append_range(origOpFlatResults, payloadOp->
getResults());
886 if (llvm::isa<TransformValueHandleTypeInterface>(operand.
getType())) {
887 for (
Value payloadValue : getPayloadValuesView(operand)) {
888 if (llvm::isa<OpResult>(payloadValue)) {
894 llvm::map_range(*llvm::cast<BlockArgument>(payloadValue).getOwner(),
901 <<
"unexpectedly consumed a value that is not a handle as operand #"
902 << opOperand->getOperandNumber();
904 <<
"value defined here with type " << operand.
getType();
913 llvm::find_if(llvm::reverse(regionStack), [&](
RegionScope *scope) {
914 return handle.getParentRegion() == scope->region;
916 assert(scopeIt != regionStack.rend() &&
917 "could not find region scope for handle");
919 return llvm::all_of(handle.getUsers(), [&](
Operation *user) {
920 return user == scope->currentTransform ||
921 happensBefore(user, scope->currentTransform);
940 transform->hasAttr(FindPayloadReplacementOpInterface::
941 kSilenceTrackingFailuresAttrName)) {
950 result = std::move(trackingFailure);
953 if (
result.isSilenceableFailure())
954 result.attachNote() <<
"tracking listener also failed: "
959 if (
result.isDefiniteFailure())
964 if (
result.isSilenceableFailure())
970 Value operand = opOperand->get();
971 if (llvm::isa<TransformHandleTypeInterface>(operand.
getType())) {
972 forgetMapping(operand, origOpFlatResults);
973 }
else if (llvm::isa<TransformValueHandleTypeInterface>(
975 forgetValueMapping(operand, origAssociatedOps);
979 if (failed(updateStateFromResults(results,
transform->getResults())))
982 printOnFailureRAII.release();
984 LDBG() <<
"Top-level payload:\n" << *getTopLevel();
989LogicalResult transform::TransformState::updateStateFromResults(
990 const TransformResults &results,
ResultRange opResults) {
992 if (llvm::isa<TransformParamTypeInterface>(
result.getType())) {
993 assert(results.isParam(
result.getResultNumber()) &&
994 "expected parameters for the parameter-typed result");
996 setParams(
result, results.getParams(
result.getResultNumber())))) {
999 }
else if (llvm::isa<TransformValueHandleTypeInterface>(
result.getType())) {
1000 assert(results.isValue(
result.getResultNumber()) &&
1001 "expected values for value-type-result");
1002 if (
failed(setPayloadValues(
1003 result, results.getValues(
result.getResultNumber())))) {
1007 assert(!results.isParam(
result.getResultNumber()) &&
1008 "expected payload ops for the non-parameter typed result");
1010 setPayloadOps(
result, results.get(
result.getResultNumber())))) {
1035 return state.replacePayloadValue(value,
replacement);
1046 for (
Block &block : *region) {
1047 for (
Value handle : block.getArguments()) {
1048 state.invalidatedHandles.erase(handle);
1052 state.invalidatedHandles.erase(handle);
1057#if LLVM_ENABLE_ABI_BREAKING_CHECKS
1061 llvm::to_vector(llvm::make_first_range(state.mappings[region]->reverse));
1064 state.mappings.erase(region);
1065 state.regionStack.pop_back();
1072transform::TransformResults::TransformResults(
unsigned numSegments) {
1073 operations.appendEmptyRows(numSegments);
1074 params.appendEmptyRows(numSegments);
1075 values.appendEmptyRows(numSegments);
1081 assert(position <
static_cast<int64_t>(this->params.size()) &&
1082 "setting params for a non-existent handle");
1083 assert(this->params[position].data() ==
nullptr &&
"params already set");
1084 assert(operations[position].data() ==
nullptr &&
1085 "another kind of results already set");
1086 assert(values[position].data() ==
nullptr &&
1087 "another kind of results already set");
1088 this->params.replace(position, params);
1105 if (!
diag.succeeded())
1106 llvm::dbgs() <<
diag.getStatusString() <<
"\n";
1107 assert(
diag.succeeded() &&
"incorrect mapping");
1113 transform::TransformOpInterface
transform) {
1115 if (!isSet(opResult.getResultNumber()))
1121transform::TransformResults::get(
unsigned resultNumber)
const {
1122 assert(resultNumber < operations.size() &&
1123 "querying results for a non-existent handle");
1124 assert(operations[resultNumber].data() !=
nullptr &&
1125 "querying unset results (values or params expected?)");
1126 return operations[resultNumber];
1130transform::TransformResults::getParams(
unsigned resultNumber)
const {
1131 assert(resultNumber < params.size() &&
1132 "querying params for a non-existent handle");
1133 assert(params[resultNumber].data() !=
nullptr &&
1134 "querying unset params (ops or values expected?)");
1135 return params[resultNumber];
1139transform::TransformResults::getValues(
unsigned resultNumber)
const {
1140 assert(resultNumber < values.size() &&
1141 "querying values for a non-existent handle");
1142 assert(values[resultNumber].data() !=
nullptr &&
1143 "querying unset values (ops or params expected?)");
1144 return values[resultNumber];
1147bool transform::TransformResults::isParam(
unsigned resultNumber)
const {
1148 assert(resultNumber < params.size() &&
1149 "querying association for a non-existent handle");
1150 return params[resultNumber].data() !=
nullptr;
1153bool transform::TransformResults::isValue(
unsigned resultNumber)
const {
1154 assert(resultNumber < values.size() &&
1155 "querying association for a non-existent handle");
1156 return values[resultNumber].data() !=
nullptr;
1159bool transform::TransformResults::isSet(
unsigned resultNumber)
const {
1160 assert(resultNumber < params.size() &&
1161 "querying association for a non-existent handle");
1162 return params[resultNumber].data() !=
nullptr ||
1163 operations[resultNumber].data() !=
nullptr ||
1164 values[resultNumber].data() !=
nullptr;
1172 TransformOpInterface op,
1176 for (
OpOperand *opOperand : transformOp.getConsumedHandleOpOperands()) {
1177 consumedHandles.insert(opOperand->get());
1184 for (
Value v : values) {
1189 defOp = v.getDefiningOp();
1192 if (defOp != v.getDefiningOp())
1201 "invalid number of replacement values");
1205 getTransformOp(),
"tracking listener failed to find replacement op "
1206 "during application of this transform op");
1212 diag.attachNote() <<
"replacement values belong to different ops";
1217 if (config.skipCastOps && isa<CastOpInterface>(defOp)) {
1221 <<
"using output of 'CastOpInterface' op";
1227 if (!config.requireMatchingReplacementOpName ||
1243 if (
auto findReplacementOpInterface =
1244 dyn_cast<FindPayloadReplacementOpInterface>(defOp)) {
1245 values.assign(findReplacementOpInterface.getNextOperands());
1246 diag.attachNote(defOp->
getLoc()) <<
"using operands provided by "
1247 "'FindPayloadReplacementOpInterface'";
1250 }
while (!values.empty());
1252 diag.attachNote() <<
"ran out of suitable replacement values";
1260 reasonCallback(
diag);
1261 LDBG() <<
"Match Failure : " <<
diag.str();
1265void transform::TrackingListener::notifyOperationErased(
Operation *op) {
1268 (
void)replacePayloadValue(value,
nullptr);
1270 (
void)replacePayloadOp(op,
nullptr);
1273void transform::TrackingListener::notifyOperationReplaced(
1276 "invalid number of replacement values");
1279 for (
auto [oldValue, newValue] : llvm::zip(op->
getResults(), newValues))
1280 (
void)replacePayloadValue(oldValue, newValue);
1284 if (failed(getTransformState().getHandlesForPayloadOp(
1285 op, opHandles,
true))) {
1299 auto handleWasConsumed = [&] {
1300 return llvm::any_of(opHandles,
1301 [&](
Value h) {
return consumedHandles.contains(h); });
1306 if (
config.skipHandleFn) {
1307 auto *it = llvm::find_if(opHandles,
1308 [&](Value v) {
return !
config.skipHandleFn(v); });
1309 if (it != opHandles.end())
1311 }
else if (!opHandles.empty()) {
1312 aliveHandle = opHandles.front();
1314 if (!aliveHandle || handleWasConsumed()) {
1317 (void)replacePayloadOp(op,
nullptr);
1322 DiagnosedSilenceableFailure
diag =
1326 if (!
diag.succeeded()) {
1328 <<
"replacement is required because this handle must be updated";
1329 notifyPayloadReplacementNotFound(op, newValues, std::move(
diag));
1330 (void)replacePayloadOp(op,
nullptr);
1341 assert(status.succeeded() &&
"listener state was not checked");
1353 return !status.succeeded();
1361 diag.takeDiagnostics(diags);
1362 if (!status.succeeded())
1363 status.takeDiagnostics(diags);
1367 status.attachNote(op->
getLoc()) <<
"[" << errorCounter <<
"] replaced op";
1368 for (
auto &&[
index, value] : llvm::enumerate(values))
1369 status.attachNote(value.
getLoc())
1370 <<
"[" << errorCounter <<
"] replacement value " <<
index;
1376 if (!matchFailure) {
1379 return matchFailure->str();
1385 reasonCallback(
diag);
1386 matchFailure = std::move(
diag);
1400 return listener->failed();
1413 return listener->replacePayloadOp(op,
replacement);
1423 for (
auto &&[position, parent] : llvm::enumerate(targets)) {
1424 for (
Operation *child : targets.drop_front(position + 1)) {
1425 if (parent->isAncestor(child)) {
1428 <<
"transform operation consumes a handle pointing to an ancestor "
1429 "payload operation before its descendant";
1431 <<
"the ancestor is likely erased or rewritten before the "
1432 "descendant is accessed, leading to undefined behavior";
1433 diag.attachNote(parent->getLoc()) <<
"ancestor payload op";
1434 diag.attachNote(child->getLoc()) <<
"descendant payload op";
1453 diag.attachNote(payloadOpLoc) <<
"when applied to this op";
1457 if (partialResult.
size() != expectedNumResults) {
1458 auto diag =
emitDiag() <<
"application of " << transformOpName
1459 <<
" expected to produce " << expectedNumResults
1460 <<
" results (actually produced "
1461 << partialResult.
size() <<
").";
1462 diag.attachNote(transformOpLoc)
1463 <<
"if you need variadic results, consider a generic `apply` "
1464 <<
"instead of the specialized `applyToOne`.";
1469 for (
const auto &[
ptr, res] :
1470 llvm::zip(partialResult, transformOp->
getResults())) {
1473 if (llvm::isa<TransformHandleTypeInterface>(res.getType()) &&
1474 !isa<Operation *>(
ptr)) {
1475 return emitDiag() <<
"application of " << transformOpName
1476 <<
" expected to produce an Operation * for result #"
1477 << res.getResultNumber();
1479 if (llvm::isa<TransformParamTypeInterface>(res.getType()) &&
1480 !isa<Attribute>(
ptr)) {
1481 return emitDiag() <<
"application of " << transformOpName
1482 <<
" expected to produce an Attribute for result #"
1483 << res.getResultNumber();
1485 if (llvm::isa<TransformValueHandleTypeInterface>(res.getType()) &&
1487 return emitDiag() <<
"application of " << transformOpName
1488 <<
" expected to produce a Value for result #"
1489 << res.getResultNumber();
1495template <
typename T>
1497 return llvm::map_to_vector(range, llvm::CastTo<T>);
1506 if (llvm::any_of(partialResults,
1507 [](
MappedValue value) {
return value.isNull(); }))
1509 assert(transformOp->
getNumResults() == partialResults.size() &&
1510 "expected as many partial results as op as results");
1511 for (
auto [i, value] : llvm::enumerate(partialResults))
1512 transposed[i].push_back(value);
1516 unsigned position = r.getResultNumber();
1517 if (llvm::isa<TransformParamTypeInterface>(r.getType())) {
1520 }
else if (llvm::isa<TransformValueHandleTypeInterface>(r.getType())) {
1535 assert(mappings.size() == values.size() &&
"mismatching number of mappings");
1536 for (
auto &&[operand, mapped] : llvm::zip_equal(values, mappings)) {
1537 size_t mappedSize = mapped.size();
1538 if (llvm::isa<TransformHandleTypeInterface>(operand.getType())) {
1540 }
else if (llvm::isa<TransformValueHandleTypeInterface>(
1541 operand.getType())) {
1544 assert(llvm::isa<TransformParamTypeInterface>(operand.getType()) &&
1545 "unsupported kind of transform dialect value");
1546 llvm::append_range(mapped, state.
getParams(operand));
1549 if (mapped.size() - mappedSize != 1 && !flatten)
1558 mappings.resize(mappings.size() + values.size());
1568 for (
auto &&[terminatorOperand,
result] :
1571 if (llvm::isa<transform::TransformHandleTypeInterface>(
result.getType())) {
1573 }
else if (llvm::isa<transform::TransformValueHandleTypeInterface>(
1578 llvm::isa<transform::TransformParamTypeInterface>(
result.getType()) &&
1579 "unhandled transform type interface");
1602 iface.getEffectsOnValue(source, nestedEffects);
1603 for (
const auto &effect : nestedEffects)
1604 effects.emplace_back(effect.getEffect(),
target, effect.getResource());
1613 auto iface = dyn_cast<MemoryEffectOpInterface>(&op);
1624 llvm::append_range(effects, nestedEffects);
1636 auto iface = dyn_cast<MemoryEffectOpInterface>(&op);
1640 iface.getEffects(effects);
1663 <<
" were provided to the interpreter";
1677 argument, extraMappings[argument.getArgNumber() - 1])))
1689 assert(isa<TransformOpInterface>(op) &&
1690 "should implement TransformOpInterface to have "
1691 "PossibleTopLevelTransformOpTrait");
1694 return op->
emitOpError() <<
"expects at least one region";
1697 if (!llvm::hasNItems(*bodyRegion, 1))
1698 return op->
emitOpError() <<
"expects a single-block region";
1703 <<
"expects the entry block to have at least one argument";
1705 if (!llvm::isa<TransformHandleTypeInterface>(
1708 <<
"expects the first entry block argument to be of type "
1709 "implementing TransformHandleTypeInterface";
1715 <<
"expects the type of the block argument to match "
1716 "the type of the operand";
1720 if (llvm::isa<TransformHandleTypeInterface, TransformParamTypeInterface,
1721 TransformValueHandleTypeInterface>(arg.
getType()))
1726 <<
"expects trailing entry block arguments to be of type implementing "
1727 "TransformHandleTypeInterface, TransformValueHandleTypeInterface or "
1728 "TransformParamTypeInterface";
1738 <<
"expects operands to be provided for a nested op";
1739 diag.attachNote(parent->getLoc())
1740 <<
"nested in another possible top-level op";
1755 bool hasPayloadOperands =
false;
1758 if (llvm::isa<TransformHandleTypeInterface,
1759 TransformValueHandleTypeInterface>(operand.get().getType()))
1760 hasPayloadOperands =
true;
1762 if (hasPayloadOperands)
1771 llvm::report_fatal_error(
1772 Twine(
"ParamProducerTransformOpTrait must be attached to an op that "
1773 "implements MemoryEffectsOpInterface, found on ") +
1777 if (llvm::isa<TransformParamTypeInterface>(
result.getType()))
1780 <<
"ParamProducerTransformOpTrait attached to this op expects "
1781 "result types to implement TransformParamTypeInterface";
1803template <
typename EffectTy,
typename ResourceTy,
typename Range>
1806 return isa<EffectTy>(effect.
getEffect()) &&
1812 transform::TransformOpInterface
transform) {
1813 auto iface = cast<MemoryEffectOpInterface>(
transform.getOperation());
1815 iface.getEffectsOnValue(handle, effects);
1816 return ::hasEffect<MemoryEffects::Read, TransformMappingResource>(effects) &&
1863 auto iface = cast<MemoryEffectOpInterface>(
transform.getOperation());
1865 iface.getEffects(effects);
1866 return ::hasEffect<MemoryEffects::Write, PayloadIRResource>(effects);
1870 auto iface = cast<MemoryEffectOpInterface>(
transform.getOperation());
1872 iface.getEffects(effects);
1873 return ::hasEffect<MemoryEffects::Read, PayloadIRResource>(effects);
1877 Block &block, llvm::SmallDenseSet<unsigned int> &consumedArguments) {
1880 auto iface = dyn_cast<MemoryEffectOpInterface>(nested);
1885 iface.getEffects(effects);
1888 dyn_cast_or_null<BlockArgument>(effect.getValue());
1889 if (!argument || argument.
getOwner() != &block ||
1890 !isa<MemoryEffects::Free>(effect.getEffect()) ||
1904 TransformOpInterface transformOp) {
1905 SmallVector<OpOperand *> consumedOperands;
1906 consumedOperands.reserve(transformOp->getNumOperands());
1907 auto memEffectInterface =
1908 cast<MemoryEffectOpInterface>(transformOp.getOperation());
1909 SmallVector<MemoryEffects::EffectInstance, 2> effects;
1910 for (OpOperand &
target : transformOp->getOpOperands()) {
1912 memEffectInterface.getEffectsOnValue(
target.get(), effects);
1914 return isa<transform::TransformMappingResource>(
1916 isa<MemoryEffects::Free>(effect.
getEffect());
1918 consumedOperands.push_back(&
target);
1921 return consumedOperands;
1925 auto iface = cast<MemoryEffectOpInterface>(op);
1927 iface.getEffects(effects);
1929 auto effectsOn = [&](
Value value) {
1930 return llvm::make_filter_range(
1932 return instance.
getValue() == value;
1936 std::optional<unsigned> firstConsumedOperand;
1938 auto range = effectsOn(operand.get());
1939 if (range.empty()) {
1941 op->
emitError() <<
"TransformOpInterface requires memory effects "
1942 "on operands to be specified";
1943 diag.attachNote() <<
"no effects specified for operand #"
1944 << operand.getOperandNumber();
1949 <<
"TransformOpInterface did not expect "
1950 "'allocate' memory effect on an operand";
1951 diag.attachNote() <<
"specified for operand #"
1952 << operand.getOperandNumber();
1955 if (!firstConsumedOperand &&
1957 firstConsumedOperand = operand.getOperandNumber();
1961 if (firstConsumedOperand &&
1965 <<
"TransformOpInterface expects ops consuming operands to have a "
1966 "'write' effect on the payload resource";
1967 diag.attachNote() <<
"consumes operand #" << *firstConsumedOperand;
1972 auto range = effectsOn(
result);
1976 op->
emitError() <<
"TransformOpInterface requires 'allocate' memory "
1977 "effect to be specified for results";
1978 diag.attachNote() <<
"no 'allocate' effect specified for result #"
1979 <<
result.getResultNumber();
1997 if (enforceToplevelTransformOp) {
2001 <<
"expected transform to start at the top-level transform op";
2010 if (stateInitializer)
2011 stateInitializer(state);
2015 return stateExporter(state);
2023#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.cpp.inc"
2024#include "mlir/Dialect/Transform/Interfaces/TransformTypeInterfaces.cpp.inc"
static Operation * getCommonDefiningOp(const ValueVector &values)
Return the operation that defines all values in the vector.
static InFlightDiagnostic emitDiag(Location location, DiagnosticSeverity severity, const Twine &message)
Helper function used to emit a diagnostic with an optionally empty twine message.
*if copies could not be generated due to yet unimplemented cases *copyInPlacementStart and copyOutPlacementStart in copyPlacementBlock *specify the insertion points where the incoming copies and outgoing should be inserted(the insertion happens right before the *insertion point). Since `begin` can itself be invalidated due to the memref *rewriting done from this method
*if copies could not be generated due to yet unimplemented cases *copyInPlacementStart and copyOutPlacementStart in copyPlacementBlock *specify the insertion points where the incoming copies and outgoing should be the output argument nBegin is set to its * replacement(set to `begin` if no invalidation happens). Since outgoing *copies could have been inserted at `end`
static std::string diag(const llvm::Value &value)
static llvm::ManagedStatic< PassManagerOptions > options
This class represents an argument of a Block.
unsigned getArgNumber() const
Returns the number of this argument.
Block * getOwner() const
Returns the block that owns this argument.
Block represents an ordered list of Operations.
MutableArrayRef< BlockArgument > BlockArgListType
Operation * findAncestorOpInBlock(Operation &op)
Returns 'op' if 'op' lies in this block, or otherwise finds the ancestor operation of 'op' that lies ...
BlockArgument getArgument(unsigned i)
unsigned getNumArguments()
Operation * getTerminator()
Get the terminator operation of this block.
BlockArgListType getArguments()
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
A compatibility class connecting InFlightDiagnostic to DiagnosedSilenceableFailure while providing an...
The result of a transform IR operation application.
LogicalResult silence()
Converts silenceable failure into LogicalResult success without reporting the diagnostic,...
static DiagnosedSilenceableFailure success()
Constructs a DiagnosedSilenceableFailure in the success state.
std::string getMessage() const
Returns the diagnostic message without emitting it.
LogicalResult checkAndReport()
Converts all kinds of failure into a LogicalResult failure, emitting the diagnostic if necessary.
static DiagnosedSilenceableFailure silenceableFailure(Diagnostic &&diag)
Constructs a DiagnosedSilenceableFailure in the silenceable failure state, ready to emit the given di...
bool succeeded() const
Returns true if this is a success.
static DiagnosedSilenceableFailure definiteFailure()
Constructs a DiagnosedSilenceableFailure in the failure state.
bool isSilenceableFailure() const
Returns true if this is a silenceable failure.
This class contains all of the information necessary to report a diagnostic to the DiagnosticEngine.
Diagnostic & attachNote(std::optional< Location > noteLoc=std::nullopt)
Attaches a note to this diagnostic.
IRValueT get() const
Return the current value being used by this operand.
This class represents a diagnostic that is inflight and set to be reported.
Diagnostic & attachNote(std::optional< Location > noteLoc=std::nullopt)
Attaches a note to this diagnostic.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
void setListener(Listener *newListener)
Sets the listener of this builder to the one provided.
This class represents an operand of an operation.
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Set of flags used to control the behavior of the various IR print methods (e.g.
This is a value defined by a result of an operation.
unsigned getResultNumber() const
Returns the number of this result.
This class provides the API for a sub-set of ops that are known to be constant-like.
This class provides the API for ops that are known to be isolated from above.
A wrapper class that allows for printing an operation with a set of flags, useful to act as a "stream...
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
StringAttr getIdentifier() const
Return the name of this operation as a StringAttr.
T::Concept * getInterface() const
Returns an instance of the concept object for the given interface if it was registered to this operat...
Operation is the basic unit of execution within MLIR.
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Value getOperand(unsigned idx)
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
bool isBeforeInBlock(Operation *other)
Given an operation 'other' that is within the same parent block, return whether the current operation...
Block * getBlock()
Returns the operation block that contains this operation.
Operation * getParentWithTrait()
Returns the closest surrounding parent operation with trait Trait.
unsigned getNumRegions()
Returns the number of regions held by this operation.
Location getLoc()
The source location the operation was defined or derived from.
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
MutableArrayRef< OpOperand > getOpOperands()
unsigned getNumOperands()
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
OperationName getName()
The name of an operation is the key identifier for it.
operand_range getOperands()
Returns an iterator on the underlying Value's.
result_range getOpResults()
result_range getResults()
bool isProperAncestor(Operation *other)
Return true if this operation is a proper ancestor of the other operation.
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
unsigned getNumResults()
Return the number of results held by this operation.
A 2D array where each row may have different length.
size_t size() const
Returns the number of rows in the 2D array.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Operation * getParentOp()
Return the parent operation this region is attached to.
This class implements the result iterators for the Operation class.
RewriterBase(MLIRContext *ctx, OpBuilder::Listener *listener=nullptr)
Initialize the builder.
Resource * getResource() const
Return the resource that the effect applies to.
EffectT * getEffect() const
Return the effect being applied.
Value getValue() const
Return the value the effect is applied on, or nullptr if there isn't a known value being affected.
static DerivedEffect * get()
static TransformMappingResource * get()
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Operation * getOwner() const
Return the owner of this operand.
SideEffects::EffectInstance< Effect > EffectInstance
Include the generated interface declarations.
const FrozenRewritePatternSet GreedyRewriteConfig config
DiagnosedSilenceableFailure emitSilenceableFailure(Location loc, const Twine &message={})
Emits a silenceable failure with the given message.
llvm::DenseSet< ValueT, ValueInfoT > DenseSet
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
DiagnosedDefiniteFailure emitDefiniteFailure(Location loc, const Twine &message={})
Emits a definite failure with the given message.
bool hasEffect(Operation *op)
Returns "true" if op has an effect of type EffectTy.
llvm::function_ref< Fn > function_ref
Represents a range (offset, size, and stride) where each element of the triple may be dynamic or stat...