26 #include "llvm/ADT/ArrayRef.h"
27 #include "llvm/ADT/PostOrderIterator.h"
28 #include "llvm/ADT/STLExtras.h"
29 #include "llvm/ADT/STLForwardCompat.h"
30 #include "llvm/ADT/SmallString.h"
31 #include "llvm/ADT/StringExtras.h"
32 #include "llvm/ADT/StringRef.h"
33 #include "llvm/ADT/TypeSwitch.h"
34 #include "llvm/ADT/bit.h"
35 #include "llvm/Frontend/OpenMP/OMPConstants.h"
36 #include "llvm/Support/InterleavedRange.h"
42 #include "mlir/Dialect/OpenMP/OpenMPOpsDialect.cpp.inc"
43 #include "mlir/Dialect/OpenMP/OpenMPOpsEnums.cpp.inc"
44 #include "mlir/Dialect/OpenMP/OpenMPOpsInterfaces.cpp.inc"
45 #include "mlir/Dialect/OpenMP/OpenMPTypeInterfaces.cpp.inc"
66 struct MemRefPointerLikeModel
67 :
public PointerLikeType::ExternalModel<MemRefPointerLikeModel,
70 return llvm::cast<MemRefType>(pointer).getElementType();
74 struct LLVMPointerPointerLikeModel
75 :
public PointerLikeType::ExternalModel<LLVMPointerPointerLikeModel,
76 LLVM::LLVMPointerType> {
101 bool isRegionArgOfOp;
111 assert(isRegionArgOfOp &&
"Must describe a region operand");
114 size_t &getArgIdx() {
115 assert(isRegionArgOfOp &&
"Must describe a region operand");
120 assert(!isRegionArgOfOp &&
"Must describe a operation of a region");
124 assert(!isRegionArgOfOp &&
"Must describe a operation of a region");
127 bool isLoopOp()
const {
128 assert(!isRegionArgOfOp &&
"Must describe a operation of a region");
129 return isa<CanonicalLoopOp>(op);
131 Region *&getParentRegion() {
132 assert(!isRegionArgOfOp &&
"Must describe a operation of a region");
135 size_t &getLoopDepth() {
136 assert(!isRegionArgOfOp &&
"Must describe a operation of a region");
140 void skipIf(
bool v =
true) { skip = skip || v; }
158 llvm::ReversePostOrderTraversal<Block *> traversal(&r->
getBlocks().front());
161 size_t sequentialIdx = -1;
162 bool isOnlyContainerOp =
true;
163 for (
Block *b : traversal) {
165 if (&op == o && !found) {
169 if (op.getNumRegions()) {
172 isOnlyContainerOp =
false;
174 if (found && !isOnlyContainerOp)
179 Component &containerOpInRegion = components.emplace_back();
180 containerOpInRegion.isRegionArgOfOp =
false;
181 containerOpInRegion.isUnique = isOnlyContainerOp;
182 containerOpInRegion.getContainerOp() = o;
183 containerOpInRegion.getOpPos() = sequentialIdx;
184 containerOpInRegion.getParentRegion() = r;
189 Component ®ionArgOfOperation = components.emplace_back();
190 regionArgOfOperation.isRegionArgOfOp =
true;
191 regionArgOfOperation.isUnique =
true;
192 regionArgOfOperation.getArgIdx() = 0;
193 regionArgOfOperation.getOwnerOp() = parent;
209 llvm_unreachable(
"Region not child of its parent operation");
211 regionArgOfOperation.isUnique =
false;
212 regionArgOfOperation.getArgIdx() = getRegionIndex(parent, r);
220 for (Component &c : components)
221 c.skipIf(c.isRegionArgOfOp && c.isUnique);
224 size_t numSurroundingLoops = 0;
225 for (Component &c : llvm::reverse(components)) {
230 if (c.isRegionArgOfOp) {
231 numSurroundingLoops = 0;
238 numSurroundingLoops = 0;
240 c.getLoopDepth() = numSurroundingLoops;
243 if (isa<CanonicalLoopOp>(c.getContainerOp()))
244 numSurroundingLoops += 1;
249 bool isLoopNest =
false;
250 for (Component &c : components) {
251 if (c.skip || c.isRegionArgOfOp)
254 if (!isLoopNest && c.getLoopDepth() >= 1) {
257 }
else if (isLoopNest) {
259 c.skipIf(c.isUnique);
263 if (c.getLoopDepth() == 0)
270 for (Component &c : components)
271 c.skipIf(!c.isRegionArgOfOp && c.isUnique &&
272 !isa<CanonicalLoopOp>(c.getContainerOp()));
276 bool newRegion =
true;
277 for (Component &c : llvm::reverse(components)) {
278 c.skipIf(newRegion && c.isUnique);
285 if (!c.isRegionArgOfOp && c.getContainerOp())
291 llvm::raw_svector_ostream NameOS(Name);
292 for (
auto &c : llvm::reverse(components)) {
296 if (c.isRegionArgOfOp)
297 NameOS <<
"_r" << c.getArgIdx();
298 else if (c.getLoopDepth() >= 1)
299 NameOS <<
"_d" << c.getLoopDepth();
301 NameOS <<
"_s" << c.getOpPos();
304 return NameOS.str().str();
307 void OpenMPDialect::initialize() {
310 #include "mlir/Dialect/OpenMP/OpenMPOps.cpp.inc"
313 #define GET_ATTRDEF_LIST
314 #include "mlir/Dialect/OpenMP/OpenMPOpsAttributes.cpp.inc"
317 #define GET_TYPEDEF_LIST
318 #include "mlir/Dialect/OpenMP/OpenMPOpsTypes.cpp.inc"
321 declarePromisedInterface<ConvertToLLVMPatternInterface, OpenMPDialect>();
323 MemRefType::attachInterface<MemRefPointerLikeModel>(*
getContext());
324 LLVM::LLVMPointerType::attachInterface<LLVMPointerPointerLikeModel>(
329 mlir::ModuleOp::attachInterface<mlir::omp::OffloadModuleDefaultModel>(
335 mlir::LLVM::GlobalOp::attachInterface<
338 mlir::LLVM::LLVMFuncOp::attachInterface<
341 mlir::func::FuncOp::attachInterface<
367 allocatorVars.push_back(operand);
368 allocatorTypes.push_back(type);
374 allocateVars.push_back(operand);
375 allocateTypes.push_back(type);
386 for (
unsigned i = 0; i < allocateVars.size(); ++i) {
387 std::string separator = i == allocateVars.size() - 1 ?
"" :
", ";
388 p << allocatorVars[i] <<
" : " << allocatorTypes[i] <<
" -> ";
389 p << allocateVars[i] <<
" : " << allocateTypes[i] << separator;
397 template <
typename ClauseAttr>
399 using ClauseT = decltype(std::declval<ClauseAttr>().getValue());
404 if (std::optional<ClauseT> enumValue = symbolizeEnum<ClauseT>(enumStr)) {
408 return parser.
emitError(loc,
"invalid clause value: '") << enumStr <<
"'";
411 template <
typename ClauseAttr>
413 p << stringifyEnum(attr.getValue());
436 linearVars.push_back(var);
437 linearTypes.push_back(type);
438 linearStepVars.push_back(stepVar);
447 size_t linearVarsSize = linearVars.size();
448 for (
unsigned i = 0; i < linearVarsSize; ++i) {
449 std::string separator = i == linearVarsSize - 1 ?
"" :
", ";
451 if (linearStepVars.size() > i)
452 p <<
" = " << linearStepVars[i];
453 p <<
" : " << linearVars[i].getType() << separator;
466 for (
const auto &it : nontemporalVars)
467 if (!nontemporalItems.insert(it).second)
468 return op->
emitOpError() <<
"nontemporal variable used more than once";
477 std::optional<ArrayAttr> alignments,
480 if (!alignedVars.empty()) {
481 if (!alignments || alignments->size() != alignedVars.size())
483 <<
"expected as many alignment values as aligned variables";
486 return op->
emitOpError() <<
"unexpected alignment values attribute";
492 for (
auto it : alignedVars)
493 if (!alignedItems.insert(it).second)
494 return op->
emitOpError() <<
"aligned variable used more than once";
500 for (
unsigned i = 0; i < (*alignments).size(); ++i) {
501 if (
auto intAttr = llvm::dyn_cast<IntegerAttr>((*alignments)[i])) {
502 if (intAttr.getValue().sle(0))
503 return op->
emitOpError() <<
"alignment should be greater than 0";
505 return op->
emitOpError() <<
"expected integer alignment";
519 ArrayAttr &alignmentsAttr) {
522 if (parser.parseOperand(alignedVars.emplace_back()) ||
523 parser.parseColonType(alignedTypes.emplace_back()) ||
524 parser.parseArrow() ||
525 parser.parseAttribute(alignmentVec.emplace_back())) {
539 std::optional<ArrayAttr> alignments) {
540 for (
unsigned i = 0; i < alignedVars.size(); ++i) {
543 p << alignedVars[i] <<
" : " << alignedVars[i].
getType();
544 p <<
" -> " << (*alignments)[i];
555 if (modifiers.size() > 2)
557 for (
const auto &mod : modifiers) {
560 auto symbol = symbolizeScheduleModifier(mod);
563 <<
" unknown modifier type: " << mod;
568 if (modifiers.size() == 1) {
569 if (symbolizeScheduleModifier(modifiers[0]) == ScheduleModifier::simd) {
570 modifiers.push_back(modifiers[0]);
571 modifiers[0] = stringifyScheduleModifier(ScheduleModifier::none);
573 }
else if (modifiers.size() == 2) {
576 if (symbolizeScheduleModifier(modifiers[0]) == ScheduleModifier::simd ||
577 symbolizeScheduleModifier(modifiers[1]) != ScheduleModifier::simd)
579 <<
" incorrect modifier order";
595 ScheduleModifierAttr &scheduleMod, UnitAttr &scheduleSimd,
596 std::optional<OpAsmParser::UnresolvedOperand> &chunkSize,
601 std::optional<mlir::omp::ClauseScheduleKind> schedule =
602 symbolizeClauseScheduleKind(keyword);
608 case ClauseScheduleKind::Static:
609 case ClauseScheduleKind::Dynamic:
610 case ClauseScheduleKind::Guided:
616 chunkSize = std::nullopt;
619 case ClauseScheduleKind::Auto:
621 chunkSize = std::nullopt;
630 modifiers.push_back(mod);
636 if (!modifiers.empty()) {
638 if (std::optional<ScheduleModifier> mod =
639 symbolizeScheduleModifier(modifiers[0])) {
642 return parser.
emitError(loc,
"invalid schedule modifier");
645 if (modifiers.size() > 1) {
646 assert(symbolizeScheduleModifier(modifiers[1]) == ScheduleModifier::simd);
656 ClauseScheduleKindAttr scheduleKind,
657 ScheduleModifierAttr scheduleMod,
658 UnitAttr scheduleSimd,
Value scheduleChunk,
659 Type scheduleChunkType) {
660 p << stringifyClauseScheduleKind(scheduleKind.getValue());
662 p <<
" = " << scheduleChunk <<
" : " << scheduleChunk.
getType();
664 p <<
", " << stringifyScheduleModifier(scheduleMod.getValue());
676 ClauseOrderKindAttr &order,
677 OrderModifierAttr &orderMod) {
682 if (std::optional<OrderModifier> enumValue =
683 symbolizeOrderModifier(enumStr)) {
691 if (std::optional<ClauseOrderKind> enumValue =
692 symbolizeClauseOrderKind(enumStr)) {
696 return parser.
emitError(loc,
"invalid clause value: '") << enumStr <<
"'";
700 ClauseOrderKindAttr order,
701 OrderModifierAttr orderMod) {
703 p << stringifyOrderModifier(orderMod.getValue()) <<
":";
705 p << stringifyClauseOrderKind(order.getValue());
708 template <
typename ClauseTypeAttr,
typename ClauseType>
711 std::optional<OpAsmParser::UnresolvedOperand> &operand,
713 std::optional<ClauseType> (*symbolizeClause)(StringRef),
714 StringRef clauseName) {
717 if (std::optional<ClauseType> enumValue = symbolizeClause(enumStr)) {
723 <<
"invalid " << clauseName <<
" modifier : '" << enumStr <<
"'";
733 <<
"expected " << clauseName <<
" operand";
736 if (operand.has_value()) {
744 template <
typename ClauseTypeAttr,
typename ClauseType>
747 ClauseTypeAttr prescriptiveness,
Value operand,
749 StringRef (*stringifyClauseType)(ClauseType)) {
751 if (prescriptiveness)
752 p << stringifyClauseType(prescriptiveness.getValue()) <<
", ";
755 p << operand <<
": " << operandType;
765 std::optional<OpAsmParser::UnresolvedOperand> &grainsize,
766 Type &grainsizeType) {
767 return parseGranularityClause<ClauseGrainsizeTypeAttr, ClauseGrainsizeType>(
768 parser, grainsizeMod, grainsize, grainsizeType,
769 &symbolizeClauseGrainsizeType,
"grainsize");
773 ClauseGrainsizeTypeAttr grainsizeMod,
775 printGranularityClause<ClauseGrainsizeTypeAttr, ClauseGrainsizeType>(
776 p, op, grainsizeMod, grainsize, grainsizeType,
777 &stringifyClauseGrainsizeType);
787 std::optional<OpAsmParser::UnresolvedOperand> &numTasks,
788 Type &numTasksType) {
789 return parseGranularityClause<ClauseNumTasksTypeAttr, ClauseNumTasksType>(
790 parser, numTasksMod, numTasks, numTasksType, &symbolizeClauseNumTasksType,
795 ClauseNumTasksTypeAttr numTasksMod,
797 printGranularityClause<ClauseNumTasksTypeAttr, ClauseNumTasksType>(
798 p, op, numTasksMod, numTasks, numTasksType, &stringifyClauseNumTasksType);
806 struct MapParseArgs {
811 : vars(vars), types(types) {}
813 struct PrivateParseArgs {
817 UnitAttr &needsBarrier;
821 UnitAttr &needsBarrier,
823 : vars(vars), types(types), syms(syms), needsBarrier(needsBarrier),
824 mapIndices(mapIndices) {}
827 struct ReductionParseArgs {
832 ReductionModifierAttr *modifier;
835 ArrayAttr &syms, ReductionModifierAttr *mod =
nullptr)
836 : vars(vars), types(types), byref(byref), syms(syms), modifier(mod) {}
839 struct AllRegionParseArgs {
840 std::optional<MapParseArgs> hasDeviceAddrArgs;
841 std::optional<MapParseArgs> hostEvalArgs;
842 std::optional<ReductionParseArgs> inReductionArgs;
843 std::optional<MapParseArgs> mapArgs;
844 std::optional<PrivateParseArgs> privateArgs;
845 std::optional<ReductionParseArgs> reductionArgs;
846 std::optional<ReductionParseArgs> taskReductionArgs;
847 std::optional<MapParseArgs> useDeviceAddrArgs;
848 std::optional<MapParseArgs> useDevicePtrArgs;
853 return "private_barrier";
863 ReductionModifierAttr *modifier =
nullptr,
864 UnitAttr *needsBarrier =
nullptr) {
868 unsigned regionArgOffset = regionPrivateArgs.size();
878 std::optional<ReductionModifier> enumValue =
879 symbolizeReductionModifier(enumStr);
880 if (!enumValue.has_value())
889 isByRefVec.push_back(
890 parser.parseOptionalKeyword(
"byref").succeeded());
892 if (symbols && parser.parseAttribute(symbolVec.emplace_back()))
895 if (parser.parseOperand(operands.emplace_back()) ||
896 parser.parseArrow() ||
897 parser.parseArgument(regionPrivateArgs.emplace_back()))
901 if (parser.parseOptionalLSquare().succeeded()) {
902 if (parser.parseKeyword(
"map_idx") || parser.parseEqual() ||
903 parser.parseInteger(mapIndicesVec.emplace_back()) ||
904 parser.parseRSquare())
907 mapIndicesVec.push_back(-1);
919 if (parser.parseType(types.emplace_back()))
926 if (operands.size() != types.size())
938 auto *argsBegin = regionPrivateArgs.begin();
940 argsBegin + regionArgOffset + types.size());
941 for (
auto [prv, type] : llvm::zip_equal(argsSubrange, types)) {
950 if (!mapIndicesVec.empty())
963 StringRef keyword, std::optional<MapParseArgs> mapArgs) {
978 StringRef keyword, std::optional<PrivateParseArgs> privateArgs) {
984 parser, privateArgs->vars, privateArgs->types, entryBlockArgs,
985 &privateArgs->syms, privateArgs->mapIndices,
nullptr,
986 nullptr, &privateArgs->needsBarrier)))
995 StringRef keyword, std::optional<ReductionParseArgs> reductionArgs) {
1000 parser, reductionArgs->vars, reductionArgs->types, entryBlockArgs,
1001 &reductionArgs->syms,
nullptr, &reductionArgs->byref,
1002 reductionArgs->modifier)))
1009 AllRegionParseArgs args) {
1013 args.hasDeviceAddrArgs)))
1015 <<
"invalid `has_device_addr` format";
1018 args.hostEvalArgs)))
1020 <<
"invalid `host_eval` format";
1023 args.inReductionArgs)))
1025 <<
"invalid `in_reduction` format";
1030 <<
"invalid `map_entries` format";
1035 <<
"invalid `private` format";
1038 args.reductionArgs)))
1040 <<
"invalid `reduction` format";
1043 args.taskReductionArgs)))
1045 <<
"invalid `task_reduction` format";
1048 args.useDeviceAddrArgs)))
1050 <<
"invalid `use_device_addr` format";
1053 args.useDevicePtrArgs)))
1055 <<
"invalid `use_device_addr` format";
1057 return parser.
parseRegion(region, entryBlockArgs);
1076 AllRegionParseArgs args;
1077 args.hasDeviceAddrArgs.emplace(hasDeviceAddrVars, hasDeviceAddrTypes);
1078 args.hostEvalArgs.emplace(hostEvalVars, hostEvalTypes);
1079 args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
1080 inReductionByref, inReductionSyms);
1081 args.mapArgs.emplace(mapVars, mapTypes);
1082 args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
1083 privateNeedsBarrier, &privateMaps);
1094 UnitAttr &privateNeedsBarrier) {
1095 AllRegionParseArgs args;
1096 args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
1097 inReductionByref, inReductionSyms);
1098 args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
1099 privateNeedsBarrier);
1110 UnitAttr &privateNeedsBarrier, ReductionModifierAttr &reductionMod,
1113 ArrayAttr &reductionSyms) {
1114 AllRegionParseArgs args;
1115 args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
1116 inReductionByref, inReductionSyms);
1117 args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
1118 privateNeedsBarrier);
1119 args.reductionArgs.emplace(reductionVars, reductionTypes, reductionByref,
1120 reductionSyms, &reductionMod);
1128 UnitAttr &privateNeedsBarrier) {
1129 AllRegionParseArgs args;
1130 args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
1131 privateNeedsBarrier);
1139 UnitAttr &privateNeedsBarrier, ReductionModifierAttr &reductionMod,
1142 ArrayAttr &reductionSyms) {
1143 AllRegionParseArgs args;
1144 args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
1145 privateNeedsBarrier);
1146 args.reductionArgs.emplace(reductionVars, reductionTypes, reductionByref,
1147 reductionSyms, &reductionMod);
1156 AllRegionParseArgs args;
1157 args.taskReductionArgs.emplace(taskReductionVars, taskReductionTypes,
1158 taskReductionByref, taskReductionSyms);
1168 AllRegionParseArgs args;
1169 args.useDeviceAddrArgs.emplace(useDeviceAddrVars, useDeviceAddrTypes);
1170 args.useDevicePtrArgs.emplace(useDevicePtrVars, useDevicePtrTypes);
1179 struct MapPrintArgs {
1184 struct PrivatePrintArgs {
1188 UnitAttr needsBarrier;
1192 : vars(vars), types(types), syms(syms), needsBarrier(needsBarrier),
1193 mapIndices(mapIndices) {}
1195 struct ReductionPrintArgs {
1200 ReductionModifierAttr modifier;
1202 ArrayAttr syms, ReductionModifierAttr mod =
nullptr)
1203 : vars(vars), types(types), byref(byref), syms(syms), modifier(mod) {}
1205 struct AllRegionPrintArgs {
1206 std::optional<MapPrintArgs> hasDeviceAddrArgs;
1207 std::optional<MapPrintArgs> hostEvalArgs;
1208 std::optional<ReductionPrintArgs> inReductionArgs;
1209 std::optional<MapPrintArgs> mapArgs;
1210 std::optional<PrivatePrintArgs> privateArgs;
1211 std::optional<ReductionPrintArgs> reductionArgs;
1212 std::optional<ReductionPrintArgs> taskReductionArgs;
1213 std::optional<MapPrintArgs> useDeviceAddrArgs;
1214 std::optional<MapPrintArgs> useDevicePtrArgs;
1223 ReductionModifierAttr modifier =
nullptr, UnitAttr needsBarrier =
nullptr) {
1224 if (argsSubrange.empty())
1227 p << clauseName <<
"(";
1230 p <<
"mod: " << stringifyReductionModifier(modifier.getValue()) <<
", ";
1247 llvm::interleaveComma(llvm::zip_equal(operands, argsSubrange, symbols,
1248 mapIndices.asArrayRef(),
1249 byref.asArrayRef()),
1251 auto [op, arg, sym, map, isByRef] = t;
1257 p << op <<
" -> " << arg;
1260 p <<
" [map_idx=" << map <<
"]";
1263 llvm::interleaveComma(types, p);
1271 StringRef clauseName,
ValueRange argsSubrange,
1272 std::optional<MapPrintArgs> mapArgs) {
1279 StringRef clauseName,
ValueRange argsSubrange,
1280 std::optional<PrivatePrintArgs> privateArgs) {
1283 p, ctx, clauseName, argsSubrange, privateArgs->vars, privateArgs->types,
1284 privateArgs->syms, privateArgs->mapIndices,
nullptr,
1285 nullptr, privateArgs->needsBarrier);
1291 std::optional<ReductionPrintArgs> reductionArgs) {
1294 reductionArgs->vars, reductionArgs->types,
1295 reductionArgs->syms,
nullptr,
1296 reductionArgs->byref, reductionArgs->modifier);
1300 const AllRegionPrintArgs &args) {
1301 auto iface = llvm::cast<mlir::omp::BlockArgOpenMPOpInterface>(op);
1305 iface.getHasDeviceAddrBlockArgs(),
1306 args.hasDeviceAddrArgs);
1310 args.inReductionArgs);
1316 args.reductionArgs);
1318 iface.getTaskReductionBlockArgs(),
1319 args.taskReductionArgs);
1321 iface.getUseDeviceAddrBlockArgs(),
1322 args.useDeviceAddrArgs);
1324 iface.getUseDevicePtrBlockArgs(), args.useDevicePtrArgs);
1338 TypeRange privateTypes, ArrayAttr privateSyms, UnitAttr privateNeedsBarrier,
1340 AllRegionPrintArgs args;
1341 args.hasDeviceAddrArgs.emplace(hasDeviceAddrVars, hasDeviceAddrTypes);
1342 args.hostEvalArgs.emplace(hostEvalVars, hostEvalTypes);
1343 args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
1344 inReductionByref, inReductionSyms);
1345 args.mapArgs.emplace(mapVars, mapTypes);
1346 args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
1347 privateNeedsBarrier, privateMaps);
1355 ArrayAttr privateSyms, UnitAttr privateNeedsBarrier) {
1356 AllRegionPrintArgs args;
1357 args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
1358 inReductionByref, inReductionSyms);
1359 args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
1360 privateNeedsBarrier,
1369 ArrayAttr privateSyms, UnitAttr privateNeedsBarrier,
1370 ReductionModifierAttr reductionMod,
ValueRange reductionVars,
1372 ArrayAttr reductionSyms) {
1373 AllRegionPrintArgs args;
1374 args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
1375 inReductionByref, inReductionSyms);
1376 args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
1377 privateNeedsBarrier,
1379 args.reductionArgs.emplace(reductionVars, reductionTypes, reductionByref,
1380 reductionSyms, reductionMod);
1386 ArrayAttr privateSyms,
1387 UnitAttr privateNeedsBarrier) {
1388 AllRegionPrintArgs args;
1389 args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
1390 privateNeedsBarrier,
1397 TypeRange privateTypes, ArrayAttr privateSyms, UnitAttr privateNeedsBarrier,
1398 ReductionModifierAttr reductionMod,
ValueRange reductionVars,
1400 ArrayAttr reductionSyms) {
1401 AllRegionPrintArgs args;
1402 args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
1403 privateNeedsBarrier,
1405 args.reductionArgs.emplace(reductionVars, reductionTypes, reductionByref,
1406 reductionSyms, reductionMod);
1415 ArrayAttr taskReductionSyms) {
1416 AllRegionPrintArgs args;
1417 args.taskReductionArgs.emplace(taskReductionVars, taskReductionTypes,
1418 taskReductionByref, taskReductionSyms);
1428 AllRegionPrintArgs args;
1429 args.useDeviceAddrArgs.emplace(useDeviceAddrVars, useDeviceAddrTypes);
1430 args.useDevicePtrArgs.emplace(useDevicePtrVars, useDevicePtrTypes);
1435 static LogicalResult
1439 if (!reductionVars.empty()) {
1440 if (!reductionSyms || reductionSyms->size() != reductionVars.size())
1442 <<
"expected as many reduction symbol references "
1443 "as reduction variables";
1444 if (reductionByref && reductionByref->size() != reductionVars.size())
1445 return op->
emitError() <<
"expected as many reduction variable by "
1446 "reference attributes as reduction variables";
1449 return op->
emitOpError() <<
"unexpected reduction symbol references";
1456 for (
auto args : llvm::zip(reductionVars, *reductionSyms)) {
1457 Value accum = std::get<0>(args);
1459 if (!accumulators.insert(accum).second)
1460 return op->
emitOpError() <<
"accumulator variable used more than once";
1463 auto symbolRef = llvm::cast<SymbolRefAttr>(std::get<1>(args));
1465 SymbolTable::lookupNearestSymbolFrom<DeclareReductionOp>(op, symbolRef);
1467 return op->
emitOpError() <<
"expected symbol reference " << symbolRef
1468 <<
" to point to a reduction declaration";
1470 if (decl.getAccumulatorType() && decl.getAccumulatorType() != varType)
1472 <<
"expected accumulator (" << varType
1473 <<
") to be the same type as reduction declaration ("
1474 << decl.getAccumulatorType() <<
")";
1493 if (parser.parseOperand(copyprivateVars.emplace_back()) ||
1494 parser.parseArrow() ||
1495 parser.parseAttribute(symsVec.emplace_back()) ||
1496 parser.parseColonType(copyprivateTypes.emplace_back()))
1510 std::optional<ArrayAttr> copyprivateSyms) {
1511 if (!copyprivateSyms.has_value())
1513 llvm::interleaveComma(
1514 llvm::zip(copyprivateVars, *copyprivateSyms, copyprivateTypes), p,
1515 [&](
const auto &args) {
1516 p << std::get<0>(args) <<
" -> " << std::get<1>(args) <<
" : "
1517 << std::get<2>(args);
1522 static LogicalResult
1524 std::optional<ArrayAttr> copyprivateSyms) {
1525 size_t copyprivateSymsSize =
1526 copyprivateSyms.has_value() ? copyprivateSyms->size() : 0;
1527 if (copyprivateSymsSize != copyprivateVars.size())
1528 return op->
emitOpError() <<
"inconsistent number of copyprivate vars (= "
1529 << copyprivateVars.size()
1530 <<
") and functions (= " << copyprivateSymsSize
1531 <<
"), both must be equal";
1532 if (!copyprivateSyms.has_value())
1535 for (
auto copyprivateVarAndSym :
1536 llvm::zip(copyprivateVars, *copyprivateSyms)) {
1538 llvm::cast<SymbolRefAttr>(std::get<1>(copyprivateVarAndSym));
1539 std::optional<std::variant<mlir::func::FuncOp, mlir::LLVM::LLVMFuncOp>>
1541 if (mlir::func::FuncOp mlirFuncOp =
1542 SymbolTable::lookupNearestSymbolFrom<mlir::func::FuncOp>(op,
1544 funcOp = mlirFuncOp;
1545 else if (mlir::LLVM::LLVMFuncOp llvmFuncOp =
1546 SymbolTable::lookupNearestSymbolFrom<mlir::LLVM::LLVMFuncOp>(
1548 funcOp = llvmFuncOp;
1550 auto getNumArguments = [&] {
1551 return std::visit([](
auto &f) {
return f.getNumArguments(); }, *funcOp);
1554 auto getArgumentType = [&](
unsigned i) {
1555 return std::visit([i](
auto &f) {
return f.getArgumentTypes()[i]; },
1560 return op->
emitOpError() <<
"expected symbol reference " << symbolRef
1561 <<
" to point to a copy function";
1563 if (getNumArguments() != 2)
1565 <<
"expected copy function " << symbolRef <<
" to have 2 operands";
1567 Type argTy = getArgumentType(0);
1568 if (argTy != getArgumentType(1))
1569 return op->
emitOpError() <<
"expected copy function " << symbolRef
1570 <<
" arguments to have the same type";
1572 Type varType = std::get<0>(copyprivateVarAndSym).getType();
1573 if (argTy != varType)
1575 <<
"expected copy function arguments' type (" << argTy
1576 <<
") to be the same as copyprivate variable's type (" << varType
1597 if (parser.parseKeyword(&keyword) || parser.parseArrow() ||
1598 parser.parseOperand(dependVars.emplace_back()) ||
1599 parser.parseColonType(dependTypes.emplace_back()))
1601 if (std::optional<ClauseTaskDepend> keywordDepend =
1602 (symbolizeClauseTaskDepend(keyword)))
1603 kindsVec.emplace_back(
1604 ClauseTaskDependAttr::get(parser.getContext(), *keywordDepend));
1618 std::optional<ArrayAttr> dependKinds) {
1620 for (
unsigned i = 0, e = dependKinds->size(); i < e; ++i) {
1623 p << stringifyClauseTaskDepend(
1624 llvm::cast<mlir::omp::ClauseTaskDependAttr>((*dependKinds)[i])
1626 <<
" -> " << dependVars[i] <<
" : " << dependTypes[i];
1632 std::optional<ArrayAttr> dependKinds,
1634 if (!dependVars.empty()) {
1635 if (!dependKinds || dependKinds->size() != dependVars.size())
1636 return op->
emitOpError() <<
"expected as many depend values"
1637 " as depend variables";
1639 if (dependKinds && !dependKinds->empty())
1640 return op->
emitOpError() <<
"unexpected depend values";
1656 IntegerAttr &hintAttr) {
1657 StringRef hintKeyword;
1663 auto parseKeyword = [&]() -> ParseResult {
1666 if (hintKeyword ==
"uncontended")
1668 else if (hintKeyword ==
"contended")
1670 else if (hintKeyword ==
"nonspeculative")
1672 else if (hintKeyword ==
"speculative")
1676 << hintKeyword <<
" is not a valid hint";
1687 IntegerAttr hintAttr) {
1688 int64_t hint = hintAttr.getInt();
1696 auto bitn = [](
int value,
int n) ->
bool {
return value & (1 << n); };
1698 bool uncontended = bitn(hint, 0);
1699 bool contended = bitn(hint, 1);
1700 bool nonspeculative = bitn(hint, 2);
1701 bool speculative = bitn(hint, 3);
1705 hints.push_back(
"uncontended");
1707 hints.push_back(
"contended");
1709 hints.push_back(
"nonspeculative");
1711 hints.push_back(
"speculative");
1713 llvm::interleaveComma(hints, p);
1720 auto bitn = [](
int value,
int n) ->
bool {
return value & (1 << n); };
1722 bool uncontended = bitn(hint, 0);
1723 bool contended = bitn(hint, 1);
1724 bool nonspeculative = bitn(hint, 2);
1725 bool speculative = bitn(hint, 3);
1727 if (uncontended && contended)
1728 return op->
emitOpError() <<
"the hints omp_sync_hint_uncontended and "
1729 "omp_sync_hint_contended cannot be combined";
1730 if (nonspeculative && speculative)
1731 return op->
emitOpError() <<
"the hints omp_sync_hint_nonspeculative and "
1732 "omp_sync_hint_speculative cannot be combined.";
1742 llvm::omp::OpenMPOffloadMappingFlags flag) {
1743 return value & llvm::to_underlying(flag);
1752 llvm::omp::OpenMPOffloadMappingFlags mapTypeBits =
1753 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_NONE;
1757 auto parseTypeAndMod = [&]() -> ParseResult {
1758 StringRef mapTypeMod;
1762 if (mapTypeMod ==
"always")
1763 mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ALWAYS;
1765 if (mapTypeMod ==
"implicit")
1766 mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_IMPLICIT;
1768 if (mapTypeMod ==
"ompx_hold")
1769 mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_OMPX_HOLD;
1771 if (mapTypeMod ==
"close")
1772 mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_CLOSE;
1774 if (mapTypeMod ==
"present")
1775 mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_PRESENT;
1777 if (mapTypeMod ==
"to")
1778 mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO;
1780 if (mapTypeMod ==
"from")
1781 mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM;
1783 if (mapTypeMod ==
"tofrom")
1784 mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO |
1785 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM;
1787 if (mapTypeMod ==
"delete")
1788 mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_DELETE;
1790 if (mapTypeMod ==
"return_param")
1791 mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_RETURN_PARAM;
1801 llvm::to_underlying(mapTypeBits));
1809 IntegerAttr mapType) {
1810 uint64_t mapTypeBits = mapType.getUInt();
1812 bool emitAllocRelease =
true;
1818 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ALWAYS))
1819 mapTypeStrs.push_back(
"always");
1821 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_IMPLICIT))
1822 mapTypeStrs.push_back(
"implicit");
1824 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_OMPX_HOLD))
1825 mapTypeStrs.push_back(
"ompx_hold");
1827 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_CLOSE))
1828 mapTypeStrs.push_back(
"close");
1830 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_PRESENT))
1831 mapTypeStrs.push_back(
"present");
1837 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO);
1839 mapTypeBits, llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM);
1841 emitAllocRelease =
false;
1842 mapTypeStrs.push_back(
"tofrom");
1844 emitAllocRelease =
false;
1845 mapTypeStrs.push_back(
"from");
1847 emitAllocRelease =
false;
1848 mapTypeStrs.push_back(
"to");
1851 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_DELETE)) {
1852 emitAllocRelease =
false;
1853 mapTypeStrs.push_back(
"delete");
1857 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_RETURN_PARAM)) {
1858 emitAllocRelease =
false;
1859 mapTypeStrs.push_back(
"return_param");
1861 if (emitAllocRelease)
1862 mapTypeStrs.push_back(
"exit_release_or_enter_alloc");
1864 for (
unsigned int i = 0; i < mapTypeStrs.size(); ++i) {
1865 p << mapTypeStrs[i];
1866 if (i + 1 < mapTypeStrs.size()) {
1873 ArrayAttr &membersIdx) {
1876 auto parseIndices = [&]() -> ParseResult {
1881 APInt(64, value,
false)));
1899 if (!memberIdxs.empty())
1906 ArrayAttr membersIdx) {
1910 llvm::interleaveComma(membersIdx, p, [&p](
Attribute v) {
1912 auto memberIdx = cast<ArrayAttr>(v);
1913 llvm::interleaveComma(memberIdx.getValue(), p, [&p](
Attribute v2) {
1914 p << cast<IntegerAttr>(v2).getInt();
1921 VariableCaptureKindAttr mapCaptureType) {
1922 std::string typeCapStr;
1923 llvm::raw_string_ostream typeCap(typeCapStr);
1924 if (mapCaptureType.getValue() == mlir::omp::VariableCaptureKind::ByRef)
1926 if (mapCaptureType.getValue() == mlir::omp::VariableCaptureKind::ByCopy)
1927 typeCap <<
"ByCopy";
1928 if (mapCaptureType.getValue() == mlir::omp::VariableCaptureKind::VLAType)
1929 typeCap <<
"VLAType";
1930 if (mapCaptureType.getValue() == mlir::omp::VariableCaptureKind::This)
1936 VariableCaptureKindAttr &mapCaptureType) {
1937 StringRef mapCaptureKey;
1941 if (mapCaptureKey ==
"This")
1943 parser.
getContext(), mlir::omp::VariableCaptureKind::This);
1944 if (mapCaptureKey ==
"ByRef")
1946 parser.
getContext(), mlir::omp::VariableCaptureKind::ByRef);
1947 if (mapCaptureKey ==
"ByCopy")
1949 parser.
getContext(), mlir::omp::VariableCaptureKind::ByCopy);
1950 if (mapCaptureKey ==
"VLAType")
1952 parser.
getContext(), mlir::omp::VariableCaptureKind::VLAType);
1961 for (
auto mapOp : mapVars) {
1962 if (!mapOp.getDefiningOp())
1965 if (
auto mapInfoOp = mapOp.getDefiningOp<mlir::omp::MapInfoOp>()) {
1966 uint64_t mapTypeBits = mapInfoOp.getMapType();
1969 mapTypeBits, llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO);
1971 mapTypeBits, llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM);
1973 mapTypeBits, llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_DELETE);
1976 mapTypeBits, llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ALWAYS);
1978 mapTypeBits, llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_CLOSE);
1980 mapTypeBits, llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_IMPLICIT);
1982 if ((isa<TargetDataOp>(op) || isa<TargetOp>(op)) && del)
1984 "to, from, tofrom and alloc map types are permitted");
1986 if (isa<TargetEnterDataOp>(op) && (from || del))
1987 return emitError(op->
getLoc(),
"to and alloc map types are permitted");
1989 if (isa<TargetExitDataOp>(op) && to)
1991 "from, release and delete map types are permitted");
1993 if (isa<TargetUpdateOp>(op)) {
1996 "at least one of to or from map types must be "
1997 "specified, other map types are not permitted");
2002 "at least one of to or from map types must be "
2003 "specified, other map types are not permitted");
2006 auto updateVar = mapInfoOp.getVarPtr();
2008 if ((to && from) || (to && updateFromVars.contains(updateVar)) ||
2009 (from && updateToVars.contains(updateVar))) {
2012 "either to or from map types can be specified, not both");
2015 if (always || close || implicit) {
2018 "present, mapper and iterator map type modifiers are permitted");
2021 to ? updateToVars.insert(updateVar) : updateFromVars.insert(updateVar);
2023 }
else if (!isa<DeclareMapperInfoOp>(op)) {
2025 "map argument is not a map entry operation");
2033 std::optional<DenseI64ArrayAttr> privateMapIndices =
2034 targetOp.getPrivateMapsAttr();
2037 if (!privateMapIndices.has_value() || !privateMapIndices.value())
2042 if (privateMapIndices.value().size() !=
2043 static_cast<int64_t
>(privateVars.size()))
2044 return emitError(targetOp.getLoc(),
"sizes of `private` operand range and "
2045 "`private_maps` attribute mismatch");
2055 StringRef clauseName,
2057 for (
Value var : vars)
2058 if (!llvm::isa_and_present<MapInfoOp>(var.getDefiningOp()))
2060 <<
"'" << clauseName
2061 <<
"' arguments must be defined by 'omp.map.info' ops";
2066 if (getMapperId() &&
2067 !SymbolTable::lookupNearestSymbolFrom<omp::DeclareMapperOp>(
2068 *
this, getMapperIdAttr())) {
2083 const TargetDataOperands &clauses) {
2084 TargetDataOp::build(builder, state, clauses.device, clauses.ifExpr,
2085 clauses.mapVars, clauses.useDeviceAddrVars,
2086 clauses.useDevicePtrVars);
2090 if (getMapVars().empty() && getUseDevicePtrVars().empty() &&
2091 getUseDeviceAddrVars().empty()) {
2093 "At least one of map, use_device_ptr_vars, or "
2094 "use_device_addr_vars operand must be present");
2098 getUseDevicePtrVars())))
2102 getUseDeviceAddrVars())))
2112 void TargetEnterDataOp::build(
2116 TargetEnterDataOp::build(builder, state,
2118 clauses.dependVars, clauses.device, clauses.ifExpr,
2119 clauses.mapVars, clauses.nowait);
2123 LogicalResult verifyDependVars =
2125 return failed(verifyDependVars) ? verifyDependVars
2136 TargetExitDataOp::build(builder, state,
2138 clauses.dependVars, clauses.device, clauses.ifExpr,
2139 clauses.mapVars, clauses.nowait);
2143 LogicalResult verifyDependVars =
2145 return failed(verifyDependVars) ? verifyDependVars
2156 TargetUpdateOp::build(builder, state,
makeArrayAttr(ctx, clauses.dependKinds),
2157 clauses.dependVars, clauses.device, clauses.ifExpr,
2158 clauses.mapVars, clauses.nowait);
2162 LogicalResult verifyDependVars =
2164 return failed(verifyDependVars) ? verifyDependVars
2173 const TargetOperands &clauses) {
2177 TargetOp::build(builder, state, {}, {},
2179 clauses.dependVars, clauses.device, clauses.hasDeviceAddrVars,
2180 clauses.hostEvalVars, clauses.ifExpr,
2182 nullptr, clauses.isDevicePtrVars,
2183 clauses.mapVars, clauses.nowait, clauses.privateVars,
2185 clauses.privateNeedsBarrier, clauses.threadLimit,
2194 getHasDeviceAddrVars())))
2203 LogicalResult TargetOp::verifyRegions() {
2204 auto teamsOps = getOps<TeamsOp>();
2205 if (std::distance(teamsOps.begin(), teamsOps.end()) > 1)
2206 return emitError(
"target containing multiple 'omp.teams' nested ops");
2209 Operation *capturedOp = getInnermostCapturedOmpOp();
2210 TargetRegionFlags execFlags = getKernelExecFlags(capturedOp);
2211 for (
Value hostEvalArg :
2212 cast<BlockArgOpenMPOpInterface>(getOperation()).getHostEvalBlockArgs()) {
2214 if (
auto teamsOp = dyn_cast<TeamsOp>(user)) {
2215 if (llvm::is_contained({teamsOp.getNumTeamsLower(),
2216 teamsOp.getNumTeamsUpper(),
2217 teamsOp.getThreadLimit()},
2221 return emitOpError() <<
"host_eval argument only legal as 'num_teams' "
2222 "and 'thread_limit' in 'omp.teams'";
2224 if (
auto parallelOp = dyn_cast<ParallelOp>(user)) {
2225 if (bitEnumContainsAny(execFlags, TargetRegionFlags::spmd) &&
2226 parallelOp->isAncestor(capturedOp) &&
2227 hostEvalArg == parallelOp.getNumThreads())
2230 return emitOpError()
2231 <<
"host_eval argument only legal as 'num_threads' in "
2232 "'omp.parallel' when representing target SPMD";
2234 if (
auto loopNestOp = dyn_cast<LoopNestOp>(user)) {
2235 if (bitEnumContainsAny(execFlags, TargetRegionFlags::trip_count) &&
2236 loopNestOp.getOperation() == capturedOp &&
2237 (llvm::is_contained(loopNestOp.getLoopLowerBounds(), hostEvalArg) ||
2238 llvm::is_contained(loopNestOp.getLoopUpperBounds(), hostEvalArg) ||
2239 llvm::is_contained(loopNestOp.getLoopSteps(), hostEvalArg)))
2242 return emitOpError() <<
"host_eval argument only legal as loop bounds "
2243 "and steps in 'omp.loop_nest' when trip count "
2244 "must be evaluated in the host";
2247 return emitOpError() <<
"host_eval argument illegal use in '"
2248 << user->getName() <<
"' operation";
2257 assert(rootOp &&
"expected valid operation");
2269 return WalkResult::advance();
2274 bool isOmpDialect = op->
getDialect() == ompDialect;
2276 if (!isOmpDialect || !hasRegions)
2277 return WalkResult::skip();
2283 if (checkSingleMandatoryExec) {
2288 if (successor->isReachable(parentBlock))
2289 return WalkResult::interrupt();
2291 for (
Block &block : *parentRegion)
2293 !domInfo.
dominates(parentBlock, &block))
2294 return WalkResult::interrupt();
2300 if (&sibling != op && !siblingAllowedFn(&sibling))
2301 return WalkResult::interrupt();
2306 return llvm::isa<LoopNestOp>(op) ? WalkResult::interrupt()
2307 : WalkResult::advance();
2313 Operation *TargetOp::getInnermostCapturedOmpOp() {
2326 if (
auto memOp = dyn_cast<MemoryEffectOpInterface>(sibling)) {
2329 memOp.getEffects(effects);
2330 return !llvm::any_of(
2332 return isa<MemoryEffects::Write>(effect.
getEffect()) &&
2333 isa<SideEffects::AutomaticAllocationScopeResource>(
2343 WsloopOp *wsLoopOp) {
2345 if (teamsOp.getNumTeamsUpper())
2349 if (teamsOp.getNumReductionVars())
2351 if (wsLoopOp->getNumReductionVars())
2355 OffloadModuleInterface offloadMod =
2359 auto ompFlags = offloadMod.getFlags();
2362 return ompFlags.getAssumeTeamsOversubscription() &&
2363 ompFlags.getAssumeThreadsOversubscription();
2366 TargetRegionFlags TargetOp::getKernelExecFlags(
Operation *capturedOp) {
2371 assert((!capturedOp ||
2372 (targetOp && targetOp.getInnermostCapturedOmpOp() == capturedOp)) &&
2373 "unexpected captured op");
2376 if (!isa_and_present<LoopNestOp>(capturedOp))
2377 return TargetRegionFlags::generic;
2381 cast<LoopNestOp>(capturedOp).gatherWrappers(loopWrappers);
2382 assert(!loopWrappers.empty());
2384 LoopWrapperInterface *innermostWrapper = loopWrappers.begin();
2385 if (isa<SimdOp>(innermostWrapper))
2386 innermostWrapper = std::next(innermostWrapper);
2388 auto numWrappers = std::distance(innermostWrapper, loopWrappers.end());
2389 if (numWrappers != 1 && numWrappers != 2)
2390 return TargetRegionFlags::generic;
2393 if (numWrappers == 2) {
2394 WsloopOp *wsloopOp = dyn_cast<WsloopOp>(innermostWrapper);
2396 return TargetRegionFlags::generic;
2398 innermostWrapper = std::next(innermostWrapper);
2399 if (!isa<DistributeOp>(innermostWrapper))
2400 return TargetRegionFlags::generic;
2403 if (!isa_and_present<ParallelOp>(parallelOp))
2404 return TargetRegionFlags::generic;
2406 TeamsOp teamsOp = dyn_cast<TeamsOp>(parallelOp->
getParentOp());
2408 return TargetRegionFlags::generic;
2410 if (teamsOp->getParentOp() == targetOp.getOperation()) {
2411 TargetRegionFlags result =
2412 TargetRegionFlags::spmd | TargetRegionFlags::trip_count;
2414 result = result | TargetRegionFlags::no_loop;
2419 else if (isa<DistributeOp, LoopOp>(innermostWrapper)) {
2421 if (!isa_and_present<TeamsOp>(teamsOp))
2422 return TargetRegionFlags::generic;
2424 if (teamsOp->
getParentOp() != targetOp.getOperation())
2425 return TargetRegionFlags::generic;
2427 if (isa<LoopOp>(innermostWrapper))
2428 return TargetRegionFlags::spmd | TargetRegionFlags::trip_count;
2438 Dialect *ompDialect = targetOp->getDialect();
2442 return sibling && (ompDialect != sibling->
getDialect() ||
2446 TargetRegionFlags result =
2447 TargetRegionFlags::generic | TargetRegionFlags::trip_count;
2452 while (nestedCapture->
getParentOp() != capturedOp)
2455 return isa<ParallelOp>(nestedCapture) ? result | TargetRegionFlags::spmd
2459 else if (isa<WsloopOp>(innermostWrapper)) {
2461 if (!isa_and_present<ParallelOp>(parallelOp))
2462 return TargetRegionFlags::generic;
2464 if (parallelOp->
getParentOp() == targetOp.getOperation())
2465 return TargetRegionFlags::spmd;
2468 return TargetRegionFlags::generic;
2477 ParallelOp::build(builder, state,
ValueRange(),
2484 state.addAttributes(attributes);
2488 const ParallelOperands &clauses) {
2490 ParallelOp::build(builder, state, clauses.allocateVars, clauses.allocatorVars,
2491 clauses.ifExpr, clauses.numThreads, clauses.privateVars,
2493 clauses.privateNeedsBarrier, clauses.procBindKind,
2494 clauses.reductionMod, clauses.reductionVars,
2499 template <
typename OpType>
2501 auto privateVars = op.getPrivateVars();
2502 auto privateSyms = op.getPrivateSymsAttr();
2504 if (privateVars.empty() && (privateSyms ==
nullptr || privateSyms.empty()))
2507 auto numPrivateVars = privateVars.size();
2508 auto numPrivateSyms = (privateSyms ==
nullptr) ? 0 : privateSyms.size();
2510 if (numPrivateVars != numPrivateSyms)
2511 return op.emitError() <<
"inconsistent number of private variables and "
2512 "privatizer op symbols, private vars: "
2514 <<
" vs. privatizer op symbols: " << numPrivateSyms;
2516 for (
auto privateVarInfo : llvm::zip_equal(privateVars, privateSyms)) {
2517 Type varType = std::get<0>(privateVarInfo).getType();
2518 SymbolRefAttr privateSym = cast<SymbolRefAttr>(std::get<1>(privateVarInfo));
2519 PrivateClauseOp privatizerOp =
2520 SymbolTable::lookupNearestSymbolFrom<PrivateClauseOp>(op, privateSym);
2522 if (privatizerOp ==
nullptr)
2523 return op.emitError() <<
"failed to lookup privatizer op with symbol: '"
2524 << privateSym <<
"'";
2526 Type privatizerType = privatizerOp.getArgType();
2528 if (privatizerType && (varType != privatizerType))
2529 return op.emitError()
2530 <<
"type mismatch between a "
2531 << (privatizerOp.getDataSharingType() ==
2532 DataSharingClauseType::Private
2535 <<
" variable and its privatizer op, var type: " << varType
2536 <<
" vs. privatizer op type: " << privatizerType;
2543 if (getAllocateVars().size() != getAllocatorVars().size())
2545 "expected equal sizes for allocate and allocator variables");
2551 getReductionByref());
2554 LogicalResult ParallelOp::verifyRegions() {
2555 auto distChildOps = getOps<DistributeOp>();
2556 int numDistChildOps = std::distance(distChildOps.begin(), distChildOps.end());
2557 if (numDistChildOps > 1)
2559 <<
"multiple 'omp.distribute' nested inside of 'omp.parallel'";
2561 if (numDistChildOps == 1) {
2564 <<
"'omp.composite' attribute missing from composite operation";
2567 Operation &distributeOp = **distChildOps.begin();
2569 if (&childOp == &distributeOp || ompDialect != childOp.getDialect())
2573 return emitError() <<
"unexpected OpenMP operation inside of composite "
2575 << childOp.getName();
2577 }
else if (isComposite()) {
2579 <<
"'omp.composite' attribute present in non-composite operation";
2596 const TeamsOperands &clauses) {
2599 TeamsOp::build(builder, state, clauses.allocateVars, clauses.allocatorVars,
2600 clauses.ifExpr, clauses.numTeamsLower, clauses.numTeamsUpper,
2602 nullptr, clauses.reductionMod,
2603 clauses.reductionVars,
2606 clauses.threadLimit);
2618 return emitError(
"expected to be nested inside of omp.target or not nested "
2619 "in any OpenMP dialect operations");
2622 if (
auto numTeamsLowerBound = getNumTeamsLower()) {
2623 auto numTeamsUpperBound = getNumTeamsUpper();
2624 if (!numTeamsUpperBound)
2625 return emitError(
"expected num_teams upper bound to be defined if the "
2626 "lower bound is defined");
2627 if (numTeamsLowerBound.getType() != numTeamsUpperBound.getType())
2629 "expected num_teams upper bound and lower bound to be the same type");
2633 if (getAllocateVars().size() != getAllocatorVars().size())
2635 "expected equal sizes for allocate and allocator variables");
2638 getReductionByref());
2646 return getParentOp().getPrivateVars();
2650 return getParentOp().getReductionVars();
2658 const SectionsOperands &clauses) {
2661 SectionsOp::build(builder, state, clauses.allocateVars, clauses.allocatorVars,
2664 clauses.reductionMod, clauses.reductionVars,
2670 if (getAllocateVars().size() != getAllocatorVars().size())
2672 "expected equal sizes for allocate and allocator variables");
2675 getReductionByref());
2678 LogicalResult SectionsOp::verifyRegions() {
2679 for (
auto &inst : *getRegion().begin()) {
2680 if (!(isa<SectionOp>(inst) || isa<TerminatorOp>(inst))) {
2681 return emitOpError()
2682 <<
"expected omp.section op or terminator op inside region";
2694 const SingleOperands &clauses) {
2697 SingleOp::build(builder, state, clauses.allocateVars, clauses.allocatorVars,
2698 clauses.copyprivateVars,
2699 makeArrayAttr(ctx, clauses.copyprivateSyms), clauses.nowait,
2706 if (getAllocateVars().size() != getAllocatorVars().size())
2708 "expected equal sizes for allocate and allocator variables");
2711 getCopyprivateSyms());
2719 const WorkshareOperands &clauses) {
2720 WorkshareOp::build(builder, state, clauses.nowait);
2728 if (!(*this)->getParentOfType<WorkshareOp>())
2729 return emitOpError() <<
"must be nested in an omp.workshare";
2733 LogicalResult WorkshareLoopWrapperOp::verifyRegions() {
2734 if (isa_and_nonnull<LoopWrapperInterface>((*this)->getParentOp()) ||
2736 return emitOpError() <<
"expected to be a standalone loop wrapper";
2745 LogicalResult LoopWrapperInterface::verifyImpl() {
2749 return emitOpError() <<
"loop wrapper must also have the `NoTerminator` "
2750 "and `SingleBlock` traits";
2753 return emitOpError() <<
"loop wrapper does not contain exactly one region";
2756 if (range_size(region.
getOps()) != 1)
2757 return emitOpError()
2758 <<
"loop wrapper does not contain exactly one nested op";
2761 if (!isa<LoopNestOp, LoopWrapperInterface>(firstOp))
2762 return emitOpError() <<
"nested in loop wrapper is not another loop "
2763 "wrapper or `omp.loop_nest`";
2773 const LoopOperands &clauses) {
2776 LoopOp::build(builder, state, clauses.bindKind, clauses.privateVars,
2778 clauses.privateNeedsBarrier, clauses.order, clauses.orderMod,
2779 clauses.reductionMod, clauses.reductionVars,
2786 getReductionByref());
2789 LogicalResult LoopOp::verifyRegions() {
2790 if (llvm::isa_and_nonnull<LoopWrapperInterface>((*this)->getParentOp()) ||
2792 return emitOpError() <<
"expected to be a standalone loop wrapper";
2803 build(builder, state, {}, {},
2805 false,
nullptr,
nullptr,
2806 nullptr, {},
nullptr,
2813 state.addAttributes(attributes);
2817 const WsloopOperands &clauses) {
2822 {}, {}, clauses.linearVars,
2823 clauses.linearStepVars, clauses.nowait, clauses.order, clauses.orderMod,
2824 clauses.ordered, clauses.privateVars,
2825 makeArrayAttr(ctx, clauses.privateSyms), clauses.privateNeedsBarrier,
2826 clauses.reductionMod, clauses.reductionVars,
2828 makeArrayAttr(ctx, clauses.reductionSyms), clauses.scheduleKind,
2829 clauses.scheduleChunk, clauses.scheduleMod, clauses.scheduleSimd);
2834 getReductionByref());
2837 LogicalResult WsloopOp::verifyRegions() {
2838 bool isCompositeChildLeaf =
2839 llvm::dyn_cast_if_present<LoopWrapperInterface>((*this)->getParentOp());
2841 if (LoopWrapperInterface nested = getNestedWrapper()) {
2844 <<
"'omp.composite' attribute missing from composite wrapper";
2848 if (!isa<SimdOp>(nested))
2849 return emitError() <<
"only supported nested wrapper is 'omp.simd'";
2851 }
else if (isComposite() && !isCompositeChildLeaf) {
2853 <<
"'omp.composite' attribute present in non-composite wrapper";
2854 }
else if (!isComposite() && isCompositeChildLeaf) {
2856 <<
"'omp.composite' attribute missing from composite wrapper";
2867 const SimdOperands &clauses) {
2870 SimdOp::build(builder, state, clauses.alignedVars,
2873 clauses.nontemporalVars, clauses.order, clauses.orderMod,
2874 clauses.privateVars,
makeArrayAttr(ctx, clauses.privateSyms),
2875 clauses.privateNeedsBarrier, clauses.reductionMod,
2876 clauses.reductionVars,
2883 if (getSimdlen().has_value() && getSafelen().has_value() &&
2884 getSimdlen().value() > getSafelen().value())
2885 return emitOpError()
2886 <<
"simdlen clause and safelen clause are both present, but the "
2887 "simdlen value is not less than or equal to safelen value";
2895 bool isCompositeChildLeaf =
2896 llvm::dyn_cast_if_present<LoopWrapperInterface>((*this)->getParentOp());
2898 if (!isComposite() && isCompositeChildLeaf)
2900 <<
"'omp.composite' attribute missing from composite wrapper";
2902 if (isComposite() && !isCompositeChildLeaf)
2904 <<
"'omp.composite' attribute present in non-composite wrapper";
2908 std::optional<ArrayAttr> privateSyms = getPrivateSyms();
2910 for (
const Attribute &sym : *privateSyms) {
2911 auto symRef = cast<SymbolRefAttr>(sym);
2912 omp::PrivateClauseOp privatizer =
2913 SymbolTable::lookupNearestSymbolFrom<omp::PrivateClauseOp>(
2914 getOperation(), symRef);
2916 return emitError() <<
"Cannot find privatizer '" << symRef <<
"'";
2917 if (privatizer.getDataSharingType() ==
2918 DataSharingClauseType::FirstPrivate)
2919 return emitError() <<
"FIRSTPRIVATE cannot be used with SIMD";
2926 LogicalResult SimdOp::verifyRegions() {
2927 if (getNestedWrapper())
2928 return emitOpError() <<
"must wrap an 'omp.loop_nest' directly";
2938 const DistributeOperands &clauses) {
2939 DistributeOp::build(builder, state, clauses.allocateVars,
2940 clauses.allocatorVars, clauses.distScheduleStatic,
2941 clauses.distScheduleChunkSize, clauses.order,
2942 clauses.orderMod, clauses.privateVars,
2944 clauses.privateNeedsBarrier);
2948 if (this->getDistScheduleChunkSize() && !this->getDistScheduleStatic())
2949 return emitOpError() <<
"chunk size set without "
2950 "dist_schedule_static being present";
2952 if (getAllocateVars().size() != getAllocatorVars().size())
2954 "expected equal sizes for allocate and allocator variables");
2959 LogicalResult DistributeOp::verifyRegions() {
2960 if (LoopWrapperInterface nested = getNestedWrapper()) {
2963 <<
"'omp.composite' attribute missing from composite wrapper";
2966 if (isa<WsloopOp>(nested)) {
2968 if (!llvm::dyn_cast_if_present<ParallelOp>(parentOp) ||
2969 !cast<ComposableOpInterface>(parentOp).isComposite()) {
2970 return emitError() <<
"an 'omp.wsloop' nested wrapper is only allowed "
2971 "when a composite 'omp.parallel' is the direct "
2974 }
else if (!isa<SimdOp>(nested))
2975 return emitError() <<
"only supported nested wrappers are 'omp.simd' and "
2977 }
else if (isComposite()) {
2979 <<
"'omp.composite' attribute present in non-composite wrapper";
2993 LogicalResult DeclareMapperOp::verifyRegions() {
2994 if (!llvm::isa_and_present<DeclareMapperInfoOp>(
2995 getRegion().getBlocks().front().getTerminator()))
2996 return emitOpError() <<
"expected terminator to be a DeclareMapperInfoOp";
3005 LogicalResult DeclareReductionOp::verifyRegions() {
3006 if (!getAllocRegion().empty()) {
3007 for (YieldOp yieldOp : getAllocRegion().getOps<YieldOp>()) {
3008 if (yieldOp.getResults().size() != 1 ||
3009 yieldOp.getResults().getTypes()[0] !=
getType())
3010 return emitOpError() <<
"expects alloc region to yield a value "
3011 "of the reduction type";
3015 if (getInitializerRegion().empty())
3016 return emitOpError() <<
"expects non-empty initializer region";
3017 Block &initializerEntryBlock = getInitializerRegion().
front();
3020 if (!getAllocRegion().empty())
3021 return emitOpError() <<
"expects two arguments to the initializer region "
3022 "when an allocation region is used";
3024 if (getAllocRegion().empty())
3025 return emitOpError() <<
"expects one argument to the initializer region "
3026 "when no allocation region is used";
3028 return emitOpError()
3029 <<
"expects one or two arguments to the initializer region";
3033 if (arg.getType() !=
getType())
3034 return emitOpError() <<
"expects initializer region argument to match "
3035 "the reduction type";
3037 for (YieldOp yieldOp : getInitializerRegion().getOps<YieldOp>()) {
3038 if (yieldOp.getResults().size() != 1 ||
3039 yieldOp.getResults().getTypes()[0] !=
getType())
3040 return emitOpError() <<
"expects initializer region to yield a value "
3041 "of the reduction type";
3044 if (getReductionRegion().empty())
3045 return emitOpError() <<
"expects non-empty reduction region";
3046 Block &reductionEntryBlock = getReductionRegion().
front();
3051 return emitOpError() <<
"expects reduction region with two arguments of "
3052 "the reduction type";
3053 for (YieldOp yieldOp : getReductionRegion().getOps<YieldOp>()) {
3054 if (yieldOp.getResults().size() != 1 ||
3055 yieldOp.getResults().getTypes()[0] !=
getType())
3056 return emitOpError() <<
"expects reduction region to yield a value "
3057 "of the reduction type";
3060 if (!getAtomicReductionRegion().empty()) {
3061 Block &atomicReductionEntryBlock = getAtomicReductionRegion().
front();
3065 return emitOpError() <<
"expects atomic reduction region with two "
3066 "arguments of the same type";
3067 auto ptrType = llvm::dyn_cast<PointerLikeType>(
3070 (ptrType.getElementType() && ptrType.getElementType() !=
getType()))
3071 return emitOpError() <<
"expects atomic reduction region arguments to "
3072 "be accumulators containing the reduction type";
3075 if (getCleanupRegion().empty())
3077 Block &cleanupEntryBlock = getCleanupRegion().
front();
3080 return emitOpError() <<
"expects cleanup region with one argument "
3081 "of the reduction type";
3091 const TaskOperands &clauses) {
3093 TaskOp::build(builder, state, clauses.allocateVars, clauses.allocatorVars,
3094 makeArrayAttr(ctx, clauses.dependKinds), clauses.dependVars,
3095 clauses.final, clauses.ifExpr, clauses.inReductionVars,
3097 makeArrayAttr(ctx, clauses.inReductionSyms), clauses.mergeable,
3098 clauses.priority, clauses.privateVars,
3100 clauses.privateNeedsBarrier, clauses.untied,
3101 clauses.eventHandle);
3105 LogicalResult verifyDependVars =
3107 return failed(verifyDependVars)
3110 getInReductionVars(),
3111 getInReductionByref());
3119 const TaskgroupOperands &clauses) {
3121 TaskgroupOp::build(builder, state, clauses.allocateVars,
3122 clauses.allocatorVars, clauses.taskReductionVars,
3129 getTaskReductionVars(),
3130 getTaskReductionByref());
3138 const TaskloopOperands &clauses) {
3141 builder, state, clauses.allocateVars, clauses.allocatorVars,
3142 clauses.final, clauses.grainsizeMod, clauses.grainsize, clauses.ifExpr,
3143 clauses.inReductionVars,
3145 makeArrayAttr(ctx, clauses.inReductionSyms), clauses.mergeable,
3146 clauses.nogroup, clauses.numTasksMod, clauses.numTasks, clauses.priority,
3147 clauses.privateVars,
3149 clauses.privateNeedsBarrier, clauses.reductionMod, clauses.reductionVars,
3155 if (getAllocateVars().size() != getAllocatorVars().size())
3157 "expected equal sizes for allocate and allocator variables");
3159 getReductionVars(), getReductionByref())) ||
3161 getInReductionVars(),
3162 getInReductionByref())))
3165 if (!getReductionVars().empty() && getNogroup())
3166 return emitError(
"if a reduction clause is present on the taskloop "
3167 "directive, the nogroup clause must not be specified");
3168 for (
auto var : getReductionVars()) {
3169 if (llvm::is_contained(getInReductionVars(), var))
3170 return emitError(
"the same list item cannot appear in both a reduction "
3171 "and an in_reduction clause");
3174 if (getGrainsize() && getNumTasks()) {
3176 "the grainsize clause and num_tasks clause are mutually exclusive and "
3177 "may not appear on the same taskloop directive");
3183 LogicalResult TaskloopOp::verifyRegions() {
3184 if (LoopWrapperInterface nested = getNestedWrapper()) {
3187 <<
"'omp.composite' attribute missing from composite wrapper";
3191 if (!isa<SimdOp>(nested))
3192 return emitError() <<
"only supported nested wrapper is 'omp.simd'";
3193 }
else if (isComposite()) {
3195 <<
"'omp.composite' attribute present in non-composite wrapper";
3219 for (
auto &iv : ivs)
3220 iv.type = loopVarType;
3241 "collapse_num_loops",
3246 auto parseTiles = [&]() -> ParseResult {
3250 tiles.push_back(
tile);
3259 if (tiles.size() > 0)
3278 Region ®ion = getRegion();
3280 p <<
" (" << args <<
") : " << args[0].getType() <<
" = ("
3281 << getLoopLowerBounds() <<
") to (" << getLoopUpperBounds() <<
") ";
3282 if (getLoopInclusive())
3284 p <<
"step (" << getLoopSteps() <<
") ";
3285 if (int64_t numCollapse = getCollapseNumLoops())
3286 if (numCollapse > 1)
3287 p <<
"collapse(" << numCollapse <<
") ";
3290 p <<
"tiles(" << tiles.value() <<
") ";
3296 const LoopNestOperands &clauses) {
3298 LoopNestOp::build(builder, state, clauses.collapseNumLoops,
3299 clauses.loopLowerBounds, clauses.loopUpperBounds,
3300 clauses.loopSteps, clauses.loopInclusive,
3305 if (getLoopLowerBounds().empty())
3306 return emitOpError() <<
"must represent at least one loop";
3308 if (getLoopLowerBounds().size() != getIVs().size())
3309 return emitOpError() <<
"number of range arguments and IVs do not match";
3311 for (
auto [lb, iv] : llvm::zip_equal(getLoopLowerBounds(), getIVs())) {
3312 if (lb.getType() != iv.getType())
3313 return emitOpError()
3314 <<
"range argument type does not match corresponding IV type";
3317 uint64_t numIVs = getIVs().size();
3319 if (
const auto &numCollapse = getCollapseNumLoops())
3320 if (numCollapse > numIVs)
3321 return emitOpError()
3322 <<
"collapse value is larger than the number of loops";
3325 if (tiles.value().size() > numIVs)
3326 return emitOpError() <<
"too few canonical loops for tile dimensions";
3328 if (!llvm::dyn_cast_if_present<LoopWrapperInterface>((*this)->getParentOp()))
3329 return emitOpError() <<
"expects parent op to be a loop wrapper";
3334 void LoopNestOp::gatherWrappers(
3337 while (
auto wrapper =
3338 llvm::dyn_cast_if_present<LoopWrapperInterface>(parent)) {
3339 wrappers.push_back(wrapper);
3348 std::tuple<NewCliOp, OpOperand *, OpOperand *>
3354 return {{},
nullptr,
nullptr};
3357 "Unexpected type of cli");
3363 auto op = cast<LoopTransformationInterface>(use.getOwner());
3365 unsigned opnum = use.getOperandNumber();
3366 if (op.isGeneratee(opnum)) {
3367 assert(!gen &&
"Each CLI may have at most one def");
3369 }
else if (op.isApplyee(opnum)) {
3370 assert(!cons &&
"Each CLI may have at most one consumer");
3373 llvm_unreachable(
"Unexpected operand for a CLI");
3377 return {create, gen, cons};
3386 Value result = getResult();
3387 auto [newCli, gen, cons] =
decodeCli(result);
3400 std::string cliName{
"cli"};
3404 .Case([&](CanonicalLoopOp op) {
3407 .Case([&](UnrollHeuristicOp op) -> std::string {
3408 llvm_unreachable(
"heuristic unrolling does not generate a loop");
3410 .Case([&](TileOp op) -> std::string {
3411 auto [generateesFirst, generateesCount] =
3412 op.getGenerateesODSOperandIndexAndLength();
3413 unsigned firstGrid = generateesFirst;
3414 unsigned firstIntratile = generateesFirst + generateesCount / 2;
3415 unsigned end = generateesFirst + generateesCount;
3416 unsigned opnum =
generator->getOperandNumber();
3418 if (firstGrid <= opnum && opnum < firstIntratile) {
3419 unsigned gridnum = opnum - firstGrid + 1;
3420 return (
"grid" + Twine(gridnum)).str();
3422 if (firstIntratile <= opnum && opnum < end) {
3423 unsigned intratilenum = opnum - firstIntratile + 1;
3424 return (
"intratile" + Twine(intratilenum)).str();
3426 llvm_unreachable(
"Unexpected generatee argument");
3428 .DefaultUnreachable(
"TODO: Custom name for this operation");
3431 setNameFn(result, cliName);
3435 Value cli = getResult();
3438 "Unexpected type of cli");
3444 auto op = cast<mlir::omp::LoopTransformationInterface>(use.getOwner());
3446 unsigned opnum = use.getOperandNumber();
3447 if (op.isGeneratee(opnum)) {
3450 emitOpError(
"CLI must have at most one generator");
3452 .
append(
"first generator here:");
3454 .
append(
"second generator here:");
3459 }
else if (op.isApplyee(opnum)) {
3462 emitOpError(
"CLI must have at most one consumer");
3464 .
append(
"first consumer here:")
3468 .
append(
"second consumer here:")
3475 llvm_unreachable(
"Unexpected operand for a CLI");
3483 .
append(
"see consumer here: ")
3506 setNameFn(&getRegion().front(),
"body_entry");
3509 void CanonicalLoopOp::getAsmBlockArgumentNames(
Region ®ion,
3517 p <<
'(' << getCli() <<
')';
3518 p <<
' ' << getInductionVar() <<
" : " << getInductionVar().getType()
3519 <<
" in range(" << getTripCount() <<
") ";
3529 CanonicalLoopInfoType cliType =
3555 if (parser.
parseRegion(*region, {inductionVariable}))
3560 result.
operands.append(cliOperand);
3566 return mlir::success();
3572 if (!getRegion().empty()) {
3573 Region ®ion = getRegion();
3576 "Canonical loop region must have exactly one argument");
3580 "Region argument must be the same type as the trip count");
3586 Value CanonicalLoopOp::getInductionVar() {
return getRegion().getArgument(0); }
3588 std::pair<unsigned, unsigned>
3589 CanonicalLoopOp::getApplyeesODSOperandIndexAndLength() {
3594 std::pair<unsigned, unsigned>
3595 CanonicalLoopOp::getGenerateesODSOperandIndexAndLength() {
3596 return getODSOperandIndexAndLength(odsIndex_cli);
3610 p <<
'(' << getApplyee() <<
')';
3640 return mlir::success();
3643 std::pair<unsigned, unsigned>
3644 UnrollHeuristicOp ::getApplyeesODSOperandIndexAndLength() {
3645 return getODSOperandIndexAndLength(odsIndex_applyee);
3648 std::pair<unsigned, unsigned>
3649 UnrollHeuristicOp::getGenerateesODSOperandIndexAndLength() {
3660 if (!generatees.empty())
3661 p <<
'(' << llvm::interleaved(generatees) <<
')';
3663 if (!applyees.empty())
3664 p <<
" <- (" << llvm::interleaved(applyees) <<
')';
3696 if (getApplyees().empty())
3697 return emitOpError() <<
"must apply to at least one loop";
3699 if (getSizes().size() != getApplyees().size())
3700 return emitOpError() <<
"there must be one tile size for each applyee";
3702 if (!getGeneratees().empty() &&
3703 2 * getSizes().size() != getGeneratees().size())
3704 return emitOpError()
3705 <<
"expecting two times the number of generatees than applyees";
3709 Value parent = getApplyees().front();
3710 for (
auto &&applyee : llvm::drop_begin(getApplyees())) {
3711 auto [parentCreate, parentGen, parentCons] =
decodeCli(parent);
3712 auto [create, gen, cons] =
decodeCli(applyee);
3715 return emitOpError() <<
"applyee CLI has no generator";
3717 auto parentLoop = dyn_cast_or_null<CanonicalLoopOp>(parentGen->getOwner());
3719 return emitOpError()
3720 <<
"currently only supports omp.canonical_loop as applyee";
3722 parentIVs.insert(parentLoop.getInductionVar());
3725 return emitOpError() <<
"applyee CLI has no generator";
3726 auto loop = dyn_cast_or_null<CanonicalLoopOp>(gen->getOwner());
3728 return emitOpError()
3729 <<
"currently only supports omp.canonical_loop as applyee";
3735 auto &parentBody = parentLoop.getRegion();
3736 if (!parentBody.hasOneBlock())
3738 auto &parentBlock = parentBody.getBlocks().front();
3740 auto nestedLoopIt = parentBlock.begin();
3741 if (nestedLoopIt == parentBlock.end() ||
3742 (&*nestedLoopIt != loop.getOperation()))
3745 auto termIt = std::next(nestedLoopIt);
3746 if (termIt == parentBlock.end() || !isa<TerminatorOp>(termIt))
3749 if (std::next(termIt) != parentBlock.end())
3755 return emitOpError() <<
"tiled loop nest must be perfectly nested";
3757 if (parentIVs.contains(loop.getTripCount()))
3758 return emitOpError() <<
"tiled loop nest must be rectangular";
3777 std::pair<unsigned, unsigned> TileOp ::getApplyeesODSOperandIndexAndLength() {
3778 return getODSOperandIndexAndLength(odsIndex_applyees);
3781 std::pair<unsigned, unsigned> TileOp::getGenerateesODSOperandIndexAndLength() {
3782 return getODSOperandIndexAndLength(odsIndex_generatees);
3790 const CriticalDeclareOperands &clauses) {
3791 CriticalDeclareOp::build(builder, state, clauses.symName, clauses.hint);
3799 if (getNameAttr()) {
3800 SymbolRefAttr symbolRef = getNameAttr();
3804 return emitOpError() <<
"expected symbol reference " << symbolRef
3805 <<
" to point to a critical declaration";
3825 return op.
emitOpError() <<
"must be nested inside of a loop";
3829 if (
auto wsloopOp = dyn_cast<WsloopOp>(wrapper)) {
3830 IntegerAttr orderedAttr = wsloopOp.getOrderedAttr();
3832 return op.
emitOpError() <<
"the enclosing worksharing-loop region must "
3833 "have an ordered clause";
3835 if (hasRegion && orderedAttr.getInt() != 0)
3836 return op.
emitOpError() <<
"the enclosing loop's ordered clause must not "
3837 "have a parameter present";
3839 if (!hasRegion && orderedAttr.getInt() == 0)
3840 return op.
emitOpError() <<
"the enclosing loop's ordered clause must "
3841 "have a parameter present";
3842 }
else if (!isa<SimdOp>(wrapper)) {
3843 return op.
emitOpError() <<
"must be nested inside of a worksharing, simd "
3844 "or worksharing simd loop";
3850 const OrderedOperands &clauses) {
3851 OrderedOp::build(builder, state, clauses.doacrossDependType,
3852 clauses.doacrossNumLoops, clauses.doacrossDependVars);
3859 auto wrapper = (*this)->getParentOfType<WsloopOp>();
3860 if (!wrapper || *wrapper.getOrdered() != *getDoacrossNumLoops())
3861 return emitOpError() <<
"number of variables in depend clause does not "
3862 <<
"match number of iteration variables in the "
3869 const OrderedRegionOperands &clauses) {
3870 OrderedRegionOp::build(builder, state, clauses.parLevelSimd);
3880 const TaskwaitOperands &clauses) {
3882 TaskwaitOp::build(builder, state,
nullptr,
3891 if (verifyCommon().
failed())
3892 return mlir::failure();
3894 if (
auto mo = getMemoryOrder()) {
3895 if (*mo == ClauseMemoryOrderKind::Acq_rel ||
3896 *mo == ClauseMemoryOrderKind::Release) {
3898 "memory-order must not be acq_rel or release for atomic reads");
3909 if (verifyCommon().
failed())
3910 return mlir::failure();
3912 if (
auto mo = getMemoryOrder()) {
3913 if (*mo == ClauseMemoryOrderKind::Acq_rel ||
3914 *mo == ClauseMemoryOrderKind::Acquire) {
3916 "memory-order must not be acq_rel or acquire for atomic writes");
3926 LogicalResult AtomicUpdateOp::canonicalize(AtomicUpdateOp op,
3932 if (
Value writeVal = op.getWriteOpVal()) {
3934 op, op.getX(), writeVal, op.getHintAttr(), op.getMemoryOrderAttr());
3941 if (verifyCommon().
failed())
3942 return mlir::failure();
3944 if (
auto mo = getMemoryOrder()) {
3945 if (*mo == ClauseMemoryOrderKind::Acq_rel ||
3946 *mo == ClauseMemoryOrderKind::Acquire) {
3948 "memory-order must not be acq_rel or acquire for atomic updates");
3955 LogicalResult AtomicUpdateOp::verifyRegions() {
return verifyRegionsCommon(); }
3961 AtomicReadOp AtomicCaptureOp::getAtomicReadOp() {
3962 if (
auto op = dyn_cast<AtomicReadOp>(getFirstOp()))
3964 return dyn_cast<AtomicReadOp>(getSecondOp());
3967 AtomicWriteOp AtomicCaptureOp::getAtomicWriteOp() {
3968 if (
auto op = dyn_cast<AtomicWriteOp>(getFirstOp()))
3970 return dyn_cast<AtomicWriteOp>(getSecondOp());
3973 AtomicUpdateOp AtomicCaptureOp::getAtomicUpdateOp() {
3974 if (
auto op = dyn_cast<AtomicUpdateOp>(getFirstOp()))
3976 return dyn_cast<AtomicUpdateOp>(getSecondOp());
3983 LogicalResult AtomicCaptureOp::verifyRegions() {
3984 if (verifyRegionsCommon().
failed())
3985 return mlir::failure();
3987 if (getFirstOp()->getAttr(
"hint") || getSecondOp()->getAttr(
"hint"))
3989 "operations inside capture region must not have hint clause");
3991 if (getFirstOp()->getAttr(
"memory_order") ||
3992 getSecondOp()->getAttr(
"memory_order"))
3994 "operations inside capture region must not have memory_order clause");
4003 const CancelOperands &clauses) {
4004 CancelOp::build(builder, state, clauses.cancelDirective, clauses.ifExpr);
4018 ClauseCancellationConstructType cct = getCancelDirective();
4021 if (!structuralParent)
4022 return emitOpError() <<
"Orphaned cancel construct";
4024 if ((cct == ClauseCancellationConstructType::Parallel) &&
4025 !mlir::isa<ParallelOp>(structuralParent)) {
4026 return emitOpError() <<
"cancel parallel must appear "
4027 <<
"inside a parallel region";
4029 if (cct == ClauseCancellationConstructType::Loop) {
4032 auto wsloopOp = mlir::dyn_cast<WsloopOp>(structuralParent->
getParentOp());
4035 return emitOpError()
4036 <<
"cancel loop must appear inside a worksharing-loop region";
4038 if (wsloopOp.getNowaitAttr()) {
4039 return emitError() <<
"A worksharing construct that is canceled "
4040 <<
"must not have a nowait clause";
4042 if (wsloopOp.getOrderedAttr()) {
4043 return emitError() <<
"A worksharing construct that is canceled "
4044 <<
"must not have an ordered clause";
4047 }
else if (cct == ClauseCancellationConstructType::Sections) {
4051 mlir::dyn_cast<SectionsOp>(structuralParent->
getParentOp());
4053 return emitOpError() <<
"cancel sections must appear "
4054 <<
"inside a sections region";
4056 if (sectionsOp.getNowait()) {
4057 return emitError() <<
"A sections construct that is canceled "
4058 <<
"must not have a nowait clause";
4061 if ((cct == ClauseCancellationConstructType::Taskgroup) &&
4062 (!mlir::isa<omp::TaskOp>(structuralParent) &&
4063 !mlir::isa<omp::TaskloopOp>(structuralParent->
getParentOp()))) {
4064 return emitOpError() <<
"cancel taskgroup must appear "
4065 <<
"inside a task region";
4075 const CancellationPointOperands &clauses) {
4076 CancellationPointOp::build(builder, state, clauses.cancelDirective);
4080 ClauseCancellationConstructType cct = getCancelDirective();
4083 if (!structuralParent)
4084 return emitOpError() <<
"Orphaned cancellation point";
4086 if ((cct == ClauseCancellationConstructType::Parallel) &&
4087 !mlir::isa<ParallelOp>(structuralParent)) {
4088 return emitOpError() <<
"cancellation point parallel must appear "
4089 <<
"inside a parallel region";
4093 if ((cct == ClauseCancellationConstructType::Loop) &&
4094 !mlir::isa<WsloopOp>(structuralParent->
getParentOp())) {
4095 return emitOpError() <<
"cancellation point loop must appear "
4096 <<
"inside a worksharing-loop region";
4098 if ((cct == ClauseCancellationConstructType::Sections) &&
4099 !mlir::isa<omp::SectionOp>(structuralParent)) {
4100 return emitOpError() <<
"cancellation point sections must appear "
4101 <<
"inside a sections region";
4103 if ((cct == ClauseCancellationConstructType::Taskgroup) &&
4104 !mlir::isa<omp::TaskOp>(structuralParent)) {
4105 return emitOpError() <<
"cancellation point taskgroup must appear "
4106 <<
"inside a task region";
4116 auto extent = getExtent();
4118 if (!extent && !upperbound)
4119 return emitError(
"expected extent or upperbound.");
4126 PrivateClauseOp::build(
4127 odsBuilder, odsState, symName, type,
4129 DataSharingClauseType::Private));
4132 LogicalResult PrivateClauseOp::verifyRegions() {
4133 Type argType = getArgType();
4134 auto verifyTerminator = [&](
Operation *terminator,
4135 bool yieldsValue) -> LogicalResult {
4139 if (!llvm::isa<YieldOp>(terminator))
4141 <<
"expected exit block terminator to be an `omp.yield` op.";
4143 YieldOp yieldOp = llvm::cast<YieldOp>(terminator);
4144 TypeRange yieldedTypes = yieldOp.getResults().getTypes();
4147 if (yieldedTypes.empty())
4151 <<
"Did not expect any values to be yielded.";
4154 if (yieldedTypes.size() == 1 && yieldedTypes.front() == argType)
4158 <<
"Invalid yielded value. Expected type: " << argType
4161 if (yieldedTypes.empty())
4164 error << yieldedTypes;
4170 StringRef regionName,
4171 bool yieldsValue) -> LogicalResult {
4172 assert(!region.
empty());
4176 <<
"`" << regionName <<
"`: "
4177 <<
"expected " << expectedNumArgs
4180 for (
Block &block : region) {
4182 if (!block.mightHaveTerminator())
4185 if (
failed(verifyTerminator(block.getTerminator(), yieldsValue)))
4193 for (
Region *region : getRegions())
4194 for (
Type ty : region->getArgumentTypes())
4196 return emitError() <<
"Region argument type mismatch: got " << ty
4197 <<
" expected " << argType <<
".";
4200 if (!initRegion.
empty() &&
4205 DataSharingClauseType dsType = getDataSharingType();
4207 if (dsType == DataSharingClauseType::Private && !getCopyRegion().empty())
4208 return emitError(
"`private` clauses do not require a `copy` region.");
4210 if (dsType == DataSharingClauseType::FirstPrivate && getCopyRegion().empty())
4212 "`firstprivate` clauses require at least a `copy` region.");
4214 if (dsType == DataSharingClauseType::FirstPrivate &&
4219 if (!getDeallocRegion().empty() &&
4232 const MaskedOperands &clauses) {
4233 MaskedOp::build(builder, state, clauses.filteredThreadId);
4241 const ScanOperands &clauses) {
4242 ScanOp::build(builder, state, clauses.inclusiveVars, clauses.exclusiveVars);
4246 if (hasExclusiveVars() == hasInclusiveVars())
4248 "Exactly one of EXCLUSIVE or INCLUSIVE clause is expected");
4249 if (WsloopOp parentWsLoopOp = (*this)->getParentOfType<WsloopOp>()) {
4250 if (parentWsLoopOp.getReductionModAttr() &&
4251 parentWsLoopOp.getReductionModAttr().getValue() ==
4252 ReductionModifier::inscan)
4255 if (SimdOp parentSimdOp = (*this)->getParentOfType<SimdOp>()) {
4256 if (parentSimdOp.getReductionModAttr() &&
4257 parentSimdOp.getReductionModAttr().getValue() ==
4258 ReductionModifier::inscan)
4261 return emitError(
"SCAN directive needs to be enclosed within a parent "
4262 "worksharing loop construct or SIMD construct with INSCAN "
4263 "reduction modifier");
4269 std::optional<uint64_t> align = this->getAlign();
4271 if (align.has_value()) {
4272 if ((align.value() > 0) && !llvm::has_single_bit(align.value()))
4273 return emitError() <<
"ALIGN value : " << align.value()
4274 <<
" must be power of 2";
4284 mlir::Type omp::TargetAllocMemOp::getAllocatedType() {
4285 return getInTypeAttr().getValue();
4294 bool hasOperands =
false;
4295 std::int32_t typeparamsSize = 0;
4301 return mlir::failure();
4303 return mlir::failure();
4305 return mlir::failure();
4309 return mlir::failure();
4317 return mlir::failure();
4318 typeparamsSize = operands.size();
4321 std::int32_t shapeSize = 0;
4325 return mlir::failure();
4326 shapeSize = operands.size() - typeparamsSize;
4328 for (std::int32_t i = typeparamsSize, end = operands.size(); i != end; ++i)
4329 typeVec.push_back(idxTy);
4335 return mlir::failure();
4340 return mlir::failure();
4347 return mlir::failure();
4348 return mlir::success();
4363 if (!getTypeparams().empty()) {
4364 p <<
'(' << getTypeparams() <<
" : " << getTypeparams().getTypes() <<
')';
4371 {
"in_type",
"operandSegmentSizes"});
4376 if (!mlir::dyn_cast<IntegerType>(outType))
4377 return emitOpError(
"must be a integer type");
4378 return mlir::success();
4387 Region ®ion = getRegion();
4389 return emitOpError(
"region cannot be empty");
4392 if (entryBlock.
empty())
4393 return emitOpError(
"region must contain a structured block");
4395 bool hasTerminator =
false;
4396 for (
Block &block : region) {
4397 if (isa<TerminatorOp>(block.back())) {
4398 if (hasTerminator) {
4399 return emitOpError(
"region must have exactly one terminator");
4401 hasTerminator =
true;
4404 if (!hasTerminator) {
4405 return emitOpError(
"region must be terminated with omp.terminator");
4409 if (isa<BarrierOp>(op)) {
4411 "explicit barriers are not allowed in workdistribute region");
4414 if (isa<ParallelOp>(op)) {
4416 "nested parallel constructs not allowed in workdistribute");
4418 if (isa<TeamsOp>(op)) {
4420 "nested teams constructs not allowed in workdistribute");
4422 return WalkResult::advance();
4424 if (walkResult.wasInterrupted())
4428 if (!llvm::dyn_cast<TeamsOp>(parentOp))
4429 return emitOpError(
"workdistribute must be nested under teams");
4433 #define GET_ATTRDEF_CLASSES
4434 #include "mlir/Dialect/OpenMP/OpenMPOpsAttributes.cpp.inc"
4436 #define GET_OP_CLASSES
4437 #include "mlir/Dialect/OpenMP/OpenMPOps.cpp.inc"
4439 #define GET_TYPEDEF_CLASSES
4440 #include "mlir/Dialect/OpenMP/OpenMPOpsTypes.cpp.inc"
static SmallVector< Value > getTileSizes(Location loc, amx::TileType tType, RewriterBase &rewriter)
Maps the 2-dim vector shape to the two 16-bit tile sizes.
static std::optional< int64_t > getUpperBound(Value iv)
Gets the constant upper bound on an affine.for iv.
static LogicalResult verifyRegion(emitc::SwitchOp op, Region ®ion, const Twine &name)
static void visit(Operation *op, DenseSet< Operation * > &visited)
Visits all the pdl.operand(s), pdl.result(s), and pdl.operation(s) connected to the given operation.
static MLIRContext * getContext(OpFoldResult val)
static Type getElementType(Type type)
Determine the element type of type.
static ze_device_handle_t getDevice(const uint32_t driverIdx=0, const int32_t devIdx=0)
static const mlir::GenInfo * generator
static LogicalResult verifyNontemporalClause(Operation *op, OperandRange nontemporalVars)
static DenseI64ArrayAttr makeDenseI64ArrayAttr(MLIRContext *ctx, const ArrayRef< int64_t > intArray)
static constexpr StringRef getPrivateNeedsBarrierSpelling()
static LogicalResult verifyMapClause(Operation *op, OperandRange mapVars)
static ParseResult parseInReductionPrivateRegion(OpAsmParser &parser, Region ®ion, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &inReductionVars, SmallVectorImpl< Type > &inReductionTypes, DenseBoolArrayAttr &inReductionByref, ArrayAttr &inReductionSyms, llvm::SmallVectorImpl< OpAsmParser::UnresolvedOperand > &privateVars, llvm::SmallVectorImpl< Type > &privateTypes, ArrayAttr &privateSyms, UnitAttr &privateNeedsBarrier)
static ArrayAttr makeArrayAttr(MLIRContext *context, llvm::ArrayRef< Attribute > attrs)
static ParseResult parseClauseAttr(AsmParser &parser, ClauseAttr &attr)
static void printAllocateAndAllocator(OpAsmPrinter &p, Operation *op, OperandRange allocateVars, TypeRange allocateTypes, OperandRange allocatorVars, TypeRange allocatorTypes)
Print allocate clause.
static DenseBoolArrayAttr makeDenseBoolArrayAttr(MLIRContext *ctx, const ArrayRef< bool > boolArray)
static std::string generateLoopNestingName(StringRef prefix, CanonicalLoopOp op)
Generate a name of a canonical loop nest of the format <prefix>(_r<idx>_s<idx>)*.
static void printClauseWithRegionArgs(OpAsmPrinter &p, MLIRContext *ctx, StringRef clauseName, ValueRange argsSubrange, ValueRange operands, TypeRange types, ArrayAttr symbols=nullptr, DenseI64ArrayAttr mapIndices=nullptr, DenseBoolArrayAttr byref=nullptr, ReductionModifierAttr modifier=nullptr, UnitAttr needsBarrier=nullptr)
static void printBlockArgClause(OpAsmPrinter &p, MLIRContext *ctx, StringRef clauseName, ValueRange argsSubrange, std::optional< MapPrintArgs > mapArgs)
static void printBlockArgRegion(OpAsmPrinter &p, Operation *op, Region ®ion, const AllRegionPrintArgs &args)
static ParseResult parseGranularityClause(OpAsmParser &parser, ClauseTypeAttr &prescriptiveness, std::optional< OpAsmParser::UnresolvedOperand > &operand, Type &operandType, std::optional< ClauseType >(*symbolizeClause)(StringRef), StringRef clauseName)
static ParseResult parseBlockArgRegion(OpAsmParser &parser, Region ®ion, AllRegionParseArgs args)
static ParseResult parseLoopTransformClis(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &generateesOperands, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &applyeesOperands)
static ParseResult parseSynchronizationHint(OpAsmParser &parser, IntegerAttr &hintAttr)
Parses a Synchronization Hint clause.
static uint64_t mapTypeToBitFlag(uint64_t value, llvm::omp::OpenMPOffloadMappingFlags flag)
static ParseResult parseLinearClause(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &linearVars, SmallVectorImpl< Type > &linearTypes, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &linearStepVars)
linear ::= linear ( linear-list ) linear-list := linear-val | linear-val linear-list linear-val := ss...
static void printScheduleClause(OpAsmPrinter &p, Operation *op, ClauseScheduleKindAttr scheduleKind, ScheduleModifierAttr scheduleMod, UnitAttr scheduleSimd, Value scheduleChunk, Type scheduleChunkType)
Print schedule clause.
static void printCopyprivate(OpAsmPrinter &p, Operation *op, OperandRange copyprivateVars, TypeRange copyprivateTypes, std::optional< ArrayAttr > copyprivateSyms)
Print Copyprivate clause.
static ParseResult parseOrderClause(OpAsmParser &parser, ClauseOrderKindAttr &order, OrderModifierAttr &orderMod)
static void printAlignedClause(OpAsmPrinter &p, Operation *op, ValueRange alignedVars, TypeRange alignedTypes, std::optional< ArrayAttr > alignments)
Print Aligned Clause.
static LogicalResult verifySynchronizationHint(Operation *op, uint64_t hint)
Verifies a synchronization hint clause.
static ParseResult parseUseDeviceAddrUseDevicePtrRegion(OpAsmParser &parser, Region ®ion, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &useDeviceAddrVars, SmallVectorImpl< Type > &useDeviceAddrTypes, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &useDevicePtrVars, SmallVectorImpl< Type > &useDevicePtrTypes)
static void printLinearClause(OpAsmPrinter &p, Operation *op, ValueRange linearVars, TypeRange linearTypes, ValueRange linearStepVars)
Print Linear Clause.
static void printInReductionPrivateReductionRegion(OpAsmPrinter &p, Operation *op, Region ®ion, ValueRange inReductionVars, TypeRange inReductionTypes, DenseBoolArrayAttr inReductionByref, ArrayAttr inReductionSyms, ValueRange privateVars, TypeRange privateTypes, ArrayAttr privateSyms, UnitAttr privateNeedsBarrier, ReductionModifierAttr reductionMod, ValueRange reductionVars, TypeRange reductionTypes, DenseBoolArrayAttr reductionByref, ArrayAttr reductionSyms)
static void printInReductionPrivateRegion(OpAsmPrinter &p, Operation *op, Region ®ion, ValueRange inReductionVars, TypeRange inReductionTypes, DenseBoolArrayAttr inReductionByref, ArrayAttr inReductionSyms, ValueRange privateVars, TypeRange privateTypes, ArrayAttr privateSyms, UnitAttr privateNeedsBarrier)
static void printSynchronizationHint(OpAsmPrinter &p, Operation *op, IntegerAttr hintAttr)
Prints a Synchronization Hint clause.
static void printGranularityClause(OpAsmPrinter &p, Operation *op, ClauseTypeAttr prescriptiveness, Value operand, mlir::Type operandType, StringRef(*stringifyClauseType)(ClauseType))
static mlir::ParseResult parseTargetAllocMemOp(mlir::OpAsmParser &parser, mlir::OperationState &result)
operation ::= res = (omp.target_alloc_mem) $device : devicetype, $in_type ( ( $typeparams ) )?...
static void printDependVarList(OpAsmPrinter &p, Operation *op, OperandRange dependVars, TypeRange dependTypes, std::optional< ArrayAttr > dependKinds)
Print Depend clause.
static void printTargetOpRegion(OpAsmPrinter &p, Operation *op, Region ®ion, ValueRange hasDeviceAddrVars, TypeRange hasDeviceAddrTypes, ValueRange hostEvalVars, TypeRange hostEvalTypes, ValueRange inReductionVars, TypeRange inReductionTypes, DenseBoolArrayAttr inReductionByref, ArrayAttr inReductionSyms, ValueRange mapVars, TypeRange mapTypes, ValueRange privateVars, TypeRange privateTypes, ArrayAttr privateSyms, UnitAttr privateNeedsBarrier, DenseI64ArrayAttr privateMaps)
static LogicalResult verifyCopyprivateVarList(Operation *op, OperandRange copyprivateVars, std::optional< ArrayAttr > copyprivateSyms)
Verifies CopyPrivate Clause.
static LogicalResult verifyAlignedClause(Operation *op, std::optional< ArrayAttr > alignments, OperandRange alignedVars)
static ParseResult parseTargetOpRegion(OpAsmParser &parser, Region ®ion, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &hasDeviceAddrVars, SmallVectorImpl< Type > &hasDeviceAddrTypes, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &hostEvalVars, SmallVectorImpl< Type > &hostEvalTypes, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &inReductionVars, SmallVectorImpl< Type > &inReductionTypes, DenseBoolArrayAttr &inReductionByref, ArrayAttr &inReductionSyms, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &mapVars, SmallVectorImpl< Type > &mapTypes, llvm::SmallVectorImpl< OpAsmParser::UnresolvedOperand > &privateVars, llvm::SmallVectorImpl< Type > &privateTypes, ArrayAttr &privateSyms, UnitAttr &privateNeedsBarrier, DenseI64ArrayAttr &privateMaps)
static ParseResult parsePrivateRegion(OpAsmParser &parser, Region ®ion, llvm::SmallVectorImpl< OpAsmParser::UnresolvedOperand > &privateVars, llvm::SmallVectorImpl< Type > &privateTypes, ArrayAttr &privateSyms, UnitAttr &privateNeedsBarrier)
static void printNumTasksClause(OpAsmPrinter &p, Operation *op, ClauseNumTasksTypeAttr numTasksMod, Value numTasks, mlir::Type numTasksType)
static void printLoopTransformClis(OpAsmPrinter &p, TileOp op, OperandRange generatees, OperandRange applyees)
static void printPrivateRegion(OpAsmPrinter &p, Operation *op, Region ®ion, ValueRange privateVars, TypeRange privateTypes, ArrayAttr privateSyms, UnitAttr privateNeedsBarrier)
static void printPrivateReductionRegion(OpAsmPrinter &p, Operation *op, Region ®ion, ValueRange privateVars, TypeRange privateTypes, ArrayAttr privateSyms, UnitAttr privateNeedsBarrier, ReductionModifierAttr reductionMod, ValueRange reductionVars, TypeRange reductionTypes, DenseBoolArrayAttr reductionByref, ArrayAttr reductionSyms)
static void printTaskReductionRegion(OpAsmPrinter &p, Operation *op, Region ®ion, ValueRange taskReductionVars, TypeRange taskReductionTypes, DenseBoolArrayAttr taskReductionByref, ArrayAttr taskReductionSyms)
static ParseResult parseMapClause(OpAsmParser &parser, IntegerAttr &mapType)
Parses a map_entries map type from a string format back into its numeric value.
static LogicalResult verifyOrderedParent(Operation &op)
static void printOrderClause(OpAsmPrinter &p, Operation *op, ClauseOrderKindAttr order, OrderModifierAttr orderMod)
static ParseResult parseBlockArgClause(OpAsmParser &parser, llvm::SmallVectorImpl< OpAsmParser::Argument > &entryBlockArgs, StringRef keyword, std::optional< MapParseArgs > mapArgs)
static ParseResult parseClauseWithRegionArgs(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &operands, SmallVectorImpl< Type > &types, SmallVectorImpl< OpAsmParser::Argument > ®ionPrivateArgs, ArrayAttr *symbols=nullptr, DenseI64ArrayAttr *mapIndices=nullptr, DenseBoolArrayAttr *byref=nullptr, ReductionModifierAttr *modifier=nullptr, UnitAttr *needsBarrier=nullptr)
static ParseResult verifyScheduleModifiers(OpAsmParser &parser, SmallVectorImpl< SmallString< 12 >> &modifiers)
static LogicalResult verifyPrivateVarsMapping(TargetOp targetOp)
static ParseResult parseScheduleClause(OpAsmParser &parser, ClauseScheduleKindAttr &scheduleAttr, ScheduleModifierAttr &scheduleMod, UnitAttr &scheduleSimd, std::optional< OpAsmParser::UnresolvedOperand > &chunkSize, Type &chunkType)
schedule ::= schedule ( sched-list ) sched-list ::= sched-val | sched-val sched-list | sched-val ,...
static void printClauseAttr(OpAsmPrinter &p, Operation *op, ClauseAttr attr)
static Operation * getParentInSameDialect(Operation *thisOp)
static ParseResult parseAllocateAndAllocator(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &allocateVars, SmallVectorImpl< Type > &allocateTypes, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &allocatorVars, SmallVectorImpl< Type > &allocatorTypes)
Parse an allocate clause with allocators and a list of operands with types.
static void printMembersIndex(OpAsmPrinter &p, MapInfoOp op, ArrayAttr membersIdx)
static bool canPromoteToNoLoop(Operation *capturedOp, TeamsOp teamsOp, WsloopOp *wsLoopOp)
Check if we can promote SPMD kernel to No-Loop kernel.
static void printCaptureType(OpAsmPrinter &p, Operation *op, VariableCaptureKindAttr mapCaptureType)
static LogicalResult verifyReductionVarList(Operation *op, std::optional< ArrayAttr > reductionSyms, OperandRange reductionVars, std::optional< ArrayRef< bool >> reductionByref)
Verifies Reduction Clause.
static Operation * findCapturedOmpOp(Operation *rootOp, bool checkSingleMandatoryExec, llvm::function_ref< bool(Operation *)> siblingAllowedFn)
static bool opInGlobalImplicitParallelRegion(Operation *op)
static void printUseDeviceAddrUseDevicePtrRegion(OpAsmPrinter &p, Operation *op, Region ®ion, ValueRange useDeviceAddrVars, TypeRange useDeviceAddrTypes, ValueRange useDevicePtrVars, TypeRange useDevicePtrTypes)
static LogicalResult verifyPrivateVarList(OpType &op)
static void printMapClause(OpAsmPrinter &p, Operation *op, IntegerAttr mapType)
Prints a map_entries map type from its numeric value out into its string format.
static ParseResult parseNumTasksClause(OpAsmParser &parser, ClauseNumTasksTypeAttr &numTasksMod, std::optional< OpAsmParser::UnresolvedOperand > &numTasks, Type &numTasksType)
static ParseResult parseMembersIndex(OpAsmParser &parser, ArrayAttr &membersIdx)
static ParseResult parseAlignedClause(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &alignedVars, SmallVectorImpl< Type > &alignedTypes, ArrayAttr &alignmentsAttr)
aligned ::= aligned ( aligned-list ) aligned-list := aligned-val | aligned-val aligned-list aligned-v...
static ParseResult parsePrivateReductionRegion(OpAsmParser &parser, Region ®ion, llvm::SmallVectorImpl< OpAsmParser::UnresolvedOperand > &privateVars, llvm::SmallVectorImpl< Type > &privateTypes, ArrayAttr &privateSyms, UnitAttr &privateNeedsBarrier, ReductionModifierAttr &reductionMod, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &reductionVars, SmallVectorImpl< Type > &reductionTypes, DenseBoolArrayAttr &reductionByref, ArrayAttr &reductionSyms)
static ParseResult parseInReductionPrivateReductionRegion(OpAsmParser &parser, Region ®ion, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &inReductionVars, SmallVectorImpl< Type > &inReductionTypes, DenseBoolArrayAttr &inReductionByref, ArrayAttr &inReductionSyms, llvm::SmallVectorImpl< OpAsmParser::UnresolvedOperand > &privateVars, llvm::SmallVectorImpl< Type > &privateTypes, ArrayAttr &privateSyms, UnitAttr &privateNeedsBarrier, ReductionModifierAttr &reductionMod, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &reductionVars, SmallVectorImpl< Type > &reductionTypes, DenseBoolArrayAttr &reductionByref, ArrayAttr &reductionSyms)
static ParseResult parseCaptureType(OpAsmParser &parser, VariableCaptureKindAttr &mapCaptureType)
static ParseResult parseTaskReductionRegion(OpAsmParser &parser, Region ®ion, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &taskReductionVars, SmallVectorImpl< Type > &taskReductionTypes, DenseBoolArrayAttr &taskReductionByref, ArrayAttr &taskReductionSyms)
static ParseResult parseGrainsizeClause(OpAsmParser &parser, ClauseGrainsizeTypeAttr &grainsizeMod, std::optional< OpAsmParser::UnresolvedOperand > &grainsize, Type &grainsizeType)
static ParseResult parseCopyprivate(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > ©privateVars, SmallVectorImpl< Type > ©privateTypes, ArrayAttr ©privateSyms)
copyprivate-entry-list ::= copyprivate-entry | copyprivate-entry-list , copyprivate-entry copyprivate...
static LogicalResult verifyDependVarList(Operation *op, std::optional< ArrayAttr > dependKinds, OperandRange dependVars)
Verifies Depend clause.
static ParseResult parseDependVarList(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &dependVars, SmallVectorImpl< Type > &dependTypes, ArrayAttr &dependKinds)
depend-entry-list ::= depend-entry | depend-entry-list , depend-entry depend-entry ::= depend-kind ->...
static LogicalResult verifyMapInfoDefinedArgs(Operation *op, StringRef clauseName, OperandRange vars)
static void printGrainsizeClause(OpAsmPrinter &p, Operation *op, ClauseGrainsizeTypeAttr grainsizeMod, Value grainsize, mlir::Type grainsizeType)
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
bool isUnique(It begin, It end)
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
This base class exposes generic asm parser hooks, usable across the various derived parsers.
virtual ParseResult parseMinus()=0
Parse a '-' token.
@ Paren
Parens surrounding zero or more operands.
@ None
Zero or more operands with no delimiters.
virtual ParseResult parseColonTypeList(SmallVectorImpl< Type > &result)=0
Parse a colon followed by a type list, which must have at least one type.
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalEqual()=0
Parse a = token if present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseOptionalColon()=0
Parse a : token if present.
virtual ParseResult parseLSquare()=0
Parse a [ token.
virtual ParseResult parseRSquare()=0
Parse a ] token.
ParseResult parseInteger(IntT &result)
Parse an integer value from the stream.
virtual ParseResult parseOptionalArrow()=0
Parse a '->' token if present.
ParseResult addTypeToList(Type type, SmallVectorImpl< Type > &result)
Add the specified type to the end of the specified type list and return success.
virtual ParseResult parseLess()=0
Parse a '<' token.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual ParseResult parseColon()=0
Parse a : token.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseOptionalLess()=0
Parse a '<' token if present.
virtual ParseResult parseArrow()=0
Parse a '->' token.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseComma()=0
Parse a , token.
virtual ParseResult parseOptionalLParen()=0
Parse a ( token if present.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
Attributes are known-constant values of operations.
Block represents an ordered list of Operations.
ValueTypeRange< BlockArgListType > getArgumentTypes()
Return a range containing the types of the arguments for this block.
BlockArgument getArgument(unsigned i)
unsigned getNumArguments()
SuccessorRange getSuccessors()
BlockArgListType getArguments()
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
IntegerAttr getIntegerAttr(Type type, int64_t value)
IntegerType getIntegerType(unsigned width)
MLIRContext * getContext() const
Diagnostic & append(Arg1 &&arg1, Arg2 &&arg2, Args &&...args)
Append arguments to the diagnostic.
Diagnostic & appendOp(Operation &op, const OpPrintingFlags &flags)
Append an operation with the given printing flags.
Dialects are groups of MLIR operations, types and attributes, as well as behavior associated with the...
A class for computing basic dominance information.
bool dominates(Operation *a, Operation *b) const
Return true if operation A dominates operation B, i.e.
This class represents a diagnostic that is inflight and set to be reported.
Diagnostic & attachNote(std::optional< Location > noteLoc=std::nullopt)
Attaches a note to this diagnostic.
MLIRContext is the top-level object for a collection of MLIR operations.
Dialect * getLoadedDialect(StringRef name)
Get a registered IR dialect with the given namespace.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region ®ion, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult parseArgument(Argument &result, bool allowType=false, bool allowAttrs=false)=0
Parse a single argument with the following syntax:
virtual ParseResult parseArgumentList(SmallVectorImpl< Argument > &result, Delimiter delimiter=Delimiter::None, bool allowType=false, bool allowAttrs=false)=0
Parse zero or more arguments with a specified surrounding delimiter.
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
virtual void printOperand(Value value)=0
Print implementations for various things an operation contains.
This class helps build Operations.
This class represents an operand of an operation.
Set of flags used to control the behavior of the various IR print methods (e.g.
This class provides the API for ops that are known to be isolated from above.
This class provides the API for ops that are known to be terminators.
This class indicates that the regions associated with this op don't have terminators.
This class implements the operand iterators for the Operation class.
Operation is the basic unit of execution within MLIR.
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Dialect * getDialect()
Return the dialect this operation is associated with, or nullptr if the associated dialect is not loa...
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
MLIRContext * getContext()
Return the context this operation is associated with.
unsigned getNumRegions()
Returns the number of regions held by this operation.
Location getLoc()
The source location the operation was defined or derived from.
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Block * getBlock()
Returns the operation block that contains this operation.
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
user_range getUsers()
Returns a range of all users.
Region * getParentRegion()
Returns the region to which the instruction belongs.
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class contains a list of basic blocks and a link to the parent operation it is attached to.
iterator_range< OpIterator > getOps()
BlockArgListType getArguments()
OpIterator op_begin()
Return iterators that walk the operations nested directly within this region.
Operation * getParentOp()
Return the parent operation this region is attached to.
unsigned getNumArguments()
BlockListType & getBlocks()
Location getLoc()
Return a location for this region.
BlockArgument getArgument(unsigned i)
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class represents a specific instance of an effect.
Resource * getResource() const
Return the resource that the effect applies to.
EffectT * getEffect() const
Return the effect being applied.
This class represents a collection of SymbolTables.
virtual Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
type_range getType() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
MLIRContext * getContext() const
Utility to get the associated MLIRContext that this value is defined in.
Type getType() const
Return the type of this value.
use_range getUses() const
Returns a range of all uses, which is useful for iterating over all uses.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
A utility result that is used to signal how to proceed with an ongoing walk:
Base class for DenseArrayAttr that is instantiated and specialized for each supported element type be...
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< T > content)
Builder from ArrayRef<T>.
bool isReachableFromEntry(Block *a) const
Return true if the specified block is reachable from the entry block of its region.
Operation * getOwner() const
Return the owner of this operand.
bool LLVM_ATTRIBUTE_UNUSED isPerfectlyNested(ArrayRef< AffineForOp > loops)
Returns true if loops is a perfectly nested loop nest, where loops appear in it from outermost to inn...
SmallVector< SmallVector< AffineForOp, 8 >, 8 > tile(ArrayRef< AffineForOp > forOps, ArrayRef< uint64_t > sizes, ArrayRef< AffineForOp > targets)
Performs tiling fo imperfectly nested loops (with interchange) by strip-mining the forOps by sizes an...
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Runtime
Potential runtimes for AMD GPU kernels.
std::tuple< NewCliOp, OpOperand *, OpOperand * > decodeCli(mlir::Value cli)
Find the omp.new_cli, generator, and consumer of a canonical loop info.
TargetEnterDataOperands TargetEnterExitUpdateDataOperands
omp.target_enter_data, omp.target_exit_data and omp.target_update take the same clauses,...
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
This is the representation of an operand reference.
This class provides APIs and verifiers for ops with regions having a single block.
This represents an operation in an abstracted form, suitable for use with the builder APIs.
SmallVector< Value, 4 > operands
void addOperands(ValueRange newOperands)
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
void addTypes(ArrayRef< Type > newTypes)
SmallVector< Type, 4 > types
Types of the results of this operation.
Region * addRegion()
Create a region that should be attached to the operation.