25 #include "llvm/ADT/ArrayRef.h"
26 #include "llvm/ADT/BitVector.h"
27 #include "llvm/ADT/STLExtras.h"
28 #include "llvm/ADT/STLForwardCompat.h"
29 #include "llvm/ADT/SmallString.h"
30 #include "llvm/ADT/StringExtras.h"
31 #include "llvm/ADT/StringRef.h"
32 #include "llvm/ADT/TypeSwitch.h"
33 #include "llvm/Frontend/OpenMP/OMPConstants.h"
34 #include "llvm/Frontend/OpenMP/OMPDeviceConstants.h"
40 #include "mlir/Dialect/OpenMP/OpenMPOpsDialect.cpp.inc"
41 #include "mlir/Dialect/OpenMP/OpenMPOpsEnums.cpp.inc"
42 #include "mlir/Dialect/OpenMP/OpenMPOpsInterfaces.cpp.inc"
43 #include "mlir/Dialect/OpenMP/OpenMPTypeInterfaces.cpp.inc"
59 struct MemRefPointerLikeModel
60 :
public PointerLikeType::ExternalModel<MemRefPointerLikeModel,
63 return llvm::cast<MemRefType>(pointer).getElementType();
67 struct LLVMPointerPointerLikeModel
68 :
public PointerLikeType::ExternalModel<LLVMPointerPointerLikeModel,
69 LLVM::LLVMPointerType> {
74 void OpenMPDialect::initialize() {
77 #include "mlir/Dialect/OpenMP/OpenMPOps.cpp.inc"
80 #define GET_ATTRDEF_LIST
81 #include "mlir/Dialect/OpenMP/OpenMPOpsAttributes.cpp.inc"
84 #define GET_TYPEDEF_LIST
85 #include "mlir/Dialect/OpenMP/OpenMPOpsTypes.cpp.inc"
88 declarePromisedInterface<ConvertToLLVMPatternInterface, OpenMPDialect>();
90 MemRefType::attachInterface<MemRefPointerLikeModel>(*
getContext());
91 LLVM::LLVMPointerType::attachInterface<LLVMPointerPointerLikeModel>(
96 mlir::ModuleOp::attachInterface<mlir::omp::OffloadModuleDefaultModel>(
102 mlir::LLVM::GlobalOp::attachInterface<
105 mlir::LLVM::LLVMFuncOp::attachInterface<
108 mlir::func::FuncOp::attachInterface<
134 allocatorVars.push_back(operand);
135 allocatorTypes.push_back(type);
141 allocateVars.push_back(operand);
142 allocateTypes.push_back(type);
153 for (
unsigned i = 0; i < allocateVars.size(); ++i) {
154 std::string separator = i == allocateVars.size() - 1 ?
"" :
", ";
155 p << allocatorVars[i] <<
" : " << allocatorTypes[i] <<
" -> ";
156 p << allocateVars[i] <<
" : " << allocateTypes[i] << separator;
164 template <
typename ClauseAttr>
166 using ClauseT = decltype(std::declval<ClauseAttr>().getValue());
171 if (std::optional<ClauseT> enumValue = symbolizeEnum<ClauseT>(enumStr)) {
175 return parser.
emitError(loc,
"invalid clause value: '") << enumStr <<
"'";
178 template <
typename ClauseAttr>
180 p << stringifyEnum(attr.getValue());
203 linearVars.push_back(var);
204 linearTypes.push_back(type);
205 linearStepVars.push_back(stepVar);
214 size_t linearVarsSize = linearVars.size();
215 for (
unsigned i = 0; i < linearVarsSize; ++i) {
216 std::string separator = i == linearVarsSize - 1 ?
"" :
", ";
218 if (linearStepVars.size() > i)
219 p <<
" = " << linearStepVars[i];
220 p <<
" : " << linearVars[i].getType() << separator;
233 for (
const auto &it : nontemporalVars)
234 if (!nontemporalItems.insert(it).second)
235 return op->
emitOpError() <<
"nontemporal variable used more than once";
244 std::optional<ArrayAttr> alignments,
247 if (!alignedVars.empty()) {
248 if (!alignments || alignments->size() != alignedVars.size())
250 <<
"expected as many alignment values as aligned variables";
253 return op->
emitOpError() <<
"unexpected alignment values attribute";
259 for (
auto it : alignedVars)
260 if (!alignedItems.insert(it).second)
261 return op->
emitOpError() <<
"aligned variable used more than once";
267 for (
unsigned i = 0; i < (*alignments).size(); ++i) {
268 if (
auto intAttr = llvm::dyn_cast<IntegerAttr>((*alignments)[i])) {
269 if (intAttr.getValue().sle(0))
270 return op->
emitOpError() <<
"alignment should be greater than 0";
272 return op->
emitOpError() <<
"expected integer alignment";
286 ArrayAttr &alignmentsAttr) {
289 if (parser.parseOperand(alignedVars.emplace_back()) ||
290 parser.parseColonType(alignedTypes.emplace_back()) ||
291 parser.parseArrow() ||
292 parser.parseAttribute(alignmentVec.emplace_back())) {
306 std::optional<ArrayAttr> alignments) {
307 for (
unsigned i = 0; i < alignedVars.size(); ++i) {
310 p << alignedVars[i] <<
" : " << alignedVars[i].
getType();
311 p <<
" -> " << (*alignments)[i];
322 if (modifiers.size() > 2)
324 for (
const auto &mod : modifiers) {
327 auto symbol = symbolizeScheduleModifier(mod);
330 <<
" unknown modifier type: " << mod;
335 if (modifiers.size() == 1) {
336 if (symbolizeScheduleModifier(modifiers[0]) == ScheduleModifier::simd) {
337 modifiers.push_back(modifiers[0]);
338 modifiers[0] = stringifyScheduleModifier(ScheduleModifier::none);
340 }
else if (modifiers.size() == 2) {
343 if (symbolizeScheduleModifier(modifiers[0]) == ScheduleModifier::simd ||
344 symbolizeScheduleModifier(modifiers[1]) != ScheduleModifier::simd)
346 <<
" incorrect modifier order";
362 ScheduleModifierAttr &scheduleMod, UnitAttr &scheduleSimd,
363 std::optional<OpAsmParser::UnresolvedOperand> &chunkSize,
368 std::optional<mlir::omp::ClauseScheduleKind> schedule =
369 symbolizeClauseScheduleKind(keyword);
375 case ClauseScheduleKind::Static:
376 case ClauseScheduleKind::Dynamic:
377 case ClauseScheduleKind::Guided:
383 chunkSize = std::nullopt;
386 case ClauseScheduleKind::Auto:
388 chunkSize = std::nullopt;
397 modifiers.push_back(mod);
403 if (!modifiers.empty()) {
405 if (std::optional<ScheduleModifier> mod =
406 symbolizeScheduleModifier(modifiers[0])) {
409 return parser.
emitError(loc,
"invalid schedule modifier");
412 if (modifiers.size() > 1) {
413 assert(symbolizeScheduleModifier(modifiers[1]) == ScheduleModifier::simd);
423 ClauseScheduleKindAttr scheduleKind,
424 ScheduleModifierAttr scheduleMod,
425 UnitAttr scheduleSimd,
Value scheduleChunk,
426 Type scheduleChunkType) {
427 p << stringifyClauseScheduleKind(scheduleKind.getValue());
429 p <<
" = " << scheduleChunk <<
" : " << scheduleChunk.
getType();
431 p <<
", " << stringifyScheduleModifier(scheduleMod.getValue());
443 ClauseOrderKindAttr &order,
444 OrderModifierAttr &orderMod) {
449 if (std::optional<OrderModifier> enumValue =
450 symbolizeOrderModifier(enumStr)) {
458 if (std::optional<ClauseOrderKind> enumValue =
459 symbolizeClauseOrderKind(enumStr)) {
463 return parser.
emitError(loc,
"invalid clause value: '") << enumStr <<
"'";
467 ClauseOrderKindAttr order,
468 OrderModifierAttr orderMod) {
470 p << stringifyOrderModifier(orderMod.getValue()) <<
":";
472 p << stringifyClauseOrderKind(order.getValue());
480 struct MapParseArgs {
485 : vars(vars), types(types) {}
487 struct PrivateParseArgs {
495 : vars(vars), types(types), syms(syms), mapIndices(mapIndices) {}
498 struct ReductionParseArgs {
503 ReductionModifierAttr *modifier;
506 ArrayAttr &syms, ReductionModifierAttr *mod =
nullptr)
507 : vars(vars), types(types), byref(byref), syms(syms), modifier(mod) {}
510 struct AllRegionParseArgs {
511 std::optional<MapParseArgs> hostEvalArgs;
512 std::optional<ReductionParseArgs> inReductionArgs;
513 std::optional<MapParseArgs> mapArgs;
514 std::optional<PrivateParseArgs> privateArgs;
515 std::optional<ReductionParseArgs> reductionArgs;
516 std::optional<ReductionParseArgs> taskReductionArgs;
517 std::optional<MapParseArgs> useDeviceAddrArgs;
518 std::optional<MapParseArgs> useDevicePtrArgs;
529 ReductionModifierAttr *modifier =
nullptr) {
533 unsigned regionArgOffset = regionPrivateArgs.size();
543 std::optional<ReductionModifier> enumValue =
544 symbolizeReductionModifier(enumStr);
545 if (!enumValue.has_value())
554 isByRefVec.push_back(
555 parser.parseOptionalKeyword(
"byref").succeeded());
557 if (symbols && parser.parseAttribute(symbolVec.emplace_back()))
560 if (parser.parseOperand(operands.emplace_back()) ||
561 parser.parseArrow() ||
562 parser.parseArgument(regionPrivateArgs.emplace_back()))
566 if (parser.parseOptionalLSquare().succeeded()) {
567 if (parser.parseKeyword(
"map_idx") || parser.parseEqual() ||
568 parser.parseInteger(mapIndicesVec.emplace_back()) ||
569 parser.parseRSquare())
572 mapIndicesVec.push_back(-1);
579 if (parser.parseColon())
582 if (parser.parseCommaSeparatedList([&]() {
583 if (parser.parseType(types.emplace_back()))
590 if (operands.size() != types.size())
593 if (parser.parseRParen())
596 auto *argsBegin = regionPrivateArgs.begin();
598 argsBegin + regionArgOffset + types.size());
599 for (
auto [prv, type] : llvm::zip_equal(argsSubrange, types)) {
608 if (!mapIndicesVec.empty())
621 StringRef keyword, std::optional<MapParseArgs> mapArgs) {
636 StringRef keyword, std::optional<PrivateParseArgs> privateArgs) {
642 parser, privateArgs->vars, privateArgs->types, entryBlockArgs,
643 &privateArgs->syms, privateArgs->mapIndices)))
652 StringRef keyword, std::optional<ReductionParseArgs> reductionArgs) {
657 parser, reductionArgs->vars, reductionArgs->types, entryBlockArgs,
658 &reductionArgs->syms,
nullptr, &reductionArgs->byref,
659 reductionArgs->modifier)))
666 AllRegionParseArgs args) {
672 <<
"invalid `host_eval` format";
675 args.inReductionArgs)))
677 <<
"invalid `in_reduction` format";
682 <<
"invalid `map_entries` format";
687 <<
"invalid `private` format";
690 args.reductionArgs)))
692 <<
"invalid `reduction` format";
695 args.taskReductionArgs)))
697 <<
"invalid `task_reduction` format";
700 args.useDeviceAddrArgs)))
702 <<
"invalid `use_device_addr` format";
705 args.useDevicePtrArgs)))
707 <<
"invalid `use_device_addr` format";
724 AllRegionParseArgs args;
725 args.hostEvalArgs.emplace(hostEvalVars, hostEvalTypes);
726 args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
727 inReductionByref, inReductionSyms);
728 args.mapArgs.emplace(mapVars, mapTypes);
729 args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
741 AllRegionParseArgs args;
742 args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
743 inReductionByref, inReductionSyms);
744 args.privateArgs.emplace(privateVars, privateTypes, privateSyms);
755 ReductionModifierAttr &reductionMod,
758 ArrayAttr &reductionSyms) {
759 AllRegionParseArgs args;
760 args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
761 inReductionByref, inReductionSyms);
762 args.privateArgs.emplace(privateVars, privateTypes, privateSyms);
763 args.reductionArgs.emplace(reductionVars, reductionTypes, reductionByref,
764 reductionSyms, &reductionMod);
772 AllRegionParseArgs args;
773 args.privateArgs.emplace(privateVars, privateTypes, privateSyms);
781 ReductionModifierAttr &reductionMod,
784 ArrayAttr &reductionSyms) {
785 AllRegionParseArgs args;
786 args.privateArgs.emplace(privateVars, privateTypes, privateSyms);
787 args.reductionArgs.emplace(reductionVars, reductionTypes, reductionByref,
788 reductionSyms, &reductionMod);
797 AllRegionParseArgs args;
798 args.taskReductionArgs.emplace(taskReductionVars, taskReductionTypes,
799 taskReductionByref, taskReductionSyms);
809 AllRegionParseArgs args;
810 args.useDeviceAddrArgs.emplace(useDeviceAddrVars, useDeviceAddrTypes);
811 args.useDevicePtrArgs.emplace(useDevicePtrVars, useDevicePtrTypes);
820 struct MapPrintArgs {
825 struct PrivatePrintArgs {
832 : vars(vars), types(types), syms(syms), mapIndices(mapIndices) {}
834 struct ReductionPrintArgs {
839 ReductionModifierAttr modifier;
841 ArrayAttr syms, ReductionModifierAttr mod =
nullptr)
842 : vars(vars), types(types), byref(byref), syms(syms), modifier(mod) {}
844 struct AllRegionPrintArgs {
845 std::optional<MapPrintArgs> hostEvalArgs;
846 std::optional<ReductionPrintArgs> inReductionArgs;
847 std::optional<MapPrintArgs> mapArgs;
848 std::optional<PrivatePrintArgs> privateArgs;
849 std::optional<ReductionPrintArgs> reductionArgs;
850 std::optional<ReductionPrintArgs> taskReductionArgs;
851 std::optional<MapPrintArgs> useDeviceAddrArgs;
852 std::optional<MapPrintArgs> useDevicePtrArgs;
861 ReductionModifierAttr modifier =
nullptr) {
862 if (argsSubrange.empty())
865 p << clauseName <<
"(";
868 p <<
"mod: " << stringifyReductionModifier(modifier.getValue()) <<
", ";
885 llvm::interleaveComma(llvm::zip_equal(operands, argsSubrange, symbols,
886 mapIndices.asArrayRef(),
889 auto [op, arg, sym, map, isByRef] = t;
895 p << op <<
" -> " << arg;
898 p <<
" [map_idx=" << map <<
"]";
901 llvm::interleaveComma(types, p);
906 StringRef clauseName,
ValueRange argsSubrange,
907 std::optional<MapPrintArgs> mapArgs) {
914 StringRef clauseName,
ValueRange argsSubrange,
915 std::optional<PrivatePrintArgs> privateArgs) {
918 privateArgs->vars, privateArgs->types,
919 privateArgs->syms, privateArgs->mapIndices);
925 std::optional<ReductionPrintArgs> reductionArgs) {
928 reductionArgs->vars, reductionArgs->types,
929 reductionArgs->syms,
nullptr,
930 reductionArgs->byref, reductionArgs->modifier);
934 const AllRegionPrintArgs &args) {
935 auto iface = llvm::cast<mlir::omp::BlockArgOpenMPOpInterface>(op);
941 args.inReductionArgs);
949 iface.getTaskReductionBlockArgs(),
950 args.taskReductionArgs);
952 iface.getUseDeviceAddrBlockArgs(),
953 args.useDeviceAddrArgs);
955 iface.getUseDevicePtrBlockArgs(), args.useDevicePtrArgs);
967 AllRegionPrintArgs args;
968 args.hostEvalArgs.emplace(hostEvalVars, hostEvalTypes);
969 args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
970 inReductionByref, inReductionSyms);
971 args.mapArgs.emplace(mapVars, mapTypes);
972 args.privateArgs.emplace(privateVars, privateTypes, privateSyms, privateMaps);
980 ArrayAttr privateSyms) {
981 AllRegionPrintArgs args;
982 args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
983 inReductionByref, inReductionSyms);
984 args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
993 ArrayAttr privateSyms, ReductionModifierAttr reductionMod,
996 AllRegionPrintArgs args;
997 args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
998 inReductionByref, inReductionSyms);
999 args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
1001 args.reductionArgs.emplace(reductionVars, reductionTypes, reductionByref,
1002 reductionSyms, reductionMod);
1008 ArrayAttr privateSyms) {
1009 AllRegionPrintArgs args;
1010 args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
1017 TypeRange privateTypes, ArrayAttr privateSyms,
1018 ReductionModifierAttr reductionMod,
ValueRange reductionVars,
1020 ArrayAttr reductionSyms) {
1021 AllRegionPrintArgs args;
1022 args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
1024 args.reductionArgs.emplace(reductionVars, reductionTypes, reductionByref,
1025 reductionSyms, reductionMod);
1034 ArrayAttr taskReductionSyms) {
1035 AllRegionPrintArgs args;
1036 args.taskReductionArgs.emplace(taskReductionVars, taskReductionTypes,
1037 taskReductionByref, taskReductionSyms);
1047 AllRegionPrintArgs args;
1048 args.useDeviceAddrArgs.emplace(useDeviceAddrVars, useDeviceAddrTypes);
1049 args.useDevicePtrArgs.emplace(useDevicePtrVars, useDevicePtrTypes);
1054 static LogicalResult
1058 if (!reductionVars.empty()) {
1059 if (!reductionSyms || reductionSyms->size() != reductionVars.size())
1061 <<
"expected as many reduction symbol references "
1062 "as reduction variables";
1063 if (reductionByref && reductionByref->size() != reductionVars.size())
1064 return op->
emitError() <<
"expected as many reduction variable by "
1065 "reference attributes as reduction variables";
1068 return op->
emitOpError() <<
"unexpected reduction symbol references";
1075 for (
auto args : llvm::zip(reductionVars, *reductionSyms)) {
1076 Value accum = std::get<0>(args);
1078 if (!accumulators.insert(accum).second)
1079 return op->
emitOpError() <<
"accumulator variable used more than once";
1082 auto symbolRef = llvm::cast<SymbolRefAttr>(std::get<1>(args));
1084 SymbolTable::lookupNearestSymbolFrom<DeclareReductionOp>(op, symbolRef);
1086 return op->
emitOpError() <<
"expected symbol reference " << symbolRef
1087 <<
" to point to a reduction declaration";
1089 if (decl.getAccumulatorType() && decl.getAccumulatorType() != varType)
1091 <<
"expected accumulator (" << varType
1092 <<
") to be the same type as reduction declaration ("
1093 << decl.getAccumulatorType() <<
")";
1112 if (parser.parseOperand(copyprivateVars.emplace_back()) ||
1113 parser.parseArrow() ||
1114 parser.parseAttribute(symsVec.emplace_back()) ||
1115 parser.parseColonType(copyprivateTypes.emplace_back()))
1129 std::optional<ArrayAttr> copyprivateSyms) {
1130 if (!copyprivateSyms.has_value())
1132 llvm::interleaveComma(
1133 llvm::zip(copyprivateVars, *copyprivateSyms, copyprivateTypes), p,
1134 [&](
const auto &args) {
1135 p << std::get<0>(args) <<
" -> " << std::get<1>(args) <<
" : "
1136 << std::get<2>(args);
1141 static LogicalResult
1143 std::optional<ArrayAttr> copyprivateSyms) {
1144 size_t copyprivateSymsSize =
1145 copyprivateSyms.has_value() ? copyprivateSyms->size() : 0;
1146 if (copyprivateSymsSize != copyprivateVars.size())
1147 return op->
emitOpError() <<
"inconsistent number of copyprivate vars (= "
1148 << copyprivateVars.size()
1149 <<
") and functions (= " << copyprivateSymsSize
1150 <<
"), both must be equal";
1151 if (!copyprivateSyms.has_value())
1154 for (
auto copyprivateVarAndSym :
1155 llvm::zip(copyprivateVars, *copyprivateSyms)) {
1157 llvm::cast<SymbolRefAttr>(std::get<1>(copyprivateVarAndSym));
1158 std::optional<std::variant<mlir::func::FuncOp, mlir::LLVM::LLVMFuncOp>>
1160 if (mlir::func::FuncOp mlirFuncOp =
1161 SymbolTable::lookupNearestSymbolFrom<mlir::func::FuncOp>(op,
1163 funcOp = mlirFuncOp;
1164 else if (mlir::LLVM::LLVMFuncOp llvmFuncOp =
1165 SymbolTable::lookupNearestSymbolFrom<mlir::LLVM::LLVMFuncOp>(
1167 funcOp = llvmFuncOp;
1169 auto getNumArguments = [&] {
1170 return std::visit([](
auto &f) {
return f.getNumArguments(); }, *funcOp);
1173 auto getArgumentType = [&](
unsigned i) {
1174 return std::visit([i](
auto &f) {
return f.getArgumentTypes()[i]; },
1179 return op->
emitOpError() <<
"expected symbol reference " << symbolRef
1180 <<
" to point to a copy function";
1182 if (getNumArguments() != 2)
1184 <<
"expected copy function " << symbolRef <<
" to have 2 operands";
1186 Type argTy = getArgumentType(0);
1187 if (argTy != getArgumentType(1))
1188 return op->
emitOpError() <<
"expected copy function " << symbolRef
1189 <<
" arguments to have the same type";
1191 Type varType = std::get<0>(copyprivateVarAndSym).getType();
1192 if (argTy != varType)
1194 <<
"expected copy function arguments' type (" << argTy
1195 <<
") to be the same as copyprivate variable's type (" << varType
1216 if (parser.parseKeyword(&keyword) || parser.parseArrow() ||
1217 parser.parseOperand(dependVars.emplace_back()) ||
1218 parser.parseColonType(dependTypes.emplace_back()))
1220 if (std::optional<ClauseTaskDepend> keywordDepend =
1221 (symbolizeClauseTaskDepend(keyword)))
1222 kindsVec.emplace_back(
1223 ClauseTaskDependAttr::get(parser.getContext(), *keywordDepend));
1237 std::optional<ArrayAttr> dependKinds) {
1239 for (
unsigned i = 0, e = dependKinds->size(); i < e; ++i) {
1242 p << stringifyClauseTaskDepend(
1243 llvm::cast<mlir::omp::ClauseTaskDependAttr>((*dependKinds)[i])
1245 <<
" -> " << dependVars[i] <<
" : " << dependTypes[i];
1251 std::optional<ArrayAttr> dependKinds,
1253 if (!dependVars.empty()) {
1254 if (!dependKinds || dependKinds->size() != dependVars.size())
1255 return op->
emitOpError() <<
"expected as many depend values"
1256 " as depend variables";
1258 if (dependKinds && !dependKinds->empty())
1259 return op->
emitOpError() <<
"unexpected depend values";
1275 IntegerAttr &hintAttr) {
1276 StringRef hintKeyword;
1282 auto parseKeyword = [&]() -> ParseResult {
1285 if (hintKeyword ==
"uncontended")
1287 else if (hintKeyword ==
"contended")
1289 else if (hintKeyword ==
"nonspeculative")
1291 else if (hintKeyword ==
"speculative")
1295 << hintKeyword <<
" is not a valid hint";
1306 IntegerAttr hintAttr) {
1307 int64_t hint = hintAttr.getInt();
1315 auto bitn = [](
int value,
int n) ->
bool {
return value & (1 << n); };
1317 bool uncontended = bitn(hint, 0);
1318 bool contended = bitn(hint, 1);
1319 bool nonspeculative = bitn(hint, 2);
1320 bool speculative = bitn(hint, 3);
1324 hints.push_back(
"uncontended");
1326 hints.push_back(
"contended");
1328 hints.push_back(
"nonspeculative");
1330 hints.push_back(
"speculative");
1332 llvm::interleaveComma(hints, p);
1339 auto bitn = [](
int value,
int n) ->
bool {
return value & (1 << n); };
1341 bool uncontended = bitn(hint, 0);
1342 bool contended = bitn(hint, 1);
1343 bool nonspeculative = bitn(hint, 2);
1344 bool speculative = bitn(hint, 3);
1346 if (uncontended && contended)
1347 return op->
emitOpError() <<
"the hints omp_sync_hint_uncontended and "
1348 "omp_sync_hint_contended cannot be combined";
1349 if (nonspeculative && speculative)
1350 return op->
emitOpError() <<
"the hints omp_sync_hint_nonspeculative and "
1351 "omp_sync_hint_speculative cannot be combined.";
1361 llvm::omp::OpenMPOffloadMappingFlags flag) {
1362 return value & llvm::to_underlying(flag);
1371 llvm::omp::OpenMPOffloadMappingFlags mapTypeBits =
1372 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_NONE;
1376 auto parseTypeAndMod = [&]() -> ParseResult {
1377 StringRef mapTypeMod;
1381 if (mapTypeMod ==
"always")
1382 mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ALWAYS;
1384 if (mapTypeMod ==
"implicit")
1385 mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_IMPLICIT;
1387 if (mapTypeMod ==
"close")
1388 mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_CLOSE;
1390 if (mapTypeMod ==
"present")
1391 mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_PRESENT;
1393 if (mapTypeMod ==
"to")
1394 mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO;
1396 if (mapTypeMod ==
"from")
1397 mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM;
1399 if (mapTypeMod ==
"tofrom")
1400 mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO |
1401 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM;
1403 if (mapTypeMod ==
"delete")
1404 mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_DELETE;
1414 llvm::to_underlying(mapTypeBits));
1422 IntegerAttr mapType) {
1423 uint64_t mapTypeBits = mapType.getUInt();
1425 bool emitAllocRelease =
true;
1431 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ALWAYS))
1432 mapTypeStrs.push_back(
"always");
1434 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_IMPLICIT))
1435 mapTypeStrs.push_back(
"implicit");
1437 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_CLOSE))
1438 mapTypeStrs.push_back(
"close");
1440 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_PRESENT))
1441 mapTypeStrs.push_back(
"present");
1447 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO);
1449 mapTypeBits, llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM);
1451 emitAllocRelease =
false;
1452 mapTypeStrs.push_back(
"tofrom");
1454 emitAllocRelease =
false;
1455 mapTypeStrs.push_back(
"from");
1457 emitAllocRelease =
false;
1458 mapTypeStrs.push_back(
"to");
1461 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_DELETE)) {
1462 emitAllocRelease =
false;
1463 mapTypeStrs.push_back(
"delete");
1465 if (emitAllocRelease)
1466 mapTypeStrs.push_back(
"exit_release_or_enter_alloc");
1468 for (
unsigned int i = 0; i < mapTypeStrs.size(); ++i) {
1469 p << mapTypeStrs[i];
1470 if (i + 1 < mapTypeStrs.size()) {
1477 ArrayAttr &membersIdx) {
1480 auto parseIndices = [&]() -> ParseResult {
1485 APInt(64, value,
false)));
1503 if (!memberIdxs.empty())
1510 ArrayAttr membersIdx) {
1514 llvm::interleaveComma(membersIdx, p, [&p](
Attribute v) {
1516 auto memberIdx = cast<ArrayAttr>(v);
1517 llvm::interleaveComma(memberIdx.getValue(), p, [&p](
Attribute v2) {
1518 p << cast<IntegerAttr>(v2).getInt();
1525 VariableCaptureKindAttr mapCaptureType) {
1526 std::string typeCapStr;
1527 llvm::raw_string_ostream typeCap(typeCapStr);
1528 if (mapCaptureType.getValue() == mlir::omp::VariableCaptureKind::ByRef)
1530 if (mapCaptureType.getValue() == mlir::omp::VariableCaptureKind::ByCopy)
1531 typeCap <<
"ByCopy";
1532 if (mapCaptureType.getValue() == mlir::omp::VariableCaptureKind::VLAType)
1533 typeCap <<
"VLAType";
1534 if (mapCaptureType.getValue() == mlir::omp::VariableCaptureKind::This)
1540 VariableCaptureKindAttr &mapCaptureType) {
1541 StringRef mapCaptureKey;
1545 if (mapCaptureKey ==
"This")
1547 parser.
getContext(), mlir::omp::VariableCaptureKind::This);
1548 if (mapCaptureKey ==
"ByRef")
1550 parser.
getContext(), mlir::omp::VariableCaptureKind::ByRef);
1551 if (mapCaptureKey ==
"ByCopy")
1553 parser.
getContext(), mlir::omp::VariableCaptureKind::ByCopy);
1554 if (mapCaptureKey ==
"VLAType")
1556 parser.
getContext(), mlir::omp::VariableCaptureKind::VLAType);
1565 for (
auto mapOp : mapVars) {
1566 if (!mapOp.getDefiningOp())
1569 if (
auto mapInfoOp =
1570 mlir::dyn_cast<mlir::omp::MapInfoOp>(mapOp.getDefiningOp())) {
1571 if (!mapInfoOp.getMapType().has_value())
1574 if (!mapInfoOp.getMapCaptureType().has_value())
1577 uint64_t mapTypeBits = mapInfoOp.getMapType().value();
1580 mapTypeBits, llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO);
1582 mapTypeBits, llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM);
1584 mapTypeBits, llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_DELETE);
1587 mapTypeBits, llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ALWAYS);
1589 mapTypeBits, llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_CLOSE);
1591 mapTypeBits, llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_IMPLICIT);
1593 if ((isa<TargetDataOp>(op) || isa<TargetOp>(op)) && del)
1595 "to, from, tofrom and alloc map types are permitted");
1597 if (isa<TargetEnterDataOp>(op) && (from || del))
1598 return emitError(op->
getLoc(),
"to and alloc map types are permitted");
1600 if (isa<TargetExitDataOp>(op) && to)
1602 "from, release and delete map types are permitted");
1604 if (isa<TargetUpdateOp>(op)) {
1607 "at least one of to or from map types must be "
1608 "specified, other map types are not permitted");
1613 "at least one of to or from map types must be "
1614 "specified, other map types are not permitted");
1617 auto updateVar = mapInfoOp.getVarPtr();
1619 if ((to && from) || (to && updateFromVars.contains(updateVar)) ||
1620 (from && updateToVars.contains(updateVar))) {
1623 "either to or from map types can be specified, not both");
1626 if (always || close || implicit) {
1629 "present, mapper and iterator map type modifiers are permitted");
1632 to ? updateToVars.insert(updateVar) : updateFromVars.insert(updateVar);
1643 std::optional<DenseI64ArrayAttr> privateMapIndices =
1644 targetOp.getPrivateMapsAttr();
1647 if (!privateMapIndices.has_value() || !privateMapIndices.value())
1652 if (privateMapIndices.value().size() !=
1653 static_cast<int64_t
>(privateVars.size()))
1654 return emitError(targetOp.getLoc(),
"sizes of `private` operand range and "
1655 "`private_maps` attribute mismatch");
1665 const TargetDataOperands &clauses) {
1666 TargetDataOp::build(builder, state, clauses.device, clauses.ifExpr,
1667 clauses.mapVars, clauses.useDeviceAddrVars,
1668 clauses.useDevicePtrVars);
1672 if (getMapVars().empty() && getUseDevicePtrVars().empty() &&
1673 getUseDeviceAddrVars().empty()) {
1675 "At least one of map, use_device_ptr_vars, or "
1676 "use_device_addr_vars operand must be present");
1685 void TargetEnterDataOp::build(
1689 TargetEnterDataOp::build(builder, state,
1691 clauses.dependVars, clauses.device, clauses.ifExpr,
1692 clauses.mapVars, clauses.nowait);
1696 LogicalResult verifyDependVars =
1698 return failed(verifyDependVars) ? verifyDependVars
1709 TargetExitDataOp::build(builder, state,
1711 clauses.dependVars, clauses.device, clauses.ifExpr,
1712 clauses.mapVars, clauses.nowait);
1716 LogicalResult verifyDependVars =
1718 return failed(verifyDependVars) ? verifyDependVars
1729 TargetUpdateOp::build(builder, state,
makeArrayAttr(ctx, clauses.dependKinds),
1730 clauses.dependVars, clauses.device, clauses.ifExpr,
1731 clauses.mapVars, clauses.nowait);
1735 LogicalResult verifyDependVars =
1737 return failed(verifyDependVars) ? verifyDependVars
1746 const TargetOperands &clauses) {
1750 TargetOp::build(builder, state, {}, {},
1752 clauses.dependVars, clauses.device, clauses.hasDeviceAddrVars,
1753 clauses.hostEvalVars, clauses.ifExpr,
1755 nullptr, clauses.isDevicePtrVars,
1756 clauses.mapVars, clauses.nowait, clauses.privateVars,
1757 makeArrayAttr(ctx, clauses.privateSyms), clauses.threadLimit,
1762 LogicalResult verifyDependVars =
1765 if (failed(verifyDependVars))
1766 return verifyDependVars;
1770 if (failed(verifyMapVars))
1771 return verifyMapVars;
1776 LogicalResult TargetOp::verifyRegions() {
1777 auto teamsOps = getOps<TeamsOp>();
1778 if (std::distance(teamsOps.begin(), teamsOps.end()) > 1)
1779 return emitError(
"target containing multiple 'omp.teams' nested ops");
1782 llvm::omp::OMPTgtExecModeFlags execFlags = getKernelExecFlags();
1783 for (
Value hostEvalArg :
1784 cast<BlockArgOpenMPOpInterface>(getOperation()).getHostEvalBlockArgs()) {
1786 if (
auto teamsOp = dyn_cast<TeamsOp>(user)) {
1787 if (llvm::is_contained({teamsOp.getNumTeamsLower(),
1788 teamsOp.getNumTeamsUpper(),
1789 teamsOp.getThreadLimit()},
1793 return emitOpError() <<
"host_eval argument only legal as 'num_teams' "
1794 "and 'thread_limit' in 'omp.teams'";
1796 if (
auto parallelOp = dyn_cast<ParallelOp>(user)) {
1797 if (execFlags == llvm::omp::OMP_TGT_EXEC_MODE_SPMD &&
1798 hostEvalArg == parallelOp.getNumThreads())
1801 return emitOpError()
1802 <<
"host_eval argument only legal as 'num_threads' in "
1803 "'omp.parallel' when representing target SPMD";
1805 if (
auto loopNestOp = dyn_cast<LoopNestOp>(user)) {
1806 if (execFlags != llvm::omp::OMP_TGT_EXEC_MODE_GENERIC &&
1807 (llvm::is_contained(loopNestOp.getLoopLowerBounds(), hostEvalArg) ||
1808 llvm::is_contained(loopNestOp.getLoopUpperBounds(), hostEvalArg) ||
1809 llvm::is_contained(loopNestOp.getLoopSteps(), hostEvalArg)))
1812 return emitOpError() <<
"host_eval argument only legal as loop bounds "
1813 "and steps in 'omp.loop_nest' when "
1814 "representing target SPMD or Generic-SPMD";
1817 return emitOpError() <<
"host_eval argument illegal use in '"
1818 << user->getName() <<
"' operation";
1837 if (
auto memOp = dyn_cast<MemoryEffectOpInterface>(op)) {
1839 memOp.getEffects(effects);
1841 return isa<MemoryEffects::Write>(effect.
getEffect()) &&
1842 isa<SideEffects::AutomaticAllocationScopeResource>(
1849 Operation *TargetOp::getInnermostCapturedOmpOp() {
1850 Dialect *ompDialect = (*this)->getDialect();
1858 walk<WalkOrder::PreOrder>([&](
Operation *op) {
1860 return WalkResult::advance();
1865 bool isOmpDialect = op->
getDialect() == ompDialect;
1867 if (!isOmpDialect || !hasRegions)
1868 return WalkResult::skip();
1878 if (successor->isReachable(parentBlock))
1879 return WalkResult::interrupt();
1881 for (
Block &block : *parentRegion)
1883 !domInfo.
dominates(parentBlock, &block))
1884 return WalkResult::interrupt();
1890 return WalkResult::interrupt();
1895 return llvm::isa<LoopNestOp>(op) ? WalkResult::interrupt()
1902 llvm::omp::OMPTgtExecModeFlags TargetOp::getKernelExecFlags() {
1903 using namespace llvm::omp;
1907 Operation *capturedOp = getInnermostCapturedOmpOp();
1908 if (!isa_and_present<LoopNestOp>(capturedOp))
1909 return OMP_TGT_EXEC_MODE_GENERIC;
1912 cast<LoopNestOp>(capturedOp).gatherWrappers(wrappers);
1913 assert(!wrappers.empty());
1916 auto *innermostWrapper = wrappers.begin();
1917 if (isa<SimdOp>(innermostWrapper))
1918 innermostWrapper = std::next(innermostWrapper);
1920 long numWrappers = std::distance(innermostWrapper, wrappers.end());
1923 if (numWrappers == 1) {
1924 if (!isa<DistributeOp>(innermostWrapper))
1925 return OMP_TGT_EXEC_MODE_GENERIC;
1928 if (!isa_and_present<TeamsOp>(teamsOp))
1929 return OMP_TGT_EXEC_MODE_GENERIC;
1932 return OMP_TGT_EXEC_MODE_GENERIC_SPMD;
1936 if (numWrappers == 2) {
1937 if (!isa<WsloopOp>(innermostWrapper))
1938 return OMP_TGT_EXEC_MODE_GENERIC;
1940 innermostWrapper = std::next(innermostWrapper);
1941 if (!isa<DistributeOp>(innermostWrapper))
1942 return OMP_TGT_EXEC_MODE_GENERIC;
1945 if (!isa_and_present<ParallelOp>(parallelOp))
1946 return OMP_TGT_EXEC_MODE_GENERIC;
1949 if (!isa_and_present<TeamsOp>(teamsOp))
1950 return OMP_TGT_EXEC_MODE_GENERIC;
1953 return OMP_TGT_EXEC_MODE_SPMD;
1956 return OMP_TGT_EXEC_MODE_GENERIC;
1965 ParallelOp::build(builder, state,
ValueRange(),
1971 state.addAttributes(attributes);
1975 const ParallelOperands &clauses) {
1977 ParallelOp::build(builder, state, clauses.allocateVars, clauses.allocatorVars,
1978 clauses.ifExpr, clauses.numThreads, clauses.privateVars,
1980 clauses.procBindKind, clauses.reductionMod,
1981 clauses.reductionVars,
1986 template <
typename OpType>
1988 auto privateVars = op.getPrivateVars();
1989 auto privateSyms = op.getPrivateSymsAttr();
1991 if (privateVars.empty() && (privateSyms ==
nullptr || privateSyms.empty()))
1994 auto numPrivateVars = privateVars.size();
1995 auto numPrivateSyms = (privateSyms ==
nullptr) ? 0 : privateSyms.size();
1997 if (numPrivateVars != numPrivateSyms)
1998 return op.emitError() <<
"inconsistent number of private variables and "
1999 "privatizer op symbols, private vars: "
2001 <<
" vs. privatizer op symbols: " << numPrivateSyms;
2003 for (
auto privateVarInfo : llvm::zip_equal(privateVars, privateSyms)) {
2004 Type varType = std::get<0>(privateVarInfo).getType();
2005 SymbolRefAttr privateSym = cast<SymbolRefAttr>(std::get<1>(privateVarInfo));
2006 PrivateClauseOp privatizerOp =
2007 SymbolTable::lookupNearestSymbolFrom<PrivateClauseOp>(op, privateSym);
2009 if (privatizerOp ==
nullptr)
2010 return op.emitError() <<
"failed to lookup privatizer op with symbol: '"
2011 << privateSym <<
"'";
2013 Type privatizerType = privatizerOp.getArgType();
2015 if (privatizerType && (varType != privatizerType))
2016 return op.emitError()
2017 <<
"type mismatch between a "
2018 << (privatizerOp.getDataSharingType() ==
2019 DataSharingClauseType::Private
2022 <<
" variable and its privatizer op, var type: " << varType
2023 <<
" vs. privatizer op type: " << privatizerType;
2030 if (getAllocateVars().size() != getAllocatorVars().size())
2032 "expected equal sizes for allocate and allocator variables");
2038 getReductionByref());
2041 LogicalResult ParallelOp::verifyRegions() {
2042 auto distributeChildOps = getOps<DistributeOp>();
2043 if (!distributeChildOps.empty()) {
2046 <<
"'omp.composite' attribute missing from composite operation";
2049 Operation &distributeOp = **distributeChildOps.begin();
2051 if (&childOp == &distributeOp || ompDialect != childOp.getDialect())
2055 return emitError() <<
"unexpected OpenMP operation inside of composite "
2058 }
else if (isComposite()) {
2060 <<
"'omp.composite' attribute present in non-composite operation";
2077 const TeamsOperands &clauses) {
2080 TeamsOp::build(builder, state, clauses.allocateVars, clauses.allocatorVars,
2081 clauses.ifExpr, clauses.numTeamsLower, clauses.numTeamsUpper,
2083 clauses.reductionMod, clauses.reductionVars,
2086 clauses.threadLimit);
2098 return emitError(
"expected to be nested inside of omp.target or not nested "
2099 "in any OpenMP dialect operations");
2102 if (
auto numTeamsLowerBound = getNumTeamsLower()) {
2103 auto numTeamsUpperBound = getNumTeamsUpper();
2104 if (!numTeamsUpperBound)
2105 return emitError(
"expected num_teams upper bound to be defined if the "
2106 "lower bound is defined");
2107 if (numTeamsLowerBound.getType() != numTeamsUpperBound.getType())
2109 "expected num_teams upper bound and lower bound to be the same type");
2113 if (getAllocateVars().size() != getAllocatorVars().size())
2115 "expected equal sizes for allocate and allocator variables");
2118 getReductionByref());
2125 unsigned SectionOp::numPrivateBlockArgs() {
2126 return getParentOp().numPrivateBlockArgs();
2129 unsigned SectionOp::numReductionBlockArgs() {
2130 return getParentOp().numReductionBlockArgs();
2138 const SectionsOperands &clauses) {
2141 SectionsOp::build(builder, state, clauses.allocateVars, clauses.allocatorVars,
2143 nullptr, clauses.reductionMod,
2144 clauses.reductionVars,
2150 if (getAllocateVars().size() != getAllocatorVars().size())
2152 "expected equal sizes for allocate and allocator variables");
2155 getReductionByref());
2158 LogicalResult SectionsOp::verifyRegions() {
2159 for (
auto &inst : *getRegion().begin()) {
2160 if (!(isa<SectionOp>(inst) || isa<TerminatorOp>(inst))) {
2161 return emitOpError()
2162 <<
"expected omp.section op or terminator op inside region";
2174 const SingleOperands &clauses) {
2177 SingleOp::build(builder, state, clauses.allocateVars, clauses.allocatorVars,
2178 clauses.copyprivateVars,
2179 makeArrayAttr(ctx, clauses.copyprivateSyms), clauses.nowait,
2185 if (getAllocateVars().size() != getAllocatorVars().size())
2187 "expected equal sizes for allocate and allocator variables");
2190 getCopyprivateSyms());
2198 const WorkshareOperands &clauses) {
2199 WorkshareOp::build(builder, state, clauses.nowait);
2207 if (!(*this)->getParentOfType<WorkshareOp>())
2208 return emitError() <<
"must be nested in an omp.workshare";
2209 if (getNestedWrapper())
2210 return emitError() <<
"cannot be composite";
2218 LogicalResult LoopWrapperInterface::verifyImpl() {
2222 return emitOpError() <<
"loop wrapper must also have the `NoTerminator` "
2223 "and `SingleBlock` traits";
2226 return emitOpError() <<
"loop wrapper does not contain exactly one region";
2229 if (range_size(region.
getOps()) != 1)
2230 return emitOpError()
2231 <<
"loop wrapper does not contain exactly one nested op";
2234 if (!isa<LoopNestOp, LoopWrapperInterface>(firstOp))
2235 return emitOpError() <<
"op nested in loop wrapper is not another loop "
2236 "wrapper or `omp.loop_nest`";
2246 const LoopOperands &clauses) {
2249 LoopOp::build(builder, state, clauses.bindKind, clauses.privateVars,
2251 clauses.orderMod, clauses.reductionMod, clauses.reductionVars,
2258 getReductionByref());
2261 LogicalResult LoopOp::verifyRegions() {
2262 if (llvm::isa_and_nonnull<LoopWrapperInterface>((*this)->getParentOp()) ||
2264 return emitError() <<
"`omp.loop` expected to be a standalone loop wrapper";
2275 build(builder, state, {}, {},
2277 false,
nullptr,
nullptr,
2278 nullptr, {},
nullptr,
2284 state.addAttributes(attributes);
2288 const WsloopOperands &clauses) {
2292 WsloopOp::build(builder, state,
2294 clauses.linearVars, clauses.linearStepVars, clauses.nowait,
2295 clauses.order, clauses.orderMod, clauses.ordered,
2296 clauses.privateVars,
makeArrayAttr(ctx, clauses.privateSyms),
2297 clauses.reductionMod, clauses.reductionVars,
2300 clauses.scheduleKind, clauses.scheduleChunk,
2301 clauses.scheduleMod, clauses.scheduleSimd);
2306 getReductionByref());
2309 LogicalResult WsloopOp::verifyRegions() {
2310 bool isCompositeChildLeaf =
2311 llvm::dyn_cast_if_present<LoopWrapperInterface>((*this)->getParentOp());
2313 if (LoopWrapperInterface nested = getNestedWrapper()) {
2316 <<
"'omp.composite' attribute missing from composite wrapper";
2320 if (!isa<SimdOp>(nested))
2321 return emitError() <<
"only supported nested wrapper is 'omp.simd'";
2323 }
else if (isComposite() && !isCompositeChildLeaf) {
2325 <<
"'omp.composite' attribute present in non-composite wrapper";
2326 }
else if (!isComposite() && isCompositeChildLeaf) {
2328 <<
"'omp.composite' attribute missing from composite wrapper";
2339 const SimdOperands &clauses) {
2343 SimdOp::build(builder, state, clauses.alignedVars,
2346 clauses.nontemporalVars, clauses.order, clauses.orderMod,
2347 clauses.privateVars,
makeArrayAttr(ctx, clauses.privateSyms),
2348 clauses.reductionMod, clauses.reductionVars,
2355 if (getSimdlen().has_value() && getSafelen().has_value() &&
2356 getSimdlen().value() > getSafelen().value())
2357 return emitOpError()
2358 <<
"simdlen clause and safelen clause are both present, but the "
2359 "simdlen value is not less than or equal to safelen value";
2367 bool isCompositeChildLeaf =
2368 llvm::dyn_cast_if_present<LoopWrapperInterface>((*this)->getParentOp());
2370 if (!isComposite() && isCompositeChildLeaf)
2372 <<
"'omp.composite' attribute missing from composite wrapper";
2374 if (isComposite() && !isCompositeChildLeaf)
2376 <<
"'omp.composite' attribute present in non-composite wrapper";
2381 LogicalResult SimdOp::verifyRegions() {
2382 if (getNestedWrapper())
2383 return emitOpError() <<
"must wrap an 'omp.loop_nest' directly";
2393 const DistributeOperands &clauses) {
2394 DistributeOp::build(builder, state, clauses.allocateVars,
2395 clauses.allocatorVars, clauses.distScheduleStatic,
2396 clauses.distScheduleChunkSize, clauses.order,
2397 clauses.orderMod, clauses.privateVars,
2402 if (this->getDistScheduleChunkSize() && !this->getDistScheduleStatic())
2403 return emitOpError() <<
"chunk size set without "
2404 "dist_schedule_static being present";
2406 if (getAllocateVars().size() != getAllocatorVars().size())
2408 "expected equal sizes for allocate and allocator variables");
2413 LogicalResult DistributeOp::verifyRegions() {
2414 if (LoopWrapperInterface nested = getNestedWrapper()) {
2417 <<
"'omp.composite' attribute missing from composite wrapper";
2420 if (isa<WsloopOp>(nested)) {
2421 if (!llvm::dyn_cast_if_present<ParallelOp>((*this)->getParentOp()))
2422 return emitError() <<
"an 'omp.wsloop' nested wrapper is only allowed "
2423 "when 'omp.parallel' is the direct parent";
2424 }
else if (!isa<SimdOp>(nested))
2425 return emitError() <<
"only supported nested wrappers are 'omp.simd' and "
2427 }
else if (isComposite()) {
2429 <<
"'omp.composite' attribute present in non-composite wrapper";
2439 LogicalResult DeclareReductionOp::verifyRegions() {
2440 if (!getAllocRegion().empty()) {
2441 for (YieldOp yieldOp : getAllocRegion().getOps<YieldOp>()) {
2442 if (yieldOp.getResults().size() != 1 ||
2443 yieldOp.getResults().getTypes()[0] !=
getType())
2444 return emitOpError() <<
"expects alloc region to yield a value "
2445 "of the reduction type";
2449 if (getInitializerRegion().empty())
2450 return emitOpError() <<
"expects non-empty initializer region";
2451 Block &initializerEntryBlock = getInitializerRegion().
front();
2454 if (!getAllocRegion().empty())
2455 return emitOpError() <<
"expects two arguments to the initializer region "
2456 "when an allocation region is used";
2458 if (getAllocRegion().empty())
2459 return emitOpError() <<
"expects one argument to the initializer region "
2460 "when no allocation region is used";
2462 return emitOpError()
2463 <<
"expects one or two arguments to the initializer region";
2467 if (arg.getType() !=
getType())
2468 return emitOpError() <<
"expects initializer region argument to match "
2469 "the reduction type";
2471 for (YieldOp yieldOp : getInitializerRegion().getOps<YieldOp>()) {
2472 if (yieldOp.getResults().size() != 1 ||
2473 yieldOp.getResults().getTypes()[0] !=
getType())
2474 return emitOpError() <<
"expects initializer region to yield a value "
2475 "of the reduction type";
2478 if (getReductionRegion().empty())
2479 return emitOpError() <<
"expects non-empty reduction region";
2480 Block &reductionEntryBlock = getReductionRegion().
front();
2485 return emitOpError() <<
"expects reduction region with two arguments of "
2486 "the reduction type";
2487 for (YieldOp yieldOp : getReductionRegion().getOps<YieldOp>()) {
2488 if (yieldOp.getResults().size() != 1 ||
2489 yieldOp.getResults().getTypes()[0] !=
getType())
2490 return emitOpError() <<
"expects reduction region to yield a value "
2491 "of the reduction type";
2494 if (!getAtomicReductionRegion().empty()) {
2495 Block &atomicReductionEntryBlock = getAtomicReductionRegion().
front();
2499 return emitOpError() <<
"expects atomic reduction region with two "
2500 "arguments of the same type";
2501 auto ptrType = llvm::dyn_cast<PointerLikeType>(
2504 (ptrType.getElementType() && ptrType.getElementType() !=
getType()))
2505 return emitOpError() <<
"expects atomic reduction region arguments to "
2506 "be accumulators containing the reduction type";
2509 if (getCleanupRegion().empty())
2511 Block &cleanupEntryBlock = getCleanupRegion().
front();
2514 return emitOpError() <<
"expects cleanup region with one argument "
2515 "of the reduction type";
2525 const TaskOperands &clauses) {
2527 TaskOp::build(builder, state, clauses.allocateVars, clauses.allocatorVars,
2528 makeArrayAttr(ctx, clauses.dependKinds), clauses.dependVars,
2529 clauses.final, clauses.ifExpr, clauses.inReductionVars,
2531 makeArrayAttr(ctx, clauses.inReductionSyms), clauses.mergeable,
2532 clauses.priority, clauses.privateVars,
2534 clauses.untied, clauses.eventHandle);
2538 LogicalResult verifyDependVars =
2540 return failed(verifyDependVars)
2543 getInReductionVars(),
2544 getInReductionByref());
2552 const TaskgroupOperands &clauses) {
2554 TaskgroupOp::build(builder, state, clauses.allocateVars,
2555 clauses.allocatorVars, clauses.taskReductionVars,
2562 getTaskReductionVars(),
2563 getTaskReductionByref());
2571 const TaskloopOperands &clauses) {
2575 builder, state, clauses.allocateVars, clauses.allocatorVars,
2576 clauses.final, clauses.grainsize, clauses.ifExpr, clauses.inReductionVars,
2578 makeArrayAttr(ctx, clauses.inReductionSyms), clauses.mergeable,
2579 clauses.nogroup, clauses.numTasks, clauses.priority, {},
2580 nullptr, clauses.reductionMod, clauses.reductionVars,
2587 getInReductionVars().end());
2588 allReductionNvars.insert(allReductionNvars.end(), getReductionVars().begin(),
2589 getReductionVars().end());
2590 return allReductionNvars;
2594 if (getAllocateVars().size() != getAllocatorVars().size())
2596 "expected equal sizes for allocate and allocator variables");
2598 getReductionVars(), getReductionByref())) ||
2600 getInReductionVars(),
2601 getInReductionByref())))
2604 if (!getReductionVars().empty() && getNogroup())
2605 return emitError(
"if a reduction clause is present on the taskloop "
2606 "directive, the nogroup clause must not be specified");
2607 for (
auto var : getReductionVars()) {
2608 if (llvm::is_contained(getInReductionVars(), var))
2609 return emitError(
"the same list item cannot appear in both a reduction "
2610 "and an in_reduction clause");
2613 if (getGrainsize() && getNumTasks()) {
2615 "the grainsize clause and num_tasks clause are mutually exclusive and "
2616 "may not appear on the same taskloop directive");
2622 LogicalResult TaskloopOp::verifyRegions() {
2623 if (LoopWrapperInterface nested = getNestedWrapper()) {
2626 <<
"'omp.composite' attribute missing from composite wrapper";
2630 if (!isa<SimdOp>(nested))
2631 return emitError() <<
"only supported nested wrapper is 'omp.simd'";
2632 }
else if (isComposite()) {
2634 <<
"'omp.composite' attribute present in non-composite wrapper";
2658 for (
auto &iv : ivs)
2659 iv.type = loopVarType;
2688 Region ®ion = getRegion();
2690 p <<
" (" << args <<
") : " << args[0].getType() <<
" = ("
2691 << getLoopLowerBounds() <<
") to (" << getLoopUpperBounds() <<
") ";
2692 if (getLoopInclusive())
2694 p <<
"step (" << getLoopSteps() <<
") ";
2699 const LoopNestOperands &clauses) {
2700 LoopNestOp::build(builder, state, clauses.loopLowerBounds,
2701 clauses.loopUpperBounds, clauses.loopSteps,
2702 clauses.loopInclusive);
2706 if (getLoopLowerBounds().empty())
2707 return emitOpError() <<
"must represent at least one loop";
2709 if (getLoopLowerBounds().size() != getIVs().size())
2710 return emitOpError() <<
"number of range arguments and IVs do not match";
2712 for (
auto [lb, iv] : llvm::zip_equal(getLoopLowerBounds(), getIVs())) {
2713 if (lb.getType() != iv.getType())
2714 return emitOpError()
2715 <<
"range argument type does not match corresponding IV type";
2718 if (!llvm::dyn_cast_if_present<LoopWrapperInterface>((*this)->getParentOp()))
2719 return emitOpError() <<
"expects parent op to be a loop wrapper";
2724 void LoopNestOp::gatherWrappers(
2727 while (
auto wrapper =
2728 llvm::dyn_cast_if_present<LoopWrapperInterface>(parent)) {
2729 wrappers.push_back(wrapper);
2739 const CriticalDeclareOperands &clauses) {
2740 CriticalDeclareOp::build(builder, state, clauses.symName, clauses.hint);
2748 if (getNameAttr()) {
2749 SymbolRefAttr symbolRef = getNameAttr();
2753 return emitOpError() <<
"expected symbol reference " << symbolRef
2754 <<
" to point to a critical declaration";
2774 return op.
emitOpError() <<
"must be nested inside of a loop";
2778 if (
auto wsloopOp = dyn_cast<WsloopOp>(wrapper)) {
2779 IntegerAttr orderedAttr = wsloopOp.getOrderedAttr();
2781 return op.
emitOpError() <<
"the enclosing worksharing-loop region must "
2782 "have an ordered clause";
2784 if (hasRegion && orderedAttr.getInt() != 0)
2785 return op.
emitOpError() <<
"the enclosing loop's ordered clause must not "
2786 "have a parameter present";
2788 if (!hasRegion && orderedAttr.getInt() == 0)
2789 return op.
emitOpError() <<
"the enclosing loop's ordered clause must "
2790 "have a parameter present";
2791 }
else if (!isa<SimdOp>(wrapper)) {
2792 return op.
emitOpError() <<
"must be nested inside of a worksharing, simd "
2793 "or worksharing simd loop";
2799 const OrderedOperands &clauses) {
2800 OrderedOp::build(builder, state, clauses.doacrossDependType,
2801 clauses.doacrossNumLoops, clauses.doacrossDependVars);
2808 auto wrapper = (*this)->getParentOfType<WsloopOp>();
2809 if (!wrapper || *wrapper.getOrdered() != *getDoacrossNumLoops())
2810 return emitOpError() <<
"number of variables in depend clause does not "
2811 <<
"match number of iteration variables in the "
2818 const OrderedRegionOperands &clauses) {
2819 OrderedRegionOp::build(builder, state, clauses.parLevelSimd);
2829 const TaskwaitOperands &clauses) {
2831 TaskwaitOp::build(builder, state,
nullptr,
2840 if (verifyCommon().failed())
2841 return mlir::failure();
2843 if (
auto mo = getMemoryOrder()) {
2844 if (*mo == ClauseMemoryOrderKind::Acq_rel ||
2845 *mo == ClauseMemoryOrderKind::Release) {
2847 "memory-order must not be acq_rel or release for atomic reads");
2858 if (verifyCommon().failed())
2859 return mlir::failure();
2861 if (
auto mo = getMemoryOrder()) {
2862 if (*mo == ClauseMemoryOrderKind::Acq_rel ||
2863 *mo == ClauseMemoryOrderKind::Acquire) {
2865 "memory-order must not be acq_rel or acquire for atomic writes");
2875 LogicalResult AtomicUpdateOp::canonicalize(AtomicUpdateOp op,
2881 if (
Value writeVal = op.getWriteOpVal()) {
2883 op, op.getX(), writeVal, op.getHintAttr(), op.getMemoryOrderAttr());
2890 if (verifyCommon().failed())
2891 return mlir::failure();
2893 if (
auto mo = getMemoryOrder()) {
2894 if (*mo == ClauseMemoryOrderKind::Acq_rel ||
2895 *mo == ClauseMemoryOrderKind::Acquire) {
2897 "memory-order must not be acq_rel or acquire for atomic updates");
2904 LogicalResult AtomicUpdateOp::verifyRegions() {
return verifyRegionsCommon(); }
2910 AtomicReadOp AtomicCaptureOp::getAtomicReadOp() {
2911 if (
auto op = dyn_cast<AtomicReadOp>(getFirstOp()))
2913 return dyn_cast<AtomicReadOp>(getSecondOp());
2916 AtomicWriteOp AtomicCaptureOp::getAtomicWriteOp() {
2917 if (
auto op = dyn_cast<AtomicWriteOp>(getFirstOp()))
2919 return dyn_cast<AtomicWriteOp>(getSecondOp());
2922 AtomicUpdateOp AtomicCaptureOp::getAtomicUpdateOp() {
2923 if (
auto op = dyn_cast<AtomicUpdateOp>(getFirstOp()))
2925 return dyn_cast<AtomicUpdateOp>(getSecondOp());
2932 LogicalResult AtomicCaptureOp::verifyRegions() {
2933 if (verifyRegionsCommon().failed())
2934 return mlir::failure();
2936 if (getFirstOp()->getAttr(
"hint") || getSecondOp()->getAttr(
"hint"))
2938 "operations inside capture region must not have hint clause");
2940 if (getFirstOp()->getAttr(
"memory_order") ||
2941 getSecondOp()->getAttr(
"memory_order"))
2943 "operations inside capture region must not have memory_order clause");
2952 const CancelOperands &clauses) {
2953 CancelOp::build(builder, state, clauses.cancelDirective, clauses.ifExpr);
2957 ClauseCancellationConstructType cct = getCancelDirective();
2961 return emitOpError() <<
"must be used within a region supporting "
2965 if ((cct == ClauseCancellationConstructType::Parallel) &&
2966 !isa<ParallelOp>(parentOp)) {
2967 return emitOpError() <<
"cancel parallel must appear "
2968 <<
"inside a parallel region";
2970 if (cct == ClauseCancellationConstructType::Loop) {
2971 auto loopOp = dyn_cast<LoopNestOp>(parentOp);
2972 auto wsloopOp = llvm::dyn_cast_if_present<WsloopOp>(
2973 loopOp ? loopOp->getParentOp() :
nullptr);
2976 return emitOpError()
2977 <<
"cancel loop must appear inside a worksharing-loop region";
2979 if (wsloopOp.getNowaitAttr()) {
2980 return emitError() <<
"A worksharing construct that is canceled "
2981 <<
"must not have a nowait clause";
2983 if (wsloopOp.getOrderedAttr()) {
2984 return emitError() <<
"A worksharing construct that is canceled "
2985 <<
"must not have an ordered clause";
2988 }
else if (cct == ClauseCancellationConstructType::Sections) {
2989 if (!(isa<SectionsOp>(parentOp) || isa<SectionOp>(parentOp))) {
2990 return emitOpError() <<
"cancel sections must appear "
2991 <<
"inside a sections region";
2993 if (isa_and_nonnull<SectionsOp>(parentOp->
getParentOp()) &&
2994 cast<SectionsOp>(parentOp->
getParentOp()).getNowaitAttr()) {
2995 return emitError() <<
"A sections construct that is canceled "
2996 <<
"must not have a nowait clause";
3008 const CancellationPointOperands &clauses) {
3009 CancellationPointOp::build(builder, state, clauses.cancelDirective);
3013 ClauseCancellationConstructType cct = getCancelDirective();
3017 return emitOpError() <<
"must be used within a region supporting "
3018 "cancellation point directive";
3021 if ((cct == ClauseCancellationConstructType::Parallel) &&
3022 !(isa<ParallelOp>(parentOp))) {
3023 return emitOpError() <<
"cancellation point parallel must appear "
3024 <<
"inside a parallel region";
3026 if ((cct == ClauseCancellationConstructType::Loop) &&
3027 (!isa<LoopNestOp>(parentOp) || !isa<WsloopOp>(parentOp->
getParentOp()))) {
3028 return emitOpError() <<
"cancellation point loop must appear "
3029 <<
"inside a worksharing-loop region";
3031 if ((cct == ClauseCancellationConstructType::Sections) &&
3032 !(isa<SectionsOp>(parentOp) || isa<SectionOp>(parentOp))) {
3033 return emitOpError() <<
"cancellation point sections must appear "
3034 <<
"inside a sections region";
3045 auto extent = getExtent();
3047 if (!extent && !upperbound)
3048 return emitError(
"expected extent or upperbound.");
3055 PrivateClauseOp::build(
3056 odsBuilder, odsState, symName, type,
3058 DataSharingClauseType::Private));
3061 LogicalResult PrivateClauseOp::verifyRegions() {
3062 Type argType = getArgType();
3063 auto verifyTerminator = [&](
Operation *terminator,
3064 bool yieldsValue) -> LogicalResult {
3068 if (!llvm::isa<YieldOp>(terminator))
3070 <<
"expected exit block terminator to be an `omp.yield` op.";
3072 YieldOp yieldOp = llvm::cast<YieldOp>(terminator);
3073 TypeRange yieldedTypes = yieldOp.getResults().getTypes();
3076 if (yieldedTypes.empty())
3080 <<
"Did not expect any values to be yielded.";
3083 if (yieldedTypes.size() == 1 && yieldedTypes.front() == argType)
3087 <<
"Invalid yielded value. Expected type: " << argType
3090 if (yieldedTypes.empty())
3093 error << yieldedTypes;
3099 StringRef regionName,
3100 bool yieldsValue) -> LogicalResult {
3101 assert(!region.
empty());
3105 <<
"`" << regionName <<
"`: "
3106 <<
"expected " << expectedNumArgs
3109 for (
Block &block : region) {
3111 if (!block.mightHaveTerminator())
3114 if (failed(verifyTerminator(block.getTerminator(), yieldsValue)))
3122 for (
Region *region : getRegions())
3123 for (
Type ty : region->getArgumentTypes())
3125 return emitError() <<
"Region argument type mismatch: got " << ty
3126 <<
" expected " << argType <<
".";
3129 if (!initRegion.
empty() &&
3134 DataSharingClauseType dsType = getDataSharingType();
3136 if (dsType == DataSharingClauseType::Private && !getCopyRegion().empty())
3137 return emitError(
"`private` clauses do not require a `copy` region.");
3139 if (dsType == DataSharingClauseType::FirstPrivate && getCopyRegion().empty())
3141 "`firstprivate` clauses require at least a `copy` region.");
3143 if (dsType == DataSharingClauseType::FirstPrivate &&
3148 if (!getDeallocRegion().empty() &&
3161 const MaskedOperands &clauses) {
3162 MaskedOp::build(builder, state, clauses.filteredThreadId);
3170 const ScanOperands &clauses) {
3171 ScanOp::build(builder, state, clauses.inclusiveVars, clauses.exclusiveVars);
3175 if (hasExclusiveVars() == hasInclusiveVars())
3177 "Exactly one of EXCLUSIVE or INCLUSIVE clause is expected");
3178 if (WsloopOp parentWsLoopOp = (*this)->getParentOfType<WsloopOp>()) {
3179 if (parentWsLoopOp.getReductionModAttr() &&
3180 parentWsLoopOp.getReductionModAttr().getValue() ==
3181 ReductionModifier::inscan)
3184 if (SimdOp parentSimdOp = (*this)->getParentOfType<SimdOp>()) {
3185 if (parentSimdOp.getReductionModAttr() &&
3186 parentSimdOp.getReductionModAttr().getValue() ==
3187 ReductionModifier::inscan)
3190 return emitError(
"SCAN directive needs to be enclosed within a parent "
3191 "worksharing loop construct or SIMD construct with INSCAN "
3192 "reduction modifier");
3195 #define GET_ATTRDEF_CLASSES
3196 #include "mlir/Dialect/OpenMP/OpenMPOpsAttributes.cpp.inc"
3198 #define GET_OP_CLASSES
3199 #include "mlir/Dialect/OpenMP/OpenMPOps.cpp.inc"
3201 #define GET_TYPEDEF_CLASSES
3202 #include "mlir/Dialect/OpenMP/OpenMPOpsTypes.cpp.inc"
static std::optional< int64_t > getUpperBound(Value iv)
Gets the constant upper bound on an affine.for iv.
static LogicalResult verifyRegion(emitc::SwitchOp op, Region ®ion, const Twine &name)
static void visit(Operation *op, DenseSet< Operation * > &visited)
Visits all the pdl.operand(s), pdl.result(s), and pdl.operation(s) connected to the given operation.
static MLIRContext * getContext(OpFoldResult val)
void printClauseAttr(OpAsmPrinter &p, Operation *op, ClauseAttr attr)
static LogicalResult verifyNontemporalClause(Operation *op, OperandRange nontemporalVars)
static ParseResult parsePrivateRegion(OpAsmParser &parser, Region ®ion, llvm::SmallVectorImpl< OpAsmParser::UnresolvedOperand > &privateVars, llvm::SmallVectorImpl< Type > &privateTypes, ArrayAttr &privateSyms)
static LogicalResult verifyMapClause(Operation *op, OperandRange mapVars)
static ArrayAttr makeArrayAttr(MLIRContext *context, llvm::ArrayRef< Attribute > attrs)
static ParseResult parseClauseAttr(AsmParser &parser, ClauseAttr &attr)
static void printAllocateAndAllocator(OpAsmPrinter &p, Operation *op, OperandRange allocateVars, TypeRange allocateTypes, OperandRange allocatorVars, TypeRange allocatorTypes)
Print allocate clause.
static DenseBoolArrayAttr makeDenseBoolArrayAttr(MLIRContext *ctx, const ArrayRef< bool > boolArray)
static ParseResult parseInReductionPrivateRegion(OpAsmParser &parser, Region ®ion, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &inReductionVars, SmallVectorImpl< Type > &inReductionTypes, DenseBoolArrayAttr &inReductionByref, ArrayAttr &inReductionSyms, llvm::SmallVectorImpl< OpAsmParser::UnresolvedOperand > &privateVars, llvm::SmallVectorImpl< Type > &privateTypes, ArrayAttr &privateSyms)
static void printBlockArgClause(OpAsmPrinter &p, MLIRContext *ctx, StringRef clauseName, ValueRange argsSubrange, std::optional< MapPrintArgs > mapArgs)
static void printBlockArgRegion(OpAsmPrinter &p, Operation *op, Region ®ion, const AllRegionPrintArgs &args)
static ParseResult parseBlockArgRegion(OpAsmParser &parser, Region ®ion, AllRegionParseArgs args)
static ParseResult parseSynchronizationHint(OpAsmParser &parser, IntegerAttr &hintAttr)
Parses a Synchronization Hint clause.
static void printClauseWithRegionArgs(OpAsmPrinter &p, MLIRContext *ctx, StringRef clauseName, ValueRange argsSubrange, ValueRange operands, TypeRange types, ArrayAttr symbols=nullptr, DenseI64ArrayAttr mapIndices=nullptr, DenseBoolArrayAttr byref=nullptr, ReductionModifierAttr modifier=nullptr)
uint64_t mapTypeToBitFlag(uint64_t value, llvm::omp::OpenMPOffloadMappingFlags flag)
static ParseResult parseLinearClause(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &linearVars, SmallVectorImpl< Type > &linearTypes, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &linearStepVars)
linear ::= linear ( linear-list ) linear-list := linear-val | linear-val linear-list linear-val := ss...
static void printInReductionPrivateReductionRegion(OpAsmPrinter &p, Operation *op, Region ®ion, ValueRange inReductionVars, TypeRange inReductionTypes, DenseBoolArrayAttr inReductionByref, ArrayAttr inReductionSyms, ValueRange privateVars, TypeRange privateTypes, ArrayAttr privateSyms, ReductionModifierAttr reductionMod, ValueRange reductionVars, TypeRange reductionTypes, DenseBoolArrayAttr reductionByref, ArrayAttr reductionSyms)
static void printScheduleClause(OpAsmPrinter &p, Operation *op, ClauseScheduleKindAttr scheduleKind, ScheduleModifierAttr scheduleMod, UnitAttr scheduleSimd, Value scheduleChunk, Type scheduleChunkType)
Print schedule clause.
static void printPrivateReductionRegion(OpAsmPrinter &p, Operation *op, Region ®ion, ValueRange privateVars, TypeRange privateTypes, ArrayAttr privateSyms, ReductionModifierAttr reductionMod, ValueRange reductionVars, TypeRange reductionTypes, DenseBoolArrayAttr reductionByref, ArrayAttr reductionSyms)
static void printCopyprivate(OpAsmPrinter &p, Operation *op, OperandRange copyprivateVars, TypeRange copyprivateTypes, std::optional< ArrayAttr > copyprivateSyms)
Print Copyprivate clause.
static ParseResult parseOrderClause(OpAsmParser &parser, ClauseOrderKindAttr &order, OrderModifierAttr &orderMod)
static void printAlignedClause(OpAsmPrinter &p, Operation *op, ValueRange alignedVars, TypeRange alignedTypes, std::optional< ArrayAttr > alignments)
Print Aligned Clause.
static LogicalResult verifySynchronizationHint(Operation *op, uint64_t hint)
Verifies a synchronization hint clause.
static ParseResult parseUseDeviceAddrUseDevicePtrRegion(OpAsmParser &parser, Region ®ion, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &useDeviceAddrVars, SmallVectorImpl< Type > &useDeviceAddrTypes, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &useDevicePtrVars, SmallVectorImpl< Type > &useDevicePtrTypes)
static void printLinearClause(OpAsmPrinter &p, Operation *op, ValueRange linearVars, TypeRange linearTypes, ValueRange linearStepVars)
Print Linear Clause.
static void printSynchronizationHint(OpAsmPrinter &p, Operation *op, IntegerAttr hintAttr)
Prints a Synchronization Hint clause.
static void printDependVarList(OpAsmPrinter &p, Operation *op, OperandRange dependVars, TypeRange dependTypes, std::optional< ArrayAttr > dependKinds)
Print Depend clause.
static bool siblingAllowedInCapture(Operation *op)
Only allow OpenMP terminators and non-OpenMP ops that have known memory effects, but don't include a ...
static LogicalResult verifyCopyprivateVarList(Operation *op, OperandRange copyprivateVars, std::optional< ArrayAttr > copyprivateSyms)
Verifies CopyPrivate Clause.
static LogicalResult verifyAlignedClause(Operation *op, std::optional< ArrayAttr > alignments, OperandRange alignedVars)
static void printTaskReductionRegion(OpAsmPrinter &p, Operation *op, Region ®ion, ValueRange taskReductionVars, TypeRange taskReductionTypes, DenseBoolArrayAttr taskReductionByref, ArrayAttr taskReductionSyms)
static ParseResult parseInReductionPrivateReductionRegion(OpAsmParser &parser, Region ®ion, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &inReductionVars, SmallVectorImpl< Type > &inReductionTypes, DenseBoolArrayAttr &inReductionByref, ArrayAttr &inReductionSyms, llvm::SmallVectorImpl< OpAsmParser::UnresolvedOperand > &privateVars, llvm::SmallVectorImpl< Type > &privateTypes, ArrayAttr &privateSyms, ReductionModifierAttr &reductionMod, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &reductionVars, SmallVectorImpl< Type > &reductionTypes, DenseBoolArrayAttr &reductionByref, ArrayAttr &reductionSyms)
static ParseResult parseMapClause(OpAsmParser &parser, IntegerAttr &mapType)
Parses a map_entries map type from a string format back into its numeric value.
static LogicalResult verifyOrderedParent(Operation &op)
static void printOrderClause(OpAsmPrinter &p, Operation *op, ClauseOrderKindAttr order, OrderModifierAttr orderMod)
static ParseResult parseBlockArgClause(OpAsmParser &parser, llvm::SmallVectorImpl< OpAsmParser::Argument > &entryBlockArgs, StringRef keyword, std::optional< MapParseArgs > mapArgs)
static ParseResult verifyScheduleModifiers(OpAsmParser &parser, SmallVectorImpl< SmallString< 12 >> &modifiers)
static LogicalResult verifyPrivateVarsMapping(TargetOp targetOp)
static ParseResult parseScheduleClause(OpAsmParser &parser, ClauseScheduleKindAttr &scheduleAttr, ScheduleModifierAttr &scheduleMod, UnitAttr &scheduleSimd, std::optional< OpAsmParser::UnresolvedOperand > &chunkSize, Type &chunkType)
schedule ::= schedule ( sched-list ) sched-list ::= sched-val | sched-val sched-list | sched-val ,...
static void printPrivateRegion(OpAsmPrinter &p, Operation *op, Region ®ion, ValueRange privateVars, TypeRange privateTypes, ArrayAttr privateSyms)
static ParseResult parseAllocateAndAllocator(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &allocateVars, SmallVectorImpl< Type > &allocateTypes, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &allocatorVars, SmallVectorImpl< Type > &allocatorTypes)
Parse an allocate clause with allocators and a list of operands with types.
static void printMembersIndex(OpAsmPrinter &p, MapInfoOp op, ArrayAttr membersIdx)
static void printCaptureType(OpAsmPrinter &p, Operation *op, VariableCaptureKindAttr mapCaptureType)
static LogicalResult verifyReductionVarList(Operation *op, std::optional< ArrayAttr > reductionSyms, OperandRange reductionVars, std::optional< ArrayRef< bool >> reductionByref)
Verifies Reduction Clause.
static ParseResult parseClauseWithRegionArgs(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &operands, SmallVectorImpl< Type > &types, SmallVectorImpl< OpAsmParser::Argument > ®ionPrivateArgs, ArrayAttr *symbols=nullptr, DenseI64ArrayAttr *mapIndices=nullptr, DenseBoolArrayAttr *byref=nullptr, ReductionModifierAttr *modifier=nullptr)
static bool opInGlobalImplicitParallelRegion(Operation *op)
static void printUseDeviceAddrUseDevicePtrRegion(OpAsmPrinter &p, Operation *op, Region ®ion, ValueRange useDeviceAddrVars, TypeRange useDeviceAddrTypes, ValueRange useDevicePtrVars, TypeRange useDevicePtrTypes)
static LogicalResult verifyPrivateVarList(OpType &op)
static void printMapClause(OpAsmPrinter &p, Operation *op, IntegerAttr mapType)
Prints a map_entries map type from its numeric value out into its string format.
static ParseResult parseMembersIndex(OpAsmParser &parser, ArrayAttr &membersIdx)
static ParseResult parseAlignedClause(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &alignedVars, SmallVectorImpl< Type > &alignedTypes, ArrayAttr &alignmentsAttr)
aligned ::= aligned ( aligned-list ) aligned-list := aligned-val | aligned-val aligned-list aligned-v...
static void printInReductionPrivateRegion(OpAsmPrinter &p, Operation *op, Region ®ion, ValueRange inReductionVars, TypeRange inReductionTypes, DenseBoolArrayAttr inReductionByref, ArrayAttr inReductionSyms, ValueRange privateVars, TypeRange privateTypes, ArrayAttr privateSyms)
static ParseResult parseCaptureType(OpAsmParser &parser, VariableCaptureKindAttr &mapCaptureType)
static ParseResult parseTaskReductionRegion(OpAsmParser &parser, Region ®ion, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &taskReductionVars, SmallVectorImpl< Type > &taskReductionTypes, DenseBoolArrayAttr &taskReductionByref, ArrayAttr &taskReductionSyms)
static void printHostEvalInReductionMapPrivateRegion(OpAsmPrinter &p, Operation *op, Region ®ion, ValueRange hostEvalVars, TypeRange hostEvalTypes, ValueRange inReductionVars, TypeRange inReductionTypes, DenseBoolArrayAttr inReductionByref, ArrayAttr inReductionSyms, ValueRange mapVars, TypeRange mapTypes, ValueRange privateVars, TypeRange privateTypes, ArrayAttr privateSyms, DenseI64ArrayAttr privateMaps)
static ParseResult parseHostEvalInReductionMapPrivateRegion(OpAsmParser &parser, Region ®ion, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &hostEvalVars, SmallVectorImpl< Type > &hostEvalTypes, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &inReductionVars, SmallVectorImpl< Type > &inReductionTypes, DenseBoolArrayAttr &inReductionByref, ArrayAttr &inReductionSyms, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &mapVars, SmallVectorImpl< Type > &mapTypes, llvm::SmallVectorImpl< OpAsmParser::UnresolvedOperand > &privateVars, llvm::SmallVectorImpl< Type > &privateTypes, ArrayAttr &privateSyms, DenseI64ArrayAttr &privateMaps)
static ParseResult parseCopyprivate(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > ©privateVars, SmallVectorImpl< Type > ©privateTypes, ArrayAttr ©privateSyms)
copyprivate-entry-list ::= copyprivate-entry | copyprivate-entry-list , copyprivate-entry copyprivate...
static LogicalResult verifyDependVarList(Operation *op, std::optional< ArrayAttr > dependKinds, OperandRange dependVars)
Verifies Depend clause.
static ParseResult parseDependVarList(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &dependVars, SmallVectorImpl< Type > &dependTypes, ArrayAttr &dependKinds)
depend-entry-list ::= depend-entry | depend-entry-list , depend-entry depend-entry ::= depend-kind ->...
static ParseResult parsePrivateReductionRegion(OpAsmParser &parser, Region ®ion, llvm::SmallVectorImpl< OpAsmParser::UnresolvedOperand > &privateVars, llvm::SmallVectorImpl< Type > &privateTypes, ArrayAttr &privateSyms, ReductionModifierAttr &reductionMod, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &reductionVars, SmallVectorImpl< Type > &reductionTypes, DenseBoolArrayAttr &reductionByref, ArrayAttr &reductionSyms)
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity,...
This base class exposes generic asm parser hooks, usable across the various derived parsers.
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalEqual()=0
Parse a = token if present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseOptionalColon()=0
Parse a : token if present.
virtual ParseResult parseLSquare()=0
Parse a [ token.
virtual ParseResult parseRSquare()=0
Parse a ] token.
ParseResult parseInteger(IntT &result)
Parse an integer value from the stream.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual ParseResult parseColon()=0
Parse a : token.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseArrow()=0
Parse a '->' token.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseComma()=0
Parse a , token.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
Attributes are known-constant values of operations.
Block represents an ordered list of Operations.
ValueTypeRange< BlockArgListType > getArgumentTypes()
Return a range containing the types of the arguments for this block.
BlockArgument getArgument(unsigned i)
unsigned getNumArguments()
SuccessorRange getSuccessors()
BlockArgListType getArguments()
IntegerAttr getIntegerAttr(Type type, int64_t value)
IntegerType getIntegerType(unsigned width)
MLIRContext * getContext() const
Dialects are groups of MLIR operations, types and attributes, as well as behavior associated with the...
A class for computing basic dominance information.
bool dominates(Operation *a, Operation *b) const
Return true if operation A dominates operation B, i.e.
MLIRContext is the top-level object for a collection of MLIR operations.
Dialect * getLoadedDialect(StringRef name)
Get a registered IR dialect with the given namespace.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region ®ion, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult parseArgumentList(SmallVectorImpl< Argument > &result, Delimiter delimiter=Delimiter::None, bool allowType=false, bool allowAttrs=false)=0
Parse zero or more arguments with a specified surrounding delimiter.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
This class helps build Operations.
This class provides the API for ops that are known to be terminators.
This class indicates that the regions associated with this op don't have terminators.
This class implements the operand iterators for the Operation class.
Operation is the basic unit of execution within MLIR.
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Dialect * getDialect()
Return the dialect this operation is associated with, or nullptr if the associated dialect is not loa...
MLIRContext * getContext()
Return the context this operation is associated with.
unsigned getNumRegions()
Returns the number of regions held by this operation.
Location getLoc()
The source location the operation was defined or derived from.
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Block * getBlock()
Returns the operation block that contains this operation.
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
user_range getUsers()
Returns a range of all users.
Region * getParentRegion()
Returns the region to which the instruction belongs.
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class contains a list of basic blocks and a link to the parent operation it is attached to.
iterator_range< OpIterator > getOps()
BlockArgListType getArguments()
OpIterator op_begin()
Return iterators that walk the operations nested directly within this region.
unsigned getNumArguments()
Location getLoc()
Return a location for this region.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class represents a specific instance of an effect.
Resource * getResource() const
Return the resource that the effect applies to.
EffectT * getEffect() const
Return the effect being applied.
This class represents a collection of SymbolTables.
Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
type_range getType() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
A utility result that is used to signal how to proceed with an ongoing walk:
Base class for DenseArrayAttr that is instantiated and specialized for each supported element type be...
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< T > content)
Builder from ArrayRef<T>.
bool isReachableFromEntry(Block *a) const
Return true if the specified block is reachable from the entry block of its region.
Runtime
Potential runtimes for AMD GPU kernels.
TargetEnterDataOperands TargetEnterExitUpdateDataOperands
omp.target_enter_data, omp.target_exit_data and omp.target_update take the same clauses,...
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
This is the representation of an operand reference.
This class provides APIs and verifiers for ops with regions having a single block.
This represents an operation in an abstracted form, suitable for use with the builder APIs.
SmallVector< Value, 4 > operands
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
Region * addRegion()
Create a region that should be attached to the operation.