26 #include "llvm/ADT/ArrayRef.h"
27 #include "llvm/ADT/PostOrderIterator.h"
28 #include "llvm/ADT/STLExtras.h"
29 #include "llvm/ADT/STLForwardCompat.h"
30 #include "llvm/ADT/SmallString.h"
31 #include "llvm/ADT/StringExtras.h"
32 #include "llvm/ADT/StringRef.h"
33 #include "llvm/ADT/TypeSwitch.h"
34 #include "llvm/ADT/bit.h"
35 #include "llvm/Frontend/OpenMP/OMPConstants.h"
41 #include "mlir/Dialect/OpenMP/OpenMPOpsDialect.cpp.inc"
42 #include "mlir/Dialect/OpenMP/OpenMPOpsEnums.cpp.inc"
43 #include "mlir/Dialect/OpenMP/OpenMPOpsInterfaces.cpp.inc"
44 #include "mlir/Dialect/OpenMP/OpenMPTypeInterfaces.cpp.inc"
65 struct MemRefPointerLikeModel
66 :
public PointerLikeType::ExternalModel<MemRefPointerLikeModel,
69 return llvm::cast<MemRefType>(pointer).getElementType();
73 struct LLVMPointerPointerLikeModel
74 :
public PointerLikeType::ExternalModel<LLVMPointerPointerLikeModel,
75 LLVM::LLVMPointerType> {
80 void OpenMPDialect::initialize() {
83 #include "mlir/Dialect/OpenMP/OpenMPOps.cpp.inc"
86 #define GET_ATTRDEF_LIST
87 #include "mlir/Dialect/OpenMP/OpenMPOpsAttributes.cpp.inc"
90 #define GET_TYPEDEF_LIST
91 #include "mlir/Dialect/OpenMP/OpenMPOpsTypes.cpp.inc"
94 declarePromisedInterface<ConvertToLLVMPatternInterface, OpenMPDialect>();
96 MemRefType::attachInterface<MemRefPointerLikeModel>(*
getContext());
97 LLVM::LLVMPointerType::attachInterface<LLVMPointerPointerLikeModel>(
102 mlir::ModuleOp::attachInterface<mlir::omp::OffloadModuleDefaultModel>(
108 mlir::LLVM::GlobalOp::attachInterface<
111 mlir::LLVM::LLVMFuncOp::attachInterface<
114 mlir::func::FuncOp::attachInterface<
140 allocatorVars.push_back(operand);
141 allocatorTypes.push_back(type);
147 allocateVars.push_back(operand);
148 allocateTypes.push_back(type);
159 for (
unsigned i = 0; i < allocateVars.size(); ++i) {
160 std::string separator = i == allocateVars.size() - 1 ?
"" :
", ";
161 p << allocatorVars[i] <<
" : " << allocatorTypes[i] <<
" -> ";
162 p << allocateVars[i] <<
" : " << allocateTypes[i] << separator;
170 template <
typename ClauseAttr>
172 using ClauseT = decltype(std::declval<ClauseAttr>().getValue());
177 if (std::optional<ClauseT> enumValue = symbolizeEnum<ClauseT>(enumStr)) {
181 return parser.
emitError(loc,
"invalid clause value: '") << enumStr <<
"'";
184 template <
typename ClauseAttr>
186 p << stringifyEnum(attr.getValue());
209 linearVars.push_back(var);
210 linearTypes.push_back(type);
211 linearStepVars.push_back(stepVar);
220 size_t linearVarsSize = linearVars.size();
221 for (
unsigned i = 0; i < linearVarsSize; ++i) {
222 std::string separator = i == linearVarsSize - 1 ?
"" :
", ";
224 if (linearStepVars.size() > i)
225 p <<
" = " << linearStepVars[i];
226 p <<
" : " << linearVars[i].getType() << separator;
239 for (
const auto &it : nontemporalVars)
240 if (!nontemporalItems.insert(it).second)
241 return op->
emitOpError() <<
"nontemporal variable used more than once";
250 std::optional<ArrayAttr> alignments,
253 if (!alignedVars.empty()) {
254 if (!alignments || alignments->size() != alignedVars.size())
256 <<
"expected as many alignment values as aligned variables";
259 return op->
emitOpError() <<
"unexpected alignment values attribute";
265 for (
auto it : alignedVars)
266 if (!alignedItems.insert(it).second)
267 return op->
emitOpError() <<
"aligned variable used more than once";
273 for (
unsigned i = 0; i < (*alignments).size(); ++i) {
274 if (
auto intAttr = llvm::dyn_cast<IntegerAttr>((*alignments)[i])) {
275 if (intAttr.getValue().sle(0))
276 return op->
emitOpError() <<
"alignment should be greater than 0";
278 return op->
emitOpError() <<
"expected integer alignment";
292 ArrayAttr &alignmentsAttr) {
295 if (parser.parseOperand(alignedVars.emplace_back()) ||
296 parser.parseColonType(alignedTypes.emplace_back()) ||
297 parser.parseArrow() ||
298 parser.parseAttribute(alignmentVec.emplace_back())) {
312 std::optional<ArrayAttr> alignments) {
313 for (
unsigned i = 0; i < alignedVars.size(); ++i) {
316 p << alignedVars[i] <<
" : " << alignedVars[i].
getType();
317 p <<
" -> " << (*alignments)[i];
328 if (modifiers.size() > 2)
330 for (
const auto &mod : modifiers) {
333 auto symbol = symbolizeScheduleModifier(mod);
336 <<
" unknown modifier type: " << mod;
341 if (modifiers.size() == 1) {
342 if (symbolizeScheduleModifier(modifiers[0]) == ScheduleModifier::simd) {
343 modifiers.push_back(modifiers[0]);
344 modifiers[0] = stringifyScheduleModifier(ScheduleModifier::none);
346 }
else if (modifiers.size() == 2) {
349 if (symbolizeScheduleModifier(modifiers[0]) == ScheduleModifier::simd ||
350 symbolizeScheduleModifier(modifiers[1]) != ScheduleModifier::simd)
352 <<
" incorrect modifier order";
368 ScheduleModifierAttr &scheduleMod, UnitAttr &scheduleSimd,
369 std::optional<OpAsmParser::UnresolvedOperand> &chunkSize,
374 std::optional<mlir::omp::ClauseScheduleKind> schedule =
375 symbolizeClauseScheduleKind(keyword);
381 case ClauseScheduleKind::Static:
382 case ClauseScheduleKind::Dynamic:
383 case ClauseScheduleKind::Guided:
389 chunkSize = std::nullopt;
392 case ClauseScheduleKind::Auto:
394 chunkSize = std::nullopt;
403 modifiers.push_back(mod);
409 if (!modifiers.empty()) {
411 if (std::optional<ScheduleModifier> mod =
412 symbolizeScheduleModifier(modifiers[0])) {
415 return parser.
emitError(loc,
"invalid schedule modifier");
418 if (modifiers.size() > 1) {
419 assert(symbolizeScheduleModifier(modifiers[1]) == ScheduleModifier::simd);
429 ClauseScheduleKindAttr scheduleKind,
430 ScheduleModifierAttr scheduleMod,
431 UnitAttr scheduleSimd,
Value scheduleChunk,
432 Type scheduleChunkType) {
433 p << stringifyClauseScheduleKind(scheduleKind.getValue());
435 p <<
" = " << scheduleChunk <<
" : " << scheduleChunk.
getType();
437 p <<
", " << stringifyScheduleModifier(scheduleMod.getValue());
449 ClauseOrderKindAttr &order,
450 OrderModifierAttr &orderMod) {
455 if (std::optional<OrderModifier> enumValue =
456 symbolizeOrderModifier(enumStr)) {
464 if (std::optional<ClauseOrderKind> enumValue =
465 symbolizeClauseOrderKind(enumStr)) {
469 return parser.
emitError(loc,
"invalid clause value: '") << enumStr <<
"'";
473 ClauseOrderKindAttr order,
474 OrderModifierAttr orderMod) {
476 p << stringifyOrderModifier(orderMod.getValue()) <<
":";
478 p << stringifyClauseOrderKind(order.getValue());
481 template <
typename ClauseTypeAttr,
typename ClauseType>
484 std::optional<OpAsmParser::UnresolvedOperand> &operand,
486 std::optional<ClauseType> (*symbolizeClause)(StringRef),
487 StringRef clauseName) {
490 if (std::optional<ClauseType> enumValue = symbolizeClause(enumStr)) {
496 <<
"invalid " << clauseName <<
" modifier : '" << enumStr <<
"'";
506 <<
"expected " << clauseName <<
" operand";
509 if (operand.has_value()) {
517 template <
typename ClauseTypeAttr,
typename ClauseType>
520 ClauseTypeAttr prescriptiveness,
Value operand,
522 StringRef (*stringifyClauseType)(ClauseType)) {
524 if (prescriptiveness)
525 p << stringifyClauseType(prescriptiveness.getValue()) <<
", ";
528 p << operand <<
": " << operandType;
538 std::optional<OpAsmParser::UnresolvedOperand> &grainsize,
539 Type &grainsizeType) {
540 return parseGranularityClause<ClauseGrainsizeTypeAttr, ClauseGrainsizeType>(
541 parser, grainsizeMod, grainsize, grainsizeType,
542 &symbolizeClauseGrainsizeType,
"grainsize");
546 ClauseGrainsizeTypeAttr grainsizeMod,
548 printGranularityClause<ClauseGrainsizeTypeAttr, ClauseGrainsizeType>(
549 p, op, grainsizeMod, grainsize, grainsizeType,
550 &stringifyClauseGrainsizeType);
560 std::optional<OpAsmParser::UnresolvedOperand> &numTasks,
561 Type &numTasksType) {
562 return parseGranularityClause<ClauseNumTasksTypeAttr, ClauseNumTasksType>(
563 parser, numTasksMod, numTasks, numTasksType, &symbolizeClauseNumTasksType,
568 ClauseNumTasksTypeAttr numTasksMod,
570 printGranularityClause<ClauseNumTasksTypeAttr, ClauseNumTasksType>(
571 p, op, numTasksMod, numTasks, numTasksType, &stringifyClauseNumTasksType);
579 struct MapParseArgs {
584 : vars(vars), types(types) {}
586 struct PrivateParseArgs {
590 UnitAttr &needsBarrier;
594 UnitAttr &needsBarrier,
596 : vars(vars), types(types), syms(syms), needsBarrier(needsBarrier),
597 mapIndices(mapIndices) {}
600 struct ReductionParseArgs {
605 ReductionModifierAttr *modifier;
608 ArrayAttr &syms, ReductionModifierAttr *mod =
nullptr)
609 : vars(vars), types(types), byref(byref), syms(syms), modifier(mod) {}
612 struct AllRegionParseArgs {
613 std::optional<MapParseArgs> hasDeviceAddrArgs;
614 std::optional<MapParseArgs> hostEvalArgs;
615 std::optional<ReductionParseArgs> inReductionArgs;
616 std::optional<MapParseArgs> mapArgs;
617 std::optional<PrivateParseArgs> privateArgs;
618 std::optional<ReductionParseArgs> reductionArgs;
619 std::optional<ReductionParseArgs> taskReductionArgs;
620 std::optional<MapParseArgs> useDeviceAddrArgs;
621 std::optional<MapParseArgs> useDevicePtrArgs;
626 return "private_barrier";
636 ReductionModifierAttr *modifier =
nullptr,
637 UnitAttr *needsBarrier =
nullptr) {
641 unsigned regionArgOffset = regionPrivateArgs.size();
651 std::optional<ReductionModifier> enumValue =
652 symbolizeReductionModifier(enumStr);
653 if (!enumValue.has_value())
662 isByRefVec.push_back(
663 parser.parseOptionalKeyword(
"byref").succeeded());
665 if (symbols && parser.parseAttribute(symbolVec.emplace_back()))
668 if (parser.parseOperand(operands.emplace_back()) ||
669 parser.parseArrow() ||
670 parser.parseArgument(regionPrivateArgs.emplace_back()))
674 if (parser.parseOptionalLSquare().succeeded()) {
675 if (parser.parseKeyword(
"map_idx") || parser.parseEqual() ||
676 parser.parseInteger(mapIndicesVec.emplace_back()) ||
677 parser.parseRSquare())
680 mapIndicesVec.push_back(-1);
692 if (parser.parseType(types.emplace_back()))
699 if (operands.size() != types.size())
711 auto *argsBegin = regionPrivateArgs.begin();
713 argsBegin + regionArgOffset + types.size());
714 for (
auto [prv, type] : llvm::zip_equal(argsSubrange, types)) {
723 if (!mapIndicesVec.empty())
736 StringRef keyword, std::optional<MapParseArgs> mapArgs) {
751 StringRef keyword, std::optional<PrivateParseArgs> privateArgs) {
757 parser, privateArgs->vars, privateArgs->types, entryBlockArgs,
758 &privateArgs->syms, privateArgs->mapIndices,
nullptr,
759 nullptr, &privateArgs->needsBarrier)))
768 StringRef keyword, std::optional<ReductionParseArgs> reductionArgs) {
773 parser, reductionArgs->vars, reductionArgs->types, entryBlockArgs,
774 &reductionArgs->syms,
nullptr, &reductionArgs->byref,
775 reductionArgs->modifier)))
782 AllRegionParseArgs args) {
786 args.hasDeviceAddrArgs)))
788 <<
"invalid `has_device_addr` format";
793 <<
"invalid `host_eval` format";
796 args.inReductionArgs)))
798 <<
"invalid `in_reduction` format";
803 <<
"invalid `map_entries` format";
808 <<
"invalid `private` format";
811 args.reductionArgs)))
813 <<
"invalid `reduction` format";
816 args.taskReductionArgs)))
818 <<
"invalid `task_reduction` format";
821 args.useDeviceAddrArgs)))
823 <<
"invalid `use_device_addr` format";
826 args.useDevicePtrArgs)))
828 <<
"invalid `use_device_addr` format";
849 AllRegionParseArgs args;
850 args.hasDeviceAddrArgs.emplace(hasDeviceAddrVars, hasDeviceAddrTypes);
851 args.hostEvalArgs.emplace(hostEvalVars, hostEvalTypes);
852 args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
853 inReductionByref, inReductionSyms);
854 args.mapArgs.emplace(mapVars, mapTypes);
855 args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
856 privateNeedsBarrier, &privateMaps);
867 UnitAttr &privateNeedsBarrier) {
868 AllRegionParseArgs args;
869 args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
870 inReductionByref, inReductionSyms);
871 args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
872 privateNeedsBarrier);
883 UnitAttr &privateNeedsBarrier, ReductionModifierAttr &reductionMod,
886 ArrayAttr &reductionSyms) {
887 AllRegionParseArgs args;
888 args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
889 inReductionByref, inReductionSyms);
890 args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
891 privateNeedsBarrier);
892 args.reductionArgs.emplace(reductionVars, reductionTypes, reductionByref,
893 reductionSyms, &reductionMod);
901 UnitAttr &privateNeedsBarrier) {
902 AllRegionParseArgs args;
903 args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
904 privateNeedsBarrier);
912 UnitAttr &privateNeedsBarrier, ReductionModifierAttr &reductionMod,
915 ArrayAttr &reductionSyms) {
916 AllRegionParseArgs args;
917 args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
918 privateNeedsBarrier);
919 args.reductionArgs.emplace(reductionVars, reductionTypes, reductionByref,
920 reductionSyms, &reductionMod);
929 AllRegionParseArgs args;
930 args.taskReductionArgs.emplace(taskReductionVars, taskReductionTypes,
931 taskReductionByref, taskReductionSyms);
941 AllRegionParseArgs args;
942 args.useDeviceAddrArgs.emplace(useDeviceAddrVars, useDeviceAddrTypes);
943 args.useDevicePtrArgs.emplace(useDevicePtrVars, useDevicePtrTypes);
952 struct MapPrintArgs {
957 struct PrivatePrintArgs {
961 UnitAttr needsBarrier;
965 : vars(vars), types(types), syms(syms), needsBarrier(needsBarrier),
966 mapIndices(mapIndices) {}
968 struct ReductionPrintArgs {
973 ReductionModifierAttr modifier;
975 ArrayAttr syms, ReductionModifierAttr mod =
nullptr)
976 : vars(vars), types(types), byref(byref), syms(syms), modifier(mod) {}
978 struct AllRegionPrintArgs {
979 std::optional<MapPrintArgs> hasDeviceAddrArgs;
980 std::optional<MapPrintArgs> hostEvalArgs;
981 std::optional<ReductionPrintArgs> inReductionArgs;
982 std::optional<MapPrintArgs> mapArgs;
983 std::optional<PrivatePrintArgs> privateArgs;
984 std::optional<ReductionPrintArgs> reductionArgs;
985 std::optional<ReductionPrintArgs> taskReductionArgs;
986 std::optional<MapPrintArgs> useDeviceAddrArgs;
987 std::optional<MapPrintArgs> useDevicePtrArgs;
996 ReductionModifierAttr modifier =
nullptr, UnitAttr needsBarrier =
nullptr) {
997 if (argsSubrange.empty())
1000 p << clauseName <<
"(";
1003 p <<
"mod: " << stringifyReductionModifier(modifier.getValue()) <<
", ";
1020 llvm::interleaveComma(llvm::zip_equal(operands, argsSubrange, symbols,
1021 mapIndices.asArrayRef(),
1022 byref.asArrayRef()),
1024 auto [op, arg, sym, map, isByRef] = t;
1030 p << op <<
" -> " << arg;
1033 p <<
" [map_idx=" << map <<
"]";
1036 llvm::interleaveComma(types, p);
1044 StringRef clauseName,
ValueRange argsSubrange,
1045 std::optional<MapPrintArgs> mapArgs) {
1052 StringRef clauseName,
ValueRange argsSubrange,
1053 std::optional<PrivatePrintArgs> privateArgs) {
1056 p, ctx, clauseName, argsSubrange, privateArgs->vars, privateArgs->types,
1057 privateArgs->syms, privateArgs->mapIndices,
nullptr,
1058 nullptr, privateArgs->needsBarrier);
1064 std::optional<ReductionPrintArgs> reductionArgs) {
1067 reductionArgs->vars, reductionArgs->types,
1068 reductionArgs->syms,
nullptr,
1069 reductionArgs->byref, reductionArgs->modifier);
1073 const AllRegionPrintArgs &args) {
1074 auto iface = llvm::cast<mlir::omp::BlockArgOpenMPOpInterface>(op);
1078 iface.getHasDeviceAddrBlockArgs(),
1079 args.hasDeviceAddrArgs);
1083 args.inReductionArgs);
1089 args.reductionArgs);
1091 iface.getTaskReductionBlockArgs(),
1092 args.taskReductionArgs);
1094 iface.getUseDeviceAddrBlockArgs(),
1095 args.useDeviceAddrArgs);
1097 iface.getUseDevicePtrBlockArgs(), args.useDevicePtrArgs);
1111 TypeRange privateTypes, ArrayAttr privateSyms, UnitAttr privateNeedsBarrier,
1113 AllRegionPrintArgs args;
1114 args.hasDeviceAddrArgs.emplace(hasDeviceAddrVars, hasDeviceAddrTypes);
1115 args.hostEvalArgs.emplace(hostEvalVars, hostEvalTypes);
1116 args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
1117 inReductionByref, inReductionSyms);
1118 args.mapArgs.emplace(mapVars, mapTypes);
1119 args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
1120 privateNeedsBarrier, privateMaps);
1128 ArrayAttr privateSyms, UnitAttr privateNeedsBarrier) {
1129 AllRegionPrintArgs args;
1130 args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
1131 inReductionByref, inReductionSyms);
1132 args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
1133 privateNeedsBarrier,
1142 ArrayAttr privateSyms, UnitAttr privateNeedsBarrier,
1143 ReductionModifierAttr reductionMod,
ValueRange reductionVars,
1145 ArrayAttr reductionSyms) {
1146 AllRegionPrintArgs args;
1147 args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
1148 inReductionByref, inReductionSyms);
1149 args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
1150 privateNeedsBarrier,
1152 args.reductionArgs.emplace(reductionVars, reductionTypes, reductionByref,
1153 reductionSyms, reductionMod);
1159 ArrayAttr privateSyms,
1160 UnitAttr privateNeedsBarrier) {
1161 AllRegionPrintArgs args;
1162 args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
1163 privateNeedsBarrier,
1170 TypeRange privateTypes, ArrayAttr privateSyms, UnitAttr privateNeedsBarrier,
1171 ReductionModifierAttr reductionMod,
ValueRange reductionVars,
1173 ArrayAttr reductionSyms) {
1174 AllRegionPrintArgs args;
1175 args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
1176 privateNeedsBarrier,
1178 args.reductionArgs.emplace(reductionVars, reductionTypes, reductionByref,
1179 reductionSyms, reductionMod);
1188 ArrayAttr taskReductionSyms) {
1189 AllRegionPrintArgs args;
1190 args.taskReductionArgs.emplace(taskReductionVars, taskReductionTypes,
1191 taskReductionByref, taskReductionSyms);
1201 AllRegionPrintArgs args;
1202 args.useDeviceAddrArgs.emplace(useDeviceAddrVars, useDeviceAddrTypes);
1203 args.useDevicePtrArgs.emplace(useDevicePtrVars, useDevicePtrTypes);
1208 static LogicalResult
1212 if (!reductionVars.empty()) {
1213 if (!reductionSyms || reductionSyms->size() != reductionVars.size())
1215 <<
"expected as many reduction symbol references "
1216 "as reduction variables";
1217 if (reductionByref && reductionByref->size() != reductionVars.size())
1218 return op->
emitError() <<
"expected as many reduction variable by "
1219 "reference attributes as reduction variables";
1222 return op->
emitOpError() <<
"unexpected reduction symbol references";
1229 for (
auto args : llvm::zip(reductionVars, *reductionSyms)) {
1230 Value accum = std::get<0>(args);
1232 if (!accumulators.insert(accum).second)
1233 return op->
emitOpError() <<
"accumulator variable used more than once";
1236 auto symbolRef = llvm::cast<SymbolRefAttr>(std::get<1>(args));
1238 SymbolTable::lookupNearestSymbolFrom<DeclareReductionOp>(op, symbolRef);
1240 return op->
emitOpError() <<
"expected symbol reference " << symbolRef
1241 <<
" to point to a reduction declaration";
1243 if (decl.getAccumulatorType() && decl.getAccumulatorType() != varType)
1245 <<
"expected accumulator (" << varType
1246 <<
") to be the same type as reduction declaration ("
1247 << decl.getAccumulatorType() <<
")";
1266 if (parser.parseOperand(copyprivateVars.emplace_back()) ||
1267 parser.parseArrow() ||
1268 parser.parseAttribute(symsVec.emplace_back()) ||
1269 parser.parseColonType(copyprivateTypes.emplace_back()))
1283 std::optional<ArrayAttr> copyprivateSyms) {
1284 if (!copyprivateSyms.has_value())
1286 llvm::interleaveComma(
1287 llvm::zip(copyprivateVars, *copyprivateSyms, copyprivateTypes), p,
1288 [&](
const auto &args) {
1289 p << std::get<0>(args) <<
" -> " << std::get<1>(args) <<
" : "
1290 << std::get<2>(args);
1295 static LogicalResult
1297 std::optional<ArrayAttr> copyprivateSyms) {
1298 size_t copyprivateSymsSize =
1299 copyprivateSyms.has_value() ? copyprivateSyms->size() : 0;
1300 if (copyprivateSymsSize != copyprivateVars.size())
1301 return op->
emitOpError() <<
"inconsistent number of copyprivate vars (= "
1302 << copyprivateVars.size()
1303 <<
") and functions (= " << copyprivateSymsSize
1304 <<
"), both must be equal";
1305 if (!copyprivateSyms.has_value())
1308 for (
auto copyprivateVarAndSym :
1309 llvm::zip(copyprivateVars, *copyprivateSyms)) {
1311 llvm::cast<SymbolRefAttr>(std::get<1>(copyprivateVarAndSym));
1312 std::optional<std::variant<mlir::func::FuncOp, mlir::LLVM::LLVMFuncOp>>
1314 if (mlir::func::FuncOp mlirFuncOp =
1315 SymbolTable::lookupNearestSymbolFrom<mlir::func::FuncOp>(op,
1317 funcOp = mlirFuncOp;
1318 else if (mlir::LLVM::LLVMFuncOp llvmFuncOp =
1319 SymbolTable::lookupNearestSymbolFrom<mlir::LLVM::LLVMFuncOp>(
1321 funcOp = llvmFuncOp;
1323 auto getNumArguments = [&] {
1324 return std::visit([](
auto &f) {
return f.getNumArguments(); }, *funcOp);
1327 auto getArgumentType = [&](
unsigned i) {
1328 return std::visit([i](
auto &f) {
return f.getArgumentTypes()[i]; },
1333 return op->
emitOpError() <<
"expected symbol reference " << symbolRef
1334 <<
" to point to a copy function";
1336 if (getNumArguments() != 2)
1338 <<
"expected copy function " << symbolRef <<
" to have 2 operands";
1340 Type argTy = getArgumentType(0);
1341 if (argTy != getArgumentType(1))
1342 return op->
emitOpError() <<
"expected copy function " << symbolRef
1343 <<
" arguments to have the same type";
1345 Type varType = std::get<0>(copyprivateVarAndSym).getType();
1346 if (argTy != varType)
1348 <<
"expected copy function arguments' type (" << argTy
1349 <<
") to be the same as copyprivate variable's type (" << varType
1370 if (parser.parseKeyword(&keyword) || parser.parseArrow() ||
1371 parser.parseOperand(dependVars.emplace_back()) ||
1372 parser.parseColonType(dependTypes.emplace_back()))
1374 if (std::optional<ClauseTaskDepend> keywordDepend =
1375 (symbolizeClauseTaskDepend(keyword)))
1376 kindsVec.emplace_back(
1377 ClauseTaskDependAttr::get(parser.getContext(), *keywordDepend));
1391 std::optional<ArrayAttr> dependKinds) {
1393 for (
unsigned i = 0, e = dependKinds->size(); i < e; ++i) {
1396 p << stringifyClauseTaskDepend(
1397 llvm::cast<mlir::omp::ClauseTaskDependAttr>((*dependKinds)[i])
1399 <<
" -> " << dependVars[i] <<
" : " << dependTypes[i];
1405 std::optional<ArrayAttr> dependKinds,
1407 if (!dependVars.empty()) {
1408 if (!dependKinds || dependKinds->size() != dependVars.size())
1409 return op->
emitOpError() <<
"expected as many depend values"
1410 " as depend variables";
1412 if (dependKinds && !dependKinds->empty())
1413 return op->
emitOpError() <<
"unexpected depend values";
1429 IntegerAttr &hintAttr) {
1430 StringRef hintKeyword;
1436 auto parseKeyword = [&]() -> ParseResult {
1439 if (hintKeyword ==
"uncontended")
1441 else if (hintKeyword ==
"contended")
1443 else if (hintKeyword ==
"nonspeculative")
1445 else if (hintKeyword ==
"speculative")
1449 << hintKeyword <<
" is not a valid hint";
1460 IntegerAttr hintAttr) {
1461 int64_t hint = hintAttr.getInt();
1469 auto bitn = [](
int value,
int n) ->
bool {
return value & (1 << n); };
1471 bool uncontended = bitn(hint, 0);
1472 bool contended = bitn(hint, 1);
1473 bool nonspeculative = bitn(hint, 2);
1474 bool speculative = bitn(hint, 3);
1478 hints.push_back(
"uncontended");
1480 hints.push_back(
"contended");
1482 hints.push_back(
"nonspeculative");
1484 hints.push_back(
"speculative");
1486 llvm::interleaveComma(hints, p);
1493 auto bitn = [](
int value,
int n) ->
bool {
return value & (1 << n); };
1495 bool uncontended = bitn(hint, 0);
1496 bool contended = bitn(hint, 1);
1497 bool nonspeculative = bitn(hint, 2);
1498 bool speculative = bitn(hint, 3);
1500 if (uncontended && contended)
1501 return op->
emitOpError() <<
"the hints omp_sync_hint_uncontended and "
1502 "omp_sync_hint_contended cannot be combined";
1503 if (nonspeculative && speculative)
1504 return op->
emitOpError() <<
"the hints omp_sync_hint_nonspeculative and "
1505 "omp_sync_hint_speculative cannot be combined.";
1515 llvm::omp::OpenMPOffloadMappingFlags flag) {
1516 return value & llvm::to_underlying(flag);
1525 llvm::omp::OpenMPOffloadMappingFlags mapTypeBits =
1526 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_NONE;
1530 auto parseTypeAndMod = [&]() -> ParseResult {
1531 StringRef mapTypeMod;
1535 if (mapTypeMod ==
"always")
1536 mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ALWAYS;
1538 if (mapTypeMod ==
"implicit")
1539 mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_IMPLICIT;
1541 if (mapTypeMod ==
"ompx_hold")
1542 mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_OMPX_HOLD;
1544 if (mapTypeMod ==
"close")
1545 mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_CLOSE;
1547 if (mapTypeMod ==
"present")
1548 mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_PRESENT;
1550 if (mapTypeMod ==
"to")
1551 mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO;
1553 if (mapTypeMod ==
"from")
1554 mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM;
1556 if (mapTypeMod ==
"tofrom")
1557 mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO |
1558 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM;
1560 if (mapTypeMod ==
"delete")
1561 mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_DELETE;
1563 if (mapTypeMod ==
"return_param")
1564 mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_RETURN_PARAM;
1574 llvm::to_underlying(mapTypeBits));
1582 IntegerAttr mapType) {
1583 uint64_t mapTypeBits = mapType.getUInt();
1585 bool emitAllocRelease =
true;
1591 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ALWAYS))
1592 mapTypeStrs.push_back(
"always");
1594 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_IMPLICIT))
1595 mapTypeStrs.push_back(
"implicit");
1597 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_OMPX_HOLD))
1598 mapTypeStrs.push_back(
"ompx_hold");
1600 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_CLOSE))
1601 mapTypeStrs.push_back(
"close");
1603 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_PRESENT))
1604 mapTypeStrs.push_back(
"present");
1610 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO);
1612 mapTypeBits, llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM);
1614 emitAllocRelease =
false;
1615 mapTypeStrs.push_back(
"tofrom");
1617 emitAllocRelease =
false;
1618 mapTypeStrs.push_back(
"from");
1620 emitAllocRelease =
false;
1621 mapTypeStrs.push_back(
"to");
1624 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_DELETE)) {
1625 emitAllocRelease =
false;
1626 mapTypeStrs.push_back(
"delete");
1630 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_RETURN_PARAM)) {
1631 emitAllocRelease =
false;
1632 mapTypeStrs.push_back(
"return_param");
1634 if (emitAllocRelease)
1635 mapTypeStrs.push_back(
"exit_release_or_enter_alloc");
1637 for (
unsigned int i = 0; i < mapTypeStrs.size(); ++i) {
1638 p << mapTypeStrs[i];
1639 if (i + 1 < mapTypeStrs.size()) {
1646 ArrayAttr &membersIdx) {
1649 auto parseIndices = [&]() -> ParseResult {
1654 APInt(64, value,
false)));
1672 if (!memberIdxs.empty())
1679 ArrayAttr membersIdx) {
1683 llvm::interleaveComma(membersIdx, p, [&p](
Attribute v) {
1685 auto memberIdx = cast<ArrayAttr>(v);
1686 llvm::interleaveComma(memberIdx.getValue(), p, [&p](
Attribute v2) {
1687 p << cast<IntegerAttr>(v2).getInt();
1694 VariableCaptureKindAttr mapCaptureType) {
1695 std::string typeCapStr;
1696 llvm::raw_string_ostream typeCap(typeCapStr);
1697 if (mapCaptureType.getValue() == mlir::omp::VariableCaptureKind::ByRef)
1699 if (mapCaptureType.getValue() == mlir::omp::VariableCaptureKind::ByCopy)
1700 typeCap <<
"ByCopy";
1701 if (mapCaptureType.getValue() == mlir::omp::VariableCaptureKind::VLAType)
1702 typeCap <<
"VLAType";
1703 if (mapCaptureType.getValue() == mlir::omp::VariableCaptureKind::This)
1709 VariableCaptureKindAttr &mapCaptureType) {
1710 StringRef mapCaptureKey;
1714 if (mapCaptureKey ==
"This")
1716 parser.
getContext(), mlir::omp::VariableCaptureKind::This);
1717 if (mapCaptureKey ==
"ByRef")
1719 parser.
getContext(), mlir::omp::VariableCaptureKind::ByRef);
1720 if (mapCaptureKey ==
"ByCopy")
1722 parser.
getContext(), mlir::omp::VariableCaptureKind::ByCopy);
1723 if (mapCaptureKey ==
"VLAType")
1725 parser.
getContext(), mlir::omp::VariableCaptureKind::VLAType);
1734 for (
auto mapOp : mapVars) {
1735 if (!mapOp.getDefiningOp())
1738 if (
auto mapInfoOp = mapOp.getDefiningOp<mlir::omp::MapInfoOp>()) {
1739 uint64_t mapTypeBits = mapInfoOp.getMapType();
1742 mapTypeBits, llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO);
1744 mapTypeBits, llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM);
1746 mapTypeBits, llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_DELETE);
1749 mapTypeBits, llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ALWAYS);
1751 mapTypeBits, llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_CLOSE);
1753 mapTypeBits, llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_IMPLICIT);
1755 if ((isa<TargetDataOp>(op) || isa<TargetOp>(op)) && del)
1757 "to, from, tofrom and alloc map types are permitted");
1759 if (isa<TargetEnterDataOp>(op) && (from || del))
1760 return emitError(op->
getLoc(),
"to and alloc map types are permitted");
1762 if (isa<TargetExitDataOp>(op) && to)
1764 "from, release and delete map types are permitted");
1766 if (isa<TargetUpdateOp>(op)) {
1769 "at least one of to or from map types must be "
1770 "specified, other map types are not permitted");
1775 "at least one of to or from map types must be "
1776 "specified, other map types are not permitted");
1779 auto updateVar = mapInfoOp.getVarPtr();
1781 if ((to && from) || (to && updateFromVars.contains(updateVar)) ||
1782 (from && updateToVars.contains(updateVar))) {
1785 "either to or from map types can be specified, not both");
1788 if (always || close || implicit) {
1791 "present, mapper and iterator map type modifiers are permitted");
1794 to ? updateToVars.insert(updateVar) : updateFromVars.insert(updateVar);
1796 }
else if (!isa<DeclareMapperInfoOp>(op)) {
1798 "map argument is not a map entry operation");
1806 std::optional<DenseI64ArrayAttr> privateMapIndices =
1807 targetOp.getPrivateMapsAttr();
1810 if (!privateMapIndices.has_value() || !privateMapIndices.value())
1815 if (privateMapIndices.value().size() !=
1816 static_cast<int64_t
>(privateVars.size()))
1817 return emitError(targetOp.getLoc(),
"sizes of `private` operand range and "
1818 "`private_maps` attribute mismatch");
1828 StringRef clauseName,
1830 for (
Value var : vars)
1831 if (!llvm::isa_and_present<MapInfoOp>(var.getDefiningOp()))
1833 <<
"'" << clauseName
1834 <<
"' arguments must be defined by 'omp.map.info' ops";
1839 if (getMapperId() &&
1840 !SymbolTable::lookupNearestSymbolFrom<omp::DeclareMapperOp>(
1841 *
this, getMapperIdAttr())) {
1856 const TargetDataOperands &clauses) {
1857 TargetDataOp::build(builder, state, clauses.device, clauses.ifExpr,
1858 clauses.mapVars, clauses.useDeviceAddrVars,
1859 clauses.useDevicePtrVars);
1863 if (getMapVars().empty() && getUseDevicePtrVars().empty() &&
1864 getUseDeviceAddrVars().empty()) {
1866 "At least one of map, use_device_ptr_vars, or "
1867 "use_device_addr_vars operand must be present");
1871 getUseDevicePtrVars())))
1875 getUseDeviceAddrVars())))
1885 void TargetEnterDataOp::build(
1889 TargetEnterDataOp::build(builder, state,
1891 clauses.dependVars, clauses.device, clauses.ifExpr,
1892 clauses.mapVars, clauses.nowait);
1896 LogicalResult verifyDependVars =
1898 return failed(verifyDependVars) ? verifyDependVars
1909 TargetExitDataOp::build(builder, state,
1911 clauses.dependVars, clauses.device, clauses.ifExpr,
1912 clauses.mapVars, clauses.nowait);
1916 LogicalResult verifyDependVars =
1918 return failed(verifyDependVars) ? verifyDependVars
1929 TargetUpdateOp::build(builder, state,
makeArrayAttr(ctx, clauses.dependKinds),
1930 clauses.dependVars, clauses.device, clauses.ifExpr,
1931 clauses.mapVars, clauses.nowait);
1935 LogicalResult verifyDependVars =
1937 return failed(verifyDependVars) ? verifyDependVars
1946 const TargetOperands &clauses) {
1950 TargetOp::build(builder, state, {}, {},
1952 clauses.dependVars, clauses.device, clauses.hasDeviceAddrVars,
1953 clauses.hostEvalVars, clauses.ifExpr,
1955 nullptr, clauses.isDevicePtrVars,
1956 clauses.mapVars, clauses.nowait, clauses.privateVars,
1958 clauses.privateNeedsBarrier, clauses.threadLimit,
1967 getHasDeviceAddrVars())))
1976 LogicalResult TargetOp::verifyRegions() {
1977 auto teamsOps = getOps<TeamsOp>();
1978 if (std::distance(teamsOps.begin(), teamsOps.end()) > 1)
1979 return emitError(
"target containing multiple 'omp.teams' nested ops");
1982 Operation *capturedOp = getInnermostCapturedOmpOp();
1983 TargetRegionFlags execFlags = getKernelExecFlags(capturedOp);
1984 for (
Value hostEvalArg :
1985 cast<BlockArgOpenMPOpInterface>(getOperation()).getHostEvalBlockArgs()) {
1987 if (
auto teamsOp = dyn_cast<TeamsOp>(user)) {
1988 if (llvm::is_contained({teamsOp.getNumTeamsLower(),
1989 teamsOp.getNumTeamsUpper(),
1990 teamsOp.getThreadLimit()},
1994 return emitOpError() <<
"host_eval argument only legal as 'num_teams' "
1995 "and 'thread_limit' in 'omp.teams'";
1997 if (
auto parallelOp = dyn_cast<ParallelOp>(user)) {
1998 if (bitEnumContainsAny(execFlags, TargetRegionFlags::spmd) &&
1999 parallelOp->isAncestor(capturedOp) &&
2000 hostEvalArg == parallelOp.getNumThreads())
2003 return emitOpError()
2004 <<
"host_eval argument only legal as 'num_threads' in "
2005 "'omp.parallel' when representing target SPMD";
2007 if (
auto loopNestOp = dyn_cast<LoopNestOp>(user)) {
2008 if (bitEnumContainsAny(execFlags, TargetRegionFlags::trip_count) &&
2009 loopNestOp.getOperation() == capturedOp &&
2010 (llvm::is_contained(loopNestOp.getLoopLowerBounds(), hostEvalArg) ||
2011 llvm::is_contained(loopNestOp.getLoopUpperBounds(), hostEvalArg) ||
2012 llvm::is_contained(loopNestOp.getLoopSteps(), hostEvalArg)))
2015 return emitOpError() <<
"host_eval argument only legal as loop bounds "
2016 "and steps in 'omp.loop_nest' when trip count "
2017 "must be evaluated in the host";
2020 return emitOpError() <<
"host_eval argument illegal use in '"
2021 << user->getName() <<
"' operation";
2030 assert(rootOp &&
"expected valid operation");
2042 return WalkResult::advance();
2047 bool isOmpDialect = op->
getDialect() == ompDialect;
2049 if (!isOmpDialect || !hasRegions)
2050 return WalkResult::skip();
2056 if (checkSingleMandatoryExec) {
2061 if (successor->isReachable(parentBlock))
2062 return WalkResult::interrupt();
2064 for (
Block &block : *parentRegion)
2066 !domInfo.
dominates(parentBlock, &block))
2067 return WalkResult::interrupt();
2073 if (&sibling != op && !siblingAllowedFn(&sibling))
2074 return WalkResult::interrupt();
2079 return llvm::isa<LoopNestOp>(op) ? WalkResult::interrupt()
2080 : WalkResult::advance();
2086 Operation *TargetOp::getInnermostCapturedOmpOp() {
2099 if (
auto memOp = dyn_cast<MemoryEffectOpInterface>(sibling)) {
2102 memOp.getEffects(effects);
2103 return !llvm::any_of(
2105 return isa<MemoryEffects::Write>(effect.
getEffect()) &&
2106 isa<SideEffects::AutomaticAllocationScopeResource>(
2114 TargetRegionFlags TargetOp::getKernelExecFlags(
Operation *capturedOp) {
2119 assert((!capturedOp ||
2120 (targetOp && targetOp.getInnermostCapturedOmpOp() == capturedOp)) &&
2121 "unexpected captured op");
2124 if (!isa_and_present<LoopNestOp>(capturedOp))
2125 return TargetRegionFlags::generic;
2129 cast<LoopNestOp>(capturedOp).gatherWrappers(loopWrappers);
2130 assert(!loopWrappers.empty());
2132 LoopWrapperInterface *innermostWrapper = loopWrappers.begin();
2133 if (isa<SimdOp>(innermostWrapper))
2134 innermostWrapper = std::next(innermostWrapper);
2136 auto numWrappers = std::distance(innermostWrapper, loopWrappers.end());
2137 if (numWrappers != 1 && numWrappers != 2)
2138 return TargetRegionFlags::generic;
2141 if (numWrappers == 2) {
2142 if (!isa<WsloopOp>(innermostWrapper))
2143 return TargetRegionFlags::generic;
2145 innermostWrapper = std::next(innermostWrapper);
2146 if (!isa<DistributeOp>(innermostWrapper))
2147 return TargetRegionFlags::generic;
2150 if (!isa_and_present<ParallelOp>(parallelOp))
2151 return TargetRegionFlags::generic;
2154 if (!isa_and_present<TeamsOp>(teamsOp))
2155 return TargetRegionFlags::generic;
2157 if (teamsOp->
getParentOp() == targetOp.getOperation())
2158 return TargetRegionFlags::spmd | TargetRegionFlags::trip_count;
2161 else if (isa<DistributeOp, LoopOp>(innermostWrapper)) {
2163 if (!isa_and_present<TeamsOp>(teamsOp))
2164 return TargetRegionFlags::generic;
2166 if (teamsOp->
getParentOp() != targetOp.getOperation())
2167 return TargetRegionFlags::generic;
2169 if (isa<LoopOp>(innermostWrapper))
2170 return TargetRegionFlags::spmd | TargetRegionFlags::trip_count;
2180 Dialect *ompDialect = targetOp->getDialect();
2184 return sibling && (ompDialect != sibling->
getDialect() ||
2188 TargetRegionFlags result =
2189 TargetRegionFlags::generic | TargetRegionFlags::trip_count;
2194 while (nestedCapture->
getParentOp() != capturedOp)
2197 return isa<ParallelOp>(nestedCapture) ? result | TargetRegionFlags::spmd
2201 else if (isa<WsloopOp>(innermostWrapper)) {
2203 if (!isa_and_present<ParallelOp>(parallelOp))
2204 return TargetRegionFlags::generic;
2206 if (parallelOp->
getParentOp() == targetOp.getOperation())
2207 return TargetRegionFlags::spmd;
2210 return TargetRegionFlags::generic;
2219 ParallelOp::build(builder, state,
ValueRange(),
2226 state.addAttributes(attributes);
2230 const ParallelOperands &clauses) {
2232 ParallelOp::build(builder, state, clauses.allocateVars, clauses.allocatorVars,
2233 clauses.ifExpr, clauses.numThreads, clauses.privateVars,
2235 clauses.privateNeedsBarrier, clauses.procBindKind,
2236 clauses.reductionMod, clauses.reductionVars,
2241 template <
typename OpType>
2243 auto privateVars = op.getPrivateVars();
2244 auto privateSyms = op.getPrivateSymsAttr();
2246 if (privateVars.empty() && (privateSyms ==
nullptr || privateSyms.empty()))
2249 auto numPrivateVars = privateVars.size();
2250 auto numPrivateSyms = (privateSyms ==
nullptr) ? 0 : privateSyms.size();
2252 if (numPrivateVars != numPrivateSyms)
2253 return op.emitError() <<
"inconsistent number of private variables and "
2254 "privatizer op symbols, private vars: "
2256 <<
" vs. privatizer op symbols: " << numPrivateSyms;
2258 for (
auto privateVarInfo : llvm::zip_equal(privateVars, privateSyms)) {
2259 Type varType = std::get<0>(privateVarInfo).getType();
2260 SymbolRefAttr privateSym = cast<SymbolRefAttr>(std::get<1>(privateVarInfo));
2261 PrivateClauseOp privatizerOp =
2262 SymbolTable::lookupNearestSymbolFrom<PrivateClauseOp>(op, privateSym);
2264 if (privatizerOp ==
nullptr)
2265 return op.emitError() <<
"failed to lookup privatizer op with symbol: '"
2266 << privateSym <<
"'";
2268 Type privatizerType = privatizerOp.getArgType();
2270 if (privatizerType && (varType != privatizerType))
2271 return op.emitError()
2272 <<
"type mismatch between a "
2273 << (privatizerOp.getDataSharingType() ==
2274 DataSharingClauseType::Private
2277 <<
" variable and its privatizer op, var type: " << varType
2278 <<
" vs. privatizer op type: " << privatizerType;
2285 if (getAllocateVars().size() != getAllocatorVars().size())
2287 "expected equal sizes for allocate and allocator variables");
2293 getReductionByref());
2296 LogicalResult ParallelOp::verifyRegions() {
2297 auto distChildOps = getOps<DistributeOp>();
2298 int numDistChildOps = std::distance(distChildOps.begin(), distChildOps.end());
2299 if (numDistChildOps > 1)
2301 <<
"multiple 'omp.distribute' nested inside of 'omp.parallel'";
2303 if (numDistChildOps == 1) {
2306 <<
"'omp.composite' attribute missing from composite operation";
2309 Operation &distributeOp = **distChildOps.begin();
2311 if (&childOp == &distributeOp || ompDialect != childOp.getDialect())
2315 return emitError() <<
"unexpected OpenMP operation inside of composite "
2317 << childOp.getName();
2319 }
else if (isComposite()) {
2321 <<
"'omp.composite' attribute present in non-composite operation";
2338 const TeamsOperands &clauses) {
2341 TeamsOp::build(builder, state, clauses.allocateVars, clauses.allocatorVars,
2342 clauses.ifExpr, clauses.numTeamsLower, clauses.numTeamsUpper,
2344 nullptr, clauses.reductionMod,
2345 clauses.reductionVars,
2348 clauses.threadLimit);
2360 return emitError(
"expected to be nested inside of omp.target or not nested "
2361 "in any OpenMP dialect operations");
2364 if (
auto numTeamsLowerBound = getNumTeamsLower()) {
2365 auto numTeamsUpperBound = getNumTeamsUpper();
2366 if (!numTeamsUpperBound)
2367 return emitError(
"expected num_teams upper bound to be defined if the "
2368 "lower bound is defined");
2369 if (numTeamsLowerBound.getType() != numTeamsUpperBound.getType())
2371 "expected num_teams upper bound and lower bound to be the same type");
2375 if (getAllocateVars().size() != getAllocatorVars().size())
2377 "expected equal sizes for allocate and allocator variables");
2380 getReductionByref());
2388 return getParentOp().getPrivateVars();
2392 return getParentOp().getReductionVars();
2400 const SectionsOperands &clauses) {
2403 SectionsOp::build(builder, state, clauses.allocateVars, clauses.allocatorVars,
2406 clauses.reductionMod, clauses.reductionVars,
2412 if (getAllocateVars().size() != getAllocatorVars().size())
2414 "expected equal sizes for allocate and allocator variables");
2417 getReductionByref());
2420 LogicalResult SectionsOp::verifyRegions() {
2421 for (
auto &inst : *getRegion().begin()) {
2422 if (!(isa<SectionOp>(inst) || isa<TerminatorOp>(inst))) {
2423 return emitOpError()
2424 <<
"expected omp.section op or terminator op inside region";
2436 const SingleOperands &clauses) {
2439 SingleOp::build(builder, state, clauses.allocateVars, clauses.allocatorVars,
2440 clauses.copyprivateVars,
2441 makeArrayAttr(ctx, clauses.copyprivateSyms), clauses.nowait,
2448 if (getAllocateVars().size() != getAllocatorVars().size())
2450 "expected equal sizes for allocate and allocator variables");
2453 getCopyprivateSyms());
2461 const WorkshareOperands &clauses) {
2462 WorkshareOp::build(builder, state, clauses.nowait);
2470 if (!(*this)->getParentOfType<WorkshareOp>())
2471 return emitOpError() <<
"must be nested in an omp.workshare";
2475 LogicalResult WorkshareLoopWrapperOp::verifyRegions() {
2476 if (isa_and_nonnull<LoopWrapperInterface>((*this)->getParentOp()) ||
2478 return emitOpError() <<
"expected to be a standalone loop wrapper";
2487 LogicalResult LoopWrapperInterface::verifyImpl() {
2491 return emitOpError() <<
"loop wrapper must also have the `NoTerminator` "
2492 "and `SingleBlock` traits";
2495 return emitOpError() <<
"loop wrapper does not contain exactly one region";
2498 if (range_size(region.
getOps()) != 1)
2499 return emitOpError()
2500 <<
"loop wrapper does not contain exactly one nested op";
2503 if (!isa<LoopNestOp, LoopWrapperInterface>(firstOp))
2504 return emitOpError() <<
"nested in loop wrapper is not another loop "
2505 "wrapper or `omp.loop_nest`";
2515 const LoopOperands &clauses) {
2518 LoopOp::build(builder, state, clauses.bindKind, clauses.privateVars,
2520 clauses.privateNeedsBarrier, clauses.order, clauses.orderMod,
2521 clauses.reductionMod, clauses.reductionVars,
2528 getReductionByref());
2531 LogicalResult LoopOp::verifyRegions() {
2532 if (llvm::isa_and_nonnull<LoopWrapperInterface>((*this)->getParentOp()) ||
2534 return emitOpError() <<
"expected to be a standalone loop wrapper";
2545 build(builder, state, {}, {},
2547 false,
nullptr,
nullptr,
2548 nullptr, {},
nullptr,
2555 state.addAttributes(attributes);
2559 const WsloopOperands &clauses) {
2564 {}, {}, clauses.linearVars,
2565 clauses.linearStepVars, clauses.nowait, clauses.order, clauses.orderMod,
2566 clauses.ordered, clauses.privateVars,
2567 makeArrayAttr(ctx, clauses.privateSyms), clauses.privateNeedsBarrier,
2568 clauses.reductionMod, clauses.reductionVars,
2570 makeArrayAttr(ctx, clauses.reductionSyms), clauses.scheduleKind,
2571 clauses.scheduleChunk, clauses.scheduleMod, clauses.scheduleSimd);
2576 getReductionByref());
2579 LogicalResult WsloopOp::verifyRegions() {
2580 bool isCompositeChildLeaf =
2581 llvm::dyn_cast_if_present<LoopWrapperInterface>((*this)->getParentOp());
2583 if (LoopWrapperInterface nested = getNestedWrapper()) {
2586 <<
"'omp.composite' attribute missing from composite wrapper";
2590 if (!isa<SimdOp>(nested))
2591 return emitError() <<
"only supported nested wrapper is 'omp.simd'";
2593 }
else if (isComposite() && !isCompositeChildLeaf) {
2595 <<
"'omp.composite' attribute present in non-composite wrapper";
2596 }
else if (!isComposite() && isCompositeChildLeaf) {
2598 <<
"'omp.composite' attribute missing from composite wrapper";
2609 const SimdOperands &clauses) {
2612 SimdOp::build(builder, state, clauses.alignedVars,
2615 clauses.nontemporalVars, clauses.order, clauses.orderMod,
2616 clauses.privateVars,
makeArrayAttr(ctx, clauses.privateSyms),
2617 clauses.privateNeedsBarrier, clauses.reductionMod,
2618 clauses.reductionVars,
2625 if (getSimdlen().has_value() && getSafelen().has_value() &&
2626 getSimdlen().value() > getSafelen().value())
2627 return emitOpError()
2628 <<
"simdlen clause and safelen clause are both present, but the "
2629 "simdlen value is not less than or equal to safelen value";
2637 bool isCompositeChildLeaf =
2638 llvm::dyn_cast_if_present<LoopWrapperInterface>((*this)->getParentOp());
2640 if (!isComposite() && isCompositeChildLeaf)
2642 <<
"'omp.composite' attribute missing from composite wrapper";
2644 if (isComposite() && !isCompositeChildLeaf)
2646 <<
"'omp.composite' attribute present in non-composite wrapper";
2650 std::optional<ArrayAttr> privateSyms = getPrivateSyms();
2652 for (
const Attribute &sym : *privateSyms) {
2653 auto symRef = cast<SymbolRefAttr>(sym);
2654 omp::PrivateClauseOp privatizer =
2655 SymbolTable::lookupNearestSymbolFrom<omp::PrivateClauseOp>(
2656 getOperation(), symRef);
2658 return emitError() <<
"Cannot find privatizer '" << symRef <<
"'";
2659 if (privatizer.getDataSharingType() ==
2660 DataSharingClauseType::FirstPrivate)
2661 return emitError() <<
"FIRSTPRIVATE cannot be used with SIMD";
2668 LogicalResult SimdOp::verifyRegions() {
2669 if (getNestedWrapper())
2670 return emitOpError() <<
"must wrap an 'omp.loop_nest' directly";
2680 const DistributeOperands &clauses) {
2681 DistributeOp::build(builder, state, clauses.allocateVars,
2682 clauses.allocatorVars, clauses.distScheduleStatic,
2683 clauses.distScheduleChunkSize, clauses.order,
2684 clauses.orderMod, clauses.privateVars,
2686 clauses.privateNeedsBarrier);
2690 if (this->getDistScheduleChunkSize() && !this->getDistScheduleStatic())
2691 return emitOpError() <<
"chunk size set without "
2692 "dist_schedule_static being present";
2694 if (getAllocateVars().size() != getAllocatorVars().size())
2696 "expected equal sizes for allocate and allocator variables");
2701 LogicalResult DistributeOp::verifyRegions() {
2702 if (LoopWrapperInterface nested = getNestedWrapper()) {
2705 <<
"'omp.composite' attribute missing from composite wrapper";
2708 if (isa<WsloopOp>(nested)) {
2710 if (!llvm::dyn_cast_if_present<ParallelOp>(parentOp) ||
2711 !cast<ComposableOpInterface>(parentOp).isComposite()) {
2712 return emitError() <<
"an 'omp.wsloop' nested wrapper is only allowed "
2713 "when a composite 'omp.parallel' is the direct "
2716 }
else if (!isa<SimdOp>(nested))
2717 return emitError() <<
"only supported nested wrappers are 'omp.simd' and "
2719 }
else if (isComposite()) {
2721 <<
"'omp.composite' attribute present in non-composite wrapper";
2735 LogicalResult DeclareMapperOp::verifyRegions() {
2736 if (!llvm::isa_and_present<DeclareMapperInfoOp>(
2737 getRegion().getBlocks().front().getTerminator()))
2738 return emitOpError() <<
"expected terminator to be a DeclareMapperInfoOp";
2747 LogicalResult DeclareReductionOp::verifyRegions() {
2748 if (!getAllocRegion().empty()) {
2749 for (YieldOp yieldOp : getAllocRegion().getOps<YieldOp>()) {
2750 if (yieldOp.getResults().size() != 1 ||
2751 yieldOp.getResults().getTypes()[0] !=
getType())
2752 return emitOpError() <<
"expects alloc region to yield a value "
2753 "of the reduction type";
2757 if (getInitializerRegion().empty())
2758 return emitOpError() <<
"expects non-empty initializer region";
2759 Block &initializerEntryBlock = getInitializerRegion().
front();
2762 if (!getAllocRegion().empty())
2763 return emitOpError() <<
"expects two arguments to the initializer region "
2764 "when an allocation region is used";
2766 if (getAllocRegion().empty())
2767 return emitOpError() <<
"expects one argument to the initializer region "
2768 "when no allocation region is used";
2770 return emitOpError()
2771 <<
"expects one or two arguments to the initializer region";
2775 if (arg.getType() !=
getType())
2776 return emitOpError() <<
"expects initializer region argument to match "
2777 "the reduction type";
2779 for (YieldOp yieldOp : getInitializerRegion().getOps<YieldOp>()) {
2780 if (yieldOp.getResults().size() != 1 ||
2781 yieldOp.getResults().getTypes()[0] !=
getType())
2782 return emitOpError() <<
"expects initializer region to yield a value "
2783 "of the reduction type";
2786 if (getReductionRegion().empty())
2787 return emitOpError() <<
"expects non-empty reduction region";
2788 Block &reductionEntryBlock = getReductionRegion().
front();
2793 return emitOpError() <<
"expects reduction region with two arguments of "
2794 "the reduction type";
2795 for (YieldOp yieldOp : getReductionRegion().getOps<YieldOp>()) {
2796 if (yieldOp.getResults().size() != 1 ||
2797 yieldOp.getResults().getTypes()[0] !=
getType())
2798 return emitOpError() <<
"expects reduction region to yield a value "
2799 "of the reduction type";
2802 if (!getAtomicReductionRegion().empty()) {
2803 Block &atomicReductionEntryBlock = getAtomicReductionRegion().
front();
2807 return emitOpError() <<
"expects atomic reduction region with two "
2808 "arguments of the same type";
2809 auto ptrType = llvm::dyn_cast<PointerLikeType>(
2812 (ptrType.getElementType() && ptrType.getElementType() !=
getType()))
2813 return emitOpError() <<
"expects atomic reduction region arguments to "
2814 "be accumulators containing the reduction type";
2817 if (getCleanupRegion().empty())
2819 Block &cleanupEntryBlock = getCleanupRegion().
front();
2822 return emitOpError() <<
"expects cleanup region with one argument "
2823 "of the reduction type";
2833 const TaskOperands &clauses) {
2835 TaskOp::build(builder, state, clauses.allocateVars, clauses.allocatorVars,
2836 makeArrayAttr(ctx, clauses.dependKinds), clauses.dependVars,
2837 clauses.final, clauses.ifExpr, clauses.inReductionVars,
2839 makeArrayAttr(ctx, clauses.inReductionSyms), clauses.mergeable,
2840 clauses.priority, clauses.privateVars,
2842 clauses.privateNeedsBarrier, clauses.untied,
2843 clauses.eventHandle);
2847 LogicalResult verifyDependVars =
2849 return failed(verifyDependVars)
2852 getInReductionVars(),
2853 getInReductionByref());
2861 const TaskgroupOperands &clauses) {
2863 TaskgroupOp::build(builder, state, clauses.allocateVars,
2864 clauses.allocatorVars, clauses.taskReductionVars,
2871 getTaskReductionVars(),
2872 getTaskReductionByref());
2880 const TaskloopOperands &clauses) {
2883 builder, state, clauses.allocateVars, clauses.allocatorVars,
2884 clauses.final, clauses.grainsizeMod, clauses.grainsize, clauses.ifExpr,
2885 clauses.inReductionVars,
2887 makeArrayAttr(ctx, clauses.inReductionSyms), clauses.mergeable,
2888 clauses.nogroup, clauses.numTasksMod, clauses.numTasks, clauses.priority,
2889 clauses.privateVars,
2891 clauses.privateNeedsBarrier, clauses.reductionMod, clauses.reductionVars,
2897 if (getAllocateVars().size() != getAllocatorVars().size())
2899 "expected equal sizes for allocate and allocator variables");
2901 getReductionVars(), getReductionByref())) ||
2903 getInReductionVars(),
2904 getInReductionByref())))
2907 if (!getReductionVars().empty() && getNogroup())
2908 return emitError(
"if a reduction clause is present on the taskloop "
2909 "directive, the nogroup clause must not be specified");
2910 for (
auto var : getReductionVars()) {
2911 if (llvm::is_contained(getInReductionVars(), var))
2912 return emitError(
"the same list item cannot appear in both a reduction "
2913 "and an in_reduction clause");
2916 if (getGrainsize() && getNumTasks()) {
2918 "the grainsize clause and num_tasks clause are mutually exclusive and "
2919 "may not appear on the same taskloop directive");
2925 LogicalResult TaskloopOp::verifyRegions() {
2926 if (LoopWrapperInterface nested = getNestedWrapper()) {
2929 <<
"'omp.composite' attribute missing from composite wrapper";
2933 if (!isa<SimdOp>(nested))
2934 return emitError() <<
"only supported nested wrapper is 'omp.simd'";
2935 }
else if (isComposite()) {
2937 <<
"'omp.composite' attribute present in non-composite wrapper";
2961 for (
auto &iv : ivs)
2962 iv.type = loopVarType;
2983 "collapse_num_loops",
2988 auto parseTiles = [&]() -> ParseResult {
2992 tiles.push_back(
tile);
3001 if (tiles.size() > 0)
3020 Region ®ion = getRegion();
3022 p <<
" (" << args <<
") : " << args[0].getType() <<
" = ("
3023 << getLoopLowerBounds() <<
") to (" << getLoopUpperBounds() <<
") ";
3024 if (getLoopInclusive())
3026 p <<
"step (" << getLoopSteps() <<
") ";
3027 if (int64_t numCollapse = getCollapseNumLoops())
3028 if (numCollapse > 1)
3029 p <<
"collapse(" << numCollapse <<
") ";
3032 p <<
"tiles(" << tiles.value() <<
") ";
3038 const LoopNestOperands &clauses) {
3040 LoopNestOp::build(builder, state, clauses.collapseNumLoops,
3041 clauses.loopLowerBounds, clauses.loopUpperBounds,
3042 clauses.loopSteps, clauses.loopInclusive,
3047 if (getLoopLowerBounds().empty())
3048 return emitOpError() <<
"must represent at least one loop";
3050 if (getLoopLowerBounds().size() != getIVs().size())
3051 return emitOpError() <<
"number of range arguments and IVs do not match";
3053 for (
auto [lb, iv] : llvm::zip_equal(getLoopLowerBounds(), getIVs())) {
3054 if (lb.getType() != iv.getType())
3055 return emitOpError()
3056 <<
"range argument type does not match corresponding IV type";
3059 uint64_t numIVs = getIVs().size();
3061 if (
const auto &numCollapse = getCollapseNumLoops())
3062 if (numCollapse > numIVs)
3063 return emitOpError()
3064 <<
"collapse value is larger than the number of loops";
3067 if (tiles.value().size() > numIVs)
3068 return emitOpError() <<
"too few canonical loops for tile dimensions";
3070 if (!llvm::dyn_cast_if_present<LoopWrapperInterface>((*this)->getParentOp()))
3071 return emitOpError() <<
"expects parent op to be a loop wrapper";
3076 void LoopNestOp::gatherWrappers(
3079 while (
auto wrapper =
3080 llvm::dyn_cast_if_present<LoopWrapperInterface>(parent)) {
3081 wrappers.push_back(wrapper);
3090 std::tuple<NewCliOp, OpOperand *, OpOperand *>
3096 return {{},
nullptr,
nullptr};
3099 "Unexpected type of cli");
3105 auto op = cast<LoopTransformationInterface>(use.getOwner());
3107 unsigned opnum = use.getOperandNumber();
3108 if (op.isGeneratee(opnum)) {
3109 assert(!gen &&
"Each CLI may have at most one def");
3111 }
else if (op.isApplyee(opnum)) {
3112 assert(!cons &&
"Each CLI may have at most one consumer");
3115 llvm_unreachable(
"Unexpected operand for a CLI");
3119 return {create, gen, cons};
3128 Value result = getResult();
3129 auto [newCli, gen, cons] =
decodeCli(result);
3139 std::string cliName{
"cli"};
3143 .Case([&](CanonicalLoopOp op) {
3157 llvm::ReversePostOrderTraversal<Block *> traversal(
3160 for (
Block *b : traversal) {
3166 if (!op.getRegions().empty())
3170 llvm_unreachable(
"Operation not part of the region");
3172 size_t sequentialIdx = getSequentialIndex(r, o);
3173 components.push_back((
"s" + Twine(sequentialIdx)).str());
3183 for (
auto [idx, region] :
3188 llvm_unreachable(
"Region not child its parent operation");
3190 size_t regionIdx = getRegionIndex(parent, r);
3191 components.push_back((
"r" + Twine(regionIdx)).str());
3199 for (
const std::string &s : reverse(components)) {
3206 .Case([&](UnrollHeuristicOp op) -> std::string {
3207 llvm_unreachable(
"heuristic unrolling does not generate a loop");
3210 assert(
false &&
"TODO: Custom name for this operation");
3211 return "transformed";
3215 setNameFn(result, cliName);
3219 Value cli = getResult();
3222 "Unexpected type of cli");
3228 auto op = cast<mlir::omp::LoopTransformationInterface>(use.getOwner());
3230 unsigned opnum = use.getOperandNumber();
3231 if (op.isGeneratee(opnum)) {
3234 emitOpError(
"CLI must have at most one generator");
3236 .
append(
"first generator here:");
3238 .
append(
"second generator here:");
3243 }
else if (op.isApplyee(opnum)) {
3246 emitOpError(
"CLI must have at most one consumer");
3248 .
append(
"first consumer here:")
3252 .
append(
"second consumer here:")
3259 llvm_unreachable(
"Unexpected operand for a CLI");
3267 .
append(
"see consumer here: ")
3290 setNameFn(&getRegion().front(),
"body_entry");
3293 void CanonicalLoopOp::getAsmBlockArgumentNames(
Region ®ion,
3300 p <<
'(' << getCli() <<
')';
3301 p <<
' ' << getInductionVar() <<
" : " << getInductionVar().getType()
3302 <<
" in range(" << getTripCount() <<
") ";
3312 CanonicalLoopInfoType cliType =
3338 if (parser.
parseRegion(*region, {inductionVariable}))
3343 result.
operands.append(cliOperand);
3349 return mlir::success();
3355 if (!getRegion().empty()) {
3356 Region ®ion = getRegion();
3359 "Canonical loop region must have exactly one argument");
3363 "Region argument must be the same type as the trip count");
3369 Value CanonicalLoopOp::getInductionVar() {
return getRegion().getArgument(0); }
3371 std::pair<unsigned, unsigned>
3372 CanonicalLoopOp::getApplyeesODSOperandIndexAndLength() {
3377 std::pair<unsigned, unsigned>
3378 CanonicalLoopOp::getGenerateesODSOperandIndexAndLength() {
3379 return getODSOperandIndexAndLength(odsIndex_cli);
3393 p <<
'(' << getApplyee() <<
')';
3423 return mlir::success();
3426 std::pair<unsigned, unsigned>
3427 UnrollHeuristicOp ::getApplyeesODSOperandIndexAndLength() {
3428 return getODSOperandIndexAndLength(odsIndex_applyee);
3431 std::pair<unsigned, unsigned>
3432 UnrollHeuristicOp::getGenerateesODSOperandIndexAndLength() {
3441 const CriticalDeclareOperands &clauses) {
3442 CriticalDeclareOp::build(builder, state, clauses.symName, clauses.hint);
3450 if (getNameAttr()) {
3451 SymbolRefAttr symbolRef = getNameAttr();
3455 return emitOpError() <<
"expected symbol reference " << symbolRef
3456 <<
" to point to a critical declaration";
3476 return op.
emitOpError() <<
"must be nested inside of a loop";
3480 if (
auto wsloopOp = dyn_cast<WsloopOp>(wrapper)) {
3481 IntegerAttr orderedAttr = wsloopOp.getOrderedAttr();
3483 return op.
emitOpError() <<
"the enclosing worksharing-loop region must "
3484 "have an ordered clause";
3486 if (hasRegion && orderedAttr.getInt() != 0)
3487 return op.
emitOpError() <<
"the enclosing loop's ordered clause must not "
3488 "have a parameter present";
3490 if (!hasRegion && orderedAttr.getInt() == 0)
3491 return op.
emitOpError() <<
"the enclosing loop's ordered clause must "
3492 "have a parameter present";
3493 }
else if (!isa<SimdOp>(wrapper)) {
3494 return op.
emitOpError() <<
"must be nested inside of a worksharing, simd "
3495 "or worksharing simd loop";
3501 const OrderedOperands &clauses) {
3502 OrderedOp::build(builder, state, clauses.doacrossDependType,
3503 clauses.doacrossNumLoops, clauses.doacrossDependVars);
3510 auto wrapper = (*this)->getParentOfType<WsloopOp>();
3511 if (!wrapper || *wrapper.getOrdered() != *getDoacrossNumLoops())
3512 return emitOpError() <<
"number of variables in depend clause does not "
3513 <<
"match number of iteration variables in the "
3520 const OrderedRegionOperands &clauses) {
3521 OrderedRegionOp::build(builder, state, clauses.parLevelSimd);
3531 const TaskwaitOperands &clauses) {
3533 TaskwaitOp::build(builder, state,
nullptr,
3542 if (verifyCommon().
failed())
3543 return mlir::failure();
3545 if (
auto mo = getMemoryOrder()) {
3546 if (*mo == ClauseMemoryOrderKind::Acq_rel ||
3547 *mo == ClauseMemoryOrderKind::Release) {
3549 "memory-order must not be acq_rel or release for atomic reads");
3560 if (verifyCommon().
failed())
3561 return mlir::failure();
3563 if (
auto mo = getMemoryOrder()) {
3564 if (*mo == ClauseMemoryOrderKind::Acq_rel ||
3565 *mo == ClauseMemoryOrderKind::Acquire) {
3567 "memory-order must not be acq_rel or acquire for atomic writes");
3577 LogicalResult AtomicUpdateOp::canonicalize(AtomicUpdateOp op,
3583 if (
Value writeVal = op.getWriteOpVal()) {
3585 op, op.getX(), writeVal, op.getHintAttr(), op.getMemoryOrderAttr());
3592 if (verifyCommon().
failed())
3593 return mlir::failure();
3595 if (
auto mo = getMemoryOrder()) {
3596 if (*mo == ClauseMemoryOrderKind::Acq_rel ||
3597 *mo == ClauseMemoryOrderKind::Acquire) {
3599 "memory-order must not be acq_rel or acquire for atomic updates");
3606 LogicalResult AtomicUpdateOp::verifyRegions() {
return verifyRegionsCommon(); }
3612 AtomicReadOp AtomicCaptureOp::getAtomicReadOp() {
3613 if (
auto op = dyn_cast<AtomicReadOp>(getFirstOp()))
3615 return dyn_cast<AtomicReadOp>(getSecondOp());
3618 AtomicWriteOp AtomicCaptureOp::getAtomicWriteOp() {
3619 if (
auto op = dyn_cast<AtomicWriteOp>(getFirstOp()))
3621 return dyn_cast<AtomicWriteOp>(getSecondOp());
3624 AtomicUpdateOp AtomicCaptureOp::getAtomicUpdateOp() {
3625 if (
auto op = dyn_cast<AtomicUpdateOp>(getFirstOp()))
3627 return dyn_cast<AtomicUpdateOp>(getSecondOp());
3634 LogicalResult AtomicCaptureOp::verifyRegions() {
3635 if (verifyRegionsCommon().
failed())
3636 return mlir::failure();
3638 if (getFirstOp()->getAttr(
"hint") || getSecondOp()->getAttr(
"hint"))
3640 "operations inside capture region must not have hint clause");
3642 if (getFirstOp()->getAttr(
"memory_order") ||
3643 getSecondOp()->getAttr(
"memory_order"))
3645 "operations inside capture region must not have memory_order clause");
3654 const CancelOperands &clauses) {
3655 CancelOp::build(builder, state, clauses.cancelDirective, clauses.ifExpr);
3669 ClauseCancellationConstructType cct = getCancelDirective();
3672 if (!structuralParent)
3673 return emitOpError() <<
"Orphaned cancel construct";
3675 if ((cct == ClauseCancellationConstructType::Parallel) &&
3676 !mlir::isa<ParallelOp>(structuralParent)) {
3677 return emitOpError() <<
"cancel parallel must appear "
3678 <<
"inside a parallel region";
3680 if (cct == ClauseCancellationConstructType::Loop) {
3683 auto wsloopOp = mlir::dyn_cast<WsloopOp>(structuralParent->
getParentOp());
3686 return emitOpError()
3687 <<
"cancel loop must appear inside a worksharing-loop region";
3689 if (wsloopOp.getNowaitAttr()) {
3690 return emitError() <<
"A worksharing construct that is canceled "
3691 <<
"must not have a nowait clause";
3693 if (wsloopOp.getOrderedAttr()) {
3694 return emitError() <<
"A worksharing construct that is canceled "
3695 <<
"must not have an ordered clause";
3698 }
else if (cct == ClauseCancellationConstructType::Sections) {
3702 mlir::dyn_cast<SectionsOp>(structuralParent->
getParentOp());
3704 return emitOpError() <<
"cancel sections must appear "
3705 <<
"inside a sections region";
3707 if (sectionsOp.getNowait()) {
3708 return emitError() <<
"A sections construct that is canceled "
3709 <<
"must not have a nowait clause";
3712 if ((cct == ClauseCancellationConstructType::Taskgroup) &&
3713 (!mlir::isa<omp::TaskOp>(structuralParent) &&
3714 !mlir::isa<omp::TaskloopOp>(structuralParent->
getParentOp()))) {
3715 return emitOpError() <<
"cancel taskgroup must appear "
3716 <<
"inside a task region";
3726 const CancellationPointOperands &clauses) {
3727 CancellationPointOp::build(builder, state, clauses.cancelDirective);
3731 ClauseCancellationConstructType cct = getCancelDirective();
3734 if (!structuralParent)
3735 return emitOpError() <<
"Orphaned cancellation point";
3737 if ((cct == ClauseCancellationConstructType::Parallel) &&
3738 !mlir::isa<ParallelOp>(structuralParent)) {
3739 return emitOpError() <<
"cancellation point parallel must appear "
3740 <<
"inside a parallel region";
3744 if ((cct == ClauseCancellationConstructType::Loop) &&
3745 !mlir::isa<WsloopOp>(structuralParent->
getParentOp())) {
3746 return emitOpError() <<
"cancellation point loop must appear "
3747 <<
"inside a worksharing-loop region";
3749 if ((cct == ClauseCancellationConstructType::Sections) &&
3750 !mlir::isa<omp::SectionOp>(structuralParent)) {
3751 return emitOpError() <<
"cancellation point sections must appear "
3752 <<
"inside a sections region";
3754 if ((cct == ClauseCancellationConstructType::Taskgroup) &&
3755 !mlir::isa<omp::TaskOp>(structuralParent)) {
3756 return emitOpError() <<
"cancellation point taskgroup must appear "
3757 <<
"inside a task region";
3767 auto extent = getExtent();
3769 if (!extent && !upperbound)
3770 return emitError(
"expected extent or upperbound.");
3777 PrivateClauseOp::build(
3778 odsBuilder, odsState, symName, type,
3780 DataSharingClauseType::Private));
3783 LogicalResult PrivateClauseOp::verifyRegions() {
3784 Type argType = getArgType();
3785 auto verifyTerminator = [&](
Operation *terminator,
3786 bool yieldsValue) -> LogicalResult {
3790 if (!llvm::isa<YieldOp>(terminator))
3792 <<
"expected exit block terminator to be an `omp.yield` op.";
3794 YieldOp yieldOp = llvm::cast<YieldOp>(terminator);
3795 TypeRange yieldedTypes = yieldOp.getResults().getTypes();
3798 if (yieldedTypes.empty())
3802 <<
"Did not expect any values to be yielded.";
3805 if (yieldedTypes.size() == 1 && yieldedTypes.front() == argType)
3809 <<
"Invalid yielded value. Expected type: " << argType
3812 if (yieldedTypes.empty())
3815 error << yieldedTypes;
3821 StringRef regionName,
3822 bool yieldsValue) -> LogicalResult {
3823 assert(!region.
empty());
3827 <<
"`" << regionName <<
"`: "
3828 <<
"expected " << expectedNumArgs
3831 for (
Block &block : region) {
3833 if (!block.mightHaveTerminator())
3836 if (
failed(verifyTerminator(block.getTerminator(), yieldsValue)))
3844 for (
Region *region : getRegions())
3845 for (
Type ty : region->getArgumentTypes())
3847 return emitError() <<
"Region argument type mismatch: got " << ty
3848 <<
" expected " << argType <<
".";
3851 if (!initRegion.
empty() &&
3856 DataSharingClauseType dsType = getDataSharingType();
3858 if (dsType == DataSharingClauseType::Private && !getCopyRegion().empty())
3859 return emitError(
"`private` clauses do not require a `copy` region.");
3861 if (dsType == DataSharingClauseType::FirstPrivate && getCopyRegion().empty())
3863 "`firstprivate` clauses require at least a `copy` region.");
3865 if (dsType == DataSharingClauseType::FirstPrivate &&
3870 if (!getDeallocRegion().empty() &&
3883 const MaskedOperands &clauses) {
3884 MaskedOp::build(builder, state, clauses.filteredThreadId);
3892 const ScanOperands &clauses) {
3893 ScanOp::build(builder, state, clauses.inclusiveVars, clauses.exclusiveVars);
3897 if (hasExclusiveVars() == hasInclusiveVars())
3899 "Exactly one of EXCLUSIVE or INCLUSIVE clause is expected");
3900 if (WsloopOp parentWsLoopOp = (*this)->getParentOfType<WsloopOp>()) {
3901 if (parentWsLoopOp.getReductionModAttr() &&
3902 parentWsLoopOp.getReductionModAttr().getValue() ==
3903 ReductionModifier::inscan)
3906 if (SimdOp parentSimdOp = (*this)->getParentOfType<SimdOp>()) {
3907 if (parentSimdOp.getReductionModAttr() &&
3908 parentSimdOp.getReductionModAttr().getValue() ==
3909 ReductionModifier::inscan)
3912 return emitError(
"SCAN directive needs to be enclosed within a parent "
3913 "worksharing loop construct or SIMD construct with INSCAN "
3914 "reduction modifier");
3920 std::optional<uint64_t> align = this->getAlign();
3922 if (align.has_value()) {
3923 if ((align.value() > 0) && !llvm::has_single_bit(align.value()))
3924 return emitError() <<
"ALIGN value : " << align.value()
3925 <<
" must be power of 2";
3935 mlir::Type omp::TargetAllocMemOp::getAllocatedType() {
3936 return getInTypeAttr().getValue();
3945 bool hasOperands =
false;
3946 std::int32_t typeparamsSize = 0;
3952 return mlir::failure();
3954 return mlir::failure();
3956 return mlir::failure();
3960 return mlir::failure();
3968 return mlir::failure();
3969 typeparamsSize = operands.size();
3972 std::int32_t shapeSize = 0;
3976 return mlir::failure();
3977 shapeSize = operands.size() - typeparamsSize;
3979 for (std::int32_t i = typeparamsSize, end = operands.size(); i != end; ++i)
3980 typeVec.push_back(idxTy);
3986 return mlir::failure();
3991 return mlir::failure();
3998 return mlir::failure();
3999 return mlir::success();
4014 if (!getTypeparams().empty()) {
4015 p <<
'(' << getTypeparams() <<
" : " << getTypeparams().getTypes() <<
')';
4022 {
"in_type",
"operandSegmentSizes"});
4027 if (!mlir::dyn_cast<IntegerType>(outType))
4028 return emitOpError(
"must be a integer type");
4029 return mlir::success();
4038 Region ®ion = getRegion();
4040 return emitOpError(
"region cannot be empty");
4043 if (entryBlock.
empty())
4044 return emitOpError(
"region must contain a structured block");
4046 bool hasTerminator =
false;
4047 for (
Block &block : region) {
4048 if (isa<TerminatorOp>(block.back())) {
4049 if (hasTerminator) {
4050 return emitOpError(
"region must have exactly one terminator");
4052 hasTerminator =
true;
4055 if (!hasTerminator) {
4056 return emitOpError(
"region must be terminated with omp.terminator");
4060 if (isa<BarrierOp>(op)) {
4062 "explicit barriers are not allowed in workdistribute region");
4065 if (isa<ParallelOp>(op)) {
4067 "nested parallel constructs not allowed in workdistribute");
4069 if (isa<TeamsOp>(op)) {
4071 "nested teams constructs not allowed in workdistribute");
4073 return WalkResult::advance();
4075 if (walkResult.wasInterrupted())
4079 if (!llvm::dyn_cast<TeamsOp>(parentOp))
4080 return emitOpError(
"workdistribute must be nested under teams");
4084 #define GET_ATTRDEF_CLASSES
4085 #include "mlir/Dialect/OpenMP/OpenMPOpsAttributes.cpp.inc"
4087 #define GET_OP_CLASSES
4088 #include "mlir/Dialect/OpenMP/OpenMPOps.cpp.inc"
4090 #define GET_TYPEDEF_CLASSES
4091 #include "mlir/Dialect/OpenMP/OpenMPOpsTypes.cpp.inc"
static SmallVector< Value > getTileSizes(Location loc, amx::TileType tType, RewriterBase &rewriter)
Maps the 2-dim vector shape to the two 16-bit tile sizes.
static std::optional< int64_t > getUpperBound(Value iv)
Gets the constant upper bound on an affine.for iv.
static LogicalResult verifyRegion(emitc::SwitchOp op, Region ®ion, const Twine &name)
static void visit(Operation *op, DenseSet< Operation * > &visited)
Visits all the pdl.operand(s), pdl.result(s), and pdl.operation(s) connected to the given operation.
static MLIRContext * getContext(OpFoldResult val)
static Type getElementType(Type type)
Determine the element type of type.
static ze_device_handle_t getDevice(const uint32_t driverIdx=0, const int32_t devIdx=0)
void printClauseAttr(OpAsmPrinter &p, Operation *op, ClauseAttr attr)
static LogicalResult verifyNontemporalClause(Operation *op, OperandRange nontemporalVars)
static DenseI64ArrayAttr makeDenseI64ArrayAttr(MLIRContext *ctx, const ArrayRef< int64_t > intArray)
static constexpr StringRef getPrivateNeedsBarrierSpelling()
static LogicalResult verifyMapClause(Operation *op, OperandRange mapVars)
static ParseResult parseInReductionPrivateRegion(OpAsmParser &parser, Region ®ion, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &inReductionVars, SmallVectorImpl< Type > &inReductionTypes, DenseBoolArrayAttr &inReductionByref, ArrayAttr &inReductionSyms, llvm::SmallVectorImpl< OpAsmParser::UnresolvedOperand > &privateVars, llvm::SmallVectorImpl< Type > &privateTypes, ArrayAttr &privateSyms, UnitAttr &privateNeedsBarrier)
static ArrayAttr makeArrayAttr(MLIRContext *context, llvm::ArrayRef< Attribute > attrs)
static ParseResult parseClauseAttr(AsmParser &parser, ClauseAttr &attr)
static void printAllocateAndAllocator(OpAsmPrinter &p, Operation *op, OperandRange allocateVars, TypeRange allocateTypes, OperandRange allocatorVars, TypeRange allocatorTypes)
Print allocate clause.
static DenseBoolArrayAttr makeDenseBoolArrayAttr(MLIRContext *ctx, const ArrayRef< bool > boolArray)
static void printClauseWithRegionArgs(OpAsmPrinter &p, MLIRContext *ctx, StringRef clauseName, ValueRange argsSubrange, ValueRange operands, TypeRange types, ArrayAttr symbols=nullptr, DenseI64ArrayAttr mapIndices=nullptr, DenseBoolArrayAttr byref=nullptr, ReductionModifierAttr modifier=nullptr, UnitAttr needsBarrier=nullptr)
static void printBlockArgClause(OpAsmPrinter &p, MLIRContext *ctx, StringRef clauseName, ValueRange argsSubrange, std::optional< MapPrintArgs > mapArgs)
static void printBlockArgRegion(OpAsmPrinter &p, Operation *op, Region ®ion, const AllRegionPrintArgs &args)
static ParseResult parseGranularityClause(OpAsmParser &parser, ClauseTypeAttr &prescriptiveness, std::optional< OpAsmParser::UnresolvedOperand > &operand, Type &operandType, std::optional< ClauseType >(*symbolizeClause)(StringRef), StringRef clauseName)
static ParseResult parseBlockArgRegion(OpAsmParser &parser, Region ®ion, AllRegionParseArgs args)
static ParseResult parseSynchronizationHint(OpAsmParser &parser, IntegerAttr &hintAttr)
Parses a Synchronization Hint clause.
uint64_t mapTypeToBitFlag(uint64_t value, llvm::omp::OpenMPOffloadMappingFlags flag)
static ParseResult parseLinearClause(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &linearVars, SmallVectorImpl< Type > &linearTypes, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &linearStepVars)
linear ::= linear ( linear-list ) linear-list := linear-val | linear-val linear-list linear-val := ss...
static void printScheduleClause(OpAsmPrinter &p, Operation *op, ClauseScheduleKindAttr scheduleKind, ScheduleModifierAttr scheduleMod, UnitAttr scheduleSimd, Value scheduleChunk, Type scheduleChunkType)
Print schedule clause.
static void printCopyprivate(OpAsmPrinter &p, Operation *op, OperandRange copyprivateVars, TypeRange copyprivateTypes, std::optional< ArrayAttr > copyprivateSyms)
Print Copyprivate clause.
static ParseResult parseOrderClause(OpAsmParser &parser, ClauseOrderKindAttr &order, OrderModifierAttr &orderMod)
static void printAlignedClause(OpAsmPrinter &p, Operation *op, ValueRange alignedVars, TypeRange alignedTypes, std::optional< ArrayAttr > alignments)
Print Aligned Clause.
static LogicalResult verifySynchronizationHint(Operation *op, uint64_t hint)
Verifies a synchronization hint clause.
static ParseResult parseUseDeviceAddrUseDevicePtrRegion(OpAsmParser &parser, Region ®ion, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &useDeviceAddrVars, SmallVectorImpl< Type > &useDeviceAddrTypes, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &useDevicePtrVars, SmallVectorImpl< Type > &useDevicePtrTypes)
static void printLinearClause(OpAsmPrinter &p, Operation *op, ValueRange linearVars, TypeRange linearTypes, ValueRange linearStepVars)
Print Linear Clause.
static void printInReductionPrivateReductionRegion(OpAsmPrinter &p, Operation *op, Region ®ion, ValueRange inReductionVars, TypeRange inReductionTypes, DenseBoolArrayAttr inReductionByref, ArrayAttr inReductionSyms, ValueRange privateVars, TypeRange privateTypes, ArrayAttr privateSyms, UnitAttr privateNeedsBarrier, ReductionModifierAttr reductionMod, ValueRange reductionVars, TypeRange reductionTypes, DenseBoolArrayAttr reductionByref, ArrayAttr reductionSyms)
static void printInReductionPrivateRegion(OpAsmPrinter &p, Operation *op, Region ®ion, ValueRange inReductionVars, TypeRange inReductionTypes, DenseBoolArrayAttr inReductionByref, ArrayAttr inReductionSyms, ValueRange privateVars, TypeRange privateTypes, ArrayAttr privateSyms, UnitAttr privateNeedsBarrier)
static void printSynchronizationHint(OpAsmPrinter &p, Operation *op, IntegerAttr hintAttr)
Prints a Synchronization Hint clause.
static void printGranularityClause(OpAsmPrinter &p, Operation *op, ClauseTypeAttr prescriptiveness, Value operand, mlir::Type operandType, StringRef(*stringifyClauseType)(ClauseType))
static mlir::ParseResult parseTargetAllocMemOp(mlir::OpAsmParser &parser, mlir::OperationState &result)
operation ::= res = (omp.target_alloc_mem) $device : devicetype, $in_type ( ( $typeparams ) )?...
static void printDependVarList(OpAsmPrinter &p, Operation *op, OperandRange dependVars, TypeRange dependTypes, std::optional< ArrayAttr > dependKinds)
Print Depend clause.
static void printTargetOpRegion(OpAsmPrinter &p, Operation *op, Region ®ion, ValueRange hasDeviceAddrVars, TypeRange hasDeviceAddrTypes, ValueRange hostEvalVars, TypeRange hostEvalTypes, ValueRange inReductionVars, TypeRange inReductionTypes, DenseBoolArrayAttr inReductionByref, ArrayAttr inReductionSyms, ValueRange mapVars, TypeRange mapTypes, ValueRange privateVars, TypeRange privateTypes, ArrayAttr privateSyms, UnitAttr privateNeedsBarrier, DenseI64ArrayAttr privateMaps)
static LogicalResult verifyCopyprivateVarList(Operation *op, OperandRange copyprivateVars, std::optional< ArrayAttr > copyprivateSyms)
Verifies CopyPrivate Clause.
static LogicalResult verifyAlignedClause(Operation *op, std::optional< ArrayAttr > alignments, OperandRange alignedVars)
static ParseResult parseTargetOpRegion(OpAsmParser &parser, Region ®ion, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &hasDeviceAddrVars, SmallVectorImpl< Type > &hasDeviceAddrTypes, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &hostEvalVars, SmallVectorImpl< Type > &hostEvalTypes, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &inReductionVars, SmallVectorImpl< Type > &inReductionTypes, DenseBoolArrayAttr &inReductionByref, ArrayAttr &inReductionSyms, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &mapVars, SmallVectorImpl< Type > &mapTypes, llvm::SmallVectorImpl< OpAsmParser::UnresolvedOperand > &privateVars, llvm::SmallVectorImpl< Type > &privateTypes, ArrayAttr &privateSyms, UnitAttr &privateNeedsBarrier, DenseI64ArrayAttr &privateMaps)
static ParseResult parsePrivateRegion(OpAsmParser &parser, Region ®ion, llvm::SmallVectorImpl< OpAsmParser::UnresolvedOperand > &privateVars, llvm::SmallVectorImpl< Type > &privateTypes, ArrayAttr &privateSyms, UnitAttr &privateNeedsBarrier)
static void printNumTasksClause(OpAsmPrinter &p, Operation *op, ClauseNumTasksTypeAttr numTasksMod, Value numTasks, mlir::Type numTasksType)
static void printPrivateRegion(OpAsmPrinter &p, Operation *op, Region ®ion, ValueRange privateVars, TypeRange privateTypes, ArrayAttr privateSyms, UnitAttr privateNeedsBarrier)
static void printPrivateReductionRegion(OpAsmPrinter &p, Operation *op, Region ®ion, ValueRange privateVars, TypeRange privateTypes, ArrayAttr privateSyms, UnitAttr privateNeedsBarrier, ReductionModifierAttr reductionMod, ValueRange reductionVars, TypeRange reductionTypes, DenseBoolArrayAttr reductionByref, ArrayAttr reductionSyms)
static void printTaskReductionRegion(OpAsmPrinter &p, Operation *op, Region ®ion, ValueRange taskReductionVars, TypeRange taskReductionTypes, DenseBoolArrayAttr taskReductionByref, ArrayAttr taskReductionSyms)
static ParseResult parseMapClause(OpAsmParser &parser, IntegerAttr &mapType)
Parses a map_entries map type from a string format back into its numeric value.
static LogicalResult verifyOrderedParent(Operation &op)
static void printOrderClause(OpAsmPrinter &p, Operation *op, ClauseOrderKindAttr order, OrderModifierAttr orderMod)
static ParseResult parseBlockArgClause(OpAsmParser &parser, llvm::SmallVectorImpl< OpAsmParser::Argument > &entryBlockArgs, StringRef keyword, std::optional< MapParseArgs > mapArgs)
static ParseResult parseClauseWithRegionArgs(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &operands, SmallVectorImpl< Type > &types, SmallVectorImpl< OpAsmParser::Argument > ®ionPrivateArgs, ArrayAttr *symbols=nullptr, DenseI64ArrayAttr *mapIndices=nullptr, DenseBoolArrayAttr *byref=nullptr, ReductionModifierAttr *modifier=nullptr, UnitAttr *needsBarrier=nullptr)
static ParseResult verifyScheduleModifiers(OpAsmParser &parser, SmallVectorImpl< SmallString< 12 >> &modifiers)
static LogicalResult verifyPrivateVarsMapping(TargetOp targetOp)
static ParseResult parseScheduleClause(OpAsmParser &parser, ClauseScheduleKindAttr &scheduleAttr, ScheduleModifierAttr &scheduleMod, UnitAttr &scheduleSimd, std::optional< OpAsmParser::UnresolvedOperand > &chunkSize, Type &chunkType)
schedule ::= schedule ( sched-list ) sched-list ::= sched-val | sched-val sched-list | sched-val ,...
static Operation * getParentInSameDialect(Operation *thisOp)
static ParseResult parseAllocateAndAllocator(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &allocateVars, SmallVectorImpl< Type > &allocateTypes, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &allocatorVars, SmallVectorImpl< Type > &allocatorTypes)
Parse an allocate clause with allocators and a list of operands with types.
static void printMembersIndex(OpAsmPrinter &p, MapInfoOp op, ArrayAttr membersIdx)
static void printCaptureType(OpAsmPrinter &p, Operation *op, VariableCaptureKindAttr mapCaptureType)
static LogicalResult verifyReductionVarList(Operation *op, std::optional< ArrayAttr > reductionSyms, OperandRange reductionVars, std::optional< ArrayRef< bool >> reductionByref)
Verifies Reduction Clause.
static Operation * findCapturedOmpOp(Operation *rootOp, bool checkSingleMandatoryExec, llvm::function_ref< bool(Operation *)> siblingAllowedFn)
static bool opInGlobalImplicitParallelRegion(Operation *op)
static void printUseDeviceAddrUseDevicePtrRegion(OpAsmPrinter &p, Operation *op, Region ®ion, ValueRange useDeviceAddrVars, TypeRange useDeviceAddrTypes, ValueRange useDevicePtrVars, TypeRange useDevicePtrTypes)
static LogicalResult verifyPrivateVarList(OpType &op)
static void printMapClause(OpAsmPrinter &p, Operation *op, IntegerAttr mapType)
Prints a map_entries map type from its numeric value out into its string format.
static ParseResult parseNumTasksClause(OpAsmParser &parser, ClauseNumTasksTypeAttr &numTasksMod, std::optional< OpAsmParser::UnresolvedOperand > &numTasks, Type &numTasksType)
static ParseResult parseMembersIndex(OpAsmParser &parser, ArrayAttr &membersIdx)
static ParseResult parseAlignedClause(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &alignedVars, SmallVectorImpl< Type > &alignedTypes, ArrayAttr &alignmentsAttr)
aligned ::= aligned ( aligned-list ) aligned-list := aligned-val | aligned-val aligned-list aligned-v...
static ParseResult parsePrivateReductionRegion(OpAsmParser &parser, Region ®ion, llvm::SmallVectorImpl< OpAsmParser::UnresolvedOperand > &privateVars, llvm::SmallVectorImpl< Type > &privateTypes, ArrayAttr &privateSyms, UnitAttr &privateNeedsBarrier, ReductionModifierAttr &reductionMod, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &reductionVars, SmallVectorImpl< Type > &reductionTypes, DenseBoolArrayAttr &reductionByref, ArrayAttr &reductionSyms)
static ParseResult parseInReductionPrivateReductionRegion(OpAsmParser &parser, Region ®ion, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &inReductionVars, SmallVectorImpl< Type > &inReductionTypes, DenseBoolArrayAttr &inReductionByref, ArrayAttr &inReductionSyms, llvm::SmallVectorImpl< OpAsmParser::UnresolvedOperand > &privateVars, llvm::SmallVectorImpl< Type > &privateTypes, ArrayAttr &privateSyms, UnitAttr &privateNeedsBarrier, ReductionModifierAttr &reductionMod, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &reductionVars, SmallVectorImpl< Type > &reductionTypes, DenseBoolArrayAttr &reductionByref, ArrayAttr &reductionSyms)
static ParseResult parseCaptureType(OpAsmParser &parser, VariableCaptureKindAttr &mapCaptureType)
static ParseResult parseTaskReductionRegion(OpAsmParser &parser, Region ®ion, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &taskReductionVars, SmallVectorImpl< Type > &taskReductionTypes, DenseBoolArrayAttr &taskReductionByref, ArrayAttr &taskReductionSyms)
static ParseResult parseGrainsizeClause(OpAsmParser &parser, ClauseGrainsizeTypeAttr &grainsizeMod, std::optional< OpAsmParser::UnresolvedOperand > &grainsize, Type &grainsizeType)
static ParseResult parseCopyprivate(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > ©privateVars, SmallVectorImpl< Type > ©privateTypes, ArrayAttr ©privateSyms)
copyprivate-entry-list ::= copyprivate-entry | copyprivate-entry-list , copyprivate-entry copyprivate...
static LogicalResult verifyDependVarList(Operation *op, std::optional< ArrayAttr > dependKinds, OperandRange dependVars)
Verifies Depend clause.
static ParseResult parseDependVarList(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &dependVars, SmallVectorImpl< Type > &dependTypes, ArrayAttr &dependKinds)
depend-entry-list ::= depend-entry | depend-entry-list , depend-entry depend-entry ::= depend-kind ->...
static LogicalResult verifyMapInfoDefinedArgs(Operation *op, StringRef clauseName, OperandRange vars)
static void printGrainsizeClause(OpAsmPrinter &p, Operation *op, ClauseGrainsizeTypeAttr grainsizeMod, Value grainsize, mlir::Type grainsizeType)
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
This base class exposes generic asm parser hooks, usable across the various derived parsers.
@ None
Zero or more operands with no delimiters.
virtual ParseResult parseColonTypeList(SmallVectorImpl< Type > &result)=0
Parse a colon followed by a type list, which must have at least one type.
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalEqual()=0
Parse a = token if present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseOptionalColon()=0
Parse a : token if present.
virtual ParseResult parseLSquare()=0
Parse a [ token.
virtual ParseResult parseRSquare()=0
Parse a ] token.
ParseResult parseInteger(IntT &result)
Parse an integer value from the stream.
virtual ParseResult parseOptionalArrow()=0
Parse a '->' token if present.
ParseResult addTypeToList(Type type, SmallVectorImpl< Type > &result)
Add the specified type to the end of the specified type list and return success.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual ParseResult parseColon()=0
Parse a : token.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseArrow()=0
Parse a '->' token.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseComma()=0
Parse a , token.
virtual ParseResult parseOptionalLParen()=0
Parse a ( token if present.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
Attributes are known-constant values of operations.
Block represents an ordered list of Operations.
ValueTypeRange< BlockArgListType > getArgumentTypes()
Return a range containing the types of the arguments for this block.
BlockArgument getArgument(unsigned i)
unsigned getNumArguments()
SuccessorRange getSuccessors()
BlockArgListType getArguments()
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
IntegerAttr getIntegerAttr(Type type, int64_t value)
IntegerType getIntegerType(unsigned width)
MLIRContext * getContext() const
Diagnostic & append(Arg1 &&arg1, Arg2 &&arg2, Args &&...args)
Append arguments to the diagnostic.
Diagnostic & appendOp(Operation &op, const OpPrintingFlags &flags)
Append an operation with the given printing flags.
Dialects are groups of MLIR operations, types and attributes, as well as behavior associated with the...
A class for computing basic dominance information.
bool dominates(Operation *a, Operation *b) const
Return true if operation A dominates operation B, i.e.
This class represents a diagnostic that is inflight and set to be reported.
Diagnostic & attachNote(std::optional< Location > noteLoc=std::nullopt)
Attaches a note to this diagnostic.
MLIRContext is the top-level object for a collection of MLIR operations.
Dialect * getLoadedDialect(StringRef name)
Get a registered IR dialect with the given namespace.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region ®ion, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult parseArgument(Argument &result, bool allowType=false, bool allowAttrs=false)=0
Parse a single argument with the following syntax:
virtual ParseResult parseArgumentList(SmallVectorImpl< Argument > &result, Delimiter delimiter=Delimiter::None, bool allowType=false, bool allowAttrs=false)=0
Parse zero or more arguments with a specified surrounding delimiter.
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
virtual void printOperand(Value value)=0
Print implementations for various things an operation contains.
This class helps build Operations.
This class represents an operand of an operation.
Set of flags used to control the behavior of the various IR print methods (e.g.
This class provides the API for ops that are known to be isolated from above.
This class provides the API for ops that are known to be terminators.
This class indicates that the regions associated with this op don't have terminators.
This class implements the operand iterators for the Operation class.
Operation is the basic unit of execution within MLIR.
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Dialect * getDialect()
Return the dialect this operation is associated with, or nullptr if the associated dialect is not loa...
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
MLIRContext * getContext()
Return the context this operation is associated with.
unsigned getNumRegions()
Returns the number of regions held by this operation.
Location getLoc()
The source location the operation was defined or derived from.
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Block * getBlock()
Returns the operation block that contains this operation.
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
user_range getUsers()
Returns a range of all users.
Region * getParentRegion()
Returns the region to which the instruction belongs.
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class contains a list of basic blocks and a link to the parent operation it is attached to.
iterator_range< OpIterator > getOps()
BlockArgListType getArguments()
OpIterator op_begin()
Return iterators that walk the operations nested directly within this region.
Operation * getParentOp()
Return the parent operation this region is attached to.
unsigned getNumArguments()
BlockListType & getBlocks()
Location getLoc()
Return a location for this region.
BlockArgument getArgument(unsigned i)
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class represents a specific instance of an effect.
Resource * getResource() const
Return the resource that the effect applies to.
EffectT * getEffect() const
Return the effect being applied.
This class represents a collection of SymbolTables.
virtual Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
type_range getType() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
MLIRContext * getContext() const
Utility to get the associated MLIRContext that this value is defined in.
Type getType() const
Return the type of this value.
use_range getUses() const
Returns a range of all uses, which is useful for iterating over all uses.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
A utility result that is used to signal how to proceed with an ongoing walk:
Base class for DenseArrayAttr that is instantiated and specialized for each supported element type be...
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< T > content)
Builder from ArrayRef<T>.
bool isReachableFromEntry(Block *a) const
Return true if the specified block is reachable from the entry block of its region.
Operation * getOwner() const
Return the owner of this operand.
SmallVector< SmallVector< AffineForOp, 8 >, 8 > tile(ArrayRef< AffineForOp > forOps, ArrayRef< uint64_t > sizes, ArrayRef< AffineForOp > targets)
Performs tiling fo imperfectly nested loops (with interchange) by strip-mining the forOps by sizes an...
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Runtime
Potential runtimes for AMD GPU kernels.
std::tuple< NewCliOp, OpOperand *, OpOperand * > decodeCli(mlir::Value cli)
Find the omp.new_cli, generator, and consumer of a canonical loop info.
TargetEnterDataOperands TargetEnterExitUpdateDataOperands
omp.target_enter_data, omp.target_exit_data and omp.target_update take the same clauses,...
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
This is the representation of an operand reference.
This class provides APIs and verifiers for ops with regions having a single block.
This represents an operation in an abstracted form, suitable for use with the builder APIs.
SmallVector< Value, 4 > operands
void addOperands(ValueRange newOperands)
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
void addTypes(ArrayRef< Type > newTypes)
SmallVector< Type, 4 > types
Types of the results of this operation.
Region * addRegion()
Create a region that should be attached to the operation.