34 #include "llvm/ADT/DenseSet.h"
35 #include "llvm/ADT/STLExtras.h"
36 #include "llvm/ADT/ScopeExit.h"
37 #include "llvm/ADT/SmallPtrSet.h"
38 #include "llvm/ADT/TypeSwitch.h"
39 #include "llvm/Support/Debug.h"
40 #include "llvm/Support/DebugLog.h"
41 #include "llvm/Support/ErrorHandling.h"
42 #include "llvm/Support/InterleavedRange.h"
45 #define DEBUG_TYPE "transform-dialect"
46 #define DEBUG_TYPE_MATCHER "transform-matcher"
52 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &dynamicOptions);
58 OpAsmParser &parser, std::optional<OpAsmParser::UnresolvedOperand> &root,
60 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &extraBindings,
61 SmallVectorImpl<Type> &extraBindingTypes);
67 ArrayAttr matchers, ArrayAttr actions);
78 Operation *transformAncestor = transform.getOperation();
79 while (transformAncestor) {
80 if (transformAncestor == payload) {
82 transform.emitDefiniteFailure()
83 <<
"cannot apply transform to itself (or one of its ancestors)";
84 diag.attachNote(payload->
getLoc()) <<
"target payload op";
87 transformAncestor = transformAncestor->
getParentOp();
92 #define GET_OP_CLASSES
93 #include "mlir/Dialect/Transform/IR/TransformOps.cpp.inc"
99 OperandRange transform::AlternativesOp::getEntrySuccessorOperands(
101 if (!successor.
isParent() && getOperation()->getNumOperands() == 1)
102 return getOperation()->getOperands();
104 getOperation()->operand_end());
107 void transform::AlternativesOp::getSuccessorRegions(
109 for (
Region &alternative : llvm::drop_begin(
116 regions.emplace_back(&alternative, !getOperands().empty()
117 ? alternative.getArguments()
121 regions.emplace_back(getOperation(), getOperation()->getResults());
124 void transform::AlternativesOp::getRegionInvocationBounds(
129 bounds.reserve(getNumRegions());
130 bounds.emplace_back(1, 1);
137 results.
set(res, {});
145 if (
Value scopeHandle = getScope())
146 llvm::append_range(originals, state.getPayloadOps(scopeHandle));
148 originals.push_back(state.getTopLevel());
151 if (original->isAncestor(getOperation())) {
153 <<
"scope must not contain the transforms being applied";
154 diag.attachNote(original->getLoc()) <<
"scope";
159 <<
"only isolated-from-above ops can be alternative scopes";
160 diag.attachNote(original->getLoc()) <<
"scope";
165 for (
Region ® : getAlternatives()) {
170 auto scope = state.make_region_scope(reg);
171 auto clones = llvm::to_vector(
172 llvm::map_range(originals, [](
Operation *op) {
return op->
clone(); }));
173 auto deleteClones = llvm::make_scope_exit([&] {
177 if (
failed(state.mapBlockArguments(reg.front().getArgument(0), clones)))
181 for (
Operation &transform : reg.front().without_terminator()) {
183 state.applyTransform(cast<TransformOpInterface>(transform));
185 LDBG() <<
"alternative failed: " << result.
getMessage();
199 deleteClones.release();
200 TrackingListener listener(state, *
this);
202 for (
const auto &kvp : llvm::zip(originals, clones)) {
213 return emitSilenceableError() <<
"all alternatives failed";
216 void transform::AlternativesOp::getEffects(
217 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
220 for (
Region *region : getRegions()) {
221 if (!region->empty())
228 for (
Region &alternative : getAlternatives()) {
233 <<
"expects terminator operands to have the "
234 "same type as results of the operation";
235 diag.attachNote(terminator->
getLoc()) <<
"terminator";
252 llvm::to_vector(state.getPayloadOps(getTarget()));
255 if (
auto paramH = getParam()) {
257 if (params.size() != 1) {
258 if (targets.size() != params.size()) {
259 return emitSilenceableError()
260 <<
"parameter and target have different payload lengths ("
261 << params.size() <<
" vs " << targets.size() <<
")";
263 for (
auto &&[target, attr] : llvm::zip_equal(targets, params))
264 target->setAttr(getName(), attr);
269 for (
auto *target : targets)
270 target->setAttr(getName(), attr);
274 void transform::AnnotateOp::getEffects(
275 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
286 transform::ApplyCommonSubexpressionEliminationOp::applyToOne(
301 void transform::ApplyCommonSubexpressionEliminationOp::getEffects(
302 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
326 auto addDefiningOpsToWorklist = [&](
Operation *op) {
329 if (
Operation *defOp = v.getDefiningOp())
331 worklist.insert(defOp);
339 const auto *it = llvm::find(worklist, op);
340 if (it != worklist.end())
349 addDefiningOpsToWorklist(op);
355 while (!worklist.empty()) {
359 addDefiningOpsToWorklist(op);
366 void transform::ApplyDeadCodeEliminationOp::getEffects(
367 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
391 if (!getRegion().empty()) {
392 for (
Operation &op : getRegion().front()) {
393 cast<transform::PatternDescriptorOpInterface>(&op)
394 .populatePatternsWithState(
patterns, state);
404 config.setMaxIterations(getMaxIterations() ==
static_cast<uint64_t
>(-1)
406 : getMaxIterations());
407 config.setMaxNumRewrites(getMaxNumRewrites() ==
static_cast<uint64_t
>(-1)
409 : getMaxNumRewrites());
414 bool cseChanged =
false;
417 static const int64_t kNumMaxIterations = 50;
418 int64_t iteration = 0;
420 LogicalResult result = failure();
433 if (target != nestedOp)
434 ops.push_back(nestedOp);
443 <<
"greedy pattern application failed";
451 }
while (cseChanged && ++iteration < kNumMaxIterations);
453 if (iteration == kNumMaxIterations)
460 if (!getRegion().empty()) {
461 for (
Operation &op : getRegion().front()) {
462 if (!isa<transform::PatternDescriptorOpInterface>(&op)) {
464 <<
"expected children ops to implement "
465 "PatternDescriptorOpInterface";
466 diag.attachNote(op.
getLoc()) <<
"op without interface";
474 void transform::ApplyPatternsOp::getEffects(
475 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
480 void transform::ApplyPatternsOp::build(
489 bodyBuilder(builder, result.
location);
496 void transform::ApplyCanonicalizationPatternsOp::populatePatterns(
500 dialect->getCanonicalizationPatterns(
patterns);
502 op.getCanonicalizationPatterns(
patterns, ctx);
516 std::unique_ptr<TypeConverter> defaultTypeConverter;
517 transform::TypeConverterBuilderOpInterface typeConverterBuilder =
518 getDefaultTypeConverter();
519 if (typeConverterBuilder)
520 defaultTypeConverter = typeConverterBuilder.getTypeConverter();
525 for (
Attribute attr : cast<ArrayAttr>(*getLegalOps()))
526 conversionTarget.addLegalOp(
529 for (
Attribute attr : cast<ArrayAttr>(*getIllegalOps()))
530 conversionTarget.addIllegalOp(
532 if (getLegalDialects())
533 for (
Attribute attr : cast<ArrayAttr>(*getLegalDialects()))
534 conversionTarget.addLegalDialect(cast<StringAttr>(attr).getValue());
535 if (getIllegalDialects())
536 for (
Attribute attr : cast<ArrayAttr>(*getIllegalDialects()))
537 conversionTarget.addIllegalDialect(cast<StringAttr>(attr).getValue());
545 if (!getPatterns().empty()) {
546 for (
Operation &op : getPatterns().front()) {
548 cast<transform::ConversionPatternDescriptorOpInterface>(&op);
551 std::unique_ptr<TypeConverter> typeConverter =
552 descriptor.getTypeConverter();
555 keepAliveConverters.emplace_back(std::move(typeConverter));
556 converter = keepAliveConverters.back().get();
559 if (!defaultTypeConverter) {
561 <<
"pattern descriptor does not specify type "
562 "converter and apply_conversion_patterns op has "
563 "no default type converter";
564 diag.attachNote(op.
getLoc()) <<
"pattern descriptor op";
567 converter = defaultTypeConverter.get();
573 descriptor.populateConversionTargetRules(*converter, conversionTarget);
575 descriptor.populatePatterns(*converter,
patterns);
583 TrackingListenerConfig trackingConfig;
584 trackingConfig.requireMatchingReplacementOpName =
false;
585 ErrorCheckingTrackingListener trackingListener(state, *
this, trackingConfig);
587 if (getPreserveHandles())
588 conversionConfig.
listener = &trackingListener;
591 for (
Operation *target : state.getPayloadOps(getTarget())) {
599 LogicalResult status = failure();
600 if (getPartialConversion()) {
611 diag = emitSilenceableError() <<
"dialect conversion failed";
612 diag.attachNote(target->
getLoc()) <<
"target op";
617 trackingListener.checkAndResetError();
619 if (
diag.succeeded()) {
621 return trackingFailure;
623 diag.attachNote() <<
"tracking listener also failed: "
625 (void)trackingFailure.
silence();
628 if (!
diag.succeeded())
636 if (getNumRegions() != 1 && getNumRegions() != 2)
637 return emitOpError() <<
"expected 1 or 2 regions";
638 if (!getPatterns().empty()) {
639 for (
Operation &op : getPatterns().front()) {
640 if (!isa<transform::ConversionPatternDescriptorOpInterface>(&op)) {
642 emitOpError() <<
"expected pattern children ops to implement "
643 "ConversionPatternDescriptorOpInterface";
644 diag.attachNote(op.
getLoc()) <<
"op without interface";
649 if (getNumRegions() == 2) {
650 Region &typeConverterRegion = getRegion(1);
651 if (!llvm::hasSingleElement(typeConverterRegion.
front()))
653 <<
"expected exactly one op in default type converter region";
655 auto typeConverterOp = dyn_cast<transform::TypeConverterBuilderOpInterface>(
657 if (!typeConverterOp) {
659 <<
"expected default converter child op to "
660 "implement TypeConverterBuilderOpInterface";
661 diag.attachNote(maybeTypeConverter->
getLoc()) <<
"op without interface";
665 if (!getPatterns().empty()) {
666 for (
Operation &op : getPatterns().front()) {
668 cast<transform::ConversionPatternDescriptorOpInterface>(&op);
669 if (
failed(descriptor.verifyTypeConverter(typeConverterOp)))
677 void transform::ApplyConversionPatternsOp::getEffects(
678 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
679 if (!getPreserveHandles()) {
687 void transform::ApplyConversionPatternsOp::build(
697 if (patternsBodyBuilder)
698 patternsBodyBuilder(builder, result.
location);
704 if (typeConverterBodyBuilder)
705 typeConverterBodyBuilder(builder, result.
location);
713 void transform::ApplyToLLVMConversionPatternsOp::populatePatterns(
716 assert(dialect &&
"expected that dialect is loaded");
717 auto *iface = cast<ConvertToLLVMPatternInterface>(dialect);
721 iface->populateConvertToLLVMConversionPatterns(
725 LogicalResult transform::ApplyToLLVMConversionPatternsOp::verifyTypeConverter(
726 transform::TypeConverterBuilderOpInterface builder) {
727 if (builder.getTypeConverterType() !=
"LLVMTypeConverter")
728 return emitOpError(
"expected LLVMTypeConverter");
735 return emitOpError(
"unknown dialect or dialect not loaded: ")
737 auto *iface = dyn_cast<ConvertToLLVMPatternInterface>(dialect);
740 "dialect does not implement ConvertToLLVMPatternInterface or "
741 "extension was not loaded: ")
751 transform::ApplyLoopInvariantCodeMotionOp::applyToOne(
761 void transform::ApplyLoopInvariantCodeMotionOp::getEffects(
762 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
771 void transform::ApplyRegisteredPassOp::getEffects(
772 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
789 llvm::raw_string_ostream optionsStream(
options);
794 if (
auto paramOperand = dyn_cast<transform::ParamOperandAttr>(valueAttr)) {
797 int64_t dynamicOptionIdx = paramOperand.getIndex().getInt();
798 assert(dynamicOptionIdx <
static_cast<int64_t
>(dynamicOptions.size()) &&
799 "the number of ParamOperandAttrs in the options DictionaryAttr"
800 "should be the same as the number of options passed as params");
802 state.getParams(dynamicOptions[dynamicOptionIdx]);
804 llvm::interleave(attrsAssociatedToParam, optionsStream, appendValueAttr,
806 }
else if (
auto arrayAttr = dyn_cast<ArrayAttr>(valueAttr)) {
808 llvm::interleave(arrayAttr, optionsStream, appendValueAttr,
",");
809 }
else if (
auto strAttr = dyn_cast<StringAttr>(valueAttr)) {
811 optionsStream << strAttr.getValue().str();
814 valueAttr.print(optionsStream,
true);
820 getOptions(), optionsStream,
821 [&](
auto namedAttribute) {
822 optionsStream << namedAttribute.getName().str();
823 optionsStream <<
"=";
824 appendValueAttr(namedAttribute.getValue());
827 optionsStream.flush();
835 <<
"unknown pass or pass pipeline: " << getPassName();
844 <<
"failed to add pass or pass pipeline to pipeline: "
860 if (
failed(pm.run(target))) {
861 auto diag = emitSilenceableError() <<
"pass pipeline failed";
862 diag.attachNote(target->
getLoc()) <<
"target op";
868 results.
set(llvm::cast<OpResult>(getResult()), targets);
874 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &dynamicOptions) {
877 size_t dynamicOptionsIdx = 0;
883 std::function<ParseResult(
Attribute &)> parseValue =
884 [&](
Attribute &valueAttr) -> ParseResult {
892 [&]() -> ParseResult { return parseValue(attrs.emplace_back()); },
893 " in options dictionary") ||
907 ParseResult parsedOperand = parser.
parseOperand(operand);
908 if (
failed(parsedOperand))
914 dynamicOptions.push_back(operand);
921 }
else if (isa<transform::ParamOperandAttr>(valueAttr)) {
923 <<
"the param_operand attribute is a marker reserved for "
924 <<
"indicating a value will be passed via params and is only used "
925 <<
"in the generic print format";
939 <<
"expected key to either be an identifier or a string";
943 <<
"expected '=' after key in key-value pair";
945 if (
failed(parseValue(valueAttr)))
947 <<
"expected a valid attribute or operand as value associated "
948 <<
"to key '" << key <<
"'";
957 " in options dictionary"))
960 if (DictionaryAttr::findDuplicate(
961 keyValuePairs,
false)
964 <<
"duplicate keys found in options dictionary";
979 if (
auto paramOperandAttr =
980 dyn_cast<transform::ParamOperandAttr>(valueAttr)) {
983 dynamicOptions[paramOperandAttr.getIndex().getInt()]);
984 }
else if (
auto arrayAttr = dyn_cast<ArrayAttr>(valueAttr)) {
996 printer << namedAttribute.
getName();
1010 std::function<LogicalResult(
Attribute)> checkOptionValue =
1011 [&](
Attribute valueAttr) -> LogicalResult {
1012 if (
auto paramOperand = dyn_cast<transform::ParamOperandAttr>(valueAttr)) {
1013 int64_t dynamicOptionIdx = paramOperand.getIndex().getInt();
1014 if (dynamicOptionIdx < 0 ||
1015 dynamicOptionIdx >=
static_cast<int64_t
>(dynamicOptions.size()))
1016 return emitOpError()
1017 <<
"dynamic option index " << dynamicOptionIdx
1018 <<
" is out of bounds for the number of dynamic options: "
1019 << dynamicOptions.size();
1020 if (dynamicOptions[dynamicOptionIdx] ==
nullptr)
1021 return emitOpError() <<
"dynamic option index " << dynamicOptionIdx
1022 <<
" is already used in options";
1023 dynamicOptions[dynamicOptionIdx] =
nullptr;
1024 }
else if (
auto arrayAttr = dyn_cast<ArrayAttr>(valueAttr)) {
1026 for (
auto eltAttr : arrayAttr)
1027 if (
failed(checkOptionValue(eltAttr)))
1034 if (
failed(checkOptionValue(namedAttr.getValue())))
1038 for (
Value dynamicOption : dynamicOptions)
1040 return emitOpError() <<
"a param operand does not have a corresponding "
1041 <<
"param_operand attr in the options dict";
1052 Operation *target, ApplyToEachResultList &results,
1054 results.push_back(target);
1058 void transform::CastOp::getEffects(
1059 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
1066 assert(inputs.size() == 1 &&
"expected one input");
1067 assert(outputs.size() == 1 &&
"expected one output");
1068 return llvm::all_of(
1069 std::initializer_list<Type>{inputs.front(), outputs.front()},
1070 llvm::IsaPred<transform::TransformHandleTypeInterface>);
1090 assert(block.
getParent() &&
"cannot match using a detached block");
1091 auto matchScope = state.make_region_scope(*block.
getParent());
1093 state.mapBlockArguments(block.
getArguments(), blockArgumentMapping)))
1097 if (!isa<transform::MatchOpInterface>(match)) {
1099 <<
"expected operations in the match part to "
1100 "implement MatchOpInterface";
1103 state.applyTransform(cast<transform::TransformOpInterface>(match));
1104 if (
diag.succeeded())
1122 template <
typename... Tys>
1124 return ((isa<Tys>(t1) && isa<Tys>(t2)) || ... ||
false);
1131 transform::TransformParamTypeInterface,
1132 transform::TransformValueHandleTypeInterface>(
1144 auto matcher = SymbolTable::lookupNearestSymbolFrom<FunctionOpInterface>(
1145 getOperation(), getMatcher());
1146 if (matcher.isExternal()) {
1148 <<
"unresolved external symbol " << getMatcher();
1152 rawResults.resize(getOperation()->getNumResults());
1153 std::optional<DiagnosedSilenceableFailure> maybeFailure;
1154 for (
Operation *root : state.getPayloadOps(getRoot())) {
1165 matcher.getFunctionBody().front(),
1168 if (
diag.isDefiniteFailure())
1170 if (
diag.isSilenceableFailure()) {
1172 <<
" failed: " <<
diag.getMessage();
1178 if (mapping.size() != 1) {
1179 maybeFailure.emplace(emitSilenceableError()
1180 <<
"result #" << i <<
", associated with "
1182 <<
" payload objects, expected 1");
1185 rawResults[i].push_back(mapping[0]);
1190 return std::move(*maybeFailure);
1191 assert(!maybeFailure &&
"failure set but the walk was not interrupted");
1193 for (
auto &&[opResult, rawResult] :
1194 llvm::zip_equal(getOperation()->getResults(), rawResults)) {
1201 void transform::CollectMatchingOp::getEffects(
1202 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
1208 LogicalResult transform::CollectMatchingOp::verifySymbolUses(
1210 auto matcherSymbol = dyn_cast_or_null<FunctionOpInterface>(
1212 if (!matcherSymbol ||
1213 !isa<TransformOpInterface>(matcherSymbol.getOperation()))
1214 return emitError() <<
"unresolved matcher symbol " << getMatcher();
1217 if (argumentTypes.size() != 1 ||
1218 !isa<TransformHandleTypeInterface>(argumentTypes[0])) {
1220 <<
"expected the matcher to take one operation handle argument";
1222 if (!matcherSymbol.getArgAttr(
1223 0, transform::TransformDialect::kArgReadOnlyAttrName)) {
1224 return emitError() <<
"expected the matcher argument to be marked readonly";
1228 if (resultTypes.size() != getOperation()->getNumResults()) {
1230 <<
"expected the matcher to yield as many values as op has results ("
1231 << getOperation()->getNumResults() <<
"), got "
1232 << resultTypes.size();
1235 for (
auto &&[i, matcherType, resultType] :
1241 <<
"mismatching type interfaces for matcher result and op result #"
1253 bool transform::ForeachMatchOp::allowsRepeatedHandleOperands() {
return true; }
1261 matchActionPairs.reserve(getMatchers().size());
1263 for (
auto &&[matcher, action] :
1264 llvm::zip_equal(getMatchers(), getActions())) {
1265 auto matcherSymbol =
1267 getOperation(), cast<SymbolRefAttr>(matcher));
1270 getOperation(), cast<SymbolRefAttr>(action));
1271 assert(matcherSymbol && actionSymbol &&
1272 "unresolved symbols not caught by the verifier");
1274 if (matcherSymbol.isExternal())
1276 if (actionSymbol.isExternal())
1279 matchActionPairs.emplace_back(matcherSymbol, actionSymbol);
1290 matchInputMapping.emplace_back();
1292 getForwardedInputs(), state);
1294 actionResultMapping.resize(getForwardedOutputs().size());
1296 for (
Operation *root : state.getPayloadOps(getRoot())) {
1300 if (!getRestrictRoot() && op == root)
1308 firstMatchArgument.clear();
1309 firstMatchArgument.push_back(op);
1312 for (
auto [matcher, action] : matchActionPairs) {
1314 matchBlock(matcher.getFunctionBody().front(), matchInputMapping,
1315 state, matchOutputMapping);
1316 if (
diag.isDefiniteFailure())
1318 if (
diag.isSilenceableFailure()) {
1320 <<
" failed: " <<
diag.getMessage();
1324 auto scope = state.make_region_scope(action.getFunctionBody());
1325 if (
failed(state.mapBlockArguments(
1326 action.getFunctionBody().front().getArguments(),
1327 matchOutputMapping))) {
1332 action.getFunctionBody().front().without_terminator()) {
1334 state.applyTransform(cast<TransformOpInterface>(transform));
1339 overallDiag = emitSilenceableError() <<
"actions failed";
1344 <<
"when applied to this matching payload";
1351 action.getFunctionBody().front().getTerminator()->getOperands(),
1352 state, getFlattenResults()))) {
1354 <<
"action @" << action.getName()
1355 <<
" has results associated with multiple payload entities, "
1356 "but flattening was not requested";
1371 results.
set(llvm::cast<OpResult>(getUpdated()),
1372 state.getPayloadOps(getRoot()));
1373 for (
auto &&[result, mapping] :
1374 llvm::zip_equal(getForwardedOutputs(), actionResultMapping)) {
1380 void transform::ForeachMatchOp::getAsmResultNames(
1382 setNameFn(getUpdated(),
"updated_root");
1383 for (
Value v : getForwardedOutputs()) {
1384 setNameFn(v,
"yielded");
1388 void transform::ForeachMatchOp::getEffects(
1389 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
1391 if (getOperation()->getNumOperands() < 1 ||
1392 getOperation()->getNumResults() < 1) {
1405 ArrayAttr &matchers,
1406 ArrayAttr &actions) {
1428 ArrayAttr matchers, ArrayAttr actions) {
1431 for (
auto &&[matcher, action, idx] : llvm::zip_equal(
1432 matchers, actions, llvm::seq<unsigned>(0, matchers.size()))) {
1434 printer << cast<SymbolRefAttr>(matcher) <<
" -> "
1435 << cast<SymbolRefAttr>(action);
1436 if (idx != matchers.size() - 1)
1444 if (getMatchers().size() != getActions().size())
1445 return emitOpError() <<
"expected the same number of matchers and actions";
1446 if (getMatchers().empty())
1447 return emitOpError() <<
"expected at least one match/action pair";
1451 if (matcherNames.insert(name).second)
1454 <<
" is used more than once, only the first match will apply";
1465 bool alsoVerifyInternal =
false) {
1466 auto transformOp = cast<transform::TransformOpInterface>(op.getOperation());
1467 llvm::SmallDenseSet<unsigned> consumedArguments;
1468 if (!op.isExternal()) {
1472 for (
unsigned i = 0, e = op.getNumArguments(); i < e; ++i) {
1474 op.getArgAttr(i, transform::TransformDialect::kArgConsumedAttrName) !=
1477 op.getArgAttr(i, transform::TransformDialect::kArgReadOnlyAttrName) !=
1479 if (isConsumed && isReadOnly) {
1480 return transformOp.emitSilenceableError()
1481 <<
"argument #" << i <<
" cannot be both readonly and consumed";
1483 if ((op.isExternal() || alsoVerifyInternal) && !isConsumed && !isReadOnly) {
1484 return transformOp.emitSilenceableError()
1485 <<
"must provide consumed/readonly status for arguments of "
1486 "external or called ops";
1488 if (op.isExternal())
1491 if (consumedArguments.contains(i) && !isConsumed && isReadOnly) {
1492 return transformOp.emitSilenceableError()
1493 <<
"argument #" << i
1494 <<
" is consumed in the body but is not marked as such";
1496 if (emitWarnings && !consumedArguments.contains(i) && isConsumed) {
1500 <<
"op argument #" << i
1501 <<
" is not consumed in the body but is marked as consumed";
1507 LogicalResult transform::ForeachMatchOp::verifySymbolUses(
1509 assert(getMatchers().size() == getActions().size());
1512 for (
auto &&[matcher, action] :
1513 llvm::zip_equal(getMatchers(), getActions())) {
1515 auto matcherSymbol = dyn_cast_or_null<FunctionOpInterface>(
1517 cast<SymbolRefAttr>(matcher)));
1518 auto actionSymbol = dyn_cast_or_null<FunctionOpInterface>(
1520 cast<SymbolRefAttr>(action)));
1521 if (!matcherSymbol ||
1522 !isa<TransformOpInterface>(matcherSymbol.getOperation()))
1523 return emitError() <<
"unresolved matcher symbol " << matcher;
1524 if (!actionSymbol ||
1525 !isa<TransformOpInterface>(actionSymbol.getOperation()))
1526 return emitError() <<
"unresolved action symbol " << action;
1531 .checkAndReport())) {
1537 .checkAndReport())) {
1542 TypeRange operandTypes = getOperandTypes();
1543 TypeRange matcherArguments = matcherSymbol.getArgumentTypes();
1544 if (operandTypes.size() != matcherArguments.size()) {
1546 emitError() <<
"the number of operands (" << operandTypes.size()
1547 <<
") doesn't match the number of matcher arguments ("
1548 << matcherArguments.size() <<
") for " << matcher;
1549 diag.attachNote(matcherSymbol->getLoc()) <<
"symbol declaration";
1552 for (
auto &&[i, operand, argument] :
1554 if (matcherSymbol.getArgAttr(i, consumedAttr)) {
1557 <<
"does not expect matcher symbol to consume its operand #" << i;
1558 diag.attachNote(matcherSymbol->getLoc()) <<
"symbol declaration";
1567 <<
"mismatching type interfaces for operand and matcher argument #"
1568 << i <<
" of matcher " << matcher;
1569 diag.attachNote(matcherSymbol->getLoc()) <<
"symbol declaration";
1574 TypeRange matcherResults = matcherSymbol.getResultTypes();
1575 TypeRange actionArguments = actionSymbol.getArgumentTypes();
1576 if (matcherResults.size() != actionArguments.size()) {
1577 return emitError() <<
"mismatching number of matcher results and "
1578 "action arguments between "
1579 << matcher <<
" (" << matcherResults.size() <<
") and "
1580 << action <<
" (" << actionArguments.size() <<
")";
1582 for (
auto &&[i, matcherType, actionType] :
1587 return emitError() <<
"mismatching type interfaces for matcher result "
1588 "and action argument #"
1589 << i <<
"of matcher " << matcher <<
" and action "
1594 TypeRange actionResults = actionSymbol.getResultTypes();
1595 auto resultTypes =
TypeRange(getResultTypes()).drop_front();
1596 if (actionResults.size() != resultTypes.size()) {
1598 emitError() <<
"the number of action results ("
1599 << actionResults.size() <<
") for " << action
1600 <<
" doesn't match the number of extra op results ("
1601 << resultTypes.size() <<
")";
1602 diag.attachNote(actionSymbol->getLoc()) <<
"symbol declaration";
1605 for (
auto &&[i, resultType, actionType] :
1611 emitError() <<
"mismatching type interfaces for action result #" << i
1612 <<
" of action " << action <<
" and op result";
1613 diag.attachNote(actionSymbol->getLoc()) <<
"symbol declaration";
1632 size_t numIterations = payloads.empty() ? 0 : payloads.front().size();
1633 bool withZipShortest = getWithZipShortest();
1637 if (withZipShortest) {
1641 return a.size() < b.size();
1644 for (
auto &payload : payloads)
1645 payload.resize(numIterations);
1651 for (
size_t argIdx = 1; !withZipShortest && argIdx < payloads.size();
1653 if (payloads[argIdx].size() != numIterations) {
1654 return emitSilenceableError()
1655 <<
"prior targets' payload size (" << numIterations
1656 <<
") differs from payload size (" << payloads[argIdx].size()
1657 <<
") of target " << getTargets()[argIdx];
1666 for (
size_t iterIdx = 0; iterIdx < numIterations; iterIdx++) {
1667 auto scope = state.make_region_scope(getBody());
1673 if (
failed(state.mapBlockArgument(blockArg, {argument})))
1678 for (
Operation &transform : getBody().front().without_terminator()) {
1680 llvm::cast<transform::TransformOpInterface>(transform));
1686 OperandRange yieldOperands = getYieldOp().getOperands();
1687 for (
auto &&[result, yieldOperand, resTuple] :
1688 llvm::zip_equal(getResults(), yieldOperands, zippedResults))
1690 if (isa<TransformHandleTypeInterface>(result.getType()))
1691 llvm::append_range(resTuple, state.getPayloadOps(yieldOperand));
1692 else if (isa<TransformValueHandleTypeInterface>(result.getType()))
1693 llvm::append_range(resTuple, state.getPayloadValues(yieldOperand));
1694 else if (isa<TransformParamTypeInterface>(result.getType()))
1695 llvm::append_range(resTuple, state.getParams(yieldOperand));
1697 assert(
false &&
"unhandled handle type");
1701 for (
auto &&[result, resPayload] : zip_equal(getResults(), zippedResults))
1707 void transform::ForeachOp::getEffects(
1708 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
1711 for (
auto &&[target, blockArg] :
1712 llvm::zip(getTargetsMutable(), getBody().front().getArguments())) {
1714 if (any_of(getBody().front().without_terminator(), [&](
Operation &op) {
1716 cast<TransformOpInterface>(&op));
1724 if (any_of(getBody().front().without_terminator(), [&](
Operation &op) {
1728 }
else if (any_of(getBody().front().without_terminator(), [&](
Operation &op) {
1737 void transform::ForeachOp::getSuccessorRegions(
1739 Region *bodyRegion = &getBody();
1741 regions.emplace_back(bodyRegion, bodyRegion->
getArguments());
1748 "unexpected region index");
1749 regions.emplace_back(bodyRegion, bodyRegion->
getArguments());
1750 regions.emplace_back(getOperation(), getOperation()->getResults());
1754 transform::ForeachOp::getEntrySuccessorOperands(
RegionSuccessor successor) {
1757 assert(successor.
getSuccessor() == &getBody() &&
"unexpected region index");
1758 return getOperation()->getOperands();
1761 transform::YieldOp transform::ForeachOp::getYieldOp() {
1762 return cast<transform::YieldOp>(getBody().front().getTerminator());
1766 for (
auto [targetOpt, bodyArgOpt] :
1767 llvm::zip_longest(getTargets(), getBody().front().getArguments())) {
1768 if (!targetOpt || !bodyArgOpt)
1769 return emitOpError() <<
"expects the same number of targets as the body "
1770 "has block arguments";
1771 if (targetOpt.value().getType() != bodyArgOpt.value().getType())
1773 "expects co-indexed targets and the body's "
1774 "block arguments to have the same op/value/param type");
1777 for (
auto [resultOpt, yieldOperandOpt] :
1778 llvm::zip_longest(getResults(), getYieldOp().getOperands())) {
1779 if (!resultOpt || !yieldOperandOpt)
1780 return emitOpError() <<
"expects the same number of results as the "
1781 "yield terminator has operands";
1782 if (resultOpt.value().getType() != yieldOperandOpt.value().getType())
1783 return emitOpError(
"expects co-indexed results and yield "
1784 "operands to have the same op/value/param type");
1800 for (
Operation *target : state.getPayloadOps(getTarget())) {
1802 for (int64_t i = 0, e = getNthParent(); i < e; ++i) {
1805 bool checkIsolatedFromAbove =
1806 !getIsolatedFromAbove() ||
1808 bool checkOpName = !getOpName().has_value() ||
1810 if (checkIsolatedFromAbove && checkOpName)
1815 if (getAllowEmptyResults()) {
1816 results.
set(llvm::cast<OpResult>(getResult()), parents);
1820 emitSilenceableError()
1821 <<
"could not find a parent op that matches all requirements";
1822 diag.attachNote(target->
getLoc()) <<
"target op";
1826 if (getDeduplicate()) {
1827 if (resultSet.insert(parent).second)
1828 parents.push_back(parent);
1830 parents.push_back(parent);
1833 results.
set(llvm::cast<OpResult>(getResult()), parents);
1845 int64_t resultNumber = getResultNumber();
1846 auto payloadOps = state.getPayloadOps(getTarget());
1847 if (std::empty(payloadOps)) {
1848 results.
set(cast<OpResult>(getResult()), {});
1851 if (!llvm::hasSingleElement(payloadOps))
1853 <<
"handle must be mapped to exactly one payload op";
1855 Operation *target = *payloadOps.begin();
1858 results.
set(llvm::cast<OpResult>(getResult()),
1872 for (
Value v : state.getPayloadValues(getTarget())) {
1873 if (llvm::isa<BlockArgument>(v)) {
1875 emitSilenceableError() <<
"cannot get defining op of block argument";
1876 diag.attachNote(v.getLoc()) <<
"target value";
1879 definingOps.push_back(v.getDefiningOp());
1881 results.
set(llvm::cast<OpResult>(getResult()), definingOps);
1893 int64_t operandNumber = getOperandNumber();
1895 for (
Operation *target : state.getPayloadOps(getTarget())) {
1902 emitSilenceableError()
1903 <<
"could not find a producer for operand number: " << operandNumber
1904 <<
" of " << *target;
1905 diag.attachNote(target->getLoc()) <<
"target op";
1908 producers.push_back(producer);
1910 results.
set(llvm::cast<OpResult>(getResult()), producers);
1923 for (
Operation *target : state.getPayloadOps(getTarget())) {
1926 getLoc(), getIsAll(), getIsInverted(), getRawPositionList(),
1927 target->getNumOperands(), operandPositions);
1928 if (
diag.isSilenceableFailure()) {
1929 diag.attachNote(target->getLoc())
1930 <<
"while considering positions of this payload operation";
1933 llvm::append_range(operands,
1934 llvm::map_range(operandPositions, [&](int64_t pos) {
1935 return target->getOperand(pos);
1938 results.
setValues(cast<OpResult>(getResult()), operands);
1944 getIsInverted(), getIsAll());
1956 for (
Operation *target : state.getPayloadOps(getTarget())) {
1959 getLoc(), getIsAll(), getIsInverted(), getRawPositionList(),
1960 target->getNumResults(), resultPositions);
1961 if (
diag.isSilenceableFailure()) {
1962 diag.attachNote(target->getLoc())
1963 <<
"while considering positions of this payload operation";
1966 llvm::append_range(opResults,
1967 llvm::map_range(resultPositions, [&](int64_t pos) {
1968 return target->getResult(pos);
1971 results.
setValues(cast<OpResult>(getResult()), opResults);
1977 getIsInverted(), getIsAll());
1984 void transform::GetTypeOp::getEffects(
1985 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
1996 for (
Value value : state.getPayloadValues(getValue())) {
1997 Type type = value.getType();
1998 if (getElemental()) {
1999 if (
auto shaped = dyn_cast<ShapedType>(type)) {
2000 type = shaped.getElementType();
2005 results.
setParams(cast<OpResult>(getResult()), params);
2022 state.applyTransform(cast<transform::TransformOpInterface>(transform));
2027 if (mode == transform::FailurePropagationMode::Propagate) {
2046 auto callee = SymbolTable::lookupNearestSymbolFrom<NamedSequenceOp>(
2047 getOperation(), getTarget());
2048 assert(callee &&
"unverified reference to unknown symbol");
2050 if (callee.isExternal())
2056 auto scope = state.make_region_scope(callee.getBody());
2057 for (
auto &&[arg, map] :
2058 llvm::zip_equal(callee.getBody().front().getArguments(), mappings)) {
2059 if (
failed(state.mapBlockArgument(arg, map)))
2064 callee.getBody().front(), getFailurePropagationMode(), state, results);
2067 mappings, callee.getBody().front().getTerminator()->getOperands(), state);
2068 for (
auto &&[result, mapping] : llvm::zip_equal(getResults(), mappings))
2076 void transform::IncludeOp::getEffects(
2077 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
2091 auto defaultEffects = [&] {
2098 getOperation()->getAttrOfType<SymbolRefAttr>(getTargetAttrName());
2100 return defaultEffects();
2101 auto callee = SymbolTable::lookupNearestSymbolFrom<NamedSequenceOp>(
2102 getOperation(), getTarget());
2104 return defaultEffects();
2106 for (
unsigned i = 0, e = getNumOperands(); i < e; ++i) {
2107 if (callee.getArgAttr(i, TransformDialect::kArgConsumedAttrName))
2109 else if (callee.getArgAttr(i, TransformDialect::kArgReadOnlyAttrName))
2118 auto targetAttr = getOperation()->getAttrOfType<SymbolRefAttr>(
"target");
2120 return emitOpError() <<
"expects a 'target' symbol reference attribute";
2125 return emitOpError() <<
"does not reference a named transform sequence";
2127 FunctionType fnType = target.getFunctionType();
2128 if (fnType.getNumInputs() != getNumOperands())
2129 return emitError(
"incorrect number of operands for callee");
2131 for (
unsigned i = 0, e = fnType.getNumInputs(); i != e; ++i) {
2132 if (getOperand(i).
getType() != fnType.getInput(i)) {
2133 return emitOpError(
"operand type mismatch: expected operand type ")
2134 << fnType.getInput(i) <<
", but provided "
2135 << getOperand(i).getType() <<
" for operand number " << i;
2139 if (fnType.getNumResults() != getNumResults())
2140 return emitError(
"incorrect number of results for callee");
2142 for (
unsigned i = 0, e = fnType.getNumResults(); i != e; ++i) {
2143 Type resultType = getResult(i).getType();
2144 Type funcType = fnType.getResult(i);
2146 return emitOpError() <<
"type of result #" << i
2147 <<
" must implement the same transform dialect "
2148 "interface as the corresponding callee result";
2153 cast<FunctionOpInterface>(*target),
false,
2163 ::std::optional<::mlir::Operation *> maybeCurrent,
2165 if (!maybeCurrent.has_value()) {
2170 return emitSilenceableError() <<
"operation is not empty";
2181 for (
auto acceptedAttr : getOpNames().getAsRange<StringAttr>()) {
2182 if (acceptedAttr.getValue() == currentOpName)
2185 return emitSilenceableError() <<
"wrong operation name";
2196 auto signedAPIntAsString = [&](
const APInt &value) {
2198 llvm::raw_string_ostream os(str);
2199 value.print(os,
true);
2206 if (params.size() != references.size()) {
2207 return emitSilenceableError()
2208 <<
"parameters have different payload lengths (" << params.size()
2209 <<
" vs " << references.size() <<
")";
2212 for (
auto &&[i, param, reference] :
llvm::enumerate(params, references)) {
2213 auto intAttr = llvm::dyn_cast<IntegerAttr>(param);
2214 auto refAttr = llvm::dyn_cast<IntegerAttr>(reference);
2215 if (!intAttr || !refAttr) {
2217 <<
"non-integer parameter value not expected";
2219 if (intAttr.getType() != refAttr.getType()) {
2221 <<
"mismatching integer attribute types in parameter #" << i;
2223 APInt value = intAttr.getValue();
2224 APInt refValue = refAttr.getValue();
2227 int64_t position = i;
2228 auto reportError = [&](StringRef direction) {
2230 emitSilenceableError() <<
"expected parameter to be " << direction
2231 <<
" " << signedAPIntAsString(refValue)
2232 <<
", got " << signedAPIntAsString(value);
2233 diag.attachNote(getParam().getLoc())
2234 <<
"value # " << position
2235 <<
" associated with the parameter defined here";
2239 switch (getPredicate()) {
2240 case MatchCmpIPredicate::eq:
2241 if (value.eq(refValue))
2243 return reportError(
"equal to");
2244 case MatchCmpIPredicate::ne:
2245 if (value.ne(refValue))
2247 return reportError(
"not equal to");
2248 case MatchCmpIPredicate::lt:
2249 if (value.slt(refValue))
2251 return reportError(
"less than");
2252 case MatchCmpIPredicate::le:
2253 if (value.sle(refValue))
2255 return reportError(
"less than or equal to");
2256 case MatchCmpIPredicate::gt:
2257 if (value.sgt(refValue))
2259 return reportError(
"greater than");
2260 case MatchCmpIPredicate::ge:
2261 if (value.sge(refValue))
2263 return reportError(
"greater than or equal to");
2269 void transform::MatchParamCmpIOp::getEffects(
2270 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
2283 results.
setParams(cast<OpResult>(getParam()), {getValue()});
2296 if (isa<TransformHandleTypeInterface>(handles.front().
getType())) {
2298 for (
Value operand : handles)
2299 llvm::append_range(operations, state.getPayloadOps(operand));
2300 if (!getDeduplicate()) {
2301 results.
set(llvm::cast<OpResult>(getResult()), operations);
2306 results.
set(llvm::cast<OpResult>(getResult()), uniqued.getArrayRef());
2310 if (llvm::isa<TransformParamTypeInterface>(handles.front().getType())) {
2312 for (
Value attribute : handles)
2313 llvm::append_range(attrs, state.getParams(attribute));
2314 if (!getDeduplicate()) {
2315 results.
setParams(cast<OpResult>(getResult()), attrs);
2320 results.
setParams(cast<OpResult>(getResult()), uniqued.getArrayRef());
2325 llvm::isa<TransformValueHandleTypeInterface>(handles.front().getType()) &&
2326 "expected value handle type");
2328 for (
Value value : handles)
2329 llvm::append_range(payloadValues, state.getPayloadValues(value));
2330 if (!getDeduplicate()) {
2331 results.
setValues(cast<OpResult>(getResult()), payloadValues);
2336 results.
setValues(cast<OpResult>(getResult()), uniqued.getArrayRef());
2340 bool transform::MergeHandlesOp::allowsRepeatedHandleOperands() {
2342 return getDeduplicate();
2345 void transform::MergeHandlesOp::getEffects(
2346 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
2354 OpFoldResult transform::MergeHandlesOp::fold(FoldAdaptor adaptor) {
2355 if (getDeduplicate() || getHandles().size() != 1)
2360 return getHandles().front();
2378 auto scope = state.make_region_scope(getBody());
2380 state, this->getOperation(), getBody())))
2384 FailurePropagationMode::Propagate, state, results);
2387 void transform::NamedSequenceOp::getEffects(
2388 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {}
2393 parser, result,
false,
2394 getFunctionTypeAttrName(result.
name),
2397 std::string &) { return builder.getFunctionType(inputs, results); },
2398 getArgAttrsAttrName(result.
name), getResAttrsAttrName(result.
name));
2403 printer, cast<FunctionOpInterface>(getOperation()),
false,
2404 getFunctionTypeAttrName().getValue(), getArgAttrsAttrName(),
2405 getResAttrsAttrName());
2415 if (
auto parent = op->
getParentOfType<transform::TransformOpInterface>()) {
2418 <<
"cannot be defined inside another transform op";
2419 diag.attachNote(parent.
getLoc()) <<
"ancestor transform op";
2423 if (op.isExternal() || op.getFunctionBody().empty()) {
2430 if (op.getFunctionBody().front().empty())
2433 Operation *terminator = &op.getFunctionBody().front().back();
2434 if (!isa<transform::YieldOp>(terminator)) {
2437 << transform::YieldOp::getOperationName()
2438 <<
"' as terminator";
2439 diag.attachNote(terminator->
getLoc()) <<
"terminator";
2443 if (terminator->
getNumOperands() != op.getResultTypes().size()) {
2445 <<
"expected terminator to have as many operands as the parent op "
2448 for (
auto [i, operandType, resultType] : llvm::zip_equal(
2451 if (operandType == resultType)
2454 <<
"the type of the terminator operand #" << i
2455 <<
" must match the type of the corresponding parent op result ("
2456 << operandType <<
" vs " << resultType <<
")";
2469 transform::TransformDialect::kWithNamedSequenceAttrName)) {
2472 <<
"expects the parent symbol table to have the '"
2473 << transform::TransformDialect::kWithNamedSequenceAttrName
2475 diag.attachNote(parent->
getLoc()) <<
"symbol table operation";
2480 if (
auto parent = op->
getParentOfType<transform::TransformOpInterface>()) {
2483 <<
"cannot be defined inside another transform op";
2484 diag.attachNote(parent.
getLoc()) <<
"ancestor transform op";
2488 if (op.isExternal() || op.getBody().empty())
2492 if (op.getBody().front().empty())
2495 Operation *terminator = &op.getBody().front().back();
2496 if (!isa<transform::YieldOp>(terminator)) {
2499 << transform::YieldOp::getOperationName()
2500 <<
"' as terminator";
2501 diag.attachNote(terminator->
getLoc()) <<
"terminator";
2505 if (terminator->
getNumOperands() != op.getFunctionType().getNumResults()) {
2507 <<
"expected terminator to have as many operands as the parent op "
2510 for (
auto [i, operandType, resultType] :
2511 llvm::zip_equal(llvm::seq<unsigned>(0, terminator->
getNumOperands()),
2513 op.getFunctionType().getResults())) {
2514 if (operandType == resultType)
2517 <<
"the type of the terminator operand #" << i
2518 <<
" must match the type of the corresponding parent op result ("
2519 << operandType <<
" vs " << resultType <<
")";
2522 auto funcOp = cast<FunctionOpInterface>(*op);
2525 if (!
diag.succeeded())
2537 template <
typename FnTy>
2542 types.reserve(1 + extraBindingTypes.size());
2543 types.push_back(bbArgType);
2544 llvm::append_range(types, extraBindingTypes);
2547 Region *region = state.regions.back().get();
2554 if constexpr (llvm::function_traits<FnTy>::num_args == 3) {
2555 bodyBuilder(builder, state.location, bodyBlock->
getArgument(0));
2557 bodyBuilder(builder, state.location, bodyBlock->
getArgument(0),
2562 void transform::NamedSequenceOp::build(
OpBuilder &builder,
2570 state.addAttribute(getFunctionTypeAttrName(state.name),
2572 rootType, resultTypes)));
2573 state.attributes.append(attrs.begin(), attrs.end());
2588 size_t numAssociations =
2590 .Case([&](TransformHandleTypeInterface opHandle) {
2591 return llvm::range_size(state.getPayloadOps(getHandle()));
2593 .Case([&](TransformValueHandleTypeInterface valueHandle) {
2594 return llvm::range_size(state.getPayloadValues(getHandle()));
2596 .Case([&](TransformParamTypeInterface param) {
2597 return llvm::range_size(state.getParams(getHandle()));
2599 .DefaultUnreachable(
"unknown kind of transform dialect type");
2600 results.
setParams(cast<OpResult>(getNum()),
2607 auto resultType = cast<TransformParamTypeInterface>(getNum().
getType());
2622 auto payloadOps = state.getPayloadOps(getTarget());
2625 result.push_back(op);
2627 results.
set(cast<OpResult>(getResult()), result);
2636 Value target, int64_t numResultHandles) {
2645 int64_t numPayloads =
2647 .Case<TransformHandleTypeInterface>([&](
auto x) {
2648 return llvm::range_size(state.getPayloadOps(getHandle()));
2650 .Case<TransformValueHandleTypeInterface>([&](
auto x) {
2651 return llvm::range_size(state.getPayloadValues(getHandle()));
2653 .Case<TransformParamTypeInterface>([&](
auto x) {
2654 return llvm::range_size(state.getParams(getHandle()));
2656 .DefaultUnreachable(
"unknown transform dialect type interface");
2658 auto produceNumOpsError = [&]() {
2659 return emitSilenceableError()
2660 << getHandle() <<
" expected to contain " << this->getNumResults()
2661 <<
" payloads but it contains " << numPayloads <<
" payloads";
2666 if (numPayloads > getNumResults() && !getOverflowResult().has_value())
2667 return produceNumOpsError();
2672 if (numPayloads < getNumResults() && getFailOnPayloadTooSmall() &&
2673 (numPayloads != 0 || !getPassThroughEmptyHandle()))
2674 return produceNumOpsError();
2678 if (getOverflowResult())
2679 resultHandles[*getOverflowResult()].reserve(numPayloads - getNumResults());
2681 auto container = [&]() {
2682 if (isa<TransformHandleTypeInterface>(getHandle().
getType())) {
2683 return llvm::map_to_vector(
2684 state.getPayloadOps(getHandle()),
2687 if (isa<TransformValueHandleTypeInterface>(getHandle().
getType())) {
2688 return llvm::map_to_vector(state.getPayloadValues(getHandle()),
2691 assert(isa<TransformParamTypeInterface>(getHandle().
getType()) &&
2692 "unsupported kind of transform dialect type");
2693 return llvm::map_to_vector(state.getParams(getHandle()),
2698 int64_t resultNum = en.index();
2699 if (resultNum >= getNumResults())
2700 resultNum = *getOverflowResult();
2701 resultHandles[resultNum].push_back(en.value());
2712 void transform::SplitHandleOp::getEffects(
2713 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
2721 if (getOverflowResult().has_value() &&
2722 !(*getOverflowResult() < getNumResults()))
2723 return emitOpError(
"overflow_result is not a valid result index");
2725 for (
Type resultType : getResultTypes()) {
2729 return emitOpError(
"expects result types to implement the same transform "
2730 "interface as the operand type");
2744 unsigned numRepetitions = llvm::range_size(state.getPayloadOps(getPattern()));
2746 Value handle = en.value();
2747 if (isa<TransformHandleTypeInterface>(handle.
getType())) {
2749 llvm::to_vector(state.getPayloadOps(handle));
2751 payload.reserve(numRepetitions * current.size());
2752 for (
unsigned i = 0; i < numRepetitions; ++i)
2753 llvm::append_range(payload, current);
2754 results.
set(llvm::cast<OpResult>(getReplicated()[en.index()]), payload);
2756 assert(llvm::isa<TransformParamTypeInterface>(handle.
getType()) &&
2757 "expected param type");
2760 params.reserve(numRepetitions * current.size());
2761 for (
unsigned i = 0; i < numRepetitions; ++i)
2762 llvm::append_range(params, current);
2763 results.
setParams(llvm::cast<OpResult>(getReplicated()[en.index()]),
2770 void transform::ReplicateOp::getEffects(
2771 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
2786 auto scope = state.make_region_scope(*getBodyBlock()->getParent());
2787 if (
failed(mapBlockArguments(state)))
2795 OpAsmParser &parser, std::optional<OpAsmParser::UnresolvedOperand> &root,
2797 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &extraBindings,
2798 SmallVectorImpl<Type> &extraBindingTypes) {
2802 root = std::nullopt;
2823 if (!extraBindings.empty()) {
2828 if (extraBindingTypes.size() != extraBindings.size()) {
2830 "expected types to be provided for all operands");
2846 bool hasExtras = !extraBindings.empty();
2856 printer << rootType;
2858 printer <<
", " << llvm::interleaved(extraBindingTypes) <<
')';
2865 auto iface = dyn_cast<transform::TransformOpInterface>(use.
getOwner());
2880 if (!potentialConsumer) {
2881 potentialConsumer = &use;
2886 <<
" has more than one potential consumer";
2889 diag.attachNote(use.getOwner()->getLoc())
2890 <<
"used here as operand #" << use.getOperandNumber();
2898 assert(getBodyBlock()->getNumArguments() >= 1 &&
2899 "the number of arguments must have been verified to be more than 1 by "
2900 "PossibleTopLevelTransformOpTrait");
2902 if (!getRoot() && !getExtraBindings().empty()) {
2903 return emitOpError()
2904 <<
"does not expect extra operands when used as top-level";
2910 return (emitOpError() <<
"block argument #" << arg.getArgNumber());
2917 for (
Operation &child : *getBodyBlock()) {
2918 if (!isa<TransformOpInterface>(child) &&
2919 &child != &getBodyBlock()->back()) {
2922 <<
"expected children ops to implement TransformOpInterface";
2923 diag.attachNote(child.getLoc()) <<
"op without interface";
2927 for (
OpResult result : child.getResults()) {
2928 auto report = [&]() {
2929 return (child.emitError() <<
"result #" << result.getResultNumber());
2936 if (!getBodyBlock()->mightHaveTerminator())
2937 return emitOpError() <<
"expects to have a terminator in the body";
2939 if (getBodyBlock()->getTerminator()->getOperandTypes() !=
2940 getOperation()->getResultTypes()) {
2942 <<
"expects the types of the terminator operands "
2943 "to match the types of the result";
2944 diag.attachNote(getBodyBlock()->getTerminator()->getLoc()) <<
"terminator";
2950 void transform::SequenceOp::getEffects(
2951 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
2956 transform::SequenceOp::getEntrySuccessorOperands(
RegionSuccessor successor) {
2957 assert(successor.
getSuccessor() == &getBody() &&
"unexpected region index");
2958 if (getOperation()->getNumOperands() > 0)
2959 return getOperation()->getOperands();
2961 getOperation()->operand_end());
2964 void transform::SequenceOp::getSuccessorRegions(
2967 Region *bodyRegion = &getBody();
2968 regions.emplace_back(bodyRegion, getNumOperands() != 0
2976 "unexpected region index");
2977 regions.emplace_back(getOperation(), getOperation()->getResults());
2980 void transform::SequenceOp::getRegionInvocationBounds(
2983 bounds.emplace_back(1, 1);
2988 FailurePropagationMode failurePropagationMode,
2991 build(builder, state, resultTypes, failurePropagationMode, root,
3000 FailurePropagationMode failurePropagationMode,
3003 build(builder, state, resultTypes, failurePropagationMode, root,
3011 FailurePropagationMode failurePropagationMode,
3014 build(builder, state, resultTypes, failurePropagationMode,
Value(),
3022 FailurePropagationMode failurePropagationMode,
3025 build(builder, state, resultTypes, failurePropagationMode,
Value(),
3041 Value target, StringRef name) {
3043 build(builder, result, name);
3050 llvm::outs() <<
"[[[ IR printer: ";
3051 if (getName().has_value())
3052 llvm::outs() << *getName() <<
" ";
3055 if (getAssumeVerified().value_or(
false))
3057 if (getUseLocalScope().value_or(
false))
3059 if (getSkipRegions().value_or(
false))
3063 llvm::outs() <<
"top-level ]]]\n";
3064 state.getTopLevel()->print(llvm::outs(), printFlags);
3065 llvm::outs() <<
"\n";
3066 llvm::outs().flush();
3070 llvm::outs() <<
"]]]\n";
3071 for (
Operation *target : state.getPayloadOps(getTarget())) {
3072 target->
print(llvm::outs(), printFlags);
3073 llvm::outs() <<
"\n";
3076 llvm::outs().flush();
3080 void transform::PrintOp::getEffects(
3081 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
3085 if (!getTargetMutable().empty())
3105 <<
"failed to verify payload op";
3106 diag.attachNote(target->
getLoc()) <<
"payload op";
3112 void transform::VerifyOp::getEffects(
3113 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
3121 void transform::YieldOp::getEffects(
3122 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
static ParseResult parseKeyValuePair(AsmParser &parser, DataLayoutEntryInterface &entry, bool tryType=false)
Parse an entry which can either be of the form key = value or a #dlti.dl_entry attribute.
static MLIRContext * getContext(OpFoldResult val)
static std::string diag(const llvm::Value &value)
static llvm::ManagedStatic< PassManagerOptions > options
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
ParseResult parseSymbolName(StringAttr &result)
Parse an -identifier and store it (without the '@' symbol) in a string attribute.
@ None
Zero or more operands with no delimiters.
@ Braces
{} brackets surrounding zero or more operands.
virtual ParseResult parseOptionalKeywordOrString(std::string *result)=0
Parse an optional keyword or string.
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
MLIRContext * getContext() const
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseRSquare()=0
Parse a ] token.
virtual ParseResult parseOptionalRParen()=0
Parse a ) token if present.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual OptionalParseResult parseOptionalAttribute(Attribute &result, Type type={})=0
Parse an arbitrary optional attribute of a given type and return it in result.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual ParseResult parseColon()=0
Parse a : token.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseArrow()=0
Parse a '->' token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseComma()=0
Parse a , token.
virtual ParseResult parseOptionalLParen()=0
Parse a ( token if present.
ParseResult parseTypeList(SmallVectorImpl< Type > &result)
Parse a type list.
virtual ParseResult parseOptionalLSquare()=0
Parse a [ token if present.
virtual void printAttribute(Attribute attr)
Attributes are known-constant values of operations.
This class represents an argument of a Block.
Block represents an ordered list of Operations.
MutableArrayRef< BlockArgument > BlockArgListType
BlockArgument getArgument(unsigned i)
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
Operation * getTerminator()
Get the terminator operation of this block.
OpListType & getOperations()
BlockArgListType getArguments()
iterator_range< iterator > without_terminator()
Return an iterator range over the operation within this block excluding the terminator operation at t...
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
This class is a general helper class for creating context-global objects like types,...
IntegerAttr getI64IntegerAttr(int64_t value)
StringAttr getStringAttr(const Twine &bytes)
MLIRContext * getContext() const
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
This class describes a specific conversion target.
A compatibility class connecting InFlightDiagnostic to DiagnosedSilenceableFailure while providing an...
The result of a transform IR operation application.
LogicalResult silence()
Converts silenceable failure into LogicalResult success without reporting the diagnostic,...
static DiagnosedSilenceableFailure success()
Constructs a DiagnosedSilenceableFailure in the success state.
Diagnostic & attachNote(std::optional< Location > loc=std::nullopt)
Attaches a note to the last diagnostic.
std::string getMessage() const
Returns the diagnostic message without emitting it.
bool isDefiniteFailure() const
Returns true if this is a definite failure.
LogicalResult checkAndReport()
Converts all kinds of failure into a LogicalResult failure, emitting the diagnostic if necessary.
bool succeeded() const
Returns true if this is a success.
static DiagnosedSilenceableFailure definiteFailure()
Constructs a DiagnosedSilenceableFailure in the failure state.
bool isSilenceableFailure() const
Returns true if this is a silenceable failure.
Dialects are groups of MLIR operations, types and attributes, as well as behavior associated with the...
A class for computing basic dominance information.
This class represents a frozen set of patterns that can be processed by a pattern applicator.
This class allows control over how the GreedyPatternRewriteDriver works.
static constexpr int64_t kNoLimit
IRValueT get() const
Return the current value being used by this operand.
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
This class represents a diagnostic that is inflight and set to be reported.
This class represents upper and lower bounds on the number of times a region of a RegionBranchOpInter...
Conversion from types to the LLVM IR dialect.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
Dialect * getLoadedDialect(StringRef name)
Get a registered IR dialect with the given namespace.
std::vector< Dialect * > getLoadedDialects()
Return information about all IR dialects loaded in the context.
ArrayRef< RegisteredOperationName > getRegisteredOperations()
Return a sorted array containing the information about all registered operations.
NamedAttribute represents a combination of a name and an Attribute value.
StringAttr getName() const
Return the name of the attribute.
Attribute getValue() const
Return the value of the attribute.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual OptionalParseResult parseOptionalOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single operand if present.
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printNewline()=0
Print a newline and indent the printer to the start of the current operation.
virtual void increaseIndent()=0
Increase indentation.
void printOperands(const ContainerType &container)
Print a comma separated list of operands.
virtual void decreaseIndent()=0
Decrease indentation.
virtual void printOperand(Value value)=0
Print implementations for various things an operation contains.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Listener * getListener() const
Returns the current listener of this builder, or nullptr if this builder doesn't have a listener.
This class represents a single result from folding an operation.
This class represents an operand of an operation.
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Set of flags used to control the behavior of the various IR print methods (e.g.
OpPrintingFlags & useLocalScope(bool enable=true)
Use local scope when printing the operation.
OpPrintingFlags & assumeVerified(bool enable=true)
Do not verify the operation when using custom operation printers.
OpPrintingFlags & skipRegions(bool skip=true)
Skip printing regions.
This is a value defined by a result of an operation.
This class provides the API for ops that are known to be isolated from above.
A trait used to provide symbol table functionalities to a region operation.
A wrapper class that allows for printing an operation with a set of flags, useful to act as a "stream...
This class implements the operand iterators for the Operation class.
type_range getType() const
type_range getTypes() const
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
Operation is the basic unit of execution within MLIR.
Value getOperand(unsigned idx)
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Attribute getAttr(StringAttr name)
Return the specified attribute if present, null otherwise.
Operation * clone(IRMapping &mapper, CloneOptions options=CloneOptions::all())
Create a deep copy of this operation, remapping any operands that use values outside of the operation...
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
MLIRContext * getContext()
Return the context this operation is associated with.
Location getLoc()
The source location the operation was defined or derived from.
unsigned getNumOperands()
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Block * getBlock()
Returns the operation block that contains this operation.
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
Operation * getParentWithTrait()
Returns the closest surrounding parent operation with trait Trait.
OperationName getName()
The name of an operation is the key identifier for it.
operand_range getOperands()
Returns an iterator on the underlying Value's.
result_range getOpResults()
Region * getParentRegion()
Returns the region to which the instruction belongs.
result_range getResults()
bool isProperAncestor(Operation *other)
Return true if this operation is a proper ancestor of the other operation.
void erase()
Remove this operation from its parent block and delete it.
unsigned getNumResults()
Return the number of results held by this operation.
This class implements Optional functionality for ParseResult.
ParseResult value() const
Access the internal ParseResult value.
bool has_value() const
Returns true if we contain a valid ParseResult value.
static const PassInfo * lookup(StringRef passArg)
Returns the pass info for the specified pass class or null if unknown.
The main pass manager and pipeline builder.
static const PassPipelineInfo * lookup(StringRef pipelineArg)
Returns the pass pipeline info for the specified pass pipeline or null if unknown.
Structure to group information about a passes and pass pipelines (argument to invoke via mlir-opt,...
LogicalResult addToPipeline(OpPassManager &pm, StringRef options, function_ref< LogicalResult(const Twine &)> errorHandler) const
Adds this pass registry entry to the given pass manager.
This class represents a point being branched from in the methods of the RegionBranchOpInterface.
bool isParent() const
Returns true if branching from the parent op.
Operation * getTerminatorPredecessorOrNull() const
Returns the terminator if branching from a region.
This class represents a successor of a region.
bool isParent() const
Return true if the successor is the parent operation.
Region * getSuccessor() const
Return the given region successor.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
BlockArgListType getArguments()
unsigned getRegionNumber()
Return the number of this region in the parent operation.
This is a "type erased" representation of a registered operation.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
static DerivedEffect * get()
Returns a unique instance for the derived effect class.
This class represents a collection of SymbolTables.
virtual Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
static StringRef getSymbolAttrName()
Return the name of the attribute used for symbol names.
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
type_range getType() const
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
void print(raw_ostream &os) const
Type getType() const
Return the type of this value.
use_range getUses() const
Returns a range of all uses, which is useful for iterating over all uses.
user_range getUsers() const
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
A utility result that is used to signal how to proceed with an ongoing walk:
static WalkResult advance()
bool wasInterrupted() const
Returns true if the walk was interrupted.
static WalkResult interrupt()
Operation * getOwner() const
Return the owner of this operand.
A named class for passing around the variadic flag.
static void printOptionValue(raw_ostream &os, const bool &value)
Utility methods for printing option values.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
void printFunctionOp(OpAsmPrinter &p, FunctionOpInterface op, bool isVariadic, StringRef typeAttrName, StringAttr argAttrsName, StringAttr resAttrsName)
Printer implementation for function-like operations.
ParseResult parseFunctionOp(OpAsmParser &parser, OperationState &result, bool allowVariadic, StringAttr typeAttrName, FuncTypeBuilder funcTypeBuilder, StringAttr argAttrsName, StringAttr resAttrsName)
Parser implementation for function-like operations.
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
Include the generated interface declarations.
InFlightDiagnostic emitWarning(Location loc)
Utility method to emit a warning message using this location.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
const FrozenRewritePatternSet GreedyRewriteConfig config
LogicalResult applyPatternsGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
DiagnosedSilenceableFailure emitSilenceableFailure(Location loc, const Twine &message={})
Emits a silenceable failure with the given message.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
LogicalResult applyFullConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Apply a complete conversion on the given operations, and all nested operations.
LogicalResult applyOpPatternsGreedily(ArrayRef< Operation * > ops, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr, bool *allErased=nullptr)
Rewrite the specified ops by repeatedly applying the highest benefit patterns in a greedy worklist dr...
DiagnosedDefiniteFailure emitDefiniteFailure(Location loc, const Twine &message={})
Emits a definite failure with the given message.
bool isOpTriviallyDead(Operation *op)
Return true if the given operation is unused, and has no side effects on memory that prevent erasing.
const FrozenRewritePatternSet & patterns
void eliminateCommonSubExpressions(RewriterBase &rewriter, DominanceInfo &domInfo, Operation *op, bool *changed=nullptr)
Eliminate common subexpressions within the given operation.
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.
size_t moveLoopInvariantCode(ArrayRef< Region * > regions, function_ref< bool(Value, Region *)> isDefinedOutsideRegion, function_ref< bool(Operation *, Region *)> shouldMoveOutOfRegion, function_ref< void(Operation *, Region *)> moveOutOfRegion)
Given a list of regions, perform loop-invariant code motion.
Dialect conversion configuration.
RewriterBase::Listener * listener
An optional listener that is notified about all IR modifications in case dialect conversion succeeds.
This is the representation of an operand reference.
This represents an operation in an abstracted form, suitable for use with the builder APIs.
T & getOrAddProperties()
Get (or create) a properties of the provided type to be set on the operation on creation.
void addOperands(ValueRange newOperands)
void addTypes(ArrayRef< Type > newTypes)
Region * addRegion()
Create a region that should be attached to the operation.