35#include "llvm/ADT/DenseSet.h"
36#include "llvm/ADT/STLExtras.h"
37#include "llvm/ADT/ScopeExit.h"
38#include "llvm/ADT/SmallPtrSet.h"
39#include "llvm/ADT/SmallVectorExtras.h"
40#include "llvm/ADT/TypeSwitch.h"
41#include "llvm/Support/Debug.h"
42#include "llvm/Support/DebugLog.h"
43#include "llvm/Support/ErrorHandling.h"
44#include "llvm/Support/InterleavedRange.h"
47#define DEBUG_TYPE "transform-dialect"
48#define DEBUG_TYPE_MATCHER "transform-matcher"
60 OpAsmParser &parser, std::optional<OpAsmParser::UnresolvedOperand> &root,
81 while (transformAncestor) {
82 if (transformAncestor == payload) {
85 <<
"cannot apply transform to itself (or one of its ancestors)";
86 diag.attachNote(payload->
getLoc()) <<
"target payload op";
89 transformAncestor = transformAncestor->
getParentOp();
95#include "mlir/Dialect/Transform/IR/TransformOps.cpp.inc"
101OperandRange transform::AlternativesOp::getEntrySuccessorOperands(
103 if (!successor.
isParent() && getOperation()->getNumOperands() == 1)
104 return getOperation()->getOperands();
106 getOperation()->operand_end());
109void transform::AlternativesOp::getSuccessorRegions(
111 for (
Region &alternative : llvm::drop_begin(
116 ->getRegionNumber() +
118 regions.emplace_back(&alternative);
125transform::AlternativesOp::getSuccessorInputs(
RegionSuccessor successor) {
127 return getOperation()->getResults();
131void transform::AlternativesOp::getRegionInvocationBounds(
136 bounds.reserve(getNumRegions());
137 bounds.emplace_back(1, 1);
144 results.
set(res, {});
152 if (
Value scopeHandle = getScope())
153 llvm::append_range(originals, state.
getPayloadOps(scopeHandle));
158 if (original->isAncestor(getOperation())) {
160 <<
"scope must not contain the transforms being applied";
161 diag.attachNote(original->getLoc()) <<
"scope";
166 <<
"only isolated-from-above ops can be alternative scopes";
167 diag.attachNote(original->getLoc()) <<
"scope";
172 for (
Region ® : getAlternatives()) {
178 auto clones = llvm::map_to_vector(
180 llvm::scope_exit deleteClones([&] {
191 if (
result.isSilenceableFailure()) {
192 LDBG() <<
"alternative failed: " <<
result.getMessage();
197 if (::mlir::failed(
result.silence()))
206 deleteClones.release();
207 TrackingListener listener(state, *
this);
209 for (
const auto &kvp : llvm::zip(originals, clones)) {
216 detail::forwardTerminatorOperands(®.front(), state, results);
220 return emitSilenceableError() <<
"all alternatives failed";
223void transform::AlternativesOp::getEffects(
227 for (
Region *region : getRegions()) {
228 if (!region->empty())
234LogicalResult transform::AlternativesOp::verify() {
235 for (
Region &alternative : getAlternatives()) {
240 <<
"expects terminator operands to have the "
241 "same type as results of the operation";
242 diag.attachNote(terminator->
getLoc()) <<
"terminator";
262 if (
auto paramH = getParam()) {
264 if (params.size() != 1) {
265 if (targets.size() != params.size()) {
266 return emitSilenceableError()
267 <<
"parameter and target have different payload lengths ("
268 << params.size() <<
" vs " << targets.size() <<
")";
270 for (
auto &&[
target, attr] : llvm::zip_equal(targets, params))
271 target->setAttr(getName(), attr);
276 for (
auto *
target : targets)
277 target->setAttr(getName(), attr);
281void transform::AnnotateOp::getEffects(
293transform::ApplyCommonSubexpressionEliminationOp::applyToOne(
308void transform::ApplyCommonSubexpressionEliminationOp::getEffects(
334void transform::ApplyDeadCodeEliminationOp::getEffects(
359 if (!getRegion().empty()) {
360 for (
Operation &op : getRegion().front()) {
361 cast<transform::PatternDescriptorOpInterface>(&op)
362 .populatePatternsWithState(patterns, state);
374 : getMaxIterations());
377 : getMaxNumRewrites());
385 <<
"greedy pattern application failed";
396 ops.push_back(nestedOp);
401 static const int64_t kNumMaxIterations = 50;
403 bool cseChanged =
false;
407 <<
"greedy pattern application failed";
415 }
while (cseChanged && ++iteration < kNumMaxIterations);
417 if (iteration == kNumMaxIterations)
423LogicalResult transform::ApplyPatternsOp::verify() {
424 if (!getRegion().empty()) {
425 for (
Operation &op : getRegion().front()) {
426 if (!isa<transform::PatternDescriptorOpInterface>(&op)) {
428 <<
"expected children ops to implement "
429 "PatternDescriptorOpInterface";
430 diag.attachNote(op.
getLoc()) <<
"op without interface";
438void transform::ApplyPatternsOp::getEffects(
444void transform::ApplyPatternsOp::build(
453 bodyBuilder(builder,
result.location);
460void transform::ApplyCanonicalizationPatternsOp::populatePatterns(
464 dialect->getCanonicalizationPatterns(patterns);
466 op.getCanonicalizationPatterns(patterns, ctx);
480 std::unique_ptr<TypeConverter> defaultTypeConverter;
481 transform::TypeConverterBuilderOpInterface typeConverterBuilder =
482 getDefaultTypeConverter();
483 if (typeConverterBuilder)
484 defaultTypeConverter = typeConverterBuilder.getTypeConverter();
489 for (
Attribute attr : cast<ArrayAttr>(*getLegalOps()))
490 conversionTarget.addLegalOp(
493 for (
Attribute attr : cast<ArrayAttr>(*getIllegalOps()))
494 conversionTarget.addIllegalOp(
496 if (getLegalDialects())
497 for (
Attribute attr : cast<ArrayAttr>(*getLegalDialects()))
498 conversionTarget.addLegalDialect(cast<StringAttr>(attr).getValue());
499 if (getIllegalDialects())
500 for (
Attribute attr : cast<ArrayAttr>(*getIllegalDialects()))
501 conversionTarget.addIllegalDialect(cast<StringAttr>(attr).getValue());
509 if (!getPatterns().empty()) {
510 for (
Operation &op : getPatterns().front()) {
512 cast<transform::ConversionPatternDescriptorOpInterface>(&op);
515 std::unique_ptr<TypeConverter> typeConverter =
516 descriptor.getTypeConverter();
519 keepAliveConverters.emplace_back(std::move(typeConverter));
520 converter = keepAliveConverters.back().get();
523 if (!defaultTypeConverter) {
525 <<
"pattern descriptor does not specify type "
526 "converter and apply_conversion_patterns op has "
527 "no default type converter";
528 diag.attachNote(op.
getLoc()) <<
"pattern descriptor op";
531 converter = defaultTypeConverter.get();
537 descriptor.populateConversionTargetRules(*converter, conversionTarget);
539 descriptor.populatePatterns(*converter, patterns);
547 TrackingListenerConfig trackingConfig;
548 trackingConfig.requireMatchingReplacementOpName =
false;
549 ErrorCheckingTrackingListener trackingListener(state, *
this, trackingConfig);
550 ConversionConfig conversionConfig;
551 if (getPreserveHandles())
552 conversionConfig.listener = &trackingListener;
563 LogicalResult status = failure();
564 if (getPartialConversion()) {
565 status = applyPartialConversion(
target, conversionTarget, frozenPatterns,
568 status = applyFullConversion(
target, conversionTarget, frozenPatterns,
575 diag = emitSilenceableError() <<
"dialect conversion failed";
576 diag.attachNote(
target->getLoc()) <<
"target op";
581 trackingListener.checkAndResetError();
583 if (
diag.succeeded()) {
585 return trackingFailure;
587 diag.attachNote() <<
"tracking listener also failed: "
592 if (!
diag.succeeded())
599LogicalResult transform::ApplyConversionPatternsOp::verify() {
600 if (getNumRegions() != 1 && getNumRegions() != 2)
602 if (!getPatterns().empty()) {
603 for (
Operation &op : getPatterns().front()) {
604 if (!isa<transform::ConversionPatternDescriptorOpInterface>(&op)) {
606 emitOpError() <<
"expected pattern children ops to implement "
607 "ConversionPatternDescriptorOpInterface";
608 diag.attachNote(op.
getLoc()) <<
"op without interface";
613 if (getNumRegions() == 2) {
614 Region &typeConverterRegion = getRegion(1);
615 if (!llvm::hasSingleElement(typeConverterRegion.
front()))
617 <<
"expected exactly one op in default type converter region";
619 auto typeConverterOp = dyn_cast<transform::TypeConverterBuilderOpInterface>(
621 if (!typeConverterOp) {
623 <<
"expected default converter child op to "
624 "implement TypeConverterBuilderOpInterface";
625 diag.attachNote(maybeTypeConverter->
getLoc()) <<
"op without interface";
629 if (!getPatterns().empty()) {
630 for (
Operation &op : getPatterns().front()) {
632 cast<transform::ConversionPatternDescriptorOpInterface>(&op);
633 if (
failed(descriptor.verifyTypeConverter(typeConverterOp)))
641void transform::ApplyConversionPatternsOp::getEffects(
643 if (!getPreserveHandles()) {
651void transform::ApplyConversionPatternsOp::build(
661 if (patternsBodyBuilder)
662 patternsBodyBuilder(builder,
result.location);
668 if (typeConverterBodyBuilder)
669 typeConverterBodyBuilder(builder,
result.location);
677void transform::ApplyToLLVMConversionPatternsOp::populatePatterns(
680 assert(dialect &&
"expected that dialect is loaded");
681 auto *iface = cast<ConvertToLLVMPatternInterface>(dialect);
685 iface->populateConvertToLLVMConversionPatterns(
689LogicalResult transform::ApplyToLLVMConversionPatternsOp::verifyTypeConverter(
690 transform::TypeConverterBuilderOpInterface builder) {
691 if (builder.getTypeConverterType() !=
"LLVMTypeConverter")
696LogicalResult transform::ApplyToLLVMConversionPatternsOp::verify() {
699 return emitOpError(
"unknown dialect or dialect not loaded: ")
701 auto *iface = dyn_cast<ConvertToLLVMPatternInterface>(dialect);
704 "dialect does not implement ConvertToLLVMPatternInterface or "
705 "extension was not loaded: ")
715transform::ApplyLoopInvariantCodeMotionOp::applyToOne(
725void transform::ApplyLoopInvariantCodeMotionOp::getEffects(
735void transform::ApplyRegisteredPassOp::getEffects(
753 llvm::raw_string_ostream optionsStream(
options);
758 if (
auto paramOperand = dyn_cast<transform::ParamOperandAttr>(valueAttr)) {
761 int64_t dynamicOptionIdx = paramOperand.getIndex().getInt();
762 assert(dynamicOptionIdx <
static_cast<int64_t>(dynamicOptions.size()) &&
763 "the number of ParamOperandAttrs in the options DictionaryAttr"
764 "should be the same as the number of options passed as params");
766 state.
getParams(dynamicOptions[dynamicOptionIdx]);
768 llvm::interleave(attrsAssociatedToParam, optionsStream, appendValueAttr,
770 }
else if (
auto arrayAttr = dyn_cast<ArrayAttr>(valueAttr)) {
772 llvm::interleave(arrayAttr, optionsStream, appendValueAttr,
",");
773 }
else if (
auto strAttr = dyn_cast<StringAttr>(valueAttr)) {
775 optionsStream << strAttr.getValue().str();
778 valueAttr.print(optionsStream,
true);
784 getOptions(), optionsStream,
785 [&](
auto namedAttribute) {
786 optionsStream << namedAttribute.getName().str();
787 optionsStream <<
"=";
788 appendValueAttr(namedAttribute.getValue());
791 optionsStream.flush();
799 <<
"unknown pass or pass pipeline: " << getPassName();
808 <<
"failed to add pass or pass pipeline to pipeline: "
825 auto diag = emitSilenceableError() <<
"pass pipeline failed";
826 diag.attachNote(
target->getLoc()) <<
"target op";
832 results.
set(llvm::cast<OpResult>(getResult()), targets);
841 size_t dynamicOptionsIdx = 0;
847 std::function<ParseResult(
Attribute &)> parseValue =
848 [&](
Attribute &valueAttr) -> ParseResult {
856 [&]() -> ParseResult { return parseValue(attrs.emplace_back()); },
857 " in options dictionary") ||
861 valueAttr = ArrayAttr::get(parser.
getContext(), attrs);
871 ParseResult parsedOperand = parser.
parseOperand(operand);
872 if (failed(parsedOperand))
878 dynamicOptions.push_back(operand);
879 auto wrappedIndex = IntegerAttr::get(
880 IntegerType::get(parser.
getContext(), 64), dynamicOptionsIdx++);
882 transform::ParamOperandAttr::get(parser.
getContext(), wrappedIndex);
883 }
else if (failed(parsedValueAttr.
value())) {
885 }
else if (isa<transform::ParamOperandAttr>(valueAttr)) {
887 <<
"the param_operand attribute is a marker reserved for "
888 <<
"indicating a value will be passed via params and is only used "
889 <<
"in the generic print format";
903 <<
"expected key to either be an identifier or a string";
907 <<
"expected '=' after key in key-value pair";
909 if (failed(parseValue(valueAttr)))
911 <<
"expected a valid attribute or operand as value associated "
912 <<
"to key '" << key <<
"'";
921 " in options dictionary"))
924 if (DictionaryAttr::findDuplicate(
925 keyValuePairs,
false)
928 <<
"duplicate keys found in options dictionary";
943 if (
auto paramOperandAttr =
944 dyn_cast<transform::ParamOperandAttr>(valueAttr)) {
947 dynamicOptions[paramOperandAttr.getIndex().getInt()]);
948 }
else if (
auto arrayAttr = dyn_cast<ArrayAttr>(valueAttr)) {
951 llvm::interleaveComma(arrayAttr, printer, printOptionValue);
960 printer << namedAttribute.
getName();
962 printOptionValue(namedAttribute.
getValue());
967LogicalResult transform::ApplyRegisteredPassOp::verify() {
974 std::function<LogicalResult(
Attribute)> checkOptionValue =
975 [&](
Attribute valueAttr) -> LogicalResult {
976 if (
auto paramOperand = dyn_cast<transform::ParamOperandAttr>(valueAttr)) {
977 int64_t dynamicOptionIdx = paramOperand.getIndex().getInt();
978 if (dynamicOptionIdx < 0 ||
979 dynamicOptionIdx >=
static_cast<int64_t>(dynamicOptions.size()))
981 <<
"dynamic option index " << dynamicOptionIdx
982 <<
" is out of bounds for the number of dynamic options: "
983 << dynamicOptions.size();
984 if (dynamicOptions[dynamicOptionIdx] ==
nullptr)
985 return emitOpError() <<
"dynamic option index " << dynamicOptionIdx
986 <<
" is already used in options";
987 dynamicOptions[dynamicOptionIdx] =
nullptr;
988 }
else if (
auto arrayAttr = dyn_cast<ArrayAttr>(valueAttr)) {
990 for (
auto eltAttr : arrayAttr)
991 if (
failed(checkOptionValue(eltAttr)))
998 if (
failed(checkOptionValue(namedAttr.getValue())))
1002 for (
Value dynamicOption : dynamicOptions)
1004 return emitOpError() <<
"a param operand does not have a corresponding "
1005 <<
"param_operand attr in the options dict";
1018 results.push_back(
target);
1022void transform::CastOp::getEffects(
1030 assert(inputs.size() == 1 &&
"expected one input");
1031 assert(outputs.size() == 1 &&
"expected one output");
1032 return llvm::all_of(
1033 std::initializer_list<Type>{inputs.front(), outputs.front()},
1034 llvm::IsaPred<transform::TransformHandleTypeInterface>);
1054 assert(block.
getParent() &&
"cannot match using a detached block");
1061 if (!isa<transform::MatchOpInterface>(match)) {
1063 <<
"expected operations in the match part to "
1064 "implement MatchOpInterface";
1067 state.
applyTransform(cast<transform::TransformOpInterface>(match));
1068 if (
diag.succeeded())
1086template <
typename... Tys>
1088 return ((isa<Tys>(t1) && isa<Tys>(t2)) || ... ||
false);
1095 transform::TransformParamTypeInterface,
1096 transform::TransformValueHandleTypeInterface>(
1109 getOperation(), getMatcher());
1110 if (matcher.isExternal()) {
1112 <<
"unresolved external symbol " << getMatcher();
1116 rawResults.resize(getOperation()->getNumResults());
1117 std::optional<DiagnosedSilenceableFailure> maybeFailure;
1129 matcher.getFunctionBody().front(),
1132 if (
diag.isDefiniteFailure())
1134 if (
diag.isSilenceableFailure()) {
1136 <<
" failed: " <<
diag.getMessage();
1141 for (
auto &&[i, mapping] : llvm::enumerate(mappings)) {
1142 if (mapping.size() != 1) {
1143 maybeFailure.emplace(emitSilenceableError()
1144 <<
"result #" << i <<
", associated with "
1146 <<
" payload objects, expected 1");
1149 rawResults[i].push_back(mapping[0]);
1154 return std::move(*maybeFailure);
1155 assert(!maybeFailure &&
"failure set but the walk was not interrupted");
1157 for (
auto &&[opResult, rawResult] :
1158 llvm::zip_equal(getOperation()->getResults(), rawResults)) {
1165void transform::CollectMatchingOp::getEffects(
1172LogicalResult transform::CollectMatchingOp::verifySymbolUses(
1174 auto matcherSymbol = dyn_cast_or_null<FunctionOpInterface>(
1176 if (!matcherSymbol ||
1177 !isa<TransformOpInterface>(matcherSymbol.getOperation()))
1178 return emitError() <<
"unresolved matcher symbol " << getMatcher();
1181 if (argumentTypes.size() != 1 ||
1182 !isa<TransformHandleTypeInterface>(argumentTypes[0])) {
1184 <<
"expected the matcher to take one operation handle argument";
1186 if (!matcherSymbol.getArgAttr(
1187 0, transform::TransformDialect::kArgReadOnlyAttrName)) {
1188 return emitError() <<
"expected the matcher argument to be marked readonly";
1192 if (resultTypes.size() != getOperation()->getNumResults()) {
1194 <<
"expected the matcher to yield as many values as op has results ("
1195 << getOperation()->getNumResults() <<
"), got "
1196 << resultTypes.size();
1199 for (
auto &&[i, matcherType, resultType] :
1200 llvm::enumerate(resultTypes, getOperation()->getResultTypes())) {
1205 <<
"mismatching type interfaces for matcher result and op result #"
1217bool transform::ForeachMatchOp::allowsRepeatedHandleOperands() {
return true; }
1225 matchActionPairs.reserve(getMatchers().size());
1227 for (
auto &&[matcher, action] :
1228 llvm::zip_equal(getMatchers(), getActions())) {
1229 auto matcherSymbol =
1231 getOperation(), cast<SymbolRefAttr>(matcher));
1234 getOperation(), cast<SymbolRefAttr>(action));
1235 assert(matcherSymbol && actionSymbol &&
1236 "unresolved symbols not caught by the verifier");
1238 if (matcherSymbol.isExternal())
1240 if (actionSymbol.isExternal())
1243 matchActionPairs.emplace_back(matcherSymbol, actionSymbol);
1254 matchInputMapping.emplace_back();
1256 getForwardedInputs(), state);
1258 actionResultMapping.resize(getForwardedOutputs().size());
1264 if (!getRestrictRoot() && op == root)
1272 firstMatchArgument.clear();
1273 firstMatchArgument.push_back(op);
1276 for (
auto [matcher, action] : matchActionPairs) {
1278 matchBlock(matcher.getFunctionBody().front(), matchInputMapping,
1279 state, matchOutputMapping);
1280 if (
diag.isDefiniteFailure())
1282 if (
diag.isSilenceableFailure()) {
1284 <<
" failed: " <<
diag.getMessage();
1290 action.getFunctionBody().front().getArguments(),
1291 matchOutputMapping))) {
1296 action.getFunctionBody().front().without_terminator()) {
1299 if (
result.isDefiniteFailure())
1301 if (
result.isSilenceableFailure()) {
1303 overallDiag = emitSilenceableError() <<
"actions failed";
1306 <<
"failed action: " <<
result.getMessage();
1308 <<
"when applied to this matching payload";
1313 if (
failed(detail::appendValueMappings(
1315 action.getFunctionBody().front().getTerminator()->getOperands(),
1316 state, getFlattenResults()))) {
1318 <<
"action @" << action.getName()
1319 <<
" has results associated with multiple payload entities, "
1320 "but flattening was not requested";
1335 results.
set(llvm::cast<OpResult>(getUpdated()),
1337 for (
auto &&[
result, mapping] :
1338 llvm::zip_equal(getForwardedOutputs(), actionResultMapping)) {
1344void transform::ForeachMatchOp::getAsmResultNames(
1346 setNameFn(getUpdated(),
"updated_root");
1347 for (
Value v : getForwardedOutputs()) {
1348 setNameFn(v,
"yielded");
1352void transform::ForeachMatchOp::getEffects(
1355 if (getOperation()->getNumOperands() < 1 ||
1356 getOperation()->getNumResults() < 1) {
1380 matcherList.push_back(SymbolRefAttr::get(matcher));
1381 actionList.push_back(SymbolRefAttr::get(action));
1395 for (
auto &&[matcher, action, idx] : llvm::zip_equal(
1398 printer << cast<SymbolRefAttr>(matcher) <<
" -> "
1399 << cast<SymbolRefAttr>(action);
1407LogicalResult transform::ForeachMatchOp::verify() {
1408 if (getMatchers().size() != getActions().size())
1409 return emitOpError() <<
"expected the same number of matchers and actions";
1410 if (getMatchers().empty())
1411 return emitOpError() <<
"expected at least one match/action pair";
1415 if (matcherNames.insert(name).second)
1418 <<
" is used more than once, only the first match will apply";
1429 bool alsoVerifyInternal =
false) {
1430 auto transformOp = cast<transform::TransformOpInterface>(op.getOperation());
1431 llvm::SmallDenseSet<unsigned> consumedArguments;
1432 if (!op.isExternal()) {
1436 for (
unsigned i = 0, e = op.getNumArguments(); i < e; ++i) {
1438 op.getArgAttr(i, transform::TransformDialect::kArgConsumedAttrName) !=
1441 op.getArgAttr(i, transform::TransformDialect::kArgReadOnlyAttrName) !=
1443 if (isConsumed && isReadOnly) {
1444 return transformOp.emitSilenceableError()
1445 <<
"argument #" << i <<
" cannot be both readonly and consumed";
1447 if ((op.isExternal() || alsoVerifyInternal) && !isConsumed && !isReadOnly) {
1448 return transformOp.emitSilenceableError()
1449 <<
"must provide consumed/readonly status for arguments of "
1450 "external or called ops";
1452 if (op.isExternal())
1455 if (consumedArguments.contains(i) && !isConsumed && isReadOnly) {
1456 return transformOp.emitSilenceableError()
1457 <<
"argument #" << i
1458 <<
" is consumed in the body but is not marked as such";
1460 if (emitWarnings && !consumedArguments.contains(i) && isConsumed) {
1464 <<
"op argument #" << i
1465 <<
" is not consumed in the body but is marked as consumed";
1471LogicalResult transform::ForeachMatchOp::verifySymbolUses(
1473 assert(getMatchers().size() == getActions().size());
1475 StringAttr::get(
getContext(), TransformDialect::kArgConsumedAttrName);
1476 for (
auto &&[matcher, action] :
1477 llvm::zip_equal(getMatchers(), getActions())) {
1479 auto matcherSymbol = dyn_cast_or_null<FunctionOpInterface>(
1481 cast<SymbolRefAttr>(matcher)));
1482 auto actionSymbol = dyn_cast_or_null<FunctionOpInterface>(
1484 cast<SymbolRefAttr>(action)));
1485 if (!matcherSymbol ||
1486 !isa<TransformOpInterface>(matcherSymbol.getOperation()))
1487 return emitError() <<
"unresolved matcher symbol " << matcher;
1488 if (!actionSymbol ||
1489 !isa<TransformOpInterface>(actionSymbol.getOperation()))
1490 return emitError() <<
"unresolved action symbol " << action;
1495 .checkAndReport())) {
1501 .checkAndReport())) {
1506 TypeRange operandTypes = getOperandTypes();
1507 TypeRange matcherArguments = matcherSymbol.getArgumentTypes();
1508 if (operandTypes.size() != matcherArguments.size()) {
1510 emitError() <<
"the number of operands (" << operandTypes.size()
1511 <<
") doesn't match the number of matcher arguments ("
1512 << matcherArguments.size() <<
") for " << matcher;
1513 diag.attachNote(matcherSymbol->getLoc()) <<
"symbol declaration";
1516 for (
auto &&[i, operand, argument] :
1517 llvm::enumerate(operandTypes, matcherArguments)) {
1518 if (matcherSymbol.getArgAttr(i, consumedAttr)) {
1521 <<
"does not expect matcher symbol to consume its operand #" << i;
1522 diag.attachNote(matcherSymbol->getLoc()) <<
"symbol declaration";
1531 <<
"mismatching type interfaces for operand and matcher argument #"
1532 << i <<
" of matcher " << matcher;
1533 diag.attachNote(matcherSymbol->getLoc()) <<
"symbol declaration";
1538 TypeRange matcherResults = matcherSymbol.getResultTypes();
1539 TypeRange actionArguments = actionSymbol.getArgumentTypes();
1540 if (matcherResults.size() != actionArguments.size()) {
1541 return emitError() <<
"mismatching number of matcher results and "
1542 "action arguments between "
1543 << matcher <<
" (" << matcherResults.size() <<
") and "
1544 << action <<
" (" << actionArguments.size() <<
")";
1546 for (
auto &&[i, matcherType, actionType] :
1547 llvm::enumerate(matcherResults, actionArguments)) {
1551 return emitError() <<
"mismatching type interfaces for matcher result "
1552 "and action argument #"
1553 << i <<
"of matcher " << matcher <<
" and action "
1558 TypeRange actionResults = actionSymbol.getResultTypes();
1559 auto resultTypes =
TypeRange(getResultTypes()).drop_front();
1560 if (actionResults.size() != resultTypes.size()) {
1562 emitError() <<
"the number of action results ("
1563 << actionResults.size() <<
") for " << action
1564 <<
" doesn't match the number of extra op results ("
1565 << resultTypes.size() <<
")";
1566 diag.attachNote(actionSymbol->getLoc()) <<
"symbol declaration";
1569 for (
auto &&[i, resultType, actionType] :
1570 llvm::enumerate(resultTypes, actionResults)) {
1575 emitError() <<
"mismatching type interfaces for action result #" << i
1576 <<
" of action " << action <<
" and op result";
1577 diag.attachNote(actionSymbol->getLoc()) <<
"symbol declaration";
1595 detail::prepareValueMappings(payloads, getTargets(), state);
1596 size_t numIterations = payloads.empty() ? 0 : payloads.front().size();
1597 bool withZipShortest = getWithZipShortest();
1601 if (withZipShortest) {
1605 return a.size() <
b.size();
1608 for (
auto &payload : payloads)
1609 payload.resize(numIterations);
1615 for (
size_t argIdx = 1; !withZipShortest && argIdx < payloads.size();
1617 if (payloads[argIdx].size() != numIterations) {
1618 return emitSilenceableError()
1619 <<
"prior targets' payload size (" << numIterations
1620 <<
") differs from payload size (" << payloads[argIdx].size()
1621 <<
") of target " << getTargets()[argIdx];
1630 for (
size_t iterIdx = 0; iterIdx < numIterations; iterIdx++) {
1633 for (
auto &&[argIdx, blockArg] : llvm::enumerate(blockArguments)) {
1644 llvm::cast<transform::TransformOpInterface>(
transform));
1650 OperandRange yieldOperands = getYieldOp().getOperands();
1651 for (
auto &&[
result, yieldOperand, resTuple] :
1652 llvm::zip_equal(getResults(), yieldOperands, zippedResults))
1654 if (isa<TransformHandleTypeInterface>(
result.getType()))
1655 llvm::append_range(resTuple, state.
getPayloadOps(yieldOperand));
1656 else if (isa<TransformValueHandleTypeInterface>(
result.getType()))
1658 else if (isa<TransformParamTypeInterface>(
result.getType()))
1659 llvm::append_range(resTuple, state.
getParams(yieldOperand));
1661 assert(
false &&
"unhandled handle type");
1665 for (
auto &&[
result, resPayload] : zip_equal(getResults(), zippedResults))
1671void transform::ForeachOp::getEffects(
1675 for (
auto &&[
target, blockArg] :
1676 llvm::zip(getTargetsMutable(), getBody().front().getArguments())) {
1678 if (any_of(getBody().front().without_terminator(), [&](
Operation &op) {
1680 cast<TransformOpInterface>(&op));
1688 if (any_of(getBody().front().without_terminator(), [&](
Operation &op) {
1692 }
else if (any_of(getBody().front().without_terminator(), [&](
Operation &op) {
1701void transform::ForeachOp::getSuccessorRegions(
1703 Region *bodyRegion = &getBody();
1705 regions.emplace_back(bodyRegion);
1712 "unexpected region index");
1713 regions.emplace_back(bodyRegion);
1723transform::ForeachOp::getEntrySuccessorOperands(
RegionSuccessor successor) {
1726 assert(successor.
getSuccessor() == &getBody() &&
"unexpected region index");
1727 return getOperation()->getOperands();
1730transform::YieldOp transform::ForeachOp::getYieldOp() {
1731 return cast<transform::YieldOp>(getBody().front().getTerminator());
1734LogicalResult transform::ForeachOp::verify() {
1735 for (
auto [targetOpt, bodyArgOpt] :
1736 llvm::zip_longest(getTargets(), getBody().front().getArguments())) {
1737 if (!targetOpt || !bodyArgOpt)
1738 return emitOpError() <<
"expects the same number of targets as the body "
1739 "has block arguments";
1740 if (targetOpt.value().getType() != bodyArgOpt.value().getType())
1742 "expects co-indexed targets and the body's "
1743 "block arguments to have the same op/value/param type");
1746 for (
auto [resultOpt, yieldOperandOpt] :
1747 llvm::zip_longest(getResults(), getYieldOp().getOperands())) {
1748 if (!resultOpt || !yieldOperandOpt)
1749 return emitOpError() <<
"expects the same number of results as the "
1750 "yield terminator has operands";
1751 if (resultOpt.value().getType() != yieldOperandOpt.value().getType())
1752 return emitOpError(
"expects co-indexed results and yield "
1753 "operands to have the same op/value/param type");
1771 for (
int64_t i = 0, e = getNthParent(); i < e; ++i) {
1774 bool checkIsolatedFromAbove =
1775 !getIsolatedFromAbove() ||
1777 bool checkOpName = !getOpName().has_value() ||
1779 if (checkIsolatedFromAbove && checkOpName)
1784 if (getAllowEmptyResults()) {
1785 results.
set(llvm::cast<OpResult>(getResult()), parents);
1789 emitSilenceableError()
1790 <<
"could not find a parent op that matches all requirements";
1791 diag.attachNote(
target->getLoc()) <<
"target op";
1795 if (getDeduplicate()) {
1796 if (resultSet.insert(parent).second)
1797 parents.push_back(parent);
1799 parents.push_back(parent);
1802 results.
set(llvm::cast<OpResult>(getResult()), parents);
1814 int64_t resultNumber = getResultNumber();
1816 if (std::empty(payloadOps)) {
1817 results.
set(cast<OpResult>(getResult()), {});
1820 if (!llvm::hasSingleElement(payloadOps))
1822 <<
"handle must be mapped to exactly one payload op";
1825 if (
target->getNumResults() <= resultNumber)
1827 results.
set(llvm::cast<OpResult>(getResult()),
1828 llvm::to_vector(
target->getResult(resultNumber).getUsers()));
1842 if (llvm::isa<BlockArgument>(v)) {
1844 emitSilenceableError() <<
"cannot get defining op of block argument";
1845 diag.attachNote(v.getLoc()) <<
"target value";
1848 definingOps.push_back(v.getDefiningOp());
1850 results.
set(llvm::cast<OpResult>(getResult()), definingOps);
1862 int64_t operandNumber = getOperandNumber();
1866 target->getNumOperands() <= operandNumber
1868 :
target->getOperand(operandNumber).getDefiningOp();
1871 emitSilenceableError()
1872 <<
"could not find a producer for operand number: " << operandNumber
1874 diag.attachNote(
target->getLoc()) <<
"target op";
1877 producers.push_back(producer);
1879 results.
set(llvm::cast<OpResult>(getResult()), producers);
1895 getLoc(), getIsAll(), getIsInverted(), getRawPositionList(),
1896 target->getNumOperands(), operandPositions);
1897 if (
diag.isSilenceableFailure()) {
1899 <<
"while considering positions of this payload operation";
1902 llvm::append_range(operands,
1903 llvm::map_range(operandPositions, [&](
int64_t pos) {
1904 return target->getOperand(pos);
1907 results.
setValues(cast<OpResult>(getResult()), operands);
1911LogicalResult transform::GetOperandOp::verify() {
1913 getIsInverted(), getIsAll());
1928 getLoc(), getIsAll(), getIsInverted(), getRawPositionList(),
1929 target->getNumResults(), resultPositions);
1930 if (
diag.isSilenceableFailure()) {
1932 <<
"while considering positions of this payload operation";
1935 llvm::append_range(opResults,
1936 llvm::map_range(resultPositions, [&](
int64_t pos) {
1937 return target->getResult(pos);
1940 results.
setValues(cast<OpResult>(getResult()), opResults);
1944LogicalResult transform::GetResultOp::verify() {
1946 getIsInverted(), getIsAll());
1953void transform::GetTypeOp::getEffects(
1966 Type type = value.getType();
1967 if (getElemental()) {
1968 if (
auto shaped = dyn_cast<ShapedType>(type)) {
1969 type = shaped.getElementType();
1972 params.push_back(TypeAttr::get(type));
1974 results.
setParams(cast<OpResult>(getResult()), params);
1992 if (
result.isDefiniteFailure())
1995 if (
result.isSilenceableFailure()) {
1996 if (mode == transform::FailurePropagationMode::Propagate) {
2016 getOperation(), getTarget());
2017 assert(callee &&
"unverified reference to unknown symbol");
2019 if (callee.isExternal())
2024 detail::prepareValueMappings(mappings, getOperands(), state);
2026 for (
auto &&[arg, map] :
2027 llvm::zip_equal(callee.getBody().front().getArguments(), mappings)) {
2033 callee.getBody().front(), getFailurePropagationMode(), state, results);
2039 detail::prepareValueMappings(
2040 mappings, callee.getBody().front().getTerminator()->getOperands(), state);
2041 for (
auto &&[
result, mapping] : llvm::zip_equal(getResults(), mappings))
2049void transform::IncludeOp::getEffects(
2064 auto defaultEffects = [&] {
2071 getOperation()->getAttrOfType<SymbolRefAttr>(getTargetAttrName());
2073 return defaultEffects();
2075 getOperation(), getTarget());
2077 return defaultEffects();
2079 for (
unsigned i = 0, e = getNumOperands(); i < e; ++i) {
2080 if (callee.getArgAttr(i, TransformDialect::kArgConsumedAttrName))
2082 else if (callee.getArgAttr(i, TransformDialect::kArgReadOnlyAttrName))
2091 auto targetAttr = getOperation()->getAttrOfType<SymbolRefAttr>(
"target");
2093 return emitOpError() <<
"expects a 'target' symbol reference attribute";
2098 return emitOpError() <<
"does not reference a named transform sequence";
2100 FunctionType fnType =
target.getFunctionType();
2101 if (fnType.getNumInputs() != getNumOperands())
2102 return emitError(
"incorrect number of operands for callee");
2104 for (
unsigned i = 0, e = fnType.getNumInputs(); i != e; ++i) {
2105 if (getOperand(i).
getType() != fnType.getInput(i)) {
2106 return emitOpError(
"operand type mismatch: expected operand type ")
2107 << fnType.getInput(i) <<
", but provided "
2108 << getOperand(i).getType() <<
" for operand number " << i;
2112 if (fnType.getNumResults() != getNumResults())
2113 return emitError(
"incorrect number of results for callee");
2115 for (
unsigned i = 0, e = fnType.getNumResults(); i != e; ++i) {
2116 Type resultType = getResult(i).getType();
2117 Type funcType = fnType.getResult(i);
2120 <<
" must implement the same transform dialect "
2121 "interface as the corresponding callee result";
2126 cast<FunctionOpInterface>(*
target),
false,
2136 ::std::optional<::mlir::Operation *> maybeCurrent,
2138 if (!maybeCurrent.has_value()) {
2143 return emitSilenceableError() <<
"operation is not empty";
2154 for (
auto acceptedAttr : getOpNames().getAsRange<StringAttr>()) {
2155 if (acceptedAttr.getValue() == currentOpName)
2158 return emitSilenceableError() <<
"wrong operation name";
2169 auto signedAPIntAsString = [&](
const APInt &value) {
2171 llvm::raw_string_ostream os(str);
2172 value.print(os,
true);
2179 if (params.size() != references.size()) {
2180 return emitSilenceableError()
2181 <<
"parameters have different payload lengths (" << params.size()
2182 <<
" vs " << references.size() <<
")";
2185 for (
auto &&[i, param, reference] : llvm::enumerate(params, references)) {
2186 auto intAttr = llvm::dyn_cast<IntegerAttr>(param);
2187 auto refAttr = llvm::dyn_cast<IntegerAttr>(reference);
2188 if (!intAttr || !refAttr) {
2190 <<
"non-integer parameter value not expected";
2192 if (intAttr.getType() != refAttr.getType()) {
2194 <<
"mismatching integer attribute types in parameter #" << i;
2196 APInt value = intAttr.getValue();
2197 APInt refValue = refAttr.getValue();
2201 auto reportError = [&](StringRef direction) {
2203 emitSilenceableError() <<
"expected parameter to be " << direction
2204 <<
" " << signedAPIntAsString(refValue)
2205 <<
", got " << signedAPIntAsString(value);
2206 diag.attachNote(getParam().getLoc())
2207 <<
"value # " << position
2208 <<
" associated with the parameter defined here";
2212 switch (getPredicate()) {
2213 case MatchCmpIPredicate::eq:
2214 if (value.eq(refValue))
2216 return reportError(
"equal to");
2217 case MatchCmpIPredicate::ne:
2218 if (value.ne(refValue))
2220 return reportError(
"not equal to");
2221 case MatchCmpIPredicate::lt:
2222 if (value.slt(refValue))
2224 return reportError(
"less than");
2225 case MatchCmpIPredicate::le:
2226 if (value.sle(refValue))
2228 return reportError(
"less than or equal to");
2229 case MatchCmpIPredicate::gt:
2230 if (value.sgt(refValue))
2232 return reportError(
"greater than");
2233 case MatchCmpIPredicate::ge:
2234 if (value.sge(refValue))
2236 return reportError(
"greater than or equal to");
2242void transform::MatchParamCmpIOp::getEffects(
2256 results.
setParams(cast<OpResult>(getParam()), {getValue()});
2269 if (isa<TransformHandleTypeInterface>(handles.front().
getType())) {
2271 for (
Value operand : handles)
2272 llvm::append_range(operations, state.
getPayloadOps(operand));
2273 if (!getDeduplicate()) {
2274 results.
set(llvm::cast<OpResult>(getResult()), operations);
2279 results.
set(llvm::cast<OpResult>(getResult()), uniqued.getArrayRef());
2283 if (llvm::isa<TransformParamTypeInterface>(handles.front().getType())) {
2285 for (
Value attribute : handles)
2286 llvm::append_range(attrs, state.
getParams(attribute));
2287 if (!getDeduplicate()) {
2288 results.
setParams(cast<OpResult>(getResult()), attrs);
2293 results.
setParams(cast<OpResult>(getResult()), uniqued.getArrayRef());
2298 llvm::isa<TransformValueHandleTypeInterface>(handles.front().getType()) &&
2299 "expected value handle type");
2301 for (
Value value : handles)
2303 if (!getDeduplicate()) {
2304 results.
setValues(cast<OpResult>(getResult()), payloadValues);
2309 results.
setValues(cast<OpResult>(getResult()), uniqued.getArrayRef());
2313bool transform::MergeHandlesOp::allowsRepeatedHandleOperands() {
2315 return getDeduplicate();
2318void transform::MergeHandlesOp::getEffects(
2327OpFoldResult transform::MergeHandlesOp::fold(FoldAdaptor adaptor) {
2328 if (getDeduplicate() || getHandles().size() != 1)
2333 return getHandles().front();
2352 if (
failed(detail::mapPossibleTopLevelTransformOpBlockArguments(
2353 state, this->getOperation(), getBody())))
2357 FailurePropagationMode::Propagate, state, results);
2360void transform::NamedSequenceOp::getEffects(
2363ParseResult transform::NamedSequenceOp::parse(
OpAsmParser &parser,
2367 getFunctionTypeAttrName(
result.name),
2370 std::string &) { return builder.getFunctionType(inputs, results); },
2371 getArgAttrsAttrName(
result.name), getResAttrsAttrName(
result.name));
2374void transform::NamedSequenceOp::print(
OpAsmPrinter &printer) {
2376 printer, cast<FunctionOpInterface>(getOperation()),
false,
2377 getFunctionTypeAttrName().getValue(), getArgAttrsAttrName(),
2378 getResAttrsAttrName());
2388 if (
auto parent = op->
getParentOfType<transform::TransformOpInterface>()) {
2391 <<
"cannot be defined inside another transform op";
2392 diag.attachNote(parent.
getLoc()) <<
"ancestor transform op";
2396 if (op.isExternal() || op.getFunctionBody().empty()) {
2403 if (op.getFunctionBody().front().empty())
2406 Operation *terminator = &op.getFunctionBody().front().back();
2407 if (!isa<transform::YieldOp>(terminator)) {
2410 << transform::YieldOp::getOperationName()
2411 <<
"' as terminator";
2412 diag.attachNote(terminator->
getLoc()) <<
"terminator";
2416 if (terminator->
getNumOperands() != op.getResultTypes().size()) {
2418 <<
"expected terminator to have as many operands as the parent op "
2421 for (
auto [i, operandType, resultType] : llvm::zip_equal(
2424 if (operandType == resultType)
2427 <<
"the type of the terminator operand #" << i
2428 <<
" must match the type of the corresponding parent op result ("
2429 << operandType <<
" vs " << resultType <<
")";
2442 transform::TransformDialect::kWithNamedSequenceAttrName)) {
2445 <<
"expects the parent symbol table to have the '"
2446 << transform::TransformDialect::kWithNamedSequenceAttrName
2448 diag.attachNote(parent->
getLoc()) <<
"symbol table operation";
2453 if (
auto parent = op->
getParentOfType<transform::TransformOpInterface>()) {
2456 <<
"cannot be defined inside another transform op";
2457 diag.attachNote(parent.
getLoc()) <<
"ancestor transform op";
2461 if (op.isExternal() || op.getBody().empty())
2465 if (op.getBody().front().empty())
2469 for (
Operation &child : op.getBody().front().without_terminator()) {
2470 if (!isa<transform::TransformOpInterface>(child)) {
2473 <<
"expected children ops to implement TransformOpInterface";
2474 diag.attachNote(child.getLoc()) <<
"op without interface";
2479 Operation *terminator = &op.getBody().front().back();
2480 if (!isa<transform::YieldOp>(terminator)) {
2483 << transform::YieldOp::getOperationName()
2484 <<
"' as terminator";
2485 diag.attachNote(terminator->
getLoc()) <<
"terminator";
2489 if (terminator->
getNumOperands() != op.getFunctionType().getNumResults()) {
2491 <<
"expected terminator to have as many operands as the parent op "
2494 for (
auto [i, operandType, resultType] :
2495 llvm::zip_equal(llvm::seq<unsigned>(0, terminator->
getNumOperands()),
2497 op.getFunctionType().getResults())) {
2498 if (operandType == resultType)
2501 <<
"the type of the terminator operand #" << i
2502 <<
" must match the type of the corresponding parent op result ("
2503 << operandType <<
" vs " << resultType <<
")";
2506 auto funcOp = cast<FunctionOpInterface>(*op);
2509 if (!
diag.succeeded())
2516LogicalResult transform::NamedSequenceOp::verify() {
2521template <
typename FnTy>
2526 types.reserve(1 + extraBindingTypes.size());
2527 types.push_back(bbArgType);
2528 llvm::append_range(types, extraBindingTypes);
2538 if constexpr (llvm::function_traits<FnTy>::num_args == 3) {
2546void transform::NamedSequenceOp::build(
OpBuilder &builder,
2549 SequenceBodyBuilderFn bodyBuilder,
2555 TypeAttr::get(FunctionType::get(builder.
getContext(),
2556 rootType, resultTypes)));
2572 size_t numAssociations =
2574 .Case([&](TransformHandleTypeInterface opHandle) {
2577 .Case([&](TransformValueHandleTypeInterface valueHandle) {
2580 .Case([&](TransformParamTypeInterface param) {
2581 return llvm::range_size(state.
getParams(getHandle()));
2583 .DefaultUnreachable(
"unknown kind of transform dialect type");
2584 results.
setParams(cast<OpResult>(getNum()),
2589LogicalResult transform::NumAssociationsOp::verify() {
2591 auto resultType = cast<TransformParamTypeInterface>(getNum().
getType());
2611 results.
set(cast<OpResult>(getResult()),
result);
2631 .Case([&](TransformHandleTypeInterface x) {
2634 .Case([&](TransformValueHandleTypeInterface x) {
2637 .Case([&](TransformParamTypeInterface x) {
2638 return llvm::range_size(state.
getParams(getHandle()));
2640 .DefaultUnreachable(
"unknown transform dialect type interface");
2642 auto produceNumOpsError = [&]() {
2643 return emitSilenceableError()
2644 << getHandle() <<
" expected to contain " << this->getNumResults()
2645 <<
" payloads but it contains " << numPayloads <<
" payloads";
2650 if (numPayloads > getNumResults() && !getOverflowResult().has_value())
2651 return produceNumOpsError();
2656 if (numPayloads < getNumResults() && getFailOnPayloadTooSmall() &&
2657 (numPayloads != 0 || !getPassThroughEmptyHandle()))
2658 return produceNumOpsError();
2662 if (getOverflowResult())
2663 resultHandles[*getOverflowResult()].reserve(numPayloads - getNumResults());
2665 auto container = [&]() {
2666 if (isa<TransformHandleTypeInterface>(getHandle().
getType())) {
2667 return llvm::map_to_vector(
2669 [](
Operation *op) -> MappedValue {
return op; });
2671 if (isa<TransformValueHandleTypeInterface>(getHandle().
getType())) {
2673 [](
Value v) -> MappedValue {
return v; });
2675 assert(isa<TransformParamTypeInterface>(getHandle().
getType()) &&
2676 "unsupported kind of transform dialect type");
2677 return llvm::map_to_vector(state.
getParams(getHandle()),
2678 [](
Attribute a) -> MappedValue {
return a; });
2681 for (
auto &&en : llvm::enumerate(container)) {
2682 int64_t resultNum = en.index();
2683 if (resultNum >= getNumResults())
2684 resultNum = *getOverflowResult();
2685 resultHandles[resultNum].push_back(en.value());
2689 for (
auto &&it : llvm::enumerate(resultHandles))
2696void transform::SplitHandleOp::getEffects(
2704LogicalResult transform::SplitHandleOp::verify() {
2705 if (getOverflowResult().has_value() &&
2706 !(*getOverflowResult() < getNumResults()))
2707 return emitOpError(
"overflow_result is not a valid result index");
2709 for (
Type resultType : getResultTypes()) {
2713 return emitOpError(
"expects result types to implement the same transform "
2714 "interface as the operand type");
2724void transform::PayloadOp::getCheckedNormalForms(
2726 llvm::append_range(normalForms,
2727 getNormalForms().getAsRange<NormalFormAttrInterface>());
2738 unsigned numRepetitions = llvm::range_size(state.
getPayloadOps(getPattern()));
2739 for (
const auto &en : llvm::enumerate(getHandles())) {
2740 Value handle = en.value();
2741 if (isa<TransformHandleTypeInterface>(handle.getType())) {
2745 payload.reserve(numRepetitions * current.size());
2746 for (
unsigned i = 0; i < numRepetitions; ++i)
2747 llvm::append_range(payload, current);
2748 results.
set(llvm::cast<OpResult>(getReplicated()[en.index()]), payload);
2750 assert(llvm::isa<TransformParamTypeInterface>(handle.getType()) &&
2751 "expected param type");
2754 params.reserve(numRepetitions * current.size());
2755 for (
unsigned i = 0; i < numRepetitions; ++i)
2756 llvm::append_range(params, current);
2757 results.
setParams(llvm::cast<OpResult>(getReplicated()[en.index()]),
2764void transform::ReplicateOp::getEffects(
2781 if (
failed(mapBlockArguments(state)))
2789 OpAsmParser &parser, std::optional<OpAsmParser::UnresolvedOperand> &root,
2796 root = std::nullopt;
2799 if (failed(hasRoot.
value()))
2813 if (failed(parser.
parseType(rootType))) {
2817 if (!extraBindings.empty()) {
2822 if (extraBindingTypes.size() != extraBindings.size()) {
2824 "expected types to be provided for all operands");
2840 bool hasExtras = !extraBindings.empty();
2850 printer << rootType;
2852 printer <<
", " << llvm::interleaved(extraBindingTypes) <<
')';
2859 auto iface = dyn_cast<transform::TransformOpInterface>(use.
getOwner());
2863 return isHandleConsumed(use.
get(), iface);
2874 if (!potentialConsumer) {
2875 potentialConsumer = &use;
2880 <<
" has more than one potential consumer";
2883 diag.attachNote(use.getOwner()->getLoc())
2884 <<
"used here as operand #" << use.getOperandNumber();
2891LogicalResult transform::SequenceOp::verify() {
2892 assert(getBodyBlock()->getNumArguments() >= 1 &&
2893 "the number of arguments must have been verified to be more than 1 by "
2894 "PossibleTopLevelTransformOpTrait");
2896 if (!getRoot() && !getExtraBindings().empty()) {
2898 <<
"does not expect extra operands when used as top-level";
2904 return (
emitOpError() <<
"block argument #" << arg.getArgNumber());
2911 for (
Operation &child : *getBodyBlock()) {
2912 if (!isa<TransformOpInterface>(child) &&
2913 &child != &getBodyBlock()->back()) {
2916 <<
"expected children ops to implement TransformOpInterface";
2917 diag.attachNote(child.getLoc()) <<
"op without interface";
2922 auto report = [&]() {
2923 return (child.emitError() <<
"result #" <<
result.getResultNumber());
2930 if (!getBodyBlock()->mightHaveTerminator())
2931 return emitOpError() <<
"expects to have a terminator in the body";
2933 if (getBodyBlock()->getTerminator()->getOperandTypes() !=
2934 getOperation()->getResultTypes()) {
2936 <<
"expects the types of the terminator operands "
2937 "to match the types of the result";
2938 diag.attachNote(getBodyBlock()->getTerminator()->getLoc()) <<
"terminator";
2944void transform::SequenceOp::getEffects(
2950transform::SequenceOp::getEntrySuccessorOperands(
RegionSuccessor successor) {
2951 assert(successor.
getSuccessor() == &getBody() &&
"unexpected region index");
2952 if (getOperation()->getNumOperands() > 0)
2953 return getOperation()->getOperands();
2955 getOperation()->operand_end());
2958void transform::SequenceOp::getSuccessorRegions(
2961 Region *bodyRegion = &getBody();
2962 regions.emplace_back(bodyRegion);
2968 "unexpected region index");
2974 if (getNumOperands() == 0)
2977 return getResults();
2978 return getBody().getArguments();
2981void transform::SequenceOp::getRegionInvocationBounds(
2984 bounds.emplace_back(1, 1);
2989 FailurePropagationMode failurePropagationMode,
2991 SequenceBodyBuilderFn bodyBuilder) {
2992 build(builder, state, resultTypes, failurePropagationMode, root,
3001 FailurePropagationMode failurePropagationMode,
3003 SequenceBodyBuilderArgsFn bodyBuilder) {
3004 build(builder, state, resultTypes, failurePropagationMode, root,
3012 FailurePropagationMode failurePropagationMode,
3014 SequenceBodyBuilderFn bodyBuilder) {
3015 build(builder, state, resultTypes, failurePropagationMode,
Value(),
3023 FailurePropagationMode failurePropagationMode,
3025 SequenceBodyBuilderArgsFn bodyBuilder) {
3026 build(builder, state, resultTypes, failurePropagationMode,
Value(),
3044 build(builder,
result, name);
3051 llvm::outs() <<
"[[[ IR printer: ";
3052 if (getName().has_value())
3053 llvm::outs() << *getName() <<
" ";
3056 if (getAssumeVerified().value_or(
false))
3058 if (getUseLocalScope().value_or(
false))
3060 if (getSkipRegions().value_or(
false))
3064 llvm::outs() <<
"top-level ]]]\n";
3066 llvm::outs() <<
"\n";
3067 llvm::outs().flush();
3071 llvm::outs() <<
"]]]\n";
3073 target->print(llvm::outs(), printFlags);
3074 llvm::outs() <<
"\n";
3077 llvm::outs().flush();
3081void transform::PrintOp::getEffects(
3086 if (!getTargetMutable().empty())
3106 <<
"failed to verify payload op";
3107 diag.attachNote(
target->getLoc()) <<
"payload op";
3113void transform::VerifyOp::getEffects(
3122void transform::YieldOp::getEffects(
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
static ParseResult parseKeyValuePair(AsmParser &parser, DataLayoutEntryInterface &entry, bool tryType=false)
Parse an entry which can either be of the form key = value or a dlti.dl_entry attribute.
static std::string diag(const llvm::Value &value)
static llvm::ManagedStatic< PassManagerOptions > options
ParseResult parseSymbolName(StringAttr &result)
Parse an -identifier and store it (without the '@' symbol) in a string attribute.
@ None
Zero or more operands with no delimiters.
@ Braces
{} brackets surrounding zero or more operands.
virtual ParseResult parseOptionalKeywordOrString(std::string *result)=0
Parse an optional keyword or string.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
MLIRContext * getContext() const
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseRSquare()=0
Parse a ] token.
virtual ParseResult parseOptionalRParen()=0
Parse a ) token if present.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual OptionalParseResult parseOptionalAttribute(Attribute &result, Type type={})=0
Parse an arbitrary optional attribute of a given type and return it in result.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual ParseResult parseColon()=0
Parse a : token.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseArrow()=0
Parse a '->' token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseComma()=0
Parse a , token.
virtual ParseResult parseOptionalLParen()=0
Parse a ( token if present.
ParseResult parseTypeList(SmallVectorImpl< Type > &result)
Parse a type list.
virtual ParseResult parseOptionalLSquare()=0
Parse a [ token if present.
virtual void decreaseIndent()
Decrease indentation.
virtual void increaseIndent()
Increase indentation.
virtual void printAttribute(Attribute attr)
virtual void printNewline()
Print a newline and indent the printer to the start of the current operation/attribute/type.
Attributes are known-constant values of operations.
This class represents an argument of a Block.
Block represents an ordered list of Operations.
BlockArgument getArgument(unsigned i)
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
OpListType & getOperations()
Operation * getTerminator()
Get the terminator operation of this block.
BlockArgListType getArguments()
iterator_range< iterator > without_terminator()
Return an iterator range over the operation within this block excluding the terminator operation at t...
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
This class is a general helper class for creating context-global objects like types,...
IntegerAttr getI64IntegerAttr(int64_t value)
StringAttr getStringAttr(const Twine &bytes)
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
MLIRContext * getContext() const
A compatibility class connecting InFlightDiagnostic to DiagnosedSilenceableFailure while providing an...
The result of a transform IR operation application.
LogicalResult silence()
Converts silenceable failure into LogicalResult success without reporting the diagnostic,...
static DiagnosedSilenceableFailure success()
Constructs a DiagnosedSilenceableFailure in the success state.
std::string getMessage() const
Returns the diagnostic message without emitting it.
Diagnostic & attachNote(std::optional< Location > loc=std::nullopt)
Attaches a note to the last diagnostic.
LogicalResult checkAndReport()
Converts all kinds of failure into a LogicalResult failure, emitting the diagnostic if necessary.
bool succeeded() const
Returns true if this is a success.
static DiagnosedSilenceableFailure definiteFailure()
Constructs a DiagnosedSilenceableFailure in the failure state.
Dialects are groups of MLIR operations, types and attributes, as well as behavior associated with the...
A class for computing basic dominance information.
This class represents a frozen set of patterns that can be processed by a pattern applicator.
This class allows control over how the GreedyPatternRewriteDriver works.
static constexpr int64_t kNoLimit
GreedyRewriteConfig & setListener(RewriterBase::Listener *listener)
GreedyRewriteConfig & enableCSEBetweenIterations(bool enable=true)
GreedyRewriteConfig & setMaxIterations(int64_t iterations)
GreedyRewriteConfig & setMaxNumRewrites(int64_t limit)
IRValueT get() const
Return the current value being used by this operand.
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
This class represents a diagnostic that is inflight and set to be reported.
This class represents upper and lower bounds on the number of times a region of a RegionBranchOpInter...
Conversion from types to the LLVM IR dialect.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
std::vector< Dialect * > getLoadedDialects()
Return information about all IR dialects loaded in the context.
ArrayRef< RegisteredOperationName > getRegisteredOperations()
Return a sorted array containing the information about all registered operations.
void append(StringRef name, Attribute attr)
Add an attribute with the specified name.
NamedAttribute represents a combination of a name and an Attribute value.
StringAttr getName() const
Return the name of the attribute.
Attribute getValue() const
Return the value of the attribute.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual OptionalParseResult parseOptionalOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single operand if present.
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
void printOperands(const ContainerType &container)
Print a comma separated list of operands.
virtual void printOperand(Value value)=0
Print implementations for various things an operation contains.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Listener * getListener() const
Returns the current listener of this builder, or nullptr if this builder doesn't have a listener.
This class represents a single result from folding an operation.
This class represents an operand of an operation.
unsigned getOperandNumber() const
Return which operand this is in the OpOperand list of the Operation.
Set of flags used to control the behavior of the various IR print methods (e.g.
OpPrintingFlags & useLocalScope(bool enable=true)
Use local scope when printing the operation.
OpPrintingFlags & assumeVerified(bool enable=true)
Do not verify the operation when using custom operation printers.
OpPrintingFlags & skipRegions(bool skip=true)
Skip printing regions.
This is a value defined by a result of an operation.
This class provides the API for ops that are known to be isolated from above.
A trait used to provide symbol table functionalities to a region operation.
A wrapper class that allows for printing an operation with a set of flags, useful to act as a "stream...
This class implements the operand iterators for the Operation class.
type_range getType() const
type_range getTypes() const
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
Operation is the basic unit of execution within MLIR.
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Attribute getAttr(StringAttr name)
Return the specified attribute if present, null otherwise.
Block * getBlock()
Returns the operation block that contains this operation.
Operation * getParentWithTrait()
Returns the closest surrounding parent operation with trait Trait.
Location getLoc()
The source location the operation was defined or derived from.
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
unsigned getNumOperands()
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
OperationName getName()
The name of an operation is the key identifier for it.
void print(raw_ostream &os, const OpPrintingFlags &flags={})
operand_range getOperands()
Returns an iterator on the underlying Value's.
result_range getOpResults()
Operation * clone(IRMapping &mapper, const CloneOptions &options=CloneOptions::all())
Create a deep copy of this operation, remapping any operands that use values outside of the operation...
This class implements Optional functionality for ParseResult.
ParseResult value() const
Access the internal ParseResult value.
bool has_value() const
Returns true if we contain a valid ParseResult value.
static const PassInfo * lookup(StringRef passArg)
Returns the pass info for the specified pass class or null if unknown.
The main pass manager and pipeline builder.
static const PassPipelineInfo * lookup(StringRef pipelineArg)
Returns the pass pipeline info for the specified pass pipeline or null if unknown.
Structure to group information about a passes and pass pipelines (argument to invoke via mlir-opt,...
LogicalResult addToPipeline(OpPassManager &pm, StringRef options, function_ref< LogicalResult(const Twine &)> errorHandler) const
Adds this pass registry entry to the given pass manager.
This class represents a point being branched from in the methods of the RegionBranchOpInterface.
bool isParent() const
Returns true if branching from the parent op.
RegionBranchTerminatorOpInterface getTerminatorPredecessorOrNull() const
Returns the terminator if branching from a region.
This class represents a successor of a region.
static RegionSuccessor parent()
Initialize a successor that branches after/out of the parent operation.
bool isParent() const
Return true if the successor is the parent operation.
Region * getSuccessor() const
Return the given region successor.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
BlockArgListType getArguments()
This is a "type erased" representation of a registered operation.
MLIRContext * getContext() const
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
static DerivedEffect * get()
This class represents a collection of SymbolTables.
virtual Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
static StringRef getSymbolAttrName()
Return the name of the attribute used for symbol names.
static Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
type_range getType() const
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
use_range getUses() const
Returns a range of all uses, which is useful for iterating over all uses.
A utility result that is used to signal how to proceed with an ongoing walk:
static WalkResult advance()
bool wasInterrupted() const
Returns true if the walk was interrupted.
static WalkResult interrupt()
Operation * getOwner() const
Return the owner of this operand.
A named class for passing around the variadic flag.
void printFunctionOp(OpAsmPrinter &p, FunctionOpInterface op, bool isVariadic, StringRef typeAttrName, StringAttr argAttrsName, StringAttr resAttrsName)
Printer implementation for function-like operations.
ParseResult parseFunctionOp(OpAsmParser &parser, OperationState &result, bool allowVariadic, StringAttr typeAttrName, FuncTypeBuilder funcTypeBuilder, StringAttr argAttrsName, StringAttr resAttrsName)
Parser implementation for function-like operations.
Include the generated interface declarations.
InFlightDiagnostic emitWarning(Location loc)
Utility method to emit a warning message using this location.
void eliminateCommonSubExpressions(RewriterBase &rewriter, DominanceInfo &domInfo, Operation *op, bool *changed=nullptr, int64_t *numCSE=nullptr, int64_t *numDCE=nullptr)
Eliminate common subexpressions within the given operation.
function_ref< void(Value, StringRef)> OpAsmSetValueNameFn
A functor used to set the name of the start of a result group of an operation.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
LogicalResult applyPatternsGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
DiagnosedSilenceableFailure emitSilenceableFailure(Location loc, const Twine &message={})
Emits a silenceable failure with the given message.
llvm::DenseSet< ValueT, ValueInfoT > DenseSet
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
LogicalResult applyOpPatternsGreedily(ArrayRef< Operation * > ops, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr, bool *allErased=nullptr)
Rewrite the specified ops by repeatedly applying the highest benefit patterns in a greedy worklist dr...
llvm::SetVector< T, Vector, Set, N > SetVector
DiagnosedDefiniteFailure emitDefiniteFailure(Location loc, const Twine &message={})
Emits a definite failure with the given message.
bool eliminateTriviallyDeadOps(RewriterBase &rewriter, Region ®ion, bool includeNestedRegions=true)
Remove trivially dead operations from region.
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
llvm::function_ref< Fn > function_ref
size_t moveLoopInvariantCode(ArrayRef< Region * > regions, function_ref< bool(Value, Region *)> isDefinedOutsideRegion, function_ref< bool(Operation *, Region *)> shouldMoveOutOfRegion, function_ref< void(Operation *, Region *)> moveOutOfRegion)
Given a list of regions, perform loop-invariant code motion.
This is the representation of an operand reference.
This represents an operation in an abstracted form, suitable for use with the builder APIs.
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
SmallVector< std::unique_ptr< Region >, 1 > regions
Regions that the op will hold.
Region * addRegion()
Create a region that should be attached to the operation.