32 #include "llvm/ADT/STLExtras.h"
33 #include "llvm/ADT/ScopeExit.h"
34 #include "llvm/ADT/SmallPtrSet.h"
35 #include "llvm/Support/Debug.h"
38 #define DEBUG_TYPE "transform-dialect"
39 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "] ")
41 #define DEBUG_TYPE_MATCHER "transform-matcher"
42 #define DBGS_MATCHER() (llvm::dbgs() << "[" DEBUG_TYPE_MATCHER "] ")
43 #define DEBUG_MATCHER(x) DEBUG_WITH_TYPE(DEBUG_TYPE_MATCHER, x)
48 OpAsmParser &parser, std::optional<OpAsmParser::UnresolvedOperand> &root,
50 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &extraBindings,
51 SmallVectorImpl<Type> &extraBindingTypes);
57 ArrayAttr matchers, ArrayAttr actions);
68 Operation *transformAncestor = transform.getOperation();
69 while (transformAncestor) {
70 if (transformAncestor == payload) {
72 transform.emitDefiniteFailure()
73 <<
"cannot apply transform to itself (or one of its ancestors)";
74 diag.attachNote(payload->
getLoc()) <<
"target payload op";
77 transformAncestor = transformAncestor->
getParentOp();
82 #define GET_OP_CLASSES
83 #include "mlir/Dialect/Transform/IR/TransformOps.cpp.inc"
91 if (!point.
isParent() && getOperation()->getNumOperands() == 1)
92 return getOperation()->getOperands();
94 getOperation()->operand_end());
97 void transform::AlternativesOp::getSuccessorRegions(
99 for (
Region &alternative : llvm::drop_begin(
103 regions.emplace_back(&alternative, !getOperands().empty()
104 ? alternative.getArguments()
108 regions.emplace_back(getOperation()->getResults());
111 void transform::AlternativesOp::getRegionInvocationBounds(
116 bounds.reserve(getNumRegions());
117 bounds.emplace_back(1, 1);
124 results.
set(res, {});
132 if (
Value scopeHandle = getScope())
133 llvm::append_range(originals, state.getPayloadOps(scopeHandle));
135 originals.push_back(state.getTopLevel());
138 if (original->isAncestor(getOperation())) {
140 <<
"scope must not contain the transforms being applied";
141 diag.attachNote(original->getLoc()) <<
"scope";
146 <<
"only isolated-from-above ops can be alternative scopes";
147 diag.attachNote(original->getLoc()) <<
"scope";
152 for (
Region ® : getAlternatives()) {
157 auto scope = state.make_region_scope(reg);
158 auto clones = llvm::to_vector(
159 llvm::map_range(originals, [](
Operation *op) {
return op->
clone(); }));
160 auto deleteClones = llvm::make_scope_exit([&] {
164 if (
failed(state.mapBlockArguments(reg.front().getArgument(0), clones)))
168 for (
Operation &transform : reg.front().without_terminator()) {
170 state.applyTransform(cast<TransformOpInterface>(transform));
172 LLVM_DEBUG(
DBGS() <<
"alternative failed: " << result.
getMessage()
187 deleteClones.release();
188 TrackingListener listener(state, *
this);
190 for (
const auto &kvp : llvm::zip(originals, clones)) {
201 return emitSilenceableError() <<
"all alternatives failed";
204 void transform::AlternativesOp::getEffects(
205 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
208 for (
Region *region : getRegions()) {
209 if (!region->empty())
216 for (
Region &alternative : getAlternatives()) {
221 <<
"expects terminator operands to have the "
222 "same type as results of the operation";
223 diag.attachNote(terminator->
getLoc()) <<
"terminator";
240 llvm::to_vector(state.getPayloadOps(getTarget()));
243 if (
auto paramH = getParam()) {
245 if (params.size() != 1) {
246 if (targets.size() != params.size()) {
247 return emitSilenceableError()
248 <<
"parameter and target have different payload lengths ("
249 << params.size() <<
" vs " << targets.size() <<
")";
251 for (
auto &&[target, attr] : llvm::zip_equal(targets, params))
252 target->setAttr(getName(), attr);
257 for (
auto target : targets)
258 target->setAttr(getName(), attr);
262 void transform::AnnotateOp::getEffects(
263 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
274 transform::ApplyCommonSubexpressionEliminationOp::applyToOne(
289 void transform::ApplyCommonSubexpressionEliminationOp::getEffects(
290 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
314 auto addDefiningOpsToWorklist = [&](
Operation *op) {
317 if (
Operation *defOp = v.getDefiningOp())
319 worklist.insert(defOp);
327 auto it = llvm::find(worklist, op);
328 if (it != worklist.end())
337 addDefiningOpsToWorklist(op);
343 while (!worklist.empty()) {
347 addDefiningOpsToWorklist(op);
354 void transform::ApplyDeadCodeEliminationOp::getEffects(
355 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
379 if (!getRegion().empty()) {
380 for (
Operation &op : getRegion().front()) {
381 cast<transform::PatternDescriptorOpInterface>(&op)
382 .populatePatternsWithState(patterns, state);
395 bool cseChanged =
false;
398 static const int64_t kNumMaxIterations = 50;
399 int64_t iteration = 0;
414 if (target != nestedOp)
415 ops.push_back(nestedOp);
424 <<
"greedy pattern application failed";
432 }
while (cseChanged && ++iteration < kNumMaxIterations);
434 if (iteration == kNumMaxIterations)
441 if (!getRegion().empty()) {
442 for (
Operation &op : getRegion().front()) {
443 if (!isa<transform::PatternDescriptorOpInterface>(&op)) {
445 <<
"expected children ops to implement "
446 "PatternDescriptorOpInterface";
447 diag.attachNote(op.
getLoc()) <<
"op without interface";
455 void transform::ApplyPatternsOp::getEffects(
456 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
461 void transform::ApplyPatternsOp::build(
470 bodyBuilder(builder, result.
location);
477 void transform::ApplyCanonicalizationPatternsOp::populatePatterns(
481 dialect->getCanonicalizationPatterns(patterns);
483 op.getCanonicalizationPatterns(patterns, ctx);
497 std::unique_ptr<TypeConverter> defaultTypeConverter;
498 transform::TypeConverterBuilderOpInterface typeConverterBuilder =
499 getDefaultTypeConverter();
500 if (typeConverterBuilder)
501 defaultTypeConverter = typeConverterBuilder.getTypeConverter();
506 for (
Attribute attr : cast<ArrayAttr>(*getLegalOps()))
507 conversionTarget.addLegalOp(
510 for (
Attribute attr : cast<ArrayAttr>(*getIllegalOps()))
511 conversionTarget.addIllegalOp(
513 if (getLegalDialects())
514 for (
Attribute attr : cast<ArrayAttr>(*getLegalDialects()))
515 conversionTarget.addLegalDialect(cast<StringAttr>(attr).getValue());
516 if (getIllegalDialects())
517 for (
Attribute attr : cast<ArrayAttr>(*getIllegalDialects()))
518 conversionTarget.addIllegalDialect(cast<StringAttr>(attr).getValue());
526 if (!getPatterns().empty()) {
527 for (
Operation &op : getPatterns().front()) {
529 cast<transform::ConversionPatternDescriptorOpInterface>(&op);
532 std::unique_ptr<TypeConverter> typeConverter =
533 descriptor.getTypeConverter();
536 keepAliveConverters.emplace_back(std::move(typeConverter));
537 converter = keepAliveConverters.back().get();
540 if (!defaultTypeConverter) {
542 <<
"pattern descriptor does not specify type "
543 "converter and apply_conversion_patterns op has "
544 "no default type converter";
545 diag.attachNote(op.
getLoc()) <<
"pattern descriptor op";
548 converter = defaultTypeConverter.get();
554 descriptor.populateConversionTargetRules(*converter, conversionTarget);
556 descriptor.populatePatterns(*converter, patterns);
561 for (
Operation *target : state.getPayloadOps(getTarget())) {
570 if (getPartialConversion()) {
577 auto diag = emitSilenceableError() <<
"dialect conversion failed";
578 diag.attachNote(target->
getLoc()) <<
"target op";
587 if (getNumRegions() != 1 && getNumRegions() != 2)
588 return emitOpError() <<
"expected 1 or 2 regions";
589 if (!getPatterns().empty()) {
590 for (
Operation &op : getPatterns().front()) {
591 if (!isa<transform::ConversionPatternDescriptorOpInterface>(&op)) {
593 emitOpError() <<
"expected pattern children ops to implement "
594 "ConversionPatternDescriptorOpInterface";
595 diag.attachNote(op.
getLoc()) <<
"op without interface";
600 if (getNumRegions() == 2) {
601 Region &typeConverterRegion = getRegion(1);
602 if (!llvm::hasSingleElement(typeConverterRegion.
front()))
604 <<
"expected exactly one op in default type converter region";
605 auto typeConverterOp = dyn_cast<transform::TypeConverterBuilderOpInterface>(
607 if (!typeConverterOp) {
609 <<
"expected default converter child op to "
610 "implement TypeConverterBuilderOpInterface";
611 diag.attachNote(typeConverterOp->getLoc()) <<
"op without interface";
615 if (!getPatterns().empty()) {
616 for (
Operation &op : getPatterns().front()) {
618 cast<transform::ConversionPatternDescriptorOpInterface>(&op);
619 if (
failed(descriptor.verifyTypeConverter(typeConverterOp)))
627 void transform::ApplyConversionPatternsOp::getEffects(
628 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
633 void transform::ApplyConversionPatternsOp::build(
643 if (patternsBodyBuilder)
644 patternsBodyBuilder(builder, result.
location);
650 if (typeConverterBodyBuilder)
651 typeConverterBodyBuilder(builder, result.
location);
659 void transform::ApplyToLLVMConversionPatternsOp::populatePatterns(
662 assert(dialect &&
"expected that dialect is loaded");
663 auto iface = cast<ConvertToLLVMPatternInterface>(dialect);
667 iface->populateConvertToLLVMConversionPatterns(
671 LogicalResult transform::ApplyToLLVMConversionPatternsOp::verifyTypeConverter(
672 transform::TypeConverterBuilderOpInterface builder) {
673 if (builder.getTypeConverterType() !=
"LLVMTypeConverter")
674 return emitOpError(
"expected LLVMTypeConverter");
681 return emitOpError(
"unknown dialect or dialect not loaded: ")
683 auto iface = dyn_cast<ConvertToLLVMPatternInterface>(dialect);
686 "dialect does not implement ConvertToLLVMPatternInterface or "
687 "extension was not loaded: ")
697 transform::ApplyLoopInvariantCodeMotionOp::applyToOne(
707 void transform::ApplyLoopInvariantCodeMotionOp::getEffects(
708 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
735 <<
"unknown pass or pass pipeline: " << getPassName();
744 <<
"failed to add pass or pass pipeline to pipeline: "
747 if (
failed(pm.run(target))) {
748 auto diag = emitSilenceableError() <<
"pass pipeline failed";
749 diag.attachNote(target->
getLoc()) <<
"target op";
753 results.push_back(target);
763 Operation *target, ApplyToEachResultList &results,
765 results.push_back(target);
769 void transform::CastOp::getEffects(
770 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
777 assert(inputs.size() == 1 &&
"expected one input");
778 assert(outputs.size() == 1 &&
"expected one output");
780 std::initializer_list<Type>{inputs.front(), outputs.front()},
781 [](
Type ty) {
return isa<transform::TransformHandleTypeInterface>(ty); });
798 assert(block.
getParent() &&
"cannot match using a detached block");
799 auto matchScope = state.make_region_scope(*block.
getParent());
804 if (!isa<transform::MatchOpInterface>(match)) {
806 <<
"expected operations in the match part to "
807 "implement MatchOpInterface";
810 state.applyTransform(cast<transform::TransformOpInterface>(match));
811 if (
diag.succeeded())
830 matchActionPairs.reserve(getMatchers().size());
832 for (
auto &&[matcher, action] :
833 llvm::zip_equal(getMatchers(), getActions())) {
836 getOperation(), cast<SymbolRefAttr>(matcher));
839 getOperation(), cast<SymbolRefAttr>(action));
840 assert(matcherSymbol && actionSymbol &&
841 "unresolved symbols not caught by the verifier");
843 if (matcherSymbol.isExternal())
845 if (actionSymbol.isExternal())
848 matchActionPairs.emplace_back(matcherSymbol, actionSymbol);
851 for (
Operation *root : state.getPayloadOps(getRoot())) {
855 if (!getRestrictRoot() && op == root)
860 op->
print(llvm::dbgs(),
862 llvm::dbgs() <<
" @" << op <<
"\n";
866 for (
auto [matcher, action] : matchActionPairs) {
869 matchBlock(matcher.getFunctionBody().front(), op, state, mappings);
870 if (
diag.isDefiniteFailure())
872 if (
diag.isSilenceableFailure()) {
874 <<
" failed: " <<
diag.getMessage());
878 auto scope = state.make_region_scope(action.getFunctionBody());
879 for (
auto &&[arg, map] : llvm::zip_equal(
880 action.getFunctionBody().front().getArguments(), mappings)) {
881 if (
failed(state.mapBlockArgument(arg, map)))
886 action.getFunctionBody().front().without_terminator()) {
888 state.applyTransform(cast<TransformOpInterface>(transform));
904 results.
set(llvm::cast<OpResult>(getUpdated()),
905 state.getPayloadOps(getRoot()));
909 void transform::ForeachMatchOp::getEffects(
910 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
912 if (getOperation()->getNumOperands() < 1 ||
913 getOperation()->getNumResults() < 1) {
926 ArrayAttr &actions) {
948 ArrayAttr matchers, ArrayAttr actions) {
951 for (
auto &&[matcher, action, idx] : llvm::zip_equal(
952 matchers, actions, llvm::seq<unsigned>(0, matchers.size()))) {
954 printer << cast<SymbolRefAttr>(matcher) <<
" -> "
955 << cast<SymbolRefAttr>(action);
956 if (idx != matchers.size() - 1)
964 if (getMatchers().size() != getActions().size())
965 return emitOpError() <<
"expected the same number of matchers and actions";
966 if (getMatchers().empty())
967 return emitOpError() <<
"expected at least one match/action pair";
971 if (matcherNames.insert(name).second)
974 <<
" is used more than once, only the first match will apply";
982 template <
typename... Tys>
984 return ((isa<Tys>(t1) && isa<Tys>(t2)) || ... ||
false);
991 transform::TransformParamTypeInterface,
992 transform::TransformValueHandleTypeInterface>(
1001 bool alsoVerifyInternal =
false) {
1002 auto transformOp = cast<transform::TransformOpInterface>(op.getOperation());
1003 llvm::SmallDenseSet<unsigned> consumedArguments;
1004 if (!op.isExternal()) {
1008 for (
unsigned i = 0, e = op.getNumArguments(); i < e; ++i) {
1010 op.getArgAttr(i, transform::TransformDialect::kArgConsumedAttrName) !=
1013 op.getArgAttr(i, transform::TransformDialect::kArgReadOnlyAttrName) !=
1015 if (isConsumed && isReadOnly) {
1016 return transformOp.emitSilenceableError()
1017 <<
"argument #" << i <<
" cannot be both readonly and consumed";
1019 if ((op.isExternal() || alsoVerifyInternal) && !isConsumed && !isReadOnly) {
1020 return transformOp.emitSilenceableError()
1021 <<
"must provide consumed/readonly status for arguments of "
1022 "external or called ops";
1024 if (op.isExternal())
1027 if (consumedArguments.contains(i) && !isConsumed && isReadOnly) {
1028 return transformOp.emitSilenceableError()
1029 <<
"argument #" << i
1030 <<
" is consumed in the body but is not marked as such";
1032 if (emitWarnings && !consumedArguments.contains(i) && isConsumed) {
1036 <<
"op argument #" << i
1037 <<
" is not consumed in the body but is marked as consumed";
1045 assert(getMatchers().size() == getActions().size());
1048 for (
auto &&[matcher, action] :
1049 llvm::zip_equal(getMatchers(), getActions())) {
1050 auto matcherSymbol = dyn_cast_or_null<FunctionOpInterface>(
1052 cast<SymbolRefAttr>(matcher)));
1053 auto actionSymbol = dyn_cast_or_null<FunctionOpInterface>(
1055 cast<SymbolRefAttr>(action)));
1056 if (!matcherSymbol ||
1057 !isa<TransformOpInterface>(matcherSymbol.getOperation()))
1058 return emitError() <<
"unresolved matcher symbol " << matcher;
1059 if (!actionSymbol ||
1060 !isa<TransformOpInterface>(actionSymbol.getOperation()))
1061 return emitError() <<
"unresolved action symbol " << action;
1066 .checkAndReport())) {
1072 .checkAndReport())) {
1077 ArrayRef<Type> actionArguments = actionSymbol.getArgumentTypes();
1078 if (matcherResults.size() != actionArguments.size()) {
1079 return emitError() <<
"mismatching number of matcher results and "
1080 "action arguments between "
1081 << matcher <<
" (" << matcherResults.size() <<
") and "
1082 << action <<
" (" << actionArguments.size() <<
")";
1084 for (
auto &&[i, matcherType, actionType] :
1089 return emitError() <<
"mismatching type interfaces for matcher result "
1090 "and action argument #"
1094 if (!actionSymbol.getResultTypes().empty()) {
1096 emitError() <<
"action symbol is not expected to have results";
1097 diag.attachNote(actionSymbol->getLoc()) <<
"symbol declaration";
1101 if (matcherSymbol.getArgumentTypes().size() != 1 ||
1103 getRoot().getType())) {
1105 emitOpError() <<
"expects matcher symbol to have one argument with "
1106 "the same transform interface as the first operand";
1107 diag.attachNote(matcherSymbol->getLoc()) <<
"symbol declaration";
1111 if (matcherSymbol.getArgAttr(0, consumedAttr)) {
1114 <<
"does not expect matcher symbol to consume its operand";
1115 diag.attachNote(matcherSymbol->getLoc()) <<
"symbol declaration";
1134 llvm::to_vector(state.getPayloadOps(getTarget()));
1136 auto scope = state.make_region_scope(getBody());
1137 if (
failed(state.mapBlockArguments(getIterationVariable(), {op})))
1141 for (
Operation &transform : getBody().front().without_terminator()) {
1143 cast<transform::TransformOpInterface>(transform));
1149 for (
unsigned i = 0; i < getNumResults(); ++i) {
1150 auto yieldedOps = state.getPayloadOps(getYieldOp().getOperand(i));
1151 resultOps[i].append(yieldedOps.begin(), yieldedOps.end());
1155 for (
unsigned i = 0; i < getNumResults(); ++i)
1156 results.
set(llvm::cast<OpResult>(getResult(i)), resultOps[i]);
1161 void transform::ForeachOp::getEffects(
1162 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
1164 if (any_of(getBody().front().without_terminator(), [&](
Operation &op) {
1172 if (any_of(getBody().front().without_terminator(), [&](
Operation &op) {
1176 }
else if (any_of(getBody().front().without_terminator(), [&](
Operation &op) {
1182 for (
Value result : getResults())
1186 void transform::ForeachOp::getSuccessorRegions(
1188 Region *bodyRegion = &getBody();
1190 regions.emplace_back(bodyRegion, bodyRegion->
getArguments());
1195 assert(point == getBody() &&
"unexpected region index");
1196 regions.emplace_back(bodyRegion, bodyRegion->
getArguments());
1197 regions.emplace_back();
1204 assert(point == getBody() &&
"unexpected region index");
1205 return getOperation()->getOperands();
1208 transform::YieldOp transform::ForeachOp::getYieldOp() {
1209 return cast<transform::YieldOp>(getBody().front().getTerminator());
1213 auto yieldOp = getYieldOp();
1214 if (getNumResults() != yieldOp.getNumOperands())
1215 return emitOpError() <<
"expects the same number of results as the "
1216 "terminator has operands";
1217 for (
Value v : yieldOp.getOperands())
1218 if (!llvm::isa<TransformHandleTypeInterface>(v.getType()))
1219 return yieldOp->emitOpError(
"expects operands to have types implementing "
1220 "TransformHandleTypeInterface");
1234 for (
Operation *target : state.getPayloadOps(getTarget())) {
1236 for (int64_t i = 0, e = getNthParent(); i < e; ++i) {
1239 bool checkIsolatedFromAbove =
1240 !getIsolatedFromAbove() ||
1242 bool checkOpName = !getOpName().has_value() ||
1244 if (checkIsolatedFromAbove && checkOpName)
1249 if (getAllowEmptyResults()) {
1250 results.
set(llvm::cast<OpResult>(getResult()), parents);
1254 emitSilenceableError()
1255 <<
"could not find a parent op that matches all requirements";
1256 diag.attachNote(target->
getLoc()) <<
"target op";
1260 if (getDeduplicate()) {
1261 if (!resultSet.contains(parent)) {
1262 parents.push_back(parent);
1263 resultSet.insert(parent);
1266 parents.push_back(parent);
1269 results.
set(llvm::cast<OpResult>(getResult()), parents);
1281 int64_t resultNumber = getResultNumber();
1282 auto payloadOps = state.getPayloadOps(getTarget());
1283 if (std::empty(payloadOps)) {
1284 results.
set(cast<OpResult>(getResult()), {});
1287 if (!llvm::hasSingleElement(payloadOps))
1289 <<
"handle must be mapped to exactly one payload op";
1291 Operation *target = *payloadOps.begin();
1294 results.
set(llvm::cast<OpResult>(getResult()),
1308 for (
Value v : state.getPayloadValues(getTarget())) {
1309 if (llvm::isa<BlockArgument>(v)) {
1311 emitSilenceableError() <<
"cannot get defining op of block argument";
1312 diag.attachNote(v.getLoc()) <<
"target value";
1315 definingOps.push_back(v.getDefiningOp());
1317 results.
set(llvm::cast<OpResult>(getResult()), definingOps);
1329 int64_t operandNumber = getOperandNumber();
1331 for (
Operation *target : state.getPayloadOps(getTarget())) {
1338 emitSilenceableError()
1339 <<
"could not find a producer for operand number: " << operandNumber
1340 <<
" of " << *target;
1341 diag.attachNote(target->getLoc()) <<
"target op";
1344 producers.push_back(producer);
1346 results.
set(llvm::cast<OpResult>(getResult()), producers);
1358 int64_t resultNumber = getResultNumber();
1360 for (
Operation *target : state.getPayloadOps(getTarget())) {
1361 if (resultNumber >= target->getNumResults()) {
1363 emitSilenceableError() <<
"targeted op does not have enough results";
1364 diag.attachNote(target->getLoc()) <<
"target op";
1367 opResults.push_back(target->getOpResult(resultNumber));
1369 results.
setValues(llvm::cast<OpResult>(getResult()), opResults);
1377 void transform::GetTypeOp::getEffects(
1378 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
1389 for (
Value value : state.getPayloadValues(getValue())) {
1390 Type type = value.getType();
1391 if (getElemental()) {
1392 if (
auto shaped = dyn_cast<ShapedType>(type)) {
1393 type = shaped.getElementType();
1398 results.
setParams(getResult().cast<OpResult>(), params);
1415 state.applyTransform(cast<transform::TransformOpInterface>(transform));
1420 if (mode == transform::FailurePropagationMode::Propagate) {
1439 auto callee = SymbolTable::lookupNearestSymbolFrom<NamedSequenceOp>(
1440 getOperation(), getTarget());
1441 assert(callee &&
"unverified reference to unknown symbol");
1443 if (callee.isExternal())
1449 auto scope = state.make_region_scope(callee.getBody());
1450 for (
auto &&[arg, map] :
1451 llvm::zip_equal(callee.getBody().front().getArguments(), mappings)) {
1452 if (
failed(state.mapBlockArgument(arg, map)))
1457 callee.getBody().front(), getFailurePropagationMode(), state, results);
1460 mappings, callee.getBody().front().getTerminator()->getOperands(), state);
1461 for (
auto &&[result, mapping] : llvm::zip_equal(getResults(), mappings))
1469 void transform::IncludeOp::getEffects(
1470 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
1484 auto defaultEffects = [&] {
onlyReadsHandle(getOperands(), effects); };
1489 getOperation()->getAttrOfType<SymbolRefAttr>(getTargetAttrName());
1491 return defaultEffects();
1492 auto callee = SymbolTable::lookupNearestSymbolFrom<NamedSequenceOp>(
1493 getOperation(), getTarget());
1495 return defaultEffects();
1499 (void)earlyVerifierResult.
silence();
1500 return defaultEffects();
1503 for (
unsigned i = 0, e = getNumOperands(); i < e; ++i) {
1504 if (callee.getArgAttr(i, TransformDialect::kArgConsumedAttrName))
1515 auto targetAttr = getOperation()->getAttrOfType<SymbolRefAttr>(
"target");
1517 return emitOpError() <<
"expects a 'target' symbol reference attribute";
1522 return emitOpError() <<
"does not reference a named transform sequence";
1524 FunctionType fnType = target.getFunctionType();
1525 if (fnType.getNumInputs() != getNumOperands())
1526 return emitError(
"incorrect number of operands for callee");
1528 for (
unsigned i = 0, e = fnType.getNumInputs(); i != e; ++i) {
1529 if (getOperand(i).getType() != fnType.getInput(i)) {
1530 return emitOpError(
"operand type mismatch: expected operand type ")
1531 << fnType.getInput(i) <<
", but provided "
1532 << getOperand(i).getType() <<
" for operand number " << i;
1536 if (fnType.getNumResults() != getNumResults())
1537 return emitError(
"incorrect number of results for callee");
1539 for (
unsigned i = 0, e = fnType.getNumResults(); i != e; ++i) {
1540 Type resultType = getResult(i).getType();
1541 Type funcType = fnType.getResult(i);
1543 return emitOpError() <<
"type of result #" << i
1544 <<
" must implement the same transform dialect "
1545 "interface as the corresponding callee result";
1550 cast<FunctionOpInterface>(*target),
false,
1560 ::std::optional<::mlir::Operation *> maybeCurrent,
1562 if (!maybeCurrent.has_value()) {
1567 return emitSilenceableError() <<
"operation is not empty";
1578 for (
auto acceptedAttr : getOpNames().getAsRange<StringAttr>()) {
1579 if (acceptedAttr.getValue() == currentOpName)
1582 return emitSilenceableError() <<
"wrong operation name";
1593 auto signedAPIntAsString = [&](APInt value) {
1595 llvm::raw_string_ostream os(str);
1596 value.print(os,
true);
1603 if (params.size() != references.size()) {
1604 return emitSilenceableError()
1605 <<
"parameters have different payload lengths (" << params.size()
1606 <<
" vs " << references.size() <<
")";
1609 for (
auto &&[i, param, reference] :
llvm::enumerate(params, references)) {
1610 auto intAttr = llvm::dyn_cast<IntegerAttr>(param);
1611 auto refAttr = llvm::dyn_cast<IntegerAttr>(reference);
1612 if (!intAttr || !refAttr) {
1614 <<
"non-integer parameter value not expected";
1616 if (intAttr.getType() != refAttr.getType()) {
1618 <<
"mismatching integer attribute types in parameter #" << i;
1620 APInt value = intAttr.getValue();
1621 APInt refValue = refAttr.getValue();
1624 int64_t position = i;
1625 auto reportError = [&](StringRef direction) {
1627 emitSilenceableError() <<
"expected parameter to be " << direction
1628 <<
" " << signedAPIntAsString(refValue)
1629 <<
", got " << signedAPIntAsString(value);
1630 diag.attachNote(getParam().getLoc())
1631 <<
"value # " << position
1632 <<
" associated with the parameter defined here";
1636 switch (getPredicate()) {
1637 case MatchCmpIPredicate::eq:
1638 if (value.eq(refValue))
1640 return reportError(
"equal to");
1641 case MatchCmpIPredicate::ne:
1642 if (value.ne(refValue))
1644 return reportError(
"not equal to");
1645 case MatchCmpIPredicate::lt:
1646 if (value.slt(refValue))
1648 return reportError(
"less than");
1649 case MatchCmpIPredicate::le:
1650 if (value.sle(refValue))
1652 return reportError(
"less than or equal to");
1653 case MatchCmpIPredicate::gt:
1654 if (value.sgt(refValue))
1656 return reportError(
"greater than");
1657 case MatchCmpIPredicate::ge:
1658 if (value.sge(refValue))
1660 return reportError(
"greater than or equal to");
1666 void transform::MatchParamCmpIOp::getEffects(
1667 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
1680 results.
setParams(cast<OpResult>(getParam()), {getValue()});
1693 if (isa<TransformHandleTypeInterface>(handles.front().
getType())) {
1695 for (
Value operand : handles)
1696 llvm::append_range(operations, state.getPayloadOps(operand));
1697 if (!getDeduplicate()) {
1698 results.
set(llvm::cast<OpResult>(getResult()), operations);
1703 results.
set(llvm::cast<OpResult>(getResult()), uniqued.getArrayRef());
1707 if (llvm::isa<TransformParamTypeInterface>(handles.front().getType())) {
1709 for (
Value attribute : handles)
1710 llvm::append_range(attrs, state.getParams(attribute));
1711 if (!getDeduplicate()) {
1712 results.
setParams(cast<OpResult>(getResult()), attrs);
1717 results.
setParams(cast<OpResult>(getResult()), uniqued.getArrayRef());
1722 llvm::isa<TransformValueHandleTypeInterface>(handles.front().getType()) &&
1723 "expected value handle type");
1725 for (
Value value : handles)
1726 llvm::append_range(payloadValues, state.getPayloadValues(value));
1727 if (!getDeduplicate()) {
1728 results.
setValues(cast<OpResult>(getResult()), payloadValues);
1733 results.
setValues(cast<OpResult>(getResult()), uniqued.getArrayRef());
1737 bool transform::MergeHandlesOp::allowsRepeatedHandleOperands() {
1739 return getDeduplicate();
1742 void transform::MergeHandlesOp::getEffects(
1743 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
1751 OpFoldResult transform::MergeHandlesOp::fold(FoldAdaptor adaptor) {
1752 if (getDeduplicate() || getHandles().size() != 1)
1757 return getHandles().front();
1775 auto scope = state.make_region_scope(getBody());
1777 state, this->getOperation(), getBody())))
1781 FailurePropagationMode::Propagate, state, results);
1784 void transform::NamedSequenceOp::getEffects(
1785 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {}
1790 parser, result,
false,
1791 getFunctionTypeAttrName(result.
name),
1794 std::string &) { return builder.getFunctionType(inputs, results); },
1795 getArgAttrsAttrName(result.
name), getResAttrsAttrName(result.
name));
1800 printer, cast<FunctionOpInterface>(getOperation()),
false,
1801 getFunctionTypeAttrName().getValue(), getArgAttrsAttrName(),
1802 getResAttrsAttrName());
1812 if (
auto parent = op->
getParentOfType<transform::TransformOpInterface>()) {
1815 <<
"cannot be defined inside another transform op";
1816 diag.attachNote(parent.
getLoc()) <<
"ancestor transform op";
1820 if (op.isExternal() || op.getFunctionBody().empty()) {
1827 if (op.getFunctionBody().front().empty())
1830 Operation *terminator = &op.getFunctionBody().front().back();
1831 if (!isa<transform::YieldOp>(terminator)) {
1834 << transform::YieldOp::getOperationName()
1835 <<
"' as terminator";
1836 diag.attachNote(terminator->
getLoc()) <<
"terminator";
1842 <<
"expected terminator to have as many operands as the parent op "
1845 for (
auto [i, operandType, resultType] : llvm::zip_equal(
1848 if (operandType == resultType)
1851 <<
"the type of the terminator operand #" << i
1852 <<
" must match the type of the corresponding parent op result ("
1853 << operandType <<
" vs " << resultType <<
")";
1866 transform::TransformDialect::kWithNamedSequenceAttrName)) {
1869 <<
"expects the parent symbol table to have the '"
1870 << transform::TransformDialect::kWithNamedSequenceAttrName
1872 diag.attachNote(parent->
getLoc()) <<
"symbol table operation";
1877 if (
auto parent = op->
getParentOfType<transform::TransformOpInterface>()) {
1880 <<
"cannot be defined inside another transform op";
1881 diag.attachNote(parent.
getLoc()) <<
"ancestor transform op";
1885 if (op.isExternal() || op.getBody().empty())
1889 if (op.getBody().front().empty())
1892 Operation *terminator = &op.getBody().front().back();
1893 if (!isa<transform::YieldOp>(terminator)) {
1896 << transform::YieldOp::getOperationName()
1897 <<
"' as terminator";
1898 diag.attachNote(terminator->
getLoc()) <<
"terminator";
1904 <<
"expected terminator to have as many operands as the parent op "
1907 for (
auto [i, operandType, resultType] :
1908 llvm::zip_equal(llvm::seq<unsigned>(0, terminator->
getNumOperands()),
1911 if (operandType == resultType)
1914 <<
"the type of the terminator operand #" << i
1915 <<
" must match the type of the corresponding parent op result ("
1916 << operandType <<
" vs " << resultType <<
")";
1919 auto funcOp = cast<FunctionOpInterface>(*op);
1922 if (!
diag.succeeded())
1934 template <
typename FnTy>
1939 types.reserve(1 + extraBindingTypes.size());
1940 types.push_back(bbArgType);
1941 llvm::append_range(types, extraBindingTypes);
1944 Region *region = state.regions.back().get();
1951 if constexpr (llvm::function_traits<FnTy>::num_args == 3) {
1952 bodyBuilder(builder, state.location, bodyBlock->
getArgument(0));
1954 bodyBuilder(builder, state.location, bodyBlock->
getArgument(0),
1959 void transform::NamedSequenceOp::build(
OpBuilder &builder,
1967 state.addAttribute(getFunctionTypeAttrName(state.name),
1969 rootType, resultTypes)));
1970 state.attributes.append(attrs.begin(), attrs.end());
1986 auto payloadOps = state.getPayloadOps(getTarget());
1989 result.push_back(op);
1991 results.
set(cast<OpResult>(getResult()), result);
2000 Value target, int64_t numResultHandles) {
2009 int64_t numPayloadOps = llvm::range_size(state.getPayloadOps(getHandle()));
2010 auto produceNumOpsError = [&]() {
2011 return emitSilenceableError()
2012 << getHandle() <<
" expected to contain " << this->getNumResults()
2013 <<
" payload ops but it contains " << numPayloadOps
2019 if (numPayloadOps > getNumResults() && !getOverflowResult().has_value())
2020 return produceNumOpsError();
2025 if (numPayloadOps < getNumResults() && getFailOnPayloadTooSmall() &&
2026 !(numPayloadOps == 0 && getPassThroughEmptyHandle()))
2027 return produceNumOpsError();
2031 if (getOverflowResult())
2032 resultHandles[*getOverflowResult()].reserve(numPayloadOps -
2035 int64_t resultNum = en.index();
2036 if (resultNum >= getNumResults())
2037 resultNum = *getOverflowResult();
2038 resultHandles[resultNum].push_back(en.value());
2043 results.
set(llvm::cast<OpResult>(getResult(it.index())), it.value());
2048 void transform::SplitHandleOp::getEffects(
2049 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
2057 if (getOverflowResult().has_value() &&
2058 !(*getOverflowResult() < getNumResults()))
2059 return emitOpError(
"overflow_result is not a valid result index");
2071 unsigned numRepetitions = llvm::range_size(state.getPayloadOps(getPattern()));
2073 Value handle = en.value();
2074 if (isa<TransformHandleTypeInterface>(handle.
getType())) {
2076 llvm::to_vector(state.getPayloadOps(handle));
2078 payload.reserve(numRepetitions * current.size());
2079 for (
unsigned i = 0; i < numRepetitions; ++i)
2080 llvm::append_range(payload, current);
2081 results.
set(llvm::cast<OpResult>(getReplicated()[en.index()]), payload);
2083 assert(llvm::isa<TransformParamTypeInterface>(handle.
getType()) &&
2084 "expected param type");
2087 params.reserve(numRepetitions * current.size());
2088 for (
unsigned i = 0; i < numRepetitions; ++i)
2089 llvm::append_range(params, current);
2090 results.
setParams(llvm::cast<OpResult>(getReplicated()[en.index()]),
2097 void transform::ReplicateOp::getEffects(
2098 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
2113 auto scope = state.make_region_scope(*getBodyBlock()->getParent());
2114 if (
failed(mapBlockArguments(state)))
2122 OpAsmParser &parser, std::optional<OpAsmParser::UnresolvedOperand> &root,
2124 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &extraBindings,
2125 SmallVectorImpl<Type> &extraBindingTypes) {
2129 root = std::nullopt;
2150 if (!extraBindings.empty()) {
2155 if (extraBindingTypes.size() != extraBindings.size()) {
2157 "expected types to be provided for all operands");
2173 bool hasExtras = !extraBindings.empty();
2183 printer << rootType;
2186 llvm::interleaveComma(extraBindingTypes, printer.
getStream());
2195 auto iface = dyn_cast<transform::TransformOpInterface>(use.
getOwner());
2210 if (!potentialConsumer) {
2211 potentialConsumer = &use;
2216 <<
" has more than one potential consumer";
2219 diag.attachNote(use.getOwner()->getLoc())
2220 <<
"used here as operand #" << use.getOperandNumber();
2228 assert(getBodyBlock()->getNumArguments() >= 1 &&
2229 "the number of arguments must have been verified to be more than 1 by "
2230 "PossibleTopLevelTransformOpTrait");
2232 if (!getRoot() && !getExtraBindings().empty()) {
2233 return emitOpError()
2234 <<
"does not expect extra operands when used as top-level";
2240 return (emitOpError() <<
"block argument #" << arg.getArgNumber());
2247 for (
Operation &child : *getBodyBlock()) {
2248 if (!isa<TransformOpInterface>(child) &&
2249 &child != &getBodyBlock()->back()) {
2252 <<
"expected children ops to implement TransformOpInterface";
2253 diag.attachNote(child.getLoc()) <<
"op without interface";
2257 for (
OpResult result : child.getResults()) {
2258 auto report = [&]() {
2259 return (child.emitError() <<
"result #" << result.getResultNumber());
2266 if (!getBodyBlock()->mightHaveTerminator())
2267 return emitOpError() <<
"expects to have a terminator in the body";
2269 if (getBodyBlock()->getTerminator()->getOperandTypes() !=
2270 getOperation()->getResultTypes()) {
2272 <<
"expects the types of the terminator operands "
2273 "to match the types of the result";
2274 diag.attachNote(getBodyBlock()->getTerminator()->getLoc()) <<
"terminator";
2280 void transform::SequenceOp::getEffects(
2281 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
2287 assert(point == getBody() &&
"unexpected region index");
2288 if (getOperation()->getNumOperands() > 0)
2289 return getOperation()->getOperands();
2291 getOperation()->operand_end());
2294 void transform::SequenceOp::getSuccessorRegions(
2297 Region *bodyRegion = &getBody();
2298 regions.emplace_back(bodyRegion, getNumOperands() != 0
2304 assert(point == getBody() &&
"unexpected region index");
2305 regions.emplace_back(getOperation()->getResults());
2308 void transform::SequenceOp::getRegionInvocationBounds(
2311 bounds.emplace_back(1, 1);
2316 FailurePropagationMode failurePropagationMode,
2319 build(builder, state, resultTypes, failurePropagationMode, root,
2328 FailurePropagationMode failurePropagationMode,
2331 build(builder, state, resultTypes, failurePropagationMode, root,
2339 FailurePropagationMode failurePropagationMode,
2342 build(builder, state, resultTypes, failurePropagationMode,
Value(),
2350 FailurePropagationMode failurePropagationMode,
2353 build(builder, state, resultTypes, failurePropagationMode,
Value(),
2369 Value target, StringRef name) {
2371 build(builder, result, name);
2378 llvm::outs() <<
"[[[ IR printer: ";
2379 if (getName().has_value())
2380 llvm::outs() << *getName() <<
" ";
2383 llvm::outs() <<
"top-level ]]]\n" << *state.getTopLevel() <<
"\n";
2387 llvm::outs() <<
"]]]\n";
2388 for (
Operation *target : state.getPayloadOps(getTarget()))
2389 llvm::outs() << *target <<
"\n";
2394 void transform::PrintOp::getEffects(
2395 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
2415 <<
"failed to verify payload op";
2416 diag.attachNote(target->
getLoc()) <<
"payload op";
2422 void transform::VerifyOp::getEffects(
2423 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
2431 void transform::YieldOp::getEffects(
2432 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
static MLIRContext * getContext(OpFoldResult val)
static std::string diag(const llvm::Value &value)
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
ParseResult parseSymbolName(StringAttr &result)
Parse an -identifier and store it (without the '@' symbol) in a string attribute.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseOptionalRParen()=0
Parse a ) token if present.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual ParseResult parseColon()=0
Parse a : token.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseArrow()=0
Parse a '->' token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseComma()=0
Parse a , token.
virtual ParseResult parseOptionalLParen()=0
Parse a ( token if present.
ParseResult parseTypeList(SmallVectorImpl< Type > &result)
Parse a type list.
virtual raw_ostream & getStream() const
Return the raw output stream used by this printer.
Attributes are known-constant values of operations.
This class represents an argument of a Block.
Block represents an ordered list of Operations.
BlockArgument getArgument(unsigned i)
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
Operation * getTerminator()
Get the terminator operation of this block.
OpListType & getOperations()
BlockArgListType getArguments()
iterator_range< iterator > without_terminator()
Return an iterator range over the operation within this block excluding the terminator operation at t...
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
This class is a general helper class for creating context-global objects like types,...
StringAttr getStringAttr(const Twine &bytes)
MLIRContext * getContext() const
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
This class describes a specific conversion target.
A compatibility class connecting InFlightDiagnostic to DiagnosedSilenceableFailure while providing an...
The result of a transform IR operation application.
LogicalResult silence()
Converts silenceable failure into LogicalResult success without reporting the diagnostic,...
static DiagnosedSilenceableFailure success()
Constructs a DiagnosedSilenceableFailure in the success state.
std::string getMessage() const
Returns the diagnostic message without emitting it.
bool isDefiniteFailure() const
Returns true if this is a definite failure.
LogicalResult checkAndReport()
Converts all kinds of failure into a LogicalResult failure, emitting the diagnostic if necessary.
bool succeeded() const
Returns true if this is a success.
static DiagnosedSilenceableFailure definiteFailure()
Constructs a DiagnosedSilenceableFailure in the failure state.
bool isSilenceableFailure() const
Returns true if this is a silenceable failure.
Dialects are groups of MLIR operations, types and attributes, as well as behavior associated with the...
A class for computing basic dominance information.
This class represents a frozen set of patterns that can be processed by a pattern applicator.
This class allows control over how the GreedyPatternRewriteDriver works.
RewriterBase::Listener * listener
An optional listener that should be notified about IR modifications.
IRValueT get() const
Return the current value being used by this operand.
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
This class represents a diagnostic that is inflight and set to be reported.
This class represents upper and lower bounds on the number of times a region of a RegionBranchOpInter...
Conversion from types to the LLVM IR dialect.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
Dialect * getLoadedDialect(StringRef name)
Get a registered IR dialect with the given namespace.
std::vector< Dialect * > getLoadedDialects()
Return information about all IR dialects loaded in the context.
ArrayRef< RegisteredOperationName > getRegisteredOperations()
Return a sorted array containing the information about all registered operations.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual OptionalParseResult parseOptionalOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single operand if present.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printNewline()=0
Print a newline and indent the printer to the start of the current operation.
virtual void increaseIndent()=0
Increase indentation.
void printOperands(const ContainerType &container)
Print a comma separated list of operands.
virtual void decreaseIndent()=0
Decrease indentation.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Listener * getListener() const
Returns the current listener of this builder, or nullptr if this builder doesn't have a listener.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
This class represents a single result from folding an operation.
This class represents an operand of an operation.
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Set of flags used to control the behavior of the various IR print methods (e.g.
This is a value defined by a result of an operation.
This class provides the API for ops that are known to be isolated from above.
A trait used to provide symbol table functionalities to a region operation.
This class implements the operand iterators for the Operation class.
type_range getType() const
type_range getTypes() const
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
Operation is the basic unit of execution within MLIR.
Value getOperand(unsigned idx)
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Attribute getAttr(StringAttr name)
Return the specified attribute if present, null otherwise.
Operation * clone(IRMapping &mapper, CloneOptions options=CloneOptions::all())
Create a deep copy of this operation, remapping any operands that use values outside of the operation...
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
void print(raw_ostream &os, const OpPrintingFlags &flags=std::nullopt)
MLIRContext * getContext()
Return the context this operation is associated with.
Location getLoc()
The source location the operation was defined or derived from.
unsigned getNumOperands()
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Block * getBlock()
Returns the operation block that contains this operation.
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
Operation * getParentWithTrait()
Returns the closest surrounding parent operation with trait Trait.
OperationName getName()
The name of an operation is the key identifier for it.
result_type_range getResultTypes()
operand_range getOperands()
Returns an iterator on the underlying Value's.
result_range getOpResults()
result_range getResults()
bool isProperAncestor(Operation *other)
Return true if this operation is a proper ancestor of the other operation.
void erase()
Remove this operation from its parent block and delete it.
unsigned getNumResults()
Return the number of results held by this operation.
This class implements Optional functionality for ParseResult.
ParseResult value() const
Access the internal ParseResult value.
bool has_value() const
Returns true if we contain a valid ParseResult value.
This class represents success/failure for parsing-like operations that find it important to chain tog...
static const PassInfo * lookup(StringRef passArg)
Returns the pass info for the specified pass class or null if unknown.
The main pass manager and pipeline builder.
static const PassPipelineInfo * lookup(StringRef pipelineArg)
Returns the pass pipeline info for the specified pass pipeline or null if unknown.
Structure to group information about a passes and pass pipelines (argument to invoke via mlir-opt,...
LogicalResult addToPipeline(OpPassManager &pm, StringRef options, function_ref< LogicalResult(const Twine &)> errorHandler) const
Adds this pass registry entry to the given pass manager.
This class represents a point being branched from in the methods of the RegionBranchOpInterface.
bool isParent() const
Returns true if branching from the parent op.
Region * getRegionOrNull() const
Returns the region if branching from a region.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
BlockArgListType getArguments()
unsigned getRegionNumber()
Return the number of this region in the parent operation.
This is a "type erased" representation of a registered operation.
MLIRContext * getContext() const
virtual void replaceOp(Operation *op, ValueRange newValues)
This method replaces the results of the operation with the specified list of values.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
static DerivedEffect * get()
Returns a unique instance for the derived effect class.
This class represents a collection of SymbolTables.
Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
static StringRef getSymbolAttrName()
Return the name of the attribute used for symbol names.
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
type_range getType() const
type_range getTypes() const
size_t size() const
Return the size of this range.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
use_range getUses() const
Returns a range of all uses, which is useful for iterating over all uses.
user_range getUsers() const
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
A utility result that is used to signal how to proceed with an ongoing walk:
static WalkResult advance()
bool wasInterrupted() const
Returns true if the walk was interrupted.
static WalkResult interrupt()
Operation * getOwner() const
Return the owner of this operand.
A named class for passing around the variadic flag.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
void printFunctionOp(OpAsmPrinter &p, FunctionOpInterface op, bool isVariadic, StringRef typeAttrName, StringAttr argAttrsName, StringAttr resAttrsName)
Printer implementation for function-like operations.
ParseResult parseFunctionOp(OpAsmParser &parser, OperationState &result, bool allowVariadic, StringAttr typeAttrName, FuncTypeBuilder funcTypeBuilder, StringAttr argAttrsName, StringAttr resAttrsName)
Parser implementation for function-like operations.
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
InFlightDiagnostic emitWarning(Location loc)
Utility method to emit a warning message using this location.
LogicalResult applyOpPatternsAndFold(ArrayRef< Operation * > ops, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr, bool *allErased=nullptr)
Applies the specified rewrite patterns on ops while also trying to fold these ops.
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, DenseSet< Operation * > *unconvertedOps=nullptr)
Below we define several entry points for operation conversion.
DiagnosedSilenceableFailure emitSilenceableFailure(Location loc, const Twine &message={})
Emits a silenceable failure with the given message.
LogicalResult applyFullConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns)
Apply a complete conversion on the given operations, and all nested operations.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
DiagnosedDefiniteFailure emitDefiniteFailure(Location loc, const Twine &message={})
Emits a definite failure with the given message.
bool isOpTriviallyDead(Operation *op)
Return true if the given operation is unused, and has no side effects on memory that prevent erasing.
LogicalResult applyPatternsAndFoldGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
void eliminateCommonSubExpressions(RewriterBase &rewriter, DominanceInfo &domInfo, Operation *op, bool *changed=nullptr)
Eliminate common subexpressions within the given operation.
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
size_t moveLoopInvariantCode(ArrayRef< Region * > regions, function_ref< bool(Value, Region *)> isDefinedOutsideRegion, function_ref< bool(Operation *, Region *)> shouldMoveOutOfRegion, function_ref< void(Operation *, Region *)> moveOutOfRegion)
Given a list of regions, perform loop-invariant code motion.
This class represents an efficient way to signal success or failure.
bool succeeded() const
Returns true if the provided LogicalResult corresponds to a success value.
This is the representation of an operand reference.
This represents an operation in an abstracted form, suitable for use with the builder APIs.
T & getOrAddProperties()
Get (or create) a properties of the provided type to be set on the operation on creation.
void addOperands(ValueRange newOperands)
void addTypes(ArrayRef< Type > newTypes)
Region * addRegion()
Create a region that should be attached to the operation.