37 #include "llvm/ADT/DenseSet.h"
38 #include "llvm/ADT/STLExtras.h"
39 #include "llvm/ADT/ScopeExit.h"
40 #include "llvm/ADT/SmallPtrSet.h"
41 #include "llvm/ADT/TypeSwitch.h"
42 #include "llvm/Support/Debug.h"
43 #include "llvm/Support/ErrorHandling.h"
44 #include "llvm/Support/InterleavedRange.h"
47 #define DEBUG_TYPE "transform-dialect"
48 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "] ")
50 #define DEBUG_TYPE_MATCHER "transform-matcher"
51 #define DBGS_MATCHER() (llvm::dbgs() << "[" DEBUG_TYPE_MATCHER "] ")
52 #define DEBUG_MATCHER(x) DEBUG_WITH_TYPE(DEBUG_TYPE_MATCHER, x)
57 OpAsmParser &parser, std::optional<OpAsmParser::UnresolvedOperand> &root,
59 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &extraBindings,
60 SmallVectorImpl<Type> &extraBindingTypes);
66 ArrayAttr matchers, ArrayAttr actions);
77 Operation *transformAncestor = transform.getOperation();
78 while (transformAncestor) {
79 if (transformAncestor == payload) {
81 transform.emitDefiniteFailure()
82 <<
"cannot apply transform to itself (or one of its ancestors)";
83 diag.attachNote(payload->
getLoc()) <<
"target payload op";
86 transformAncestor = transformAncestor->
getParentOp();
91 #define GET_OP_CLASSES
92 #include "mlir/Dialect/Transform/IR/TransformOps.cpp.inc"
100 if (!point.
isParent() && getOperation()->getNumOperands() == 1)
101 return getOperation()->getOperands();
103 getOperation()->operand_end());
106 void transform::AlternativesOp::getSuccessorRegions(
108 for (
Region &alternative : llvm::drop_begin(
112 regions.emplace_back(&alternative, !getOperands().empty()
113 ? alternative.getArguments()
117 regions.emplace_back(getOperation()->getResults());
120 void transform::AlternativesOp::getRegionInvocationBounds(
125 bounds.reserve(getNumRegions());
126 bounds.emplace_back(1, 1);
133 results.
set(res, {});
141 if (
Value scopeHandle = getScope())
142 llvm::append_range(originals, state.getPayloadOps(scopeHandle));
144 originals.push_back(state.getTopLevel());
147 if (original->isAncestor(getOperation())) {
149 <<
"scope must not contain the transforms being applied";
150 diag.attachNote(original->getLoc()) <<
"scope";
155 <<
"only isolated-from-above ops can be alternative scopes";
156 diag.attachNote(original->getLoc()) <<
"scope";
161 for (
Region ® : getAlternatives()) {
166 auto scope = state.make_region_scope(reg);
167 auto clones = llvm::to_vector(
168 llvm::map_range(originals, [](
Operation *op) {
return op->
clone(); }));
169 auto deleteClones = llvm::make_scope_exit([&] {
173 if (failed(state.mapBlockArguments(reg.front().getArgument(0), clones)))
177 for (
Operation &transform : reg.front().without_terminator()) {
179 state.applyTransform(cast<TransformOpInterface>(transform));
181 LLVM_DEBUG(
DBGS() <<
"alternative failed: " << result.
getMessage()
187 if (::mlir::failed(result.
silence()))
196 deleteClones.release();
197 TrackingListener listener(state, *
this);
199 for (
const auto &kvp : llvm::zip(originals, clones)) {
210 return emitSilenceableError() <<
"all alternatives failed";
213 void transform::AlternativesOp::getEffects(
214 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
217 for (
Region *region : getRegions()) {
218 if (!region->empty())
225 for (
Region &alternative : getAlternatives()) {
230 <<
"expects terminator operands to have the "
231 "same type as results of the operation";
232 diag.attachNote(terminator->
getLoc()) <<
"terminator";
249 llvm::to_vector(state.getPayloadOps(getTarget()));
252 if (
auto paramH = getParam()) {
254 if (params.size() != 1) {
255 if (targets.size() != params.size()) {
256 return emitSilenceableError()
257 <<
"parameter and target have different payload lengths ("
258 << params.size() <<
" vs " << targets.size() <<
")";
260 for (
auto &&[target, attr] : llvm::zip_equal(targets, params))
261 target->setAttr(getName(), attr);
266 for (
auto *target : targets)
267 target->setAttr(getName(), attr);
271 void transform::AnnotateOp::getEffects(
272 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
283 transform::ApplyCommonSubexpressionEliminationOp::applyToOne(
298 void transform::ApplyCommonSubexpressionEliminationOp::getEffects(
299 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
323 auto addDefiningOpsToWorklist = [&](
Operation *op) {
326 if (
Operation *defOp = v.getDefiningOp())
328 worklist.insert(defOp);
336 const auto *it = llvm::find(worklist, op);
337 if (it != worklist.end())
346 addDefiningOpsToWorklist(op);
352 while (!worklist.empty()) {
356 addDefiningOpsToWorklist(op);
363 void transform::ApplyDeadCodeEliminationOp::getEffects(
364 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
388 if (!getRegion().empty()) {
389 for (
Operation &op : getRegion().front()) {
390 cast<transform::PatternDescriptorOpInterface>(&op)
391 .populatePatternsWithState(
patterns, state);
401 config.maxIterations = getMaxIterations() ==
static_cast<uint64_t
>(-1)
403 : getMaxIterations();
404 config.maxNumRewrites = getMaxNumRewrites() ==
static_cast<uint64_t
>(-1)
406 : getMaxNumRewrites();
411 bool cseChanged =
false;
414 static const int64_t kNumMaxIterations = 50;
415 int64_t iteration = 0;
417 LogicalResult result = failure();
430 if (target != nestedOp)
431 ops.push_back(nestedOp);
438 if (failed(result)) {
440 <<
"greedy pattern application failed";
448 }
while (cseChanged && ++iteration < kNumMaxIterations);
450 if (iteration == kNumMaxIterations)
457 if (!getRegion().empty()) {
458 for (
Operation &op : getRegion().front()) {
459 if (!isa<transform::PatternDescriptorOpInterface>(&op)) {
461 <<
"expected children ops to implement "
462 "PatternDescriptorOpInterface";
463 diag.attachNote(op.
getLoc()) <<
"op without interface";
471 void transform::ApplyPatternsOp::getEffects(
472 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
477 void transform::ApplyPatternsOp::build(
486 bodyBuilder(builder, result.
location);
493 void transform::ApplyCanonicalizationPatternsOp::populatePatterns(
497 dialect->getCanonicalizationPatterns(
patterns);
499 op.getCanonicalizationPatterns(
patterns, ctx);
513 std::unique_ptr<TypeConverter> defaultTypeConverter;
514 transform::TypeConverterBuilderOpInterface typeConverterBuilder =
515 getDefaultTypeConverter();
516 if (typeConverterBuilder)
517 defaultTypeConverter = typeConverterBuilder.getTypeConverter();
522 for (
Attribute attr : cast<ArrayAttr>(*getLegalOps()))
523 conversionTarget.addLegalOp(
526 for (
Attribute attr : cast<ArrayAttr>(*getIllegalOps()))
527 conversionTarget.addIllegalOp(
529 if (getLegalDialects())
530 for (
Attribute attr : cast<ArrayAttr>(*getLegalDialects()))
531 conversionTarget.addLegalDialect(cast<StringAttr>(attr).getValue());
532 if (getIllegalDialects())
533 for (
Attribute attr : cast<ArrayAttr>(*getIllegalDialects()))
534 conversionTarget.addIllegalDialect(cast<StringAttr>(attr).getValue());
542 if (!getPatterns().empty()) {
543 for (
Operation &op : getPatterns().front()) {
545 cast<transform::ConversionPatternDescriptorOpInterface>(&op);
548 std::unique_ptr<TypeConverter> typeConverter =
549 descriptor.getTypeConverter();
552 keepAliveConverters.emplace_back(std::move(typeConverter));
553 converter = keepAliveConverters.back().get();
556 if (!defaultTypeConverter) {
558 <<
"pattern descriptor does not specify type "
559 "converter and apply_conversion_patterns op has "
560 "no default type converter";
561 diag.attachNote(op.
getLoc()) <<
"pattern descriptor op";
564 converter = defaultTypeConverter.get();
570 descriptor.populateConversionTargetRules(*converter, conversionTarget);
572 descriptor.populatePatterns(*converter,
patterns);
580 TrackingListenerConfig trackingConfig;
581 trackingConfig.requireMatchingReplacementOpName =
false;
582 ErrorCheckingTrackingListener trackingListener(state, *
this, trackingConfig);
584 if (getPreserveHandles())
585 conversionConfig.
listener = &trackingListener;
588 for (
Operation *target : state.getPayloadOps(getTarget())) {
596 LogicalResult status = failure();
597 if (getPartialConversion()) {
607 if (failed(status)) {
608 diag = emitSilenceableError() <<
"dialect conversion failed";
609 diag.attachNote(target->
getLoc()) <<
"target op";
614 trackingListener.checkAndResetError();
616 if (
diag.succeeded()) {
618 return trackingFailure;
620 diag.attachNote() <<
"tracking listener also failed: "
622 (void)trackingFailure.
silence();
626 if (!
diag.succeeded())
634 if (getNumRegions() != 1 && getNumRegions() != 2)
635 return emitOpError() <<
"expected 1 or 2 regions";
636 if (!getPatterns().empty()) {
637 for (
Operation &op : getPatterns().front()) {
638 if (!isa<transform::ConversionPatternDescriptorOpInterface>(&op)) {
640 emitOpError() <<
"expected pattern children ops to implement "
641 "ConversionPatternDescriptorOpInterface";
642 diag.attachNote(op.
getLoc()) <<
"op without interface";
647 if (getNumRegions() == 2) {
648 Region &typeConverterRegion = getRegion(1);
649 if (!llvm::hasSingleElement(typeConverterRegion.
front()))
651 <<
"expected exactly one op in default type converter region";
653 auto typeConverterOp = dyn_cast<transform::TypeConverterBuilderOpInterface>(
655 if (!typeConverterOp) {
657 <<
"expected default converter child op to "
658 "implement TypeConverterBuilderOpInterface";
659 diag.attachNote(maybeTypeConverter->
getLoc()) <<
"op without interface";
663 if (!getPatterns().empty()) {
664 for (
Operation &op : getPatterns().front()) {
666 cast<transform::ConversionPatternDescriptorOpInterface>(&op);
667 if (failed(descriptor.verifyTypeConverter(typeConverterOp)))
675 void transform::ApplyConversionPatternsOp::getEffects(
676 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
677 if (!getPreserveHandles()) {
685 void transform::ApplyConversionPatternsOp::build(
695 if (patternsBodyBuilder)
696 patternsBodyBuilder(builder, result.
location);
702 if (typeConverterBodyBuilder)
703 typeConverterBodyBuilder(builder, result.
location);
711 void transform::ApplyToLLVMConversionPatternsOp::populatePatterns(
714 assert(dialect &&
"expected that dialect is loaded");
715 auto *iface = cast<ConvertToLLVMPatternInterface>(dialect);
719 iface->populateConvertToLLVMConversionPatterns(
723 LogicalResult transform::ApplyToLLVMConversionPatternsOp::verifyTypeConverter(
724 transform::TypeConverterBuilderOpInterface builder) {
725 if (builder.getTypeConverterType() !=
"LLVMTypeConverter")
726 return emitOpError(
"expected LLVMTypeConverter");
733 return emitOpError(
"unknown dialect or dialect not loaded: ")
735 auto *iface = dyn_cast<ConvertToLLVMPatternInterface>(dialect);
738 "dialect does not implement ConvertToLLVMPatternInterface or "
739 "extension was not loaded: ")
749 transform::ApplyLoopInvariantCodeMotionOp::applyToOne(
759 void transform::ApplyLoopInvariantCodeMotionOp::getEffects(
760 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
787 <<
"unknown pass or pass pipeline: " << getPassName();
791 if (failed(info->
addToPipeline(pm, getOptions(), [&](
const Twine &msg) {
796 <<
"failed to add pass or pass pipeline to pipeline: "
799 if (failed(pm.run(target))) {
800 auto diag = emitSilenceableError() <<
"pass pipeline failed";
801 diag.attachNote(target->
getLoc()) <<
"target op";
805 results.push_back(target);
815 Operation *target, ApplyToEachResultList &results,
817 results.push_back(target);
821 void transform::CastOp::getEffects(
822 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
829 assert(inputs.size() == 1 &&
"expected one input");
830 assert(outputs.size() == 1 &&
"expected one output");
832 std::initializer_list<Type>{inputs.front(), outputs.front()},
833 llvm::IsaPred<transform::TransformHandleTypeInterface>);
853 assert(block.
getParent() &&
"cannot match using a detached block");
854 auto matchScope = state.make_region_scope(*block.
getParent());
856 state.mapBlockArguments(block.
getArguments(), blockArgumentMapping)))
860 if (!isa<transform::MatchOpInterface>(match)) {
862 <<
"expected operations in the match part to "
863 "implement MatchOpInterface";
866 state.applyTransform(cast<transform::TransformOpInterface>(match));
867 if (
diag.succeeded())
885 template <
typename... Tys>
887 return ((isa<Tys>(t1) && isa<Tys>(t2)) || ... ||
false);
894 transform::TransformParamTypeInterface,
895 transform::TransformValueHandleTypeInterface>(
907 auto matcher = SymbolTable::lookupNearestSymbolFrom<FunctionOpInterface>(
908 getOperation(), getMatcher());
909 if (matcher.isExternal()) {
911 <<
"unresolved external symbol " << getMatcher();
915 rawResults.resize(getOperation()->getNumResults());
916 std::optional<DiagnosedSilenceableFailure> maybeFailure;
917 for (
Operation *root : state.getPayloadOps(getRoot())) {
921 op->
print(llvm::dbgs(),
923 llvm::dbgs() <<
" @" << op <<
"\n";
930 matcher.getFunctionBody().front(),
933 if (
diag.isDefiniteFailure())
935 if (
diag.isSilenceableFailure()) {
937 <<
" failed: " <<
diag.getMessage());
943 if (mapping.size() != 1) {
944 maybeFailure.emplace(emitSilenceableError()
945 <<
"result #" << i <<
", associated with "
947 <<
" payload objects, expected 1");
950 rawResults[i].push_back(mapping[0]);
955 return std::move(*maybeFailure);
956 assert(!maybeFailure &&
"failure set but the walk was not interrupted");
958 for (
auto &&[opResult, rawResult] :
959 llvm::zip_equal(getOperation()->getResults(), rawResults)) {
966 void transform::CollectMatchingOp::getEffects(
967 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
973 LogicalResult transform::CollectMatchingOp::verifySymbolUses(
975 auto matcherSymbol = dyn_cast_or_null<FunctionOpInterface>(
977 if (!matcherSymbol ||
978 !isa<TransformOpInterface>(matcherSymbol.getOperation()))
979 return emitError() <<
"unresolved matcher symbol " << getMatcher();
982 if (argumentTypes.size() != 1 ||
983 !isa<TransformHandleTypeInterface>(argumentTypes[0])) {
985 <<
"expected the matcher to take one operation handle argument";
987 if (!matcherSymbol.getArgAttr(
988 0, transform::TransformDialect::kArgReadOnlyAttrName)) {
989 return emitError() <<
"expected the matcher argument to be marked readonly";
993 if (resultTypes.size() != getOperation()->getNumResults()) {
995 <<
"expected the matcher to yield as many values as op has results ("
996 << getOperation()->getNumResults() <<
"), got "
997 << resultTypes.size();
1000 for (
auto &&[i, matcherType, resultType] :
1006 <<
"mismatching type interfaces for matcher result and op result #"
1018 bool transform::ForeachMatchOp::allowsRepeatedHandleOperands() {
return true; }
1026 matchActionPairs.reserve(getMatchers().size());
1028 for (
auto &&[matcher, action] :
1029 llvm::zip_equal(getMatchers(), getActions())) {
1030 auto matcherSymbol =
1032 getOperation(), cast<SymbolRefAttr>(matcher));
1035 getOperation(), cast<SymbolRefAttr>(action));
1036 assert(matcherSymbol && actionSymbol &&
1037 "unresolved symbols not caught by the verifier");
1039 if (matcherSymbol.isExternal())
1041 if (actionSymbol.isExternal())
1044 matchActionPairs.emplace_back(matcherSymbol, actionSymbol);
1055 matchInputMapping.emplace_back();
1057 getForwardedInputs(), state);
1059 actionResultMapping.resize(getForwardedOutputs().size());
1061 for (
Operation *root : state.getPayloadOps(getRoot())) {
1065 if (!getRestrictRoot() && op == root)
1070 op->
print(llvm::dbgs(),
1072 llvm::dbgs() <<
" @" << op <<
"\n";
1075 firstMatchArgument.clear();
1076 firstMatchArgument.push_back(op);
1079 for (
auto [matcher, action] : matchActionPairs) {
1081 matchBlock(matcher.getFunctionBody().front(), matchInputMapping,
1082 state, matchOutputMapping);
1083 if (
diag.isDefiniteFailure())
1085 if (
diag.isSilenceableFailure()) {
1087 <<
" failed: " <<
diag.getMessage());
1091 auto scope = state.make_region_scope(action.getFunctionBody());
1092 if (failed(state.mapBlockArguments(
1093 action.getFunctionBody().front().getArguments(),
1094 matchOutputMapping))) {
1099 action.getFunctionBody().front().without_terminator()) {
1101 state.applyTransform(cast<TransformOpInterface>(transform));
1106 overallDiag = emitSilenceableError() <<
"actions failed";
1111 <<
"when applied to this matching payload";
1118 action.getFunctionBody().front().getTerminator()->getOperands(),
1119 state, getFlattenResults()))) {
1121 <<
"action @" << action.getName()
1122 <<
" has results associated with multiple payload entities, "
1123 "but flattening was not requested";
1138 results.
set(llvm::cast<OpResult>(getUpdated()),
1139 state.getPayloadOps(getRoot()));
1140 for (
auto &&[result, mapping] :
1141 llvm::zip_equal(getForwardedOutputs(), actionResultMapping)) {
1147 void transform::ForeachMatchOp::getAsmResultNames(
1149 setNameFn(getUpdated(),
"updated_root");
1150 for (
Value v : getForwardedOutputs()) {
1151 setNameFn(v,
"yielded");
1155 void transform::ForeachMatchOp::getEffects(
1156 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
1158 if (getOperation()->getNumOperands() < 1 ||
1159 getOperation()->getNumResults() < 1) {
1172 ArrayAttr &matchers,
1173 ArrayAttr &actions) {
1195 ArrayAttr matchers, ArrayAttr actions) {
1198 for (
auto &&[matcher, action, idx] : llvm::zip_equal(
1199 matchers, actions, llvm::seq<unsigned>(0, matchers.size()))) {
1201 printer << cast<SymbolRefAttr>(matcher) <<
" -> "
1202 << cast<SymbolRefAttr>(action);
1203 if (idx != matchers.size() - 1)
1211 if (getMatchers().size() != getActions().size())
1212 return emitOpError() <<
"expected the same number of matchers and actions";
1213 if (getMatchers().empty())
1214 return emitOpError() <<
"expected at least one match/action pair";
1218 if (matcherNames.insert(name).second)
1221 <<
" is used more than once, only the first match will apply";
1232 bool alsoVerifyInternal =
false) {
1233 auto transformOp = cast<transform::TransformOpInterface>(op.getOperation());
1234 llvm::SmallDenseSet<unsigned> consumedArguments;
1235 if (!op.isExternal()) {
1239 for (
unsigned i = 0, e = op.getNumArguments(); i < e; ++i) {
1241 op.getArgAttr(i, transform::TransformDialect::kArgConsumedAttrName) !=
1244 op.getArgAttr(i, transform::TransformDialect::kArgReadOnlyAttrName) !=
1246 if (isConsumed && isReadOnly) {
1247 return transformOp.emitSilenceableError()
1248 <<
"argument #" << i <<
" cannot be both readonly and consumed";
1250 if ((op.isExternal() || alsoVerifyInternal) && !isConsumed && !isReadOnly) {
1251 return transformOp.emitSilenceableError()
1252 <<
"must provide consumed/readonly status for arguments of "
1253 "external or called ops";
1255 if (op.isExternal())
1258 if (consumedArguments.contains(i) && !isConsumed && isReadOnly) {
1259 return transformOp.emitSilenceableError()
1260 <<
"argument #" << i
1261 <<
" is consumed in the body but is not marked as such";
1263 if (emitWarnings && !consumedArguments.contains(i) && isConsumed) {
1267 <<
"op argument #" << i
1268 <<
" is not consumed in the body but is marked as consumed";
1274 LogicalResult transform::ForeachMatchOp::verifySymbolUses(
1276 assert(getMatchers().size() == getActions().size());
1279 for (
auto &&[matcher, action] :
1280 llvm::zip_equal(getMatchers(), getActions())) {
1282 auto matcherSymbol = dyn_cast_or_null<FunctionOpInterface>(
1284 cast<SymbolRefAttr>(matcher)));
1285 auto actionSymbol = dyn_cast_or_null<FunctionOpInterface>(
1287 cast<SymbolRefAttr>(action)));
1288 if (!matcherSymbol ||
1289 !isa<TransformOpInterface>(matcherSymbol.getOperation()))
1290 return emitError() <<
"unresolved matcher symbol " << matcher;
1291 if (!actionSymbol ||
1292 !isa<TransformOpInterface>(actionSymbol.getOperation()))
1293 return emitError() <<
"unresolved action symbol " << action;
1298 .checkAndReport())) {
1304 .checkAndReport())) {
1309 TypeRange operandTypes = getOperandTypes();
1310 TypeRange matcherArguments = matcherSymbol.getArgumentTypes();
1311 if (operandTypes.size() != matcherArguments.size()) {
1313 emitError() <<
"the number of operands (" << operandTypes.size()
1314 <<
") doesn't match the number of matcher arguments ("
1315 << matcherArguments.size() <<
") for " << matcher;
1316 diag.attachNote(matcherSymbol->getLoc()) <<
"symbol declaration";
1319 for (
auto &&[i, operand, argument] :
1321 if (matcherSymbol.getArgAttr(i, consumedAttr)) {
1324 <<
"does not expect matcher symbol to consume its operand #" << i;
1325 diag.attachNote(matcherSymbol->getLoc()) <<
"symbol declaration";
1334 <<
"mismatching type interfaces for operand and matcher argument #"
1335 << i <<
" of matcher " << matcher;
1336 diag.attachNote(matcherSymbol->getLoc()) <<
"symbol declaration";
1341 TypeRange matcherResults = matcherSymbol.getResultTypes();
1342 TypeRange actionArguments = actionSymbol.getArgumentTypes();
1343 if (matcherResults.size() != actionArguments.size()) {
1344 return emitError() <<
"mismatching number of matcher results and "
1345 "action arguments between "
1346 << matcher <<
" (" << matcherResults.size() <<
") and "
1347 << action <<
" (" << actionArguments.size() <<
")";
1349 for (
auto &&[i, matcherType, actionType] :
1354 return emitError() <<
"mismatching type interfaces for matcher result "
1355 "and action argument #"
1356 << i <<
"of matcher " << matcher <<
" and action "
1361 TypeRange actionResults = actionSymbol.getResultTypes();
1362 auto resultTypes =
TypeRange(getResultTypes()).drop_front();
1363 if (actionResults.size() != resultTypes.size()) {
1365 emitError() <<
"the number of action results ("
1366 << actionResults.size() <<
") for " << action
1367 <<
" doesn't match the number of extra op results ("
1368 << resultTypes.size() <<
")";
1369 diag.attachNote(actionSymbol->getLoc()) <<
"symbol declaration";
1372 for (
auto &&[i, resultType, actionType] :
1378 emitError() <<
"mismatching type interfaces for action result #" << i
1379 <<
" of action " << action <<
" and op result";
1380 diag.attachNote(actionSymbol->getLoc()) <<
"symbol declaration";
1399 size_t numIterations = payloads.empty() ? 0 : payloads.front().size();
1400 bool withZipShortest = getWithZipShortest();
1404 if (withZipShortest) {
1408 return A.size() <
B.size();
1411 for (
size_t argIdx = 0; argIdx < payloads.size(); argIdx++)
1412 payloads[argIdx].resize(numIterations);
1418 for (
size_t argIdx = 1; !withZipShortest && argIdx < payloads.size();
1420 if (payloads[argIdx].size() != numIterations) {
1421 return emitSilenceableError()
1422 <<
"prior targets' payload size (" << numIterations
1423 <<
") differs from payload size (" << payloads[argIdx].size()
1424 <<
") of target " << getTargets()[argIdx];
1433 for (
size_t iterIdx = 0; iterIdx < numIterations; iterIdx++) {
1434 auto scope = state.make_region_scope(getBody());
1440 if (failed(state.mapBlockArgument(blockArg, {argument})))
1445 for (
Operation &transform : getBody().front().without_terminator()) {
1447 llvm::cast<transform::TransformOpInterface>(transform));
1453 OperandRange yieldOperands = getYieldOp().getOperands();
1454 for (
auto &&[result, yieldOperand, resTuple] :
1455 llvm::zip_equal(getResults(), yieldOperands, zippedResults))
1457 if (isa<TransformHandleTypeInterface>(result.getType()))
1458 llvm::append_range(resTuple, state.getPayloadOps(yieldOperand));
1459 else if (isa<TransformValueHandleTypeInterface>(result.getType()))
1460 llvm::append_range(resTuple, state.getPayloadValues(yieldOperand));
1461 else if (isa<TransformParamTypeInterface>(result.getType()))
1462 llvm::append_range(resTuple, state.getParams(yieldOperand));
1464 assert(
false &&
"unhandled handle type");
1468 for (
auto &&[result, resPayload] : zip_equal(getResults(), zippedResults))
1474 void transform::ForeachOp::getEffects(
1475 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
1478 for (
auto &&[target, blockArg] :
1479 llvm::zip(getTargetsMutable(), getBody().front().getArguments())) {
1481 if (any_of(getBody().front().without_terminator(), [&](
Operation &op) {
1483 cast<TransformOpInterface>(&op));
1491 if (any_of(getBody().front().without_terminator(), [&](
Operation &op) {
1495 }
else if (any_of(getBody().front().without_terminator(), [&](
Operation &op) {
1504 void transform::ForeachOp::getSuccessorRegions(
1506 Region *bodyRegion = &getBody();
1508 regions.emplace_back(bodyRegion, bodyRegion->
getArguments());
1513 assert(point == getBody() &&
"unexpected region index");
1514 regions.emplace_back(bodyRegion, bodyRegion->
getArguments());
1515 regions.emplace_back();
1522 assert(point == getBody() &&
"unexpected region index");
1523 return getOperation()->getOperands();
1526 transform::YieldOp transform::ForeachOp::getYieldOp() {
1527 return cast<transform::YieldOp>(getBody().front().getTerminator());
1531 for (
auto [targetOpt, bodyArgOpt] :
1532 llvm::zip_longest(getTargets(), getBody().front().getArguments())) {
1533 if (!targetOpt || !bodyArgOpt)
1534 return emitOpError() <<
"expects the same number of targets as the body "
1535 "has block arguments";
1536 if (targetOpt.value().getType() != bodyArgOpt.value().getType())
1538 "expects co-indexed targets and the body's "
1539 "block arguments to have the same op/value/param type");
1542 for (
auto [resultOpt, yieldOperandOpt] :
1543 llvm::zip_longest(getResults(), getYieldOp().getOperands())) {
1544 if (!resultOpt || !yieldOperandOpt)
1545 return emitOpError() <<
"expects the same number of results as the "
1546 "yield terminator has operands";
1547 if (resultOpt.value().getType() != yieldOperandOpt.value().getType())
1548 return emitOpError(
"expects co-indexed results and yield "
1549 "operands to have the same op/value/param type");
1565 for (
Operation *target : state.getPayloadOps(getTarget())) {
1567 for (int64_t i = 0, e = getNthParent(); i < e; ++i) {
1570 bool checkIsolatedFromAbove =
1571 !getIsolatedFromAbove() ||
1573 bool checkOpName = !getOpName().has_value() ||
1575 if (checkIsolatedFromAbove && checkOpName)
1580 if (getAllowEmptyResults()) {
1581 results.
set(llvm::cast<OpResult>(getResult()), parents);
1585 emitSilenceableError()
1586 <<
"could not find a parent op that matches all requirements";
1587 diag.attachNote(target->
getLoc()) <<
"target op";
1591 if (getDeduplicate()) {
1592 if (resultSet.insert(parent).second)
1593 parents.push_back(parent);
1595 parents.push_back(parent);
1598 results.
set(llvm::cast<OpResult>(getResult()), parents);
1610 int64_t resultNumber = getResultNumber();
1611 auto payloadOps = state.getPayloadOps(getTarget());
1612 if (std::empty(payloadOps)) {
1613 results.
set(cast<OpResult>(getResult()), {});
1616 if (!llvm::hasSingleElement(payloadOps))
1618 <<
"handle must be mapped to exactly one payload op";
1620 Operation *target = *payloadOps.begin();
1623 results.
set(llvm::cast<OpResult>(getResult()),
1637 for (
Value v : state.getPayloadValues(getTarget())) {
1638 if (llvm::isa<BlockArgument>(v)) {
1640 emitSilenceableError() <<
"cannot get defining op of block argument";
1641 diag.attachNote(v.getLoc()) <<
"target value";
1644 definingOps.push_back(v.getDefiningOp());
1646 results.
set(llvm::cast<OpResult>(getResult()), definingOps);
1658 int64_t operandNumber = getOperandNumber();
1660 for (
Operation *target : state.getPayloadOps(getTarget())) {
1667 emitSilenceableError()
1668 <<
"could not find a producer for operand number: " << operandNumber
1669 <<
" of " << *target;
1670 diag.attachNote(target->getLoc()) <<
"target op";
1673 producers.push_back(producer);
1675 results.
set(llvm::cast<OpResult>(getResult()), producers);
1688 for (
Operation *target : state.getPayloadOps(getTarget())) {
1691 getLoc(), getIsAll(), getIsInverted(), getRawPositionList(),
1692 target->getNumOperands(), operandPositions);
1693 if (
diag.isSilenceableFailure()) {
1694 diag.attachNote(target->getLoc())
1695 <<
"while considering positions of this payload operation";
1698 llvm::append_range(operands,
1699 llvm::map_range(operandPositions, [&](int64_t pos) {
1700 return target->getOperand(pos);
1703 results.
setValues(cast<OpResult>(getResult()), operands);
1709 getIsInverted(), getIsAll());
1721 for (
Operation *target : state.getPayloadOps(getTarget())) {
1724 getLoc(), getIsAll(), getIsInverted(), getRawPositionList(),
1725 target->getNumResults(), resultPositions);
1726 if (
diag.isSilenceableFailure()) {
1727 diag.attachNote(target->getLoc())
1728 <<
"while considering positions of this payload operation";
1731 llvm::append_range(opResults,
1732 llvm::map_range(resultPositions, [&](int64_t pos) {
1733 return target->getResult(pos);
1736 results.
setValues(cast<OpResult>(getResult()), opResults);
1742 getIsInverted(), getIsAll());
1749 void transform::GetTypeOp::getEffects(
1750 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
1761 for (
Value value : state.getPayloadValues(getValue())) {
1762 Type type = value.getType();
1763 if (getElemental()) {
1764 if (
auto shaped = dyn_cast<ShapedType>(type)) {
1765 type = shaped.getElementType();
1770 results.
setParams(cast<OpResult>(getResult()), params);
1787 state.applyTransform(cast<transform::TransformOpInterface>(transform));
1792 if (mode == transform::FailurePropagationMode::Propagate) {
1811 auto callee = SymbolTable::lookupNearestSymbolFrom<NamedSequenceOp>(
1812 getOperation(), getTarget());
1813 assert(callee &&
"unverified reference to unknown symbol");
1815 if (callee.isExternal())
1821 auto scope = state.make_region_scope(callee.getBody());
1822 for (
auto &&[arg, map] :
1823 llvm::zip_equal(callee.getBody().front().getArguments(), mappings)) {
1824 if (failed(state.mapBlockArgument(arg, map)))
1829 callee.getBody().front(), getFailurePropagationMode(), state, results);
1832 mappings, callee.getBody().front().getTerminator()->getOperands(), state);
1833 for (
auto &&[result, mapping] : llvm::zip_equal(getResults(), mappings))
1841 void transform::IncludeOp::getEffects(
1842 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
1856 auto defaultEffects = [&] {
1863 getOperation()->getAttrOfType<SymbolRefAttr>(getTargetAttrName());
1865 return defaultEffects();
1866 auto callee = SymbolTable::lookupNearestSymbolFrom<NamedSequenceOp>(
1867 getOperation(), getTarget());
1869 return defaultEffects();
1873 (void)earlyVerifierResult.
silence();
1874 return defaultEffects();
1877 for (
unsigned i = 0, e = getNumOperands(); i < e; ++i) {
1878 if (callee.getArgAttr(i, TransformDialect::kArgConsumedAttrName))
1889 auto targetAttr = getOperation()->getAttrOfType<SymbolRefAttr>(
"target");
1891 return emitOpError() <<
"expects a 'target' symbol reference attribute";
1896 return emitOpError() <<
"does not reference a named transform sequence";
1898 FunctionType fnType = target.getFunctionType();
1899 if (fnType.getNumInputs() != getNumOperands())
1900 return emitError(
"incorrect number of operands for callee");
1902 for (
unsigned i = 0, e = fnType.getNumInputs(); i != e; ++i) {
1903 if (getOperand(i).
getType() != fnType.getInput(i)) {
1904 return emitOpError(
"operand type mismatch: expected operand type ")
1905 << fnType.getInput(i) <<
", but provided "
1906 << getOperand(i).getType() <<
" for operand number " << i;
1910 if (fnType.getNumResults() != getNumResults())
1911 return emitError(
"incorrect number of results for callee");
1913 for (
unsigned i = 0, e = fnType.getNumResults(); i != e; ++i) {
1914 Type resultType = getResult(i).getType();
1915 Type funcType = fnType.getResult(i);
1917 return emitOpError() <<
"type of result #" << i
1918 <<
" must implement the same transform dialect "
1919 "interface as the corresponding callee result";
1924 cast<FunctionOpInterface>(*target),
false,
1934 ::std::optional<::mlir::Operation *> maybeCurrent,
1936 if (!maybeCurrent.has_value()) {
1941 return emitSilenceableError() <<
"operation is not empty";
1952 for (
auto acceptedAttr : getOpNames().getAsRange<StringAttr>()) {
1953 if (acceptedAttr.getValue() == currentOpName)
1956 return emitSilenceableError() <<
"wrong operation name";
1967 auto signedAPIntAsString = [&](
const APInt &value) {
1969 llvm::raw_string_ostream os(str);
1970 value.print(os,
true);
1977 if (params.size() != references.size()) {
1978 return emitSilenceableError()
1979 <<
"parameters have different payload lengths (" << params.size()
1980 <<
" vs " << references.size() <<
")";
1983 for (
auto &&[i, param, reference] :
llvm::enumerate(params, references)) {
1984 auto intAttr = llvm::dyn_cast<IntegerAttr>(param);
1985 auto refAttr = llvm::dyn_cast<IntegerAttr>(reference);
1986 if (!intAttr || !refAttr) {
1988 <<
"non-integer parameter value not expected";
1990 if (intAttr.getType() != refAttr.getType()) {
1992 <<
"mismatching integer attribute types in parameter #" << i;
1994 APInt value = intAttr.getValue();
1995 APInt refValue = refAttr.getValue();
1998 int64_t position = i;
1999 auto reportError = [&](StringRef direction) {
2001 emitSilenceableError() <<
"expected parameter to be " << direction
2002 <<
" " << signedAPIntAsString(refValue)
2003 <<
", got " << signedAPIntAsString(value);
2004 diag.attachNote(getParam().getLoc())
2005 <<
"value # " << position
2006 <<
" associated with the parameter defined here";
2010 switch (getPredicate()) {
2011 case MatchCmpIPredicate::eq:
2012 if (value.eq(refValue))
2014 return reportError(
"equal to");
2015 case MatchCmpIPredicate::ne:
2016 if (value.ne(refValue))
2018 return reportError(
"not equal to");
2019 case MatchCmpIPredicate::lt:
2020 if (value.slt(refValue))
2022 return reportError(
"less than");
2023 case MatchCmpIPredicate::le:
2024 if (value.sle(refValue))
2026 return reportError(
"less than or equal to");
2027 case MatchCmpIPredicate::gt:
2028 if (value.sgt(refValue))
2030 return reportError(
"greater than");
2031 case MatchCmpIPredicate::ge:
2032 if (value.sge(refValue))
2034 return reportError(
"greater than or equal to");
2040 void transform::MatchParamCmpIOp::getEffects(
2041 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
2054 results.
setParams(cast<OpResult>(getParam()), {getValue()});
2067 if (isa<TransformHandleTypeInterface>(handles.front().
getType())) {
2069 for (
Value operand : handles)
2070 llvm::append_range(operations, state.getPayloadOps(operand));
2071 if (!getDeduplicate()) {
2072 results.
set(llvm::cast<OpResult>(getResult()), operations);
2077 results.
set(llvm::cast<OpResult>(getResult()), uniqued.getArrayRef());
2081 if (llvm::isa<TransformParamTypeInterface>(handles.front().getType())) {
2083 for (
Value attribute : handles)
2084 llvm::append_range(attrs, state.getParams(attribute));
2085 if (!getDeduplicate()) {
2086 results.
setParams(cast<OpResult>(getResult()), attrs);
2091 results.
setParams(cast<OpResult>(getResult()), uniqued.getArrayRef());
2096 llvm::isa<TransformValueHandleTypeInterface>(handles.front().getType()) &&
2097 "expected value handle type");
2099 for (
Value value : handles)
2100 llvm::append_range(payloadValues, state.getPayloadValues(value));
2101 if (!getDeduplicate()) {
2102 results.
setValues(cast<OpResult>(getResult()), payloadValues);
2107 results.
setValues(cast<OpResult>(getResult()), uniqued.getArrayRef());
2111 bool transform::MergeHandlesOp::allowsRepeatedHandleOperands() {
2113 return getDeduplicate();
2116 void transform::MergeHandlesOp::getEffects(
2117 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
2125 OpFoldResult transform::MergeHandlesOp::fold(FoldAdaptor adaptor) {
2126 if (getDeduplicate() || getHandles().size() != 1)
2131 return getHandles().front();
2149 auto scope = state.make_region_scope(getBody());
2151 state, this->getOperation(), getBody())))
2155 FailurePropagationMode::Propagate, state, results);
2158 void transform::NamedSequenceOp::getEffects(
2159 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {}
2164 parser, result,
false,
2165 getFunctionTypeAttrName(result.
name),
2168 std::string &) { return builder.getFunctionType(inputs, results); },
2169 getArgAttrsAttrName(result.
name), getResAttrsAttrName(result.
name));
2174 printer, cast<FunctionOpInterface>(getOperation()),
false,
2175 getFunctionTypeAttrName().getValue(), getArgAttrsAttrName(),
2176 getResAttrsAttrName());
2186 if (
auto parent = op->
getParentOfType<transform::TransformOpInterface>()) {
2189 <<
"cannot be defined inside another transform op";
2190 diag.attachNote(parent.
getLoc()) <<
"ancestor transform op";
2194 if (op.isExternal() || op.getFunctionBody().empty()) {
2201 if (op.getFunctionBody().front().empty())
2204 Operation *terminator = &op.getFunctionBody().front().back();
2205 if (!isa<transform::YieldOp>(terminator)) {
2208 << transform::YieldOp::getOperationName()
2209 <<
"' as terminator";
2210 diag.attachNote(terminator->
getLoc()) <<
"terminator";
2214 if (terminator->
getNumOperands() != op.getResultTypes().size()) {
2216 <<
"expected terminator to have as many operands as the parent op "
2219 for (
auto [i, operandType, resultType] : llvm::zip_equal(
2222 if (operandType == resultType)
2225 <<
"the type of the terminator operand #" << i
2226 <<
" must match the type of the corresponding parent op result ("
2227 << operandType <<
" vs " << resultType <<
")";
2240 transform::TransformDialect::kWithNamedSequenceAttrName)) {
2243 <<
"expects the parent symbol table to have the '"
2244 << transform::TransformDialect::kWithNamedSequenceAttrName
2246 diag.attachNote(parent->
getLoc()) <<
"symbol table operation";
2251 if (
auto parent = op->
getParentOfType<transform::TransformOpInterface>()) {
2254 <<
"cannot be defined inside another transform op";
2255 diag.attachNote(parent.
getLoc()) <<
"ancestor transform op";
2259 if (op.isExternal() || op.getBody().empty())
2263 if (op.getBody().front().empty())
2266 Operation *terminator = &op.getBody().front().back();
2267 if (!isa<transform::YieldOp>(terminator)) {
2270 << transform::YieldOp::getOperationName()
2271 <<
"' as terminator";
2272 diag.attachNote(terminator->
getLoc()) <<
"terminator";
2276 if (terminator->
getNumOperands() != op.getFunctionType().getNumResults()) {
2278 <<
"expected terminator to have as many operands as the parent op "
2281 for (
auto [i, operandType, resultType] :
2282 llvm::zip_equal(llvm::seq<unsigned>(0, terminator->
getNumOperands()),
2284 op.getFunctionType().getResults())) {
2285 if (operandType == resultType)
2288 <<
"the type of the terminator operand #" << i
2289 <<
" must match the type of the corresponding parent op result ("
2290 << operandType <<
" vs " << resultType <<
")";
2293 auto funcOp = cast<FunctionOpInterface>(*op);
2296 if (!
diag.succeeded())
2308 template <
typename FnTy>
2313 types.reserve(1 + extraBindingTypes.size());
2314 types.push_back(bbArgType);
2315 llvm::append_range(types, extraBindingTypes);
2318 Region *region = state.regions.back().get();
2325 if constexpr (llvm::function_traits<FnTy>::num_args == 3) {
2326 bodyBuilder(builder, state.location, bodyBlock->
getArgument(0));
2328 bodyBuilder(builder, state.location, bodyBlock->
getArgument(0),
2333 void transform::NamedSequenceOp::build(
OpBuilder &builder,
2341 state.addAttribute(getFunctionTypeAttrName(state.name),
2343 rootType, resultTypes)));
2344 state.attributes.append(attrs.begin(), attrs.end());
2359 size_t numAssociations =
2361 .Case([&](TransformHandleTypeInterface opHandle) {
2362 return llvm::range_size(state.getPayloadOps(getHandle()));
2364 .Case([&](TransformValueHandleTypeInterface valueHandle) {
2365 return llvm::range_size(state.getPayloadValues(getHandle()));
2367 .Case([&](TransformParamTypeInterface param) {
2368 return llvm::range_size(state.getParams(getHandle()));
2371 llvm_unreachable(
"unknown kind of transform dialect type");
2374 results.
setParams(cast<OpResult>(getNum()),
2381 auto resultType = cast<TransformParamTypeInterface>(getNum().
getType());
2396 auto payloadOps = state.getPayloadOps(getTarget());
2399 result.push_back(op);
2401 results.
set(cast<OpResult>(getResult()), result);
2410 Value target, int64_t numResultHandles) {
2419 int64_t numPayloads =
2421 .Case<TransformHandleTypeInterface>([&](
auto x) {
2422 return llvm::range_size(state.getPayloadOps(getHandle()));
2424 .Case<TransformValueHandleTypeInterface>([&](
auto x) {
2425 return llvm::range_size(state.getPayloadValues(getHandle()));
2427 .Case<TransformParamTypeInterface>([&](
auto x) {
2428 return llvm::range_size(state.getParams(getHandle()));
2430 .Default([](
auto x) {
2431 llvm_unreachable(
"unknown transform dialect type interface");
2435 auto produceNumOpsError = [&]() {
2436 return emitSilenceableError()
2437 << getHandle() <<
" expected to contain " << this->getNumResults()
2438 <<
" payloads but it contains " << numPayloads <<
" payloads";
2443 if (numPayloads > getNumResults() && !getOverflowResult().has_value())
2444 return produceNumOpsError();
2449 if (numPayloads < getNumResults() && getFailOnPayloadTooSmall() &&
2450 (numPayloads != 0 || !getPassThroughEmptyHandle()))
2451 return produceNumOpsError();
2455 if (getOverflowResult())
2456 resultHandles[*getOverflowResult()].reserve(numPayloads - getNumResults());
2458 auto container = [&]() {
2459 if (isa<TransformHandleTypeInterface>(getHandle().
getType())) {
2460 return llvm::map_to_vector(
2461 state.getPayloadOps(getHandle()),
2464 if (isa<TransformValueHandleTypeInterface>(getHandle().
getType())) {
2465 return llvm::map_to_vector(state.getPayloadValues(getHandle()),
2468 assert(isa<TransformParamTypeInterface>(getHandle().
getType()) &&
2469 "unsupported kind of transform dialect type");
2470 return llvm::map_to_vector(state.getParams(getHandle()),
2475 int64_t resultNum = en.index();
2476 if (resultNum >= getNumResults())
2477 resultNum = *getOverflowResult();
2478 resultHandles[resultNum].push_back(en.value());
2489 void transform::SplitHandleOp::getEffects(
2490 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
2498 if (getOverflowResult().has_value() &&
2499 !(*getOverflowResult() < getNumResults()))
2500 return emitOpError(
"overflow_result is not a valid result index");
2502 for (
Type resultType : getResultTypes()) {
2506 return emitOpError(
"expects result types to implement the same transform "
2507 "interface as the operand type");
2521 unsigned numRepetitions = llvm::range_size(state.getPayloadOps(getPattern()));
2523 Value handle = en.value();
2524 if (isa<TransformHandleTypeInterface>(handle.
getType())) {
2526 llvm::to_vector(state.getPayloadOps(handle));
2528 payload.reserve(numRepetitions * current.size());
2529 for (
unsigned i = 0; i < numRepetitions; ++i)
2530 llvm::append_range(payload, current);
2531 results.
set(llvm::cast<OpResult>(getReplicated()[en.index()]), payload);
2533 assert(llvm::isa<TransformParamTypeInterface>(handle.
getType()) &&
2534 "expected param type");
2537 params.reserve(numRepetitions * current.size());
2538 for (
unsigned i = 0; i < numRepetitions; ++i)
2539 llvm::append_range(params, current);
2540 results.
setParams(llvm::cast<OpResult>(getReplicated()[en.index()]),
2547 void transform::ReplicateOp::getEffects(
2548 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
2563 auto scope = state.make_region_scope(*getBodyBlock()->getParent());
2564 if (failed(mapBlockArguments(state)))
2572 OpAsmParser &parser, std::optional<OpAsmParser::UnresolvedOperand> &root,
2574 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &extraBindings,
2575 SmallVectorImpl<Type> &extraBindingTypes) {
2579 root = std::nullopt;
2582 if (failed(hasRoot.
value()))
2596 if (failed(parser.
parseType(rootType))) {
2600 if (!extraBindings.empty()) {
2605 if (extraBindingTypes.size() != extraBindings.size()) {
2607 "expected types to be provided for all operands");
2623 bool hasExtras = !extraBindings.empty();
2633 printer << rootType;
2635 printer <<
", " << llvm::interleaved(extraBindingTypes) <<
')';
2642 auto iface = dyn_cast<transform::TransformOpInterface>(use.
getOwner());
2657 if (!potentialConsumer) {
2658 potentialConsumer = &use;
2663 <<
" has more than one potential consumer";
2666 diag.attachNote(use.getOwner()->getLoc())
2667 <<
"used here as operand #" << use.getOperandNumber();
2675 assert(getBodyBlock()->getNumArguments() >= 1 &&
2676 "the number of arguments must have been verified to be more than 1 by "
2677 "PossibleTopLevelTransformOpTrait");
2679 if (!getRoot() && !getExtraBindings().empty()) {
2680 return emitOpError()
2681 <<
"does not expect extra operands when used as top-level";
2687 return (emitOpError() <<
"block argument #" << arg.getArgNumber());
2694 for (
Operation &child : *getBodyBlock()) {
2695 if (!isa<TransformOpInterface>(child) &&
2696 &child != &getBodyBlock()->back()) {
2699 <<
"expected children ops to implement TransformOpInterface";
2700 diag.attachNote(child.getLoc()) <<
"op without interface";
2704 for (
OpResult result : child.getResults()) {
2705 auto report = [&]() {
2706 return (child.emitError() <<
"result #" << result.getResultNumber());
2713 if (!getBodyBlock()->mightHaveTerminator())
2714 return emitOpError() <<
"expects to have a terminator in the body";
2716 if (getBodyBlock()->getTerminator()->getOperandTypes() !=
2717 getOperation()->getResultTypes()) {
2719 <<
"expects the types of the terminator operands "
2720 "to match the types of the result";
2721 diag.attachNote(getBodyBlock()->getTerminator()->getLoc()) <<
"terminator";
2727 void transform::SequenceOp::getEffects(
2728 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
2734 assert(point == getBody() &&
"unexpected region index");
2735 if (getOperation()->getNumOperands() > 0)
2736 return getOperation()->getOperands();
2738 getOperation()->operand_end());
2741 void transform::SequenceOp::getSuccessorRegions(
2744 Region *bodyRegion = &getBody();
2745 regions.emplace_back(bodyRegion, getNumOperands() != 0
2751 assert(point == getBody() &&
"unexpected region index");
2752 regions.emplace_back(getOperation()->getResults());
2755 void transform::SequenceOp::getRegionInvocationBounds(
2758 bounds.emplace_back(1, 1);
2763 FailurePropagationMode failurePropagationMode,
2766 build(builder, state, resultTypes, failurePropagationMode, root,
2775 FailurePropagationMode failurePropagationMode,
2778 build(builder, state, resultTypes, failurePropagationMode, root,
2786 FailurePropagationMode failurePropagationMode,
2789 build(builder, state, resultTypes, failurePropagationMode,
Value(),
2797 FailurePropagationMode failurePropagationMode,
2800 build(builder, state, resultTypes, failurePropagationMode,
Value(),
2816 Value target, StringRef name) {
2818 build(builder, result, name);
2825 llvm::outs() <<
"[[[ IR printer: ";
2826 if (getName().has_value())
2827 llvm::outs() << *getName() <<
" ";
2830 if (getAssumeVerified().value_or(
false))
2832 if (getUseLocalScope().value_or(
false))
2834 if (getSkipRegions().value_or(
false))
2838 llvm::outs() <<
"top-level ]]]\n";
2839 state.getTopLevel()->print(llvm::outs(), printFlags);
2840 llvm::outs() <<
"\n";
2841 llvm::outs().flush();
2845 llvm::outs() <<
"]]]\n";
2846 for (
Operation *target : state.getPayloadOps(getTarget())) {
2847 target->
print(llvm::outs(), printFlags);
2848 llvm::outs() <<
"\n";
2851 llvm::outs().flush();
2855 void transform::PrintOp::getEffects(
2856 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
2860 if (!getTargetMutable().empty())
2880 <<
"failed to verify payload op";
2881 diag.attachNote(target->
getLoc()) <<
"payload op";
2887 void transform::VerifyOp::getEffects(
2888 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
2896 void transform::YieldOp::getEffects(
2897 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
static MLIRContext * getContext(OpFoldResult val)
static std::string diag(const llvm::Value &value)
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
ParseResult parseSymbolName(StringAttr &result)
Parse an -identifier and store it (without the '@' symbol) in a string attribute.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseOptionalRParen()=0
Parse a ) token if present.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual ParseResult parseColon()=0
Parse a : token.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseArrow()=0
Parse a '->' token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseComma()=0
Parse a , token.
virtual ParseResult parseOptionalLParen()=0
Parse a ( token if present.
ParseResult parseTypeList(SmallVectorImpl< Type > &result)
Parse a type list.
Attributes are known-constant values of operations.
This class represents an argument of a Block.
Block represents an ordered list of Operations.
MutableArrayRef< BlockArgument > BlockArgListType
BlockArgument getArgument(unsigned i)
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
Operation * getTerminator()
Get the terminator operation of this block.
OpListType & getOperations()
BlockArgListType getArguments()
iterator_range< iterator > without_terminator()
Return an iterator range over the operation within this block excluding the terminator operation at t...
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
This class is a general helper class for creating context-global objects like types,...
IntegerAttr getI64IntegerAttr(int64_t value)
StringAttr getStringAttr(const Twine &bytes)
MLIRContext * getContext() const
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
This class describes a specific conversion target.
A compatibility class connecting InFlightDiagnostic to DiagnosedSilenceableFailure while providing an...
The result of a transform IR operation application.
LogicalResult silence()
Converts silenceable failure into LogicalResult success without reporting the diagnostic,...
static DiagnosedSilenceableFailure success()
Constructs a DiagnosedSilenceableFailure in the success state.
Diagnostic & attachNote(std::optional< Location > loc=std::nullopt)
Attaches a note to the last diagnostic.
std::string getMessage() const
Returns the diagnostic message without emitting it.
bool isDefiniteFailure() const
Returns true if this is a definite failure.
LogicalResult checkAndReport()
Converts all kinds of failure into a LogicalResult failure, emitting the diagnostic if necessary.
bool succeeded() const
Returns true if this is a success.
static DiagnosedSilenceableFailure definiteFailure()
Constructs a DiagnosedSilenceableFailure in the failure state.
bool isSilenceableFailure() const
Returns true if this is a silenceable failure.
Dialects are groups of MLIR operations, types and attributes, as well as behavior associated with the...
A class for computing basic dominance information.
This class represents a frozen set of patterns that can be processed by a pattern applicator.
This class allows control over how the GreedyPatternRewriteDriver works.
static constexpr int64_t kNoLimit
IRValueT get() const
Return the current value being used by this operand.
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
This class represents a diagnostic that is inflight and set to be reported.
This class represents upper and lower bounds on the number of times a region of a RegionBranchOpInter...
Conversion from types to the LLVM IR dialect.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
Dialect * getLoadedDialect(StringRef name)
Get a registered IR dialect with the given namespace.
std::vector< Dialect * > getLoadedDialects()
Return information about all IR dialects loaded in the context.
ArrayRef< RegisteredOperationName > getRegisteredOperations()
Return a sorted array containing the information about all registered operations.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual OptionalParseResult parseOptionalOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single operand if present.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printNewline()=0
Print a newline and indent the printer to the start of the current operation.
virtual void increaseIndent()=0
Increase indentation.
void printOperands(const ContainerType &container)
Print a comma separated list of operands.
virtual void decreaseIndent()=0
Decrease indentation.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Listener * getListener() const
Returns the current listener of this builder, or nullptr if this builder doesn't have a listener.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
This class represents a single result from folding an operation.
This class represents an operand of an operation.
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Set of flags used to control the behavior of the various IR print methods (e.g.
OpPrintingFlags & useLocalScope(bool enable=true)
Use local scope when printing the operation.
OpPrintingFlags & assumeVerified(bool enable=true)
Do not verify the operation when using custom operation printers.
OpPrintingFlags & skipRegions(bool skip=true)
Skip printing regions.
This is a value defined by a result of an operation.
This class provides the API for ops that are known to be isolated from above.
A trait used to provide symbol table functionalities to a region operation.
This class implements the operand iterators for the Operation class.
type_range getType() const
type_range getTypes() const
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
Operation is the basic unit of execution within MLIR.
Value getOperand(unsigned idx)
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Attribute getAttr(StringAttr name)
Return the specified attribute if present, null otherwise.
Operation * clone(IRMapping &mapper, CloneOptions options=CloneOptions::all())
Create a deep copy of this operation, remapping any operands that use values outside of the operation...
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
void print(raw_ostream &os, const OpPrintingFlags &flags=std::nullopt)
MLIRContext * getContext()
Return the context this operation is associated with.
Location getLoc()
The source location the operation was defined or derived from.
unsigned getNumOperands()
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Block * getBlock()
Returns the operation block that contains this operation.
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
Operation * getParentWithTrait()
Returns the closest surrounding parent operation with trait Trait.
OperationName getName()
The name of an operation is the key identifier for it.
operand_range getOperands()
Returns an iterator on the underlying Value's.
result_range getOpResults()
result_range getResults()
bool isProperAncestor(Operation *other)
Return true if this operation is a proper ancestor of the other operation.
void erase()
Remove this operation from its parent block and delete it.
unsigned getNumResults()
Return the number of results held by this operation.
This class implements Optional functionality for ParseResult.
ParseResult value() const
Access the internal ParseResult value.
bool has_value() const
Returns true if we contain a valid ParseResult value.
static const PassInfo * lookup(StringRef passArg)
Returns the pass info for the specified pass class or null if unknown.
The main pass manager and pipeline builder.
static const PassPipelineInfo * lookup(StringRef pipelineArg)
Returns the pass pipeline info for the specified pass pipeline or null if unknown.
Structure to group information about a passes and pass pipelines (argument to invoke via mlir-opt,...
LogicalResult addToPipeline(OpPassManager &pm, StringRef options, function_ref< LogicalResult(const Twine &)> errorHandler) const
Adds this pass registry entry to the given pass manager.
This class represents a point being branched from in the methods of the RegionBranchOpInterface.
bool isParent() const
Returns true if branching from the parent op.
Region * getRegionOrNull() const
Returns the region if branching from a region.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
BlockArgListType getArguments()
unsigned getRegionNumber()
Return the number of this region in the parent operation.
This is a "type erased" representation of a registered operation.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
static DerivedEffect * get()
Returns a unique instance for the derived effect class.
This class represents a collection of SymbolTables.
Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
static StringRef getSymbolAttrName()
Return the name of the attribute used for symbol names.
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
type_range getType() const
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
void print(raw_ostream &os) const
Type getType() const
Return the type of this value.
use_range getUses() const
Returns a range of all uses, which is useful for iterating over all uses.
user_range getUsers() const
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
A utility result that is used to signal how to proceed with an ongoing walk:
static WalkResult advance()
bool wasInterrupted() const
Returns true if the walk was interrupted.
static WalkResult interrupt()
Operation * getOwner() const
Return the owner of this operand.
A named class for passing around the variadic flag.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
void printFunctionOp(OpAsmPrinter &p, FunctionOpInterface op, bool isVariadic, StringRef typeAttrName, StringAttr argAttrsName, StringAttr resAttrsName)
Printer implementation for function-like operations.
ParseResult parseFunctionOp(OpAsmParser &parser, OperationState &result, bool allowVariadic, StringAttr typeAttrName, FuncTypeBuilder funcTypeBuilder, StringAttr argAttrsName, StringAttr resAttrsName)
Parser implementation for function-like operations.
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
Include the generated interface declarations.
InFlightDiagnostic emitWarning(Location loc)
Utility method to emit a warning message using this location.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
const FrozenRewritePatternSet GreedyRewriteConfig config
LogicalResult applyPatternsGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
DiagnosedSilenceableFailure emitSilenceableFailure(Location loc, const Twine &message={})
Emits a silenceable failure with the given message.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
LogicalResult applyFullConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Apply a complete conversion on the given operations, and all nested operations.
LogicalResult applyOpPatternsGreedily(ArrayRef< Operation * > ops, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr, bool *allErased=nullptr)
Rewrite the specified ops by repeatedly applying the highest benefit patterns in a greedy worklist dr...
DiagnosedDefiniteFailure emitDefiniteFailure(Location loc, const Twine &message={})
Emits a definite failure with the given message.
bool isOpTriviallyDead(Operation *op)
Return true if the given operation is unused, and has no side effects on memory that prevent erasing.
const FrozenRewritePatternSet & patterns
void eliminateCommonSubExpressions(RewriterBase &rewriter, DominanceInfo &domInfo, Operation *op, bool *changed=nullptr)
Eliminate common subexpressions within the given operation.
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.
size_t moveLoopInvariantCode(ArrayRef< Region * > regions, function_ref< bool(Value, Region *)> isDefinedOutsideRegion, function_ref< bool(Operation *, Region *)> shouldMoveOutOfRegion, function_ref< void(Operation *, Region *)> moveOutOfRegion)
Given a list of regions, perform loop-invariant code motion.
Dialect conversion configuration.
RewriterBase::Listener * listener
An optional listener that is notified about all IR modifications in case dialect conversion succeeds.
This is the representation of an operand reference.
This represents an operation in an abstracted form, suitable for use with the builder APIs.
T & getOrAddProperties()
Get (or create) a properties of the provided type to be set on the operation on creation.
void addOperands(ValueRange newOperands)
void addTypes(ArrayRef< Type > newTypes)
Region * addRegion()
Create a region that should be attached to the operation.