34#include "llvm/ADT/DenseSet.h"
35#include "llvm/ADT/STLExtras.h"
36#include "llvm/ADT/ScopeExit.h"
37#include "llvm/ADT/SmallPtrSet.h"
38#include "llvm/ADT/SmallVectorExtras.h"
39#include "llvm/ADT/TypeSwitch.h"
40#include "llvm/Support/Debug.h"
41#include "llvm/Support/DebugLog.h"
42#include "llvm/Support/ErrorHandling.h"
43#include "llvm/Support/InterleavedRange.h"
46#define DEBUG_TYPE "transform-dialect"
47#define DEBUG_TYPE_MATCHER "transform-matcher"
59 OpAsmParser &parser, std::optional<OpAsmParser::UnresolvedOperand> &root,
80 while (transformAncestor) {
81 if (transformAncestor == payload) {
84 <<
"cannot apply transform to itself (or one of its ancestors)";
85 diag.attachNote(payload->
getLoc()) <<
"target payload op";
88 transformAncestor = transformAncestor->
getParentOp();
94#include "mlir/Dialect/Transform/IR/TransformOps.cpp.inc"
100OperandRange transform::AlternativesOp::getEntrySuccessorOperands(
102 if (!successor.
isParent() && getOperation()->getNumOperands() == 1)
103 return getOperation()->getOperands();
105 getOperation()->operand_end());
108void transform::AlternativesOp::getSuccessorRegions(
110 for (
Region &alternative : llvm::drop_begin(
115 ->getRegionNumber() +
117 regions.emplace_back(&alternative);
124transform::AlternativesOp::getSuccessorInputs(
RegionSuccessor successor) {
126 return getOperation()->getResults();
130void transform::AlternativesOp::getRegionInvocationBounds(
135 bounds.reserve(getNumRegions());
136 bounds.emplace_back(1, 1);
143 results.
set(res, {});
151 if (
Value scopeHandle = getScope())
152 llvm::append_range(originals, state.
getPayloadOps(scopeHandle));
157 if (original->isAncestor(getOperation())) {
159 <<
"scope must not contain the transforms being applied";
160 diag.attachNote(original->getLoc()) <<
"scope";
165 <<
"only isolated-from-above ops can be alternative scopes";
166 diag.attachNote(original->getLoc()) <<
"scope";
171 for (
Region ® : getAlternatives()) {
177 auto clones = llvm::map_to_vector(
179 llvm::scope_exit deleteClones([&] {
190 if (
result.isSilenceableFailure()) {
191 LDBG() <<
"alternative failed: " <<
result.getMessage();
196 if (::mlir::failed(
result.silence()))
205 deleteClones.release();
206 TrackingListener listener(state, *
this);
208 for (
const auto &kvp : llvm::zip(originals, clones)) {
215 detail::forwardTerminatorOperands(®.front(), state, results);
219 return emitSilenceableError() <<
"all alternatives failed";
222void transform::AlternativesOp::getEffects(
226 for (
Region *region : getRegions()) {
227 if (!region->empty())
233LogicalResult transform::AlternativesOp::verify() {
234 for (
Region &alternative : getAlternatives()) {
239 <<
"expects terminator operands to have the "
240 "same type as results of the operation";
241 diag.attachNote(terminator->
getLoc()) <<
"terminator";
261 if (
auto paramH = getParam()) {
263 if (params.size() != 1) {
264 if (targets.size() != params.size()) {
265 return emitSilenceableError()
266 <<
"parameter and target have different payload lengths ("
267 << params.size() <<
" vs " << targets.size() <<
")";
269 for (
auto &&[
target, attr] : llvm::zip_equal(targets, params))
270 target->setAttr(getName(), attr);
275 for (
auto *
target : targets)
276 target->setAttr(getName(), attr);
280void transform::AnnotateOp::getEffects(
292transform::ApplyCommonSubexpressionEliminationOp::applyToOne(
307void transform::ApplyCommonSubexpressionEliminationOp::getEffects(
332 auto addDefiningOpsToWorklist = [&](
Operation *op) {
335 if (
Operation *defOp = v.getDefiningOp())
336 if (
target->isProperAncestor(defOp))
337 worklist.insert(defOp);
345 const auto *it = llvm::find(worklist, op);
346 if (it != worklist.end())
355 addDefiningOpsToWorklist(op);
361 while (!worklist.empty()) {
365 addDefiningOpsToWorklist(op);
372void transform::ApplyDeadCodeEliminationOp::getEffects(
397 if (!getRegion().empty()) {
398 for (
Operation &op : getRegion().front()) {
399 cast<transform::PatternDescriptorOpInterface>(&op)
400 .populatePatternsWithState(
patterns, state);
410 config.setMaxIterations(getMaxIterations() ==
static_cast<uint64_t
>(-1)
412 : getMaxIterations());
413 config.setMaxNumRewrites(getMaxNumRewrites() ==
static_cast<uint64_t
>(-1)
415 : getMaxNumRewrites());
420 bool cseChanged =
false;
423 static const int64_t kNumMaxIterations = 50;
426 LogicalResult
result = failure();
440 ops.push_back(nestedOp);
449 <<
"greedy pattern application failed";
457 }
while (cseChanged && ++iteration < kNumMaxIterations);
459 if (iteration == kNumMaxIterations)
465LogicalResult transform::ApplyPatternsOp::verify() {
466 if (!getRegion().empty()) {
467 for (
Operation &op : getRegion().front()) {
468 if (!isa<transform::PatternDescriptorOpInterface>(&op)) {
470 <<
"expected children ops to implement "
471 "PatternDescriptorOpInterface";
472 diag.attachNote(op.
getLoc()) <<
"op without interface";
480void transform::ApplyPatternsOp::getEffects(
486void transform::ApplyPatternsOp::build(
495 bodyBuilder(builder,
result.location);
502void transform::ApplyCanonicalizationPatternsOp::populatePatterns(
506 dialect->getCanonicalizationPatterns(
patterns);
508 op.getCanonicalizationPatterns(
patterns, ctx);
522 std::unique_ptr<TypeConverter> defaultTypeConverter;
523 transform::TypeConverterBuilderOpInterface typeConverterBuilder =
524 getDefaultTypeConverter();
525 if (typeConverterBuilder)
526 defaultTypeConverter = typeConverterBuilder.getTypeConverter();
531 for (
Attribute attr : cast<ArrayAttr>(*getLegalOps()))
532 conversionTarget.addLegalOp(
535 for (
Attribute attr : cast<ArrayAttr>(*getIllegalOps()))
536 conversionTarget.addIllegalOp(
538 if (getLegalDialects())
539 for (
Attribute attr : cast<ArrayAttr>(*getLegalDialects()))
540 conversionTarget.addLegalDialect(cast<StringAttr>(attr).getValue());
541 if (getIllegalDialects())
542 for (
Attribute attr : cast<ArrayAttr>(*getIllegalDialects()))
543 conversionTarget.addIllegalDialect(cast<StringAttr>(attr).getValue());
551 if (!getPatterns().empty()) {
552 for (
Operation &op : getPatterns().front()) {
554 cast<transform::ConversionPatternDescriptorOpInterface>(&op);
557 std::unique_ptr<TypeConverter> typeConverter =
558 descriptor.getTypeConverter();
561 keepAliveConverters.emplace_back(std::move(typeConverter));
562 converter = keepAliveConverters.back().get();
565 if (!defaultTypeConverter) {
567 <<
"pattern descriptor does not specify type "
568 "converter and apply_conversion_patterns op has "
569 "no default type converter";
570 diag.attachNote(op.
getLoc()) <<
"pattern descriptor op";
573 converter = defaultTypeConverter.get();
579 descriptor.populateConversionTargetRules(*converter, conversionTarget);
581 descriptor.populatePatterns(*converter,
patterns);
589 TrackingListenerConfig trackingConfig;
590 trackingConfig.requireMatchingReplacementOpName =
false;
591 ErrorCheckingTrackingListener trackingListener(state, *
this, trackingConfig);
592 ConversionConfig conversionConfig;
593 if (getPreserveHandles())
594 conversionConfig.listener = &trackingListener;
605 LogicalResult status = failure();
606 if (getPartialConversion()) {
607 status = applyPartialConversion(
target, conversionTarget, frozenPatterns,
610 status = applyFullConversion(
target, conversionTarget, frozenPatterns,
617 diag = emitSilenceableError() <<
"dialect conversion failed";
618 diag.attachNote(
target->getLoc()) <<
"target op";
623 trackingListener.checkAndResetError();
625 if (
diag.succeeded()) {
627 return trackingFailure;
629 diag.attachNote() <<
"tracking listener also failed: "
634 if (!
diag.succeeded())
641LogicalResult transform::ApplyConversionPatternsOp::verify() {
642 if (getNumRegions() != 1 && getNumRegions() != 2)
644 if (!getPatterns().empty()) {
645 for (
Operation &op : getPatterns().front()) {
646 if (!isa<transform::ConversionPatternDescriptorOpInterface>(&op)) {
648 emitOpError() <<
"expected pattern children ops to implement "
649 "ConversionPatternDescriptorOpInterface";
650 diag.attachNote(op.
getLoc()) <<
"op without interface";
655 if (getNumRegions() == 2) {
656 Region &typeConverterRegion = getRegion(1);
657 if (!llvm::hasSingleElement(typeConverterRegion.
front()))
659 <<
"expected exactly one op in default type converter region";
661 auto typeConverterOp = dyn_cast<transform::TypeConverterBuilderOpInterface>(
663 if (!typeConverterOp) {
665 <<
"expected default converter child op to "
666 "implement TypeConverterBuilderOpInterface";
667 diag.attachNote(maybeTypeConverter->
getLoc()) <<
"op without interface";
671 if (!getPatterns().empty()) {
672 for (
Operation &op : getPatterns().front()) {
674 cast<transform::ConversionPatternDescriptorOpInterface>(&op);
675 if (
failed(descriptor.verifyTypeConverter(typeConverterOp)))
683void transform::ApplyConversionPatternsOp::getEffects(
685 if (!getPreserveHandles()) {
693void transform::ApplyConversionPatternsOp::build(
703 if (patternsBodyBuilder)
704 patternsBodyBuilder(builder,
result.location);
710 if (typeConverterBodyBuilder)
711 typeConverterBodyBuilder(builder,
result.location);
719void transform::ApplyToLLVMConversionPatternsOp::populatePatterns(
722 assert(dialect &&
"expected that dialect is loaded");
723 auto *iface = cast<ConvertToLLVMPatternInterface>(dialect);
727 iface->populateConvertToLLVMConversionPatterns(
731LogicalResult transform::ApplyToLLVMConversionPatternsOp::verifyTypeConverter(
732 transform::TypeConverterBuilderOpInterface builder) {
733 if (builder.getTypeConverterType() !=
"LLVMTypeConverter")
738LogicalResult transform::ApplyToLLVMConversionPatternsOp::verify() {
741 return emitOpError(
"unknown dialect or dialect not loaded: ")
743 auto *iface = dyn_cast<ConvertToLLVMPatternInterface>(dialect);
746 "dialect does not implement ConvertToLLVMPatternInterface or "
747 "extension was not loaded: ")
757transform::ApplyLoopInvariantCodeMotionOp::applyToOne(
767void transform::ApplyLoopInvariantCodeMotionOp::getEffects(
777void transform::ApplyRegisteredPassOp::getEffects(
795 llvm::raw_string_ostream optionsStream(
options);
800 if (
auto paramOperand = dyn_cast<transform::ParamOperandAttr>(valueAttr)) {
803 int64_t dynamicOptionIdx = paramOperand.getIndex().getInt();
804 assert(dynamicOptionIdx <
static_cast<int64_t>(dynamicOptions.size()) &&
805 "the number of ParamOperandAttrs in the options DictionaryAttr"
806 "should be the same as the number of options passed as params");
808 state.
getParams(dynamicOptions[dynamicOptionIdx]);
810 llvm::interleave(attrsAssociatedToParam, optionsStream, appendValueAttr,
812 }
else if (
auto arrayAttr = dyn_cast<ArrayAttr>(valueAttr)) {
814 llvm::interleave(arrayAttr, optionsStream, appendValueAttr,
",");
815 }
else if (
auto strAttr = dyn_cast<StringAttr>(valueAttr)) {
817 optionsStream << strAttr.getValue().str();
820 valueAttr.print(optionsStream,
true);
826 getOptions(), optionsStream,
827 [&](
auto namedAttribute) {
828 optionsStream << namedAttribute.getName().str();
829 optionsStream <<
"=";
830 appendValueAttr(namedAttribute.getValue());
833 optionsStream.flush();
841 <<
"unknown pass or pass pipeline: " << getPassName();
850 <<
"failed to add pass or pass pipeline to pipeline: "
867 auto diag = emitSilenceableError() <<
"pass pipeline failed";
868 diag.attachNote(
target->getLoc()) <<
"target op";
874 results.
set(llvm::cast<OpResult>(getResult()), targets);
883 size_t dynamicOptionsIdx = 0;
889 std::function<ParseResult(
Attribute &)> parseValue =
890 [&](
Attribute &valueAttr) -> ParseResult {
898 [&]() -> ParseResult { return parseValue(attrs.emplace_back()); },
899 " in options dictionary") ||
903 valueAttr = ArrayAttr::get(parser.
getContext(), attrs);
913 ParseResult parsedOperand = parser.
parseOperand(operand);
914 if (failed(parsedOperand))
920 dynamicOptions.push_back(operand);
921 auto wrappedIndex = IntegerAttr::get(
922 IntegerType::get(parser.
getContext(), 64), dynamicOptionsIdx++);
924 transform::ParamOperandAttr::get(parser.
getContext(), wrappedIndex);
925 }
else if (failed(parsedValueAttr.
value())) {
927 }
else if (isa<transform::ParamOperandAttr>(valueAttr)) {
929 <<
"the param_operand attribute is a marker reserved for "
930 <<
"indicating a value will be passed via params and is only used "
931 <<
"in the generic print format";
945 <<
"expected key to either be an identifier or a string";
949 <<
"expected '=' after key in key-value pair";
951 if (failed(parseValue(valueAttr)))
953 <<
"expected a valid attribute or operand as value associated "
954 <<
"to key '" << key <<
"'";
963 " in options dictionary"))
966 if (DictionaryAttr::findDuplicate(
967 keyValuePairs,
false)
970 <<
"duplicate keys found in options dictionary";
985 if (
auto paramOperandAttr =
986 dyn_cast<transform::ParamOperandAttr>(valueAttr)) {
989 dynamicOptions[paramOperandAttr.getIndex().getInt()]);
990 }
else if (
auto arrayAttr = dyn_cast<ArrayAttr>(valueAttr)) {
993 llvm::interleaveComma(arrayAttr, printer, printOptionValue);
1002 printer << namedAttribute.
getName();
1004 printOptionValue(namedAttribute.
getValue());
1009LogicalResult transform::ApplyRegisteredPassOp::verify() {
1016 std::function<LogicalResult(
Attribute)> checkOptionValue =
1017 [&](
Attribute valueAttr) -> LogicalResult {
1018 if (
auto paramOperand = dyn_cast<transform::ParamOperandAttr>(valueAttr)) {
1019 int64_t dynamicOptionIdx = paramOperand.getIndex().getInt();
1020 if (dynamicOptionIdx < 0 ||
1021 dynamicOptionIdx >=
static_cast<int64_t>(dynamicOptions.size()))
1023 <<
"dynamic option index " << dynamicOptionIdx
1024 <<
" is out of bounds for the number of dynamic options: "
1025 << dynamicOptions.size();
1026 if (dynamicOptions[dynamicOptionIdx] ==
nullptr)
1027 return emitOpError() <<
"dynamic option index " << dynamicOptionIdx
1028 <<
" is already used in options";
1029 dynamicOptions[dynamicOptionIdx] =
nullptr;
1030 }
else if (
auto arrayAttr = dyn_cast<ArrayAttr>(valueAttr)) {
1032 for (
auto eltAttr : arrayAttr)
1033 if (
failed(checkOptionValue(eltAttr)))
1040 if (
failed(checkOptionValue(namedAttr.getValue())))
1044 for (
Value dynamicOption : dynamicOptions)
1046 return emitOpError() <<
"a param operand does not have a corresponding "
1047 <<
"param_operand attr in the options dict";
1060 results.push_back(
target);
1064void transform::CastOp::getEffects(
1072 assert(inputs.size() == 1 &&
"expected one input");
1073 assert(outputs.size() == 1 &&
"expected one output");
1074 return llvm::all_of(
1075 std::initializer_list<Type>{inputs.front(), outputs.front()},
1076 llvm::IsaPred<transform::TransformHandleTypeInterface>);
1096 assert(block.
getParent() &&
"cannot match using a detached block");
1103 if (!isa<transform::MatchOpInterface>(match)) {
1105 <<
"expected operations in the match part to "
1106 "implement MatchOpInterface";
1109 state.
applyTransform(cast<transform::TransformOpInterface>(match));
1110 if (
diag.succeeded())
1128template <
typename... Tys>
1130 return ((isa<Tys>(t1) && isa<Tys>(t2)) || ... ||
false);
1137 transform::TransformParamTypeInterface,
1138 transform::TransformValueHandleTypeInterface>(
1151 getOperation(), getMatcher());
1152 if (matcher.isExternal()) {
1154 <<
"unresolved external symbol " << getMatcher();
1158 rawResults.resize(getOperation()->getNumResults());
1159 std::optional<DiagnosedSilenceableFailure> maybeFailure;
1171 matcher.getFunctionBody().front(),
1174 if (
diag.isDefiniteFailure())
1176 if (
diag.isSilenceableFailure()) {
1178 <<
" failed: " <<
diag.getMessage();
1183 for (
auto &&[i, mapping] : llvm::enumerate(mappings)) {
1184 if (mapping.size() != 1) {
1185 maybeFailure.emplace(emitSilenceableError()
1186 <<
"result #" << i <<
", associated with "
1188 <<
" payload objects, expected 1");
1191 rawResults[i].push_back(mapping[0]);
1196 return std::move(*maybeFailure);
1197 assert(!maybeFailure &&
"failure set but the walk was not interrupted");
1199 for (
auto &&[opResult, rawResult] :
1200 llvm::zip_equal(getOperation()->getResults(), rawResults)) {
1207void transform::CollectMatchingOp::getEffects(
1214LogicalResult transform::CollectMatchingOp::verifySymbolUses(
1216 auto matcherSymbol = dyn_cast_or_null<FunctionOpInterface>(
1218 if (!matcherSymbol ||
1219 !isa<TransformOpInterface>(matcherSymbol.getOperation()))
1220 return emitError() <<
"unresolved matcher symbol " << getMatcher();
1223 if (argumentTypes.size() != 1 ||
1224 !isa<TransformHandleTypeInterface>(argumentTypes[0])) {
1226 <<
"expected the matcher to take one operation handle argument";
1228 if (!matcherSymbol.getArgAttr(
1229 0, transform::TransformDialect::kArgReadOnlyAttrName)) {
1230 return emitError() <<
"expected the matcher argument to be marked readonly";
1234 if (resultTypes.size() != getOperation()->getNumResults()) {
1236 <<
"expected the matcher to yield as many values as op has results ("
1237 << getOperation()->getNumResults() <<
"), got "
1238 << resultTypes.size();
1241 for (
auto &&[i, matcherType, resultType] :
1242 llvm::enumerate(resultTypes, getOperation()->getResultTypes())) {
1247 <<
"mismatching type interfaces for matcher result and op result #"
1259bool transform::ForeachMatchOp::allowsRepeatedHandleOperands() {
return true; }
1267 matchActionPairs.reserve(getMatchers().size());
1269 for (
auto &&[matcher, action] :
1270 llvm::zip_equal(getMatchers(), getActions())) {
1271 auto matcherSymbol =
1273 getOperation(), cast<SymbolRefAttr>(matcher));
1276 getOperation(), cast<SymbolRefAttr>(action));
1277 assert(matcherSymbol && actionSymbol &&
1278 "unresolved symbols not caught by the verifier");
1280 if (matcherSymbol.isExternal())
1282 if (actionSymbol.isExternal())
1285 matchActionPairs.emplace_back(matcherSymbol, actionSymbol);
1296 matchInputMapping.emplace_back();
1298 getForwardedInputs(), state);
1300 actionResultMapping.resize(getForwardedOutputs().size());
1306 if (!getRestrictRoot() && op == root)
1314 firstMatchArgument.clear();
1315 firstMatchArgument.push_back(op);
1318 for (
auto [matcher, action] : matchActionPairs) {
1320 matchBlock(matcher.getFunctionBody().front(), matchInputMapping,
1321 state, matchOutputMapping);
1322 if (
diag.isDefiniteFailure())
1324 if (
diag.isSilenceableFailure()) {
1326 <<
" failed: " <<
diag.getMessage();
1332 action.getFunctionBody().front().getArguments(),
1333 matchOutputMapping))) {
1338 action.getFunctionBody().front().without_terminator()) {
1341 if (
result.isDefiniteFailure())
1343 if (
result.isSilenceableFailure()) {
1345 overallDiag = emitSilenceableError() <<
"actions failed";
1348 <<
"failed action: " <<
result.getMessage();
1350 <<
"when applied to this matching payload";
1355 if (
failed(detail::appendValueMappings(
1357 action.getFunctionBody().front().getTerminator()->getOperands(),
1358 state, getFlattenResults()))) {
1360 <<
"action @" << action.getName()
1361 <<
" has results associated with multiple payload entities, "
1362 "but flattening was not requested";
1377 results.
set(llvm::cast<OpResult>(getUpdated()),
1379 for (
auto &&[
result, mapping] :
1380 llvm::zip_equal(getForwardedOutputs(), actionResultMapping)) {
1386void transform::ForeachMatchOp::getAsmResultNames(
1388 setNameFn(getUpdated(),
"updated_root");
1389 for (
Value v : getForwardedOutputs()) {
1390 setNameFn(v,
"yielded");
1394void transform::ForeachMatchOp::getEffects(
1397 if (getOperation()->getNumOperands() < 1 ||
1398 getOperation()->getNumResults() < 1) {
1422 matcherList.push_back(SymbolRefAttr::get(matcher));
1423 actionList.push_back(SymbolRefAttr::get(action));
1437 for (
auto &&[matcher, action, idx] : llvm::zip_equal(
1440 printer << cast<SymbolRefAttr>(matcher) <<
" -> "
1441 << cast<SymbolRefAttr>(action);
1449LogicalResult transform::ForeachMatchOp::verify() {
1450 if (getMatchers().size() != getActions().size())
1451 return emitOpError() <<
"expected the same number of matchers and actions";
1452 if (getMatchers().empty())
1453 return emitOpError() <<
"expected at least one match/action pair";
1457 if (matcherNames.insert(name).second)
1460 <<
" is used more than once, only the first match will apply";
1471 bool alsoVerifyInternal =
false) {
1472 auto transformOp = cast<transform::TransformOpInterface>(op.getOperation());
1473 llvm::SmallDenseSet<unsigned> consumedArguments;
1474 if (!op.isExternal()) {
1478 for (
unsigned i = 0, e = op.getNumArguments(); i < e; ++i) {
1480 op.getArgAttr(i, transform::TransformDialect::kArgConsumedAttrName) !=
1483 op.getArgAttr(i, transform::TransformDialect::kArgReadOnlyAttrName) !=
1485 if (isConsumed && isReadOnly) {
1486 return transformOp.emitSilenceableError()
1487 <<
"argument #" << i <<
" cannot be both readonly and consumed";
1489 if ((op.isExternal() || alsoVerifyInternal) && !isConsumed && !isReadOnly) {
1490 return transformOp.emitSilenceableError()
1491 <<
"must provide consumed/readonly status for arguments of "
1492 "external or called ops";
1494 if (op.isExternal())
1497 if (consumedArguments.contains(i) && !isConsumed && isReadOnly) {
1498 return transformOp.emitSilenceableError()
1499 <<
"argument #" << i
1500 <<
" is consumed in the body but is not marked as such";
1502 if (emitWarnings && !consumedArguments.contains(i) && isConsumed) {
1506 <<
"op argument #" << i
1507 <<
" is not consumed in the body but is marked as consumed";
1513LogicalResult transform::ForeachMatchOp::verifySymbolUses(
1515 assert(getMatchers().size() == getActions().size());
1517 StringAttr::get(
getContext(), TransformDialect::kArgConsumedAttrName);
1518 for (
auto &&[matcher, action] :
1519 llvm::zip_equal(getMatchers(), getActions())) {
1521 auto matcherSymbol = dyn_cast_or_null<FunctionOpInterface>(
1523 cast<SymbolRefAttr>(matcher)));
1524 auto actionSymbol = dyn_cast_or_null<FunctionOpInterface>(
1526 cast<SymbolRefAttr>(action)));
1527 if (!matcherSymbol ||
1528 !isa<TransformOpInterface>(matcherSymbol.getOperation()))
1529 return emitError() <<
"unresolved matcher symbol " << matcher;
1530 if (!actionSymbol ||
1531 !isa<TransformOpInterface>(actionSymbol.getOperation()))
1532 return emitError() <<
"unresolved action symbol " << action;
1537 .checkAndReport())) {
1543 .checkAndReport())) {
1548 TypeRange operandTypes = getOperandTypes();
1549 TypeRange matcherArguments = matcherSymbol.getArgumentTypes();
1550 if (operandTypes.size() != matcherArguments.size()) {
1552 emitError() <<
"the number of operands (" << operandTypes.size()
1553 <<
") doesn't match the number of matcher arguments ("
1554 << matcherArguments.size() <<
") for " << matcher;
1555 diag.attachNote(matcherSymbol->getLoc()) <<
"symbol declaration";
1558 for (
auto &&[i, operand, argument] :
1559 llvm::enumerate(operandTypes, matcherArguments)) {
1560 if (matcherSymbol.getArgAttr(i, consumedAttr)) {
1563 <<
"does not expect matcher symbol to consume its operand #" << i;
1564 diag.attachNote(matcherSymbol->getLoc()) <<
"symbol declaration";
1573 <<
"mismatching type interfaces for operand and matcher argument #"
1574 << i <<
" of matcher " << matcher;
1575 diag.attachNote(matcherSymbol->getLoc()) <<
"symbol declaration";
1580 TypeRange matcherResults = matcherSymbol.getResultTypes();
1581 TypeRange actionArguments = actionSymbol.getArgumentTypes();
1582 if (matcherResults.size() != actionArguments.size()) {
1583 return emitError() <<
"mismatching number of matcher results and "
1584 "action arguments between "
1585 << matcher <<
" (" << matcherResults.size() <<
") and "
1586 << action <<
" (" << actionArguments.size() <<
")";
1588 for (
auto &&[i, matcherType, actionType] :
1589 llvm::enumerate(matcherResults, actionArguments)) {
1593 return emitError() <<
"mismatching type interfaces for matcher result "
1594 "and action argument #"
1595 << i <<
"of matcher " << matcher <<
" and action "
1600 TypeRange actionResults = actionSymbol.getResultTypes();
1601 auto resultTypes =
TypeRange(getResultTypes()).drop_front();
1602 if (actionResults.size() != resultTypes.size()) {
1604 emitError() <<
"the number of action results ("
1605 << actionResults.size() <<
") for " << action
1606 <<
" doesn't match the number of extra op results ("
1607 << resultTypes.size() <<
")";
1608 diag.attachNote(actionSymbol->getLoc()) <<
"symbol declaration";
1611 for (
auto &&[i, resultType, actionType] :
1612 llvm::enumerate(resultTypes, actionResults)) {
1617 emitError() <<
"mismatching type interfaces for action result #" << i
1618 <<
" of action " << action <<
" and op result";
1619 diag.attachNote(actionSymbol->getLoc()) <<
"symbol declaration";
1637 detail::prepareValueMappings(payloads, getTargets(), state);
1638 size_t numIterations = payloads.empty() ? 0 : payloads.front().size();
1639 bool withZipShortest = getWithZipShortest();
1643 if (withZipShortest) {
1647 return a.size() <
b.size();
1650 for (
auto &payload : payloads)
1651 payload.resize(numIterations);
1657 for (
size_t argIdx = 1; !withZipShortest && argIdx < payloads.size();
1659 if (payloads[argIdx].size() != numIterations) {
1660 return emitSilenceableError()
1661 <<
"prior targets' payload size (" << numIterations
1662 <<
") differs from payload size (" << payloads[argIdx].size()
1663 <<
") of target " << getTargets()[argIdx];
1672 for (
size_t iterIdx = 0; iterIdx < numIterations; iterIdx++) {
1675 for (
auto &&[argIdx, blockArg] : llvm::enumerate(blockArguments)) {
1686 llvm::cast<transform::TransformOpInterface>(
transform));
1692 OperandRange yieldOperands = getYieldOp().getOperands();
1693 for (
auto &&[
result, yieldOperand, resTuple] :
1694 llvm::zip_equal(getResults(), yieldOperands, zippedResults))
1696 if (isa<TransformHandleTypeInterface>(
result.getType()))
1697 llvm::append_range(resTuple, state.
getPayloadOps(yieldOperand));
1698 else if (isa<TransformValueHandleTypeInterface>(
result.getType()))
1700 else if (isa<TransformParamTypeInterface>(
result.getType()))
1701 llvm::append_range(resTuple, state.
getParams(yieldOperand));
1703 assert(
false &&
"unhandled handle type");
1707 for (
auto &&[
result, resPayload] : zip_equal(getResults(), zippedResults))
1713void transform::ForeachOp::getEffects(
1717 for (
auto &&[
target, blockArg] :
1718 llvm::zip(getTargetsMutable(), getBody().front().getArguments())) {
1720 if (any_of(getBody().front().without_terminator(), [&](
Operation &op) {
1722 cast<TransformOpInterface>(&op));
1730 if (any_of(getBody().front().without_terminator(), [&](
Operation &op) {
1734 }
else if (any_of(getBody().front().without_terminator(), [&](
Operation &op) {
1743void transform::ForeachOp::getSuccessorRegions(
1745 Region *bodyRegion = &getBody();
1747 regions.emplace_back(bodyRegion);
1754 "unexpected region index");
1755 regions.emplace_back(bodyRegion);
1765transform::ForeachOp::getEntrySuccessorOperands(
RegionSuccessor successor) {
1768 assert(successor.
getSuccessor() == &getBody() &&
"unexpected region index");
1769 return getOperation()->getOperands();
1772transform::YieldOp transform::ForeachOp::getYieldOp() {
1773 return cast<transform::YieldOp>(getBody().front().getTerminator());
1776LogicalResult transform::ForeachOp::verify() {
1777 for (
auto [targetOpt, bodyArgOpt] :
1778 llvm::zip_longest(getTargets(), getBody().front().getArguments())) {
1779 if (!targetOpt || !bodyArgOpt)
1780 return emitOpError() <<
"expects the same number of targets as the body "
1781 "has block arguments";
1782 if (targetOpt.value().getType() != bodyArgOpt.value().getType())
1784 "expects co-indexed targets and the body's "
1785 "block arguments to have the same op/value/param type");
1788 for (
auto [resultOpt, yieldOperandOpt] :
1789 llvm::zip_longest(getResults(), getYieldOp().getOperands())) {
1790 if (!resultOpt || !yieldOperandOpt)
1791 return emitOpError() <<
"expects the same number of results as the "
1792 "yield terminator has operands";
1793 if (resultOpt.value().getType() != yieldOperandOpt.value().getType())
1794 return emitOpError(
"expects co-indexed results and yield "
1795 "operands to have the same op/value/param type");
1813 for (
int64_t i = 0, e = getNthParent(); i < e; ++i) {
1816 bool checkIsolatedFromAbove =
1817 !getIsolatedFromAbove() ||
1819 bool checkOpName = !getOpName().has_value() ||
1821 if (checkIsolatedFromAbove && checkOpName)
1826 if (getAllowEmptyResults()) {
1827 results.
set(llvm::cast<OpResult>(getResult()), parents);
1831 emitSilenceableError()
1832 <<
"could not find a parent op that matches all requirements";
1833 diag.attachNote(
target->getLoc()) <<
"target op";
1837 if (getDeduplicate()) {
1838 if (resultSet.insert(parent).second)
1839 parents.push_back(parent);
1841 parents.push_back(parent);
1844 results.
set(llvm::cast<OpResult>(getResult()), parents);
1856 int64_t resultNumber = getResultNumber();
1858 if (std::empty(payloadOps)) {
1859 results.
set(cast<OpResult>(getResult()), {});
1862 if (!llvm::hasSingleElement(payloadOps))
1864 <<
"handle must be mapped to exactly one payload op";
1867 if (
target->getNumResults() <= resultNumber)
1869 results.
set(llvm::cast<OpResult>(getResult()),
1870 llvm::to_vector(
target->getResult(resultNumber).getUsers()));
1884 if (llvm::isa<BlockArgument>(v)) {
1886 emitSilenceableError() <<
"cannot get defining op of block argument";
1887 diag.attachNote(v.getLoc()) <<
"target value";
1890 definingOps.push_back(v.getDefiningOp());
1892 results.
set(llvm::cast<OpResult>(getResult()), definingOps);
1904 int64_t operandNumber = getOperandNumber();
1908 target->getNumOperands() <= operandNumber
1910 :
target->getOperand(operandNumber).getDefiningOp();
1913 emitSilenceableError()
1914 <<
"could not find a producer for operand number: " << operandNumber
1916 diag.attachNote(
target->getLoc()) <<
"target op";
1919 producers.push_back(producer);
1921 results.
set(llvm::cast<OpResult>(getResult()), producers);
1937 getLoc(), getIsAll(), getIsInverted(), getRawPositionList(),
1938 target->getNumOperands(), operandPositions);
1939 if (
diag.isSilenceableFailure()) {
1941 <<
"while considering positions of this payload operation";
1944 llvm::append_range(operands,
1945 llvm::map_range(operandPositions, [&](
int64_t pos) {
1946 return target->getOperand(pos);
1949 results.
setValues(cast<OpResult>(getResult()), operands);
1953LogicalResult transform::GetOperandOp::verify() {
1955 getIsInverted(), getIsAll());
1970 getLoc(), getIsAll(), getIsInverted(), getRawPositionList(),
1971 target->getNumResults(), resultPositions);
1972 if (
diag.isSilenceableFailure()) {
1974 <<
"while considering positions of this payload operation";
1977 llvm::append_range(opResults,
1978 llvm::map_range(resultPositions, [&](
int64_t pos) {
1979 return target->getResult(pos);
1982 results.
setValues(cast<OpResult>(getResult()), opResults);
1986LogicalResult transform::GetResultOp::verify() {
1988 getIsInverted(), getIsAll());
1995void transform::GetTypeOp::getEffects(
2008 Type type = value.getType();
2009 if (getElemental()) {
2010 if (
auto shaped = dyn_cast<ShapedType>(type)) {
2011 type = shaped.getElementType();
2014 params.push_back(TypeAttr::get(type));
2016 results.
setParams(cast<OpResult>(getResult()), params);
2034 if (
result.isDefiniteFailure())
2037 if (
result.isSilenceableFailure()) {
2038 if (mode == transform::FailurePropagationMode::Propagate) {
2058 getOperation(), getTarget());
2059 assert(callee &&
"unverified reference to unknown symbol");
2061 if (callee.isExternal())
2066 detail::prepareValueMappings(mappings, getOperands(), state);
2068 for (
auto &&[arg, map] :
2069 llvm::zip_equal(callee.getBody().front().getArguments(), mappings)) {
2075 callee.getBody().front(), getFailurePropagationMode(), state, results);
2081 detail::prepareValueMappings(
2082 mappings, callee.getBody().front().getTerminator()->getOperands(), state);
2083 for (
auto &&[
result, mapping] : llvm::zip_equal(getResults(), mappings))
2091void transform::IncludeOp::getEffects(
2106 auto defaultEffects = [&] {
2113 getOperation()->getAttrOfType<SymbolRefAttr>(getTargetAttrName());
2115 return defaultEffects();
2117 getOperation(), getTarget());
2119 return defaultEffects();
2121 for (
unsigned i = 0, e = getNumOperands(); i < e; ++i) {
2122 if (callee.getArgAttr(i, TransformDialect::kArgConsumedAttrName))
2124 else if (callee.getArgAttr(i, TransformDialect::kArgReadOnlyAttrName))
2133 auto targetAttr = getOperation()->getAttrOfType<SymbolRefAttr>(
"target");
2135 return emitOpError() <<
"expects a 'target' symbol reference attribute";
2140 return emitOpError() <<
"does not reference a named transform sequence";
2142 FunctionType fnType =
target.getFunctionType();
2143 if (fnType.getNumInputs() != getNumOperands())
2144 return emitError(
"incorrect number of operands for callee");
2146 for (
unsigned i = 0, e = fnType.getNumInputs(); i != e; ++i) {
2147 if (getOperand(i).
getType() != fnType.getInput(i)) {
2148 return emitOpError(
"operand type mismatch: expected operand type ")
2149 << fnType.getInput(i) <<
", but provided "
2150 << getOperand(i).getType() <<
" for operand number " << i;
2154 if (fnType.getNumResults() != getNumResults())
2155 return emitError(
"incorrect number of results for callee");
2157 for (
unsigned i = 0, e = fnType.getNumResults(); i != e; ++i) {
2158 Type resultType = getResult(i).getType();
2159 Type funcType = fnType.getResult(i);
2162 <<
" must implement the same transform dialect "
2163 "interface as the corresponding callee result";
2168 cast<FunctionOpInterface>(*
target),
false,
2178 ::std::optional<::mlir::Operation *> maybeCurrent,
2180 if (!maybeCurrent.has_value()) {
2185 return emitSilenceableError() <<
"operation is not empty";
2196 for (
auto acceptedAttr : getOpNames().getAsRange<StringAttr>()) {
2197 if (acceptedAttr.getValue() == currentOpName)
2200 return emitSilenceableError() <<
"wrong operation name";
2211 auto signedAPIntAsString = [&](
const APInt &value) {
2213 llvm::raw_string_ostream os(str);
2214 value.print(os,
true);
2221 if (params.size() != references.size()) {
2222 return emitSilenceableError()
2223 <<
"parameters have different payload lengths (" << params.size()
2224 <<
" vs " << references.size() <<
")";
2227 for (
auto &&[i, param, reference] : llvm::enumerate(params, references)) {
2228 auto intAttr = llvm::dyn_cast<IntegerAttr>(param);
2229 auto refAttr = llvm::dyn_cast<IntegerAttr>(reference);
2230 if (!intAttr || !refAttr) {
2232 <<
"non-integer parameter value not expected";
2234 if (intAttr.getType() != refAttr.getType()) {
2236 <<
"mismatching integer attribute types in parameter #" << i;
2238 APInt value = intAttr.getValue();
2239 APInt refValue = refAttr.getValue();
2243 auto reportError = [&](StringRef direction) {
2245 emitSilenceableError() <<
"expected parameter to be " << direction
2246 <<
" " << signedAPIntAsString(refValue)
2247 <<
", got " << signedAPIntAsString(value);
2248 diag.attachNote(getParam().getLoc())
2249 <<
"value # " << position
2250 <<
" associated with the parameter defined here";
2254 switch (getPredicate()) {
2255 case MatchCmpIPredicate::eq:
2256 if (value.eq(refValue))
2258 return reportError(
"equal to");
2259 case MatchCmpIPredicate::ne:
2260 if (value.ne(refValue))
2262 return reportError(
"not equal to");
2263 case MatchCmpIPredicate::lt:
2264 if (value.slt(refValue))
2266 return reportError(
"less than");
2267 case MatchCmpIPredicate::le:
2268 if (value.sle(refValue))
2270 return reportError(
"less than or equal to");
2271 case MatchCmpIPredicate::gt:
2272 if (value.sgt(refValue))
2274 return reportError(
"greater than");
2275 case MatchCmpIPredicate::ge:
2276 if (value.sge(refValue))
2278 return reportError(
"greater than or equal to");
2284void transform::MatchParamCmpIOp::getEffects(
2298 results.
setParams(cast<OpResult>(getParam()), {getValue()});
2311 if (isa<TransformHandleTypeInterface>(handles.front().
getType())) {
2313 for (
Value operand : handles)
2314 llvm::append_range(operations, state.
getPayloadOps(operand));
2315 if (!getDeduplicate()) {
2316 results.
set(llvm::cast<OpResult>(getResult()), operations);
2321 results.
set(llvm::cast<OpResult>(getResult()), uniqued.getArrayRef());
2325 if (llvm::isa<TransformParamTypeInterface>(handles.front().getType())) {
2327 for (
Value attribute : handles)
2328 llvm::append_range(attrs, state.
getParams(attribute));
2329 if (!getDeduplicate()) {
2330 results.
setParams(cast<OpResult>(getResult()), attrs);
2335 results.
setParams(cast<OpResult>(getResult()), uniqued.getArrayRef());
2340 llvm::isa<TransformValueHandleTypeInterface>(handles.front().getType()) &&
2341 "expected value handle type");
2343 for (
Value value : handles)
2345 if (!getDeduplicate()) {
2346 results.
setValues(cast<OpResult>(getResult()), payloadValues);
2351 results.
setValues(cast<OpResult>(getResult()), uniqued.getArrayRef());
2355bool transform::MergeHandlesOp::allowsRepeatedHandleOperands() {
2357 return getDeduplicate();
2360void transform::MergeHandlesOp::getEffects(
2369OpFoldResult transform::MergeHandlesOp::fold(FoldAdaptor adaptor) {
2370 if (getDeduplicate() || getHandles().size() != 1)
2375 return getHandles().front();
2394 if (
failed(detail::mapPossibleTopLevelTransformOpBlockArguments(
2395 state, this->getOperation(), getBody())))
2399 FailurePropagationMode::Propagate, state, results);
2402void transform::NamedSequenceOp::getEffects(
2405ParseResult transform::NamedSequenceOp::parse(
OpAsmParser &parser,
2409 getFunctionTypeAttrName(
result.name),
2412 std::string &) { return builder.getFunctionType(inputs, results); },
2413 getArgAttrsAttrName(
result.name), getResAttrsAttrName(
result.name));
2416void transform::NamedSequenceOp::print(
OpAsmPrinter &printer) {
2418 printer, cast<FunctionOpInterface>(getOperation()),
false,
2419 getFunctionTypeAttrName().getValue(), getArgAttrsAttrName(),
2420 getResAttrsAttrName());
2430 if (
auto parent = op->
getParentOfType<transform::TransformOpInterface>()) {
2433 <<
"cannot be defined inside another transform op";
2434 diag.attachNote(parent.
getLoc()) <<
"ancestor transform op";
2438 if (op.isExternal() || op.getFunctionBody().empty()) {
2445 if (op.getFunctionBody().front().empty())
2448 Operation *terminator = &op.getFunctionBody().front().back();
2449 if (!isa<transform::YieldOp>(terminator)) {
2452 << transform::YieldOp::getOperationName()
2453 <<
"' as terminator";
2454 diag.attachNote(terminator->
getLoc()) <<
"terminator";
2458 if (terminator->
getNumOperands() != op.getResultTypes().size()) {
2460 <<
"expected terminator to have as many operands as the parent op "
2463 for (
auto [i, operandType, resultType] : llvm::zip_equal(
2466 if (operandType == resultType)
2469 <<
"the type of the terminator operand #" << i
2470 <<
" must match the type of the corresponding parent op result ("
2471 << operandType <<
" vs " << resultType <<
")";
2484 transform::TransformDialect::kWithNamedSequenceAttrName)) {
2487 <<
"expects the parent symbol table to have the '"
2488 << transform::TransformDialect::kWithNamedSequenceAttrName
2490 diag.attachNote(parent->
getLoc()) <<
"symbol table operation";
2495 if (
auto parent = op->
getParentOfType<transform::TransformOpInterface>()) {
2498 <<
"cannot be defined inside another transform op";
2499 diag.attachNote(parent.
getLoc()) <<
"ancestor transform op";
2503 if (op.isExternal() || op.getBody().empty())
2507 if (op.getBody().front().empty())
2511 for (
Operation &child : op.getBody().front().without_terminator()) {
2512 if (!isa<transform::TransformOpInterface>(child)) {
2515 <<
"expected children ops to implement TransformOpInterface";
2516 diag.attachNote(child.getLoc()) <<
"op without interface";
2521 Operation *terminator = &op.getBody().front().back();
2522 if (!isa<transform::YieldOp>(terminator)) {
2525 << transform::YieldOp::getOperationName()
2526 <<
"' as terminator";
2527 diag.attachNote(terminator->
getLoc()) <<
"terminator";
2531 if (terminator->
getNumOperands() != op.getFunctionType().getNumResults()) {
2533 <<
"expected terminator to have as many operands as the parent op "
2536 for (
auto [i, operandType, resultType] :
2537 llvm::zip_equal(llvm::seq<unsigned>(0, terminator->
getNumOperands()),
2539 op.getFunctionType().getResults())) {
2540 if (operandType == resultType)
2543 <<
"the type of the terminator operand #" << i
2544 <<
" must match the type of the corresponding parent op result ("
2545 << operandType <<
" vs " << resultType <<
")";
2548 auto funcOp = cast<FunctionOpInterface>(*op);
2551 if (!
diag.succeeded())
2558LogicalResult transform::NamedSequenceOp::verify() {
2563template <
typename FnTy>
2568 types.reserve(1 + extraBindingTypes.size());
2569 types.push_back(bbArgType);
2570 llvm::append_range(types, extraBindingTypes);
2580 if constexpr (llvm::function_traits<FnTy>::num_args == 3) {
2588void transform::NamedSequenceOp::build(
OpBuilder &builder,
2591 SequenceBodyBuilderFn bodyBuilder,
2597 TypeAttr::get(FunctionType::get(builder.
getContext(),
2598 rootType, resultTypes)));
2614 size_t numAssociations =
2616 .Case([&](TransformHandleTypeInterface opHandle) {
2619 .Case([&](TransformValueHandleTypeInterface valueHandle) {
2622 .Case([&](TransformParamTypeInterface param) {
2623 return llvm::range_size(state.
getParams(getHandle()));
2625 .DefaultUnreachable(
"unknown kind of transform dialect type");
2626 results.
setParams(cast<OpResult>(getNum()),
2631LogicalResult transform::NumAssociationsOp::verify() {
2633 auto resultType = cast<TransformParamTypeInterface>(getNum().
getType());
2653 results.
set(cast<OpResult>(getResult()),
result);
2673 .Case([&](TransformHandleTypeInterface x) {
2676 .Case([&](TransformValueHandleTypeInterface x) {
2679 .Case([&](TransformParamTypeInterface x) {
2680 return llvm::range_size(state.
getParams(getHandle()));
2682 .DefaultUnreachable(
"unknown transform dialect type interface");
2684 auto produceNumOpsError = [&]() {
2685 return emitSilenceableError()
2686 << getHandle() <<
" expected to contain " << this->getNumResults()
2687 <<
" payloads but it contains " << numPayloads <<
" payloads";
2692 if (numPayloads > getNumResults() && !getOverflowResult().has_value())
2693 return produceNumOpsError();
2698 if (numPayloads < getNumResults() && getFailOnPayloadTooSmall() &&
2699 (numPayloads != 0 || !getPassThroughEmptyHandle()))
2700 return produceNumOpsError();
2704 if (getOverflowResult())
2705 resultHandles[*getOverflowResult()].reserve(numPayloads - getNumResults());
2707 auto container = [&]() {
2708 if (isa<TransformHandleTypeInterface>(getHandle().
getType())) {
2709 return llvm::map_to_vector(
2711 [](
Operation *op) -> MappedValue {
return op; });
2713 if (isa<TransformValueHandleTypeInterface>(getHandle().
getType())) {
2715 [](
Value v) -> MappedValue {
return v; });
2717 assert(isa<TransformParamTypeInterface>(getHandle().
getType()) &&
2718 "unsupported kind of transform dialect type");
2719 return llvm::map_to_vector(state.
getParams(getHandle()),
2720 [](
Attribute a) -> MappedValue {
return a; });
2723 for (
auto &&en : llvm::enumerate(container)) {
2724 int64_t resultNum = en.index();
2725 if (resultNum >= getNumResults())
2726 resultNum = *getOverflowResult();
2727 resultHandles[resultNum].push_back(en.value());
2731 for (
auto &&it : llvm::enumerate(resultHandles))
2738void transform::SplitHandleOp::getEffects(
2746LogicalResult transform::SplitHandleOp::verify() {
2747 if (getOverflowResult().has_value() &&
2748 !(*getOverflowResult() < getNumResults()))
2749 return emitOpError(
"overflow_result is not a valid result index");
2751 for (
Type resultType : getResultTypes()) {
2755 return emitOpError(
"expects result types to implement the same transform "
2756 "interface as the operand type");
2770 unsigned numRepetitions = llvm::range_size(state.
getPayloadOps(getPattern()));
2771 for (
const auto &en : llvm::enumerate(getHandles())) {
2772 Value handle = en.value();
2773 if (isa<TransformHandleTypeInterface>(handle.getType())) {
2777 payload.reserve(numRepetitions * current.size());
2778 for (
unsigned i = 0; i < numRepetitions; ++i)
2779 llvm::append_range(payload, current);
2780 results.
set(llvm::cast<OpResult>(getReplicated()[en.index()]), payload);
2782 assert(llvm::isa<TransformParamTypeInterface>(handle.getType()) &&
2783 "expected param type");
2786 params.reserve(numRepetitions * current.size());
2787 for (
unsigned i = 0; i < numRepetitions; ++i)
2788 llvm::append_range(params, current);
2789 results.
setParams(llvm::cast<OpResult>(getReplicated()[en.index()]),
2796void transform::ReplicateOp::getEffects(
2813 if (
failed(mapBlockArguments(state)))
2821 OpAsmParser &parser, std::optional<OpAsmParser::UnresolvedOperand> &root,
2828 root = std::nullopt;
2831 if (failed(hasRoot.
value()))
2845 if (failed(parser.
parseType(rootType))) {
2849 if (!extraBindings.empty()) {
2854 if (extraBindingTypes.size() != extraBindings.size()) {
2856 "expected types to be provided for all operands");
2872 bool hasExtras = !extraBindings.empty();
2882 printer << rootType;
2884 printer <<
", " << llvm::interleaved(extraBindingTypes) <<
')';
2891 auto iface = dyn_cast<transform::TransformOpInterface>(use.
getOwner());
2895 return isHandleConsumed(use.
get(), iface);
2906 if (!potentialConsumer) {
2907 potentialConsumer = &use;
2912 <<
" has more than one potential consumer";
2915 diag.attachNote(use.getOwner()->getLoc())
2916 <<
"used here as operand #" << use.getOperandNumber();
2923LogicalResult transform::SequenceOp::verify() {
2924 assert(getBodyBlock()->getNumArguments() >= 1 &&
2925 "the number of arguments must have been verified to be more than 1 by "
2926 "PossibleTopLevelTransformOpTrait");
2928 if (!getRoot() && !getExtraBindings().empty()) {
2930 <<
"does not expect extra operands when used as top-level";
2936 return (
emitOpError() <<
"block argument #" << arg.getArgNumber());
2943 for (
Operation &child : *getBodyBlock()) {
2944 if (!isa<TransformOpInterface>(child) &&
2945 &child != &getBodyBlock()->back()) {
2948 <<
"expected children ops to implement TransformOpInterface";
2949 diag.attachNote(child.getLoc()) <<
"op without interface";
2954 auto report = [&]() {
2955 return (child.emitError() <<
"result #" <<
result.getResultNumber());
2962 if (!getBodyBlock()->mightHaveTerminator())
2963 return emitOpError() <<
"expects to have a terminator in the body";
2965 if (getBodyBlock()->getTerminator()->getOperandTypes() !=
2966 getOperation()->getResultTypes()) {
2968 <<
"expects the types of the terminator operands "
2969 "to match the types of the result";
2970 diag.attachNote(getBodyBlock()->getTerminator()->getLoc()) <<
"terminator";
2976void transform::SequenceOp::getEffects(
2982transform::SequenceOp::getEntrySuccessorOperands(
RegionSuccessor successor) {
2983 assert(successor.
getSuccessor() == &getBody() &&
"unexpected region index");
2984 if (getOperation()->getNumOperands() > 0)
2985 return getOperation()->getOperands();
2987 getOperation()->operand_end());
2990void transform::SequenceOp::getSuccessorRegions(
2993 Region *bodyRegion = &getBody();
2994 regions.emplace_back(bodyRegion);
3000 "unexpected region index");
3006 if (getNumOperands() == 0)
3009 return getResults();
3010 return getBody().getArguments();
3013void transform::SequenceOp::getRegionInvocationBounds(
3016 bounds.emplace_back(1, 1);
3021 FailurePropagationMode failurePropagationMode,
3023 SequenceBodyBuilderFn bodyBuilder) {
3024 build(builder, state, resultTypes, failurePropagationMode, root,
3033 FailurePropagationMode failurePropagationMode,
3035 SequenceBodyBuilderArgsFn bodyBuilder) {
3036 build(builder, state, resultTypes, failurePropagationMode, root,
3044 FailurePropagationMode failurePropagationMode,
3046 SequenceBodyBuilderFn bodyBuilder) {
3047 build(builder, state, resultTypes, failurePropagationMode,
Value(),
3055 FailurePropagationMode failurePropagationMode,
3057 SequenceBodyBuilderArgsFn bodyBuilder) {
3058 build(builder, state, resultTypes, failurePropagationMode,
Value(),
3076 build(builder,
result, name);
3083 llvm::outs() <<
"[[[ IR printer: ";
3084 if (getName().has_value())
3085 llvm::outs() << *getName() <<
" ";
3088 if (getAssumeVerified().value_or(
false))
3090 if (getUseLocalScope().value_or(
false))
3092 if (getSkipRegions().value_or(
false))
3096 llvm::outs() <<
"top-level ]]]\n";
3098 llvm::outs() <<
"\n";
3099 llvm::outs().flush();
3103 llvm::outs() <<
"]]]\n";
3105 target->print(llvm::outs(), printFlags);
3106 llvm::outs() <<
"\n";
3109 llvm::outs().flush();
3113void transform::PrintOp::getEffects(
3118 if (!getTargetMutable().empty())
3138 <<
"failed to verify payload op";
3139 diag.attachNote(
target->getLoc()) <<
"payload op";
3145void transform::VerifyOp::getEffects(
3154void transform::YieldOp::getEffects(
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
static ParseResult parseKeyValuePair(AsmParser &parser, DataLayoutEntryInterface &entry, bool tryType=false)
Parse an entry which can either be of the form key = value or a dlti.dl_entry attribute.
static std::string diag(const llvm::Value &value)
static llvm::ManagedStatic< PassManagerOptions > options
ParseResult parseSymbolName(StringAttr &result)
Parse an -identifier and store it (without the '@' symbol) in a string attribute.
@ None
Zero or more operands with no delimiters.
@ Braces
{} brackets surrounding zero or more operands.
virtual ParseResult parseOptionalKeywordOrString(std::string *result)=0
Parse an optional keyword or string.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
MLIRContext * getContext() const
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseRSquare()=0
Parse a ] token.
virtual ParseResult parseOptionalRParen()=0
Parse a ) token if present.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual OptionalParseResult parseOptionalAttribute(Attribute &result, Type type={})=0
Parse an arbitrary optional attribute of a given type and return it in result.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual ParseResult parseColon()=0
Parse a : token.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseArrow()=0
Parse a '->' token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseComma()=0
Parse a , token.
virtual ParseResult parseOptionalLParen()=0
Parse a ( token if present.
ParseResult parseTypeList(SmallVectorImpl< Type > &result)
Parse a type list.
virtual ParseResult parseOptionalLSquare()=0
Parse a [ token if present.
virtual void decreaseIndent()
Decrease indentation.
virtual void increaseIndent()
Increase indentation.
virtual void printAttribute(Attribute attr)
virtual void printNewline()
Print a newline and indent the printer to the start of the current operation/attribute/type.
Attributes are known-constant values of operations.
This class represents an argument of a Block.
Block represents an ordered list of Operations.
BlockArgument getArgument(unsigned i)
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
OpListType & getOperations()
Operation * getTerminator()
Get the terminator operation of this block.
BlockArgListType getArguments()
iterator_range< iterator > without_terminator()
Return an iterator range over the operation within this block excluding the terminator operation at t...
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
This class is a general helper class for creating context-global objects like types,...
IntegerAttr getI64IntegerAttr(int64_t value)
StringAttr getStringAttr(const Twine &bytes)
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
MLIRContext * getContext() const
A compatibility class connecting InFlightDiagnostic to DiagnosedSilenceableFailure while providing an...
The result of a transform IR operation application.
LogicalResult silence()
Converts silenceable failure into LogicalResult success without reporting the diagnostic,...
static DiagnosedSilenceableFailure success()
Constructs a DiagnosedSilenceableFailure in the success state.
std::string getMessage() const
Returns the diagnostic message without emitting it.
Diagnostic & attachNote(std::optional< Location > loc=std::nullopt)
Attaches a note to the last diagnostic.
LogicalResult checkAndReport()
Converts all kinds of failure into a LogicalResult failure, emitting the diagnostic if necessary.
bool succeeded() const
Returns true if this is a success.
static DiagnosedSilenceableFailure definiteFailure()
Constructs a DiagnosedSilenceableFailure in the failure state.
Dialects are groups of MLIR operations, types and attributes, as well as behavior associated with the...
A class for computing basic dominance information.
This class represents a frozen set of patterns that can be processed by a pattern applicator.
This class allows control over how the GreedyPatternRewriteDriver works.
static constexpr int64_t kNoLimit
IRValueT get() const
Return the current value being used by this operand.
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
This class represents a diagnostic that is inflight and set to be reported.
This class represents upper and lower bounds on the number of times a region of a RegionBranchOpInter...
Conversion from types to the LLVM IR dialect.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
std::vector< Dialect * > getLoadedDialects()
Return information about all IR dialects loaded in the context.
ArrayRef< RegisteredOperationName > getRegisteredOperations()
Return a sorted array containing the information about all registered operations.
void append(StringRef name, Attribute attr)
Add an attribute with the specified name.
NamedAttribute represents a combination of a name and an Attribute value.
StringAttr getName() const
Return the name of the attribute.
Attribute getValue() const
Return the value of the attribute.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual OptionalParseResult parseOptionalOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single operand if present.
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
void printOperands(const ContainerType &container)
Print a comma separated list of operands.
virtual void printOperand(Value value)=0
Print implementations for various things an operation contains.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Listener * getListener() const
Returns the current listener of this builder, or nullptr if this builder doesn't have a listener.
This class represents a single result from folding an operation.
This class represents an operand of an operation.
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Set of flags used to control the behavior of the various IR print methods (e.g.
OpPrintingFlags & useLocalScope(bool enable=true)
Use local scope when printing the operation.
OpPrintingFlags & assumeVerified(bool enable=true)
Do not verify the operation when using custom operation printers.
OpPrintingFlags & skipRegions(bool skip=true)
Skip printing regions.
This is a value defined by a result of an operation.
This class provides the API for ops that are known to be isolated from above.
A trait used to provide symbol table functionalities to a region operation.
A wrapper class that allows for printing an operation with a set of flags, useful to act as a "stream...
This class implements the operand iterators for the Operation class.
type_range getType() const
type_range getTypes() const
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
Operation is the basic unit of execution within MLIR.
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Attribute getAttr(StringAttr name)
Return the specified attribute if present, null otherwise.
Operation * clone(IRMapping &mapper, CloneOptions options=CloneOptions::all())
Create a deep copy of this operation, remapping any operands that use values outside of the operation...
Block * getBlock()
Returns the operation block that contains this operation.
Operation * getParentWithTrait()
Returns the closest surrounding parent operation with trait Trait.
Location getLoc()
The source location the operation was defined or derived from.
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
unsigned getNumOperands()
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
OperationName getName()
The name of an operation is the key identifier for it.
void print(raw_ostream &os, const OpPrintingFlags &flags={})
operand_range getOperands()
Returns an iterator on the underlying Value's.
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
result_range getOpResults()
This class implements Optional functionality for ParseResult.
ParseResult value() const
Access the internal ParseResult value.
bool has_value() const
Returns true if we contain a valid ParseResult value.
static const PassInfo * lookup(StringRef passArg)
Returns the pass info for the specified pass class or null if unknown.
The main pass manager and pipeline builder.
static const PassPipelineInfo * lookup(StringRef pipelineArg)
Returns the pass pipeline info for the specified pass pipeline or null if unknown.
Structure to group information about a passes and pass pipelines (argument to invoke via mlir-opt,...
LogicalResult addToPipeline(OpPassManager &pm, StringRef options, function_ref< LogicalResult(const Twine &)> errorHandler) const
Adds this pass registry entry to the given pass manager.
This class represents a point being branched from in the methods of the RegionBranchOpInterface.
bool isParent() const
Returns true if branching from the parent op.
RegionBranchTerminatorOpInterface getTerminatorPredecessorOrNull() const
Returns the terminator if branching from a region.
This class represents a successor of a region.
static RegionSuccessor parent()
Initialize a successor that branches after/out of the parent operation.
bool isParent() const
Return true if the successor is the parent operation.
Region * getSuccessor() const
Return the given region successor.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
BlockArgListType getArguments()
This is a "type erased" representation of a registered operation.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
static DerivedEffect * get()
This class represents a collection of SymbolTables.
virtual Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
static StringRef getSymbolAttrName()
Return the name of the attribute used for symbol names.
static Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
type_range getType() const
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
use_range getUses() const
Returns a range of all uses, which is useful for iterating over all uses.
A utility result that is used to signal how to proceed with an ongoing walk:
static WalkResult advance()
bool wasInterrupted() const
Returns true if the walk was interrupted.
static WalkResult interrupt()
Operation * getOwner() const
Return the owner of this operand.
A named class for passing around the variadic flag.
void printFunctionOp(OpAsmPrinter &p, FunctionOpInterface op, bool isVariadic, StringRef typeAttrName, StringAttr argAttrsName, StringAttr resAttrsName)
Printer implementation for function-like operations.
ParseResult parseFunctionOp(OpAsmParser &parser, OperationState &result, bool allowVariadic, StringAttr typeAttrName, FuncTypeBuilder funcTypeBuilder, StringAttr argAttrsName, StringAttr resAttrsName)
Parser implementation for function-like operations.
Include the generated interface declarations.
InFlightDiagnostic emitWarning(Location loc)
Utility method to emit a warning message using this location.
function_ref< void(Value, StringRef)> OpAsmSetValueNameFn
A functor used to set the name of the start of a result group of an operation.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
const FrozenRewritePatternSet GreedyRewriteConfig config
LogicalResult applyPatternsGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
DiagnosedSilenceableFailure emitSilenceableFailure(Location loc, const Twine &message={})
Emits a silenceable failure with the given message.
llvm::DenseSet< ValueT, ValueInfoT > DenseSet
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
LogicalResult applyOpPatternsGreedily(ArrayRef< Operation * > ops, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr, bool *allErased=nullptr)
Rewrite the specified ops by repeatedly applying the highest benefit patterns in a greedy worklist dr...
llvm::SetVector< T, Vector, Set, N > SetVector
DiagnosedDefiniteFailure emitDefiniteFailure(Location loc, const Twine &message={})
Emits a definite failure with the given message.
bool isOpTriviallyDead(Operation *op)
Return true if the given operation is unused, and has no side effects on memory that prevent erasing.
const FrozenRewritePatternSet & patterns
void eliminateCommonSubExpressions(RewriterBase &rewriter, DominanceInfo &domInfo, Operation *op, bool *changed=nullptr)
Eliminate common subexpressions within the given operation.
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
llvm::function_ref< Fn > function_ref
size_t moveLoopInvariantCode(ArrayRef< Region * > regions, function_ref< bool(Value, Region *)> isDefinedOutsideRegion, function_ref< bool(Operation *, Region *)> shouldMoveOutOfRegion, function_ref< void(Operation *, Region *)> moveOutOfRegion)
Given a list of regions, perform loop-invariant code motion.
This is the representation of an operand reference.
This represents an operation in an abstracted form, suitable for use with the builder APIs.
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
SmallVector< std::unique_ptr< Region >, 1 > regions
Regions that the op will hold.
Region * addRegion()
Create a region that should be attached to the operation.