34#include "llvm/ADT/DenseSet.h"
35#include "llvm/ADT/STLExtras.h"
36#include "llvm/ADT/ScopeExit.h"
37#include "llvm/ADT/SmallPtrSet.h"
38#include "llvm/ADT/TypeSwitch.h"
39#include "llvm/Support/Debug.h"
40#include "llvm/Support/DebugLog.h"
41#include "llvm/Support/ErrorHandling.h"
42#include "llvm/Support/InterleavedRange.h"
45#define DEBUG_TYPE "transform-dialect"
46#define DEBUG_TYPE_MATCHER "transform-matcher"
58 OpAsmParser &parser, std::optional<OpAsmParser::UnresolvedOperand> &root,
79 while (transformAncestor) {
80 if (transformAncestor == payload) {
83 <<
"cannot apply transform to itself (or one of its ancestors)";
84 diag.attachNote(payload->
getLoc()) <<
"target payload op";
87 transformAncestor = transformAncestor->
getParentOp();
93#include "mlir/Dialect/Transform/IR/TransformOps.cpp.inc"
99OperandRange transform::AlternativesOp::getEntrySuccessorOperands(
101 if (!successor.
isParent() && getOperation()->getNumOperands() == 1)
102 return getOperation()->getOperands();
104 getOperation()->operand_end());
107void transform::AlternativesOp::getSuccessorRegions(
109 for (
Region &alternative : llvm::drop_begin(
114 ->getRegionNumber() +
116 regions.emplace_back(&alternative);
123transform::AlternativesOp::getSuccessorInputs(
RegionSuccessor successor) {
125 return getOperation()->getResults();
129void transform::AlternativesOp::getRegionInvocationBounds(
134 bounds.reserve(getNumRegions());
135 bounds.emplace_back(1, 1);
142 results.
set(res, {});
150 if (
Value scopeHandle = getScope())
151 llvm::append_range(originals, state.
getPayloadOps(scopeHandle));
156 if (original->isAncestor(getOperation())) {
158 <<
"scope must not contain the transforms being applied";
159 diag.attachNote(original->getLoc()) <<
"scope";
164 <<
"only isolated-from-above ops can be alternative scopes";
165 diag.attachNote(original->getLoc()) <<
"scope";
170 for (
Region ® : getAlternatives()) {
176 auto clones = llvm::to_vector(
177 llvm::map_range(originals, [](
Operation *op) {
return op->
clone(); }));
178 llvm::scope_exit deleteClones([&] {
189 if (
result.isSilenceableFailure()) {
190 LDBG() <<
"alternative failed: " <<
result.getMessage();
195 if (::mlir::failed(
result.silence()))
204 deleteClones.release();
205 TrackingListener listener(state, *
this);
207 for (
const auto &kvp : llvm::zip(originals, clones)) {
214 detail::forwardTerminatorOperands(®.front(), state, results);
218 return emitSilenceableError() <<
"all alternatives failed";
221void transform::AlternativesOp::getEffects(
225 for (
Region *region : getRegions()) {
226 if (!region->empty())
232LogicalResult transform::AlternativesOp::verify() {
233 for (
Region &alternative : getAlternatives()) {
238 <<
"expects terminator operands to have the "
239 "same type as results of the operation";
240 diag.attachNote(terminator->
getLoc()) <<
"terminator";
260 if (
auto paramH = getParam()) {
262 if (params.size() != 1) {
263 if (targets.size() != params.size()) {
264 return emitSilenceableError()
265 <<
"parameter and target have different payload lengths ("
266 << params.size() <<
" vs " << targets.size() <<
")";
268 for (
auto &&[
target, attr] : llvm::zip_equal(targets, params))
269 target->setAttr(getName(), attr);
274 for (
auto *
target : targets)
275 target->setAttr(getName(), attr);
279void transform::AnnotateOp::getEffects(
291transform::ApplyCommonSubexpressionEliminationOp::applyToOne(
306void transform::ApplyCommonSubexpressionEliminationOp::getEffects(
331 auto addDefiningOpsToWorklist = [&](
Operation *op) {
334 if (
Operation *defOp = v.getDefiningOp())
335 if (
target->isProperAncestor(defOp))
336 worklist.insert(defOp);
344 const auto *it = llvm::find(worklist, op);
345 if (it != worklist.end())
354 addDefiningOpsToWorklist(op);
360 while (!worklist.empty()) {
364 addDefiningOpsToWorklist(op);
371void transform::ApplyDeadCodeEliminationOp::getEffects(
396 if (!getRegion().empty()) {
397 for (
Operation &op : getRegion().front()) {
398 cast<transform::PatternDescriptorOpInterface>(&op)
399 .populatePatternsWithState(
patterns, state);
409 config.setMaxIterations(getMaxIterations() ==
static_cast<uint64_t
>(-1)
411 : getMaxIterations());
412 config.setMaxNumRewrites(getMaxNumRewrites() ==
static_cast<uint64_t
>(-1)
414 : getMaxNumRewrites());
419 bool cseChanged =
false;
422 static const int64_t kNumMaxIterations = 50;
425 LogicalResult
result = failure();
439 ops.push_back(nestedOp);
448 <<
"greedy pattern application failed";
456 }
while (cseChanged && ++iteration < kNumMaxIterations);
458 if (iteration == kNumMaxIterations)
464LogicalResult transform::ApplyPatternsOp::verify() {
465 if (!getRegion().empty()) {
466 for (
Operation &op : getRegion().front()) {
467 if (!isa<transform::PatternDescriptorOpInterface>(&op)) {
469 <<
"expected children ops to implement "
470 "PatternDescriptorOpInterface";
471 diag.attachNote(op.
getLoc()) <<
"op without interface";
479void transform::ApplyPatternsOp::getEffects(
485void transform::ApplyPatternsOp::build(
494 bodyBuilder(builder,
result.location);
501void transform::ApplyCanonicalizationPatternsOp::populatePatterns(
505 dialect->getCanonicalizationPatterns(
patterns);
507 op.getCanonicalizationPatterns(
patterns, ctx);
521 std::unique_ptr<TypeConverter> defaultTypeConverter;
522 transform::TypeConverterBuilderOpInterface typeConverterBuilder =
523 getDefaultTypeConverter();
524 if (typeConverterBuilder)
525 defaultTypeConverter = typeConverterBuilder.getTypeConverter();
530 for (
Attribute attr : cast<ArrayAttr>(*getLegalOps()))
531 conversionTarget.addLegalOp(
534 for (
Attribute attr : cast<ArrayAttr>(*getIllegalOps()))
535 conversionTarget.addIllegalOp(
537 if (getLegalDialects())
538 for (
Attribute attr : cast<ArrayAttr>(*getLegalDialects()))
539 conversionTarget.addLegalDialect(cast<StringAttr>(attr).getValue());
540 if (getIllegalDialects())
541 for (
Attribute attr : cast<ArrayAttr>(*getIllegalDialects()))
542 conversionTarget.addIllegalDialect(cast<StringAttr>(attr).getValue());
550 if (!getPatterns().empty()) {
551 for (
Operation &op : getPatterns().front()) {
553 cast<transform::ConversionPatternDescriptorOpInterface>(&op);
556 std::unique_ptr<TypeConverter> typeConverter =
557 descriptor.getTypeConverter();
560 keepAliveConverters.emplace_back(std::move(typeConverter));
561 converter = keepAliveConverters.back().get();
564 if (!defaultTypeConverter) {
566 <<
"pattern descriptor does not specify type "
567 "converter and apply_conversion_patterns op has "
568 "no default type converter";
569 diag.attachNote(op.
getLoc()) <<
"pattern descriptor op";
572 converter = defaultTypeConverter.get();
578 descriptor.populateConversionTargetRules(*converter, conversionTarget);
580 descriptor.populatePatterns(*converter,
patterns);
588 TrackingListenerConfig trackingConfig;
589 trackingConfig.requireMatchingReplacementOpName =
false;
590 ErrorCheckingTrackingListener trackingListener(state, *
this, trackingConfig);
591 ConversionConfig conversionConfig;
592 if (getPreserveHandles())
593 conversionConfig.listener = &trackingListener;
604 LogicalResult status = failure();
605 if (getPartialConversion()) {
606 status = applyPartialConversion(
target, conversionTarget, frozenPatterns,
609 status = applyFullConversion(
target, conversionTarget, frozenPatterns,
616 diag = emitSilenceableError() <<
"dialect conversion failed";
617 diag.attachNote(
target->getLoc()) <<
"target op";
622 trackingListener.checkAndResetError();
624 if (
diag.succeeded()) {
626 return trackingFailure;
628 diag.attachNote() <<
"tracking listener also failed: "
633 if (!
diag.succeeded())
640LogicalResult transform::ApplyConversionPatternsOp::verify() {
641 if (getNumRegions() != 1 && getNumRegions() != 2)
643 if (!getPatterns().empty()) {
644 for (
Operation &op : getPatterns().front()) {
645 if (!isa<transform::ConversionPatternDescriptorOpInterface>(&op)) {
647 emitOpError() <<
"expected pattern children ops to implement "
648 "ConversionPatternDescriptorOpInterface";
649 diag.attachNote(op.
getLoc()) <<
"op without interface";
654 if (getNumRegions() == 2) {
655 Region &typeConverterRegion = getRegion(1);
656 if (!llvm::hasSingleElement(typeConverterRegion.
front()))
658 <<
"expected exactly one op in default type converter region";
660 auto typeConverterOp = dyn_cast<transform::TypeConverterBuilderOpInterface>(
662 if (!typeConverterOp) {
664 <<
"expected default converter child op to "
665 "implement TypeConverterBuilderOpInterface";
666 diag.attachNote(maybeTypeConverter->
getLoc()) <<
"op without interface";
670 if (!getPatterns().empty()) {
671 for (
Operation &op : getPatterns().front()) {
673 cast<transform::ConversionPatternDescriptorOpInterface>(&op);
674 if (
failed(descriptor.verifyTypeConverter(typeConverterOp)))
682void transform::ApplyConversionPatternsOp::getEffects(
684 if (!getPreserveHandles()) {
692void transform::ApplyConversionPatternsOp::build(
702 if (patternsBodyBuilder)
703 patternsBodyBuilder(builder,
result.location);
709 if (typeConverterBodyBuilder)
710 typeConverterBodyBuilder(builder,
result.location);
718void transform::ApplyToLLVMConversionPatternsOp::populatePatterns(
721 assert(dialect &&
"expected that dialect is loaded");
722 auto *iface = cast<ConvertToLLVMPatternInterface>(dialect);
726 iface->populateConvertToLLVMConversionPatterns(
730LogicalResult transform::ApplyToLLVMConversionPatternsOp::verifyTypeConverter(
731 transform::TypeConverterBuilderOpInterface builder) {
732 if (builder.getTypeConverterType() !=
"LLVMTypeConverter")
737LogicalResult transform::ApplyToLLVMConversionPatternsOp::verify() {
740 return emitOpError(
"unknown dialect or dialect not loaded: ")
742 auto *iface = dyn_cast<ConvertToLLVMPatternInterface>(dialect);
745 "dialect does not implement ConvertToLLVMPatternInterface or "
746 "extension was not loaded: ")
756transform::ApplyLoopInvariantCodeMotionOp::applyToOne(
766void transform::ApplyLoopInvariantCodeMotionOp::getEffects(
776void transform::ApplyRegisteredPassOp::getEffects(
794 llvm::raw_string_ostream optionsStream(
options);
799 if (
auto paramOperand = dyn_cast<transform::ParamOperandAttr>(valueAttr)) {
802 int64_t dynamicOptionIdx = paramOperand.getIndex().getInt();
803 assert(dynamicOptionIdx <
static_cast<int64_t>(dynamicOptions.size()) &&
804 "the number of ParamOperandAttrs in the options DictionaryAttr"
805 "should be the same as the number of options passed as params");
807 state.
getParams(dynamicOptions[dynamicOptionIdx]);
809 llvm::interleave(attrsAssociatedToParam, optionsStream, appendValueAttr,
811 }
else if (
auto arrayAttr = dyn_cast<ArrayAttr>(valueAttr)) {
813 llvm::interleave(arrayAttr, optionsStream, appendValueAttr,
",");
814 }
else if (
auto strAttr = dyn_cast<StringAttr>(valueAttr)) {
816 optionsStream << strAttr.getValue().str();
819 valueAttr.print(optionsStream,
true);
825 getOptions(), optionsStream,
826 [&](
auto namedAttribute) {
827 optionsStream << namedAttribute.getName().str();
828 optionsStream <<
"=";
829 appendValueAttr(namedAttribute.getValue());
832 optionsStream.flush();
840 <<
"unknown pass or pass pipeline: " << getPassName();
849 <<
"failed to add pass or pass pipeline to pipeline: "
866 auto diag = emitSilenceableError() <<
"pass pipeline failed";
867 diag.attachNote(
target->getLoc()) <<
"target op";
873 results.
set(llvm::cast<OpResult>(getResult()), targets);
882 size_t dynamicOptionsIdx = 0;
888 std::function<ParseResult(
Attribute &)> parseValue =
889 [&](
Attribute &valueAttr) -> ParseResult {
897 [&]() -> ParseResult { return parseValue(attrs.emplace_back()); },
898 " in options dictionary") ||
902 valueAttr = ArrayAttr::get(parser.
getContext(), attrs);
912 ParseResult parsedOperand = parser.
parseOperand(operand);
913 if (failed(parsedOperand))
919 dynamicOptions.push_back(operand);
920 auto wrappedIndex = IntegerAttr::get(
921 IntegerType::get(parser.
getContext(), 64), dynamicOptionsIdx++);
923 transform::ParamOperandAttr::get(parser.
getContext(), wrappedIndex);
924 }
else if (failed(parsedValueAttr.
value())) {
926 }
else if (isa<transform::ParamOperandAttr>(valueAttr)) {
928 <<
"the param_operand attribute is a marker reserved for "
929 <<
"indicating a value will be passed via params and is only used "
930 <<
"in the generic print format";
944 <<
"expected key to either be an identifier or a string";
948 <<
"expected '=' after key in key-value pair";
950 if (failed(parseValue(valueAttr)))
952 <<
"expected a valid attribute or operand as value associated "
953 <<
"to key '" << key <<
"'";
962 " in options dictionary"))
965 if (DictionaryAttr::findDuplicate(
966 keyValuePairs,
false)
969 <<
"duplicate keys found in options dictionary";
984 if (
auto paramOperandAttr =
985 dyn_cast<transform::ParamOperandAttr>(valueAttr)) {
988 dynamicOptions[paramOperandAttr.getIndex().getInt()]);
989 }
else if (
auto arrayAttr = dyn_cast<ArrayAttr>(valueAttr)) {
992 llvm::interleaveComma(arrayAttr, printer, printOptionValue);
1001 printer << namedAttribute.
getName();
1003 printOptionValue(namedAttribute.
getValue());
1008LogicalResult transform::ApplyRegisteredPassOp::verify() {
1015 std::function<LogicalResult(
Attribute)> checkOptionValue =
1016 [&](
Attribute valueAttr) -> LogicalResult {
1017 if (
auto paramOperand = dyn_cast<transform::ParamOperandAttr>(valueAttr)) {
1018 int64_t dynamicOptionIdx = paramOperand.getIndex().getInt();
1019 if (dynamicOptionIdx < 0 ||
1020 dynamicOptionIdx >=
static_cast<int64_t>(dynamicOptions.size()))
1022 <<
"dynamic option index " << dynamicOptionIdx
1023 <<
" is out of bounds for the number of dynamic options: "
1024 << dynamicOptions.size();
1025 if (dynamicOptions[dynamicOptionIdx] ==
nullptr)
1026 return emitOpError() <<
"dynamic option index " << dynamicOptionIdx
1027 <<
" is already used in options";
1028 dynamicOptions[dynamicOptionIdx] =
nullptr;
1029 }
else if (
auto arrayAttr = dyn_cast<ArrayAttr>(valueAttr)) {
1031 for (
auto eltAttr : arrayAttr)
1032 if (
failed(checkOptionValue(eltAttr)))
1039 if (
failed(checkOptionValue(namedAttr.getValue())))
1043 for (
Value dynamicOption : dynamicOptions)
1045 return emitOpError() <<
"a param operand does not have a corresponding "
1046 <<
"param_operand attr in the options dict";
1059 results.push_back(
target);
1063void transform::CastOp::getEffects(
1071 assert(inputs.size() == 1 &&
"expected one input");
1072 assert(outputs.size() == 1 &&
"expected one output");
1073 return llvm::all_of(
1074 std::initializer_list<Type>{inputs.front(), outputs.front()},
1075 llvm::IsaPred<transform::TransformHandleTypeInterface>);
1095 assert(block.
getParent() &&
"cannot match using a detached block");
1102 if (!isa<transform::MatchOpInterface>(match)) {
1104 <<
"expected operations in the match part to "
1105 "implement MatchOpInterface";
1108 state.
applyTransform(cast<transform::TransformOpInterface>(match));
1109 if (
diag.succeeded())
1127template <
typename... Tys>
1129 return ((isa<Tys>(t1) && isa<Tys>(t2)) || ... ||
false);
1136 transform::TransformParamTypeInterface,
1137 transform::TransformValueHandleTypeInterface>(
1150 getOperation(), getMatcher());
1151 if (matcher.isExternal()) {
1153 <<
"unresolved external symbol " << getMatcher();
1157 rawResults.resize(getOperation()->getNumResults());
1158 std::optional<DiagnosedSilenceableFailure> maybeFailure;
1170 matcher.getFunctionBody().front(),
1173 if (
diag.isDefiniteFailure())
1175 if (
diag.isSilenceableFailure()) {
1177 <<
" failed: " <<
diag.getMessage();
1182 for (
auto &&[i, mapping] : llvm::enumerate(mappings)) {
1183 if (mapping.size() != 1) {
1184 maybeFailure.emplace(emitSilenceableError()
1185 <<
"result #" << i <<
", associated with "
1187 <<
" payload objects, expected 1");
1190 rawResults[i].push_back(mapping[0]);
1195 return std::move(*maybeFailure);
1196 assert(!maybeFailure &&
"failure set but the walk was not interrupted");
1198 for (
auto &&[opResult, rawResult] :
1199 llvm::zip_equal(getOperation()->getResults(), rawResults)) {
1206void transform::CollectMatchingOp::getEffects(
1213LogicalResult transform::CollectMatchingOp::verifySymbolUses(
1215 auto matcherSymbol = dyn_cast_or_null<FunctionOpInterface>(
1217 if (!matcherSymbol ||
1218 !isa<TransformOpInterface>(matcherSymbol.getOperation()))
1219 return emitError() <<
"unresolved matcher symbol " << getMatcher();
1222 if (argumentTypes.size() != 1 ||
1223 !isa<TransformHandleTypeInterface>(argumentTypes[0])) {
1225 <<
"expected the matcher to take one operation handle argument";
1227 if (!matcherSymbol.getArgAttr(
1228 0, transform::TransformDialect::kArgReadOnlyAttrName)) {
1229 return emitError() <<
"expected the matcher argument to be marked readonly";
1233 if (resultTypes.size() != getOperation()->getNumResults()) {
1235 <<
"expected the matcher to yield as many values as op has results ("
1236 << getOperation()->getNumResults() <<
"), got "
1237 << resultTypes.size();
1240 for (
auto &&[i, matcherType, resultType] :
1241 llvm::enumerate(resultTypes, getOperation()->getResultTypes())) {
1246 <<
"mismatching type interfaces for matcher result and op result #"
1258bool transform::ForeachMatchOp::allowsRepeatedHandleOperands() {
return true; }
1266 matchActionPairs.reserve(getMatchers().size());
1268 for (
auto &&[matcher, action] :
1269 llvm::zip_equal(getMatchers(), getActions())) {
1270 auto matcherSymbol =
1272 getOperation(), cast<SymbolRefAttr>(matcher));
1275 getOperation(), cast<SymbolRefAttr>(action));
1276 assert(matcherSymbol && actionSymbol &&
1277 "unresolved symbols not caught by the verifier");
1279 if (matcherSymbol.isExternal())
1281 if (actionSymbol.isExternal())
1284 matchActionPairs.emplace_back(matcherSymbol, actionSymbol);
1295 matchInputMapping.emplace_back();
1297 getForwardedInputs(), state);
1299 actionResultMapping.resize(getForwardedOutputs().size());
1305 if (!getRestrictRoot() && op == root)
1313 firstMatchArgument.clear();
1314 firstMatchArgument.push_back(op);
1317 for (
auto [matcher, action] : matchActionPairs) {
1319 matchBlock(matcher.getFunctionBody().front(), matchInputMapping,
1320 state, matchOutputMapping);
1321 if (
diag.isDefiniteFailure())
1323 if (
diag.isSilenceableFailure()) {
1325 <<
" failed: " <<
diag.getMessage();
1331 action.getFunctionBody().front().getArguments(),
1332 matchOutputMapping))) {
1337 action.getFunctionBody().front().without_terminator()) {
1340 if (
result.isDefiniteFailure())
1342 if (
result.isSilenceableFailure()) {
1344 overallDiag = emitSilenceableError() <<
"actions failed";
1347 <<
"failed action: " <<
result.getMessage();
1349 <<
"when applied to this matching payload";
1354 if (
failed(detail::appendValueMappings(
1356 action.getFunctionBody().front().getTerminator()->getOperands(),
1357 state, getFlattenResults()))) {
1359 <<
"action @" << action.getName()
1360 <<
" has results associated with multiple payload entities, "
1361 "but flattening was not requested";
1376 results.
set(llvm::cast<OpResult>(getUpdated()),
1378 for (
auto &&[
result, mapping] :
1379 llvm::zip_equal(getForwardedOutputs(), actionResultMapping)) {
1385void transform::ForeachMatchOp::getAsmResultNames(
1387 setNameFn(getUpdated(),
"updated_root");
1388 for (
Value v : getForwardedOutputs()) {
1389 setNameFn(v,
"yielded");
1393void transform::ForeachMatchOp::getEffects(
1396 if (getOperation()->getNumOperands() < 1 ||
1397 getOperation()->getNumResults() < 1) {
1421 matcherList.push_back(SymbolRefAttr::get(matcher));
1422 actionList.push_back(SymbolRefAttr::get(action));
1436 for (
auto &&[matcher, action, idx] : llvm::zip_equal(
1439 printer << cast<SymbolRefAttr>(matcher) <<
" -> "
1440 << cast<SymbolRefAttr>(action);
1448LogicalResult transform::ForeachMatchOp::verify() {
1449 if (getMatchers().size() != getActions().size())
1450 return emitOpError() <<
"expected the same number of matchers and actions";
1451 if (getMatchers().empty())
1452 return emitOpError() <<
"expected at least one match/action pair";
1456 if (matcherNames.insert(name).second)
1459 <<
" is used more than once, only the first match will apply";
1470 bool alsoVerifyInternal =
false) {
1471 auto transformOp = cast<transform::TransformOpInterface>(op.getOperation());
1472 llvm::SmallDenseSet<unsigned> consumedArguments;
1473 if (!op.isExternal()) {
1477 for (
unsigned i = 0, e = op.getNumArguments(); i < e; ++i) {
1479 op.getArgAttr(i, transform::TransformDialect::kArgConsumedAttrName) !=
1482 op.getArgAttr(i, transform::TransformDialect::kArgReadOnlyAttrName) !=
1484 if (isConsumed && isReadOnly) {
1485 return transformOp.emitSilenceableError()
1486 <<
"argument #" << i <<
" cannot be both readonly and consumed";
1488 if ((op.isExternal() || alsoVerifyInternal) && !isConsumed && !isReadOnly) {
1489 return transformOp.emitSilenceableError()
1490 <<
"must provide consumed/readonly status for arguments of "
1491 "external or called ops";
1493 if (op.isExternal())
1496 if (consumedArguments.contains(i) && !isConsumed && isReadOnly) {
1497 return transformOp.emitSilenceableError()
1498 <<
"argument #" << i
1499 <<
" is consumed in the body but is not marked as such";
1501 if (emitWarnings && !consumedArguments.contains(i) && isConsumed) {
1505 <<
"op argument #" << i
1506 <<
" is not consumed in the body but is marked as consumed";
1512LogicalResult transform::ForeachMatchOp::verifySymbolUses(
1514 assert(getMatchers().size() == getActions().size());
1516 StringAttr::get(
getContext(), TransformDialect::kArgConsumedAttrName);
1517 for (
auto &&[matcher, action] :
1518 llvm::zip_equal(getMatchers(), getActions())) {
1520 auto matcherSymbol = dyn_cast_or_null<FunctionOpInterface>(
1522 cast<SymbolRefAttr>(matcher)));
1523 auto actionSymbol = dyn_cast_or_null<FunctionOpInterface>(
1525 cast<SymbolRefAttr>(action)));
1526 if (!matcherSymbol ||
1527 !isa<TransformOpInterface>(matcherSymbol.getOperation()))
1528 return emitError() <<
"unresolved matcher symbol " << matcher;
1529 if (!actionSymbol ||
1530 !isa<TransformOpInterface>(actionSymbol.getOperation()))
1531 return emitError() <<
"unresolved action symbol " << action;
1536 .checkAndReport())) {
1542 .checkAndReport())) {
1547 TypeRange operandTypes = getOperandTypes();
1548 TypeRange matcherArguments = matcherSymbol.getArgumentTypes();
1549 if (operandTypes.size() != matcherArguments.size()) {
1551 emitError() <<
"the number of operands (" << operandTypes.size()
1552 <<
") doesn't match the number of matcher arguments ("
1553 << matcherArguments.size() <<
") for " << matcher;
1554 diag.attachNote(matcherSymbol->getLoc()) <<
"symbol declaration";
1557 for (
auto &&[i, operand, argument] :
1558 llvm::enumerate(operandTypes, matcherArguments)) {
1559 if (matcherSymbol.getArgAttr(i, consumedAttr)) {
1562 <<
"does not expect matcher symbol to consume its operand #" << i;
1563 diag.attachNote(matcherSymbol->getLoc()) <<
"symbol declaration";
1572 <<
"mismatching type interfaces for operand and matcher argument #"
1573 << i <<
" of matcher " << matcher;
1574 diag.attachNote(matcherSymbol->getLoc()) <<
"symbol declaration";
1579 TypeRange matcherResults = matcherSymbol.getResultTypes();
1580 TypeRange actionArguments = actionSymbol.getArgumentTypes();
1581 if (matcherResults.size() != actionArguments.size()) {
1582 return emitError() <<
"mismatching number of matcher results and "
1583 "action arguments between "
1584 << matcher <<
" (" << matcherResults.size() <<
") and "
1585 << action <<
" (" << actionArguments.size() <<
")";
1587 for (
auto &&[i, matcherType, actionType] :
1588 llvm::enumerate(matcherResults, actionArguments)) {
1592 return emitError() <<
"mismatching type interfaces for matcher result "
1593 "and action argument #"
1594 << i <<
"of matcher " << matcher <<
" and action "
1599 TypeRange actionResults = actionSymbol.getResultTypes();
1600 auto resultTypes =
TypeRange(getResultTypes()).drop_front();
1601 if (actionResults.size() != resultTypes.size()) {
1603 emitError() <<
"the number of action results ("
1604 << actionResults.size() <<
") for " << action
1605 <<
" doesn't match the number of extra op results ("
1606 << resultTypes.size() <<
")";
1607 diag.attachNote(actionSymbol->getLoc()) <<
"symbol declaration";
1610 for (
auto &&[i, resultType, actionType] :
1611 llvm::enumerate(resultTypes, actionResults)) {
1616 emitError() <<
"mismatching type interfaces for action result #" << i
1617 <<
" of action " << action <<
" and op result";
1618 diag.attachNote(actionSymbol->getLoc()) <<
"symbol declaration";
1636 detail::prepareValueMappings(payloads, getTargets(), state);
1637 size_t numIterations = payloads.empty() ? 0 : payloads.front().size();
1638 bool withZipShortest = getWithZipShortest();
1642 if (withZipShortest) {
1646 return a.size() <
b.size();
1649 for (
auto &payload : payloads)
1650 payload.resize(numIterations);
1656 for (
size_t argIdx = 1; !withZipShortest && argIdx < payloads.size();
1658 if (payloads[argIdx].size() != numIterations) {
1659 return emitSilenceableError()
1660 <<
"prior targets' payload size (" << numIterations
1661 <<
") differs from payload size (" << payloads[argIdx].size()
1662 <<
") of target " << getTargets()[argIdx];
1671 for (
size_t iterIdx = 0; iterIdx < numIterations; iterIdx++) {
1674 for (
auto &&[argIdx, blockArg] : llvm::enumerate(blockArguments)) {
1685 llvm::cast<transform::TransformOpInterface>(
transform));
1691 OperandRange yieldOperands = getYieldOp().getOperands();
1692 for (
auto &&[
result, yieldOperand, resTuple] :
1693 llvm::zip_equal(getResults(), yieldOperands, zippedResults))
1695 if (isa<TransformHandleTypeInterface>(
result.getType()))
1696 llvm::append_range(resTuple, state.
getPayloadOps(yieldOperand));
1697 else if (isa<TransformValueHandleTypeInterface>(
result.getType()))
1699 else if (isa<TransformParamTypeInterface>(
result.getType()))
1700 llvm::append_range(resTuple, state.
getParams(yieldOperand));
1702 assert(
false &&
"unhandled handle type");
1706 for (
auto &&[
result, resPayload] : zip_equal(getResults(), zippedResults))
1712void transform::ForeachOp::getEffects(
1716 for (
auto &&[
target, blockArg] :
1717 llvm::zip(getTargetsMutable(), getBody().front().getArguments())) {
1719 if (any_of(getBody().front().without_terminator(), [&](
Operation &op) {
1721 cast<TransformOpInterface>(&op));
1729 if (any_of(getBody().front().without_terminator(), [&](
Operation &op) {
1733 }
else if (any_of(getBody().front().without_terminator(), [&](
Operation &op) {
1742void transform::ForeachOp::getSuccessorRegions(
1744 Region *bodyRegion = &getBody();
1746 regions.emplace_back(bodyRegion);
1753 "unexpected region index");
1754 regions.emplace_back(bodyRegion);
1764transform::ForeachOp::getEntrySuccessorOperands(
RegionSuccessor successor) {
1767 assert(successor.
getSuccessor() == &getBody() &&
"unexpected region index");
1768 return getOperation()->getOperands();
1771transform::YieldOp transform::ForeachOp::getYieldOp() {
1772 return cast<transform::YieldOp>(getBody().front().getTerminator());
1775LogicalResult transform::ForeachOp::verify() {
1776 for (
auto [targetOpt, bodyArgOpt] :
1777 llvm::zip_longest(getTargets(), getBody().front().getArguments())) {
1778 if (!targetOpt || !bodyArgOpt)
1779 return emitOpError() <<
"expects the same number of targets as the body "
1780 "has block arguments";
1781 if (targetOpt.value().getType() != bodyArgOpt.value().getType())
1783 "expects co-indexed targets and the body's "
1784 "block arguments to have the same op/value/param type");
1787 for (
auto [resultOpt, yieldOperandOpt] :
1788 llvm::zip_longest(getResults(), getYieldOp().getOperands())) {
1789 if (!resultOpt || !yieldOperandOpt)
1790 return emitOpError() <<
"expects the same number of results as the "
1791 "yield terminator has operands";
1792 if (resultOpt.value().getType() != yieldOperandOpt.value().getType())
1793 return emitOpError(
"expects co-indexed results and yield "
1794 "operands to have the same op/value/param type");
1812 for (
int64_t i = 0, e = getNthParent(); i < e; ++i) {
1815 bool checkIsolatedFromAbove =
1816 !getIsolatedFromAbove() ||
1818 bool checkOpName = !getOpName().has_value() ||
1820 if (checkIsolatedFromAbove && checkOpName)
1825 if (getAllowEmptyResults()) {
1826 results.
set(llvm::cast<OpResult>(getResult()), parents);
1830 emitSilenceableError()
1831 <<
"could not find a parent op that matches all requirements";
1832 diag.attachNote(
target->getLoc()) <<
"target op";
1836 if (getDeduplicate()) {
1837 if (resultSet.insert(parent).second)
1838 parents.push_back(parent);
1840 parents.push_back(parent);
1843 results.
set(llvm::cast<OpResult>(getResult()), parents);
1855 int64_t resultNumber = getResultNumber();
1857 if (std::empty(payloadOps)) {
1858 results.
set(cast<OpResult>(getResult()), {});
1861 if (!llvm::hasSingleElement(payloadOps))
1863 <<
"handle must be mapped to exactly one payload op";
1866 if (
target->getNumResults() <= resultNumber)
1868 results.
set(llvm::cast<OpResult>(getResult()),
1869 llvm::to_vector(
target->getResult(resultNumber).getUsers()));
1883 if (llvm::isa<BlockArgument>(v)) {
1885 emitSilenceableError() <<
"cannot get defining op of block argument";
1886 diag.attachNote(v.getLoc()) <<
"target value";
1889 definingOps.push_back(v.getDefiningOp());
1891 results.
set(llvm::cast<OpResult>(getResult()), definingOps);
1903 int64_t operandNumber = getOperandNumber();
1907 target->getNumOperands() <= operandNumber
1909 :
target->getOperand(operandNumber).getDefiningOp();
1912 emitSilenceableError()
1913 <<
"could not find a producer for operand number: " << operandNumber
1915 diag.attachNote(
target->getLoc()) <<
"target op";
1918 producers.push_back(producer);
1920 results.
set(llvm::cast<OpResult>(getResult()), producers);
1936 getLoc(), getIsAll(), getIsInverted(), getRawPositionList(),
1937 target->getNumOperands(), operandPositions);
1938 if (
diag.isSilenceableFailure()) {
1940 <<
"while considering positions of this payload operation";
1943 llvm::append_range(operands,
1944 llvm::map_range(operandPositions, [&](
int64_t pos) {
1945 return target->getOperand(pos);
1948 results.
setValues(cast<OpResult>(getResult()), operands);
1952LogicalResult transform::GetOperandOp::verify() {
1954 getIsInverted(), getIsAll());
1969 getLoc(), getIsAll(), getIsInverted(), getRawPositionList(),
1970 target->getNumResults(), resultPositions);
1971 if (
diag.isSilenceableFailure()) {
1973 <<
"while considering positions of this payload operation";
1976 llvm::append_range(opResults,
1977 llvm::map_range(resultPositions, [&](
int64_t pos) {
1978 return target->getResult(pos);
1981 results.
setValues(cast<OpResult>(getResult()), opResults);
1985LogicalResult transform::GetResultOp::verify() {
1987 getIsInverted(), getIsAll());
1994void transform::GetTypeOp::getEffects(
2007 Type type = value.getType();
2008 if (getElemental()) {
2009 if (
auto shaped = dyn_cast<ShapedType>(type)) {
2010 type = shaped.getElementType();
2013 params.push_back(TypeAttr::get(type));
2015 results.
setParams(cast<OpResult>(getResult()), params);
2033 if (
result.isDefiniteFailure())
2036 if (
result.isSilenceableFailure()) {
2037 if (mode == transform::FailurePropagationMode::Propagate) {
2057 getOperation(), getTarget());
2058 assert(callee &&
"unverified reference to unknown symbol");
2060 if (callee.isExternal())
2065 detail::prepareValueMappings(mappings, getOperands(), state);
2067 for (
auto &&[arg, map] :
2068 llvm::zip_equal(callee.getBody().front().getArguments(), mappings)) {
2074 callee.getBody().front(), getFailurePropagationMode(), state, results);
2080 detail::prepareValueMappings(
2081 mappings, callee.getBody().front().getTerminator()->getOperands(), state);
2082 for (
auto &&[
result, mapping] : llvm::zip_equal(getResults(), mappings))
2090void transform::IncludeOp::getEffects(
2105 auto defaultEffects = [&] {
2112 getOperation()->getAttrOfType<SymbolRefAttr>(getTargetAttrName());
2114 return defaultEffects();
2116 getOperation(), getTarget());
2118 return defaultEffects();
2120 for (
unsigned i = 0, e = getNumOperands(); i < e; ++i) {
2121 if (callee.getArgAttr(i, TransformDialect::kArgConsumedAttrName))
2123 else if (callee.getArgAttr(i, TransformDialect::kArgReadOnlyAttrName))
2132 auto targetAttr = getOperation()->getAttrOfType<SymbolRefAttr>(
"target");
2134 return emitOpError() <<
"expects a 'target' symbol reference attribute";
2139 return emitOpError() <<
"does not reference a named transform sequence";
2141 FunctionType fnType =
target.getFunctionType();
2142 if (fnType.getNumInputs() != getNumOperands())
2143 return emitError(
"incorrect number of operands for callee");
2145 for (
unsigned i = 0, e = fnType.getNumInputs(); i != e; ++i) {
2146 if (getOperand(i).
getType() != fnType.getInput(i)) {
2147 return emitOpError(
"operand type mismatch: expected operand type ")
2148 << fnType.getInput(i) <<
", but provided "
2149 << getOperand(i).getType() <<
" for operand number " << i;
2153 if (fnType.getNumResults() != getNumResults())
2154 return emitError(
"incorrect number of results for callee");
2156 for (
unsigned i = 0, e = fnType.getNumResults(); i != e; ++i) {
2157 Type resultType = getResult(i).getType();
2158 Type funcType = fnType.getResult(i);
2161 <<
" must implement the same transform dialect "
2162 "interface as the corresponding callee result";
2167 cast<FunctionOpInterface>(*
target),
false,
2177 ::std::optional<::mlir::Operation *> maybeCurrent,
2179 if (!maybeCurrent.has_value()) {
2184 return emitSilenceableError() <<
"operation is not empty";
2195 for (
auto acceptedAttr : getOpNames().getAsRange<StringAttr>()) {
2196 if (acceptedAttr.getValue() == currentOpName)
2199 return emitSilenceableError() <<
"wrong operation name";
2210 auto signedAPIntAsString = [&](
const APInt &value) {
2212 llvm::raw_string_ostream os(str);
2213 value.print(os,
true);
2220 if (params.size() != references.size()) {
2221 return emitSilenceableError()
2222 <<
"parameters have different payload lengths (" << params.size()
2223 <<
" vs " << references.size() <<
")";
2226 for (
auto &&[i, param, reference] : llvm::enumerate(params, references)) {
2227 auto intAttr = llvm::dyn_cast<IntegerAttr>(param);
2228 auto refAttr = llvm::dyn_cast<IntegerAttr>(reference);
2229 if (!intAttr || !refAttr) {
2231 <<
"non-integer parameter value not expected";
2233 if (intAttr.getType() != refAttr.getType()) {
2235 <<
"mismatching integer attribute types in parameter #" << i;
2237 APInt value = intAttr.getValue();
2238 APInt refValue = refAttr.getValue();
2242 auto reportError = [&](StringRef direction) {
2244 emitSilenceableError() <<
"expected parameter to be " << direction
2245 <<
" " << signedAPIntAsString(refValue)
2246 <<
", got " << signedAPIntAsString(value);
2247 diag.attachNote(getParam().getLoc())
2248 <<
"value # " << position
2249 <<
" associated with the parameter defined here";
2253 switch (getPredicate()) {
2254 case MatchCmpIPredicate::eq:
2255 if (value.eq(refValue))
2257 return reportError(
"equal to");
2258 case MatchCmpIPredicate::ne:
2259 if (value.ne(refValue))
2261 return reportError(
"not equal to");
2262 case MatchCmpIPredicate::lt:
2263 if (value.slt(refValue))
2265 return reportError(
"less than");
2266 case MatchCmpIPredicate::le:
2267 if (value.sle(refValue))
2269 return reportError(
"less than or equal to");
2270 case MatchCmpIPredicate::gt:
2271 if (value.sgt(refValue))
2273 return reportError(
"greater than");
2274 case MatchCmpIPredicate::ge:
2275 if (value.sge(refValue))
2277 return reportError(
"greater than or equal to");
2283void transform::MatchParamCmpIOp::getEffects(
2297 results.
setParams(cast<OpResult>(getParam()), {getValue()});
2310 if (isa<TransformHandleTypeInterface>(handles.front().
getType())) {
2312 for (
Value operand : handles)
2313 llvm::append_range(operations, state.
getPayloadOps(operand));
2314 if (!getDeduplicate()) {
2315 results.
set(llvm::cast<OpResult>(getResult()), operations);
2320 results.
set(llvm::cast<OpResult>(getResult()), uniqued.getArrayRef());
2324 if (llvm::isa<TransformParamTypeInterface>(handles.front().getType())) {
2326 for (
Value attribute : handles)
2327 llvm::append_range(attrs, state.
getParams(attribute));
2328 if (!getDeduplicate()) {
2329 results.
setParams(cast<OpResult>(getResult()), attrs);
2334 results.
setParams(cast<OpResult>(getResult()), uniqued.getArrayRef());
2339 llvm::isa<TransformValueHandleTypeInterface>(handles.front().getType()) &&
2340 "expected value handle type");
2342 for (
Value value : handles)
2344 if (!getDeduplicate()) {
2345 results.
setValues(cast<OpResult>(getResult()), payloadValues);
2350 results.
setValues(cast<OpResult>(getResult()), uniqued.getArrayRef());
2354bool transform::MergeHandlesOp::allowsRepeatedHandleOperands() {
2356 return getDeduplicate();
2359void transform::MergeHandlesOp::getEffects(
2368OpFoldResult transform::MergeHandlesOp::fold(FoldAdaptor adaptor) {
2369 if (getDeduplicate() || getHandles().size() != 1)
2374 return getHandles().front();
2393 if (
failed(detail::mapPossibleTopLevelTransformOpBlockArguments(
2394 state, this->getOperation(), getBody())))
2398 FailurePropagationMode::Propagate, state, results);
2401void transform::NamedSequenceOp::getEffects(
2404ParseResult transform::NamedSequenceOp::parse(
OpAsmParser &parser,
2408 getFunctionTypeAttrName(
result.name),
2411 std::string &) { return builder.getFunctionType(inputs, results); },
2412 getArgAttrsAttrName(
result.name), getResAttrsAttrName(
result.name));
2415void transform::NamedSequenceOp::print(
OpAsmPrinter &printer) {
2417 printer, cast<FunctionOpInterface>(getOperation()),
false,
2418 getFunctionTypeAttrName().getValue(), getArgAttrsAttrName(),
2419 getResAttrsAttrName());
2429 if (
auto parent = op->
getParentOfType<transform::TransformOpInterface>()) {
2432 <<
"cannot be defined inside another transform op";
2433 diag.attachNote(parent.
getLoc()) <<
"ancestor transform op";
2437 if (op.isExternal() || op.getFunctionBody().empty()) {
2444 if (op.getFunctionBody().front().empty())
2447 Operation *terminator = &op.getFunctionBody().front().back();
2448 if (!isa<transform::YieldOp>(terminator)) {
2451 << transform::YieldOp::getOperationName()
2452 <<
"' as terminator";
2453 diag.attachNote(terminator->
getLoc()) <<
"terminator";
2457 if (terminator->
getNumOperands() != op.getResultTypes().size()) {
2459 <<
"expected terminator to have as many operands as the parent op "
2462 for (
auto [i, operandType, resultType] : llvm::zip_equal(
2465 if (operandType == resultType)
2468 <<
"the type of the terminator operand #" << i
2469 <<
" must match the type of the corresponding parent op result ("
2470 << operandType <<
" vs " << resultType <<
")";
2483 transform::TransformDialect::kWithNamedSequenceAttrName)) {
2486 <<
"expects the parent symbol table to have the '"
2487 << transform::TransformDialect::kWithNamedSequenceAttrName
2489 diag.attachNote(parent->
getLoc()) <<
"symbol table operation";
2494 if (
auto parent = op->
getParentOfType<transform::TransformOpInterface>()) {
2497 <<
"cannot be defined inside another transform op";
2498 diag.attachNote(parent.
getLoc()) <<
"ancestor transform op";
2502 if (op.isExternal() || op.getBody().empty())
2506 if (op.getBody().front().empty())
2509 Operation *terminator = &op.getBody().front().back();
2510 if (!isa<transform::YieldOp>(terminator)) {
2513 << transform::YieldOp::getOperationName()
2514 <<
"' as terminator";
2515 diag.attachNote(terminator->
getLoc()) <<
"terminator";
2519 if (terminator->
getNumOperands() != op.getFunctionType().getNumResults()) {
2521 <<
"expected terminator to have as many operands as the parent op "
2524 for (
auto [i, operandType, resultType] :
2525 llvm::zip_equal(llvm::seq<unsigned>(0, terminator->
getNumOperands()),
2527 op.getFunctionType().getResults())) {
2528 if (operandType == resultType)
2531 <<
"the type of the terminator operand #" << i
2532 <<
" must match the type of the corresponding parent op result ("
2533 << operandType <<
" vs " << resultType <<
")";
2536 auto funcOp = cast<FunctionOpInterface>(*op);
2539 if (!
diag.succeeded())
2546LogicalResult transform::NamedSequenceOp::verify() {
2551template <
typename FnTy>
2556 types.reserve(1 + extraBindingTypes.size());
2557 types.push_back(bbArgType);
2558 llvm::append_range(types, extraBindingTypes);
2568 if constexpr (llvm::function_traits<FnTy>::num_args == 3) {
2576void transform::NamedSequenceOp::build(
OpBuilder &builder,
2579 SequenceBodyBuilderFn bodyBuilder,
2585 TypeAttr::get(FunctionType::get(builder.
getContext(),
2586 rootType, resultTypes)));
2602 size_t numAssociations =
2604 .Case([&](TransformHandleTypeInterface opHandle) {
2607 .Case([&](TransformValueHandleTypeInterface valueHandle) {
2610 .Case([&](TransformParamTypeInterface param) {
2611 return llvm::range_size(state.
getParams(getHandle()));
2613 .DefaultUnreachable(
"unknown kind of transform dialect type");
2614 results.
setParams(cast<OpResult>(getNum()),
2619LogicalResult transform::NumAssociationsOp::verify() {
2621 auto resultType = cast<TransformParamTypeInterface>(getNum().
getType());
2641 results.
set(cast<OpResult>(getResult()),
result);
2661 .Case([&](TransformHandleTypeInterface x) {
2664 .Case([&](TransformValueHandleTypeInterface x) {
2667 .Case([&](TransformParamTypeInterface x) {
2668 return llvm::range_size(state.
getParams(getHandle()));
2670 .DefaultUnreachable(
"unknown transform dialect type interface");
2672 auto produceNumOpsError = [&]() {
2673 return emitSilenceableError()
2674 << getHandle() <<
" expected to contain " << this->getNumResults()
2675 <<
" payloads but it contains " << numPayloads <<
" payloads";
2680 if (numPayloads > getNumResults() && !getOverflowResult().has_value())
2681 return produceNumOpsError();
2686 if (numPayloads < getNumResults() && getFailOnPayloadTooSmall() &&
2687 (numPayloads != 0 || !getPassThroughEmptyHandle()))
2688 return produceNumOpsError();
2692 if (getOverflowResult())
2693 resultHandles[*getOverflowResult()].reserve(numPayloads - getNumResults());
2695 auto container = [&]() {
2696 if (isa<TransformHandleTypeInterface>(getHandle().
getType())) {
2697 return llvm::map_to_vector(
2699 [](
Operation *op) -> MappedValue {
return op; });
2701 if (isa<TransformValueHandleTypeInterface>(getHandle().
getType())) {
2703 [](
Value v) -> MappedValue {
return v; });
2705 assert(isa<TransformParamTypeInterface>(getHandle().
getType()) &&
2706 "unsupported kind of transform dialect type");
2707 return llvm::map_to_vector(state.
getParams(getHandle()),
2708 [](
Attribute a) -> MappedValue {
return a; });
2711 for (
auto &&en : llvm::enumerate(container)) {
2712 int64_t resultNum = en.index();
2713 if (resultNum >= getNumResults())
2714 resultNum = *getOverflowResult();
2715 resultHandles[resultNum].push_back(en.value());
2719 for (
auto &&it : llvm::enumerate(resultHandles))
2726void transform::SplitHandleOp::getEffects(
2734LogicalResult transform::SplitHandleOp::verify() {
2735 if (getOverflowResult().has_value() &&
2736 !(*getOverflowResult() < getNumResults()))
2737 return emitOpError(
"overflow_result is not a valid result index");
2739 for (
Type resultType : getResultTypes()) {
2743 return emitOpError(
"expects result types to implement the same transform "
2744 "interface as the operand type");
2758 unsigned numRepetitions = llvm::range_size(state.
getPayloadOps(getPattern()));
2759 for (
const auto &en : llvm::enumerate(getHandles())) {
2760 Value handle = en.value();
2761 if (isa<TransformHandleTypeInterface>(handle.getType())) {
2765 payload.reserve(numRepetitions * current.size());
2766 for (
unsigned i = 0; i < numRepetitions; ++i)
2767 llvm::append_range(payload, current);
2768 results.
set(llvm::cast<OpResult>(getReplicated()[en.index()]), payload);
2770 assert(llvm::isa<TransformParamTypeInterface>(handle.getType()) &&
2771 "expected param type");
2774 params.reserve(numRepetitions * current.size());
2775 for (
unsigned i = 0; i < numRepetitions; ++i)
2776 llvm::append_range(params, current);
2777 results.
setParams(llvm::cast<OpResult>(getReplicated()[en.index()]),
2784void transform::ReplicateOp::getEffects(
2801 if (
failed(mapBlockArguments(state)))
2809 OpAsmParser &parser, std::optional<OpAsmParser::UnresolvedOperand> &root,
2816 root = std::nullopt;
2819 if (failed(hasRoot.
value()))
2833 if (failed(parser.
parseType(rootType))) {
2837 if (!extraBindings.empty()) {
2842 if (extraBindingTypes.size() != extraBindings.size()) {
2844 "expected types to be provided for all operands");
2860 bool hasExtras = !extraBindings.empty();
2870 printer << rootType;
2872 printer <<
", " << llvm::interleaved(extraBindingTypes) <<
')';
2879 auto iface = dyn_cast<transform::TransformOpInterface>(use.
getOwner());
2883 return isHandleConsumed(use.
get(), iface);
2894 if (!potentialConsumer) {
2895 potentialConsumer = &use;
2900 <<
" has more than one potential consumer";
2903 diag.attachNote(use.getOwner()->getLoc())
2904 <<
"used here as operand #" << use.getOperandNumber();
2911LogicalResult transform::SequenceOp::verify() {
2912 assert(getBodyBlock()->getNumArguments() >= 1 &&
2913 "the number of arguments must have been verified to be more than 1 by "
2914 "PossibleTopLevelTransformOpTrait");
2916 if (!getRoot() && !getExtraBindings().empty()) {
2918 <<
"does not expect extra operands when used as top-level";
2924 return (
emitOpError() <<
"block argument #" << arg.getArgNumber());
2931 for (
Operation &child : *getBodyBlock()) {
2932 if (!isa<TransformOpInterface>(child) &&
2933 &child != &getBodyBlock()->back()) {
2936 <<
"expected children ops to implement TransformOpInterface";
2937 diag.attachNote(child.getLoc()) <<
"op without interface";
2942 auto report = [&]() {
2943 return (child.emitError() <<
"result #" <<
result.getResultNumber());
2950 if (!getBodyBlock()->mightHaveTerminator())
2951 return emitOpError() <<
"expects to have a terminator in the body";
2953 if (getBodyBlock()->getTerminator()->getOperandTypes() !=
2954 getOperation()->getResultTypes()) {
2956 <<
"expects the types of the terminator operands "
2957 "to match the types of the result";
2958 diag.attachNote(getBodyBlock()->getTerminator()->getLoc()) <<
"terminator";
2964void transform::SequenceOp::getEffects(
2970transform::SequenceOp::getEntrySuccessorOperands(
RegionSuccessor successor) {
2971 assert(successor.
getSuccessor() == &getBody() &&
"unexpected region index");
2972 if (getOperation()->getNumOperands() > 0)
2973 return getOperation()->getOperands();
2975 getOperation()->operand_end());
2978void transform::SequenceOp::getSuccessorRegions(
2981 Region *bodyRegion = &getBody();
2982 regions.emplace_back(bodyRegion);
2988 "unexpected region index");
2994 if (getNumOperands() == 0)
2997 return getResults();
2998 return getBody().getArguments();
3001void transform::SequenceOp::getRegionInvocationBounds(
3004 bounds.emplace_back(1, 1);
3009 FailurePropagationMode failurePropagationMode,
3011 SequenceBodyBuilderFn bodyBuilder) {
3012 build(builder, state, resultTypes, failurePropagationMode, root,
3021 FailurePropagationMode failurePropagationMode,
3023 SequenceBodyBuilderArgsFn bodyBuilder) {
3024 build(builder, state, resultTypes, failurePropagationMode, root,
3032 FailurePropagationMode failurePropagationMode,
3034 SequenceBodyBuilderFn bodyBuilder) {
3035 build(builder, state, resultTypes, failurePropagationMode,
Value(),
3043 FailurePropagationMode failurePropagationMode,
3045 SequenceBodyBuilderArgsFn bodyBuilder) {
3046 build(builder, state, resultTypes, failurePropagationMode,
Value(),
3064 build(builder,
result, name);
3071 llvm::outs() <<
"[[[ IR printer: ";
3072 if (getName().has_value())
3073 llvm::outs() << *getName() <<
" ";
3076 if (getAssumeVerified().value_or(
false))
3078 if (getUseLocalScope().value_or(
false))
3080 if (getSkipRegions().value_or(
false))
3084 llvm::outs() <<
"top-level ]]]\n";
3086 llvm::outs() <<
"\n";
3087 llvm::outs().flush();
3091 llvm::outs() <<
"]]]\n";
3093 target->print(llvm::outs(), printFlags);
3094 llvm::outs() <<
"\n";
3097 llvm::outs().flush();
3101void transform::PrintOp::getEffects(
3106 if (!getTargetMutable().empty())
3126 <<
"failed to verify payload op";
3127 diag.attachNote(
target->getLoc()) <<
"payload op";
3133void transform::VerifyOp::getEffects(
3142void transform::YieldOp::getEffects(
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
static ParseResult parseKeyValuePair(AsmParser &parser, DataLayoutEntryInterface &entry, bool tryType=false)
Parse an entry which can either be of the form key = value or a dlti.dl_entry attribute.
static std::string diag(const llvm::Value &value)
static llvm::ManagedStatic< PassManagerOptions > options
ParseResult parseSymbolName(StringAttr &result)
Parse an -identifier and store it (without the '@' symbol) in a string attribute.
@ None
Zero or more operands with no delimiters.
@ Braces
{} brackets surrounding zero or more operands.
virtual ParseResult parseOptionalKeywordOrString(std::string *result)=0
Parse an optional keyword or string.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
MLIRContext * getContext() const
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseRSquare()=0
Parse a ] token.
virtual ParseResult parseOptionalRParen()=0
Parse a ) token if present.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual OptionalParseResult parseOptionalAttribute(Attribute &result, Type type={})=0
Parse an arbitrary optional attribute of a given type and return it in result.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual ParseResult parseColon()=0
Parse a : token.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseArrow()=0
Parse a '->' token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseComma()=0
Parse a , token.
virtual ParseResult parseOptionalLParen()=0
Parse a ( token if present.
ParseResult parseTypeList(SmallVectorImpl< Type > &result)
Parse a type list.
virtual ParseResult parseOptionalLSquare()=0
Parse a [ token if present.
virtual void decreaseIndent()
Decrease indentation.
virtual void increaseIndent()
Increase indentation.
virtual void printAttribute(Attribute attr)
virtual void printNewline()
Print a newline and indent the printer to the start of the current operation/attribute/type.
Attributes are known-constant values of operations.
This class represents an argument of a Block.
Block represents an ordered list of Operations.
BlockArgument getArgument(unsigned i)
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
OpListType & getOperations()
Operation * getTerminator()
Get the terminator operation of this block.
BlockArgListType getArguments()
iterator_range< iterator > without_terminator()
Return an iterator range over the operation within this block excluding the terminator operation at t...
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
This class is a general helper class for creating context-global objects like types,...
IntegerAttr getI64IntegerAttr(int64_t value)
StringAttr getStringAttr(const Twine &bytes)
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
MLIRContext * getContext() const
A compatibility class connecting InFlightDiagnostic to DiagnosedSilenceableFailure while providing an...
The result of a transform IR operation application.
LogicalResult silence()
Converts silenceable failure into LogicalResult success without reporting the diagnostic,...
static DiagnosedSilenceableFailure success()
Constructs a DiagnosedSilenceableFailure in the success state.
std::string getMessage() const
Returns the diagnostic message without emitting it.
Diagnostic & attachNote(std::optional< Location > loc=std::nullopt)
Attaches a note to the last diagnostic.
LogicalResult checkAndReport()
Converts all kinds of failure into a LogicalResult failure, emitting the diagnostic if necessary.
bool succeeded() const
Returns true if this is a success.
static DiagnosedSilenceableFailure definiteFailure()
Constructs a DiagnosedSilenceableFailure in the failure state.
Dialects are groups of MLIR operations, types and attributes, as well as behavior associated with the...
A class for computing basic dominance information.
This class represents a frozen set of patterns that can be processed by a pattern applicator.
This class allows control over how the GreedyPatternRewriteDriver works.
static constexpr int64_t kNoLimit
IRValueT get() const
Return the current value being used by this operand.
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
This class represents a diagnostic that is inflight and set to be reported.
This class represents upper and lower bounds on the number of times a region of a RegionBranchOpInter...
Conversion from types to the LLVM IR dialect.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
std::vector< Dialect * > getLoadedDialects()
Return information about all IR dialects loaded in the context.
ArrayRef< RegisteredOperationName > getRegisteredOperations()
Return a sorted array containing the information about all registered operations.
void append(StringRef name, Attribute attr)
Add an attribute with the specified name.
NamedAttribute represents a combination of a name and an Attribute value.
StringAttr getName() const
Return the name of the attribute.
Attribute getValue() const
Return the value of the attribute.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual OptionalParseResult parseOptionalOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single operand if present.
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
void printOperands(const ContainerType &container)
Print a comma separated list of operands.
virtual void printOperand(Value value)=0
Print implementations for various things an operation contains.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Listener * getListener() const
Returns the current listener of this builder, or nullptr if this builder doesn't have a listener.
This class represents a single result from folding an operation.
This class represents an operand of an operation.
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Set of flags used to control the behavior of the various IR print methods (e.g.
OpPrintingFlags & useLocalScope(bool enable=true)
Use local scope when printing the operation.
OpPrintingFlags & assumeVerified(bool enable=true)
Do not verify the operation when using custom operation printers.
OpPrintingFlags & skipRegions(bool skip=true)
Skip printing regions.
This is a value defined by a result of an operation.
This class provides the API for ops that are known to be isolated from above.
A trait used to provide symbol table functionalities to a region operation.
A wrapper class that allows for printing an operation with a set of flags, useful to act as a "stream...
This class implements the operand iterators for the Operation class.
type_range getType() const
type_range getTypes() const
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
Operation is the basic unit of execution within MLIR.
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Attribute getAttr(StringAttr name)
Return the specified attribute if present, null otherwise.
Operation * clone(IRMapping &mapper, CloneOptions options=CloneOptions::all())
Create a deep copy of this operation, remapping any operands that use values outside of the operation...
Block * getBlock()
Returns the operation block that contains this operation.
Operation * getParentWithTrait()
Returns the closest surrounding parent operation with trait Trait.
Location getLoc()
The source location the operation was defined or derived from.
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
unsigned getNumOperands()
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
OperationName getName()
The name of an operation is the key identifier for it.
void print(raw_ostream &os, const OpPrintingFlags &flags={})
operand_range getOperands()
Returns an iterator on the underlying Value's.
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
result_range getOpResults()
This class implements Optional functionality for ParseResult.
ParseResult value() const
Access the internal ParseResult value.
bool has_value() const
Returns true if we contain a valid ParseResult value.
static const PassInfo * lookup(StringRef passArg)
Returns the pass info for the specified pass class or null if unknown.
The main pass manager and pipeline builder.
static const PassPipelineInfo * lookup(StringRef pipelineArg)
Returns the pass pipeline info for the specified pass pipeline or null if unknown.
Structure to group information about a passes and pass pipelines (argument to invoke via mlir-opt,...
LogicalResult addToPipeline(OpPassManager &pm, StringRef options, function_ref< LogicalResult(const Twine &)> errorHandler) const
Adds this pass registry entry to the given pass manager.
This class represents a point being branched from in the methods of the RegionBranchOpInterface.
bool isParent() const
Returns true if branching from the parent op.
RegionBranchTerminatorOpInterface getTerminatorPredecessorOrNull() const
Returns the terminator if branching from a region.
This class represents a successor of a region.
static RegionSuccessor parent()
Initialize a successor that branches after/out of the parent operation.
bool isParent() const
Return true if the successor is the parent operation.
Region * getSuccessor() const
Return the given region successor.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
BlockArgListType getArguments()
This is a "type erased" representation of a registered operation.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
static DerivedEffect * get()
This class represents a collection of SymbolTables.
virtual Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
static StringRef getSymbolAttrName()
Return the name of the attribute used for symbol names.
static Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
type_range getType() const
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
use_range getUses() const
Returns a range of all uses, which is useful for iterating over all uses.
A utility result that is used to signal how to proceed with an ongoing walk:
static WalkResult advance()
bool wasInterrupted() const
Returns true if the walk was interrupted.
static WalkResult interrupt()
Operation * getOwner() const
Return the owner of this operand.
A named class for passing around the variadic flag.
void printFunctionOp(OpAsmPrinter &p, FunctionOpInterface op, bool isVariadic, StringRef typeAttrName, StringAttr argAttrsName, StringAttr resAttrsName)
Printer implementation for function-like operations.
ParseResult parseFunctionOp(OpAsmParser &parser, OperationState &result, bool allowVariadic, StringAttr typeAttrName, FuncTypeBuilder funcTypeBuilder, StringAttr argAttrsName, StringAttr resAttrsName)
Parser implementation for function-like operations.
Include the generated interface declarations.
InFlightDiagnostic emitWarning(Location loc)
Utility method to emit a warning message using this location.
function_ref< void(Value, StringRef)> OpAsmSetValueNameFn
A functor used to set the name of the start of a result group of an operation.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
const FrozenRewritePatternSet GreedyRewriteConfig config
LogicalResult applyPatternsGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
DiagnosedSilenceableFailure emitSilenceableFailure(Location loc, const Twine &message={})
Emits a silenceable failure with the given message.
llvm::DenseSet< ValueT, ValueInfoT > DenseSet
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
LogicalResult applyOpPatternsGreedily(ArrayRef< Operation * > ops, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr, bool *allErased=nullptr)
Rewrite the specified ops by repeatedly applying the highest benefit patterns in a greedy worklist dr...
llvm::SetVector< T, Vector, Set, N > SetVector
DiagnosedDefiniteFailure emitDefiniteFailure(Location loc, const Twine &message={})
Emits a definite failure with the given message.
bool isOpTriviallyDead(Operation *op)
Return true if the given operation is unused, and has no side effects on memory that prevent erasing.
const FrozenRewritePatternSet & patterns
void eliminateCommonSubExpressions(RewriterBase &rewriter, DominanceInfo &domInfo, Operation *op, bool *changed=nullptr)
Eliminate common subexpressions within the given operation.
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
llvm::function_ref< Fn > function_ref
size_t moveLoopInvariantCode(ArrayRef< Region * > regions, function_ref< bool(Value, Region *)> isDefinedOutsideRegion, function_ref< bool(Operation *, Region *)> shouldMoveOutOfRegion, function_ref< void(Operation *, Region *)> moveOutOfRegion)
Given a list of regions, perform loop-invariant code motion.
This is the representation of an operand reference.
This represents an operation in an abstracted form, suitable for use with the builder APIs.
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
SmallVector< std::unique_ptr< Region >, 1 > regions
Regions that the op will hold.
Region * addRegion()
Create a region that should be attached to the operation.