24 #include "llvm/ADT/ArrayRef.h"
25 #include "llvm/ADT/BitVector.h"
26 #include "llvm/ADT/STLExtras.h"
27 #include "llvm/ADT/STLForwardCompat.h"
28 #include "llvm/ADT/SmallString.h"
29 #include "llvm/ADT/StringExtras.h"
30 #include "llvm/ADT/StringRef.h"
31 #include "llvm/ADT/TypeSwitch.h"
32 #include "llvm/Frontend/OpenMP/OMPConstants.h"
38 #include "mlir/Dialect/OpenMP/OpenMPOpsDialect.cpp.inc"
39 #include "mlir/Dialect/OpenMP/OpenMPOpsEnums.cpp.inc"
40 #include "mlir/Dialect/OpenMP/OpenMPOpsInterfaces.cpp.inc"
41 #include "mlir/Dialect/OpenMP/OpenMPTypeInterfaces.cpp.inc"
57 struct MemRefPointerLikeModel
58 :
public PointerLikeType::ExternalModel<MemRefPointerLikeModel,
61 return llvm::cast<MemRefType>(pointer).getElementType();
65 struct LLVMPointerPointerLikeModel
66 :
public PointerLikeType::ExternalModel<LLVMPointerPointerLikeModel,
67 LLVM::LLVMPointerType> {
72 void OpenMPDialect::initialize() {
75 #include "mlir/Dialect/OpenMP/OpenMPOps.cpp.inc"
78 #define GET_ATTRDEF_LIST
79 #include "mlir/Dialect/OpenMP/OpenMPOpsAttributes.cpp.inc"
82 #define GET_TYPEDEF_LIST
83 #include "mlir/Dialect/OpenMP/OpenMPOpsTypes.cpp.inc"
86 MemRefType::attachInterface<MemRefPointerLikeModel>(*
getContext());
87 LLVM::LLVMPointerType::attachInterface<LLVMPointerPointerLikeModel>(
92 mlir::ModuleOp::attachInterface<mlir::omp::OffloadModuleDefaultModel>(
98 mlir::LLVM::GlobalOp::attachInterface<
101 mlir::LLVM::LLVMFuncOp::attachInterface<
104 mlir::func::FuncOp::attachInterface<
130 allocatorVars.push_back(operand);
131 allocatorTypes.push_back(type);
137 allocateVars.push_back(operand);
138 allocateTypes.push_back(type);
149 for (
unsigned i = 0; i < allocateVars.size(); ++i) {
150 std::string separator = i == allocateVars.size() - 1 ?
"" :
", ";
151 p << allocatorVars[i] <<
" : " << allocatorTypes[i] <<
" -> ";
152 p << allocateVars[i] <<
" : " << allocateTypes[i] << separator;
160 template <
typename ClauseAttr>
162 using ClauseT = decltype(std::declval<ClauseAttr>().getValue());
167 if (std::optional<ClauseT> enumValue = symbolizeEnum<ClauseT>(enumStr)) {
171 return parser.
emitError(loc,
"invalid clause value: '") << enumStr <<
"'";
174 template <
typename ClauseAttr>
176 p << stringifyEnum(attr.getValue());
199 linearVars.push_back(var);
200 linearTypes.push_back(type);
201 linearStepVars.push_back(stepVar);
210 size_t linearVarsSize = linearVars.size();
211 for (
unsigned i = 0; i < linearVarsSize; ++i) {
212 std::string separator = i == linearVarsSize - 1 ?
"" :
", ";
214 if (linearStepVars.size() > i)
215 p <<
" = " << linearStepVars[i];
216 p <<
" : " << linearVars[i].getType() << separator;
229 for (
const auto &it : nontemporalVars)
230 if (!nontemporalItems.insert(it).second)
231 return op->
emitOpError() <<
"nontemporal variable used more than once";
240 std::optional<ArrayAttr> alignments,
243 if (!alignedVars.empty()) {
244 if (!alignments || alignments->size() != alignedVars.size())
246 <<
"expected as many alignment values as aligned variables";
249 return op->
emitOpError() <<
"unexpected alignment values attribute";
255 for (
auto it : alignedVars)
256 if (!alignedItems.insert(it).second)
257 return op->
emitOpError() <<
"aligned variable used more than once";
263 for (
unsigned i = 0; i < (*alignments).size(); ++i) {
264 if (
auto intAttr = llvm::dyn_cast<IntegerAttr>((*alignments)[i])) {
265 if (intAttr.getValue().sle(0))
266 return op->
emitOpError() <<
"alignment should be greater than 0";
268 return op->
emitOpError() <<
"expected integer alignment";
282 ArrayAttr &alignmentsAttr) {
285 if (parser.parseOperand(alignedVars.emplace_back()) ||
286 parser.parseColonType(alignedTypes.emplace_back()) ||
287 parser.parseArrow() ||
288 parser.parseAttribute(alignmentVec.emplace_back())) {
302 std::optional<ArrayAttr> alignments) {
303 for (
unsigned i = 0; i < alignedVars.size(); ++i) {
306 p << alignedVars[i] <<
" : " << alignedVars[i].
getType();
307 p <<
" -> " << (*alignments)[i];
318 if (modifiers.size() > 2)
320 for (
const auto &mod : modifiers) {
323 auto symbol = symbolizeScheduleModifier(mod);
326 <<
" unknown modifier type: " << mod;
331 if (modifiers.size() == 1) {
332 if (symbolizeScheduleModifier(modifiers[0]) == ScheduleModifier::simd) {
333 modifiers.push_back(modifiers[0]);
334 modifiers[0] = stringifyScheduleModifier(ScheduleModifier::none);
336 }
else if (modifiers.size() == 2) {
339 if (symbolizeScheduleModifier(modifiers[0]) == ScheduleModifier::simd ||
340 symbolizeScheduleModifier(modifiers[1]) != ScheduleModifier::simd)
342 <<
" incorrect modifier order";
358 ScheduleModifierAttr &scheduleMod, UnitAttr &scheduleSimd,
359 std::optional<OpAsmParser::UnresolvedOperand> &chunkSize,
364 std::optional<mlir::omp::ClauseScheduleKind> schedule =
365 symbolizeClauseScheduleKind(keyword);
371 case ClauseScheduleKind::Static:
372 case ClauseScheduleKind::Dynamic:
373 case ClauseScheduleKind::Guided:
379 chunkSize = std::nullopt;
382 case ClauseScheduleKind::Auto:
384 chunkSize = std::nullopt;
393 modifiers.push_back(mod);
399 if (!modifiers.empty()) {
401 if (std::optional<ScheduleModifier> mod =
402 symbolizeScheduleModifier(modifiers[0])) {
405 return parser.
emitError(loc,
"invalid schedule modifier");
408 if (modifiers.size() > 1) {
409 assert(symbolizeScheduleModifier(modifiers[1]) == ScheduleModifier::simd);
419 ClauseScheduleKindAttr scheduleKind,
420 ScheduleModifierAttr scheduleMod,
421 UnitAttr scheduleSimd,
Value scheduleChunk,
422 Type scheduleChunkType) {
423 p << stringifyClauseScheduleKind(scheduleKind.getValue());
425 p <<
" = " << scheduleChunk <<
" : " << scheduleChunk.
getType();
427 p <<
", " << stringifyScheduleModifier(scheduleMod.getValue());
439 ClauseOrderKindAttr &order,
440 OrderModifierAttr &orderMod) {
445 if (std::optional<OrderModifier> enumValue =
446 symbolizeOrderModifier(enumStr)) {
454 if (std::optional<ClauseOrderKind> enumValue =
455 symbolizeClauseOrderKind(enumStr)) {
459 return parser.
emitError(loc,
"invalid clause value: '") << enumStr <<
"'";
463 ClauseOrderKindAttr order,
464 OrderModifierAttr orderMod) {
466 p << stringifyOrderModifier(orderMod.getValue()) <<
":";
468 p << stringifyClauseOrderKind(order.getValue());
482 unsigned regionArgOffset = regionPrivateArgs.size();
486 ParseResult optionalByref = parser.parseOptionalKeyword(
"byref");
487 if (parser.parseAttribute(reductionVec.emplace_back()) ||
488 parser.parseOperand(operands.emplace_back()) ||
489 parser.parseArrow() ||
490 parser.parseArgument(regionPrivateArgs.emplace_back()) ||
491 parser.parseColonType(types.emplace_back()))
493 isByRefVec.push_back(optionalByref.succeeded());
499 auto *argsBegin = regionPrivateArgs.begin();
501 argsBegin + regionArgOffset + types.size());
502 for (
auto [prv, type] : llvm::zip_equal(argsSubrange, types)) {
515 if (!clauseName.empty())
516 p << clauseName <<
"(";
518 llvm::interleaveComma(llvm::zip_equal(symbols, operands, argsSubrange, types,
521 auto [sym, op, arg, type, isByRef] = t;
522 p << (isByRef ?
"byref " :
"") << sym <<
" " << op
523 <<
" -> " << arg <<
" : " << type;
526 if (!clauseName.empty())
534 ArrayAttr &reductionSyms,
541 reductionTypes, reductionByref,
542 reductionSyms, regionPrivateArgs)))
549 privateTypes, privateByref,
550 privateSyms, regionPrivateArgs)))
552 if (llvm::any_of(privateByref.asArrayRef(),
553 [](
bool byref) { return byref; })) {
555 "private clause cannot have byref attributes");
560 return parser.
parseRegion(region, regionPrivateArgs);
567 ArrayAttr reductionSyms,
ValueRange privateVars,
568 TypeRange privateTypes, ArrayAttr privateSyms) {
571 MutableArrayRef argsSubrange(argsBegin, argsBegin + reductionTypes.size());
573 reductionTypes, reductionByref, reductionSyms);
579 argsBegin + reductionVars.size() +
580 privateTypes.size());
582 isByRefVec.resize(privateTypes.size(),
false);
587 privateTypes, isByRef, privateSyms);
600 ArrayAttr &reductionSyms) {
604 ParseResult optionalByref = parser.parseOptionalKeyword(
"byref");
605 if (parser.parseAttribute(reductionVec.emplace_back()) ||
606 parser.parseArrow() ||
607 parser.parseOperand(reductionVars.emplace_back()) ||
608 parser.parseColonType(reductionTypes.emplace_back()))
610 isByRefVec.push_back(optionalByref.succeeded());
624 std::optional<DenseBoolArrayAttr> reductionByref,
625 std::optional<ArrayAttr> reductionSyms) {
626 auto getByRef = [&](
unsigned i) ->
const char * {
627 if (!reductionByref || !*reductionByref)
629 assert(reductionByref->empty() || i < reductionByref->size());
630 if (!reductionByref->empty() && (*reductionByref)[i])
635 for (
unsigned i = 0, e = reductionVars.size(); i < e; ++i) {
638 p << getByRef(i) << (*reductionSyms)[i] <<
" -> " << reductionVars[i]
639 <<
" : " << reductionVars[i].
getType();
648 if (!reductionVars.empty()) {
649 if (!reductionSyms || reductionSyms->size() != reductionVars.size())
651 <<
"expected as many reduction symbol references "
652 "as reduction variables";
653 if (reductionByref && reductionByref->size() != reductionVars.size())
654 return op->
emitError() <<
"expected as many reduction variable by "
655 "reference attributes as reduction variables";
658 return op->
emitOpError() <<
"unexpected reduction symbol references";
665 for (
auto args : llvm::zip(reductionVars, *reductionSyms)) {
666 Value accum = std::get<0>(args);
668 if (!accumulators.insert(accum).second)
669 return op->
emitOpError() <<
"accumulator variable used more than once";
672 auto symbolRef = llvm::cast<SymbolRefAttr>(std::get<1>(args));
674 SymbolTable::lookupNearestSymbolFrom<DeclareReductionOp>(op, symbolRef);
676 return op->
emitOpError() <<
"expected symbol reference " << symbolRef
677 <<
" to point to a reduction declaration";
679 if (decl.getAccumulatorType() && decl.getAccumulatorType() != varType)
681 <<
"expected accumulator (" << varType
682 <<
") to be the same type as reduction declaration ("
683 << decl.getAccumulatorType() <<
")";
702 if (parser.parseOperand(copyprivateVars.emplace_back()) ||
703 parser.parseArrow() ||
704 parser.parseAttribute(symsVec.emplace_back()) ||
705 parser.parseColonType(copyprivateTypes.emplace_back()))
719 std::optional<ArrayAttr> copyprivateSyms) {
720 if (!copyprivateSyms.has_value())
722 llvm::interleaveComma(
723 llvm::zip(copyprivateVars, *copyprivateSyms, copyprivateTypes), p,
724 [&](
const auto &args) {
725 p << std::get<0>(args) <<
" -> " << std::get<1>(args) <<
" : "
726 << std::get<2>(args);
733 std::optional<ArrayAttr> copyprivateSyms) {
734 size_t copyprivateSymsSize =
735 copyprivateSyms.has_value() ? copyprivateSyms->size() : 0;
736 if (copyprivateSymsSize != copyprivateVars.size())
737 return op->
emitOpError() <<
"inconsistent number of copyprivate vars (= "
738 << copyprivateVars.size()
739 <<
") and functions (= " << copyprivateSymsSize
740 <<
"), both must be equal";
741 if (!copyprivateSyms.has_value())
744 for (
auto copyprivateVarAndSym :
745 llvm::zip(copyprivateVars, *copyprivateSyms)) {
747 llvm::cast<SymbolRefAttr>(std::get<1>(copyprivateVarAndSym));
748 std::optional<std::variant<mlir::func::FuncOp, mlir::LLVM::LLVMFuncOp>>
750 if (mlir::func::FuncOp mlirFuncOp =
751 SymbolTable::lookupNearestSymbolFrom<mlir::func::FuncOp>(op,
754 else if (mlir::LLVM::LLVMFuncOp llvmFuncOp =
755 SymbolTable::lookupNearestSymbolFrom<mlir::LLVM::LLVMFuncOp>(
759 auto getNumArguments = [&] {
760 return std::visit([](
auto &f) {
return f.getNumArguments(); }, *funcOp);
763 auto getArgumentType = [&](
unsigned i) {
764 return std::visit([i](
auto &f) {
return f.getArgumentTypes()[i]; },
769 return op->
emitOpError() <<
"expected symbol reference " << symbolRef
770 <<
" to point to a copy function";
772 if (getNumArguments() != 2)
774 <<
"expected copy function " << symbolRef <<
" to have 2 operands";
776 Type argTy = getArgumentType(0);
777 if (argTy != getArgumentType(1))
778 return op->
emitOpError() <<
"expected copy function " << symbolRef
779 <<
" arguments to have the same type";
781 Type varType = std::get<0>(copyprivateVarAndSym).getType();
782 if (argTy != varType)
784 <<
"expected copy function arguments' type (" << argTy
785 <<
") to be the same as copyprivate variable's type (" << varType
806 if (parser.parseKeyword(&keyword) || parser.parseArrow() ||
807 parser.parseOperand(dependVars.emplace_back()) ||
808 parser.parseColonType(dependTypes.emplace_back()))
810 if (std::optional<ClauseTaskDepend> keywordDepend =
811 (symbolizeClauseTaskDepend(keyword)))
812 kindsVec.emplace_back(
813 ClauseTaskDependAttr::get(parser.getContext(), *keywordDepend));
827 std::optional<ArrayAttr> dependKinds) {
829 for (
unsigned i = 0, e = dependKinds->size(); i < e; ++i) {
832 p << stringifyClauseTaskDepend(
833 llvm::cast<mlir::omp::ClauseTaskDependAttr>((*dependKinds)[i])
835 <<
" -> " << dependVars[i] <<
" : " << dependTypes[i];
841 std::optional<ArrayAttr> dependKinds,
843 if (!dependVars.empty()) {
844 if (!dependKinds || dependKinds->size() != dependVars.size())
845 return op->
emitOpError() <<
"expected as many depend values"
846 " as depend variables";
848 if (dependKinds && !dependKinds->empty())
849 return op->
emitOpError() <<
"unexpected depend values";
865 IntegerAttr &hintAttr) {
866 StringRef hintKeyword;
872 auto parseKeyword = [&]() -> ParseResult {
875 if (hintKeyword ==
"uncontended")
877 else if (hintKeyword ==
"contended")
879 else if (hintKeyword ==
"nonspeculative")
881 else if (hintKeyword ==
"speculative")
885 << hintKeyword <<
" is not a valid hint";
896 IntegerAttr hintAttr) {
897 int64_t hint = hintAttr.getInt();
905 auto bitn = [](
int value,
int n) ->
bool {
return value & (1 << n); };
907 bool uncontended = bitn(hint, 0);
908 bool contended = bitn(hint, 1);
909 bool nonspeculative = bitn(hint, 2);
910 bool speculative = bitn(hint, 3);
914 hints.push_back(
"uncontended");
916 hints.push_back(
"contended");
918 hints.push_back(
"nonspeculative");
920 hints.push_back(
"speculative");
922 llvm::interleaveComma(hints, p);
929 auto bitn = [](
int value,
int n) ->
bool {
return value & (1 << n); };
931 bool uncontended = bitn(hint, 0);
932 bool contended = bitn(hint, 1);
933 bool nonspeculative = bitn(hint, 2);
934 bool speculative = bitn(hint, 3);
936 if (uncontended && contended)
937 return op->
emitOpError() <<
"the hints omp_sync_hint_uncontended and "
938 "omp_sync_hint_contended cannot be combined";
939 if (nonspeculative && speculative)
940 return op->
emitOpError() <<
"the hints omp_sync_hint_nonspeculative and "
941 "omp_sync_hint_speculative cannot be combined.";
951 llvm::omp::OpenMPOffloadMappingFlags flag) {
952 return value & llvm::to_underlying(flag);
961 llvm::omp::OpenMPOffloadMappingFlags mapTypeBits =
962 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_NONE;
966 auto parseTypeAndMod = [&]() -> ParseResult {
967 StringRef mapTypeMod;
971 if (mapTypeMod ==
"always")
972 mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ALWAYS;
974 if (mapTypeMod ==
"implicit")
975 mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_IMPLICIT;
977 if (mapTypeMod ==
"close")
978 mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_CLOSE;
980 if (mapTypeMod ==
"present")
981 mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_PRESENT;
983 if (mapTypeMod ==
"to")
984 mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO;
986 if (mapTypeMod ==
"from")
987 mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM;
989 if (mapTypeMod ==
"tofrom")
990 mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO |
991 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM;
993 if (mapTypeMod ==
"delete")
994 mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_DELETE;
1004 llvm::to_underlying(mapTypeBits));
1012 IntegerAttr mapType) {
1013 uint64_t mapTypeBits = mapType.getUInt();
1015 bool emitAllocRelease =
true;
1021 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ALWAYS))
1022 mapTypeStrs.push_back(
"always");
1024 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_IMPLICIT))
1025 mapTypeStrs.push_back(
"implicit");
1027 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_CLOSE))
1028 mapTypeStrs.push_back(
"close");
1030 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_PRESENT))
1031 mapTypeStrs.push_back(
"present");
1037 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO);
1039 mapTypeBits, llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM);
1041 emitAllocRelease =
false;
1042 mapTypeStrs.push_back(
"tofrom");
1044 emitAllocRelease =
false;
1045 mapTypeStrs.push_back(
"from");
1047 emitAllocRelease =
false;
1048 mapTypeStrs.push_back(
"to");
1051 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_DELETE)) {
1052 emitAllocRelease =
false;
1053 mapTypeStrs.push_back(
"delete");
1055 if (emitAllocRelease)
1056 mapTypeStrs.push_back(
"exit_release_or_enter_alloc");
1058 for (
unsigned int i = 0; i < mapTypeStrs.size(); ++i) {
1059 p << mapTypeStrs[i];
1060 if (i + 1 < mapTypeStrs.size()) {
1070 int64_t shape[2] = {0, 0};
1071 unsigned shapeTmp = 0;
1072 auto parseIndices = [&]() -> ParseResult {
1076 values.push_back(APInt(32, value));
1093 shape[1] = shapeTmp;
1099 if (shapeTmp != shape[1])
1106 if (!values.empty()) {
1107 ShapedType valueType =
1118 assert(shape.size() <= 2);
1123 for (
int i = 0; i < shape[0]; ++i) {
1125 int rowOffset = i * shape[1];
1126 for (
int j = 0;
j < shape[1]; ++
j) {
1127 p << membersIdx.getValues<int32_t>()[rowOffset +
j];
1128 if ((
j + 1) < shape[1])
1133 if ((i + 1) < shape[0])
1145 auto parseEntries = [&]() -> ParseResult {
1150 mapVars.push_back(arg);
1154 auto parseTypes = [&]() -> ParseResult {
1157 mapTypes.push_back(argType);
1179 unsigned argIndex = 0;
1181 for (
const auto &mapOp : mapVars) {
1184 const auto &blockArg = entryBlock->
getArgument(argIndex);
1185 p <<
" -> " << blockArg;
1188 if (argIndex < mapVars.size())
1194 for (
const auto &mapType : mapTypes) {
1197 if (argIndex < mapVars.size())
1210 if (parser.parseAttribute(privateSymRefs.emplace_back()) ||
1211 parser.parseOperand(privateVars.emplace_back()) ||
1212 parser.parseArrow() ||
1213 parser.parseArgument(regionPrivateArgs.emplace_back()) ||
1214 parser.parseColonType(privateTypes.emplace_back()))
1221 privateSymRefs.end());
1229 ArrayAttr privateSyms) {
1231 auto targetOp = mlir::dyn_cast<mlir::omp::TargetOp>(op);
1236 MutableArrayRef argsSubrange(argsBegin + targetOp.getMapVars().size(),
1237 argsBegin + targetOp.getMapVars().size() +
1238 privateTypes.size());
1240 isByRefVec.resize(privateTypes.size(),
false);
1245 llvm::StringRef{}, privateVars,
1246 privateTypes, isByRef, privateSyms);
1250 VariableCaptureKindAttr mapCaptureType) {
1251 std::string typeCapStr;
1252 llvm::raw_string_ostream typeCap(typeCapStr);
1253 if (mapCaptureType.getValue() == mlir::omp::VariableCaptureKind::ByRef)
1255 if (mapCaptureType.getValue() == mlir::omp::VariableCaptureKind::ByCopy)
1256 typeCap <<
"ByCopy";
1257 if (mapCaptureType.getValue() == mlir::omp::VariableCaptureKind::VLAType)
1258 typeCap <<
"VLAType";
1259 if (mapCaptureType.getValue() == mlir::omp::VariableCaptureKind::This)
1265 VariableCaptureKindAttr &mapCaptureType) {
1266 StringRef mapCaptureKey;
1270 if (mapCaptureKey ==
"This")
1272 parser.
getContext(), mlir::omp::VariableCaptureKind::This);
1273 if (mapCaptureKey ==
"ByRef")
1275 parser.
getContext(), mlir::omp::VariableCaptureKind::ByRef);
1276 if (mapCaptureKey ==
"ByCopy")
1278 parser.
getContext(), mlir::omp::VariableCaptureKind::ByCopy);
1279 if (mapCaptureKey ==
"VLAType")
1281 parser.
getContext(), mlir::omp::VariableCaptureKind::VLAType);
1290 for (
auto mapOp : mapVars) {
1291 if (!mapOp.getDefiningOp())
1294 if (
auto mapInfoOp =
1295 mlir::dyn_cast<mlir::omp::MapInfoOp>(mapOp.getDefiningOp())) {
1296 if (!mapInfoOp.getMapType().has_value())
1299 if (!mapInfoOp.getMapCaptureType().has_value())
1302 uint64_t mapTypeBits = mapInfoOp.getMapType().value();
1305 mapTypeBits, llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO);
1307 mapTypeBits, llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM);
1309 mapTypeBits, llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_DELETE);
1312 mapTypeBits, llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ALWAYS);
1314 mapTypeBits, llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_CLOSE);
1316 mapTypeBits, llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_IMPLICIT);
1318 if ((isa<TargetDataOp>(op) || isa<TargetOp>(op)) && del)
1320 "to, from, tofrom and alloc map types are permitted");
1322 if (isa<TargetEnterDataOp>(op) && (from || del))
1323 return emitError(op->
getLoc(),
"to and alloc map types are permitted");
1325 if (isa<TargetExitDataOp>(op) && to)
1327 "from, release and delete map types are permitted");
1329 if (isa<TargetUpdateOp>(op)) {
1332 "at least one of to or from map types must be "
1333 "specified, other map types are not permitted");
1338 "at least one of to or from map types must be "
1339 "specified, other map types are not permitted");
1342 auto updateVar = mapInfoOp.getVarPtr();
1344 if ((to && from) || (to && updateFromVars.contains(updateVar)) ||
1345 (from && updateToVars.contains(updateVar))) {
1348 "either to or from map types can be specified, not both");
1351 if (always || close || implicit) {
1354 "present, mapper and iterator map type modifiers are permitted");
1357 to ? updateToVars.insert(updateVar) : updateFromVars.insert(updateVar);
1373 TargetDataOp::build(builder, state, clauses.device, clauses.ifVar,
1374 clauses.mapVars, clauses.useDeviceAddrVars,
1375 clauses.useDevicePtrVars);
1379 if (getMapVars().empty() && getUseDevicePtrVars().empty() &&
1380 getUseDeviceAddrVars().empty()) {
1382 "At least one of map, use_device_ptr_vars, or "
1383 "use_device_addr_vars operand must be present");
1392 void TargetEnterDataOp::build(
1396 TargetEnterDataOp::build(builder, state,
1398 clauses.dependVars, clauses.device, clauses.ifVar,
1399 clauses.mapVars, clauses.nowait);
1403 LogicalResult verifyDependVars =
1405 return failed(verifyDependVars) ? verifyDependVars
1416 TargetExitDataOp::build(builder, state,
1418 clauses.dependVars, clauses.device, clauses.ifVar,
1419 clauses.mapVars, clauses.nowait);
1423 LogicalResult verifyDependVars =
1425 return failed(verifyDependVars) ? verifyDependVars
1436 TargetUpdateOp::build(builder, state,
makeArrayAttr(ctx, clauses.dependKinds),
1437 clauses.dependVars, clauses.device, clauses.ifVar,
1438 clauses.mapVars, clauses.nowait);
1442 LogicalResult verifyDependVars =
1444 return failed(verifyDependVars) ? verifyDependVars
1457 TargetOp::build(builder, state, {}, {},
1458 makeArrayAttr(ctx, clauses.dependKinds), clauses.dependVars,
1459 clauses.device, clauses.hasDeviceAddrVars, clauses.ifVar,
1461 nullptr, clauses.isDevicePtrVars,
1462 clauses.mapVars, clauses.nowait, clauses.privateVars,
1463 makeArrayAttr(ctx, clauses.privateSyms), clauses.threadLimit);
1467 LogicalResult verifyDependVars =
1469 return failed(verifyDependVars) ? verifyDependVars
1479 ParallelOp::build(builder, state,
ValueRange(),
1485 state.addAttributes(attributes);
1492 ParallelOp::build(builder, state, clauses.allocateVars, clauses.allocatorVars,
1493 clauses.ifVar, clauses.numThreads, clauses.privateVars,
1495 clauses.procBindKind, clauses.reductionVars,
1500 template <
typename OpType>
1502 auto privateVars = op.getPrivateVars();
1503 auto privateSyms = op.getPrivateSymsAttr();
1505 if (privateVars.empty() && (privateSyms ==
nullptr || privateSyms.empty()))
1508 auto numPrivateVars = privateVars.size();
1509 auto numPrivateSyms = (privateSyms ==
nullptr) ? 0 : privateSyms.size();
1511 if (numPrivateVars != numPrivateSyms)
1512 return op.
emitError() <<
"inconsistent number of private variables and "
1513 "privatizer op symbols, private vars: "
1515 <<
" vs. privatizer op symbols: " << numPrivateSyms;
1517 for (
auto privateVarInfo : llvm::zip_equal(privateVars, privateSyms)) {
1518 Type varType = std::get<0>(privateVarInfo).getType();
1519 SymbolRefAttr privateSym = cast<SymbolRefAttr>(std::get<1>(privateVarInfo));
1520 PrivateClauseOp privatizerOp =
1521 SymbolTable::lookupNearestSymbolFrom<PrivateClauseOp>(op, privateSym);
1523 if (privatizerOp ==
nullptr)
1524 return op.
emitError() <<
"failed to lookup privatizer op with symbol: '"
1525 << privateSym <<
"'";
1527 Type privatizerType = privatizerOp.getType();
1529 if (varType != privatizerType)
1531 <<
"type mismatch between a "
1532 << (privatizerOp.getDataSharingType() ==
1533 DataSharingClauseType::Private
1536 <<
" variable and its privatizer op, var type: " << varType
1537 <<
" vs. privatizer op type: " << privatizerType;
1544 auto distributeChildOps = getOps<DistributeOp>();
1545 if (!distributeChildOps.empty()) {
1548 <<
"'omp.composite' attribute missing from composite operation";
1551 Operation &distributeOp = **distributeChildOps.begin();
1553 if (&childOp == &distributeOp || ompDialect != childOp.getDialect())
1557 return emitError() <<
"unexpected OpenMP operation inside of composite "
1560 }
else if (isComposite()) {
1562 <<
"'omp.composite' attribute present in non-composite operation";
1565 if (getAllocateVars().size() != getAllocatorVars().size())
1567 "expected equal sizes for allocate and allocator variables");
1573 getReductionByref());
1591 TeamsOp::build(builder, state, clauses.allocateVars, clauses.allocatorVars,
1592 clauses.ifVar, clauses.numTeamsLower, clauses.numTeamsUpper,
1594 nullptr, clauses.reductionVars,
1597 clauses.threadLimit);
1609 return emitError(
"expected to be nested inside of omp.target or not nested "
1610 "in any OpenMP dialect operations");
1613 if (
auto numTeamsLowerBound = getNumTeamsLower()) {
1614 auto numTeamsUpperBound = getNumTeamsUpper();
1615 if (!numTeamsUpperBound)
1616 return emitError(
"expected num_teams upper bound to be defined if the "
1617 "lower bound is defined");
1618 if (numTeamsLowerBound.getType() != numTeamsUpperBound.getType())
1620 "expected num_teams upper bound and lower bound to be the same type");
1624 if (getAllocateVars().size() != getAllocatorVars().size())
1626 "expected equal sizes for allocate and allocator variables");
1629 getReductionByref());
1640 SectionsOp::build(builder, state, clauses.allocateVars, clauses.allocatorVars,
1642 nullptr, clauses.reductionVars,
1648 if (getAllocateVars().size() != getAllocatorVars().size())
1650 "expected equal sizes for allocate and allocator variables");
1653 getReductionByref());
1656 LogicalResult SectionsOp::verifyRegions() {
1657 for (
auto &inst : *getRegion().begin()) {
1658 if (!(isa<SectionOp>(inst) || isa<TerminatorOp>(inst))) {
1659 return emitOpError()
1660 <<
"expected omp.section op or terminator op inside region";
1675 SingleOp::build(builder, state, clauses.allocateVars, clauses.allocatorVars,
1676 clauses.copyprivateVars,
1677 makeArrayAttr(ctx, clauses.copyprivateSyms), clauses.nowait,
1683 if (getAllocateVars().size() != getAllocatorVars().size())
1685 "expected equal sizes for allocate and allocator variables");
1688 getCopyprivateSyms());
1704 reductionTypes, reductionByRef,
1705 reductionSymbols, privates)))
1714 if (reductionSymbols) {
1717 reductionOperands, reductionTypes, isByRef,
1725 return op->
emitOpError() <<
"loop wrapper contains multiple regions";
1729 return op->
emitOpError() <<
"loop wrapper contains multiple blocks";
1731 if (::llvm::range_size(region.
getOps()) != 2)
1733 <<
"loop wrapper does not contain exactly two nested ops";
1740 <<
"second nested op in loop wrapper is not a terminator";
1742 if (!::llvm::isa<LoopNestOp, LoopWrapperInterface>(firstOp))
1743 return op->
emitOpError() <<
"first nested op in loop wrapper is not "
1744 "another loop wrapper or `omp.loop_nest`";
1751 build(builder, state, {}, {},
1753 false,
nullptr,
nullptr,
1754 nullptr, {},
nullptr,
1759 state.addAttributes(attributes);
1769 {}, {}, clauses.linearVars,
1770 clauses.linearStepVars, clauses.nowait, clauses.order, clauses.orderMod,
1771 clauses.ordered, {},
nullptr,
1772 clauses.reductionVars,
1774 makeArrayAttr(ctx, clauses.reductionSyms), clauses.scheduleKind,
1775 clauses.scheduleChunk, clauses.scheduleMod, clauses.scheduleSimd);
1782 bool isCompositeChildLeaf =
1783 llvm::dyn_cast_if_present<LoopWrapperInterface>((*this)->getParentOp());
1785 if (LoopWrapperInterface nested = getNestedWrapper()) {
1788 <<
"'omp.composite' attribute missing from composite wrapper";
1792 if (!isa<SimdOp>(nested))
1793 return emitError() <<
"only supported nested wrapper is 'omp.simd'";
1795 }
else if (isComposite() && !isCompositeChildLeaf) {
1797 <<
"'omp.composite' attribute present in non-composite wrapper";
1798 }
else if (!isComposite() && isCompositeChildLeaf) {
1800 <<
"'omp.composite' attribute missing from composite wrapper";
1804 getReductionByref());
1816 SimdOp::build(builder, state, clauses.alignedVars,
1819 clauses.nontemporalVars, clauses.order, clauses.orderMod,
1822 nullptr, clauses.safelen, clauses.simdlen);
1826 if (getSimdlen().has_value() && getSafelen().has_value() &&
1827 getSimdlen().value() > getSafelen().value())
1828 return emitOpError()
1829 <<
"simdlen clause and safelen clause are both present, but the "
1830 "simdlen value is not less than or equal to safelen value";
1841 if (getNestedWrapper())
1842 return emitOpError() <<
"must wrap an 'omp.loop_nest' directly";
1844 bool isCompositeChildLeaf =
1845 llvm::dyn_cast_if_present<LoopWrapperInterface>((*this)->getParentOp());
1847 if (!isComposite() && isCompositeChildLeaf)
1849 <<
"'omp.composite' attribute missing from composite wrapper";
1851 if (isComposite() && !isCompositeChildLeaf)
1853 <<
"'omp.composite' attribute present in non-composite wrapper";
1865 DistributeOp::build(
1866 builder, state, clauses.allocateVars, clauses.allocatorVars,
1867 clauses.distScheduleStatic, clauses.distScheduleChunkSize, clauses.order,
1868 clauses.orderMod, {},
nullptr);
1872 if (this->getDistScheduleChunkSize() && !this->getDistScheduleStatic())
1873 return emitOpError() <<
"chunk size set without "
1874 "dist_schedule_static being present";
1876 if (getAllocateVars().size() != getAllocatorVars().size())
1878 "expected equal sizes for allocate and allocator variables");
1883 if (LoopWrapperInterface nested = getNestedWrapper()) {
1886 <<
"'omp.composite' attribute missing from composite wrapper";
1889 if (isa<WsloopOp>(nested)) {
1890 if (!llvm::dyn_cast_if_present<ParallelOp>((*this)->getParentOp()))
1891 return emitError() <<
"an 'omp.wsloop' nested wrapper is only allowed "
1892 "when 'omp.parallel' is the direct parent";
1893 }
else if (!isa<SimdOp>(nested))
1894 return emitError() <<
"only supported nested wrappers are 'omp.simd' and "
1896 }
else if (isComposite()) {
1898 <<
"'omp.composite' attribute present in non-composite wrapper";
1908 LogicalResult DeclareReductionOp::verifyRegions() {
1909 if (!getAllocRegion().empty()) {
1910 for (YieldOp yieldOp : getAllocRegion().getOps<YieldOp>()) {
1911 if (yieldOp.getResults().size() != 1 ||
1912 yieldOp.getResults().getTypes()[0] !=
getType())
1913 return emitOpError() <<
"expects alloc region to yield a value "
1914 "of the reduction type";
1918 if (getInitializerRegion().empty())
1919 return emitOpError() <<
"expects non-empty initializer region";
1920 Block &initializerEntryBlock = getInitializerRegion().
front();
1923 if (!getAllocRegion().empty())
1924 return emitOpError() <<
"expects two arguments to the initializer region "
1925 "when an allocation region is used";
1927 if (getAllocRegion().empty())
1928 return emitOpError() <<
"expects one argument to the initializer region "
1929 "when no allocation region is used";
1931 return emitOpError()
1932 <<
"expects one or two arguments to the initializer region";
1936 if (arg.getType() !=
getType())
1937 return emitOpError() <<
"expects initializer region argument to match "
1938 "the reduction type";
1940 for (YieldOp yieldOp : getInitializerRegion().getOps<YieldOp>()) {
1941 if (yieldOp.getResults().size() != 1 ||
1942 yieldOp.getResults().getTypes()[0] !=
getType())
1943 return emitOpError() <<
"expects initializer region to yield a value "
1944 "of the reduction type";
1947 if (getReductionRegion().empty())
1948 return emitOpError() <<
"expects non-empty reduction region";
1949 Block &reductionEntryBlock = getReductionRegion().
front();
1954 return emitOpError() <<
"expects reduction region with two arguments of "
1955 "the reduction type";
1956 for (YieldOp yieldOp : getReductionRegion().getOps<YieldOp>()) {
1957 if (yieldOp.getResults().size() != 1 ||
1958 yieldOp.getResults().getTypes()[0] !=
getType())
1959 return emitOpError() <<
"expects reduction region to yield a value "
1960 "of the reduction type";
1963 if (!getAtomicReductionRegion().empty()) {
1964 Block &atomicReductionEntryBlock = getAtomicReductionRegion().
front();
1968 return emitOpError() <<
"expects atomic reduction region with two "
1969 "arguments of the same type";
1970 auto ptrType = llvm::dyn_cast<PointerLikeType>(
1973 (ptrType.getElementType() && ptrType.getElementType() !=
getType()))
1974 return emitOpError() <<
"expects atomic reduction region arguments to "
1975 "be accumulators containing the reduction type";
1978 if (getCleanupRegion().empty())
1980 Block &cleanupEntryBlock = getCleanupRegion().
front();
1983 return emitOpError() <<
"expects cleanup region with one argument "
1984 "of the reduction type";
1997 TaskOp::build(builder, state, clauses.allocateVars, clauses.allocatorVars,
1998 makeArrayAttr(ctx, clauses.dependKinds), clauses.dependVars,
1999 clauses.final, clauses.ifVar, clauses.inReductionVars,
2001 makeArrayAttr(ctx, clauses.inReductionSyms), clauses.mergeable,
2002 clauses.priority, {},
nullptr,
2007 LogicalResult verifyDependVars =
2009 return failed(verifyDependVars)
2012 getInReductionVars(),
2013 getInReductionByref());
2023 TaskgroupOp::build(builder, state, clauses.allocateVars,
2024 clauses.allocatorVars, clauses.taskReductionVars,
2031 getTaskReductionVars(),
2032 getTaskReductionByref());
2044 builder, state, clauses.allocateVars, clauses.allocatorVars,
2045 clauses.final, clauses.grainsize, clauses.ifVar, clauses.inReductionVars,
2047 makeArrayAttr(ctx, clauses.inReductionSyms), clauses.mergeable,
2048 clauses.nogroup, clauses.numTasks, clauses.priority, {},
2049 nullptr, clauses.reductionVars,
2056 getInReductionVars().end());
2057 allReductionNvars.insert(allReductionNvars.end(), getReductionVars().begin(),
2058 getReductionVars().end());
2059 return allReductionNvars;
2063 if (getAllocateVars().size() != getAllocatorVars().size())
2065 "expected equal sizes for allocate and allocator variables");
2067 getReductionVars(), getReductionByref())) ||
2069 getInReductionVars(),
2070 getInReductionByref())))
2073 if (!getReductionVars().empty() && getNogroup())
2074 return emitError(
"if a reduction clause is present on the taskloop "
2075 "directive, the nogroup clause must not be specified");
2076 for (
auto var : getReductionVars()) {
2077 if (llvm::is_contained(getInReductionVars(), var))
2078 return emitError(
"the same list item cannot appear in both a reduction "
2079 "and an in_reduction clause");
2082 if (getGrainsize() && getNumTasks()) {
2084 "the grainsize clause and num_tasks clause are mutually exclusive and "
2085 "may not appear on the same taskloop directive");
2091 if (LoopWrapperInterface nested = getNestedWrapper()) {
2094 <<
"'omp.composite' attribute missing from composite wrapper";
2098 if (!isa<SimdOp>(nested))
2099 return emitError() <<
"only supported nested wrapper is 'omp.simd'";
2100 }
else if (isComposite()) {
2102 <<
"'omp.composite' attribute present in non-composite wrapper";
2126 for (
auto &iv : ivs)
2127 iv.type = loopVarType;
2156 Region ®ion = getRegion();
2158 p <<
" (" << args <<
") : " << args[0].getType() <<
" = ("
2159 << getLoopLowerBounds() <<
") to (" << getLoopUpperBounds() <<
") ";
2160 if (getLoopInclusive())
2162 p <<
"step (" << getLoopSteps() <<
") ";
2168 LoopNestOp::build(builder, state, clauses.loopLowerBounds,
2169 clauses.loopUpperBounds, clauses.loopSteps,
2170 clauses.loopInclusive);
2174 if (getLoopLowerBounds().empty())
2175 return emitOpError() <<
"must represent at least one loop";
2177 if (getLoopLowerBounds().size() != getIVs().size())
2178 return emitOpError() <<
"number of range arguments and IVs do not match";
2180 for (
auto [lb, iv] : llvm::zip_equal(getLoopLowerBounds(), getIVs())) {
2181 if (lb.getType() != iv.getType())
2182 return emitOpError()
2183 <<
"range argument type does not match corresponding IV type";
2186 if (!llvm::dyn_cast_if_present<LoopWrapperInterface>((*this)->getParentOp()))
2187 return emitOpError() <<
"expects parent op to be a loop wrapper";
2192 void LoopNestOp::gatherWrappers(
2195 while (
auto wrapper =
2196 llvm::dyn_cast_if_present<LoopWrapperInterface>(parent)) {
2197 wrappers.push_back(wrapper);
2208 CriticalDeclareOp::build(builder, state, clauses.symName, clauses.hint);
2216 if (getNameAttr()) {
2217 SymbolRefAttr symbolRef = getNameAttr();
2221 return emitOpError() <<
"expected symbol reference " << symbolRef
2222 <<
" to point to a critical declaration";
2242 return op.
emitOpError() <<
"must be nested inside of a loop";
2246 if (
auto wsloopOp = dyn_cast<WsloopOp>(wrapper)) {
2247 IntegerAttr orderedAttr = wsloopOp.getOrderedAttr();
2249 return op.
emitOpError() <<
"the enclosing worksharing-loop region must "
2250 "have an ordered clause";
2252 if (hasRegion && orderedAttr.getInt() != 0)
2253 return op.
emitOpError() <<
"the enclosing loop's ordered clause must not "
2254 "have a parameter present";
2256 if (!hasRegion && orderedAttr.getInt() == 0)
2257 return op.
emitOpError() <<
"the enclosing loop's ordered clause must "
2258 "have a parameter present";
2259 }
else if (!isa<SimdOp>(wrapper)) {
2260 return op.
emitOpError() <<
"must be nested inside of a worksharing, simd "
2261 "or worksharing simd loop";
2268 OrderedOp::build(builder, state, clauses.doacrossDependType,
2269 clauses.doacrossNumLoops, clauses.doacrossDependVars);
2276 auto wrapper = (*this)->getParentOfType<WsloopOp>();
2277 if (!wrapper || *wrapper.getOrdered() != *getDoacrossNumLoops())
2278 return emitOpError() <<
"number of variables in depend clause does not "
2279 <<
"match number of iteration variables in the "
2287 OrderedRegionOp::build(builder, state, clauses.parLevelSimd);
2292 if (getParLevelSimd())
2305 TaskwaitOp::build(builder, state,
nullptr,
2314 if (verifyCommon().failed())
2315 return mlir::failure();
2317 if (
auto mo = getMemoryOrder()) {
2318 if (*mo == ClauseMemoryOrderKind::Acq_rel ||
2319 *mo == ClauseMemoryOrderKind::Release) {
2321 "memory-order must not be acq_rel or release for atomic reads");
2332 if (verifyCommon().failed())
2333 return mlir::failure();
2335 if (
auto mo = getMemoryOrder()) {
2336 if (*mo == ClauseMemoryOrderKind::Acq_rel ||
2337 *mo == ClauseMemoryOrderKind::Acquire) {
2339 "memory-order must not be acq_rel or acquire for atomic writes");
2349 LogicalResult AtomicUpdateOp::canonicalize(AtomicUpdateOp op,
2355 if (
Value writeVal = op.getWriteOpVal()) {
2357 op, op.getX(), writeVal, op.getHintAttr(), op.getMemoryOrderAttr());
2364 if (verifyCommon().failed())
2365 return mlir::failure();
2367 if (
auto mo = getMemoryOrder()) {
2368 if (*mo == ClauseMemoryOrderKind::Acq_rel ||
2369 *mo == ClauseMemoryOrderKind::Acquire) {
2371 "memory-order must not be acq_rel or acquire for atomic updates");
2378 LogicalResult AtomicUpdateOp::verifyRegions() {
return verifyRegionsCommon(); }
2384 AtomicReadOp AtomicCaptureOp::getAtomicReadOp() {
2385 if (
auto op = dyn_cast<AtomicReadOp>(getFirstOp()))
2387 return dyn_cast<AtomicReadOp>(getSecondOp());
2390 AtomicWriteOp AtomicCaptureOp::getAtomicWriteOp() {
2391 if (
auto op = dyn_cast<AtomicWriteOp>(getFirstOp()))
2393 return dyn_cast<AtomicWriteOp>(getSecondOp());
2396 AtomicUpdateOp AtomicCaptureOp::getAtomicUpdateOp() {
2397 if (
auto op = dyn_cast<AtomicUpdateOp>(getFirstOp()))
2399 return dyn_cast<AtomicUpdateOp>(getSecondOp());
2406 LogicalResult AtomicCaptureOp::verifyRegions() {
2407 if (verifyRegionsCommon().failed())
2408 return mlir::failure();
2410 if (getFirstOp()->getAttr(
"hint") || getSecondOp()->getAttr(
"hint"))
2412 "operations inside capture region must not have hint clause");
2414 if (getFirstOp()->getAttr(
"memory_order") ||
2415 getSecondOp()->getAttr(
"memory_order"))
2417 "operations inside capture region must not have memory_order clause");
2427 CancelOp::build(builder, state, clauses.cancelDirective, clauses.ifVar);
2431 ClauseCancellationConstructType cct = getCancelDirective();
2435 return emitOpError() <<
"must be used within a region supporting "
2439 if ((cct == ClauseCancellationConstructType::Parallel) &&
2440 !isa<ParallelOp>(parentOp)) {
2441 return emitOpError() <<
"cancel parallel must appear "
2442 <<
"inside a parallel region";
2444 if (cct == ClauseCancellationConstructType::Loop) {
2445 auto loopOp = dyn_cast<LoopNestOp>(parentOp);
2446 auto wsloopOp = llvm::dyn_cast_if_present<WsloopOp>(
2447 loopOp ? loopOp->getParentOp() :
nullptr);
2450 return emitOpError()
2451 <<
"cancel loop must appear inside a worksharing-loop region";
2453 if (wsloopOp.getNowaitAttr()) {
2454 return emitError() <<
"A worksharing construct that is canceled "
2455 <<
"must not have a nowait clause";
2457 if (wsloopOp.getOrderedAttr()) {
2458 return emitError() <<
"A worksharing construct that is canceled "
2459 <<
"must not have an ordered clause";
2462 }
else if (cct == ClauseCancellationConstructType::Sections) {
2463 if (!(isa<SectionsOp>(parentOp) || isa<SectionOp>(parentOp))) {
2464 return emitOpError() <<
"cancel sections must appear "
2465 <<
"inside a sections region";
2467 if (isa_and_nonnull<SectionsOp>(parentOp->
getParentOp()) &&
2468 cast<SectionsOp>(parentOp->
getParentOp()).getNowaitAttr()) {
2469 return emitError() <<
"A sections construct that is canceled "
2470 <<
"must not have a nowait clause";
2483 CancellationPointOp::build(builder, state, clauses.cancelDirective);
2487 ClauseCancellationConstructType cct = getCancelDirective();
2491 return emitOpError() <<
"must be used within a region supporting "
2492 "cancellation point directive";
2495 if ((cct == ClauseCancellationConstructType::Parallel) &&
2496 !(isa<ParallelOp>(parentOp))) {
2497 return emitOpError() <<
"cancellation point parallel must appear "
2498 <<
"inside a parallel region";
2500 if ((cct == ClauseCancellationConstructType::Loop) &&
2501 (!isa<LoopNestOp>(parentOp) || !isa<WsloopOp>(parentOp->
getParentOp()))) {
2502 return emitOpError() <<
"cancellation point loop must appear "
2503 <<
"inside a worksharing-loop region";
2505 if ((cct == ClauseCancellationConstructType::Sections) &&
2506 !(isa<SectionsOp>(parentOp) || isa<SectionOp>(parentOp))) {
2507 return emitOpError() <<
"cancellation point sections must appear "
2508 <<
"inside a sections region";
2519 auto extent = getExtent();
2521 if (!extent && !upperbound)
2522 return emitError(
"expected extent or upperbound.");
2529 PrivateClauseOp::build(
2530 odsBuilder, odsState, symName, type,
2532 DataSharingClauseType::Private));
2538 auto verifyTerminator = [&](
Operation *terminator,
2539 bool yieldsValue) -> LogicalResult {
2543 if (!llvm::isa<YieldOp>(terminator))
2545 <<
"expected exit block terminator to be an `omp.yield` op.";
2547 YieldOp yieldOp = llvm::cast<YieldOp>(terminator);
2548 TypeRange yieldedTypes = yieldOp.getResults().getTypes();
2551 if (yieldedTypes.empty())
2555 <<
"Did not expect any values to be yielded.";
2558 if (yieldedTypes.size() == 1 && yieldedTypes.front() == symType)
2562 <<
"Invalid yielded value. Expected type: " << symType
2565 if (yieldedTypes.empty())
2568 error << yieldedTypes;
2574 StringRef regionName,
2575 bool yieldsValue) -> LogicalResult {
2576 assert(!region.
empty());
2580 <<
"`" << regionName <<
"`: "
2581 <<
"expected " << expectedNumArgs
2584 for (
Block &block : region) {
2586 if (!block.mightHaveTerminator())
2589 if (failed(verifyTerminator(block.getTerminator(), yieldsValue)))
2600 DataSharingClauseType dsType = getDataSharingType();
2602 if (dsType == DataSharingClauseType::Private && !getCopyRegion().empty())
2603 return emitError(
"`private` clauses require only an `alloc` region.");
2605 if (dsType == DataSharingClauseType::FirstPrivate && getCopyRegion().empty())
2607 "`firstprivate` clauses require both `alloc` and `copy` regions.");
2609 if (dsType == DataSharingClauseType::FirstPrivate &&
2614 if (!getDeallocRegion().empty() &&
2628 MaskedOp::build(builder, state, clauses.filteredThreadId);
2631 #define GET_ATTRDEF_CLASSES
2632 #include "mlir/Dialect/OpenMP/OpenMPOpsAttributes.cpp.inc"
2634 #define GET_OP_CLASSES
2635 #include "mlir/Dialect/OpenMP/OpenMPOps.cpp.inc"
2637 #define GET_TYPEDEF_CLASSES
2638 #include "mlir/Dialect/OpenMP/OpenMPOpsTypes.cpp.inc"
static std::optional< int64_t > getUpperBound(Value iv)
Gets the constant upper bound on an affine.for iv.
static void visit(Operation *op, DenseSet< Operation * > &visited)
Visits all the pdl.operand(s), pdl.result(s), and pdl.operation(s) connected to the given operation.
static LogicalResult verifyRegion(emitc::SwitchOp op, Region ®ion, const Twine &name)
static MLIRContext * getContext(OpFoldResult val)
void printClauseAttr(OpAsmPrinter &p, Operation *op, ClauseAttr attr)
static LogicalResult verifyNontemporalClause(Operation *op, OperandRange nontemporalVars)
static LogicalResult verifyMapClause(Operation *op, OperandRange mapVars)
static ArrayAttr makeArrayAttr(MLIRContext *context, llvm::ArrayRef< Attribute > attrs)
static ParseResult parseClauseAttr(AsmParser &parser, ClauseAttr &attr)
static void printAllocateAndAllocator(OpAsmPrinter &p, Operation *op, OperandRange allocateVars, TypeRange allocateTypes, OperandRange allocatorVars, TypeRange allocatorTypes)
Print allocate clause.
static DenseBoolArrayAttr makeDenseBoolArrayAttr(MLIRContext *ctx, const ArrayRef< bool > boolArray)
static ParseResult parseReductionVarList(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &reductionVars, SmallVectorImpl< Type > &reductionTypes, DenseBoolArrayAttr &reductionByref, ArrayAttr &reductionSyms)
reduction-entry-list ::= reduction-entry | reduction-entry-list , reduction-entry reduction-entry ::=...
static ParseResult parsePrivateList(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &privateVars, SmallVectorImpl< Type > &privateTypes, ArrayAttr &privateSyms)
static void printMapEntries(OpAsmPrinter &p, Operation *op, OperandRange mapVars, TypeRange mapTypes)
static void printReductionVarList(OpAsmPrinter &p, Operation *op, OperandRange reductionVars, TypeRange reductionTypes, std::optional< DenseBoolArrayAttr > reductionByref, std::optional< ArrayAttr > reductionSyms)
Print Reduction clause.
static ParseResult parseSynchronizationHint(OpAsmParser &parser, IntegerAttr &hintAttr)
Parses a Synchronization Hint clause.
uint64_t mapTypeToBitFlag(uint64_t value, llvm::omp::OpenMPOffloadMappingFlags flag)
static ParseResult parseLinearClause(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &linearVars, SmallVectorImpl< Type > &linearTypes, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &linearStepVars)
linear ::= linear ( linear-list ) linear-list := linear-val | linear-val linear-list linear-val := ss...
static void printScheduleClause(OpAsmPrinter &p, Operation *op, ClauseScheduleKindAttr scheduleKind, ScheduleModifierAttr scheduleMod, UnitAttr scheduleSimd, Value scheduleChunk, Type scheduleChunkType)
Print schedule clause.
static ParseResult parseParallelRegion(OpAsmParser &parser, Region ®ion, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &reductionVars, SmallVectorImpl< Type > &reductionTypes, DenseBoolArrayAttr &reductionByref, ArrayAttr &reductionSyms, llvm::SmallVectorImpl< OpAsmParser::UnresolvedOperand > &privateVars, llvm::SmallVectorImpl< Type > &privateTypes, ArrayAttr &privateSyms)
static void printCopyprivate(OpAsmPrinter &p, Operation *op, OperandRange copyprivateVars, TypeRange copyprivateTypes, std::optional< ArrayAttr > copyprivateSyms)
Print Copyprivate clause.
static ParseResult parseOrderClause(OpAsmParser &parser, ClauseOrderKindAttr &order, OrderModifierAttr &orderMod)
void printWsloop(OpAsmPrinter &p, Operation *op, Region ®ion, ValueRange reductionOperands, TypeRange reductionTypes, DenseBoolArrayAttr isByRef, ArrayAttr reductionSymbols)
static void printAlignedClause(OpAsmPrinter &p, Operation *op, ValueRange alignedVars, TypeRange alignedTypes, std::optional< ArrayAttr > alignments)
Print Aligned Clause.
static LogicalResult verifySynchronizationHint(Operation *op, uint64_t hint)
Verifies a synchronization hint clause.
static void printLinearClause(OpAsmPrinter &p, Operation *op, ValueRange linearVars, TypeRange linearTypes, ValueRange linearStepVars)
Print Linear Clause.
static void printSynchronizationHint(OpAsmPrinter &p, Operation *op, IntegerAttr hintAttr)
Prints a Synchronization Hint clause.
static LogicalResult verifyLoopWrapperInterface(Operation *op)
static void printDependVarList(OpAsmPrinter &p, Operation *op, OperandRange dependVars, TypeRange dependTypes, std::optional< ArrayAttr > dependKinds)
Print Depend clause.
static ParseResult parseClauseWithRegionArgs(OpAsmParser &parser, Region ®ion, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &operands, SmallVectorImpl< Type > &types, DenseBoolArrayAttr &byref, ArrayAttr &symbols, SmallVectorImpl< OpAsmParser::Argument > ®ionPrivateArgs)
static LogicalResult verifyCopyprivateVarList(Operation *op, OperandRange copyprivateVars, std::optional< ArrayAttr > copyprivateSyms)
Verifies CopyPrivate Clause.
static LogicalResult verifyAlignedClause(Operation *op, std::optional< ArrayAttr > alignments, OperandRange alignedVars)
ParseResult parseWsloop(OpAsmParser &parser, Region ®ion, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &reductionOperands, SmallVectorImpl< Type > &reductionTypes, DenseBoolArrayAttr &reductionByRef, ArrayAttr &reductionSymbols)
static ParseResult parseMapClause(OpAsmParser &parser, IntegerAttr &mapType)
Parses a map_entries map type from a string format back into its numeric value.
static LogicalResult verifyOrderedParent(Operation &op)
static void printOrderClause(OpAsmPrinter &p, Operation *op, ClauseOrderKindAttr order, OrderModifierAttr orderMod)
static void printMembersIndex(OpAsmPrinter &p, MapInfoOp op, DenseIntElementsAttr membersIdx)
static ParseResult parseMapEntries(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &mapVars, SmallVectorImpl< Type > &mapTypes)
static ParseResult verifyScheduleModifiers(OpAsmParser &parser, SmallVectorImpl< SmallString< 12 >> &modifiers)
static ParseResult parseScheduleClause(OpAsmParser &parser, ClauseScheduleKindAttr &scheduleAttr, ScheduleModifierAttr &scheduleMod, UnitAttr &scheduleSimd, std::optional< OpAsmParser::UnresolvedOperand > &chunkSize, Type &chunkType)
schedule ::= schedule ( sched-list ) sched-list ::= sched-val | sched-val sched-list | sched-val ,...
static void printClauseWithRegionArgs(OpAsmPrinter &p, Operation *op, ValueRange argsSubrange, StringRef clauseName, ValueRange operands, TypeRange types, DenseBoolArrayAttr byref, ArrayAttr symbols)
static ParseResult parseAllocateAndAllocator(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &allocateVars, SmallVectorImpl< Type > &allocateTypes, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &allocatorVars, SmallVectorImpl< Type > &allocatorTypes)
Parse an allocate clause with allocators and a list of operands with types.
static void printCaptureType(OpAsmPrinter &p, Operation *op, VariableCaptureKindAttr mapCaptureType)
static LogicalResult verifyReductionVarList(Operation *op, std::optional< ArrayAttr > reductionSyms, OperandRange reductionVars, std::optional< ArrayRef< bool >> reductionByref)
Verifies Reduction Clause.
static void printParallelRegion(OpAsmPrinter &p, Operation *op, Region ®ion, ValueRange reductionVars, TypeRange reductionTypes, DenseBoolArrayAttr reductionByref, ArrayAttr reductionSyms, ValueRange privateVars, TypeRange privateTypes, ArrayAttr privateSyms)
static bool opInGlobalImplicitParallelRegion(Operation *op)
static void printPrivateList(OpAsmPrinter &p, Operation *op, ValueRange privateVars, TypeRange privateTypes, ArrayAttr privateSyms)
static LogicalResult verifyPrivateVarList(OpType &op)
static void printMapClause(OpAsmPrinter &p, Operation *op, IntegerAttr mapType)
Prints a map_entries map type from its numeric value out into its string format.
static ParseResult parseAlignedClause(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &alignedVars, SmallVectorImpl< Type > &alignedTypes, ArrayAttr &alignmentsAttr)
aligned ::= aligned ( aligned-list ) aligned-list := aligned-val | aligned-val aligned-list aligned-v...
static ParseResult parseCaptureType(OpAsmParser &parser, VariableCaptureKindAttr &mapCaptureType)
static ParseResult parseCopyprivate(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > ©privateVars, SmallVectorImpl< Type > ©privateTypes, ArrayAttr ©privateSyms)
copyprivate-entry-list ::= copyprivate-entry | copyprivate-entry-list , copyprivate-entry copyprivate...
static LogicalResult verifyDependVarList(Operation *op, std::optional< ArrayAttr > dependKinds, OperandRange dependVars)
Verifies Depend clause.
static ParseResult parseMembersIndex(OpAsmParser &parser, DenseIntElementsAttr &membersIdx)
static ParseResult parseDependVarList(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &dependVars, SmallVectorImpl< Type > &dependTypes, ArrayAttr &dependKinds)
depend-entry-list ::= depend-entry | depend-entry-list , depend-entry depend-entry ::= depend-kind ->...
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity,...
This base class exposes generic asm parser hooks, usable across the various derived parsers.
@ Paren
Parens surrounding zero or more operands.
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalEqual()=0
Parse a = token if present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseOptionalColon()=0
Parse a : token if present.
virtual ParseResult parseLSquare()=0
Parse a [ token.
virtual ParseResult parseRSquare()=0
Parse a ] token.
ParseResult parseInteger(IntT &result)
Parse an integer value from the stream.
virtual ParseResult parseOptionalArrow()=0
Parse a '->' token if present.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual ParseResult parseColon()=0
Parse a : token.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseArrow()=0
Parse a '->' token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
Block represents an ordered list of Operations.
ValueTypeRange< BlockArgListType > getArgumentTypes()
Return a range containing the types of the arguments for this block.
BlockArgument getArgument(unsigned i)
unsigned getNumArguments()
SuccessorRange getSuccessors()
BlockArgListType getArguments()
IntegerAttr getIntegerAttr(Type type, int64_t value)
IntegerType getIntegerType(unsigned width)
MLIRContext * getContext() const
An attribute that represents a reference to a dense integer vector or tensor object.
static DenseIntElementsAttr get(const ShapedType &type, Arg &&arg)
Get an instance of a DenseIntElementsAttr with the given arguments.
MLIRContext is the top-level object for a collection of MLIR operations.
Dialect * getLoadedDialect(StringRef name)
Get a registered IR dialect with the given namespace.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region ®ion, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult parseArgumentList(SmallVectorImpl< Argument > &result, Delimiter delimiter=Delimiter::None, bool allowType=false, bool allowAttrs=false)=0
Parse zero or more arguments with a specified surrounding delimiter.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
This class helps build Operations.
This class provides the API for ops that are known to be terminators.
This class implements the operand iterators for the Operation class.
type_range getType() const
Operation is the basic unit of execution within MLIR.
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Dialect * getDialect()
Return the dialect this operation is associated with, or nullptr if the associated dialect is not loa...
MLIRContext * getContext()
Return the context this operation is associated with.
unsigned getNumRegions()
Returns the number of regions held by this operation.
Location getLoc()
The source location the operation was defined or derived from.
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Block * getBlock()
Returns the operation block that contains this operation.
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class contains a list of basic blocks and a link to the parent operation it is attached to.
iterator_range< OpIterator > getOps()
BlockArgListType getArguments()
OpIterator op_begin()
Return iterators that walk the operations nested directly within this region.
unsigned getNumArguments()
Location getLoc()
Return a location for this region.
bool hasOneBlock()
Return true if this region has exactly one block.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class represents a collection of SymbolTables.
Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
type_range getType() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Base class for DenseArrayAttr that is instantiated and specialized for each supported element type be...
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< T > content)
Builder from ArrayRef<T>.
ArrayRef< T > asArrayRef() const
Runtime
Potential runtimes for AMD GPU kernels.
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
This is the representation of an operand reference.
This represents an operation in an abstracted form, suitable for use with the builder APIs.
SmallVector< Value, 4 > operands
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
Region * addRegion()
Create a region that should be attached to the operation.
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.