37 #include "llvm/ADT/DenseSet.h"
38 #include "llvm/ADT/STLExtras.h"
39 #include "llvm/ADT/ScopeExit.h"
40 #include "llvm/ADT/SmallPtrSet.h"
41 #include "llvm/ADT/TypeSwitch.h"
42 #include "llvm/Support/Debug.h"
43 #include "llvm/Support/ErrorHandling.h"
44 #include "llvm/Support/InterleavedRange.h"
47 #define DEBUG_TYPE "transform-dialect"
48 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "] ")
50 #define DEBUG_TYPE_MATCHER "transform-matcher"
51 #define DBGS_MATCHER() (llvm::dbgs() << "[" DEBUG_TYPE_MATCHER "] ")
52 #define DEBUG_MATCHER(x) DEBUG_WITH_TYPE(DEBUG_TYPE_MATCHER, x)
58 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &dynamicOptions);
64 OpAsmParser &parser, std::optional<OpAsmParser::UnresolvedOperand> &root,
66 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &extraBindings,
67 SmallVectorImpl<Type> &extraBindingTypes);
73 ArrayAttr matchers, ArrayAttr actions);
84 Operation *transformAncestor = transform.getOperation();
85 while (transformAncestor) {
86 if (transformAncestor == payload) {
88 transform.emitDefiniteFailure()
89 <<
"cannot apply transform to itself (or one of its ancestors)";
90 diag.attachNote(payload->
getLoc()) <<
"target payload op";
93 transformAncestor = transformAncestor->
getParentOp();
98 #define GET_OP_CLASSES
99 #include "mlir/Dialect/Transform/IR/TransformOps.cpp.inc"
107 if (!point.
isParent() && getOperation()->getNumOperands() == 1)
108 return getOperation()->getOperands();
110 getOperation()->operand_end());
113 void transform::AlternativesOp::getSuccessorRegions(
115 for (
Region &alternative : llvm::drop_begin(
119 regions.emplace_back(&alternative, !getOperands().empty()
120 ? alternative.getArguments()
124 regions.emplace_back(getOperation()->getResults());
127 void transform::AlternativesOp::getRegionInvocationBounds(
132 bounds.reserve(getNumRegions());
133 bounds.emplace_back(1, 1);
140 results.
set(res, {});
148 if (
Value scopeHandle = getScope())
149 llvm::append_range(originals, state.getPayloadOps(scopeHandle));
151 originals.push_back(state.getTopLevel());
154 if (original->isAncestor(getOperation())) {
156 <<
"scope must not contain the transforms being applied";
157 diag.attachNote(original->getLoc()) <<
"scope";
162 <<
"only isolated-from-above ops can be alternative scopes";
163 diag.attachNote(original->getLoc()) <<
"scope";
168 for (
Region ® : getAlternatives()) {
173 auto scope = state.make_region_scope(reg);
174 auto clones = llvm::to_vector(
175 llvm::map_range(originals, [](
Operation *op) {
return op->
clone(); }));
176 auto deleteClones = llvm::make_scope_exit([&] {
180 if (failed(state.mapBlockArguments(reg.front().getArgument(0), clones)))
184 for (
Operation &transform : reg.front().without_terminator()) {
186 state.applyTransform(cast<TransformOpInterface>(transform));
188 LLVM_DEBUG(
DBGS() <<
"alternative failed: " << result.
getMessage()
194 if (::mlir::failed(result.
silence()))
203 deleteClones.release();
204 TrackingListener listener(state, *
this);
206 for (
const auto &kvp : llvm::zip(originals, clones)) {
217 return emitSilenceableError() <<
"all alternatives failed";
220 void transform::AlternativesOp::getEffects(
221 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
224 for (
Region *region : getRegions()) {
225 if (!region->empty())
232 for (
Region &alternative : getAlternatives()) {
237 <<
"expects terminator operands to have the "
238 "same type as results of the operation";
239 diag.attachNote(terminator->
getLoc()) <<
"terminator";
256 llvm::to_vector(state.getPayloadOps(getTarget()));
259 if (
auto paramH = getParam()) {
261 if (params.size() != 1) {
262 if (targets.size() != params.size()) {
263 return emitSilenceableError()
264 <<
"parameter and target have different payload lengths ("
265 << params.size() <<
" vs " << targets.size() <<
")";
267 for (
auto &&[target, attr] : llvm::zip_equal(targets, params))
268 target->setAttr(getName(), attr);
273 for (
auto *target : targets)
274 target->setAttr(getName(), attr);
278 void transform::AnnotateOp::getEffects(
279 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
290 transform::ApplyCommonSubexpressionEliminationOp::applyToOne(
305 void transform::ApplyCommonSubexpressionEliminationOp::getEffects(
306 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
330 auto addDefiningOpsToWorklist = [&](
Operation *op) {
333 if (
Operation *defOp = v.getDefiningOp())
335 worklist.insert(defOp);
343 const auto *it = llvm::find(worklist, op);
344 if (it != worklist.end())
353 addDefiningOpsToWorklist(op);
359 while (!worklist.empty()) {
363 addDefiningOpsToWorklist(op);
370 void transform::ApplyDeadCodeEliminationOp::getEffects(
371 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
395 if (!getRegion().empty()) {
396 for (
Operation &op : getRegion().front()) {
397 cast<transform::PatternDescriptorOpInterface>(&op)
398 .populatePatternsWithState(
patterns, state);
408 config.setMaxIterations(getMaxIterations() ==
static_cast<uint64_t
>(-1)
410 : getMaxIterations());
411 config.setMaxNumRewrites(getMaxNumRewrites() ==
static_cast<uint64_t
>(-1)
413 : getMaxNumRewrites());
418 bool cseChanged =
false;
421 static const int64_t kNumMaxIterations = 50;
422 int64_t iteration = 0;
424 LogicalResult result = failure();
437 if (target != nestedOp)
438 ops.push_back(nestedOp);
445 if (failed(result)) {
447 <<
"greedy pattern application failed";
455 }
while (cseChanged && ++iteration < kNumMaxIterations);
457 if (iteration == kNumMaxIterations)
464 if (!getRegion().empty()) {
465 for (
Operation &op : getRegion().front()) {
466 if (!isa<transform::PatternDescriptorOpInterface>(&op)) {
468 <<
"expected children ops to implement "
469 "PatternDescriptorOpInterface";
470 diag.attachNote(op.
getLoc()) <<
"op without interface";
478 void transform::ApplyPatternsOp::getEffects(
479 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
484 void transform::ApplyPatternsOp::build(
493 bodyBuilder(builder, result.
location);
500 void transform::ApplyCanonicalizationPatternsOp::populatePatterns(
504 dialect->getCanonicalizationPatterns(
patterns);
506 op.getCanonicalizationPatterns(
patterns, ctx);
520 std::unique_ptr<TypeConverter> defaultTypeConverter;
521 transform::TypeConverterBuilderOpInterface typeConverterBuilder =
522 getDefaultTypeConverter();
523 if (typeConverterBuilder)
524 defaultTypeConverter = typeConverterBuilder.getTypeConverter();
529 for (
Attribute attr : cast<ArrayAttr>(*getLegalOps()))
530 conversionTarget.addLegalOp(
533 for (
Attribute attr : cast<ArrayAttr>(*getIllegalOps()))
534 conversionTarget.addIllegalOp(
536 if (getLegalDialects())
537 for (
Attribute attr : cast<ArrayAttr>(*getLegalDialects()))
538 conversionTarget.addLegalDialect(cast<StringAttr>(attr).getValue());
539 if (getIllegalDialects())
540 for (
Attribute attr : cast<ArrayAttr>(*getIllegalDialects()))
541 conversionTarget.addIllegalDialect(cast<StringAttr>(attr).getValue());
549 if (!getPatterns().empty()) {
550 for (
Operation &op : getPatterns().front()) {
552 cast<transform::ConversionPatternDescriptorOpInterface>(&op);
555 std::unique_ptr<TypeConverter> typeConverter =
556 descriptor.getTypeConverter();
559 keepAliveConverters.emplace_back(std::move(typeConverter));
560 converter = keepAliveConverters.back().get();
563 if (!defaultTypeConverter) {
565 <<
"pattern descriptor does not specify type "
566 "converter and apply_conversion_patterns op has "
567 "no default type converter";
568 diag.attachNote(op.
getLoc()) <<
"pattern descriptor op";
571 converter = defaultTypeConverter.get();
577 descriptor.populateConversionTargetRules(*converter, conversionTarget);
579 descriptor.populatePatterns(*converter,
patterns);
587 TrackingListenerConfig trackingConfig;
588 trackingConfig.requireMatchingReplacementOpName =
false;
589 ErrorCheckingTrackingListener trackingListener(state, *
this, trackingConfig);
591 if (getPreserveHandles())
592 conversionConfig.
listener = &trackingListener;
595 for (
Operation *target : state.getPayloadOps(getTarget())) {
603 LogicalResult status = failure();
604 if (getPartialConversion()) {
614 if (failed(status)) {
615 diag = emitSilenceableError() <<
"dialect conversion failed";
616 diag.attachNote(target->
getLoc()) <<
"target op";
621 trackingListener.checkAndResetError();
623 if (
diag.succeeded()) {
625 return trackingFailure;
627 diag.attachNote() <<
"tracking listener also failed: "
629 (void)trackingFailure.
silence();
633 if (!
diag.succeeded())
641 if (getNumRegions() != 1 && getNumRegions() != 2)
642 return emitOpError() <<
"expected 1 or 2 regions";
643 if (!getPatterns().empty()) {
644 for (
Operation &op : getPatterns().front()) {
645 if (!isa<transform::ConversionPatternDescriptorOpInterface>(&op)) {
647 emitOpError() <<
"expected pattern children ops to implement "
648 "ConversionPatternDescriptorOpInterface";
649 diag.attachNote(op.
getLoc()) <<
"op without interface";
654 if (getNumRegions() == 2) {
655 Region &typeConverterRegion = getRegion(1);
656 if (!llvm::hasSingleElement(typeConverterRegion.
front()))
658 <<
"expected exactly one op in default type converter region";
660 auto typeConverterOp = dyn_cast<transform::TypeConverterBuilderOpInterface>(
662 if (!typeConverterOp) {
664 <<
"expected default converter child op to "
665 "implement TypeConverterBuilderOpInterface";
666 diag.attachNote(maybeTypeConverter->
getLoc()) <<
"op without interface";
670 if (!getPatterns().empty()) {
671 for (
Operation &op : getPatterns().front()) {
673 cast<transform::ConversionPatternDescriptorOpInterface>(&op);
674 if (failed(descriptor.verifyTypeConverter(typeConverterOp)))
682 void transform::ApplyConversionPatternsOp::getEffects(
683 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
684 if (!getPreserveHandles()) {
692 void transform::ApplyConversionPatternsOp::build(
702 if (patternsBodyBuilder)
703 patternsBodyBuilder(builder, result.
location);
709 if (typeConverterBodyBuilder)
710 typeConverterBodyBuilder(builder, result.
location);
718 void transform::ApplyToLLVMConversionPatternsOp::populatePatterns(
721 assert(dialect &&
"expected that dialect is loaded");
722 auto *iface = cast<ConvertToLLVMPatternInterface>(dialect);
726 iface->populateConvertToLLVMConversionPatterns(
730 LogicalResult transform::ApplyToLLVMConversionPatternsOp::verifyTypeConverter(
731 transform::TypeConverterBuilderOpInterface builder) {
732 if (builder.getTypeConverterType() !=
"LLVMTypeConverter")
733 return emitOpError(
"expected LLVMTypeConverter");
740 return emitOpError(
"unknown dialect or dialect not loaded: ")
742 auto *iface = dyn_cast<ConvertToLLVMPatternInterface>(dialect);
745 "dialect does not implement ConvertToLLVMPatternInterface or "
746 "extension was not loaded: ")
756 transform::ApplyLoopInvariantCodeMotionOp::applyToOne(
766 void transform::ApplyLoopInvariantCodeMotionOp::getEffects(
767 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
776 void transform::ApplyRegisteredPassOp::getEffects(
777 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
794 llvm::raw_string_ostream optionsStream(
options);
799 if (
auto paramOperand = dyn_cast<transform::ParamOperandAttr>(valueAttr)) {
802 size_t dynamicOptionIdx = paramOperand.getIndex().getInt();
803 assert(dynamicOptionIdx < dynamicOptions.size() &&
804 "the number of ParamOperandAttrs in the options DictionaryAttr"
805 "should be the same as the number of options passed as params");
807 state.getParams(dynamicOptions[dynamicOptionIdx]);
809 llvm::interleave(attrsAssociatedToParam, optionsStream, appendValueAttr,
811 }
else if (
auto arrayAttr = dyn_cast<ArrayAttr>(valueAttr)) {
813 llvm::interleave(arrayAttr, optionsStream, appendValueAttr,
",");
814 }
else if (
auto strAttr = dyn_cast<StringAttr>(valueAttr)) {
816 optionsStream << strAttr.getValue().str();
819 valueAttr.print(optionsStream,
true);
825 getOptions(), optionsStream,
826 [&](
auto namedAttribute) {
827 optionsStream << namedAttribute.getName().str();
828 optionsStream <<
"=";
829 appendValueAttr(namedAttribute.getValue());
832 optionsStream.flush();
840 <<
"unknown pass or pass pipeline: " << getPassName();
849 <<
"failed to add pass or pass pipeline to pipeline: "
865 if (failed(pm.run(target))) {
866 auto diag = emitSilenceableError() <<
"pass pipeline failed";
867 diag.attachNote(target->
getLoc()) <<
"target op";
873 results.
set(llvm::cast<OpResult>(getResult()), targets);
879 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &dynamicOptions) {
882 size_t dynamicOptionsIdx = 0;
888 std::function<ParseResult(
Attribute &)> parseValue =
889 [&](
Attribute &valueAttr) -> ParseResult {
897 [&]() -> ParseResult { return parseValue(attrs.emplace_back()); },
898 " in options dictionary") ||
912 ParseResult parsedOperand = parser.
parseOperand(operand);
913 if (failed(parsedOperand))
919 dynamicOptions.push_back(operand);
924 }
else if (failed(parsedValueAttr.
value())) {
926 }
else if (isa<transform::ParamOperandAttr>(valueAttr)) {
928 <<
"the param_operand attribute is a marker reserved for "
929 <<
"indicating a value will be passed via params and is only used "
930 <<
"in the generic print format";
944 <<
"expected key to either be an identifier or a string";
948 <<
"expected '=' after key in key-value pair";
950 if (failed(parseValue(valueAttr)))
952 <<
"expected a valid attribute or operand as value associated "
953 <<
"to key '" << key <<
"'";
962 " in options dictionary"))
965 if (DictionaryAttr::findDuplicate(
966 keyValuePairs,
false)
969 <<
"duplicate keys found in options dictionary";
984 if (
auto paramOperandAttr =
985 dyn_cast<transform::ParamOperandAttr>(valueAttr)) {
988 dynamicOptions[paramOperandAttr.getIndex().getInt()]);
989 }
else if (
auto arrayAttr = dyn_cast<ArrayAttr>(valueAttr)) {
1001 printer << namedAttribute.
getName();
1015 std::function<LogicalResult(
Attribute)> checkOptionValue =
1016 [&](
Attribute valueAttr) -> LogicalResult {
1017 if (
auto paramOperand = dyn_cast<transform::ParamOperandAttr>(valueAttr)) {
1018 size_t dynamicOptionIdx = paramOperand.getIndex().getInt();
1019 if (dynamicOptionIdx < 0 || dynamicOptionIdx >= dynamicOptions.size())
1020 return emitOpError()
1021 <<
"dynamic option index " << dynamicOptionIdx
1022 <<
" is out of bounds for the number of dynamic options: "
1023 << dynamicOptions.size();
1024 if (dynamicOptions[dynamicOptionIdx] ==
nullptr)
1025 return emitOpError() <<
"dynamic option index " << dynamicOptionIdx
1026 <<
" is already used in options";
1027 dynamicOptions[dynamicOptionIdx] =
nullptr;
1028 }
else if (
auto arrayAttr = dyn_cast<ArrayAttr>(valueAttr)) {
1030 for (
auto eltAttr : arrayAttr)
1031 if (failed(checkOptionValue(eltAttr)))
1038 if (failed(checkOptionValue(namedAttr.getValue())))
1042 for (
Value dynamicOption : dynamicOptions)
1044 return emitOpError() <<
"a param operand does not have a corresponding "
1045 <<
"param_operand attr in the options dict";
1056 Operation *target, ApplyToEachResultList &results,
1058 results.push_back(target);
1062 void transform::CastOp::getEffects(
1063 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
1070 assert(inputs.size() == 1 &&
"expected one input");
1071 assert(outputs.size() == 1 &&
"expected one output");
1072 return llvm::all_of(
1073 std::initializer_list<Type>{inputs.front(), outputs.front()},
1074 llvm::IsaPred<transform::TransformHandleTypeInterface>);
1094 assert(block.
getParent() &&
"cannot match using a detached block");
1095 auto matchScope = state.make_region_scope(*block.
getParent());
1097 state.mapBlockArguments(block.
getArguments(), blockArgumentMapping)))
1101 if (!isa<transform::MatchOpInterface>(match)) {
1103 <<
"expected operations in the match part to "
1104 "implement MatchOpInterface";
1107 state.applyTransform(cast<transform::TransformOpInterface>(match));
1108 if (
diag.succeeded())
1126 template <
typename... Tys>
1128 return ((isa<Tys>(t1) && isa<Tys>(t2)) || ... ||
false);
1135 transform::TransformParamTypeInterface,
1136 transform::TransformValueHandleTypeInterface>(
1148 auto matcher = SymbolTable::lookupNearestSymbolFrom<FunctionOpInterface>(
1149 getOperation(), getMatcher());
1150 if (matcher.isExternal()) {
1152 <<
"unresolved external symbol " << getMatcher();
1156 rawResults.resize(getOperation()->getNumResults());
1157 std::optional<DiagnosedSilenceableFailure> maybeFailure;
1158 for (
Operation *root : state.getPayloadOps(getRoot())) {
1162 op->
print(llvm::dbgs(),
1164 llvm::dbgs() <<
" @" << op <<
"\n";
1171 matcher.getFunctionBody().front(),
1174 if (
diag.isDefiniteFailure())
1176 if (
diag.isSilenceableFailure()) {
1178 <<
" failed: " <<
diag.getMessage());
1184 if (mapping.size() != 1) {
1185 maybeFailure.emplace(emitSilenceableError()
1186 <<
"result #" << i <<
", associated with "
1188 <<
" payload objects, expected 1");
1191 rawResults[i].push_back(mapping[0]);
1196 return std::move(*maybeFailure);
1197 assert(!maybeFailure &&
"failure set but the walk was not interrupted");
1199 for (
auto &&[opResult, rawResult] :
1200 llvm::zip_equal(getOperation()->getResults(), rawResults)) {
1207 void transform::CollectMatchingOp::getEffects(
1208 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
1214 LogicalResult transform::CollectMatchingOp::verifySymbolUses(
1216 auto matcherSymbol = dyn_cast_or_null<FunctionOpInterface>(
1218 if (!matcherSymbol ||
1219 !isa<TransformOpInterface>(matcherSymbol.getOperation()))
1220 return emitError() <<
"unresolved matcher symbol " << getMatcher();
1223 if (argumentTypes.size() != 1 ||
1224 !isa<TransformHandleTypeInterface>(argumentTypes[0])) {
1226 <<
"expected the matcher to take one operation handle argument";
1228 if (!matcherSymbol.getArgAttr(
1229 0, transform::TransformDialect::kArgReadOnlyAttrName)) {
1230 return emitError() <<
"expected the matcher argument to be marked readonly";
1234 if (resultTypes.size() != getOperation()->getNumResults()) {
1236 <<
"expected the matcher to yield as many values as op has results ("
1237 << getOperation()->getNumResults() <<
"), got "
1238 << resultTypes.size();
1241 for (
auto &&[i, matcherType, resultType] :
1247 <<
"mismatching type interfaces for matcher result and op result #"
1259 bool transform::ForeachMatchOp::allowsRepeatedHandleOperands() {
return true; }
1267 matchActionPairs.reserve(getMatchers().size());
1269 for (
auto &&[matcher, action] :
1270 llvm::zip_equal(getMatchers(), getActions())) {
1271 auto matcherSymbol =
1273 getOperation(), cast<SymbolRefAttr>(matcher));
1276 getOperation(), cast<SymbolRefAttr>(action));
1277 assert(matcherSymbol && actionSymbol &&
1278 "unresolved symbols not caught by the verifier");
1280 if (matcherSymbol.isExternal())
1282 if (actionSymbol.isExternal())
1285 matchActionPairs.emplace_back(matcherSymbol, actionSymbol);
1296 matchInputMapping.emplace_back();
1298 getForwardedInputs(), state);
1300 actionResultMapping.resize(getForwardedOutputs().size());
1302 for (
Operation *root : state.getPayloadOps(getRoot())) {
1306 if (!getRestrictRoot() && op == root)
1311 op->
print(llvm::dbgs(),
1313 llvm::dbgs() <<
" @" << op <<
"\n";
1316 firstMatchArgument.clear();
1317 firstMatchArgument.push_back(op);
1320 for (
auto [matcher, action] : matchActionPairs) {
1322 matchBlock(matcher.getFunctionBody().front(), matchInputMapping,
1323 state, matchOutputMapping);
1324 if (
diag.isDefiniteFailure())
1326 if (
diag.isSilenceableFailure()) {
1328 <<
" failed: " <<
diag.getMessage());
1332 auto scope = state.make_region_scope(action.getFunctionBody());
1333 if (failed(state.mapBlockArguments(
1334 action.getFunctionBody().front().getArguments(),
1335 matchOutputMapping))) {
1340 action.getFunctionBody().front().without_terminator()) {
1342 state.applyTransform(cast<TransformOpInterface>(transform));
1347 overallDiag = emitSilenceableError() <<
"actions failed";
1352 <<
"when applied to this matching payload";
1359 action.getFunctionBody().front().getTerminator()->getOperands(),
1360 state, getFlattenResults()))) {
1362 <<
"action @" << action.getName()
1363 <<
" has results associated with multiple payload entities, "
1364 "but flattening was not requested";
1379 results.
set(llvm::cast<OpResult>(getUpdated()),
1380 state.getPayloadOps(getRoot()));
1381 for (
auto &&[result, mapping] :
1382 llvm::zip_equal(getForwardedOutputs(), actionResultMapping)) {
1388 void transform::ForeachMatchOp::getAsmResultNames(
1390 setNameFn(getUpdated(),
"updated_root");
1391 for (
Value v : getForwardedOutputs()) {
1392 setNameFn(v,
"yielded");
1396 void transform::ForeachMatchOp::getEffects(
1397 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
1399 if (getOperation()->getNumOperands() < 1 ||
1400 getOperation()->getNumResults() < 1) {
1413 ArrayAttr &matchers,
1414 ArrayAttr &actions) {
1436 ArrayAttr matchers, ArrayAttr actions) {
1439 for (
auto &&[matcher, action, idx] : llvm::zip_equal(
1440 matchers, actions, llvm::seq<unsigned>(0, matchers.size()))) {
1442 printer << cast<SymbolRefAttr>(matcher) <<
" -> "
1443 << cast<SymbolRefAttr>(action);
1444 if (idx != matchers.size() - 1)
1452 if (getMatchers().size() != getActions().size())
1453 return emitOpError() <<
"expected the same number of matchers and actions";
1454 if (getMatchers().empty())
1455 return emitOpError() <<
"expected at least one match/action pair";
1459 if (matcherNames.insert(name).second)
1462 <<
" is used more than once, only the first match will apply";
1473 bool alsoVerifyInternal =
false) {
1474 auto transformOp = cast<transform::TransformOpInterface>(op.getOperation());
1475 llvm::SmallDenseSet<unsigned> consumedArguments;
1476 if (!op.isExternal()) {
1480 for (
unsigned i = 0, e = op.getNumArguments(); i < e; ++i) {
1482 op.getArgAttr(i, transform::TransformDialect::kArgConsumedAttrName) !=
1485 op.getArgAttr(i, transform::TransformDialect::kArgReadOnlyAttrName) !=
1487 if (isConsumed && isReadOnly) {
1488 return transformOp.emitSilenceableError()
1489 <<
"argument #" << i <<
" cannot be both readonly and consumed";
1491 if ((op.isExternal() || alsoVerifyInternal) && !isConsumed && !isReadOnly) {
1492 return transformOp.emitSilenceableError()
1493 <<
"must provide consumed/readonly status for arguments of "
1494 "external or called ops";
1496 if (op.isExternal())
1499 if (consumedArguments.contains(i) && !isConsumed && isReadOnly) {
1500 return transformOp.emitSilenceableError()
1501 <<
"argument #" << i
1502 <<
" is consumed in the body but is not marked as such";
1504 if (emitWarnings && !consumedArguments.contains(i) && isConsumed) {
1508 <<
"op argument #" << i
1509 <<
" is not consumed in the body but is marked as consumed";
1515 LogicalResult transform::ForeachMatchOp::verifySymbolUses(
1517 assert(getMatchers().size() == getActions().size());
1520 for (
auto &&[matcher, action] :
1521 llvm::zip_equal(getMatchers(), getActions())) {
1523 auto matcherSymbol = dyn_cast_or_null<FunctionOpInterface>(
1525 cast<SymbolRefAttr>(matcher)));
1526 auto actionSymbol = dyn_cast_or_null<FunctionOpInterface>(
1528 cast<SymbolRefAttr>(action)));
1529 if (!matcherSymbol ||
1530 !isa<TransformOpInterface>(matcherSymbol.getOperation()))
1531 return emitError() <<
"unresolved matcher symbol " << matcher;
1532 if (!actionSymbol ||
1533 !isa<TransformOpInterface>(actionSymbol.getOperation()))
1534 return emitError() <<
"unresolved action symbol " << action;
1539 .checkAndReport())) {
1545 .checkAndReport())) {
1550 TypeRange operandTypes = getOperandTypes();
1551 TypeRange matcherArguments = matcherSymbol.getArgumentTypes();
1552 if (operandTypes.size() != matcherArguments.size()) {
1554 emitError() <<
"the number of operands (" << operandTypes.size()
1555 <<
") doesn't match the number of matcher arguments ("
1556 << matcherArguments.size() <<
") for " << matcher;
1557 diag.attachNote(matcherSymbol->getLoc()) <<
"symbol declaration";
1560 for (
auto &&[i, operand, argument] :
1562 if (matcherSymbol.getArgAttr(i, consumedAttr)) {
1565 <<
"does not expect matcher symbol to consume its operand #" << i;
1566 diag.attachNote(matcherSymbol->getLoc()) <<
"symbol declaration";
1575 <<
"mismatching type interfaces for operand and matcher argument #"
1576 << i <<
" of matcher " << matcher;
1577 diag.attachNote(matcherSymbol->getLoc()) <<
"symbol declaration";
1582 TypeRange matcherResults = matcherSymbol.getResultTypes();
1583 TypeRange actionArguments = actionSymbol.getArgumentTypes();
1584 if (matcherResults.size() != actionArguments.size()) {
1585 return emitError() <<
"mismatching number of matcher results and "
1586 "action arguments between "
1587 << matcher <<
" (" << matcherResults.size() <<
") and "
1588 << action <<
" (" << actionArguments.size() <<
")";
1590 for (
auto &&[i, matcherType, actionType] :
1595 return emitError() <<
"mismatching type interfaces for matcher result "
1596 "and action argument #"
1597 << i <<
"of matcher " << matcher <<
" and action "
1602 TypeRange actionResults = actionSymbol.getResultTypes();
1603 auto resultTypes =
TypeRange(getResultTypes()).drop_front();
1604 if (actionResults.size() != resultTypes.size()) {
1606 emitError() <<
"the number of action results ("
1607 << actionResults.size() <<
") for " << action
1608 <<
" doesn't match the number of extra op results ("
1609 << resultTypes.size() <<
")";
1610 diag.attachNote(actionSymbol->getLoc()) <<
"symbol declaration";
1613 for (
auto &&[i, resultType, actionType] :
1619 emitError() <<
"mismatching type interfaces for action result #" << i
1620 <<
" of action " << action <<
" and op result";
1621 diag.attachNote(actionSymbol->getLoc()) <<
"symbol declaration";
1640 size_t numIterations = payloads.empty() ? 0 : payloads.front().size();
1641 bool withZipShortest = getWithZipShortest();
1645 if (withZipShortest) {
1649 return A.size() <
B.size();
1652 for (
size_t argIdx = 0; argIdx < payloads.size(); argIdx++)
1653 payloads[argIdx].resize(numIterations);
1659 for (
size_t argIdx = 1; !withZipShortest && argIdx < payloads.size();
1661 if (payloads[argIdx].size() != numIterations) {
1662 return emitSilenceableError()
1663 <<
"prior targets' payload size (" << numIterations
1664 <<
") differs from payload size (" << payloads[argIdx].size()
1665 <<
") of target " << getTargets()[argIdx];
1674 for (
size_t iterIdx = 0; iterIdx < numIterations; iterIdx++) {
1675 auto scope = state.make_region_scope(getBody());
1681 if (failed(state.mapBlockArgument(blockArg, {argument})))
1686 for (
Operation &transform : getBody().front().without_terminator()) {
1688 llvm::cast<transform::TransformOpInterface>(transform));
1694 OperandRange yieldOperands = getYieldOp().getOperands();
1695 for (
auto &&[result, yieldOperand, resTuple] :
1696 llvm::zip_equal(getResults(), yieldOperands, zippedResults))
1698 if (isa<TransformHandleTypeInterface>(result.getType()))
1699 llvm::append_range(resTuple, state.getPayloadOps(yieldOperand));
1700 else if (isa<TransformValueHandleTypeInterface>(result.getType()))
1701 llvm::append_range(resTuple, state.getPayloadValues(yieldOperand));
1702 else if (isa<TransformParamTypeInterface>(result.getType()))
1703 llvm::append_range(resTuple, state.getParams(yieldOperand));
1705 assert(
false &&
"unhandled handle type");
1709 for (
auto &&[result, resPayload] : zip_equal(getResults(), zippedResults))
1715 void transform::ForeachOp::getEffects(
1716 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
1719 for (
auto &&[target, blockArg] :
1720 llvm::zip(getTargetsMutable(), getBody().front().getArguments())) {
1722 if (any_of(getBody().front().without_terminator(), [&](
Operation &op) {
1724 cast<TransformOpInterface>(&op));
1732 if (any_of(getBody().front().without_terminator(), [&](
Operation &op) {
1736 }
else if (any_of(getBody().front().without_terminator(), [&](
Operation &op) {
1745 void transform::ForeachOp::getSuccessorRegions(
1747 Region *bodyRegion = &getBody();
1749 regions.emplace_back(bodyRegion, bodyRegion->
getArguments());
1754 assert(point == getBody() &&
"unexpected region index");
1755 regions.emplace_back(bodyRegion, bodyRegion->
getArguments());
1756 regions.emplace_back();
1763 assert(point == getBody() &&
"unexpected region index");
1764 return getOperation()->getOperands();
1767 transform::YieldOp transform::ForeachOp::getYieldOp() {
1768 return cast<transform::YieldOp>(getBody().front().getTerminator());
1772 for (
auto [targetOpt, bodyArgOpt] :
1773 llvm::zip_longest(getTargets(), getBody().front().getArguments())) {
1774 if (!targetOpt || !bodyArgOpt)
1775 return emitOpError() <<
"expects the same number of targets as the body "
1776 "has block arguments";
1777 if (targetOpt.value().getType() != bodyArgOpt.value().getType())
1779 "expects co-indexed targets and the body's "
1780 "block arguments to have the same op/value/param type");
1783 for (
auto [resultOpt, yieldOperandOpt] :
1784 llvm::zip_longest(getResults(), getYieldOp().getOperands())) {
1785 if (!resultOpt || !yieldOperandOpt)
1786 return emitOpError() <<
"expects the same number of results as the "
1787 "yield terminator has operands";
1788 if (resultOpt.value().getType() != yieldOperandOpt.value().getType())
1789 return emitOpError(
"expects co-indexed results and yield "
1790 "operands to have the same op/value/param type");
1806 for (
Operation *target : state.getPayloadOps(getTarget())) {
1808 for (int64_t i = 0, e = getNthParent(); i < e; ++i) {
1811 bool checkIsolatedFromAbove =
1812 !getIsolatedFromAbove() ||
1814 bool checkOpName = !getOpName().has_value() ||
1816 if (checkIsolatedFromAbove && checkOpName)
1821 if (getAllowEmptyResults()) {
1822 results.
set(llvm::cast<OpResult>(getResult()), parents);
1826 emitSilenceableError()
1827 <<
"could not find a parent op that matches all requirements";
1828 diag.attachNote(target->
getLoc()) <<
"target op";
1832 if (getDeduplicate()) {
1833 if (resultSet.insert(parent).second)
1834 parents.push_back(parent);
1836 parents.push_back(parent);
1839 results.
set(llvm::cast<OpResult>(getResult()), parents);
1851 int64_t resultNumber = getResultNumber();
1852 auto payloadOps = state.getPayloadOps(getTarget());
1853 if (std::empty(payloadOps)) {
1854 results.
set(cast<OpResult>(getResult()), {});
1857 if (!llvm::hasSingleElement(payloadOps))
1859 <<
"handle must be mapped to exactly one payload op";
1861 Operation *target = *payloadOps.begin();
1864 results.
set(llvm::cast<OpResult>(getResult()),
1878 for (
Value v : state.getPayloadValues(getTarget())) {
1879 if (llvm::isa<BlockArgument>(v)) {
1881 emitSilenceableError() <<
"cannot get defining op of block argument";
1882 diag.attachNote(v.getLoc()) <<
"target value";
1885 definingOps.push_back(v.getDefiningOp());
1887 results.
set(llvm::cast<OpResult>(getResult()), definingOps);
1899 int64_t operandNumber = getOperandNumber();
1901 for (
Operation *target : state.getPayloadOps(getTarget())) {
1908 emitSilenceableError()
1909 <<
"could not find a producer for operand number: " << operandNumber
1910 <<
" of " << *target;
1911 diag.attachNote(target->getLoc()) <<
"target op";
1914 producers.push_back(producer);
1916 results.
set(llvm::cast<OpResult>(getResult()), producers);
1929 for (
Operation *target : state.getPayloadOps(getTarget())) {
1932 getLoc(), getIsAll(), getIsInverted(), getRawPositionList(),
1933 target->getNumOperands(), operandPositions);
1934 if (
diag.isSilenceableFailure()) {
1935 diag.attachNote(target->getLoc())
1936 <<
"while considering positions of this payload operation";
1939 llvm::append_range(operands,
1940 llvm::map_range(operandPositions, [&](int64_t pos) {
1941 return target->getOperand(pos);
1944 results.
setValues(cast<OpResult>(getResult()), operands);
1950 getIsInverted(), getIsAll());
1962 for (
Operation *target : state.getPayloadOps(getTarget())) {
1965 getLoc(), getIsAll(), getIsInverted(), getRawPositionList(),
1966 target->getNumResults(), resultPositions);
1967 if (
diag.isSilenceableFailure()) {
1968 diag.attachNote(target->getLoc())
1969 <<
"while considering positions of this payload operation";
1972 llvm::append_range(opResults,
1973 llvm::map_range(resultPositions, [&](int64_t pos) {
1974 return target->getResult(pos);
1977 results.
setValues(cast<OpResult>(getResult()), opResults);
1983 getIsInverted(), getIsAll());
1990 void transform::GetTypeOp::getEffects(
1991 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
2002 for (
Value value : state.getPayloadValues(getValue())) {
2003 Type type = value.getType();
2004 if (getElemental()) {
2005 if (
auto shaped = dyn_cast<ShapedType>(type)) {
2006 type = shaped.getElementType();
2011 results.
setParams(cast<OpResult>(getResult()), params);
2028 state.applyTransform(cast<transform::TransformOpInterface>(transform));
2033 if (mode == transform::FailurePropagationMode::Propagate) {
2052 auto callee = SymbolTable::lookupNearestSymbolFrom<NamedSequenceOp>(
2053 getOperation(), getTarget());
2054 assert(callee &&
"unverified reference to unknown symbol");
2056 if (callee.isExternal())
2062 auto scope = state.make_region_scope(callee.getBody());
2063 for (
auto &&[arg, map] :
2064 llvm::zip_equal(callee.getBody().front().getArguments(), mappings)) {
2065 if (failed(state.mapBlockArgument(arg, map)))
2070 callee.getBody().front(), getFailurePropagationMode(), state, results);
2073 mappings, callee.getBody().front().getTerminator()->getOperands(), state);
2074 for (
auto &&[result, mapping] : llvm::zip_equal(getResults(), mappings))
2082 void transform::IncludeOp::getEffects(
2083 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
2097 auto defaultEffects = [&] {
2104 getOperation()->getAttrOfType<SymbolRefAttr>(getTargetAttrName());
2106 return defaultEffects();
2107 auto callee = SymbolTable::lookupNearestSymbolFrom<NamedSequenceOp>(
2108 getOperation(), getTarget());
2110 return defaultEffects();
2114 (void)earlyVerifierResult.
silence();
2115 return defaultEffects();
2118 for (
unsigned i = 0, e = getNumOperands(); i < e; ++i) {
2119 if (callee.getArgAttr(i, TransformDialect::kArgConsumedAttrName))
2130 auto targetAttr = getOperation()->getAttrOfType<SymbolRefAttr>(
"target");
2132 return emitOpError() <<
"expects a 'target' symbol reference attribute";
2137 return emitOpError() <<
"does not reference a named transform sequence";
2139 FunctionType fnType = target.getFunctionType();
2140 if (fnType.getNumInputs() != getNumOperands())
2141 return emitError(
"incorrect number of operands for callee");
2143 for (
unsigned i = 0, e = fnType.getNumInputs(); i != e; ++i) {
2144 if (getOperand(i).
getType() != fnType.getInput(i)) {
2145 return emitOpError(
"operand type mismatch: expected operand type ")
2146 << fnType.getInput(i) <<
", but provided "
2147 << getOperand(i).getType() <<
" for operand number " << i;
2151 if (fnType.getNumResults() != getNumResults())
2152 return emitError(
"incorrect number of results for callee");
2154 for (
unsigned i = 0, e = fnType.getNumResults(); i != e; ++i) {
2155 Type resultType = getResult(i).getType();
2156 Type funcType = fnType.getResult(i);
2158 return emitOpError() <<
"type of result #" << i
2159 <<
" must implement the same transform dialect "
2160 "interface as the corresponding callee result";
2165 cast<FunctionOpInterface>(*target),
false,
2175 ::std::optional<::mlir::Operation *> maybeCurrent,
2177 if (!maybeCurrent.has_value()) {
2182 return emitSilenceableError() <<
"operation is not empty";
2193 for (
auto acceptedAttr : getOpNames().getAsRange<StringAttr>()) {
2194 if (acceptedAttr.getValue() == currentOpName)
2197 return emitSilenceableError() <<
"wrong operation name";
2208 auto signedAPIntAsString = [&](
const APInt &value) {
2210 llvm::raw_string_ostream os(str);
2211 value.print(os,
true);
2218 if (params.size() != references.size()) {
2219 return emitSilenceableError()
2220 <<
"parameters have different payload lengths (" << params.size()
2221 <<
" vs " << references.size() <<
")";
2224 for (
auto &&[i, param, reference] :
llvm::enumerate(params, references)) {
2225 auto intAttr = llvm::dyn_cast<IntegerAttr>(param);
2226 auto refAttr = llvm::dyn_cast<IntegerAttr>(reference);
2227 if (!intAttr || !refAttr) {
2229 <<
"non-integer parameter value not expected";
2231 if (intAttr.getType() != refAttr.getType()) {
2233 <<
"mismatching integer attribute types in parameter #" << i;
2235 APInt value = intAttr.getValue();
2236 APInt refValue = refAttr.getValue();
2239 int64_t position = i;
2240 auto reportError = [&](StringRef direction) {
2242 emitSilenceableError() <<
"expected parameter to be " << direction
2243 <<
" " << signedAPIntAsString(refValue)
2244 <<
", got " << signedAPIntAsString(value);
2245 diag.attachNote(getParam().getLoc())
2246 <<
"value # " << position
2247 <<
" associated with the parameter defined here";
2251 switch (getPredicate()) {
2252 case MatchCmpIPredicate::eq:
2253 if (value.eq(refValue))
2255 return reportError(
"equal to");
2256 case MatchCmpIPredicate::ne:
2257 if (value.ne(refValue))
2259 return reportError(
"not equal to");
2260 case MatchCmpIPredicate::lt:
2261 if (value.slt(refValue))
2263 return reportError(
"less than");
2264 case MatchCmpIPredicate::le:
2265 if (value.sle(refValue))
2267 return reportError(
"less than or equal to");
2268 case MatchCmpIPredicate::gt:
2269 if (value.sgt(refValue))
2271 return reportError(
"greater than");
2272 case MatchCmpIPredicate::ge:
2273 if (value.sge(refValue))
2275 return reportError(
"greater than or equal to");
2281 void transform::MatchParamCmpIOp::getEffects(
2282 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
2295 results.
setParams(cast<OpResult>(getParam()), {getValue()});
2308 if (isa<TransformHandleTypeInterface>(handles.front().
getType())) {
2310 for (
Value operand : handles)
2311 llvm::append_range(operations, state.getPayloadOps(operand));
2312 if (!getDeduplicate()) {
2313 results.
set(llvm::cast<OpResult>(getResult()), operations);
2318 results.
set(llvm::cast<OpResult>(getResult()), uniqued.getArrayRef());
2322 if (llvm::isa<TransformParamTypeInterface>(handles.front().getType())) {
2324 for (
Value attribute : handles)
2325 llvm::append_range(attrs, state.getParams(attribute));
2326 if (!getDeduplicate()) {
2327 results.
setParams(cast<OpResult>(getResult()), attrs);
2332 results.
setParams(cast<OpResult>(getResult()), uniqued.getArrayRef());
2337 llvm::isa<TransformValueHandleTypeInterface>(handles.front().getType()) &&
2338 "expected value handle type");
2340 for (
Value value : handles)
2341 llvm::append_range(payloadValues, state.getPayloadValues(value));
2342 if (!getDeduplicate()) {
2343 results.
setValues(cast<OpResult>(getResult()), payloadValues);
2348 results.
setValues(cast<OpResult>(getResult()), uniqued.getArrayRef());
2352 bool transform::MergeHandlesOp::allowsRepeatedHandleOperands() {
2354 return getDeduplicate();
2357 void transform::MergeHandlesOp::getEffects(
2358 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
2366 OpFoldResult transform::MergeHandlesOp::fold(FoldAdaptor adaptor) {
2367 if (getDeduplicate() || getHandles().size() != 1)
2372 return getHandles().front();
2390 auto scope = state.make_region_scope(getBody());
2392 state, this->getOperation(), getBody())))
2396 FailurePropagationMode::Propagate, state, results);
2399 void transform::NamedSequenceOp::getEffects(
2400 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {}
2405 parser, result,
false,
2406 getFunctionTypeAttrName(result.
name),
2409 std::string &) { return builder.getFunctionType(inputs, results); },
2410 getArgAttrsAttrName(result.
name), getResAttrsAttrName(result.
name));
2415 printer, cast<FunctionOpInterface>(getOperation()),
false,
2416 getFunctionTypeAttrName().getValue(), getArgAttrsAttrName(),
2417 getResAttrsAttrName());
2427 if (
auto parent = op->
getParentOfType<transform::TransformOpInterface>()) {
2430 <<
"cannot be defined inside another transform op";
2431 diag.attachNote(parent.
getLoc()) <<
"ancestor transform op";
2435 if (op.isExternal() || op.getFunctionBody().empty()) {
2442 if (op.getFunctionBody().front().empty())
2445 Operation *terminator = &op.getFunctionBody().front().back();
2446 if (!isa<transform::YieldOp>(terminator)) {
2449 << transform::YieldOp::getOperationName()
2450 <<
"' as terminator";
2451 diag.attachNote(terminator->
getLoc()) <<
"terminator";
2455 if (terminator->
getNumOperands() != op.getResultTypes().size()) {
2457 <<
"expected terminator to have as many operands as the parent op "
2460 for (
auto [i, operandType, resultType] : llvm::zip_equal(
2463 if (operandType == resultType)
2466 <<
"the type of the terminator operand #" << i
2467 <<
" must match the type of the corresponding parent op result ("
2468 << operandType <<
" vs " << resultType <<
")";
2481 transform::TransformDialect::kWithNamedSequenceAttrName)) {
2484 <<
"expects the parent symbol table to have the '"
2485 << transform::TransformDialect::kWithNamedSequenceAttrName
2487 diag.attachNote(parent->
getLoc()) <<
"symbol table operation";
2492 if (
auto parent = op->
getParentOfType<transform::TransformOpInterface>()) {
2495 <<
"cannot be defined inside another transform op";
2496 diag.attachNote(parent.
getLoc()) <<
"ancestor transform op";
2500 if (op.isExternal() || op.getBody().empty())
2504 if (op.getBody().front().empty())
2507 Operation *terminator = &op.getBody().front().back();
2508 if (!isa<transform::YieldOp>(terminator)) {
2511 << transform::YieldOp::getOperationName()
2512 <<
"' as terminator";
2513 diag.attachNote(terminator->
getLoc()) <<
"terminator";
2517 if (terminator->
getNumOperands() != op.getFunctionType().getNumResults()) {
2519 <<
"expected terminator to have as many operands as the parent op "
2522 for (
auto [i, operandType, resultType] :
2523 llvm::zip_equal(llvm::seq<unsigned>(0, terminator->
getNumOperands()),
2525 op.getFunctionType().getResults())) {
2526 if (operandType == resultType)
2529 <<
"the type of the terminator operand #" << i
2530 <<
" must match the type of the corresponding parent op result ("
2531 << operandType <<
" vs " << resultType <<
")";
2534 auto funcOp = cast<FunctionOpInterface>(*op);
2537 if (!
diag.succeeded())
2549 template <
typename FnTy>
2554 types.reserve(1 + extraBindingTypes.size());
2555 types.push_back(bbArgType);
2556 llvm::append_range(types, extraBindingTypes);
2559 Region *region = state.regions.back().get();
2566 if constexpr (llvm::function_traits<FnTy>::num_args == 3) {
2567 bodyBuilder(builder, state.location, bodyBlock->
getArgument(0));
2569 bodyBuilder(builder, state.location, bodyBlock->
getArgument(0),
2574 void transform::NamedSequenceOp::build(
OpBuilder &builder,
2582 state.addAttribute(getFunctionTypeAttrName(state.name),
2584 rootType, resultTypes)));
2585 state.attributes.append(attrs.begin(), attrs.end());
2600 size_t numAssociations =
2602 .Case([&](TransformHandleTypeInterface opHandle) {
2603 return llvm::range_size(state.getPayloadOps(getHandle()));
2605 .Case([&](TransformValueHandleTypeInterface valueHandle) {
2606 return llvm::range_size(state.getPayloadValues(getHandle()));
2608 .Case([&](TransformParamTypeInterface param) {
2609 return llvm::range_size(state.getParams(getHandle()));
2612 llvm_unreachable(
"unknown kind of transform dialect type");
2615 results.
setParams(cast<OpResult>(getNum()),
2622 auto resultType = cast<TransformParamTypeInterface>(getNum().
getType());
2637 auto payloadOps = state.getPayloadOps(getTarget());
2640 result.push_back(op);
2642 results.
set(cast<OpResult>(getResult()), result);
2651 Value target, int64_t numResultHandles) {
2660 int64_t numPayloads =
2662 .Case<TransformHandleTypeInterface>([&](
auto x) {
2663 return llvm::range_size(state.getPayloadOps(getHandle()));
2665 .Case<TransformValueHandleTypeInterface>([&](
auto x) {
2666 return llvm::range_size(state.getPayloadValues(getHandle()));
2668 .Case<TransformParamTypeInterface>([&](
auto x) {
2669 return llvm::range_size(state.getParams(getHandle()));
2671 .Default([](
auto x) {
2672 llvm_unreachable(
"unknown transform dialect type interface");
2676 auto produceNumOpsError = [&]() {
2677 return emitSilenceableError()
2678 << getHandle() <<
" expected to contain " << this->getNumResults()
2679 <<
" payloads but it contains " << numPayloads <<
" payloads";
2684 if (numPayloads > getNumResults() && !getOverflowResult().has_value())
2685 return produceNumOpsError();
2690 if (numPayloads < getNumResults() && getFailOnPayloadTooSmall() &&
2691 (numPayloads != 0 || !getPassThroughEmptyHandle()))
2692 return produceNumOpsError();
2696 if (getOverflowResult())
2697 resultHandles[*getOverflowResult()].reserve(numPayloads - getNumResults());
2699 auto container = [&]() {
2700 if (isa<TransformHandleTypeInterface>(getHandle().
getType())) {
2701 return llvm::map_to_vector(
2702 state.getPayloadOps(getHandle()),
2705 if (isa<TransformValueHandleTypeInterface>(getHandle().
getType())) {
2706 return llvm::map_to_vector(state.getPayloadValues(getHandle()),
2709 assert(isa<TransformParamTypeInterface>(getHandle().
getType()) &&
2710 "unsupported kind of transform dialect type");
2711 return llvm::map_to_vector(state.getParams(getHandle()),
2716 int64_t resultNum = en.index();
2717 if (resultNum >= getNumResults())
2718 resultNum = *getOverflowResult();
2719 resultHandles[resultNum].push_back(en.value());
2730 void transform::SplitHandleOp::getEffects(
2731 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
2739 if (getOverflowResult().has_value() &&
2740 !(*getOverflowResult() < getNumResults()))
2741 return emitOpError(
"overflow_result is not a valid result index");
2743 for (
Type resultType : getResultTypes()) {
2747 return emitOpError(
"expects result types to implement the same transform "
2748 "interface as the operand type");
2762 unsigned numRepetitions = llvm::range_size(state.getPayloadOps(getPattern()));
2764 Value handle = en.value();
2765 if (isa<TransformHandleTypeInterface>(handle.
getType())) {
2767 llvm::to_vector(state.getPayloadOps(handle));
2769 payload.reserve(numRepetitions * current.size());
2770 for (
unsigned i = 0; i < numRepetitions; ++i)
2771 llvm::append_range(payload, current);
2772 results.
set(llvm::cast<OpResult>(getReplicated()[en.index()]), payload);
2774 assert(llvm::isa<TransformParamTypeInterface>(handle.
getType()) &&
2775 "expected param type");
2778 params.reserve(numRepetitions * current.size());
2779 for (
unsigned i = 0; i < numRepetitions; ++i)
2780 llvm::append_range(params, current);
2781 results.
setParams(llvm::cast<OpResult>(getReplicated()[en.index()]),
2788 void transform::ReplicateOp::getEffects(
2789 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
2804 auto scope = state.make_region_scope(*getBodyBlock()->getParent());
2805 if (failed(mapBlockArguments(state)))
2813 OpAsmParser &parser, std::optional<OpAsmParser::UnresolvedOperand> &root,
2815 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &extraBindings,
2816 SmallVectorImpl<Type> &extraBindingTypes) {
2820 root = std::nullopt;
2823 if (failed(hasRoot.
value()))
2837 if (failed(parser.
parseType(rootType))) {
2841 if (!extraBindings.empty()) {
2846 if (extraBindingTypes.size() != extraBindings.size()) {
2848 "expected types to be provided for all operands");
2864 bool hasExtras = !extraBindings.empty();
2874 printer << rootType;
2876 printer <<
", " << llvm::interleaved(extraBindingTypes) <<
')';
2883 auto iface = dyn_cast<transform::TransformOpInterface>(use.
getOwner());
2898 if (!potentialConsumer) {
2899 potentialConsumer = &use;
2904 <<
" has more than one potential consumer";
2907 diag.attachNote(use.getOwner()->getLoc())
2908 <<
"used here as operand #" << use.getOperandNumber();
2916 assert(getBodyBlock()->getNumArguments() >= 1 &&
2917 "the number of arguments must have been verified to be more than 1 by "
2918 "PossibleTopLevelTransformOpTrait");
2920 if (!getRoot() && !getExtraBindings().empty()) {
2921 return emitOpError()
2922 <<
"does not expect extra operands when used as top-level";
2928 return (emitOpError() <<
"block argument #" << arg.getArgNumber());
2935 for (
Operation &child : *getBodyBlock()) {
2936 if (!isa<TransformOpInterface>(child) &&
2937 &child != &getBodyBlock()->back()) {
2940 <<
"expected children ops to implement TransformOpInterface";
2941 diag.attachNote(child.getLoc()) <<
"op without interface";
2945 for (
OpResult result : child.getResults()) {
2946 auto report = [&]() {
2947 return (child.emitError() <<
"result #" << result.getResultNumber());
2954 if (!getBodyBlock()->mightHaveTerminator())
2955 return emitOpError() <<
"expects to have a terminator in the body";
2957 if (getBodyBlock()->getTerminator()->getOperandTypes() !=
2958 getOperation()->getResultTypes()) {
2960 <<
"expects the types of the terminator operands "
2961 "to match the types of the result";
2962 diag.attachNote(getBodyBlock()->getTerminator()->getLoc()) <<
"terminator";
2968 void transform::SequenceOp::getEffects(
2969 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
2975 assert(point == getBody() &&
"unexpected region index");
2976 if (getOperation()->getNumOperands() > 0)
2977 return getOperation()->getOperands();
2979 getOperation()->operand_end());
2982 void transform::SequenceOp::getSuccessorRegions(
2985 Region *bodyRegion = &getBody();
2986 regions.emplace_back(bodyRegion, getNumOperands() != 0
2992 assert(point == getBody() &&
"unexpected region index");
2993 regions.emplace_back(getOperation()->getResults());
2996 void transform::SequenceOp::getRegionInvocationBounds(
2999 bounds.emplace_back(1, 1);
3004 FailurePropagationMode failurePropagationMode,
3007 build(builder, state, resultTypes, failurePropagationMode, root,
3016 FailurePropagationMode failurePropagationMode,
3019 build(builder, state, resultTypes, failurePropagationMode, root,
3027 FailurePropagationMode failurePropagationMode,
3030 build(builder, state, resultTypes, failurePropagationMode,
Value(),
3038 FailurePropagationMode failurePropagationMode,
3041 build(builder, state, resultTypes, failurePropagationMode,
Value(),
3057 Value target, StringRef name) {
3059 build(builder, result, name);
3066 llvm::outs() <<
"[[[ IR printer: ";
3067 if (getName().has_value())
3068 llvm::outs() << *getName() <<
" ";
3071 if (getAssumeVerified().value_or(
false))
3073 if (getUseLocalScope().value_or(
false))
3075 if (getSkipRegions().value_or(
false))
3079 llvm::outs() <<
"top-level ]]]\n";
3080 state.getTopLevel()->print(llvm::outs(), printFlags);
3081 llvm::outs() <<
"\n";
3082 llvm::outs().flush();
3086 llvm::outs() <<
"]]]\n";
3087 for (
Operation *target : state.getPayloadOps(getTarget())) {
3088 target->
print(llvm::outs(), printFlags);
3089 llvm::outs() <<
"\n";
3092 llvm::outs().flush();
3096 void transform::PrintOp::getEffects(
3097 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
3101 if (!getTargetMutable().empty())
3121 <<
"failed to verify payload op";
3122 diag.attachNote(target->
getLoc()) <<
"payload op";
3128 void transform::VerifyOp::getEffects(
3129 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
3137 void transform::YieldOp::getEffects(
3138 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
static ParseResult parseKeyValuePair(AsmParser &parser, DataLayoutEntryInterface &entry, bool tryType=false)
Parse an entry which can either be of the form key = value or a #dlti.dl_entry attribute.
static MLIRContext * getContext(OpFoldResult val)
static std::string diag(const llvm::Value &value)
static llvm::ManagedStatic< PassManagerOptions > options
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
ParseResult parseSymbolName(StringAttr &result)
Parse an -identifier and store it (without the '@' symbol) in a string attribute.
@ None
Zero or more operands with no delimiters.
@ Braces
{} brackets surrounding zero or more operands.
virtual ParseResult parseOptionalKeywordOrString(std::string *result)=0
Parse an optional keyword or string.
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
MLIRContext * getContext() const
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseRSquare()=0
Parse a ] token.
virtual ParseResult parseOptionalRParen()=0
Parse a ) token if present.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual OptionalParseResult parseOptionalAttribute(Attribute &result, Type type={})=0
Parse an arbitrary optional attribute of a given type and return it in result.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual ParseResult parseColon()=0
Parse a : token.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseArrow()=0
Parse a '->' token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseComma()=0
Parse a , token.
virtual ParseResult parseOptionalLParen()=0
Parse a ( token if present.
ParseResult parseTypeList(SmallVectorImpl< Type > &result)
Parse a type list.
virtual ParseResult parseOptionalLSquare()=0
Parse a [ token if present.
virtual void printAttribute(Attribute attr)
Attributes are known-constant values of operations.
This class represents an argument of a Block.
Block represents an ordered list of Operations.
MutableArrayRef< BlockArgument > BlockArgListType
BlockArgument getArgument(unsigned i)
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
Operation * getTerminator()
Get the terminator operation of this block.
OpListType & getOperations()
BlockArgListType getArguments()
iterator_range< iterator > without_terminator()
Return an iterator range over the operation within this block excluding the terminator operation at t...
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
This class is a general helper class for creating context-global objects like types,...
IntegerAttr getI64IntegerAttr(int64_t value)
StringAttr getStringAttr(const Twine &bytes)
MLIRContext * getContext() const
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
This class describes a specific conversion target.
A compatibility class connecting InFlightDiagnostic to DiagnosedSilenceableFailure while providing an...
The result of a transform IR operation application.
LogicalResult silence()
Converts silenceable failure into LogicalResult success without reporting the diagnostic,...
static DiagnosedSilenceableFailure success()
Constructs a DiagnosedSilenceableFailure in the success state.
Diagnostic & attachNote(std::optional< Location > loc=std::nullopt)
Attaches a note to the last diagnostic.
std::string getMessage() const
Returns the diagnostic message without emitting it.
bool isDefiniteFailure() const
Returns true if this is a definite failure.
LogicalResult checkAndReport()
Converts all kinds of failure into a LogicalResult failure, emitting the diagnostic if necessary.
bool succeeded() const
Returns true if this is a success.
static DiagnosedSilenceableFailure definiteFailure()
Constructs a DiagnosedSilenceableFailure in the failure state.
bool isSilenceableFailure() const
Returns true if this is a silenceable failure.
Dialects are groups of MLIR operations, types and attributes, as well as behavior associated with the...
A class for computing basic dominance information.
This class represents a frozen set of patterns that can be processed by a pattern applicator.
This class allows control over how the GreedyPatternRewriteDriver works.
static constexpr int64_t kNoLimit
IRValueT get() const
Return the current value being used by this operand.
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
This class represents a diagnostic that is inflight and set to be reported.
This class represents upper and lower bounds on the number of times a region of a RegionBranchOpInter...
Conversion from types to the LLVM IR dialect.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
Dialect * getLoadedDialect(StringRef name)
Get a registered IR dialect with the given namespace.
std::vector< Dialect * > getLoadedDialects()
Return information about all IR dialects loaded in the context.
ArrayRef< RegisteredOperationName > getRegisteredOperations()
Return a sorted array containing the information about all registered operations.
NamedAttribute represents a combination of a name and an Attribute value.
StringAttr getName() const
Return the name of the attribute.
Attribute getValue() const
Return the value of the attribute.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual OptionalParseResult parseOptionalOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single operand if present.
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printNewline()=0
Print a newline and indent the printer to the start of the current operation.
virtual void increaseIndent()=0
Increase indentation.
void printOperands(const ContainerType &container)
Print a comma separated list of operands.
virtual void decreaseIndent()=0
Decrease indentation.
virtual void printOperand(Value value)=0
Print implementations for various things an operation contains.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Listener * getListener() const
Returns the current listener of this builder, or nullptr if this builder doesn't have a listener.
This class represents a single result from folding an operation.
This class represents an operand of an operation.
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Set of flags used to control the behavior of the various IR print methods (e.g.
OpPrintingFlags & useLocalScope(bool enable=true)
Use local scope when printing the operation.
OpPrintingFlags & assumeVerified(bool enable=true)
Do not verify the operation when using custom operation printers.
OpPrintingFlags & skipRegions(bool skip=true)
Skip printing regions.
This is a value defined by a result of an operation.
This class provides the API for ops that are known to be isolated from above.
A trait used to provide symbol table functionalities to a region operation.
This class implements the operand iterators for the Operation class.
type_range getType() const
type_range getTypes() const
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
Operation is the basic unit of execution within MLIR.
Value getOperand(unsigned idx)
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Attribute getAttr(StringAttr name)
Return the specified attribute if present, null otherwise.
Operation * clone(IRMapping &mapper, CloneOptions options=CloneOptions::all())
Create a deep copy of this operation, remapping any operands that use values outside of the operation...
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
void print(raw_ostream &os, const OpPrintingFlags &flags=std::nullopt)
MLIRContext * getContext()
Return the context this operation is associated with.
Location getLoc()
The source location the operation was defined or derived from.
unsigned getNumOperands()
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Block * getBlock()
Returns the operation block that contains this operation.
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
Operation * getParentWithTrait()
Returns the closest surrounding parent operation with trait Trait.
OperationName getName()
The name of an operation is the key identifier for it.
operand_range getOperands()
Returns an iterator on the underlying Value's.
result_range getOpResults()
result_range getResults()
bool isProperAncestor(Operation *other)
Return true if this operation is a proper ancestor of the other operation.
void erase()
Remove this operation from its parent block and delete it.
unsigned getNumResults()
Return the number of results held by this operation.
This class implements Optional functionality for ParseResult.
ParseResult value() const
Access the internal ParseResult value.
bool has_value() const
Returns true if we contain a valid ParseResult value.
static const PassInfo * lookup(StringRef passArg)
Returns the pass info for the specified pass class or null if unknown.
The main pass manager and pipeline builder.
static const PassPipelineInfo * lookup(StringRef pipelineArg)
Returns the pass pipeline info for the specified pass pipeline or null if unknown.
Structure to group information about a passes and pass pipelines (argument to invoke via mlir-opt,...
LogicalResult addToPipeline(OpPassManager &pm, StringRef options, function_ref< LogicalResult(const Twine &)> errorHandler) const
Adds this pass registry entry to the given pass manager.
This class represents a point being branched from in the methods of the RegionBranchOpInterface.
bool isParent() const
Returns true if branching from the parent op.
Region * getRegionOrNull() const
Returns the region if branching from a region.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
BlockArgListType getArguments()
unsigned getRegionNumber()
Return the number of this region in the parent operation.
This is a "type erased" representation of a registered operation.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
static DerivedEffect * get()
Returns a unique instance for the derived effect class.
This class represents a collection of SymbolTables.
virtual Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
static StringRef getSymbolAttrName()
Return the name of the attribute used for symbol names.
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
type_range getType() const
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
void print(raw_ostream &os) const
Type getType() const
Return the type of this value.
use_range getUses() const
Returns a range of all uses, which is useful for iterating over all uses.
user_range getUsers() const
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
A utility result that is used to signal how to proceed with an ongoing walk:
static WalkResult advance()
bool wasInterrupted() const
Returns true if the walk was interrupted.
static WalkResult interrupt()
Operation * getOwner() const
Return the owner of this operand.
A named class for passing around the variadic flag.
static void printOptionValue(raw_ostream &os, const bool &value)
Utility methods for printing option values.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
void printFunctionOp(OpAsmPrinter &p, FunctionOpInterface op, bool isVariadic, StringRef typeAttrName, StringAttr argAttrsName, StringAttr resAttrsName)
Printer implementation for function-like operations.
ParseResult parseFunctionOp(OpAsmParser &parser, OperationState &result, bool allowVariadic, StringAttr typeAttrName, FuncTypeBuilder funcTypeBuilder, StringAttr argAttrsName, StringAttr resAttrsName)
Parser implementation for function-like operations.
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
Include the generated interface declarations.
InFlightDiagnostic emitWarning(Location loc)
Utility method to emit a warning message using this location.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
const FrozenRewritePatternSet GreedyRewriteConfig config
LogicalResult applyPatternsGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
DiagnosedSilenceableFailure emitSilenceableFailure(Location loc, const Twine &message={})
Emits a silenceable failure with the given message.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
LogicalResult applyFullConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Apply a complete conversion on the given operations, and all nested operations.
LogicalResult applyOpPatternsGreedily(ArrayRef< Operation * > ops, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr, bool *allErased=nullptr)
Rewrite the specified ops by repeatedly applying the highest benefit patterns in a greedy worklist dr...
DiagnosedDefiniteFailure emitDefiniteFailure(Location loc, const Twine &message={})
Emits a definite failure with the given message.
bool isOpTriviallyDead(Operation *op)
Return true if the given operation is unused, and has no side effects on memory that prevent erasing.
const FrozenRewritePatternSet & patterns
void eliminateCommonSubExpressions(RewriterBase &rewriter, DominanceInfo &domInfo, Operation *op, bool *changed=nullptr)
Eliminate common subexpressions within the given operation.
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.
size_t moveLoopInvariantCode(ArrayRef< Region * > regions, function_ref< bool(Value, Region *)> isDefinedOutsideRegion, function_ref< bool(Operation *, Region *)> shouldMoveOutOfRegion, function_ref< void(Operation *, Region *)> moveOutOfRegion)
Given a list of regions, perform loop-invariant code motion.
Dialect conversion configuration.
RewriterBase::Listener * listener
An optional listener that is notified about all IR modifications in case dialect conversion succeeds.
This is the representation of an operand reference.
This represents an operation in an abstracted form, suitable for use with the builder APIs.
T & getOrAddProperties()
Get (or create) a properties of the provided type to be set on the operation on creation.
void addOperands(ValueRange newOperands)
void addTypes(ArrayRef< Type > newTypes)
Region * addRegion()
Create a region that should be attached to the operation.