37 #include "llvm/ADT/DenseSet.h"
38 #include "llvm/ADT/STLExtras.h"
39 #include "llvm/ADT/ScopeExit.h"
40 #include "llvm/ADT/SmallPtrSet.h"
41 #include "llvm/ADT/TypeSwitch.h"
42 #include "llvm/Support/Debug.h"
43 #include "llvm/Support/ErrorHandling.h"
46 #define DEBUG_TYPE "transform-dialect"
47 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "] ")
49 #define DEBUG_TYPE_MATCHER "transform-matcher"
50 #define DBGS_MATCHER() (llvm::dbgs() << "[" DEBUG_TYPE_MATCHER "] ")
51 #define DEBUG_MATCHER(x) DEBUG_WITH_TYPE(DEBUG_TYPE_MATCHER, x)
56 OpAsmParser &parser, std::optional<OpAsmParser::UnresolvedOperand> &root,
58 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &extraBindings,
59 SmallVectorImpl<Type> &extraBindingTypes);
65 ArrayAttr matchers, ArrayAttr actions);
76 Operation *transformAncestor = transform.getOperation();
77 while (transformAncestor) {
78 if (transformAncestor == payload) {
80 transform.emitDefiniteFailure()
81 <<
"cannot apply transform to itself (or one of its ancestors)";
82 diag.attachNote(payload->
getLoc()) <<
"target payload op";
85 transformAncestor = transformAncestor->
getParentOp();
90 #define GET_OP_CLASSES
91 #include "mlir/Dialect/Transform/IR/TransformOps.cpp.inc"
99 if (!point.
isParent() && getOperation()->getNumOperands() == 1)
100 return getOperation()->getOperands();
102 getOperation()->operand_end());
105 void transform::AlternativesOp::getSuccessorRegions(
107 for (
Region &alternative : llvm::drop_begin(
111 regions.emplace_back(&alternative, !getOperands().empty()
112 ? alternative.getArguments()
116 regions.emplace_back(getOperation()->getResults());
119 void transform::AlternativesOp::getRegionInvocationBounds(
124 bounds.reserve(getNumRegions());
125 bounds.emplace_back(1, 1);
132 results.
set(res, {});
140 if (
Value scopeHandle = getScope())
141 llvm::append_range(originals, state.getPayloadOps(scopeHandle));
143 originals.push_back(state.getTopLevel());
146 if (original->isAncestor(getOperation())) {
148 <<
"scope must not contain the transforms being applied";
149 diag.attachNote(original->getLoc()) <<
"scope";
154 <<
"only isolated-from-above ops can be alternative scopes";
155 diag.attachNote(original->getLoc()) <<
"scope";
160 for (
Region ® : getAlternatives()) {
165 auto scope = state.make_region_scope(reg);
166 auto clones = llvm::to_vector(
167 llvm::map_range(originals, [](
Operation *op) {
return op->
clone(); }));
168 auto deleteClones = llvm::make_scope_exit([&] {
172 if (failed(state.mapBlockArguments(reg.front().getArgument(0), clones)))
176 for (
Operation &transform : reg.front().without_terminator()) {
178 state.applyTransform(cast<TransformOpInterface>(transform));
180 LLVM_DEBUG(
DBGS() <<
"alternative failed: " << result.
getMessage()
186 if (::mlir::failed(result.
silence()))
195 deleteClones.release();
196 TrackingListener listener(state, *
this);
198 for (
const auto &kvp : llvm::zip(originals, clones)) {
209 return emitSilenceableError() <<
"all alternatives failed";
212 void transform::AlternativesOp::getEffects(
213 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
216 for (
Region *region : getRegions()) {
217 if (!region->empty())
224 for (
Region &alternative : getAlternatives()) {
229 <<
"expects terminator operands to have the "
230 "same type as results of the operation";
231 diag.attachNote(terminator->
getLoc()) <<
"terminator";
248 llvm::to_vector(state.getPayloadOps(getTarget()));
251 if (
auto paramH = getParam()) {
253 if (params.size() != 1) {
254 if (targets.size() != params.size()) {
255 return emitSilenceableError()
256 <<
"parameter and target have different payload lengths ("
257 << params.size() <<
" vs " << targets.size() <<
")";
259 for (
auto &&[target, attr] : llvm::zip_equal(targets, params))
260 target->setAttr(getName(), attr);
265 for (
auto *target : targets)
266 target->setAttr(getName(), attr);
270 void transform::AnnotateOp::getEffects(
271 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
282 transform::ApplyCommonSubexpressionEliminationOp::applyToOne(
297 void transform::ApplyCommonSubexpressionEliminationOp::getEffects(
298 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
322 auto addDefiningOpsToWorklist = [&](
Operation *op) {
325 if (
Operation *defOp = v.getDefiningOp())
327 worklist.insert(defOp);
335 const auto *it = llvm::find(worklist, op);
336 if (it != worklist.end())
345 addDefiningOpsToWorklist(op);
351 while (!worklist.empty()) {
355 addDefiningOpsToWorklist(op);
362 void transform::ApplyDeadCodeEliminationOp::getEffects(
363 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
387 if (!getRegion().empty()) {
388 for (
Operation &op : getRegion().front()) {
389 cast<transform::PatternDescriptorOpInterface>(&op)
390 .populatePatternsWithState(
patterns, state);
400 config.maxIterations = getMaxIterations() ==
static_cast<uint64_t
>(-1)
402 : getMaxIterations();
403 config.maxNumRewrites = getMaxNumRewrites() ==
static_cast<uint64_t
>(-1)
405 : getMaxNumRewrites();
410 bool cseChanged =
false;
413 static const int64_t kNumMaxIterations = 50;
414 int64_t iteration = 0;
416 LogicalResult result = failure();
429 if (target != nestedOp)
430 ops.push_back(nestedOp);
437 if (failed(result)) {
439 <<
"greedy pattern application failed";
447 }
while (cseChanged && ++iteration < kNumMaxIterations);
449 if (iteration == kNumMaxIterations)
456 if (!getRegion().empty()) {
457 for (
Operation &op : getRegion().front()) {
458 if (!isa<transform::PatternDescriptorOpInterface>(&op)) {
460 <<
"expected children ops to implement "
461 "PatternDescriptorOpInterface";
462 diag.attachNote(op.
getLoc()) <<
"op without interface";
470 void transform::ApplyPatternsOp::getEffects(
471 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
476 void transform::ApplyPatternsOp::build(
485 bodyBuilder(builder, result.
location);
492 void transform::ApplyCanonicalizationPatternsOp::populatePatterns(
496 dialect->getCanonicalizationPatterns(
patterns);
498 op.getCanonicalizationPatterns(
patterns, ctx);
512 std::unique_ptr<TypeConverter> defaultTypeConverter;
513 transform::TypeConverterBuilderOpInterface typeConverterBuilder =
514 getDefaultTypeConverter();
515 if (typeConverterBuilder)
516 defaultTypeConverter = typeConverterBuilder.getTypeConverter();
521 for (
Attribute attr : cast<ArrayAttr>(*getLegalOps()))
522 conversionTarget.addLegalOp(
525 for (
Attribute attr : cast<ArrayAttr>(*getIllegalOps()))
526 conversionTarget.addIllegalOp(
528 if (getLegalDialects())
529 for (
Attribute attr : cast<ArrayAttr>(*getLegalDialects()))
530 conversionTarget.addLegalDialect(cast<StringAttr>(attr).getValue());
531 if (getIllegalDialects())
532 for (
Attribute attr : cast<ArrayAttr>(*getIllegalDialects()))
533 conversionTarget.addIllegalDialect(cast<StringAttr>(attr).getValue());
541 if (!getPatterns().empty()) {
542 for (
Operation &op : getPatterns().front()) {
544 cast<transform::ConversionPatternDescriptorOpInterface>(&op);
547 std::unique_ptr<TypeConverter> typeConverter =
548 descriptor.getTypeConverter();
551 keepAliveConverters.emplace_back(std::move(typeConverter));
552 converter = keepAliveConverters.back().get();
555 if (!defaultTypeConverter) {
557 <<
"pattern descriptor does not specify type "
558 "converter and apply_conversion_patterns op has "
559 "no default type converter";
560 diag.attachNote(op.
getLoc()) <<
"pattern descriptor op";
563 converter = defaultTypeConverter.get();
569 descriptor.populateConversionTargetRules(*converter, conversionTarget);
571 descriptor.populatePatterns(*converter,
patterns);
579 TrackingListenerConfig trackingConfig;
580 trackingConfig.requireMatchingReplacementOpName =
false;
581 ErrorCheckingTrackingListener trackingListener(state, *
this, trackingConfig);
583 if (getPreserveHandles())
584 conversionConfig.
listener = &trackingListener;
587 for (
Operation *target : state.getPayloadOps(getTarget())) {
595 LogicalResult status = failure();
596 if (getPartialConversion()) {
606 if (failed(status)) {
607 diag = emitSilenceableError() <<
"dialect conversion failed";
608 diag.attachNote(target->
getLoc()) <<
"target op";
613 trackingListener.checkAndResetError();
615 if (
diag.succeeded()) {
617 return trackingFailure;
619 diag.attachNote() <<
"tracking listener also failed: "
621 (void)trackingFailure.
silence();
625 if (!
diag.succeeded())
633 if (getNumRegions() != 1 && getNumRegions() != 2)
634 return emitOpError() <<
"expected 1 or 2 regions";
635 if (!getPatterns().empty()) {
636 for (
Operation &op : getPatterns().front()) {
637 if (!isa<transform::ConversionPatternDescriptorOpInterface>(&op)) {
639 emitOpError() <<
"expected pattern children ops to implement "
640 "ConversionPatternDescriptorOpInterface";
641 diag.attachNote(op.
getLoc()) <<
"op without interface";
646 if (getNumRegions() == 2) {
647 Region &typeConverterRegion = getRegion(1);
648 if (!llvm::hasSingleElement(typeConverterRegion.
front()))
650 <<
"expected exactly one op in default type converter region";
652 auto typeConverterOp = dyn_cast<transform::TypeConverterBuilderOpInterface>(
654 if (!typeConverterOp) {
656 <<
"expected default converter child op to "
657 "implement TypeConverterBuilderOpInterface";
658 diag.attachNote(maybeTypeConverter->
getLoc()) <<
"op without interface";
662 if (!getPatterns().empty()) {
663 for (
Operation &op : getPatterns().front()) {
665 cast<transform::ConversionPatternDescriptorOpInterface>(&op);
666 if (failed(descriptor.verifyTypeConverter(typeConverterOp)))
674 void transform::ApplyConversionPatternsOp::getEffects(
675 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
676 if (!getPreserveHandles()) {
684 void transform::ApplyConversionPatternsOp::build(
694 if (patternsBodyBuilder)
695 patternsBodyBuilder(builder, result.
location);
701 if (typeConverterBodyBuilder)
702 typeConverterBodyBuilder(builder, result.
location);
710 void transform::ApplyToLLVMConversionPatternsOp::populatePatterns(
713 assert(dialect &&
"expected that dialect is loaded");
714 auto *iface = cast<ConvertToLLVMPatternInterface>(dialect);
718 iface->populateConvertToLLVMConversionPatterns(
722 LogicalResult transform::ApplyToLLVMConversionPatternsOp::verifyTypeConverter(
723 transform::TypeConverterBuilderOpInterface builder) {
724 if (builder.getTypeConverterType() !=
"LLVMTypeConverter")
725 return emitOpError(
"expected LLVMTypeConverter");
732 return emitOpError(
"unknown dialect or dialect not loaded: ")
734 auto *iface = dyn_cast<ConvertToLLVMPatternInterface>(dialect);
737 "dialect does not implement ConvertToLLVMPatternInterface or "
738 "extension was not loaded: ")
748 transform::ApplyLoopInvariantCodeMotionOp::applyToOne(
758 void transform::ApplyLoopInvariantCodeMotionOp::getEffects(
759 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
786 <<
"unknown pass or pass pipeline: " << getPassName();
790 if (failed(info->
addToPipeline(pm, getOptions(), [&](
const Twine &msg) {
795 <<
"failed to add pass or pass pipeline to pipeline: "
798 if (failed(pm.run(target))) {
799 auto diag = emitSilenceableError() <<
"pass pipeline failed";
800 diag.attachNote(target->
getLoc()) <<
"target op";
804 results.push_back(target);
814 Operation *target, ApplyToEachResultList &results,
816 results.push_back(target);
820 void transform::CastOp::getEffects(
821 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
828 assert(inputs.size() == 1 &&
"expected one input");
829 assert(outputs.size() == 1 &&
"expected one output");
831 std::initializer_list<Type>{inputs.front(), outputs.front()},
832 llvm::IsaPred<transform::TransformHandleTypeInterface>);
852 assert(block.
getParent() &&
"cannot match using a detached block");
853 auto matchScope = state.make_region_scope(*block.
getParent());
855 state.mapBlockArguments(block.
getArguments(), blockArgumentMapping)))
859 if (!isa<transform::MatchOpInterface>(match)) {
861 <<
"expected operations in the match part to "
862 "implement MatchOpInterface";
865 state.applyTransform(cast<transform::TransformOpInterface>(match));
866 if (
diag.succeeded())
884 template <
typename... Tys>
886 return ((isa<Tys>(t1) && isa<Tys>(t2)) || ... ||
false);
893 transform::TransformParamTypeInterface,
894 transform::TransformValueHandleTypeInterface>(
906 auto matcher = SymbolTable::lookupNearestSymbolFrom<FunctionOpInterface>(
907 getOperation(), getMatcher());
908 if (matcher.isExternal()) {
910 <<
"unresolved external symbol " << getMatcher();
914 rawResults.resize(getOperation()->getNumResults());
915 std::optional<DiagnosedSilenceableFailure> maybeFailure;
916 for (
Operation *root : state.getPayloadOps(getRoot())) {
920 op->
print(llvm::dbgs(),
922 llvm::dbgs() <<
" @" << op <<
"\n";
929 matcher.getFunctionBody().front(),
932 if (
diag.isDefiniteFailure())
934 if (
diag.isSilenceableFailure()) {
936 <<
" failed: " <<
diag.getMessage());
942 if (mapping.size() != 1) {
943 maybeFailure.emplace(emitSilenceableError()
944 <<
"result #" << i <<
", associated with "
946 <<
" payload objects, expected 1");
949 rawResults[i].push_back(mapping[0]);
954 return std::move(*maybeFailure);
955 assert(!maybeFailure &&
"failure set but the walk was not interrupted");
957 for (
auto &&[opResult, rawResult] :
958 llvm::zip_equal(getOperation()->getResults(), rawResults)) {
965 void transform::CollectMatchingOp::getEffects(
966 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
972 LogicalResult transform::CollectMatchingOp::verifySymbolUses(
974 auto matcherSymbol = dyn_cast_or_null<FunctionOpInterface>(
976 if (!matcherSymbol ||
977 !isa<TransformOpInterface>(matcherSymbol.getOperation()))
978 return emitError() <<
"unresolved matcher symbol " << getMatcher();
981 if (argumentTypes.size() != 1 ||
982 !isa<TransformHandleTypeInterface>(argumentTypes[0])) {
984 <<
"expected the matcher to take one operation handle argument";
986 if (!matcherSymbol.getArgAttr(
987 0, transform::TransformDialect::kArgReadOnlyAttrName)) {
988 return emitError() <<
"expected the matcher argument to be marked readonly";
992 if (resultTypes.size() != getOperation()->getNumResults()) {
994 <<
"expected the matcher to yield as many values as op has results ("
995 << getOperation()->getNumResults() <<
"), got "
996 << resultTypes.size();
999 for (
auto &&[i, matcherType, resultType] :
1005 <<
"mismatching type interfaces for matcher result and op result #"
1017 bool transform::ForeachMatchOp::allowsRepeatedHandleOperands() {
return true; }
1025 matchActionPairs.reserve(getMatchers().size());
1027 for (
auto &&[matcher, action] :
1028 llvm::zip_equal(getMatchers(), getActions())) {
1029 auto matcherSymbol =
1031 getOperation(), cast<SymbolRefAttr>(matcher));
1034 getOperation(), cast<SymbolRefAttr>(action));
1035 assert(matcherSymbol && actionSymbol &&
1036 "unresolved symbols not caught by the verifier");
1038 if (matcherSymbol.isExternal())
1040 if (actionSymbol.isExternal())
1043 matchActionPairs.emplace_back(matcherSymbol, actionSymbol);
1054 matchInputMapping.emplace_back();
1056 getForwardedInputs(), state);
1058 actionResultMapping.resize(getForwardedOutputs().size());
1060 for (
Operation *root : state.getPayloadOps(getRoot())) {
1064 if (!getRestrictRoot() && op == root)
1069 op->
print(llvm::dbgs(),
1071 llvm::dbgs() <<
" @" << op <<
"\n";
1074 firstMatchArgument.clear();
1075 firstMatchArgument.push_back(op);
1078 for (
auto [matcher, action] : matchActionPairs) {
1080 matchBlock(matcher.getFunctionBody().front(), matchInputMapping,
1081 state, matchOutputMapping);
1082 if (
diag.isDefiniteFailure())
1084 if (
diag.isSilenceableFailure()) {
1086 <<
" failed: " <<
diag.getMessage());
1090 auto scope = state.make_region_scope(action.getFunctionBody());
1091 if (failed(state.mapBlockArguments(
1092 action.getFunctionBody().front().getArguments(),
1093 matchOutputMapping))) {
1098 action.getFunctionBody().front().without_terminator()) {
1100 state.applyTransform(cast<TransformOpInterface>(transform));
1105 overallDiag = emitSilenceableError() <<
"actions failed";
1110 <<
"when applied to this matching payload";
1117 action.getFunctionBody().front().getTerminator()->getOperands(),
1118 state, getFlattenResults()))) {
1120 <<
"action @" << action.getName()
1121 <<
" has results associated with multiple payload entities, "
1122 "but flattening was not requested";
1137 results.
set(llvm::cast<OpResult>(getUpdated()),
1138 state.getPayloadOps(getRoot()));
1139 for (
auto &&[result, mapping] :
1140 llvm::zip_equal(getForwardedOutputs(), actionResultMapping)) {
1146 void transform::ForeachMatchOp::getAsmResultNames(
1148 setNameFn(getUpdated(),
"updated_root");
1149 for (
Value v : getForwardedOutputs()) {
1150 setNameFn(v,
"yielded");
1154 void transform::ForeachMatchOp::getEffects(
1155 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
1157 if (getOperation()->getNumOperands() < 1 ||
1158 getOperation()->getNumResults() < 1) {
1171 ArrayAttr &matchers,
1172 ArrayAttr &actions) {
1194 ArrayAttr matchers, ArrayAttr actions) {
1197 for (
auto &&[matcher, action, idx] : llvm::zip_equal(
1198 matchers, actions, llvm::seq<unsigned>(0, matchers.size()))) {
1200 printer << cast<SymbolRefAttr>(matcher) <<
" -> "
1201 << cast<SymbolRefAttr>(action);
1202 if (idx != matchers.size() - 1)
1210 if (getMatchers().size() != getActions().size())
1211 return emitOpError() <<
"expected the same number of matchers and actions";
1212 if (getMatchers().empty())
1213 return emitOpError() <<
"expected at least one match/action pair";
1217 if (matcherNames.insert(name).second)
1220 <<
" is used more than once, only the first match will apply";
1231 bool alsoVerifyInternal =
false) {
1232 auto transformOp = cast<transform::TransformOpInterface>(op.getOperation());
1233 llvm::SmallDenseSet<unsigned> consumedArguments;
1234 if (!op.isExternal()) {
1238 for (
unsigned i = 0, e = op.getNumArguments(); i < e; ++i) {
1240 op.getArgAttr(i, transform::TransformDialect::kArgConsumedAttrName) !=
1243 op.getArgAttr(i, transform::TransformDialect::kArgReadOnlyAttrName) !=
1245 if (isConsumed && isReadOnly) {
1246 return transformOp.emitSilenceableError()
1247 <<
"argument #" << i <<
" cannot be both readonly and consumed";
1249 if ((op.isExternal() || alsoVerifyInternal) && !isConsumed && !isReadOnly) {
1250 return transformOp.emitSilenceableError()
1251 <<
"must provide consumed/readonly status for arguments of "
1252 "external or called ops";
1254 if (op.isExternal())
1257 if (consumedArguments.contains(i) && !isConsumed && isReadOnly) {
1258 return transformOp.emitSilenceableError()
1259 <<
"argument #" << i
1260 <<
" is consumed in the body but is not marked as such";
1262 if (emitWarnings && !consumedArguments.contains(i) && isConsumed) {
1266 <<
"op argument #" << i
1267 <<
" is not consumed in the body but is marked as consumed";
1273 LogicalResult transform::ForeachMatchOp::verifySymbolUses(
1275 assert(getMatchers().size() == getActions().size());
1278 for (
auto &&[matcher, action] :
1279 llvm::zip_equal(getMatchers(), getActions())) {
1281 auto matcherSymbol = dyn_cast_or_null<FunctionOpInterface>(
1283 cast<SymbolRefAttr>(matcher)));
1284 auto actionSymbol = dyn_cast_or_null<FunctionOpInterface>(
1286 cast<SymbolRefAttr>(action)));
1287 if (!matcherSymbol ||
1288 !isa<TransformOpInterface>(matcherSymbol.getOperation()))
1289 return emitError() <<
"unresolved matcher symbol " << matcher;
1290 if (!actionSymbol ||
1291 !isa<TransformOpInterface>(actionSymbol.getOperation()))
1292 return emitError() <<
"unresolved action symbol " << action;
1297 .checkAndReport())) {
1303 .checkAndReport())) {
1308 TypeRange operandTypes = getOperandTypes();
1309 TypeRange matcherArguments = matcherSymbol.getArgumentTypes();
1310 if (operandTypes.size() != matcherArguments.size()) {
1312 emitError() <<
"the number of operands (" << operandTypes.size()
1313 <<
") doesn't match the number of matcher arguments ("
1314 << matcherArguments.size() <<
") for " << matcher;
1315 diag.attachNote(matcherSymbol->getLoc()) <<
"symbol declaration";
1318 for (
auto &&[i, operand, argument] :
1320 if (matcherSymbol.getArgAttr(i, consumedAttr)) {
1323 <<
"does not expect matcher symbol to consume its operand #" << i;
1324 diag.attachNote(matcherSymbol->getLoc()) <<
"symbol declaration";
1333 <<
"mismatching type interfaces for operand and matcher argument #"
1334 << i <<
" of matcher " << matcher;
1335 diag.attachNote(matcherSymbol->getLoc()) <<
"symbol declaration";
1340 TypeRange matcherResults = matcherSymbol.getResultTypes();
1341 TypeRange actionArguments = actionSymbol.getArgumentTypes();
1342 if (matcherResults.size() != actionArguments.size()) {
1343 return emitError() <<
"mismatching number of matcher results and "
1344 "action arguments between "
1345 << matcher <<
" (" << matcherResults.size() <<
") and "
1346 << action <<
" (" << actionArguments.size() <<
")";
1348 for (
auto &&[i, matcherType, actionType] :
1353 return emitError() <<
"mismatching type interfaces for matcher result "
1354 "and action argument #"
1355 << i <<
"of matcher " << matcher <<
" and action "
1360 TypeRange actionResults = actionSymbol.getResultTypes();
1361 auto resultTypes =
TypeRange(getResultTypes()).drop_front();
1362 if (actionResults.size() != resultTypes.size()) {
1364 emitError() <<
"the number of action results ("
1365 << actionResults.size() <<
") for " << action
1366 <<
" doesn't match the number of extra op results ("
1367 << resultTypes.size() <<
")";
1368 diag.attachNote(actionSymbol->getLoc()) <<
"symbol declaration";
1371 for (
auto &&[i, resultType, actionType] :
1377 emitError() <<
"mismatching type interfaces for action result #" << i
1378 <<
" of action " << action <<
" and op result";
1379 diag.attachNote(actionSymbol->getLoc()) <<
"symbol declaration";
1398 size_t numIterations = payloads.empty() ? 0 : payloads.front().size();
1399 bool withZipShortest = getWithZipShortest();
1403 if (withZipShortest) {
1407 return A.size() <
B.size();
1410 for (
size_t argIdx = 0; argIdx < payloads.size(); argIdx++)
1411 payloads[argIdx].resize(numIterations);
1417 for (
size_t argIdx = 1; !withZipShortest && argIdx < payloads.size();
1419 if (payloads[argIdx].size() != numIterations) {
1420 return emitSilenceableError()
1421 <<
"prior targets' payload size (" << numIterations
1422 <<
") differs from payload size (" << payloads[argIdx].size()
1423 <<
") of target " << getTargets()[argIdx];
1432 for (
size_t iterIdx = 0; iterIdx < numIterations; iterIdx++) {
1433 auto scope = state.make_region_scope(getBody());
1439 if (failed(state.mapBlockArgument(blockArg, {argument})))
1444 for (
Operation &transform : getBody().front().without_terminator()) {
1446 llvm::cast<transform::TransformOpInterface>(transform));
1452 OperandRange yieldOperands = getYieldOp().getOperands();
1453 for (
auto &&[result, yieldOperand, resTuple] :
1454 llvm::zip_equal(getResults(), yieldOperands, zippedResults))
1456 if (isa<TransformHandleTypeInterface>(result.getType()))
1457 llvm::append_range(resTuple, state.getPayloadOps(yieldOperand));
1458 else if (isa<TransformValueHandleTypeInterface>(result.getType()))
1459 llvm::append_range(resTuple, state.getPayloadValues(yieldOperand));
1460 else if (isa<TransformParamTypeInterface>(result.getType()))
1461 llvm::append_range(resTuple, state.getParams(yieldOperand));
1463 assert(
false &&
"unhandled handle type");
1467 for (
auto &&[result, resPayload] : zip_equal(getResults(), zippedResults))
1473 void transform::ForeachOp::getEffects(
1474 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
1477 for (
auto &&[target, blockArg] :
1478 llvm::zip(getTargetsMutable(), getBody().front().getArguments())) {
1480 if (any_of(getBody().front().without_terminator(), [&](
Operation &op) {
1482 cast<TransformOpInterface>(&op));
1490 if (any_of(getBody().front().without_terminator(), [&](
Operation &op) {
1494 }
else if (any_of(getBody().front().without_terminator(), [&](
Operation &op) {
1503 void transform::ForeachOp::getSuccessorRegions(
1505 Region *bodyRegion = &getBody();
1507 regions.emplace_back(bodyRegion, bodyRegion->
getArguments());
1512 assert(point == getBody() &&
"unexpected region index");
1513 regions.emplace_back(bodyRegion, bodyRegion->
getArguments());
1514 regions.emplace_back();
1521 assert(point == getBody() &&
"unexpected region index");
1522 return getOperation()->getOperands();
1525 transform::YieldOp transform::ForeachOp::getYieldOp() {
1526 return cast<transform::YieldOp>(getBody().front().getTerminator());
1530 for (
auto [targetOpt, bodyArgOpt] :
1531 llvm::zip_longest(getTargets(), getBody().front().getArguments())) {
1532 if (!targetOpt || !bodyArgOpt)
1533 return emitOpError() <<
"expects the same number of targets as the body "
1534 "has block arguments";
1535 if (targetOpt.value().getType() != bodyArgOpt.value().getType())
1537 "expects co-indexed targets and the body's "
1538 "block arguments to have the same op/value/param type");
1541 for (
auto [resultOpt, yieldOperandOpt] :
1542 llvm::zip_longest(getResults(), getYieldOp().getOperands())) {
1543 if (!resultOpt || !yieldOperandOpt)
1544 return emitOpError() <<
"expects the same number of results as the "
1545 "yield terminator has operands";
1546 if (resultOpt.value().getType() != yieldOperandOpt.value().getType())
1547 return emitOpError(
"expects co-indexed results and yield "
1548 "operands to have the same op/value/param type");
1564 for (
Operation *target : state.getPayloadOps(getTarget())) {
1566 for (int64_t i = 0, e = getNthParent(); i < e; ++i) {
1569 bool checkIsolatedFromAbove =
1570 !getIsolatedFromAbove() ||
1572 bool checkOpName = !getOpName().has_value() ||
1574 if (checkIsolatedFromAbove && checkOpName)
1579 if (getAllowEmptyResults()) {
1580 results.
set(llvm::cast<OpResult>(getResult()), parents);
1584 emitSilenceableError()
1585 <<
"could not find a parent op that matches all requirements";
1586 diag.attachNote(target->
getLoc()) <<
"target op";
1590 if (getDeduplicate()) {
1591 if (resultSet.insert(parent).second)
1592 parents.push_back(parent);
1594 parents.push_back(parent);
1597 results.
set(llvm::cast<OpResult>(getResult()), parents);
1609 int64_t resultNumber = getResultNumber();
1610 auto payloadOps = state.getPayloadOps(getTarget());
1611 if (std::empty(payloadOps)) {
1612 results.
set(cast<OpResult>(getResult()), {});
1615 if (!llvm::hasSingleElement(payloadOps))
1617 <<
"handle must be mapped to exactly one payload op";
1619 Operation *target = *payloadOps.begin();
1622 results.
set(llvm::cast<OpResult>(getResult()),
1636 for (
Value v : state.getPayloadValues(getTarget())) {
1637 if (llvm::isa<BlockArgument>(v)) {
1639 emitSilenceableError() <<
"cannot get defining op of block argument";
1640 diag.attachNote(v.getLoc()) <<
"target value";
1643 definingOps.push_back(v.getDefiningOp());
1645 results.
set(llvm::cast<OpResult>(getResult()), definingOps);
1657 int64_t operandNumber = getOperandNumber();
1659 for (
Operation *target : state.getPayloadOps(getTarget())) {
1666 emitSilenceableError()
1667 <<
"could not find a producer for operand number: " << operandNumber
1668 <<
" of " << *target;
1669 diag.attachNote(target->getLoc()) <<
"target op";
1672 producers.push_back(producer);
1674 results.
set(llvm::cast<OpResult>(getResult()), producers);
1687 for (
Operation *target : state.getPayloadOps(getTarget())) {
1690 getLoc(), getIsAll(), getIsInverted(), getRawPositionList(),
1691 target->getNumOperands(), operandPositions);
1692 if (
diag.isSilenceableFailure()) {
1693 diag.attachNote(target->getLoc())
1694 <<
"while considering positions of this payload operation";
1697 llvm::append_range(operands,
1698 llvm::map_range(operandPositions, [&](int64_t pos) {
1699 return target->getOperand(pos);
1702 results.
setValues(cast<OpResult>(getResult()), operands);
1708 getIsInverted(), getIsAll());
1720 for (
Operation *target : state.getPayloadOps(getTarget())) {
1723 getLoc(), getIsAll(), getIsInverted(), getRawPositionList(),
1724 target->getNumResults(), resultPositions);
1725 if (
diag.isSilenceableFailure()) {
1726 diag.attachNote(target->getLoc())
1727 <<
"while considering positions of this payload operation";
1730 llvm::append_range(opResults,
1731 llvm::map_range(resultPositions, [&](int64_t pos) {
1732 return target->getResult(pos);
1735 results.
setValues(cast<OpResult>(getResult()), opResults);
1741 getIsInverted(), getIsAll());
1748 void transform::GetTypeOp::getEffects(
1749 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
1760 for (
Value value : state.getPayloadValues(getValue())) {
1761 Type type = value.getType();
1762 if (getElemental()) {
1763 if (
auto shaped = dyn_cast<ShapedType>(type)) {
1764 type = shaped.getElementType();
1769 results.
setParams(cast<OpResult>(getResult()), params);
1786 state.applyTransform(cast<transform::TransformOpInterface>(transform));
1791 if (mode == transform::FailurePropagationMode::Propagate) {
1810 auto callee = SymbolTable::lookupNearestSymbolFrom<NamedSequenceOp>(
1811 getOperation(), getTarget());
1812 assert(callee &&
"unverified reference to unknown symbol");
1814 if (callee.isExternal())
1820 auto scope = state.make_region_scope(callee.getBody());
1821 for (
auto &&[arg, map] :
1822 llvm::zip_equal(callee.getBody().front().getArguments(), mappings)) {
1823 if (failed(state.mapBlockArgument(arg, map)))
1828 callee.getBody().front(), getFailurePropagationMode(), state, results);
1831 mappings, callee.getBody().front().getTerminator()->getOperands(), state);
1832 for (
auto &&[result, mapping] : llvm::zip_equal(getResults(), mappings))
1840 void transform::IncludeOp::getEffects(
1841 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
1855 auto defaultEffects = [&] {
1862 getOperation()->getAttrOfType<SymbolRefAttr>(getTargetAttrName());
1864 return defaultEffects();
1865 auto callee = SymbolTable::lookupNearestSymbolFrom<NamedSequenceOp>(
1866 getOperation(), getTarget());
1868 return defaultEffects();
1872 (void)earlyVerifierResult.
silence();
1873 return defaultEffects();
1876 for (
unsigned i = 0, e = getNumOperands(); i < e; ++i) {
1877 if (callee.getArgAttr(i, TransformDialect::kArgConsumedAttrName))
1888 auto targetAttr = getOperation()->getAttrOfType<SymbolRefAttr>(
"target");
1890 return emitOpError() <<
"expects a 'target' symbol reference attribute";
1895 return emitOpError() <<
"does not reference a named transform sequence";
1897 FunctionType fnType = target.getFunctionType();
1898 if (fnType.getNumInputs() != getNumOperands())
1899 return emitError(
"incorrect number of operands for callee");
1901 for (
unsigned i = 0, e = fnType.getNumInputs(); i != e; ++i) {
1902 if (getOperand(i).
getType() != fnType.getInput(i)) {
1903 return emitOpError(
"operand type mismatch: expected operand type ")
1904 << fnType.getInput(i) <<
", but provided "
1905 << getOperand(i).getType() <<
" for operand number " << i;
1909 if (fnType.getNumResults() != getNumResults())
1910 return emitError(
"incorrect number of results for callee");
1912 for (
unsigned i = 0, e = fnType.getNumResults(); i != e; ++i) {
1913 Type resultType = getResult(i).getType();
1914 Type funcType = fnType.getResult(i);
1916 return emitOpError() <<
"type of result #" << i
1917 <<
" must implement the same transform dialect "
1918 "interface as the corresponding callee result";
1923 cast<FunctionOpInterface>(*target),
false,
1933 ::std::optional<::mlir::Operation *> maybeCurrent,
1935 if (!maybeCurrent.has_value()) {
1940 return emitSilenceableError() <<
"operation is not empty";
1951 for (
auto acceptedAttr : getOpNames().getAsRange<StringAttr>()) {
1952 if (acceptedAttr.getValue() == currentOpName)
1955 return emitSilenceableError() <<
"wrong operation name";
1966 auto signedAPIntAsString = [&](
const APInt &value) {
1968 llvm::raw_string_ostream os(str);
1969 value.print(os,
true);
1976 if (params.size() != references.size()) {
1977 return emitSilenceableError()
1978 <<
"parameters have different payload lengths (" << params.size()
1979 <<
" vs " << references.size() <<
")";
1982 for (
auto &&[i, param, reference] :
llvm::enumerate(params, references)) {
1983 auto intAttr = llvm::dyn_cast<IntegerAttr>(param);
1984 auto refAttr = llvm::dyn_cast<IntegerAttr>(reference);
1985 if (!intAttr || !refAttr) {
1987 <<
"non-integer parameter value not expected";
1989 if (intAttr.getType() != refAttr.getType()) {
1991 <<
"mismatching integer attribute types in parameter #" << i;
1993 APInt value = intAttr.getValue();
1994 APInt refValue = refAttr.getValue();
1997 int64_t position = i;
1998 auto reportError = [&](StringRef direction) {
2000 emitSilenceableError() <<
"expected parameter to be " << direction
2001 <<
" " << signedAPIntAsString(refValue)
2002 <<
", got " << signedAPIntAsString(value);
2003 diag.attachNote(getParam().getLoc())
2004 <<
"value # " << position
2005 <<
" associated with the parameter defined here";
2009 switch (getPredicate()) {
2010 case MatchCmpIPredicate::eq:
2011 if (value.eq(refValue))
2013 return reportError(
"equal to");
2014 case MatchCmpIPredicate::ne:
2015 if (value.ne(refValue))
2017 return reportError(
"not equal to");
2018 case MatchCmpIPredicate::lt:
2019 if (value.slt(refValue))
2021 return reportError(
"less than");
2022 case MatchCmpIPredicate::le:
2023 if (value.sle(refValue))
2025 return reportError(
"less than or equal to");
2026 case MatchCmpIPredicate::gt:
2027 if (value.sgt(refValue))
2029 return reportError(
"greater than");
2030 case MatchCmpIPredicate::ge:
2031 if (value.sge(refValue))
2033 return reportError(
"greater than or equal to");
2039 void transform::MatchParamCmpIOp::getEffects(
2040 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
2053 results.
setParams(cast<OpResult>(getParam()), {getValue()});
2066 if (isa<TransformHandleTypeInterface>(handles.front().
getType())) {
2068 for (
Value operand : handles)
2069 llvm::append_range(operations, state.getPayloadOps(operand));
2070 if (!getDeduplicate()) {
2071 results.
set(llvm::cast<OpResult>(getResult()), operations);
2076 results.
set(llvm::cast<OpResult>(getResult()), uniqued.getArrayRef());
2080 if (llvm::isa<TransformParamTypeInterface>(handles.front().getType())) {
2082 for (
Value attribute : handles)
2083 llvm::append_range(attrs, state.getParams(attribute));
2084 if (!getDeduplicate()) {
2085 results.
setParams(cast<OpResult>(getResult()), attrs);
2090 results.
setParams(cast<OpResult>(getResult()), uniqued.getArrayRef());
2095 llvm::isa<TransformValueHandleTypeInterface>(handles.front().getType()) &&
2096 "expected value handle type");
2098 for (
Value value : handles)
2099 llvm::append_range(payloadValues, state.getPayloadValues(value));
2100 if (!getDeduplicate()) {
2101 results.
setValues(cast<OpResult>(getResult()), payloadValues);
2106 results.
setValues(cast<OpResult>(getResult()), uniqued.getArrayRef());
2110 bool transform::MergeHandlesOp::allowsRepeatedHandleOperands() {
2112 return getDeduplicate();
2115 void transform::MergeHandlesOp::getEffects(
2116 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
2124 OpFoldResult transform::MergeHandlesOp::fold(FoldAdaptor adaptor) {
2125 if (getDeduplicate() || getHandles().size() != 1)
2130 return getHandles().front();
2148 auto scope = state.make_region_scope(getBody());
2150 state, this->getOperation(), getBody())))
2154 FailurePropagationMode::Propagate, state, results);
2157 void transform::NamedSequenceOp::getEffects(
2158 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {}
2163 parser, result,
false,
2164 getFunctionTypeAttrName(result.
name),
2167 std::string &) { return builder.getFunctionType(inputs, results); },
2168 getArgAttrsAttrName(result.
name), getResAttrsAttrName(result.
name));
2173 printer, cast<FunctionOpInterface>(getOperation()),
false,
2174 getFunctionTypeAttrName().getValue(), getArgAttrsAttrName(),
2175 getResAttrsAttrName());
2185 if (
auto parent = op->
getParentOfType<transform::TransformOpInterface>()) {
2188 <<
"cannot be defined inside another transform op";
2189 diag.attachNote(parent.
getLoc()) <<
"ancestor transform op";
2193 if (op.isExternal() || op.getFunctionBody().empty()) {
2200 if (op.getFunctionBody().front().empty())
2203 Operation *terminator = &op.getFunctionBody().front().back();
2204 if (!isa<transform::YieldOp>(terminator)) {
2207 << transform::YieldOp::getOperationName()
2208 <<
"' as terminator";
2209 diag.attachNote(terminator->
getLoc()) <<
"terminator";
2213 if (terminator->
getNumOperands() != op.getResultTypes().size()) {
2215 <<
"expected terminator to have as many operands as the parent op "
2218 for (
auto [i, operandType, resultType] : llvm::zip_equal(
2221 if (operandType == resultType)
2224 <<
"the type of the terminator operand #" << i
2225 <<
" must match the type of the corresponding parent op result ("
2226 << operandType <<
" vs " << resultType <<
")";
2239 transform::TransformDialect::kWithNamedSequenceAttrName)) {
2242 <<
"expects the parent symbol table to have the '"
2243 << transform::TransformDialect::kWithNamedSequenceAttrName
2245 diag.attachNote(parent->
getLoc()) <<
"symbol table operation";
2250 if (
auto parent = op->
getParentOfType<transform::TransformOpInterface>()) {
2253 <<
"cannot be defined inside another transform op";
2254 diag.attachNote(parent.
getLoc()) <<
"ancestor transform op";
2258 if (op.isExternal() || op.getBody().empty())
2262 if (op.getBody().front().empty())
2265 Operation *terminator = &op.getBody().front().back();
2266 if (!isa<transform::YieldOp>(terminator)) {
2269 << transform::YieldOp::getOperationName()
2270 <<
"' as terminator";
2271 diag.attachNote(terminator->
getLoc()) <<
"terminator";
2275 if (terminator->
getNumOperands() != op.getFunctionType().getNumResults()) {
2277 <<
"expected terminator to have as many operands as the parent op "
2280 for (
auto [i, operandType, resultType] :
2281 llvm::zip_equal(llvm::seq<unsigned>(0, terminator->
getNumOperands()),
2283 op.getFunctionType().getResults())) {
2284 if (operandType == resultType)
2287 <<
"the type of the terminator operand #" << i
2288 <<
" must match the type of the corresponding parent op result ("
2289 << operandType <<
" vs " << resultType <<
")";
2292 auto funcOp = cast<FunctionOpInterface>(*op);
2295 if (!
diag.succeeded())
2307 template <
typename FnTy>
2312 types.reserve(1 + extraBindingTypes.size());
2313 types.push_back(bbArgType);
2314 llvm::append_range(types, extraBindingTypes);
2317 Region *region = state.regions.back().get();
2324 if constexpr (llvm::function_traits<FnTy>::num_args == 3) {
2325 bodyBuilder(builder, state.location, bodyBlock->
getArgument(0));
2327 bodyBuilder(builder, state.location, bodyBlock->
getArgument(0),
2332 void transform::NamedSequenceOp::build(
OpBuilder &builder,
2340 state.addAttribute(getFunctionTypeAttrName(state.name),
2342 rootType, resultTypes)));
2343 state.attributes.append(attrs.begin(), attrs.end());
2358 size_t numAssociations =
2360 .Case([&](TransformHandleTypeInterface opHandle) {
2361 return llvm::range_size(state.getPayloadOps(getHandle()));
2363 .Case([&](TransformValueHandleTypeInterface valueHandle) {
2364 return llvm::range_size(state.getPayloadValues(getHandle()));
2366 .Case([&](TransformParamTypeInterface param) {
2367 return llvm::range_size(state.getParams(getHandle()));
2370 llvm_unreachable(
"unknown kind of transform dialect type");
2373 results.
setParams(cast<OpResult>(getNum()),
2380 auto resultType = cast<TransformParamTypeInterface>(getNum().
getType());
2395 auto payloadOps = state.getPayloadOps(getTarget());
2398 result.push_back(op);
2400 results.
set(cast<OpResult>(getResult()), result);
2409 Value target, int64_t numResultHandles) {
2418 int64_t numPayloads =
2420 .Case<TransformHandleTypeInterface>([&](
auto x) {
2421 return llvm::range_size(state.getPayloadOps(getHandle()));
2423 .Case<TransformValueHandleTypeInterface>([&](
auto x) {
2424 return llvm::range_size(state.getPayloadValues(getHandle()));
2426 .Case<TransformParamTypeInterface>([&](
auto x) {
2427 return llvm::range_size(state.getParams(getHandle()));
2429 .Default([](
auto x) {
2430 llvm_unreachable(
"unknown transform dialect type interface");
2434 auto produceNumOpsError = [&]() {
2435 return emitSilenceableError()
2436 << getHandle() <<
" expected to contain " << this->getNumResults()
2437 <<
" payloads but it contains " << numPayloads <<
" payloads";
2442 if (numPayloads > getNumResults() && !getOverflowResult().has_value())
2443 return produceNumOpsError();
2448 if (numPayloads < getNumResults() && getFailOnPayloadTooSmall() &&
2449 (numPayloads != 0 || !getPassThroughEmptyHandle()))
2450 return produceNumOpsError();
2454 if (getOverflowResult())
2455 resultHandles[*getOverflowResult()].reserve(numPayloads - getNumResults());
2457 auto container = [&]() {
2458 if (isa<TransformHandleTypeInterface>(getHandle().
getType())) {
2459 return llvm::map_to_vector(
2460 state.getPayloadOps(getHandle()),
2463 if (isa<TransformValueHandleTypeInterface>(getHandle().
getType())) {
2464 return llvm::map_to_vector(state.getPayloadValues(getHandle()),
2467 assert(isa<TransformParamTypeInterface>(getHandle().
getType()) &&
2468 "unsupported kind of transform dialect type");
2469 return llvm::map_to_vector(state.getParams(getHandle()),
2474 int64_t resultNum = en.index();
2475 if (resultNum >= getNumResults())
2476 resultNum = *getOverflowResult();
2477 resultHandles[resultNum].push_back(en.value());
2488 void transform::SplitHandleOp::getEffects(
2489 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
2497 if (getOverflowResult().has_value() &&
2498 !(*getOverflowResult() < getNumResults()))
2499 return emitOpError(
"overflow_result is not a valid result index");
2501 for (
Type resultType : getResultTypes()) {
2505 return emitOpError(
"expects result types to implement the same transform "
2506 "interface as the operand type");
2520 unsigned numRepetitions = llvm::range_size(state.getPayloadOps(getPattern()));
2522 Value handle = en.value();
2523 if (isa<TransformHandleTypeInterface>(handle.
getType())) {
2525 llvm::to_vector(state.getPayloadOps(handle));
2527 payload.reserve(numRepetitions * current.size());
2528 for (
unsigned i = 0; i < numRepetitions; ++i)
2529 llvm::append_range(payload, current);
2530 results.
set(llvm::cast<OpResult>(getReplicated()[en.index()]), payload);
2532 assert(llvm::isa<TransformParamTypeInterface>(handle.
getType()) &&
2533 "expected param type");
2536 params.reserve(numRepetitions * current.size());
2537 for (
unsigned i = 0; i < numRepetitions; ++i)
2538 llvm::append_range(params, current);
2539 results.
setParams(llvm::cast<OpResult>(getReplicated()[en.index()]),
2546 void transform::ReplicateOp::getEffects(
2547 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
2562 auto scope = state.make_region_scope(*getBodyBlock()->getParent());
2563 if (failed(mapBlockArguments(state)))
2571 OpAsmParser &parser, std::optional<OpAsmParser::UnresolvedOperand> &root,
2573 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &extraBindings,
2574 SmallVectorImpl<Type> &extraBindingTypes) {
2578 root = std::nullopt;
2581 if (failed(hasRoot.
value()))
2595 if (failed(parser.
parseType(rootType))) {
2599 if (!extraBindings.empty()) {
2604 if (extraBindingTypes.size() != extraBindings.size()) {
2606 "expected types to be provided for all operands");
2622 bool hasExtras = !extraBindings.empty();
2632 printer << rootType;
2635 llvm::interleaveComma(extraBindingTypes, printer.
getStream());
2644 auto iface = dyn_cast<transform::TransformOpInterface>(use.
getOwner());
2659 if (!potentialConsumer) {
2660 potentialConsumer = &use;
2665 <<
" has more than one potential consumer";
2668 diag.attachNote(use.getOwner()->getLoc())
2669 <<
"used here as operand #" << use.getOperandNumber();
2677 assert(getBodyBlock()->getNumArguments() >= 1 &&
2678 "the number of arguments must have been verified to be more than 1 by "
2679 "PossibleTopLevelTransformOpTrait");
2681 if (!getRoot() && !getExtraBindings().empty()) {
2682 return emitOpError()
2683 <<
"does not expect extra operands when used as top-level";
2689 return (emitOpError() <<
"block argument #" << arg.getArgNumber());
2696 for (
Operation &child : *getBodyBlock()) {
2697 if (!isa<TransformOpInterface>(child) &&
2698 &child != &getBodyBlock()->back()) {
2701 <<
"expected children ops to implement TransformOpInterface";
2702 diag.attachNote(child.getLoc()) <<
"op without interface";
2706 for (
OpResult result : child.getResults()) {
2707 auto report = [&]() {
2708 return (child.emitError() <<
"result #" << result.getResultNumber());
2715 if (!getBodyBlock()->mightHaveTerminator())
2716 return emitOpError() <<
"expects to have a terminator in the body";
2718 if (getBodyBlock()->getTerminator()->getOperandTypes() !=
2719 getOperation()->getResultTypes()) {
2721 <<
"expects the types of the terminator operands "
2722 "to match the types of the result";
2723 diag.attachNote(getBodyBlock()->getTerminator()->getLoc()) <<
"terminator";
2729 void transform::SequenceOp::getEffects(
2730 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
2736 assert(point == getBody() &&
"unexpected region index");
2737 if (getOperation()->getNumOperands() > 0)
2738 return getOperation()->getOperands();
2740 getOperation()->operand_end());
2743 void transform::SequenceOp::getSuccessorRegions(
2746 Region *bodyRegion = &getBody();
2747 regions.emplace_back(bodyRegion, getNumOperands() != 0
2753 assert(point == getBody() &&
"unexpected region index");
2754 regions.emplace_back(getOperation()->getResults());
2757 void transform::SequenceOp::getRegionInvocationBounds(
2760 bounds.emplace_back(1, 1);
2765 FailurePropagationMode failurePropagationMode,
2768 build(builder, state, resultTypes, failurePropagationMode, root,
2777 FailurePropagationMode failurePropagationMode,
2780 build(builder, state, resultTypes, failurePropagationMode, root,
2788 FailurePropagationMode failurePropagationMode,
2791 build(builder, state, resultTypes, failurePropagationMode,
Value(),
2799 FailurePropagationMode failurePropagationMode,
2802 build(builder, state, resultTypes, failurePropagationMode,
Value(),
2818 Value target, StringRef name) {
2820 build(builder, result, name);
2827 llvm::outs() <<
"[[[ IR printer: ";
2828 if (getName().has_value())
2829 llvm::outs() << *getName() <<
" ";
2832 if (getAssumeVerified().value_or(
false))
2834 if (getUseLocalScope().value_or(
false))
2836 if (getSkipRegions().value_or(
false))
2840 llvm::outs() <<
"top-level ]]]\n";
2841 state.getTopLevel()->print(llvm::outs(), printFlags);
2842 llvm::outs() <<
"\n";
2846 llvm::outs() <<
"]]]\n";
2847 for (
Operation *target : state.getPayloadOps(getTarget())) {
2848 target->
print(llvm::outs(), printFlags);
2849 llvm::outs() <<
"\n";
2855 void transform::PrintOp::getEffects(
2856 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
2860 if (!getTargetMutable().empty())
2880 <<
"failed to verify payload op";
2881 diag.attachNote(target->
getLoc()) <<
"payload op";
2887 void transform::VerifyOp::getEffects(
2888 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
2896 void transform::YieldOp::getEffects(
2897 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
static MLIRContext * getContext(OpFoldResult val)
static std::string diag(const llvm::Value &value)
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
ParseResult parseSymbolName(StringAttr &result)
Parse an -identifier and store it (without the '@' symbol) in a string attribute.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseOptionalRParen()=0
Parse a ) token if present.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual ParseResult parseColon()=0
Parse a : token.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseArrow()=0
Parse a '->' token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseComma()=0
Parse a , token.
virtual ParseResult parseOptionalLParen()=0
Parse a ( token if present.
ParseResult parseTypeList(SmallVectorImpl< Type > &result)
Parse a type list.
virtual raw_ostream & getStream() const
Return the raw output stream used by this printer.
Attributes are known-constant values of operations.
This class represents an argument of a Block.
Block represents an ordered list of Operations.
BlockArgument getArgument(unsigned i)
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
Operation * getTerminator()
Get the terminator operation of this block.
OpListType & getOperations()
BlockArgListType getArguments()
iterator_range< iterator > without_terminator()
Return an iterator range over the operation within this block excluding the terminator operation at t...
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
This class is a general helper class for creating context-global objects like types,...
IntegerAttr getI64IntegerAttr(int64_t value)
StringAttr getStringAttr(const Twine &bytes)
MLIRContext * getContext() const
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
This class describes a specific conversion target.
A compatibility class connecting InFlightDiagnostic to DiagnosedSilenceableFailure while providing an...
The result of a transform IR operation application.
LogicalResult silence()
Converts silenceable failure into LogicalResult success without reporting the diagnostic,...
static DiagnosedSilenceableFailure success()
Constructs a DiagnosedSilenceableFailure in the success state.
Diagnostic & attachNote(std::optional< Location > loc=std::nullopt)
Attaches a note to the last diagnostic.
std::string getMessage() const
Returns the diagnostic message without emitting it.
bool isDefiniteFailure() const
Returns true if this is a definite failure.
LogicalResult checkAndReport()
Converts all kinds of failure into a LogicalResult failure, emitting the diagnostic if necessary.
bool succeeded() const
Returns true if this is a success.
static DiagnosedSilenceableFailure definiteFailure()
Constructs a DiagnosedSilenceableFailure in the failure state.
bool isSilenceableFailure() const
Returns true if this is a silenceable failure.
Dialects are groups of MLIR operations, types and attributes, as well as behavior associated with the...
A class for computing basic dominance information.
This class represents a frozen set of patterns that can be processed by a pattern applicator.
This class allows control over how the GreedyPatternRewriteDriver works.
static constexpr int64_t kNoLimit
IRValueT get() const
Return the current value being used by this operand.
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
This class represents a diagnostic that is inflight and set to be reported.
This class represents upper and lower bounds on the number of times a region of a RegionBranchOpInter...
Conversion from types to the LLVM IR dialect.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
Dialect * getLoadedDialect(StringRef name)
Get a registered IR dialect with the given namespace.
std::vector< Dialect * > getLoadedDialects()
Return information about all IR dialects loaded in the context.
ArrayRef< RegisteredOperationName > getRegisteredOperations()
Return a sorted array containing the information about all registered operations.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual OptionalParseResult parseOptionalOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single operand if present.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printNewline()=0
Print a newline and indent the printer to the start of the current operation.
virtual void increaseIndent()=0
Increase indentation.
void printOperands(const ContainerType &container)
Print a comma separated list of operands.
virtual void decreaseIndent()=0
Decrease indentation.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Listener * getListener() const
Returns the current listener of this builder, or nullptr if this builder doesn't have a listener.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
This class represents a single result from folding an operation.
This class represents an operand of an operation.
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Set of flags used to control the behavior of the various IR print methods (e.g.
OpPrintingFlags & assumeVerified()
Do not verify the operation when using custom operation printers.
OpPrintingFlags & useLocalScope()
Use local scope when printing the operation.
OpPrintingFlags & skipRegions(bool skip=true)
Skip printing regions.
This is a value defined by a result of an operation.
This class provides the API for ops that are known to be isolated from above.
A trait used to provide symbol table functionalities to a region operation.
This class implements the operand iterators for the Operation class.
type_range getType() const
type_range getTypes() const
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
Operation is the basic unit of execution within MLIR.
Value getOperand(unsigned idx)
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Attribute getAttr(StringAttr name)
Return the specified attribute if present, null otherwise.
Operation * clone(IRMapping &mapper, CloneOptions options=CloneOptions::all())
Create a deep copy of this operation, remapping any operands that use values outside of the operation...
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
void print(raw_ostream &os, const OpPrintingFlags &flags=std::nullopt)
MLIRContext * getContext()
Return the context this operation is associated with.
Location getLoc()
The source location the operation was defined or derived from.
unsigned getNumOperands()
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Block * getBlock()
Returns the operation block that contains this operation.
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
Operation * getParentWithTrait()
Returns the closest surrounding parent operation with trait Trait.
OperationName getName()
The name of an operation is the key identifier for it.
operand_range getOperands()
Returns an iterator on the underlying Value's.
result_range getOpResults()
result_range getResults()
bool isProperAncestor(Operation *other)
Return true if this operation is a proper ancestor of the other operation.
void erase()
Remove this operation from its parent block and delete it.
unsigned getNumResults()
Return the number of results held by this operation.
This class implements Optional functionality for ParseResult.
ParseResult value() const
Access the internal ParseResult value.
bool has_value() const
Returns true if we contain a valid ParseResult value.
static const PassInfo * lookup(StringRef passArg)
Returns the pass info for the specified pass class or null if unknown.
The main pass manager and pipeline builder.
static const PassPipelineInfo * lookup(StringRef pipelineArg)
Returns the pass pipeline info for the specified pass pipeline or null if unknown.
Structure to group information about a passes and pass pipelines (argument to invoke via mlir-opt,...
LogicalResult addToPipeline(OpPassManager &pm, StringRef options, function_ref< LogicalResult(const Twine &)> errorHandler) const
Adds this pass registry entry to the given pass manager.
This class represents a point being branched from in the methods of the RegionBranchOpInterface.
bool isParent() const
Returns true if branching from the parent op.
Region * getRegionOrNull() const
Returns the region if branching from a region.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
BlockArgListType getArguments()
unsigned getRegionNumber()
Return the number of this region in the parent operation.
This is a "type erased" representation of a registered operation.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
static DerivedEffect * get()
Returns a unique instance for the derived effect class.
This class represents a collection of SymbolTables.
Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
static StringRef getSymbolAttrName()
Return the name of the attribute used for symbol names.
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
type_range getType() const
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
void print(raw_ostream &os) const
Type getType() const
Return the type of this value.
use_range getUses() const
Returns a range of all uses, which is useful for iterating over all uses.
user_range getUsers() const
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
A utility result that is used to signal how to proceed with an ongoing walk:
static WalkResult advance()
bool wasInterrupted() const
Returns true if the walk was interrupted.
static WalkResult interrupt()
Operation * getOwner() const
Return the owner of this operand.
A named class for passing around the variadic flag.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
void printFunctionOp(OpAsmPrinter &p, FunctionOpInterface op, bool isVariadic, StringRef typeAttrName, StringAttr argAttrsName, StringAttr resAttrsName)
Printer implementation for function-like operations.
ParseResult parseFunctionOp(OpAsmParser &parser, OperationState &result, bool allowVariadic, StringAttr typeAttrName, FuncTypeBuilder funcTypeBuilder, StringAttr argAttrsName, StringAttr resAttrsName)
Parser implementation for function-like operations.
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
Include the generated interface declarations.
InFlightDiagnostic emitWarning(Location loc)
Utility method to emit a warning message using this location.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
const FrozenRewritePatternSet GreedyRewriteConfig config
LogicalResult applyPatternsGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
DiagnosedSilenceableFailure emitSilenceableFailure(Location loc, const Twine &message={})
Emits a silenceable failure with the given message.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
LogicalResult applyFullConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Apply a complete conversion on the given operations, and all nested operations.
LogicalResult applyOpPatternsGreedily(ArrayRef< Operation * > ops, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr, bool *allErased=nullptr)
Rewrite the specified ops by repeatedly applying the highest benefit patterns in a greedy worklist dr...
DiagnosedDefiniteFailure emitDefiniteFailure(Location loc, const Twine &message={})
Emits a definite failure with the given message.
bool isOpTriviallyDead(Operation *op)
Return true if the given operation is unused, and has no side effects on memory that prevent erasing.
const FrozenRewritePatternSet & patterns
void eliminateCommonSubExpressions(RewriterBase &rewriter, DominanceInfo &domInfo, Operation *op, bool *changed=nullptr)
Eliminate common subexpressions within the given operation.
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.
size_t moveLoopInvariantCode(ArrayRef< Region * > regions, function_ref< bool(Value, Region *)> isDefinedOutsideRegion, function_ref< bool(Operation *, Region *)> shouldMoveOutOfRegion, function_ref< void(Operation *, Region *)> moveOutOfRegion)
Given a list of regions, perform loop-invariant code motion.
Dialect conversion configuration.
RewriterBase::Listener * listener
An optional listener that is notified about all IR modifications in case dialect conversion succeeds.
This is the representation of an operand reference.
This represents an operation in an abstracted form, suitable for use with the builder APIs.
T & getOrAddProperties()
Get (or create) a properties of the provided type to be set on the operation on creation.
void addOperands(ValueRange newOperands)
void addTypes(ArrayRef< Type > newTypes)
Region * addRegion()
Create a region that should be attached to the operation.