37 #include "llvm/ADT/DenseSet.h"
38 #include "llvm/ADT/STLExtras.h"
39 #include "llvm/ADT/ScopeExit.h"
40 #include "llvm/ADT/SmallPtrSet.h"
41 #include "llvm/ADT/TypeSwitch.h"
42 #include "llvm/Support/Debug.h"
43 #include "llvm/Support/ErrorHandling.h"
46 #define DEBUG_TYPE "transform-dialect"
47 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "] ")
49 #define DEBUG_TYPE_MATCHER "transform-matcher"
50 #define DBGS_MATCHER() (llvm::dbgs() << "[" DEBUG_TYPE_MATCHER "] ")
51 #define DEBUG_MATCHER(x) DEBUG_WITH_TYPE(DEBUG_TYPE_MATCHER, x)
56 OpAsmParser &parser, std::optional<OpAsmParser::UnresolvedOperand> &root,
58 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &extraBindings,
59 SmallVectorImpl<Type> &extraBindingTypes);
65 ArrayAttr matchers, ArrayAttr actions);
76 Operation *transformAncestor = transform.getOperation();
77 while (transformAncestor) {
78 if (transformAncestor == payload) {
80 transform.emitDefiniteFailure()
81 <<
"cannot apply transform to itself (or one of its ancestors)";
82 diag.attachNote(payload->
getLoc()) <<
"target payload op";
85 transformAncestor = transformAncestor->
getParentOp();
90 #define GET_OP_CLASSES
91 #include "mlir/Dialect/Transform/IR/TransformOps.cpp.inc"
99 if (!point.
isParent() && getOperation()->getNumOperands() == 1)
100 return getOperation()->getOperands();
102 getOperation()->operand_end());
105 void transform::AlternativesOp::getSuccessorRegions(
107 for (
Region &alternative : llvm::drop_begin(
111 regions.emplace_back(&alternative, !getOperands().empty()
112 ? alternative.getArguments()
116 regions.emplace_back(getOperation()->getResults());
119 void transform::AlternativesOp::getRegionInvocationBounds(
124 bounds.reserve(getNumRegions());
125 bounds.emplace_back(1, 1);
132 results.
set(res, {});
140 if (
Value scopeHandle = getScope())
141 llvm::append_range(originals, state.getPayloadOps(scopeHandle));
143 originals.push_back(state.getTopLevel());
146 if (original->isAncestor(getOperation())) {
148 <<
"scope must not contain the transforms being applied";
149 diag.attachNote(original->getLoc()) <<
"scope";
154 <<
"only isolated-from-above ops can be alternative scopes";
155 diag.attachNote(original->getLoc()) <<
"scope";
160 for (
Region ® : getAlternatives()) {
165 auto scope = state.make_region_scope(reg);
166 auto clones = llvm::to_vector(
167 llvm::map_range(originals, [](
Operation *op) {
return op->
clone(); }));
168 auto deleteClones = llvm::make_scope_exit([&] {
172 if (failed(state.mapBlockArguments(reg.front().getArgument(0), clones)))
176 for (
Operation &transform : reg.front().without_terminator()) {
178 state.applyTransform(cast<TransformOpInterface>(transform));
180 LLVM_DEBUG(
DBGS() <<
"alternative failed: " << result.
getMessage()
186 if (::mlir::failed(result.
silence()))
195 deleteClones.release();
196 TrackingListener listener(state, *
this);
198 for (
const auto &kvp : llvm::zip(originals, clones)) {
209 return emitSilenceableError() <<
"all alternatives failed";
212 void transform::AlternativesOp::getEffects(
213 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
216 for (
Region *region : getRegions()) {
217 if (!region->empty())
224 for (
Region &alternative : getAlternatives()) {
229 <<
"expects terminator operands to have the "
230 "same type as results of the operation";
231 diag.attachNote(terminator->
getLoc()) <<
"terminator";
248 llvm::to_vector(state.getPayloadOps(getTarget()));
251 if (
auto paramH = getParam()) {
253 if (params.size() != 1) {
254 if (targets.size() != params.size()) {
255 return emitSilenceableError()
256 <<
"parameter and target have different payload lengths ("
257 << params.size() <<
" vs " << targets.size() <<
")";
259 for (
auto &&[target, attr] : llvm::zip_equal(targets, params))
260 target->setAttr(getName(), attr);
265 for (
auto *target : targets)
266 target->setAttr(getName(), attr);
270 void transform::AnnotateOp::getEffects(
271 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
282 transform::ApplyCommonSubexpressionEliminationOp::applyToOne(
297 void transform::ApplyCommonSubexpressionEliminationOp::getEffects(
298 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
322 auto addDefiningOpsToWorklist = [&](
Operation *op) {
325 if (
Operation *defOp = v.getDefiningOp())
327 worklist.insert(defOp);
335 const auto *it = llvm::find(worklist, op);
336 if (it != worklist.end())
345 addDefiningOpsToWorklist(op);
351 while (!worklist.empty()) {
355 addDefiningOpsToWorklist(op);
362 void transform::ApplyDeadCodeEliminationOp::getEffects(
363 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
387 if (!getRegion().empty()) {
388 for (
Operation &op : getRegion().front()) {
389 cast<transform::PatternDescriptorOpInterface>(&op)
390 .populatePatternsWithState(patterns, state);
400 config.
maxIterations = getMaxIterations() ==
static_cast<uint64_t
>(-1)
402 : getMaxIterations();
403 config.
maxNumRewrites = getMaxNumRewrites() ==
static_cast<uint64_t
>(-1)
405 : getMaxNumRewrites();
410 bool cseChanged =
false;
413 static const int64_t kNumMaxIterations = 50;
414 int64_t iteration = 0;
416 LogicalResult result = failure();
429 if (target != nestedOp)
430 ops.push_back(nestedOp);
437 if (failed(result)) {
439 <<
"greedy pattern application failed";
447 }
while (cseChanged && ++iteration < kNumMaxIterations);
449 if (iteration == kNumMaxIterations)
456 if (!getRegion().empty()) {
457 for (
Operation &op : getRegion().front()) {
458 if (!isa<transform::PatternDescriptorOpInterface>(&op)) {
460 <<
"expected children ops to implement "
461 "PatternDescriptorOpInterface";
462 diag.attachNote(op.
getLoc()) <<
"op without interface";
470 void transform::ApplyPatternsOp::getEffects(
471 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
476 void transform::ApplyPatternsOp::build(
485 bodyBuilder(builder, result.
location);
492 void transform::ApplyCanonicalizationPatternsOp::populatePatterns(
496 dialect->getCanonicalizationPatterns(patterns);
498 op.getCanonicalizationPatterns(patterns, ctx);
512 std::unique_ptr<TypeConverter> defaultTypeConverter;
513 transform::TypeConverterBuilderOpInterface typeConverterBuilder =
514 getDefaultTypeConverter();
515 if (typeConverterBuilder)
516 defaultTypeConverter = typeConverterBuilder.getTypeConverter();
521 for (
Attribute attr : cast<ArrayAttr>(*getLegalOps()))
522 conversionTarget.addLegalOp(
525 for (
Attribute attr : cast<ArrayAttr>(*getIllegalOps()))
526 conversionTarget.addIllegalOp(
528 if (getLegalDialects())
529 for (
Attribute attr : cast<ArrayAttr>(*getLegalDialects()))
530 conversionTarget.addLegalDialect(cast<StringAttr>(attr).getValue());
531 if (getIllegalDialects())
532 for (
Attribute attr : cast<ArrayAttr>(*getIllegalDialects()))
533 conversionTarget.addIllegalDialect(cast<StringAttr>(attr).getValue());
541 if (!getPatterns().empty()) {
542 for (
Operation &op : getPatterns().front()) {
544 cast<transform::ConversionPatternDescriptorOpInterface>(&op);
547 std::unique_ptr<TypeConverter> typeConverter =
548 descriptor.getTypeConverter();
551 keepAliveConverters.emplace_back(std::move(typeConverter));
552 converter = keepAliveConverters.back().get();
555 if (!defaultTypeConverter) {
557 <<
"pattern descriptor does not specify type "
558 "converter and apply_conversion_patterns op has "
559 "no default type converter";
560 diag.attachNote(op.
getLoc()) <<
"pattern descriptor op";
563 converter = defaultTypeConverter.get();
569 descriptor.populateConversionTargetRules(*converter, conversionTarget);
571 descriptor.populatePatterns(*converter, patterns);
579 TrackingListenerConfig trackingConfig;
580 trackingConfig.requireMatchingReplacementOpName =
false;
581 ErrorCheckingTrackingListener trackingListener(state, *
this, trackingConfig);
583 if (getPreserveHandles())
584 conversionConfig.
listener = &trackingListener;
587 for (
Operation *target : state.getPayloadOps(getTarget())) {
595 LogicalResult status = failure();
596 if (getPartialConversion()) {
606 if (failed(status)) {
607 diag = emitSilenceableError() <<
"dialect conversion failed";
608 diag.attachNote(target->
getLoc()) <<
"target op";
613 trackingListener.checkAndResetError();
615 if (
diag.succeeded()) {
617 return trackingFailure;
619 diag.attachNote() <<
"tracking listener also failed: "
621 (void)trackingFailure.
silence();
625 if (!
diag.succeeded())
633 if (getNumRegions() != 1 && getNumRegions() != 2)
634 return emitOpError() <<
"expected 1 or 2 regions";
635 if (!getPatterns().empty()) {
636 for (
Operation &op : getPatterns().front()) {
637 if (!isa<transform::ConversionPatternDescriptorOpInterface>(&op)) {
639 emitOpError() <<
"expected pattern children ops to implement "
640 "ConversionPatternDescriptorOpInterface";
641 diag.attachNote(op.
getLoc()) <<
"op without interface";
646 if (getNumRegions() == 2) {
647 Region &typeConverterRegion = getRegion(1);
648 if (!llvm::hasSingleElement(typeConverterRegion.
front()))
650 <<
"expected exactly one op in default type converter region";
652 auto typeConverterOp = dyn_cast<transform::TypeConverterBuilderOpInterface>(
654 if (!typeConverterOp) {
656 <<
"expected default converter child op to "
657 "implement TypeConverterBuilderOpInterface";
658 diag.attachNote(maybeTypeConverter->
getLoc()) <<
"op without interface";
662 if (!getPatterns().empty()) {
663 for (
Operation &op : getPatterns().front()) {
665 cast<transform::ConversionPatternDescriptorOpInterface>(&op);
666 if (failed(descriptor.verifyTypeConverter(typeConverterOp)))
674 void transform::ApplyConversionPatternsOp::getEffects(
675 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
676 if (!getPreserveHandles()) {
684 void transform::ApplyConversionPatternsOp::build(
694 if (patternsBodyBuilder)
695 patternsBodyBuilder(builder, result.
location);
701 if (typeConverterBodyBuilder)
702 typeConverterBodyBuilder(builder, result.
location);
710 void transform::ApplyToLLVMConversionPatternsOp::populatePatterns(
713 assert(dialect &&
"expected that dialect is loaded");
714 auto *iface = cast<ConvertToLLVMPatternInterface>(dialect);
718 iface->populateConvertToLLVMConversionPatterns(
722 LogicalResult transform::ApplyToLLVMConversionPatternsOp::verifyTypeConverter(
723 transform::TypeConverterBuilderOpInterface builder) {
724 if (builder.getTypeConverterType() !=
"LLVMTypeConverter")
725 return emitOpError(
"expected LLVMTypeConverter");
732 return emitOpError(
"unknown dialect or dialect not loaded: ")
734 auto *iface = dyn_cast<ConvertToLLVMPatternInterface>(dialect);
737 "dialect does not implement ConvertToLLVMPatternInterface or "
738 "extension was not loaded: ")
748 transform::ApplyLoopInvariantCodeMotionOp::applyToOne(
758 void transform::ApplyLoopInvariantCodeMotionOp::getEffects(
759 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
786 <<
"unknown pass or pass pipeline: " << getPassName();
790 if (failed(info->
addToPipeline(pm, getOptions(), [&](
const Twine &msg) {
795 <<
"failed to add pass or pass pipeline to pipeline: "
798 if (failed(pm.run(target))) {
799 auto diag = emitSilenceableError() <<
"pass pipeline failed";
800 diag.attachNote(target->
getLoc()) <<
"target op";
804 results.push_back(target);
814 Operation *target, ApplyToEachResultList &results,
816 results.push_back(target);
820 void transform::CastOp::getEffects(
821 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
828 assert(inputs.size() == 1 &&
"expected one input");
829 assert(outputs.size() == 1 &&
"expected one output");
831 std::initializer_list<Type>{inputs.front(), outputs.front()},
832 llvm::IsaPred<transform::TransformHandleTypeInterface>);
852 assert(block.
getParent() &&
"cannot match using a detached block");
853 auto matchScope = state.make_region_scope(*block.
getParent());
855 state.mapBlockArguments(block.
getArguments(), blockArgumentMapping)))
859 if (!isa<transform::MatchOpInterface>(match)) {
861 <<
"expected operations in the match part to "
862 "implement MatchOpInterface";
865 state.applyTransform(cast<transform::TransformOpInterface>(match));
866 if (
diag.succeeded())
884 template <
typename... Tys>
886 return ((isa<Tys>(t1) && isa<Tys>(t2)) || ... ||
false);
893 transform::TransformParamTypeInterface,
894 transform::TransformValueHandleTypeInterface>(
906 auto matcher = SymbolTable::lookupNearestSymbolFrom<FunctionOpInterface>(
907 getOperation(), getMatcher());
908 if (matcher.isExternal()) {
910 <<
"unresolved external symbol " << getMatcher();
914 rawResults.resize(getOperation()->getNumResults());
915 std::optional<DiagnosedSilenceableFailure> maybeFailure;
916 for (
Operation *root : state.getPayloadOps(getRoot())) {
920 op->
print(llvm::dbgs(),
922 llvm::dbgs() <<
" @" << op <<
"\n";
929 matcher.getFunctionBody().front(),
932 if (
diag.isDefiniteFailure())
934 if (
diag.isSilenceableFailure()) {
936 <<
" failed: " <<
diag.getMessage());
942 if (mapping.size() != 1) {
943 maybeFailure.emplace(emitSilenceableError()
944 <<
"result #" << i <<
", associated with "
946 <<
" payload objects, expected 1");
949 rawResults[i].push_back(mapping[0]);
954 return std::move(*maybeFailure);
955 assert(!maybeFailure &&
"failure set but the walk was not interrupted");
957 for (
auto &&[opResult, rawResult] :
958 llvm::zip_equal(getOperation()->getResults(), rawResults)) {
965 void transform::CollectMatchingOp::getEffects(
966 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
972 LogicalResult transform::CollectMatchingOp::verifySymbolUses(
974 auto matcherSymbol = dyn_cast_or_null<FunctionOpInterface>(
976 if (!matcherSymbol ||
977 !isa<TransformOpInterface>(matcherSymbol.getOperation()))
978 return emitError() <<
"unresolved matcher symbol " << getMatcher();
981 if (argumentTypes.size() != 1 ||
982 !isa<TransformHandleTypeInterface>(argumentTypes[0])) {
984 <<
"expected the matcher to take one operation handle argument";
986 if (!matcherSymbol.getArgAttr(
987 0, transform::TransformDialect::kArgReadOnlyAttrName)) {
988 return emitError() <<
"expected the matcher argument to be marked readonly";
992 if (resultTypes.size() != getOperation()->getNumResults()) {
994 <<
"expected the matcher to yield as many values as op has results ("
995 << getOperation()->getNumResults() <<
"), got "
996 << resultTypes.size();
999 for (
auto &&[i, matcherType, resultType] :
1005 <<
"mismatching type interfaces for matcher result and op result #"
1017 bool transform::ForeachMatchOp::allowsRepeatedHandleOperands() {
return true; }
1025 matchActionPairs.reserve(getMatchers().size());
1027 for (
auto &&[matcher, action] :
1028 llvm::zip_equal(getMatchers(), getActions())) {
1029 auto matcherSymbol =
1031 getOperation(), cast<SymbolRefAttr>(matcher));
1034 getOperation(), cast<SymbolRefAttr>(action));
1035 assert(matcherSymbol && actionSymbol &&
1036 "unresolved symbols not caught by the verifier");
1038 if (matcherSymbol.isExternal())
1040 if (actionSymbol.isExternal())
1043 matchActionPairs.emplace_back(matcherSymbol, actionSymbol);
1054 matchInputMapping.emplace_back();
1056 getForwardedInputs(), state);
1058 actionResultMapping.resize(getForwardedOutputs().size());
1060 for (
Operation *root : state.getPayloadOps(getRoot())) {
1064 if (!getRestrictRoot() && op == root)
1069 op->
print(llvm::dbgs(),
1071 llvm::dbgs() <<
" @" << op <<
"\n";
1074 firstMatchArgument.clear();
1075 firstMatchArgument.push_back(op);
1078 for (
auto [matcher, action] : matchActionPairs) {
1080 matchBlock(matcher.getFunctionBody().front(), matchInputMapping,
1081 state, matchOutputMapping);
1082 if (
diag.isDefiniteFailure())
1084 if (
diag.isSilenceableFailure()) {
1086 <<
" failed: " <<
diag.getMessage());
1090 auto scope = state.make_region_scope(action.getFunctionBody());
1091 if (failed(state.mapBlockArguments(
1092 action.getFunctionBody().front().getArguments(),
1093 matchOutputMapping))) {
1098 action.getFunctionBody().front().without_terminator()) {
1100 state.applyTransform(cast<TransformOpInterface>(transform));
1105 overallDiag = emitSilenceableError() <<
"actions failed";
1110 <<
"when applied to this matching payload";
1117 action.getFunctionBody().front().getTerminator()->getOperands(),
1118 state, getFlattenResults()))) {
1120 <<
"action @" << action.getName()
1121 <<
" has results associated with multiple payload entities, "
1122 "but flattening was not requested";
1137 results.
set(llvm::cast<OpResult>(getUpdated()),
1138 state.getPayloadOps(getRoot()));
1139 for (
auto &&[result, mapping] :
1140 llvm::zip_equal(getForwardedOutputs(), actionResultMapping)) {
1146 void transform::ForeachMatchOp::getAsmResultNames(
1148 setNameFn(getUpdated(),
"updated_root");
1149 for (
Value v : getForwardedOutputs()) {
1150 setNameFn(v,
"yielded");
1154 void transform::ForeachMatchOp::getEffects(
1155 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
1157 if (getOperation()->getNumOperands() < 1 ||
1158 getOperation()->getNumResults() < 1) {
1171 ArrayAttr &matchers,
1172 ArrayAttr &actions) {
1194 ArrayAttr matchers, ArrayAttr actions) {
1197 for (
auto &&[matcher, action, idx] : llvm::zip_equal(
1198 matchers, actions, llvm::seq<unsigned>(0, matchers.size()))) {
1200 printer << cast<SymbolRefAttr>(matcher) <<
" -> "
1201 << cast<SymbolRefAttr>(action);
1202 if (idx != matchers.size() - 1)
1210 if (getMatchers().size() != getActions().size())
1211 return emitOpError() <<
"expected the same number of matchers and actions";
1212 if (getMatchers().empty())
1213 return emitOpError() <<
"expected at least one match/action pair";
1217 if (matcherNames.insert(name).second)
1220 <<
" is used more than once, only the first match will apply";
1231 bool alsoVerifyInternal =
false) {
1232 auto transformOp = cast<transform::TransformOpInterface>(op.getOperation());
1233 llvm::SmallDenseSet<unsigned> consumedArguments;
1234 if (!op.isExternal()) {
1238 for (
unsigned i = 0, e = op.getNumArguments(); i < e; ++i) {
1240 op.getArgAttr(i, transform::TransformDialect::kArgConsumedAttrName) !=
1243 op.getArgAttr(i, transform::TransformDialect::kArgReadOnlyAttrName) !=
1245 if (isConsumed && isReadOnly) {
1246 return transformOp.emitSilenceableError()
1247 <<
"argument #" << i <<
" cannot be both readonly and consumed";
1249 if ((op.isExternal() || alsoVerifyInternal) && !isConsumed && !isReadOnly) {
1250 return transformOp.emitSilenceableError()
1251 <<
"must provide consumed/readonly status for arguments of "
1252 "external or called ops";
1254 if (op.isExternal())
1257 if (consumedArguments.contains(i) && !isConsumed && isReadOnly) {
1258 return transformOp.emitSilenceableError()
1259 <<
"argument #" << i
1260 <<
" is consumed in the body but is not marked as such";
1262 if (emitWarnings && !consumedArguments.contains(i) && isConsumed) {
1266 <<
"op argument #" << i
1267 <<
" is not consumed in the body but is marked as consumed";
1273 LogicalResult transform::ForeachMatchOp::verifySymbolUses(
1275 assert(getMatchers().size() == getActions().size());
1278 for (
auto &&[matcher, action] :
1279 llvm::zip_equal(getMatchers(), getActions())) {
1281 auto matcherSymbol = dyn_cast_or_null<FunctionOpInterface>(
1283 cast<SymbolRefAttr>(matcher)));
1284 auto actionSymbol = dyn_cast_or_null<FunctionOpInterface>(
1286 cast<SymbolRefAttr>(action)));
1287 if (!matcherSymbol ||
1288 !isa<TransformOpInterface>(matcherSymbol.getOperation()))
1289 return emitError() <<
"unresolved matcher symbol " << matcher;
1290 if (!actionSymbol ||
1291 !isa<TransformOpInterface>(actionSymbol.getOperation()))
1292 return emitError() <<
"unresolved action symbol " << action;
1297 .checkAndReport())) {
1303 .checkAndReport())) {
1308 TypeRange operandTypes = getOperandTypes();
1309 TypeRange matcherArguments = matcherSymbol.getArgumentTypes();
1310 if (operandTypes.size() != matcherArguments.size()) {
1312 emitError() <<
"the number of operands (" << operandTypes.size()
1313 <<
") doesn't match the number of matcher arguments ("
1314 << matcherArguments.size() <<
") for " << matcher;
1315 diag.attachNote(matcherSymbol->getLoc()) <<
"symbol declaration";
1318 for (
auto &&[i, operand, argument] :
1320 if (matcherSymbol.getArgAttr(i, consumedAttr)) {
1323 <<
"does not expect matcher symbol to consume its operand #" << i;
1324 diag.attachNote(matcherSymbol->getLoc()) <<
"symbol declaration";
1333 <<
"mismatching type interfaces for operand and matcher argument #"
1334 << i <<
" of matcher " << matcher;
1335 diag.attachNote(matcherSymbol->getLoc()) <<
"symbol declaration";
1340 TypeRange matcherResults = matcherSymbol.getResultTypes();
1341 TypeRange actionArguments = actionSymbol.getArgumentTypes();
1342 if (matcherResults.size() != actionArguments.size()) {
1343 return emitError() <<
"mismatching number of matcher results and "
1344 "action arguments between "
1345 << matcher <<
" (" << matcherResults.size() <<
") and "
1346 << action <<
" (" << actionArguments.size() <<
")";
1348 for (
auto &&[i, matcherType, actionType] :
1353 return emitError() <<
"mismatching type interfaces for matcher result "
1354 "and action argument #"
1355 << i <<
"of matcher " << matcher <<
" and action "
1360 TypeRange actionResults = actionSymbol.getResultTypes();
1361 auto resultTypes =
TypeRange(getResultTypes()).drop_front();
1362 if (actionResults.size() != resultTypes.size()) {
1364 emitError() <<
"the number of action results ("
1365 << actionResults.size() <<
") for " << action
1366 <<
" doesn't match the number of extra op results ("
1367 << resultTypes.size() <<
")";
1368 diag.attachNote(actionSymbol->getLoc()) <<
"symbol declaration";
1371 for (
auto &&[i, resultType, actionType] :
1377 emitError() <<
"mismatching type interfaces for action result #" << i
1378 <<
" of action " << action <<
" and op result";
1379 diag.attachNote(actionSymbol->getLoc()) <<
"symbol declaration";
1398 size_t numIterations = payloads.empty() ? 0 : payloads.front().size();
1399 bool withZipShortest = getWithZipShortest();
1403 if (withZipShortest) {
1407 return A.size() <
B.size();
1410 for (
size_t argIdx = 0; argIdx < payloads.size(); argIdx++)
1411 payloads[argIdx].resize(numIterations);
1417 for (
size_t argIdx = 1; !withZipShortest && argIdx < payloads.size();
1419 if (payloads[argIdx].size() != numIterations) {
1420 return emitSilenceableError()
1421 <<
"prior targets' payload size (" << numIterations
1422 <<
") differs from payload size (" << payloads[argIdx].size()
1423 <<
") of target " << getTargets()[argIdx];
1432 for (
size_t iterIdx = 0; iterIdx < numIterations; iterIdx++) {
1433 auto scope = state.make_region_scope(getBody());
1439 if (failed(state.mapBlockArgument(blockArg, {argument})))
1444 for (
Operation &transform : getBody().front().without_terminator()) {
1446 llvm::cast<transform::TransformOpInterface>(transform));
1452 OperandRange yieldOperands = getYieldOp().getOperands();
1453 for (
auto &&[result, yieldOperand, resTuple] :
1454 llvm::zip_equal(getResults(), yieldOperands, zippedResults))
1456 if (isa<TransformHandleTypeInterface>(result.getType()))
1457 llvm::append_range(resTuple, state.getPayloadOps(yieldOperand));
1458 else if (isa<TransformValueHandleTypeInterface>(result.getType()))
1459 llvm::append_range(resTuple, state.getPayloadValues(yieldOperand));
1460 else if (isa<TransformParamTypeInterface>(result.getType()))
1461 llvm::append_range(resTuple, state.getParams(yieldOperand));
1463 assert(
false &&
"unhandled handle type");
1467 for (
auto &&[result, resPayload] : zip_equal(getResults(), zippedResults))
1473 void transform::ForeachOp::getEffects(
1474 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
1477 for (
auto &&[target, blockArg] :
1478 llvm::zip(getTargetsMutable(), getBody().front().getArguments())) {
1480 if (any_of(getBody().front().without_terminator(), [&](
Operation &op) {
1482 cast<TransformOpInterface>(&op));
1490 if (any_of(getBody().front().without_terminator(), [&](
Operation &op) {
1494 }
else if (any_of(getBody().front().without_terminator(), [&](
Operation &op) {
1503 void transform::ForeachOp::getSuccessorRegions(
1505 Region *bodyRegion = &getBody();
1507 regions.emplace_back(bodyRegion, bodyRegion->
getArguments());
1512 assert(point == getBody() &&
"unexpected region index");
1513 regions.emplace_back(bodyRegion, bodyRegion->
getArguments());
1514 regions.emplace_back();
1521 assert(point == getBody() &&
"unexpected region index");
1522 return getOperation()->getOperands();
1525 transform::YieldOp transform::ForeachOp::getYieldOp() {
1526 return cast<transform::YieldOp>(getBody().front().getTerminator());
1530 for (
auto [targetOpt, bodyArgOpt] :
1531 llvm::zip_longest(getTargets(), getBody().front().getArguments())) {
1532 if (!targetOpt || !bodyArgOpt)
1533 return emitOpError() <<
"expects the same number of targets as the body "
1534 "has block arguments";
1535 if (targetOpt.value().getType() != bodyArgOpt.value().getType())
1537 "expects co-indexed targets and the body's "
1538 "block arguments to have the same op/value/param type");
1541 for (
auto [resultOpt, yieldOperandOpt] :
1542 llvm::zip_longest(getResults(), getYieldOp().getOperands())) {
1543 if (!resultOpt || !yieldOperandOpt)
1544 return emitOpError() <<
"expects the same number of results as the "
1545 "yield terminator has operands";
1546 if (resultOpt.value().getType() != yieldOperandOpt.value().getType())
1547 return emitOpError(
"expects co-indexed results and yield "
1548 "operands to have the same op/value/param type");
1564 for (
Operation *target : state.getPayloadOps(getTarget())) {
1566 for (int64_t i = 0, e = getNthParent(); i < e; ++i) {
1569 bool checkIsolatedFromAbove =
1570 !getIsolatedFromAbove() ||
1572 bool checkOpName = !getOpName().has_value() ||
1574 if (checkIsolatedFromAbove && checkOpName)
1579 if (getAllowEmptyResults()) {
1580 results.
set(llvm::cast<OpResult>(getResult()), parents);
1584 emitSilenceableError()
1585 <<
"could not find a parent op that matches all requirements";
1586 diag.attachNote(target->
getLoc()) <<
"target op";
1590 if (getDeduplicate()) {
1591 if (!resultSet.contains(parent)) {
1592 parents.push_back(parent);
1593 resultSet.insert(parent);
1596 parents.push_back(parent);
1599 results.
set(llvm::cast<OpResult>(getResult()), parents);
1611 int64_t resultNumber = getResultNumber();
1612 auto payloadOps = state.getPayloadOps(getTarget());
1613 if (std::empty(payloadOps)) {
1614 results.
set(cast<OpResult>(getResult()), {});
1617 if (!llvm::hasSingleElement(payloadOps))
1619 <<
"handle must be mapped to exactly one payload op";
1621 Operation *target = *payloadOps.begin();
1624 results.
set(llvm::cast<OpResult>(getResult()),
1638 for (
Value v : state.getPayloadValues(getTarget())) {
1639 if (llvm::isa<BlockArgument>(v)) {
1641 emitSilenceableError() <<
"cannot get defining op of block argument";
1642 diag.attachNote(v.getLoc()) <<
"target value";
1645 definingOps.push_back(v.getDefiningOp());
1647 results.
set(llvm::cast<OpResult>(getResult()), definingOps);
1659 int64_t operandNumber = getOperandNumber();
1661 for (
Operation *target : state.getPayloadOps(getTarget())) {
1668 emitSilenceableError()
1669 <<
"could not find a producer for operand number: " << operandNumber
1670 <<
" of " << *target;
1671 diag.attachNote(target->getLoc()) <<
"target op";
1674 producers.push_back(producer);
1676 results.
set(llvm::cast<OpResult>(getResult()), producers);
1689 for (
Operation *target : state.getPayloadOps(getTarget())) {
1692 getLoc(), getIsAll(), getIsInverted(), getRawPositionList(),
1693 target->getNumOperands(), operandPositions);
1694 if (
diag.isSilenceableFailure()) {
1695 diag.attachNote(target->getLoc())
1696 <<
"while considering positions of this payload operation";
1699 llvm::append_range(operands,
1700 llvm::map_range(operandPositions, [&](int64_t pos) {
1701 return target->getOperand(pos);
1704 results.
setValues(cast<OpResult>(getResult()), operands);
1710 getIsInverted(), getIsAll());
1722 for (
Operation *target : state.getPayloadOps(getTarget())) {
1725 getLoc(), getIsAll(), getIsInverted(), getRawPositionList(),
1726 target->getNumResults(), resultPositions);
1727 if (
diag.isSilenceableFailure()) {
1728 diag.attachNote(target->getLoc())
1729 <<
"while considering positions of this payload operation";
1732 llvm::append_range(opResults,
1733 llvm::map_range(resultPositions, [&](int64_t pos) {
1734 return target->getResult(pos);
1737 results.
setValues(cast<OpResult>(getResult()), opResults);
1743 getIsInverted(), getIsAll());
1750 void transform::GetTypeOp::getEffects(
1751 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
1762 for (
Value value : state.getPayloadValues(getValue())) {
1763 Type type = value.getType();
1764 if (getElemental()) {
1765 if (
auto shaped = dyn_cast<ShapedType>(type)) {
1766 type = shaped.getElementType();
1771 results.
setParams(cast<OpResult>(getResult()), params);
1788 state.applyTransform(cast<transform::TransformOpInterface>(transform));
1793 if (mode == transform::FailurePropagationMode::Propagate) {
1812 auto callee = SymbolTable::lookupNearestSymbolFrom<NamedSequenceOp>(
1813 getOperation(), getTarget());
1814 assert(callee &&
"unverified reference to unknown symbol");
1816 if (callee.isExternal())
1822 auto scope = state.make_region_scope(callee.getBody());
1823 for (
auto &&[arg, map] :
1824 llvm::zip_equal(callee.getBody().front().getArguments(), mappings)) {
1825 if (failed(state.mapBlockArgument(arg, map)))
1830 callee.getBody().front(), getFailurePropagationMode(), state, results);
1833 mappings, callee.getBody().front().getTerminator()->getOperands(), state);
1834 for (
auto &&[result, mapping] : llvm::zip_equal(getResults(), mappings))
1842 void transform::IncludeOp::getEffects(
1843 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
1857 auto defaultEffects = [&] {
1864 getOperation()->getAttrOfType<SymbolRefAttr>(getTargetAttrName());
1866 return defaultEffects();
1867 auto callee = SymbolTable::lookupNearestSymbolFrom<NamedSequenceOp>(
1868 getOperation(), getTarget());
1870 return defaultEffects();
1874 (void)earlyVerifierResult.
silence();
1875 return defaultEffects();
1878 for (
unsigned i = 0, e = getNumOperands(); i < e; ++i) {
1879 if (callee.getArgAttr(i, TransformDialect::kArgConsumedAttrName))
1890 auto targetAttr = getOperation()->getAttrOfType<SymbolRefAttr>(
"target");
1892 return emitOpError() <<
"expects a 'target' symbol reference attribute";
1897 return emitOpError() <<
"does not reference a named transform sequence";
1899 FunctionType fnType = target.getFunctionType();
1900 if (fnType.getNumInputs() != getNumOperands())
1901 return emitError(
"incorrect number of operands for callee");
1903 for (
unsigned i = 0, e = fnType.getNumInputs(); i != e; ++i) {
1904 if (getOperand(i).
getType() != fnType.getInput(i)) {
1905 return emitOpError(
"operand type mismatch: expected operand type ")
1906 << fnType.getInput(i) <<
", but provided "
1907 << getOperand(i).getType() <<
" for operand number " << i;
1911 if (fnType.getNumResults() != getNumResults())
1912 return emitError(
"incorrect number of results for callee");
1914 for (
unsigned i = 0, e = fnType.getNumResults(); i != e; ++i) {
1915 Type resultType = getResult(i).getType();
1916 Type funcType = fnType.getResult(i);
1918 return emitOpError() <<
"type of result #" << i
1919 <<
" must implement the same transform dialect "
1920 "interface as the corresponding callee result";
1925 cast<FunctionOpInterface>(*target),
false,
1935 ::std::optional<::mlir::Operation *> maybeCurrent,
1937 if (!maybeCurrent.has_value()) {
1942 return emitSilenceableError() <<
"operation is not empty";
1953 for (
auto acceptedAttr : getOpNames().getAsRange<StringAttr>()) {
1954 if (acceptedAttr.getValue() == currentOpName)
1957 return emitSilenceableError() <<
"wrong operation name";
1968 auto signedAPIntAsString = [&](
const APInt &value) {
1970 llvm::raw_string_ostream os(str);
1971 value.print(os,
true);
1978 if (params.size() != references.size()) {
1979 return emitSilenceableError()
1980 <<
"parameters have different payload lengths (" << params.size()
1981 <<
" vs " << references.size() <<
")";
1984 for (
auto &&[i, param, reference] :
llvm::enumerate(params, references)) {
1985 auto intAttr = llvm::dyn_cast<IntegerAttr>(param);
1986 auto refAttr = llvm::dyn_cast<IntegerAttr>(reference);
1987 if (!intAttr || !refAttr) {
1989 <<
"non-integer parameter value not expected";
1991 if (intAttr.getType() != refAttr.getType()) {
1993 <<
"mismatching integer attribute types in parameter #" << i;
1995 APInt value = intAttr.getValue();
1996 APInt refValue = refAttr.getValue();
1999 int64_t position = i;
2000 auto reportError = [&](StringRef direction) {
2002 emitSilenceableError() <<
"expected parameter to be " << direction
2003 <<
" " << signedAPIntAsString(refValue)
2004 <<
", got " << signedAPIntAsString(value);
2005 diag.attachNote(getParam().getLoc())
2006 <<
"value # " << position
2007 <<
" associated with the parameter defined here";
2011 switch (getPredicate()) {
2012 case MatchCmpIPredicate::eq:
2013 if (value.eq(refValue))
2015 return reportError(
"equal to");
2016 case MatchCmpIPredicate::ne:
2017 if (value.ne(refValue))
2019 return reportError(
"not equal to");
2020 case MatchCmpIPredicate::lt:
2021 if (value.slt(refValue))
2023 return reportError(
"less than");
2024 case MatchCmpIPredicate::le:
2025 if (value.sle(refValue))
2027 return reportError(
"less than or equal to");
2028 case MatchCmpIPredicate::gt:
2029 if (value.sgt(refValue))
2031 return reportError(
"greater than");
2032 case MatchCmpIPredicate::ge:
2033 if (value.sge(refValue))
2035 return reportError(
"greater than or equal to");
2041 void transform::MatchParamCmpIOp::getEffects(
2042 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
2055 results.
setParams(cast<OpResult>(getParam()), {getValue()});
2068 if (isa<TransformHandleTypeInterface>(handles.front().
getType())) {
2070 for (
Value operand : handles)
2071 llvm::append_range(operations, state.getPayloadOps(operand));
2072 if (!getDeduplicate()) {
2073 results.
set(llvm::cast<OpResult>(getResult()), operations);
2078 results.
set(llvm::cast<OpResult>(getResult()), uniqued.getArrayRef());
2082 if (llvm::isa<TransformParamTypeInterface>(handles.front().getType())) {
2084 for (
Value attribute : handles)
2085 llvm::append_range(attrs, state.getParams(attribute));
2086 if (!getDeduplicate()) {
2087 results.
setParams(cast<OpResult>(getResult()), attrs);
2092 results.
setParams(cast<OpResult>(getResult()), uniqued.getArrayRef());
2097 llvm::isa<TransformValueHandleTypeInterface>(handles.front().getType()) &&
2098 "expected value handle type");
2100 for (
Value value : handles)
2101 llvm::append_range(payloadValues, state.getPayloadValues(value));
2102 if (!getDeduplicate()) {
2103 results.
setValues(cast<OpResult>(getResult()), payloadValues);
2108 results.
setValues(cast<OpResult>(getResult()), uniqued.getArrayRef());
2112 bool transform::MergeHandlesOp::allowsRepeatedHandleOperands() {
2114 return getDeduplicate();
2117 void transform::MergeHandlesOp::getEffects(
2118 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
2126 OpFoldResult transform::MergeHandlesOp::fold(FoldAdaptor adaptor) {
2127 if (getDeduplicate() || getHandles().size() != 1)
2132 return getHandles().front();
2150 auto scope = state.make_region_scope(getBody());
2152 state, this->getOperation(), getBody())))
2156 FailurePropagationMode::Propagate, state, results);
2159 void transform::NamedSequenceOp::getEffects(
2160 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {}
2165 parser, result,
false,
2166 getFunctionTypeAttrName(result.
name),
2169 std::string &) { return builder.getFunctionType(inputs, results); },
2170 getArgAttrsAttrName(result.
name), getResAttrsAttrName(result.
name));
2175 printer, cast<FunctionOpInterface>(getOperation()),
false,
2176 getFunctionTypeAttrName().getValue(), getArgAttrsAttrName(),
2177 getResAttrsAttrName());
2187 if (
auto parent = op->
getParentOfType<transform::TransformOpInterface>()) {
2190 <<
"cannot be defined inside another transform op";
2191 diag.attachNote(parent.
getLoc()) <<
"ancestor transform op";
2195 if (op.isExternal() || op.getFunctionBody().empty()) {
2202 if (op.getFunctionBody().front().empty())
2205 Operation *terminator = &op.getFunctionBody().front().back();
2206 if (!isa<transform::YieldOp>(terminator)) {
2209 << transform::YieldOp::getOperationName()
2210 <<
"' as terminator";
2211 diag.attachNote(terminator->
getLoc()) <<
"terminator";
2217 <<
"expected terminator to have as many operands as the parent op "
2220 for (
auto [i, operandType, resultType] : llvm::zip_equal(
2223 if (operandType == resultType)
2226 <<
"the type of the terminator operand #" << i
2227 <<
" must match the type of the corresponding parent op result ("
2228 << operandType <<
" vs " << resultType <<
")";
2241 transform::TransformDialect::kWithNamedSequenceAttrName)) {
2244 <<
"expects the parent symbol table to have the '"
2245 << transform::TransformDialect::kWithNamedSequenceAttrName
2247 diag.attachNote(parent->
getLoc()) <<
"symbol table operation";
2252 if (
auto parent = op->
getParentOfType<transform::TransformOpInterface>()) {
2255 <<
"cannot be defined inside another transform op";
2256 diag.attachNote(parent.
getLoc()) <<
"ancestor transform op";
2260 if (op.isExternal() || op.getBody().empty())
2264 if (op.getBody().front().empty())
2267 Operation *terminator = &op.getBody().front().back();
2268 if (!isa<transform::YieldOp>(terminator)) {
2271 << transform::YieldOp::getOperationName()
2272 <<
"' as terminator";
2273 diag.attachNote(terminator->
getLoc()) <<
"terminator";
2279 <<
"expected terminator to have as many operands as the parent op "
2282 for (
auto [i, operandType, resultType] :
2283 llvm::zip_equal(llvm::seq<unsigned>(0, terminator->
getNumOperands()),
2286 if (operandType == resultType)
2289 <<
"the type of the terminator operand #" << i
2290 <<
" must match the type of the corresponding parent op result ("
2291 << operandType <<
" vs " << resultType <<
")";
2294 auto funcOp = cast<FunctionOpInterface>(*op);
2297 if (!
diag.succeeded())
2309 template <
typename FnTy>
2314 types.reserve(1 + extraBindingTypes.size());
2315 types.push_back(bbArgType);
2316 llvm::append_range(types, extraBindingTypes);
2319 Region *region = state.regions.back().get();
2326 if constexpr (llvm::function_traits<FnTy>::num_args == 3) {
2327 bodyBuilder(builder, state.location, bodyBlock->
getArgument(0));
2329 bodyBuilder(builder, state.location, bodyBlock->
getArgument(0),
2334 void transform::NamedSequenceOp::build(
OpBuilder &builder,
2342 state.addAttribute(getFunctionTypeAttrName(state.name),
2344 rootType, resultTypes)));
2345 state.attributes.append(attrs.begin(), attrs.end());
2360 size_t numAssociations =
2362 .Case([&](TransformHandleTypeInterface opHandle) {
2363 return llvm::range_size(state.getPayloadOps(getHandle()));
2365 .Case([&](TransformValueHandleTypeInterface valueHandle) {
2366 return llvm::range_size(state.getPayloadValues(getHandle()));
2368 .Case([&](TransformParamTypeInterface param) {
2369 return llvm::range_size(state.getParams(getHandle()));
2372 llvm_unreachable(
"unknown kind of transform dialect type");
2375 results.
setParams(cast<OpResult>(getNum()),
2382 auto resultType = cast<TransformParamTypeInterface>(getNum().
getType());
2397 auto payloadOps = state.getPayloadOps(getTarget());
2400 result.push_back(op);
2402 results.
set(cast<OpResult>(getResult()), result);
2411 Value target, int64_t numResultHandles) {
2420 int64_t numPayloadOps = llvm::range_size(state.getPayloadOps(getHandle()));
2421 auto produceNumOpsError = [&]() {
2422 return emitSilenceableError()
2423 << getHandle() <<
" expected to contain " << this->getNumResults()
2424 <<
" payload ops but it contains " << numPayloadOps
2430 if (numPayloadOps > getNumResults() && !getOverflowResult().has_value())
2431 return produceNumOpsError();
2436 if (numPayloadOps < getNumResults() && getFailOnPayloadTooSmall() &&
2437 (numPayloadOps != 0 || !getPassThroughEmptyHandle()))
2438 return produceNumOpsError();
2442 if (getOverflowResult())
2443 resultHandles[*getOverflowResult()].reserve(numPayloadOps -
2446 int64_t resultNum = en.index();
2447 if (resultNum >= getNumResults())
2448 resultNum = *getOverflowResult();
2449 resultHandles[resultNum].push_back(en.value());
2454 results.
set(llvm::cast<OpResult>(getResult(it.index())), it.value());
2459 void transform::SplitHandleOp::getEffects(
2460 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
2468 if (getOverflowResult().has_value() &&
2469 !(*getOverflowResult() < getNumResults()))
2470 return emitOpError(
"overflow_result is not a valid result index");
2482 unsigned numRepetitions = llvm::range_size(state.getPayloadOps(getPattern()));
2484 Value handle = en.value();
2485 if (isa<TransformHandleTypeInterface>(handle.
getType())) {
2487 llvm::to_vector(state.getPayloadOps(handle));
2489 payload.reserve(numRepetitions * current.size());
2490 for (
unsigned i = 0; i < numRepetitions; ++i)
2491 llvm::append_range(payload, current);
2492 results.
set(llvm::cast<OpResult>(getReplicated()[en.index()]), payload);
2494 assert(llvm::isa<TransformParamTypeInterface>(handle.
getType()) &&
2495 "expected param type");
2498 params.reserve(numRepetitions * current.size());
2499 for (
unsigned i = 0; i < numRepetitions; ++i)
2500 llvm::append_range(params, current);
2501 results.
setParams(llvm::cast<OpResult>(getReplicated()[en.index()]),
2508 void transform::ReplicateOp::getEffects(
2509 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
2524 auto scope = state.make_region_scope(*getBodyBlock()->getParent());
2525 if (failed(mapBlockArguments(state)))
2533 OpAsmParser &parser, std::optional<OpAsmParser::UnresolvedOperand> &root,
2535 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &extraBindings,
2536 SmallVectorImpl<Type> &extraBindingTypes) {
2540 root = std::nullopt;
2543 if (failed(hasRoot.
value()))
2557 if (failed(parser.
parseType(rootType))) {
2561 if (!extraBindings.empty()) {
2566 if (extraBindingTypes.size() != extraBindings.size()) {
2568 "expected types to be provided for all operands");
2584 bool hasExtras = !extraBindings.empty();
2594 printer << rootType;
2597 llvm::interleaveComma(extraBindingTypes, printer.
getStream());
2606 auto iface = dyn_cast<transform::TransformOpInterface>(use.
getOwner());
2621 if (!potentialConsumer) {
2622 potentialConsumer = &use;
2627 <<
" has more than one potential consumer";
2630 diag.attachNote(use.getOwner()->getLoc())
2631 <<
"used here as operand #" << use.getOperandNumber();
2639 assert(getBodyBlock()->getNumArguments() >= 1 &&
2640 "the number of arguments must have been verified to be more than 1 by "
2641 "PossibleTopLevelTransformOpTrait");
2643 if (!getRoot() && !getExtraBindings().empty()) {
2644 return emitOpError()
2645 <<
"does not expect extra operands when used as top-level";
2651 return (emitOpError() <<
"block argument #" << arg.getArgNumber());
2658 for (
Operation &child : *getBodyBlock()) {
2659 if (!isa<TransformOpInterface>(child) &&
2660 &child != &getBodyBlock()->back()) {
2663 <<
"expected children ops to implement TransformOpInterface";
2664 diag.attachNote(child.getLoc()) <<
"op without interface";
2668 for (
OpResult result : child.getResults()) {
2669 auto report = [&]() {
2670 return (child.emitError() <<
"result #" << result.getResultNumber());
2677 if (!getBodyBlock()->mightHaveTerminator())
2678 return emitOpError() <<
"expects to have a terminator in the body";
2680 if (getBodyBlock()->getTerminator()->getOperandTypes() !=
2681 getOperation()->getResultTypes()) {
2683 <<
"expects the types of the terminator operands "
2684 "to match the types of the result";
2685 diag.attachNote(getBodyBlock()->getTerminator()->getLoc()) <<
"terminator";
2691 void transform::SequenceOp::getEffects(
2692 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
2698 assert(point == getBody() &&
"unexpected region index");
2699 if (getOperation()->getNumOperands() > 0)
2700 return getOperation()->getOperands();
2702 getOperation()->operand_end());
2705 void transform::SequenceOp::getSuccessorRegions(
2708 Region *bodyRegion = &getBody();
2709 regions.emplace_back(bodyRegion, getNumOperands() != 0
2715 assert(point == getBody() &&
"unexpected region index");
2716 regions.emplace_back(getOperation()->getResults());
2719 void transform::SequenceOp::getRegionInvocationBounds(
2722 bounds.emplace_back(1, 1);
2727 FailurePropagationMode failurePropagationMode,
2730 build(builder, state, resultTypes, failurePropagationMode, root,
2739 FailurePropagationMode failurePropagationMode,
2742 build(builder, state, resultTypes, failurePropagationMode, root,
2750 FailurePropagationMode failurePropagationMode,
2753 build(builder, state, resultTypes, failurePropagationMode,
Value(),
2761 FailurePropagationMode failurePropagationMode,
2764 build(builder, state, resultTypes, failurePropagationMode,
Value(),
2780 Value target, StringRef name) {
2782 build(builder, result, name);
2789 llvm::outs() <<
"[[[ IR printer: ";
2790 if (getName().has_value())
2791 llvm::outs() << *getName() <<
" ";
2794 if (getAssumeVerified().value_or(
false))
2796 if (getUseLocalScope().value_or(
false))
2798 if (getSkipRegions().value_or(
false))
2802 llvm::outs() <<
"top-level ]]]\n";
2803 state.getTopLevel()->print(llvm::outs(), printFlags);
2804 llvm::outs() <<
"\n";
2808 llvm::outs() <<
"]]]\n";
2809 for (
Operation *target : state.getPayloadOps(getTarget())) {
2810 target->
print(llvm::outs(), printFlags);
2811 llvm::outs() <<
"\n";
2817 void transform::PrintOp::getEffects(
2818 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
2822 if (!getTargetMutable().empty())
2842 <<
"failed to verify payload op";
2843 diag.attachNote(target->
getLoc()) <<
"payload op";
2849 void transform::VerifyOp::getEffects(
2850 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
2858 void transform::YieldOp::getEffects(
2859 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
static MLIRContext * getContext(OpFoldResult val)
static std::string diag(const llvm::Value &value)
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
ParseResult parseSymbolName(StringAttr &result)
Parse an -identifier and store it (without the '@' symbol) in a string attribute.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseOptionalRParen()=0
Parse a ) token if present.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual ParseResult parseColon()=0
Parse a : token.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseArrow()=0
Parse a '->' token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseComma()=0
Parse a , token.
virtual ParseResult parseOptionalLParen()=0
Parse a ( token if present.
ParseResult parseTypeList(SmallVectorImpl< Type > &result)
Parse a type list.
virtual raw_ostream & getStream() const
Return the raw output stream used by this printer.
Attributes are known-constant values of operations.
This class represents an argument of a Block.
Block represents an ordered list of Operations.
BlockArgument getArgument(unsigned i)
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
Operation * getTerminator()
Get the terminator operation of this block.
OpListType & getOperations()
BlockArgListType getArguments()
iterator_range< iterator > without_terminator()
Return an iterator range over the operation within this block excluding the terminator operation at t...
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
This class is a general helper class for creating context-global objects like types,...
IntegerAttr getI64IntegerAttr(int64_t value)
StringAttr getStringAttr(const Twine &bytes)
MLIRContext * getContext() const
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
This class describes a specific conversion target.
A compatibility class connecting InFlightDiagnostic to DiagnosedSilenceableFailure while providing an...
The result of a transform IR operation application.
LogicalResult silence()
Converts silenceable failure into LogicalResult success without reporting the diagnostic,...
static DiagnosedSilenceableFailure success()
Constructs a DiagnosedSilenceableFailure in the success state.
Diagnostic & attachNote(std::optional< Location > loc=std::nullopt)
Attaches a note to the last diagnostic.
std::string getMessage() const
Returns the diagnostic message without emitting it.
bool isDefiniteFailure() const
Returns true if this is a definite failure.
LogicalResult checkAndReport()
Converts all kinds of failure into a LogicalResult failure, emitting the diagnostic if necessary.
bool succeeded() const
Returns true if this is a success.
static DiagnosedSilenceableFailure definiteFailure()
Constructs a DiagnosedSilenceableFailure in the failure state.
bool isSilenceableFailure() const
Returns true if this is a silenceable failure.
Dialects are groups of MLIR operations, types and attributes, as well as behavior associated with the...
A class for computing basic dominance information.
This class represents a frozen set of patterns that can be processed by a pattern applicator.
This class allows control over how the GreedyPatternRewriteDriver works.
static constexpr int64_t kNoLimit
int64_t maxIterations
This specifies the maximum number of times the rewriter will iterate between applying patterns and si...
RewriterBase::Listener * listener
An optional listener that should be notified about IR modifications.
int64_t maxNumRewrites
This specifies the maximum number of rewrites within an iteration.
IRValueT get() const
Return the current value being used by this operand.
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
This class represents a diagnostic that is inflight and set to be reported.
This class represents upper and lower bounds on the number of times a region of a RegionBranchOpInter...
Conversion from types to the LLVM IR dialect.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
Dialect * getLoadedDialect(StringRef name)
Get a registered IR dialect with the given namespace.
std::vector< Dialect * > getLoadedDialects()
Return information about all IR dialects loaded in the context.
ArrayRef< RegisteredOperationName > getRegisteredOperations()
Return a sorted array containing the information about all registered operations.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual OptionalParseResult parseOptionalOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single operand if present.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printNewline()=0
Print a newline and indent the printer to the start of the current operation.
virtual void increaseIndent()=0
Increase indentation.
void printOperands(const ContainerType &container)
Print a comma separated list of operands.
virtual void decreaseIndent()=0
Decrease indentation.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Listener * getListener() const
Returns the current listener of this builder, or nullptr if this builder doesn't have a listener.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
This class represents a single result from folding an operation.
This class represents an operand of an operation.
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Set of flags used to control the behavior of the various IR print methods (e.g.
OpPrintingFlags & assumeVerified()
Do not verify the operation when using custom operation printers.
OpPrintingFlags & useLocalScope()
Use local scope when printing the operation.
OpPrintingFlags & skipRegions(bool skip=true)
Skip printing regions.
This is a value defined by a result of an operation.
This class provides the API for ops that are known to be isolated from above.
A trait used to provide symbol table functionalities to a region operation.
This class implements the operand iterators for the Operation class.
type_range getType() const
type_range getTypes() const
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
Operation is the basic unit of execution within MLIR.
Value getOperand(unsigned idx)
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Attribute getAttr(StringAttr name)
Return the specified attribute if present, null otherwise.
Operation * clone(IRMapping &mapper, CloneOptions options=CloneOptions::all())
Create a deep copy of this operation, remapping any operands that use values outside of the operation...
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
void print(raw_ostream &os, const OpPrintingFlags &flags=std::nullopt)
MLIRContext * getContext()
Return the context this operation is associated with.
Location getLoc()
The source location the operation was defined or derived from.
unsigned getNumOperands()
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Block * getBlock()
Returns the operation block that contains this operation.
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
Operation * getParentWithTrait()
Returns the closest surrounding parent operation with trait Trait.
OperationName getName()
The name of an operation is the key identifier for it.
result_type_range getResultTypes()
operand_range getOperands()
Returns an iterator on the underlying Value's.
result_range getOpResults()
result_range getResults()
bool isProperAncestor(Operation *other)
Return true if this operation is a proper ancestor of the other operation.
void erase()
Remove this operation from its parent block and delete it.
unsigned getNumResults()
Return the number of results held by this operation.
This class implements Optional functionality for ParseResult.
ParseResult value() const
Access the internal ParseResult value.
bool has_value() const
Returns true if we contain a valid ParseResult value.
static const PassInfo * lookup(StringRef passArg)
Returns the pass info for the specified pass class or null if unknown.
The main pass manager and pipeline builder.
static const PassPipelineInfo * lookup(StringRef pipelineArg)
Returns the pass pipeline info for the specified pass pipeline or null if unknown.
Structure to group information about a passes and pass pipelines (argument to invoke via mlir-opt,...
LogicalResult addToPipeline(OpPassManager &pm, StringRef options, function_ref< LogicalResult(const Twine &)> errorHandler) const
Adds this pass registry entry to the given pass manager.
This class represents a point being branched from in the methods of the RegionBranchOpInterface.
bool isParent() const
Returns true if branching from the parent op.
Region * getRegionOrNull() const
Returns the region if branching from a region.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
BlockArgListType getArguments()
unsigned getRegionNumber()
Return the number of this region in the parent operation.
This is a "type erased" representation of a registered operation.
MLIRContext * getContext() const
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
static DerivedEffect * get()
Returns a unique instance for the derived effect class.
This class represents a collection of SymbolTables.
Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
static StringRef getSymbolAttrName()
Return the name of the attribute used for symbol names.
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
type_range getType() const
type_range getTypes() const
size_t size() const
Return the size of this range.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
void print(raw_ostream &os) const
Type getType() const
Return the type of this value.
use_range getUses() const
Returns a range of all uses, which is useful for iterating over all uses.
user_range getUsers() const
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
A utility result that is used to signal how to proceed with an ongoing walk:
static WalkResult advance()
bool wasInterrupted() const
Returns true if the walk was interrupted.
static WalkResult interrupt()
Operation * getOwner() const
Return the owner of this operand.
A named class for passing around the variadic flag.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
void printFunctionOp(OpAsmPrinter &p, FunctionOpInterface op, bool isVariadic, StringRef typeAttrName, StringAttr argAttrsName, StringAttr resAttrsName)
Printer implementation for function-like operations.
ParseResult parseFunctionOp(OpAsmParser &parser, OperationState &result, bool allowVariadic, StringAttr typeAttrName, FuncTypeBuilder funcTypeBuilder, StringAttr argAttrsName, StringAttr resAttrsName)
Parser implementation for function-like operations.
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
Include the generated interface declarations.
InFlightDiagnostic emitWarning(Location loc)
Utility method to emit a warning message using this location.
LogicalResult applyOpPatternsAndFold(ArrayRef< Operation * > ops, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr, bool *allErased=nullptr)
Rewrite the specified ops by repeatedly applying the highest benefit patterns in a greedy worklist dr...
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
DiagnosedSilenceableFailure emitSilenceableFailure(Location loc, const Twine &message={})
Emits a silenceable failure with the given message.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
LogicalResult applyFullConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Apply a complete conversion on the given operations, and all nested operations.
DiagnosedDefiniteFailure emitDefiniteFailure(Location loc, const Twine &message={})
Emits a definite failure with the given message.
bool isOpTriviallyDead(Operation *op)
Return true if the given operation is unused, and has no side effects on memory that prevent erasing.
LogicalResult applyPatternsAndFoldGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
void eliminateCommonSubExpressions(RewriterBase &rewriter, DominanceInfo &domInfo, Operation *op, bool *changed=nullptr)
Eliminate common subexpressions within the given operation.
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.
size_t moveLoopInvariantCode(ArrayRef< Region * > regions, function_ref< bool(Value, Region *)> isDefinedOutsideRegion, function_ref< bool(Operation *, Region *)> shouldMoveOutOfRegion, function_ref< void(Operation *, Region *)> moveOutOfRegion)
Given a list of regions, perform loop-invariant code motion.
Dialect conversion configuration.
RewriterBase::Listener * listener
An optional listener that is notified about all IR modifications in case dialect conversion succeeds.
This is the representation of an operand reference.
This represents an operation in an abstracted form, suitable for use with the builder APIs.
T & getOrAddProperties()
Get (or create) a properties of the provided type to be set on the operation on creation.
void addOperands(ValueRange newOperands)
void addTypes(ArrayRef< Type > newTypes)
Region * addRegion()
Create a region that should be attached to the operation.