15#include "llvm/ADT/STLExtras.h"
16#include "llvm/ADT/ScopeExit.h"
17#include "llvm/ADT/iterator.h"
18#include "llvm/Support/Debug.h"
19#include "llvm/Support/DebugLog.h"
20#include "llvm/Support/ErrorHandling.h"
21#include "llvm/Support/InterleavedRange.h"
23#define DEBUG_TYPE "transform-dialect"
24#define DEBUG_PRINT_AFTER_ALL "transform-dialect-print-top-level-after-all"
25#define FULL_LDBG() LDBG(4)
50transform::TransformState::TransformState(
53 const TransformOptions &
options)
55 topLevelMappedValues.reserve(extraMappings.
size());
57 topLevelMappedValues.push_back(mapping);
59 RegionScope *scope =
new RegionScope(*
this, *region);
60 topLevelRegionScope.reset(scope);
67transform::TransformState::getPayloadOpsView(
Value value)
const {
68 const TransformOpMapping &operationMapping = getMapping(value).direct;
69 auto iter = operationMapping.find(value);
70 assert(iter != operationMapping.end() &&
71 "cannot find mapping for payload handle (param/value handle "
73 return iter->getSecond();
77 const ParamMapping &mapping = getMapping(value).params;
78 auto iter = mapping.find(value);
79 assert(iter != mapping.end() &&
"cannot find mapping for param handle "
80 "(operation/value handle provided?)");
81 return iter->getSecond();
85transform::TransformState::getPayloadValuesView(
Value handleValue)
const {
86 const ValueMapping &mapping = getMapping(handleValue).values;
87 auto iter = mapping.find(handleValue);
88 assert(iter != mapping.end() &&
"cannot find mapping for value handle "
89 "(param/operation handle provided?)");
90 return iter->getSecond();
95 bool includeOutOfScope)
const {
97 for (
const auto &[region, mapping] : llvm::reverse(mappings)) {
98 auto iterator = mapping->reverse.find(op);
99 if (iterator != mapping->reverse.end()) {
100 llvm::append_range(handles, iterator->getSecond());
104 if (!includeOutOfScope &&
114 bool includeOutOfScope)
const {
116 for (
const auto &[region, mapping] : llvm::reverse(mappings)) {
117 auto iterator = mapping->reverseValues.find(payloadValue);
118 if (iterator != mapping->reverseValues.end()) {
119 llvm::append_range(handles, iterator->getSecond());
123 if (!includeOutOfScope &&
138 if (llvm::isa<transform::TransformHandleTypeInterface>(handle.getType())) {
140 operations.reserve(values.size());
142 if (
auto *op = llvm::dyn_cast_if_present<Operation *>(value)) {
143 operations.push_back(op);
147 <<
"wrong kind of value provided for top-level operation handle";
149 if (failed(operationsFn(operations)))
154 if (llvm::isa<transform::TransformValueHandleTypeInterface>(
157 payloadValues.reserve(values.size());
159 if (
auto v = llvm::dyn_cast_if_present<Value>(value)) {
160 payloadValues.push_back(v);
164 <<
"wrong kind of value provided for the top-level value handle";
166 if (failed(valuesFn(payloadValues)))
171 assert(llvm::isa<transform::TransformParamTypeInterface>(handle.getType()) &&
172 "unsupported kind of block argument");
174 parameters.reserve(values.size());
176 if (
auto attr = llvm::dyn_cast_if_present<Attribute>(value)) {
177 parameters.push_back(attr);
181 <<
"wrong kind of value provided for top-level parameter";
183 if (failed(paramsFn(parameters)))
194 return setPayloadOps(argument, operations);
197 return setParams(argument, params);
200 return setPayloadValues(argument, payloadValues);
208 for (
auto &&[argument, values] : llvm::zip_equal(arguments, mapping))
215transform::TransformState::setPayloadOps(
Value value,
217 assert(value != kTopLevelValue &&
218 "attempting to reset the transformation root");
219 assert(llvm::isa<TransformHandleTypeInterface>(value.
getType()) &&
220 "wrong handle type");
226 <<
"attempting to assign a null payload op to this transform value";
229 auto iface = llvm::cast<TransformHandleTypeInterface>(value.
getType());
231 iface.checkPayload(value.
getLoc(), targets);
232 if (failed(
result.checkAndReport()))
238 Mappings &mappings = getMapping(value);
240 mappings.direct.insert({value, std::move(storedTargets)}).second;
241 assert(
inserted &&
"value is already associated with another list");
245 mappings.reverse[op].push_back(value);
251transform::TransformState::setPayloadValues(Value handle,
253 assert(handle !=
nullptr &&
"attempting to set params for a null value");
254 assert(llvm::isa<TransformValueHandleTypeInterface>(handle.
getType()) &&
255 "wrong handle type");
257 for (Value payload : payloadValues) {
260 return emitError(handle.
getLoc()) <<
"attempting to assign a null payload "
261 "value to this transform handle";
264 auto iface = llvm::cast<TransformValueHandleTypeInterface>(handle.
getType());
265 SmallVector<Value> payloadValueVector = llvm::to_vector(payloadValues);
266 DiagnosedSilenceableFailure
result =
267 iface.checkPayload(handle.
getLoc(), payloadValueVector);
271 Mappings &mappings = getMapping(handle);
273 mappings.values.insert({handle, std::move(payloadValueVector)}).second;
276 "value handle is already associated with another list of payload values");
279 for (Value payload : payloadValues)
280 mappings.reverseValues[payload].push_back(handle);
285LogicalResult transform::TransformState::setParams(Value value,
286 ArrayRef<Param> params) {
287 assert(value !=
nullptr &&
"attempting to set params for a null value");
289 for (Attribute attr : params) {
293 <<
"attempting to assign a null parameter to this transform value";
296 auto valueType = llvm::dyn_cast<TransformParamTypeInterface>(value.
getType());
298 "cannot associate parameter with a value of non-parameter type");
299 DiagnosedSilenceableFailure
result =
300 valueType.checkPayload(value.
getLoc(), params);
304 Mappings &mappings = getMapping(value);
306 mappings.params.insert({value, llvm::to_vector(params)}).second;
307 assert(
inserted &&
"value is already associated with another list of params");
312template <
typename Mapping,
typename Key,
typename Mapped>
314 auto it = mapping.find(key);
315 if (it == mapping.end())
318 llvm::erase(it->getSecond(), mapped);
319 if (it->getSecond().empty())
323void transform::TransformState::forgetMapping(Value opHandle,
325 bool allowOutOfScope) {
326 Mappings &mappings = getMapping(opHandle, allowOutOfScope);
327 for (Operation *op : mappings.direct[opHandle])
329 mappings.direct.erase(opHandle);
330#if LLVM_ENABLE_ABI_BREAKING_CHECKS
333 mappings.incrementTimestamp(opHandle);
336 for (Value opResult : origOpFlatResults) {
337 SmallVector<Value> resultHandles;
338 (void)getHandlesForPayloadValue(opResult, resultHandles);
339 for (Value resultHandle : resultHandles) {
340 Mappings &localMappings = getMapping(resultHandle);
342#if LLVM_ENABLE_ABI_BREAKING_CHECKS
345 mappings.incrementTimestamp(resultHandle);
352void transform::TransformState::forgetValueMapping(
353 Value valueHandle, ArrayRef<Operation *> payloadOperations) {
354 Mappings &mappings = getMapping(valueHandle);
355 for (Value payloadValue : mappings.reverseValues[valueHandle])
357 mappings.values.erase(valueHandle);
358#if LLVM_ENABLE_ABI_BREAKING_CHECKS
361 mappings.incrementTimestamp(valueHandle);
364 for (Operation *payloadOp : payloadOperations) {
365 SmallVector<Value> opHandles;
366 (void)getHandlesForPayloadOp(payloadOp, opHandles);
367 for (Value opHandle : opHandles) {
368 Mappings &localMappings = getMapping(opHandle);
372#if LLVM_ENABLE_ABI_BREAKING_CHECKS
375 localMappings.incrementTimestamp(opHandle);
382transform::TransformState::replacePayloadOp(Operation *op,
388 SmallVector<Value> valueHandles;
389 (void)getHandlesForPayloadValue(opResult, valueHandles,
391 assert(valueHandles.empty() &&
"expected no mapping to old results");
397 SmallVector<Value> opHandles;
398 if (
failed(getHandlesForPayloadOp(op, opHandles,
true)))
400 for (Value handle : opHandles) {
401 Mappings &mappings = getMapping(handle,
true);
414 for (Value handle : opHandles) {
415 Mappings &mappings = getMapping(handle,
true);
416 auto it = mappings.direct.find(handle);
417 if (it == mappings.direct.end())
420 SmallVector<Operation *, 2> &association = it->getSecond();
422 for (Operation *&mapped : association) {
430 opHandlesToCompact.insert(handle);
438transform::TransformState::replacePayloadValue(Value value, Value
replacement) {
439 SmallVector<Value> valueHandles;
440 if (
failed(getHandlesForPayloadValue(value, valueHandles,
444 for (Value handle : valueHandles) {
445 Mappings &mappings = getMapping(handle,
true);
452#if LLVM_ENABLE_ABI_BREAKING_CHECKS
455 mappings.incrementTimestamp(handle);
458 auto it = mappings.values.find(handle);
459 if (it == mappings.values.end())
462 SmallVector<Value> &association = it->getSecond();
463 for (Value &mapped : association) {
467 mappings.reverseValues[
replacement].push_back(handle);
474void transform::TransformState::recordOpHandleInvalidationOne(
475 OpOperand &consumingHandle, ArrayRef<Operation *> potentialAncestors,
476 Operation *payloadOp, Value otherHandle, Value throughValue,
477 transform::TransformState::InvalidatedHandleMap &newlyInvalidated)
const {
481 if (invalidatedHandles.count(otherHandle) ||
482 newlyInvalidated.count(otherHandle))
485 FULL_LDBG() <<
"--recordOpHandleInvalidationOne";
487 << llvm::interleaved(
488 llvm::make_pointee_range(potentialAncestors));
490 Operation *owner = consumingHandle.
getOwner();
492 for (Operation *ancestor : potentialAncestors) {
494 FULL_LDBG() <<
"----handle one ancestor: " << *ancestor;;
496 FULL_LDBG() <<
"----of payload with name: "
498 FULL_LDBG() <<
"----of payload: " << *payloadOp;
500 if (!ancestor->isAncestor(payloadOp))
507 Location ancestorLoc = ancestor->
getLoc();
508 Location opLoc = payloadOp->
getLoc();
509 std::optional<Location> throughValueLoc =
510 throughValue ? std::make_optional(throughValue.
getLoc()) : std::nullopt;
511 newlyInvalidated[otherHandle] = [ancestorLoc, opLoc, owner, operandNo,
513 throughValueLoc](Location currentLoc) {
515 <<
"op uses a handle invalidated by a "
516 "previously executed transform op";
517 diag.
attachNote(otherHandle.getLoc()) <<
"handle to invalidated ops";
519 <<
"invalidated by this transform op that consumes its operand #"
521 <<
" and invalidates all handles to payload IR entities associated "
522 "with this operand and entities nested in them";
525 if (throughValueLoc) {
526 diag.attachNote(*throughValueLoc)
527 <<
"consumed handle points to this payload value";
533void transform::TransformState::recordValueHandleInvalidationByOpHandleOne(
534 OpOperand &opHandle, ArrayRef<Operation *> potentialAncestors,
535 Value payloadValue, Value valueHandle,
536 transform::TransformState::InvalidatedHandleMap &newlyInvalidated)
const {
540 if (invalidatedHandles.count(valueHandle) ||
541 newlyInvalidated.count(valueHandle))
544 for (Operation *ancestor : potentialAncestors) {
545 Operation *definingOp;
546 std::optional<unsigned> resultNo;
547 unsigned argumentNo = std::numeric_limits<unsigned>::max();
548 unsigned blockNo = std::numeric_limits<unsigned>::max();
549 unsigned regionNo = std::numeric_limits<unsigned>::max();
550 if (
auto opResult = llvm::dyn_cast<OpResult>(payloadValue)) {
551 definingOp = opResult.getOwner();
552 resultNo = opResult.getResultNumber();
554 auto arg = llvm::cast<BlockArgument>(payloadValue);
556 argumentNo = arg.getArgNumber();
557 blockNo = std::distance(arg.getOwner()->getParent()->begin(),
558 arg.getOwner()->getIterator());
559 regionNo = arg.getOwner()->getParent()->getRegionNumber();
561 assert(definingOp &&
"expected the value to be defined by an op as result "
562 "or block argument");
563 if (!ancestor->isAncestor(definingOp))
566 Operation *owner = opHandle.
getOwner();
568 Location ancestorLoc = ancestor->getLoc();
569 Location opLoc = definingOp->
getLoc();
570 Location valueLoc = payloadValue.
getLoc();
571 newlyInvalidated[valueHandle] = [valueHandle, owner, operandNo, resultNo,
572 argumentNo, blockNo, regionNo, ancestorLoc,
573 opLoc, valueLoc](Location currentLoc) {
575 <<
"op uses a handle invalidated by a "
576 "previously executed transform op";
579 <<
"invalidated by this transform op that consumes its operand #"
581 <<
" and invalidates all handles to payload IR entities "
582 "associated with this operand and entities nested in them";
584 <<
"ancestor op associated with the consumed handle";
586 diag.attachNote(opLoc)
587 <<
"op defining the value as result #" << *resultNo;
589 diag.attachNote(opLoc)
590 <<
"op defining the value as block argument #" << argumentNo
591 <<
" of block #" << blockNo <<
" in region #" << regionNo;
598void transform::TransformState::recordOpHandleInvalidation(
599 OpOperand &handle, ArrayRef<Operation *> potentialAncestors,
601 transform::TransformState::InvalidatedHandleMap &newlyInvalidated)
const {
603 if (potentialAncestors.empty()) {
604 FULL_LDBG() <<
"----recording invalidation for empty handle: "
607 Operation *owner = handle.
getOwner();
609 newlyInvalidated[handle.
get()] = [owner, operandNo](Location currentLoc) {
611 <<
"op uses a handle associated with empty "
612 "payload and invalidated by a "
613 "previously executed transform op";
615 <<
"invalidated by this transform op that consumes its operand #"
628 for (
const auto &[region, mapping] : llvm::reverse(mappings)) {
632 for (
const auto &[payloadOp, otherHandles] : mapping->reverse) {
633 for (Value otherHandle : otherHandles)
634 recordOpHandleInvalidationOne(handle, potentialAncestors, payloadOp,
635 otherHandle, throughValue,
643 for (
const auto &[payloadValue, valueHandles] : mapping->reverseValues) {
644 for (Value valueHandle : valueHandles)
645 recordValueHandleInvalidationByOpHandleOne(handle, potentialAncestors,
646 payloadValue, valueHandle,
656void transform::TransformState::recordValueHandleInvalidation(
657 OpOperand &valueHandle,
658 transform::TransformState::InvalidatedHandleMap &newlyInvalidated)
const {
660 for (Value payloadValue : getPayloadValuesView(valueHandle.
get())) {
661 SmallVector<Value> otherValueHandles;
662 (void)getHandlesForPayloadValue(payloadValue, otherValueHandles);
663 for (Value otherHandle : otherValueHandles) {
664 Operation *owner = valueHandle.
getOwner();
666 Location valueLoc = payloadValue.
getLoc();
667 newlyInvalidated[otherHandle] = [otherHandle, owner, operandNo,
668 valueLoc](Location currentLoc) {
670 <<
"op uses a handle invalidated by a "
671 "previously executed transform op";
674 <<
"invalidated by this transform op that consumes its operand #"
676 <<
" and invalidates handles to the same values as associated with "
682 if (
auto opResult = llvm::dyn_cast<OpResult>(payloadValue)) {
683 Operation *payloadOp = opResult.getOwner();
684 recordOpHandleInvalidation(valueHandle, payloadOp, payloadValue,
687 auto arg = llvm::dyn_cast<BlockArgument>(payloadValue);
688 for (Operation &payloadOp : *arg.getOwner())
689 recordOpHandleInvalidation(valueHandle, &payloadOp, payloadValue,
699LogicalResult transform::TransformState::checkAndRecordHandleInvalidationImpl(
700 transform::TransformOpInterface transform,
701 transform::TransformState::InvalidatedHandleMap &newlyInvalidated)
const {
702 FULL_LDBG() <<
"--Start checkAndRecordHandleInvalidation";
703 auto memoryEffectsIface =
704 cast<MemoryEffectOpInterface>(transform.getOperation());
705 SmallVector<MemoryEffects::EffectInstance> effects;
706 memoryEffectsIface.getEffectsOnResource(
709 for (OpOperand &
target : transform->getOpOperands()) {
715 auto it = invalidatedHandles.find(
target.get());
716 auto nit = newlyInvalidated.find(
target.get());
717 if (it != invalidatedHandles.end()) {
718 FULL_LDBG() <<
"--End checkAndRecordHandleInvalidation, found already "
719 "invalidated -> FAILURE";
720 return it->getSecond()(transform->getLoc()), failure();
722 if (!transform.allowsRepeatedHandleOperands() &&
723 nit != newlyInvalidated.end()) {
724 FULL_LDBG() <<
"--End checkAndRecordHandleInvalidation, found newly "
725 "invalidated (by this op) -> FAILURE";
726 return nit->getSecond()(transform->getLoc()), failure();
732 return isa<MemoryEffects::Free>(effect.getEffect()) &&
733 effect.getValue() ==
target.get();
735 if (llvm::any_of(effects, consumesTarget)) {
736 FULL_LDBG() <<
"----found consume effect";
737 if (llvm::isa<transform::TransformHandleTypeInterface>(
738 target.get().getType())) {
739 FULL_LDBG() <<
"----recordOpHandleInvalidation";
740 SmallVector<Operation *> payloadOps =
741 llvm::to_vector(getPayloadOps(
target.get()));
742 recordOpHandleInvalidation(
target, payloadOps,
nullptr,
744 }
else if (llvm::isa<transform::TransformValueHandleTypeInterface>(
745 target.get().getType())) {
746 FULL_LDBG() <<
"----recordValueHandleInvalidation";
747 recordValueHandleInvalidation(
target, newlyInvalidated);
750 <<
"----not a TransformHandle -> SKIP AND DROP ON THE FLOOR";
753 FULL_LDBG() <<
"----no consume effect -> SKIP";
757 FULL_LDBG() <<
"--End checkAndRecordHandleInvalidation -> SUCCESS";
761LogicalResult transform::TransformState::checkAndRecordHandleInvalidation(
762 transform::TransformOpInterface transform) {
763 InvalidatedHandleMap newlyInvalidated;
764 LogicalResult checkResult =
765 checkAndRecordHandleInvalidationImpl(transform, newlyInvalidated);
766 invalidatedHandles.insert(std::make_move_iterator(newlyInvalidated.begin()),
767 std::make_move_iterator(newlyInvalidated.end()));
772static DiagnosedSilenceableFailure
774 transform::TransformOpInterface
transform,
775 unsigned operandNumber) {
777 for (T p : payload) {
778 if (!seen.insert(p).second) {
781 <<
"a handle passed as operand #" << operandNumber
782 <<
" and consumed by this operation points to a payload "
783 "entity more than once";
784 if constexpr (std::is_pointer_v<T>)
785 diag.attachNote(p->getLoc()) <<
"repeated target op";
787 diag.attachNote(p.getLoc()) <<
"repeated target value";
794void transform::TransformState::compactOpHandles() {
795 for (Value handle : opHandlesToCompact) {
796 Mappings &mappings = getMapping(handle,
true);
797#if LLVM_ENABLE_ABI_BREAKING_CHECKS
798 if (llvm::is_contained(mappings.direct[handle],
nullptr))
801 mappings.incrementTimestamp(handle);
803 llvm::erase(mappings.direct[handle],
nullptr);
805 opHandlesToCompact.clear();
808DiagnosedSilenceableFailure
810 LDBG() <<
"applying: "
813 auto printOnFailureRAII = llvm::make_scope_exit([
this] {
815 LDBG() <<
"Failing Top-level payload:\n"
821 regionStack.back()->currentTransform =
transform;
824 if (options.getExpensiveChecksEnabled()) {
826 if (failed(checkAndRecordHandleInvalidation(
transform)))
830 FULL_LDBG() <<
"iterate on handle: " << operand.get();
832 FULL_LDBG() <<
"--handle not consumed -> SKIP";
835 if (
transform.allowsRepeatedHandleOperands()) {
836 FULL_LDBG() <<
"--op allows repeated handles -> SKIP";
841 Type operandType = operand.get().getType();
842 if (llvm::isa<TransformHandleTypeInterface>(operandType)) {
843 FULL_LDBG() <<
"--checkRepeatedConsumptionInOperand for Operation*";
846 getPayloadOpsView(operand.get()),
transform,
847 operand.getOperandNumber());
852 }
else if (llvm::isa<TransformValueHandleTypeInterface>(operandType)) {
853 FULL_LDBG() <<
"--checkRepeatedConsumptionInOperand For Value";
856 getPayloadValuesView(operand.get()),
transform,
857 operand.getOperandNumber());
863 FULL_LDBG() <<
"--not a TransformHandle -> SKIP AND DROP ON THE FLOOR";
879 for (
OpOperand *opOperand : consumedOperands) {
880 Value operand = opOperand->get();
881 if (llvm::isa<TransformHandleTypeInterface>(operand.
getType())) {
883 llvm::append_range(origOpFlatResults, payloadOp->
getResults());
887 if (llvm::isa<TransformValueHandleTypeInterface>(operand.
getType())) {
888 for (
Value payloadValue : getPayloadValuesView(operand)) {
889 if (llvm::isa<OpResult>(payloadValue)) {
895 llvm::map_range(*llvm::cast<BlockArgument>(payloadValue).getOwner(),
902 <<
"unexpectedly consumed a value that is not a handle as operand #"
903 << opOperand->getOperandNumber();
905 <<
"value defined here with type " << operand.
getType();
914 llvm::find_if(llvm::reverse(regionStack), [&](
RegionScope *scope) {
915 return handle.getParentRegion() == scope->region;
917 assert(scopeIt != regionStack.rend() &&
918 "could not find region scope for handle");
920 return llvm::all_of(handle.getUsers(), [&](
Operation *user) {
921 return user == scope->currentTransform ||
922 happensBefore(user, scope->currentTransform);
941 transform->hasAttr(FindPayloadReplacementOpInterface::
942 kSilenceTrackingFailuresAttrName)) {
951 result = std::move(trackingFailure);
954 if (
result.isSilenceableFailure())
955 result.attachNote() <<
"tracking listener also failed: "
960 if (
result.isDefiniteFailure())
965 if (
result.isSilenceableFailure())
970 for (
OpOperand *opOperand : consumedOperands) {
971 Value operand = opOperand->get();
972 if (llvm::isa<TransformHandleTypeInterface>(operand.
getType())) {
973 forgetMapping(operand, origOpFlatResults);
974 }
else if (llvm::isa<TransformValueHandleTypeInterface>(
976 forgetValueMapping(operand, origAssociatedOps);
980 if (failed(updateStateFromResults(results,
transform->getResults())))
983 printOnFailureRAII.release();
985 LDBG() <<
"Top-level payload:\n" << *getTopLevel();
990LogicalResult transform::TransformState::updateStateFromResults(
991 const TransformResults &results,
ResultRange opResults) {
993 if (llvm::isa<TransformParamTypeInterface>(
result.getType())) {
994 assert(results.isParam(
result.getResultNumber()) &&
995 "expected parameters for the parameter-typed result");
997 setParams(
result, results.getParams(
result.getResultNumber())))) {
1000 }
else if (llvm::isa<TransformValueHandleTypeInterface>(
result.getType())) {
1001 assert(results.isValue(
result.getResultNumber()) &&
1002 "expected values for value-type-result");
1003 if (
failed(setPayloadValues(
1004 result, results.getValues(
result.getResultNumber())))) {
1008 assert(!results.isParam(
result.getResultNumber()) &&
1009 "expected payload ops for the non-parameter typed result");
1011 setPayloadOps(
result, results.get(
result.getResultNumber())))) {
1036 return state.replacePayloadValue(value,
replacement);
1047 for (
Block &block : *region) {
1048 for (
Value handle : block.getArguments()) {
1049 state.invalidatedHandles.erase(handle);
1053 state.invalidatedHandles.erase(handle);
1058#if LLVM_ENABLE_ABI_BREAKING_CHECKS
1062 llvm::to_vector(llvm::make_first_range(state.mappings[region]->reverse));
1065 state.mappings.erase(region);
1066 state.regionStack.pop_back();
1073transform::TransformResults::TransformResults(
unsigned numSegments) {
1074 operations.appendEmptyRows(numSegments);
1075 params.appendEmptyRows(numSegments);
1076 values.appendEmptyRows(numSegments);
1082 assert(position <
static_cast<int64_t>(this->params.size()) &&
1083 "setting params for a non-existent handle");
1084 assert(this->params[position].data() ==
nullptr &&
"params already set");
1085 assert(operations[position].data() ==
nullptr &&
1086 "another kind of results already set");
1087 assert(values[position].data() ==
nullptr &&
1088 "another kind of results already set");
1089 this->params.replace(position, params);
1106 if (!
diag.succeeded())
1107 llvm::dbgs() <<
diag.getStatusString() <<
"\n";
1108 assert(
diag.succeeded() &&
"incorrect mapping");
1114 transform::TransformOpInterface
transform) {
1116 if (!isSet(opResult.getResultNumber()))
1122transform::TransformResults::get(
unsigned resultNumber)
const {
1123 assert(resultNumber < operations.size() &&
1124 "querying results for a non-existent handle");
1125 assert(operations[resultNumber].data() !=
nullptr &&
1126 "querying unset results (values or params expected?)");
1127 return operations[resultNumber];
1131transform::TransformResults::getParams(
unsigned resultNumber)
const {
1132 assert(resultNumber < params.size() &&
1133 "querying params for a non-existent handle");
1134 assert(params[resultNumber].data() !=
nullptr &&
1135 "querying unset params (ops or values expected?)");
1136 return params[resultNumber];
1140transform::TransformResults::getValues(
unsigned resultNumber)
const {
1141 assert(resultNumber < values.size() &&
1142 "querying values for a non-existent handle");
1143 assert(values[resultNumber].data() !=
nullptr &&
1144 "querying unset values (ops or params expected?)");
1145 return values[resultNumber];
1148bool transform::TransformResults::isParam(
unsigned resultNumber)
const {
1149 assert(resultNumber < params.size() &&
1150 "querying association for a non-existent handle");
1151 return params[resultNumber].data() !=
nullptr;
1154bool transform::TransformResults::isValue(
unsigned resultNumber)
const {
1155 assert(resultNumber < values.size() &&
1156 "querying association for a non-existent handle");
1157 return values[resultNumber].data() !=
nullptr;
1160bool transform::TransformResults::isSet(
unsigned resultNumber)
const {
1161 assert(resultNumber < params.size() &&
1162 "querying association for a non-existent handle");
1163 return params[resultNumber].data() !=
nullptr ||
1164 operations[resultNumber].data() !=
nullptr ||
1165 values[resultNumber].data() !=
nullptr;
1173 TransformOpInterface op,
1177 for (
OpOperand *opOperand : transformOp.getConsumedHandleOpOperands()) {
1178 consumedHandles.insert(opOperand->get());
1185 for (
Value v : values) {
1190 defOp = v.getDefiningOp();
1193 if (defOp != v.getDefiningOp())
1202 "invalid number of replacement values");
1206 getTransformOp(),
"tracking listener failed to find replacement op "
1207 "during application of this transform op");
1213 diag.attachNote() <<
"replacement values belong to different ops";
1218 if (config.skipCastOps && isa<CastOpInterface>(defOp)) {
1222 <<
"using output of 'CastOpInterface' op";
1228 if (!config.requireMatchingReplacementOpName ||
1244 if (
auto findReplacementOpInterface =
1245 dyn_cast<FindPayloadReplacementOpInterface>(defOp)) {
1246 values.assign(findReplacementOpInterface.getNextOperands());
1247 diag.attachNote(defOp->
getLoc()) <<
"using operands provided by "
1248 "'FindPayloadReplacementOpInterface'";
1251 }
while (!values.empty());
1253 diag.attachNote() <<
"ran out of suitable replacement values";
1261 reasonCallback(
diag);
1262 LDBG() <<
"Match Failure : " <<
diag.str();
1266void transform::TrackingListener::notifyOperationErased(
Operation *op) {
1269 (
void)replacePayloadValue(value,
nullptr);
1271 (
void)replacePayloadOp(op,
nullptr);
1274void transform::TrackingListener::notifyOperationReplaced(
1277 "invalid number of replacement values");
1280 for (
auto [oldValue, newValue] : llvm::zip(op->
getResults(), newValues))
1281 (
void)replacePayloadValue(oldValue, newValue);
1285 if (failed(getTransformState().getHandlesForPayloadOp(
1286 op, opHandles,
true))) {
1300 auto handleWasConsumed = [&] {
1301 return llvm::any_of(opHandles,
1302 [&](
Value h) {
return consumedHandles.contains(h); });
1307 if (
config.skipHandleFn) {
1308 auto it = llvm::find_if(opHandles,
1309 [&](Value v) {
return !
config.skipHandleFn(v); });
1310 if (it != opHandles.end())
1312 }
else if (!opHandles.empty()) {
1313 aliveHandle = opHandles.front();
1315 if (!aliveHandle || handleWasConsumed()) {
1318 (void)replacePayloadOp(op,
nullptr);
1323 DiagnosedSilenceableFailure
diag =
1327 if (!
diag.succeeded()) {
1329 <<
"replacement is required because this handle must be updated";
1330 notifyPayloadReplacementNotFound(op, newValues, std::move(
diag));
1331 (void)replacePayloadOp(op,
nullptr);
1342 assert(status.succeeded() &&
"listener state was not checked");
1354 return !status.succeeded();
1362 diag.takeDiagnostics(diags);
1363 if (!status.succeeded())
1364 status.takeDiagnostics(diags);
1368 status.attachNote(op->
getLoc()) <<
"[" << errorCounter <<
"] replaced op";
1369 for (
auto &&[
index, value] : llvm::enumerate(values))
1370 status.attachNote(value.
getLoc())
1371 <<
"[" << errorCounter <<
"] replacement value " <<
index;
1377 if (!matchFailure) {
1380 return matchFailure->str();
1386 reasonCallback(
diag);
1387 matchFailure = std::move(
diag);
1401 return listener->failed();
1414 return listener->replacePayloadOp(op,
replacement);
1424 for (
auto &&[position, parent] : llvm::enumerate(targets)) {
1425 for (
Operation *child : targets.drop_front(position + 1)) {
1426 if (parent->isAncestor(child)) {
1429 <<
"transform operation consumes a handle pointing to an ancestor "
1430 "payload operation before its descendant";
1432 <<
"the ancestor is likely erased or rewritten before the "
1433 "descendant is accessed, leading to undefined behavior";
1434 diag.attachNote(parent->getLoc()) <<
"ancestor payload op";
1435 diag.attachNote(child->getLoc()) <<
"descendant payload op";
1454 diag.attachNote(payloadOpLoc) <<
"when applied to this op";
1458 if (partialResult.
size() != expectedNumResults) {
1459 auto diag =
emitDiag() <<
"application of " << transformOpName
1460 <<
" expected to produce " << expectedNumResults
1461 <<
" results (actually produced "
1462 << partialResult.
size() <<
").";
1463 diag.attachNote(transformOpLoc)
1464 <<
"if you need variadic results, consider a generic `apply` "
1465 <<
"instead of the specialized `applyToOne`.";
1470 for (
const auto &[
ptr, res] :
1471 llvm::zip(partialResult, transformOp->
getResults())) {
1474 if (llvm::isa<TransformHandleTypeInterface>(res.getType()) &&
1475 !isa<Operation *>(
ptr)) {
1476 return emitDiag() <<
"application of " << transformOpName
1477 <<
" expected to produce an Operation * for result #"
1478 << res.getResultNumber();
1480 if (llvm::isa<TransformParamTypeInterface>(res.getType()) &&
1481 !isa<Attribute>(
ptr)) {
1482 return emitDiag() <<
"application of " << transformOpName
1483 <<
" expected to produce an Attribute for result #"
1484 << res.getResultNumber();
1486 if (llvm::isa<TransformValueHandleTypeInterface>(res.getType()) &&
1488 return emitDiag() <<
"application of " << transformOpName
1489 <<
" expected to produce a Value for result #"
1490 << res.getResultNumber();
1496template <
typename T>
1498 return llvm::map_to_vector(range, llvm::CastTo<T>);
1507 if (llvm::any_of(partialResults,
1508 [](
MappedValue value) {
return value.isNull(); }))
1510 assert(transformOp->
getNumResults() == partialResults.size() &&
1511 "expected as many partial results as op as results");
1512 for (
auto [i, value] : llvm::enumerate(partialResults))
1513 transposed[i].push_back(value);
1517 unsigned position = r.getResultNumber();
1518 if (llvm::isa<TransformParamTypeInterface>(r.getType())) {
1521 }
else if (llvm::isa<TransformValueHandleTypeInterface>(r.getType())) {
1536 assert(mappings.size() == values.size() &&
"mismatching number of mappings");
1537 for (
auto &&[operand, mapped] : llvm::zip_equal(values, mappings)) {
1538 size_t mappedSize = mapped.size();
1539 if (llvm::isa<TransformHandleTypeInterface>(operand.getType())) {
1541 }
else if (llvm::isa<TransformValueHandleTypeInterface>(
1542 operand.getType())) {
1545 assert(llvm::isa<TransformParamTypeInterface>(operand.getType()) &&
1546 "unsupported kind of transform dialect value");
1547 llvm::append_range(mapped, state.
getParams(operand));
1550 if (mapped.size() - mappedSize != 1 && !flatten)
1559 mappings.resize(mappings.size() + values.size());
1569 for (
auto &&[terminatorOperand,
result] :
1572 if (llvm::isa<transform::TransformHandleTypeInterface>(
result.getType())) {
1574 }
else if (llvm::isa<transform::TransformValueHandleTypeInterface>(
1579 llvm::isa<transform::TransformParamTypeInterface>(
result.getType()) &&
1580 "unhandled transform type interface");
1603 iface.getEffectsOnValue(source, nestedEffects);
1604 for (
const auto &effect : nestedEffects)
1605 effects.emplace_back(effect.getEffect(),
target, effect.getResource());
1614 auto iface = dyn_cast<MemoryEffectOpInterface>(&op);
1625 llvm::append_range(effects, nestedEffects);
1637 auto iface = dyn_cast<MemoryEffectOpInterface>(&op);
1641 iface.getEffects(effects);
1664 <<
" were provided to the interpreter";
1678 argument, extraMappings[argument.getArgNumber() - 1])))
1690 assert(isa<TransformOpInterface>(op) &&
1691 "should implement TransformOpInterface to have "
1692 "PossibleTopLevelTransformOpTrait");
1695 return op->
emitOpError() <<
"expects at least one region";
1698 if (!llvm::hasNItems(*bodyRegion, 1))
1699 return op->
emitOpError() <<
"expects a single-block region";
1704 <<
"expects the entry block to have at least one argument";
1706 if (!llvm::isa<TransformHandleTypeInterface>(
1709 <<
"expects the first entry block argument to be of type "
1710 "implementing TransformHandleTypeInterface";
1716 <<
"expects the type of the block argument to match "
1717 "the type of the operand";
1721 if (llvm::isa<TransformHandleTypeInterface, TransformParamTypeInterface,
1722 TransformValueHandleTypeInterface>(arg.
getType()))
1727 <<
"expects trailing entry block arguments to be of type implementing "
1728 "TransformHandleTypeInterface, TransformValueHandleTypeInterface or "
1729 "TransformParamTypeInterface";
1739 <<
"expects operands to be provided for a nested op";
1740 diag.attachNote(parent->getLoc())
1741 <<
"nested in another possible top-level op";
1756 bool hasPayloadOperands =
false;
1759 if (llvm::isa<TransformHandleTypeInterface,
1760 TransformValueHandleTypeInterface>(operand.get().getType()))
1761 hasPayloadOperands =
true;
1763 if (hasPayloadOperands)
1772 llvm::report_fatal_error(
1773 Twine(
"ParamProducerTransformOpTrait must be attached to an op that "
1774 "implements MemoryEffectsOpInterface, found on ") +
1778 if (llvm::isa<TransformParamTypeInterface>(
result.getType()))
1781 <<
"ParamProducerTransformOpTrait attached to this op expects "
1782 "result types to implement TransformParamTypeInterface";
1804template <
typename EffectTy,
typename ResourceTy,
typename Range>
1807 return isa<EffectTy>(effect.
getEffect()) &&
1813 transform::TransformOpInterface
transform) {
1814 auto iface = cast<MemoryEffectOpInterface>(
transform.getOperation());
1816 iface.getEffectsOnValue(handle, effects);
1817 return ::hasEffect<MemoryEffects::Read, TransformMappingResource>(effects) &&
1864 auto iface = cast<MemoryEffectOpInterface>(
transform.getOperation());
1866 iface.getEffects(effects);
1867 return ::hasEffect<MemoryEffects::Write, PayloadIRResource>(effects);
1871 auto iface = cast<MemoryEffectOpInterface>(
transform.getOperation());
1873 iface.getEffects(effects);
1874 return ::hasEffect<MemoryEffects::Read, PayloadIRResource>(effects);
1878 Block &block, llvm::SmallDenseSet<unsigned int> &consumedArguments) {
1881 auto iface = dyn_cast<MemoryEffectOpInterface>(nested);
1886 iface.getEffects(effects);
1889 dyn_cast_or_null<BlockArgument>(effect.getValue());
1890 if (!argument || argument.
getOwner() != &block ||
1891 !isa<MemoryEffects::Free>(effect.getEffect()) ||
1905 TransformOpInterface transformOp) {
1906 SmallVector<OpOperand *> consumedOperands;
1907 consumedOperands.reserve(transformOp->getNumOperands());
1908 auto memEffectInterface =
1909 cast<MemoryEffectOpInterface>(transformOp.getOperation());
1910 SmallVector<MemoryEffects::EffectInstance, 2> effects;
1911 for (OpOperand &
target : transformOp->getOpOperands()) {
1913 memEffectInterface.getEffectsOnValue(
target.get(), effects);
1915 return isa<transform::TransformMappingResource>(
1917 isa<MemoryEffects::Free>(effect.
getEffect());
1919 consumedOperands.push_back(&
target);
1922 return consumedOperands;
1926 auto iface = cast<MemoryEffectOpInterface>(op);
1928 iface.getEffects(effects);
1930 auto effectsOn = [&](
Value value) {
1931 return llvm::make_filter_range(
1933 return instance.
getValue() == value;
1937 std::optional<unsigned> firstConsumedOperand;
1939 auto range = effectsOn(operand.get());
1940 if (range.empty()) {
1942 op->
emitError() <<
"TransformOpInterface requires memory effects "
1943 "on operands to be specified";
1944 diag.attachNote() <<
"no effects specified for operand #"
1945 << operand.getOperandNumber();
1950 <<
"TransformOpInterface did not expect "
1951 "'allocate' memory effect on an operand";
1952 diag.attachNote() <<
"specified for operand #"
1953 << operand.getOperandNumber();
1956 if (!firstConsumedOperand &&
1958 firstConsumedOperand = operand.getOperandNumber();
1962 if (firstConsumedOperand &&
1966 <<
"TransformOpInterface expects ops consuming operands to have a "
1967 "'write' effect on the payload resource";
1968 diag.attachNote() <<
"consumes operand #" << *firstConsumedOperand;
1973 auto range = effectsOn(
result);
1977 op->
emitError() <<
"TransformOpInterface requires 'allocate' memory "
1978 "effect to be specified for results";
1979 diag.attachNote() <<
"no 'allocate' effect specified for result #"
1980 <<
result.getResultNumber();
1998 if (enforceToplevelTransformOp) {
2002 <<
"expected transform to start at the top-level transform op";
2011 if (stateInitializer)
2012 stateInitializer(state);
2016 return stateExporter(state);
2024#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.cpp.inc"
2025#include "mlir/Dialect/Transform/Interfaces/TransformTypeInterfaces.cpp.inc"
static Operation * getCommonDefiningOp(const ValueVector &values)
Return the operation that defines all values in the vector.
static InFlightDiagnostic emitDiag(Location location, DiagnosticSeverity severity, const Twine &message)
Helper function used to emit a diagnostic with an optionally empty twine message.
*if copies could not be generated due to yet unimplemented cases *copyInPlacementStart and copyOutPlacementStart in copyPlacementBlock *specify the insertion points where the incoming copies and outgoing should be inserted(the insertion happens right before the *insertion point). Since `begin` can itself be invalidated due to the memref *rewriting done from this method
*if copies could not be generated due to yet unimplemented cases *copyInPlacementStart and copyOutPlacementStart in copyPlacementBlock *specify the insertion points where the incoming copies and outgoing should be the output argument nBegin is set to its * replacement(set to `begin` if no invalidation happens). Since outgoing *copies could have been inserted at `end`
static std::string diag(const llvm::Value &value)
static llvm::ManagedStatic< PassManagerOptions > options
This class represents an argument of a Block.
unsigned getArgNumber() const
Returns the number of this argument.
Block * getOwner() const
Returns the block that owns this argument.
Block represents an ordered list of Operations.
MutableArrayRef< BlockArgument > BlockArgListType
Operation * findAncestorOpInBlock(Operation &op)
Returns 'op' if 'op' lies in this block, or otherwise finds the ancestor operation of 'op' that lies ...
BlockArgument getArgument(unsigned i)
unsigned getNumArguments()
Operation * getTerminator()
Get the terminator operation of this block.
BlockArgListType getArguments()
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
A compatibility class connecting InFlightDiagnostic to DiagnosedSilenceableFailure while providing an...
The result of a transform IR operation application.
LogicalResult silence()
Converts silenceable failure into LogicalResult success without reporting the diagnostic,...
static DiagnosedSilenceableFailure success()
Constructs a DiagnosedSilenceableFailure in the success state.
std::string getMessage() const
Returns the diagnostic message without emitting it.
LogicalResult checkAndReport()
Converts all kinds of failure into a LogicalResult failure, emitting the diagnostic if necessary.
static DiagnosedSilenceableFailure silenceableFailure(Diagnostic &&diag)
Constructs a DiagnosedSilenceableFailure in the silenceable failure state, ready to emit the given di...
bool succeeded() const
Returns true if this is a success.
static DiagnosedSilenceableFailure definiteFailure()
Constructs a DiagnosedSilenceableFailure in the failure state.
bool isSilenceableFailure() const
Returns true if this is a silenceable failure.
This class contains all of the information necessary to report a diagnostic to the DiagnosticEngine.
Diagnostic & attachNote(std::optional< Location > noteLoc=std::nullopt)
Attaches a note to this diagnostic.
IRValueT get() const
Return the current value being used by this operand.
This class represents a diagnostic that is inflight and set to be reported.
Diagnostic & attachNote(std::optional< Location > noteLoc=std::nullopt)
Attaches a note to this diagnostic.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
void setListener(Listener *newListener)
Sets the listener of this builder to the one provided.
This class represents an operand of an operation.
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Set of flags used to control the behavior of the various IR print methods (e.g.
This is a value defined by a result of an operation.
unsigned getResultNumber() const
Returns the number of this result.
This class provides the API for a sub-set of ops that are known to be constant-like.
This class provides the API for ops that are known to be isolated from above.
A wrapper class that allows for printing an operation with a set of flags, useful to act as a "stream...
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
StringAttr getIdentifier() const
Return the name of this operation as a StringAttr.
T::Concept * getInterface() const
Returns an instance of the concept object for the given interface if it was registered to this operat...
Operation is the basic unit of execution within MLIR.
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Value getOperand(unsigned idx)
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
bool isBeforeInBlock(Operation *other)
Given an operation 'other' that is within the same parent block, return whether the current operation...
Block * getBlock()
Returns the operation block that contains this operation.
Operation * getParentWithTrait()
Returns the closest surrounding parent operation with trait Trait.
unsigned getNumRegions()
Returns the number of regions held by this operation.
Location getLoc()
The source location the operation was defined or derived from.
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
MutableArrayRef< OpOperand > getOpOperands()
unsigned getNumOperands()
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
OperationName getName()
The name of an operation is the key identifier for it.
operand_range getOperands()
Returns an iterator on the underlying Value's.
result_range getOpResults()
result_range getResults()
bool isProperAncestor(Operation *other)
Return true if this operation is a proper ancestor of the other operation.
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
unsigned getNumResults()
Return the number of results held by this operation.
A 2D array where each row may have different length.
size_t size() const
Returns the number of rows in the 2D array.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Operation * getParentOp()
Return the parent operation this region is attached to.
This class implements the result iterators for the Operation class.
RewriterBase(MLIRContext *ctx, OpBuilder::Listener *listener=nullptr)
Initialize the builder.
Resource * getResource() const
Return the resource that the effect applies to.
EffectT * getEffect() const
Return the effect being applied.
Value getValue() const
Return the value the effect is applied on, or nullptr if there isn't a known value being affected.
static DerivedEffect * get()
static TransformMappingResource * get()
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Operation * getOwner() const
Return the owner of this operand.
SideEffects::EffectInstance< Effect > EffectInstance
Include the generated interface declarations.
const FrozenRewritePatternSet GreedyRewriteConfig config
DiagnosedSilenceableFailure emitSilenceableFailure(Location loc, const Twine &message={})
Emits a silenceable failure with the given message.
llvm::DenseSet< ValueT, ValueInfoT > DenseSet
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
DiagnosedDefiniteFailure emitDefiniteFailure(Location loc, const Twine &message={})
Emits a definite failure with the given message.
bool hasEffect(Operation *op)
Returns "true" if op has an effect of type EffectTy.
llvm::function_ref< Fn > function_ref
Represents a range (offset, size, and stride) where each element of the triple may be dynamic or stat...