34#include "llvm/ADT/DenseSet.h"
35#include "llvm/ADT/STLExtras.h"
36#include "llvm/ADT/ScopeExit.h"
37#include "llvm/ADT/SmallPtrSet.h"
38#include "llvm/ADT/SmallVectorExtras.h"
39#include "llvm/ADT/TypeSwitch.h"
40#include "llvm/Support/Debug.h"
41#include "llvm/Support/DebugLog.h"
42#include "llvm/Support/ErrorHandling.h"
43#include "llvm/Support/InterleavedRange.h"
46#define DEBUG_TYPE "transform-dialect"
47#define DEBUG_TYPE_MATCHER "transform-matcher"
59 OpAsmParser &parser, std::optional<OpAsmParser::UnresolvedOperand> &root,
80 while (transformAncestor) {
81 if (transformAncestor == payload) {
84 <<
"cannot apply transform to itself (or one of its ancestors)";
85 diag.attachNote(payload->
getLoc()) <<
"target payload op";
88 transformAncestor = transformAncestor->
getParentOp();
94#include "mlir/Dialect/Transform/IR/TransformOps.cpp.inc"
100OperandRange transform::AlternativesOp::getEntrySuccessorOperands(
102 if (!successor.
isParent() && getOperation()->getNumOperands() == 1)
103 return getOperation()->getOperands();
105 getOperation()->operand_end());
108void transform::AlternativesOp::getSuccessorRegions(
110 for (
Region &alternative : llvm::drop_begin(
115 ->getRegionNumber() +
117 regions.emplace_back(&alternative);
124transform::AlternativesOp::getSuccessorInputs(
RegionSuccessor successor) {
126 return getOperation()->getResults();
130void transform::AlternativesOp::getRegionInvocationBounds(
135 bounds.reserve(getNumRegions());
136 bounds.emplace_back(1, 1);
143 results.
set(res, {});
151 if (
Value scopeHandle = getScope())
152 llvm::append_range(originals, state.
getPayloadOps(scopeHandle));
157 if (original->isAncestor(getOperation())) {
159 <<
"scope must not contain the transforms being applied";
160 diag.attachNote(original->getLoc()) <<
"scope";
165 <<
"only isolated-from-above ops can be alternative scopes";
166 diag.attachNote(original->getLoc()) <<
"scope";
171 for (
Region ® : getAlternatives()) {
177 auto clones = llvm::map_to_vector(
179 llvm::scope_exit deleteClones([&] {
190 if (
result.isSilenceableFailure()) {
191 LDBG() <<
"alternative failed: " <<
result.getMessage();
196 if (::mlir::failed(
result.silence()))
205 deleteClones.release();
206 TrackingListener listener(state, *
this);
208 for (
const auto &kvp : llvm::zip(originals, clones)) {
215 detail::forwardTerminatorOperands(®.front(), state, results);
219 return emitSilenceableError() <<
"all alternatives failed";
222void transform::AlternativesOp::getEffects(
226 for (
Region *region : getRegions()) {
227 if (!region->empty())
233LogicalResult transform::AlternativesOp::verify() {
234 for (
Region &alternative : getAlternatives()) {
239 <<
"expects terminator operands to have the "
240 "same type as results of the operation";
241 diag.attachNote(terminator->
getLoc()) <<
"terminator";
261 if (
auto paramH = getParam()) {
263 if (params.size() != 1) {
264 if (targets.size() != params.size()) {
265 return emitSilenceableError()
266 <<
"parameter and target have different payload lengths ("
267 << params.size() <<
" vs " << targets.size() <<
")";
269 for (
auto &&[
target, attr] : llvm::zip_equal(targets, params))
270 target->setAttr(getName(), attr);
275 for (
auto *
target : targets)
276 target->setAttr(getName(), attr);
280void transform::AnnotateOp::getEffects(
292transform::ApplyCommonSubexpressionEliminationOp::applyToOne(
307void transform::ApplyCommonSubexpressionEliminationOp::getEffects(
332 auto addDefiningOpsToWorklist = [&](
Operation *op) {
335 if (
Operation *defOp = v.getDefiningOp())
336 if (
target->isProperAncestor(defOp))
337 worklist.insert(defOp);
345 const auto *it = llvm::find(worklist, op);
346 if (it != worklist.end())
355 addDefiningOpsToWorklist(op);
361 while (!worklist.empty()) {
365 addDefiningOpsToWorklist(op);
372void transform::ApplyDeadCodeEliminationOp::getEffects(
397 if (!getRegion().empty()) {
398 for (
Operation &op : getRegion().front()) {
399 cast<transform::PatternDescriptorOpInterface>(&op)
400 .populatePatternsWithState(patterns, state);
412 : getMaxIterations());
415 : getMaxNumRewrites());
423 <<
"greedy pattern application failed";
434 ops.push_back(nestedOp);
439 static const int64_t kNumMaxIterations = 50;
441 bool cseChanged =
false;
445 <<
"greedy pattern application failed";
453 }
while (cseChanged && ++iteration < kNumMaxIterations);
455 if (iteration == kNumMaxIterations)
461LogicalResult transform::ApplyPatternsOp::verify() {
462 if (!getRegion().empty()) {
463 for (
Operation &op : getRegion().front()) {
464 if (!isa<transform::PatternDescriptorOpInterface>(&op)) {
466 <<
"expected children ops to implement "
467 "PatternDescriptorOpInterface";
468 diag.attachNote(op.
getLoc()) <<
"op without interface";
476void transform::ApplyPatternsOp::getEffects(
482void transform::ApplyPatternsOp::build(
491 bodyBuilder(builder,
result.location);
498void transform::ApplyCanonicalizationPatternsOp::populatePatterns(
502 dialect->getCanonicalizationPatterns(patterns);
504 op.getCanonicalizationPatterns(patterns, ctx);
518 std::unique_ptr<TypeConverter> defaultTypeConverter;
519 transform::TypeConverterBuilderOpInterface typeConverterBuilder =
520 getDefaultTypeConverter();
521 if (typeConverterBuilder)
522 defaultTypeConverter = typeConverterBuilder.getTypeConverter();
527 for (
Attribute attr : cast<ArrayAttr>(*getLegalOps()))
528 conversionTarget.addLegalOp(
531 for (
Attribute attr : cast<ArrayAttr>(*getIllegalOps()))
532 conversionTarget.addIllegalOp(
534 if (getLegalDialects())
535 for (
Attribute attr : cast<ArrayAttr>(*getLegalDialects()))
536 conversionTarget.addLegalDialect(cast<StringAttr>(attr).getValue());
537 if (getIllegalDialects())
538 for (
Attribute attr : cast<ArrayAttr>(*getIllegalDialects()))
539 conversionTarget.addIllegalDialect(cast<StringAttr>(attr).getValue());
547 if (!getPatterns().empty()) {
548 for (
Operation &op : getPatterns().front()) {
550 cast<transform::ConversionPatternDescriptorOpInterface>(&op);
553 std::unique_ptr<TypeConverter> typeConverter =
554 descriptor.getTypeConverter();
557 keepAliveConverters.emplace_back(std::move(typeConverter));
558 converter = keepAliveConverters.back().get();
561 if (!defaultTypeConverter) {
563 <<
"pattern descriptor does not specify type "
564 "converter and apply_conversion_patterns op has "
565 "no default type converter";
566 diag.attachNote(op.
getLoc()) <<
"pattern descriptor op";
569 converter = defaultTypeConverter.get();
575 descriptor.populateConversionTargetRules(*converter, conversionTarget);
577 descriptor.populatePatterns(*converter, patterns);
585 TrackingListenerConfig trackingConfig;
586 trackingConfig.requireMatchingReplacementOpName =
false;
587 ErrorCheckingTrackingListener trackingListener(state, *
this, trackingConfig);
588 ConversionConfig conversionConfig;
589 if (getPreserveHandles())
590 conversionConfig.listener = &trackingListener;
601 LogicalResult status = failure();
602 if (getPartialConversion()) {
603 status = applyPartialConversion(
target, conversionTarget, frozenPatterns,
606 status = applyFullConversion(
target, conversionTarget, frozenPatterns,
613 diag = emitSilenceableError() <<
"dialect conversion failed";
614 diag.attachNote(
target->getLoc()) <<
"target op";
619 trackingListener.checkAndResetError();
621 if (
diag.succeeded()) {
623 return trackingFailure;
625 diag.attachNote() <<
"tracking listener also failed: "
630 if (!
diag.succeeded())
637LogicalResult transform::ApplyConversionPatternsOp::verify() {
638 if (getNumRegions() != 1 && getNumRegions() != 2)
640 if (!getPatterns().empty()) {
641 for (
Operation &op : getPatterns().front()) {
642 if (!isa<transform::ConversionPatternDescriptorOpInterface>(&op)) {
644 emitOpError() <<
"expected pattern children ops to implement "
645 "ConversionPatternDescriptorOpInterface";
646 diag.attachNote(op.
getLoc()) <<
"op without interface";
651 if (getNumRegions() == 2) {
652 Region &typeConverterRegion = getRegion(1);
653 if (!llvm::hasSingleElement(typeConverterRegion.
front()))
655 <<
"expected exactly one op in default type converter region";
657 auto typeConverterOp = dyn_cast<transform::TypeConverterBuilderOpInterface>(
659 if (!typeConverterOp) {
661 <<
"expected default converter child op to "
662 "implement TypeConverterBuilderOpInterface";
663 diag.attachNote(maybeTypeConverter->
getLoc()) <<
"op without interface";
667 if (!getPatterns().empty()) {
668 for (
Operation &op : getPatterns().front()) {
670 cast<transform::ConversionPatternDescriptorOpInterface>(&op);
671 if (
failed(descriptor.verifyTypeConverter(typeConverterOp)))
679void transform::ApplyConversionPatternsOp::getEffects(
681 if (!getPreserveHandles()) {
689void transform::ApplyConversionPatternsOp::build(
699 if (patternsBodyBuilder)
700 patternsBodyBuilder(builder,
result.location);
706 if (typeConverterBodyBuilder)
707 typeConverterBodyBuilder(builder,
result.location);
715void transform::ApplyToLLVMConversionPatternsOp::populatePatterns(
718 assert(dialect &&
"expected that dialect is loaded");
719 auto *iface = cast<ConvertToLLVMPatternInterface>(dialect);
723 iface->populateConvertToLLVMConversionPatterns(
727LogicalResult transform::ApplyToLLVMConversionPatternsOp::verifyTypeConverter(
728 transform::TypeConverterBuilderOpInterface builder) {
729 if (builder.getTypeConverterType() !=
"LLVMTypeConverter")
734LogicalResult transform::ApplyToLLVMConversionPatternsOp::verify() {
737 return emitOpError(
"unknown dialect or dialect not loaded: ")
739 auto *iface = dyn_cast<ConvertToLLVMPatternInterface>(dialect);
742 "dialect does not implement ConvertToLLVMPatternInterface or "
743 "extension was not loaded: ")
753transform::ApplyLoopInvariantCodeMotionOp::applyToOne(
763void transform::ApplyLoopInvariantCodeMotionOp::getEffects(
773void transform::ApplyRegisteredPassOp::getEffects(
791 llvm::raw_string_ostream optionsStream(
options);
796 if (
auto paramOperand = dyn_cast<transform::ParamOperandAttr>(valueAttr)) {
799 int64_t dynamicOptionIdx = paramOperand.getIndex().getInt();
800 assert(dynamicOptionIdx <
static_cast<int64_t>(dynamicOptions.size()) &&
801 "the number of ParamOperandAttrs in the options DictionaryAttr"
802 "should be the same as the number of options passed as params");
804 state.
getParams(dynamicOptions[dynamicOptionIdx]);
806 llvm::interleave(attrsAssociatedToParam, optionsStream, appendValueAttr,
808 }
else if (
auto arrayAttr = dyn_cast<ArrayAttr>(valueAttr)) {
810 llvm::interleave(arrayAttr, optionsStream, appendValueAttr,
",");
811 }
else if (
auto strAttr = dyn_cast<StringAttr>(valueAttr)) {
813 optionsStream << strAttr.getValue().str();
816 valueAttr.print(optionsStream,
true);
822 getOptions(), optionsStream,
823 [&](
auto namedAttribute) {
824 optionsStream << namedAttribute.getName().str();
825 optionsStream <<
"=";
826 appendValueAttr(namedAttribute.getValue());
829 optionsStream.flush();
837 <<
"unknown pass or pass pipeline: " << getPassName();
846 <<
"failed to add pass or pass pipeline to pipeline: "
863 auto diag = emitSilenceableError() <<
"pass pipeline failed";
864 diag.attachNote(
target->getLoc()) <<
"target op";
870 results.
set(llvm::cast<OpResult>(getResult()), targets);
879 size_t dynamicOptionsIdx = 0;
885 std::function<ParseResult(
Attribute &)> parseValue =
886 [&](
Attribute &valueAttr) -> ParseResult {
894 [&]() -> ParseResult { return parseValue(attrs.emplace_back()); },
895 " in options dictionary") ||
899 valueAttr = ArrayAttr::get(parser.
getContext(), attrs);
909 ParseResult parsedOperand = parser.
parseOperand(operand);
910 if (failed(parsedOperand))
916 dynamicOptions.push_back(operand);
917 auto wrappedIndex = IntegerAttr::get(
918 IntegerType::get(parser.
getContext(), 64), dynamicOptionsIdx++);
920 transform::ParamOperandAttr::get(parser.
getContext(), wrappedIndex);
921 }
else if (failed(parsedValueAttr.
value())) {
923 }
else if (isa<transform::ParamOperandAttr>(valueAttr)) {
925 <<
"the param_operand attribute is a marker reserved for "
926 <<
"indicating a value will be passed via params and is only used "
927 <<
"in the generic print format";
941 <<
"expected key to either be an identifier or a string";
945 <<
"expected '=' after key in key-value pair";
947 if (failed(parseValue(valueAttr)))
949 <<
"expected a valid attribute or operand as value associated "
950 <<
"to key '" << key <<
"'";
959 " in options dictionary"))
962 if (DictionaryAttr::findDuplicate(
963 keyValuePairs,
false)
966 <<
"duplicate keys found in options dictionary";
981 if (
auto paramOperandAttr =
982 dyn_cast<transform::ParamOperandAttr>(valueAttr)) {
985 dynamicOptions[paramOperandAttr.getIndex().getInt()]);
986 }
else if (
auto arrayAttr = dyn_cast<ArrayAttr>(valueAttr)) {
989 llvm::interleaveComma(arrayAttr, printer, printOptionValue);
998 printer << namedAttribute.
getName();
1000 printOptionValue(namedAttribute.
getValue());
1005LogicalResult transform::ApplyRegisteredPassOp::verify() {
1012 std::function<LogicalResult(
Attribute)> checkOptionValue =
1013 [&](
Attribute valueAttr) -> LogicalResult {
1014 if (
auto paramOperand = dyn_cast<transform::ParamOperandAttr>(valueAttr)) {
1015 int64_t dynamicOptionIdx = paramOperand.getIndex().getInt();
1016 if (dynamicOptionIdx < 0 ||
1017 dynamicOptionIdx >=
static_cast<int64_t>(dynamicOptions.size()))
1019 <<
"dynamic option index " << dynamicOptionIdx
1020 <<
" is out of bounds for the number of dynamic options: "
1021 << dynamicOptions.size();
1022 if (dynamicOptions[dynamicOptionIdx] ==
nullptr)
1023 return emitOpError() <<
"dynamic option index " << dynamicOptionIdx
1024 <<
" is already used in options";
1025 dynamicOptions[dynamicOptionIdx] =
nullptr;
1026 }
else if (
auto arrayAttr = dyn_cast<ArrayAttr>(valueAttr)) {
1028 for (
auto eltAttr : arrayAttr)
1029 if (
failed(checkOptionValue(eltAttr)))
1036 if (
failed(checkOptionValue(namedAttr.getValue())))
1040 for (
Value dynamicOption : dynamicOptions)
1042 return emitOpError() <<
"a param operand does not have a corresponding "
1043 <<
"param_operand attr in the options dict";
1056 results.push_back(
target);
1060void transform::CastOp::getEffects(
1068 assert(inputs.size() == 1 &&
"expected one input");
1069 assert(outputs.size() == 1 &&
"expected one output");
1070 return llvm::all_of(
1071 std::initializer_list<Type>{inputs.front(), outputs.front()},
1072 llvm::IsaPred<transform::TransformHandleTypeInterface>);
1092 assert(block.
getParent() &&
"cannot match using a detached block");
1099 if (!isa<transform::MatchOpInterface>(match)) {
1101 <<
"expected operations in the match part to "
1102 "implement MatchOpInterface";
1105 state.
applyTransform(cast<transform::TransformOpInterface>(match));
1106 if (
diag.succeeded())
1124template <
typename... Tys>
1126 return ((isa<Tys>(t1) && isa<Tys>(t2)) || ... ||
false);
1133 transform::TransformParamTypeInterface,
1134 transform::TransformValueHandleTypeInterface>(
1147 getOperation(), getMatcher());
1148 if (matcher.isExternal()) {
1150 <<
"unresolved external symbol " << getMatcher();
1154 rawResults.resize(getOperation()->getNumResults());
1155 std::optional<DiagnosedSilenceableFailure> maybeFailure;
1167 matcher.getFunctionBody().front(),
1170 if (
diag.isDefiniteFailure())
1172 if (
diag.isSilenceableFailure()) {
1174 <<
" failed: " <<
diag.getMessage();
1179 for (
auto &&[i, mapping] : llvm::enumerate(mappings)) {
1180 if (mapping.size() != 1) {
1181 maybeFailure.emplace(emitSilenceableError()
1182 <<
"result #" << i <<
", associated with "
1184 <<
" payload objects, expected 1");
1187 rawResults[i].push_back(mapping[0]);
1192 return std::move(*maybeFailure);
1193 assert(!maybeFailure &&
"failure set but the walk was not interrupted");
1195 for (
auto &&[opResult, rawResult] :
1196 llvm::zip_equal(getOperation()->getResults(), rawResults)) {
1203void transform::CollectMatchingOp::getEffects(
1210LogicalResult transform::CollectMatchingOp::verifySymbolUses(
1212 auto matcherSymbol = dyn_cast_or_null<FunctionOpInterface>(
1214 if (!matcherSymbol ||
1215 !isa<TransformOpInterface>(matcherSymbol.getOperation()))
1216 return emitError() <<
"unresolved matcher symbol " << getMatcher();
1219 if (argumentTypes.size() != 1 ||
1220 !isa<TransformHandleTypeInterface>(argumentTypes[0])) {
1222 <<
"expected the matcher to take one operation handle argument";
1224 if (!matcherSymbol.getArgAttr(
1225 0, transform::TransformDialect::kArgReadOnlyAttrName)) {
1226 return emitError() <<
"expected the matcher argument to be marked readonly";
1230 if (resultTypes.size() != getOperation()->getNumResults()) {
1232 <<
"expected the matcher to yield as many values as op has results ("
1233 << getOperation()->getNumResults() <<
"), got "
1234 << resultTypes.size();
1237 for (
auto &&[i, matcherType, resultType] :
1238 llvm::enumerate(resultTypes, getOperation()->getResultTypes())) {
1243 <<
"mismatching type interfaces for matcher result and op result #"
1255bool transform::ForeachMatchOp::allowsRepeatedHandleOperands() {
return true; }
1263 matchActionPairs.reserve(getMatchers().size());
1265 for (
auto &&[matcher, action] :
1266 llvm::zip_equal(getMatchers(), getActions())) {
1267 auto matcherSymbol =
1269 getOperation(), cast<SymbolRefAttr>(matcher));
1272 getOperation(), cast<SymbolRefAttr>(action));
1273 assert(matcherSymbol && actionSymbol &&
1274 "unresolved symbols not caught by the verifier");
1276 if (matcherSymbol.isExternal())
1278 if (actionSymbol.isExternal())
1281 matchActionPairs.emplace_back(matcherSymbol, actionSymbol);
1292 matchInputMapping.emplace_back();
1294 getForwardedInputs(), state);
1296 actionResultMapping.resize(getForwardedOutputs().size());
1302 if (!getRestrictRoot() && op == root)
1310 firstMatchArgument.clear();
1311 firstMatchArgument.push_back(op);
1314 for (
auto [matcher, action] : matchActionPairs) {
1316 matchBlock(matcher.getFunctionBody().front(), matchInputMapping,
1317 state, matchOutputMapping);
1318 if (
diag.isDefiniteFailure())
1320 if (
diag.isSilenceableFailure()) {
1322 <<
" failed: " <<
diag.getMessage();
1328 action.getFunctionBody().front().getArguments(),
1329 matchOutputMapping))) {
1334 action.getFunctionBody().front().without_terminator()) {
1337 if (
result.isDefiniteFailure())
1339 if (
result.isSilenceableFailure()) {
1341 overallDiag = emitSilenceableError() <<
"actions failed";
1344 <<
"failed action: " <<
result.getMessage();
1346 <<
"when applied to this matching payload";
1351 if (
failed(detail::appendValueMappings(
1353 action.getFunctionBody().front().getTerminator()->getOperands(),
1354 state, getFlattenResults()))) {
1356 <<
"action @" << action.getName()
1357 <<
" has results associated with multiple payload entities, "
1358 "but flattening was not requested";
1373 results.
set(llvm::cast<OpResult>(getUpdated()),
1375 for (
auto &&[
result, mapping] :
1376 llvm::zip_equal(getForwardedOutputs(), actionResultMapping)) {
1382void transform::ForeachMatchOp::getAsmResultNames(
1384 setNameFn(getUpdated(),
"updated_root");
1385 for (
Value v : getForwardedOutputs()) {
1386 setNameFn(v,
"yielded");
1390void transform::ForeachMatchOp::getEffects(
1393 if (getOperation()->getNumOperands() < 1 ||
1394 getOperation()->getNumResults() < 1) {
1418 matcherList.push_back(SymbolRefAttr::get(matcher));
1419 actionList.push_back(SymbolRefAttr::get(action));
1433 for (
auto &&[matcher, action, idx] : llvm::zip_equal(
1436 printer << cast<SymbolRefAttr>(matcher) <<
" -> "
1437 << cast<SymbolRefAttr>(action);
1445LogicalResult transform::ForeachMatchOp::verify() {
1446 if (getMatchers().size() != getActions().size())
1447 return emitOpError() <<
"expected the same number of matchers and actions";
1448 if (getMatchers().empty())
1449 return emitOpError() <<
"expected at least one match/action pair";
1453 if (matcherNames.insert(name).second)
1456 <<
" is used more than once, only the first match will apply";
1467 bool alsoVerifyInternal =
false) {
1468 auto transformOp = cast<transform::TransformOpInterface>(op.getOperation());
1469 llvm::SmallDenseSet<unsigned> consumedArguments;
1470 if (!op.isExternal()) {
1474 for (
unsigned i = 0, e = op.getNumArguments(); i < e; ++i) {
1476 op.getArgAttr(i, transform::TransformDialect::kArgConsumedAttrName) !=
1479 op.getArgAttr(i, transform::TransformDialect::kArgReadOnlyAttrName) !=
1481 if (isConsumed && isReadOnly) {
1482 return transformOp.emitSilenceableError()
1483 <<
"argument #" << i <<
" cannot be both readonly and consumed";
1485 if ((op.isExternal() || alsoVerifyInternal) && !isConsumed && !isReadOnly) {
1486 return transformOp.emitSilenceableError()
1487 <<
"must provide consumed/readonly status for arguments of "
1488 "external or called ops";
1490 if (op.isExternal())
1493 if (consumedArguments.contains(i) && !isConsumed && isReadOnly) {
1494 return transformOp.emitSilenceableError()
1495 <<
"argument #" << i
1496 <<
" is consumed in the body but is not marked as such";
1498 if (emitWarnings && !consumedArguments.contains(i) && isConsumed) {
1502 <<
"op argument #" << i
1503 <<
" is not consumed in the body but is marked as consumed";
1509LogicalResult transform::ForeachMatchOp::verifySymbolUses(
1511 assert(getMatchers().size() == getActions().size());
1513 StringAttr::get(
getContext(), TransformDialect::kArgConsumedAttrName);
1514 for (
auto &&[matcher, action] :
1515 llvm::zip_equal(getMatchers(), getActions())) {
1517 auto matcherSymbol = dyn_cast_or_null<FunctionOpInterface>(
1519 cast<SymbolRefAttr>(matcher)));
1520 auto actionSymbol = dyn_cast_or_null<FunctionOpInterface>(
1522 cast<SymbolRefAttr>(action)));
1523 if (!matcherSymbol ||
1524 !isa<TransformOpInterface>(matcherSymbol.getOperation()))
1525 return emitError() <<
"unresolved matcher symbol " << matcher;
1526 if (!actionSymbol ||
1527 !isa<TransformOpInterface>(actionSymbol.getOperation()))
1528 return emitError() <<
"unresolved action symbol " << action;
1533 .checkAndReport())) {
1539 .checkAndReport())) {
1544 TypeRange operandTypes = getOperandTypes();
1545 TypeRange matcherArguments = matcherSymbol.getArgumentTypes();
1546 if (operandTypes.size() != matcherArguments.size()) {
1548 emitError() <<
"the number of operands (" << operandTypes.size()
1549 <<
") doesn't match the number of matcher arguments ("
1550 << matcherArguments.size() <<
") for " << matcher;
1551 diag.attachNote(matcherSymbol->getLoc()) <<
"symbol declaration";
1554 for (
auto &&[i, operand, argument] :
1555 llvm::enumerate(operandTypes, matcherArguments)) {
1556 if (matcherSymbol.getArgAttr(i, consumedAttr)) {
1559 <<
"does not expect matcher symbol to consume its operand #" << i;
1560 diag.attachNote(matcherSymbol->getLoc()) <<
"symbol declaration";
1569 <<
"mismatching type interfaces for operand and matcher argument #"
1570 << i <<
" of matcher " << matcher;
1571 diag.attachNote(matcherSymbol->getLoc()) <<
"symbol declaration";
1576 TypeRange matcherResults = matcherSymbol.getResultTypes();
1577 TypeRange actionArguments = actionSymbol.getArgumentTypes();
1578 if (matcherResults.size() != actionArguments.size()) {
1579 return emitError() <<
"mismatching number of matcher results and "
1580 "action arguments between "
1581 << matcher <<
" (" << matcherResults.size() <<
") and "
1582 << action <<
" (" << actionArguments.size() <<
")";
1584 for (
auto &&[i, matcherType, actionType] :
1585 llvm::enumerate(matcherResults, actionArguments)) {
1589 return emitError() <<
"mismatching type interfaces for matcher result "
1590 "and action argument #"
1591 << i <<
"of matcher " << matcher <<
" and action "
1596 TypeRange actionResults = actionSymbol.getResultTypes();
1597 auto resultTypes =
TypeRange(getResultTypes()).drop_front();
1598 if (actionResults.size() != resultTypes.size()) {
1600 emitError() <<
"the number of action results ("
1601 << actionResults.size() <<
") for " << action
1602 <<
" doesn't match the number of extra op results ("
1603 << resultTypes.size() <<
")";
1604 diag.attachNote(actionSymbol->getLoc()) <<
"symbol declaration";
1607 for (
auto &&[i, resultType, actionType] :
1608 llvm::enumerate(resultTypes, actionResults)) {
1613 emitError() <<
"mismatching type interfaces for action result #" << i
1614 <<
" of action " << action <<
" and op result";
1615 diag.attachNote(actionSymbol->getLoc()) <<
"symbol declaration";
1633 detail::prepareValueMappings(payloads, getTargets(), state);
1634 size_t numIterations = payloads.empty() ? 0 : payloads.front().size();
1635 bool withZipShortest = getWithZipShortest();
1639 if (withZipShortest) {
1643 return a.size() <
b.size();
1646 for (
auto &payload : payloads)
1647 payload.resize(numIterations);
1653 for (
size_t argIdx = 1; !withZipShortest && argIdx < payloads.size();
1655 if (payloads[argIdx].size() != numIterations) {
1656 return emitSilenceableError()
1657 <<
"prior targets' payload size (" << numIterations
1658 <<
") differs from payload size (" << payloads[argIdx].size()
1659 <<
") of target " << getTargets()[argIdx];
1668 for (
size_t iterIdx = 0; iterIdx < numIterations; iterIdx++) {
1671 for (
auto &&[argIdx, blockArg] : llvm::enumerate(blockArguments)) {
1682 llvm::cast<transform::TransformOpInterface>(
transform));
1688 OperandRange yieldOperands = getYieldOp().getOperands();
1689 for (
auto &&[
result, yieldOperand, resTuple] :
1690 llvm::zip_equal(getResults(), yieldOperands, zippedResults))
1692 if (isa<TransformHandleTypeInterface>(
result.getType()))
1693 llvm::append_range(resTuple, state.
getPayloadOps(yieldOperand));
1694 else if (isa<TransformValueHandleTypeInterface>(
result.getType()))
1696 else if (isa<TransformParamTypeInterface>(
result.getType()))
1697 llvm::append_range(resTuple, state.
getParams(yieldOperand));
1699 assert(
false &&
"unhandled handle type");
1703 for (
auto &&[
result, resPayload] : zip_equal(getResults(), zippedResults))
1709void transform::ForeachOp::getEffects(
1713 for (
auto &&[
target, blockArg] :
1714 llvm::zip(getTargetsMutable(), getBody().front().getArguments())) {
1716 if (any_of(getBody().front().without_terminator(), [&](
Operation &op) {
1718 cast<TransformOpInterface>(&op));
1726 if (any_of(getBody().front().without_terminator(), [&](
Operation &op) {
1730 }
else if (any_of(getBody().front().without_terminator(), [&](
Operation &op) {
1739void transform::ForeachOp::getSuccessorRegions(
1741 Region *bodyRegion = &getBody();
1743 regions.emplace_back(bodyRegion);
1750 "unexpected region index");
1751 regions.emplace_back(bodyRegion);
1761transform::ForeachOp::getEntrySuccessorOperands(
RegionSuccessor successor) {
1764 assert(successor.
getSuccessor() == &getBody() &&
"unexpected region index");
1765 return getOperation()->getOperands();
1768transform::YieldOp transform::ForeachOp::getYieldOp() {
1769 return cast<transform::YieldOp>(getBody().front().getTerminator());
1772LogicalResult transform::ForeachOp::verify() {
1773 for (
auto [targetOpt, bodyArgOpt] :
1774 llvm::zip_longest(getTargets(), getBody().front().getArguments())) {
1775 if (!targetOpt || !bodyArgOpt)
1776 return emitOpError() <<
"expects the same number of targets as the body "
1777 "has block arguments";
1778 if (targetOpt.value().getType() != bodyArgOpt.value().getType())
1780 "expects co-indexed targets and the body's "
1781 "block arguments to have the same op/value/param type");
1784 for (
auto [resultOpt, yieldOperandOpt] :
1785 llvm::zip_longest(getResults(), getYieldOp().getOperands())) {
1786 if (!resultOpt || !yieldOperandOpt)
1787 return emitOpError() <<
"expects the same number of results as the "
1788 "yield terminator has operands";
1789 if (resultOpt.value().getType() != yieldOperandOpt.value().getType())
1790 return emitOpError(
"expects co-indexed results and yield "
1791 "operands to have the same op/value/param type");
1809 for (
int64_t i = 0, e = getNthParent(); i < e; ++i) {
1812 bool checkIsolatedFromAbove =
1813 !getIsolatedFromAbove() ||
1815 bool checkOpName = !getOpName().has_value() ||
1817 if (checkIsolatedFromAbove && checkOpName)
1822 if (getAllowEmptyResults()) {
1823 results.
set(llvm::cast<OpResult>(getResult()), parents);
1827 emitSilenceableError()
1828 <<
"could not find a parent op that matches all requirements";
1829 diag.attachNote(
target->getLoc()) <<
"target op";
1833 if (getDeduplicate()) {
1834 if (resultSet.insert(parent).second)
1835 parents.push_back(parent);
1837 parents.push_back(parent);
1840 results.
set(llvm::cast<OpResult>(getResult()), parents);
1852 int64_t resultNumber = getResultNumber();
1854 if (std::empty(payloadOps)) {
1855 results.
set(cast<OpResult>(getResult()), {});
1858 if (!llvm::hasSingleElement(payloadOps))
1860 <<
"handle must be mapped to exactly one payload op";
1863 if (
target->getNumResults() <= resultNumber)
1865 results.
set(llvm::cast<OpResult>(getResult()),
1866 llvm::to_vector(
target->getResult(resultNumber).getUsers()));
1880 if (llvm::isa<BlockArgument>(v)) {
1882 emitSilenceableError() <<
"cannot get defining op of block argument";
1883 diag.attachNote(v.getLoc()) <<
"target value";
1886 definingOps.push_back(v.getDefiningOp());
1888 results.
set(llvm::cast<OpResult>(getResult()), definingOps);
1900 int64_t operandNumber = getOperandNumber();
1904 target->getNumOperands() <= operandNumber
1906 :
target->getOperand(operandNumber).getDefiningOp();
1909 emitSilenceableError()
1910 <<
"could not find a producer for operand number: " << operandNumber
1912 diag.attachNote(
target->getLoc()) <<
"target op";
1915 producers.push_back(producer);
1917 results.
set(llvm::cast<OpResult>(getResult()), producers);
1933 getLoc(), getIsAll(), getIsInverted(), getRawPositionList(),
1934 target->getNumOperands(), operandPositions);
1935 if (
diag.isSilenceableFailure()) {
1937 <<
"while considering positions of this payload operation";
1940 llvm::append_range(operands,
1941 llvm::map_range(operandPositions, [&](
int64_t pos) {
1942 return target->getOperand(pos);
1945 results.
setValues(cast<OpResult>(getResult()), operands);
1949LogicalResult transform::GetOperandOp::verify() {
1951 getIsInverted(), getIsAll());
1966 getLoc(), getIsAll(), getIsInverted(), getRawPositionList(),
1967 target->getNumResults(), resultPositions);
1968 if (
diag.isSilenceableFailure()) {
1970 <<
"while considering positions of this payload operation";
1973 llvm::append_range(opResults,
1974 llvm::map_range(resultPositions, [&](
int64_t pos) {
1975 return target->getResult(pos);
1978 results.
setValues(cast<OpResult>(getResult()), opResults);
1982LogicalResult transform::GetResultOp::verify() {
1984 getIsInverted(), getIsAll());
1991void transform::GetTypeOp::getEffects(
2004 Type type = value.getType();
2005 if (getElemental()) {
2006 if (
auto shaped = dyn_cast<ShapedType>(type)) {
2007 type = shaped.getElementType();
2010 params.push_back(TypeAttr::get(type));
2012 results.
setParams(cast<OpResult>(getResult()), params);
2030 if (
result.isDefiniteFailure())
2033 if (
result.isSilenceableFailure()) {
2034 if (mode == transform::FailurePropagationMode::Propagate) {
2054 getOperation(), getTarget());
2055 assert(callee &&
"unverified reference to unknown symbol");
2057 if (callee.isExternal())
2062 detail::prepareValueMappings(mappings, getOperands(), state);
2064 for (
auto &&[arg, map] :
2065 llvm::zip_equal(callee.getBody().front().getArguments(), mappings)) {
2071 callee.getBody().front(), getFailurePropagationMode(), state, results);
2077 detail::prepareValueMappings(
2078 mappings, callee.getBody().front().getTerminator()->getOperands(), state);
2079 for (
auto &&[
result, mapping] : llvm::zip_equal(getResults(), mappings))
2087void transform::IncludeOp::getEffects(
2102 auto defaultEffects = [&] {
2109 getOperation()->getAttrOfType<SymbolRefAttr>(getTargetAttrName());
2111 return defaultEffects();
2113 getOperation(), getTarget());
2115 return defaultEffects();
2117 for (
unsigned i = 0, e = getNumOperands(); i < e; ++i) {
2118 if (callee.getArgAttr(i, TransformDialect::kArgConsumedAttrName))
2120 else if (callee.getArgAttr(i, TransformDialect::kArgReadOnlyAttrName))
2129 auto targetAttr = getOperation()->getAttrOfType<SymbolRefAttr>(
"target");
2131 return emitOpError() <<
"expects a 'target' symbol reference attribute";
2136 return emitOpError() <<
"does not reference a named transform sequence";
2138 FunctionType fnType =
target.getFunctionType();
2139 if (fnType.getNumInputs() != getNumOperands())
2140 return emitError(
"incorrect number of operands for callee");
2142 for (
unsigned i = 0, e = fnType.getNumInputs(); i != e; ++i) {
2143 if (getOperand(i).
getType() != fnType.getInput(i)) {
2144 return emitOpError(
"operand type mismatch: expected operand type ")
2145 << fnType.getInput(i) <<
", but provided "
2146 << getOperand(i).getType() <<
" for operand number " << i;
2150 if (fnType.getNumResults() != getNumResults())
2151 return emitError(
"incorrect number of results for callee");
2153 for (
unsigned i = 0, e = fnType.getNumResults(); i != e; ++i) {
2154 Type resultType = getResult(i).getType();
2155 Type funcType = fnType.getResult(i);
2158 <<
" must implement the same transform dialect "
2159 "interface as the corresponding callee result";
2164 cast<FunctionOpInterface>(*
target),
false,
2174 ::std::optional<::mlir::Operation *> maybeCurrent,
2176 if (!maybeCurrent.has_value()) {
2181 return emitSilenceableError() <<
"operation is not empty";
2192 for (
auto acceptedAttr : getOpNames().getAsRange<StringAttr>()) {
2193 if (acceptedAttr.getValue() == currentOpName)
2196 return emitSilenceableError() <<
"wrong operation name";
2207 auto signedAPIntAsString = [&](
const APInt &value) {
2209 llvm::raw_string_ostream os(str);
2210 value.print(os,
true);
2217 if (params.size() != references.size()) {
2218 return emitSilenceableError()
2219 <<
"parameters have different payload lengths (" << params.size()
2220 <<
" vs " << references.size() <<
")";
2223 for (
auto &&[i, param, reference] : llvm::enumerate(params, references)) {
2224 auto intAttr = llvm::dyn_cast<IntegerAttr>(param);
2225 auto refAttr = llvm::dyn_cast<IntegerAttr>(reference);
2226 if (!intAttr || !refAttr) {
2228 <<
"non-integer parameter value not expected";
2230 if (intAttr.getType() != refAttr.getType()) {
2232 <<
"mismatching integer attribute types in parameter #" << i;
2234 APInt value = intAttr.getValue();
2235 APInt refValue = refAttr.getValue();
2239 auto reportError = [&](StringRef direction) {
2241 emitSilenceableError() <<
"expected parameter to be " << direction
2242 <<
" " << signedAPIntAsString(refValue)
2243 <<
", got " << signedAPIntAsString(value);
2244 diag.attachNote(getParam().getLoc())
2245 <<
"value # " << position
2246 <<
" associated with the parameter defined here";
2250 switch (getPredicate()) {
2251 case MatchCmpIPredicate::eq:
2252 if (value.eq(refValue))
2254 return reportError(
"equal to");
2255 case MatchCmpIPredicate::ne:
2256 if (value.ne(refValue))
2258 return reportError(
"not equal to");
2259 case MatchCmpIPredicate::lt:
2260 if (value.slt(refValue))
2262 return reportError(
"less than");
2263 case MatchCmpIPredicate::le:
2264 if (value.sle(refValue))
2266 return reportError(
"less than or equal to");
2267 case MatchCmpIPredicate::gt:
2268 if (value.sgt(refValue))
2270 return reportError(
"greater than");
2271 case MatchCmpIPredicate::ge:
2272 if (value.sge(refValue))
2274 return reportError(
"greater than or equal to");
2280void transform::MatchParamCmpIOp::getEffects(
2294 results.
setParams(cast<OpResult>(getParam()), {getValue()});
2307 if (isa<TransformHandleTypeInterface>(handles.front().
getType())) {
2309 for (
Value operand : handles)
2310 llvm::append_range(operations, state.
getPayloadOps(operand));
2311 if (!getDeduplicate()) {
2312 results.
set(llvm::cast<OpResult>(getResult()), operations);
2317 results.
set(llvm::cast<OpResult>(getResult()), uniqued.getArrayRef());
2321 if (llvm::isa<TransformParamTypeInterface>(handles.front().getType())) {
2323 for (
Value attribute : handles)
2324 llvm::append_range(attrs, state.
getParams(attribute));
2325 if (!getDeduplicate()) {
2326 results.
setParams(cast<OpResult>(getResult()), attrs);
2331 results.
setParams(cast<OpResult>(getResult()), uniqued.getArrayRef());
2336 llvm::isa<TransformValueHandleTypeInterface>(handles.front().getType()) &&
2337 "expected value handle type");
2339 for (
Value value : handles)
2341 if (!getDeduplicate()) {
2342 results.
setValues(cast<OpResult>(getResult()), payloadValues);
2347 results.
setValues(cast<OpResult>(getResult()), uniqued.getArrayRef());
2351bool transform::MergeHandlesOp::allowsRepeatedHandleOperands() {
2353 return getDeduplicate();
2356void transform::MergeHandlesOp::getEffects(
2365OpFoldResult transform::MergeHandlesOp::fold(FoldAdaptor adaptor) {
2366 if (getDeduplicate() || getHandles().size() != 1)
2371 return getHandles().front();
2390 if (
failed(detail::mapPossibleTopLevelTransformOpBlockArguments(
2391 state, this->getOperation(), getBody())))
2395 FailurePropagationMode::Propagate, state, results);
2398void transform::NamedSequenceOp::getEffects(
2401ParseResult transform::NamedSequenceOp::parse(
OpAsmParser &parser,
2405 getFunctionTypeAttrName(
result.name),
2408 std::string &) { return builder.getFunctionType(inputs, results); },
2409 getArgAttrsAttrName(
result.name), getResAttrsAttrName(
result.name));
2412void transform::NamedSequenceOp::print(
OpAsmPrinter &printer) {
2414 printer, cast<FunctionOpInterface>(getOperation()),
false,
2415 getFunctionTypeAttrName().getValue(), getArgAttrsAttrName(),
2416 getResAttrsAttrName());
2426 if (
auto parent = op->
getParentOfType<transform::TransformOpInterface>()) {
2429 <<
"cannot be defined inside another transform op";
2430 diag.attachNote(parent.
getLoc()) <<
"ancestor transform op";
2434 if (op.isExternal() || op.getFunctionBody().empty()) {
2441 if (op.getFunctionBody().front().empty())
2444 Operation *terminator = &op.getFunctionBody().front().back();
2445 if (!isa<transform::YieldOp>(terminator)) {
2448 << transform::YieldOp::getOperationName()
2449 <<
"' as terminator";
2450 diag.attachNote(terminator->
getLoc()) <<
"terminator";
2454 if (terminator->
getNumOperands() != op.getResultTypes().size()) {
2456 <<
"expected terminator to have as many operands as the parent op "
2459 for (
auto [i, operandType, resultType] : llvm::zip_equal(
2462 if (operandType == resultType)
2465 <<
"the type of the terminator operand #" << i
2466 <<
" must match the type of the corresponding parent op result ("
2467 << operandType <<
" vs " << resultType <<
")";
2480 transform::TransformDialect::kWithNamedSequenceAttrName)) {
2483 <<
"expects the parent symbol table to have the '"
2484 << transform::TransformDialect::kWithNamedSequenceAttrName
2486 diag.attachNote(parent->
getLoc()) <<
"symbol table operation";
2491 if (
auto parent = op->
getParentOfType<transform::TransformOpInterface>()) {
2494 <<
"cannot be defined inside another transform op";
2495 diag.attachNote(parent.
getLoc()) <<
"ancestor transform op";
2499 if (op.isExternal() || op.getBody().empty())
2503 if (op.getBody().front().empty())
2507 for (
Operation &child : op.getBody().front().without_terminator()) {
2508 if (!isa<transform::TransformOpInterface>(child)) {
2511 <<
"expected children ops to implement TransformOpInterface";
2512 diag.attachNote(child.getLoc()) <<
"op without interface";
2517 Operation *terminator = &op.getBody().front().back();
2518 if (!isa<transform::YieldOp>(terminator)) {
2521 << transform::YieldOp::getOperationName()
2522 <<
"' as terminator";
2523 diag.attachNote(terminator->
getLoc()) <<
"terminator";
2527 if (terminator->
getNumOperands() != op.getFunctionType().getNumResults()) {
2529 <<
"expected terminator to have as many operands as the parent op "
2532 for (
auto [i, operandType, resultType] :
2533 llvm::zip_equal(llvm::seq<unsigned>(0, terminator->
getNumOperands()),
2535 op.getFunctionType().getResults())) {
2536 if (operandType == resultType)
2539 <<
"the type of the terminator operand #" << i
2540 <<
" must match the type of the corresponding parent op result ("
2541 << operandType <<
" vs " << resultType <<
")";
2544 auto funcOp = cast<FunctionOpInterface>(*op);
2547 if (!
diag.succeeded())
2554LogicalResult transform::NamedSequenceOp::verify() {
2559template <
typename FnTy>
2564 types.reserve(1 + extraBindingTypes.size());
2565 types.push_back(bbArgType);
2566 llvm::append_range(types, extraBindingTypes);
2576 if constexpr (llvm::function_traits<FnTy>::num_args == 3) {
2584void transform::NamedSequenceOp::build(
OpBuilder &builder,
2587 SequenceBodyBuilderFn bodyBuilder,
2593 TypeAttr::get(FunctionType::get(builder.
getContext(),
2594 rootType, resultTypes)));
2610 size_t numAssociations =
2612 .Case([&](TransformHandleTypeInterface opHandle) {
2615 .Case([&](TransformValueHandleTypeInterface valueHandle) {
2618 .Case([&](TransformParamTypeInterface param) {
2619 return llvm::range_size(state.
getParams(getHandle()));
2621 .DefaultUnreachable(
"unknown kind of transform dialect type");
2622 results.
setParams(cast<OpResult>(getNum()),
2627LogicalResult transform::NumAssociationsOp::verify() {
2629 auto resultType = cast<TransformParamTypeInterface>(getNum().
getType());
2649 results.
set(cast<OpResult>(getResult()),
result);
2669 .Case([&](TransformHandleTypeInterface x) {
2672 .Case([&](TransformValueHandleTypeInterface x) {
2675 .Case([&](TransformParamTypeInterface x) {
2676 return llvm::range_size(state.
getParams(getHandle()));
2678 .DefaultUnreachable(
"unknown transform dialect type interface");
2680 auto produceNumOpsError = [&]() {
2681 return emitSilenceableError()
2682 << getHandle() <<
" expected to contain " << this->getNumResults()
2683 <<
" payloads but it contains " << numPayloads <<
" payloads";
2688 if (numPayloads > getNumResults() && !getOverflowResult().has_value())
2689 return produceNumOpsError();
2694 if (numPayloads < getNumResults() && getFailOnPayloadTooSmall() &&
2695 (numPayloads != 0 || !getPassThroughEmptyHandle()))
2696 return produceNumOpsError();
2700 if (getOverflowResult())
2701 resultHandles[*getOverflowResult()].reserve(numPayloads - getNumResults());
2703 auto container = [&]() {
2704 if (isa<TransformHandleTypeInterface>(getHandle().
getType())) {
2705 return llvm::map_to_vector(
2707 [](
Operation *op) -> MappedValue {
return op; });
2709 if (isa<TransformValueHandleTypeInterface>(getHandle().
getType())) {
2711 [](
Value v) -> MappedValue {
return v; });
2713 assert(isa<TransformParamTypeInterface>(getHandle().
getType()) &&
2714 "unsupported kind of transform dialect type");
2715 return llvm::map_to_vector(state.
getParams(getHandle()),
2716 [](
Attribute a) -> MappedValue {
return a; });
2719 for (
auto &&en : llvm::enumerate(container)) {
2720 int64_t resultNum = en.index();
2721 if (resultNum >= getNumResults())
2722 resultNum = *getOverflowResult();
2723 resultHandles[resultNum].push_back(en.value());
2727 for (
auto &&it : llvm::enumerate(resultHandles))
2734void transform::SplitHandleOp::getEffects(
2742LogicalResult transform::SplitHandleOp::verify() {
2743 if (getOverflowResult().has_value() &&
2744 !(*getOverflowResult() < getNumResults()))
2745 return emitOpError(
"overflow_result is not a valid result index");
2747 for (
Type resultType : getResultTypes()) {
2751 return emitOpError(
"expects result types to implement the same transform "
2752 "interface as the operand type");
2762void transform::PayloadOp::getCheckedNormalForms(
2764 llvm::append_range(normalForms,
2765 getNormalForms().getAsRange<NormalFormAttrInterface>());
2776 unsigned numRepetitions = llvm::range_size(state.
getPayloadOps(getPattern()));
2777 for (
const auto &en : llvm::enumerate(getHandles())) {
2778 Value handle = en.value();
2779 if (isa<TransformHandleTypeInterface>(handle.getType())) {
2783 payload.reserve(numRepetitions * current.size());
2784 for (
unsigned i = 0; i < numRepetitions; ++i)
2785 llvm::append_range(payload, current);
2786 results.
set(llvm::cast<OpResult>(getReplicated()[en.index()]), payload);
2788 assert(llvm::isa<TransformParamTypeInterface>(handle.getType()) &&
2789 "expected param type");
2792 params.reserve(numRepetitions * current.size());
2793 for (
unsigned i = 0; i < numRepetitions; ++i)
2794 llvm::append_range(params, current);
2795 results.
setParams(llvm::cast<OpResult>(getReplicated()[en.index()]),
2802void transform::ReplicateOp::getEffects(
2819 if (
failed(mapBlockArguments(state)))
2827 OpAsmParser &parser, std::optional<OpAsmParser::UnresolvedOperand> &root,
2834 root = std::nullopt;
2837 if (failed(hasRoot.
value()))
2851 if (failed(parser.
parseType(rootType))) {
2855 if (!extraBindings.empty()) {
2860 if (extraBindingTypes.size() != extraBindings.size()) {
2862 "expected types to be provided for all operands");
2878 bool hasExtras = !extraBindings.empty();
2888 printer << rootType;
2890 printer <<
", " << llvm::interleaved(extraBindingTypes) <<
')';
2897 auto iface = dyn_cast<transform::TransformOpInterface>(use.
getOwner());
2901 return isHandleConsumed(use.
get(), iface);
2912 if (!potentialConsumer) {
2913 potentialConsumer = &use;
2918 <<
" has more than one potential consumer";
2921 diag.attachNote(use.getOwner()->getLoc())
2922 <<
"used here as operand #" << use.getOperandNumber();
2929LogicalResult transform::SequenceOp::verify() {
2930 assert(getBodyBlock()->getNumArguments() >= 1 &&
2931 "the number of arguments must have been verified to be more than 1 by "
2932 "PossibleTopLevelTransformOpTrait");
2934 if (!getRoot() && !getExtraBindings().empty()) {
2936 <<
"does not expect extra operands when used as top-level";
2942 return (
emitOpError() <<
"block argument #" << arg.getArgNumber());
2949 for (
Operation &child : *getBodyBlock()) {
2950 if (!isa<TransformOpInterface>(child) &&
2951 &child != &getBodyBlock()->back()) {
2954 <<
"expected children ops to implement TransformOpInterface";
2955 diag.attachNote(child.getLoc()) <<
"op without interface";
2960 auto report = [&]() {
2961 return (child.emitError() <<
"result #" <<
result.getResultNumber());
2968 if (!getBodyBlock()->mightHaveTerminator())
2969 return emitOpError() <<
"expects to have a terminator in the body";
2971 if (getBodyBlock()->getTerminator()->getOperandTypes() !=
2972 getOperation()->getResultTypes()) {
2974 <<
"expects the types of the terminator operands "
2975 "to match the types of the result";
2976 diag.attachNote(getBodyBlock()->getTerminator()->getLoc()) <<
"terminator";
2982void transform::SequenceOp::getEffects(
2988transform::SequenceOp::getEntrySuccessorOperands(
RegionSuccessor successor) {
2989 assert(successor.
getSuccessor() == &getBody() &&
"unexpected region index");
2990 if (getOperation()->getNumOperands() > 0)
2991 return getOperation()->getOperands();
2993 getOperation()->operand_end());
2996void transform::SequenceOp::getSuccessorRegions(
2999 Region *bodyRegion = &getBody();
3000 regions.emplace_back(bodyRegion);
3006 "unexpected region index");
3012 if (getNumOperands() == 0)
3015 return getResults();
3016 return getBody().getArguments();
3019void transform::SequenceOp::getRegionInvocationBounds(
3022 bounds.emplace_back(1, 1);
3027 FailurePropagationMode failurePropagationMode,
3029 SequenceBodyBuilderFn bodyBuilder) {
3030 build(builder, state, resultTypes, failurePropagationMode, root,
3039 FailurePropagationMode failurePropagationMode,
3041 SequenceBodyBuilderArgsFn bodyBuilder) {
3042 build(builder, state, resultTypes, failurePropagationMode, root,
3050 FailurePropagationMode failurePropagationMode,
3052 SequenceBodyBuilderFn bodyBuilder) {
3053 build(builder, state, resultTypes, failurePropagationMode,
Value(),
3061 FailurePropagationMode failurePropagationMode,
3063 SequenceBodyBuilderArgsFn bodyBuilder) {
3064 build(builder, state, resultTypes, failurePropagationMode,
Value(),
3082 build(builder,
result, name);
3089 llvm::outs() <<
"[[[ IR printer: ";
3090 if (getName().has_value())
3091 llvm::outs() << *getName() <<
" ";
3094 if (getAssumeVerified().value_or(
false))
3096 if (getUseLocalScope().value_or(
false))
3098 if (getSkipRegions().value_or(
false))
3102 llvm::outs() <<
"top-level ]]]\n";
3104 llvm::outs() <<
"\n";
3105 llvm::outs().flush();
3109 llvm::outs() <<
"]]]\n";
3111 target->print(llvm::outs(), printFlags);
3112 llvm::outs() <<
"\n";
3115 llvm::outs().flush();
3119void transform::PrintOp::getEffects(
3124 if (!getTargetMutable().empty())
3144 <<
"failed to verify payload op";
3145 diag.attachNote(
target->getLoc()) <<
"payload op";
3151void transform::VerifyOp::getEffects(
3160void transform::YieldOp::getEffects(
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
static ParseResult parseKeyValuePair(AsmParser &parser, DataLayoutEntryInterface &entry, bool tryType=false)
Parse an entry which can either be of the form key = value or a dlti.dl_entry attribute.
static std::string diag(const llvm::Value &value)
static llvm::ManagedStatic< PassManagerOptions > options
ParseResult parseSymbolName(StringAttr &result)
Parse an -identifier and store it (without the '@' symbol) in a string attribute.
@ None
Zero or more operands with no delimiters.
@ Braces
{} brackets surrounding zero or more operands.
virtual ParseResult parseOptionalKeywordOrString(std::string *result)=0
Parse an optional keyword or string.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
MLIRContext * getContext() const
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseRSquare()=0
Parse a ] token.
virtual ParseResult parseOptionalRParen()=0
Parse a ) token if present.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual OptionalParseResult parseOptionalAttribute(Attribute &result, Type type={})=0
Parse an arbitrary optional attribute of a given type and return it in result.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual ParseResult parseColon()=0
Parse a : token.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseArrow()=0
Parse a '->' token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseComma()=0
Parse a , token.
virtual ParseResult parseOptionalLParen()=0
Parse a ( token if present.
ParseResult parseTypeList(SmallVectorImpl< Type > &result)
Parse a type list.
virtual ParseResult parseOptionalLSquare()=0
Parse a [ token if present.
virtual void decreaseIndent()
Decrease indentation.
virtual void increaseIndent()
Increase indentation.
virtual void printAttribute(Attribute attr)
virtual void printNewline()
Print a newline and indent the printer to the start of the current operation/attribute/type.
Attributes are known-constant values of operations.
This class represents an argument of a Block.
Block represents an ordered list of Operations.
BlockArgument getArgument(unsigned i)
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
OpListType & getOperations()
Operation * getTerminator()
Get the terminator operation of this block.
BlockArgListType getArguments()
iterator_range< iterator > without_terminator()
Return an iterator range over the operation within this block excluding the terminator operation at t...
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
This class is a general helper class for creating context-global objects like types,...
IntegerAttr getI64IntegerAttr(int64_t value)
StringAttr getStringAttr(const Twine &bytes)
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
MLIRContext * getContext() const
A compatibility class connecting InFlightDiagnostic to DiagnosedSilenceableFailure while providing an...
The result of a transform IR operation application.
LogicalResult silence()
Converts silenceable failure into LogicalResult success without reporting the diagnostic,...
static DiagnosedSilenceableFailure success()
Constructs a DiagnosedSilenceableFailure in the success state.
std::string getMessage() const
Returns the diagnostic message without emitting it.
Diagnostic & attachNote(std::optional< Location > loc=std::nullopt)
Attaches a note to the last diagnostic.
LogicalResult checkAndReport()
Converts all kinds of failure into a LogicalResult failure, emitting the diagnostic if necessary.
bool succeeded() const
Returns true if this is a success.
static DiagnosedSilenceableFailure definiteFailure()
Constructs a DiagnosedSilenceableFailure in the failure state.
Dialects are groups of MLIR operations, types and attributes, as well as behavior associated with the...
A class for computing basic dominance information.
This class represents a frozen set of patterns that can be processed by a pattern applicator.
This class allows control over how the GreedyPatternRewriteDriver works.
static constexpr int64_t kNoLimit
GreedyRewriteConfig & setListener(RewriterBase::Listener *listener)
GreedyRewriteConfig & enableCSEBetweenIterations(bool enable=true)
GreedyRewriteConfig & setMaxIterations(int64_t iterations)
GreedyRewriteConfig & setMaxNumRewrites(int64_t limit)
IRValueT get() const
Return the current value being used by this operand.
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
This class represents a diagnostic that is inflight and set to be reported.
This class represents upper and lower bounds on the number of times a region of a RegionBranchOpInter...
Conversion from types to the LLVM IR dialect.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
std::vector< Dialect * > getLoadedDialects()
Return information about all IR dialects loaded in the context.
ArrayRef< RegisteredOperationName > getRegisteredOperations()
Return a sorted array containing the information about all registered operations.
void append(StringRef name, Attribute attr)
Add an attribute with the specified name.
NamedAttribute represents a combination of a name and an Attribute value.
StringAttr getName() const
Return the name of the attribute.
Attribute getValue() const
Return the value of the attribute.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual OptionalParseResult parseOptionalOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single operand if present.
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
void printOperands(const ContainerType &container)
Print a comma separated list of operands.
virtual void printOperand(Value value)=0
Print implementations for various things an operation contains.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Listener * getListener() const
Returns the current listener of this builder, or nullptr if this builder doesn't have a listener.
This class represents a single result from folding an operation.
This class represents an operand of an operation.
unsigned getOperandNumber() const
Return which operand this is in the OpOperand list of the Operation.
Set of flags used to control the behavior of the various IR print methods (e.g.
OpPrintingFlags & useLocalScope(bool enable=true)
Use local scope when printing the operation.
OpPrintingFlags & assumeVerified(bool enable=true)
Do not verify the operation when using custom operation printers.
OpPrintingFlags & skipRegions(bool skip=true)
Skip printing regions.
This is a value defined by a result of an operation.
This class provides the API for ops that are known to be isolated from above.
A trait used to provide symbol table functionalities to a region operation.
A wrapper class that allows for printing an operation with a set of flags, useful to act as a "stream...
This class implements the operand iterators for the Operation class.
type_range getType() const
type_range getTypes() const
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
Operation is the basic unit of execution within MLIR.
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Attribute getAttr(StringAttr name)
Return the specified attribute if present, null otherwise.
Block * getBlock()
Returns the operation block that contains this operation.
Operation * getParentWithTrait()
Returns the closest surrounding parent operation with trait Trait.
Location getLoc()
The source location the operation was defined or derived from.
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
unsigned getNumOperands()
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
OperationName getName()
The name of an operation is the key identifier for it.
void print(raw_ostream &os, const OpPrintingFlags &flags={})
operand_range getOperands()
Returns an iterator on the underlying Value's.
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
result_range getOpResults()
Operation * clone(IRMapping &mapper, const CloneOptions &options=CloneOptions::all())
Create a deep copy of this operation, remapping any operands that use values outside of the operation...
This class implements Optional functionality for ParseResult.
ParseResult value() const
Access the internal ParseResult value.
bool has_value() const
Returns true if we contain a valid ParseResult value.
static const PassInfo * lookup(StringRef passArg)
Returns the pass info for the specified pass class or null if unknown.
The main pass manager and pipeline builder.
static const PassPipelineInfo * lookup(StringRef pipelineArg)
Returns the pass pipeline info for the specified pass pipeline or null if unknown.
Structure to group information about a passes and pass pipelines (argument to invoke via mlir-opt,...
LogicalResult addToPipeline(OpPassManager &pm, StringRef options, function_ref< LogicalResult(const Twine &)> errorHandler) const
Adds this pass registry entry to the given pass manager.
This class represents a point being branched from in the methods of the RegionBranchOpInterface.
bool isParent() const
Returns true if branching from the parent op.
RegionBranchTerminatorOpInterface getTerminatorPredecessorOrNull() const
Returns the terminator if branching from a region.
This class represents a successor of a region.
static RegionSuccessor parent()
Initialize a successor that branches after/out of the parent operation.
bool isParent() const
Return true if the successor is the parent operation.
Region * getSuccessor() const
Return the given region successor.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
BlockArgListType getArguments()
This is a "type erased" representation of a registered operation.
MLIRContext * getContext() const
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
static DerivedEffect * get()
This class represents a collection of SymbolTables.
virtual Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
static StringRef getSymbolAttrName()
Return the name of the attribute used for symbol names.
static Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
type_range getType() const
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
use_range getUses() const
Returns a range of all uses, which is useful for iterating over all uses.
A utility result that is used to signal how to proceed with an ongoing walk:
static WalkResult advance()
bool wasInterrupted() const
Returns true if the walk was interrupted.
static WalkResult interrupt()
Operation * getOwner() const
Return the owner of this operand.
A named class for passing around the variadic flag.
void printFunctionOp(OpAsmPrinter &p, FunctionOpInterface op, bool isVariadic, StringRef typeAttrName, StringAttr argAttrsName, StringAttr resAttrsName)
Printer implementation for function-like operations.
ParseResult parseFunctionOp(OpAsmParser &parser, OperationState &result, bool allowVariadic, StringAttr typeAttrName, FuncTypeBuilder funcTypeBuilder, StringAttr argAttrsName, StringAttr resAttrsName)
Parser implementation for function-like operations.
Include the generated interface declarations.
InFlightDiagnostic emitWarning(Location loc)
Utility method to emit a warning message using this location.
void eliminateCommonSubExpressions(RewriterBase &rewriter, DominanceInfo &domInfo, Operation *op, bool *changed=nullptr, int64_t *numCSE=nullptr, int64_t *numDCE=nullptr)
Eliminate common subexpressions within the given operation.
function_ref< void(Value, StringRef)> OpAsmSetValueNameFn
A functor used to set the name of the start of a result group of an operation.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
LogicalResult applyPatternsGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
DiagnosedSilenceableFailure emitSilenceableFailure(Location loc, const Twine &message={})
Emits a silenceable failure with the given message.
llvm::DenseSet< ValueT, ValueInfoT > DenseSet
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
LogicalResult applyOpPatternsGreedily(ArrayRef< Operation * > ops, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr, bool *allErased=nullptr)
Rewrite the specified ops by repeatedly applying the highest benefit patterns in a greedy worklist dr...
llvm::SetVector< T, Vector, Set, N > SetVector
DiagnosedDefiniteFailure emitDefiniteFailure(Location loc, const Twine &message={})
Emits a definite failure with the given message.
bool isOpTriviallyDead(Operation *op)
Return true if the given operation is unused, and has no side effects on memory that prevent erasing.
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
llvm::function_ref< Fn > function_ref
size_t moveLoopInvariantCode(ArrayRef< Region * > regions, function_ref< bool(Value, Region *)> isDefinedOutsideRegion, function_ref< bool(Operation *, Region *)> shouldMoveOutOfRegion, function_ref< void(Operation *, Region *)> moveOutOfRegion)
Given a list of regions, perform loop-invariant code motion.
This is the representation of an operand reference.
This represents an operation in an abstracted form, suitable for use with the builder APIs.
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
SmallVector< std::unique_ptr< Region >, 1 > regions
Regions that the op will hold.
Region * addRegion()
Create a region that should be attached to the operation.