34#include "llvm/ADT/DenseSet.h"
35#include "llvm/ADT/STLExtras.h"
36#include "llvm/ADT/ScopeExit.h"
37#include "llvm/ADT/SmallPtrSet.h"
38#include "llvm/ADT/TypeSwitch.h"
39#include "llvm/Support/Debug.h"
40#include "llvm/Support/DebugLog.h"
41#include "llvm/Support/ErrorHandling.h"
42#include "llvm/Support/InterleavedRange.h"
45#define DEBUG_TYPE "transform-dialect"
46#define DEBUG_TYPE_MATCHER "transform-matcher"
58 OpAsmParser &parser, std::optional<OpAsmParser::UnresolvedOperand> &root,
79 while (transformAncestor) {
80 if (transformAncestor == payload) {
83 <<
"cannot apply transform to itself (or one of its ancestors)";
84 diag.attachNote(payload->
getLoc()) <<
"target payload op";
87 transformAncestor = transformAncestor->
getParentOp();
93#include "mlir/Dialect/Transform/IR/TransformOps.cpp.inc"
99OperandRange transform::AlternativesOp::getEntrySuccessorOperands(
101 if (!successor.
isParent() && getOperation()->getNumOperands() == 1)
102 return getOperation()->getOperands();
104 getOperation()->operand_end());
107void transform::AlternativesOp::getSuccessorRegions(
109 for (
Region &alternative : llvm::drop_begin(
116 regions.emplace_back(&alternative, !getOperands().empty()
117 ? alternative.getArguments()
121 regions.emplace_back(getOperation(), getOperation()->getResults());
124void transform::AlternativesOp::getRegionInvocationBounds(
129 bounds.reserve(getNumRegions());
130 bounds.emplace_back(1, 1);
137 results.
set(res, {});
145 if (
Value scopeHandle = getScope())
146 llvm::append_range(originals, state.
getPayloadOps(scopeHandle));
151 if (original->isAncestor(getOperation())) {
153 <<
"scope must not contain the transforms being applied";
154 diag.attachNote(original->getLoc()) <<
"scope";
159 <<
"only isolated-from-above ops can be alternative scopes";
160 diag.attachNote(original->getLoc()) <<
"scope";
165 for (
Region ® : getAlternatives()) {
171 auto clones = llvm::to_vector(
172 llvm::map_range(originals, [](
Operation *op) {
return op->
clone(); }));
173 auto deleteClones = llvm::make_scope_exit([&] {
184 if (
result.isSilenceableFailure()) {
185 LDBG() <<
"alternative failed: " <<
result.getMessage();
190 if (::mlir::failed(
result.silence()))
199 deleteClones.release();
200 TrackingListener listener(state, *
this);
202 for (
const auto &kvp : llvm::zip(originals, clones)) {
209 detail::forwardTerminatorOperands(®.front(), state, results);
213 return emitSilenceableError() <<
"all alternatives failed";
216void transform::AlternativesOp::getEffects(
220 for (
Region *region : getRegions()) {
221 if (!region->empty())
227LogicalResult transform::AlternativesOp::verify() {
228 for (
Region &alternative : getAlternatives()) {
233 <<
"expects terminator operands to have the "
234 "same type as results of the operation";
235 diag.attachNote(terminator->
getLoc()) <<
"terminator";
255 if (
auto paramH = getParam()) {
257 if (params.size() != 1) {
258 if (targets.size() != params.size()) {
259 return emitSilenceableError()
260 <<
"parameter and target have different payload lengths ("
261 << params.size() <<
" vs " << targets.size() <<
")";
263 for (
auto &&[
target, attr] : llvm::zip_equal(targets, params))
264 target->setAttr(getName(), attr);
269 for (
auto *
target : targets)
270 target->setAttr(getName(), attr);
274void transform::AnnotateOp::getEffects(
286transform::ApplyCommonSubexpressionEliminationOp::applyToOne(
301void transform::ApplyCommonSubexpressionEliminationOp::getEffects(
326 auto addDefiningOpsToWorklist = [&](
Operation *op) {
329 if (
Operation *defOp = v.getDefiningOp())
330 if (
target->isProperAncestor(defOp))
331 worklist.insert(defOp);
339 const auto *it = llvm::find(worklist, op);
340 if (it != worklist.end())
349 addDefiningOpsToWorklist(op);
355 while (!worklist.empty()) {
359 addDefiningOpsToWorklist(op);
366void transform::ApplyDeadCodeEliminationOp::getEffects(
391 if (!getRegion().empty()) {
392 for (
Operation &op : getRegion().front()) {
393 cast<transform::PatternDescriptorOpInterface>(&op)
394 .populatePatternsWithState(
patterns, state);
404 config.setMaxIterations(getMaxIterations() ==
static_cast<uint64_t
>(-1)
406 : getMaxIterations());
407 config.setMaxNumRewrites(getMaxNumRewrites() ==
static_cast<uint64_t
>(-1)
409 : getMaxNumRewrites());
414 bool cseChanged =
false;
417 static const int64_t kNumMaxIterations = 50;
420 LogicalResult
result = failure();
434 ops.push_back(nestedOp);
443 <<
"greedy pattern application failed";
451 }
while (cseChanged && ++iteration < kNumMaxIterations);
453 if (iteration == kNumMaxIterations)
459LogicalResult transform::ApplyPatternsOp::verify() {
460 if (!getRegion().empty()) {
461 for (
Operation &op : getRegion().front()) {
462 if (!isa<transform::PatternDescriptorOpInterface>(&op)) {
464 <<
"expected children ops to implement "
465 "PatternDescriptorOpInterface";
466 diag.attachNote(op.
getLoc()) <<
"op without interface";
474void transform::ApplyPatternsOp::getEffects(
480void transform::ApplyPatternsOp::build(
489 bodyBuilder(builder,
result.location);
496void transform::ApplyCanonicalizationPatternsOp::populatePatterns(
500 dialect->getCanonicalizationPatterns(
patterns);
502 op.getCanonicalizationPatterns(
patterns, ctx);
516 std::unique_ptr<TypeConverter> defaultTypeConverter;
517 transform::TypeConverterBuilderOpInterface typeConverterBuilder =
518 getDefaultTypeConverter();
519 if (typeConverterBuilder)
520 defaultTypeConverter = typeConverterBuilder.getTypeConverter();
525 for (
Attribute attr : cast<ArrayAttr>(*getLegalOps()))
526 conversionTarget.addLegalOp(
529 for (
Attribute attr : cast<ArrayAttr>(*getIllegalOps()))
530 conversionTarget.addIllegalOp(
532 if (getLegalDialects())
533 for (
Attribute attr : cast<ArrayAttr>(*getLegalDialects()))
534 conversionTarget.addLegalDialect(cast<StringAttr>(attr).getValue());
535 if (getIllegalDialects())
536 for (
Attribute attr : cast<ArrayAttr>(*getIllegalDialects()))
537 conversionTarget.addIllegalDialect(cast<StringAttr>(attr).getValue());
545 if (!getPatterns().empty()) {
546 for (
Operation &op : getPatterns().front()) {
548 cast<transform::ConversionPatternDescriptorOpInterface>(&op);
551 std::unique_ptr<TypeConverter> typeConverter =
552 descriptor.getTypeConverter();
555 keepAliveConverters.emplace_back(std::move(typeConverter));
556 converter = keepAliveConverters.back().get();
559 if (!defaultTypeConverter) {
561 <<
"pattern descriptor does not specify type "
562 "converter and apply_conversion_patterns op has "
563 "no default type converter";
564 diag.attachNote(op.
getLoc()) <<
"pattern descriptor op";
567 converter = defaultTypeConverter.get();
573 descriptor.populateConversionTargetRules(*converter, conversionTarget);
575 descriptor.populatePatterns(*converter,
patterns);
583 TrackingListenerConfig trackingConfig;
584 trackingConfig.requireMatchingReplacementOpName =
false;
585 ErrorCheckingTrackingListener trackingListener(state, *
this, trackingConfig);
586 ConversionConfig conversionConfig;
587 if (getPreserveHandles())
588 conversionConfig.listener = &trackingListener;
599 LogicalResult status = failure();
600 if (getPartialConversion()) {
601 status = applyPartialConversion(
target, conversionTarget, frozenPatterns,
604 status = applyFullConversion(
target, conversionTarget, frozenPatterns,
611 diag = emitSilenceableError() <<
"dialect conversion failed";
612 diag.attachNote(
target->getLoc()) <<
"target op";
617 trackingListener.checkAndResetError();
619 if (
diag.succeeded()) {
621 return trackingFailure;
623 diag.attachNote() <<
"tracking listener also failed: "
628 if (!
diag.succeeded())
635LogicalResult transform::ApplyConversionPatternsOp::verify() {
636 if (getNumRegions() != 1 && getNumRegions() != 2)
638 if (!getPatterns().empty()) {
639 for (
Operation &op : getPatterns().front()) {
640 if (!isa<transform::ConversionPatternDescriptorOpInterface>(&op)) {
642 emitOpError() <<
"expected pattern children ops to implement "
643 "ConversionPatternDescriptorOpInterface";
644 diag.attachNote(op.
getLoc()) <<
"op without interface";
649 if (getNumRegions() == 2) {
650 Region &typeConverterRegion = getRegion(1);
651 if (!llvm::hasSingleElement(typeConverterRegion.
front()))
653 <<
"expected exactly one op in default type converter region";
655 auto typeConverterOp = dyn_cast<transform::TypeConverterBuilderOpInterface>(
657 if (!typeConverterOp) {
659 <<
"expected default converter child op to "
660 "implement TypeConverterBuilderOpInterface";
661 diag.attachNote(maybeTypeConverter->
getLoc()) <<
"op without interface";
665 if (!getPatterns().empty()) {
666 for (
Operation &op : getPatterns().front()) {
668 cast<transform::ConversionPatternDescriptorOpInterface>(&op);
669 if (
failed(descriptor.verifyTypeConverter(typeConverterOp)))
677void transform::ApplyConversionPatternsOp::getEffects(
679 if (!getPreserveHandles()) {
687void transform::ApplyConversionPatternsOp::build(
697 if (patternsBodyBuilder)
698 patternsBodyBuilder(builder,
result.location);
704 if (typeConverterBodyBuilder)
705 typeConverterBodyBuilder(builder,
result.location);
713void transform::ApplyToLLVMConversionPatternsOp::populatePatterns(
716 assert(dialect &&
"expected that dialect is loaded");
717 auto *iface = cast<ConvertToLLVMPatternInterface>(dialect);
721 iface->populateConvertToLLVMConversionPatterns(
725LogicalResult transform::ApplyToLLVMConversionPatternsOp::verifyTypeConverter(
726 transform::TypeConverterBuilderOpInterface builder) {
727 if (builder.getTypeConverterType() !=
"LLVMTypeConverter")
732LogicalResult transform::ApplyToLLVMConversionPatternsOp::verify() {
735 return emitOpError(
"unknown dialect or dialect not loaded: ")
737 auto *iface = dyn_cast<ConvertToLLVMPatternInterface>(dialect);
740 "dialect does not implement ConvertToLLVMPatternInterface or "
741 "extension was not loaded: ")
751transform::ApplyLoopInvariantCodeMotionOp::applyToOne(
761void transform::ApplyLoopInvariantCodeMotionOp::getEffects(
771void transform::ApplyRegisteredPassOp::getEffects(
789 llvm::raw_string_ostream optionsStream(
options);
794 if (
auto paramOperand = dyn_cast<transform::ParamOperandAttr>(valueAttr)) {
797 int64_t dynamicOptionIdx = paramOperand.getIndex().getInt();
798 assert(dynamicOptionIdx <
static_cast<int64_t>(dynamicOptions.size()) &&
799 "the number of ParamOperandAttrs in the options DictionaryAttr"
800 "should be the same as the number of options passed as params");
802 state.
getParams(dynamicOptions[dynamicOptionIdx]);
804 llvm::interleave(attrsAssociatedToParam, optionsStream, appendValueAttr,
806 }
else if (
auto arrayAttr = dyn_cast<ArrayAttr>(valueAttr)) {
808 llvm::interleave(arrayAttr, optionsStream, appendValueAttr,
",");
809 }
else if (
auto strAttr = dyn_cast<StringAttr>(valueAttr)) {
811 optionsStream << strAttr.getValue().str();
814 valueAttr.print(optionsStream,
true);
820 getOptions(), optionsStream,
821 [&](
auto namedAttribute) {
822 optionsStream << namedAttribute.getName().str();
823 optionsStream <<
"=";
824 appendValueAttr(namedAttribute.getValue());
827 optionsStream.flush();
835 <<
"unknown pass or pass pipeline: " << getPassName();
844 <<
"failed to add pass or pass pipeline to pipeline: "
861 auto diag = emitSilenceableError() <<
"pass pipeline failed";
862 diag.attachNote(
target->getLoc()) <<
"target op";
868 results.
set(llvm::cast<OpResult>(getResult()), targets);
877 size_t dynamicOptionsIdx = 0;
883 std::function<ParseResult(
Attribute &)> parseValue =
884 [&](
Attribute &valueAttr) -> ParseResult {
892 [&]() -> ParseResult { return parseValue(attrs.emplace_back()); },
893 " in options dictionary") ||
897 valueAttr = ArrayAttr::get(parser.
getContext(), attrs);
907 ParseResult parsedOperand = parser.
parseOperand(operand);
908 if (failed(parsedOperand))
914 dynamicOptions.push_back(operand);
915 auto wrappedIndex = IntegerAttr::get(
916 IntegerType::get(parser.
getContext(), 64), dynamicOptionsIdx++);
918 transform::ParamOperandAttr::get(parser.
getContext(), wrappedIndex);
919 }
else if (failed(parsedValueAttr.
value())) {
921 }
else if (isa<transform::ParamOperandAttr>(valueAttr)) {
923 <<
"the param_operand attribute is a marker reserved for "
924 <<
"indicating a value will be passed via params and is only used "
925 <<
"in the generic print format";
939 <<
"expected key to either be an identifier or a string";
943 <<
"expected '=' after key in key-value pair";
945 if (failed(parseValue(valueAttr)))
947 <<
"expected a valid attribute or operand as value associated "
948 <<
"to key '" << key <<
"'";
957 " in options dictionary"))
960 if (DictionaryAttr::findDuplicate(
961 keyValuePairs,
false)
964 <<
"duplicate keys found in options dictionary";
979 if (
auto paramOperandAttr =
980 dyn_cast<transform::ParamOperandAttr>(valueAttr)) {
983 dynamicOptions[paramOperandAttr.getIndex().getInt()]);
984 }
else if (
auto arrayAttr = dyn_cast<ArrayAttr>(valueAttr)) {
987 llvm::interleaveComma(arrayAttr, printer, printOptionValue);
996 printer << namedAttribute.
getName();
998 printOptionValue(namedAttribute.
getValue());
1003LogicalResult transform::ApplyRegisteredPassOp::verify() {
1010 std::function<LogicalResult(
Attribute)> checkOptionValue =
1011 [&](
Attribute valueAttr) -> LogicalResult {
1012 if (
auto paramOperand = dyn_cast<transform::ParamOperandAttr>(valueAttr)) {
1013 int64_t dynamicOptionIdx = paramOperand.getIndex().getInt();
1014 if (dynamicOptionIdx < 0 ||
1015 dynamicOptionIdx >=
static_cast<int64_t>(dynamicOptions.size()))
1017 <<
"dynamic option index " << dynamicOptionIdx
1018 <<
" is out of bounds for the number of dynamic options: "
1019 << dynamicOptions.size();
1020 if (dynamicOptions[dynamicOptionIdx] ==
nullptr)
1021 return emitOpError() <<
"dynamic option index " << dynamicOptionIdx
1022 <<
" is already used in options";
1023 dynamicOptions[dynamicOptionIdx] =
nullptr;
1024 }
else if (
auto arrayAttr = dyn_cast<ArrayAttr>(valueAttr)) {
1026 for (
auto eltAttr : arrayAttr)
1027 if (
failed(checkOptionValue(eltAttr)))
1034 if (
failed(checkOptionValue(namedAttr.getValue())))
1038 for (
Value dynamicOption : dynamicOptions)
1040 return emitOpError() <<
"a param operand does not have a corresponding "
1041 <<
"param_operand attr in the options dict";
1054 results.push_back(
target);
1058void transform::CastOp::getEffects(
1066 assert(inputs.size() == 1 &&
"expected one input");
1067 assert(outputs.size() == 1 &&
"expected one output");
1068 return llvm::all_of(
1069 std::initializer_list<Type>{inputs.front(), outputs.front()},
1070 llvm::IsaPred<transform::TransformHandleTypeInterface>);
1090 assert(block.
getParent() &&
"cannot match using a detached block");
1097 if (!isa<transform::MatchOpInterface>(match)) {
1099 <<
"expected operations in the match part to "
1100 "implement MatchOpInterface";
1103 state.
applyTransform(cast<transform::TransformOpInterface>(match));
1104 if (
diag.succeeded())
1122template <
typename... Tys>
1124 return ((isa<Tys>(t1) && isa<Tys>(t2)) || ... ||
false);
1131 transform::TransformParamTypeInterface,
1132 transform::TransformValueHandleTypeInterface>(
1145 getOperation(), getMatcher());
1146 if (matcher.isExternal()) {
1148 <<
"unresolved external symbol " << getMatcher();
1152 rawResults.resize(getOperation()->getNumResults());
1153 std::optional<DiagnosedSilenceableFailure> maybeFailure;
1165 matcher.getFunctionBody().front(),
1168 if (
diag.isDefiniteFailure())
1170 if (
diag.isSilenceableFailure()) {
1172 <<
" failed: " <<
diag.getMessage();
1177 for (
auto &&[i, mapping] : llvm::enumerate(mappings)) {
1178 if (mapping.size() != 1) {
1179 maybeFailure.emplace(emitSilenceableError()
1180 <<
"result #" << i <<
", associated with "
1182 <<
" payload objects, expected 1");
1185 rawResults[i].push_back(mapping[0]);
1190 return std::move(*maybeFailure);
1191 assert(!maybeFailure &&
"failure set but the walk was not interrupted");
1193 for (
auto &&[opResult, rawResult] :
1194 llvm::zip_equal(getOperation()->getResults(), rawResults)) {
1201void transform::CollectMatchingOp::getEffects(
1208LogicalResult transform::CollectMatchingOp::verifySymbolUses(
1210 auto matcherSymbol = dyn_cast_or_null<FunctionOpInterface>(
1212 if (!matcherSymbol ||
1213 !isa<TransformOpInterface>(matcherSymbol.getOperation()))
1214 return emitError() <<
"unresolved matcher symbol " << getMatcher();
1217 if (argumentTypes.size() != 1 ||
1218 !isa<TransformHandleTypeInterface>(argumentTypes[0])) {
1220 <<
"expected the matcher to take one operation handle argument";
1222 if (!matcherSymbol.getArgAttr(
1223 0, transform::TransformDialect::kArgReadOnlyAttrName)) {
1224 return emitError() <<
"expected the matcher argument to be marked readonly";
1228 if (resultTypes.size() != getOperation()->getNumResults()) {
1230 <<
"expected the matcher to yield as many values as op has results ("
1231 << getOperation()->getNumResults() <<
"), got "
1232 << resultTypes.size();
1235 for (
auto &&[i, matcherType, resultType] :
1236 llvm::enumerate(resultTypes, getOperation()->getResultTypes())) {
1241 <<
"mismatching type interfaces for matcher result and op result #"
1253bool transform::ForeachMatchOp::allowsRepeatedHandleOperands() {
return true; }
1261 matchActionPairs.reserve(getMatchers().size());
1263 for (
auto &&[matcher, action] :
1264 llvm::zip_equal(getMatchers(), getActions())) {
1265 auto matcherSymbol =
1267 getOperation(), cast<SymbolRefAttr>(matcher));
1270 getOperation(), cast<SymbolRefAttr>(action));
1271 assert(matcherSymbol && actionSymbol &&
1272 "unresolved symbols not caught by the verifier");
1274 if (matcherSymbol.isExternal())
1276 if (actionSymbol.isExternal())
1279 matchActionPairs.emplace_back(matcherSymbol, actionSymbol);
1290 matchInputMapping.emplace_back();
1292 getForwardedInputs(), state);
1294 actionResultMapping.resize(getForwardedOutputs().size());
1300 if (!getRestrictRoot() && op == root)
1308 firstMatchArgument.clear();
1309 firstMatchArgument.push_back(op);
1312 for (
auto [matcher, action] : matchActionPairs) {
1314 matchBlock(matcher.getFunctionBody().front(), matchInputMapping,
1315 state, matchOutputMapping);
1316 if (
diag.isDefiniteFailure())
1318 if (
diag.isSilenceableFailure()) {
1320 <<
" failed: " <<
diag.getMessage();
1326 action.getFunctionBody().front().getArguments(),
1327 matchOutputMapping))) {
1332 action.getFunctionBody().front().without_terminator()) {
1335 if (
result.isDefiniteFailure())
1337 if (
result.isSilenceableFailure()) {
1339 overallDiag = emitSilenceableError() <<
"actions failed";
1342 <<
"failed action: " <<
result.getMessage();
1344 <<
"when applied to this matching payload";
1349 if (
failed(detail::appendValueMappings(
1351 action.getFunctionBody().front().getTerminator()->getOperands(),
1352 state, getFlattenResults()))) {
1354 <<
"action @" << action.getName()
1355 <<
" has results associated with multiple payload entities, "
1356 "but flattening was not requested";
1371 results.
set(llvm::cast<OpResult>(getUpdated()),
1373 for (
auto &&[
result, mapping] :
1374 llvm::zip_equal(getForwardedOutputs(), actionResultMapping)) {
1380void transform::ForeachMatchOp::getAsmResultNames(
1382 setNameFn(getUpdated(),
"updated_root");
1383 for (
Value v : getForwardedOutputs()) {
1384 setNameFn(v,
"yielded");
1388void transform::ForeachMatchOp::getEffects(
1391 if (getOperation()->getNumOperands() < 1 ||
1392 getOperation()->getNumResults() < 1) {
1416 matcherList.push_back(SymbolRefAttr::get(matcher));
1417 actionList.push_back(SymbolRefAttr::get(action));
1431 for (
auto &&[matcher, action, idx] : llvm::zip_equal(
1434 printer << cast<SymbolRefAttr>(matcher) <<
" -> "
1435 << cast<SymbolRefAttr>(action);
1443LogicalResult transform::ForeachMatchOp::verify() {
1444 if (getMatchers().size() != getActions().size())
1445 return emitOpError() <<
"expected the same number of matchers and actions";
1446 if (getMatchers().empty())
1447 return emitOpError() <<
"expected at least one match/action pair";
1451 if (matcherNames.insert(name).second)
1454 <<
" is used more than once, only the first match will apply";
1465 bool alsoVerifyInternal =
false) {
1466 auto transformOp = cast<transform::TransformOpInterface>(op.getOperation());
1467 llvm::SmallDenseSet<unsigned> consumedArguments;
1468 if (!op.isExternal()) {
1472 for (
unsigned i = 0, e = op.getNumArguments(); i < e; ++i) {
1474 op.getArgAttr(i, transform::TransformDialect::kArgConsumedAttrName) !=
1477 op.getArgAttr(i, transform::TransformDialect::kArgReadOnlyAttrName) !=
1479 if (isConsumed && isReadOnly) {
1480 return transformOp.emitSilenceableError()
1481 <<
"argument #" << i <<
" cannot be both readonly and consumed";
1483 if ((op.isExternal() || alsoVerifyInternal) && !isConsumed && !isReadOnly) {
1484 return transformOp.emitSilenceableError()
1485 <<
"must provide consumed/readonly status for arguments of "
1486 "external or called ops";
1488 if (op.isExternal())
1491 if (consumedArguments.contains(i) && !isConsumed && isReadOnly) {
1492 return transformOp.emitSilenceableError()
1493 <<
"argument #" << i
1494 <<
" is consumed in the body but is not marked as such";
1496 if (emitWarnings && !consumedArguments.contains(i) && isConsumed) {
1500 <<
"op argument #" << i
1501 <<
" is not consumed in the body but is marked as consumed";
1507LogicalResult transform::ForeachMatchOp::verifySymbolUses(
1509 assert(getMatchers().size() == getActions().size());
1511 StringAttr::get(
getContext(), TransformDialect::kArgConsumedAttrName);
1512 for (
auto &&[matcher, action] :
1513 llvm::zip_equal(getMatchers(), getActions())) {
1515 auto matcherSymbol = dyn_cast_or_null<FunctionOpInterface>(
1517 cast<SymbolRefAttr>(matcher)));
1518 auto actionSymbol = dyn_cast_or_null<FunctionOpInterface>(
1520 cast<SymbolRefAttr>(action)));
1521 if (!matcherSymbol ||
1522 !isa<TransformOpInterface>(matcherSymbol.getOperation()))
1523 return emitError() <<
"unresolved matcher symbol " << matcher;
1524 if (!actionSymbol ||
1525 !isa<TransformOpInterface>(actionSymbol.getOperation()))
1526 return emitError() <<
"unresolved action symbol " << action;
1531 .checkAndReport())) {
1537 .checkAndReport())) {
1542 TypeRange operandTypes = getOperandTypes();
1543 TypeRange matcherArguments = matcherSymbol.getArgumentTypes();
1544 if (operandTypes.size() != matcherArguments.size()) {
1546 emitError() <<
"the number of operands (" << operandTypes.size()
1547 <<
") doesn't match the number of matcher arguments ("
1548 << matcherArguments.size() <<
") for " << matcher;
1549 diag.attachNote(matcherSymbol->getLoc()) <<
"symbol declaration";
1552 for (
auto &&[i, operand, argument] :
1553 llvm::enumerate(operandTypes, matcherArguments)) {
1554 if (matcherSymbol.getArgAttr(i, consumedAttr)) {
1557 <<
"does not expect matcher symbol to consume its operand #" << i;
1558 diag.attachNote(matcherSymbol->getLoc()) <<
"symbol declaration";
1567 <<
"mismatching type interfaces for operand and matcher argument #"
1568 << i <<
" of matcher " << matcher;
1569 diag.attachNote(matcherSymbol->getLoc()) <<
"symbol declaration";
1574 TypeRange matcherResults = matcherSymbol.getResultTypes();
1575 TypeRange actionArguments = actionSymbol.getArgumentTypes();
1576 if (matcherResults.size() != actionArguments.size()) {
1577 return emitError() <<
"mismatching number of matcher results and "
1578 "action arguments between "
1579 << matcher <<
" (" << matcherResults.size() <<
") and "
1580 << action <<
" (" << actionArguments.size() <<
")";
1582 for (
auto &&[i, matcherType, actionType] :
1583 llvm::enumerate(matcherResults, actionArguments)) {
1587 return emitError() <<
"mismatching type interfaces for matcher result "
1588 "and action argument #"
1589 << i <<
"of matcher " << matcher <<
" and action "
1594 TypeRange actionResults = actionSymbol.getResultTypes();
1595 auto resultTypes =
TypeRange(getResultTypes()).drop_front();
1596 if (actionResults.size() != resultTypes.size()) {
1598 emitError() <<
"the number of action results ("
1599 << actionResults.size() <<
") for " << action
1600 <<
" doesn't match the number of extra op results ("
1601 << resultTypes.size() <<
")";
1602 diag.attachNote(actionSymbol->getLoc()) <<
"symbol declaration";
1605 for (
auto &&[i, resultType, actionType] :
1606 llvm::enumerate(resultTypes, actionResults)) {
1611 emitError() <<
"mismatching type interfaces for action result #" << i
1612 <<
" of action " << action <<
" and op result";
1613 diag.attachNote(actionSymbol->getLoc()) <<
"symbol declaration";
1631 detail::prepareValueMappings(payloads, getTargets(), state);
1632 size_t numIterations = payloads.empty() ? 0 : payloads.front().size();
1633 bool withZipShortest = getWithZipShortest();
1637 if (withZipShortest) {
1641 return a.size() <
b.size();
1644 for (
auto &payload : payloads)
1645 payload.resize(numIterations);
1651 for (
size_t argIdx = 1; !withZipShortest && argIdx < payloads.size();
1653 if (payloads[argIdx].size() != numIterations) {
1654 return emitSilenceableError()
1655 <<
"prior targets' payload size (" << numIterations
1656 <<
") differs from payload size (" << payloads[argIdx].size()
1657 <<
") of target " << getTargets()[argIdx];
1666 for (
size_t iterIdx = 0; iterIdx < numIterations; iterIdx++) {
1669 for (
auto &&[argIdx, blockArg] : llvm::enumerate(blockArguments)) {
1680 llvm::cast<transform::TransformOpInterface>(
transform));
1686 OperandRange yieldOperands = getYieldOp().getOperands();
1687 for (
auto &&[
result, yieldOperand, resTuple] :
1688 llvm::zip_equal(getResults(), yieldOperands, zippedResults))
1690 if (isa<TransformHandleTypeInterface>(
result.getType()))
1691 llvm::append_range(resTuple, state.
getPayloadOps(yieldOperand));
1692 else if (isa<TransformValueHandleTypeInterface>(
result.getType()))
1694 else if (isa<TransformParamTypeInterface>(
result.getType()))
1695 llvm::append_range(resTuple, state.
getParams(yieldOperand));
1697 assert(
false &&
"unhandled handle type");
1701 for (
auto &&[
result, resPayload] : zip_equal(getResults(), zippedResults))
1707void transform::ForeachOp::getEffects(
1711 for (
auto &&[
target, blockArg] :
1712 llvm::zip(getTargetsMutable(), getBody().front().getArguments())) {
1714 if (any_of(getBody().front().without_terminator(), [&](
Operation &op) {
1716 cast<TransformOpInterface>(&op));
1724 if (any_of(getBody().front().without_terminator(), [&](
Operation &op) {
1728 }
else if (any_of(getBody().front().without_terminator(), [&](
Operation &op) {
1737void transform::ForeachOp::getSuccessorRegions(
1739 Region *bodyRegion = &getBody();
1741 regions.emplace_back(bodyRegion, bodyRegion->
getArguments());
1748 "unexpected region index");
1749 regions.emplace_back(bodyRegion, bodyRegion->
getArguments());
1750 regions.emplace_back(getOperation(), getOperation()->getResults());
1754transform::ForeachOp::getEntrySuccessorOperands(
RegionSuccessor successor) {
1757 assert(successor.
getSuccessor() == &getBody() &&
"unexpected region index");
1758 return getOperation()->getOperands();
1761transform::YieldOp transform::ForeachOp::getYieldOp() {
1762 return cast<transform::YieldOp>(getBody().front().getTerminator());
1765LogicalResult transform::ForeachOp::verify() {
1766 for (
auto [targetOpt, bodyArgOpt] :
1767 llvm::zip_longest(getTargets(), getBody().front().getArguments())) {
1768 if (!targetOpt || !bodyArgOpt)
1769 return emitOpError() <<
"expects the same number of targets as the body "
1770 "has block arguments";
1771 if (targetOpt.value().getType() != bodyArgOpt.value().getType())
1773 "expects co-indexed targets and the body's "
1774 "block arguments to have the same op/value/param type");
1777 for (
auto [resultOpt, yieldOperandOpt] :
1778 llvm::zip_longest(getResults(), getYieldOp().getOperands())) {
1779 if (!resultOpt || !yieldOperandOpt)
1780 return emitOpError() <<
"expects the same number of results as the "
1781 "yield terminator has operands";
1782 if (resultOpt.value().getType() != yieldOperandOpt.value().getType())
1783 return emitOpError(
"expects co-indexed results and yield "
1784 "operands to have the same op/value/param type");
1802 for (
int64_t i = 0, e = getNthParent(); i < e; ++i) {
1805 bool checkIsolatedFromAbove =
1806 !getIsolatedFromAbove() ||
1808 bool checkOpName = !getOpName().has_value() ||
1810 if (checkIsolatedFromAbove && checkOpName)
1815 if (getAllowEmptyResults()) {
1816 results.
set(llvm::cast<OpResult>(getResult()), parents);
1820 emitSilenceableError()
1821 <<
"could not find a parent op that matches all requirements";
1822 diag.attachNote(
target->getLoc()) <<
"target op";
1826 if (getDeduplicate()) {
1827 if (resultSet.insert(parent).second)
1828 parents.push_back(parent);
1830 parents.push_back(parent);
1833 results.
set(llvm::cast<OpResult>(getResult()), parents);
1845 int64_t resultNumber = getResultNumber();
1847 if (std::empty(payloadOps)) {
1848 results.
set(cast<OpResult>(getResult()), {});
1851 if (!llvm::hasSingleElement(payloadOps))
1853 <<
"handle must be mapped to exactly one payload op";
1856 if (
target->getNumResults() <= resultNumber)
1858 results.
set(llvm::cast<OpResult>(getResult()),
1859 llvm::to_vector(
target->getResult(resultNumber).getUsers()));
1873 if (llvm::isa<BlockArgument>(v)) {
1875 emitSilenceableError() <<
"cannot get defining op of block argument";
1876 diag.attachNote(v.getLoc()) <<
"target value";
1879 definingOps.push_back(v.getDefiningOp());
1881 results.
set(llvm::cast<OpResult>(getResult()), definingOps);
1893 int64_t operandNumber = getOperandNumber();
1897 target->getNumOperands() <= operandNumber
1899 :
target->getOperand(operandNumber).getDefiningOp();
1902 emitSilenceableError()
1903 <<
"could not find a producer for operand number: " << operandNumber
1905 diag.attachNote(
target->getLoc()) <<
"target op";
1908 producers.push_back(producer);
1910 results.
set(llvm::cast<OpResult>(getResult()), producers);
1926 getLoc(), getIsAll(), getIsInverted(), getRawPositionList(),
1927 target->getNumOperands(), operandPositions);
1928 if (
diag.isSilenceableFailure()) {
1930 <<
"while considering positions of this payload operation";
1933 llvm::append_range(operands,
1934 llvm::map_range(operandPositions, [&](
int64_t pos) {
1935 return target->getOperand(pos);
1938 results.
setValues(cast<OpResult>(getResult()), operands);
1942LogicalResult transform::GetOperandOp::verify() {
1944 getIsInverted(), getIsAll());
1959 getLoc(), getIsAll(), getIsInverted(), getRawPositionList(),
1960 target->getNumResults(), resultPositions);
1961 if (
diag.isSilenceableFailure()) {
1963 <<
"while considering positions of this payload operation";
1966 llvm::append_range(opResults,
1967 llvm::map_range(resultPositions, [&](
int64_t pos) {
1968 return target->getResult(pos);
1971 results.
setValues(cast<OpResult>(getResult()), opResults);
1975LogicalResult transform::GetResultOp::verify() {
1977 getIsInverted(), getIsAll());
1984void transform::GetTypeOp::getEffects(
1997 Type type = value.getType();
1998 if (getElemental()) {
1999 if (
auto shaped = dyn_cast<ShapedType>(type)) {
2000 type = shaped.getElementType();
2003 params.push_back(TypeAttr::get(type));
2005 results.
setParams(cast<OpResult>(getResult()), params);
2023 if (
result.isDefiniteFailure())
2026 if (
result.isSilenceableFailure()) {
2027 if (mode == transform::FailurePropagationMode::Propagate) {
2047 getOperation(), getTarget());
2048 assert(callee &&
"unverified reference to unknown symbol");
2050 if (callee.isExternal())
2055 detail::prepareValueMappings(mappings, getOperands(), state);
2057 for (
auto &&[arg, map] :
2058 llvm::zip_equal(callee.getBody().front().getArguments(), mappings)) {
2064 callee.getBody().front(), getFailurePropagationMode(), state, results);
2070 detail::prepareValueMappings(
2071 mappings, callee.getBody().front().getTerminator()->getOperands(), state);
2072 for (
auto &&[
result, mapping] : llvm::zip_equal(getResults(), mappings))
2080void transform::IncludeOp::getEffects(
2095 auto defaultEffects = [&] {
2102 getOperation()->getAttrOfType<SymbolRefAttr>(getTargetAttrName());
2104 return defaultEffects();
2106 getOperation(), getTarget());
2108 return defaultEffects();
2110 for (
unsigned i = 0, e = getNumOperands(); i < e; ++i) {
2111 if (callee.getArgAttr(i, TransformDialect::kArgConsumedAttrName))
2113 else if (callee.getArgAttr(i, TransformDialect::kArgReadOnlyAttrName))
2122 auto targetAttr = getOperation()->getAttrOfType<SymbolRefAttr>(
"target");
2124 return emitOpError() <<
"expects a 'target' symbol reference attribute";
2129 return emitOpError() <<
"does not reference a named transform sequence";
2131 FunctionType fnType =
target.getFunctionType();
2132 if (fnType.getNumInputs() != getNumOperands())
2133 return emitError(
"incorrect number of operands for callee");
2135 for (
unsigned i = 0, e = fnType.getNumInputs(); i != e; ++i) {
2136 if (getOperand(i).
getType() != fnType.getInput(i)) {
2137 return emitOpError(
"operand type mismatch: expected operand type ")
2138 << fnType.getInput(i) <<
", but provided "
2139 << getOperand(i).getType() <<
" for operand number " << i;
2143 if (fnType.getNumResults() != getNumResults())
2144 return emitError(
"incorrect number of results for callee");
2146 for (
unsigned i = 0, e = fnType.getNumResults(); i != e; ++i) {
2147 Type resultType = getResult(i).getType();
2148 Type funcType = fnType.getResult(i);
2151 <<
" must implement the same transform dialect "
2152 "interface as the corresponding callee result";
2157 cast<FunctionOpInterface>(*
target),
false,
2167 ::std::optional<::mlir::Operation *> maybeCurrent,
2169 if (!maybeCurrent.has_value()) {
2174 return emitSilenceableError() <<
"operation is not empty";
2185 for (
auto acceptedAttr : getOpNames().getAsRange<StringAttr>()) {
2186 if (acceptedAttr.getValue() == currentOpName)
2189 return emitSilenceableError() <<
"wrong operation name";
2200 auto signedAPIntAsString = [&](
const APInt &value) {
2202 llvm::raw_string_ostream os(str);
2203 value.print(os,
true);
2210 if (params.size() != references.size()) {
2211 return emitSilenceableError()
2212 <<
"parameters have different payload lengths (" << params.size()
2213 <<
" vs " << references.size() <<
")";
2216 for (
auto &&[i, param, reference] : llvm::enumerate(params, references)) {
2217 auto intAttr = llvm::dyn_cast<IntegerAttr>(param);
2218 auto refAttr = llvm::dyn_cast<IntegerAttr>(reference);
2219 if (!intAttr || !refAttr) {
2221 <<
"non-integer parameter value not expected";
2223 if (intAttr.getType() != refAttr.getType()) {
2225 <<
"mismatching integer attribute types in parameter #" << i;
2227 APInt value = intAttr.getValue();
2228 APInt refValue = refAttr.getValue();
2232 auto reportError = [&](StringRef direction) {
2234 emitSilenceableError() <<
"expected parameter to be " << direction
2235 <<
" " << signedAPIntAsString(refValue)
2236 <<
", got " << signedAPIntAsString(value);
2237 diag.attachNote(getParam().getLoc())
2238 <<
"value # " << position
2239 <<
" associated with the parameter defined here";
2243 switch (getPredicate()) {
2244 case MatchCmpIPredicate::eq:
2245 if (value.eq(refValue))
2247 return reportError(
"equal to");
2248 case MatchCmpIPredicate::ne:
2249 if (value.ne(refValue))
2251 return reportError(
"not equal to");
2252 case MatchCmpIPredicate::lt:
2253 if (value.slt(refValue))
2255 return reportError(
"less than");
2256 case MatchCmpIPredicate::le:
2257 if (value.sle(refValue))
2259 return reportError(
"less than or equal to");
2260 case MatchCmpIPredicate::gt:
2261 if (value.sgt(refValue))
2263 return reportError(
"greater than");
2264 case MatchCmpIPredicate::ge:
2265 if (value.sge(refValue))
2267 return reportError(
"greater than or equal to");
2273void transform::MatchParamCmpIOp::getEffects(
2287 results.
setParams(cast<OpResult>(getParam()), {getValue()});
2300 if (isa<TransformHandleTypeInterface>(handles.front().
getType())) {
2302 for (
Value operand : handles)
2303 llvm::append_range(operations, state.
getPayloadOps(operand));
2304 if (!getDeduplicate()) {
2305 results.
set(llvm::cast<OpResult>(getResult()), operations);
2310 results.
set(llvm::cast<OpResult>(getResult()), uniqued.getArrayRef());
2314 if (llvm::isa<TransformParamTypeInterface>(handles.front().getType())) {
2316 for (
Value attribute : handles)
2317 llvm::append_range(attrs, state.
getParams(attribute));
2318 if (!getDeduplicate()) {
2319 results.
setParams(cast<OpResult>(getResult()), attrs);
2324 results.
setParams(cast<OpResult>(getResult()), uniqued.getArrayRef());
2329 llvm::isa<TransformValueHandleTypeInterface>(handles.front().getType()) &&
2330 "expected value handle type");
2332 for (
Value value : handles)
2334 if (!getDeduplicate()) {
2335 results.
setValues(cast<OpResult>(getResult()), payloadValues);
2340 results.
setValues(cast<OpResult>(getResult()), uniqued.getArrayRef());
2344bool transform::MergeHandlesOp::allowsRepeatedHandleOperands() {
2346 return getDeduplicate();
2349void transform::MergeHandlesOp::getEffects(
2358OpFoldResult transform::MergeHandlesOp::fold(FoldAdaptor adaptor) {
2359 if (getDeduplicate() || getHandles().size() != 1)
2364 return getHandles().front();
2383 if (
failed(detail::mapPossibleTopLevelTransformOpBlockArguments(
2384 state, this->getOperation(), getBody())))
2388 FailurePropagationMode::Propagate, state, results);
2391void transform::NamedSequenceOp::getEffects(
2394ParseResult transform::NamedSequenceOp::parse(
OpAsmParser &parser,
2398 getFunctionTypeAttrName(
result.name),
2401 std::string &) { return builder.getFunctionType(inputs, results); },
2402 getArgAttrsAttrName(
result.name), getResAttrsAttrName(
result.name));
2405void transform::NamedSequenceOp::print(
OpAsmPrinter &printer) {
2407 printer, cast<FunctionOpInterface>(getOperation()),
false,
2408 getFunctionTypeAttrName().getValue(), getArgAttrsAttrName(),
2409 getResAttrsAttrName());
2419 if (
auto parent = op->
getParentOfType<transform::TransformOpInterface>()) {
2422 <<
"cannot be defined inside another transform op";
2423 diag.attachNote(parent.
getLoc()) <<
"ancestor transform op";
2427 if (op.isExternal() || op.getFunctionBody().empty()) {
2434 if (op.getFunctionBody().front().empty())
2437 Operation *terminator = &op.getFunctionBody().front().back();
2438 if (!isa<transform::YieldOp>(terminator)) {
2441 << transform::YieldOp::getOperationName()
2442 <<
"' as terminator";
2443 diag.attachNote(terminator->
getLoc()) <<
"terminator";
2447 if (terminator->
getNumOperands() != op.getResultTypes().size()) {
2449 <<
"expected terminator to have as many operands as the parent op "
2452 for (
auto [i, operandType, resultType] : llvm::zip_equal(
2455 if (operandType == resultType)
2458 <<
"the type of the terminator operand #" << i
2459 <<
" must match the type of the corresponding parent op result ("
2460 << operandType <<
" vs " << resultType <<
")";
2473 transform::TransformDialect::kWithNamedSequenceAttrName)) {
2476 <<
"expects the parent symbol table to have the '"
2477 << transform::TransformDialect::kWithNamedSequenceAttrName
2479 diag.attachNote(parent->
getLoc()) <<
"symbol table operation";
2484 if (
auto parent = op->
getParentOfType<transform::TransformOpInterface>()) {
2487 <<
"cannot be defined inside another transform op";
2488 diag.attachNote(parent.
getLoc()) <<
"ancestor transform op";
2492 if (op.isExternal() || op.getBody().empty())
2496 if (op.getBody().front().empty())
2499 Operation *terminator = &op.getBody().front().back();
2500 if (!isa<transform::YieldOp>(terminator)) {
2503 << transform::YieldOp::getOperationName()
2504 <<
"' as terminator";
2505 diag.attachNote(terminator->
getLoc()) <<
"terminator";
2509 if (terminator->
getNumOperands() != op.getFunctionType().getNumResults()) {
2511 <<
"expected terminator to have as many operands as the parent op "
2514 for (
auto [i, operandType, resultType] :
2515 llvm::zip_equal(llvm::seq<unsigned>(0, terminator->
getNumOperands()),
2517 op.getFunctionType().getResults())) {
2518 if (operandType == resultType)
2521 <<
"the type of the terminator operand #" << i
2522 <<
" must match the type of the corresponding parent op result ("
2523 << operandType <<
" vs " << resultType <<
")";
2526 auto funcOp = cast<FunctionOpInterface>(*op);
2529 if (!
diag.succeeded())
2536LogicalResult transform::NamedSequenceOp::verify() {
2541template <
typename FnTy>
2546 types.reserve(1 + extraBindingTypes.size());
2547 types.push_back(bbArgType);
2548 llvm::append_range(types, extraBindingTypes);
2558 if constexpr (llvm::function_traits<FnTy>::num_args == 3) {
2566void transform::NamedSequenceOp::build(
OpBuilder &builder,
2569 SequenceBodyBuilderFn bodyBuilder,
2575 TypeAttr::get(FunctionType::get(builder.
getContext(),
2576 rootType, resultTypes)));
2592 size_t numAssociations =
2594 .Case([&](TransformHandleTypeInterface opHandle) {
2597 .Case([&](TransformValueHandleTypeInterface valueHandle) {
2600 .Case([&](TransformParamTypeInterface param) {
2601 return llvm::range_size(state.
getParams(getHandle()));
2603 .DefaultUnreachable(
"unknown kind of transform dialect type");
2604 results.
setParams(cast<OpResult>(getNum()),
2609LogicalResult transform::NumAssociationsOp::verify() {
2611 auto resultType = cast<TransformParamTypeInterface>(getNum().
getType());
2631 results.
set(cast<OpResult>(getResult()),
result);
2651 .Case<TransformHandleTypeInterface>([&](
auto x) {
2654 .Case<TransformValueHandleTypeInterface>([&](
auto x) {
2657 .Case<TransformParamTypeInterface>([&](
auto x) {
2658 return llvm::range_size(state.
getParams(getHandle()));
2660 .DefaultUnreachable(
"unknown transform dialect type interface");
2662 auto produceNumOpsError = [&]() {
2663 return emitSilenceableError()
2664 << getHandle() <<
" expected to contain " << this->getNumResults()
2665 <<
" payloads but it contains " << numPayloads <<
" payloads";
2670 if (numPayloads > getNumResults() && !getOverflowResult().has_value())
2671 return produceNumOpsError();
2676 if (numPayloads < getNumResults() && getFailOnPayloadTooSmall() &&
2677 (numPayloads != 0 || !getPassThroughEmptyHandle()))
2678 return produceNumOpsError();
2682 if (getOverflowResult())
2683 resultHandles[*getOverflowResult()].reserve(numPayloads - getNumResults());
2685 auto container = [&]() {
2686 if (isa<TransformHandleTypeInterface>(getHandle().
getType())) {
2687 return llvm::map_to_vector(
2689 [](
Operation *op) -> MappedValue {
return op; });
2691 if (isa<TransformValueHandleTypeInterface>(getHandle().
getType())) {
2693 [](
Value v) -> MappedValue {
return v; });
2695 assert(isa<TransformParamTypeInterface>(getHandle().
getType()) &&
2696 "unsupported kind of transform dialect type");
2697 return llvm::map_to_vector(state.
getParams(getHandle()),
2698 [](
Attribute a) -> MappedValue {
return a; });
2701 for (
auto &&en : llvm::enumerate(container)) {
2702 int64_t resultNum = en.index();
2703 if (resultNum >= getNumResults())
2704 resultNum = *getOverflowResult();
2705 resultHandles[resultNum].push_back(en.value());
2709 for (
auto &&it : llvm::enumerate(resultHandles))
2716void transform::SplitHandleOp::getEffects(
2724LogicalResult transform::SplitHandleOp::verify() {
2725 if (getOverflowResult().has_value() &&
2726 !(*getOverflowResult() < getNumResults()))
2727 return emitOpError(
"overflow_result is not a valid result index");
2729 for (
Type resultType : getResultTypes()) {
2733 return emitOpError(
"expects result types to implement the same transform "
2734 "interface as the operand type");
2748 unsigned numRepetitions = llvm::range_size(state.
getPayloadOps(getPattern()));
2749 for (
const auto &en : llvm::enumerate(getHandles())) {
2750 Value handle = en.value();
2751 if (isa<TransformHandleTypeInterface>(handle.getType())) {
2755 payload.reserve(numRepetitions * current.size());
2756 for (
unsigned i = 0; i < numRepetitions; ++i)
2757 llvm::append_range(payload, current);
2758 results.
set(llvm::cast<OpResult>(getReplicated()[en.index()]), payload);
2760 assert(llvm::isa<TransformParamTypeInterface>(handle.getType()) &&
2761 "expected param type");
2764 params.reserve(numRepetitions * current.size());
2765 for (
unsigned i = 0; i < numRepetitions; ++i)
2766 llvm::append_range(params, current);
2767 results.
setParams(llvm::cast<OpResult>(getReplicated()[en.index()]),
2774void transform::ReplicateOp::getEffects(
2791 if (
failed(mapBlockArguments(state)))
2799 OpAsmParser &parser, std::optional<OpAsmParser::UnresolvedOperand> &root,
2806 root = std::nullopt;
2809 if (failed(hasRoot.
value()))
2823 if (failed(parser.
parseType(rootType))) {
2827 if (!extraBindings.empty()) {
2832 if (extraBindingTypes.size() != extraBindings.size()) {
2834 "expected types to be provided for all operands");
2850 bool hasExtras = !extraBindings.empty();
2860 printer << rootType;
2862 printer <<
", " << llvm::interleaved(extraBindingTypes) <<
')';
2869 auto iface = dyn_cast<transform::TransformOpInterface>(use.
getOwner());
2873 return isHandleConsumed(use.
get(), iface);
2884 if (!potentialConsumer) {
2885 potentialConsumer = &use;
2890 <<
" has more than one potential consumer";
2893 diag.attachNote(use.getOwner()->getLoc())
2894 <<
"used here as operand #" << use.getOperandNumber();
2901LogicalResult transform::SequenceOp::verify() {
2902 assert(getBodyBlock()->getNumArguments() >= 1 &&
2903 "the number of arguments must have been verified to be more than 1 by "
2904 "PossibleTopLevelTransformOpTrait");
2906 if (!getRoot() && !getExtraBindings().empty()) {
2908 <<
"does not expect extra operands when used as top-level";
2914 return (
emitOpError() <<
"block argument #" << arg.getArgNumber());
2921 for (
Operation &child : *getBodyBlock()) {
2922 if (!isa<TransformOpInterface>(child) &&
2923 &child != &getBodyBlock()->back()) {
2926 <<
"expected children ops to implement TransformOpInterface";
2927 diag.attachNote(child.getLoc()) <<
"op without interface";
2932 auto report = [&]() {
2933 return (child.emitError() <<
"result #" <<
result.getResultNumber());
2940 if (!getBodyBlock()->mightHaveTerminator())
2941 return emitOpError() <<
"expects to have a terminator in the body";
2943 if (getBodyBlock()->getTerminator()->getOperandTypes() !=
2944 getOperation()->getResultTypes()) {
2946 <<
"expects the types of the terminator operands "
2947 "to match the types of the result";
2948 diag.attachNote(getBodyBlock()->getTerminator()->getLoc()) <<
"terminator";
2954void transform::SequenceOp::getEffects(
2960transform::SequenceOp::getEntrySuccessorOperands(
RegionSuccessor successor) {
2961 assert(successor.
getSuccessor() == &getBody() &&
"unexpected region index");
2962 if (getOperation()->getNumOperands() > 0)
2963 return getOperation()->getOperands();
2965 getOperation()->operand_end());
2968void transform::SequenceOp::getSuccessorRegions(
2971 Region *bodyRegion = &getBody();
2972 regions.emplace_back(bodyRegion, getNumOperands() != 0
2980 "unexpected region index");
2981 regions.emplace_back(getOperation(), getOperation()->getResults());
2984void transform::SequenceOp::getRegionInvocationBounds(
2987 bounds.emplace_back(1, 1);
2992 FailurePropagationMode failurePropagationMode,
2994 SequenceBodyBuilderFn bodyBuilder) {
2995 build(builder, state, resultTypes, failurePropagationMode, root,
3004 FailurePropagationMode failurePropagationMode,
3006 SequenceBodyBuilderArgsFn bodyBuilder) {
3007 build(builder, state, resultTypes, failurePropagationMode, root,
3015 FailurePropagationMode failurePropagationMode,
3017 SequenceBodyBuilderFn bodyBuilder) {
3018 build(builder, state, resultTypes, failurePropagationMode,
Value(),
3026 FailurePropagationMode failurePropagationMode,
3028 SequenceBodyBuilderArgsFn bodyBuilder) {
3029 build(builder, state, resultTypes, failurePropagationMode,
Value(),
3047 build(builder,
result, name);
3054 llvm::outs() <<
"[[[ IR printer: ";
3055 if (getName().has_value())
3056 llvm::outs() << *getName() <<
" ";
3059 if (getAssumeVerified().value_or(
false))
3061 if (getUseLocalScope().value_or(
false))
3063 if (getSkipRegions().value_or(
false))
3067 llvm::outs() <<
"top-level ]]]\n";
3069 llvm::outs() <<
"\n";
3070 llvm::outs().flush();
3074 llvm::outs() <<
"]]]\n";
3076 target->print(llvm::outs(), printFlags);
3077 llvm::outs() <<
"\n";
3080 llvm::outs().flush();
3084void transform::PrintOp::getEffects(
3089 if (!getTargetMutable().empty())
3109 <<
"failed to verify payload op";
3110 diag.attachNote(
target->getLoc()) <<
"payload op";
3116void transform::VerifyOp::getEffects(
3125void transform::YieldOp::getEffects(
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
static ParseResult parseKeyValuePair(AsmParser &parser, DataLayoutEntryInterface &entry, bool tryType=false)
Parse an entry which can either be of the form key = value or a dlti.dl_entry attribute.
static std::string diag(const llvm::Value &value)
static llvm::ManagedStatic< PassManagerOptions > options
ParseResult parseSymbolName(StringAttr &result)
Parse an -identifier and store it (without the '@' symbol) in a string attribute.
@ None
Zero or more operands with no delimiters.
@ Braces
{} brackets surrounding zero or more operands.
virtual ParseResult parseOptionalKeywordOrString(std::string *result)=0
Parse an optional keyword or string.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
MLIRContext * getContext() const
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseRSquare()=0
Parse a ] token.
virtual ParseResult parseOptionalRParen()=0
Parse a ) token if present.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual OptionalParseResult parseOptionalAttribute(Attribute &result, Type type={})=0
Parse an arbitrary optional attribute of a given type and return it in result.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual ParseResult parseColon()=0
Parse a : token.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseArrow()=0
Parse a '->' token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseComma()=0
Parse a , token.
virtual ParseResult parseOptionalLParen()=0
Parse a ( token if present.
ParseResult parseTypeList(SmallVectorImpl< Type > &result)
Parse a type list.
virtual ParseResult parseOptionalLSquare()=0
Parse a [ token if present.
virtual void printAttribute(Attribute attr)
Attributes are known-constant values of operations.
This class represents an argument of a Block.
Block represents an ordered list of Operations.
MutableArrayRef< BlockArgument > BlockArgListType
BlockArgument getArgument(unsigned i)
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
OpListType & getOperations()
Operation * getTerminator()
Get the terminator operation of this block.
BlockArgListType getArguments()
iterator_range< iterator > without_terminator()
Return an iterator range over the operation within this block excluding the terminator operation at t...
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
This class is a general helper class for creating context-global objects like types,...
IntegerAttr getI64IntegerAttr(int64_t value)
StringAttr getStringAttr(const Twine &bytes)
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
MLIRContext * getContext() const
A compatibility class connecting InFlightDiagnostic to DiagnosedSilenceableFailure while providing an...
The result of a transform IR operation application.
LogicalResult silence()
Converts silenceable failure into LogicalResult success without reporting the diagnostic,...
static DiagnosedSilenceableFailure success()
Constructs a DiagnosedSilenceableFailure in the success state.
std::string getMessage() const
Returns the diagnostic message without emitting it.
Diagnostic & attachNote(std::optional< Location > loc=std::nullopt)
Attaches a note to the last diagnostic.
LogicalResult checkAndReport()
Converts all kinds of failure into a LogicalResult failure, emitting the diagnostic if necessary.
bool succeeded() const
Returns true if this is a success.
static DiagnosedSilenceableFailure definiteFailure()
Constructs a DiagnosedSilenceableFailure in the failure state.
Dialects are groups of MLIR operations, types and attributes, as well as behavior associated with the...
A class for computing basic dominance information.
This class represents a frozen set of patterns that can be processed by a pattern applicator.
This class allows control over how the GreedyPatternRewriteDriver works.
static constexpr int64_t kNoLimit
IRValueT get() const
Return the current value being used by this operand.
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
This class represents a diagnostic that is inflight and set to be reported.
This class represents upper and lower bounds on the number of times a region of a RegionBranchOpInter...
Conversion from types to the LLVM IR dialect.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
std::vector< Dialect * > getLoadedDialects()
Return information about all IR dialects loaded in the context.
ArrayRef< RegisteredOperationName > getRegisteredOperations()
Return a sorted array containing the information about all registered operations.
void append(StringRef name, Attribute attr)
Add an attribute with the specified name.
NamedAttribute represents a combination of a name and an Attribute value.
StringAttr getName() const
Return the name of the attribute.
Attribute getValue() const
Return the value of the attribute.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual OptionalParseResult parseOptionalOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single operand if present.
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printNewline()=0
Print a newline and indent the printer to the start of the current operation.
virtual void increaseIndent()=0
Increase indentation.
void printOperands(const ContainerType &container)
Print a comma separated list of operands.
virtual void decreaseIndent()=0
Decrease indentation.
virtual void printOperand(Value value)=0
Print implementations for various things an operation contains.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Listener * getListener() const
Returns the current listener of this builder, or nullptr if this builder doesn't have a listener.
This class represents a single result from folding an operation.
This class represents an operand of an operation.
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Set of flags used to control the behavior of the various IR print methods (e.g.
OpPrintingFlags & useLocalScope(bool enable=true)
Use local scope when printing the operation.
OpPrintingFlags & assumeVerified(bool enable=true)
Do not verify the operation when using custom operation printers.
OpPrintingFlags & skipRegions(bool skip=true)
Skip printing regions.
This is a value defined by a result of an operation.
This class provides the API for ops that are known to be isolated from above.
A trait used to provide symbol table functionalities to a region operation.
A wrapper class that allows for printing an operation with a set of flags, useful to act as a "stream...
This class implements the operand iterators for the Operation class.
type_range getType() const
type_range getTypes() const
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
Operation is the basic unit of execution within MLIR.
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Attribute getAttr(StringAttr name)
Return the specified attribute if present, null otherwise.
Operation * clone(IRMapping &mapper, CloneOptions options=CloneOptions::all())
Create a deep copy of this operation, remapping any operands that use values outside of the operation...
Block * getBlock()
Returns the operation block that contains this operation.
Operation * getParentWithTrait()
Returns the closest surrounding parent operation with trait Trait.
Location getLoc()
The source location the operation was defined or derived from.
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
unsigned getNumOperands()
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
OperationName getName()
The name of an operation is the key identifier for it.
void print(raw_ostream &os, const OpPrintingFlags &flags={})
operand_range getOperands()
Returns an iterator on the underlying Value's.
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
result_range getOpResults()
Region * getParentRegion()
Returns the region to which the instruction belongs.
This class implements Optional functionality for ParseResult.
ParseResult value() const
Access the internal ParseResult value.
bool has_value() const
Returns true if we contain a valid ParseResult value.
static const PassInfo * lookup(StringRef passArg)
Returns the pass info for the specified pass class or null if unknown.
The main pass manager and pipeline builder.
static const PassPipelineInfo * lookup(StringRef pipelineArg)
Returns the pass pipeline info for the specified pass pipeline or null if unknown.
Structure to group information about a passes and pass pipelines (argument to invoke via mlir-opt,...
LogicalResult addToPipeline(OpPassManager &pm, StringRef options, function_ref< LogicalResult(const Twine &)> errorHandler) const
Adds this pass registry entry to the given pass manager.
This class represents a point being branched from in the methods of the RegionBranchOpInterface.
bool isParent() const
Returns true if branching from the parent op.
Operation * getTerminatorPredecessorOrNull() const
Returns the terminator if branching from a region.
This class represents a successor of a region.
bool isParent() const
Return true if the successor is the parent operation.
Region * getSuccessor() const
Return the given region successor.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
BlockArgListType getArguments()
unsigned getRegionNumber()
Return the number of this region in the parent operation.
This is a "type erased" representation of a registered operation.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
static DerivedEffect * get()
This class represents a collection of SymbolTables.
virtual Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
static StringRef getSymbolAttrName()
Return the name of the attribute used for symbol names.
static Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
type_range getType() const
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
use_range getUses() const
Returns a range of all uses, which is useful for iterating over all uses.
A utility result that is used to signal how to proceed with an ongoing walk:
static WalkResult advance()
bool wasInterrupted() const
Returns true if the walk was interrupted.
static WalkResult interrupt()
Operation * getOwner() const
Return the owner of this operand.
A named class for passing around the variadic flag.
void printFunctionOp(OpAsmPrinter &p, FunctionOpInterface op, bool isVariadic, StringRef typeAttrName, StringAttr argAttrsName, StringAttr resAttrsName)
Printer implementation for function-like operations.
ParseResult parseFunctionOp(OpAsmParser &parser, OperationState &result, bool allowVariadic, StringAttr typeAttrName, FuncTypeBuilder funcTypeBuilder, StringAttr argAttrsName, StringAttr resAttrsName)
Parser implementation for function-like operations.
Include the generated interface declarations.
InFlightDiagnostic emitWarning(Location loc)
Utility method to emit a warning message using this location.
function_ref< void(Value, StringRef)> OpAsmSetValueNameFn
A functor used to set the name of the start of a result group of an operation.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
const FrozenRewritePatternSet GreedyRewriteConfig config
LogicalResult applyPatternsGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
DiagnosedSilenceableFailure emitSilenceableFailure(Location loc, const Twine &message={})
Emits a silenceable failure with the given message.
llvm::DenseSet< ValueT, ValueInfoT > DenseSet
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
LogicalResult applyOpPatternsGreedily(ArrayRef< Operation * > ops, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr, bool *allErased=nullptr)
Rewrite the specified ops by repeatedly applying the highest benefit patterns in a greedy worklist dr...
llvm::SetVector< T, Vector, Set, N > SetVector
DiagnosedDefiniteFailure emitDefiniteFailure(Location loc, const Twine &message={})
Emits a definite failure with the given message.
bool isOpTriviallyDead(Operation *op)
Return true if the given operation is unused, and has no side effects on memory that prevent erasing.
const FrozenRewritePatternSet & patterns
void eliminateCommonSubExpressions(RewriterBase &rewriter, DominanceInfo &domInfo, Operation *op, bool *changed=nullptr)
Eliminate common subexpressions within the given operation.
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
llvm::function_ref< Fn > function_ref
size_t moveLoopInvariantCode(ArrayRef< Region * > regions, function_ref< bool(Value, Region *)> isDefinedOutsideRegion, function_ref< bool(Operation *, Region *)> shouldMoveOutOfRegion, function_ref< void(Operation *, Region *)> moveOutOfRegion)
Given a list of regions, perform loop-invariant code motion.
This is the representation of an operand reference.
This represents an operation in an abstracted form, suitable for use with the builder APIs.
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
SmallVector< std::unique_ptr< Region >, 1 > regions
Regions that the op will hold.
Region * addRegion()
Create a region that should be attached to the operation.