34 #include "llvm/ADT/DenseSet.h"
35 #include "llvm/ADT/STLExtras.h"
36 #include "llvm/ADT/ScopeExit.h"
37 #include "llvm/ADT/SmallPtrSet.h"
38 #include "llvm/ADT/TypeSwitch.h"
39 #include "llvm/Support/Debug.h"
40 #include "llvm/Support/DebugLog.h"
41 #include "llvm/Support/ErrorHandling.h"
42 #include "llvm/Support/InterleavedRange.h"
45 #define DEBUG_TYPE "transform-dialect"
46 #define DEBUG_TYPE_MATCHER "transform-matcher"
52 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &dynamicOptions);
58 OpAsmParser &parser, std::optional<OpAsmParser::UnresolvedOperand> &root,
60 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &extraBindings,
61 SmallVectorImpl<Type> &extraBindingTypes);
67 ArrayAttr matchers, ArrayAttr actions);
78 Operation *transformAncestor = transform.getOperation();
79 while (transformAncestor) {
80 if (transformAncestor == payload) {
82 transform.emitDefiniteFailure()
83 <<
"cannot apply transform to itself (or one of its ancestors)";
84 diag.attachNote(payload->
getLoc()) <<
"target payload op";
87 transformAncestor = transformAncestor->
getParentOp();
92 #define GET_OP_CLASSES
93 #include "mlir/Dialect/Transform/IR/TransformOps.cpp.inc"
101 if (!point.
isParent() && getOperation()->getNumOperands() == 1)
102 return getOperation()->getOperands();
104 getOperation()->operand_end());
107 void transform::AlternativesOp::getSuccessorRegions(
109 for (
Region &alternative : llvm::drop_begin(
113 regions.emplace_back(&alternative, !getOperands().empty()
114 ? alternative.getArguments()
118 regions.emplace_back(getOperation()->getResults());
121 void transform::AlternativesOp::getRegionInvocationBounds(
126 bounds.reserve(getNumRegions());
127 bounds.emplace_back(1, 1);
134 results.
set(res, {});
142 if (
Value scopeHandle = getScope())
143 llvm::append_range(originals, state.getPayloadOps(scopeHandle));
145 originals.push_back(state.getTopLevel());
148 if (original->isAncestor(getOperation())) {
150 <<
"scope must not contain the transforms being applied";
151 diag.attachNote(original->getLoc()) <<
"scope";
156 <<
"only isolated-from-above ops can be alternative scopes";
157 diag.attachNote(original->getLoc()) <<
"scope";
162 for (
Region ® : getAlternatives()) {
167 auto scope = state.make_region_scope(reg);
168 auto clones = llvm::to_vector(
169 llvm::map_range(originals, [](
Operation *op) {
return op->
clone(); }));
170 auto deleteClones = llvm::make_scope_exit([&] {
174 if (
failed(state.mapBlockArguments(reg.front().getArgument(0), clones)))
178 for (
Operation &transform : reg.front().without_terminator()) {
180 state.applyTransform(cast<TransformOpInterface>(transform));
182 LDBG() <<
"alternative failed: " << result.
getMessage();
196 deleteClones.release();
197 TrackingListener listener(state, *
this);
199 for (
const auto &kvp : llvm::zip(originals, clones)) {
210 return emitSilenceableError() <<
"all alternatives failed";
213 void transform::AlternativesOp::getEffects(
214 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
217 for (
Region *region : getRegions()) {
218 if (!region->empty())
225 for (
Region &alternative : getAlternatives()) {
230 <<
"expects terminator operands to have the "
231 "same type as results of the operation";
232 diag.attachNote(terminator->
getLoc()) <<
"terminator";
249 llvm::to_vector(state.getPayloadOps(getTarget()));
252 if (
auto paramH = getParam()) {
254 if (params.size() != 1) {
255 if (targets.size() != params.size()) {
256 return emitSilenceableError()
257 <<
"parameter and target have different payload lengths ("
258 << params.size() <<
" vs " << targets.size() <<
")";
260 for (
auto &&[target, attr] : llvm::zip_equal(targets, params))
261 target->setAttr(getName(), attr);
266 for (
auto *target : targets)
267 target->setAttr(getName(), attr);
271 void transform::AnnotateOp::getEffects(
272 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
283 transform::ApplyCommonSubexpressionEliminationOp::applyToOne(
298 void transform::ApplyCommonSubexpressionEliminationOp::getEffects(
299 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
323 auto addDefiningOpsToWorklist = [&](
Operation *op) {
326 if (
Operation *defOp = v.getDefiningOp())
328 worklist.insert(defOp);
336 const auto *it = llvm::find(worklist, op);
337 if (it != worklist.end())
346 addDefiningOpsToWorklist(op);
352 while (!worklist.empty()) {
356 addDefiningOpsToWorklist(op);
363 void transform::ApplyDeadCodeEliminationOp::getEffects(
364 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
388 if (!getRegion().empty()) {
389 for (
Operation &op : getRegion().front()) {
390 cast<transform::PatternDescriptorOpInterface>(&op)
391 .populatePatternsWithState(
patterns, state);
401 config.setMaxIterations(getMaxIterations() ==
static_cast<uint64_t
>(-1)
403 : getMaxIterations());
404 config.setMaxNumRewrites(getMaxNumRewrites() ==
static_cast<uint64_t
>(-1)
406 : getMaxNumRewrites());
411 bool cseChanged =
false;
414 static const int64_t kNumMaxIterations = 50;
415 int64_t iteration = 0;
417 LogicalResult result = failure();
430 if (target != nestedOp)
431 ops.push_back(nestedOp);
440 <<
"greedy pattern application failed";
448 }
while (cseChanged && ++iteration < kNumMaxIterations);
450 if (iteration == kNumMaxIterations)
457 if (!getRegion().empty()) {
458 for (
Operation &op : getRegion().front()) {
459 if (!isa<transform::PatternDescriptorOpInterface>(&op)) {
461 <<
"expected children ops to implement "
462 "PatternDescriptorOpInterface";
463 diag.attachNote(op.
getLoc()) <<
"op without interface";
471 void transform::ApplyPatternsOp::getEffects(
472 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
477 void transform::ApplyPatternsOp::build(
486 bodyBuilder(builder, result.
location);
493 void transform::ApplyCanonicalizationPatternsOp::populatePatterns(
497 dialect->getCanonicalizationPatterns(
patterns);
499 op.getCanonicalizationPatterns(
patterns, ctx);
513 std::unique_ptr<TypeConverter> defaultTypeConverter;
514 transform::TypeConverterBuilderOpInterface typeConverterBuilder =
515 getDefaultTypeConverter();
516 if (typeConverterBuilder)
517 defaultTypeConverter = typeConverterBuilder.getTypeConverter();
522 for (
Attribute attr : cast<ArrayAttr>(*getLegalOps()))
523 conversionTarget.addLegalOp(
526 for (
Attribute attr : cast<ArrayAttr>(*getIllegalOps()))
527 conversionTarget.addIllegalOp(
529 if (getLegalDialects())
530 for (
Attribute attr : cast<ArrayAttr>(*getLegalDialects()))
531 conversionTarget.addLegalDialect(cast<StringAttr>(attr).getValue());
532 if (getIllegalDialects())
533 for (
Attribute attr : cast<ArrayAttr>(*getIllegalDialects()))
534 conversionTarget.addIllegalDialect(cast<StringAttr>(attr).getValue());
542 if (!getPatterns().empty()) {
543 for (
Operation &op : getPatterns().front()) {
545 cast<transform::ConversionPatternDescriptorOpInterface>(&op);
548 std::unique_ptr<TypeConverter> typeConverter =
549 descriptor.getTypeConverter();
552 keepAliveConverters.emplace_back(std::move(typeConverter));
553 converter = keepAliveConverters.back().get();
556 if (!defaultTypeConverter) {
558 <<
"pattern descriptor does not specify type "
559 "converter and apply_conversion_patterns op has "
560 "no default type converter";
561 diag.attachNote(op.
getLoc()) <<
"pattern descriptor op";
564 converter = defaultTypeConverter.get();
570 descriptor.populateConversionTargetRules(*converter, conversionTarget);
572 descriptor.populatePatterns(*converter,
patterns);
580 TrackingListenerConfig trackingConfig;
581 trackingConfig.requireMatchingReplacementOpName =
false;
582 ErrorCheckingTrackingListener trackingListener(state, *
this, trackingConfig);
584 if (getPreserveHandles())
585 conversionConfig.
listener = &trackingListener;
588 for (
Operation *target : state.getPayloadOps(getTarget())) {
596 LogicalResult status = failure();
597 if (getPartialConversion()) {
608 diag = emitSilenceableError() <<
"dialect conversion failed";
609 diag.attachNote(target->
getLoc()) <<
"target op";
614 trackingListener.checkAndResetError();
616 if (
diag.succeeded()) {
618 return trackingFailure;
620 diag.attachNote() <<
"tracking listener also failed: "
622 (void)trackingFailure.
silence();
626 if (!
diag.succeeded())
634 if (getNumRegions() != 1 && getNumRegions() != 2)
635 return emitOpError() <<
"expected 1 or 2 regions";
636 if (!getPatterns().empty()) {
637 for (
Operation &op : getPatterns().front()) {
638 if (!isa<transform::ConversionPatternDescriptorOpInterface>(&op)) {
640 emitOpError() <<
"expected pattern children ops to implement "
641 "ConversionPatternDescriptorOpInterface";
642 diag.attachNote(op.
getLoc()) <<
"op without interface";
647 if (getNumRegions() == 2) {
648 Region &typeConverterRegion = getRegion(1);
649 if (!llvm::hasSingleElement(typeConverterRegion.
front()))
651 <<
"expected exactly one op in default type converter region";
653 auto typeConverterOp = dyn_cast<transform::TypeConverterBuilderOpInterface>(
655 if (!typeConverterOp) {
657 <<
"expected default converter child op to "
658 "implement TypeConverterBuilderOpInterface";
659 diag.attachNote(maybeTypeConverter->
getLoc()) <<
"op without interface";
663 if (!getPatterns().empty()) {
664 for (
Operation &op : getPatterns().front()) {
666 cast<transform::ConversionPatternDescriptorOpInterface>(&op);
667 if (
failed(descriptor.verifyTypeConverter(typeConverterOp)))
675 void transform::ApplyConversionPatternsOp::getEffects(
676 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
677 if (!getPreserveHandles()) {
685 void transform::ApplyConversionPatternsOp::build(
695 if (patternsBodyBuilder)
696 patternsBodyBuilder(builder, result.
location);
702 if (typeConverterBodyBuilder)
703 typeConverterBodyBuilder(builder, result.
location);
711 void transform::ApplyToLLVMConversionPatternsOp::populatePatterns(
714 assert(dialect &&
"expected that dialect is loaded");
715 auto *iface = cast<ConvertToLLVMPatternInterface>(dialect);
719 iface->populateConvertToLLVMConversionPatterns(
723 LogicalResult transform::ApplyToLLVMConversionPatternsOp::verifyTypeConverter(
724 transform::TypeConverterBuilderOpInterface builder) {
725 if (builder.getTypeConverterType() !=
"LLVMTypeConverter")
726 return emitOpError(
"expected LLVMTypeConverter");
733 return emitOpError(
"unknown dialect or dialect not loaded: ")
735 auto *iface = dyn_cast<ConvertToLLVMPatternInterface>(dialect);
738 "dialect does not implement ConvertToLLVMPatternInterface or "
739 "extension was not loaded: ")
749 transform::ApplyLoopInvariantCodeMotionOp::applyToOne(
759 void transform::ApplyLoopInvariantCodeMotionOp::getEffects(
760 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
769 void transform::ApplyRegisteredPassOp::getEffects(
770 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
787 llvm::raw_string_ostream optionsStream(
options);
792 if (
auto paramOperand = dyn_cast<transform::ParamOperandAttr>(valueAttr)) {
795 int64_t dynamicOptionIdx = paramOperand.getIndex().getInt();
796 assert(dynamicOptionIdx <
static_cast<int64_t
>(dynamicOptions.size()) &&
797 "the number of ParamOperandAttrs in the options DictionaryAttr"
798 "should be the same as the number of options passed as params");
800 state.getParams(dynamicOptions[dynamicOptionIdx]);
802 llvm::interleave(attrsAssociatedToParam, optionsStream, appendValueAttr,
804 }
else if (
auto arrayAttr = dyn_cast<ArrayAttr>(valueAttr)) {
806 llvm::interleave(arrayAttr, optionsStream, appendValueAttr,
",");
807 }
else if (
auto strAttr = dyn_cast<StringAttr>(valueAttr)) {
809 optionsStream << strAttr.getValue().str();
812 valueAttr.print(optionsStream,
true);
818 getOptions(), optionsStream,
819 [&](
auto namedAttribute) {
820 optionsStream << namedAttribute.getName().str();
821 optionsStream <<
"=";
822 appendValueAttr(namedAttribute.getValue());
825 optionsStream.flush();
833 <<
"unknown pass or pass pipeline: " << getPassName();
842 <<
"failed to add pass or pass pipeline to pipeline: "
858 if (
failed(pm.run(target))) {
859 auto diag = emitSilenceableError() <<
"pass pipeline failed";
860 diag.attachNote(target->
getLoc()) <<
"target op";
866 results.
set(llvm::cast<OpResult>(getResult()), targets);
872 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &dynamicOptions) {
875 size_t dynamicOptionsIdx = 0;
881 std::function<ParseResult(
Attribute &)> parseValue =
882 [&](
Attribute &valueAttr) -> ParseResult {
890 [&]() -> ParseResult { return parseValue(attrs.emplace_back()); },
891 " in options dictionary") ||
905 ParseResult parsedOperand = parser.
parseOperand(operand);
906 if (
failed(parsedOperand))
912 dynamicOptions.push_back(operand);
919 }
else if (isa<transform::ParamOperandAttr>(valueAttr)) {
921 <<
"the param_operand attribute is a marker reserved for "
922 <<
"indicating a value will be passed via params and is only used "
923 <<
"in the generic print format";
937 <<
"expected key to either be an identifier or a string";
941 <<
"expected '=' after key in key-value pair";
943 if (
failed(parseValue(valueAttr)))
945 <<
"expected a valid attribute or operand as value associated "
946 <<
"to key '" << key <<
"'";
955 " in options dictionary"))
958 if (DictionaryAttr::findDuplicate(
959 keyValuePairs,
false)
962 <<
"duplicate keys found in options dictionary";
977 if (
auto paramOperandAttr =
978 dyn_cast<transform::ParamOperandAttr>(valueAttr)) {
981 dynamicOptions[paramOperandAttr.getIndex().getInt()]);
982 }
else if (
auto arrayAttr = dyn_cast<ArrayAttr>(valueAttr)) {
994 printer << namedAttribute.
getName();
1008 std::function<LogicalResult(
Attribute)> checkOptionValue =
1009 [&](
Attribute valueAttr) -> LogicalResult {
1010 if (
auto paramOperand = dyn_cast<transform::ParamOperandAttr>(valueAttr)) {
1011 int64_t dynamicOptionIdx = paramOperand.getIndex().getInt();
1012 if (dynamicOptionIdx < 0 ||
1013 dynamicOptionIdx >=
static_cast<int64_t
>(dynamicOptions.size()))
1014 return emitOpError()
1015 <<
"dynamic option index " << dynamicOptionIdx
1016 <<
" is out of bounds for the number of dynamic options: "
1017 << dynamicOptions.size();
1018 if (dynamicOptions[dynamicOptionIdx] ==
nullptr)
1019 return emitOpError() <<
"dynamic option index " << dynamicOptionIdx
1020 <<
" is already used in options";
1021 dynamicOptions[dynamicOptionIdx] =
nullptr;
1022 }
else if (
auto arrayAttr = dyn_cast<ArrayAttr>(valueAttr)) {
1024 for (
auto eltAttr : arrayAttr)
1025 if (
failed(checkOptionValue(eltAttr)))
1032 if (
failed(checkOptionValue(namedAttr.getValue())))
1036 for (
Value dynamicOption : dynamicOptions)
1038 return emitOpError() <<
"a param operand does not have a corresponding "
1039 <<
"param_operand attr in the options dict";
1050 Operation *target, ApplyToEachResultList &results,
1052 results.push_back(target);
1056 void transform::CastOp::getEffects(
1057 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
1064 assert(inputs.size() == 1 &&
"expected one input");
1065 assert(outputs.size() == 1 &&
"expected one output");
1066 return llvm::all_of(
1067 std::initializer_list<Type>{inputs.front(), outputs.front()},
1068 llvm::IsaPred<transform::TransformHandleTypeInterface>);
1088 assert(block.
getParent() &&
"cannot match using a detached block");
1089 auto matchScope = state.make_region_scope(*block.
getParent());
1091 state.mapBlockArguments(block.
getArguments(), blockArgumentMapping)))
1095 if (!isa<transform::MatchOpInterface>(match)) {
1097 <<
"expected operations in the match part to "
1098 "implement MatchOpInterface";
1101 state.applyTransform(cast<transform::TransformOpInterface>(match));
1102 if (
diag.succeeded())
1120 template <
typename... Tys>
1122 return ((isa<Tys>(t1) && isa<Tys>(t2)) || ... ||
false);
1129 transform::TransformParamTypeInterface,
1130 transform::TransformValueHandleTypeInterface>(
1142 auto matcher = SymbolTable::lookupNearestSymbolFrom<FunctionOpInterface>(
1143 getOperation(), getMatcher());
1144 if (matcher.isExternal()) {
1146 <<
"unresolved external symbol " << getMatcher();
1150 rawResults.resize(getOperation()->getNumResults());
1151 std::optional<DiagnosedSilenceableFailure> maybeFailure;
1152 for (
Operation *root : state.getPayloadOps(getRoot())) {
1163 matcher.getFunctionBody().front(),
1166 if (
diag.isDefiniteFailure())
1168 if (
diag.isSilenceableFailure()) {
1170 <<
" failed: " <<
diag.getMessage();
1176 if (mapping.size() != 1) {
1177 maybeFailure.emplace(emitSilenceableError()
1178 <<
"result #" << i <<
", associated with "
1180 <<
" payload objects, expected 1");
1183 rawResults[i].push_back(mapping[0]);
1188 return std::move(*maybeFailure);
1189 assert(!maybeFailure &&
"failure set but the walk was not interrupted");
1191 for (
auto &&[opResult, rawResult] :
1192 llvm::zip_equal(getOperation()->getResults(), rawResults)) {
1199 void transform::CollectMatchingOp::getEffects(
1200 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
1206 LogicalResult transform::CollectMatchingOp::verifySymbolUses(
1208 auto matcherSymbol = dyn_cast_or_null<FunctionOpInterface>(
1210 if (!matcherSymbol ||
1211 !isa<TransformOpInterface>(matcherSymbol.getOperation()))
1212 return emitError() <<
"unresolved matcher symbol " << getMatcher();
1215 if (argumentTypes.size() != 1 ||
1216 !isa<TransformHandleTypeInterface>(argumentTypes[0])) {
1218 <<
"expected the matcher to take one operation handle argument";
1220 if (!matcherSymbol.getArgAttr(
1221 0, transform::TransformDialect::kArgReadOnlyAttrName)) {
1222 return emitError() <<
"expected the matcher argument to be marked readonly";
1226 if (resultTypes.size() != getOperation()->getNumResults()) {
1228 <<
"expected the matcher to yield as many values as op has results ("
1229 << getOperation()->getNumResults() <<
"), got "
1230 << resultTypes.size();
1233 for (
auto &&[i, matcherType, resultType] :
1239 <<
"mismatching type interfaces for matcher result and op result #"
1251 bool transform::ForeachMatchOp::allowsRepeatedHandleOperands() {
return true; }
1259 matchActionPairs.reserve(getMatchers().size());
1261 for (
auto &&[matcher, action] :
1262 llvm::zip_equal(getMatchers(), getActions())) {
1263 auto matcherSymbol =
1265 getOperation(), cast<SymbolRefAttr>(matcher));
1268 getOperation(), cast<SymbolRefAttr>(action));
1269 assert(matcherSymbol && actionSymbol &&
1270 "unresolved symbols not caught by the verifier");
1272 if (matcherSymbol.isExternal())
1274 if (actionSymbol.isExternal())
1277 matchActionPairs.emplace_back(matcherSymbol, actionSymbol);
1288 matchInputMapping.emplace_back();
1290 getForwardedInputs(), state);
1292 actionResultMapping.resize(getForwardedOutputs().size());
1294 for (
Operation *root : state.getPayloadOps(getRoot())) {
1298 if (!getRestrictRoot() && op == root)
1306 firstMatchArgument.clear();
1307 firstMatchArgument.push_back(op);
1310 for (
auto [matcher, action] : matchActionPairs) {
1312 matchBlock(matcher.getFunctionBody().front(), matchInputMapping,
1313 state, matchOutputMapping);
1314 if (
diag.isDefiniteFailure())
1316 if (
diag.isSilenceableFailure()) {
1318 <<
" failed: " <<
diag.getMessage();
1322 auto scope = state.make_region_scope(action.getFunctionBody());
1323 if (
failed(state.mapBlockArguments(
1324 action.getFunctionBody().front().getArguments(),
1325 matchOutputMapping))) {
1330 action.getFunctionBody().front().without_terminator()) {
1332 state.applyTransform(cast<TransformOpInterface>(transform));
1337 overallDiag = emitSilenceableError() <<
"actions failed";
1342 <<
"when applied to this matching payload";
1349 action.getFunctionBody().front().getTerminator()->getOperands(),
1350 state, getFlattenResults()))) {
1352 <<
"action @" << action.getName()
1353 <<
" has results associated with multiple payload entities, "
1354 "but flattening was not requested";
1369 results.
set(llvm::cast<OpResult>(getUpdated()),
1370 state.getPayloadOps(getRoot()));
1371 for (
auto &&[result, mapping] :
1372 llvm::zip_equal(getForwardedOutputs(), actionResultMapping)) {
1378 void transform::ForeachMatchOp::getAsmResultNames(
1380 setNameFn(getUpdated(),
"updated_root");
1381 for (
Value v : getForwardedOutputs()) {
1382 setNameFn(v,
"yielded");
1386 void transform::ForeachMatchOp::getEffects(
1387 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
1389 if (getOperation()->getNumOperands() < 1 ||
1390 getOperation()->getNumResults() < 1) {
1403 ArrayAttr &matchers,
1404 ArrayAttr &actions) {
1426 ArrayAttr matchers, ArrayAttr actions) {
1429 for (
auto &&[matcher, action, idx] : llvm::zip_equal(
1430 matchers, actions, llvm::seq<unsigned>(0, matchers.size()))) {
1432 printer << cast<SymbolRefAttr>(matcher) <<
" -> "
1433 << cast<SymbolRefAttr>(action);
1434 if (idx != matchers.size() - 1)
1442 if (getMatchers().size() != getActions().size())
1443 return emitOpError() <<
"expected the same number of matchers and actions";
1444 if (getMatchers().empty())
1445 return emitOpError() <<
"expected at least one match/action pair";
1449 if (matcherNames.insert(name).second)
1452 <<
" is used more than once, only the first match will apply";
1463 bool alsoVerifyInternal =
false) {
1464 auto transformOp = cast<transform::TransformOpInterface>(op.getOperation());
1465 llvm::SmallDenseSet<unsigned> consumedArguments;
1466 if (!op.isExternal()) {
1470 for (
unsigned i = 0, e = op.getNumArguments(); i < e; ++i) {
1472 op.getArgAttr(i, transform::TransformDialect::kArgConsumedAttrName) !=
1475 op.getArgAttr(i, transform::TransformDialect::kArgReadOnlyAttrName) !=
1477 if (isConsumed && isReadOnly) {
1478 return transformOp.emitSilenceableError()
1479 <<
"argument #" << i <<
" cannot be both readonly and consumed";
1481 if ((op.isExternal() || alsoVerifyInternal) && !isConsumed && !isReadOnly) {
1482 return transformOp.emitSilenceableError()
1483 <<
"must provide consumed/readonly status for arguments of "
1484 "external or called ops";
1486 if (op.isExternal())
1489 if (consumedArguments.contains(i) && !isConsumed && isReadOnly) {
1490 return transformOp.emitSilenceableError()
1491 <<
"argument #" << i
1492 <<
" is consumed in the body but is not marked as such";
1494 if (emitWarnings && !consumedArguments.contains(i) && isConsumed) {
1498 <<
"op argument #" << i
1499 <<
" is not consumed in the body but is marked as consumed";
1505 LogicalResult transform::ForeachMatchOp::verifySymbolUses(
1507 assert(getMatchers().size() == getActions().size());
1510 for (
auto &&[matcher, action] :
1511 llvm::zip_equal(getMatchers(), getActions())) {
1513 auto matcherSymbol = dyn_cast_or_null<FunctionOpInterface>(
1515 cast<SymbolRefAttr>(matcher)));
1516 auto actionSymbol = dyn_cast_or_null<FunctionOpInterface>(
1518 cast<SymbolRefAttr>(action)));
1519 if (!matcherSymbol ||
1520 !isa<TransformOpInterface>(matcherSymbol.getOperation()))
1521 return emitError() <<
"unresolved matcher symbol " << matcher;
1522 if (!actionSymbol ||
1523 !isa<TransformOpInterface>(actionSymbol.getOperation()))
1524 return emitError() <<
"unresolved action symbol " << action;
1529 .checkAndReport())) {
1535 .checkAndReport())) {
1540 TypeRange operandTypes = getOperandTypes();
1541 TypeRange matcherArguments = matcherSymbol.getArgumentTypes();
1542 if (operandTypes.size() != matcherArguments.size()) {
1544 emitError() <<
"the number of operands (" << operandTypes.size()
1545 <<
") doesn't match the number of matcher arguments ("
1546 << matcherArguments.size() <<
") for " << matcher;
1547 diag.attachNote(matcherSymbol->getLoc()) <<
"symbol declaration";
1550 for (
auto &&[i, operand, argument] :
1552 if (matcherSymbol.getArgAttr(i, consumedAttr)) {
1555 <<
"does not expect matcher symbol to consume its operand #" << i;
1556 diag.attachNote(matcherSymbol->getLoc()) <<
"symbol declaration";
1565 <<
"mismatching type interfaces for operand and matcher argument #"
1566 << i <<
" of matcher " << matcher;
1567 diag.attachNote(matcherSymbol->getLoc()) <<
"symbol declaration";
1572 TypeRange matcherResults = matcherSymbol.getResultTypes();
1573 TypeRange actionArguments = actionSymbol.getArgumentTypes();
1574 if (matcherResults.size() != actionArguments.size()) {
1575 return emitError() <<
"mismatching number of matcher results and "
1576 "action arguments between "
1577 << matcher <<
" (" << matcherResults.size() <<
") and "
1578 << action <<
" (" << actionArguments.size() <<
")";
1580 for (
auto &&[i, matcherType, actionType] :
1585 return emitError() <<
"mismatching type interfaces for matcher result "
1586 "and action argument #"
1587 << i <<
"of matcher " << matcher <<
" and action "
1592 TypeRange actionResults = actionSymbol.getResultTypes();
1593 auto resultTypes =
TypeRange(getResultTypes()).drop_front();
1594 if (actionResults.size() != resultTypes.size()) {
1596 emitError() <<
"the number of action results ("
1597 << actionResults.size() <<
") for " << action
1598 <<
" doesn't match the number of extra op results ("
1599 << resultTypes.size() <<
")";
1600 diag.attachNote(actionSymbol->getLoc()) <<
"symbol declaration";
1603 for (
auto &&[i, resultType, actionType] :
1609 emitError() <<
"mismatching type interfaces for action result #" << i
1610 <<
" of action " << action <<
" and op result";
1611 diag.attachNote(actionSymbol->getLoc()) <<
"symbol declaration";
1630 size_t numIterations = payloads.empty() ? 0 : payloads.front().size();
1631 bool withZipShortest = getWithZipShortest();
1635 if (withZipShortest) {
1639 return A.size() <
B.size();
1642 for (
size_t argIdx = 0; argIdx < payloads.size(); argIdx++)
1643 payloads[argIdx].resize(numIterations);
1649 for (
size_t argIdx = 1; !withZipShortest && argIdx < payloads.size();
1651 if (payloads[argIdx].size() != numIterations) {
1652 return emitSilenceableError()
1653 <<
"prior targets' payload size (" << numIterations
1654 <<
") differs from payload size (" << payloads[argIdx].size()
1655 <<
") of target " << getTargets()[argIdx];
1664 for (
size_t iterIdx = 0; iterIdx < numIterations; iterIdx++) {
1665 auto scope = state.make_region_scope(getBody());
1671 if (
failed(state.mapBlockArgument(blockArg, {argument})))
1676 for (
Operation &transform : getBody().front().without_terminator()) {
1678 llvm::cast<transform::TransformOpInterface>(transform));
1684 OperandRange yieldOperands = getYieldOp().getOperands();
1685 for (
auto &&[result, yieldOperand, resTuple] :
1686 llvm::zip_equal(getResults(), yieldOperands, zippedResults))
1688 if (isa<TransformHandleTypeInterface>(result.getType()))
1689 llvm::append_range(resTuple, state.getPayloadOps(yieldOperand));
1690 else if (isa<TransformValueHandleTypeInterface>(result.getType()))
1691 llvm::append_range(resTuple, state.getPayloadValues(yieldOperand));
1692 else if (isa<TransformParamTypeInterface>(result.getType()))
1693 llvm::append_range(resTuple, state.getParams(yieldOperand));
1695 assert(
false &&
"unhandled handle type");
1699 for (
auto &&[result, resPayload] : zip_equal(getResults(), zippedResults))
1705 void transform::ForeachOp::getEffects(
1706 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
1709 for (
auto &&[target, blockArg] :
1710 llvm::zip(getTargetsMutable(), getBody().front().getArguments())) {
1712 if (any_of(getBody().front().without_terminator(), [&](
Operation &op) {
1714 cast<TransformOpInterface>(&op));
1722 if (any_of(getBody().front().without_terminator(), [&](
Operation &op) {
1726 }
else if (any_of(getBody().front().without_terminator(), [&](
Operation &op) {
1735 void transform::ForeachOp::getSuccessorRegions(
1737 Region *bodyRegion = &getBody();
1739 regions.emplace_back(bodyRegion, bodyRegion->
getArguments());
1744 assert(point == getBody() &&
"unexpected region index");
1745 regions.emplace_back(bodyRegion, bodyRegion->
getArguments());
1746 regions.emplace_back();
1753 assert(point == getBody() &&
"unexpected region index");
1754 return getOperation()->getOperands();
1757 transform::YieldOp transform::ForeachOp::getYieldOp() {
1758 return cast<transform::YieldOp>(getBody().front().getTerminator());
1762 for (
auto [targetOpt, bodyArgOpt] :
1763 llvm::zip_longest(getTargets(), getBody().front().getArguments())) {
1764 if (!targetOpt || !bodyArgOpt)
1765 return emitOpError() <<
"expects the same number of targets as the body "
1766 "has block arguments";
1767 if (targetOpt.value().getType() != bodyArgOpt.value().getType())
1769 "expects co-indexed targets and the body's "
1770 "block arguments to have the same op/value/param type");
1773 for (
auto [resultOpt, yieldOperandOpt] :
1774 llvm::zip_longest(getResults(), getYieldOp().getOperands())) {
1775 if (!resultOpt || !yieldOperandOpt)
1776 return emitOpError() <<
"expects the same number of results as the "
1777 "yield terminator has operands";
1778 if (resultOpt.value().getType() != yieldOperandOpt.value().getType())
1779 return emitOpError(
"expects co-indexed results and yield "
1780 "operands to have the same op/value/param type");
1796 for (
Operation *target : state.getPayloadOps(getTarget())) {
1798 for (int64_t i = 0, e = getNthParent(); i < e; ++i) {
1801 bool checkIsolatedFromAbove =
1802 !getIsolatedFromAbove() ||
1804 bool checkOpName = !getOpName().has_value() ||
1806 if (checkIsolatedFromAbove && checkOpName)
1811 if (getAllowEmptyResults()) {
1812 results.
set(llvm::cast<OpResult>(getResult()), parents);
1816 emitSilenceableError()
1817 <<
"could not find a parent op that matches all requirements";
1818 diag.attachNote(target->
getLoc()) <<
"target op";
1822 if (getDeduplicate()) {
1823 if (resultSet.insert(parent).second)
1824 parents.push_back(parent);
1826 parents.push_back(parent);
1829 results.
set(llvm::cast<OpResult>(getResult()), parents);
1841 int64_t resultNumber = getResultNumber();
1842 auto payloadOps = state.getPayloadOps(getTarget());
1843 if (std::empty(payloadOps)) {
1844 results.
set(cast<OpResult>(getResult()), {});
1847 if (!llvm::hasSingleElement(payloadOps))
1849 <<
"handle must be mapped to exactly one payload op";
1851 Operation *target = *payloadOps.begin();
1854 results.
set(llvm::cast<OpResult>(getResult()),
1868 for (
Value v : state.getPayloadValues(getTarget())) {
1869 if (llvm::isa<BlockArgument>(v)) {
1871 emitSilenceableError() <<
"cannot get defining op of block argument";
1872 diag.attachNote(v.getLoc()) <<
"target value";
1875 definingOps.push_back(v.getDefiningOp());
1877 results.
set(llvm::cast<OpResult>(getResult()), definingOps);
1889 int64_t operandNumber = getOperandNumber();
1891 for (
Operation *target : state.getPayloadOps(getTarget())) {
1898 emitSilenceableError()
1899 <<
"could not find a producer for operand number: " << operandNumber
1900 <<
" of " << *target;
1901 diag.attachNote(target->getLoc()) <<
"target op";
1904 producers.push_back(producer);
1906 results.
set(llvm::cast<OpResult>(getResult()), producers);
1919 for (
Operation *target : state.getPayloadOps(getTarget())) {
1922 getLoc(), getIsAll(), getIsInverted(), getRawPositionList(),
1923 target->getNumOperands(), operandPositions);
1924 if (
diag.isSilenceableFailure()) {
1925 diag.attachNote(target->getLoc())
1926 <<
"while considering positions of this payload operation";
1929 llvm::append_range(operands,
1930 llvm::map_range(operandPositions, [&](int64_t pos) {
1931 return target->getOperand(pos);
1934 results.
setValues(cast<OpResult>(getResult()), operands);
1940 getIsInverted(), getIsAll());
1952 for (
Operation *target : state.getPayloadOps(getTarget())) {
1955 getLoc(), getIsAll(), getIsInverted(), getRawPositionList(),
1956 target->getNumResults(), resultPositions);
1957 if (
diag.isSilenceableFailure()) {
1958 diag.attachNote(target->getLoc())
1959 <<
"while considering positions of this payload operation";
1962 llvm::append_range(opResults,
1963 llvm::map_range(resultPositions, [&](int64_t pos) {
1964 return target->getResult(pos);
1967 results.
setValues(cast<OpResult>(getResult()), opResults);
1973 getIsInverted(), getIsAll());
1980 void transform::GetTypeOp::getEffects(
1981 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
1992 for (
Value value : state.getPayloadValues(getValue())) {
1993 Type type = value.getType();
1994 if (getElemental()) {
1995 if (
auto shaped = dyn_cast<ShapedType>(type)) {
1996 type = shaped.getElementType();
2001 results.
setParams(cast<OpResult>(getResult()), params);
2018 state.applyTransform(cast<transform::TransformOpInterface>(transform));
2023 if (mode == transform::FailurePropagationMode::Propagate) {
2042 auto callee = SymbolTable::lookupNearestSymbolFrom<NamedSequenceOp>(
2043 getOperation(), getTarget());
2044 assert(callee &&
"unverified reference to unknown symbol");
2046 if (callee.isExternal())
2052 auto scope = state.make_region_scope(callee.getBody());
2053 for (
auto &&[arg, map] :
2054 llvm::zip_equal(callee.getBody().front().getArguments(), mappings)) {
2055 if (
failed(state.mapBlockArgument(arg, map)))
2060 callee.getBody().front(), getFailurePropagationMode(), state, results);
2063 mappings, callee.getBody().front().getTerminator()->getOperands(), state);
2064 for (
auto &&[result, mapping] : llvm::zip_equal(getResults(), mappings))
2072 void transform::IncludeOp::getEffects(
2073 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
2087 auto defaultEffects = [&] {
2094 getOperation()->getAttrOfType<SymbolRefAttr>(getTargetAttrName());
2096 return defaultEffects();
2097 auto callee = SymbolTable::lookupNearestSymbolFrom<NamedSequenceOp>(
2098 getOperation(), getTarget());
2100 return defaultEffects();
2104 (void)earlyVerifierResult.
silence();
2105 return defaultEffects();
2108 for (
unsigned i = 0, e = getNumOperands(); i < e; ++i) {
2109 if (callee.getArgAttr(i, TransformDialect::kArgConsumedAttrName))
2120 auto targetAttr = getOperation()->getAttrOfType<SymbolRefAttr>(
"target");
2122 return emitOpError() <<
"expects a 'target' symbol reference attribute";
2127 return emitOpError() <<
"does not reference a named transform sequence";
2129 FunctionType fnType = target.getFunctionType();
2130 if (fnType.getNumInputs() != getNumOperands())
2131 return emitError(
"incorrect number of operands for callee");
2133 for (
unsigned i = 0, e = fnType.getNumInputs(); i != e; ++i) {
2134 if (getOperand(i).
getType() != fnType.getInput(i)) {
2135 return emitOpError(
"operand type mismatch: expected operand type ")
2136 << fnType.getInput(i) <<
", but provided "
2137 << getOperand(i).getType() <<
" for operand number " << i;
2141 if (fnType.getNumResults() != getNumResults())
2142 return emitError(
"incorrect number of results for callee");
2144 for (
unsigned i = 0, e = fnType.getNumResults(); i != e; ++i) {
2145 Type resultType = getResult(i).getType();
2146 Type funcType = fnType.getResult(i);
2148 return emitOpError() <<
"type of result #" << i
2149 <<
" must implement the same transform dialect "
2150 "interface as the corresponding callee result";
2155 cast<FunctionOpInterface>(*target),
false,
2165 ::std::optional<::mlir::Operation *> maybeCurrent,
2167 if (!maybeCurrent.has_value()) {
2172 return emitSilenceableError() <<
"operation is not empty";
2183 for (
auto acceptedAttr : getOpNames().getAsRange<StringAttr>()) {
2184 if (acceptedAttr.getValue() == currentOpName)
2187 return emitSilenceableError() <<
"wrong operation name";
2198 auto signedAPIntAsString = [&](
const APInt &value) {
2200 llvm::raw_string_ostream os(str);
2201 value.print(os,
true);
2208 if (params.size() != references.size()) {
2209 return emitSilenceableError()
2210 <<
"parameters have different payload lengths (" << params.size()
2211 <<
" vs " << references.size() <<
")";
2214 for (
auto &&[i, param, reference] :
llvm::enumerate(params, references)) {
2215 auto intAttr = llvm::dyn_cast<IntegerAttr>(param);
2216 auto refAttr = llvm::dyn_cast<IntegerAttr>(reference);
2217 if (!intAttr || !refAttr) {
2219 <<
"non-integer parameter value not expected";
2221 if (intAttr.getType() != refAttr.getType()) {
2223 <<
"mismatching integer attribute types in parameter #" << i;
2225 APInt value = intAttr.getValue();
2226 APInt refValue = refAttr.getValue();
2229 int64_t position = i;
2230 auto reportError = [&](StringRef direction) {
2232 emitSilenceableError() <<
"expected parameter to be " << direction
2233 <<
" " << signedAPIntAsString(refValue)
2234 <<
", got " << signedAPIntAsString(value);
2235 diag.attachNote(getParam().getLoc())
2236 <<
"value # " << position
2237 <<
" associated with the parameter defined here";
2241 switch (getPredicate()) {
2242 case MatchCmpIPredicate::eq:
2243 if (value.eq(refValue))
2245 return reportError(
"equal to");
2246 case MatchCmpIPredicate::ne:
2247 if (value.ne(refValue))
2249 return reportError(
"not equal to");
2250 case MatchCmpIPredicate::lt:
2251 if (value.slt(refValue))
2253 return reportError(
"less than");
2254 case MatchCmpIPredicate::le:
2255 if (value.sle(refValue))
2257 return reportError(
"less than or equal to");
2258 case MatchCmpIPredicate::gt:
2259 if (value.sgt(refValue))
2261 return reportError(
"greater than");
2262 case MatchCmpIPredicate::ge:
2263 if (value.sge(refValue))
2265 return reportError(
"greater than or equal to");
2271 void transform::MatchParamCmpIOp::getEffects(
2272 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
2285 results.
setParams(cast<OpResult>(getParam()), {getValue()});
2298 if (isa<TransformHandleTypeInterface>(handles.front().
getType())) {
2300 for (
Value operand : handles)
2301 llvm::append_range(operations, state.getPayloadOps(operand));
2302 if (!getDeduplicate()) {
2303 results.
set(llvm::cast<OpResult>(getResult()), operations);
2308 results.
set(llvm::cast<OpResult>(getResult()), uniqued.getArrayRef());
2312 if (llvm::isa<TransformParamTypeInterface>(handles.front().getType())) {
2314 for (
Value attribute : handles)
2315 llvm::append_range(attrs, state.getParams(attribute));
2316 if (!getDeduplicate()) {
2317 results.
setParams(cast<OpResult>(getResult()), attrs);
2322 results.
setParams(cast<OpResult>(getResult()), uniqued.getArrayRef());
2327 llvm::isa<TransformValueHandleTypeInterface>(handles.front().getType()) &&
2328 "expected value handle type");
2330 for (
Value value : handles)
2331 llvm::append_range(payloadValues, state.getPayloadValues(value));
2332 if (!getDeduplicate()) {
2333 results.
setValues(cast<OpResult>(getResult()), payloadValues);
2338 results.
setValues(cast<OpResult>(getResult()), uniqued.getArrayRef());
2342 bool transform::MergeHandlesOp::allowsRepeatedHandleOperands() {
2344 return getDeduplicate();
2347 void transform::MergeHandlesOp::getEffects(
2348 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
2356 OpFoldResult transform::MergeHandlesOp::fold(FoldAdaptor adaptor) {
2357 if (getDeduplicate() || getHandles().size() != 1)
2362 return getHandles().front();
2380 auto scope = state.make_region_scope(getBody());
2382 state, this->getOperation(), getBody())))
2386 FailurePropagationMode::Propagate, state, results);
2389 void transform::NamedSequenceOp::getEffects(
2390 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {}
2395 parser, result,
false,
2396 getFunctionTypeAttrName(result.
name),
2399 std::string &) { return builder.getFunctionType(inputs, results); },
2400 getArgAttrsAttrName(result.
name), getResAttrsAttrName(result.
name));
2405 printer, cast<FunctionOpInterface>(getOperation()),
false,
2406 getFunctionTypeAttrName().getValue(), getArgAttrsAttrName(),
2407 getResAttrsAttrName());
2417 if (
auto parent = op->
getParentOfType<transform::TransformOpInterface>()) {
2420 <<
"cannot be defined inside another transform op";
2421 diag.attachNote(parent.
getLoc()) <<
"ancestor transform op";
2425 if (op.isExternal() || op.getFunctionBody().empty()) {
2432 if (op.getFunctionBody().front().empty())
2435 Operation *terminator = &op.getFunctionBody().front().back();
2436 if (!isa<transform::YieldOp>(terminator)) {
2439 << transform::YieldOp::getOperationName()
2440 <<
"' as terminator";
2441 diag.attachNote(terminator->
getLoc()) <<
"terminator";
2445 if (terminator->
getNumOperands() != op.getResultTypes().size()) {
2447 <<
"expected terminator to have as many operands as the parent op "
2450 for (
auto [i, operandType, resultType] : llvm::zip_equal(
2453 if (operandType == resultType)
2456 <<
"the type of the terminator operand #" << i
2457 <<
" must match the type of the corresponding parent op result ("
2458 << operandType <<
" vs " << resultType <<
")";
2471 transform::TransformDialect::kWithNamedSequenceAttrName)) {
2474 <<
"expects the parent symbol table to have the '"
2475 << transform::TransformDialect::kWithNamedSequenceAttrName
2477 diag.attachNote(parent->
getLoc()) <<
"symbol table operation";
2482 if (
auto parent = op->
getParentOfType<transform::TransformOpInterface>()) {
2485 <<
"cannot be defined inside another transform op";
2486 diag.attachNote(parent.
getLoc()) <<
"ancestor transform op";
2490 if (op.isExternal() || op.getBody().empty())
2494 if (op.getBody().front().empty())
2497 Operation *terminator = &op.getBody().front().back();
2498 if (!isa<transform::YieldOp>(terminator)) {
2501 << transform::YieldOp::getOperationName()
2502 <<
"' as terminator";
2503 diag.attachNote(terminator->
getLoc()) <<
"terminator";
2507 if (terminator->
getNumOperands() != op.getFunctionType().getNumResults()) {
2509 <<
"expected terminator to have as many operands as the parent op "
2512 for (
auto [i, operandType, resultType] :
2513 llvm::zip_equal(llvm::seq<unsigned>(0, terminator->
getNumOperands()),
2515 op.getFunctionType().getResults())) {
2516 if (operandType == resultType)
2519 <<
"the type of the terminator operand #" << i
2520 <<
" must match the type of the corresponding parent op result ("
2521 << operandType <<
" vs " << resultType <<
")";
2524 auto funcOp = cast<FunctionOpInterface>(*op);
2527 if (!
diag.succeeded())
2539 template <
typename FnTy>
2544 types.reserve(1 + extraBindingTypes.size());
2545 types.push_back(bbArgType);
2546 llvm::append_range(types, extraBindingTypes);
2549 Region *region = state.regions.back().get();
2556 if constexpr (llvm::function_traits<FnTy>::num_args == 3) {
2557 bodyBuilder(builder, state.location, bodyBlock->
getArgument(0));
2559 bodyBuilder(builder, state.location, bodyBlock->
getArgument(0),
2564 void transform::NamedSequenceOp::build(
OpBuilder &builder,
2572 state.addAttribute(getFunctionTypeAttrName(state.name),
2574 rootType, resultTypes)));
2575 state.attributes.append(attrs.begin(), attrs.end());
2590 size_t numAssociations =
2592 .Case([&](TransformHandleTypeInterface opHandle) {
2593 return llvm::range_size(state.getPayloadOps(getHandle()));
2595 .Case([&](TransformValueHandleTypeInterface valueHandle) {
2596 return llvm::range_size(state.getPayloadValues(getHandle()));
2598 .Case([&](TransformParamTypeInterface param) {
2599 return llvm::range_size(state.getParams(getHandle()));
2602 llvm_unreachable(
"unknown kind of transform dialect type");
2605 results.
setParams(cast<OpResult>(getNum()),
2612 auto resultType = cast<TransformParamTypeInterface>(getNum().
getType());
2627 auto payloadOps = state.getPayloadOps(getTarget());
2630 result.push_back(op);
2632 results.
set(cast<OpResult>(getResult()), result);
2641 Value target, int64_t numResultHandles) {
2650 int64_t numPayloads =
2652 .Case<TransformHandleTypeInterface>([&](
auto x) {
2653 return llvm::range_size(state.getPayloadOps(getHandle()));
2655 .Case<TransformValueHandleTypeInterface>([&](
auto x) {
2656 return llvm::range_size(state.getPayloadValues(getHandle()));
2658 .Case<TransformParamTypeInterface>([&](
auto x) {
2659 return llvm::range_size(state.getParams(getHandle()));
2661 .Default([](
auto x) {
2662 llvm_unreachable(
"unknown transform dialect type interface");
2666 auto produceNumOpsError = [&]() {
2667 return emitSilenceableError()
2668 << getHandle() <<
" expected to contain " << this->getNumResults()
2669 <<
" payloads but it contains " << numPayloads <<
" payloads";
2674 if (numPayloads > getNumResults() && !getOverflowResult().has_value())
2675 return produceNumOpsError();
2680 if (numPayloads < getNumResults() && getFailOnPayloadTooSmall() &&
2681 (numPayloads != 0 || !getPassThroughEmptyHandle()))
2682 return produceNumOpsError();
2686 if (getOverflowResult())
2687 resultHandles[*getOverflowResult()].reserve(numPayloads - getNumResults());
2689 auto container = [&]() {
2690 if (isa<TransformHandleTypeInterface>(getHandle().
getType())) {
2691 return llvm::map_to_vector(
2692 state.getPayloadOps(getHandle()),
2695 if (isa<TransformValueHandleTypeInterface>(getHandle().
getType())) {
2696 return llvm::map_to_vector(state.getPayloadValues(getHandle()),
2699 assert(isa<TransformParamTypeInterface>(getHandle().
getType()) &&
2700 "unsupported kind of transform dialect type");
2701 return llvm::map_to_vector(state.getParams(getHandle()),
2706 int64_t resultNum = en.index();
2707 if (resultNum >= getNumResults())
2708 resultNum = *getOverflowResult();
2709 resultHandles[resultNum].push_back(en.value());
2720 void transform::SplitHandleOp::getEffects(
2721 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
2729 if (getOverflowResult().has_value() &&
2730 !(*getOverflowResult() < getNumResults()))
2731 return emitOpError(
"overflow_result is not a valid result index");
2733 for (
Type resultType : getResultTypes()) {
2737 return emitOpError(
"expects result types to implement the same transform "
2738 "interface as the operand type");
2752 unsigned numRepetitions = llvm::range_size(state.getPayloadOps(getPattern()));
2754 Value handle = en.value();
2755 if (isa<TransformHandleTypeInterface>(handle.
getType())) {
2757 llvm::to_vector(state.getPayloadOps(handle));
2759 payload.reserve(numRepetitions * current.size());
2760 for (
unsigned i = 0; i < numRepetitions; ++i)
2761 llvm::append_range(payload, current);
2762 results.
set(llvm::cast<OpResult>(getReplicated()[en.index()]), payload);
2764 assert(llvm::isa<TransformParamTypeInterface>(handle.
getType()) &&
2765 "expected param type");
2768 params.reserve(numRepetitions * current.size());
2769 for (
unsigned i = 0; i < numRepetitions; ++i)
2770 llvm::append_range(params, current);
2771 results.
setParams(llvm::cast<OpResult>(getReplicated()[en.index()]),
2778 void transform::ReplicateOp::getEffects(
2779 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
2794 auto scope = state.make_region_scope(*getBodyBlock()->getParent());
2795 if (
failed(mapBlockArguments(state)))
2803 OpAsmParser &parser, std::optional<OpAsmParser::UnresolvedOperand> &root,
2805 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &extraBindings,
2806 SmallVectorImpl<Type> &extraBindingTypes) {
2810 root = std::nullopt;
2831 if (!extraBindings.empty()) {
2836 if (extraBindingTypes.size() != extraBindings.size()) {
2838 "expected types to be provided for all operands");
2854 bool hasExtras = !extraBindings.empty();
2864 printer << rootType;
2866 printer <<
", " << llvm::interleaved(extraBindingTypes) <<
')';
2873 auto iface = dyn_cast<transform::TransformOpInterface>(use.
getOwner());
2888 if (!potentialConsumer) {
2889 potentialConsumer = &use;
2894 <<
" has more than one potential consumer";
2897 diag.attachNote(use.getOwner()->getLoc())
2898 <<
"used here as operand #" << use.getOperandNumber();
2906 assert(getBodyBlock()->getNumArguments() >= 1 &&
2907 "the number of arguments must have been verified to be more than 1 by "
2908 "PossibleTopLevelTransformOpTrait");
2910 if (!getRoot() && !getExtraBindings().empty()) {
2911 return emitOpError()
2912 <<
"does not expect extra operands when used as top-level";
2918 return (emitOpError() <<
"block argument #" << arg.getArgNumber());
2925 for (
Operation &child : *getBodyBlock()) {
2926 if (!isa<TransformOpInterface>(child) &&
2927 &child != &getBodyBlock()->back()) {
2930 <<
"expected children ops to implement TransformOpInterface";
2931 diag.attachNote(child.getLoc()) <<
"op without interface";
2935 for (
OpResult result : child.getResults()) {
2936 auto report = [&]() {
2937 return (child.emitError() <<
"result #" << result.getResultNumber());
2944 if (!getBodyBlock()->mightHaveTerminator())
2945 return emitOpError() <<
"expects to have a terminator in the body";
2947 if (getBodyBlock()->getTerminator()->getOperandTypes() !=
2948 getOperation()->getResultTypes()) {
2950 <<
"expects the types of the terminator operands "
2951 "to match the types of the result";
2952 diag.attachNote(getBodyBlock()->getTerminator()->getLoc()) <<
"terminator";
2958 void transform::SequenceOp::getEffects(
2959 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
2965 assert(point == getBody() &&
"unexpected region index");
2966 if (getOperation()->getNumOperands() > 0)
2967 return getOperation()->getOperands();
2969 getOperation()->operand_end());
2972 void transform::SequenceOp::getSuccessorRegions(
2975 Region *bodyRegion = &getBody();
2976 regions.emplace_back(bodyRegion, getNumOperands() != 0
2982 assert(point == getBody() &&
"unexpected region index");
2983 regions.emplace_back(getOperation()->getResults());
2986 void transform::SequenceOp::getRegionInvocationBounds(
2989 bounds.emplace_back(1, 1);
2994 FailurePropagationMode failurePropagationMode,
2997 build(builder, state, resultTypes, failurePropagationMode, root,
3006 FailurePropagationMode failurePropagationMode,
3009 build(builder, state, resultTypes, failurePropagationMode, root,
3017 FailurePropagationMode failurePropagationMode,
3020 build(builder, state, resultTypes, failurePropagationMode,
Value(),
3028 FailurePropagationMode failurePropagationMode,
3031 build(builder, state, resultTypes, failurePropagationMode,
Value(),
3047 Value target, StringRef name) {
3049 build(builder, result, name);
3056 llvm::outs() <<
"[[[ IR printer: ";
3057 if (getName().has_value())
3058 llvm::outs() << *getName() <<
" ";
3061 if (getAssumeVerified().value_or(
false))
3063 if (getUseLocalScope().value_or(
false))
3065 if (getSkipRegions().value_or(
false))
3069 llvm::outs() <<
"top-level ]]]\n";
3070 state.getTopLevel()->print(llvm::outs(), printFlags);
3071 llvm::outs() <<
"\n";
3072 llvm::outs().flush();
3076 llvm::outs() <<
"]]]\n";
3077 for (
Operation *target : state.getPayloadOps(getTarget())) {
3078 target->
print(llvm::outs(), printFlags);
3079 llvm::outs() <<
"\n";
3082 llvm::outs().flush();
3086 void transform::PrintOp::getEffects(
3087 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
3091 if (!getTargetMutable().empty())
3111 <<
"failed to verify payload op";
3112 diag.attachNote(target->
getLoc()) <<
"payload op";
3118 void transform::VerifyOp::getEffects(
3119 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
3127 void transform::YieldOp::getEffects(
3128 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
static ParseResult parseKeyValuePair(AsmParser &parser, DataLayoutEntryInterface &entry, bool tryType=false)
Parse an entry which can either be of the form key = value or a #dlti.dl_entry attribute.
static MLIRContext * getContext(OpFoldResult val)
static std::string diag(const llvm::Value &value)
static llvm::ManagedStatic< PassManagerOptions > options
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
ParseResult parseSymbolName(StringAttr &result)
Parse an -identifier and store it (without the '@' symbol) in a string attribute.
@ None
Zero or more operands with no delimiters.
@ Braces
{} brackets surrounding zero or more operands.
virtual ParseResult parseOptionalKeywordOrString(std::string *result)=0
Parse an optional keyword or string.
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
MLIRContext * getContext() const
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseRSquare()=0
Parse a ] token.
virtual ParseResult parseOptionalRParen()=0
Parse a ) token if present.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual OptionalParseResult parseOptionalAttribute(Attribute &result, Type type={})=0
Parse an arbitrary optional attribute of a given type and return it in result.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual ParseResult parseColon()=0
Parse a : token.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseArrow()=0
Parse a '->' token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseComma()=0
Parse a , token.
virtual ParseResult parseOptionalLParen()=0
Parse a ( token if present.
ParseResult parseTypeList(SmallVectorImpl< Type > &result)
Parse a type list.
virtual ParseResult parseOptionalLSquare()=0
Parse a [ token if present.
virtual void printAttribute(Attribute attr)
Attributes are known-constant values of operations.
This class represents an argument of a Block.
Block represents an ordered list of Operations.
MutableArrayRef< BlockArgument > BlockArgListType
BlockArgument getArgument(unsigned i)
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
Operation * getTerminator()
Get the terminator operation of this block.
OpListType & getOperations()
BlockArgListType getArguments()
iterator_range< iterator > without_terminator()
Return an iterator range over the operation within this block excluding the terminator operation at t...
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
This class is a general helper class for creating context-global objects like types,...
IntegerAttr getI64IntegerAttr(int64_t value)
StringAttr getStringAttr(const Twine &bytes)
MLIRContext * getContext() const
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
This class describes a specific conversion target.
A compatibility class connecting InFlightDiagnostic to DiagnosedSilenceableFailure while providing an...
The result of a transform IR operation application.
LogicalResult silence()
Converts silenceable failure into LogicalResult success without reporting the diagnostic,...
static DiagnosedSilenceableFailure success()
Constructs a DiagnosedSilenceableFailure in the success state.
Diagnostic & attachNote(std::optional< Location > loc=std::nullopt)
Attaches a note to the last diagnostic.
std::string getMessage() const
Returns the diagnostic message without emitting it.
bool isDefiniteFailure() const
Returns true if this is a definite failure.
LogicalResult checkAndReport()
Converts all kinds of failure into a LogicalResult failure, emitting the diagnostic if necessary.
bool succeeded() const
Returns true if this is a success.
static DiagnosedSilenceableFailure definiteFailure()
Constructs a DiagnosedSilenceableFailure in the failure state.
bool isSilenceableFailure() const
Returns true if this is a silenceable failure.
Dialects are groups of MLIR operations, types and attributes, as well as behavior associated with the...
A class for computing basic dominance information.
This class represents a frozen set of patterns that can be processed by a pattern applicator.
This class allows control over how the GreedyPatternRewriteDriver works.
static constexpr int64_t kNoLimit
IRValueT get() const
Return the current value being used by this operand.
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
This class represents a diagnostic that is inflight and set to be reported.
This class represents upper and lower bounds on the number of times a region of a RegionBranchOpInter...
Conversion from types to the LLVM IR dialect.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
Dialect * getLoadedDialect(StringRef name)
Get a registered IR dialect with the given namespace.
std::vector< Dialect * > getLoadedDialects()
Return information about all IR dialects loaded in the context.
ArrayRef< RegisteredOperationName > getRegisteredOperations()
Return a sorted array containing the information about all registered operations.
NamedAttribute represents a combination of a name and an Attribute value.
StringAttr getName() const
Return the name of the attribute.
Attribute getValue() const
Return the value of the attribute.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual OptionalParseResult parseOptionalOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single operand if present.
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printNewline()=0
Print a newline and indent the printer to the start of the current operation.
virtual void increaseIndent()=0
Increase indentation.
void printOperands(const ContainerType &container)
Print a comma separated list of operands.
virtual void decreaseIndent()=0
Decrease indentation.
virtual void printOperand(Value value)=0
Print implementations for various things an operation contains.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Listener * getListener() const
Returns the current listener of this builder, or nullptr if this builder doesn't have a listener.
This class represents a single result from folding an operation.
This class represents an operand of an operation.
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Set of flags used to control the behavior of the various IR print methods (e.g.
OpPrintingFlags & useLocalScope(bool enable=true)
Use local scope when printing the operation.
OpPrintingFlags & assumeVerified(bool enable=true)
Do not verify the operation when using custom operation printers.
OpPrintingFlags & skipRegions(bool skip=true)
Skip printing regions.
This is a value defined by a result of an operation.
This class provides the API for ops that are known to be isolated from above.
A trait used to provide symbol table functionalities to a region operation.
A wrapper class that allows for printing an operation with a set of flags, useful to act as a "stream...
This class implements the operand iterators for the Operation class.
type_range getType() const
type_range getTypes() const
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
Operation is the basic unit of execution within MLIR.
Value getOperand(unsigned idx)
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Attribute getAttr(StringAttr name)
Return the specified attribute if present, null otherwise.
Operation * clone(IRMapping &mapper, CloneOptions options=CloneOptions::all())
Create a deep copy of this operation, remapping any operands that use values outside of the operation...
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
MLIRContext * getContext()
Return the context this operation is associated with.
Location getLoc()
The source location the operation was defined or derived from.
unsigned getNumOperands()
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Block * getBlock()
Returns the operation block that contains this operation.
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
Operation * getParentWithTrait()
Returns the closest surrounding parent operation with trait Trait.
OperationName getName()
The name of an operation is the key identifier for it.
operand_range getOperands()
Returns an iterator on the underlying Value's.
result_range getOpResults()
result_range getResults()
bool isProperAncestor(Operation *other)
Return true if this operation is a proper ancestor of the other operation.
void erase()
Remove this operation from its parent block and delete it.
unsigned getNumResults()
Return the number of results held by this operation.
This class implements Optional functionality for ParseResult.
ParseResult value() const
Access the internal ParseResult value.
bool has_value() const
Returns true if we contain a valid ParseResult value.
static const PassInfo * lookup(StringRef passArg)
Returns the pass info for the specified pass class or null if unknown.
The main pass manager and pipeline builder.
static const PassPipelineInfo * lookup(StringRef pipelineArg)
Returns the pass pipeline info for the specified pass pipeline or null if unknown.
Structure to group information about a passes and pass pipelines (argument to invoke via mlir-opt,...
LogicalResult addToPipeline(OpPassManager &pm, StringRef options, function_ref< LogicalResult(const Twine &)> errorHandler) const
Adds this pass registry entry to the given pass manager.
This class represents a point being branched from in the methods of the RegionBranchOpInterface.
bool isParent() const
Returns true if branching from the parent op.
Region * getRegionOrNull() const
Returns the region if branching from a region.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
BlockArgListType getArguments()
unsigned getRegionNumber()
Return the number of this region in the parent operation.
This is a "type erased" representation of a registered operation.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
static DerivedEffect * get()
Returns a unique instance for the derived effect class.
This class represents a collection of SymbolTables.
virtual Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
static StringRef getSymbolAttrName()
Return the name of the attribute used for symbol names.
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
type_range getType() const
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
void print(raw_ostream &os) const
Type getType() const
Return the type of this value.
use_range getUses() const
Returns a range of all uses, which is useful for iterating over all uses.
user_range getUsers() const
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
A utility result that is used to signal how to proceed with an ongoing walk:
static WalkResult advance()
bool wasInterrupted() const
Returns true if the walk was interrupted.
static WalkResult interrupt()
Operation * getOwner() const
Return the owner of this operand.
A named class for passing around the variadic flag.
static void printOptionValue(raw_ostream &os, const bool &value)
Utility methods for printing option values.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
void printFunctionOp(OpAsmPrinter &p, FunctionOpInterface op, bool isVariadic, StringRef typeAttrName, StringAttr argAttrsName, StringAttr resAttrsName)
Printer implementation for function-like operations.
ParseResult parseFunctionOp(OpAsmParser &parser, OperationState &result, bool allowVariadic, StringAttr typeAttrName, FuncTypeBuilder funcTypeBuilder, StringAttr argAttrsName, StringAttr resAttrsName)
Parser implementation for function-like operations.
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
Include the generated interface declarations.
InFlightDiagnostic emitWarning(Location loc)
Utility method to emit a warning message using this location.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
const FrozenRewritePatternSet GreedyRewriteConfig config
LogicalResult applyPatternsGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
DiagnosedSilenceableFailure emitSilenceableFailure(Location loc, const Twine &message={})
Emits a silenceable failure with the given message.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
LogicalResult applyFullConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Apply a complete conversion on the given operations, and all nested operations.
LogicalResult applyOpPatternsGreedily(ArrayRef< Operation * > ops, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr, bool *allErased=nullptr)
Rewrite the specified ops by repeatedly applying the highest benefit patterns in a greedy worklist dr...
DiagnosedDefiniteFailure emitDefiniteFailure(Location loc, const Twine &message={})
Emits a definite failure with the given message.
bool isOpTriviallyDead(Operation *op)
Return true if the given operation is unused, and has no side effects on memory that prevent erasing.
const FrozenRewritePatternSet & patterns
void eliminateCommonSubExpressions(RewriterBase &rewriter, DominanceInfo &domInfo, Operation *op, bool *changed=nullptr)
Eliminate common subexpressions within the given operation.
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.
size_t moveLoopInvariantCode(ArrayRef< Region * > regions, function_ref< bool(Value, Region *)> isDefinedOutsideRegion, function_ref< bool(Operation *, Region *)> shouldMoveOutOfRegion, function_ref< void(Operation *, Region *)> moveOutOfRegion)
Given a list of regions, perform loop-invariant code motion.
Dialect conversion configuration.
RewriterBase::Listener * listener
An optional listener that is notified about all IR modifications in case dialect conversion succeeds.
This is the representation of an operand reference.
This represents an operation in an abstracted form, suitable for use with the builder APIs.
T & getOrAddProperties()
Get (or create) a properties of the provided type to be set on the operation on creation.
void addOperands(ValueRange newOperands)
void addTypes(ArrayRef< Type > newTypes)
Region * addRegion()
Create a region that should be attached to the operation.