26 #include "llvm/ADT/ArrayRef.h"
27 #include "llvm/ADT/PostOrderIterator.h"
28 #include "llvm/ADT/STLExtras.h"
29 #include "llvm/ADT/STLForwardCompat.h"
30 #include "llvm/ADT/SmallString.h"
31 #include "llvm/ADT/StringExtras.h"
32 #include "llvm/ADT/StringRef.h"
33 #include "llvm/ADT/TypeSwitch.h"
34 #include "llvm/ADT/bit.h"
35 #include "llvm/Support/InterleavedRange.h"
41 #include "mlir/Dialect/OpenMP/OpenMPOpsDialect.cpp.inc"
42 #include "mlir/Dialect/OpenMP/OpenMPOpsEnums.cpp.inc"
43 #include "mlir/Dialect/OpenMP/OpenMPOpsInterfaces.cpp.inc"
44 #include "mlir/Dialect/OpenMP/OpenMPTypeInterfaces.cpp.inc"
65 struct MemRefPointerLikeModel
66 :
public PointerLikeType::ExternalModel<MemRefPointerLikeModel,
69 return llvm::cast<MemRefType>(pointer).getElementType();
73 struct LLVMPointerPointerLikeModel
74 :
public PointerLikeType::ExternalModel<LLVMPointerPointerLikeModel,
75 LLVM::LLVMPointerType> {
100 bool isRegionArgOfOp;
110 assert(isRegionArgOfOp &&
"Must describe a region operand");
113 size_t &getArgIdx() {
114 assert(isRegionArgOfOp &&
"Must describe a region operand");
119 assert(!isRegionArgOfOp &&
"Must describe a operation of a region");
123 assert(!isRegionArgOfOp &&
"Must describe a operation of a region");
126 bool isLoopOp()
const {
127 assert(!isRegionArgOfOp &&
"Must describe a operation of a region");
128 return isa<CanonicalLoopOp>(op);
130 Region *&getParentRegion() {
131 assert(!isRegionArgOfOp &&
"Must describe a operation of a region");
134 size_t &getLoopDepth() {
135 assert(!isRegionArgOfOp &&
"Must describe a operation of a region");
139 void skipIf(
bool v =
true) { skip = skip || v; }
157 llvm::ReversePostOrderTraversal<Block *> traversal(&r->
getBlocks().front());
160 size_t sequentialIdx = -1;
161 bool isOnlyContainerOp =
true;
162 for (
Block *b : traversal) {
164 if (&op == o && !found) {
168 if (op.getNumRegions()) {
171 isOnlyContainerOp =
false;
173 if (found && !isOnlyContainerOp)
178 Component &containerOpInRegion = components.emplace_back();
179 containerOpInRegion.isRegionArgOfOp =
false;
180 containerOpInRegion.isUnique = isOnlyContainerOp;
181 containerOpInRegion.getContainerOp() = o;
182 containerOpInRegion.getOpPos() = sequentialIdx;
183 containerOpInRegion.getParentRegion() = r;
188 Component ®ionArgOfOperation = components.emplace_back();
189 regionArgOfOperation.isRegionArgOfOp =
true;
190 regionArgOfOperation.isUnique =
true;
191 regionArgOfOperation.getArgIdx() = 0;
192 regionArgOfOperation.getOwnerOp() = parent;
208 llvm_unreachable(
"Region not child of its parent operation");
210 regionArgOfOperation.isUnique =
false;
211 regionArgOfOperation.getArgIdx() = getRegionIndex(parent, r);
219 for (Component &c : components)
220 c.skipIf(c.isRegionArgOfOp && c.isUnique);
223 size_t numSurroundingLoops = 0;
224 for (Component &c : llvm::reverse(components)) {
229 if (c.isRegionArgOfOp) {
230 numSurroundingLoops = 0;
237 numSurroundingLoops = 0;
239 c.getLoopDepth() = numSurroundingLoops;
242 if (isa<CanonicalLoopOp>(c.getContainerOp()))
243 numSurroundingLoops += 1;
248 bool isLoopNest =
false;
249 for (Component &c : components) {
250 if (c.skip || c.isRegionArgOfOp)
253 if (!isLoopNest && c.getLoopDepth() >= 1) {
256 }
else if (isLoopNest) {
258 c.skipIf(c.isUnique);
262 if (c.getLoopDepth() == 0)
269 for (Component &c : components)
270 c.skipIf(!c.isRegionArgOfOp && c.isUnique &&
271 !isa<CanonicalLoopOp>(c.getContainerOp()));
275 bool newRegion =
true;
276 for (Component &c : llvm::reverse(components)) {
277 c.skipIf(newRegion && c.isUnique);
284 if (!c.isRegionArgOfOp && c.getContainerOp())
290 llvm::raw_svector_ostream NameOS(Name);
291 for (
auto &c : llvm::reverse(components)) {
295 if (c.isRegionArgOfOp)
296 NameOS <<
"_r" << c.getArgIdx();
297 else if (c.getLoopDepth() >= 1)
298 NameOS <<
"_d" << c.getLoopDepth();
300 NameOS <<
"_s" << c.getOpPos();
303 return NameOS.str().str();
306 void OpenMPDialect::initialize() {
309 #include "mlir/Dialect/OpenMP/OpenMPOps.cpp.inc"
312 #define GET_ATTRDEF_LIST
313 #include "mlir/Dialect/OpenMP/OpenMPOpsAttributes.cpp.inc"
316 #define GET_TYPEDEF_LIST
317 #include "mlir/Dialect/OpenMP/OpenMPOpsTypes.cpp.inc"
320 declarePromisedInterface<ConvertToLLVMPatternInterface, OpenMPDialect>();
322 MemRefType::attachInterface<MemRefPointerLikeModel>(*
getContext());
323 LLVM::LLVMPointerType::attachInterface<LLVMPointerPointerLikeModel>(
328 mlir::ModuleOp::attachInterface<mlir::omp::OffloadModuleDefaultModel>(
334 mlir::LLVM::GlobalOp::attachInterface<
337 mlir::LLVM::LLVMFuncOp::attachInterface<
340 mlir::func::FuncOp::attachInterface<
366 allocatorVars.push_back(operand);
367 allocatorTypes.push_back(type);
373 allocateVars.push_back(operand);
374 allocateTypes.push_back(type);
385 for (
unsigned i = 0; i < allocateVars.size(); ++i) {
386 std::string separator = i == allocateVars.size() - 1 ?
"" :
", ";
387 p << allocatorVars[i] <<
" : " << allocatorTypes[i] <<
" -> ";
388 p << allocateVars[i] <<
" : " << allocateTypes[i] << separator;
396 template <
typename ClauseAttr>
398 using ClauseT = decltype(std::declval<ClauseAttr>().getValue());
403 if (std::optional<ClauseT> enumValue = symbolizeEnum<ClauseT>(enumStr)) {
407 return parser.
emitError(loc,
"invalid clause value: '") << enumStr <<
"'";
410 template <
typename ClauseAttr>
412 p << stringifyEnum(attr.getValue());
435 linearVars.push_back(var);
436 linearTypes.push_back(type);
437 linearStepVars.push_back(stepVar);
446 size_t linearVarsSize = linearVars.size();
447 for (
unsigned i = 0; i < linearVarsSize; ++i) {
448 std::string separator = i == linearVarsSize - 1 ?
"" :
", ";
450 if (linearStepVars.size() > i)
451 p <<
" = " << linearStepVars[i];
452 p <<
" : " << linearVars[i].getType() << separator;
465 for (
const auto &it : nontemporalVars)
466 if (!nontemporalItems.insert(it).second)
467 return op->
emitOpError() <<
"nontemporal variable used more than once";
476 std::optional<ArrayAttr> alignments,
479 if (!alignedVars.empty()) {
480 if (!alignments || alignments->size() != alignedVars.size())
482 <<
"expected as many alignment values as aligned variables";
485 return op->
emitOpError() <<
"unexpected alignment values attribute";
491 for (
auto it : alignedVars)
492 if (!alignedItems.insert(it).second)
493 return op->
emitOpError() <<
"aligned variable used more than once";
499 for (
unsigned i = 0; i < (*alignments).size(); ++i) {
500 if (
auto intAttr = llvm::dyn_cast<IntegerAttr>((*alignments)[i])) {
501 if (intAttr.getValue().sle(0))
502 return op->
emitOpError() <<
"alignment should be greater than 0";
504 return op->
emitOpError() <<
"expected integer alignment";
518 ArrayAttr &alignmentsAttr) {
521 if (parser.parseOperand(alignedVars.emplace_back()) ||
522 parser.parseColonType(alignedTypes.emplace_back()) ||
523 parser.parseArrow() ||
524 parser.parseAttribute(alignmentVec.emplace_back())) {
538 std::optional<ArrayAttr> alignments) {
539 for (
unsigned i = 0; i < alignedVars.size(); ++i) {
542 p << alignedVars[i] <<
" : " << alignedVars[i].
getType();
543 p <<
" -> " << (*alignments)[i];
554 if (modifiers.size() > 2)
556 for (
const auto &mod : modifiers) {
559 auto symbol = symbolizeScheduleModifier(mod);
562 <<
" unknown modifier type: " << mod;
567 if (modifiers.size() == 1) {
568 if (symbolizeScheduleModifier(modifiers[0]) == ScheduleModifier::simd) {
569 modifiers.push_back(modifiers[0]);
570 modifiers[0] = stringifyScheduleModifier(ScheduleModifier::none);
572 }
else if (modifiers.size() == 2) {
575 if (symbolizeScheduleModifier(modifiers[0]) == ScheduleModifier::simd ||
576 symbolizeScheduleModifier(modifiers[1]) != ScheduleModifier::simd)
578 <<
" incorrect modifier order";
594 ScheduleModifierAttr &scheduleMod, UnitAttr &scheduleSimd,
595 std::optional<OpAsmParser::UnresolvedOperand> &chunkSize,
600 std::optional<mlir::omp::ClauseScheduleKind> schedule =
601 symbolizeClauseScheduleKind(keyword);
607 case ClauseScheduleKind::Static:
608 case ClauseScheduleKind::Dynamic:
609 case ClauseScheduleKind::Guided:
615 chunkSize = std::nullopt;
618 case ClauseScheduleKind::Auto:
620 chunkSize = std::nullopt;
629 modifiers.push_back(mod);
635 if (!modifiers.empty()) {
637 if (std::optional<ScheduleModifier> mod =
638 symbolizeScheduleModifier(modifiers[0])) {
641 return parser.
emitError(loc,
"invalid schedule modifier");
644 if (modifiers.size() > 1) {
645 assert(symbolizeScheduleModifier(modifiers[1]) == ScheduleModifier::simd);
655 ClauseScheduleKindAttr scheduleKind,
656 ScheduleModifierAttr scheduleMod,
657 UnitAttr scheduleSimd,
Value scheduleChunk,
658 Type scheduleChunkType) {
659 p << stringifyClauseScheduleKind(scheduleKind.getValue());
661 p <<
" = " << scheduleChunk <<
" : " << scheduleChunk.
getType();
663 p <<
", " << stringifyScheduleModifier(scheduleMod.getValue());
675 ClauseOrderKindAttr &order,
676 OrderModifierAttr &orderMod) {
681 if (std::optional<OrderModifier> enumValue =
682 symbolizeOrderModifier(enumStr)) {
690 if (std::optional<ClauseOrderKind> enumValue =
691 symbolizeClauseOrderKind(enumStr)) {
695 return parser.
emitError(loc,
"invalid clause value: '") << enumStr <<
"'";
699 ClauseOrderKindAttr order,
700 OrderModifierAttr orderMod) {
702 p << stringifyOrderModifier(orderMod.getValue()) <<
":";
704 p << stringifyClauseOrderKind(order.getValue());
707 template <
typename ClauseTypeAttr,
typename ClauseType>
710 std::optional<OpAsmParser::UnresolvedOperand> &operand,
712 std::optional<ClauseType> (*symbolizeClause)(StringRef),
713 StringRef clauseName) {
716 if (std::optional<ClauseType> enumValue = symbolizeClause(enumStr)) {
722 <<
"invalid " << clauseName <<
" modifier : '" << enumStr <<
"'";
732 <<
"expected " << clauseName <<
" operand";
735 if (operand.has_value()) {
743 template <
typename ClauseTypeAttr,
typename ClauseType>
746 ClauseTypeAttr prescriptiveness,
Value operand,
748 StringRef (*stringifyClauseType)(ClauseType)) {
750 if (prescriptiveness)
751 p << stringifyClauseType(prescriptiveness.getValue()) <<
", ";
754 p << operand <<
": " << operandType;
764 std::optional<OpAsmParser::UnresolvedOperand> &grainsize,
765 Type &grainsizeType) {
766 return parseGranularityClause<ClauseGrainsizeTypeAttr, ClauseGrainsizeType>(
767 parser, grainsizeMod, grainsize, grainsizeType,
768 &symbolizeClauseGrainsizeType,
"grainsize");
772 ClauseGrainsizeTypeAttr grainsizeMod,
774 printGranularityClause<ClauseGrainsizeTypeAttr, ClauseGrainsizeType>(
775 p, op, grainsizeMod, grainsize, grainsizeType,
776 &stringifyClauseGrainsizeType);
786 std::optional<OpAsmParser::UnresolvedOperand> &numTasks,
787 Type &numTasksType) {
788 return parseGranularityClause<ClauseNumTasksTypeAttr, ClauseNumTasksType>(
789 parser, numTasksMod, numTasks, numTasksType, &symbolizeClauseNumTasksType,
794 ClauseNumTasksTypeAttr numTasksMod,
796 printGranularityClause<ClauseNumTasksTypeAttr, ClauseNumTasksType>(
797 p, op, numTasksMod, numTasks, numTasksType, &stringifyClauseNumTasksType);
805 struct MapParseArgs {
810 : vars(vars), types(types) {}
812 struct PrivateParseArgs {
816 UnitAttr &needsBarrier;
820 UnitAttr &needsBarrier,
822 : vars(vars), types(types), syms(syms), needsBarrier(needsBarrier),
823 mapIndices(mapIndices) {}
826 struct ReductionParseArgs {
831 ReductionModifierAttr *modifier;
834 ArrayAttr &syms, ReductionModifierAttr *mod =
nullptr)
835 : vars(vars), types(types), byref(byref), syms(syms), modifier(mod) {}
838 struct AllRegionParseArgs {
839 std::optional<MapParseArgs> hasDeviceAddrArgs;
840 std::optional<MapParseArgs> hostEvalArgs;
841 std::optional<ReductionParseArgs> inReductionArgs;
842 std::optional<MapParseArgs> mapArgs;
843 std::optional<PrivateParseArgs> privateArgs;
844 std::optional<ReductionParseArgs> reductionArgs;
845 std::optional<ReductionParseArgs> taskReductionArgs;
846 std::optional<MapParseArgs> useDeviceAddrArgs;
847 std::optional<MapParseArgs> useDevicePtrArgs;
852 return "private_barrier";
862 ReductionModifierAttr *modifier =
nullptr,
863 UnitAttr *needsBarrier =
nullptr) {
867 unsigned regionArgOffset = regionPrivateArgs.size();
877 std::optional<ReductionModifier> enumValue =
878 symbolizeReductionModifier(enumStr);
879 if (!enumValue.has_value())
888 isByRefVec.push_back(
889 parser.parseOptionalKeyword(
"byref").succeeded());
891 if (symbols && parser.parseAttribute(symbolVec.emplace_back()))
894 if (parser.parseOperand(operands.emplace_back()) ||
895 parser.parseArrow() ||
896 parser.parseArgument(regionPrivateArgs.emplace_back()))
900 if (parser.parseOptionalLSquare().succeeded()) {
901 if (parser.parseKeyword(
"map_idx") || parser.parseEqual() ||
902 parser.parseInteger(mapIndicesVec.emplace_back()) ||
903 parser.parseRSquare())
906 mapIndicesVec.push_back(-1);
918 if (parser.parseType(types.emplace_back()))
925 if (operands.size() != types.size())
937 auto *argsBegin = regionPrivateArgs.begin();
939 argsBegin + regionArgOffset + types.size());
940 for (
auto [prv, type] : llvm::zip_equal(argsSubrange, types)) {
949 if (!mapIndicesVec.empty())
962 StringRef keyword, std::optional<MapParseArgs> mapArgs) {
977 StringRef keyword, std::optional<PrivateParseArgs> privateArgs) {
983 parser, privateArgs->vars, privateArgs->types, entryBlockArgs,
984 &privateArgs->syms, privateArgs->mapIndices,
nullptr,
985 nullptr, &privateArgs->needsBarrier)))
994 StringRef keyword, std::optional<ReductionParseArgs> reductionArgs) {
999 parser, reductionArgs->vars, reductionArgs->types, entryBlockArgs,
1000 &reductionArgs->syms,
nullptr, &reductionArgs->byref,
1001 reductionArgs->modifier)))
1008 AllRegionParseArgs args) {
1012 args.hasDeviceAddrArgs)))
1014 <<
"invalid `has_device_addr` format";
1017 args.hostEvalArgs)))
1019 <<
"invalid `host_eval` format";
1022 args.inReductionArgs)))
1024 <<
"invalid `in_reduction` format";
1029 <<
"invalid `map_entries` format";
1034 <<
"invalid `private` format";
1037 args.reductionArgs)))
1039 <<
"invalid `reduction` format";
1042 args.taskReductionArgs)))
1044 <<
"invalid `task_reduction` format";
1047 args.useDeviceAddrArgs)))
1049 <<
"invalid `use_device_addr` format";
1052 args.useDevicePtrArgs)))
1054 <<
"invalid `use_device_addr` format";
1056 return parser.
parseRegion(region, entryBlockArgs);
1075 AllRegionParseArgs args;
1076 args.hasDeviceAddrArgs.emplace(hasDeviceAddrVars, hasDeviceAddrTypes);
1077 args.hostEvalArgs.emplace(hostEvalVars, hostEvalTypes);
1078 args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
1079 inReductionByref, inReductionSyms);
1080 args.mapArgs.emplace(mapVars, mapTypes);
1081 args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
1082 privateNeedsBarrier, &privateMaps);
1093 UnitAttr &privateNeedsBarrier) {
1094 AllRegionParseArgs args;
1095 args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
1096 inReductionByref, inReductionSyms);
1097 args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
1098 privateNeedsBarrier);
1109 UnitAttr &privateNeedsBarrier, ReductionModifierAttr &reductionMod,
1112 ArrayAttr &reductionSyms) {
1113 AllRegionParseArgs args;
1114 args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
1115 inReductionByref, inReductionSyms);
1116 args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
1117 privateNeedsBarrier);
1118 args.reductionArgs.emplace(reductionVars, reductionTypes, reductionByref,
1119 reductionSyms, &reductionMod);
1127 UnitAttr &privateNeedsBarrier) {
1128 AllRegionParseArgs args;
1129 args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
1130 privateNeedsBarrier);
1138 UnitAttr &privateNeedsBarrier, ReductionModifierAttr &reductionMod,
1141 ArrayAttr &reductionSyms) {
1142 AllRegionParseArgs args;
1143 args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
1144 privateNeedsBarrier);
1145 args.reductionArgs.emplace(reductionVars, reductionTypes, reductionByref,
1146 reductionSyms, &reductionMod);
1155 AllRegionParseArgs args;
1156 args.taskReductionArgs.emplace(taskReductionVars, taskReductionTypes,
1157 taskReductionByref, taskReductionSyms);
1167 AllRegionParseArgs args;
1168 args.useDeviceAddrArgs.emplace(useDeviceAddrVars, useDeviceAddrTypes);
1169 args.useDevicePtrArgs.emplace(useDevicePtrVars, useDevicePtrTypes);
1178 struct MapPrintArgs {
1183 struct PrivatePrintArgs {
1187 UnitAttr needsBarrier;
1191 : vars(vars), types(types), syms(syms), needsBarrier(needsBarrier),
1192 mapIndices(mapIndices) {}
1194 struct ReductionPrintArgs {
1199 ReductionModifierAttr modifier;
1201 ArrayAttr syms, ReductionModifierAttr mod =
nullptr)
1202 : vars(vars), types(types), byref(byref), syms(syms), modifier(mod) {}
1204 struct AllRegionPrintArgs {
1205 std::optional<MapPrintArgs> hasDeviceAddrArgs;
1206 std::optional<MapPrintArgs> hostEvalArgs;
1207 std::optional<ReductionPrintArgs> inReductionArgs;
1208 std::optional<MapPrintArgs> mapArgs;
1209 std::optional<PrivatePrintArgs> privateArgs;
1210 std::optional<ReductionPrintArgs> reductionArgs;
1211 std::optional<ReductionPrintArgs> taskReductionArgs;
1212 std::optional<MapPrintArgs> useDeviceAddrArgs;
1213 std::optional<MapPrintArgs> useDevicePtrArgs;
1222 ReductionModifierAttr modifier =
nullptr, UnitAttr needsBarrier =
nullptr) {
1223 if (argsSubrange.empty())
1226 p << clauseName <<
"(";
1229 p <<
"mod: " << stringifyReductionModifier(modifier.getValue()) <<
", ";
1246 llvm::interleaveComma(llvm::zip_equal(operands, argsSubrange, symbols,
1247 mapIndices.asArrayRef(),
1248 byref.asArrayRef()),
1250 auto [op, arg, sym, map, isByRef] = t;
1256 p << op <<
" -> " << arg;
1259 p <<
" [map_idx=" << map <<
"]";
1262 llvm::interleaveComma(types, p);
1270 StringRef clauseName,
ValueRange argsSubrange,
1271 std::optional<MapPrintArgs> mapArgs) {
1278 StringRef clauseName,
ValueRange argsSubrange,
1279 std::optional<PrivatePrintArgs> privateArgs) {
1282 p, ctx, clauseName, argsSubrange, privateArgs->vars, privateArgs->types,
1283 privateArgs->syms, privateArgs->mapIndices,
nullptr,
1284 nullptr, privateArgs->needsBarrier);
1290 std::optional<ReductionPrintArgs> reductionArgs) {
1293 reductionArgs->vars, reductionArgs->types,
1294 reductionArgs->syms,
nullptr,
1295 reductionArgs->byref, reductionArgs->modifier);
1299 const AllRegionPrintArgs &args) {
1300 auto iface = llvm::cast<mlir::omp::BlockArgOpenMPOpInterface>(op);
1304 iface.getHasDeviceAddrBlockArgs(),
1305 args.hasDeviceAddrArgs);
1309 args.inReductionArgs);
1315 args.reductionArgs);
1317 iface.getTaskReductionBlockArgs(),
1318 args.taskReductionArgs);
1320 iface.getUseDeviceAddrBlockArgs(),
1321 args.useDeviceAddrArgs);
1323 iface.getUseDevicePtrBlockArgs(), args.useDevicePtrArgs);
1337 TypeRange privateTypes, ArrayAttr privateSyms, UnitAttr privateNeedsBarrier,
1339 AllRegionPrintArgs args;
1340 args.hasDeviceAddrArgs.emplace(hasDeviceAddrVars, hasDeviceAddrTypes);
1341 args.hostEvalArgs.emplace(hostEvalVars, hostEvalTypes);
1342 args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
1343 inReductionByref, inReductionSyms);
1344 args.mapArgs.emplace(mapVars, mapTypes);
1345 args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
1346 privateNeedsBarrier, privateMaps);
1354 ArrayAttr privateSyms, UnitAttr privateNeedsBarrier) {
1355 AllRegionPrintArgs args;
1356 args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
1357 inReductionByref, inReductionSyms);
1358 args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
1359 privateNeedsBarrier,
1368 ArrayAttr privateSyms, UnitAttr privateNeedsBarrier,
1369 ReductionModifierAttr reductionMod,
ValueRange reductionVars,
1371 ArrayAttr reductionSyms) {
1372 AllRegionPrintArgs args;
1373 args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
1374 inReductionByref, inReductionSyms);
1375 args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
1376 privateNeedsBarrier,
1378 args.reductionArgs.emplace(reductionVars, reductionTypes, reductionByref,
1379 reductionSyms, reductionMod);
1385 ArrayAttr privateSyms,
1386 UnitAttr privateNeedsBarrier) {
1387 AllRegionPrintArgs args;
1388 args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
1389 privateNeedsBarrier,
1396 TypeRange privateTypes, ArrayAttr privateSyms, UnitAttr privateNeedsBarrier,
1397 ReductionModifierAttr reductionMod,
ValueRange reductionVars,
1399 ArrayAttr reductionSyms) {
1400 AllRegionPrintArgs args;
1401 args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
1402 privateNeedsBarrier,
1404 args.reductionArgs.emplace(reductionVars, reductionTypes, reductionByref,
1405 reductionSyms, reductionMod);
1414 ArrayAttr taskReductionSyms) {
1415 AllRegionPrintArgs args;
1416 args.taskReductionArgs.emplace(taskReductionVars, taskReductionTypes,
1417 taskReductionByref, taskReductionSyms);
1427 AllRegionPrintArgs args;
1428 args.useDeviceAddrArgs.emplace(useDeviceAddrVars, useDeviceAddrTypes);
1429 args.useDevicePtrArgs.emplace(useDevicePtrVars, useDevicePtrTypes);
1434 static LogicalResult
1438 if (!reductionVars.empty()) {
1439 if (!reductionSyms || reductionSyms->size() != reductionVars.size())
1441 <<
"expected as many reduction symbol references "
1442 "as reduction variables";
1443 if (reductionByref && reductionByref->size() != reductionVars.size())
1444 return op->
emitError() <<
"expected as many reduction variable by "
1445 "reference attributes as reduction variables";
1448 return op->
emitOpError() <<
"unexpected reduction symbol references";
1455 for (
auto args : llvm::zip(reductionVars, *reductionSyms)) {
1456 Value accum = std::get<0>(args);
1458 if (!accumulators.insert(accum).second)
1459 return op->
emitOpError() <<
"accumulator variable used more than once";
1462 auto symbolRef = llvm::cast<SymbolRefAttr>(std::get<1>(args));
1464 SymbolTable::lookupNearestSymbolFrom<DeclareReductionOp>(op, symbolRef);
1466 return op->
emitOpError() <<
"expected symbol reference " << symbolRef
1467 <<
" to point to a reduction declaration";
1469 if (decl.getAccumulatorType() && decl.getAccumulatorType() != varType)
1471 <<
"expected accumulator (" << varType
1472 <<
") to be the same type as reduction declaration ("
1473 << decl.getAccumulatorType() <<
")";
1492 if (parser.parseOperand(copyprivateVars.emplace_back()) ||
1493 parser.parseArrow() ||
1494 parser.parseAttribute(symsVec.emplace_back()) ||
1495 parser.parseColonType(copyprivateTypes.emplace_back()))
1509 std::optional<ArrayAttr> copyprivateSyms) {
1510 if (!copyprivateSyms.has_value())
1512 llvm::interleaveComma(
1513 llvm::zip(copyprivateVars, *copyprivateSyms, copyprivateTypes), p,
1514 [&](
const auto &args) {
1515 p << std::get<0>(args) <<
" -> " << std::get<1>(args) <<
" : "
1516 << std::get<2>(args);
1521 static LogicalResult
1523 std::optional<ArrayAttr> copyprivateSyms) {
1524 size_t copyprivateSymsSize =
1525 copyprivateSyms.has_value() ? copyprivateSyms->size() : 0;
1526 if (copyprivateSymsSize != copyprivateVars.size())
1527 return op->
emitOpError() <<
"inconsistent number of copyprivate vars (= "
1528 << copyprivateVars.size()
1529 <<
") and functions (= " << copyprivateSymsSize
1530 <<
"), both must be equal";
1531 if (!copyprivateSyms.has_value())
1534 for (
auto copyprivateVarAndSym :
1535 llvm::zip(copyprivateVars, *copyprivateSyms)) {
1537 llvm::cast<SymbolRefAttr>(std::get<1>(copyprivateVarAndSym));
1538 std::optional<std::variant<mlir::func::FuncOp, mlir::LLVM::LLVMFuncOp>>
1540 if (mlir::func::FuncOp mlirFuncOp =
1541 SymbolTable::lookupNearestSymbolFrom<mlir::func::FuncOp>(op,
1543 funcOp = mlirFuncOp;
1544 else if (mlir::LLVM::LLVMFuncOp llvmFuncOp =
1545 SymbolTable::lookupNearestSymbolFrom<mlir::LLVM::LLVMFuncOp>(
1547 funcOp = llvmFuncOp;
1549 auto getNumArguments = [&] {
1550 return std::visit([](
auto &f) {
return f.getNumArguments(); }, *funcOp);
1553 auto getArgumentType = [&](
unsigned i) {
1554 return std::visit([i](
auto &f) {
return f.getArgumentTypes()[i]; },
1559 return op->
emitOpError() <<
"expected symbol reference " << symbolRef
1560 <<
" to point to a copy function";
1562 if (getNumArguments() != 2)
1564 <<
"expected copy function " << symbolRef <<
" to have 2 operands";
1566 Type argTy = getArgumentType(0);
1567 if (argTy != getArgumentType(1))
1568 return op->
emitOpError() <<
"expected copy function " << symbolRef
1569 <<
" arguments to have the same type";
1571 Type varType = std::get<0>(copyprivateVarAndSym).getType();
1572 if (argTy != varType)
1574 <<
"expected copy function arguments' type (" << argTy
1575 <<
") to be the same as copyprivate variable's type (" << varType
1596 if (parser.parseKeyword(&keyword) || parser.parseArrow() ||
1597 parser.parseOperand(dependVars.emplace_back()) ||
1598 parser.parseColonType(dependTypes.emplace_back()))
1600 if (std::optional<ClauseTaskDepend> keywordDepend =
1601 (symbolizeClauseTaskDepend(keyword)))
1602 kindsVec.emplace_back(
1603 ClauseTaskDependAttr::get(parser.getContext(), *keywordDepend));
1617 std::optional<ArrayAttr> dependKinds) {
1619 for (
unsigned i = 0, e = dependKinds->size(); i < e; ++i) {
1622 p << stringifyClauseTaskDepend(
1623 llvm::cast<mlir::omp::ClauseTaskDependAttr>((*dependKinds)[i])
1625 <<
" -> " << dependVars[i] <<
" : " << dependTypes[i];
1631 std::optional<ArrayAttr> dependKinds,
1633 if (!dependVars.empty()) {
1634 if (!dependKinds || dependKinds->size() != dependVars.size())
1635 return op->
emitOpError() <<
"expected as many depend values"
1636 " as depend variables";
1638 if (dependKinds && !dependKinds->empty())
1639 return op->
emitOpError() <<
"unexpected depend values";
1655 IntegerAttr &hintAttr) {
1656 StringRef hintKeyword;
1662 auto parseKeyword = [&]() -> ParseResult {
1665 if (hintKeyword ==
"uncontended")
1667 else if (hintKeyword ==
"contended")
1669 else if (hintKeyword ==
"nonspeculative")
1671 else if (hintKeyword ==
"speculative")
1675 << hintKeyword <<
" is not a valid hint";
1686 IntegerAttr hintAttr) {
1687 int64_t hint = hintAttr.getInt();
1695 auto bitn = [](
int value,
int n) ->
bool {
return value & (1 << n); };
1697 bool uncontended = bitn(hint, 0);
1698 bool contended = bitn(hint, 1);
1699 bool nonspeculative = bitn(hint, 2);
1700 bool speculative = bitn(hint, 3);
1704 hints.push_back(
"uncontended");
1706 hints.push_back(
"contended");
1708 hints.push_back(
"nonspeculative");
1710 hints.push_back(
"speculative");
1712 llvm::interleaveComma(hints, p);
1719 auto bitn = [](
int value,
int n) ->
bool {
return value & (1 << n); };
1721 bool uncontended = bitn(hint, 0);
1722 bool contended = bitn(hint, 1);
1723 bool nonspeculative = bitn(hint, 2);
1724 bool speculative = bitn(hint, 3);
1726 if (uncontended && contended)
1727 return op->
emitOpError() <<
"the hints omp_sync_hint_uncontended and "
1728 "omp_sync_hint_contended cannot be combined";
1729 if (nonspeculative && speculative)
1730 return op->
emitOpError() <<
"the hints omp_sync_hint_nonspeculative and "
1731 "omp_sync_hint_speculative cannot be combined.";
1742 return (value & flag) == flag;
1751 ClauseMapFlagsAttr &mapType) {
1752 ClauseMapFlags mapTypeBits = ClauseMapFlags::none;
1755 auto parseTypeAndMod = [&]() -> ParseResult {
1756 StringRef mapTypeMod;
1760 if (mapTypeMod ==
"always")
1761 mapTypeBits |= ClauseMapFlags::always;
1763 if (mapTypeMod ==
"implicit")
1764 mapTypeBits |= ClauseMapFlags::implicit;
1766 if (mapTypeMod ==
"ompx_hold")
1767 mapTypeBits |= ClauseMapFlags::ompx_hold;
1769 if (mapTypeMod ==
"close")
1770 mapTypeBits |= ClauseMapFlags::close;
1772 if (mapTypeMod ==
"present")
1773 mapTypeBits |= ClauseMapFlags::present;
1775 if (mapTypeMod ==
"to")
1776 mapTypeBits |= ClauseMapFlags::to;
1778 if (mapTypeMod ==
"from")
1779 mapTypeBits |= ClauseMapFlags::from;
1781 if (mapTypeMod ==
"tofrom")
1782 mapTypeBits |= ClauseMapFlags::to | ClauseMapFlags::from;
1784 if (mapTypeMod ==
"delete")
1785 mapTypeBits |= ClauseMapFlags::del;
1787 if (mapTypeMod ==
"storage")
1788 mapTypeBits |= ClauseMapFlags::storage;
1790 if (mapTypeMod ==
"return_param")
1791 mapTypeBits |= ClauseMapFlags::return_param;
1793 if (mapTypeMod ==
"private")
1794 mapTypeBits |= ClauseMapFlags::priv;
1796 if (mapTypeMod ==
"literal")
1797 mapTypeBits |= ClauseMapFlags::literal;
1799 if (mapTypeMod ==
"attach")
1800 mapTypeBits |= ClauseMapFlags::attach;
1802 if (mapTypeMod ==
"attach_always")
1803 mapTypeBits |= ClauseMapFlags::attach_always;
1805 if (mapTypeMod ==
"attach_none")
1806 mapTypeBits |= ClauseMapFlags::attach_none;
1808 if (mapTypeMod ==
"attach_auto")
1809 mapTypeBits |= ClauseMapFlags::attach_auto;
1811 if (mapTypeMod ==
"ref_ptr")
1812 mapTypeBits |= ClauseMapFlags::ref_ptr;
1814 if (mapTypeMod ==
"ref_ptee")
1815 mapTypeBits |= ClauseMapFlags::ref_ptee;
1817 if (mapTypeMod ==
"ref_ptr_ptee")
1818 mapTypeBits |= ClauseMapFlags::ref_ptr_ptee;
1835 ClauseMapFlagsAttr mapType) {
1837 ClauseMapFlags mapFlags = mapType.getValue();
1842 mapTypeStrs.push_back(
"always");
1844 mapTypeStrs.push_back(
"implicit");
1846 mapTypeStrs.push_back(
"ompx_hold");
1848 mapTypeStrs.push_back(
"close");
1850 mapTypeStrs.push_back(
"present");
1859 mapTypeStrs.push_back(
"tofrom");
1861 mapTypeStrs.push_back(
"from");
1863 mapTypeStrs.push_back(
"to");
1866 mapTypeStrs.push_back(
"delete");
1868 mapTypeStrs.push_back(
"return_param");
1870 mapTypeStrs.push_back(
"storage");
1872 mapTypeStrs.push_back(
"private");
1874 mapTypeStrs.push_back(
"literal");
1876 mapTypeStrs.push_back(
"attach");
1878 mapTypeStrs.push_back(
"attach_always");
1880 mapTypeStrs.push_back(
"attach_none");
1882 mapTypeStrs.push_back(
"attach_auto");
1884 mapTypeStrs.push_back(
"ref_ptr");
1886 mapTypeStrs.push_back(
"ref_ptee");
1888 mapTypeStrs.push_back(
"ref_ptr_ptee");
1889 if (mapFlags == ClauseMapFlags::none)
1890 mapTypeStrs.push_back(
"none");
1892 for (
unsigned int i = 0; i < mapTypeStrs.size(); ++i) {
1893 p << mapTypeStrs[i];
1894 if (i + 1 < mapTypeStrs.size()) {
1901 ArrayAttr &membersIdx) {
1904 auto parseIndices = [&]() -> ParseResult {
1909 APInt(64, value,
false)));
1927 if (!memberIdxs.empty())
1934 ArrayAttr membersIdx) {
1938 llvm::interleaveComma(membersIdx, p, [&p](
Attribute v) {
1940 auto memberIdx = cast<ArrayAttr>(v);
1941 llvm::interleaveComma(memberIdx.getValue(), p, [&p](
Attribute v2) {
1942 p << cast<IntegerAttr>(v2).getInt();
1949 VariableCaptureKindAttr mapCaptureType) {
1950 std::string typeCapStr;
1951 llvm::raw_string_ostream typeCap(typeCapStr);
1952 if (mapCaptureType.getValue() == mlir::omp::VariableCaptureKind::ByRef)
1954 if (mapCaptureType.getValue() == mlir::omp::VariableCaptureKind::ByCopy)
1955 typeCap <<
"ByCopy";
1956 if (mapCaptureType.getValue() == mlir::omp::VariableCaptureKind::VLAType)
1957 typeCap <<
"VLAType";
1958 if (mapCaptureType.getValue() == mlir::omp::VariableCaptureKind::This)
1964 VariableCaptureKindAttr &mapCaptureType) {
1965 StringRef mapCaptureKey;
1969 if (mapCaptureKey ==
"This")
1971 parser.
getContext(), mlir::omp::VariableCaptureKind::This);
1972 if (mapCaptureKey ==
"ByRef")
1974 parser.
getContext(), mlir::omp::VariableCaptureKind::ByRef);
1975 if (mapCaptureKey ==
"ByCopy")
1977 parser.
getContext(), mlir::omp::VariableCaptureKind::ByCopy);
1978 if (mapCaptureKey ==
"VLAType")
1980 parser.
getContext(), mlir::omp::VariableCaptureKind::VLAType);
1989 for (
auto mapOp : mapVars) {
1990 if (!mapOp.getDefiningOp())
1993 if (
auto mapInfoOp = mapOp.getDefiningOp<mlir::omp::MapInfoOp>()) {
1994 mlir::omp::ClauseMapFlags mapTypeBits = mapInfoOp.getMapType();
1997 bool from =
mapTypeToBool(mapTypeBits, ClauseMapFlags::from);
2000 bool always =
mapTypeToBool(mapTypeBits, ClauseMapFlags::always);
2001 bool close =
mapTypeToBool(mapTypeBits, ClauseMapFlags::close);
2002 bool implicit =
mapTypeToBool(mapTypeBits, ClauseMapFlags::implicit);
2004 if ((isa<TargetDataOp>(op) || isa<TargetOp>(op)) && del)
2006 "to, from, tofrom and alloc map types are permitted");
2008 if (isa<TargetEnterDataOp>(op) && (from || del))
2009 return emitError(op->
getLoc(),
"to and alloc map types are permitted");
2011 if (isa<TargetExitDataOp>(op) && to)
2013 "from, release and delete map types are permitted");
2015 if (isa<TargetUpdateOp>(op)) {
2018 "at least one of to or from map types must be "
2019 "specified, other map types are not permitted");
2024 "at least one of to or from map types must be "
2025 "specified, other map types are not permitted");
2028 auto updateVar = mapInfoOp.getVarPtr();
2030 if ((to && from) || (to && updateFromVars.contains(updateVar)) ||
2031 (from && updateToVars.contains(updateVar))) {
2034 "either to or from map types can be specified, not both");
2037 if (always || close || implicit) {
2040 "present, mapper and iterator map type modifiers are permitted");
2043 to ? updateToVars.insert(updateVar) : updateFromVars.insert(updateVar);
2045 }
else if (!isa<DeclareMapperInfoOp>(op)) {
2047 "map argument is not a map entry operation");
2055 std::optional<DenseI64ArrayAttr> privateMapIndices =
2056 targetOp.getPrivateMapsAttr();
2059 if (!privateMapIndices.has_value() || !privateMapIndices.value())
2064 if (privateMapIndices.value().size() !=
2065 static_cast<int64_t
>(privateVars.size()))
2066 return emitError(targetOp.getLoc(),
"sizes of `private` operand range and "
2067 "`private_maps` attribute mismatch");
2077 StringRef clauseName,
2079 for (
Value var : vars)
2080 if (!llvm::isa_and_present<MapInfoOp>(var.getDefiningOp()))
2082 <<
"'" << clauseName
2083 <<
"' arguments must be defined by 'omp.map.info' ops";
2088 if (getMapperId() &&
2089 !SymbolTable::lookupNearestSymbolFrom<omp::DeclareMapperOp>(
2090 *
this, getMapperIdAttr())) {
2105 const TargetDataOperands &clauses) {
2106 TargetDataOp::build(builder, state, clauses.device, clauses.ifExpr,
2107 clauses.mapVars, clauses.useDeviceAddrVars,
2108 clauses.useDevicePtrVars);
2112 if (getMapVars().empty() && getUseDevicePtrVars().empty() &&
2113 getUseDeviceAddrVars().empty()) {
2115 "At least one of map, use_device_ptr_vars, or "
2116 "use_device_addr_vars operand must be present");
2120 getUseDevicePtrVars())))
2124 getUseDeviceAddrVars())))
2134 void TargetEnterDataOp::build(
2138 TargetEnterDataOp::build(builder, state,
2140 clauses.dependVars, clauses.device, clauses.ifExpr,
2141 clauses.mapVars, clauses.nowait);
2145 LogicalResult verifyDependVars =
2147 return failed(verifyDependVars) ? verifyDependVars
2158 TargetExitDataOp::build(builder, state,
2160 clauses.dependVars, clauses.device, clauses.ifExpr,
2161 clauses.mapVars, clauses.nowait);
2165 LogicalResult verifyDependVars =
2167 return failed(verifyDependVars) ? verifyDependVars
2178 TargetUpdateOp::build(builder, state,
makeArrayAttr(ctx, clauses.dependKinds),
2179 clauses.dependVars, clauses.device, clauses.ifExpr,
2180 clauses.mapVars, clauses.nowait);
2184 LogicalResult verifyDependVars =
2186 return failed(verifyDependVars) ? verifyDependVars
2195 const TargetOperands &clauses) {
2199 TargetOp::build(builder, state, {}, {},
2201 clauses.dependVars, clauses.device, clauses.hasDeviceAddrVars,
2202 clauses.hostEvalVars, clauses.ifExpr,
2204 nullptr, clauses.isDevicePtrVars,
2205 clauses.mapVars, clauses.nowait, clauses.privateVars,
2207 clauses.privateNeedsBarrier, clauses.threadLimit,
2216 getHasDeviceAddrVars())))
2225 LogicalResult TargetOp::verifyRegions() {
2226 auto teamsOps = getOps<TeamsOp>();
2227 if (std::distance(teamsOps.begin(), teamsOps.end()) > 1)
2228 return emitError(
"target containing multiple 'omp.teams' nested ops");
2231 Operation *capturedOp = getInnermostCapturedOmpOp();
2232 TargetRegionFlags execFlags = getKernelExecFlags(capturedOp);
2233 for (
Value hostEvalArg :
2234 cast<BlockArgOpenMPOpInterface>(getOperation()).getHostEvalBlockArgs()) {
2236 if (
auto teamsOp = dyn_cast<TeamsOp>(user)) {
2237 if (llvm::is_contained({teamsOp.getNumTeamsLower(),
2238 teamsOp.getNumTeamsUpper(),
2239 teamsOp.getThreadLimit()},
2243 return emitOpError() <<
"host_eval argument only legal as 'num_teams' "
2244 "and 'thread_limit' in 'omp.teams'";
2246 if (
auto parallelOp = dyn_cast<ParallelOp>(user)) {
2247 if (bitEnumContainsAny(execFlags, TargetRegionFlags::spmd) &&
2248 parallelOp->isAncestor(capturedOp) &&
2249 hostEvalArg == parallelOp.getNumThreads())
2252 return emitOpError()
2253 <<
"host_eval argument only legal as 'num_threads' in "
2254 "'omp.parallel' when representing target SPMD";
2256 if (
auto loopNestOp = dyn_cast<LoopNestOp>(user)) {
2257 if (bitEnumContainsAny(execFlags, TargetRegionFlags::trip_count) &&
2258 loopNestOp.getOperation() == capturedOp &&
2259 (llvm::is_contained(loopNestOp.getLoopLowerBounds(), hostEvalArg) ||
2260 llvm::is_contained(loopNestOp.getLoopUpperBounds(), hostEvalArg) ||
2261 llvm::is_contained(loopNestOp.getLoopSteps(), hostEvalArg)))
2264 return emitOpError() <<
"host_eval argument only legal as loop bounds "
2265 "and steps in 'omp.loop_nest' when trip count "
2266 "must be evaluated in the host";
2269 return emitOpError() <<
"host_eval argument illegal use in '"
2270 << user->getName() <<
"' operation";
2279 assert(rootOp &&
"expected valid operation");
2291 return WalkResult::advance();
2296 bool isOmpDialect = op->
getDialect() == ompDialect;
2298 if (!isOmpDialect || !hasRegions)
2299 return WalkResult::skip();
2305 if (checkSingleMandatoryExec) {
2310 if (successor->isReachable(parentBlock))
2311 return WalkResult::interrupt();
2313 for (
Block &block : *parentRegion)
2315 !domInfo.
dominates(parentBlock, &block))
2316 return WalkResult::interrupt();
2322 if (&sibling != op && !siblingAllowedFn(&sibling))
2323 return WalkResult::interrupt();
2328 return llvm::isa<LoopNestOp>(op) ? WalkResult::interrupt()
2329 : WalkResult::advance();
2335 Operation *TargetOp::getInnermostCapturedOmpOp() {
2348 if (
auto memOp = dyn_cast<MemoryEffectOpInterface>(sibling)) {
2351 memOp.getEffects(effects);
2352 return !llvm::any_of(
2354 return isa<MemoryEffects::Write>(effect.
getEffect()) &&
2355 isa<SideEffects::AutomaticAllocationScopeResource>(
2365 WsloopOp *wsLoopOp) {
2367 if (teamsOp.getNumTeamsUpper())
2371 if (teamsOp.getNumReductionVars())
2373 if (wsLoopOp->getNumReductionVars())
2377 OffloadModuleInterface offloadMod =
2381 auto ompFlags = offloadMod.getFlags();
2384 return ompFlags.getAssumeTeamsOversubscription() &&
2385 ompFlags.getAssumeThreadsOversubscription();
2388 TargetRegionFlags TargetOp::getKernelExecFlags(
Operation *capturedOp) {
2393 assert((!capturedOp ||
2394 (targetOp && targetOp.getInnermostCapturedOmpOp() == capturedOp)) &&
2395 "unexpected captured op");
2398 if (!isa_and_present<LoopNestOp>(capturedOp))
2399 return TargetRegionFlags::generic;
2403 cast<LoopNestOp>(capturedOp).gatherWrappers(loopWrappers);
2404 assert(!loopWrappers.empty());
2406 LoopWrapperInterface *innermostWrapper = loopWrappers.begin();
2407 if (isa<SimdOp>(innermostWrapper))
2408 innermostWrapper = std::next(innermostWrapper);
2410 auto numWrappers = std::distance(innermostWrapper, loopWrappers.end());
2411 if (numWrappers != 1 && numWrappers != 2)
2412 return TargetRegionFlags::generic;
2415 if (numWrappers == 2) {
2416 WsloopOp *wsloopOp = dyn_cast<WsloopOp>(innermostWrapper);
2418 return TargetRegionFlags::generic;
2420 innermostWrapper = std::next(innermostWrapper);
2421 if (!isa<DistributeOp>(innermostWrapper))
2422 return TargetRegionFlags::generic;
2425 if (!isa_and_present<ParallelOp>(parallelOp))
2426 return TargetRegionFlags::generic;
2428 TeamsOp teamsOp = dyn_cast<TeamsOp>(parallelOp->
getParentOp());
2430 return TargetRegionFlags::generic;
2432 if (teamsOp->getParentOp() == targetOp.getOperation()) {
2433 TargetRegionFlags result =
2434 TargetRegionFlags::spmd | TargetRegionFlags::trip_count;
2436 result = result | TargetRegionFlags::no_loop;
2441 else if (isa<DistributeOp, LoopOp>(innermostWrapper)) {
2443 if (!isa_and_present<TeamsOp>(teamsOp))
2444 return TargetRegionFlags::generic;
2446 if (teamsOp->
getParentOp() != targetOp.getOperation())
2447 return TargetRegionFlags::generic;
2449 if (isa<LoopOp>(innermostWrapper))
2450 return TargetRegionFlags::spmd | TargetRegionFlags::trip_count;
2460 Dialect *ompDialect = targetOp->getDialect();
2464 return sibling && (ompDialect != sibling->
getDialect() ||
2468 TargetRegionFlags result =
2469 TargetRegionFlags::generic | TargetRegionFlags::trip_count;
2474 while (nestedCapture->
getParentOp() != capturedOp)
2477 return isa<ParallelOp>(nestedCapture) ? result | TargetRegionFlags::spmd
2481 else if (isa<WsloopOp>(innermostWrapper)) {
2483 if (!isa_and_present<ParallelOp>(parallelOp))
2484 return TargetRegionFlags::generic;
2486 if (parallelOp->
getParentOp() == targetOp.getOperation())
2487 return TargetRegionFlags::spmd;
2490 return TargetRegionFlags::generic;
2499 ParallelOp::build(builder, state,
ValueRange(),
2506 state.addAttributes(attributes);
2510 const ParallelOperands &clauses) {
2512 ParallelOp::build(builder, state, clauses.allocateVars, clauses.allocatorVars,
2513 clauses.ifExpr, clauses.numThreads, clauses.privateVars,
2515 clauses.privateNeedsBarrier, clauses.procBindKind,
2516 clauses.reductionMod, clauses.reductionVars,
2521 template <
typename OpType>
2523 auto privateVars = op.getPrivateVars();
2524 auto privateSyms = op.getPrivateSymsAttr();
2526 if (privateVars.empty() && (privateSyms ==
nullptr || privateSyms.empty()))
2529 auto numPrivateVars = privateVars.size();
2530 auto numPrivateSyms = (privateSyms ==
nullptr) ? 0 : privateSyms.size();
2532 if (numPrivateVars != numPrivateSyms)
2533 return op.emitError() <<
"inconsistent number of private variables and "
2534 "privatizer op symbols, private vars: "
2536 <<
" vs. privatizer op symbols: " << numPrivateSyms;
2538 for (
auto privateVarInfo : llvm::zip_equal(privateVars, privateSyms)) {
2539 Type varType = std::get<0>(privateVarInfo).getType();
2540 SymbolRefAttr privateSym = cast<SymbolRefAttr>(std::get<1>(privateVarInfo));
2541 PrivateClauseOp privatizerOp =
2542 SymbolTable::lookupNearestSymbolFrom<PrivateClauseOp>(op, privateSym);
2544 if (privatizerOp ==
nullptr)
2545 return op.emitError() <<
"failed to lookup privatizer op with symbol: '"
2546 << privateSym <<
"'";
2548 Type privatizerType = privatizerOp.getArgType();
2550 if (privatizerType && (varType != privatizerType))
2551 return op.emitError()
2552 <<
"type mismatch between a "
2553 << (privatizerOp.getDataSharingType() ==
2554 DataSharingClauseType::Private
2557 <<
" variable and its privatizer op, var type: " << varType
2558 <<
" vs. privatizer op type: " << privatizerType;
2565 if (getAllocateVars().size() != getAllocatorVars().size())
2567 "expected equal sizes for allocate and allocator variables");
2573 getReductionByref());
2576 LogicalResult ParallelOp::verifyRegions() {
2577 auto distChildOps = getOps<DistributeOp>();
2578 int numDistChildOps = std::distance(distChildOps.begin(), distChildOps.end());
2579 if (numDistChildOps > 1)
2581 <<
"multiple 'omp.distribute' nested inside of 'omp.parallel'";
2583 if (numDistChildOps == 1) {
2586 <<
"'omp.composite' attribute missing from composite operation";
2589 Operation &distributeOp = **distChildOps.begin();
2591 if (&childOp == &distributeOp || ompDialect != childOp.getDialect())
2595 return emitError() <<
"unexpected OpenMP operation inside of composite "
2597 << childOp.getName();
2599 }
else if (isComposite()) {
2601 <<
"'omp.composite' attribute present in non-composite operation";
2618 const TeamsOperands &clauses) {
2621 TeamsOp::build(builder, state, clauses.allocateVars, clauses.allocatorVars,
2622 clauses.ifExpr, clauses.numTeamsLower, clauses.numTeamsUpper,
2624 nullptr, clauses.reductionMod,
2625 clauses.reductionVars,
2628 clauses.threadLimit);
2640 return emitError(
"expected to be nested inside of omp.target or not nested "
2641 "in any OpenMP dialect operations");
2644 if (
auto numTeamsLowerBound = getNumTeamsLower()) {
2645 auto numTeamsUpperBound = getNumTeamsUpper();
2646 if (!numTeamsUpperBound)
2647 return emitError(
"expected num_teams upper bound to be defined if the "
2648 "lower bound is defined");
2649 if (numTeamsLowerBound.getType() != numTeamsUpperBound.getType())
2651 "expected num_teams upper bound and lower bound to be the same type");
2655 if (getAllocateVars().size() != getAllocatorVars().size())
2657 "expected equal sizes for allocate and allocator variables");
2660 getReductionByref());
2668 return getParentOp().getPrivateVars();
2672 return getParentOp().getReductionVars();
2680 const SectionsOperands &clauses) {
2683 SectionsOp::build(builder, state, clauses.allocateVars, clauses.allocatorVars,
2686 clauses.reductionMod, clauses.reductionVars,
2692 if (getAllocateVars().size() != getAllocatorVars().size())
2694 "expected equal sizes for allocate and allocator variables");
2697 getReductionByref());
2700 LogicalResult SectionsOp::verifyRegions() {
2701 for (
auto &inst : *getRegion().begin()) {
2702 if (!(isa<SectionOp>(inst) || isa<TerminatorOp>(inst))) {
2703 return emitOpError()
2704 <<
"expected omp.section op or terminator op inside region";
2716 const SingleOperands &clauses) {
2719 SingleOp::build(builder, state, clauses.allocateVars, clauses.allocatorVars,
2720 clauses.copyprivateVars,
2721 makeArrayAttr(ctx, clauses.copyprivateSyms), clauses.nowait,
2728 if (getAllocateVars().size() != getAllocatorVars().size())
2730 "expected equal sizes for allocate and allocator variables");
2733 getCopyprivateSyms());
2741 const WorkshareOperands &clauses) {
2742 WorkshareOp::build(builder, state, clauses.nowait);
2750 if (!(*this)->getParentOfType<WorkshareOp>())
2751 return emitOpError() <<
"must be nested in an omp.workshare";
2755 LogicalResult WorkshareLoopWrapperOp::verifyRegions() {
2756 if (isa_and_nonnull<LoopWrapperInterface>((*this)->getParentOp()) ||
2758 return emitOpError() <<
"expected to be a standalone loop wrapper";
2767 LogicalResult LoopWrapperInterface::verifyImpl() {
2771 return emitOpError() <<
"loop wrapper must also have the `NoTerminator` "
2772 "and `SingleBlock` traits";
2775 return emitOpError() <<
"loop wrapper does not contain exactly one region";
2778 if (range_size(region.
getOps()) != 1)
2779 return emitOpError()
2780 <<
"loop wrapper does not contain exactly one nested op";
2783 if (!isa<LoopNestOp, LoopWrapperInterface>(firstOp))
2784 return emitOpError() <<
"nested in loop wrapper is not another loop "
2785 "wrapper or `omp.loop_nest`";
2795 const LoopOperands &clauses) {
2798 LoopOp::build(builder, state, clauses.bindKind, clauses.privateVars,
2800 clauses.privateNeedsBarrier, clauses.order, clauses.orderMod,
2801 clauses.reductionMod, clauses.reductionVars,
2808 getReductionByref());
2811 LogicalResult LoopOp::verifyRegions() {
2812 if (llvm::isa_and_nonnull<LoopWrapperInterface>((*this)->getParentOp()) ||
2814 return emitOpError() <<
"expected to be a standalone loop wrapper";
2825 build(builder, state, {}, {},
2827 false,
nullptr,
nullptr,
2828 nullptr, {},
nullptr,
2835 state.addAttributes(attributes);
2839 const WsloopOperands &clauses) {
2844 {}, {}, clauses.linearVars,
2845 clauses.linearStepVars, clauses.nowait, clauses.order, clauses.orderMod,
2846 clauses.ordered, clauses.privateVars,
2847 makeArrayAttr(ctx, clauses.privateSyms), clauses.privateNeedsBarrier,
2848 clauses.reductionMod, clauses.reductionVars,
2850 makeArrayAttr(ctx, clauses.reductionSyms), clauses.scheduleKind,
2851 clauses.scheduleChunk, clauses.scheduleMod, clauses.scheduleSimd);
2856 getReductionByref());
2859 LogicalResult WsloopOp::verifyRegions() {
2860 bool isCompositeChildLeaf =
2861 llvm::dyn_cast_if_present<LoopWrapperInterface>((*this)->getParentOp());
2863 if (LoopWrapperInterface nested = getNestedWrapper()) {
2866 <<
"'omp.composite' attribute missing from composite wrapper";
2870 if (!isa<SimdOp>(nested))
2871 return emitError() <<
"only supported nested wrapper is 'omp.simd'";
2873 }
else if (isComposite() && !isCompositeChildLeaf) {
2875 <<
"'omp.composite' attribute present in non-composite wrapper";
2876 }
else if (!isComposite() && isCompositeChildLeaf) {
2878 <<
"'omp.composite' attribute missing from composite wrapper";
2889 const SimdOperands &clauses) {
2892 SimdOp::build(builder, state, clauses.alignedVars,
2895 clauses.nontemporalVars, clauses.order, clauses.orderMod,
2896 clauses.privateVars,
makeArrayAttr(ctx, clauses.privateSyms),
2897 clauses.privateNeedsBarrier, clauses.reductionMod,
2898 clauses.reductionVars,
2905 if (getSimdlen().has_value() && getSafelen().has_value() &&
2906 getSimdlen().value() > getSafelen().value())
2907 return emitOpError()
2908 <<
"simdlen clause and safelen clause are both present, but the "
2909 "simdlen value is not less than or equal to safelen value";
2917 bool isCompositeChildLeaf =
2918 llvm::dyn_cast_if_present<LoopWrapperInterface>((*this)->getParentOp());
2920 if (!isComposite() && isCompositeChildLeaf)
2922 <<
"'omp.composite' attribute missing from composite wrapper";
2924 if (isComposite() && !isCompositeChildLeaf)
2926 <<
"'omp.composite' attribute present in non-composite wrapper";
2930 std::optional<ArrayAttr> privateSyms = getPrivateSyms();
2932 for (
const Attribute &sym : *privateSyms) {
2933 auto symRef = cast<SymbolRefAttr>(sym);
2934 omp::PrivateClauseOp privatizer =
2935 SymbolTable::lookupNearestSymbolFrom<omp::PrivateClauseOp>(
2936 getOperation(), symRef);
2938 return emitError() <<
"Cannot find privatizer '" << symRef <<
"'";
2939 if (privatizer.getDataSharingType() ==
2940 DataSharingClauseType::FirstPrivate)
2941 return emitError() <<
"FIRSTPRIVATE cannot be used with SIMD";
2948 LogicalResult SimdOp::verifyRegions() {
2949 if (getNestedWrapper())
2950 return emitOpError() <<
"must wrap an 'omp.loop_nest' directly";
2960 const DistributeOperands &clauses) {
2961 DistributeOp::build(builder, state, clauses.allocateVars,
2962 clauses.allocatorVars, clauses.distScheduleStatic,
2963 clauses.distScheduleChunkSize, clauses.order,
2964 clauses.orderMod, clauses.privateVars,
2966 clauses.privateNeedsBarrier);
2970 if (this->getDistScheduleChunkSize() && !this->getDistScheduleStatic())
2971 return emitOpError() <<
"chunk size set without "
2972 "dist_schedule_static being present";
2974 if (getAllocateVars().size() != getAllocatorVars().size())
2976 "expected equal sizes for allocate and allocator variables");
2981 LogicalResult DistributeOp::verifyRegions() {
2982 if (LoopWrapperInterface nested = getNestedWrapper()) {
2985 <<
"'omp.composite' attribute missing from composite wrapper";
2988 if (isa<WsloopOp>(nested)) {
2990 if (!llvm::dyn_cast_if_present<ParallelOp>(parentOp) ||
2991 !cast<ComposableOpInterface>(parentOp).isComposite()) {
2992 return emitError() <<
"an 'omp.wsloop' nested wrapper is only allowed "
2993 "when a composite 'omp.parallel' is the direct "
2996 }
else if (!isa<SimdOp>(nested))
2997 return emitError() <<
"only supported nested wrappers are 'omp.simd' and "
2999 }
else if (isComposite()) {
3001 <<
"'omp.composite' attribute present in non-composite wrapper";
3015 LogicalResult DeclareMapperOp::verifyRegions() {
3016 if (!llvm::isa_and_present<DeclareMapperInfoOp>(
3017 getRegion().getBlocks().front().getTerminator()))
3018 return emitOpError() <<
"expected terminator to be a DeclareMapperInfoOp";
3027 LogicalResult DeclareReductionOp::verifyRegions() {
3028 if (!getAllocRegion().empty()) {
3029 for (YieldOp yieldOp : getAllocRegion().getOps<YieldOp>()) {
3030 if (yieldOp.getResults().size() != 1 ||
3031 yieldOp.getResults().getTypes()[0] !=
getType())
3032 return emitOpError() <<
"expects alloc region to yield a value "
3033 "of the reduction type";
3037 if (getInitializerRegion().empty())
3038 return emitOpError() <<
"expects non-empty initializer region";
3039 Block &initializerEntryBlock = getInitializerRegion().
front();
3042 if (!getAllocRegion().empty())
3043 return emitOpError() <<
"expects two arguments to the initializer region "
3044 "when an allocation region is used";
3046 if (getAllocRegion().empty())
3047 return emitOpError() <<
"expects one argument to the initializer region "
3048 "when no allocation region is used";
3050 return emitOpError()
3051 <<
"expects one or two arguments to the initializer region";
3055 if (arg.getType() !=
getType())
3056 return emitOpError() <<
"expects initializer region argument to match "
3057 "the reduction type";
3059 for (YieldOp yieldOp : getInitializerRegion().getOps<YieldOp>()) {
3060 if (yieldOp.getResults().size() != 1 ||
3061 yieldOp.getResults().getTypes()[0] !=
getType())
3062 return emitOpError() <<
"expects initializer region to yield a value "
3063 "of the reduction type";
3066 if (getReductionRegion().empty())
3067 return emitOpError() <<
"expects non-empty reduction region";
3068 Block &reductionEntryBlock = getReductionRegion().
front();
3073 return emitOpError() <<
"expects reduction region with two arguments of "
3074 "the reduction type";
3075 for (YieldOp yieldOp : getReductionRegion().getOps<YieldOp>()) {
3076 if (yieldOp.getResults().size() != 1 ||
3077 yieldOp.getResults().getTypes()[0] !=
getType())
3078 return emitOpError() <<
"expects reduction region to yield a value "
3079 "of the reduction type";
3082 if (!getAtomicReductionRegion().empty()) {
3083 Block &atomicReductionEntryBlock = getAtomicReductionRegion().
front();
3087 return emitOpError() <<
"expects atomic reduction region with two "
3088 "arguments of the same type";
3089 auto ptrType = llvm::dyn_cast<PointerLikeType>(
3092 (ptrType.getElementType() && ptrType.getElementType() !=
getType()))
3093 return emitOpError() <<
"expects atomic reduction region arguments to "
3094 "be accumulators containing the reduction type";
3097 if (getCleanupRegion().empty())
3099 Block &cleanupEntryBlock = getCleanupRegion().
front();
3102 return emitOpError() <<
"expects cleanup region with one argument "
3103 "of the reduction type";
3113 const TaskOperands &clauses) {
3115 TaskOp::build(builder, state, clauses.allocateVars, clauses.allocatorVars,
3116 makeArrayAttr(ctx, clauses.dependKinds), clauses.dependVars,
3117 clauses.final, clauses.ifExpr, clauses.inReductionVars,
3119 makeArrayAttr(ctx, clauses.inReductionSyms), clauses.mergeable,
3120 clauses.priority, clauses.privateVars,
3122 clauses.privateNeedsBarrier, clauses.untied,
3123 clauses.eventHandle);
3127 LogicalResult verifyDependVars =
3129 return failed(verifyDependVars)
3132 getInReductionVars(),
3133 getInReductionByref());
3141 const TaskgroupOperands &clauses) {
3143 TaskgroupOp::build(builder, state, clauses.allocateVars,
3144 clauses.allocatorVars, clauses.taskReductionVars,
3151 getTaskReductionVars(),
3152 getTaskReductionByref());
3160 const TaskloopOperands &clauses) {
3163 builder, state, clauses.allocateVars, clauses.allocatorVars,
3164 clauses.final, clauses.grainsizeMod, clauses.grainsize, clauses.ifExpr,
3165 clauses.inReductionVars,
3167 makeArrayAttr(ctx, clauses.inReductionSyms), clauses.mergeable,
3168 clauses.nogroup, clauses.numTasksMod, clauses.numTasks, clauses.priority,
3169 clauses.privateVars,
3171 clauses.privateNeedsBarrier, clauses.reductionMod, clauses.reductionVars,
3177 if (getAllocateVars().size() != getAllocatorVars().size())
3179 "expected equal sizes for allocate and allocator variables");
3181 getReductionVars(), getReductionByref())) ||
3183 getInReductionVars(),
3184 getInReductionByref())))
3187 if (!getReductionVars().empty() && getNogroup())
3188 return emitError(
"if a reduction clause is present on the taskloop "
3189 "directive, the nogroup clause must not be specified");
3190 for (
auto var : getReductionVars()) {
3191 if (llvm::is_contained(getInReductionVars(), var))
3192 return emitError(
"the same list item cannot appear in both a reduction "
3193 "and an in_reduction clause");
3196 if (getGrainsize() && getNumTasks()) {
3198 "the grainsize clause and num_tasks clause are mutually exclusive and "
3199 "may not appear on the same taskloop directive");
3205 LogicalResult TaskloopOp::verifyRegions() {
3206 if (LoopWrapperInterface nested = getNestedWrapper()) {
3209 <<
"'omp.composite' attribute missing from composite wrapper";
3213 if (!isa<SimdOp>(nested))
3214 return emitError() <<
"only supported nested wrapper is 'omp.simd'";
3215 }
else if (isComposite()) {
3217 <<
"'omp.composite' attribute present in non-composite wrapper";
3241 for (
auto &iv : ivs)
3242 iv.type = loopVarType;
3263 "collapse_num_loops",
3268 auto parseTiles = [&]() -> ParseResult {
3272 tiles.push_back(
tile);
3281 if (tiles.size() > 0)
3300 Region ®ion = getRegion();
3302 p <<
" (" << args <<
") : " << args[0].getType() <<
" = ("
3303 << getLoopLowerBounds() <<
") to (" << getLoopUpperBounds() <<
") ";
3304 if (getLoopInclusive())
3306 p <<
"step (" << getLoopSteps() <<
") ";
3307 if (int64_t numCollapse = getCollapseNumLoops())
3308 if (numCollapse > 1)
3309 p <<
"collapse(" << numCollapse <<
") ";
3312 p <<
"tiles(" << tiles.value() <<
") ";
3318 const LoopNestOperands &clauses) {
3320 LoopNestOp::build(builder, state, clauses.collapseNumLoops,
3321 clauses.loopLowerBounds, clauses.loopUpperBounds,
3322 clauses.loopSteps, clauses.loopInclusive,
3327 if (getLoopLowerBounds().empty())
3328 return emitOpError() <<
"must represent at least one loop";
3330 if (getLoopLowerBounds().size() != getIVs().size())
3331 return emitOpError() <<
"number of range arguments and IVs do not match";
3333 for (
auto [lb, iv] : llvm::zip_equal(getLoopLowerBounds(), getIVs())) {
3334 if (lb.getType() != iv.getType())
3335 return emitOpError()
3336 <<
"range argument type does not match corresponding IV type";
3339 uint64_t numIVs = getIVs().size();
3341 if (
const auto &numCollapse = getCollapseNumLoops())
3342 if (numCollapse > numIVs)
3343 return emitOpError()
3344 <<
"collapse value is larger than the number of loops";
3347 if (tiles.value().size() > numIVs)
3348 return emitOpError() <<
"too few canonical loops for tile dimensions";
3350 if (!llvm::dyn_cast_if_present<LoopWrapperInterface>((*this)->getParentOp()))
3351 return emitOpError() <<
"expects parent op to be a loop wrapper";
3356 void LoopNestOp::gatherWrappers(
3359 while (
auto wrapper =
3360 llvm::dyn_cast_if_present<LoopWrapperInterface>(parent)) {
3361 wrappers.push_back(wrapper);
3370 std::tuple<NewCliOp, OpOperand *, OpOperand *>
3376 return {{},
nullptr,
nullptr};
3379 "Unexpected type of cli");
3385 auto op = cast<LoopTransformationInterface>(use.getOwner());
3387 unsigned opnum = use.getOperandNumber();
3388 if (op.isGeneratee(opnum)) {
3389 assert(!gen &&
"Each CLI may have at most one def");
3391 }
else if (op.isApplyee(opnum)) {
3392 assert(!cons &&
"Each CLI may have at most one consumer");
3395 llvm_unreachable(
"Unexpected operand for a CLI");
3399 return {create, gen, cons};
3408 Value result = getResult();
3409 auto [newCli, gen, cons] =
decodeCli(result);
3422 std::string cliName{
"cli"};
3426 .Case([&](CanonicalLoopOp op) {
3429 .Case([&](UnrollHeuristicOp op) -> std::string {
3430 llvm_unreachable(
"heuristic unrolling does not generate a loop");
3432 .Case([&](TileOp op) -> std::string {
3433 auto [generateesFirst, generateesCount] =
3434 op.getGenerateesODSOperandIndexAndLength();
3435 unsigned firstGrid = generateesFirst;
3436 unsigned firstIntratile = generateesFirst + generateesCount / 2;
3437 unsigned end = generateesFirst + generateesCount;
3438 unsigned opnum =
generator->getOperandNumber();
3440 if (firstGrid <= opnum && opnum < firstIntratile) {
3441 unsigned gridnum = opnum - firstGrid + 1;
3442 return (
"grid" + Twine(gridnum)).str();
3444 if (firstIntratile <= opnum && opnum < end) {
3445 unsigned intratilenum = opnum - firstIntratile + 1;
3446 return (
"intratile" + Twine(intratilenum)).str();
3448 llvm_unreachable(
"Unexpected generatee argument");
3450 .DefaultUnreachable(
"TODO: Custom name for this operation");
3453 setNameFn(result, cliName);
3457 Value cli = getResult();
3460 "Unexpected type of cli");
3466 auto op = cast<mlir::omp::LoopTransformationInterface>(use.getOwner());
3468 unsigned opnum = use.getOperandNumber();
3469 if (op.isGeneratee(opnum)) {
3472 emitOpError(
"CLI must have at most one generator");
3474 .
append(
"first generator here:");
3476 .
append(
"second generator here:");
3481 }
else if (op.isApplyee(opnum)) {
3484 emitOpError(
"CLI must have at most one consumer");
3486 .
append(
"first consumer here:")
3490 .
append(
"second consumer here:")
3497 llvm_unreachable(
"Unexpected operand for a CLI");
3505 .
append(
"see consumer here: ")
3528 setNameFn(&getRegion().front(),
"body_entry");
3531 void CanonicalLoopOp::getAsmBlockArgumentNames(
Region ®ion,
3539 p <<
'(' << getCli() <<
')';
3540 p <<
' ' << getInductionVar() <<
" : " << getInductionVar().getType()
3541 <<
" in range(" << getTripCount() <<
") ";
3551 CanonicalLoopInfoType cliType =
3577 if (parser.
parseRegion(*region, {inductionVariable}))
3582 result.
operands.append(cliOperand);
3588 return mlir::success();
3594 if (!getRegion().empty()) {
3595 Region ®ion = getRegion();
3598 "Canonical loop region must have exactly one argument");
3602 "Region argument must be the same type as the trip count");
3608 Value CanonicalLoopOp::getInductionVar() {
return getRegion().getArgument(0); }
3610 std::pair<unsigned, unsigned>
3611 CanonicalLoopOp::getApplyeesODSOperandIndexAndLength() {
3616 std::pair<unsigned, unsigned>
3617 CanonicalLoopOp::getGenerateesODSOperandIndexAndLength() {
3618 return getODSOperandIndexAndLength(odsIndex_cli);
3632 p <<
'(' << getApplyee() <<
')';
3662 return mlir::success();
3665 std::pair<unsigned, unsigned>
3666 UnrollHeuristicOp ::getApplyeesODSOperandIndexAndLength() {
3667 return getODSOperandIndexAndLength(odsIndex_applyee);
3670 std::pair<unsigned, unsigned>
3671 UnrollHeuristicOp::getGenerateesODSOperandIndexAndLength() {
3682 if (!generatees.empty())
3683 p <<
'(' << llvm::interleaved(generatees) <<
')';
3685 if (!applyees.empty())
3686 p <<
" <- (" << llvm::interleaved(applyees) <<
')';
3718 if (getApplyees().empty())
3719 return emitOpError() <<
"must apply to at least one loop";
3721 if (getSizes().size() != getApplyees().size())
3722 return emitOpError() <<
"there must be one tile size for each applyee";
3724 if (!getGeneratees().empty() &&
3725 2 * getSizes().size() != getGeneratees().size())
3726 return emitOpError()
3727 <<
"expecting two times the number of generatees than applyees";
3731 Value parent = getApplyees().front();
3732 for (
auto &&applyee : llvm::drop_begin(getApplyees())) {
3733 auto [parentCreate, parentGen, parentCons] =
decodeCli(parent);
3734 auto [create, gen, cons] =
decodeCli(applyee);
3737 return emitOpError() <<
"applyee CLI has no generator";
3739 auto parentLoop = dyn_cast_or_null<CanonicalLoopOp>(parentGen->getOwner());
3741 return emitOpError()
3742 <<
"currently only supports omp.canonical_loop as applyee";
3744 parentIVs.insert(parentLoop.getInductionVar());
3747 return emitOpError() <<
"applyee CLI has no generator";
3748 auto loop = dyn_cast_or_null<CanonicalLoopOp>(gen->getOwner());
3750 return emitOpError()
3751 <<
"currently only supports omp.canonical_loop as applyee";
3757 auto &parentBody = parentLoop.getRegion();
3758 if (!parentBody.hasOneBlock())
3760 auto &parentBlock = parentBody.getBlocks().front();
3762 auto nestedLoopIt = parentBlock.begin();
3763 if (nestedLoopIt == parentBlock.end() ||
3764 (&*nestedLoopIt != loop.getOperation()))
3767 auto termIt = std::next(nestedLoopIt);
3768 if (termIt == parentBlock.end() || !isa<TerminatorOp>(termIt))
3771 if (std::next(termIt) != parentBlock.end())
3777 return emitOpError() <<
"tiled loop nest must be perfectly nested";
3779 if (parentIVs.contains(loop.getTripCount()))
3780 return emitOpError() <<
"tiled loop nest must be rectangular";
3799 std::pair<unsigned, unsigned> TileOp ::getApplyeesODSOperandIndexAndLength() {
3800 return getODSOperandIndexAndLength(odsIndex_applyees);
3803 std::pair<unsigned, unsigned> TileOp::getGenerateesODSOperandIndexAndLength() {
3804 return getODSOperandIndexAndLength(odsIndex_generatees);
3812 const CriticalDeclareOperands &clauses) {
3813 CriticalDeclareOp::build(builder, state, clauses.symName, clauses.hint);
3821 if (getNameAttr()) {
3822 SymbolRefAttr symbolRef = getNameAttr();
3826 return emitOpError() <<
"expected symbol reference " << symbolRef
3827 <<
" to point to a critical declaration";
3847 return op.
emitOpError() <<
"must be nested inside of a loop";
3851 if (
auto wsloopOp = dyn_cast<WsloopOp>(wrapper)) {
3852 IntegerAttr orderedAttr = wsloopOp.getOrderedAttr();
3854 return op.
emitOpError() <<
"the enclosing worksharing-loop region must "
3855 "have an ordered clause";
3857 if (hasRegion && orderedAttr.getInt() != 0)
3858 return op.
emitOpError() <<
"the enclosing loop's ordered clause must not "
3859 "have a parameter present";
3861 if (!hasRegion && orderedAttr.getInt() == 0)
3862 return op.
emitOpError() <<
"the enclosing loop's ordered clause must "
3863 "have a parameter present";
3864 }
else if (!isa<SimdOp>(wrapper)) {
3865 return op.
emitOpError() <<
"must be nested inside of a worksharing, simd "
3866 "or worksharing simd loop";
3872 const OrderedOperands &clauses) {
3873 OrderedOp::build(builder, state, clauses.doacrossDependType,
3874 clauses.doacrossNumLoops, clauses.doacrossDependVars);
3881 auto wrapper = (*this)->getParentOfType<WsloopOp>();
3882 if (!wrapper || *wrapper.getOrdered() != *getDoacrossNumLoops())
3883 return emitOpError() <<
"number of variables in depend clause does not "
3884 <<
"match number of iteration variables in the "
3891 const OrderedRegionOperands &clauses) {
3892 OrderedRegionOp::build(builder, state, clauses.parLevelSimd);
3902 const TaskwaitOperands &clauses) {
3904 TaskwaitOp::build(builder, state,
nullptr,
3913 if (verifyCommon().
failed())
3914 return mlir::failure();
3916 if (
auto mo = getMemoryOrder()) {
3917 if (*mo == ClauseMemoryOrderKind::Acq_rel ||
3918 *mo == ClauseMemoryOrderKind::Release) {
3920 "memory-order must not be acq_rel or release for atomic reads");
3931 if (verifyCommon().
failed())
3932 return mlir::failure();
3934 if (
auto mo = getMemoryOrder()) {
3935 if (*mo == ClauseMemoryOrderKind::Acq_rel ||
3936 *mo == ClauseMemoryOrderKind::Acquire) {
3938 "memory-order must not be acq_rel or acquire for atomic writes");
3948 LogicalResult AtomicUpdateOp::canonicalize(AtomicUpdateOp op,
3954 if (
Value writeVal = op.getWriteOpVal()) {
3956 op, op.getX(), writeVal, op.getHintAttr(), op.getMemoryOrderAttr());
3963 if (verifyCommon().
failed())
3964 return mlir::failure();
3966 if (
auto mo = getMemoryOrder()) {
3967 if (*mo == ClauseMemoryOrderKind::Acq_rel ||
3968 *mo == ClauseMemoryOrderKind::Acquire) {
3970 "memory-order must not be acq_rel or acquire for atomic updates");
3977 LogicalResult AtomicUpdateOp::verifyRegions() {
return verifyRegionsCommon(); }
3983 AtomicReadOp AtomicCaptureOp::getAtomicReadOp() {
3984 if (
auto op = dyn_cast<AtomicReadOp>(getFirstOp()))
3986 return dyn_cast<AtomicReadOp>(getSecondOp());
3989 AtomicWriteOp AtomicCaptureOp::getAtomicWriteOp() {
3990 if (
auto op = dyn_cast<AtomicWriteOp>(getFirstOp()))
3992 return dyn_cast<AtomicWriteOp>(getSecondOp());
3995 AtomicUpdateOp AtomicCaptureOp::getAtomicUpdateOp() {
3996 if (
auto op = dyn_cast<AtomicUpdateOp>(getFirstOp()))
3998 return dyn_cast<AtomicUpdateOp>(getSecondOp());
4005 LogicalResult AtomicCaptureOp::verifyRegions() {
4006 if (verifyRegionsCommon().
failed())
4007 return mlir::failure();
4009 if (getFirstOp()->getAttr(
"hint") || getSecondOp()->getAttr(
"hint"))
4011 "operations inside capture region must not have hint clause");
4013 if (getFirstOp()->getAttr(
"memory_order") ||
4014 getSecondOp()->getAttr(
"memory_order"))
4016 "operations inside capture region must not have memory_order clause");
4025 const CancelOperands &clauses) {
4026 CancelOp::build(builder, state, clauses.cancelDirective, clauses.ifExpr);
4040 ClauseCancellationConstructType cct = getCancelDirective();
4043 if (!structuralParent)
4044 return emitOpError() <<
"Orphaned cancel construct";
4046 if ((cct == ClauseCancellationConstructType::Parallel) &&
4047 !mlir::isa<ParallelOp>(structuralParent)) {
4048 return emitOpError() <<
"cancel parallel must appear "
4049 <<
"inside a parallel region";
4051 if (cct == ClauseCancellationConstructType::Loop) {
4054 auto wsloopOp = mlir::dyn_cast<WsloopOp>(structuralParent->
getParentOp());
4057 return emitOpError()
4058 <<
"cancel loop must appear inside a worksharing-loop region";
4060 if (wsloopOp.getNowaitAttr()) {
4061 return emitError() <<
"A worksharing construct that is canceled "
4062 <<
"must not have a nowait clause";
4064 if (wsloopOp.getOrderedAttr()) {
4065 return emitError() <<
"A worksharing construct that is canceled "
4066 <<
"must not have an ordered clause";
4069 }
else if (cct == ClauseCancellationConstructType::Sections) {
4073 mlir::dyn_cast<SectionsOp>(structuralParent->
getParentOp());
4075 return emitOpError() <<
"cancel sections must appear "
4076 <<
"inside a sections region";
4078 if (sectionsOp.getNowait()) {
4079 return emitError() <<
"A sections construct that is canceled "
4080 <<
"must not have a nowait clause";
4083 if ((cct == ClauseCancellationConstructType::Taskgroup) &&
4084 (!mlir::isa<omp::TaskOp>(structuralParent) &&
4085 !mlir::isa<omp::TaskloopOp>(structuralParent->
getParentOp()))) {
4086 return emitOpError() <<
"cancel taskgroup must appear "
4087 <<
"inside a task region";
4097 const CancellationPointOperands &clauses) {
4098 CancellationPointOp::build(builder, state, clauses.cancelDirective);
4102 ClauseCancellationConstructType cct = getCancelDirective();
4105 if (!structuralParent)
4106 return emitOpError() <<
"Orphaned cancellation point";
4108 if ((cct == ClauseCancellationConstructType::Parallel) &&
4109 !mlir::isa<ParallelOp>(structuralParent)) {
4110 return emitOpError() <<
"cancellation point parallel must appear "
4111 <<
"inside a parallel region";
4115 if ((cct == ClauseCancellationConstructType::Loop) &&
4116 !mlir::isa<WsloopOp>(structuralParent->
getParentOp())) {
4117 return emitOpError() <<
"cancellation point loop must appear "
4118 <<
"inside a worksharing-loop region";
4120 if ((cct == ClauseCancellationConstructType::Sections) &&
4121 !mlir::isa<omp::SectionOp>(structuralParent)) {
4122 return emitOpError() <<
"cancellation point sections must appear "
4123 <<
"inside a sections region";
4125 if ((cct == ClauseCancellationConstructType::Taskgroup) &&
4126 !mlir::isa<omp::TaskOp>(structuralParent)) {
4127 return emitOpError() <<
"cancellation point taskgroup must appear "
4128 <<
"inside a task region";
4138 auto extent = getExtent();
4140 if (!extent && !upperbound)
4141 return emitError(
"expected extent or upperbound.");
4148 PrivateClauseOp::build(
4149 odsBuilder, odsState, symName, type,
4151 DataSharingClauseType::Private));
4154 LogicalResult PrivateClauseOp::verifyRegions() {
4155 Type argType = getArgType();
4156 auto verifyTerminator = [&](
Operation *terminator,
4157 bool yieldsValue) -> LogicalResult {
4161 if (!llvm::isa<YieldOp>(terminator))
4163 <<
"expected exit block terminator to be an `omp.yield` op.";
4165 YieldOp yieldOp = llvm::cast<YieldOp>(terminator);
4166 TypeRange yieldedTypes = yieldOp.getResults().getTypes();
4169 if (yieldedTypes.empty())
4173 <<
"Did not expect any values to be yielded.";
4176 if (yieldedTypes.size() == 1 && yieldedTypes.front() == argType)
4180 <<
"Invalid yielded value. Expected type: " << argType
4183 if (yieldedTypes.empty())
4186 error << yieldedTypes;
4192 StringRef regionName,
4193 bool yieldsValue) -> LogicalResult {
4194 assert(!region.
empty());
4198 <<
"`" << regionName <<
"`: "
4199 <<
"expected " << expectedNumArgs
4202 for (
Block &block : region) {
4204 if (!block.mightHaveTerminator())
4207 if (
failed(verifyTerminator(block.getTerminator(), yieldsValue)))
4215 for (
Region *region : getRegions())
4216 for (
Type ty : region->getArgumentTypes())
4218 return emitError() <<
"Region argument type mismatch: got " << ty
4219 <<
" expected " << argType <<
".";
4222 if (!initRegion.
empty() &&
4227 DataSharingClauseType dsType = getDataSharingType();
4229 if (dsType == DataSharingClauseType::Private && !getCopyRegion().empty())
4230 return emitError(
"`private` clauses do not require a `copy` region.");
4232 if (dsType == DataSharingClauseType::FirstPrivate && getCopyRegion().empty())
4234 "`firstprivate` clauses require at least a `copy` region.");
4236 if (dsType == DataSharingClauseType::FirstPrivate &&
4241 if (!getDeallocRegion().empty() &&
4254 const MaskedOperands &clauses) {
4255 MaskedOp::build(builder, state, clauses.filteredThreadId);
4263 const ScanOperands &clauses) {
4264 ScanOp::build(builder, state, clauses.inclusiveVars, clauses.exclusiveVars);
4268 if (hasExclusiveVars() == hasInclusiveVars())
4270 "Exactly one of EXCLUSIVE or INCLUSIVE clause is expected");
4271 if (WsloopOp parentWsLoopOp = (*this)->getParentOfType<WsloopOp>()) {
4272 if (parentWsLoopOp.getReductionModAttr() &&
4273 parentWsLoopOp.getReductionModAttr().getValue() ==
4274 ReductionModifier::inscan)
4277 if (SimdOp parentSimdOp = (*this)->getParentOfType<SimdOp>()) {
4278 if (parentSimdOp.getReductionModAttr() &&
4279 parentSimdOp.getReductionModAttr().getValue() ==
4280 ReductionModifier::inscan)
4283 return emitError(
"SCAN directive needs to be enclosed within a parent "
4284 "worksharing loop construct or SIMD construct with INSCAN "
4285 "reduction modifier");
4291 std::optional<uint64_t> align = this->getAlign();
4293 if (align.has_value()) {
4294 if ((align.value() > 0) && !llvm::has_single_bit(align.value()))
4295 return emitError() <<
"ALIGN value : " << align.value()
4296 <<
" must be power of 2";
4306 mlir::Type omp::TargetAllocMemOp::getAllocatedType() {
4307 return getInTypeAttr().getValue();
4316 bool hasOperands =
false;
4317 std::int32_t typeparamsSize = 0;
4323 return mlir::failure();
4325 return mlir::failure();
4327 return mlir::failure();
4331 return mlir::failure();
4339 return mlir::failure();
4340 typeparamsSize = operands.size();
4343 std::int32_t shapeSize = 0;
4347 return mlir::failure();
4348 shapeSize = operands.size() - typeparamsSize;
4350 for (std::int32_t i = typeparamsSize, end = operands.size(); i != end; ++i)
4351 typeVec.push_back(idxTy);
4357 return mlir::failure();
4362 return mlir::failure();
4369 return mlir::failure();
4370 return mlir::success();
4385 if (!getTypeparams().empty()) {
4386 p <<
'(' << getTypeparams() <<
" : " << getTypeparams().getTypes() <<
')';
4393 {
"in_type",
"operandSegmentSizes"});
4398 if (!mlir::dyn_cast<IntegerType>(outType))
4399 return emitOpError(
"must be a integer type");
4400 return mlir::success();
4409 Region ®ion = getRegion();
4411 return emitOpError(
"region cannot be empty");
4414 if (entryBlock.
empty())
4415 return emitOpError(
"region must contain a structured block");
4417 bool hasTerminator =
false;
4418 for (
Block &block : region) {
4419 if (isa<TerminatorOp>(block.back())) {
4420 if (hasTerminator) {
4421 return emitOpError(
"region must have exactly one terminator");
4423 hasTerminator =
true;
4426 if (!hasTerminator) {
4427 return emitOpError(
"region must be terminated with omp.terminator");
4431 if (isa<BarrierOp>(op)) {
4433 "explicit barriers are not allowed in workdistribute region");
4436 if (isa<ParallelOp>(op)) {
4438 "nested parallel constructs not allowed in workdistribute");
4440 if (isa<TeamsOp>(op)) {
4442 "nested teams constructs not allowed in workdistribute");
4444 return WalkResult::advance();
4446 if (walkResult.wasInterrupted())
4450 if (!llvm::dyn_cast<TeamsOp>(parentOp))
4451 return emitOpError(
"workdistribute must be nested under teams");
4455 #define GET_ATTRDEF_CLASSES
4456 #include "mlir/Dialect/OpenMP/OpenMPOpsAttributes.cpp.inc"
4458 #define GET_OP_CLASSES
4459 #include "mlir/Dialect/OpenMP/OpenMPOps.cpp.inc"
4461 #define GET_TYPEDEF_CLASSES
4462 #include "mlir/Dialect/OpenMP/OpenMPOpsTypes.cpp.inc"
static SmallVector< Value > getTileSizes(Location loc, amx::TileType tType, RewriterBase &rewriter)
Maps the 2-dim vector shape to the two 16-bit tile sizes.
static std::optional< int64_t > getUpperBound(Value iv)
Gets the constant upper bound on an affine.for iv.
static LogicalResult verifyRegion(emitc::SwitchOp op, Region ®ion, const Twine &name)
static void visit(Operation *op, DenseSet< Operation * > &visited)
Visits all the pdl.operand(s), pdl.result(s), and pdl.operation(s) connected to the given operation.
static MLIRContext * getContext(OpFoldResult val)
static Type getElementType(Type type)
Determine the element type of type.
static ze_device_handle_t getDevice(const uint32_t driverIdx=0, const int32_t devIdx=0)
static const mlir::GenInfo * generator
static LogicalResult verifyNontemporalClause(Operation *op, OperandRange nontemporalVars)
static DenseI64ArrayAttr makeDenseI64ArrayAttr(MLIRContext *ctx, const ArrayRef< int64_t > intArray)
static constexpr StringRef getPrivateNeedsBarrierSpelling()
static LogicalResult verifyMapClause(Operation *op, OperandRange mapVars)
static ParseResult parseInReductionPrivateRegion(OpAsmParser &parser, Region ®ion, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &inReductionVars, SmallVectorImpl< Type > &inReductionTypes, DenseBoolArrayAttr &inReductionByref, ArrayAttr &inReductionSyms, llvm::SmallVectorImpl< OpAsmParser::UnresolvedOperand > &privateVars, llvm::SmallVectorImpl< Type > &privateTypes, ArrayAttr &privateSyms, UnitAttr &privateNeedsBarrier)
static ArrayAttr makeArrayAttr(MLIRContext *context, llvm::ArrayRef< Attribute > attrs)
static ParseResult parseClauseAttr(AsmParser &parser, ClauseAttr &attr)
static void printAllocateAndAllocator(OpAsmPrinter &p, Operation *op, OperandRange allocateVars, TypeRange allocateTypes, OperandRange allocatorVars, TypeRange allocatorTypes)
Print allocate clause.
static DenseBoolArrayAttr makeDenseBoolArrayAttr(MLIRContext *ctx, const ArrayRef< bool > boolArray)
static std::string generateLoopNestingName(StringRef prefix, CanonicalLoopOp op)
Generate a name of a canonical loop nest of the format <prefix>(_r<idx>_s<idx>)*.
static void printClauseWithRegionArgs(OpAsmPrinter &p, MLIRContext *ctx, StringRef clauseName, ValueRange argsSubrange, ValueRange operands, TypeRange types, ArrayAttr symbols=nullptr, DenseI64ArrayAttr mapIndices=nullptr, DenseBoolArrayAttr byref=nullptr, ReductionModifierAttr modifier=nullptr, UnitAttr needsBarrier=nullptr)
static void printBlockArgClause(OpAsmPrinter &p, MLIRContext *ctx, StringRef clauseName, ValueRange argsSubrange, std::optional< MapPrintArgs > mapArgs)
static void printBlockArgRegion(OpAsmPrinter &p, Operation *op, Region ®ion, const AllRegionPrintArgs &args)
static ParseResult parseGranularityClause(OpAsmParser &parser, ClauseTypeAttr &prescriptiveness, std::optional< OpAsmParser::UnresolvedOperand > &operand, Type &operandType, std::optional< ClauseType >(*symbolizeClause)(StringRef), StringRef clauseName)
static ParseResult parseBlockArgRegion(OpAsmParser &parser, Region ®ion, AllRegionParseArgs args)
static ParseResult parseLoopTransformClis(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &generateesOperands, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &applyeesOperands)
static ParseResult parseSynchronizationHint(OpAsmParser &parser, IntegerAttr &hintAttr)
Parses a Synchronization Hint clause.
static ParseResult parseLinearClause(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &linearVars, SmallVectorImpl< Type > &linearTypes, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &linearStepVars)
linear ::= linear ( linear-list ) linear-list := linear-val | linear-val linear-list linear-val := ss...
static void printScheduleClause(OpAsmPrinter &p, Operation *op, ClauseScheduleKindAttr scheduleKind, ScheduleModifierAttr scheduleMod, UnitAttr scheduleSimd, Value scheduleChunk, Type scheduleChunkType)
Print schedule clause.
static void printCopyprivate(OpAsmPrinter &p, Operation *op, OperandRange copyprivateVars, TypeRange copyprivateTypes, std::optional< ArrayAttr > copyprivateSyms)
Print Copyprivate clause.
static ParseResult parseOrderClause(OpAsmParser &parser, ClauseOrderKindAttr &order, OrderModifierAttr &orderMod)
static bool mapTypeToBool(ClauseMapFlags value, ClauseMapFlags flag)
static void printAlignedClause(OpAsmPrinter &p, Operation *op, ValueRange alignedVars, TypeRange alignedTypes, std::optional< ArrayAttr > alignments)
Print Aligned Clause.
static LogicalResult verifySynchronizationHint(Operation *op, uint64_t hint)
Verifies a synchronization hint clause.
static ParseResult parseUseDeviceAddrUseDevicePtrRegion(OpAsmParser &parser, Region ®ion, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &useDeviceAddrVars, SmallVectorImpl< Type > &useDeviceAddrTypes, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &useDevicePtrVars, SmallVectorImpl< Type > &useDevicePtrTypes)
static void printLinearClause(OpAsmPrinter &p, Operation *op, ValueRange linearVars, TypeRange linearTypes, ValueRange linearStepVars)
Print Linear Clause.
static void printInReductionPrivateReductionRegion(OpAsmPrinter &p, Operation *op, Region ®ion, ValueRange inReductionVars, TypeRange inReductionTypes, DenseBoolArrayAttr inReductionByref, ArrayAttr inReductionSyms, ValueRange privateVars, TypeRange privateTypes, ArrayAttr privateSyms, UnitAttr privateNeedsBarrier, ReductionModifierAttr reductionMod, ValueRange reductionVars, TypeRange reductionTypes, DenseBoolArrayAttr reductionByref, ArrayAttr reductionSyms)
static void printInReductionPrivateRegion(OpAsmPrinter &p, Operation *op, Region ®ion, ValueRange inReductionVars, TypeRange inReductionTypes, DenseBoolArrayAttr inReductionByref, ArrayAttr inReductionSyms, ValueRange privateVars, TypeRange privateTypes, ArrayAttr privateSyms, UnitAttr privateNeedsBarrier)
static void printSynchronizationHint(OpAsmPrinter &p, Operation *op, IntegerAttr hintAttr)
Prints a Synchronization Hint clause.
static void printGranularityClause(OpAsmPrinter &p, Operation *op, ClauseTypeAttr prescriptiveness, Value operand, mlir::Type operandType, StringRef(*stringifyClauseType)(ClauseType))
static mlir::ParseResult parseTargetAllocMemOp(mlir::OpAsmParser &parser, mlir::OperationState &result)
operation ::= res = (omp.target_alloc_mem) $device : devicetype, $in_type ( ( $typeparams ) )?...
static void printDependVarList(OpAsmPrinter &p, Operation *op, OperandRange dependVars, TypeRange dependTypes, std::optional< ArrayAttr > dependKinds)
Print Depend clause.
static void printTargetOpRegion(OpAsmPrinter &p, Operation *op, Region ®ion, ValueRange hasDeviceAddrVars, TypeRange hasDeviceAddrTypes, ValueRange hostEvalVars, TypeRange hostEvalTypes, ValueRange inReductionVars, TypeRange inReductionTypes, DenseBoolArrayAttr inReductionByref, ArrayAttr inReductionSyms, ValueRange mapVars, TypeRange mapTypes, ValueRange privateVars, TypeRange privateTypes, ArrayAttr privateSyms, UnitAttr privateNeedsBarrier, DenseI64ArrayAttr privateMaps)
static LogicalResult verifyCopyprivateVarList(Operation *op, OperandRange copyprivateVars, std::optional< ArrayAttr > copyprivateSyms)
Verifies CopyPrivate Clause.
static LogicalResult verifyAlignedClause(Operation *op, std::optional< ArrayAttr > alignments, OperandRange alignedVars)
static ParseResult parseTargetOpRegion(OpAsmParser &parser, Region ®ion, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &hasDeviceAddrVars, SmallVectorImpl< Type > &hasDeviceAddrTypes, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &hostEvalVars, SmallVectorImpl< Type > &hostEvalTypes, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &inReductionVars, SmallVectorImpl< Type > &inReductionTypes, DenseBoolArrayAttr &inReductionByref, ArrayAttr &inReductionSyms, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &mapVars, SmallVectorImpl< Type > &mapTypes, llvm::SmallVectorImpl< OpAsmParser::UnresolvedOperand > &privateVars, llvm::SmallVectorImpl< Type > &privateTypes, ArrayAttr &privateSyms, UnitAttr &privateNeedsBarrier, DenseI64ArrayAttr &privateMaps)
static ParseResult parsePrivateRegion(OpAsmParser &parser, Region ®ion, llvm::SmallVectorImpl< OpAsmParser::UnresolvedOperand > &privateVars, llvm::SmallVectorImpl< Type > &privateTypes, ArrayAttr &privateSyms, UnitAttr &privateNeedsBarrier)
static void printNumTasksClause(OpAsmPrinter &p, Operation *op, ClauseNumTasksTypeAttr numTasksMod, Value numTasks, mlir::Type numTasksType)
static void printLoopTransformClis(OpAsmPrinter &p, TileOp op, OperandRange generatees, OperandRange applyees)
static void printPrivateRegion(OpAsmPrinter &p, Operation *op, Region ®ion, ValueRange privateVars, TypeRange privateTypes, ArrayAttr privateSyms, UnitAttr privateNeedsBarrier)
static void printPrivateReductionRegion(OpAsmPrinter &p, Operation *op, Region ®ion, ValueRange privateVars, TypeRange privateTypes, ArrayAttr privateSyms, UnitAttr privateNeedsBarrier, ReductionModifierAttr reductionMod, ValueRange reductionVars, TypeRange reductionTypes, DenseBoolArrayAttr reductionByref, ArrayAttr reductionSyms)
static void printTaskReductionRegion(OpAsmPrinter &p, Operation *op, Region ®ion, ValueRange taskReductionVars, TypeRange taskReductionTypes, DenseBoolArrayAttr taskReductionByref, ArrayAttr taskReductionSyms)
static LogicalResult verifyOrderedParent(Operation &op)
static void printOrderClause(OpAsmPrinter &p, Operation *op, ClauseOrderKindAttr order, OrderModifierAttr orderMod)
static ParseResult parseBlockArgClause(OpAsmParser &parser, llvm::SmallVectorImpl< OpAsmParser::Argument > &entryBlockArgs, StringRef keyword, std::optional< MapParseArgs > mapArgs)
static ParseResult parseClauseWithRegionArgs(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &operands, SmallVectorImpl< Type > &types, SmallVectorImpl< OpAsmParser::Argument > ®ionPrivateArgs, ArrayAttr *symbols=nullptr, DenseI64ArrayAttr *mapIndices=nullptr, DenseBoolArrayAttr *byref=nullptr, ReductionModifierAttr *modifier=nullptr, UnitAttr *needsBarrier=nullptr)
static ParseResult verifyScheduleModifiers(OpAsmParser &parser, SmallVectorImpl< SmallString< 12 >> &modifiers)
static LogicalResult verifyPrivateVarsMapping(TargetOp targetOp)
static ParseResult parseScheduleClause(OpAsmParser &parser, ClauseScheduleKindAttr &scheduleAttr, ScheduleModifierAttr &scheduleMod, UnitAttr &scheduleSimd, std::optional< OpAsmParser::UnresolvedOperand > &chunkSize, Type &chunkType)
schedule ::= schedule ( sched-list ) sched-list ::= sched-val | sched-val sched-list | sched-val ,...
static void printClauseAttr(OpAsmPrinter &p, Operation *op, ClauseAttr attr)
static Operation * getParentInSameDialect(Operation *thisOp)
static ParseResult parseAllocateAndAllocator(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &allocateVars, SmallVectorImpl< Type > &allocateTypes, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &allocatorVars, SmallVectorImpl< Type > &allocatorTypes)
Parse an allocate clause with allocators and a list of operands with types.
static void printMembersIndex(OpAsmPrinter &p, MapInfoOp op, ArrayAttr membersIdx)
static bool canPromoteToNoLoop(Operation *capturedOp, TeamsOp teamsOp, WsloopOp *wsLoopOp)
Check if we can promote SPMD kernel to No-Loop kernel.
static void printCaptureType(OpAsmPrinter &p, Operation *op, VariableCaptureKindAttr mapCaptureType)
static LogicalResult verifyReductionVarList(Operation *op, std::optional< ArrayAttr > reductionSyms, OperandRange reductionVars, std::optional< ArrayRef< bool >> reductionByref)
Verifies Reduction Clause.
static void printMapClause(OpAsmPrinter &p, Operation *op, ClauseMapFlagsAttr mapType)
Prints a map_entries map type from its numeric value out into its string format.
static Operation * findCapturedOmpOp(Operation *rootOp, bool checkSingleMandatoryExec, llvm::function_ref< bool(Operation *)> siblingAllowedFn)
static bool opInGlobalImplicitParallelRegion(Operation *op)
static void printUseDeviceAddrUseDevicePtrRegion(OpAsmPrinter &p, Operation *op, Region ®ion, ValueRange useDeviceAddrVars, TypeRange useDeviceAddrTypes, ValueRange useDevicePtrVars, TypeRange useDevicePtrTypes)
static LogicalResult verifyPrivateVarList(OpType &op)
static ParseResult parseNumTasksClause(OpAsmParser &parser, ClauseNumTasksTypeAttr &numTasksMod, std::optional< OpAsmParser::UnresolvedOperand > &numTasks, Type &numTasksType)
static ParseResult parseMembersIndex(OpAsmParser &parser, ArrayAttr &membersIdx)
static ParseResult parseAlignedClause(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &alignedVars, SmallVectorImpl< Type > &alignedTypes, ArrayAttr &alignmentsAttr)
aligned ::= aligned ( aligned-list ) aligned-list := aligned-val | aligned-val aligned-list aligned-v...
static ParseResult parsePrivateReductionRegion(OpAsmParser &parser, Region ®ion, llvm::SmallVectorImpl< OpAsmParser::UnresolvedOperand > &privateVars, llvm::SmallVectorImpl< Type > &privateTypes, ArrayAttr &privateSyms, UnitAttr &privateNeedsBarrier, ReductionModifierAttr &reductionMod, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &reductionVars, SmallVectorImpl< Type > &reductionTypes, DenseBoolArrayAttr &reductionByref, ArrayAttr &reductionSyms)
static ParseResult parseInReductionPrivateReductionRegion(OpAsmParser &parser, Region ®ion, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &inReductionVars, SmallVectorImpl< Type > &inReductionTypes, DenseBoolArrayAttr &inReductionByref, ArrayAttr &inReductionSyms, llvm::SmallVectorImpl< OpAsmParser::UnresolvedOperand > &privateVars, llvm::SmallVectorImpl< Type > &privateTypes, ArrayAttr &privateSyms, UnitAttr &privateNeedsBarrier, ReductionModifierAttr &reductionMod, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &reductionVars, SmallVectorImpl< Type > &reductionTypes, DenseBoolArrayAttr &reductionByref, ArrayAttr &reductionSyms)
static ParseResult parseCaptureType(OpAsmParser &parser, VariableCaptureKindAttr &mapCaptureType)
static ParseResult parseTaskReductionRegion(OpAsmParser &parser, Region ®ion, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &taskReductionVars, SmallVectorImpl< Type > &taskReductionTypes, DenseBoolArrayAttr &taskReductionByref, ArrayAttr &taskReductionSyms)
static ParseResult parseGrainsizeClause(OpAsmParser &parser, ClauseGrainsizeTypeAttr &grainsizeMod, std::optional< OpAsmParser::UnresolvedOperand > &grainsize, Type &grainsizeType)
static ParseResult parseMapClause(OpAsmParser &parser, ClauseMapFlagsAttr &mapType)
Parses a map_entries map type from a string format back into its numeric value.
static ParseResult parseCopyprivate(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > ©privateVars, SmallVectorImpl< Type > ©privateTypes, ArrayAttr ©privateSyms)
copyprivate-entry-list ::= copyprivate-entry | copyprivate-entry-list , copyprivate-entry copyprivate...
static LogicalResult verifyDependVarList(Operation *op, std::optional< ArrayAttr > dependKinds, OperandRange dependVars)
Verifies Depend clause.
static ParseResult parseDependVarList(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &dependVars, SmallVectorImpl< Type > &dependTypes, ArrayAttr &dependKinds)
depend-entry-list ::= depend-entry | depend-entry-list , depend-entry depend-entry ::= depend-kind ->...
static LogicalResult verifyMapInfoDefinedArgs(Operation *op, StringRef clauseName, OperandRange vars)
static void printGrainsizeClause(OpAsmPrinter &p, Operation *op, ClauseGrainsizeTypeAttr grainsizeMod, Value grainsize, mlir::Type grainsizeType)
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static bool isUnique(It begin, It end)
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
This base class exposes generic asm parser hooks, usable across the various derived parsers.
virtual ParseResult parseMinus()=0
Parse a '-' token.
@ Paren
Parens surrounding zero or more operands.
@ None
Zero or more operands with no delimiters.
virtual ParseResult parseColonTypeList(SmallVectorImpl< Type > &result)=0
Parse a colon followed by a type list, which must have at least one type.
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalEqual()=0
Parse a = token if present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseOptionalColon()=0
Parse a : token if present.
virtual ParseResult parseLSquare()=0
Parse a [ token.
virtual ParseResult parseRSquare()=0
Parse a ] token.
ParseResult parseInteger(IntT &result)
Parse an integer value from the stream.
virtual ParseResult parseOptionalArrow()=0
Parse a '->' token if present.
ParseResult addTypeToList(Type type, SmallVectorImpl< Type > &result)
Add the specified type to the end of the specified type list and return success.
virtual ParseResult parseLess()=0
Parse a '<' token.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual ParseResult parseColon()=0
Parse a : token.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseOptionalLess()=0
Parse a '<' token if present.
virtual ParseResult parseArrow()=0
Parse a '->' token.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseComma()=0
Parse a , token.
virtual ParseResult parseOptionalLParen()=0
Parse a ( token if present.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
Attributes are known-constant values of operations.
Block represents an ordered list of Operations.
ValueTypeRange< BlockArgListType > getArgumentTypes()
Return a range containing the types of the arguments for this block.
BlockArgument getArgument(unsigned i)
unsigned getNumArguments()
SuccessorRange getSuccessors()
BlockArgListType getArguments()
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
IntegerType getIntegerType(unsigned width)
MLIRContext * getContext() const
Attr getAttr(Args &&...args)
Get or construct an instance of the attribute Attr with provided arguments.
Diagnostic & append(Arg1 &&arg1, Arg2 &&arg2, Args &&...args)
Append arguments to the diagnostic.
Diagnostic & appendOp(Operation &op, const OpPrintingFlags &flags)
Append an operation with the given printing flags.
Dialects are groups of MLIR operations, types and attributes, as well as behavior associated with the...
A class for computing basic dominance information.
bool dominates(Operation *a, Operation *b) const
Return true if operation A dominates operation B, i.e.
This class represents a diagnostic that is inflight and set to be reported.
Diagnostic & attachNote(std::optional< Location > noteLoc=std::nullopt)
Attaches a note to this diagnostic.
MLIRContext is the top-level object for a collection of MLIR operations.
Dialect * getLoadedDialect(StringRef name)
Get a registered IR dialect with the given namespace.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region ®ion, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult parseArgument(Argument &result, bool allowType=false, bool allowAttrs=false)=0
Parse a single argument with the following syntax:
virtual ParseResult parseArgumentList(SmallVectorImpl< Argument > &result, Delimiter delimiter=Delimiter::None, bool allowType=false, bool allowAttrs=false)=0
Parse zero or more arguments with a specified surrounding delimiter.
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
virtual void printOperand(Value value)=0
Print implementations for various things an operation contains.
This class helps build Operations.
This class represents an operand of an operation.
Set of flags used to control the behavior of the various IR print methods (e.g.
This class provides the API for ops that are known to be isolated from above.
This class provides the API for ops that are known to be terminators.
This class indicates that the regions associated with this op don't have terminators.
This class implements the operand iterators for the Operation class.
Operation is the basic unit of execution within MLIR.
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Dialect * getDialect()
Return the dialect this operation is associated with, or nullptr if the associated dialect is not loa...
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
MLIRContext * getContext()
Return the context this operation is associated with.
unsigned getNumRegions()
Returns the number of regions held by this operation.
Location getLoc()
The source location the operation was defined or derived from.
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Block * getBlock()
Returns the operation block that contains this operation.
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
user_range getUsers()
Returns a range of all users.
Region * getParentRegion()
Returns the region to which the instruction belongs.
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class contains a list of basic blocks and a link to the parent operation it is attached to.
iterator_range< OpIterator > getOps()
BlockArgListType getArguments()
OpIterator op_begin()
Return iterators that walk the operations nested directly within this region.
Operation * getParentOp()
Return the parent operation this region is attached to.
unsigned getNumArguments()
BlockListType & getBlocks()
Location getLoc()
Return a location for this region.
BlockArgument getArgument(unsigned i)
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class represents a specific instance of an effect.
Resource * getResource() const
Return the resource that the effect applies to.
EffectT * getEffect() const
Return the effect being applied.
This class represents a collection of SymbolTables.
virtual Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
type_range getType() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
MLIRContext * getContext() const
Utility to get the associated MLIRContext that this value is defined in.
Type getType() const
Return the type of this value.
use_range getUses() const
Returns a range of all uses, which is useful for iterating over all uses.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
A utility result that is used to signal how to proceed with an ongoing walk:
Base class for DenseArrayAttr that is instantiated and specialized for each supported element type be...
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< T > content)
Builder from ArrayRef<T>.
bool isReachableFromEntry(Block *a) const
Return true if the specified block is reachable from the entry block of its region.
Operation * getOwner() const
Return the owner of this operand.
bool isPerfectlyNested(ArrayRef< AffineForOp > loops)
Returns true if loops is a perfectly nested loop nest, where loops appear in it from outermost to inn...
SmallVector< SmallVector< AffineForOp, 8 >, 8 > tile(ArrayRef< AffineForOp > forOps, ArrayRef< uint64_t > sizes, ArrayRef< AffineForOp > targets)
Performs tiling fo imperfectly nested loops (with interchange) by strip-mining the forOps by sizes an...
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Runtime
Potential runtimes for AMD GPU kernels.
std::tuple< NewCliOp, OpOperand *, OpOperand * > decodeCli(mlir::Value cli)
Find the omp.new_cli, generator, and consumer of a canonical loop info.
TargetEnterDataOperands TargetEnterExitUpdateDataOperands
omp.target_enter_data, omp.target_exit_data and omp.target_update take the same clauses,...
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
This is the representation of an operand reference.
This class provides APIs and verifiers for ops with regions having a single block.
This represents an operation in an abstracted form, suitable for use with the builder APIs.
SmallVector< Value, 4 > operands
void addOperands(ValueRange newOperands)
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
void addTypes(ArrayRef< Type > newTypes)
SmallVector< Type, 4 > types
Types of the results of this operation.
Region * addRegion()
Create a region that should be attached to the operation.