26 #include "llvm/ADT/ArrayRef.h"
27 #include "llvm/ADT/PostOrderIterator.h"
28 #include "llvm/ADT/STLExtras.h"
29 #include "llvm/ADT/STLForwardCompat.h"
30 #include "llvm/ADT/SmallString.h"
31 #include "llvm/ADT/StringExtras.h"
32 #include "llvm/ADT/StringRef.h"
33 #include "llvm/ADT/TypeSwitch.h"
34 #include "llvm/ADT/bit.h"
35 #include "llvm/Frontend/OpenMP/OMPConstants.h"
41 #include "mlir/Dialect/OpenMP/OpenMPOpsDialect.cpp.inc"
42 #include "mlir/Dialect/OpenMP/OpenMPOpsEnums.cpp.inc"
43 #include "mlir/Dialect/OpenMP/OpenMPOpsInterfaces.cpp.inc"
44 #include "mlir/Dialect/OpenMP/OpenMPTypeInterfaces.cpp.inc"
60 struct MemRefPointerLikeModel
61 :
public PointerLikeType::ExternalModel<MemRefPointerLikeModel,
64 return llvm::cast<MemRefType>(pointer).getElementType();
68 struct LLVMPointerPointerLikeModel
69 :
public PointerLikeType::ExternalModel<LLVMPointerPointerLikeModel,
70 LLVM::LLVMPointerType> {
75 void OpenMPDialect::initialize() {
78 #include "mlir/Dialect/OpenMP/OpenMPOps.cpp.inc"
81 #define GET_ATTRDEF_LIST
82 #include "mlir/Dialect/OpenMP/OpenMPOpsAttributes.cpp.inc"
85 #define GET_TYPEDEF_LIST
86 #include "mlir/Dialect/OpenMP/OpenMPOpsTypes.cpp.inc"
89 declarePromisedInterface<ConvertToLLVMPatternInterface, OpenMPDialect>();
91 MemRefType::attachInterface<MemRefPointerLikeModel>(*
getContext());
92 LLVM::LLVMPointerType::attachInterface<LLVMPointerPointerLikeModel>(
97 mlir::ModuleOp::attachInterface<mlir::omp::OffloadModuleDefaultModel>(
103 mlir::LLVM::GlobalOp::attachInterface<
106 mlir::LLVM::LLVMFuncOp::attachInterface<
109 mlir::func::FuncOp::attachInterface<
135 allocatorVars.push_back(operand);
136 allocatorTypes.push_back(type);
142 allocateVars.push_back(operand);
143 allocateTypes.push_back(type);
154 for (
unsigned i = 0; i < allocateVars.size(); ++i) {
155 std::string separator = i == allocateVars.size() - 1 ?
"" :
", ";
156 p << allocatorVars[i] <<
" : " << allocatorTypes[i] <<
" -> ";
157 p << allocateVars[i] <<
" : " << allocateTypes[i] << separator;
165 template <
typename ClauseAttr>
167 using ClauseT = decltype(std::declval<ClauseAttr>().getValue());
172 if (std::optional<ClauseT> enumValue = symbolizeEnum<ClauseT>(enumStr)) {
176 return parser.
emitError(loc,
"invalid clause value: '") << enumStr <<
"'";
179 template <
typename ClauseAttr>
181 p << stringifyEnum(attr.getValue());
204 linearVars.push_back(var);
205 linearTypes.push_back(type);
206 linearStepVars.push_back(stepVar);
215 size_t linearVarsSize = linearVars.size();
216 for (
unsigned i = 0; i < linearVarsSize; ++i) {
217 std::string separator = i == linearVarsSize - 1 ?
"" :
", ";
219 if (linearStepVars.size() > i)
220 p <<
" = " << linearStepVars[i];
221 p <<
" : " << linearVars[i].getType() << separator;
234 for (
const auto &it : nontemporalVars)
235 if (!nontemporalItems.insert(it).second)
236 return op->
emitOpError() <<
"nontemporal variable used more than once";
245 std::optional<ArrayAttr> alignments,
248 if (!alignedVars.empty()) {
249 if (!alignments || alignments->size() != alignedVars.size())
251 <<
"expected as many alignment values as aligned variables";
254 return op->
emitOpError() <<
"unexpected alignment values attribute";
260 for (
auto it : alignedVars)
261 if (!alignedItems.insert(it).second)
262 return op->
emitOpError() <<
"aligned variable used more than once";
268 for (
unsigned i = 0; i < (*alignments).size(); ++i) {
269 if (
auto intAttr = llvm::dyn_cast<IntegerAttr>((*alignments)[i])) {
270 if (intAttr.getValue().sle(0))
271 return op->
emitOpError() <<
"alignment should be greater than 0";
273 return op->
emitOpError() <<
"expected integer alignment";
287 ArrayAttr &alignmentsAttr) {
290 if (parser.parseOperand(alignedVars.emplace_back()) ||
291 parser.parseColonType(alignedTypes.emplace_back()) ||
292 parser.parseArrow() ||
293 parser.parseAttribute(alignmentVec.emplace_back())) {
307 std::optional<ArrayAttr> alignments) {
308 for (
unsigned i = 0; i < alignedVars.size(); ++i) {
311 p << alignedVars[i] <<
" : " << alignedVars[i].
getType();
312 p <<
" -> " << (*alignments)[i];
323 if (modifiers.size() > 2)
325 for (
const auto &mod : modifiers) {
328 auto symbol = symbolizeScheduleModifier(mod);
331 <<
" unknown modifier type: " << mod;
336 if (modifiers.size() == 1) {
337 if (symbolizeScheduleModifier(modifiers[0]) == ScheduleModifier::simd) {
338 modifiers.push_back(modifiers[0]);
339 modifiers[0] = stringifyScheduleModifier(ScheduleModifier::none);
341 }
else if (modifiers.size() == 2) {
344 if (symbolizeScheduleModifier(modifiers[0]) == ScheduleModifier::simd ||
345 symbolizeScheduleModifier(modifiers[1]) != ScheduleModifier::simd)
347 <<
" incorrect modifier order";
363 ScheduleModifierAttr &scheduleMod, UnitAttr &scheduleSimd,
364 std::optional<OpAsmParser::UnresolvedOperand> &chunkSize,
369 std::optional<mlir::omp::ClauseScheduleKind> schedule =
370 symbolizeClauseScheduleKind(keyword);
376 case ClauseScheduleKind::Static:
377 case ClauseScheduleKind::Dynamic:
378 case ClauseScheduleKind::Guided:
384 chunkSize = std::nullopt;
387 case ClauseScheduleKind::Auto:
389 chunkSize = std::nullopt;
398 modifiers.push_back(mod);
404 if (!modifiers.empty()) {
406 if (std::optional<ScheduleModifier> mod =
407 symbolizeScheduleModifier(modifiers[0])) {
410 return parser.
emitError(loc,
"invalid schedule modifier");
413 if (modifiers.size() > 1) {
414 assert(symbolizeScheduleModifier(modifiers[1]) == ScheduleModifier::simd);
424 ClauseScheduleKindAttr scheduleKind,
425 ScheduleModifierAttr scheduleMod,
426 UnitAttr scheduleSimd,
Value scheduleChunk,
427 Type scheduleChunkType) {
428 p << stringifyClauseScheduleKind(scheduleKind.getValue());
430 p <<
" = " << scheduleChunk <<
" : " << scheduleChunk.
getType();
432 p <<
", " << stringifyScheduleModifier(scheduleMod.getValue());
444 ClauseOrderKindAttr &order,
445 OrderModifierAttr &orderMod) {
450 if (std::optional<OrderModifier> enumValue =
451 symbolizeOrderModifier(enumStr)) {
459 if (std::optional<ClauseOrderKind> enumValue =
460 symbolizeClauseOrderKind(enumStr)) {
464 return parser.
emitError(loc,
"invalid clause value: '") << enumStr <<
"'";
468 ClauseOrderKindAttr order,
469 OrderModifierAttr orderMod) {
471 p << stringifyOrderModifier(orderMod.getValue()) <<
":";
473 p << stringifyClauseOrderKind(order.getValue());
476 template <
typename ClauseTypeAttr,
typename ClauseType>
479 std::optional<OpAsmParser::UnresolvedOperand> &operand,
481 std::optional<ClauseType> (*symbolizeClause)(StringRef),
482 StringRef clauseName) {
485 if (std::optional<ClauseType> enumValue = symbolizeClause(enumStr)) {
491 <<
"invalid " << clauseName <<
" modifier : '" << enumStr <<
"'";
501 <<
"expected " << clauseName <<
" operand";
504 if (operand.has_value()) {
512 template <
typename ClauseTypeAttr,
typename ClauseType>
515 ClauseTypeAttr prescriptiveness,
Value operand,
517 StringRef (*stringifyClauseType)(ClauseType)) {
519 if (prescriptiveness)
520 p << stringifyClauseType(prescriptiveness.getValue()) <<
", ";
523 p << operand <<
": " << operandType;
533 std::optional<OpAsmParser::UnresolvedOperand> &grainsize,
534 Type &grainsizeType) {
535 return parseGranularityClause<ClauseGrainsizeTypeAttr, ClauseGrainsizeType>(
536 parser, grainsizeMod, grainsize, grainsizeType,
537 &symbolizeClauseGrainsizeType,
"grainsize");
541 ClauseGrainsizeTypeAttr grainsizeMod,
543 printGranularityClause<ClauseGrainsizeTypeAttr, ClauseGrainsizeType>(
544 p, op, grainsizeMod, grainsize, grainsizeType,
545 &stringifyClauseGrainsizeType);
555 std::optional<OpAsmParser::UnresolvedOperand> &numTasks,
556 Type &numTasksType) {
557 return parseGranularityClause<ClauseNumTasksTypeAttr, ClauseNumTasksType>(
558 parser, numTasksMod, numTasks, numTasksType, &symbolizeClauseNumTasksType,
563 ClauseNumTasksTypeAttr numTasksMod,
565 printGranularityClause<ClauseNumTasksTypeAttr, ClauseNumTasksType>(
566 p, op, numTasksMod, numTasks, numTasksType, &stringifyClauseNumTasksType);
574 struct MapParseArgs {
579 : vars(vars), types(types) {}
581 struct PrivateParseArgs {
585 UnitAttr &needsBarrier;
589 UnitAttr &needsBarrier,
591 : vars(vars), types(types), syms(syms), needsBarrier(needsBarrier),
592 mapIndices(mapIndices) {}
595 struct ReductionParseArgs {
600 ReductionModifierAttr *modifier;
603 ArrayAttr &syms, ReductionModifierAttr *mod =
nullptr)
604 : vars(vars), types(types), byref(byref), syms(syms), modifier(mod) {}
607 struct AllRegionParseArgs {
608 std::optional<MapParseArgs> hasDeviceAddrArgs;
609 std::optional<MapParseArgs> hostEvalArgs;
610 std::optional<ReductionParseArgs> inReductionArgs;
611 std::optional<MapParseArgs> mapArgs;
612 std::optional<PrivateParseArgs> privateArgs;
613 std::optional<ReductionParseArgs> reductionArgs;
614 std::optional<ReductionParseArgs> taskReductionArgs;
615 std::optional<MapParseArgs> useDeviceAddrArgs;
616 std::optional<MapParseArgs> useDevicePtrArgs;
621 return "private_barrier";
631 ReductionModifierAttr *modifier =
nullptr,
632 UnitAttr *needsBarrier =
nullptr) {
636 unsigned regionArgOffset = regionPrivateArgs.size();
646 std::optional<ReductionModifier> enumValue =
647 symbolizeReductionModifier(enumStr);
648 if (!enumValue.has_value())
657 isByRefVec.push_back(
658 parser.parseOptionalKeyword(
"byref").succeeded());
660 if (symbols && parser.parseAttribute(symbolVec.emplace_back()))
663 if (parser.parseOperand(operands.emplace_back()) ||
664 parser.parseArrow() ||
665 parser.parseArgument(regionPrivateArgs.emplace_back()))
669 if (parser.parseOptionalLSquare().succeeded()) {
670 if (parser.parseKeyword(
"map_idx") || parser.parseEqual() ||
671 parser.parseInteger(mapIndicesVec.emplace_back()) ||
672 parser.parseRSquare())
675 mapIndicesVec.push_back(-1);
687 if (parser.parseType(types.emplace_back()))
694 if (operands.size() != types.size())
706 auto *argsBegin = regionPrivateArgs.begin();
708 argsBegin + regionArgOffset + types.size());
709 for (
auto [prv, type] : llvm::zip_equal(argsSubrange, types)) {
718 if (!mapIndicesVec.empty())
731 StringRef keyword, std::optional<MapParseArgs> mapArgs) {
746 StringRef keyword, std::optional<PrivateParseArgs> privateArgs) {
752 parser, privateArgs->vars, privateArgs->types, entryBlockArgs,
753 &privateArgs->syms, privateArgs->mapIndices,
nullptr,
754 nullptr, &privateArgs->needsBarrier)))
763 StringRef keyword, std::optional<ReductionParseArgs> reductionArgs) {
768 parser, reductionArgs->vars, reductionArgs->types, entryBlockArgs,
769 &reductionArgs->syms,
nullptr, &reductionArgs->byref,
770 reductionArgs->modifier)))
777 AllRegionParseArgs args) {
781 args.hasDeviceAddrArgs)))
783 <<
"invalid `has_device_addr` format";
788 <<
"invalid `host_eval` format";
791 args.inReductionArgs)))
793 <<
"invalid `in_reduction` format";
798 <<
"invalid `map_entries` format";
803 <<
"invalid `private` format";
806 args.reductionArgs)))
808 <<
"invalid `reduction` format";
811 args.taskReductionArgs)))
813 <<
"invalid `task_reduction` format";
816 args.useDeviceAddrArgs)))
818 <<
"invalid `use_device_addr` format";
821 args.useDevicePtrArgs)))
823 <<
"invalid `use_device_addr` format";
844 AllRegionParseArgs args;
845 args.hasDeviceAddrArgs.emplace(hasDeviceAddrVars, hasDeviceAddrTypes);
846 args.hostEvalArgs.emplace(hostEvalVars, hostEvalTypes);
847 args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
848 inReductionByref, inReductionSyms);
849 args.mapArgs.emplace(mapVars, mapTypes);
850 args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
851 privateNeedsBarrier, &privateMaps);
862 UnitAttr &privateNeedsBarrier) {
863 AllRegionParseArgs args;
864 args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
865 inReductionByref, inReductionSyms);
866 args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
867 privateNeedsBarrier);
878 UnitAttr &privateNeedsBarrier, ReductionModifierAttr &reductionMod,
881 ArrayAttr &reductionSyms) {
882 AllRegionParseArgs args;
883 args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
884 inReductionByref, inReductionSyms);
885 args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
886 privateNeedsBarrier);
887 args.reductionArgs.emplace(reductionVars, reductionTypes, reductionByref,
888 reductionSyms, &reductionMod);
896 UnitAttr &privateNeedsBarrier) {
897 AllRegionParseArgs args;
898 args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
899 privateNeedsBarrier);
907 UnitAttr &privateNeedsBarrier, ReductionModifierAttr &reductionMod,
910 ArrayAttr &reductionSyms) {
911 AllRegionParseArgs args;
912 args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
913 privateNeedsBarrier);
914 args.reductionArgs.emplace(reductionVars, reductionTypes, reductionByref,
915 reductionSyms, &reductionMod);
924 AllRegionParseArgs args;
925 args.taskReductionArgs.emplace(taskReductionVars, taskReductionTypes,
926 taskReductionByref, taskReductionSyms);
936 AllRegionParseArgs args;
937 args.useDeviceAddrArgs.emplace(useDeviceAddrVars, useDeviceAddrTypes);
938 args.useDevicePtrArgs.emplace(useDevicePtrVars, useDevicePtrTypes);
947 struct MapPrintArgs {
952 struct PrivatePrintArgs {
956 UnitAttr needsBarrier;
960 : vars(vars), types(types), syms(syms), needsBarrier(needsBarrier),
961 mapIndices(mapIndices) {}
963 struct ReductionPrintArgs {
968 ReductionModifierAttr modifier;
970 ArrayAttr syms, ReductionModifierAttr mod =
nullptr)
971 : vars(vars), types(types), byref(byref), syms(syms), modifier(mod) {}
973 struct AllRegionPrintArgs {
974 std::optional<MapPrintArgs> hasDeviceAddrArgs;
975 std::optional<MapPrintArgs> hostEvalArgs;
976 std::optional<ReductionPrintArgs> inReductionArgs;
977 std::optional<MapPrintArgs> mapArgs;
978 std::optional<PrivatePrintArgs> privateArgs;
979 std::optional<ReductionPrintArgs> reductionArgs;
980 std::optional<ReductionPrintArgs> taskReductionArgs;
981 std::optional<MapPrintArgs> useDeviceAddrArgs;
982 std::optional<MapPrintArgs> useDevicePtrArgs;
991 ReductionModifierAttr modifier =
nullptr, UnitAttr needsBarrier =
nullptr) {
992 if (argsSubrange.empty())
995 p << clauseName <<
"(";
998 p <<
"mod: " << stringifyReductionModifier(modifier.getValue()) <<
", ";
1015 llvm::interleaveComma(llvm::zip_equal(operands, argsSubrange, symbols,
1016 mapIndices.asArrayRef(),
1017 byref.asArrayRef()),
1019 auto [op, arg, sym, map, isByRef] = t;
1025 p << op <<
" -> " << arg;
1028 p <<
" [map_idx=" << map <<
"]";
1031 llvm::interleaveComma(types, p);
1039 StringRef clauseName,
ValueRange argsSubrange,
1040 std::optional<MapPrintArgs> mapArgs) {
1047 StringRef clauseName,
ValueRange argsSubrange,
1048 std::optional<PrivatePrintArgs> privateArgs) {
1051 p, ctx, clauseName, argsSubrange, privateArgs->vars, privateArgs->types,
1052 privateArgs->syms, privateArgs->mapIndices,
nullptr,
1053 nullptr, privateArgs->needsBarrier);
1059 std::optional<ReductionPrintArgs> reductionArgs) {
1062 reductionArgs->vars, reductionArgs->types,
1063 reductionArgs->syms,
nullptr,
1064 reductionArgs->byref, reductionArgs->modifier);
1068 const AllRegionPrintArgs &args) {
1069 auto iface = llvm::cast<mlir::omp::BlockArgOpenMPOpInterface>(op);
1073 iface.getHasDeviceAddrBlockArgs(),
1074 args.hasDeviceAddrArgs);
1078 args.inReductionArgs);
1084 args.reductionArgs);
1086 iface.getTaskReductionBlockArgs(),
1087 args.taskReductionArgs);
1089 iface.getUseDeviceAddrBlockArgs(),
1090 args.useDeviceAddrArgs);
1092 iface.getUseDevicePtrBlockArgs(), args.useDevicePtrArgs);
1106 TypeRange privateTypes, ArrayAttr privateSyms, UnitAttr privateNeedsBarrier,
1108 AllRegionPrintArgs args;
1109 args.hasDeviceAddrArgs.emplace(hasDeviceAddrVars, hasDeviceAddrTypes);
1110 args.hostEvalArgs.emplace(hostEvalVars, hostEvalTypes);
1111 args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
1112 inReductionByref, inReductionSyms);
1113 args.mapArgs.emplace(mapVars, mapTypes);
1114 args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
1115 privateNeedsBarrier, privateMaps);
1123 ArrayAttr privateSyms, UnitAttr privateNeedsBarrier) {
1124 AllRegionPrintArgs args;
1125 args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
1126 inReductionByref, inReductionSyms);
1127 args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
1128 privateNeedsBarrier,
1137 ArrayAttr privateSyms, UnitAttr privateNeedsBarrier,
1138 ReductionModifierAttr reductionMod,
ValueRange reductionVars,
1140 ArrayAttr reductionSyms) {
1141 AllRegionPrintArgs args;
1142 args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
1143 inReductionByref, inReductionSyms);
1144 args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
1145 privateNeedsBarrier,
1147 args.reductionArgs.emplace(reductionVars, reductionTypes, reductionByref,
1148 reductionSyms, reductionMod);
1154 ArrayAttr privateSyms,
1155 UnitAttr privateNeedsBarrier) {
1156 AllRegionPrintArgs args;
1157 args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
1158 privateNeedsBarrier,
1165 TypeRange privateTypes, ArrayAttr privateSyms, UnitAttr privateNeedsBarrier,
1166 ReductionModifierAttr reductionMod,
ValueRange reductionVars,
1168 ArrayAttr reductionSyms) {
1169 AllRegionPrintArgs args;
1170 args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
1171 privateNeedsBarrier,
1173 args.reductionArgs.emplace(reductionVars, reductionTypes, reductionByref,
1174 reductionSyms, reductionMod);
1183 ArrayAttr taskReductionSyms) {
1184 AllRegionPrintArgs args;
1185 args.taskReductionArgs.emplace(taskReductionVars, taskReductionTypes,
1186 taskReductionByref, taskReductionSyms);
1196 AllRegionPrintArgs args;
1197 args.useDeviceAddrArgs.emplace(useDeviceAddrVars, useDeviceAddrTypes);
1198 args.useDevicePtrArgs.emplace(useDevicePtrVars, useDevicePtrTypes);
1203 static LogicalResult
1207 if (!reductionVars.empty()) {
1208 if (!reductionSyms || reductionSyms->size() != reductionVars.size())
1210 <<
"expected as many reduction symbol references "
1211 "as reduction variables";
1212 if (reductionByref && reductionByref->size() != reductionVars.size())
1213 return op->
emitError() <<
"expected as many reduction variable by "
1214 "reference attributes as reduction variables";
1217 return op->
emitOpError() <<
"unexpected reduction symbol references";
1224 for (
auto args : llvm::zip(reductionVars, *reductionSyms)) {
1225 Value accum = std::get<0>(args);
1227 if (!accumulators.insert(accum).second)
1228 return op->
emitOpError() <<
"accumulator variable used more than once";
1231 auto symbolRef = llvm::cast<SymbolRefAttr>(std::get<1>(args));
1233 SymbolTable::lookupNearestSymbolFrom<DeclareReductionOp>(op, symbolRef);
1235 return op->
emitOpError() <<
"expected symbol reference " << symbolRef
1236 <<
" to point to a reduction declaration";
1238 if (decl.getAccumulatorType() && decl.getAccumulatorType() != varType)
1240 <<
"expected accumulator (" << varType
1241 <<
") to be the same type as reduction declaration ("
1242 << decl.getAccumulatorType() <<
")";
1261 if (parser.parseOperand(copyprivateVars.emplace_back()) ||
1262 parser.parseArrow() ||
1263 parser.parseAttribute(symsVec.emplace_back()) ||
1264 parser.parseColonType(copyprivateTypes.emplace_back()))
1278 std::optional<ArrayAttr> copyprivateSyms) {
1279 if (!copyprivateSyms.has_value())
1281 llvm::interleaveComma(
1282 llvm::zip(copyprivateVars, *copyprivateSyms, copyprivateTypes), p,
1283 [&](
const auto &args) {
1284 p << std::get<0>(args) <<
" -> " << std::get<1>(args) <<
" : "
1285 << std::get<2>(args);
1290 static LogicalResult
1292 std::optional<ArrayAttr> copyprivateSyms) {
1293 size_t copyprivateSymsSize =
1294 copyprivateSyms.has_value() ? copyprivateSyms->size() : 0;
1295 if (copyprivateSymsSize != copyprivateVars.size())
1296 return op->
emitOpError() <<
"inconsistent number of copyprivate vars (= "
1297 << copyprivateVars.size()
1298 <<
") and functions (= " << copyprivateSymsSize
1299 <<
"), both must be equal";
1300 if (!copyprivateSyms.has_value())
1303 for (
auto copyprivateVarAndSym :
1304 llvm::zip(copyprivateVars, *copyprivateSyms)) {
1306 llvm::cast<SymbolRefAttr>(std::get<1>(copyprivateVarAndSym));
1307 std::optional<std::variant<mlir::func::FuncOp, mlir::LLVM::LLVMFuncOp>>
1309 if (mlir::func::FuncOp mlirFuncOp =
1310 SymbolTable::lookupNearestSymbolFrom<mlir::func::FuncOp>(op,
1312 funcOp = mlirFuncOp;
1313 else if (mlir::LLVM::LLVMFuncOp llvmFuncOp =
1314 SymbolTable::lookupNearestSymbolFrom<mlir::LLVM::LLVMFuncOp>(
1316 funcOp = llvmFuncOp;
1318 auto getNumArguments = [&] {
1319 return std::visit([](
auto &f) {
return f.getNumArguments(); }, *funcOp);
1322 auto getArgumentType = [&](
unsigned i) {
1323 return std::visit([i](
auto &f) {
return f.getArgumentTypes()[i]; },
1328 return op->
emitOpError() <<
"expected symbol reference " << symbolRef
1329 <<
" to point to a copy function";
1331 if (getNumArguments() != 2)
1333 <<
"expected copy function " << symbolRef <<
" to have 2 operands";
1335 Type argTy = getArgumentType(0);
1336 if (argTy != getArgumentType(1))
1337 return op->
emitOpError() <<
"expected copy function " << symbolRef
1338 <<
" arguments to have the same type";
1340 Type varType = std::get<0>(copyprivateVarAndSym).getType();
1341 if (argTy != varType)
1343 <<
"expected copy function arguments' type (" << argTy
1344 <<
") to be the same as copyprivate variable's type (" << varType
1365 if (parser.parseKeyword(&keyword) || parser.parseArrow() ||
1366 parser.parseOperand(dependVars.emplace_back()) ||
1367 parser.parseColonType(dependTypes.emplace_back()))
1369 if (std::optional<ClauseTaskDepend> keywordDepend =
1370 (symbolizeClauseTaskDepend(keyword)))
1371 kindsVec.emplace_back(
1372 ClauseTaskDependAttr::get(parser.getContext(), *keywordDepend));
1386 std::optional<ArrayAttr> dependKinds) {
1388 for (
unsigned i = 0, e = dependKinds->size(); i < e; ++i) {
1391 p << stringifyClauseTaskDepend(
1392 llvm::cast<mlir::omp::ClauseTaskDependAttr>((*dependKinds)[i])
1394 <<
" -> " << dependVars[i] <<
" : " << dependTypes[i];
1400 std::optional<ArrayAttr> dependKinds,
1402 if (!dependVars.empty()) {
1403 if (!dependKinds || dependKinds->size() != dependVars.size())
1404 return op->
emitOpError() <<
"expected as many depend values"
1405 " as depend variables";
1407 if (dependKinds && !dependKinds->empty())
1408 return op->
emitOpError() <<
"unexpected depend values";
1424 IntegerAttr &hintAttr) {
1425 StringRef hintKeyword;
1431 auto parseKeyword = [&]() -> ParseResult {
1434 if (hintKeyword ==
"uncontended")
1436 else if (hintKeyword ==
"contended")
1438 else if (hintKeyword ==
"nonspeculative")
1440 else if (hintKeyword ==
"speculative")
1444 << hintKeyword <<
" is not a valid hint";
1455 IntegerAttr hintAttr) {
1456 int64_t hint = hintAttr.getInt();
1464 auto bitn = [](
int value,
int n) ->
bool {
return value & (1 << n); };
1466 bool uncontended = bitn(hint, 0);
1467 bool contended = bitn(hint, 1);
1468 bool nonspeculative = bitn(hint, 2);
1469 bool speculative = bitn(hint, 3);
1473 hints.push_back(
"uncontended");
1475 hints.push_back(
"contended");
1477 hints.push_back(
"nonspeculative");
1479 hints.push_back(
"speculative");
1481 llvm::interleaveComma(hints, p);
1488 auto bitn = [](
int value,
int n) ->
bool {
return value & (1 << n); };
1490 bool uncontended = bitn(hint, 0);
1491 bool contended = bitn(hint, 1);
1492 bool nonspeculative = bitn(hint, 2);
1493 bool speculative = bitn(hint, 3);
1495 if (uncontended && contended)
1496 return op->
emitOpError() <<
"the hints omp_sync_hint_uncontended and "
1497 "omp_sync_hint_contended cannot be combined";
1498 if (nonspeculative && speculative)
1499 return op->
emitOpError() <<
"the hints omp_sync_hint_nonspeculative and "
1500 "omp_sync_hint_speculative cannot be combined.";
1510 llvm::omp::OpenMPOffloadMappingFlags flag) {
1511 return value & llvm::to_underlying(flag);
1520 llvm::omp::OpenMPOffloadMappingFlags mapTypeBits =
1521 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_NONE;
1525 auto parseTypeAndMod = [&]() -> ParseResult {
1526 StringRef mapTypeMod;
1530 if (mapTypeMod ==
"always")
1531 mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ALWAYS;
1533 if (mapTypeMod ==
"implicit")
1534 mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_IMPLICIT;
1536 if (mapTypeMod ==
"ompx_hold")
1537 mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_OMPX_HOLD;
1539 if (mapTypeMod ==
"close")
1540 mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_CLOSE;
1542 if (mapTypeMod ==
"present")
1543 mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_PRESENT;
1545 if (mapTypeMod ==
"to")
1546 mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO;
1548 if (mapTypeMod ==
"from")
1549 mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM;
1551 if (mapTypeMod ==
"tofrom")
1552 mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO |
1553 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM;
1555 if (mapTypeMod ==
"delete")
1556 mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_DELETE;
1558 if (mapTypeMod ==
"return_param")
1559 mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_RETURN_PARAM;
1569 llvm::to_underlying(mapTypeBits));
1577 IntegerAttr mapType) {
1578 uint64_t mapTypeBits = mapType.getUInt();
1580 bool emitAllocRelease =
true;
1586 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ALWAYS))
1587 mapTypeStrs.push_back(
"always");
1589 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_IMPLICIT))
1590 mapTypeStrs.push_back(
"implicit");
1592 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_OMPX_HOLD))
1593 mapTypeStrs.push_back(
"ompx_hold");
1595 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_CLOSE))
1596 mapTypeStrs.push_back(
"close");
1598 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_PRESENT))
1599 mapTypeStrs.push_back(
"present");
1605 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO);
1607 mapTypeBits, llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM);
1609 emitAllocRelease =
false;
1610 mapTypeStrs.push_back(
"tofrom");
1612 emitAllocRelease =
false;
1613 mapTypeStrs.push_back(
"from");
1615 emitAllocRelease =
false;
1616 mapTypeStrs.push_back(
"to");
1619 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_DELETE)) {
1620 emitAllocRelease =
false;
1621 mapTypeStrs.push_back(
"delete");
1625 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_RETURN_PARAM)) {
1626 emitAllocRelease =
false;
1627 mapTypeStrs.push_back(
"return_param");
1629 if (emitAllocRelease)
1630 mapTypeStrs.push_back(
"exit_release_or_enter_alloc");
1632 for (
unsigned int i = 0; i < mapTypeStrs.size(); ++i) {
1633 p << mapTypeStrs[i];
1634 if (i + 1 < mapTypeStrs.size()) {
1641 ArrayAttr &membersIdx) {
1644 auto parseIndices = [&]() -> ParseResult {
1649 APInt(64, value,
false)));
1667 if (!memberIdxs.empty())
1674 ArrayAttr membersIdx) {
1678 llvm::interleaveComma(membersIdx, p, [&p](
Attribute v) {
1680 auto memberIdx = cast<ArrayAttr>(v);
1681 llvm::interleaveComma(memberIdx.getValue(), p, [&p](
Attribute v2) {
1682 p << cast<IntegerAttr>(v2).getInt();
1689 VariableCaptureKindAttr mapCaptureType) {
1690 std::string typeCapStr;
1691 llvm::raw_string_ostream typeCap(typeCapStr);
1692 if (mapCaptureType.getValue() == mlir::omp::VariableCaptureKind::ByRef)
1694 if (mapCaptureType.getValue() == mlir::omp::VariableCaptureKind::ByCopy)
1695 typeCap <<
"ByCopy";
1696 if (mapCaptureType.getValue() == mlir::omp::VariableCaptureKind::VLAType)
1697 typeCap <<
"VLAType";
1698 if (mapCaptureType.getValue() == mlir::omp::VariableCaptureKind::This)
1704 VariableCaptureKindAttr &mapCaptureType) {
1705 StringRef mapCaptureKey;
1709 if (mapCaptureKey ==
"This")
1711 parser.
getContext(), mlir::omp::VariableCaptureKind::This);
1712 if (mapCaptureKey ==
"ByRef")
1714 parser.
getContext(), mlir::omp::VariableCaptureKind::ByRef);
1715 if (mapCaptureKey ==
"ByCopy")
1717 parser.
getContext(), mlir::omp::VariableCaptureKind::ByCopy);
1718 if (mapCaptureKey ==
"VLAType")
1720 parser.
getContext(), mlir::omp::VariableCaptureKind::VLAType);
1729 for (
auto mapOp : mapVars) {
1730 if (!mapOp.getDefiningOp())
1733 if (
auto mapInfoOp = mapOp.getDefiningOp<mlir::omp::MapInfoOp>()) {
1734 uint64_t mapTypeBits = mapInfoOp.getMapType();
1737 mapTypeBits, llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO);
1739 mapTypeBits, llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM);
1741 mapTypeBits, llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_DELETE);
1744 mapTypeBits, llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ALWAYS);
1746 mapTypeBits, llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_CLOSE);
1748 mapTypeBits, llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_IMPLICIT);
1750 if ((isa<TargetDataOp>(op) || isa<TargetOp>(op)) && del)
1752 "to, from, tofrom and alloc map types are permitted");
1754 if (isa<TargetEnterDataOp>(op) && (from || del))
1755 return emitError(op->
getLoc(),
"to and alloc map types are permitted");
1757 if (isa<TargetExitDataOp>(op) && to)
1759 "from, release and delete map types are permitted");
1761 if (isa<TargetUpdateOp>(op)) {
1764 "at least one of to or from map types must be "
1765 "specified, other map types are not permitted");
1770 "at least one of to or from map types must be "
1771 "specified, other map types are not permitted");
1774 auto updateVar = mapInfoOp.getVarPtr();
1776 if ((to && from) || (to && updateFromVars.contains(updateVar)) ||
1777 (from && updateToVars.contains(updateVar))) {
1780 "either to or from map types can be specified, not both");
1783 if (always || close || implicit) {
1786 "present, mapper and iterator map type modifiers are permitted");
1789 to ? updateToVars.insert(updateVar) : updateFromVars.insert(updateVar);
1791 }
else if (!isa<DeclareMapperInfoOp>(op)) {
1793 "map argument is not a map entry operation");
1801 std::optional<DenseI64ArrayAttr> privateMapIndices =
1802 targetOp.getPrivateMapsAttr();
1805 if (!privateMapIndices.has_value() || !privateMapIndices.value())
1810 if (privateMapIndices.value().size() !=
1811 static_cast<int64_t
>(privateVars.size()))
1812 return emitError(targetOp.getLoc(),
"sizes of `private` operand range and "
1813 "`private_maps` attribute mismatch");
1823 StringRef clauseName,
1825 for (
Value var : vars)
1826 if (!llvm::isa_and_present<MapInfoOp>(var.getDefiningOp()))
1828 <<
"'" << clauseName
1829 <<
"' arguments must be defined by 'omp.map.info' ops";
1834 if (getMapperId() &&
1835 !SymbolTable::lookupNearestSymbolFrom<omp::DeclareMapperOp>(
1836 *
this, getMapperIdAttr())) {
1851 const TargetDataOperands &clauses) {
1852 TargetDataOp::build(builder, state, clauses.device, clauses.ifExpr,
1853 clauses.mapVars, clauses.useDeviceAddrVars,
1854 clauses.useDevicePtrVars);
1858 if (getMapVars().empty() && getUseDevicePtrVars().empty() &&
1859 getUseDeviceAddrVars().empty()) {
1861 "At least one of map, use_device_ptr_vars, or "
1862 "use_device_addr_vars operand must be present");
1866 getUseDevicePtrVars())))
1870 getUseDeviceAddrVars())))
1880 void TargetEnterDataOp::build(
1884 TargetEnterDataOp::build(builder, state,
1886 clauses.dependVars, clauses.device, clauses.ifExpr,
1887 clauses.mapVars, clauses.nowait);
1891 LogicalResult verifyDependVars =
1893 return failed(verifyDependVars) ? verifyDependVars
1904 TargetExitDataOp::build(builder, state,
1906 clauses.dependVars, clauses.device, clauses.ifExpr,
1907 clauses.mapVars, clauses.nowait);
1911 LogicalResult verifyDependVars =
1913 return failed(verifyDependVars) ? verifyDependVars
1924 TargetUpdateOp::build(builder, state,
makeArrayAttr(ctx, clauses.dependKinds),
1925 clauses.dependVars, clauses.device, clauses.ifExpr,
1926 clauses.mapVars, clauses.nowait);
1930 LogicalResult verifyDependVars =
1932 return failed(verifyDependVars) ? verifyDependVars
1941 const TargetOperands &clauses) {
1945 TargetOp::build(builder, state, {}, {},
1947 clauses.dependVars, clauses.device, clauses.hasDeviceAddrVars,
1948 clauses.hostEvalVars, clauses.ifExpr,
1950 nullptr, clauses.isDevicePtrVars,
1951 clauses.mapVars, clauses.nowait, clauses.privateVars,
1953 clauses.privateNeedsBarrier, clauses.threadLimit,
1962 getHasDeviceAddrVars())))
1971 LogicalResult TargetOp::verifyRegions() {
1972 auto teamsOps = getOps<TeamsOp>();
1973 if (std::distance(teamsOps.begin(), teamsOps.end()) > 1)
1974 return emitError(
"target containing multiple 'omp.teams' nested ops");
1977 Operation *capturedOp = getInnermostCapturedOmpOp();
1978 TargetRegionFlags execFlags = getKernelExecFlags(capturedOp);
1979 for (
Value hostEvalArg :
1980 cast<BlockArgOpenMPOpInterface>(getOperation()).getHostEvalBlockArgs()) {
1982 if (
auto teamsOp = dyn_cast<TeamsOp>(user)) {
1983 if (llvm::is_contained({teamsOp.getNumTeamsLower(),
1984 teamsOp.getNumTeamsUpper(),
1985 teamsOp.getThreadLimit()},
1989 return emitOpError() <<
"host_eval argument only legal as 'num_teams' "
1990 "and 'thread_limit' in 'omp.teams'";
1992 if (
auto parallelOp = dyn_cast<ParallelOp>(user)) {
1993 if (bitEnumContainsAny(execFlags, TargetRegionFlags::spmd) &&
1994 parallelOp->isAncestor(capturedOp) &&
1995 hostEvalArg == parallelOp.getNumThreads())
1998 return emitOpError()
1999 <<
"host_eval argument only legal as 'num_threads' in "
2000 "'omp.parallel' when representing target SPMD";
2002 if (
auto loopNestOp = dyn_cast<LoopNestOp>(user)) {
2003 if (bitEnumContainsAny(execFlags, TargetRegionFlags::trip_count) &&
2004 loopNestOp.getOperation() == capturedOp &&
2005 (llvm::is_contained(loopNestOp.getLoopLowerBounds(), hostEvalArg) ||
2006 llvm::is_contained(loopNestOp.getLoopUpperBounds(), hostEvalArg) ||
2007 llvm::is_contained(loopNestOp.getLoopSteps(), hostEvalArg)))
2010 return emitOpError() <<
"host_eval argument only legal as loop bounds "
2011 "and steps in 'omp.loop_nest' when trip count "
2012 "must be evaluated in the host";
2015 return emitOpError() <<
"host_eval argument illegal use in '"
2016 << user->getName() <<
"' operation";
2025 assert(rootOp &&
"expected valid operation");
2037 return WalkResult::advance();
2042 bool isOmpDialect = op->
getDialect() == ompDialect;
2044 if (!isOmpDialect || !hasRegions)
2045 return WalkResult::skip();
2051 if (checkSingleMandatoryExec) {
2056 if (successor->isReachable(parentBlock))
2057 return WalkResult::interrupt();
2059 for (
Block &block : *parentRegion)
2061 !domInfo.
dominates(parentBlock, &block))
2062 return WalkResult::interrupt();
2068 if (&sibling != op && !siblingAllowedFn(&sibling))
2069 return WalkResult::interrupt();
2074 return llvm::isa<LoopNestOp>(op) ? WalkResult::interrupt()
2075 : WalkResult::advance();
2081 Operation *TargetOp::getInnermostCapturedOmpOp() {
2094 if (
auto memOp = dyn_cast<MemoryEffectOpInterface>(sibling)) {
2097 memOp.getEffects(effects);
2098 return !llvm::any_of(
2100 return isa<MemoryEffects::Write>(effect.
getEffect()) &&
2101 isa<SideEffects::AutomaticAllocationScopeResource>(
2109 TargetRegionFlags TargetOp::getKernelExecFlags(
Operation *capturedOp) {
2114 assert((!capturedOp ||
2115 (targetOp && targetOp.getInnermostCapturedOmpOp() == capturedOp)) &&
2116 "unexpected captured op");
2119 if (!isa_and_present<LoopNestOp>(capturedOp))
2120 return TargetRegionFlags::generic;
2124 cast<LoopNestOp>(capturedOp).gatherWrappers(loopWrappers);
2125 assert(!loopWrappers.empty());
2127 LoopWrapperInterface *innermostWrapper = loopWrappers.begin();
2128 if (isa<SimdOp>(innermostWrapper))
2129 innermostWrapper = std::next(innermostWrapper);
2131 auto numWrappers = std::distance(innermostWrapper, loopWrappers.end());
2132 if (numWrappers != 1 && numWrappers != 2)
2133 return TargetRegionFlags::generic;
2136 if (numWrappers == 2) {
2137 if (!isa<WsloopOp>(innermostWrapper))
2138 return TargetRegionFlags::generic;
2140 innermostWrapper = std::next(innermostWrapper);
2141 if (!isa<DistributeOp>(innermostWrapper))
2142 return TargetRegionFlags::generic;
2145 if (!isa_and_present<ParallelOp>(parallelOp))
2146 return TargetRegionFlags::generic;
2149 if (!isa_and_present<TeamsOp>(teamsOp))
2150 return TargetRegionFlags::generic;
2152 if (teamsOp->
getParentOp() == targetOp.getOperation())
2153 return TargetRegionFlags::spmd | TargetRegionFlags::trip_count;
2156 else if (isa<DistributeOp, LoopOp>(innermostWrapper)) {
2158 if (!isa_and_present<TeamsOp>(teamsOp))
2159 return TargetRegionFlags::generic;
2161 if (teamsOp->
getParentOp() != targetOp.getOperation())
2162 return TargetRegionFlags::generic;
2164 if (isa<LoopOp>(innermostWrapper))
2165 return TargetRegionFlags::spmd | TargetRegionFlags::trip_count;
2175 Dialect *ompDialect = targetOp->getDialect();
2179 return sibling && (ompDialect != sibling->
getDialect() ||
2183 TargetRegionFlags result =
2184 TargetRegionFlags::generic | TargetRegionFlags::trip_count;
2189 while (nestedCapture->
getParentOp() != capturedOp)
2192 return isa<ParallelOp>(nestedCapture) ? result | TargetRegionFlags::spmd
2196 else if (isa<WsloopOp>(innermostWrapper)) {
2198 if (!isa_and_present<ParallelOp>(parallelOp))
2199 return TargetRegionFlags::generic;
2201 if (parallelOp->
getParentOp() == targetOp.getOperation())
2202 return TargetRegionFlags::spmd;
2205 return TargetRegionFlags::generic;
2214 ParallelOp::build(builder, state,
ValueRange(),
2221 state.addAttributes(attributes);
2225 const ParallelOperands &clauses) {
2227 ParallelOp::build(builder, state, clauses.allocateVars, clauses.allocatorVars,
2228 clauses.ifExpr, clauses.numThreads, clauses.privateVars,
2230 clauses.privateNeedsBarrier, clauses.procBindKind,
2231 clauses.reductionMod, clauses.reductionVars,
2236 template <
typename OpType>
2238 auto privateVars = op.getPrivateVars();
2239 auto privateSyms = op.getPrivateSymsAttr();
2241 if (privateVars.empty() && (privateSyms ==
nullptr || privateSyms.empty()))
2244 auto numPrivateVars = privateVars.size();
2245 auto numPrivateSyms = (privateSyms ==
nullptr) ? 0 : privateSyms.size();
2247 if (numPrivateVars != numPrivateSyms)
2248 return op.emitError() <<
"inconsistent number of private variables and "
2249 "privatizer op symbols, private vars: "
2251 <<
" vs. privatizer op symbols: " << numPrivateSyms;
2253 for (
auto privateVarInfo : llvm::zip_equal(privateVars, privateSyms)) {
2254 Type varType = std::get<0>(privateVarInfo).getType();
2255 SymbolRefAttr privateSym = cast<SymbolRefAttr>(std::get<1>(privateVarInfo));
2256 PrivateClauseOp privatizerOp =
2257 SymbolTable::lookupNearestSymbolFrom<PrivateClauseOp>(op, privateSym);
2259 if (privatizerOp ==
nullptr)
2260 return op.emitError() <<
"failed to lookup privatizer op with symbol: '"
2261 << privateSym <<
"'";
2263 Type privatizerType = privatizerOp.getArgType();
2265 if (privatizerType && (varType != privatizerType))
2266 return op.emitError()
2267 <<
"type mismatch between a "
2268 << (privatizerOp.getDataSharingType() ==
2269 DataSharingClauseType::Private
2272 <<
" variable and its privatizer op, var type: " << varType
2273 <<
" vs. privatizer op type: " << privatizerType;
2280 if (getAllocateVars().size() != getAllocatorVars().size())
2282 "expected equal sizes for allocate and allocator variables");
2288 getReductionByref());
2291 LogicalResult ParallelOp::verifyRegions() {
2292 auto distChildOps = getOps<DistributeOp>();
2293 int numDistChildOps = std::distance(distChildOps.begin(), distChildOps.end());
2294 if (numDistChildOps > 1)
2296 <<
"multiple 'omp.distribute' nested inside of 'omp.parallel'";
2298 if (numDistChildOps == 1) {
2301 <<
"'omp.composite' attribute missing from composite operation";
2304 Operation &distributeOp = **distChildOps.begin();
2306 if (&childOp == &distributeOp || ompDialect != childOp.getDialect())
2310 return emitError() <<
"unexpected OpenMP operation inside of composite "
2312 << childOp.getName();
2314 }
else if (isComposite()) {
2316 <<
"'omp.composite' attribute present in non-composite operation";
2333 const TeamsOperands &clauses) {
2336 TeamsOp::build(builder, state, clauses.allocateVars, clauses.allocatorVars,
2337 clauses.ifExpr, clauses.numTeamsLower, clauses.numTeamsUpper,
2339 nullptr, clauses.reductionMod,
2340 clauses.reductionVars,
2343 clauses.threadLimit);
2355 return emitError(
"expected to be nested inside of omp.target or not nested "
2356 "in any OpenMP dialect operations");
2359 if (
auto numTeamsLowerBound = getNumTeamsLower()) {
2360 auto numTeamsUpperBound = getNumTeamsUpper();
2361 if (!numTeamsUpperBound)
2362 return emitError(
"expected num_teams upper bound to be defined if the "
2363 "lower bound is defined");
2364 if (numTeamsLowerBound.getType() != numTeamsUpperBound.getType())
2366 "expected num_teams upper bound and lower bound to be the same type");
2370 if (getAllocateVars().size() != getAllocatorVars().size())
2372 "expected equal sizes for allocate and allocator variables");
2375 getReductionByref());
2383 return getParentOp().getPrivateVars();
2387 return getParentOp().getReductionVars();
2395 const SectionsOperands &clauses) {
2398 SectionsOp::build(builder, state, clauses.allocateVars, clauses.allocatorVars,
2401 clauses.reductionMod, clauses.reductionVars,
2407 if (getAllocateVars().size() != getAllocatorVars().size())
2409 "expected equal sizes for allocate and allocator variables");
2412 getReductionByref());
2415 LogicalResult SectionsOp::verifyRegions() {
2416 for (
auto &inst : *getRegion().begin()) {
2417 if (!(isa<SectionOp>(inst) || isa<TerminatorOp>(inst))) {
2418 return emitOpError()
2419 <<
"expected omp.section op or terminator op inside region";
2431 const SingleOperands &clauses) {
2434 SingleOp::build(builder, state, clauses.allocateVars, clauses.allocatorVars,
2435 clauses.copyprivateVars,
2436 makeArrayAttr(ctx, clauses.copyprivateSyms), clauses.nowait,
2443 if (getAllocateVars().size() != getAllocatorVars().size())
2445 "expected equal sizes for allocate and allocator variables");
2448 getCopyprivateSyms());
2456 const WorkshareOperands &clauses) {
2457 WorkshareOp::build(builder, state, clauses.nowait);
2465 if (!(*this)->getParentOfType<WorkshareOp>())
2466 return emitOpError() <<
"must be nested in an omp.workshare";
2470 LogicalResult WorkshareLoopWrapperOp::verifyRegions() {
2471 if (isa_and_nonnull<LoopWrapperInterface>((*this)->getParentOp()) ||
2473 return emitOpError() <<
"expected to be a standalone loop wrapper";
2482 LogicalResult LoopWrapperInterface::verifyImpl() {
2486 return emitOpError() <<
"loop wrapper must also have the `NoTerminator` "
2487 "and `SingleBlock` traits";
2490 return emitOpError() <<
"loop wrapper does not contain exactly one region";
2493 if (range_size(region.
getOps()) != 1)
2494 return emitOpError()
2495 <<
"loop wrapper does not contain exactly one nested op";
2498 if (!isa<LoopNestOp, LoopWrapperInterface>(firstOp))
2499 return emitOpError() <<
"nested in loop wrapper is not another loop "
2500 "wrapper or `omp.loop_nest`";
2510 const LoopOperands &clauses) {
2513 LoopOp::build(builder, state, clauses.bindKind, clauses.privateVars,
2515 clauses.privateNeedsBarrier, clauses.order, clauses.orderMod,
2516 clauses.reductionMod, clauses.reductionVars,
2523 getReductionByref());
2526 LogicalResult LoopOp::verifyRegions() {
2527 if (llvm::isa_and_nonnull<LoopWrapperInterface>((*this)->getParentOp()) ||
2529 return emitOpError() <<
"expected to be a standalone loop wrapper";
2540 build(builder, state, {}, {},
2542 false,
nullptr,
nullptr,
2543 nullptr, {},
nullptr,
2550 state.addAttributes(attributes);
2554 const WsloopOperands &clauses) {
2559 {}, {}, clauses.linearVars,
2560 clauses.linearStepVars, clauses.nowait, clauses.order, clauses.orderMod,
2561 clauses.ordered, clauses.privateVars,
2562 makeArrayAttr(ctx, clauses.privateSyms), clauses.privateNeedsBarrier,
2563 clauses.reductionMod, clauses.reductionVars,
2565 makeArrayAttr(ctx, clauses.reductionSyms), clauses.scheduleKind,
2566 clauses.scheduleChunk, clauses.scheduleMod, clauses.scheduleSimd);
2571 getReductionByref());
2574 LogicalResult WsloopOp::verifyRegions() {
2575 bool isCompositeChildLeaf =
2576 llvm::dyn_cast_if_present<LoopWrapperInterface>((*this)->getParentOp());
2578 if (LoopWrapperInterface nested = getNestedWrapper()) {
2581 <<
"'omp.composite' attribute missing from composite wrapper";
2585 if (!isa<SimdOp>(nested))
2586 return emitError() <<
"only supported nested wrapper is 'omp.simd'";
2588 }
else if (isComposite() && !isCompositeChildLeaf) {
2590 <<
"'omp.composite' attribute present in non-composite wrapper";
2591 }
else if (!isComposite() && isCompositeChildLeaf) {
2593 <<
"'omp.composite' attribute missing from composite wrapper";
2604 const SimdOperands &clauses) {
2607 SimdOp::build(builder, state, clauses.alignedVars,
2610 clauses.nontemporalVars, clauses.order, clauses.orderMod,
2611 clauses.privateVars,
makeArrayAttr(ctx, clauses.privateSyms),
2612 clauses.privateNeedsBarrier, clauses.reductionMod,
2613 clauses.reductionVars,
2620 if (getSimdlen().has_value() && getSafelen().has_value() &&
2621 getSimdlen().value() > getSafelen().value())
2622 return emitOpError()
2623 <<
"simdlen clause and safelen clause are both present, but the "
2624 "simdlen value is not less than or equal to safelen value";
2632 bool isCompositeChildLeaf =
2633 llvm::dyn_cast_if_present<LoopWrapperInterface>((*this)->getParentOp());
2635 if (!isComposite() && isCompositeChildLeaf)
2637 <<
"'omp.composite' attribute missing from composite wrapper";
2639 if (isComposite() && !isCompositeChildLeaf)
2641 <<
"'omp.composite' attribute present in non-composite wrapper";
2645 std::optional<ArrayAttr> privateSyms = getPrivateSyms();
2647 for (
const Attribute &sym : *privateSyms) {
2648 auto symRef = cast<SymbolRefAttr>(sym);
2649 omp::PrivateClauseOp privatizer =
2650 SymbolTable::lookupNearestSymbolFrom<omp::PrivateClauseOp>(
2651 getOperation(), symRef);
2653 return emitError() <<
"Cannot find privatizer '" << symRef <<
"'";
2654 if (privatizer.getDataSharingType() ==
2655 DataSharingClauseType::FirstPrivate)
2656 return emitError() <<
"FIRSTPRIVATE cannot be used with SIMD";
2663 LogicalResult SimdOp::verifyRegions() {
2664 if (getNestedWrapper())
2665 return emitOpError() <<
"must wrap an 'omp.loop_nest' directly";
2675 const DistributeOperands &clauses) {
2676 DistributeOp::build(builder, state, clauses.allocateVars,
2677 clauses.allocatorVars, clauses.distScheduleStatic,
2678 clauses.distScheduleChunkSize, clauses.order,
2679 clauses.orderMod, clauses.privateVars,
2681 clauses.privateNeedsBarrier);
2685 if (this->getDistScheduleChunkSize() && !this->getDistScheduleStatic())
2686 return emitOpError() <<
"chunk size set without "
2687 "dist_schedule_static being present";
2689 if (getAllocateVars().size() != getAllocatorVars().size())
2691 "expected equal sizes for allocate and allocator variables");
2696 LogicalResult DistributeOp::verifyRegions() {
2697 if (LoopWrapperInterface nested = getNestedWrapper()) {
2700 <<
"'omp.composite' attribute missing from composite wrapper";
2703 if (isa<WsloopOp>(nested)) {
2705 if (!llvm::dyn_cast_if_present<ParallelOp>(parentOp) ||
2706 !cast<ComposableOpInterface>(parentOp).isComposite()) {
2707 return emitError() <<
"an 'omp.wsloop' nested wrapper is only allowed "
2708 "when a composite 'omp.parallel' is the direct "
2711 }
else if (!isa<SimdOp>(nested))
2712 return emitError() <<
"only supported nested wrappers are 'omp.simd' and "
2714 }
else if (isComposite()) {
2716 <<
"'omp.composite' attribute present in non-composite wrapper";
2730 LogicalResult DeclareMapperOp::verifyRegions() {
2731 if (!llvm::isa_and_present<DeclareMapperInfoOp>(
2732 getRegion().getBlocks().front().getTerminator()))
2733 return emitOpError() <<
"expected terminator to be a DeclareMapperInfoOp";
2742 LogicalResult DeclareReductionOp::verifyRegions() {
2743 if (!getAllocRegion().empty()) {
2744 for (YieldOp yieldOp : getAllocRegion().getOps<YieldOp>()) {
2745 if (yieldOp.getResults().size() != 1 ||
2746 yieldOp.getResults().getTypes()[0] !=
getType())
2747 return emitOpError() <<
"expects alloc region to yield a value "
2748 "of the reduction type";
2752 if (getInitializerRegion().empty())
2753 return emitOpError() <<
"expects non-empty initializer region";
2754 Block &initializerEntryBlock = getInitializerRegion().
front();
2757 if (!getAllocRegion().empty())
2758 return emitOpError() <<
"expects two arguments to the initializer region "
2759 "when an allocation region is used";
2761 if (getAllocRegion().empty())
2762 return emitOpError() <<
"expects one argument to the initializer region "
2763 "when no allocation region is used";
2765 return emitOpError()
2766 <<
"expects one or two arguments to the initializer region";
2770 if (arg.getType() !=
getType())
2771 return emitOpError() <<
"expects initializer region argument to match "
2772 "the reduction type";
2774 for (YieldOp yieldOp : getInitializerRegion().getOps<YieldOp>()) {
2775 if (yieldOp.getResults().size() != 1 ||
2776 yieldOp.getResults().getTypes()[0] !=
getType())
2777 return emitOpError() <<
"expects initializer region to yield a value "
2778 "of the reduction type";
2781 if (getReductionRegion().empty())
2782 return emitOpError() <<
"expects non-empty reduction region";
2783 Block &reductionEntryBlock = getReductionRegion().
front();
2788 return emitOpError() <<
"expects reduction region with two arguments of "
2789 "the reduction type";
2790 for (YieldOp yieldOp : getReductionRegion().getOps<YieldOp>()) {
2791 if (yieldOp.getResults().size() != 1 ||
2792 yieldOp.getResults().getTypes()[0] !=
getType())
2793 return emitOpError() <<
"expects reduction region to yield a value "
2794 "of the reduction type";
2797 if (!getAtomicReductionRegion().empty()) {
2798 Block &atomicReductionEntryBlock = getAtomicReductionRegion().
front();
2802 return emitOpError() <<
"expects atomic reduction region with two "
2803 "arguments of the same type";
2804 auto ptrType = llvm::dyn_cast<PointerLikeType>(
2807 (ptrType.getElementType() && ptrType.getElementType() !=
getType()))
2808 return emitOpError() <<
"expects atomic reduction region arguments to "
2809 "be accumulators containing the reduction type";
2812 if (getCleanupRegion().empty())
2814 Block &cleanupEntryBlock = getCleanupRegion().
front();
2817 return emitOpError() <<
"expects cleanup region with one argument "
2818 "of the reduction type";
2828 const TaskOperands &clauses) {
2830 TaskOp::build(builder, state, clauses.allocateVars, clauses.allocatorVars,
2831 makeArrayAttr(ctx, clauses.dependKinds), clauses.dependVars,
2832 clauses.final, clauses.ifExpr, clauses.inReductionVars,
2834 makeArrayAttr(ctx, clauses.inReductionSyms), clauses.mergeable,
2835 clauses.priority, clauses.privateVars,
2837 clauses.privateNeedsBarrier, clauses.untied,
2838 clauses.eventHandle);
2842 LogicalResult verifyDependVars =
2844 return failed(verifyDependVars)
2847 getInReductionVars(),
2848 getInReductionByref());
2856 const TaskgroupOperands &clauses) {
2858 TaskgroupOp::build(builder, state, clauses.allocateVars,
2859 clauses.allocatorVars, clauses.taskReductionVars,
2866 getTaskReductionVars(),
2867 getTaskReductionByref());
2875 const TaskloopOperands &clauses) {
2878 builder, state, clauses.allocateVars, clauses.allocatorVars,
2879 clauses.final, clauses.grainsizeMod, clauses.grainsize, clauses.ifExpr,
2880 clauses.inReductionVars,
2882 makeArrayAttr(ctx, clauses.inReductionSyms), clauses.mergeable,
2883 clauses.nogroup, clauses.numTasksMod, clauses.numTasks, clauses.priority,
2884 clauses.privateVars,
2886 clauses.privateNeedsBarrier, clauses.reductionMod, clauses.reductionVars,
2892 if (getAllocateVars().size() != getAllocatorVars().size())
2894 "expected equal sizes for allocate and allocator variables");
2896 getReductionVars(), getReductionByref())) ||
2898 getInReductionVars(),
2899 getInReductionByref())))
2902 if (!getReductionVars().empty() && getNogroup())
2903 return emitError(
"if a reduction clause is present on the taskloop "
2904 "directive, the nogroup clause must not be specified");
2905 for (
auto var : getReductionVars()) {
2906 if (llvm::is_contained(getInReductionVars(), var))
2907 return emitError(
"the same list item cannot appear in both a reduction "
2908 "and an in_reduction clause");
2911 if (getGrainsize() && getNumTasks()) {
2913 "the grainsize clause and num_tasks clause are mutually exclusive and "
2914 "may not appear on the same taskloop directive");
2920 LogicalResult TaskloopOp::verifyRegions() {
2921 if (LoopWrapperInterface nested = getNestedWrapper()) {
2924 <<
"'omp.composite' attribute missing from composite wrapper";
2928 if (!isa<SimdOp>(nested))
2929 return emitError() <<
"only supported nested wrapper is 'omp.simd'";
2930 }
else if (isComposite()) {
2932 <<
"'omp.composite' attribute present in non-composite wrapper";
2956 for (
auto &iv : ivs)
2957 iv.type = loopVarType;
2986 Region ®ion = getRegion();
2988 p <<
" (" << args <<
") : " << args[0].getType() <<
" = ("
2989 << getLoopLowerBounds() <<
") to (" << getLoopUpperBounds() <<
") ";
2990 if (getLoopInclusive())
2992 p <<
"step (" << getLoopSteps() <<
") ";
2997 const LoopNestOperands &clauses) {
2998 LoopNestOp::build(builder, state, clauses.loopLowerBounds,
2999 clauses.loopUpperBounds, clauses.loopSteps,
3000 clauses.loopInclusive);
3004 if (getLoopLowerBounds().empty())
3005 return emitOpError() <<
"must represent at least one loop";
3007 if (getLoopLowerBounds().size() != getIVs().size())
3008 return emitOpError() <<
"number of range arguments and IVs do not match";
3010 for (
auto [lb, iv] : llvm::zip_equal(getLoopLowerBounds(), getIVs())) {
3011 if (lb.getType() != iv.getType())
3012 return emitOpError()
3013 <<
"range argument type does not match corresponding IV type";
3016 if (!llvm::dyn_cast_if_present<LoopWrapperInterface>((*this)->getParentOp()))
3017 return emitOpError() <<
"expects parent op to be a loop wrapper";
3022 void LoopNestOp::gatherWrappers(
3025 while (
auto wrapper =
3026 llvm::dyn_cast_if_present<LoopWrapperInterface>(parent)) {
3027 wrappers.push_back(wrapper);
3036 std::tuple<NewCliOp, OpOperand *, OpOperand *>
3042 return {{},
nullptr,
nullptr};
3045 "Unexpected type of cli");
3051 auto op = cast<LoopTransformationInterface>(use.getOwner());
3053 unsigned opnum = use.getOperandNumber();
3054 if (op.isGeneratee(opnum)) {
3055 assert(!gen &&
"Each CLI may have at most one def");
3057 }
else if (op.isApplyee(opnum)) {
3058 assert(!cons &&
"Each CLI may have at most one consumer");
3061 llvm_unreachable(
"Unexpected operand for a CLI");
3065 return {create, gen, cons};
3074 Value result = getResult();
3075 auto [newCli, gen, cons] =
decodeCli(result);
3085 std::string cliName{
"cli"};
3089 .Case([&](CanonicalLoopOp op) {
3103 llvm::ReversePostOrderTraversal<Block *> traversal(
3106 for (
Block *b : traversal) {
3112 if (!op.getRegions().empty())
3116 llvm_unreachable(
"Operation not part of the region");
3118 size_t sequentialIdx = getSequentialIndex(r, o);
3119 components.push_back((
"s" + Twine(sequentialIdx)).str());
3129 for (
auto [idx, region] :
3134 llvm_unreachable(
"Region not child its parent operation");
3136 size_t regionIdx = getRegionIndex(parent, r);
3137 components.push_back((
"r" + Twine(regionIdx)).str());
3145 for (std::string s : reverse(components)) {
3152 .Case([&](UnrollHeuristicOp op) -> std::string {
3153 llvm_unreachable(
"heuristic unrolling does not generate a loop");
3156 assert(
false &&
"TODO: Custom name for this operation");
3157 return "transformed";
3161 setNameFn(result, cliName);
3165 Value cli = getResult();
3168 "Unexpected type of cli");
3174 auto op = cast<mlir::omp::LoopTransformationInterface>(use.getOwner());
3176 unsigned opnum = use.getOperandNumber();
3177 if (op.isGeneratee(opnum)) {
3180 emitOpError(
"CLI must have at most one generator");
3182 .
append(
"first generator here:");
3184 .
append(
"second generator here:");
3189 }
else if (op.isApplyee(opnum)) {
3192 emitOpError(
"CLI must have at most one consumer");
3194 .
append(
"first consumer here:")
3198 .
append(
"second consumer here:")
3205 llvm_unreachable(
"Unexpected operand for a CLI");
3213 .
append(
"see consumer here: ")
3236 setNameFn(&getRegion().front(),
"body_entry");
3239 void CanonicalLoopOp::getAsmBlockArgumentNames(
Region ®ion,
3246 p <<
'(' << getCli() <<
')';
3247 p <<
' ' << getInductionVar() <<
" : " << getInductionVar().getType()
3248 <<
" in range(" << getTripCount() <<
") ";
3258 CanonicalLoopInfoType cliType =
3284 if (parser.
parseRegion(*region, {inductionVariable}))
3289 result.
operands.append(cliOperand);
3295 return mlir::success();
3301 if (!getRegion().empty()) {
3302 Region ®ion = getRegion();
3305 "Canonical loop region must have exactly one argument");
3309 "Region argument must be the same type as the trip count");
3315 Value CanonicalLoopOp::getInductionVar() {
return getRegion().getArgument(0); }
3317 std::pair<unsigned, unsigned>
3318 CanonicalLoopOp::getApplyeesODSOperandIndexAndLength() {
3323 std::pair<unsigned, unsigned>
3324 CanonicalLoopOp::getGenerateesODSOperandIndexAndLength() {
3325 return getODSOperandIndexAndLength(odsIndex_cli);
3339 p <<
'(' << getApplyee() <<
')';
3369 return mlir::success();
3372 std::pair<unsigned, unsigned>
3373 UnrollHeuristicOp ::getApplyeesODSOperandIndexAndLength() {
3374 return getODSOperandIndexAndLength(odsIndex_applyee);
3377 std::pair<unsigned, unsigned>
3378 UnrollHeuristicOp::getGenerateesODSOperandIndexAndLength() {
3387 const CriticalDeclareOperands &clauses) {
3388 CriticalDeclareOp::build(builder, state, clauses.symName, clauses.hint);
3396 if (getNameAttr()) {
3397 SymbolRefAttr symbolRef = getNameAttr();
3401 return emitOpError() <<
"expected symbol reference " << symbolRef
3402 <<
" to point to a critical declaration";
3422 return op.
emitOpError() <<
"must be nested inside of a loop";
3426 if (
auto wsloopOp = dyn_cast<WsloopOp>(wrapper)) {
3427 IntegerAttr orderedAttr = wsloopOp.getOrderedAttr();
3429 return op.
emitOpError() <<
"the enclosing worksharing-loop region must "
3430 "have an ordered clause";
3432 if (hasRegion && orderedAttr.getInt() != 0)
3433 return op.
emitOpError() <<
"the enclosing loop's ordered clause must not "
3434 "have a parameter present";
3436 if (!hasRegion && orderedAttr.getInt() == 0)
3437 return op.
emitOpError() <<
"the enclosing loop's ordered clause must "
3438 "have a parameter present";
3439 }
else if (!isa<SimdOp>(wrapper)) {
3440 return op.
emitOpError() <<
"must be nested inside of a worksharing, simd "
3441 "or worksharing simd loop";
3447 const OrderedOperands &clauses) {
3448 OrderedOp::build(builder, state, clauses.doacrossDependType,
3449 clauses.doacrossNumLoops, clauses.doacrossDependVars);
3456 auto wrapper = (*this)->getParentOfType<WsloopOp>();
3457 if (!wrapper || *wrapper.getOrdered() != *getDoacrossNumLoops())
3458 return emitOpError() <<
"number of variables in depend clause does not "
3459 <<
"match number of iteration variables in the "
3466 const OrderedRegionOperands &clauses) {
3467 OrderedRegionOp::build(builder, state, clauses.parLevelSimd);
3477 const TaskwaitOperands &clauses) {
3479 TaskwaitOp::build(builder, state,
nullptr,
3488 if (verifyCommon().failed())
3489 return mlir::failure();
3491 if (
auto mo = getMemoryOrder()) {
3492 if (*mo == ClauseMemoryOrderKind::Acq_rel ||
3493 *mo == ClauseMemoryOrderKind::Release) {
3495 "memory-order must not be acq_rel or release for atomic reads");
3506 if (verifyCommon().failed())
3507 return mlir::failure();
3509 if (
auto mo = getMemoryOrder()) {
3510 if (*mo == ClauseMemoryOrderKind::Acq_rel ||
3511 *mo == ClauseMemoryOrderKind::Acquire) {
3513 "memory-order must not be acq_rel or acquire for atomic writes");
3523 LogicalResult AtomicUpdateOp::canonicalize(AtomicUpdateOp op,
3529 if (
Value writeVal = op.getWriteOpVal()) {
3531 op, op.getX(), writeVal, op.getHintAttr(), op.getMemoryOrderAttr());
3538 if (verifyCommon().failed())
3539 return mlir::failure();
3541 if (
auto mo = getMemoryOrder()) {
3542 if (*mo == ClauseMemoryOrderKind::Acq_rel ||
3543 *mo == ClauseMemoryOrderKind::Acquire) {
3545 "memory-order must not be acq_rel or acquire for atomic updates");
3552 LogicalResult AtomicUpdateOp::verifyRegions() {
return verifyRegionsCommon(); }
3558 AtomicReadOp AtomicCaptureOp::getAtomicReadOp() {
3559 if (
auto op = dyn_cast<AtomicReadOp>(getFirstOp()))
3561 return dyn_cast<AtomicReadOp>(getSecondOp());
3564 AtomicWriteOp AtomicCaptureOp::getAtomicWriteOp() {
3565 if (
auto op = dyn_cast<AtomicWriteOp>(getFirstOp()))
3567 return dyn_cast<AtomicWriteOp>(getSecondOp());
3570 AtomicUpdateOp AtomicCaptureOp::getAtomicUpdateOp() {
3571 if (
auto op = dyn_cast<AtomicUpdateOp>(getFirstOp()))
3573 return dyn_cast<AtomicUpdateOp>(getSecondOp());
3580 LogicalResult AtomicCaptureOp::verifyRegions() {
3581 if (verifyRegionsCommon().failed())
3582 return mlir::failure();
3584 if (getFirstOp()->getAttr(
"hint") || getSecondOp()->getAttr(
"hint"))
3586 "operations inside capture region must not have hint clause");
3588 if (getFirstOp()->getAttr(
"memory_order") ||
3589 getSecondOp()->getAttr(
"memory_order"))
3591 "operations inside capture region must not have memory_order clause");
3600 const CancelOperands &clauses) {
3601 CancelOp::build(builder, state, clauses.cancelDirective, clauses.ifExpr);
3615 ClauseCancellationConstructType cct = getCancelDirective();
3618 if (!structuralParent)
3619 return emitOpError() <<
"Orphaned cancel construct";
3621 if ((cct == ClauseCancellationConstructType::Parallel) &&
3622 !mlir::isa<ParallelOp>(structuralParent)) {
3623 return emitOpError() <<
"cancel parallel must appear "
3624 <<
"inside a parallel region";
3626 if (cct == ClauseCancellationConstructType::Loop) {
3629 auto wsloopOp = mlir::dyn_cast<WsloopOp>(structuralParent->
getParentOp());
3632 return emitOpError()
3633 <<
"cancel loop must appear inside a worksharing-loop region";
3635 if (wsloopOp.getNowaitAttr()) {
3636 return emitError() <<
"A worksharing construct that is canceled "
3637 <<
"must not have a nowait clause";
3639 if (wsloopOp.getOrderedAttr()) {
3640 return emitError() <<
"A worksharing construct that is canceled "
3641 <<
"must not have an ordered clause";
3644 }
else if (cct == ClauseCancellationConstructType::Sections) {
3648 mlir::dyn_cast<SectionsOp>(structuralParent->
getParentOp());
3650 return emitOpError() <<
"cancel sections must appear "
3651 <<
"inside a sections region";
3653 if (sectionsOp.getNowait()) {
3654 return emitError() <<
"A sections construct that is canceled "
3655 <<
"must not have a nowait clause";
3658 if ((cct == ClauseCancellationConstructType::Taskgroup) &&
3659 (!mlir::isa<omp::TaskOp>(structuralParent) &&
3660 !mlir::isa<omp::TaskloopOp>(structuralParent->
getParentOp()))) {
3661 return emitOpError() <<
"cancel taskgroup must appear "
3662 <<
"inside a task region";
3672 const CancellationPointOperands &clauses) {
3673 CancellationPointOp::build(builder, state, clauses.cancelDirective);
3677 ClauseCancellationConstructType cct = getCancelDirective();
3680 if (!structuralParent)
3681 return emitOpError() <<
"Orphaned cancellation point";
3683 if ((cct == ClauseCancellationConstructType::Parallel) &&
3684 !mlir::isa<ParallelOp>(structuralParent)) {
3685 return emitOpError() <<
"cancellation point parallel must appear "
3686 <<
"inside a parallel region";
3690 if ((cct == ClauseCancellationConstructType::Loop) &&
3691 !mlir::isa<WsloopOp>(structuralParent->
getParentOp())) {
3692 return emitOpError() <<
"cancellation point loop must appear "
3693 <<
"inside a worksharing-loop region";
3695 if ((cct == ClauseCancellationConstructType::Sections) &&
3696 !mlir::isa<omp::SectionOp>(structuralParent)) {
3697 return emitOpError() <<
"cancellation point sections must appear "
3698 <<
"inside a sections region";
3700 if ((cct == ClauseCancellationConstructType::Taskgroup) &&
3701 !mlir::isa<omp::TaskOp>(structuralParent)) {
3702 return emitOpError() <<
"cancellation point taskgroup must appear "
3703 <<
"inside a task region";
3713 auto extent = getExtent();
3715 if (!extent && !upperbound)
3716 return emitError(
"expected extent or upperbound.");
3723 PrivateClauseOp::build(
3724 odsBuilder, odsState, symName, type,
3726 DataSharingClauseType::Private));
3729 LogicalResult PrivateClauseOp::verifyRegions() {
3730 Type argType = getArgType();
3731 auto verifyTerminator = [&](
Operation *terminator,
3732 bool yieldsValue) -> LogicalResult {
3736 if (!llvm::isa<YieldOp>(terminator))
3738 <<
"expected exit block terminator to be an `omp.yield` op.";
3740 YieldOp yieldOp = llvm::cast<YieldOp>(terminator);
3741 TypeRange yieldedTypes = yieldOp.getResults().getTypes();
3744 if (yieldedTypes.empty())
3748 <<
"Did not expect any values to be yielded.";
3751 if (yieldedTypes.size() == 1 && yieldedTypes.front() == argType)
3755 <<
"Invalid yielded value. Expected type: " << argType
3758 if (yieldedTypes.empty())
3761 error << yieldedTypes;
3767 StringRef regionName,
3768 bool yieldsValue) -> LogicalResult {
3769 assert(!region.
empty());
3773 <<
"`" << regionName <<
"`: "
3774 <<
"expected " << expectedNumArgs
3777 for (
Block &block : region) {
3779 if (!block.mightHaveTerminator())
3782 if (failed(verifyTerminator(block.getTerminator(), yieldsValue)))
3790 for (
Region *region : getRegions())
3791 for (
Type ty : region->getArgumentTypes())
3793 return emitError() <<
"Region argument type mismatch: got " << ty
3794 <<
" expected " << argType <<
".";
3797 if (!initRegion.
empty() &&
3802 DataSharingClauseType dsType = getDataSharingType();
3804 if (dsType == DataSharingClauseType::Private && !getCopyRegion().empty())
3805 return emitError(
"`private` clauses do not require a `copy` region.");
3807 if (dsType == DataSharingClauseType::FirstPrivate && getCopyRegion().empty())
3809 "`firstprivate` clauses require at least a `copy` region.");
3811 if (dsType == DataSharingClauseType::FirstPrivate &&
3816 if (!getDeallocRegion().empty() &&
3829 const MaskedOperands &clauses) {
3830 MaskedOp::build(builder, state, clauses.filteredThreadId);
3838 const ScanOperands &clauses) {
3839 ScanOp::build(builder, state, clauses.inclusiveVars, clauses.exclusiveVars);
3843 if (hasExclusiveVars() == hasInclusiveVars())
3845 "Exactly one of EXCLUSIVE or INCLUSIVE clause is expected");
3846 if (WsloopOp parentWsLoopOp = (*this)->getParentOfType<WsloopOp>()) {
3847 if (parentWsLoopOp.getReductionModAttr() &&
3848 parentWsLoopOp.getReductionModAttr().getValue() ==
3849 ReductionModifier::inscan)
3852 if (SimdOp parentSimdOp = (*this)->getParentOfType<SimdOp>()) {
3853 if (parentSimdOp.getReductionModAttr() &&
3854 parentSimdOp.getReductionModAttr().getValue() ==
3855 ReductionModifier::inscan)
3858 return emitError(
"SCAN directive needs to be enclosed within a parent "
3859 "worksharing loop construct or SIMD construct with INSCAN "
3860 "reduction modifier");
3866 std::optional<uint64_t> align = this->getAlign();
3868 if (align.has_value()) {
3869 if ((align.value() > 0) && !llvm::has_single_bit(align.value()))
3870 return emitError() <<
"ALIGN value : " << align.value()
3871 <<
" must be power of 2";
3877 #define GET_ATTRDEF_CLASSES
3878 #include "mlir/Dialect/OpenMP/OpenMPOpsAttributes.cpp.inc"
3880 #define GET_OP_CLASSES
3881 #include "mlir/Dialect/OpenMP/OpenMPOps.cpp.inc"
3883 #define GET_TYPEDEF_CLASSES
3884 #include "mlir/Dialect/OpenMP/OpenMPOpsTypes.cpp.inc"
static std::optional< int64_t > getUpperBound(Value iv)
Gets the constant upper bound on an affine.for iv.
static LogicalResult verifyRegion(emitc::SwitchOp op, Region ®ion, const Twine &name)
static void visit(Operation *op, DenseSet< Operation * > &visited)
Visits all the pdl.operand(s), pdl.result(s), and pdl.operation(s) connected to the given operation.
static MLIRContext * getContext(OpFoldResult val)
static Type getElementType(Type type)
Determine the element type of type.
void printClauseAttr(OpAsmPrinter &p, Operation *op, ClauseAttr attr)
static LogicalResult verifyNontemporalClause(Operation *op, OperandRange nontemporalVars)
static constexpr StringRef getPrivateNeedsBarrierSpelling()
static LogicalResult verifyMapClause(Operation *op, OperandRange mapVars)
static ParseResult parseInReductionPrivateRegion(OpAsmParser &parser, Region ®ion, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &inReductionVars, SmallVectorImpl< Type > &inReductionTypes, DenseBoolArrayAttr &inReductionByref, ArrayAttr &inReductionSyms, llvm::SmallVectorImpl< OpAsmParser::UnresolvedOperand > &privateVars, llvm::SmallVectorImpl< Type > &privateTypes, ArrayAttr &privateSyms, UnitAttr &privateNeedsBarrier)
static ArrayAttr makeArrayAttr(MLIRContext *context, llvm::ArrayRef< Attribute > attrs)
static ParseResult parseClauseAttr(AsmParser &parser, ClauseAttr &attr)
static void printAllocateAndAllocator(OpAsmPrinter &p, Operation *op, OperandRange allocateVars, TypeRange allocateTypes, OperandRange allocatorVars, TypeRange allocatorTypes)
Print allocate clause.
static DenseBoolArrayAttr makeDenseBoolArrayAttr(MLIRContext *ctx, const ArrayRef< bool > boolArray)
static void printClauseWithRegionArgs(OpAsmPrinter &p, MLIRContext *ctx, StringRef clauseName, ValueRange argsSubrange, ValueRange operands, TypeRange types, ArrayAttr symbols=nullptr, DenseI64ArrayAttr mapIndices=nullptr, DenseBoolArrayAttr byref=nullptr, ReductionModifierAttr modifier=nullptr, UnitAttr needsBarrier=nullptr)
static void printBlockArgClause(OpAsmPrinter &p, MLIRContext *ctx, StringRef clauseName, ValueRange argsSubrange, std::optional< MapPrintArgs > mapArgs)
static void printBlockArgRegion(OpAsmPrinter &p, Operation *op, Region ®ion, const AllRegionPrintArgs &args)
static ParseResult parseGranularityClause(OpAsmParser &parser, ClauseTypeAttr &prescriptiveness, std::optional< OpAsmParser::UnresolvedOperand > &operand, Type &operandType, std::optional< ClauseType >(*symbolizeClause)(StringRef), StringRef clauseName)
static ParseResult parseBlockArgRegion(OpAsmParser &parser, Region ®ion, AllRegionParseArgs args)
static ParseResult parseSynchronizationHint(OpAsmParser &parser, IntegerAttr &hintAttr)
Parses a Synchronization Hint clause.
uint64_t mapTypeToBitFlag(uint64_t value, llvm::omp::OpenMPOffloadMappingFlags flag)
static ParseResult parseLinearClause(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &linearVars, SmallVectorImpl< Type > &linearTypes, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &linearStepVars)
linear ::= linear ( linear-list ) linear-list := linear-val | linear-val linear-list linear-val := ss...
static void printScheduleClause(OpAsmPrinter &p, Operation *op, ClauseScheduleKindAttr scheduleKind, ScheduleModifierAttr scheduleMod, UnitAttr scheduleSimd, Value scheduleChunk, Type scheduleChunkType)
Print schedule clause.
static void printCopyprivate(OpAsmPrinter &p, Operation *op, OperandRange copyprivateVars, TypeRange copyprivateTypes, std::optional< ArrayAttr > copyprivateSyms)
Print Copyprivate clause.
static ParseResult parseOrderClause(OpAsmParser &parser, ClauseOrderKindAttr &order, OrderModifierAttr &orderMod)
static void printAlignedClause(OpAsmPrinter &p, Operation *op, ValueRange alignedVars, TypeRange alignedTypes, std::optional< ArrayAttr > alignments)
Print Aligned Clause.
static LogicalResult verifySynchronizationHint(Operation *op, uint64_t hint)
Verifies a synchronization hint clause.
static ParseResult parseUseDeviceAddrUseDevicePtrRegion(OpAsmParser &parser, Region ®ion, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &useDeviceAddrVars, SmallVectorImpl< Type > &useDeviceAddrTypes, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &useDevicePtrVars, SmallVectorImpl< Type > &useDevicePtrTypes)
static void printLinearClause(OpAsmPrinter &p, Operation *op, ValueRange linearVars, TypeRange linearTypes, ValueRange linearStepVars)
Print Linear Clause.
static void printInReductionPrivateReductionRegion(OpAsmPrinter &p, Operation *op, Region ®ion, ValueRange inReductionVars, TypeRange inReductionTypes, DenseBoolArrayAttr inReductionByref, ArrayAttr inReductionSyms, ValueRange privateVars, TypeRange privateTypes, ArrayAttr privateSyms, UnitAttr privateNeedsBarrier, ReductionModifierAttr reductionMod, ValueRange reductionVars, TypeRange reductionTypes, DenseBoolArrayAttr reductionByref, ArrayAttr reductionSyms)
static void printInReductionPrivateRegion(OpAsmPrinter &p, Operation *op, Region ®ion, ValueRange inReductionVars, TypeRange inReductionTypes, DenseBoolArrayAttr inReductionByref, ArrayAttr inReductionSyms, ValueRange privateVars, TypeRange privateTypes, ArrayAttr privateSyms, UnitAttr privateNeedsBarrier)
static void printSynchronizationHint(OpAsmPrinter &p, Operation *op, IntegerAttr hintAttr)
Prints a Synchronization Hint clause.
static void printGranularityClause(OpAsmPrinter &p, Operation *op, ClauseTypeAttr prescriptiveness, Value operand, mlir::Type operandType, StringRef(*stringifyClauseType)(ClauseType))
static void printDependVarList(OpAsmPrinter &p, Operation *op, OperandRange dependVars, TypeRange dependTypes, std::optional< ArrayAttr > dependKinds)
Print Depend clause.
static void printTargetOpRegion(OpAsmPrinter &p, Operation *op, Region ®ion, ValueRange hasDeviceAddrVars, TypeRange hasDeviceAddrTypes, ValueRange hostEvalVars, TypeRange hostEvalTypes, ValueRange inReductionVars, TypeRange inReductionTypes, DenseBoolArrayAttr inReductionByref, ArrayAttr inReductionSyms, ValueRange mapVars, TypeRange mapTypes, ValueRange privateVars, TypeRange privateTypes, ArrayAttr privateSyms, UnitAttr privateNeedsBarrier, DenseI64ArrayAttr privateMaps)
static LogicalResult verifyCopyprivateVarList(Operation *op, OperandRange copyprivateVars, std::optional< ArrayAttr > copyprivateSyms)
Verifies CopyPrivate Clause.
static LogicalResult verifyAlignedClause(Operation *op, std::optional< ArrayAttr > alignments, OperandRange alignedVars)
static ParseResult parseTargetOpRegion(OpAsmParser &parser, Region ®ion, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &hasDeviceAddrVars, SmallVectorImpl< Type > &hasDeviceAddrTypes, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &hostEvalVars, SmallVectorImpl< Type > &hostEvalTypes, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &inReductionVars, SmallVectorImpl< Type > &inReductionTypes, DenseBoolArrayAttr &inReductionByref, ArrayAttr &inReductionSyms, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &mapVars, SmallVectorImpl< Type > &mapTypes, llvm::SmallVectorImpl< OpAsmParser::UnresolvedOperand > &privateVars, llvm::SmallVectorImpl< Type > &privateTypes, ArrayAttr &privateSyms, UnitAttr &privateNeedsBarrier, DenseI64ArrayAttr &privateMaps)
static ParseResult parsePrivateRegion(OpAsmParser &parser, Region ®ion, llvm::SmallVectorImpl< OpAsmParser::UnresolvedOperand > &privateVars, llvm::SmallVectorImpl< Type > &privateTypes, ArrayAttr &privateSyms, UnitAttr &privateNeedsBarrier)
static void printNumTasksClause(OpAsmPrinter &p, Operation *op, ClauseNumTasksTypeAttr numTasksMod, Value numTasks, mlir::Type numTasksType)
static void printPrivateRegion(OpAsmPrinter &p, Operation *op, Region ®ion, ValueRange privateVars, TypeRange privateTypes, ArrayAttr privateSyms, UnitAttr privateNeedsBarrier)
static void printPrivateReductionRegion(OpAsmPrinter &p, Operation *op, Region ®ion, ValueRange privateVars, TypeRange privateTypes, ArrayAttr privateSyms, UnitAttr privateNeedsBarrier, ReductionModifierAttr reductionMod, ValueRange reductionVars, TypeRange reductionTypes, DenseBoolArrayAttr reductionByref, ArrayAttr reductionSyms)
static void printTaskReductionRegion(OpAsmPrinter &p, Operation *op, Region ®ion, ValueRange taskReductionVars, TypeRange taskReductionTypes, DenseBoolArrayAttr taskReductionByref, ArrayAttr taskReductionSyms)
static ParseResult parseMapClause(OpAsmParser &parser, IntegerAttr &mapType)
Parses a map_entries map type from a string format back into its numeric value.
static LogicalResult verifyOrderedParent(Operation &op)
static void printOrderClause(OpAsmPrinter &p, Operation *op, ClauseOrderKindAttr order, OrderModifierAttr orderMod)
static ParseResult parseBlockArgClause(OpAsmParser &parser, llvm::SmallVectorImpl< OpAsmParser::Argument > &entryBlockArgs, StringRef keyword, std::optional< MapParseArgs > mapArgs)
static ParseResult parseClauseWithRegionArgs(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &operands, SmallVectorImpl< Type > &types, SmallVectorImpl< OpAsmParser::Argument > ®ionPrivateArgs, ArrayAttr *symbols=nullptr, DenseI64ArrayAttr *mapIndices=nullptr, DenseBoolArrayAttr *byref=nullptr, ReductionModifierAttr *modifier=nullptr, UnitAttr *needsBarrier=nullptr)
static ParseResult verifyScheduleModifiers(OpAsmParser &parser, SmallVectorImpl< SmallString< 12 >> &modifiers)
static LogicalResult verifyPrivateVarsMapping(TargetOp targetOp)
static ParseResult parseScheduleClause(OpAsmParser &parser, ClauseScheduleKindAttr &scheduleAttr, ScheduleModifierAttr &scheduleMod, UnitAttr &scheduleSimd, std::optional< OpAsmParser::UnresolvedOperand > &chunkSize, Type &chunkType)
schedule ::= schedule ( sched-list ) sched-list ::= sched-val | sched-val sched-list | sched-val ,...
static Operation * getParentInSameDialect(Operation *thisOp)
static ParseResult parseAllocateAndAllocator(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &allocateVars, SmallVectorImpl< Type > &allocateTypes, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &allocatorVars, SmallVectorImpl< Type > &allocatorTypes)
Parse an allocate clause with allocators and a list of operands with types.
static void printMembersIndex(OpAsmPrinter &p, MapInfoOp op, ArrayAttr membersIdx)
static void printCaptureType(OpAsmPrinter &p, Operation *op, VariableCaptureKindAttr mapCaptureType)
static LogicalResult verifyReductionVarList(Operation *op, std::optional< ArrayAttr > reductionSyms, OperandRange reductionVars, std::optional< ArrayRef< bool >> reductionByref)
Verifies Reduction Clause.
static Operation * findCapturedOmpOp(Operation *rootOp, bool checkSingleMandatoryExec, llvm::function_ref< bool(Operation *)> siblingAllowedFn)
static bool opInGlobalImplicitParallelRegion(Operation *op)
static void printUseDeviceAddrUseDevicePtrRegion(OpAsmPrinter &p, Operation *op, Region ®ion, ValueRange useDeviceAddrVars, TypeRange useDeviceAddrTypes, ValueRange useDevicePtrVars, TypeRange useDevicePtrTypes)
static LogicalResult verifyPrivateVarList(OpType &op)
static void printMapClause(OpAsmPrinter &p, Operation *op, IntegerAttr mapType)
Prints a map_entries map type from its numeric value out into its string format.
static ParseResult parseNumTasksClause(OpAsmParser &parser, ClauseNumTasksTypeAttr &numTasksMod, std::optional< OpAsmParser::UnresolvedOperand > &numTasks, Type &numTasksType)
static ParseResult parseMembersIndex(OpAsmParser &parser, ArrayAttr &membersIdx)
static ParseResult parseAlignedClause(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &alignedVars, SmallVectorImpl< Type > &alignedTypes, ArrayAttr &alignmentsAttr)
aligned ::= aligned ( aligned-list ) aligned-list := aligned-val | aligned-val aligned-list aligned-v...
static ParseResult parsePrivateReductionRegion(OpAsmParser &parser, Region ®ion, llvm::SmallVectorImpl< OpAsmParser::UnresolvedOperand > &privateVars, llvm::SmallVectorImpl< Type > &privateTypes, ArrayAttr &privateSyms, UnitAttr &privateNeedsBarrier, ReductionModifierAttr &reductionMod, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &reductionVars, SmallVectorImpl< Type > &reductionTypes, DenseBoolArrayAttr &reductionByref, ArrayAttr &reductionSyms)
static ParseResult parseInReductionPrivateReductionRegion(OpAsmParser &parser, Region ®ion, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &inReductionVars, SmallVectorImpl< Type > &inReductionTypes, DenseBoolArrayAttr &inReductionByref, ArrayAttr &inReductionSyms, llvm::SmallVectorImpl< OpAsmParser::UnresolvedOperand > &privateVars, llvm::SmallVectorImpl< Type > &privateTypes, ArrayAttr &privateSyms, UnitAttr &privateNeedsBarrier, ReductionModifierAttr &reductionMod, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &reductionVars, SmallVectorImpl< Type > &reductionTypes, DenseBoolArrayAttr &reductionByref, ArrayAttr &reductionSyms)
static ParseResult parseCaptureType(OpAsmParser &parser, VariableCaptureKindAttr &mapCaptureType)
static ParseResult parseTaskReductionRegion(OpAsmParser &parser, Region ®ion, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &taskReductionVars, SmallVectorImpl< Type > &taskReductionTypes, DenseBoolArrayAttr &taskReductionByref, ArrayAttr &taskReductionSyms)
static ParseResult parseGrainsizeClause(OpAsmParser &parser, ClauseGrainsizeTypeAttr &grainsizeMod, std::optional< OpAsmParser::UnresolvedOperand > &grainsize, Type &grainsizeType)
static ParseResult parseCopyprivate(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > ©privateVars, SmallVectorImpl< Type > ©privateTypes, ArrayAttr ©privateSyms)
copyprivate-entry-list ::= copyprivate-entry | copyprivate-entry-list , copyprivate-entry copyprivate...
static LogicalResult verifyDependVarList(Operation *op, std::optional< ArrayAttr > dependKinds, OperandRange dependVars)
Verifies Depend clause.
static ParseResult parseDependVarList(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &dependVars, SmallVectorImpl< Type > &dependTypes, ArrayAttr &dependKinds)
depend-entry-list ::= depend-entry | depend-entry-list , depend-entry depend-entry ::= depend-kind ->...
static LogicalResult verifyMapInfoDefinedArgs(Operation *op, StringRef clauseName, OperandRange vars)
static void printGrainsizeClause(OpAsmPrinter &p, Operation *op, ClauseGrainsizeTypeAttr grainsizeMod, Value grainsize, mlir::Type grainsizeType)
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
This base class exposes generic asm parser hooks, usable across the various derived parsers.
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalEqual()=0
Parse a = token if present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseOptionalColon()=0
Parse a : token if present.
virtual ParseResult parseLSquare()=0
Parse a [ token.
virtual ParseResult parseRSquare()=0
Parse a ] token.
ParseResult parseInteger(IntT &result)
Parse an integer value from the stream.
virtual ParseResult parseOptionalArrow()=0
Parse a '->' token if present.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual ParseResult parseColon()=0
Parse a : token.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseArrow()=0
Parse a '->' token.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseComma()=0
Parse a , token.
virtual ParseResult parseOptionalLParen()=0
Parse a ( token if present.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
Attributes are known-constant values of operations.
Block represents an ordered list of Operations.
ValueTypeRange< BlockArgListType > getArgumentTypes()
Return a range containing the types of the arguments for this block.
BlockArgument getArgument(unsigned i)
unsigned getNumArguments()
SuccessorRange getSuccessors()
BlockArgListType getArguments()
IntegerAttr getIntegerAttr(Type type, int64_t value)
IntegerType getIntegerType(unsigned width)
MLIRContext * getContext() const
Diagnostic & append(Arg1 &&arg1, Arg2 &&arg2, Args &&...args)
Append arguments to the diagnostic.
Diagnostic & appendOp(Operation &op, const OpPrintingFlags &flags)
Append an operation with the given printing flags.
Dialects are groups of MLIR operations, types and attributes, as well as behavior associated with the...
A class for computing basic dominance information.
bool dominates(Operation *a, Operation *b) const
Return true if operation A dominates operation B, i.e.
This class represents a diagnostic that is inflight and set to be reported.
Diagnostic & attachNote(std::optional< Location > noteLoc=std::nullopt)
Attaches a note to this diagnostic.
MLIRContext is the top-level object for a collection of MLIR operations.
Dialect * getLoadedDialect(StringRef name)
Get a registered IR dialect with the given namespace.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region ®ion, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult parseArgument(Argument &result, bool allowType=false, bool allowAttrs=false)=0
Parse a single argument with the following syntax:
virtual ParseResult parseArgumentList(SmallVectorImpl< Argument > &result, Delimiter delimiter=Delimiter::None, bool allowType=false, bool allowAttrs=false)=0
Parse zero or more arguments with a specified surrounding delimiter.
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
This class helps build Operations.
This class represents an operand of an operation.
Set of flags used to control the behavior of the various IR print methods (e.g.
This class provides the API for ops that are known to be isolated from above.
This class provides the API for ops that are known to be terminators.
This class indicates that the regions associated with this op don't have terminators.
This class implements the operand iterators for the Operation class.
Operation is the basic unit of execution within MLIR.
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Dialect * getDialect()
Return the dialect this operation is associated with, or nullptr if the associated dialect is not loa...
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
MLIRContext * getContext()
Return the context this operation is associated with.
unsigned getNumRegions()
Returns the number of regions held by this operation.
Location getLoc()
The source location the operation was defined or derived from.
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Block * getBlock()
Returns the operation block that contains this operation.
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
user_range getUsers()
Returns a range of all users.
Region * getParentRegion()
Returns the region to which the instruction belongs.
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class contains a list of basic blocks and a link to the parent operation it is attached to.
iterator_range< OpIterator > getOps()
BlockArgListType getArguments()
OpIterator op_begin()
Return iterators that walk the operations nested directly within this region.
Operation * getParentOp()
Return the parent operation this region is attached to.
unsigned getNumArguments()
BlockListType & getBlocks()
Location getLoc()
Return a location for this region.
BlockArgument getArgument(unsigned i)
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class represents a specific instance of an effect.
Resource * getResource() const
Return the resource that the effect applies to.
EffectT * getEffect() const
Return the effect being applied.
This class represents a collection of SymbolTables.
virtual Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
type_range getType() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
MLIRContext * getContext() const
Utility to get the associated MLIRContext that this value is defined in.
Type getType() const
Return the type of this value.
use_range getUses() const
Returns a range of all uses, which is useful for iterating over all uses.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Base class for DenseArrayAttr that is instantiated and specialized for each supported element type be...
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< T > content)
Builder from ArrayRef<T>.
bool isReachableFromEntry(Block *a) const
Return true if the specified block is reachable from the entry block of its region.
Operation * getOwner() const
Return the owner of this operand.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Runtime
Potential runtimes for AMD GPU kernels.
std::tuple< NewCliOp, OpOperand *, OpOperand * > decodeCli(mlir::Value cli)
Find the omp.new_cli, generator, and consumer of a canonical loop info.
TargetEnterDataOperands TargetEnterExitUpdateDataOperands
omp.target_enter_data, omp.target_exit_data and omp.target_update take the same clauses,...
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
This is the representation of an operand reference.
This class provides APIs and verifiers for ops with regions having a single block.
This represents an operation in an abstracted form, suitable for use with the builder APIs.
SmallVector< Value, 4 > operands
void addOperands(ValueRange newOperands)
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
void addTypes(ArrayRef< Type > newTypes)
Region * addRegion()
Create a region that should be attached to the operation.