36 #include "llvm/ADT/DenseSet.h"
37 #include "llvm/ADT/STLExtras.h"
38 #include "llvm/ADT/ScopeExit.h"
39 #include "llvm/ADT/SmallPtrSet.h"
40 #include "llvm/ADT/TypeSwitch.h"
41 #include "llvm/Support/Debug.h"
42 #include "llvm/Support/ErrorHandling.h"
45 #define DEBUG_TYPE "transform-dialect"
46 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "] ")
48 #define DEBUG_TYPE_MATCHER "transform-matcher"
49 #define DBGS_MATCHER() (llvm::dbgs() << "[" DEBUG_TYPE_MATCHER "] ")
50 #define DEBUG_MATCHER(x) DEBUG_WITH_TYPE(DEBUG_TYPE_MATCHER, x)
55 OpAsmParser &parser, std::optional<OpAsmParser::UnresolvedOperand> &root,
57 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &extraBindings,
58 SmallVectorImpl<Type> &extraBindingTypes);
64 ArrayAttr matchers, ArrayAttr actions);
75 Operation *transformAncestor = transform.getOperation();
76 while (transformAncestor) {
77 if (transformAncestor == payload) {
79 transform.emitDefiniteFailure()
80 <<
"cannot apply transform to itself (or one of its ancestors)";
81 diag.attachNote(payload->
getLoc()) <<
"target payload op";
84 transformAncestor = transformAncestor->
getParentOp();
89 #define GET_OP_CLASSES
90 #include "mlir/Dialect/Transform/IR/TransformOps.cpp.inc"
98 if (!point.
isParent() && getOperation()->getNumOperands() == 1)
99 return getOperation()->getOperands();
101 getOperation()->operand_end());
104 void transform::AlternativesOp::getSuccessorRegions(
106 for (
Region &alternative : llvm::drop_begin(
110 regions.emplace_back(&alternative, !getOperands().empty()
111 ? alternative.getArguments()
115 regions.emplace_back(getOperation()->getResults());
118 void transform::AlternativesOp::getRegionInvocationBounds(
123 bounds.reserve(getNumRegions());
124 bounds.emplace_back(1, 1);
131 results.
set(res, {});
139 if (
Value scopeHandle = getScope())
140 llvm::append_range(originals, state.getPayloadOps(scopeHandle));
142 originals.push_back(state.getTopLevel());
145 if (original->isAncestor(getOperation())) {
147 <<
"scope must not contain the transforms being applied";
148 diag.attachNote(original->getLoc()) <<
"scope";
153 <<
"only isolated-from-above ops can be alternative scopes";
154 diag.attachNote(original->getLoc()) <<
"scope";
159 for (
Region ® : getAlternatives()) {
164 auto scope = state.make_region_scope(reg);
165 auto clones = llvm::to_vector(
166 llvm::map_range(originals, [](
Operation *op) {
return op->
clone(); }));
167 auto deleteClones = llvm::make_scope_exit([&] {
171 if (
failed(state.mapBlockArguments(reg.front().getArgument(0), clones)))
175 for (
Operation &transform : reg.front().without_terminator()) {
177 state.applyTransform(cast<TransformOpInterface>(transform));
179 LLVM_DEBUG(
DBGS() <<
"alternative failed: " << result.
getMessage()
194 deleteClones.release();
195 TrackingListener listener(state, *
this);
197 for (
const auto &kvp : llvm::zip(originals, clones)) {
208 return emitSilenceableError() <<
"all alternatives failed";
211 void transform::AlternativesOp::getEffects(
212 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
215 for (
Region *region : getRegions()) {
216 if (!region->empty())
223 for (
Region &alternative : getAlternatives()) {
228 <<
"expects terminator operands to have the "
229 "same type as results of the operation";
230 diag.attachNote(terminator->
getLoc()) <<
"terminator";
247 llvm::to_vector(state.getPayloadOps(getTarget()));
250 if (
auto paramH = getParam()) {
252 if (params.size() != 1) {
253 if (targets.size() != params.size()) {
254 return emitSilenceableError()
255 <<
"parameter and target have different payload lengths ("
256 << params.size() <<
" vs " << targets.size() <<
")";
258 for (
auto &&[target, attr] : llvm::zip_equal(targets, params))
259 target->setAttr(getName(), attr);
264 for (
auto *target : targets)
265 target->setAttr(getName(), attr);
269 void transform::AnnotateOp::getEffects(
270 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
281 transform::ApplyCommonSubexpressionEliminationOp::applyToOne(
296 void transform::ApplyCommonSubexpressionEliminationOp::getEffects(
297 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
321 auto addDefiningOpsToWorklist = [&](
Operation *op) {
324 if (
Operation *defOp = v.getDefiningOp())
326 worklist.insert(defOp);
334 const auto *it = llvm::find(worklist, op);
335 if (it != worklist.end())
344 addDefiningOpsToWorklist(op);
350 while (!worklist.empty()) {
354 addDefiningOpsToWorklist(op);
361 void transform::ApplyDeadCodeEliminationOp::getEffects(
362 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
386 if (!getRegion().empty()) {
387 for (
Operation &op : getRegion().front()) {
388 cast<transform::PatternDescriptorOpInterface>(&op)
389 .populatePatternsWithState(patterns, state);
399 config.
maxIterations = getMaxIterations() ==
static_cast<uint64_t
>(-1)
401 : getMaxIterations();
402 config.
maxNumRewrites = getMaxNumRewrites() ==
static_cast<uint64_t
>(-1)
404 : getMaxNumRewrites();
409 bool cseChanged =
false;
412 static const int64_t kNumMaxIterations = 50;
413 int64_t iteration = 0;
428 if (target != nestedOp)
429 ops.push_back(nestedOp);
438 <<
"greedy pattern application failed";
446 }
while (cseChanged && ++iteration < kNumMaxIterations);
448 if (iteration == kNumMaxIterations)
455 if (!getRegion().empty()) {
456 for (
Operation &op : getRegion().front()) {
457 if (!isa<transform::PatternDescriptorOpInterface>(&op)) {
459 <<
"expected children ops to implement "
460 "PatternDescriptorOpInterface";
461 diag.attachNote(op.
getLoc()) <<
"op without interface";
469 void transform::ApplyPatternsOp::getEffects(
470 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
475 void transform::ApplyPatternsOp::build(
484 bodyBuilder(builder, result.
location);
491 void transform::ApplyCanonicalizationPatternsOp::populatePatterns(
495 dialect->getCanonicalizationPatterns(patterns);
497 op.getCanonicalizationPatterns(patterns, ctx);
511 std::unique_ptr<TypeConverter> defaultTypeConverter;
512 transform::TypeConverterBuilderOpInterface typeConverterBuilder =
513 getDefaultTypeConverter();
514 if (typeConverterBuilder)
515 defaultTypeConverter = typeConverterBuilder.getTypeConverter();
520 for (
Attribute attr : cast<ArrayAttr>(*getLegalOps()))
521 conversionTarget.addLegalOp(
524 for (
Attribute attr : cast<ArrayAttr>(*getIllegalOps()))
525 conversionTarget.addIllegalOp(
527 if (getLegalDialects())
528 for (
Attribute attr : cast<ArrayAttr>(*getLegalDialects()))
529 conversionTarget.addLegalDialect(cast<StringAttr>(attr).getValue());
530 if (getIllegalDialects())
531 for (
Attribute attr : cast<ArrayAttr>(*getIllegalDialects()))
532 conversionTarget.addIllegalDialect(cast<StringAttr>(attr).getValue());
540 if (!getPatterns().empty()) {
541 for (
Operation &op : getPatterns().front()) {
543 cast<transform::ConversionPatternDescriptorOpInterface>(&op);
546 std::unique_ptr<TypeConverter> typeConverter =
547 descriptor.getTypeConverter();
550 keepAliveConverters.emplace_back(std::move(typeConverter));
551 converter = keepAliveConverters.back().get();
554 if (!defaultTypeConverter) {
556 <<
"pattern descriptor does not specify type "
557 "converter and apply_conversion_patterns op has "
558 "no default type converter";
559 diag.attachNote(op.
getLoc()) <<
"pattern descriptor op";
562 converter = defaultTypeConverter.get();
568 descriptor.populateConversionTargetRules(*converter, conversionTarget);
570 descriptor.populatePatterns(*converter, patterns);
578 TrackingListenerConfig trackingConfig;
579 trackingConfig.requireMatchingReplacementOpName =
false;
580 ErrorCheckingTrackingListener trackingListener(state, *
this, trackingConfig);
582 if (getPreserveHandles())
583 conversionConfig.
listener = &trackingListener;
586 for (
Operation *target : state.getPayloadOps(getTarget())) {
595 if (getPartialConversion()) {
606 diag = emitSilenceableError() <<
"dialect conversion failed";
607 diag.attachNote(target->
getLoc()) <<
"target op";
612 trackingListener.checkAndResetError();
614 if (
diag.succeeded()) {
616 return trackingFailure;
618 diag.attachNote() <<
"tracking listener also failed: "
620 (void)trackingFailure.
silence();
624 if (!
diag.succeeded())
632 if (getNumRegions() != 1 && getNumRegions() != 2)
633 return emitOpError() <<
"expected 1 or 2 regions";
634 if (!getPatterns().empty()) {
635 for (
Operation &op : getPatterns().front()) {
636 if (!isa<transform::ConversionPatternDescriptorOpInterface>(&op)) {
638 emitOpError() <<
"expected pattern children ops to implement "
639 "ConversionPatternDescriptorOpInterface";
640 diag.attachNote(op.
getLoc()) <<
"op without interface";
645 if (getNumRegions() == 2) {
646 Region &typeConverterRegion = getRegion(1);
647 if (!llvm::hasSingleElement(typeConverterRegion.
front()))
649 <<
"expected exactly one op in default type converter region";
650 auto typeConverterOp = dyn_cast<transform::TypeConverterBuilderOpInterface>(
652 if (!typeConverterOp) {
654 <<
"expected default converter child op to "
655 "implement TypeConverterBuilderOpInterface";
656 diag.attachNote(typeConverterOp->getLoc()) <<
"op without interface";
660 if (!getPatterns().empty()) {
661 for (
Operation &op : getPatterns().front()) {
663 cast<transform::ConversionPatternDescriptorOpInterface>(&op);
664 if (
failed(descriptor.verifyTypeConverter(typeConverterOp)))
672 void transform::ApplyConversionPatternsOp::getEffects(
673 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
674 if (!getPreserveHandles()) {
682 void transform::ApplyConversionPatternsOp::build(
692 if (patternsBodyBuilder)
693 patternsBodyBuilder(builder, result.
location);
699 if (typeConverterBodyBuilder)
700 typeConverterBodyBuilder(builder, result.
location);
708 void transform::ApplyToLLVMConversionPatternsOp::populatePatterns(
711 assert(dialect &&
"expected that dialect is loaded");
712 auto *iface = cast<ConvertToLLVMPatternInterface>(dialect);
716 iface->populateConvertToLLVMConversionPatterns(
720 LogicalResult transform::ApplyToLLVMConversionPatternsOp::verifyTypeConverter(
721 transform::TypeConverterBuilderOpInterface builder) {
722 if (builder.getTypeConverterType() !=
"LLVMTypeConverter")
723 return emitOpError(
"expected LLVMTypeConverter");
730 return emitOpError(
"unknown dialect or dialect not loaded: ")
732 auto *iface = dyn_cast<ConvertToLLVMPatternInterface>(dialect);
735 "dialect does not implement ConvertToLLVMPatternInterface or "
736 "extension was not loaded: ")
746 transform::ApplyLoopInvariantCodeMotionOp::applyToOne(
756 void transform::ApplyLoopInvariantCodeMotionOp::getEffects(
757 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
784 <<
"unknown pass or pass pipeline: " << getPassName();
793 <<
"failed to add pass or pass pipeline to pipeline: "
796 if (
failed(pm.run(target))) {
797 auto diag = emitSilenceableError() <<
"pass pipeline failed";
798 diag.attachNote(target->
getLoc()) <<
"target op";
802 results.push_back(target);
812 Operation *target, ApplyToEachResultList &results,
814 results.push_back(target);
818 void transform::CastOp::getEffects(
819 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
826 assert(inputs.size() == 1 &&
"expected one input");
827 assert(outputs.size() == 1 &&
"expected one output");
829 std::initializer_list<Type>{inputs.front(), outputs.front()},
830 llvm::IsaPred<transform::TransformHandleTypeInterface>);
847 assert(block.
getParent() &&
"cannot match using a detached block");
848 auto matchScope = state.make_region_scope(*block.
getParent());
853 if (!isa<transform::MatchOpInterface>(match)) {
855 <<
"expected operations in the match part to "
856 "implement MatchOpInterface";
859 state.applyTransform(cast<transform::TransformOpInterface>(match));
860 if (
diag.succeeded())
875 template <
typename... Tys>
877 return ((isa<Tys>(t1) && isa<Tys>(t2)) || ... ||
false);
884 transform::TransformParamTypeInterface,
885 transform::TransformValueHandleTypeInterface>(
897 auto matcher = SymbolTable::lookupNearestSymbolFrom<FunctionOpInterface>(
898 getOperation(), getMatcher());
899 if (matcher.isExternal()) {
901 <<
"unresolved external symbol " << getMatcher();
905 rawResults.resize(getOperation()->getNumResults());
906 std::optional<DiagnosedSilenceableFailure> maybeFailure;
907 for (
Operation *root : state.getPayloadOps(getRoot())) {
911 op->
print(llvm::dbgs(),
913 llvm::dbgs() <<
" @" << op <<
"\n";
919 matchBlock(matcher.getFunctionBody().front(), op, state, mappings);
920 if (
diag.isDefiniteFailure())
922 if (
diag.isSilenceableFailure()) {
924 <<
" failed: " <<
diag.getMessage());
930 if (mapping.size() != 1) {
931 maybeFailure.emplace(emitSilenceableError()
932 <<
"result #" << i <<
", associated with "
934 <<
" payload objects, expected 1");
937 rawResults[i].push_back(mapping[0]);
942 return std::move(*maybeFailure);
943 assert(!maybeFailure &&
"failure set but the walk was not interrupted");
945 for (
auto &&[opResult, rawResult] :
946 llvm::zip_equal(getOperation()->getResults(), rawResults)) {
953 void transform::CollectMatchingOp::getEffects(
954 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
960 LogicalResult transform::CollectMatchingOp::verifySymbolUses(
962 auto matcherSymbol = dyn_cast_or_null<FunctionOpInterface>(
964 if (!matcherSymbol ||
965 !isa<TransformOpInterface>(matcherSymbol.getOperation()))
966 return emitError() <<
"unresolved matcher symbol " << getMatcher();
969 if (argumentTypes.size() != 1 ||
970 !isa<TransformHandleTypeInterface>(argumentTypes[0])) {
972 <<
"expected the matcher to take one operation handle argument";
974 if (!matcherSymbol.getArgAttr(
975 0, transform::TransformDialect::kArgReadOnlyAttrName)) {
976 return emitError() <<
"expected the matcher argument to be marked readonly";
980 if (resultTypes.size() != getOperation()->getNumResults()) {
982 <<
"expected the matcher to yield as many values as op has results ("
983 << getOperation()->getNumResults() <<
"), got "
984 << resultTypes.size();
987 for (
auto &&[i, matcherType, resultType] :
993 <<
"mismatching type interfaces for matcher result and op result #"
1010 matchActionPairs.reserve(getMatchers().size());
1012 for (
auto &&[matcher, action] :
1013 llvm::zip_equal(getMatchers(), getActions())) {
1014 auto matcherSymbol =
1016 getOperation(), cast<SymbolRefAttr>(matcher));
1019 getOperation(), cast<SymbolRefAttr>(action));
1020 assert(matcherSymbol && actionSymbol &&
1021 "unresolved symbols not caught by the verifier");
1023 if (matcherSymbol.isExternal())
1025 if (actionSymbol.isExternal())
1028 matchActionPairs.emplace_back(matcherSymbol, actionSymbol);
1033 for (
Operation *root : state.getPayloadOps(getRoot())) {
1037 if (!getRestrictRoot() && op == root)
1042 op->
print(llvm::dbgs(),
1044 llvm::dbgs() <<
" @" << op <<
"\n";
1048 for (
auto [matcher, action] : matchActionPairs) {
1051 matchBlock(matcher.getFunctionBody().front(), op, state, mappings);
1052 if (
diag.isDefiniteFailure())
1054 if (
diag.isSilenceableFailure()) {
1056 <<
" failed: " <<
diag.getMessage());
1060 auto scope = state.make_region_scope(action.getFunctionBody());
1061 for (
auto &&[arg, map] : llvm::zip_equal(
1062 action.getFunctionBody().front().getArguments(), mappings)) {
1063 if (
failed(state.mapBlockArgument(arg, map)))
1068 action.getFunctionBody().front().without_terminator()) {
1070 state.applyTransform(cast<TransformOpInterface>(transform));
1075 overallDiag = emitSilenceableError() <<
"actions failed";
1080 <<
"when applied to this matching payload";
1097 results.
set(llvm::cast<OpResult>(getUpdated()),
1098 state.getPayloadOps(getRoot()));
1102 void transform::ForeachMatchOp::getEffects(
1103 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
1105 if (getOperation()->getNumOperands() < 1 ||
1106 getOperation()->getNumResults() < 1) {
1118 ArrayAttr &matchers,
1119 ArrayAttr &actions) {
1141 ArrayAttr matchers, ArrayAttr actions) {
1144 for (
auto &&[matcher, action, idx] : llvm::zip_equal(
1145 matchers, actions, llvm::seq<unsigned>(0, matchers.size()))) {
1147 printer << cast<SymbolRefAttr>(matcher) <<
" -> "
1148 << cast<SymbolRefAttr>(action);
1149 if (idx != matchers.size() - 1)
1157 if (getMatchers().size() != getActions().size())
1158 return emitOpError() <<
"expected the same number of matchers and actions";
1159 if (getMatchers().empty())
1160 return emitOpError() <<
"expected at least one match/action pair";
1164 if (matcherNames.insert(name).second)
1167 <<
" is used more than once, only the first match will apply";
1178 bool alsoVerifyInternal =
false) {
1179 auto transformOp = cast<transform::TransformOpInterface>(op.getOperation());
1180 llvm::SmallDenseSet<unsigned> consumedArguments;
1181 if (!op.isExternal()) {
1185 for (
unsigned i = 0, e = op.getNumArguments(); i < e; ++i) {
1187 op.getArgAttr(i, transform::TransformDialect::kArgConsumedAttrName) !=
1190 op.getArgAttr(i, transform::TransformDialect::kArgReadOnlyAttrName) !=
1192 if (isConsumed && isReadOnly) {
1193 return transformOp.emitSilenceableError()
1194 <<
"argument #" << i <<
" cannot be both readonly and consumed";
1196 if ((op.isExternal() || alsoVerifyInternal) && !isConsumed && !isReadOnly) {
1197 return transformOp.emitSilenceableError()
1198 <<
"must provide consumed/readonly status for arguments of "
1199 "external or called ops";
1201 if (op.isExternal())
1204 if (consumedArguments.contains(i) && !isConsumed && isReadOnly) {
1205 return transformOp.emitSilenceableError()
1206 <<
"argument #" << i
1207 <<
" is consumed in the body but is not marked as such";
1209 if (emitWarnings && !consumedArguments.contains(i) && isConsumed) {
1213 <<
"op argument #" << i
1214 <<
" is not consumed in the body but is marked as consumed";
1222 assert(getMatchers().size() == getActions().size());
1225 for (
auto &&[matcher, action] :
1226 llvm::zip_equal(getMatchers(), getActions())) {
1227 auto matcherSymbol = dyn_cast_or_null<FunctionOpInterface>(
1229 cast<SymbolRefAttr>(matcher)));
1230 auto actionSymbol = dyn_cast_or_null<FunctionOpInterface>(
1232 cast<SymbolRefAttr>(action)));
1233 if (!matcherSymbol ||
1234 !isa<TransformOpInterface>(matcherSymbol.getOperation()))
1235 return emitError() <<
"unresolved matcher symbol " << matcher;
1236 if (!actionSymbol ||
1237 !isa<TransformOpInterface>(actionSymbol.getOperation()))
1238 return emitError() <<
"unresolved action symbol " << action;
1243 .checkAndReport())) {
1249 .checkAndReport())) {
1254 ArrayRef<Type> actionArguments = actionSymbol.getArgumentTypes();
1255 if (matcherResults.size() != actionArguments.size()) {
1256 return emitError() <<
"mismatching number of matcher results and "
1257 "action arguments between "
1258 << matcher <<
" (" << matcherResults.size() <<
") and "
1259 << action <<
" (" << actionArguments.size() <<
")";
1261 for (
auto &&[i, matcherType, actionType] :
1266 return emitError() <<
"mismatching type interfaces for matcher result "
1267 "and action argument #"
1271 if (!actionSymbol.getResultTypes().empty()) {
1273 emitError() <<
"action symbol is not expected to have results";
1274 diag.attachNote(actionSymbol->getLoc()) <<
"symbol declaration";
1278 if (matcherSymbol.getArgumentTypes().size() != 1 ||
1280 getRoot().getType())) {
1282 emitOpError() <<
"expects matcher symbol to have one argument with "
1283 "the same transform interface as the first operand";
1284 diag.attachNote(matcherSymbol->getLoc()) <<
"symbol declaration";
1288 if (matcherSymbol.getArgAttr(0, consumedAttr)) {
1291 <<
"does not expect matcher symbol to consume its operand";
1292 diag.attachNote(matcherSymbol->getLoc()) <<
"symbol declaration";
1311 llvm::to_vector(state.getPayloadOps(getTarget()));
1313 auto scope = state.make_region_scope(getBody());
1314 if (
failed(state.mapBlockArguments(getIterationVariable(), {op})))
1318 for (
Operation &transform : getBody().front().without_terminator()) {
1320 cast<transform::TransformOpInterface>(transform));
1326 for (
unsigned i = 0; i < getNumResults(); ++i) {
1327 auto yieldedOps = state.getPayloadOps(getYieldOp().getOperand(i));
1328 resultOps[i].append(yieldedOps.begin(), yieldedOps.end());
1332 for (
unsigned i = 0; i < getNumResults(); ++i)
1333 results.
set(llvm::cast<OpResult>(getResult(i)), resultOps[i]);
1338 void transform::ForeachOp::getEffects(
1339 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
1341 if (any_of(getBody().front().without_terminator(), [&](
Operation &op) {
1349 if (any_of(getBody().front().without_terminator(), [&](
Operation &op) {
1353 }
else if (any_of(getBody().front().without_terminator(), [&](
Operation &op) {
1359 for (
Value result : getResults())
1363 void transform::ForeachOp::getSuccessorRegions(
1365 Region *bodyRegion = &getBody();
1367 regions.emplace_back(bodyRegion, bodyRegion->
getArguments());
1372 assert(point == getBody() &&
"unexpected region index");
1373 regions.emplace_back(bodyRegion, bodyRegion->
getArguments());
1374 regions.emplace_back();
1381 assert(point == getBody() &&
"unexpected region index");
1382 return getOperation()->getOperands();
1385 transform::YieldOp transform::ForeachOp::getYieldOp() {
1386 return cast<transform::YieldOp>(getBody().front().getTerminator());
1390 auto yieldOp = getYieldOp();
1391 if (getNumResults() != yieldOp.getNumOperands())
1392 return emitOpError() <<
"expects the same number of results as the "
1393 "terminator has operands";
1394 for (
Value v : yieldOp.getOperands())
1395 if (!llvm::isa<TransformHandleTypeInterface>(v.getType()))
1396 return yieldOp->emitOpError(
"expects operands to have types implementing "
1397 "TransformHandleTypeInterface");
1411 for (
Operation *target : state.getPayloadOps(getTarget())) {
1413 for (int64_t i = 0, e = getNthParent(); i < e; ++i) {
1416 bool checkIsolatedFromAbove =
1417 !getIsolatedFromAbove() ||
1419 bool checkOpName = !getOpName().has_value() ||
1421 if (checkIsolatedFromAbove && checkOpName)
1426 if (getAllowEmptyResults()) {
1427 results.
set(llvm::cast<OpResult>(getResult()), parents);
1431 emitSilenceableError()
1432 <<
"could not find a parent op that matches all requirements";
1433 diag.attachNote(target->
getLoc()) <<
"target op";
1437 if (getDeduplicate()) {
1438 if (!resultSet.contains(parent)) {
1439 parents.push_back(parent);
1440 resultSet.insert(parent);
1443 parents.push_back(parent);
1446 results.
set(llvm::cast<OpResult>(getResult()), parents);
1458 int64_t resultNumber = getResultNumber();
1459 auto payloadOps = state.getPayloadOps(getTarget());
1460 if (std::empty(payloadOps)) {
1461 results.
set(cast<OpResult>(getResult()), {});
1464 if (!llvm::hasSingleElement(payloadOps))
1466 <<
"handle must be mapped to exactly one payload op";
1468 Operation *target = *payloadOps.begin();
1471 results.
set(llvm::cast<OpResult>(getResult()),
1485 for (
Value v : state.getPayloadValues(getTarget())) {
1486 if (llvm::isa<BlockArgument>(v)) {
1488 emitSilenceableError() <<
"cannot get defining op of block argument";
1489 diag.attachNote(v.getLoc()) <<
"target value";
1492 definingOps.push_back(v.getDefiningOp());
1494 results.
set(llvm::cast<OpResult>(getResult()), definingOps);
1506 int64_t operandNumber = getOperandNumber();
1508 for (
Operation *target : state.getPayloadOps(getTarget())) {
1515 emitSilenceableError()
1516 <<
"could not find a producer for operand number: " << operandNumber
1517 <<
" of " << *target;
1518 diag.attachNote(target->getLoc()) <<
"target op";
1521 producers.push_back(producer);
1523 results.
set(llvm::cast<OpResult>(getResult()), producers);
1536 for (
Operation *target : state.getPayloadOps(getTarget())) {
1539 getLoc(), getIsAll(), getIsInverted(), getRawPositionList(),
1540 target->getNumOperands(), operandPositions);
1541 if (
diag.isSilenceableFailure()) {
1542 diag.attachNote(target->getLoc())
1543 <<
"while considering positions of this payload operation";
1546 llvm::append_range(operands,
1547 llvm::map_range(operandPositions, [&](int64_t pos) {
1548 return target->getOperand(pos);
1551 results.
setValues(cast<OpResult>(getResult()), operands);
1557 getIsInverted(), getIsAll());
1569 for (
Operation *target : state.getPayloadOps(getTarget())) {
1572 getLoc(), getIsAll(), getIsInverted(), getRawPositionList(),
1573 target->getNumResults(), resultPositions);
1574 if (
diag.isSilenceableFailure()) {
1575 diag.attachNote(target->getLoc())
1576 <<
"while considering positions of this payload operation";
1579 llvm::append_range(opResults,
1580 llvm::map_range(resultPositions, [&](int64_t pos) {
1581 return target->getResult(pos);
1584 results.
setValues(cast<OpResult>(getResult()), opResults);
1590 getIsInverted(), getIsAll());
1597 void transform::GetTypeOp::getEffects(
1598 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
1609 for (
Value value : state.getPayloadValues(getValue())) {
1610 Type type = value.getType();
1611 if (getElemental()) {
1612 if (
auto shaped = dyn_cast<ShapedType>(type)) {
1613 type = shaped.getElementType();
1618 results.
setParams(getResult().cast<OpResult>(), params);
1635 state.applyTransform(cast<transform::TransformOpInterface>(transform));
1640 if (mode == transform::FailurePropagationMode::Propagate) {
1659 auto callee = SymbolTable::lookupNearestSymbolFrom<NamedSequenceOp>(
1660 getOperation(), getTarget());
1661 assert(callee &&
"unverified reference to unknown symbol");
1663 if (callee.isExternal())
1669 auto scope = state.make_region_scope(callee.getBody());
1670 for (
auto &&[arg, map] :
1671 llvm::zip_equal(callee.getBody().front().getArguments(), mappings)) {
1672 if (
failed(state.mapBlockArgument(arg, map)))
1677 callee.getBody().front(), getFailurePropagationMode(), state, results);
1680 mappings, callee.getBody().front().getTerminator()->getOperands(), state);
1681 for (
auto &&[result, mapping] : llvm::zip_equal(getResults(), mappings))
1689 void transform::IncludeOp::getEffects(
1690 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
1704 auto defaultEffects = [&] {
onlyReadsHandle(getOperands(), effects); };
1709 getOperation()->getAttrOfType<SymbolRefAttr>(getTargetAttrName());
1711 return defaultEffects();
1712 auto callee = SymbolTable::lookupNearestSymbolFrom<NamedSequenceOp>(
1713 getOperation(), getTarget());
1715 return defaultEffects();
1719 (void)earlyVerifierResult.
silence();
1720 return defaultEffects();
1723 for (
unsigned i = 0, e = getNumOperands(); i < e; ++i) {
1724 if (callee.getArgAttr(i, TransformDialect::kArgConsumedAttrName))
1735 auto targetAttr = getOperation()->getAttrOfType<SymbolRefAttr>(
"target");
1737 return emitOpError() <<
"expects a 'target' symbol reference attribute";
1742 return emitOpError() <<
"does not reference a named transform sequence";
1744 FunctionType fnType = target.getFunctionType();
1745 if (fnType.getNumInputs() != getNumOperands())
1746 return emitError(
"incorrect number of operands for callee");
1748 for (
unsigned i = 0, e = fnType.getNumInputs(); i != e; ++i) {
1749 if (getOperand(i).getType() != fnType.getInput(i)) {
1750 return emitOpError(
"operand type mismatch: expected operand type ")
1751 << fnType.getInput(i) <<
", but provided "
1752 << getOperand(i).getType() <<
" for operand number " << i;
1756 if (fnType.getNumResults() != getNumResults())
1757 return emitError(
"incorrect number of results for callee");
1759 for (
unsigned i = 0, e = fnType.getNumResults(); i != e; ++i) {
1760 Type resultType = getResult(i).getType();
1761 Type funcType = fnType.getResult(i);
1763 return emitOpError() <<
"type of result #" << i
1764 <<
" must implement the same transform dialect "
1765 "interface as the corresponding callee result";
1770 cast<FunctionOpInterface>(*target),
false,
1780 ::std::optional<::mlir::Operation *> maybeCurrent,
1782 if (!maybeCurrent.has_value()) {
1787 return emitSilenceableError() <<
"operation is not empty";
1798 for (
auto acceptedAttr : getOpNames().getAsRange<StringAttr>()) {
1799 if (acceptedAttr.getValue() == currentOpName)
1802 return emitSilenceableError() <<
"wrong operation name";
1813 auto signedAPIntAsString = [&](
const APInt &value) {
1815 llvm::raw_string_ostream os(str);
1816 value.print(os,
true);
1823 if (params.size() != references.size()) {
1824 return emitSilenceableError()
1825 <<
"parameters have different payload lengths (" << params.size()
1826 <<
" vs " << references.size() <<
")";
1829 for (
auto &&[i, param, reference] :
llvm::enumerate(params, references)) {
1830 auto intAttr = llvm::dyn_cast<IntegerAttr>(param);
1831 auto refAttr = llvm::dyn_cast<IntegerAttr>(reference);
1832 if (!intAttr || !refAttr) {
1834 <<
"non-integer parameter value not expected";
1836 if (intAttr.getType() != refAttr.getType()) {
1838 <<
"mismatching integer attribute types in parameter #" << i;
1840 APInt value = intAttr.getValue();
1841 APInt refValue = refAttr.getValue();
1844 int64_t position = i;
1845 auto reportError = [&](StringRef direction) {
1847 emitSilenceableError() <<
"expected parameter to be " << direction
1848 <<
" " << signedAPIntAsString(refValue)
1849 <<
", got " << signedAPIntAsString(value);
1850 diag.attachNote(getParam().getLoc())
1851 <<
"value # " << position
1852 <<
" associated with the parameter defined here";
1856 switch (getPredicate()) {
1857 case MatchCmpIPredicate::eq:
1858 if (value.eq(refValue))
1860 return reportError(
"equal to");
1861 case MatchCmpIPredicate::ne:
1862 if (value.ne(refValue))
1864 return reportError(
"not equal to");
1865 case MatchCmpIPredicate::lt:
1866 if (value.slt(refValue))
1868 return reportError(
"less than");
1869 case MatchCmpIPredicate::le:
1870 if (value.sle(refValue))
1872 return reportError(
"less than or equal to");
1873 case MatchCmpIPredicate::gt:
1874 if (value.sgt(refValue))
1876 return reportError(
"greater than");
1877 case MatchCmpIPredicate::ge:
1878 if (value.sge(refValue))
1880 return reportError(
"greater than or equal to");
1886 void transform::MatchParamCmpIOp::getEffects(
1887 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
1900 results.
setParams(cast<OpResult>(getParam()), {getValue()});
1913 if (isa<TransformHandleTypeInterface>(handles.front().
getType())) {
1915 for (
Value operand : handles)
1916 llvm::append_range(operations, state.getPayloadOps(operand));
1917 if (!getDeduplicate()) {
1918 results.
set(llvm::cast<OpResult>(getResult()), operations);
1923 results.
set(llvm::cast<OpResult>(getResult()), uniqued.getArrayRef());
1927 if (llvm::isa<TransformParamTypeInterface>(handles.front().getType())) {
1929 for (
Value attribute : handles)
1930 llvm::append_range(attrs, state.getParams(attribute));
1931 if (!getDeduplicate()) {
1932 results.
setParams(cast<OpResult>(getResult()), attrs);
1937 results.
setParams(cast<OpResult>(getResult()), uniqued.getArrayRef());
1942 llvm::isa<TransformValueHandleTypeInterface>(handles.front().getType()) &&
1943 "expected value handle type");
1945 for (
Value value : handles)
1946 llvm::append_range(payloadValues, state.getPayloadValues(value));
1947 if (!getDeduplicate()) {
1948 results.
setValues(cast<OpResult>(getResult()), payloadValues);
1953 results.
setValues(cast<OpResult>(getResult()), uniqued.getArrayRef());
1957 bool transform::MergeHandlesOp::allowsRepeatedHandleOperands() {
1959 return getDeduplicate();
1962 void transform::MergeHandlesOp::getEffects(
1963 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
1971 OpFoldResult transform::MergeHandlesOp::fold(FoldAdaptor adaptor) {
1972 if (getDeduplicate() || getHandles().size() != 1)
1977 return getHandles().front();
1995 auto scope = state.make_region_scope(getBody());
1997 state, this->getOperation(), getBody())))
2001 FailurePropagationMode::Propagate, state, results);
2004 void transform::NamedSequenceOp::getEffects(
2005 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {}
2010 parser, result,
false,
2011 getFunctionTypeAttrName(result.
name),
2014 std::string &) { return builder.getFunctionType(inputs, results); },
2015 getArgAttrsAttrName(result.
name), getResAttrsAttrName(result.
name));
2020 printer, cast<FunctionOpInterface>(getOperation()),
false,
2021 getFunctionTypeAttrName().getValue(), getArgAttrsAttrName(),
2022 getResAttrsAttrName());
2032 if (
auto parent = op->
getParentOfType<transform::TransformOpInterface>()) {
2035 <<
"cannot be defined inside another transform op";
2036 diag.attachNote(parent.
getLoc()) <<
"ancestor transform op";
2040 if (op.isExternal() || op.getFunctionBody().empty()) {
2047 if (op.getFunctionBody().front().empty())
2050 Operation *terminator = &op.getFunctionBody().front().back();
2051 if (!isa<transform::YieldOp>(terminator)) {
2054 << transform::YieldOp::getOperationName()
2055 <<
"' as terminator";
2056 diag.attachNote(terminator->
getLoc()) <<
"terminator";
2062 <<
"expected terminator to have as many operands as the parent op "
2065 for (
auto [i, operandType, resultType] : llvm::zip_equal(
2068 if (operandType == resultType)
2071 <<
"the type of the terminator operand #" << i
2072 <<
" must match the type of the corresponding parent op result ("
2073 << operandType <<
" vs " << resultType <<
")";
2086 transform::TransformDialect::kWithNamedSequenceAttrName)) {
2089 <<
"expects the parent symbol table to have the '"
2090 << transform::TransformDialect::kWithNamedSequenceAttrName
2092 diag.attachNote(parent->
getLoc()) <<
"symbol table operation";
2097 if (
auto parent = op->
getParentOfType<transform::TransformOpInterface>()) {
2100 <<
"cannot be defined inside another transform op";
2101 diag.attachNote(parent.
getLoc()) <<
"ancestor transform op";
2105 if (op.isExternal() || op.getBody().empty())
2109 if (op.getBody().front().empty())
2112 Operation *terminator = &op.getBody().front().back();
2113 if (!isa<transform::YieldOp>(terminator)) {
2116 << transform::YieldOp::getOperationName()
2117 <<
"' as terminator";
2118 diag.attachNote(terminator->
getLoc()) <<
"terminator";
2124 <<
"expected terminator to have as many operands as the parent op "
2127 for (
auto [i, operandType, resultType] :
2128 llvm::zip_equal(llvm::seq<unsigned>(0, terminator->
getNumOperands()),
2131 if (operandType == resultType)
2134 <<
"the type of the terminator operand #" << i
2135 <<
" must match the type of the corresponding parent op result ("
2136 << operandType <<
" vs " << resultType <<
")";
2139 auto funcOp = cast<FunctionOpInterface>(*op);
2142 if (!
diag.succeeded())
2154 template <
typename FnTy>
2159 types.reserve(1 + extraBindingTypes.size());
2160 types.push_back(bbArgType);
2161 llvm::append_range(types, extraBindingTypes);
2164 Region *region = state.regions.back().get();
2171 if constexpr (llvm::function_traits<FnTy>::num_args == 3) {
2172 bodyBuilder(builder, state.location, bodyBlock->
getArgument(0));
2174 bodyBuilder(builder, state.location, bodyBlock->
getArgument(0),
2179 void transform::NamedSequenceOp::build(
OpBuilder &builder,
2187 state.addAttribute(getFunctionTypeAttrName(state.name),
2189 rootType, resultTypes)));
2190 state.attributes.append(attrs.begin(), attrs.end());
2205 size_t numAssociations =
2207 .Case([&](TransformHandleTypeInterface opHandle) {
2208 return llvm::range_size(state.getPayloadOps(getHandle()));
2210 .Case([&](TransformValueHandleTypeInterface valueHandle) {
2211 return llvm::range_size(state.getPayloadValues(getHandle()));
2213 .Case([&](TransformParamTypeInterface param) {
2214 return llvm::range_size(state.getParams(getHandle()));
2217 llvm_unreachable(
"unknown kind of transform dialect type");
2227 auto resultType = getNum().getType().
cast<TransformParamTypeInterface>();
2242 auto payloadOps = state.getPayloadOps(getTarget());
2245 result.push_back(op);
2247 results.
set(cast<OpResult>(getResult()), result);
2256 Value target, int64_t numResultHandles) {
2265 int64_t numPayloadOps = llvm::range_size(state.getPayloadOps(getHandle()));
2266 auto produceNumOpsError = [&]() {
2267 return emitSilenceableError()
2268 << getHandle() <<
" expected to contain " << this->getNumResults()
2269 <<
" payload ops but it contains " << numPayloadOps
2275 if (numPayloadOps > getNumResults() && !getOverflowResult().has_value())
2276 return produceNumOpsError();
2281 if (numPayloadOps < getNumResults() && getFailOnPayloadTooSmall() &&
2282 (numPayloadOps != 0 || !getPassThroughEmptyHandle()))
2283 return produceNumOpsError();
2287 if (getOverflowResult())
2288 resultHandles[*getOverflowResult()].reserve(numPayloadOps -
2291 int64_t resultNum = en.index();
2292 if (resultNum >= getNumResults())
2293 resultNum = *getOverflowResult();
2294 resultHandles[resultNum].push_back(en.value());
2299 results.
set(llvm::cast<OpResult>(getResult(it.index())), it.value());
2304 void transform::SplitHandleOp::getEffects(
2305 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
2313 if (getOverflowResult().has_value() &&
2314 !(*getOverflowResult() < getNumResults()))
2315 return emitOpError(
"overflow_result is not a valid result index");
2327 unsigned numRepetitions = llvm::range_size(state.getPayloadOps(getPattern()));
2329 Value handle = en.value();
2330 if (isa<TransformHandleTypeInterface>(handle.
getType())) {
2332 llvm::to_vector(state.getPayloadOps(handle));
2334 payload.reserve(numRepetitions * current.size());
2335 for (
unsigned i = 0; i < numRepetitions; ++i)
2336 llvm::append_range(payload, current);
2337 results.
set(llvm::cast<OpResult>(getReplicated()[en.index()]), payload);
2339 assert(llvm::isa<TransformParamTypeInterface>(handle.
getType()) &&
2340 "expected param type");
2343 params.reserve(numRepetitions * current.size());
2344 for (
unsigned i = 0; i < numRepetitions; ++i)
2345 llvm::append_range(params, current);
2346 results.
setParams(llvm::cast<OpResult>(getReplicated()[en.index()]),
2353 void transform::ReplicateOp::getEffects(
2354 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
2369 auto scope = state.make_region_scope(*getBodyBlock()->getParent());
2370 if (
failed(mapBlockArguments(state)))
2378 OpAsmParser &parser, std::optional<OpAsmParser::UnresolvedOperand> &root,
2380 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &extraBindings,
2381 SmallVectorImpl<Type> &extraBindingTypes) {
2385 root = std::nullopt;
2406 if (!extraBindings.empty()) {
2411 if (extraBindingTypes.size() != extraBindings.size()) {
2413 "expected types to be provided for all operands");
2429 bool hasExtras = !extraBindings.empty();
2439 printer << rootType;
2442 llvm::interleaveComma(extraBindingTypes, printer.
getStream());
2451 auto iface = dyn_cast<transform::TransformOpInterface>(use.
getOwner());
2466 if (!potentialConsumer) {
2467 potentialConsumer = &use;
2472 <<
" has more than one potential consumer";
2475 diag.attachNote(use.getOwner()->getLoc())
2476 <<
"used here as operand #" << use.getOperandNumber();
2484 assert(getBodyBlock()->getNumArguments() >= 1 &&
2485 "the number of arguments must have been verified to be more than 1 by "
2486 "PossibleTopLevelTransformOpTrait");
2488 if (!getRoot() && !getExtraBindings().empty()) {
2489 return emitOpError()
2490 <<
"does not expect extra operands when used as top-level";
2496 return (emitOpError() <<
"block argument #" << arg.getArgNumber());
2503 for (
Operation &child : *getBodyBlock()) {
2504 if (!isa<TransformOpInterface>(child) &&
2505 &child != &getBodyBlock()->back()) {
2508 <<
"expected children ops to implement TransformOpInterface";
2509 diag.attachNote(child.getLoc()) <<
"op without interface";
2513 for (
OpResult result : child.getResults()) {
2514 auto report = [&]() {
2515 return (child.emitError() <<
"result #" << result.getResultNumber());
2522 if (!getBodyBlock()->mightHaveTerminator())
2523 return emitOpError() <<
"expects to have a terminator in the body";
2525 if (getBodyBlock()->getTerminator()->getOperandTypes() !=
2526 getOperation()->getResultTypes()) {
2528 <<
"expects the types of the terminator operands "
2529 "to match the types of the result";
2530 diag.attachNote(getBodyBlock()->getTerminator()->getLoc()) <<
"terminator";
2536 void transform::SequenceOp::getEffects(
2537 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
2543 assert(point == getBody() &&
"unexpected region index");
2544 if (getOperation()->getNumOperands() > 0)
2545 return getOperation()->getOperands();
2547 getOperation()->operand_end());
2550 void transform::SequenceOp::getSuccessorRegions(
2553 Region *bodyRegion = &getBody();
2554 regions.emplace_back(bodyRegion, getNumOperands() != 0
2560 assert(point == getBody() &&
"unexpected region index");
2561 regions.emplace_back(getOperation()->getResults());
2564 void transform::SequenceOp::getRegionInvocationBounds(
2567 bounds.emplace_back(1, 1);
2572 FailurePropagationMode failurePropagationMode,
2575 build(builder, state, resultTypes, failurePropagationMode, root,
2584 FailurePropagationMode failurePropagationMode,
2587 build(builder, state, resultTypes, failurePropagationMode, root,
2595 FailurePropagationMode failurePropagationMode,
2598 build(builder, state, resultTypes, failurePropagationMode,
Value(),
2606 FailurePropagationMode failurePropagationMode,
2609 build(builder, state, resultTypes, failurePropagationMode,
Value(),
2625 Value target, StringRef name) {
2627 build(builder, result, name);
2634 llvm::outs() <<
"[[[ IR printer: ";
2635 if (getName().has_value())
2636 llvm::outs() << *getName() <<
" ";
2639 if (getAssumeVerified().value_or(
false))
2641 if (getUseLocalScope().value_or(
false))
2643 if (getSkipRegions().value_or(
false))
2647 llvm::outs() <<
"top-level ]]]\n";
2648 state.getTopLevel()->print(llvm::outs(), printFlags);
2649 llvm::outs() <<
"\n";
2653 llvm::outs() <<
"]]]\n";
2654 for (
Operation *target : state.getPayloadOps(getTarget())) {
2655 target->
print(llvm::outs(), printFlags);
2656 llvm::outs() <<
"\n";
2662 void transform::PrintOp::getEffects(
2663 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
2667 if (!getTargetMutable().empty())
2687 <<
"failed to verify payload op";
2688 diag.attachNote(target->
getLoc()) <<
"payload op";
2694 void transform::VerifyOp::getEffects(
2695 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
2703 void transform::YieldOp::getEffects(
2704 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
static MLIRContext * getContext(OpFoldResult val)
static bool areCastCompatible(const DataLayout &layout, Type lhs, Type rhs)
Checks that two types are the same or can be cast into one another.
static std::string diag(const llvm::Value &value)
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
ParseResult parseSymbolName(StringAttr &result)
Parse an -identifier and store it (without the '@' symbol) in a string attribute.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseOptionalRParen()=0
Parse a ) token if present.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual ParseResult parseColon()=0
Parse a : token.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseArrow()=0
Parse a '->' token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseComma()=0
Parse a , token.
virtual ParseResult parseOptionalLParen()=0
Parse a ( token if present.
ParseResult parseTypeList(SmallVectorImpl< Type > &result)
Parse a type list.
virtual raw_ostream & getStream() const
Return the raw output stream used by this printer.
Attributes are known-constant values of operations.
This class represents an argument of a Block.
Block represents an ordered list of Operations.
BlockArgument getArgument(unsigned i)
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
Operation * getTerminator()
Get the terminator operation of this block.
OpListType & getOperations()
BlockArgListType getArguments()
iterator_range< iterator > without_terminator()
Return an iterator range over the operation within this block excluding the terminator operation at t...
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
This class is a general helper class for creating context-global objects like types,...
IntegerAttr getI64IntegerAttr(int64_t value)
StringAttr getStringAttr(const Twine &bytes)
MLIRContext * getContext() const
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
This class describes a specific conversion target.
A compatibility class connecting InFlightDiagnostic to DiagnosedSilenceableFailure while providing an...
The result of a transform IR operation application.
LogicalResult silence()
Converts silenceable failure into LogicalResult success without reporting the diagnostic,...
static DiagnosedSilenceableFailure success()
Constructs a DiagnosedSilenceableFailure in the success state.
Diagnostic & attachNote(std::optional< Location > loc=std::nullopt)
Attaches a note to the last diagnostic.
std::string getMessage() const
Returns the diagnostic message without emitting it.
bool isDefiniteFailure() const
Returns true if this is a definite failure.
LogicalResult checkAndReport()
Converts all kinds of failure into a LogicalResult failure, emitting the diagnostic if necessary.
bool succeeded() const
Returns true if this is a success.
static DiagnosedSilenceableFailure definiteFailure()
Constructs a DiagnosedSilenceableFailure in the failure state.
bool isSilenceableFailure() const
Returns true if this is a silenceable failure.
Dialects are groups of MLIR operations, types and attributes, as well as behavior associated with the...
A class for computing basic dominance information.
This class represents a frozen set of patterns that can be processed by a pattern applicator.
This class allows control over how the GreedyPatternRewriteDriver works.
static constexpr int64_t kNoLimit
int64_t maxIterations
This specifies the maximum number of times the rewriter will iterate between applying patterns and si...
RewriterBase::Listener * listener
An optional listener that should be notified about IR modifications.
int64_t maxNumRewrites
This specifies the maximum number of rewrites within an iteration.
IRValueT get() const
Return the current value being used by this operand.
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
This class represents a diagnostic that is inflight and set to be reported.
This class represents upper and lower bounds on the number of times a region of a RegionBranchOpInter...
Conversion from types to the LLVM IR dialect.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
Dialect * getLoadedDialect(StringRef name)
Get a registered IR dialect with the given namespace.
std::vector< Dialect * > getLoadedDialects()
Return information about all IR dialects loaded in the context.
ArrayRef< RegisteredOperationName > getRegisteredOperations()
Return a sorted array containing the information about all registered operations.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual OptionalParseResult parseOptionalOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single operand if present.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printNewline()=0
Print a newline and indent the printer to the start of the current operation.
virtual void increaseIndent()=0
Increase indentation.
void printOperands(const ContainerType &container)
Print a comma separated list of operands.
virtual void decreaseIndent()=0
Decrease indentation.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Listener * getListener() const
Returns the current listener of this builder, or nullptr if this builder doesn't have a listener.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
This class represents a single result from folding an operation.
This class represents an operand of an operation.
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Set of flags used to control the behavior of the various IR print methods (e.g.
OpPrintingFlags & assumeVerified()
Do not verify the operation when using custom operation printers.
OpPrintingFlags & useLocalScope()
Use local scope when printing the operation.
OpPrintingFlags & skipRegions(bool skip=true)
Skip printing regions.
This is a value defined by a result of an operation.
This class provides the API for ops that are known to be isolated from above.
A trait used to provide symbol table functionalities to a region operation.
This class implements the operand iterators for the Operation class.
type_range getType() const
type_range getTypes() const
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
Operation is the basic unit of execution within MLIR.
Value getOperand(unsigned idx)
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Attribute getAttr(StringAttr name)
Return the specified attribute if present, null otherwise.
Operation * clone(IRMapping &mapper, CloneOptions options=CloneOptions::all())
Create a deep copy of this operation, remapping any operands that use values outside of the operation...
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
void print(raw_ostream &os, const OpPrintingFlags &flags=std::nullopt)
MLIRContext * getContext()
Return the context this operation is associated with.
Location getLoc()
The source location the operation was defined or derived from.
unsigned getNumOperands()
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Block * getBlock()
Returns the operation block that contains this operation.
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
Operation * getParentWithTrait()
Returns the closest surrounding parent operation with trait Trait.
OperationName getName()
The name of an operation is the key identifier for it.
result_type_range getResultTypes()
operand_range getOperands()
Returns an iterator on the underlying Value's.
result_range getOpResults()
result_range getResults()
bool isProperAncestor(Operation *other)
Return true if this operation is a proper ancestor of the other operation.
void erase()
Remove this operation from its parent block and delete it.
unsigned getNumResults()
Return the number of results held by this operation.
This class implements Optional functionality for ParseResult.
ParseResult value() const
Access the internal ParseResult value.
bool has_value() const
Returns true if we contain a valid ParseResult value.
This class represents success/failure for parsing-like operations that find it important to chain tog...
static const PassInfo * lookup(StringRef passArg)
Returns the pass info for the specified pass class or null if unknown.
The main pass manager and pipeline builder.
static const PassPipelineInfo * lookup(StringRef pipelineArg)
Returns the pass pipeline info for the specified pass pipeline or null if unknown.
Structure to group information about a passes and pass pipelines (argument to invoke via mlir-opt,...
LogicalResult addToPipeline(OpPassManager &pm, StringRef options, function_ref< LogicalResult(const Twine &)> errorHandler) const
Adds this pass registry entry to the given pass manager.
This class represents a point being branched from in the methods of the RegionBranchOpInterface.
bool isParent() const
Returns true if branching from the parent op.
Region * getRegionOrNull() const
Returns the region if branching from a region.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
BlockArgListType getArguments()
unsigned getRegionNumber()
Return the number of this region in the parent operation.
This is a "type erased" representation of a registered operation.
MLIRContext * getContext() const
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
static DerivedEffect * get()
Returns a unique instance for the derived effect class.
This class represents a collection of SymbolTables.
Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
static StringRef getSymbolAttrName()
Return the name of the attribute used for symbol names.
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
type_range getType() const
type_range getTypes() const
size_t size() const
Return the size of this range.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
void print(raw_ostream &os) const
Type getType() const
Return the type of this value.
use_range getUses() const
Returns a range of all uses, which is useful for iterating over all uses.
user_range getUsers() const
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
A utility result that is used to signal how to proceed with an ongoing walk:
static WalkResult advance()
bool wasInterrupted() const
Returns true if the walk was interrupted.
static WalkResult interrupt()
Operation * getOwner() const
Return the owner of this operand.
A named class for passing around the variadic flag.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
void printFunctionOp(OpAsmPrinter &p, FunctionOpInterface op, bool isVariadic, StringRef typeAttrName, StringAttr argAttrsName, StringAttr resAttrsName)
Printer implementation for function-like operations.
ParseResult parseFunctionOp(OpAsmParser &parser, OperationState &result, bool allowVariadic, StringAttr typeAttrName, FuncTypeBuilder funcTypeBuilder, StringAttr argAttrsName, StringAttr resAttrsName)
Parser implementation for function-like operations.
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
InFlightDiagnostic emitWarning(Location loc)
Utility method to emit a warning message using this location.
LogicalResult applyOpPatternsAndFold(ArrayRef< Operation * > ops, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr, bool *allErased=nullptr)
Rewrite the specified ops by repeatedly applying the highest benefit patterns in a greedy worklist dr...
DiagnosedSilenceableFailure emitSilenceableFailure(Location loc, const Twine &message={})
Emits a silenceable failure with the given message.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
LogicalResult applyFullConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Apply a complete conversion on the given operations, and all nested operations.
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
DiagnosedDefiniteFailure emitDefiniteFailure(Location loc, const Twine &message={})
Emits a definite failure with the given message.
bool isOpTriviallyDead(Operation *op)
Return true if the given operation is unused, and has no side effects on memory that prevent erasing.
LogicalResult applyPatternsAndFoldGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
void eliminateCommonSubExpressions(RewriterBase &rewriter, DominanceInfo &domInfo, Operation *op, bool *changed=nullptr)
Eliminate common subexpressions within the given operation.
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
size_t moveLoopInvariantCode(ArrayRef< Region * > regions, function_ref< bool(Value, Region *)> isDefinedOutsideRegion, function_ref< bool(Operation *, Region *)> shouldMoveOutOfRegion, function_ref< void(Operation *, Region *)> moveOutOfRegion)
Given a list of regions, perform loop-invariant code motion.
Dialect conversion configuration.
RewriterBase::Listener * listener
An optional listener that is notified about all IR modifications in case dialect conversion succeeds.
This class represents an efficient way to signal success or failure.
bool succeeded() const
Returns true if the provided LogicalResult corresponds to a success value.
This is the representation of an operand reference.
This represents an operation in an abstracted form, suitable for use with the builder APIs.
T & getOrAddProperties()
Get (or create) a properties of the provided type to be set on the operation on creation.
void addOperands(ValueRange newOperands)
void addTypes(ArrayRef< Type > newTypes)
Region * addRegion()
Create a region that should be attached to the operation.