25 #include "llvm/ADT/ArrayRef.h"
26 #include "llvm/ADT/BitVector.h"
27 #include "llvm/ADT/STLExtras.h"
28 #include "llvm/ADT/STLForwardCompat.h"
29 #include "llvm/ADT/SmallString.h"
30 #include "llvm/ADT/StringExtras.h"
31 #include "llvm/ADT/StringRef.h"
32 #include "llvm/ADT/TypeSwitch.h"
33 #include "llvm/Frontend/OpenMP/OMPConstants.h"
34 #include "llvm/Frontend/OpenMP/OMPDeviceConstants.h"
40 #include "mlir/Dialect/OpenMP/OpenMPOpsDialect.cpp.inc"
41 #include "mlir/Dialect/OpenMP/OpenMPOpsEnums.cpp.inc"
42 #include "mlir/Dialect/OpenMP/OpenMPOpsInterfaces.cpp.inc"
43 #include "mlir/Dialect/OpenMP/OpenMPTypeInterfaces.cpp.inc"
59 struct MemRefPointerLikeModel
60 :
public PointerLikeType::ExternalModel<MemRefPointerLikeModel,
63 return llvm::cast<MemRefType>(pointer).getElementType();
67 struct LLVMPointerPointerLikeModel
68 :
public PointerLikeType::ExternalModel<LLVMPointerPointerLikeModel,
69 LLVM::LLVMPointerType> {
74 void OpenMPDialect::initialize() {
77 #include "mlir/Dialect/OpenMP/OpenMPOps.cpp.inc"
80 #define GET_ATTRDEF_LIST
81 #include "mlir/Dialect/OpenMP/OpenMPOpsAttributes.cpp.inc"
84 #define GET_TYPEDEF_LIST
85 #include "mlir/Dialect/OpenMP/OpenMPOpsTypes.cpp.inc"
88 declarePromisedInterface<ConvertToLLVMPatternInterface, OpenMPDialect>();
90 MemRefType::attachInterface<MemRefPointerLikeModel>(*
getContext());
91 LLVM::LLVMPointerType::attachInterface<LLVMPointerPointerLikeModel>(
96 mlir::ModuleOp::attachInterface<mlir::omp::OffloadModuleDefaultModel>(
102 mlir::LLVM::GlobalOp::attachInterface<
105 mlir::LLVM::LLVMFuncOp::attachInterface<
108 mlir::func::FuncOp::attachInterface<
134 allocatorVars.push_back(operand);
135 allocatorTypes.push_back(type);
141 allocateVars.push_back(operand);
142 allocateTypes.push_back(type);
153 for (
unsigned i = 0; i < allocateVars.size(); ++i) {
154 std::string separator = i == allocateVars.size() - 1 ?
"" :
", ";
155 p << allocatorVars[i] <<
" : " << allocatorTypes[i] <<
" -> ";
156 p << allocateVars[i] <<
" : " << allocateTypes[i] << separator;
164 template <
typename ClauseAttr>
166 using ClauseT = decltype(std::declval<ClauseAttr>().getValue());
171 if (std::optional<ClauseT> enumValue = symbolizeEnum<ClauseT>(enumStr)) {
175 return parser.
emitError(loc,
"invalid clause value: '") << enumStr <<
"'";
178 template <
typename ClauseAttr>
180 p << stringifyEnum(attr.getValue());
203 linearVars.push_back(var);
204 linearTypes.push_back(type);
205 linearStepVars.push_back(stepVar);
214 size_t linearVarsSize = linearVars.size();
215 for (
unsigned i = 0; i < linearVarsSize; ++i) {
216 std::string separator = i == linearVarsSize - 1 ?
"" :
", ";
218 if (linearStepVars.size() > i)
219 p <<
" = " << linearStepVars[i];
220 p <<
" : " << linearVars[i].getType() << separator;
233 for (
const auto &it : nontemporalVars)
234 if (!nontemporalItems.insert(it).second)
235 return op->
emitOpError() <<
"nontemporal variable used more than once";
244 std::optional<ArrayAttr> alignments,
247 if (!alignedVars.empty()) {
248 if (!alignments || alignments->size() != alignedVars.size())
250 <<
"expected as many alignment values as aligned variables";
253 return op->
emitOpError() <<
"unexpected alignment values attribute";
259 for (
auto it : alignedVars)
260 if (!alignedItems.insert(it).second)
261 return op->
emitOpError() <<
"aligned variable used more than once";
267 for (
unsigned i = 0; i < (*alignments).size(); ++i) {
268 if (
auto intAttr = llvm::dyn_cast<IntegerAttr>((*alignments)[i])) {
269 if (intAttr.getValue().sle(0))
270 return op->
emitOpError() <<
"alignment should be greater than 0";
272 return op->
emitOpError() <<
"expected integer alignment";
286 ArrayAttr &alignmentsAttr) {
289 if (parser.parseOperand(alignedVars.emplace_back()) ||
290 parser.parseColonType(alignedTypes.emplace_back()) ||
291 parser.parseArrow() ||
292 parser.parseAttribute(alignmentVec.emplace_back())) {
306 std::optional<ArrayAttr> alignments) {
307 for (
unsigned i = 0; i < alignedVars.size(); ++i) {
310 p << alignedVars[i] <<
" : " << alignedVars[i].
getType();
311 p <<
" -> " << (*alignments)[i];
322 if (modifiers.size() > 2)
324 for (
const auto &mod : modifiers) {
327 auto symbol = symbolizeScheduleModifier(mod);
330 <<
" unknown modifier type: " << mod;
335 if (modifiers.size() == 1) {
336 if (symbolizeScheduleModifier(modifiers[0]) == ScheduleModifier::simd) {
337 modifiers.push_back(modifiers[0]);
338 modifiers[0] = stringifyScheduleModifier(ScheduleModifier::none);
340 }
else if (modifiers.size() == 2) {
343 if (symbolizeScheduleModifier(modifiers[0]) == ScheduleModifier::simd ||
344 symbolizeScheduleModifier(modifiers[1]) != ScheduleModifier::simd)
346 <<
" incorrect modifier order";
362 ScheduleModifierAttr &scheduleMod, UnitAttr &scheduleSimd,
363 std::optional<OpAsmParser::UnresolvedOperand> &chunkSize,
368 std::optional<mlir::omp::ClauseScheduleKind> schedule =
369 symbolizeClauseScheduleKind(keyword);
375 case ClauseScheduleKind::Static:
376 case ClauseScheduleKind::Dynamic:
377 case ClauseScheduleKind::Guided:
383 chunkSize = std::nullopt;
386 case ClauseScheduleKind::Auto:
388 chunkSize = std::nullopt;
397 modifiers.push_back(mod);
403 if (!modifiers.empty()) {
405 if (std::optional<ScheduleModifier> mod =
406 symbolizeScheduleModifier(modifiers[0])) {
409 return parser.
emitError(loc,
"invalid schedule modifier");
412 if (modifiers.size() > 1) {
413 assert(symbolizeScheduleModifier(modifiers[1]) == ScheduleModifier::simd);
423 ClauseScheduleKindAttr scheduleKind,
424 ScheduleModifierAttr scheduleMod,
425 UnitAttr scheduleSimd,
Value scheduleChunk,
426 Type scheduleChunkType) {
427 p << stringifyClauseScheduleKind(scheduleKind.getValue());
429 p <<
" = " << scheduleChunk <<
" : " << scheduleChunk.
getType();
431 p <<
", " << stringifyScheduleModifier(scheduleMod.getValue());
443 ClauseOrderKindAttr &order,
444 OrderModifierAttr &orderMod) {
449 if (std::optional<OrderModifier> enumValue =
450 symbolizeOrderModifier(enumStr)) {
458 if (std::optional<ClauseOrderKind> enumValue =
459 symbolizeClauseOrderKind(enumStr)) {
463 return parser.
emitError(loc,
"invalid clause value: '") << enumStr <<
"'";
467 ClauseOrderKindAttr order,
468 OrderModifierAttr orderMod) {
470 p << stringifyOrderModifier(orderMod.getValue()) <<
":";
472 p << stringifyClauseOrderKind(order.getValue());
475 template <
typename ClauseTypeAttr,
typename ClauseType>
478 std::optional<OpAsmParser::UnresolvedOperand> &operand,
480 std::optional<ClauseType> (*symbolizeClause)(StringRef),
481 StringRef clauseName) {
484 if (std::optional<ClauseType> enumValue = symbolizeClause(enumStr)) {
490 <<
"invalid " << clauseName <<
" modifier : '" << enumStr <<
"'";
500 <<
"expected " << clauseName <<
" operand";
503 if (operand.has_value()) {
511 template <
typename ClauseTypeAttr,
typename ClauseType>
514 ClauseTypeAttr prescriptiveness,
Value operand,
516 StringRef (*stringifyClauseType)(ClauseType)) {
518 if (prescriptiveness)
519 p << stringifyClauseType(prescriptiveness.getValue()) <<
", ";
522 p << operand <<
": " << operandType;
532 std::optional<OpAsmParser::UnresolvedOperand> &grainsize,
533 Type &grainsizeType) {
534 return parseGranularityClause<ClauseGrainsizeTypeAttr, ClauseGrainsizeType>(
535 parser, grainsizeMod, grainsize, grainsizeType,
536 &symbolizeClauseGrainsizeType,
"grainsize");
540 ClauseGrainsizeTypeAttr grainsizeMod,
542 printGranularityClause<ClauseGrainsizeTypeAttr, ClauseGrainsizeType>(
543 p, op, grainsizeMod, grainsize, grainsizeType,
544 &stringifyClauseGrainsizeType);
554 std::optional<OpAsmParser::UnresolvedOperand> &numTasks,
555 Type &numTasksType) {
556 return parseGranularityClause<ClauseNumTasksTypeAttr, ClauseNumTasksType>(
557 parser, numTasksMod, numTasks, numTasksType, &symbolizeClauseNumTasksType,
562 ClauseNumTasksTypeAttr numTasksMod,
564 printGranularityClause<ClauseNumTasksTypeAttr, ClauseNumTasksType>(
565 p, op, numTasksMod, numTasks, numTasksType, &stringifyClauseNumTasksType);
573 struct MapParseArgs {
578 : vars(vars), types(types) {}
580 struct PrivateParseArgs {
588 : vars(vars), types(types), syms(syms), mapIndices(mapIndices) {}
591 struct ReductionParseArgs {
596 ReductionModifierAttr *modifier;
599 ArrayAttr &syms, ReductionModifierAttr *mod =
nullptr)
600 : vars(vars), types(types), byref(byref), syms(syms), modifier(mod) {}
603 struct AllRegionParseArgs {
604 std::optional<MapParseArgs> hasDeviceAddrArgs;
605 std::optional<MapParseArgs> hostEvalArgs;
606 std::optional<ReductionParseArgs> inReductionArgs;
607 std::optional<MapParseArgs> mapArgs;
608 std::optional<PrivateParseArgs> privateArgs;
609 std::optional<ReductionParseArgs> reductionArgs;
610 std::optional<ReductionParseArgs> taskReductionArgs;
611 std::optional<MapParseArgs> useDeviceAddrArgs;
612 std::optional<MapParseArgs> useDevicePtrArgs;
623 ReductionModifierAttr *modifier =
nullptr) {
627 unsigned regionArgOffset = regionPrivateArgs.size();
637 std::optional<ReductionModifier> enumValue =
638 symbolizeReductionModifier(enumStr);
639 if (!enumValue.has_value())
648 isByRefVec.push_back(
649 parser.parseOptionalKeyword(
"byref").succeeded());
651 if (symbols && parser.parseAttribute(symbolVec.emplace_back()))
654 if (parser.parseOperand(operands.emplace_back()) ||
655 parser.parseArrow() ||
656 parser.parseArgument(regionPrivateArgs.emplace_back()))
660 if (parser.parseOptionalLSquare().succeeded()) {
661 if (parser.parseKeyword(
"map_idx") || parser.parseEqual() ||
662 parser.parseInteger(mapIndicesVec.emplace_back()) ||
663 parser.parseRSquare())
666 mapIndicesVec.push_back(-1);
677 if (parser.parseType(types.emplace_back()))
684 if (operands.size() != types.size())
690 auto *argsBegin = regionPrivateArgs.begin();
692 argsBegin + regionArgOffset + types.size());
693 for (
auto [prv, type] : llvm::zip_equal(argsSubrange, types)) {
702 if (!mapIndicesVec.empty())
715 StringRef keyword, std::optional<MapParseArgs> mapArgs) {
730 StringRef keyword, std::optional<PrivateParseArgs> privateArgs) {
736 parser, privateArgs->vars, privateArgs->types, entryBlockArgs,
737 &privateArgs->syms, privateArgs->mapIndices)))
746 StringRef keyword, std::optional<ReductionParseArgs> reductionArgs) {
751 parser, reductionArgs->vars, reductionArgs->types, entryBlockArgs,
752 &reductionArgs->syms,
nullptr, &reductionArgs->byref,
753 reductionArgs->modifier)))
760 AllRegionParseArgs args) {
764 args.hasDeviceAddrArgs)))
766 <<
"invalid `has_device_addr` format";
771 <<
"invalid `host_eval` format";
774 args.inReductionArgs)))
776 <<
"invalid `in_reduction` format";
781 <<
"invalid `map_entries` format";
786 <<
"invalid `private` format";
789 args.reductionArgs)))
791 <<
"invalid `reduction` format";
794 args.taskReductionArgs)))
796 <<
"invalid `task_reduction` format";
799 args.useDeviceAddrArgs)))
801 <<
"invalid `use_device_addr` format";
804 args.useDevicePtrArgs)))
806 <<
"invalid `use_device_addr` format";
827 AllRegionParseArgs args;
828 args.hasDeviceAddrArgs.emplace(hasDeviceAddrVars, hasDeviceAddrTypes);
829 args.hostEvalArgs.emplace(hostEvalVars, hostEvalTypes);
830 args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
831 inReductionByref, inReductionSyms);
832 args.mapArgs.emplace(mapVars, mapTypes);
833 args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
845 AllRegionParseArgs args;
846 args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
847 inReductionByref, inReductionSyms);
848 args.privateArgs.emplace(privateVars, privateTypes, privateSyms);
859 ReductionModifierAttr &reductionMod,
862 ArrayAttr &reductionSyms) {
863 AllRegionParseArgs args;
864 args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
865 inReductionByref, inReductionSyms);
866 args.privateArgs.emplace(privateVars, privateTypes, privateSyms);
867 args.reductionArgs.emplace(reductionVars, reductionTypes, reductionByref,
868 reductionSyms, &reductionMod);
876 AllRegionParseArgs args;
877 args.privateArgs.emplace(privateVars, privateTypes, privateSyms);
885 ReductionModifierAttr &reductionMod,
888 ArrayAttr &reductionSyms) {
889 AllRegionParseArgs args;
890 args.privateArgs.emplace(privateVars, privateTypes, privateSyms);
891 args.reductionArgs.emplace(reductionVars, reductionTypes, reductionByref,
892 reductionSyms, &reductionMod);
901 AllRegionParseArgs args;
902 args.taskReductionArgs.emplace(taskReductionVars, taskReductionTypes,
903 taskReductionByref, taskReductionSyms);
913 AllRegionParseArgs args;
914 args.useDeviceAddrArgs.emplace(useDeviceAddrVars, useDeviceAddrTypes);
915 args.useDevicePtrArgs.emplace(useDevicePtrVars, useDevicePtrTypes);
924 struct MapPrintArgs {
929 struct PrivatePrintArgs {
936 : vars(vars), types(types), syms(syms), mapIndices(mapIndices) {}
938 struct ReductionPrintArgs {
943 ReductionModifierAttr modifier;
945 ArrayAttr syms, ReductionModifierAttr mod =
nullptr)
946 : vars(vars), types(types), byref(byref), syms(syms), modifier(mod) {}
948 struct AllRegionPrintArgs {
949 std::optional<MapPrintArgs> hasDeviceAddrArgs;
950 std::optional<MapPrintArgs> hostEvalArgs;
951 std::optional<ReductionPrintArgs> inReductionArgs;
952 std::optional<MapPrintArgs> mapArgs;
953 std::optional<PrivatePrintArgs> privateArgs;
954 std::optional<ReductionPrintArgs> reductionArgs;
955 std::optional<ReductionPrintArgs> taskReductionArgs;
956 std::optional<MapPrintArgs> useDeviceAddrArgs;
957 std::optional<MapPrintArgs> useDevicePtrArgs;
966 ReductionModifierAttr modifier =
nullptr) {
967 if (argsSubrange.empty())
970 p << clauseName <<
"(";
973 p <<
"mod: " << stringifyReductionModifier(modifier.getValue()) <<
", ";
990 llvm::interleaveComma(llvm::zip_equal(operands, argsSubrange, symbols,
991 mapIndices.asArrayRef(),
994 auto [op, arg, sym, map, isByRef] = t;
1000 p << op <<
" -> " << arg;
1003 p <<
" [map_idx=" << map <<
"]";
1006 llvm::interleaveComma(types, p);
1011 StringRef clauseName,
ValueRange argsSubrange,
1012 std::optional<MapPrintArgs> mapArgs) {
1019 StringRef clauseName,
ValueRange argsSubrange,
1020 std::optional<PrivatePrintArgs> privateArgs) {
1023 privateArgs->vars, privateArgs->types,
1024 privateArgs->syms, privateArgs->mapIndices);
1030 std::optional<ReductionPrintArgs> reductionArgs) {
1033 reductionArgs->vars, reductionArgs->types,
1034 reductionArgs->syms,
nullptr,
1035 reductionArgs->byref, reductionArgs->modifier);
1039 const AllRegionPrintArgs &args) {
1040 auto iface = llvm::cast<mlir::omp::BlockArgOpenMPOpInterface>(op);
1044 iface.getHasDeviceAddrBlockArgs(),
1045 args.hasDeviceAddrArgs);
1049 args.inReductionArgs);
1055 args.reductionArgs);
1057 iface.getTaskReductionBlockArgs(),
1058 args.taskReductionArgs);
1060 iface.getUseDeviceAddrBlockArgs(),
1061 args.useDeviceAddrArgs);
1063 iface.getUseDevicePtrBlockArgs(), args.useDevicePtrArgs);
1076 ArrayAttr inReductionSyms,
ValueRange mapVars,
1078 TypeRange privateTypes, ArrayAttr privateSyms,
1080 AllRegionPrintArgs args;
1081 args.hasDeviceAddrArgs.emplace(hasDeviceAddrVars, hasDeviceAddrTypes);
1082 args.hostEvalArgs.emplace(hostEvalVars, hostEvalTypes);
1083 args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
1084 inReductionByref, inReductionSyms);
1085 args.mapArgs.emplace(mapVars, mapTypes);
1086 args.privateArgs.emplace(privateVars, privateTypes, privateSyms, privateMaps);
1094 ArrayAttr privateSyms) {
1095 AllRegionPrintArgs args;
1096 args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
1097 inReductionByref, inReductionSyms);
1098 args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
1107 ArrayAttr privateSyms, ReductionModifierAttr reductionMod,
1110 AllRegionPrintArgs args;
1111 args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
1112 inReductionByref, inReductionSyms);
1113 args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
1115 args.reductionArgs.emplace(reductionVars, reductionTypes, reductionByref,
1116 reductionSyms, reductionMod);
1122 ArrayAttr privateSyms) {
1123 AllRegionPrintArgs args;
1124 args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
1131 TypeRange privateTypes, ArrayAttr privateSyms,
1132 ReductionModifierAttr reductionMod,
ValueRange reductionVars,
1134 ArrayAttr reductionSyms) {
1135 AllRegionPrintArgs args;
1136 args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
1138 args.reductionArgs.emplace(reductionVars, reductionTypes, reductionByref,
1139 reductionSyms, reductionMod);
1148 ArrayAttr taskReductionSyms) {
1149 AllRegionPrintArgs args;
1150 args.taskReductionArgs.emplace(taskReductionVars, taskReductionTypes,
1151 taskReductionByref, taskReductionSyms);
1161 AllRegionPrintArgs args;
1162 args.useDeviceAddrArgs.emplace(useDeviceAddrVars, useDeviceAddrTypes);
1163 args.useDevicePtrArgs.emplace(useDevicePtrVars, useDevicePtrTypes);
1168 static LogicalResult
1172 if (!reductionVars.empty()) {
1173 if (!reductionSyms || reductionSyms->size() != reductionVars.size())
1175 <<
"expected as many reduction symbol references "
1176 "as reduction variables";
1177 if (reductionByref && reductionByref->size() != reductionVars.size())
1178 return op->
emitError() <<
"expected as many reduction variable by "
1179 "reference attributes as reduction variables";
1182 return op->
emitOpError() <<
"unexpected reduction symbol references";
1189 for (
auto args : llvm::zip(reductionVars, *reductionSyms)) {
1190 Value accum = std::get<0>(args);
1192 if (!accumulators.insert(accum).second)
1193 return op->
emitOpError() <<
"accumulator variable used more than once";
1196 auto symbolRef = llvm::cast<SymbolRefAttr>(std::get<1>(args));
1198 SymbolTable::lookupNearestSymbolFrom<DeclareReductionOp>(op, symbolRef);
1200 return op->
emitOpError() <<
"expected symbol reference " << symbolRef
1201 <<
" to point to a reduction declaration";
1203 if (decl.getAccumulatorType() && decl.getAccumulatorType() != varType)
1205 <<
"expected accumulator (" << varType
1206 <<
") to be the same type as reduction declaration ("
1207 << decl.getAccumulatorType() <<
")";
1226 if (parser.parseOperand(copyprivateVars.emplace_back()) ||
1227 parser.parseArrow() ||
1228 parser.parseAttribute(symsVec.emplace_back()) ||
1229 parser.parseColonType(copyprivateTypes.emplace_back()))
1243 std::optional<ArrayAttr> copyprivateSyms) {
1244 if (!copyprivateSyms.has_value())
1246 llvm::interleaveComma(
1247 llvm::zip(copyprivateVars, *copyprivateSyms, copyprivateTypes), p,
1248 [&](
const auto &args) {
1249 p << std::get<0>(args) <<
" -> " << std::get<1>(args) <<
" : "
1250 << std::get<2>(args);
1255 static LogicalResult
1257 std::optional<ArrayAttr> copyprivateSyms) {
1258 size_t copyprivateSymsSize =
1259 copyprivateSyms.has_value() ? copyprivateSyms->size() : 0;
1260 if (copyprivateSymsSize != copyprivateVars.size())
1261 return op->
emitOpError() <<
"inconsistent number of copyprivate vars (= "
1262 << copyprivateVars.size()
1263 <<
") and functions (= " << copyprivateSymsSize
1264 <<
"), both must be equal";
1265 if (!copyprivateSyms.has_value())
1268 for (
auto copyprivateVarAndSym :
1269 llvm::zip(copyprivateVars, *copyprivateSyms)) {
1271 llvm::cast<SymbolRefAttr>(std::get<1>(copyprivateVarAndSym));
1272 std::optional<std::variant<mlir::func::FuncOp, mlir::LLVM::LLVMFuncOp>>
1274 if (mlir::func::FuncOp mlirFuncOp =
1275 SymbolTable::lookupNearestSymbolFrom<mlir::func::FuncOp>(op,
1277 funcOp = mlirFuncOp;
1278 else if (mlir::LLVM::LLVMFuncOp llvmFuncOp =
1279 SymbolTable::lookupNearestSymbolFrom<mlir::LLVM::LLVMFuncOp>(
1281 funcOp = llvmFuncOp;
1283 auto getNumArguments = [&] {
1284 return std::visit([](
auto &f) {
return f.getNumArguments(); }, *funcOp);
1287 auto getArgumentType = [&](
unsigned i) {
1288 return std::visit([i](
auto &f) {
return f.getArgumentTypes()[i]; },
1293 return op->
emitOpError() <<
"expected symbol reference " << symbolRef
1294 <<
" to point to a copy function";
1296 if (getNumArguments() != 2)
1298 <<
"expected copy function " << symbolRef <<
" to have 2 operands";
1300 Type argTy = getArgumentType(0);
1301 if (argTy != getArgumentType(1))
1302 return op->
emitOpError() <<
"expected copy function " << symbolRef
1303 <<
" arguments to have the same type";
1305 Type varType = std::get<0>(copyprivateVarAndSym).getType();
1306 if (argTy != varType)
1308 <<
"expected copy function arguments' type (" << argTy
1309 <<
") to be the same as copyprivate variable's type (" << varType
1330 if (parser.parseKeyword(&keyword) || parser.parseArrow() ||
1331 parser.parseOperand(dependVars.emplace_back()) ||
1332 parser.parseColonType(dependTypes.emplace_back()))
1334 if (std::optional<ClauseTaskDepend> keywordDepend =
1335 (symbolizeClauseTaskDepend(keyword)))
1336 kindsVec.emplace_back(
1337 ClauseTaskDependAttr::get(parser.getContext(), *keywordDepend));
1351 std::optional<ArrayAttr> dependKinds) {
1353 for (
unsigned i = 0, e = dependKinds->size(); i < e; ++i) {
1356 p << stringifyClauseTaskDepend(
1357 llvm::cast<mlir::omp::ClauseTaskDependAttr>((*dependKinds)[i])
1359 <<
" -> " << dependVars[i] <<
" : " << dependTypes[i];
1365 std::optional<ArrayAttr> dependKinds,
1367 if (!dependVars.empty()) {
1368 if (!dependKinds || dependKinds->size() != dependVars.size())
1369 return op->
emitOpError() <<
"expected as many depend values"
1370 " as depend variables";
1372 if (dependKinds && !dependKinds->empty())
1373 return op->
emitOpError() <<
"unexpected depend values";
1389 IntegerAttr &hintAttr) {
1390 StringRef hintKeyword;
1396 auto parseKeyword = [&]() -> ParseResult {
1399 if (hintKeyword ==
"uncontended")
1401 else if (hintKeyword ==
"contended")
1403 else if (hintKeyword ==
"nonspeculative")
1405 else if (hintKeyword ==
"speculative")
1409 << hintKeyword <<
" is not a valid hint";
1420 IntegerAttr hintAttr) {
1421 int64_t hint = hintAttr.getInt();
1429 auto bitn = [](
int value,
int n) ->
bool {
return value & (1 << n); };
1431 bool uncontended = bitn(hint, 0);
1432 bool contended = bitn(hint, 1);
1433 bool nonspeculative = bitn(hint, 2);
1434 bool speculative = bitn(hint, 3);
1438 hints.push_back(
"uncontended");
1440 hints.push_back(
"contended");
1442 hints.push_back(
"nonspeculative");
1444 hints.push_back(
"speculative");
1446 llvm::interleaveComma(hints, p);
1453 auto bitn = [](
int value,
int n) ->
bool {
return value & (1 << n); };
1455 bool uncontended = bitn(hint, 0);
1456 bool contended = bitn(hint, 1);
1457 bool nonspeculative = bitn(hint, 2);
1458 bool speculative = bitn(hint, 3);
1460 if (uncontended && contended)
1461 return op->
emitOpError() <<
"the hints omp_sync_hint_uncontended and "
1462 "omp_sync_hint_contended cannot be combined";
1463 if (nonspeculative && speculative)
1464 return op->
emitOpError() <<
"the hints omp_sync_hint_nonspeculative and "
1465 "omp_sync_hint_speculative cannot be combined.";
1475 llvm::omp::OpenMPOffloadMappingFlags flag) {
1476 return value & llvm::to_underlying(flag);
1485 llvm::omp::OpenMPOffloadMappingFlags mapTypeBits =
1486 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_NONE;
1490 auto parseTypeAndMod = [&]() -> ParseResult {
1491 StringRef mapTypeMod;
1495 if (mapTypeMod ==
"always")
1496 mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ALWAYS;
1498 if (mapTypeMod ==
"implicit")
1499 mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_IMPLICIT;
1501 if (mapTypeMod ==
"ompx_hold")
1502 mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_OMPX_HOLD;
1504 if (mapTypeMod ==
"close")
1505 mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_CLOSE;
1507 if (mapTypeMod ==
"present")
1508 mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_PRESENT;
1510 if (mapTypeMod ==
"to")
1511 mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO;
1513 if (mapTypeMod ==
"from")
1514 mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM;
1516 if (mapTypeMod ==
"tofrom")
1517 mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO |
1518 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM;
1520 if (mapTypeMod ==
"delete")
1521 mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_DELETE;
1531 llvm::to_underlying(mapTypeBits));
1539 IntegerAttr mapType) {
1540 uint64_t mapTypeBits = mapType.getUInt();
1542 bool emitAllocRelease =
true;
1548 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ALWAYS))
1549 mapTypeStrs.push_back(
"always");
1551 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_IMPLICIT))
1552 mapTypeStrs.push_back(
"implicit");
1554 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_OMPX_HOLD))
1555 mapTypeStrs.push_back(
"ompx_hold");
1557 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_CLOSE))
1558 mapTypeStrs.push_back(
"close");
1560 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_PRESENT))
1561 mapTypeStrs.push_back(
"present");
1567 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO);
1569 mapTypeBits, llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM);
1571 emitAllocRelease =
false;
1572 mapTypeStrs.push_back(
"tofrom");
1574 emitAllocRelease =
false;
1575 mapTypeStrs.push_back(
"from");
1577 emitAllocRelease =
false;
1578 mapTypeStrs.push_back(
"to");
1581 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_DELETE)) {
1582 emitAllocRelease =
false;
1583 mapTypeStrs.push_back(
"delete");
1585 if (emitAllocRelease)
1586 mapTypeStrs.push_back(
"exit_release_or_enter_alloc");
1588 for (
unsigned int i = 0; i < mapTypeStrs.size(); ++i) {
1589 p << mapTypeStrs[i];
1590 if (i + 1 < mapTypeStrs.size()) {
1597 ArrayAttr &membersIdx) {
1600 auto parseIndices = [&]() -> ParseResult {
1605 APInt(64, value,
false)));
1623 if (!memberIdxs.empty())
1630 ArrayAttr membersIdx) {
1634 llvm::interleaveComma(membersIdx, p, [&p](
Attribute v) {
1636 auto memberIdx = cast<ArrayAttr>(v);
1637 llvm::interleaveComma(memberIdx.getValue(), p, [&p](
Attribute v2) {
1638 p << cast<IntegerAttr>(v2).getInt();
1645 VariableCaptureKindAttr mapCaptureType) {
1646 std::string typeCapStr;
1647 llvm::raw_string_ostream typeCap(typeCapStr);
1648 if (mapCaptureType.getValue() == mlir::omp::VariableCaptureKind::ByRef)
1650 if (mapCaptureType.getValue() == mlir::omp::VariableCaptureKind::ByCopy)
1651 typeCap <<
"ByCopy";
1652 if (mapCaptureType.getValue() == mlir::omp::VariableCaptureKind::VLAType)
1653 typeCap <<
"VLAType";
1654 if (mapCaptureType.getValue() == mlir::omp::VariableCaptureKind::This)
1660 VariableCaptureKindAttr &mapCaptureType) {
1661 StringRef mapCaptureKey;
1665 if (mapCaptureKey ==
"This")
1667 parser.
getContext(), mlir::omp::VariableCaptureKind::This);
1668 if (mapCaptureKey ==
"ByRef")
1670 parser.
getContext(), mlir::omp::VariableCaptureKind::ByRef);
1671 if (mapCaptureKey ==
"ByCopy")
1673 parser.
getContext(), mlir::omp::VariableCaptureKind::ByCopy);
1674 if (mapCaptureKey ==
"VLAType")
1676 parser.
getContext(), mlir::omp::VariableCaptureKind::VLAType);
1685 for (
auto mapOp : mapVars) {
1686 if (!mapOp.getDefiningOp())
1689 if (
auto mapInfoOp =
1690 mlir::dyn_cast<mlir::omp::MapInfoOp>(mapOp.getDefiningOp())) {
1691 uint64_t mapTypeBits = mapInfoOp.getMapType();
1694 mapTypeBits, llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO);
1696 mapTypeBits, llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM);
1698 mapTypeBits, llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_DELETE);
1701 mapTypeBits, llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ALWAYS);
1703 mapTypeBits, llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_CLOSE);
1705 mapTypeBits, llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_IMPLICIT);
1707 if ((isa<TargetDataOp>(op) || isa<TargetOp>(op)) && del)
1709 "to, from, tofrom and alloc map types are permitted");
1711 if (isa<TargetEnterDataOp>(op) && (from || del))
1712 return emitError(op->
getLoc(),
"to and alloc map types are permitted");
1714 if (isa<TargetExitDataOp>(op) && to)
1716 "from, release and delete map types are permitted");
1718 if (isa<TargetUpdateOp>(op)) {
1721 "at least one of to or from map types must be "
1722 "specified, other map types are not permitted");
1727 "at least one of to or from map types must be "
1728 "specified, other map types are not permitted");
1731 auto updateVar = mapInfoOp.getVarPtr();
1733 if ((to && from) || (to && updateFromVars.contains(updateVar)) ||
1734 (from && updateToVars.contains(updateVar))) {
1737 "either to or from map types can be specified, not both");
1740 if (always || close || implicit) {
1743 "present, mapper and iterator map type modifiers are permitted");
1746 to ? updateToVars.insert(updateVar) : updateFromVars.insert(updateVar);
1748 }
else if (!isa<DeclareMapperInfoOp>(op)) {
1750 "map argument is not a map entry operation");
1758 std::optional<DenseI64ArrayAttr> privateMapIndices =
1759 targetOp.getPrivateMapsAttr();
1762 if (!privateMapIndices.has_value() || !privateMapIndices.value())
1767 if (privateMapIndices.value().size() !=
1768 static_cast<int64_t
>(privateVars.size()))
1769 return emitError(targetOp.getLoc(),
"sizes of `private` operand range and "
1770 "`private_maps` attribute mismatch");
1780 if (getMapperId() &&
1781 !SymbolTable::lookupNearestSymbolFrom<omp::DeclareMapperOp>(
1782 *
this, getMapperIdAttr())) {
1794 const TargetDataOperands &clauses) {
1795 TargetDataOp::build(builder, state, clauses.device, clauses.ifExpr,
1796 clauses.mapVars, clauses.useDeviceAddrVars,
1797 clauses.useDevicePtrVars);
1801 if (getMapVars().empty() && getUseDevicePtrVars().empty() &&
1802 getUseDeviceAddrVars().empty()) {
1804 "At least one of map, use_device_ptr_vars, or "
1805 "use_device_addr_vars operand must be present");
1814 void TargetEnterDataOp::build(
1818 TargetEnterDataOp::build(builder, state,
1820 clauses.dependVars, clauses.device, clauses.ifExpr,
1821 clauses.mapVars, clauses.nowait);
1825 LogicalResult verifyDependVars =
1827 return failed(verifyDependVars) ? verifyDependVars
1838 TargetExitDataOp::build(builder, state,
1840 clauses.dependVars, clauses.device, clauses.ifExpr,
1841 clauses.mapVars, clauses.nowait);
1845 LogicalResult verifyDependVars =
1847 return failed(verifyDependVars) ? verifyDependVars
1858 TargetUpdateOp::build(builder, state,
makeArrayAttr(ctx, clauses.dependKinds),
1859 clauses.dependVars, clauses.device, clauses.ifExpr,
1860 clauses.mapVars, clauses.nowait);
1864 LogicalResult verifyDependVars =
1866 return failed(verifyDependVars) ? verifyDependVars
1875 const TargetOperands &clauses) {
1879 TargetOp::build(builder, state, {}, {},
1881 clauses.dependVars, clauses.device, clauses.hasDeviceAddrVars,
1882 clauses.hostEvalVars, clauses.ifExpr,
1884 nullptr, clauses.isDevicePtrVars,
1885 clauses.mapVars, clauses.nowait, clauses.privateVars,
1886 makeArrayAttr(ctx, clauses.privateSyms), clauses.threadLimit,
1891 LogicalResult verifyDependVars =
1894 if (failed(verifyDependVars))
1895 return verifyDependVars;
1899 if (failed(verifyMapVars))
1900 return verifyMapVars;
1905 LogicalResult TargetOp::verifyRegions() {
1906 auto teamsOps = getOps<TeamsOp>();
1907 if (std::distance(teamsOps.begin(), teamsOps.end()) > 1)
1908 return emitError(
"target containing multiple 'omp.teams' nested ops");
1911 Operation *capturedOp = getInnermostCapturedOmpOp();
1912 TargetRegionFlags execFlags = getKernelExecFlags(capturedOp);
1913 for (
Value hostEvalArg :
1914 cast<BlockArgOpenMPOpInterface>(getOperation()).getHostEvalBlockArgs()) {
1916 if (
auto teamsOp = dyn_cast<TeamsOp>(user)) {
1917 if (llvm::is_contained({teamsOp.getNumTeamsLower(),
1918 teamsOp.getNumTeamsUpper(),
1919 teamsOp.getThreadLimit()},
1923 return emitOpError() <<
"host_eval argument only legal as 'num_teams' "
1924 "and 'thread_limit' in 'omp.teams'";
1926 if (
auto parallelOp = dyn_cast<ParallelOp>(user)) {
1927 if (bitEnumContainsAny(execFlags, TargetRegionFlags::spmd) &&
1928 parallelOp->isAncestor(capturedOp) &&
1929 hostEvalArg == parallelOp.getNumThreads())
1932 return emitOpError()
1933 <<
"host_eval argument only legal as 'num_threads' in "
1934 "'omp.parallel' when representing target SPMD";
1936 if (
auto loopNestOp = dyn_cast<LoopNestOp>(user)) {
1937 if (bitEnumContainsAny(execFlags, TargetRegionFlags::trip_count) &&
1938 loopNestOp.getOperation() == capturedOp &&
1939 (llvm::is_contained(loopNestOp.getLoopLowerBounds(), hostEvalArg) ||
1940 llvm::is_contained(loopNestOp.getLoopUpperBounds(), hostEvalArg) ||
1941 llvm::is_contained(loopNestOp.getLoopSteps(), hostEvalArg)))
1944 return emitOpError() <<
"host_eval argument only legal as loop bounds "
1945 "and steps in 'omp.loop_nest' when trip count "
1946 "must be evaluated in the host";
1949 return emitOpError() <<
"host_eval argument illegal use in '"
1950 << user->getName() <<
"' operation";
1959 assert(rootOp &&
"expected valid operation");
1971 return WalkResult::advance();
1976 bool isOmpDialect = op->
getDialect() == ompDialect;
1978 if (!isOmpDialect || !hasRegions)
1979 return WalkResult::skip();
1985 if (checkSingleMandatoryExec) {
1990 if (successor->isReachable(parentBlock))
1991 return WalkResult::interrupt();
1993 for (
Block &block : *parentRegion)
1995 !domInfo.
dominates(parentBlock, &block))
1996 return WalkResult::interrupt();
2002 if (&sibling != op && !siblingAllowedFn(&sibling))
2003 return WalkResult::interrupt();
2008 return llvm::isa<LoopNestOp>(op) ? WalkResult::interrupt()
2009 : WalkResult::advance();
2015 Operation *TargetOp::getInnermostCapturedOmpOp() {
2028 if (
auto memOp = dyn_cast<MemoryEffectOpInterface>(sibling)) {
2031 memOp.getEffects(effects);
2032 return !llvm::any_of(
2034 return isa<MemoryEffects::Write>(effect.
getEffect()) &&
2035 isa<SideEffects::AutomaticAllocationScopeResource>(
2043 TargetRegionFlags TargetOp::getKernelExecFlags(
Operation *capturedOp) {
2048 assert((!capturedOp ||
2049 (targetOp && targetOp.getInnermostCapturedOmpOp() == capturedOp)) &&
2050 "unexpected captured op");
2053 if (!isa_and_present<LoopNestOp>(capturedOp))
2054 return TargetRegionFlags::generic;
2058 cast<LoopNestOp>(capturedOp).gatherWrappers(loopWrappers);
2059 assert(!loopWrappers.empty());
2061 LoopWrapperInterface *innermostWrapper = loopWrappers.begin();
2062 if (isa<SimdOp>(innermostWrapper))
2063 innermostWrapper = std::next(innermostWrapper);
2065 auto numWrappers = std::distance(innermostWrapper, loopWrappers.end());
2066 if (numWrappers != 1 && numWrappers != 2)
2067 return TargetRegionFlags::generic;
2070 if (numWrappers == 2) {
2071 if (!isa<WsloopOp>(innermostWrapper))
2072 return TargetRegionFlags::generic;
2074 innermostWrapper = std::next(innermostWrapper);
2075 if (!isa<DistributeOp>(innermostWrapper))
2076 return TargetRegionFlags::generic;
2079 if (!isa_and_present<ParallelOp>(parallelOp))
2080 return TargetRegionFlags::generic;
2083 if (!isa_and_present<TeamsOp>(teamsOp))
2084 return TargetRegionFlags::generic;
2086 if (teamsOp->
getParentOp() == targetOp.getOperation())
2087 return TargetRegionFlags::spmd | TargetRegionFlags::trip_count;
2090 else if (isa<DistributeOp, LoopOp>(innermostWrapper)) {
2092 if (!isa_and_present<TeamsOp>(teamsOp))
2093 return TargetRegionFlags::generic;
2095 if (teamsOp->
getParentOp() != targetOp.getOperation())
2096 return TargetRegionFlags::generic;
2098 if (isa<LoopOp>(innermostWrapper))
2099 return TargetRegionFlags::spmd | TargetRegionFlags::trip_count;
2109 Dialect *ompDialect = targetOp->getDialect();
2113 return sibling && (ompDialect != sibling->
getDialect() ||
2117 TargetRegionFlags result =
2118 TargetRegionFlags::generic | TargetRegionFlags::trip_count;
2123 while (nestedCapture->
getParentOp() != capturedOp)
2126 return isa<ParallelOp>(nestedCapture) ? result | TargetRegionFlags::spmd
2130 else if (isa<WsloopOp>(innermostWrapper)) {
2132 if (!isa_and_present<ParallelOp>(parallelOp))
2133 return TargetRegionFlags::generic;
2135 if (parallelOp->
getParentOp() == targetOp.getOperation())
2136 return TargetRegionFlags::spmd;
2139 return TargetRegionFlags::generic;
2148 ParallelOp::build(builder, state,
ValueRange(),
2154 state.addAttributes(attributes);
2158 const ParallelOperands &clauses) {
2160 ParallelOp::build(builder, state, clauses.allocateVars, clauses.allocatorVars,
2161 clauses.ifExpr, clauses.numThreads, clauses.privateVars,
2163 clauses.procBindKind, clauses.reductionMod,
2164 clauses.reductionVars,
2169 template <
typename OpType>
2171 auto privateVars = op.getPrivateVars();
2172 auto privateSyms = op.getPrivateSymsAttr();
2174 if (privateVars.empty() && (privateSyms ==
nullptr || privateSyms.empty()))
2177 auto numPrivateVars = privateVars.size();
2178 auto numPrivateSyms = (privateSyms ==
nullptr) ? 0 : privateSyms.size();
2180 if (numPrivateVars != numPrivateSyms)
2181 return op.emitError() <<
"inconsistent number of private variables and "
2182 "privatizer op symbols, private vars: "
2184 <<
" vs. privatizer op symbols: " << numPrivateSyms;
2186 for (
auto privateVarInfo : llvm::zip_equal(privateVars, privateSyms)) {
2187 Type varType = std::get<0>(privateVarInfo).getType();
2188 SymbolRefAttr privateSym = cast<SymbolRefAttr>(std::get<1>(privateVarInfo));
2189 PrivateClauseOp privatizerOp =
2190 SymbolTable::lookupNearestSymbolFrom<PrivateClauseOp>(op, privateSym);
2192 if (privatizerOp ==
nullptr)
2193 return op.emitError() <<
"failed to lookup privatizer op with symbol: '"
2194 << privateSym <<
"'";
2196 Type privatizerType = privatizerOp.getArgType();
2198 if (privatizerType && (varType != privatizerType))
2199 return op.emitError()
2200 <<
"type mismatch between a "
2201 << (privatizerOp.getDataSharingType() ==
2202 DataSharingClauseType::Private
2205 <<
" variable and its privatizer op, var type: " << varType
2206 <<
" vs. privatizer op type: " << privatizerType;
2213 if (getAllocateVars().size() != getAllocatorVars().size())
2215 "expected equal sizes for allocate and allocator variables");
2221 getReductionByref());
2224 LogicalResult ParallelOp::verifyRegions() {
2225 auto distChildOps = getOps<DistributeOp>();
2226 int numDistChildOps = std::distance(distChildOps.begin(), distChildOps.end());
2227 if (numDistChildOps > 1)
2229 <<
"multiple 'omp.distribute' nested inside of 'omp.parallel'";
2231 if (numDistChildOps == 1) {
2234 <<
"'omp.composite' attribute missing from composite operation";
2237 Operation &distributeOp = **distChildOps.begin();
2239 if (&childOp == &distributeOp || ompDialect != childOp.getDialect())
2243 return emitError() <<
"unexpected OpenMP operation inside of composite "
2245 << childOp.getName();
2247 }
else if (isComposite()) {
2249 <<
"'omp.composite' attribute present in non-composite operation";
2266 const TeamsOperands &clauses) {
2269 TeamsOp::build(builder, state, clauses.allocateVars, clauses.allocatorVars,
2270 clauses.ifExpr, clauses.numTeamsLower, clauses.numTeamsUpper,
2272 clauses.reductionMod, clauses.reductionVars,
2275 clauses.threadLimit);
2287 return emitError(
"expected to be nested inside of omp.target or not nested "
2288 "in any OpenMP dialect operations");
2291 if (
auto numTeamsLowerBound = getNumTeamsLower()) {
2292 auto numTeamsUpperBound = getNumTeamsUpper();
2293 if (!numTeamsUpperBound)
2294 return emitError(
"expected num_teams upper bound to be defined if the "
2295 "lower bound is defined");
2296 if (numTeamsLowerBound.getType() != numTeamsUpperBound.getType())
2298 "expected num_teams upper bound and lower bound to be the same type");
2302 if (getAllocateVars().size() != getAllocatorVars().size())
2304 "expected equal sizes for allocate and allocator variables");
2307 getReductionByref());
2315 return getParentOp().getPrivateVars();
2319 return getParentOp().getReductionVars();
2327 const SectionsOperands &clauses) {
2330 SectionsOp::build(builder, state, clauses.allocateVars, clauses.allocatorVars,
2332 nullptr, clauses.reductionMod,
2333 clauses.reductionVars,
2339 if (getAllocateVars().size() != getAllocatorVars().size())
2341 "expected equal sizes for allocate and allocator variables");
2344 getReductionByref());
2347 LogicalResult SectionsOp::verifyRegions() {
2348 for (
auto &inst : *getRegion().begin()) {
2349 if (!(isa<SectionOp>(inst) || isa<TerminatorOp>(inst))) {
2350 return emitOpError()
2351 <<
"expected omp.section op or terminator op inside region";
2363 const SingleOperands &clauses) {
2366 SingleOp::build(builder, state, clauses.allocateVars, clauses.allocatorVars,
2367 clauses.copyprivateVars,
2368 makeArrayAttr(ctx, clauses.copyprivateSyms), clauses.nowait,
2374 if (getAllocateVars().size() != getAllocatorVars().size())
2376 "expected equal sizes for allocate and allocator variables");
2379 getCopyprivateSyms());
2387 const WorkshareOperands &clauses) {
2388 WorkshareOp::build(builder, state, clauses.nowait);
2396 if (!(*this)->getParentOfType<WorkshareOp>())
2397 return emitOpError() <<
"must be nested in an omp.workshare";
2401 LogicalResult WorkshareLoopWrapperOp::verifyRegions() {
2402 if (isa_and_nonnull<LoopWrapperInterface>((*this)->getParentOp()) ||
2404 return emitOpError() <<
"expected to be a standalone loop wrapper";
2413 LogicalResult LoopWrapperInterface::verifyImpl() {
2417 return emitOpError() <<
"loop wrapper must also have the `NoTerminator` "
2418 "and `SingleBlock` traits";
2421 return emitOpError() <<
"loop wrapper does not contain exactly one region";
2424 if (range_size(region.
getOps()) != 1)
2425 return emitOpError()
2426 <<
"loop wrapper does not contain exactly one nested op";
2429 if (!isa<LoopNestOp, LoopWrapperInterface>(firstOp))
2430 return emitOpError() <<
"nested in loop wrapper is not another loop "
2431 "wrapper or `omp.loop_nest`";
2441 const LoopOperands &clauses) {
2444 LoopOp::build(builder, state, clauses.bindKind, clauses.privateVars,
2446 clauses.orderMod, clauses.reductionMod, clauses.reductionVars,
2453 getReductionByref());
2456 LogicalResult LoopOp::verifyRegions() {
2457 if (llvm::isa_and_nonnull<LoopWrapperInterface>((*this)->getParentOp()) ||
2459 return emitOpError() <<
"expected to be a standalone loop wrapper";
2470 build(builder, state, {}, {},
2472 false,
nullptr,
nullptr,
2473 nullptr, {},
nullptr,
2479 state.addAttributes(attributes);
2483 const WsloopOperands &clauses) {
2487 WsloopOp::build(builder, state,
2489 clauses.linearVars, clauses.linearStepVars, clauses.nowait,
2490 clauses.order, clauses.orderMod, clauses.ordered,
2491 clauses.privateVars,
makeArrayAttr(ctx, clauses.privateSyms),
2492 clauses.reductionMod, clauses.reductionVars,
2495 clauses.scheduleKind, clauses.scheduleChunk,
2496 clauses.scheduleMod, clauses.scheduleSimd);
2501 getReductionByref());
2504 LogicalResult WsloopOp::verifyRegions() {
2505 bool isCompositeChildLeaf =
2506 llvm::dyn_cast_if_present<LoopWrapperInterface>((*this)->getParentOp());
2508 if (LoopWrapperInterface nested = getNestedWrapper()) {
2511 <<
"'omp.composite' attribute missing from composite wrapper";
2515 if (!isa<SimdOp>(nested))
2516 return emitError() <<
"only supported nested wrapper is 'omp.simd'";
2518 }
else if (isComposite() && !isCompositeChildLeaf) {
2520 <<
"'omp.composite' attribute present in non-composite wrapper";
2521 }
else if (!isComposite() && isCompositeChildLeaf) {
2523 <<
"'omp.composite' attribute missing from composite wrapper";
2534 const SimdOperands &clauses) {
2538 SimdOp::build(builder, state, clauses.alignedVars,
2541 clauses.nontemporalVars, clauses.order, clauses.orderMod,
2542 clauses.privateVars,
makeArrayAttr(ctx, clauses.privateSyms),
2543 clauses.reductionMod, clauses.reductionVars,
2550 if (getSimdlen().has_value() && getSafelen().has_value() &&
2551 getSimdlen().value() > getSafelen().value())
2552 return emitOpError()
2553 <<
"simdlen clause and safelen clause are both present, but the "
2554 "simdlen value is not less than or equal to safelen value";
2562 bool isCompositeChildLeaf =
2563 llvm::dyn_cast_if_present<LoopWrapperInterface>((*this)->getParentOp());
2565 if (!isComposite() && isCompositeChildLeaf)
2567 <<
"'omp.composite' attribute missing from composite wrapper";
2569 if (isComposite() && !isCompositeChildLeaf)
2571 <<
"'omp.composite' attribute present in non-composite wrapper";
2576 LogicalResult SimdOp::verifyRegions() {
2577 if (getNestedWrapper())
2578 return emitOpError() <<
"must wrap an 'omp.loop_nest' directly";
2588 const DistributeOperands &clauses) {
2589 DistributeOp::build(builder, state, clauses.allocateVars,
2590 clauses.allocatorVars, clauses.distScheduleStatic,
2591 clauses.distScheduleChunkSize, clauses.order,
2592 clauses.orderMod, clauses.privateVars,
2597 if (this->getDistScheduleChunkSize() && !this->getDistScheduleStatic())
2598 return emitOpError() <<
"chunk size set without "
2599 "dist_schedule_static being present";
2601 if (getAllocateVars().size() != getAllocatorVars().size())
2603 "expected equal sizes for allocate and allocator variables");
2608 LogicalResult DistributeOp::verifyRegions() {
2609 if (LoopWrapperInterface nested = getNestedWrapper()) {
2612 <<
"'omp.composite' attribute missing from composite wrapper";
2615 if (isa<WsloopOp>(nested)) {
2617 if (!llvm::dyn_cast_if_present<ParallelOp>(parentOp) ||
2618 !cast<ComposableOpInterface>(parentOp).isComposite()) {
2619 return emitError() <<
"an 'omp.wsloop' nested wrapper is only allowed "
2620 "when a composite 'omp.parallel' is the direct "
2623 }
else if (!isa<SimdOp>(nested))
2624 return emitError() <<
"only supported nested wrappers are 'omp.simd' and "
2626 }
else if (isComposite()) {
2628 <<
"'omp.composite' attribute present in non-composite wrapper";
2642 LogicalResult DeclareMapperOp::verifyRegions() {
2643 if (!llvm::isa_and_present<DeclareMapperInfoOp>(
2644 getRegion().getBlocks().front().getTerminator()))
2645 return emitOpError() <<
"expected terminator to be a DeclareMapperInfoOp";
2654 LogicalResult DeclareReductionOp::verifyRegions() {
2655 if (!getAllocRegion().empty()) {
2656 for (YieldOp yieldOp : getAllocRegion().getOps<YieldOp>()) {
2657 if (yieldOp.getResults().size() != 1 ||
2658 yieldOp.getResults().getTypes()[0] !=
getType())
2659 return emitOpError() <<
"expects alloc region to yield a value "
2660 "of the reduction type";
2664 if (getInitializerRegion().empty())
2665 return emitOpError() <<
"expects non-empty initializer region";
2666 Block &initializerEntryBlock = getInitializerRegion().
front();
2669 if (!getAllocRegion().empty())
2670 return emitOpError() <<
"expects two arguments to the initializer region "
2671 "when an allocation region is used";
2673 if (getAllocRegion().empty())
2674 return emitOpError() <<
"expects one argument to the initializer region "
2675 "when no allocation region is used";
2677 return emitOpError()
2678 <<
"expects one or two arguments to the initializer region";
2682 if (arg.getType() !=
getType())
2683 return emitOpError() <<
"expects initializer region argument to match "
2684 "the reduction type";
2686 for (YieldOp yieldOp : getInitializerRegion().getOps<YieldOp>()) {
2687 if (yieldOp.getResults().size() != 1 ||
2688 yieldOp.getResults().getTypes()[0] !=
getType())
2689 return emitOpError() <<
"expects initializer region to yield a value "
2690 "of the reduction type";
2693 if (getReductionRegion().empty())
2694 return emitOpError() <<
"expects non-empty reduction region";
2695 Block &reductionEntryBlock = getReductionRegion().
front();
2700 return emitOpError() <<
"expects reduction region with two arguments of "
2701 "the reduction type";
2702 for (YieldOp yieldOp : getReductionRegion().getOps<YieldOp>()) {
2703 if (yieldOp.getResults().size() != 1 ||
2704 yieldOp.getResults().getTypes()[0] !=
getType())
2705 return emitOpError() <<
"expects reduction region to yield a value "
2706 "of the reduction type";
2709 if (!getAtomicReductionRegion().empty()) {
2710 Block &atomicReductionEntryBlock = getAtomicReductionRegion().
front();
2714 return emitOpError() <<
"expects atomic reduction region with two "
2715 "arguments of the same type";
2716 auto ptrType = llvm::dyn_cast<PointerLikeType>(
2719 (ptrType.getElementType() && ptrType.getElementType() !=
getType()))
2720 return emitOpError() <<
"expects atomic reduction region arguments to "
2721 "be accumulators containing the reduction type";
2724 if (getCleanupRegion().empty())
2726 Block &cleanupEntryBlock = getCleanupRegion().
front();
2729 return emitOpError() <<
"expects cleanup region with one argument "
2730 "of the reduction type";
2740 const TaskOperands &clauses) {
2742 TaskOp::build(builder, state, clauses.allocateVars, clauses.allocatorVars,
2743 makeArrayAttr(ctx, clauses.dependKinds), clauses.dependVars,
2744 clauses.final, clauses.ifExpr, clauses.inReductionVars,
2746 makeArrayAttr(ctx, clauses.inReductionSyms), clauses.mergeable,
2747 clauses.priority, clauses.privateVars,
2749 clauses.untied, clauses.eventHandle);
2753 LogicalResult verifyDependVars =
2755 return failed(verifyDependVars)
2758 getInReductionVars(),
2759 getInReductionByref());
2767 const TaskgroupOperands &clauses) {
2769 TaskgroupOp::build(builder, state, clauses.allocateVars,
2770 clauses.allocatorVars, clauses.taskReductionVars,
2777 getTaskReductionVars(),
2778 getTaskReductionByref());
2786 const TaskloopOperands &clauses) {
2789 TaskloopOp::build(builder, state, clauses.allocateVars, clauses.allocatorVars,
2790 clauses.final, clauses.grainsizeMod, clauses.grainsize,
2791 clauses.ifExpr, clauses.inReductionVars,
2794 clauses.mergeable, clauses.nogroup, clauses.numTasksMod,
2795 clauses.numTasks, clauses.priority, {},
2796 nullptr, clauses.reductionMod,
2797 clauses.reductionVars,
2803 if (getAllocateVars().size() != getAllocatorVars().size())
2805 "expected equal sizes for allocate and allocator variables");
2807 getReductionVars(), getReductionByref())) ||
2809 getInReductionVars(),
2810 getInReductionByref())))
2813 if (!getReductionVars().empty() && getNogroup())
2814 return emitError(
"if a reduction clause is present on the taskloop "
2815 "directive, the nogroup clause must not be specified");
2816 for (
auto var : getReductionVars()) {
2817 if (llvm::is_contained(getInReductionVars(), var))
2818 return emitError(
"the same list item cannot appear in both a reduction "
2819 "and an in_reduction clause");
2822 if (getGrainsize() && getNumTasks()) {
2824 "the grainsize clause and num_tasks clause are mutually exclusive and "
2825 "may not appear on the same taskloop directive");
2831 LogicalResult TaskloopOp::verifyRegions() {
2832 if (LoopWrapperInterface nested = getNestedWrapper()) {
2835 <<
"'omp.composite' attribute missing from composite wrapper";
2839 if (!isa<SimdOp>(nested))
2840 return emitError() <<
"only supported nested wrapper is 'omp.simd'";
2841 }
else if (isComposite()) {
2843 <<
"'omp.composite' attribute present in non-composite wrapper";
2867 for (
auto &iv : ivs)
2868 iv.type = loopVarType;
2897 Region ®ion = getRegion();
2899 p <<
" (" << args <<
") : " << args[0].getType() <<
" = ("
2900 << getLoopLowerBounds() <<
") to (" << getLoopUpperBounds() <<
") ";
2901 if (getLoopInclusive())
2903 p <<
"step (" << getLoopSteps() <<
") ";
2908 const LoopNestOperands &clauses) {
2909 LoopNestOp::build(builder, state, clauses.loopLowerBounds,
2910 clauses.loopUpperBounds, clauses.loopSteps,
2911 clauses.loopInclusive);
2915 if (getLoopLowerBounds().empty())
2916 return emitOpError() <<
"must represent at least one loop";
2918 if (getLoopLowerBounds().size() != getIVs().size())
2919 return emitOpError() <<
"number of range arguments and IVs do not match";
2921 for (
auto [lb, iv] : llvm::zip_equal(getLoopLowerBounds(), getIVs())) {
2922 if (lb.getType() != iv.getType())
2923 return emitOpError()
2924 <<
"range argument type does not match corresponding IV type";
2927 if (!llvm::dyn_cast_if_present<LoopWrapperInterface>((*this)->getParentOp()))
2928 return emitOpError() <<
"expects parent op to be a loop wrapper";
2933 void LoopNestOp::gatherWrappers(
2936 while (
auto wrapper =
2937 llvm::dyn_cast_if_present<LoopWrapperInterface>(parent)) {
2938 wrappers.push_back(wrapper);
2948 const CriticalDeclareOperands &clauses) {
2949 CriticalDeclareOp::build(builder, state, clauses.symName, clauses.hint);
2957 if (getNameAttr()) {
2958 SymbolRefAttr symbolRef = getNameAttr();
2962 return emitOpError() <<
"expected symbol reference " << symbolRef
2963 <<
" to point to a critical declaration";
2983 return op.
emitOpError() <<
"must be nested inside of a loop";
2987 if (
auto wsloopOp = dyn_cast<WsloopOp>(wrapper)) {
2988 IntegerAttr orderedAttr = wsloopOp.getOrderedAttr();
2990 return op.
emitOpError() <<
"the enclosing worksharing-loop region must "
2991 "have an ordered clause";
2993 if (hasRegion && orderedAttr.getInt() != 0)
2994 return op.
emitOpError() <<
"the enclosing loop's ordered clause must not "
2995 "have a parameter present";
2997 if (!hasRegion && orderedAttr.getInt() == 0)
2998 return op.
emitOpError() <<
"the enclosing loop's ordered clause must "
2999 "have a parameter present";
3000 }
else if (!isa<SimdOp>(wrapper)) {
3001 return op.
emitOpError() <<
"must be nested inside of a worksharing, simd "
3002 "or worksharing simd loop";
3008 const OrderedOperands &clauses) {
3009 OrderedOp::build(builder, state, clauses.doacrossDependType,
3010 clauses.doacrossNumLoops, clauses.doacrossDependVars);
3017 auto wrapper = (*this)->getParentOfType<WsloopOp>();
3018 if (!wrapper || *wrapper.getOrdered() != *getDoacrossNumLoops())
3019 return emitOpError() <<
"number of variables in depend clause does not "
3020 <<
"match number of iteration variables in the "
3027 const OrderedRegionOperands &clauses) {
3028 OrderedRegionOp::build(builder, state, clauses.parLevelSimd);
3038 const TaskwaitOperands &clauses) {
3040 TaskwaitOp::build(builder, state,
nullptr,
3049 if (verifyCommon().failed())
3050 return mlir::failure();
3052 if (
auto mo = getMemoryOrder()) {
3053 if (*mo == ClauseMemoryOrderKind::Acq_rel ||
3054 *mo == ClauseMemoryOrderKind::Release) {
3056 "memory-order must not be acq_rel or release for atomic reads");
3067 if (verifyCommon().failed())
3068 return mlir::failure();
3070 if (
auto mo = getMemoryOrder()) {
3071 if (*mo == ClauseMemoryOrderKind::Acq_rel ||
3072 *mo == ClauseMemoryOrderKind::Acquire) {
3074 "memory-order must not be acq_rel or acquire for atomic writes");
3084 LogicalResult AtomicUpdateOp::canonicalize(AtomicUpdateOp op,
3090 if (
Value writeVal = op.getWriteOpVal()) {
3092 op, op.getX(), writeVal, op.getHintAttr(), op.getMemoryOrderAttr());
3099 if (verifyCommon().failed())
3100 return mlir::failure();
3102 if (
auto mo = getMemoryOrder()) {
3103 if (*mo == ClauseMemoryOrderKind::Acq_rel ||
3104 *mo == ClauseMemoryOrderKind::Acquire) {
3106 "memory-order must not be acq_rel or acquire for atomic updates");
3113 LogicalResult AtomicUpdateOp::verifyRegions() {
return verifyRegionsCommon(); }
3119 AtomicReadOp AtomicCaptureOp::getAtomicReadOp() {
3120 if (
auto op = dyn_cast<AtomicReadOp>(getFirstOp()))
3122 return dyn_cast<AtomicReadOp>(getSecondOp());
3125 AtomicWriteOp AtomicCaptureOp::getAtomicWriteOp() {
3126 if (
auto op = dyn_cast<AtomicWriteOp>(getFirstOp()))
3128 return dyn_cast<AtomicWriteOp>(getSecondOp());
3131 AtomicUpdateOp AtomicCaptureOp::getAtomicUpdateOp() {
3132 if (
auto op = dyn_cast<AtomicUpdateOp>(getFirstOp()))
3134 return dyn_cast<AtomicUpdateOp>(getSecondOp());
3141 LogicalResult AtomicCaptureOp::verifyRegions() {
3142 if (verifyRegionsCommon().failed())
3143 return mlir::failure();
3145 if (getFirstOp()->getAttr(
"hint") || getSecondOp()->getAttr(
"hint"))
3147 "operations inside capture region must not have hint clause");
3149 if (getFirstOp()->getAttr(
"memory_order") ||
3150 getSecondOp()->getAttr(
"memory_order"))
3152 "operations inside capture region must not have memory_order clause");
3161 const CancelOperands &clauses) {
3162 CancelOp::build(builder, state, clauses.cancelDirective, clauses.ifExpr);
3176 ClauseCancellationConstructType cct = getCancelDirective();
3179 if (!structuralParent)
3180 return emitOpError() <<
"Orphaned cancel construct";
3182 if ((cct == ClauseCancellationConstructType::Parallel) &&
3183 !mlir::isa<ParallelOp>(structuralParent)) {
3184 return emitOpError() <<
"cancel parallel must appear "
3185 <<
"inside a parallel region";
3187 if (cct == ClauseCancellationConstructType::Loop) {
3190 auto wsloopOp = mlir::dyn_cast<WsloopOp>(structuralParent->
getParentOp());
3193 return emitOpError()
3194 <<
"cancel loop must appear inside a worksharing-loop region";
3196 if (wsloopOp.getNowaitAttr()) {
3197 return emitError() <<
"A worksharing construct that is canceled "
3198 <<
"must not have a nowait clause";
3200 if (wsloopOp.getOrderedAttr()) {
3201 return emitError() <<
"A worksharing construct that is canceled "
3202 <<
"must not have an ordered clause";
3205 }
else if (cct == ClauseCancellationConstructType::Sections) {
3209 mlir::dyn_cast<SectionsOp>(structuralParent->
getParentOp());
3211 return emitOpError() <<
"cancel sections must appear "
3212 <<
"inside a sections region";
3214 if (sectionsOp.getNowait()) {
3215 return emitError() <<
"A sections construct that is canceled "
3216 <<
"must not have a nowait clause";
3228 const CancellationPointOperands &clauses) {
3229 CancellationPointOp::build(builder, state, clauses.cancelDirective);
3233 ClauseCancellationConstructType cct = getCancelDirective();
3236 if (!structuralParent)
3237 return emitOpError() <<
"Orphaned cancellation point";
3239 if ((cct == ClauseCancellationConstructType::Parallel) &&
3240 !mlir::isa<ParallelOp>(structuralParent)) {
3241 return emitOpError() <<
"cancellation point parallel must appear "
3242 <<
"inside a parallel region";
3246 if ((cct == ClauseCancellationConstructType::Loop) &&
3247 !mlir::isa<WsloopOp>(structuralParent->
getParentOp())) {
3248 return emitOpError() <<
"cancellation point loop must appear "
3249 <<
"inside a worksharing-loop region";
3251 if ((cct == ClauseCancellationConstructType::Sections) &&
3252 !mlir::isa<omp::SectionOp>(structuralParent)) {
3253 return emitOpError() <<
"cancellation point sections must appear "
3254 <<
"inside a sections region";
3265 auto extent = getExtent();
3267 if (!extent && !upperbound)
3268 return emitError(
"expected extent or upperbound.");
3275 PrivateClauseOp::build(
3276 odsBuilder, odsState, symName, type,
3278 DataSharingClauseType::Private));
3281 LogicalResult PrivateClauseOp::verifyRegions() {
3282 Type argType = getArgType();
3283 auto verifyTerminator = [&](
Operation *terminator,
3284 bool yieldsValue) -> LogicalResult {
3288 if (!llvm::isa<YieldOp>(terminator))
3290 <<
"expected exit block terminator to be an `omp.yield` op.";
3292 YieldOp yieldOp = llvm::cast<YieldOp>(terminator);
3293 TypeRange yieldedTypes = yieldOp.getResults().getTypes();
3296 if (yieldedTypes.empty())
3300 <<
"Did not expect any values to be yielded.";
3303 if (yieldedTypes.size() == 1 && yieldedTypes.front() == argType)
3307 <<
"Invalid yielded value. Expected type: " << argType
3310 if (yieldedTypes.empty())
3313 error << yieldedTypes;
3319 StringRef regionName,
3320 bool yieldsValue) -> LogicalResult {
3321 assert(!region.
empty());
3325 <<
"`" << regionName <<
"`: "
3326 <<
"expected " << expectedNumArgs
3329 for (
Block &block : region) {
3331 if (!block.mightHaveTerminator())
3334 if (failed(verifyTerminator(block.getTerminator(), yieldsValue)))
3342 for (
Region *region : getRegions())
3343 for (
Type ty : region->getArgumentTypes())
3345 return emitError() <<
"Region argument type mismatch: got " << ty
3346 <<
" expected " << argType <<
".";
3349 if (!initRegion.
empty() &&
3354 DataSharingClauseType dsType = getDataSharingType();
3356 if (dsType == DataSharingClauseType::Private && !getCopyRegion().empty())
3357 return emitError(
"`private` clauses do not require a `copy` region.");
3359 if (dsType == DataSharingClauseType::FirstPrivate && getCopyRegion().empty())
3361 "`firstprivate` clauses require at least a `copy` region.");
3363 if (dsType == DataSharingClauseType::FirstPrivate &&
3368 if (!getDeallocRegion().empty() &&
3381 const MaskedOperands &clauses) {
3382 MaskedOp::build(builder, state, clauses.filteredThreadId);
3390 const ScanOperands &clauses) {
3391 ScanOp::build(builder, state, clauses.inclusiveVars, clauses.exclusiveVars);
3395 if (hasExclusiveVars() == hasInclusiveVars())
3397 "Exactly one of EXCLUSIVE or INCLUSIVE clause is expected");
3398 if (WsloopOp parentWsLoopOp = (*this)->getParentOfType<WsloopOp>()) {
3399 if (parentWsLoopOp.getReductionModAttr() &&
3400 parentWsLoopOp.getReductionModAttr().getValue() ==
3401 ReductionModifier::inscan)
3404 if (SimdOp parentSimdOp = (*this)->getParentOfType<SimdOp>()) {
3405 if (parentSimdOp.getReductionModAttr() &&
3406 parentSimdOp.getReductionModAttr().getValue() ==
3407 ReductionModifier::inscan)
3410 return emitError(
"SCAN directive needs to be enclosed within a parent "
3411 "worksharing loop construct or SIMD construct with INSCAN "
3412 "reduction modifier");
3415 #define GET_ATTRDEF_CLASSES
3416 #include "mlir/Dialect/OpenMP/OpenMPOpsAttributes.cpp.inc"
3418 #define GET_OP_CLASSES
3419 #include "mlir/Dialect/OpenMP/OpenMPOps.cpp.inc"
3421 #define GET_TYPEDEF_CLASSES
3422 #include "mlir/Dialect/OpenMP/OpenMPOpsTypes.cpp.inc"
static std::optional< int64_t > getUpperBound(Value iv)
Gets the constant upper bound on an affine.for iv.
static LogicalResult verifyRegion(emitc::SwitchOp op, Region ®ion, const Twine &name)
static void visit(Operation *op, DenseSet< Operation * > &visited)
Visits all the pdl.operand(s), pdl.result(s), and pdl.operation(s) connected to the given operation.
static MLIRContext * getContext(OpFoldResult val)
void printClauseAttr(OpAsmPrinter &p, Operation *op, ClauseAttr attr)
static LogicalResult verifyNontemporalClause(Operation *op, OperandRange nontemporalVars)
static ParseResult parsePrivateRegion(OpAsmParser &parser, Region ®ion, llvm::SmallVectorImpl< OpAsmParser::UnresolvedOperand > &privateVars, llvm::SmallVectorImpl< Type > &privateTypes, ArrayAttr &privateSyms)
static LogicalResult verifyMapClause(Operation *op, OperandRange mapVars)
static ArrayAttr makeArrayAttr(MLIRContext *context, llvm::ArrayRef< Attribute > attrs)
static ParseResult parseClauseAttr(AsmParser &parser, ClauseAttr &attr)
static void printTargetOpRegion(OpAsmPrinter &p, Operation *op, Region ®ion, ValueRange hasDeviceAddrVars, TypeRange hasDeviceAddrTypes, ValueRange hostEvalVars, TypeRange hostEvalTypes, ValueRange inReductionVars, TypeRange inReductionTypes, DenseBoolArrayAttr inReductionByref, ArrayAttr inReductionSyms, ValueRange mapVars, TypeRange mapTypes, ValueRange privateVars, TypeRange privateTypes, ArrayAttr privateSyms, DenseI64ArrayAttr privateMaps)
static void printAllocateAndAllocator(OpAsmPrinter &p, Operation *op, OperandRange allocateVars, TypeRange allocateTypes, OperandRange allocatorVars, TypeRange allocatorTypes)
Print allocate clause.
static DenseBoolArrayAttr makeDenseBoolArrayAttr(MLIRContext *ctx, const ArrayRef< bool > boolArray)
static ParseResult parseInReductionPrivateRegion(OpAsmParser &parser, Region ®ion, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &inReductionVars, SmallVectorImpl< Type > &inReductionTypes, DenseBoolArrayAttr &inReductionByref, ArrayAttr &inReductionSyms, llvm::SmallVectorImpl< OpAsmParser::UnresolvedOperand > &privateVars, llvm::SmallVectorImpl< Type > &privateTypes, ArrayAttr &privateSyms)
static void printBlockArgClause(OpAsmPrinter &p, MLIRContext *ctx, StringRef clauseName, ValueRange argsSubrange, std::optional< MapPrintArgs > mapArgs)
static void printBlockArgRegion(OpAsmPrinter &p, Operation *op, Region ®ion, const AllRegionPrintArgs &args)
static ParseResult parseGranularityClause(OpAsmParser &parser, ClauseTypeAttr &prescriptiveness, std::optional< OpAsmParser::UnresolvedOperand > &operand, Type &operandType, std::optional< ClauseType >(*symbolizeClause)(StringRef), StringRef clauseName)
static ParseResult parseBlockArgRegion(OpAsmParser &parser, Region ®ion, AllRegionParseArgs args)
static ParseResult parseSynchronizationHint(OpAsmParser &parser, IntegerAttr &hintAttr)
Parses a Synchronization Hint clause.
static void printClauseWithRegionArgs(OpAsmPrinter &p, MLIRContext *ctx, StringRef clauseName, ValueRange argsSubrange, ValueRange operands, TypeRange types, ArrayAttr symbols=nullptr, DenseI64ArrayAttr mapIndices=nullptr, DenseBoolArrayAttr byref=nullptr, ReductionModifierAttr modifier=nullptr)
uint64_t mapTypeToBitFlag(uint64_t value, llvm::omp::OpenMPOffloadMappingFlags flag)
static ParseResult parseLinearClause(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &linearVars, SmallVectorImpl< Type > &linearTypes, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &linearStepVars)
linear ::= linear ( linear-list ) linear-list := linear-val | linear-val linear-list linear-val := ss...
static void printInReductionPrivateReductionRegion(OpAsmPrinter &p, Operation *op, Region ®ion, ValueRange inReductionVars, TypeRange inReductionTypes, DenseBoolArrayAttr inReductionByref, ArrayAttr inReductionSyms, ValueRange privateVars, TypeRange privateTypes, ArrayAttr privateSyms, ReductionModifierAttr reductionMod, ValueRange reductionVars, TypeRange reductionTypes, DenseBoolArrayAttr reductionByref, ArrayAttr reductionSyms)
static void printScheduleClause(OpAsmPrinter &p, Operation *op, ClauseScheduleKindAttr scheduleKind, ScheduleModifierAttr scheduleMod, UnitAttr scheduleSimd, Value scheduleChunk, Type scheduleChunkType)
Print schedule clause.
static void printPrivateReductionRegion(OpAsmPrinter &p, Operation *op, Region ®ion, ValueRange privateVars, TypeRange privateTypes, ArrayAttr privateSyms, ReductionModifierAttr reductionMod, ValueRange reductionVars, TypeRange reductionTypes, DenseBoolArrayAttr reductionByref, ArrayAttr reductionSyms)
static void printCopyprivate(OpAsmPrinter &p, Operation *op, OperandRange copyprivateVars, TypeRange copyprivateTypes, std::optional< ArrayAttr > copyprivateSyms)
Print Copyprivate clause.
static ParseResult parseOrderClause(OpAsmParser &parser, ClauseOrderKindAttr &order, OrderModifierAttr &orderMod)
static void printAlignedClause(OpAsmPrinter &p, Operation *op, ValueRange alignedVars, TypeRange alignedTypes, std::optional< ArrayAttr > alignments)
Print Aligned Clause.
static LogicalResult verifySynchronizationHint(Operation *op, uint64_t hint)
Verifies a synchronization hint clause.
static ParseResult parseUseDeviceAddrUseDevicePtrRegion(OpAsmParser &parser, Region ®ion, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &useDeviceAddrVars, SmallVectorImpl< Type > &useDeviceAddrTypes, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &useDevicePtrVars, SmallVectorImpl< Type > &useDevicePtrTypes)
static void printLinearClause(OpAsmPrinter &p, Operation *op, ValueRange linearVars, TypeRange linearTypes, ValueRange linearStepVars)
Print Linear Clause.
static void printSynchronizationHint(OpAsmPrinter &p, Operation *op, IntegerAttr hintAttr)
Prints a Synchronization Hint clause.
static void printGranularityClause(OpAsmPrinter &p, Operation *op, ClauseTypeAttr prescriptiveness, Value operand, mlir::Type operandType, StringRef(*stringifyClauseType)(ClauseType))
static void printDependVarList(OpAsmPrinter &p, Operation *op, OperandRange dependVars, TypeRange dependTypes, std::optional< ArrayAttr > dependKinds)
Print Depend clause.
static LogicalResult verifyCopyprivateVarList(Operation *op, OperandRange copyprivateVars, std::optional< ArrayAttr > copyprivateSyms)
Verifies CopyPrivate Clause.
static LogicalResult verifyAlignedClause(Operation *op, std::optional< ArrayAttr > alignments, OperandRange alignedVars)
static void printNumTasksClause(OpAsmPrinter &p, Operation *op, ClauseNumTasksTypeAttr numTasksMod, Value numTasks, mlir::Type numTasksType)
static void printTaskReductionRegion(OpAsmPrinter &p, Operation *op, Region ®ion, ValueRange taskReductionVars, TypeRange taskReductionTypes, DenseBoolArrayAttr taskReductionByref, ArrayAttr taskReductionSyms)
static ParseResult parseInReductionPrivateReductionRegion(OpAsmParser &parser, Region ®ion, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &inReductionVars, SmallVectorImpl< Type > &inReductionTypes, DenseBoolArrayAttr &inReductionByref, ArrayAttr &inReductionSyms, llvm::SmallVectorImpl< OpAsmParser::UnresolvedOperand > &privateVars, llvm::SmallVectorImpl< Type > &privateTypes, ArrayAttr &privateSyms, ReductionModifierAttr &reductionMod, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &reductionVars, SmallVectorImpl< Type > &reductionTypes, DenseBoolArrayAttr &reductionByref, ArrayAttr &reductionSyms)
static ParseResult parseMapClause(OpAsmParser &parser, IntegerAttr &mapType)
Parses a map_entries map type from a string format back into its numeric value.
static LogicalResult verifyOrderedParent(Operation &op)
static void printOrderClause(OpAsmPrinter &p, Operation *op, ClauseOrderKindAttr order, OrderModifierAttr orderMod)
static ParseResult parseBlockArgClause(OpAsmParser &parser, llvm::SmallVectorImpl< OpAsmParser::Argument > &entryBlockArgs, StringRef keyword, std::optional< MapParseArgs > mapArgs)
static ParseResult verifyScheduleModifiers(OpAsmParser &parser, SmallVectorImpl< SmallString< 12 >> &modifiers)
static LogicalResult verifyPrivateVarsMapping(TargetOp targetOp)
static ParseResult parseScheduleClause(OpAsmParser &parser, ClauseScheduleKindAttr &scheduleAttr, ScheduleModifierAttr &scheduleMod, UnitAttr &scheduleSimd, std::optional< OpAsmParser::UnresolvedOperand > &chunkSize, Type &chunkType)
schedule ::= schedule ( sched-list ) sched-list ::= sched-val | sched-val sched-list | sched-val ,...
static void printPrivateRegion(OpAsmPrinter &p, Operation *op, Region ®ion, ValueRange privateVars, TypeRange privateTypes, ArrayAttr privateSyms)
static Operation * getParentInSameDialect(Operation *thisOp)
static ParseResult parseAllocateAndAllocator(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &allocateVars, SmallVectorImpl< Type > &allocateTypes, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &allocatorVars, SmallVectorImpl< Type > &allocatorTypes)
Parse an allocate clause with allocators and a list of operands with types.
static void printMembersIndex(OpAsmPrinter &p, MapInfoOp op, ArrayAttr membersIdx)
static ParseResult parseTargetOpRegion(OpAsmParser &parser, Region ®ion, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &hasDeviceAddrVars, SmallVectorImpl< Type > &hasDeviceAddrTypes, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &hostEvalVars, SmallVectorImpl< Type > &hostEvalTypes, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &inReductionVars, SmallVectorImpl< Type > &inReductionTypes, DenseBoolArrayAttr &inReductionByref, ArrayAttr &inReductionSyms, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &mapVars, SmallVectorImpl< Type > &mapTypes, llvm::SmallVectorImpl< OpAsmParser::UnresolvedOperand > &privateVars, llvm::SmallVectorImpl< Type > &privateTypes, ArrayAttr &privateSyms, DenseI64ArrayAttr &privateMaps)
static void printCaptureType(OpAsmPrinter &p, Operation *op, VariableCaptureKindAttr mapCaptureType)
static LogicalResult verifyReductionVarList(Operation *op, std::optional< ArrayAttr > reductionSyms, OperandRange reductionVars, std::optional< ArrayRef< bool >> reductionByref)
Verifies Reduction Clause.
static ParseResult parseClauseWithRegionArgs(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &operands, SmallVectorImpl< Type > &types, SmallVectorImpl< OpAsmParser::Argument > ®ionPrivateArgs, ArrayAttr *symbols=nullptr, DenseI64ArrayAttr *mapIndices=nullptr, DenseBoolArrayAttr *byref=nullptr, ReductionModifierAttr *modifier=nullptr)
static Operation * findCapturedOmpOp(Operation *rootOp, bool checkSingleMandatoryExec, llvm::function_ref< bool(Operation *)> siblingAllowedFn)
static bool opInGlobalImplicitParallelRegion(Operation *op)
static void printUseDeviceAddrUseDevicePtrRegion(OpAsmPrinter &p, Operation *op, Region ®ion, ValueRange useDeviceAddrVars, TypeRange useDeviceAddrTypes, ValueRange useDevicePtrVars, TypeRange useDevicePtrTypes)
static LogicalResult verifyPrivateVarList(OpType &op)
static void printMapClause(OpAsmPrinter &p, Operation *op, IntegerAttr mapType)
Prints a map_entries map type from its numeric value out into its string format.
static ParseResult parseNumTasksClause(OpAsmParser &parser, ClauseNumTasksTypeAttr &numTasksMod, std::optional< OpAsmParser::UnresolvedOperand > &numTasks, Type &numTasksType)
static ParseResult parseMembersIndex(OpAsmParser &parser, ArrayAttr &membersIdx)
static ParseResult parseAlignedClause(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &alignedVars, SmallVectorImpl< Type > &alignedTypes, ArrayAttr &alignmentsAttr)
aligned ::= aligned ( aligned-list ) aligned-list := aligned-val | aligned-val aligned-list aligned-v...
static void printInReductionPrivateRegion(OpAsmPrinter &p, Operation *op, Region ®ion, ValueRange inReductionVars, TypeRange inReductionTypes, DenseBoolArrayAttr inReductionByref, ArrayAttr inReductionSyms, ValueRange privateVars, TypeRange privateTypes, ArrayAttr privateSyms)
static ParseResult parseCaptureType(OpAsmParser &parser, VariableCaptureKindAttr &mapCaptureType)
static ParseResult parseTaskReductionRegion(OpAsmParser &parser, Region ®ion, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &taskReductionVars, SmallVectorImpl< Type > &taskReductionTypes, DenseBoolArrayAttr &taskReductionByref, ArrayAttr &taskReductionSyms)
static ParseResult parseGrainsizeClause(OpAsmParser &parser, ClauseGrainsizeTypeAttr &grainsizeMod, std::optional< OpAsmParser::UnresolvedOperand > &grainsize, Type &grainsizeType)
static ParseResult parseCopyprivate(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > ©privateVars, SmallVectorImpl< Type > ©privateTypes, ArrayAttr ©privateSyms)
copyprivate-entry-list ::= copyprivate-entry | copyprivate-entry-list , copyprivate-entry copyprivate...
static LogicalResult verifyDependVarList(Operation *op, std::optional< ArrayAttr > dependKinds, OperandRange dependVars)
Verifies Depend clause.
static ParseResult parseDependVarList(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &dependVars, SmallVectorImpl< Type > &dependTypes, ArrayAttr &dependKinds)
depend-entry-list ::= depend-entry | depend-entry-list , depend-entry depend-entry ::= depend-kind ->...
static ParseResult parsePrivateReductionRegion(OpAsmParser &parser, Region ®ion, llvm::SmallVectorImpl< OpAsmParser::UnresolvedOperand > &privateVars, llvm::SmallVectorImpl< Type > &privateTypes, ArrayAttr &privateSyms, ReductionModifierAttr &reductionMod, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &reductionVars, SmallVectorImpl< Type > &reductionTypes, DenseBoolArrayAttr &reductionByref, ArrayAttr &reductionSyms)
static void printGrainsizeClause(OpAsmPrinter &p, Operation *op, ClauseGrainsizeTypeAttr grainsizeMod, Value grainsize, mlir::Type grainsizeType)
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity,...
This base class exposes generic asm parser hooks, usable across the various derived parsers.
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalEqual()=0
Parse a = token if present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseOptionalColon()=0
Parse a : token if present.
virtual ParseResult parseLSquare()=0
Parse a [ token.
virtual ParseResult parseRSquare()=0
Parse a ] token.
ParseResult parseInteger(IntT &result)
Parse an integer value from the stream.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual ParseResult parseColon()=0
Parse a : token.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseArrow()=0
Parse a '->' token.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseComma()=0
Parse a , token.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
Attributes are known-constant values of operations.
Block represents an ordered list of Operations.
ValueTypeRange< BlockArgListType > getArgumentTypes()
Return a range containing the types of the arguments for this block.
BlockArgument getArgument(unsigned i)
unsigned getNumArguments()
SuccessorRange getSuccessors()
BlockArgListType getArguments()
IntegerAttr getIntegerAttr(Type type, int64_t value)
IntegerType getIntegerType(unsigned width)
MLIRContext * getContext() const
Dialects are groups of MLIR operations, types and attributes, as well as behavior associated with the...
A class for computing basic dominance information.
bool dominates(Operation *a, Operation *b) const
Return true if operation A dominates operation B, i.e.
MLIRContext is the top-level object for a collection of MLIR operations.
Dialect * getLoadedDialect(StringRef name)
Get a registered IR dialect with the given namespace.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region ®ion, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult parseArgumentList(SmallVectorImpl< Argument > &result, Delimiter delimiter=Delimiter::None, bool allowType=false, bool allowAttrs=false)=0
Parse zero or more arguments with a specified surrounding delimiter.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
This class helps build Operations.
This class provides the API for ops that are known to be terminators.
This class indicates that the regions associated with this op don't have terminators.
This class implements the operand iterators for the Operation class.
Operation is the basic unit of execution within MLIR.
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Dialect * getDialect()
Return the dialect this operation is associated with, or nullptr if the associated dialect is not loa...
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
MLIRContext * getContext()
Return the context this operation is associated with.
unsigned getNumRegions()
Returns the number of regions held by this operation.
Location getLoc()
The source location the operation was defined or derived from.
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Block * getBlock()
Returns the operation block that contains this operation.
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
user_range getUsers()
Returns a range of all users.
Region * getParentRegion()
Returns the region to which the instruction belongs.
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class contains a list of basic blocks and a link to the parent operation it is attached to.
iterator_range< OpIterator > getOps()
BlockArgListType getArguments()
OpIterator op_begin()
Return iterators that walk the operations nested directly within this region.
unsigned getNumArguments()
Location getLoc()
Return a location for this region.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class represents a specific instance of an effect.
Resource * getResource() const
Return the resource that the effect applies to.
EffectT * getEffect() const
Return the effect being applied.
This class represents a collection of SymbolTables.
Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
type_range getType() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Base class for DenseArrayAttr that is instantiated and specialized for each supported element type be...
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< T > content)
Builder from ArrayRef<T>.
bool isReachableFromEntry(Block *a) const
Return true if the specified block is reachable from the entry block of its region.
Runtime
Potential runtimes for AMD GPU kernels.
TargetEnterDataOperands TargetEnterExitUpdateDataOperands
omp.target_enter_data, omp.target_exit_data and omp.target_update take the same clauses,...
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
This is the representation of an operand reference.
This class provides APIs and verifiers for ops with regions having a single block.
This represents an operation in an abstracted form, suitable for use with the builder APIs.
SmallVector< Value, 4 > operands
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
Region * addRegion()
Create a region that should be attached to the operation.