34 #include "llvm/ADT/DenseSet.h"
35 #include "llvm/ADT/STLExtras.h"
36 #include "llvm/ADT/ScopeExit.h"
37 #include "llvm/ADT/SmallPtrSet.h"
38 #include "llvm/ADT/TypeSwitch.h"
39 #include "llvm/Support/Debug.h"
40 #include "llvm/Support/DebugLog.h"
41 #include "llvm/Support/ErrorHandling.h"
42 #include "llvm/Support/InterleavedRange.h"
45 #define DEBUG_TYPE "transform-dialect"
46 #define DEBUG_TYPE_MATCHER "transform-matcher"
52 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &dynamicOptions);
58 OpAsmParser &parser, std::optional<OpAsmParser::UnresolvedOperand> &root,
60 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &extraBindings,
61 SmallVectorImpl<Type> &extraBindingTypes);
67 ArrayAttr matchers, ArrayAttr actions);
78 Operation *transformAncestor = transform.getOperation();
79 while (transformAncestor) {
80 if (transformAncestor == payload) {
82 transform.emitDefiniteFailure()
83 <<
"cannot apply transform to itself (or one of its ancestors)";
84 diag.attachNote(payload->
getLoc()) <<
"target payload op";
87 transformAncestor = transformAncestor->
getParentOp();
92 #define GET_OP_CLASSES
93 #include "mlir/Dialect/Transform/IR/TransformOps.cpp.inc"
101 if (!point.
isParent() && getOperation()->getNumOperands() == 1)
102 return getOperation()->getOperands();
104 getOperation()->operand_end());
107 void transform::AlternativesOp::getSuccessorRegions(
109 for (
Region &alternative : llvm::drop_begin(
113 regions.emplace_back(&alternative, !getOperands().empty()
114 ? alternative.getArguments()
118 regions.emplace_back(getOperation()->getResults());
121 void transform::AlternativesOp::getRegionInvocationBounds(
126 bounds.reserve(getNumRegions());
127 bounds.emplace_back(1, 1);
134 results.
set(res, {});
142 if (
Value scopeHandle = getScope())
143 llvm::append_range(originals, state.getPayloadOps(scopeHandle));
145 originals.push_back(state.getTopLevel());
148 if (original->isAncestor(getOperation())) {
150 <<
"scope must not contain the transforms being applied";
151 diag.attachNote(original->getLoc()) <<
"scope";
156 <<
"only isolated-from-above ops can be alternative scopes";
157 diag.attachNote(original->getLoc()) <<
"scope";
162 for (
Region ® : getAlternatives()) {
167 auto scope = state.make_region_scope(reg);
168 auto clones = llvm::to_vector(
169 llvm::map_range(originals, [](
Operation *op) {
return op->
clone(); }));
170 auto deleteClones = llvm::make_scope_exit([&] {
174 if (
failed(state.mapBlockArguments(reg.front().getArgument(0), clones)))
178 for (
Operation &transform : reg.front().without_terminator()) {
180 state.applyTransform(cast<TransformOpInterface>(transform));
182 LDBG() <<
"alternative failed: " << result.
getMessage();
196 deleteClones.release();
197 TrackingListener listener(state, *
this);
199 for (
const auto &kvp : llvm::zip(originals, clones)) {
210 return emitSilenceableError() <<
"all alternatives failed";
213 void transform::AlternativesOp::getEffects(
214 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
217 for (
Region *region : getRegions()) {
218 if (!region->empty())
225 for (
Region &alternative : getAlternatives()) {
230 <<
"expects terminator operands to have the "
231 "same type as results of the operation";
232 diag.attachNote(terminator->
getLoc()) <<
"terminator";
249 llvm::to_vector(state.getPayloadOps(getTarget()));
252 if (
auto paramH = getParam()) {
254 if (params.size() != 1) {
255 if (targets.size() != params.size()) {
256 return emitSilenceableError()
257 <<
"parameter and target have different payload lengths ("
258 << params.size() <<
" vs " << targets.size() <<
")";
260 for (
auto &&[target, attr] : llvm::zip_equal(targets, params))
261 target->setAttr(getName(), attr);
266 for (
auto *target : targets)
267 target->setAttr(getName(), attr);
271 void transform::AnnotateOp::getEffects(
272 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
283 transform::ApplyCommonSubexpressionEliminationOp::applyToOne(
298 void transform::ApplyCommonSubexpressionEliminationOp::getEffects(
299 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
323 auto addDefiningOpsToWorklist = [&](
Operation *op) {
326 if (
Operation *defOp = v.getDefiningOp())
328 worklist.insert(defOp);
336 const auto *it = llvm::find(worklist, op);
337 if (it != worklist.end())
346 addDefiningOpsToWorklist(op);
352 while (!worklist.empty()) {
356 addDefiningOpsToWorklist(op);
363 void transform::ApplyDeadCodeEliminationOp::getEffects(
364 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
388 if (!getRegion().empty()) {
389 for (
Operation &op : getRegion().front()) {
390 cast<transform::PatternDescriptorOpInterface>(&op)
391 .populatePatternsWithState(
patterns, state);
401 config.setMaxIterations(getMaxIterations() ==
static_cast<uint64_t
>(-1)
403 : getMaxIterations());
404 config.setMaxNumRewrites(getMaxNumRewrites() ==
static_cast<uint64_t
>(-1)
406 : getMaxNumRewrites());
411 bool cseChanged =
false;
414 static const int64_t kNumMaxIterations = 50;
415 int64_t iteration = 0;
417 LogicalResult result = failure();
430 if (target != nestedOp)
431 ops.push_back(nestedOp);
440 <<
"greedy pattern application failed";
448 }
while (cseChanged && ++iteration < kNumMaxIterations);
450 if (iteration == kNumMaxIterations)
457 if (!getRegion().empty()) {
458 for (
Operation &op : getRegion().front()) {
459 if (!isa<transform::PatternDescriptorOpInterface>(&op)) {
461 <<
"expected children ops to implement "
462 "PatternDescriptorOpInterface";
463 diag.attachNote(op.
getLoc()) <<
"op without interface";
471 void transform::ApplyPatternsOp::getEffects(
472 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
477 void transform::ApplyPatternsOp::build(
486 bodyBuilder(builder, result.
location);
493 void transform::ApplyCanonicalizationPatternsOp::populatePatterns(
497 dialect->getCanonicalizationPatterns(
patterns);
499 op.getCanonicalizationPatterns(
patterns, ctx);
513 std::unique_ptr<TypeConverter> defaultTypeConverter;
514 transform::TypeConverterBuilderOpInterface typeConverterBuilder =
515 getDefaultTypeConverter();
516 if (typeConverterBuilder)
517 defaultTypeConverter = typeConverterBuilder.getTypeConverter();
522 for (
Attribute attr : cast<ArrayAttr>(*getLegalOps()))
523 conversionTarget.addLegalOp(
526 for (
Attribute attr : cast<ArrayAttr>(*getIllegalOps()))
527 conversionTarget.addIllegalOp(
529 if (getLegalDialects())
530 for (
Attribute attr : cast<ArrayAttr>(*getLegalDialects()))
531 conversionTarget.addLegalDialect(cast<StringAttr>(attr).getValue());
532 if (getIllegalDialects())
533 for (
Attribute attr : cast<ArrayAttr>(*getIllegalDialects()))
534 conversionTarget.addIllegalDialect(cast<StringAttr>(attr).getValue());
542 if (!getPatterns().empty()) {
543 for (
Operation &op : getPatterns().front()) {
545 cast<transform::ConversionPatternDescriptorOpInterface>(&op);
548 std::unique_ptr<TypeConverter> typeConverter =
549 descriptor.getTypeConverter();
552 keepAliveConverters.emplace_back(std::move(typeConverter));
553 converter = keepAliveConverters.back().get();
556 if (!defaultTypeConverter) {
558 <<
"pattern descriptor does not specify type "
559 "converter and apply_conversion_patterns op has "
560 "no default type converter";
561 diag.attachNote(op.
getLoc()) <<
"pattern descriptor op";
564 converter = defaultTypeConverter.get();
570 descriptor.populateConversionTargetRules(*converter, conversionTarget);
572 descriptor.populatePatterns(*converter,
patterns);
580 TrackingListenerConfig trackingConfig;
581 trackingConfig.requireMatchingReplacementOpName =
false;
582 ErrorCheckingTrackingListener trackingListener(state, *
this, trackingConfig);
584 if (getPreserveHandles())
585 conversionConfig.
listener = &trackingListener;
588 for (
Operation *target : state.getPayloadOps(getTarget())) {
596 LogicalResult status = failure();
597 if (getPartialConversion()) {
608 diag = emitSilenceableError() <<
"dialect conversion failed";
609 diag.attachNote(target->
getLoc()) <<
"target op";
614 trackingListener.checkAndResetError();
616 if (
diag.succeeded()) {
618 return trackingFailure;
620 diag.attachNote() <<
"tracking listener also failed: "
622 (void)trackingFailure.
silence();
625 if (!
diag.succeeded())
633 if (getNumRegions() != 1 && getNumRegions() != 2)
634 return emitOpError() <<
"expected 1 or 2 regions";
635 if (!getPatterns().empty()) {
636 for (
Operation &op : getPatterns().front()) {
637 if (!isa<transform::ConversionPatternDescriptorOpInterface>(&op)) {
639 emitOpError() <<
"expected pattern children ops to implement "
640 "ConversionPatternDescriptorOpInterface";
641 diag.attachNote(op.
getLoc()) <<
"op without interface";
646 if (getNumRegions() == 2) {
647 Region &typeConverterRegion = getRegion(1);
648 if (!llvm::hasSingleElement(typeConverterRegion.
front()))
650 <<
"expected exactly one op in default type converter region";
652 auto typeConverterOp = dyn_cast<transform::TypeConverterBuilderOpInterface>(
654 if (!typeConverterOp) {
656 <<
"expected default converter child op to "
657 "implement TypeConverterBuilderOpInterface";
658 diag.attachNote(maybeTypeConverter->
getLoc()) <<
"op without interface";
662 if (!getPatterns().empty()) {
663 for (
Operation &op : getPatterns().front()) {
665 cast<transform::ConversionPatternDescriptorOpInterface>(&op);
666 if (
failed(descriptor.verifyTypeConverter(typeConverterOp)))
674 void transform::ApplyConversionPatternsOp::getEffects(
675 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
676 if (!getPreserveHandles()) {
684 void transform::ApplyConversionPatternsOp::build(
694 if (patternsBodyBuilder)
695 patternsBodyBuilder(builder, result.
location);
701 if (typeConverterBodyBuilder)
702 typeConverterBodyBuilder(builder, result.
location);
710 void transform::ApplyToLLVMConversionPatternsOp::populatePatterns(
713 assert(dialect &&
"expected that dialect is loaded");
714 auto *iface = cast<ConvertToLLVMPatternInterface>(dialect);
718 iface->populateConvertToLLVMConversionPatterns(
722 LogicalResult transform::ApplyToLLVMConversionPatternsOp::verifyTypeConverter(
723 transform::TypeConverterBuilderOpInterface builder) {
724 if (builder.getTypeConverterType() !=
"LLVMTypeConverter")
725 return emitOpError(
"expected LLVMTypeConverter");
732 return emitOpError(
"unknown dialect or dialect not loaded: ")
734 auto *iface = dyn_cast<ConvertToLLVMPatternInterface>(dialect);
737 "dialect does not implement ConvertToLLVMPatternInterface or "
738 "extension was not loaded: ")
748 transform::ApplyLoopInvariantCodeMotionOp::applyToOne(
758 void transform::ApplyLoopInvariantCodeMotionOp::getEffects(
759 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
768 void transform::ApplyRegisteredPassOp::getEffects(
769 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
786 llvm::raw_string_ostream optionsStream(
options);
791 if (
auto paramOperand = dyn_cast<transform::ParamOperandAttr>(valueAttr)) {
794 int64_t dynamicOptionIdx = paramOperand.getIndex().getInt();
795 assert(dynamicOptionIdx <
static_cast<int64_t
>(dynamicOptions.size()) &&
796 "the number of ParamOperandAttrs in the options DictionaryAttr"
797 "should be the same as the number of options passed as params");
799 state.getParams(dynamicOptions[dynamicOptionIdx]);
801 llvm::interleave(attrsAssociatedToParam, optionsStream, appendValueAttr,
803 }
else if (
auto arrayAttr = dyn_cast<ArrayAttr>(valueAttr)) {
805 llvm::interleave(arrayAttr, optionsStream, appendValueAttr,
",");
806 }
else if (
auto strAttr = dyn_cast<StringAttr>(valueAttr)) {
808 optionsStream << strAttr.getValue().str();
811 valueAttr.print(optionsStream,
true);
817 getOptions(), optionsStream,
818 [&](
auto namedAttribute) {
819 optionsStream << namedAttribute.getName().str();
820 optionsStream <<
"=";
821 appendValueAttr(namedAttribute.getValue());
824 optionsStream.flush();
832 <<
"unknown pass or pass pipeline: " << getPassName();
841 <<
"failed to add pass or pass pipeline to pipeline: "
857 if (
failed(pm.run(target))) {
858 auto diag = emitSilenceableError() <<
"pass pipeline failed";
859 diag.attachNote(target->
getLoc()) <<
"target op";
865 results.
set(llvm::cast<OpResult>(getResult()), targets);
871 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &dynamicOptions) {
874 size_t dynamicOptionsIdx = 0;
880 std::function<ParseResult(
Attribute &)> parseValue =
881 [&](
Attribute &valueAttr) -> ParseResult {
889 [&]() -> ParseResult { return parseValue(attrs.emplace_back()); },
890 " in options dictionary") ||
904 ParseResult parsedOperand = parser.
parseOperand(operand);
905 if (
failed(parsedOperand))
911 dynamicOptions.push_back(operand);
918 }
else if (isa<transform::ParamOperandAttr>(valueAttr)) {
920 <<
"the param_operand attribute is a marker reserved for "
921 <<
"indicating a value will be passed via params and is only used "
922 <<
"in the generic print format";
936 <<
"expected key to either be an identifier or a string";
940 <<
"expected '=' after key in key-value pair";
942 if (
failed(parseValue(valueAttr)))
944 <<
"expected a valid attribute or operand as value associated "
945 <<
"to key '" << key <<
"'";
954 " in options dictionary"))
957 if (DictionaryAttr::findDuplicate(
958 keyValuePairs,
false)
961 <<
"duplicate keys found in options dictionary";
976 if (
auto paramOperandAttr =
977 dyn_cast<transform::ParamOperandAttr>(valueAttr)) {
980 dynamicOptions[paramOperandAttr.getIndex().getInt()]);
981 }
else if (
auto arrayAttr = dyn_cast<ArrayAttr>(valueAttr)) {
993 printer << namedAttribute.
getName();
1007 std::function<LogicalResult(
Attribute)> checkOptionValue =
1008 [&](
Attribute valueAttr) -> LogicalResult {
1009 if (
auto paramOperand = dyn_cast<transform::ParamOperandAttr>(valueAttr)) {
1010 int64_t dynamicOptionIdx = paramOperand.getIndex().getInt();
1011 if (dynamicOptionIdx < 0 ||
1012 dynamicOptionIdx >=
static_cast<int64_t
>(dynamicOptions.size()))
1013 return emitOpError()
1014 <<
"dynamic option index " << dynamicOptionIdx
1015 <<
" is out of bounds for the number of dynamic options: "
1016 << dynamicOptions.size();
1017 if (dynamicOptions[dynamicOptionIdx] ==
nullptr)
1018 return emitOpError() <<
"dynamic option index " << dynamicOptionIdx
1019 <<
" is already used in options";
1020 dynamicOptions[dynamicOptionIdx] =
nullptr;
1021 }
else if (
auto arrayAttr = dyn_cast<ArrayAttr>(valueAttr)) {
1023 for (
auto eltAttr : arrayAttr)
1024 if (
failed(checkOptionValue(eltAttr)))
1031 if (
failed(checkOptionValue(namedAttr.getValue())))
1035 for (
Value dynamicOption : dynamicOptions)
1037 return emitOpError() <<
"a param operand does not have a corresponding "
1038 <<
"param_operand attr in the options dict";
1049 Operation *target, ApplyToEachResultList &results,
1051 results.push_back(target);
1055 void transform::CastOp::getEffects(
1056 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
1063 assert(inputs.size() == 1 &&
"expected one input");
1064 assert(outputs.size() == 1 &&
"expected one output");
1065 return llvm::all_of(
1066 std::initializer_list<Type>{inputs.front(), outputs.front()},
1067 llvm::IsaPred<transform::TransformHandleTypeInterface>);
1087 assert(block.
getParent() &&
"cannot match using a detached block");
1088 auto matchScope = state.make_region_scope(*block.
getParent());
1090 state.mapBlockArguments(block.
getArguments(), blockArgumentMapping)))
1094 if (!isa<transform::MatchOpInterface>(match)) {
1096 <<
"expected operations in the match part to "
1097 "implement MatchOpInterface";
1100 state.applyTransform(cast<transform::TransformOpInterface>(match));
1101 if (
diag.succeeded())
1119 template <
typename... Tys>
1121 return ((isa<Tys>(t1) && isa<Tys>(t2)) || ... ||
false);
1128 transform::TransformParamTypeInterface,
1129 transform::TransformValueHandleTypeInterface>(
1141 auto matcher = SymbolTable::lookupNearestSymbolFrom<FunctionOpInterface>(
1142 getOperation(), getMatcher());
1143 if (matcher.isExternal()) {
1145 <<
"unresolved external symbol " << getMatcher();
1149 rawResults.resize(getOperation()->getNumResults());
1150 std::optional<DiagnosedSilenceableFailure> maybeFailure;
1151 for (
Operation *root : state.getPayloadOps(getRoot())) {
1162 matcher.getFunctionBody().front(),
1165 if (
diag.isDefiniteFailure())
1167 if (
diag.isSilenceableFailure()) {
1169 <<
" failed: " <<
diag.getMessage();
1175 if (mapping.size() != 1) {
1176 maybeFailure.emplace(emitSilenceableError()
1177 <<
"result #" << i <<
", associated with "
1179 <<
" payload objects, expected 1");
1182 rawResults[i].push_back(mapping[0]);
1187 return std::move(*maybeFailure);
1188 assert(!maybeFailure &&
"failure set but the walk was not interrupted");
1190 for (
auto &&[opResult, rawResult] :
1191 llvm::zip_equal(getOperation()->getResults(), rawResults)) {
1198 void transform::CollectMatchingOp::getEffects(
1199 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
1205 LogicalResult transform::CollectMatchingOp::verifySymbolUses(
1207 auto matcherSymbol = dyn_cast_or_null<FunctionOpInterface>(
1209 if (!matcherSymbol ||
1210 !isa<TransformOpInterface>(matcherSymbol.getOperation()))
1211 return emitError() <<
"unresolved matcher symbol " << getMatcher();
1214 if (argumentTypes.size() != 1 ||
1215 !isa<TransformHandleTypeInterface>(argumentTypes[0])) {
1217 <<
"expected the matcher to take one operation handle argument";
1219 if (!matcherSymbol.getArgAttr(
1220 0, transform::TransformDialect::kArgReadOnlyAttrName)) {
1221 return emitError() <<
"expected the matcher argument to be marked readonly";
1225 if (resultTypes.size() != getOperation()->getNumResults()) {
1227 <<
"expected the matcher to yield as many values as op has results ("
1228 << getOperation()->getNumResults() <<
"), got "
1229 << resultTypes.size();
1232 for (
auto &&[i, matcherType, resultType] :
1238 <<
"mismatching type interfaces for matcher result and op result #"
1250 bool transform::ForeachMatchOp::allowsRepeatedHandleOperands() {
return true; }
1258 matchActionPairs.reserve(getMatchers().size());
1260 for (
auto &&[matcher, action] :
1261 llvm::zip_equal(getMatchers(), getActions())) {
1262 auto matcherSymbol =
1264 getOperation(), cast<SymbolRefAttr>(matcher));
1267 getOperation(), cast<SymbolRefAttr>(action));
1268 assert(matcherSymbol && actionSymbol &&
1269 "unresolved symbols not caught by the verifier");
1271 if (matcherSymbol.isExternal())
1273 if (actionSymbol.isExternal())
1276 matchActionPairs.emplace_back(matcherSymbol, actionSymbol);
1287 matchInputMapping.emplace_back();
1289 getForwardedInputs(), state);
1291 actionResultMapping.resize(getForwardedOutputs().size());
1293 for (
Operation *root : state.getPayloadOps(getRoot())) {
1297 if (!getRestrictRoot() && op == root)
1305 firstMatchArgument.clear();
1306 firstMatchArgument.push_back(op);
1309 for (
auto [matcher, action] : matchActionPairs) {
1311 matchBlock(matcher.getFunctionBody().front(), matchInputMapping,
1312 state, matchOutputMapping);
1313 if (
diag.isDefiniteFailure())
1315 if (
diag.isSilenceableFailure()) {
1317 <<
" failed: " <<
diag.getMessage();
1321 auto scope = state.make_region_scope(action.getFunctionBody());
1322 if (
failed(state.mapBlockArguments(
1323 action.getFunctionBody().front().getArguments(),
1324 matchOutputMapping))) {
1329 action.getFunctionBody().front().without_terminator()) {
1331 state.applyTransform(cast<TransformOpInterface>(transform));
1336 overallDiag = emitSilenceableError() <<
"actions failed";
1341 <<
"when applied to this matching payload";
1348 action.getFunctionBody().front().getTerminator()->getOperands(),
1349 state, getFlattenResults()))) {
1351 <<
"action @" << action.getName()
1352 <<
" has results associated with multiple payload entities, "
1353 "but flattening was not requested";
1368 results.
set(llvm::cast<OpResult>(getUpdated()),
1369 state.getPayloadOps(getRoot()));
1370 for (
auto &&[result, mapping] :
1371 llvm::zip_equal(getForwardedOutputs(), actionResultMapping)) {
1377 void transform::ForeachMatchOp::getAsmResultNames(
1379 setNameFn(getUpdated(),
"updated_root");
1380 for (
Value v : getForwardedOutputs()) {
1381 setNameFn(v,
"yielded");
1385 void transform::ForeachMatchOp::getEffects(
1386 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
1388 if (getOperation()->getNumOperands() < 1 ||
1389 getOperation()->getNumResults() < 1) {
1402 ArrayAttr &matchers,
1403 ArrayAttr &actions) {
1425 ArrayAttr matchers, ArrayAttr actions) {
1428 for (
auto &&[matcher, action, idx] : llvm::zip_equal(
1429 matchers, actions, llvm::seq<unsigned>(0, matchers.size()))) {
1431 printer << cast<SymbolRefAttr>(matcher) <<
" -> "
1432 << cast<SymbolRefAttr>(action);
1433 if (idx != matchers.size() - 1)
1441 if (getMatchers().size() != getActions().size())
1442 return emitOpError() <<
"expected the same number of matchers and actions";
1443 if (getMatchers().empty())
1444 return emitOpError() <<
"expected at least one match/action pair";
1448 if (matcherNames.insert(name).second)
1451 <<
" is used more than once, only the first match will apply";
1462 bool alsoVerifyInternal =
false) {
1463 auto transformOp = cast<transform::TransformOpInterface>(op.getOperation());
1464 llvm::SmallDenseSet<unsigned> consumedArguments;
1465 if (!op.isExternal()) {
1469 for (
unsigned i = 0, e = op.getNumArguments(); i < e; ++i) {
1471 op.getArgAttr(i, transform::TransformDialect::kArgConsumedAttrName) !=
1474 op.getArgAttr(i, transform::TransformDialect::kArgReadOnlyAttrName) !=
1476 if (isConsumed && isReadOnly) {
1477 return transformOp.emitSilenceableError()
1478 <<
"argument #" << i <<
" cannot be both readonly and consumed";
1480 if ((op.isExternal() || alsoVerifyInternal) && !isConsumed && !isReadOnly) {
1481 return transformOp.emitSilenceableError()
1482 <<
"must provide consumed/readonly status for arguments of "
1483 "external or called ops";
1485 if (op.isExternal())
1488 if (consumedArguments.contains(i) && !isConsumed && isReadOnly) {
1489 return transformOp.emitSilenceableError()
1490 <<
"argument #" << i
1491 <<
" is consumed in the body but is not marked as such";
1493 if (emitWarnings && !consumedArguments.contains(i) && isConsumed) {
1497 <<
"op argument #" << i
1498 <<
" is not consumed in the body but is marked as consumed";
1504 LogicalResult transform::ForeachMatchOp::verifySymbolUses(
1506 assert(getMatchers().size() == getActions().size());
1509 for (
auto &&[matcher, action] :
1510 llvm::zip_equal(getMatchers(), getActions())) {
1512 auto matcherSymbol = dyn_cast_or_null<FunctionOpInterface>(
1514 cast<SymbolRefAttr>(matcher)));
1515 auto actionSymbol = dyn_cast_or_null<FunctionOpInterface>(
1517 cast<SymbolRefAttr>(action)));
1518 if (!matcherSymbol ||
1519 !isa<TransformOpInterface>(matcherSymbol.getOperation()))
1520 return emitError() <<
"unresolved matcher symbol " << matcher;
1521 if (!actionSymbol ||
1522 !isa<TransformOpInterface>(actionSymbol.getOperation()))
1523 return emitError() <<
"unresolved action symbol " << action;
1528 .checkAndReport())) {
1534 .checkAndReport())) {
1539 TypeRange operandTypes = getOperandTypes();
1540 TypeRange matcherArguments = matcherSymbol.getArgumentTypes();
1541 if (operandTypes.size() != matcherArguments.size()) {
1543 emitError() <<
"the number of operands (" << operandTypes.size()
1544 <<
") doesn't match the number of matcher arguments ("
1545 << matcherArguments.size() <<
") for " << matcher;
1546 diag.attachNote(matcherSymbol->getLoc()) <<
"symbol declaration";
1549 for (
auto &&[i, operand, argument] :
1551 if (matcherSymbol.getArgAttr(i, consumedAttr)) {
1554 <<
"does not expect matcher symbol to consume its operand #" << i;
1555 diag.attachNote(matcherSymbol->getLoc()) <<
"symbol declaration";
1564 <<
"mismatching type interfaces for operand and matcher argument #"
1565 << i <<
" of matcher " << matcher;
1566 diag.attachNote(matcherSymbol->getLoc()) <<
"symbol declaration";
1571 TypeRange matcherResults = matcherSymbol.getResultTypes();
1572 TypeRange actionArguments = actionSymbol.getArgumentTypes();
1573 if (matcherResults.size() != actionArguments.size()) {
1574 return emitError() <<
"mismatching number of matcher results and "
1575 "action arguments between "
1576 << matcher <<
" (" << matcherResults.size() <<
") and "
1577 << action <<
" (" << actionArguments.size() <<
")";
1579 for (
auto &&[i, matcherType, actionType] :
1584 return emitError() <<
"mismatching type interfaces for matcher result "
1585 "and action argument #"
1586 << i <<
"of matcher " << matcher <<
" and action "
1591 TypeRange actionResults = actionSymbol.getResultTypes();
1592 auto resultTypes =
TypeRange(getResultTypes()).drop_front();
1593 if (actionResults.size() != resultTypes.size()) {
1595 emitError() <<
"the number of action results ("
1596 << actionResults.size() <<
") for " << action
1597 <<
" doesn't match the number of extra op results ("
1598 << resultTypes.size() <<
")";
1599 diag.attachNote(actionSymbol->getLoc()) <<
"symbol declaration";
1602 for (
auto &&[i, resultType, actionType] :
1608 emitError() <<
"mismatching type interfaces for action result #" << i
1609 <<
" of action " << action <<
" and op result";
1610 diag.attachNote(actionSymbol->getLoc()) <<
"symbol declaration";
1629 size_t numIterations = payloads.empty() ? 0 : payloads.front().size();
1630 bool withZipShortest = getWithZipShortest();
1634 if (withZipShortest) {
1638 return a.size() < b.size();
1641 for (
auto &payload : payloads)
1642 payload.resize(numIterations);
1648 for (
size_t argIdx = 1; !withZipShortest && argIdx < payloads.size();
1650 if (payloads[argIdx].size() != numIterations) {
1651 return emitSilenceableError()
1652 <<
"prior targets' payload size (" << numIterations
1653 <<
") differs from payload size (" << payloads[argIdx].size()
1654 <<
") of target " << getTargets()[argIdx];
1663 for (
size_t iterIdx = 0; iterIdx < numIterations; iterIdx++) {
1664 auto scope = state.make_region_scope(getBody());
1670 if (
failed(state.mapBlockArgument(blockArg, {argument})))
1675 for (
Operation &transform : getBody().front().without_terminator()) {
1677 llvm::cast<transform::TransformOpInterface>(transform));
1683 OperandRange yieldOperands = getYieldOp().getOperands();
1684 for (
auto &&[result, yieldOperand, resTuple] :
1685 llvm::zip_equal(getResults(), yieldOperands, zippedResults))
1687 if (isa<TransformHandleTypeInterface>(result.getType()))
1688 llvm::append_range(resTuple, state.getPayloadOps(yieldOperand));
1689 else if (isa<TransformValueHandleTypeInterface>(result.getType()))
1690 llvm::append_range(resTuple, state.getPayloadValues(yieldOperand));
1691 else if (isa<TransformParamTypeInterface>(result.getType()))
1692 llvm::append_range(resTuple, state.getParams(yieldOperand));
1694 assert(
false &&
"unhandled handle type");
1698 for (
auto &&[result, resPayload] : zip_equal(getResults(), zippedResults))
1704 void transform::ForeachOp::getEffects(
1705 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
1708 for (
auto &&[target, blockArg] :
1709 llvm::zip(getTargetsMutable(), getBody().front().getArguments())) {
1711 if (any_of(getBody().front().without_terminator(), [&](
Operation &op) {
1713 cast<TransformOpInterface>(&op));
1721 if (any_of(getBody().front().without_terminator(), [&](
Operation &op) {
1725 }
else if (any_of(getBody().front().without_terminator(), [&](
Operation &op) {
1734 void transform::ForeachOp::getSuccessorRegions(
1736 Region *bodyRegion = &getBody();
1738 regions.emplace_back(bodyRegion, bodyRegion->
getArguments());
1743 assert(point == getBody() &&
"unexpected region index");
1744 regions.emplace_back(bodyRegion, bodyRegion->
getArguments());
1745 regions.emplace_back();
1752 assert(point == getBody() &&
"unexpected region index");
1753 return getOperation()->getOperands();
1756 transform::YieldOp transform::ForeachOp::getYieldOp() {
1757 return cast<transform::YieldOp>(getBody().front().getTerminator());
1761 for (
auto [targetOpt, bodyArgOpt] :
1762 llvm::zip_longest(getTargets(), getBody().front().getArguments())) {
1763 if (!targetOpt || !bodyArgOpt)
1764 return emitOpError() <<
"expects the same number of targets as the body "
1765 "has block arguments";
1766 if (targetOpt.value().getType() != bodyArgOpt.value().getType())
1768 "expects co-indexed targets and the body's "
1769 "block arguments to have the same op/value/param type");
1772 for (
auto [resultOpt, yieldOperandOpt] :
1773 llvm::zip_longest(getResults(), getYieldOp().getOperands())) {
1774 if (!resultOpt || !yieldOperandOpt)
1775 return emitOpError() <<
"expects the same number of results as the "
1776 "yield terminator has operands";
1777 if (resultOpt.value().getType() != yieldOperandOpt.value().getType())
1778 return emitOpError(
"expects co-indexed results and yield "
1779 "operands to have the same op/value/param type");
1795 for (
Operation *target : state.getPayloadOps(getTarget())) {
1797 for (int64_t i = 0, e = getNthParent(); i < e; ++i) {
1800 bool checkIsolatedFromAbove =
1801 !getIsolatedFromAbove() ||
1803 bool checkOpName = !getOpName().has_value() ||
1805 if (checkIsolatedFromAbove && checkOpName)
1810 if (getAllowEmptyResults()) {
1811 results.
set(llvm::cast<OpResult>(getResult()), parents);
1815 emitSilenceableError()
1816 <<
"could not find a parent op that matches all requirements";
1817 diag.attachNote(target->
getLoc()) <<
"target op";
1821 if (getDeduplicate()) {
1822 if (resultSet.insert(parent).second)
1823 parents.push_back(parent);
1825 parents.push_back(parent);
1828 results.
set(llvm::cast<OpResult>(getResult()), parents);
1840 int64_t resultNumber = getResultNumber();
1841 auto payloadOps = state.getPayloadOps(getTarget());
1842 if (std::empty(payloadOps)) {
1843 results.
set(cast<OpResult>(getResult()), {});
1846 if (!llvm::hasSingleElement(payloadOps))
1848 <<
"handle must be mapped to exactly one payload op";
1850 Operation *target = *payloadOps.begin();
1853 results.
set(llvm::cast<OpResult>(getResult()),
1867 for (
Value v : state.getPayloadValues(getTarget())) {
1868 if (llvm::isa<BlockArgument>(v)) {
1870 emitSilenceableError() <<
"cannot get defining op of block argument";
1871 diag.attachNote(v.getLoc()) <<
"target value";
1874 definingOps.push_back(v.getDefiningOp());
1876 results.
set(llvm::cast<OpResult>(getResult()), definingOps);
1888 int64_t operandNumber = getOperandNumber();
1890 for (
Operation *target : state.getPayloadOps(getTarget())) {
1897 emitSilenceableError()
1898 <<
"could not find a producer for operand number: " << operandNumber
1899 <<
" of " << *target;
1900 diag.attachNote(target->getLoc()) <<
"target op";
1903 producers.push_back(producer);
1905 results.
set(llvm::cast<OpResult>(getResult()), producers);
1918 for (
Operation *target : state.getPayloadOps(getTarget())) {
1921 getLoc(), getIsAll(), getIsInverted(), getRawPositionList(),
1922 target->getNumOperands(), operandPositions);
1923 if (
diag.isSilenceableFailure()) {
1924 diag.attachNote(target->getLoc())
1925 <<
"while considering positions of this payload operation";
1928 llvm::append_range(operands,
1929 llvm::map_range(operandPositions, [&](int64_t pos) {
1930 return target->getOperand(pos);
1933 results.
setValues(cast<OpResult>(getResult()), operands);
1939 getIsInverted(), getIsAll());
1951 for (
Operation *target : state.getPayloadOps(getTarget())) {
1954 getLoc(), getIsAll(), getIsInverted(), getRawPositionList(),
1955 target->getNumResults(), resultPositions);
1956 if (
diag.isSilenceableFailure()) {
1957 diag.attachNote(target->getLoc())
1958 <<
"while considering positions of this payload operation";
1961 llvm::append_range(opResults,
1962 llvm::map_range(resultPositions, [&](int64_t pos) {
1963 return target->getResult(pos);
1966 results.
setValues(cast<OpResult>(getResult()), opResults);
1972 getIsInverted(), getIsAll());
1979 void transform::GetTypeOp::getEffects(
1980 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
1991 for (
Value value : state.getPayloadValues(getValue())) {
1992 Type type = value.getType();
1993 if (getElemental()) {
1994 if (
auto shaped = dyn_cast<ShapedType>(type)) {
1995 type = shaped.getElementType();
2000 results.
setParams(cast<OpResult>(getResult()), params);
2017 state.applyTransform(cast<transform::TransformOpInterface>(transform));
2022 if (mode == transform::FailurePropagationMode::Propagate) {
2041 auto callee = SymbolTable::lookupNearestSymbolFrom<NamedSequenceOp>(
2042 getOperation(), getTarget());
2043 assert(callee &&
"unverified reference to unknown symbol");
2045 if (callee.isExternal())
2051 auto scope = state.make_region_scope(callee.getBody());
2052 for (
auto &&[arg, map] :
2053 llvm::zip_equal(callee.getBody().front().getArguments(), mappings)) {
2054 if (
failed(state.mapBlockArgument(arg, map)))
2059 callee.getBody().front(), getFailurePropagationMode(), state, results);
2062 mappings, callee.getBody().front().getTerminator()->getOperands(), state);
2063 for (
auto &&[result, mapping] : llvm::zip_equal(getResults(), mappings))
2071 void transform::IncludeOp::getEffects(
2072 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
2086 auto defaultEffects = [&] {
2093 getOperation()->getAttrOfType<SymbolRefAttr>(getTargetAttrName());
2095 return defaultEffects();
2096 auto callee = SymbolTable::lookupNearestSymbolFrom<NamedSequenceOp>(
2097 getOperation(), getTarget());
2099 return defaultEffects();
2101 for (
unsigned i = 0, e = getNumOperands(); i < e; ++i) {
2102 if (callee.getArgAttr(i, TransformDialect::kArgConsumedAttrName))
2104 else if (callee.getArgAttr(i, TransformDialect::kArgReadOnlyAttrName))
2113 auto targetAttr = getOperation()->getAttrOfType<SymbolRefAttr>(
"target");
2115 return emitOpError() <<
"expects a 'target' symbol reference attribute";
2120 return emitOpError() <<
"does not reference a named transform sequence";
2122 FunctionType fnType = target.getFunctionType();
2123 if (fnType.getNumInputs() != getNumOperands())
2124 return emitError(
"incorrect number of operands for callee");
2126 for (
unsigned i = 0, e = fnType.getNumInputs(); i != e; ++i) {
2127 if (getOperand(i).
getType() != fnType.getInput(i)) {
2128 return emitOpError(
"operand type mismatch: expected operand type ")
2129 << fnType.getInput(i) <<
", but provided "
2130 << getOperand(i).getType() <<
" for operand number " << i;
2134 if (fnType.getNumResults() != getNumResults())
2135 return emitError(
"incorrect number of results for callee");
2137 for (
unsigned i = 0, e = fnType.getNumResults(); i != e; ++i) {
2138 Type resultType = getResult(i).getType();
2139 Type funcType = fnType.getResult(i);
2141 return emitOpError() <<
"type of result #" << i
2142 <<
" must implement the same transform dialect "
2143 "interface as the corresponding callee result";
2148 cast<FunctionOpInterface>(*target),
false,
2158 ::std::optional<::mlir::Operation *> maybeCurrent,
2160 if (!maybeCurrent.has_value()) {
2165 return emitSilenceableError() <<
"operation is not empty";
2176 for (
auto acceptedAttr : getOpNames().getAsRange<StringAttr>()) {
2177 if (acceptedAttr.getValue() == currentOpName)
2180 return emitSilenceableError() <<
"wrong operation name";
2191 auto signedAPIntAsString = [&](
const APInt &value) {
2193 llvm::raw_string_ostream os(str);
2194 value.print(os,
true);
2201 if (params.size() != references.size()) {
2202 return emitSilenceableError()
2203 <<
"parameters have different payload lengths (" << params.size()
2204 <<
" vs " << references.size() <<
")";
2207 for (
auto &&[i, param, reference] :
llvm::enumerate(params, references)) {
2208 auto intAttr = llvm::dyn_cast<IntegerAttr>(param);
2209 auto refAttr = llvm::dyn_cast<IntegerAttr>(reference);
2210 if (!intAttr || !refAttr) {
2212 <<
"non-integer parameter value not expected";
2214 if (intAttr.getType() != refAttr.getType()) {
2216 <<
"mismatching integer attribute types in parameter #" << i;
2218 APInt value = intAttr.getValue();
2219 APInt refValue = refAttr.getValue();
2222 int64_t position = i;
2223 auto reportError = [&](StringRef direction) {
2225 emitSilenceableError() <<
"expected parameter to be " << direction
2226 <<
" " << signedAPIntAsString(refValue)
2227 <<
", got " << signedAPIntAsString(value);
2228 diag.attachNote(getParam().getLoc())
2229 <<
"value # " << position
2230 <<
" associated with the parameter defined here";
2234 switch (getPredicate()) {
2235 case MatchCmpIPredicate::eq:
2236 if (value.eq(refValue))
2238 return reportError(
"equal to");
2239 case MatchCmpIPredicate::ne:
2240 if (value.ne(refValue))
2242 return reportError(
"not equal to");
2243 case MatchCmpIPredicate::lt:
2244 if (value.slt(refValue))
2246 return reportError(
"less than");
2247 case MatchCmpIPredicate::le:
2248 if (value.sle(refValue))
2250 return reportError(
"less than or equal to");
2251 case MatchCmpIPredicate::gt:
2252 if (value.sgt(refValue))
2254 return reportError(
"greater than");
2255 case MatchCmpIPredicate::ge:
2256 if (value.sge(refValue))
2258 return reportError(
"greater than or equal to");
2264 void transform::MatchParamCmpIOp::getEffects(
2265 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
2278 results.
setParams(cast<OpResult>(getParam()), {getValue()});
2291 if (isa<TransformHandleTypeInterface>(handles.front().
getType())) {
2293 for (
Value operand : handles)
2294 llvm::append_range(operations, state.getPayloadOps(operand));
2295 if (!getDeduplicate()) {
2296 results.
set(llvm::cast<OpResult>(getResult()), operations);
2301 results.
set(llvm::cast<OpResult>(getResult()), uniqued.getArrayRef());
2305 if (llvm::isa<TransformParamTypeInterface>(handles.front().getType())) {
2307 for (
Value attribute : handles)
2308 llvm::append_range(attrs, state.getParams(attribute));
2309 if (!getDeduplicate()) {
2310 results.
setParams(cast<OpResult>(getResult()), attrs);
2315 results.
setParams(cast<OpResult>(getResult()), uniqued.getArrayRef());
2320 llvm::isa<TransformValueHandleTypeInterface>(handles.front().getType()) &&
2321 "expected value handle type");
2323 for (
Value value : handles)
2324 llvm::append_range(payloadValues, state.getPayloadValues(value));
2325 if (!getDeduplicate()) {
2326 results.
setValues(cast<OpResult>(getResult()), payloadValues);
2331 results.
setValues(cast<OpResult>(getResult()), uniqued.getArrayRef());
2335 bool transform::MergeHandlesOp::allowsRepeatedHandleOperands() {
2337 return getDeduplicate();
2340 void transform::MergeHandlesOp::getEffects(
2341 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
2349 OpFoldResult transform::MergeHandlesOp::fold(FoldAdaptor adaptor) {
2350 if (getDeduplicate() || getHandles().size() != 1)
2355 return getHandles().front();
2373 auto scope = state.make_region_scope(getBody());
2375 state, this->getOperation(), getBody())))
2379 FailurePropagationMode::Propagate, state, results);
2382 void transform::NamedSequenceOp::getEffects(
2383 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {}
2388 parser, result,
false,
2389 getFunctionTypeAttrName(result.
name),
2392 std::string &) { return builder.getFunctionType(inputs, results); },
2393 getArgAttrsAttrName(result.
name), getResAttrsAttrName(result.
name));
2398 printer, cast<FunctionOpInterface>(getOperation()),
false,
2399 getFunctionTypeAttrName().getValue(), getArgAttrsAttrName(),
2400 getResAttrsAttrName());
2410 if (
auto parent = op->
getParentOfType<transform::TransformOpInterface>()) {
2413 <<
"cannot be defined inside another transform op";
2414 diag.attachNote(parent.
getLoc()) <<
"ancestor transform op";
2418 if (op.isExternal() || op.getFunctionBody().empty()) {
2425 if (op.getFunctionBody().front().empty())
2428 Operation *terminator = &op.getFunctionBody().front().back();
2429 if (!isa<transform::YieldOp>(terminator)) {
2432 << transform::YieldOp::getOperationName()
2433 <<
"' as terminator";
2434 diag.attachNote(terminator->
getLoc()) <<
"terminator";
2438 if (terminator->
getNumOperands() != op.getResultTypes().size()) {
2440 <<
"expected terminator to have as many operands as the parent op "
2443 for (
auto [i, operandType, resultType] : llvm::zip_equal(
2446 if (operandType == resultType)
2449 <<
"the type of the terminator operand #" << i
2450 <<
" must match the type of the corresponding parent op result ("
2451 << operandType <<
" vs " << resultType <<
")";
2464 transform::TransformDialect::kWithNamedSequenceAttrName)) {
2467 <<
"expects the parent symbol table to have the '"
2468 << transform::TransformDialect::kWithNamedSequenceAttrName
2470 diag.attachNote(parent->
getLoc()) <<
"symbol table operation";
2475 if (
auto parent = op->
getParentOfType<transform::TransformOpInterface>()) {
2478 <<
"cannot be defined inside another transform op";
2479 diag.attachNote(parent.
getLoc()) <<
"ancestor transform op";
2483 if (op.isExternal() || op.getBody().empty())
2487 if (op.getBody().front().empty())
2490 Operation *terminator = &op.getBody().front().back();
2491 if (!isa<transform::YieldOp>(terminator)) {
2494 << transform::YieldOp::getOperationName()
2495 <<
"' as terminator";
2496 diag.attachNote(terminator->
getLoc()) <<
"terminator";
2500 if (terminator->
getNumOperands() != op.getFunctionType().getNumResults()) {
2502 <<
"expected terminator to have as many operands as the parent op "
2505 for (
auto [i, operandType, resultType] :
2506 llvm::zip_equal(llvm::seq<unsigned>(0, terminator->
getNumOperands()),
2508 op.getFunctionType().getResults())) {
2509 if (operandType == resultType)
2512 <<
"the type of the terminator operand #" << i
2513 <<
" must match the type of the corresponding parent op result ("
2514 << operandType <<
" vs " << resultType <<
")";
2517 auto funcOp = cast<FunctionOpInterface>(*op);
2520 if (!
diag.succeeded())
2532 template <
typename FnTy>
2537 types.reserve(1 + extraBindingTypes.size());
2538 types.push_back(bbArgType);
2539 llvm::append_range(types, extraBindingTypes);
2542 Region *region = state.regions.back().get();
2549 if constexpr (llvm::function_traits<FnTy>::num_args == 3) {
2550 bodyBuilder(builder, state.location, bodyBlock->
getArgument(0));
2552 bodyBuilder(builder, state.location, bodyBlock->
getArgument(0),
2557 void transform::NamedSequenceOp::build(
OpBuilder &builder,
2565 state.addAttribute(getFunctionTypeAttrName(state.name),
2567 rootType, resultTypes)));
2568 state.attributes.append(attrs.begin(), attrs.end());
2583 size_t numAssociations =
2585 .Case([&](TransformHandleTypeInterface opHandle) {
2586 return llvm::range_size(state.getPayloadOps(getHandle()));
2588 .Case([&](TransformValueHandleTypeInterface valueHandle) {
2589 return llvm::range_size(state.getPayloadValues(getHandle()));
2591 .Case([&](TransformParamTypeInterface param) {
2592 return llvm::range_size(state.getParams(getHandle()));
2594 .DefaultUnreachable(
"unknown kind of transform dialect type");
2595 results.
setParams(cast<OpResult>(getNum()),
2602 auto resultType = cast<TransformParamTypeInterface>(getNum().
getType());
2617 auto payloadOps = state.getPayloadOps(getTarget());
2620 result.push_back(op);
2622 results.
set(cast<OpResult>(getResult()), result);
2631 Value target, int64_t numResultHandles) {
2640 int64_t numPayloads =
2642 .Case<TransformHandleTypeInterface>([&](
auto x) {
2643 return llvm::range_size(state.getPayloadOps(getHandle()));
2645 .Case<TransformValueHandleTypeInterface>([&](
auto x) {
2646 return llvm::range_size(state.getPayloadValues(getHandle()));
2648 .Case<TransformParamTypeInterface>([&](
auto x) {
2649 return llvm::range_size(state.getParams(getHandle()));
2651 .DefaultUnreachable(
"unknown transform dialect type interface");
2653 auto produceNumOpsError = [&]() {
2654 return emitSilenceableError()
2655 << getHandle() <<
" expected to contain " << this->getNumResults()
2656 <<
" payloads but it contains " << numPayloads <<
" payloads";
2661 if (numPayloads > getNumResults() && !getOverflowResult().has_value())
2662 return produceNumOpsError();
2667 if (numPayloads < getNumResults() && getFailOnPayloadTooSmall() &&
2668 (numPayloads != 0 || !getPassThroughEmptyHandle()))
2669 return produceNumOpsError();
2673 if (getOverflowResult())
2674 resultHandles[*getOverflowResult()].reserve(numPayloads - getNumResults());
2676 auto container = [&]() {
2677 if (isa<TransformHandleTypeInterface>(getHandle().
getType())) {
2678 return llvm::map_to_vector(
2679 state.getPayloadOps(getHandle()),
2682 if (isa<TransformValueHandleTypeInterface>(getHandle().
getType())) {
2683 return llvm::map_to_vector(state.getPayloadValues(getHandle()),
2686 assert(isa<TransformParamTypeInterface>(getHandle().
getType()) &&
2687 "unsupported kind of transform dialect type");
2688 return llvm::map_to_vector(state.getParams(getHandle()),
2693 int64_t resultNum = en.index();
2694 if (resultNum >= getNumResults())
2695 resultNum = *getOverflowResult();
2696 resultHandles[resultNum].push_back(en.value());
2707 void transform::SplitHandleOp::getEffects(
2708 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
2716 if (getOverflowResult().has_value() &&
2717 !(*getOverflowResult() < getNumResults()))
2718 return emitOpError(
"overflow_result is not a valid result index");
2720 for (
Type resultType : getResultTypes()) {
2724 return emitOpError(
"expects result types to implement the same transform "
2725 "interface as the operand type");
2739 unsigned numRepetitions = llvm::range_size(state.getPayloadOps(getPattern()));
2741 Value handle = en.value();
2742 if (isa<TransformHandleTypeInterface>(handle.
getType())) {
2744 llvm::to_vector(state.getPayloadOps(handle));
2746 payload.reserve(numRepetitions * current.size());
2747 for (
unsigned i = 0; i < numRepetitions; ++i)
2748 llvm::append_range(payload, current);
2749 results.
set(llvm::cast<OpResult>(getReplicated()[en.index()]), payload);
2751 assert(llvm::isa<TransformParamTypeInterface>(handle.
getType()) &&
2752 "expected param type");
2755 params.reserve(numRepetitions * current.size());
2756 for (
unsigned i = 0; i < numRepetitions; ++i)
2757 llvm::append_range(params, current);
2758 results.
setParams(llvm::cast<OpResult>(getReplicated()[en.index()]),
2765 void transform::ReplicateOp::getEffects(
2766 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
2781 auto scope = state.make_region_scope(*getBodyBlock()->getParent());
2782 if (
failed(mapBlockArguments(state)))
2790 OpAsmParser &parser, std::optional<OpAsmParser::UnresolvedOperand> &root,
2792 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &extraBindings,
2793 SmallVectorImpl<Type> &extraBindingTypes) {
2797 root = std::nullopt;
2818 if (!extraBindings.empty()) {
2823 if (extraBindingTypes.size() != extraBindings.size()) {
2825 "expected types to be provided for all operands");
2841 bool hasExtras = !extraBindings.empty();
2851 printer << rootType;
2853 printer <<
", " << llvm::interleaved(extraBindingTypes) <<
')';
2860 auto iface = dyn_cast<transform::TransformOpInterface>(use.
getOwner());
2875 if (!potentialConsumer) {
2876 potentialConsumer = &use;
2881 <<
" has more than one potential consumer";
2884 diag.attachNote(use.getOwner()->getLoc())
2885 <<
"used here as operand #" << use.getOperandNumber();
2893 assert(getBodyBlock()->getNumArguments() >= 1 &&
2894 "the number of arguments must have been verified to be more than 1 by "
2895 "PossibleTopLevelTransformOpTrait");
2897 if (!getRoot() && !getExtraBindings().empty()) {
2898 return emitOpError()
2899 <<
"does not expect extra operands when used as top-level";
2905 return (emitOpError() <<
"block argument #" << arg.getArgNumber());
2912 for (
Operation &child : *getBodyBlock()) {
2913 if (!isa<TransformOpInterface>(child) &&
2914 &child != &getBodyBlock()->back()) {
2917 <<
"expected children ops to implement TransformOpInterface";
2918 diag.attachNote(child.getLoc()) <<
"op without interface";
2922 for (
OpResult result : child.getResults()) {
2923 auto report = [&]() {
2924 return (child.emitError() <<
"result #" << result.getResultNumber());
2931 if (!getBodyBlock()->mightHaveTerminator())
2932 return emitOpError() <<
"expects to have a terminator in the body";
2934 if (getBodyBlock()->getTerminator()->getOperandTypes() !=
2935 getOperation()->getResultTypes()) {
2937 <<
"expects the types of the terminator operands "
2938 "to match the types of the result";
2939 diag.attachNote(getBodyBlock()->getTerminator()->getLoc()) <<
"terminator";
2945 void transform::SequenceOp::getEffects(
2946 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
2952 assert(point == getBody() &&
"unexpected region index");
2953 if (getOperation()->getNumOperands() > 0)
2954 return getOperation()->getOperands();
2956 getOperation()->operand_end());
2959 void transform::SequenceOp::getSuccessorRegions(
2962 Region *bodyRegion = &getBody();
2963 regions.emplace_back(bodyRegion, getNumOperands() != 0
2969 assert(point == getBody() &&
"unexpected region index");
2970 regions.emplace_back(getOperation()->getResults());
2973 void transform::SequenceOp::getRegionInvocationBounds(
2976 bounds.emplace_back(1, 1);
2981 FailurePropagationMode failurePropagationMode,
2984 build(builder, state, resultTypes, failurePropagationMode, root,
2993 FailurePropagationMode failurePropagationMode,
2996 build(builder, state, resultTypes, failurePropagationMode, root,
3004 FailurePropagationMode failurePropagationMode,
3007 build(builder, state, resultTypes, failurePropagationMode,
Value(),
3015 FailurePropagationMode failurePropagationMode,
3018 build(builder, state, resultTypes, failurePropagationMode,
Value(),
3034 Value target, StringRef name) {
3036 build(builder, result, name);
3043 llvm::outs() <<
"[[[ IR printer: ";
3044 if (getName().has_value())
3045 llvm::outs() << *getName() <<
" ";
3048 if (getAssumeVerified().value_or(
false))
3050 if (getUseLocalScope().value_or(
false))
3052 if (getSkipRegions().value_or(
false))
3056 llvm::outs() <<
"top-level ]]]\n";
3057 state.getTopLevel()->print(llvm::outs(), printFlags);
3058 llvm::outs() <<
"\n";
3059 llvm::outs().flush();
3063 llvm::outs() <<
"]]]\n";
3064 for (
Operation *target : state.getPayloadOps(getTarget())) {
3065 target->
print(llvm::outs(), printFlags);
3066 llvm::outs() <<
"\n";
3069 llvm::outs().flush();
3073 void transform::PrintOp::getEffects(
3074 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
3078 if (!getTargetMutable().empty())
3098 <<
"failed to verify payload op";
3099 diag.attachNote(target->
getLoc()) <<
"payload op";
3105 void transform::VerifyOp::getEffects(
3106 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
3114 void transform::YieldOp::getEffects(
3115 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
static ParseResult parseKeyValuePair(AsmParser &parser, DataLayoutEntryInterface &entry, bool tryType=false)
Parse an entry which can either be of the form key = value or a #dlti.dl_entry attribute.
static MLIRContext * getContext(OpFoldResult val)
static std::string diag(const llvm::Value &value)
static llvm::ManagedStatic< PassManagerOptions > options
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
ParseResult parseSymbolName(StringAttr &result)
Parse an -identifier and store it (without the '@' symbol) in a string attribute.
@ None
Zero or more operands with no delimiters.
@ Braces
{} brackets surrounding zero or more operands.
virtual ParseResult parseOptionalKeywordOrString(std::string *result)=0
Parse an optional keyword or string.
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
MLIRContext * getContext() const
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseRSquare()=0
Parse a ] token.
virtual ParseResult parseOptionalRParen()=0
Parse a ) token if present.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual OptionalParseResult parseOptionalAttribute(Attribute &result, Type type={})=0
Parse an arbitrary optional attribute of a given type and return it in result.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual ParseResult parseColon()=0
Parse a : token.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseArrow()=0
Parse a '->' token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseComma()=0
Parse a , token.
virtual ParseResult parseOptionalLParen()=0
Parse a ( token if present.
ParseResult parseTypeList(SmallVectorImpl< Type > &result)
Parse a type list.
virtual ParseResult parseOptionalLSquare()=0
Parse a [ token if present.
virtual void printAttribute(Attribute attr)
Attributes are known-constant values of operations.
This class represents an argument of a Block.
Block represents an ordered list of Operations.
MutableArrayRef< BlockArgument > BlockArgListType
BlockArgument getArgument(unsigned i)
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
Operation * getTerminator()
Get the terminator operation of this block.
OpListType & getOperations()
BlockArgListType getArguments()
iterator_range< iterator > without_terminator()
Return an iterator range over the operation within this block excluding the terminator operation at t...
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
This class is a general helper class for creating context-global objects like types,...
IntegerAttr getI64IntegerAttr(int64_t value)
StringAttr getStringAttr(const Twine &bytes)
MLIRContext * getContext() const
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
This class describes a specific conversion target.
A compatibility class connecting InFlightDiagnostic to DiagnosedSilenceableFailure while providing an...
The result of a transform IR operation application.
LogicalResult silence()
Converts silenceable failure into LogicalResult success without reporting the diagnostic,...
static DiagnosedSilenceableFailure success()
Constructs a DiagnosedSilenceableFailure in the success state.
Diagnostic & attachNote(std::optional< Location > loc=std::nullopt)
Attaches a note to the last diagnostic.
std::string getMessage() const
Returns the diagnostic message without emitting it.
bool isDefiniteFailure() const
Returns true if this is a definite failure.
LogicalResult checkAndReport()
Converts all kinds of failure into a LogicalResult failure, emitting the diagnostic if necessary.
bool succeeded() const
Returns true if this is a success.
static DiagnosedSilenceableFailure definiteFailure()
Constructs a DiagnosedSilenceableFailure in the failure state.
bool isSilenceableFailure() const
Returns true if this is a silenceable failure.
Dialects are groups of MLIR operations, types and attributes, as well as behavior associated with the...
A class for computing basic dominance information.
This class represents a frozen set of patterns that can be processed by a pattern applicator.
This class allows control over how the GreedyPatternRewriteDriver works.
static constexpr int64_t kNoLimit
IRValueT get() const
Return the current value being used by this operand.
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
This class represents a diagnostic that is inflight and set to be reported.
This class represents upper and lower bounds on the number of times a region of a RegionBranchOpInter...
Conversion from types to the LLVM IR dialect.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
Dialect * getLoadedDialect(StringRef name)
Get a registered IR dialect with the given namespace.
std::vector< Dialect * > getLoadedDialects()
Return information about all IR dialects loaded in the context.
ArrayRef< RegisteredOperationName > getRegisteredOperations()
Return a sorted array containing the information about all registered operations.
NamedAttribute represents a combination of a name and an Attribute value.
StringAttr getName() const
Return the name of the attribute.
Attribute getValue() const
Return the value of the attribute.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual OptionalParseResult parseOptionalOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single operand if present.
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printNewline()=0
Print a newline and indent the printer to the start of the current operation.
virtual void increaseIndent()=0
Increase indentation.
void printOperands(const ContainerType &container)
Print a comma separated list of operands.
virtual void decreaseIndent()=0
Decrease indentation.
virtual void printOperand(Value value)=0
Print implementations for various things an operation contains.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Listener * getListener() const
Returns the current listener of this builder, or nullptr if this builder doesn't have a listener.
This class represents a single result from folding an operation.
This class represents an operand of an operation.
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Set of flags used to control the behavior of the various IR print methods (e.g.
OpPrintingFlags & useLocalScope(bool enable=true)
Use local scope when printing the operation.
OpPrintingFlags & assumeVerified(bool enable=true)
Do not verify the operation when using custom operation printers.
OpPrintingFlags & skipRegions(bool skip=true)
Skip printing regions.
This is a value defined by a result of an operation.
This class provides the API for ops that are known to be isolated from above.
A trait used to provide symbol table functionalities to a region operation.
A wrapper class that allows for printing an operation with a set of flags, useful to act as a "stream...
This class implements the operand iterators for the Operation class.
type_range getType() const
type_range getTypes() const
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
Operation is the basic unit of execution within MLIR.
Value getOperand(unsigned idx)
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Attribute getAttr(StringAttr name)
Return the specified attribute if present, null otherwise.
Operation * clone(IRMapping &mapper, CloneOptions options=CloneOptions::all())
Create a deep copy of this operation, remapping any operands that use values outside of the operation...
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
MLIRContext * getContext()
Return the context this operation is associated with.
Location getLoc()
The source location the operation was defined or derived from.
unsigned getNumOperands()
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Block * getBlock()
Returns the operation block that contains this operation.
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
Operation * getParentWithTrait()
Returns the closest surrounding parent operation with trait Trait.
OperationName getName()
The name of an operation is the key identifier for it.
operand_range getOperands()
Returns an iterator on the underlying Value's.
result_range getOpResults()
result_range getResults()
bool isProperAncestor(Operation *other)
Return true if this operation is a proper ancestor of the other operation.
void erase()
Remove this operation from its parent block and delete it.
unsigned getNumResults()
Return the number of results held by this operation.
This class implements Optional functionality for ParseResult.
ParseResult value() const
Access the internal ParseResult value.
bool has_value() const
Returns true if we contain a valid ParseResult value.
static const PassInfo * lookup(StringRef passArg)
Returns the pass info for the specified pass class or null if unknown.
The main pass manager and pipeline builder.
static const PassPipelineInfo * lookup(StringRef pipelineArg)
Returns the pass pipeline info for the specified pass pipeline or null if unknown.
Structure to group information about a passes and pass pipelines (argument to invoke via mlir-opt,...
LogicalResult addToPipeline(OpPassManager &pm, StringRef options, function_ref< LogicalResult(const Twine &)> errorHandler) const
Adds this pass registry entry to the given pass manager.
This class represents a point being branched from in the methods of the RegionBranchOpInterface.
bool isParent() const
Returns true if branching from the parent op.
Region * getRegionOrNull() const
Returns the region if branching from a region.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
BlockArgListType getArguments()
unsigned getRegionNumber()
Return the number of this region in the parent operation.
This is a "type erased" representation of a registered operation.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
static DerivedEffect * get()
Returns a unique instance for the derived effect class.
This class represents a collection of SymbolTables.
virtual Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
static StringRef getSymbolAttrName()
Return the name of the attribute used for symbol names.
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
type_range getType() const
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
void print(raw_ostream &os) const
Type getType() const
Return the type of this value.
use_range getUses() const
Returns a range of all uses, which is useful for iterating over all uses.
user_range getUsers() const
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
A utility result that is used to signal how to proceed with an ongoing walk:
static WalkResult advance()
bool wasInterrupted() const
Returns true if the walk was interrupted.
static WalkResult interrupt()
Operation * getOwner() const
Return the owner of this operand.
A named class for passing around the variadic flag.
static void printOptionValue(raw_ostream &os, const bool &value)
Utility methods for printing option values.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
void printFunctionOp(OpAsmPrinter &p, FunctionOpInterface op, bool isVariadic, StringRef typeAttrName, StringAttr argAttrsName, StringAttr resAttrsName)
Printer implementation for function-like operations.
ParseResult parseFunctionOp(OpAsmParser &parser, OperationState &result, bool allowVariadic, StringAttr typeAttrName, FuncTypeBuilder funcTypeBuilder, StringAttr argAttrsName, StringAttr resAttrsName)
Parser implementation for function-like operations.
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
Include the generated interface declarations.
InFlightDiagnostic emitWarning(Location loc)
Utility method to emit a warning message using this location.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
const FrozenRewritePatternSet GreedyRewriteConfig config
LogicalResult applyPatternsGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
DiagnosedSilenceableFailure emitSilenceableFailure(Location loc, const Twine &message={})
Emits a silenceable failure with the given message.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
LogicalResult applyFullConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Apply a complete conversion on the given operations, and all nested operations.
LogicalResult applyOpPatternsGreedily(ArrayRef< Operation * > ops, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr, bool *allErased=nullptr)
Rewrite the specified ops by repeatedly applying the highest benefit patterns in a greedy worklist dr...
DiagnosedDefiniteFailure emitDefiniteFailure(Location loc, const Twine &message={})
Emits a definite failure with the given message.
bool isOpTriviallyDead(Operation *op)
Return true if the given operation is unused, and has no side effects on memory that prevent erasing.
const FrozenRewritePatternSet & patterns
void eliminateCommonSubExpressions(RewriterBase &rewriter, DominanceInfo &domInfo, Operation *op, bool *changed=nullptr)
Eliminate common subexpressions within the given operation.
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.
size_t moveLoopInvariantCode(ArrayRef< Region * > regions, function_ref< bool(Value, Region *)> isDefinedOutsideRegion, function_ref< bool(Operation *, Region *)> shouldMoveOutOfRegion, function_ref< void(Operation *, Region *)> moveOutOfRegion)
Given a list of regions, perform loop-invariant code motion.
Dialect conversion configuration.
RewriterBase::Listener * listener
An optional listener that is notified about all IR modifications in case dialect conversion succeeds.
This is the representation of an operand reference.
This represents an operation in an abstracted form, suitable for use with the builder APIs.
T & getOrAddProperties()
Get (or create) a properties of the provided type to be set on the operation on creation.
void addOperands(ValueRange newOperands)
void addTypes(ArrayRef< Type > newTypes)
Region * addRegion()
Create a region that should be attached to the operation.