24 #include "llvm/ADT/ArrayRef.h"
25 #include "llvm/ADT/BitVector.h"
26 #include "llvm/ADT/STLExtras.h"
27 #include "llvm/ADT/STLForwardCompat.h"
28 #include "llvm/ADT/SmallString.h"
29 #include "llvm/ADT/StringExtras.h"
30 #include "llvm/ADT/StringRef.h"
31 #include "llvm/ADT/TypeSwitch.h"
32 #include "llvm/Frontend/OpenMP/OMPConstants.h"
38 #include "mlir/Dialect/OpenMP/OpenMPOpsDialect.cpp.inc"
39 #include "mlir/Dialect/OpenMP/OpenMPOpsEnums.cpp.inc"
40 #include "mlir/Dialect/OpenMP/OpenMPOpsInterfaces.cpp.inc"
41 #include "mlir/Dialect/OpenMP/OpenMPTypeInterfaces.cpp.inc"
57 struct MemRefPointerLikeModel
58 :
public PointerLikeType::ExternalModel<MemRefPointerLikeModel,
61 return llvm::cast<MemRefType>(pointer).getElementType();
65 struct LLVMPointerPointerLikeModel
66 :
public PointerLikeType::ExternalModel<LLVMPointerPointerLikeModel,
67 LLVM::LLVMPointerType> {
72 void OpenMPDialect::initialize() {
75 #include "mlir/Dialect/OpenMP/OpenMPOps.cpp.inc"
78 #define GET_ATTRDEF_LIST
79 #include "mlir/Dialect/OpenMP/OpenMPOpsAttributes.cpp.inc"
82 #define GET_TYPEDEF_LIST
83 #include "mlir/Dialect/OpenMP/OpenMPOpsTypes.cpp.inc"
86 MemRefType::attachInterface<MemRefPointerLikeModel>(*
getContext());
87 LLVM::LLVMPointerType::attachInterface<LLVMPointerPointerLikeModel>(
92 mlir::ModuleOp::attachInterface<mlir::omp::OffloadModuleDefaultModel>(
98 mlir::LLVM::GlobalOp::attachInterface<
101 mlir::LLVM::LLVMFuncOp::attachInterface<
104 mlir::func::FuncOp::attachInterface<
130 operandsAllocator.push_back(operand);
131 typesAllocator.push_back(type);
137 operandsAllocate.push_back(operand);
138 typesAllocate.push_back(type);
149 for (
unsigned i = 0; i < varsAllocate.size(); ++i) {
150 std::string separator = i == varsAllocate.size() - 1 ?
"" :
", ";
151 p << varsAllocator[i] <<
" : " << typesAllocator[i] <<
" -> ";
152 p << varsAllocate[i] <<
" : " << typesAllocate[i] << separator;
160 template <
typename ClauseAttr>
162 using ClauseT = decltype(std::declval<ClauseAttr>().getValue());
167 if (std::optional<ClauseT> enumValue = symbolizeEnum<ClauseT>(enumStr)) {
171 return parser.
emitError(loc,
"invalid clause value: '") << enumStr <<
"'";
174 template <
typename ClauseAttr>
176 p << stringifyEnum(attr.getValue());
200 types.push_back(type);
201 stepVars.push_back(stepVar);
210 size_t linearVarsSize = linearVars.size();
211 for (
unsigned i = 0; i < linearVarsSize; ++i) {
212 std::string separator = i == linearVarsSize - 1 ?
"" :
", ";
214 if (linearStepVars.size() > i)
215 p <<
" = " << linearStepVars[i];
216 p <<
" : " << linearVars[i].getType() << separator;
229 for (
const auto &it : nontemporalVariables)
230 if (!nontemporalItems.insert(it).second)
231 return op->
emitOpError() <<
"nontemporal variable used more than once";
243 if (!alignedVariables.empty()) {
244 if (!alignmentValues || alignmentValues->size() != alignedVariables.size())
246 <<
"expected as many alignment values as aligned variables";
249 return op->
emitOpError() <<
"unexpected alignment values attribute";
255 for (
auto it : alignedVariables)
256 if (!alignedItems.insert(it).second)
257 return op->
emitOpError() <<
"aligned variable used more than once";
259 if (!alignmentValues)
263 for (
unsigned i = 0; i < (*alignmentValues).size(); ++i) {
264 if (
auto intAttr = llvm::dyn_cast<IntegerAttr>((*alignmentValues)[i])) {
265 if (intAttr.getValue().sle(0))
266 return op->
emitOpError() <<
"alignment should be greater than 0";
268 return op->
emitOpError() <<
"expected integer alignment";
284 if (parser.parseOperand(alignedItems.emplace_back()) ||
285 parser.parseColonType(types.emplace_back()) ||
286 parser.parseArrow() ||
287 parser.parseAttribute(alignmentVec.emplace_back())) {
294 alignmentValues =
ArrayAttr::get(parser.getContext(), alignments);
302 std::optional<ArrayAttr> alignmentValues) {
303 for (
unsigned i = 0; i < alignedVars.size(); ++i) {
306 p << alignedVars[i] <<
" : " << alignedVars[i].
getType();
307 p <<
" -> " << (*alignmentValues)[i];
318 if (modifiers.size() > 2)
320 for (
const auto &mod : modifiers) {
323 auto symbol = symbolizeScheduleModifier(mod);
326 <<
" unknown modifier type: " << mod;
331 if (modifiers.size() == 1) {
332 if (symbolizeScheduleModifier(modifiers[0]) == ScheduleModifier::simd) {
333 modifiers.push_back(modifiers[0]);
334 modifiers[0] = stringifyScheduleModifier(ScheduleModifier::none);
336 }
else if (modifiers.size() == 2) {
339 if (symbolizeScheduleModifier(modifiers[0]) == ScheduleModifier::simd ||
340 symbolizeScheduleModifier(modifiers[1]) != ScheduleModifier::simd)
342 <<
" incorrect modifier order";
357 OpAsmParser &parser, ClauseScheduleKindAttr &scheduleAttr,
358 ScheduleModifierAttr &scheduleModifier, UnitAttr &simdModifier,
359 std::optional<OpAsmParser::UnresolvedOperand> &chunkSize,
Type &chunkType) {
363 std::optional<mlir::omp::ClauseScheduleKind> schedule =
364 symbolizeClauseScheduleKind(keyword);
370 case ClauseScheduleKind::Static:
371 case ClauseScheduleKind::Dynamic:
372 case ClauseScheduleKind::Guided:
378 chunkSize = std::nullopt;
381 case ClauseScheduleKind::Auto:
383 chunkSize = std::nullopt;
392 modifiers.push_back(mod);
398 if (!modifiers.empty()) {
400 if (std::optional<ScheduleModifier> mod =
401 symbolizeScheduleModifier(modifiers[0])) {
404 return parser.
emitError(loc,
"invalid schedule modifier");
407 if (modifiers.size() > 1) {
408 assert(symbolizeScheduleModifier(modifiers[1]) == ScheduleModifier::simd);
418 ClauseScheduleKindAttr schedAttr,
419 ScheduleModifierAttr modifier, UnitAttr simd,
420 Value scheduleChunkVar,
421 Type scheduleChunkType) {
422 p << stringifyClauseScheduleKind(schedAttr.getValue());
423 if (scheduleChunkVar)
424 p <<
" = " << scheduleChunkVar <<
" : " << scheduleChunkVar.
getType();
426 p <<
", " << stringifyScheduleModifier(modifier.getValue());
438 ClauseOrderKindAttr &kindAttr,
439 OrderModifierAttr &modifierAttr) {
444 if (std::optional<OrderModifier> enumValue =
445 symbolizeOrderModifier(enumStr)) {
453 if (std::optional<ClauseOrderKind> enumValue =
454 symbolizeClauseOrderKind(enumStr)) {
458 return parser.
emitError(loc,
"invalid clause value: '") << enumStr <<
"'";
462 ClauseOrderKindAttr kindAttr,
463 OrderModifierAttr modifierAttr) {
465 p << stringifyOrderModifier(modifierAttr.getValue()) <<
":";
467 p << stringifyClauseOrderKind(kindAttr.getValue());
482 unsigned regionArgOffset = regionPrivateArgs.size();
486 ParseResult optionalByref = parser.parseOptionalKeyword(
"byref");
487 if (parser.parseAttribute(reductionVec.emplace_back()) ||
488 parser.parseOperand(operands.emplace_back()) ||
489 parser.parseArrow() ||
490 parser.parseArgument(regionPrivateArgs.emplace_back()) ||
491 parser.parseColonType(types.emplace_back()))
493 isByRefVec.push_back(optionalByref.succeeded());
499 auto *argsBegin = regionPrivateArgs.begin();
501 argsBegin + regionArgOffset + types.size());
502 for (
auto [prv, type] : llvm::zip_equal(argsSubrange, types)) {
515 if (!clauseName.empty())
516 p << clauseName <<
"(";
518 llvm::interleaveComma(llvm::zip_equal(symbols, operands, argsSubrange, types,
521 auto [sym, op, arg, type, isByRef] = t;
522 p << (isByRef ?
"byref " :
"") << sym <<
" " << op
523 <<
" -> " << arg <<
" : " << type;
526 if (!clauseName.empty())
537 ArrayAttr &privatizerSymbols) {
542 reductionVarTypes, reductionByRef,
543 reductionSymbols, regionPrivateArgs)))
550 privateVarsTypes, privateByRef,
551 privatizerSymbols, regionPrivateArgs)))
553 if (llvm::any_of(privateByRef.asArrayRef(),
554 [](
bool byref) { return byref; })) {
556 "private clause cannot have byref attributes");
561 return parser.
parseRegion(region, regionPrivateArgs);
568 ArrayAttr reductionSymbols,
571 ArrayAttr privatizerSymbols) {
572 if (reductionSymbols) {
575 argsBegin + reductionVarTypes.size());
577 reductionVarOperands, reductionVarTypes,
578 reductionVarIsByRef, reductionSymbols);
581 if (privatizerSymbols) {
584 argsBegin + reductionVarOperands.size() +
585 privateVarTypes.size());
587 isByRefVec.resize(privateVarTypes.size(),
false);
592 privateVarOperands, privateVarTypes, isByRef,
606 ArrayAttr &reductionSymbols) {
610 ParseResult optionalByref = parser.parseOptionalKeyword(
"byref");
611 if (parser.parseAttribute(reductionVec.emplace_back()) ||
612 parser.parseArrow() ||
613 parser.parseOperand(operands.emplace_back()) ||
614 parser.parseColonType(types.emplace_back()))
616 isByRefVec.push_back(optionalByref.succeeded());
630 std::optional<DenseBoolArrayAttr> isByRef,
631 std::optional<ArrayAttr> reductions) {
632 auto getByRef = [&](
unsigned i) ->
const char * {
633 if (!isByRef || !*isByRef)
635 assert(isByRef->empty() || i < isByRef->size());
636 if (!isByRef->empty() && (*isByRef)[i])
641 for (
unsigned i = 0, e = reductionVars.size(); i < e; ++i) {
644 p << getByRef(i) << (*reductions)[i] <<
" -> " << reductionVars[i] <<
" : "
654 if (!reductionVars.empty()) {
655 if (!reductions || reductions->size() != reductionVars.size())
657 <<
"expected as many reduction symbol references "
658 "as reduction variables";
659 if (byRef && byRef->size() != reductionVars.size())
660 return op->
emitError() <<
"expected as many reduction variable by "
661 "reference attributes as reduction variables";
664 return op->
emitOpError() <<
"unexpected reduction symbol references";
671 for (
auto args : llvm::zip(reductionVars, *reductions)) {
672 Value accum = std::get<0>(args);
674 if (!accumulators.insert(accum).second)
675 return op->
emitOpError() <<
"accumulator variable used more than once";
678 auto symbolRef = llvm::cast<SymbolRefAttr>(std::get<1>(args));
680 SymbolTable::lookupNearestSymbolFrom<DeclareReductionOp>(op, symbolRef);
682 return op->
emitOpError() <<
"expected symbol reference " << symbolRef
683 <<
" to point to a reduction declaration";
685 if (decl.getAccumulatorType() && decl.getAccumulatorType() != varType)
687 <<
"expected accumulator (" << varType
688 <<
") to be the same type as reduction declaration ("
689 << decl.getAccumulatorType() <<
")";
708 if (parser.parseOperand(operands.emplace_back()) ||
709 parser.parseArrow() ||
710 parser.parseAttribute(copyPrivateFuncsVec.emplace_back()) ||
711 parser.parseColonType(types.emplace_back()))
717 copyPrivateFuncsVec.end());
726 std::optional<ArrayAttr> copyPrivateFuncs) {
727 if (!copyPrivateFuncs.has_value())
729 llvm::interleaveComma(
730 llvm::zip(copyPrivateVars, *copyPrivateFuncs, copyPrivateTypes), p,
731 [&](
const auto &args) {
732 p << std::get<0>(args) <<
" -> " << std::get<1>(args) <<
" : "
733 << std::get<2>(args);
740 std::optional<ArrayAttr> copyPrivateFuncs) {
741 size_t copyPrivateFuncsSize =
742 copyPrivateFuncs.has_value() ? copyPrivateFuncs->size() : 0;
743 if (copyPrivateFuncsSize != copyPrivateVars.size())
744 return op->
emitOpError() <<
"inconsistent number of copyPrivate vars (= "
745 << copyPrivateVars.size()
746 <<
") and functions (= " << copyPrivateFuncsSize
747 <<
"), both must be equal";
748 if (!copyPrivateFuncs.has_value())
751 for (
auto copyPrivateVarAndFunc :
752 llvm::zip(copyPrivateVars, *copyPrivateFuncs)) {
754 llvm::cast<SymbolRefAttr>(std::get<1>(copyPrivateVarAndFunc));
755 std::optional<std::variant<mlir::func::FuncOp, mlir::LLVM::LLVMFuncOp>>
757 if (mlir::func::FuncOp mlirFuncOp =
758 SymbolTable::lookupNearestSymbolFrom<mlir::func::FuncOp>(op,
761 else if (mlir::LLVM::LLVMFuncOp llvmFuncOp =
762 SymbolTable::lookupNearestSymbolFrom<mlir::LLVM::LLVMFuncOp>(
766 auto getNumArguments = [&] {
767 return std::visit([](
auto &f) {
return f.getNumArguments(); }, *funcOp);
770 auto getArgumentType = [&](
unsigned i) {
771 return std::visit([i](
auto &f) {
return f.getArgumentTypes()[i]; },
776 return op->
emitOpError() <<
"expected symbol reference " << symbolRef
777 <<
" to point to a copy function";
779 if (getNumArguments() != 2)
781 <<
"expected copy function " << symbolRef <<
" to have 2 operands";
783 Type argTy = getArgumentType(0);
784 if (argTy != getArgumentType(1))
785 return op->
emitOpError() <<
"expected copy function " << symbolRef
786 <<
" arguments to have the same type";
788 Type varType = std::get<0>(copyPrivateVarAndFunc).getType();
789 if (argTy != varType)
791 <<
"expected copy function arguments' type (" << argTy
792 <<
") to be the same as copyprivate variable's type (" << varType
813 if (parser.parseKeyword(&keyword) || parser.parseArrow() ||
814 parser.parseOperand(operands.emplace_back()) ||
815 parser.parseColonType(types.emplace_back()))
817 if (std::optional<ClauseTaskDepend> keywordDepend =
818 (symbolizeClauseTaskDepend(keyword)))
819 dependVec.emplace_back(
820 ClauseTaskDependAttr::get(parser.getContext(), *keywordDepend));
834 std::optional<ArrayAttr> depends) {
836 for (
unsigned i = 0, e = depends->size(); i < e; ++i) {
839 p << stringifyClauseTaskDepend(
840 llvm::cast<mlir::omp::ClauseTaskDependAttr>((*depends)[i])
842 <<
" -> " << dependVars[i] <<
" : " << dependTypes[i];
848 std::optional<ArrayAttr> depends,
850 if (!dependVars.empty()) {
851 if (!depends || depends->size() != dependVars.size())
852 return op->
emitOpError() <<
"expected as many depend values"
853 " as depend variables";
855 if (depends && !depends->empty())
856 return op->
emitOpError() <<
"unexpected depend values";
872 IntegerAttr &hintAttr) {
873 StringRef hintKeyword;
879 auto parseKeyword = [&]() -> ParseResult {
882 if (hintKeyword ==
"uncontended")
884 else if (hintKeyword ==
"contended")
886 else if (hintKeyword ==
"nonspeculative")
888 else if (hintKeyword ==
"speculative")
892 << hintKeyword <<
" is not a valid hint";
903 IntegerAttr hintAttr) {
904 int64_t hint = hintAttr.getInt();
912 auto bitn = [](
int value,
int n) ->
bool {
return value & (1 << n); };
914 bool uncontended = bitn(hint, 0);
915 bool contended = bitn(hint, 1);
916 bool nonspeculative = bitn(hint, 2);
917 bool speculative = bitn(hint, 3);
921 hints.push_back(
"uncontended");
923 hints.push_back(
"contended");
925 hints.push_back(
"nonspeculative");
927 hints.push_back(
"speculative");
929 llvm::interleaveComma(hints, p);
936 auto bitn = [](
int value,
int n) ->
bool {
return value & (1 << n); };
938 bool uncontended = bitn(hint, 0);
939 bool contended = bitn(hint, 1);
940 bool nonspeculative = bitn(hint, 2);
941 bool speculative = bitn(hint, 3);
943 if (uncontended && contended)
944 return op->
emitOpError() <<
"the hints omp_sync_hint_uncontended and "
945 "omp_sync_hint_contended cannot be combined";
946 if (nonspeculative && speculative)
947 return op->
emitOpError() <<
"the hints omp_sync_hint_nonspeculative and "
948 "omp_sync_hint_speculative cannot be combined.";
958 llvm::omp::OpenMPOffloadMappingFlags flag) {
959 return value & llvm::to_underlying(flag);
968 llvm::omp::OpenMPOffloadMappingFlags mapTypeBits =
969 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_NONE;
973 auto parseTypeAndMod = [&]() -> ParseResult {
974 StringRef mapTypeMod;
978 if (mapTypeMod ==
"always")
979 mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ALWAYS;
981 if (mapTypeMod ==
"implicit")
982 mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_IMPLICIT;
984 if (mapTypeMod ==
"close")
985 mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_CLOSE;
987 if (mapTypeMod ==
"present")
988 mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_PRESENT;
990 if (mapTypeMod ==
"to")
991 mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO;
993 if (mapTypeMod ==
"from")
994 mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM;
996 if (mapTypeMod ==
"tofrom")
997 mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO |
998 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM;
1000 if (mapTypeMod ==
"delete")
1001 mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_DELETE;
1011 llvm::to_underlying(mapTypeBits));
1019 IntegerAttr mapType) {
1020 uint64_t mapTypeBits = mapType.getUInt();
1022 bool emitAllocRelease =
true;
1028 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ALWAYS))
1029 mapTypeStrs.push_back(
"always");
1031 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_IMPLICIT))
1032 mapTypeStrs.push_back(
"implicit");
1034 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_CLOSE))
1035 mapTypeStrs.push_back(
"close");
1037 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_PRESENT))
1038 mapTypeStrs.push_back(
"present");
1044 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO);
1046 mapTypeBits, llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM);
1048 emitAllocRelease =
false;
1049 mapTypeStrs.push_back(
"tofrom");
1051 emitAllocRelease =
false;
1052 mapTypeStrs.push_back(
"from");
1054 emitAllocRelease =
false;
1055 mapTypeStrs.push_back(
"to");
1058 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_DELETE)) {
1059 emitAllocRelease =
false;
1060 mapTypeStrs.push_back(
"delete");
1062 if (emitAllocRelease)
1063 mapTypeStrs.push_back(
"exit_release_or_enter_alloc");
1065 for (
unsigned int i = 0; i < mapTypeStrs.size(); ++i) {
1066 p << mapTypeStrs[i];
1067 if (i + 1 < mapTypeStrs.size()) {
1077 int64_t shape[2] = {0, 0};
1078 unsigned shapeTmp = 0;
1079 auto parseIndices = [&]() -> ParseResult {
1083 values.push_back(APInt(32, value));
1100 shape[1] = shapeTmp;
1106 if (shapeTmp != shape[1])
1113 if (!values.empty()) {
1114 ShapedType valueType =
1125 assert(shape.size() <= 2);
1130 for (
int i = 0; i < shape[0]; ++i) {
1132 int rowOffset = i * shape[1];
1133 for (
int j = 0;
j < shape[1]; ++
j) {
1134 p << membersIdx.getValues<int32_t>()[rowOffset +
j];
1135 if ((
j + 1) < shape[1])
1140 if ((i + 1) < shape[0])
1152 auto parseEntries = [&]() -> ParseResult {
1157 mapOperands.push_back(arg);
1161 auto parseTypes = [&]() -> ParseResult {
1164 mapOperandTypes.push_back(argType);
1187 unsigned argIndex = 0;
1189 for (
const auto &mapOp : mapOperands) {
1192 const auto &blockArg = entryBlock->
getArgument(argIndex);
1193 p <<
" -> " << blockArg;
1196 if (argIndex < mapOperands.size())
1202 for (
const auto &mapType : mapOperandTypes) {
1205 if (argIndex < mapOperands.size())
1218 if (parser.parseAttribute(privateSymRefs.emplace_back()) ||
1219 parser.parseOperand(privateOperands.emplace_back()) ||
1220 parser.parseArrow() ||
1221 parser.parseArgument(regionPrivateArgs.emplace_back()) ||
1222 parser.parseColonType(privateOperandTypes.emplace_back()))
1229 privateSymRefs.end());
1238 ArrayAttr privatizerSymbols) {
1240 auto targetOp = mlir::dyn_cast<mlir::omp::TargetOp>(op);
1245 MutableArrayRef argsSubrange(argsBegin + targetOp.getMapOperands().size(),
1246 argsBegin + targetOp.getMapOperands().size() +
1247 privateVarTypes.size());
1249 isByRefVec.resize(privateVarTypes.size(),
false);
1254 p, op, argsSubrange, llvm::StringRef{}, privateVarOperands,
1255 privateVarTypes, isByRef, privatizerSymbols);
1259 VariableCaptureKindAttr mapCaptureType) {
1260 std::string typeCapStr;
1261 llvm::raw_string_ostream typeCap(typeCapStr);
1262 if (mapCaptureType.getValue() == mlir::omp::VariableCaptureKind::ByRef)
1264 if (mapCaptureType.getValue() == mlir::omp::VariableCaptureKind::ByCopy)
1265 typeCap <<
"ByCopy";
1266 if (mapCaptureType.getValue() == mlir::omp::VariableCaptureKind::VLAType)
1267 typeCap <<
"VLAType";
1268 if (mapCaptureType.getValue() == mlir::omp::VariableCaptureKind::This)
1274 VariableCaptureKindAttr &mapCapture) {
1275 StringRef mapCaptureKey;
1279 if (mapCaptureKey ==
"This")
1281 parser.
getContext(), mlir::omp::VariableCaptureKind::This);
1282 if (mapCaptureKey ==
"ByRef")
1284 parser.
getContext(), mlir::omp::VariableCaptureKind::ByRef);
1285 if (mapCaptureKey ==
"ByCopy")
1287 parser.
getContext(), mlir::omp::VariableCaptureKind::ByCopy);
1288 if (mapCaptureKey ==
"VLAType")
1290 parser.
getContext(), mlir::omp::VariableCaptureKind::VLAType);
1299 for (
auto mapOp : mapOperands) {
1300 if (!mapOp.getDefiningOp())
1303 if (
auto mapInfoOp =
1304 mlir::dyn_cast<mlir::omp::MapInfoOp>(mapOp.getDefiningOp())) {
1305 if (!mapInfoOp.getMapType().has_value())
1308 if (!mapInfoOp.getMapCaptureType().has_value())
1311 uint64_t mapTypeBits = mapInfoOp.getMapType().value();
1314 mapTypeBits, llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO);
1316 mapTypeBits, llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM);
1318 mapTypeBits, llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_DELETE);
1321 mapTypeBits, llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ALWAYS);
1323 mapTypeBits, llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_CLOSE);
1325 mapTypeBits, llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_IMPLICIT);
1327 if ((isa<TargetDataOp>(op) || isa<TargetOp>(op)) && del)
1329 "to, from, tofrom and alloc map types are permitted");
1331 if (isa<TargetEnterDataOp>(op) && (from || del))
1332 return emitError(op->
getLoc(),
"to and alloc map types are permitted");
1334 if (isa<TargetExitDataOp>(op) && to)
1336 "from, release and delete map types are permitted");
1338 if (isa<TargetUpdateOp>(op)) {
1341 "at least one of to or from map types must be "
1342 "specified, other map types are not permitted");
1347 "at least one of to or from map types must be "
1348 "specified, other map types are not permitted");
1351 auto updateVar = mapInfoOp.getVarPtr();
1353 if ((to && from) || (to && updateFromVars.contains(updateVar)) ||
1354 (from && updateToVars.contains(updateVar))) {
1357 "either to or from map types can be specified, not both");
1360 if (always || close || implicit) {
1363 "present, mapper and iterator map type modifiers are permitted");
1366 to ? updateToVars.insert(updateVar) : updateFromVars.insert(updateVar);
1382 TargetDataOp::build(builder, state, clauses.ifVar, clauses.deviceVar,
1383 clauses.useDevicePtrVars, clauses.useDeviceAddrVars,
1388 if (getMapOperands().empty() && getUseDevicePtr().empty() &&
1389 getUseDeviceAddr().empty()) {
1391 "useDeviceAddr operand must be present");
1400 void TargetEnterDataOp::build(
1404 TargetEnterDataOp::build(builder, state, clauses.ifVar, clauses.deviceVar,
1406 clauses.dependVars, clauses.nowaitAttr,
1411 LogicalResult verifyDependVars =
1413 return failed(verifyDependVars) ? verifyDependVars
1421 void TargetExitDataOp::build(
1425 TargetExitDataOp::build(builder, state, clauses.ifVar, clauses.deviceVar,
1427 clauses.dependVars, clauses.nowaitAttr,
1432 LogicalResult verifyDependVars =
1434 return failed(verifyDependVars) ? verifyDependVars
1445 TargetUpdateOp::build(builder, state, clauses.ifVar, clauses.deviceVar,
1447 clauses.dependVars, clauses.nowaitAttr,
1452 LogicalResult verifyDependVars =
1454 return failed(verifyDependVars) ? verifyDependVars
1469 builder, state, clauses.ifVar, clauses.deviceVar, clauses.threadLimitVar,
1470 makeArrayAttr(ctx, clauses.dependTypeAttrs), clauses.dependVars,
1471 clauses.nowaitAttr, clauses.isDevicePtrVars, clauses.hasDeviceAddrVars,
1472 clauses.mapVars, clauses.privateVars,
1477 LogicalResult verifyDependVars =
1479 return failed(verifyDependVars) ? verifyDependVars
1490 builder, state,
nullptr,
nullptr,
1495 state.addAttributes(attributes);
1502 ParallelOp::build(builder, state, clauses.ifVar, clauses.numThreadsVar,
1503 clauses.allocateVars, clauses.allocatorVars,
1504 clauses.reductionVars,
1507 clauses.procBindKindAttr, clauses.privateVars,
1511 template <
typename OpType>
1513 auto privateVars = op.getPrivateVars();
1514 auto privatizers = op.getPrivatizersAttr();
1516 if (privateVars.empty() && (privatizers ==
nullptr || privatizers.empty()))
1519 auto numPrivateVars = privateVars.size();
1520 auto numPrivatizers = (privatizers ==
nullptr) ? 0 : privatizers.size();
1522 if (numPrivateVars != numPrivatizers)
1523 return op.
emitError() <<
"inconsistent number of private variables and "
1524 "privatizer op symbols, private vars: "
1526 <<
" vs. privatizer op symbols: " << numPrivatizers;
1528 for (
auto privateVarInfo : llvm::zip_equal(privateVars, privatizers)) {
1529 Type varType = std::get<0>(privateVarInfo).getType();
1530 SymbolRefAttr privatizerSym =
1531 cast<SymbolRefAttr>(std::get<1>(privateVarInfo));
1532 PrivateClauseOp privatizerOp =
1533 SymbolTable::lookupNearestSymbolFrom<PrivateClauseOp>(op,
1536 if (privatizerOp ==
nullptr)
1537 return op.
emitError() <<
"failed to lookup privatizer op with symbol: '"
1538 << privatizerSym <<
"'";
1540 Type privatizerType = privatizerOp.getType();
1542 if (varType != privatizerType)
1544 <<
"type mismatch between a "
1545 << (privatizerOp.getDataSharingType() ==
1546 DataSharingClauseType::Private
1549 <<
" variable and its privatizer op, var type: " << varType
1550 <<
" vs. privatizer op type: " << privatizerType;
1558 if (isa<DistributeOp>((*this)->getParentOp())) {
1560 return emitOpError() <<
"must take a loop wrapper role if nested inside "
1561 "of 'omp.distribute'";
1563 if (LoopWrapperInterface nested = getNestedWrapper()) {
1566 if (!isa<WsloopOp>(nested))
1567 return emitError() <<
"only supported nested wrapper is 'omp.wsloop'";
1569 return emitOpError() <<
"must not wrap an 'omp.loop_nest' directly";
1573 if (getAllocateVars().size() != getAllocatorsVars().size())
1575 "expected equal sizes for allocate and allocator variables");
1581 getReductionVarsByref());
1599 TeamsOp::build(builder, state, clauses.numTeamsLowerVar,
1600 clauses.numTeamsUpperVar, clauses.ifVar,
1601 clauses.threadLimitVar, clauses.allocateVars,
1602 clauses.allocatorVars, clauses.reductionVars,
1616 return emitError(
"expected to be nested inside of omp.target or not nested "
1617 "in any OpenMP dialect operations");
1620 if (
auto numTeamsLowerBound = getNumTeamsLower()) {
1621 auto numTeamsUpperBound = getNumTeamsUpper();
1622 if (!numTeamsUpperBound)
1623 return emitError(
"expected num_teams upper bound to be defined if the "
1624 "lower bound is defined");
1625 if (numTeamsLowerBound.getType() != numTeamsUpperBound.getType())
1627 "expected num_teams upper bound and lower bound to be the same type");
1631 if (getAllocateVars().size() != getAllocatorsVars().size())
1633 "expected equal sizes for allocate and allocator variables");
1636 getReductionVarsByref());
1647 SectionsOp::build(builder, state, clauses.reductionVars,
1650 clauses.allocateVars, clauses.allocatorVars,
1651 clauses.nowaitAttr);
1655 if (getAllocateVars().size() != getAllocatorsVars().size())
1657 "expected equal sizes for allocate and allocator variables");
1660 getReductionVarsByref());
1663 LogicalResult SectionsOp::verifyRegions() {
1664 for (
auto &inst : *getRegion().begin()) {
1665 if (!(isa<SectionOp>(inst) || isa<TerminatorOp>(inst))) {
1666 return emitOpError()
1667 <<
"expected omp.section op or terminator op inside region";
1682 SingleOp::build(builder, state, clauses.allocateVars, clauses.allocatorVars,
1683 clauses.copyprivateVars,
1685 clauses.nowaitAttr);
1690 if (getAllocateVars().size() != getAllocatorsVars().size())
1692 "expected equal sizes for allocate and allocator variables");
1695 getCopyprivateFuncs());
1711 reductionTypes, reductionByRef,
1712 reductionSymbols, privates)))
1721 if (reductionSymbols) {
1724 reductionOperands, reductionTypes, isByRef,
1740 state.addAttributes(attributes);
1748 WsloopOp::build(builder, state, clauses.linearVars, clauses.linearStepVars,
1749 clauses.reductionVars,
1752 clauses.scheduleValAttr, clauses.scheduleChunkVar,
1753 clauses.scheduleModAttr, clauses.scheduleSimdAttr,
1754 clauses.nowaitAttr, clauses.orderedAttr, clauses.orderAttr,
1755 clauses.orderModAttr);
1760 return emitOpError() <<
"must be a loop wrapper";
1762 if (LoopWrapperInterface nested = getNestedWrapper()) {
1765 if (!isa<SimdOp>(nested))
1766 return emitError() <<
"only supported nested wrapper is 'omp.simd'";
1770 getReductionVarsByref());
1782 SimdOp::build(builder, state, clauses.alignedVars,
1784 clauses.nontemporalVars, clauses.orderAttr,
1785 clauses.orderModAttr, clauses.safelenAttr, clauses.simdlenAttr);
1789 if (getSimdlen().has_value() && getSafelen().has_value() &&
1790 getSimdlen().value() > getSafelen().value())
1791 return emitOpError()
1792 <<
"simdlen clause and safelen clause are both present, but the "
1793 "simdlen value is not less than or equal to safelen value";
1803 return emitOpError() <<
"must be a loop wrapper";
1805 if (getNestedWrapper())
1806 return emitOpError() <<
"must wrap an 'omp.loop_nest' directly";
1818 DistributeOp::build(builder, state, clauses.distScheduleStaticAttr,
1819 clauses.distScheduleChunkSizeVar, clauses.allocateVars,
1820 clauses.allocatorVars, clauses.orderAttr,
1821 clauses.orderModAttr);
1825 if (this->getChunkSize() && !this->getDistScheduleStatic())
1826 return emitOpError() <<
"chunk size set without "
1827 "dist_schedule_static being present";
1829 if (getAllocateVars().size() != getAllocatorsVars().size())
1831 "expected equal sizes for allocate and allocator variables");
1834 return emitOpError() <<
"must be a loop wrapper";
1836 if (LoopWrapperInterface nested = getNestedWrapper()) {
1839 if (!isa<ParallelOp, SimdOp>(nested))
1840 return emitError() <<
"only supported nested wrappers are 'omp.parallel' "
1859 DeclareReductionOp op,
Region ®ion) {
1862 printer <<
"atomic ";
1874 DeclareReductionOp op,
Region ®ion) {
1877 printer <<
"cleanup ";
1881 LogicalResult DeclareReductionOp::verifyRegions() {
1882 if (getInitializerRegion().empty())
1883 return emitOpError() <<
"expects non-empty initializer region";
1884 Block &initializerEntryBlock = getInitializerRegion().
front();
1887 return emitOpError() <<
"expects initializer region with one argument "
1888 "of the reduction type";
1891 for (YieldOp yieldOp : getInitializerRegion().getOps<YieldOp>()) {
1892 if (yieldOp.getResults().size() != 1 ||
1893 yieldOp.getResults().getTypes()[0] !=
getType())
1894 return emitOpError() <<
"expects initializer region to yield a value "
1895 "of the reduction type";
1898 if (getReductionRegion().empty())
1899 return emitOpError() <<
"expects non-empty reduction region";
1900 Block &reductionEntryBlock = getReductionRegion().
front();
1905 return emitOpError() <<
"expects reduction region with two arguments of "
1906 "the reduction type";
1907 for (YieldOp yieldOp : getReductionRegion().getOps<YieldOp>()) {
1908 if (yieldOp.getResults().size() != 1 ||
1909 yieldOp.getResults().getTypes()[0] !=
getType())
1910 return emitOpError() <<
"expects reduction region to yield a value "
1911 "of the reduction type";
1914 if (!getAtomicReductionRegion().empty()) {
1915 Block &atomicReductionEntryBlock = getAtomicReductionRegion().
front();
1919 return emitOpError() <<
"expects atomic reduction region with two "
1920 "arguments of the same type";
1921 auto ptrType = llvm::dyn_cast<PointerLikeType>(
1924 (ptrType.getElementType() && ptrType.getElementType() !=
getType()))
1925 return emitOpError() <<
"expects atomic reduction region arguments to "
1926 "be accumulators containing the reduction type";
1929 if (getCleanupRegion().empty())
1931 Block &cleanupEntryBlock = getCleanupRegion().
front();
1934 return emitOpError() <<
"expects cleanup region with one argument "
1935 "of the reduction type";
1949 builder, state, clauses.ifVar, clauses.finalVar, clauses.untiedAttr,
1950 clauses.mergeableAttr, clauses.inReductionVars,
1952 makeArrayAttr(ctx, clauses.inReductionDeclSymbols), clauses.priorityVar,
1953 makeArrayAttr(ctx, clauses.dependTypeAttrs), clauses.dependVars,
1954 clauses.allocateVars, clauses.allocatorVars);
1958 LogicalResult verifyDependVars =
1960 return failed(verifyDependVars)
1963 getInReductionVars(),
1964 getInReductionVarsByref());
1975 builder, state, clauses.taskReductionVars,
1978 clauses.allocateVars, clauses.allocatorVars);
1983 getTaskReductionVars(),
1984 getTaskReductionVarsByref());
1996 builder, state, clauses.ifVar, clauses.finalVar, clauses.untiedAttr,
1997 clauses.mergeableAttr, clauses.inReductionVars,
1999 makeArrayAttr(ctx, clauses.inReductionDeclSymbols), clauses.reductionVars,
2001 makeArrayAttr(ctx, clauses.reductionDeclSymbols), clauses.priorityVar,
2002 clauses.allocateVars, clauses.allocatorVars, clauses.grainsizeVar,
2003 clauses.numTasksVar, clauses.nogroupAttr);
2008 getInReductionVars().end());
2009 allReductionNvars.insert(allReductionNvars.end(), getReductionVars().begin(),
2010 getReductionVars().end());
2011 return allReductionNvars;
2015 if (getAllocateVars().size() != getAllocatorsVars().size())
2017 "expected equal sizes for allocate and allocator variables");
2019 getReductionVarsByref())) ||
2021 getInReductionVars(),
2022 getInReductionVarsByref())))
2025 if (!getReductionVars().empty() && getNogroup())
2026 return emitError(
"if a reduction clause is present on the taskloop "
2027 "directive, the nogroup clause must not be specified");
2028 for (
auto var : getReductionVars()) {
2029 if (llvm::is_contained(getInReductionVars(), var))
2030 return emitError(
"the same list item cannot appear in both a reduction "
2031 "and an in_reduction clause");
2034 if (getGrainSize() && getNumTasks()) {
2036 "the grainsize clause and num_tasks clause are mutually exclusive and "
2037 "may not appear on the same taskloop directive");
2041 return emitOpError() <<
"must be a loop wrapper";
2043 if (LoopWrapperInterface nested = getNestedWrapper()) {
2046 if (!isa<SimdOp>(nested))
2047 return emitError() <<
"only supported nested wrapper is 'omp.simd'";
2070 for (
auto &iv : ivs)
2071 iv.type = loopVarType;
2100 Region ®ion = getRegion();
2102 p <<
" (" << args <<
") : " << args[0].getType() <<
" = (" <<
getLowerBound()
2106 p <<
"step (" << getStep() <<
") ";
2112 LoopNestOp::build(builder, state, clauses.loopLBVar, clauses.loopUBVar,
2113 clauses.loopStepVar, clauses.loopInclusiveAttr);
2118 return emitOpError() <<
"must represent at least one loop";
2121 return emitOpError() <<
"number of range arguments and IVs do not match";
2123 for (
auto [lb, iv] : llvm::zip_equal(
getLowerBound(), getIVs())) {
2124 if (lb.getType() != iv.getType())
2125 return emitOpError()
2126 <<
"range argument type does not match corresponding IV type";
2130 llvm::dyn_cast_if_present<LoopWrapperInterface>((*this)->getParentOp());
2132 if (!wrapper || !wrapper.isWrapper())
2133 return emitOpError() <<
"expects parent op to be a valid loop wrapper";
2138 void LoopNestOp::gatherWrappers(
2141 while (
auto wrapper =
2142 llvm::dyn_cast_if_present<LoopWrapperInterface>(parent)) {
2143 if (!wrapper.isWrapper())
2145 wrappers.push_back(wrapper);
2156 CriticalDeclareOp::build(builder, state, clauses.criticalNameAttr,
2165 if (getNameAttr()) {
2166 SymbolRefAttr symbolRef = getNameAttr();
2170 return emitOpError() <<
"expected symbol reference " << symbolRef
2171 <<
" to point to a critical declaration";
2191 return op.
emitOpError() <<
"must be nested inside of a loop";
2195 if (
auto wsloopOp = dyn_cast<WsloopOp>(wrapper)) {
2196 IntegerAttr orderedAttr = wsloopOp.getOrderedValAttr();
2198 return op.
emitOpError() <<
"the enclosing worksharing-loop region must "
2199 "have an ordered clause";
2201 if (hasRegion && orderedAttr.getInt() != 0)
2202 return op.
emitOpError() <<
"the enclosing loop's ordered clause must not "
2203 "have a parameter present";
2205 if (!hasRegion && orderedAttr.getInt() == 0)
2206 return op.
emitOpError() <<
"the enclosing loop's ordered clause must "
2207 "have a parameter present";
2208 }
else if (!isa<SimdOp>(wrapper)) {
2209 return op.
emitOpError() <<
"must be nested inside of a worksharing, simd "
2210 "or worksharing simd loop";
2217 OrderedOp::build(builder, state, clauses.doacrossDependTypeAttr,
2218 clauses.doacrossNumLoopsAttr, clauses.doacrossVectorVars);
2225 auto wrapper = (*this)->getParentOfType<WsloopOp>();
2226 if (!wrapper || *wrapper.getOrderedVal() != *getNumLoopsVal())
2227 return emitOpError() <<
"number of variables in depend clause does not "
2228 <<
"match number of iteration variables in the "
2236 OrderedRegionOp::build(builder, state, clauses.parLevelSimdAttr);
2254 TaskwaitOp::build(builder, state);
2262 if (verifyCommon().failed())
2263 return mlir::failure();
2265 if (
auto mo = getMemoryOrderVal()) {
2266 if (*mo == ClauseMemoryOrderKind::Acq_rel ||
2267 *mo == ClauseMemoryOrderKind::Release) {
2269 "memory-order must not be acq_rel or release for atomic reads");
2280 if (verifyCommon().failed())
2281 return mlir::failure();
2283 if (
auto mo = getMemoryOrderVal()) {
2284 if (*mo == ClauseMemoryOrderKind::Acq_rel ||
2285 *mo == ClauseMemoryOrderKind::Acquire) {
2287 "memory-order must not be acq_rel or acquire for atomic writes");
2297 LogicalResult AtomicUpdateOp::canonicalize(AtomicUpdateOp op,
2303 if (
Value writeVal = op.getWriteOpVal()) {
2305 op.getHintValAttr(),
2306 op.getMemoryOrderValAttr());
2313 if (verifyCommon().failed())
2314 return mlir::failure();
2316 if (
auto mo = getMemoryOrderVal()) {
2317 if (*mo == ClauseMemoryOrderKind::Acq_rel ||
2318 *mo == ClauseMemoryOrderKind::Acquire) {
2320 "memory-order must not be acq_rel or acquire for atomic updates");
2327 LogicalResult AtomicUpdateOp::verifyRegions() {
return verifyRegionsCommon(); }
2333 AtomicReadOp AtomicCaptureOp::getAtomicReadOp() {
2334 if (
auto op = dyn_cast<AtomicReadOp>(getFirstOp()))
2336 return dyn_cast<AtomicReadOp>(getSecondOp());
2339 AtomicWriteOp AtomicCaptureOp::getAtomicWriteOp() {
2340 if (
auto op = dyn_cast<AtomicWriteOp>(getFirstOp()))
2342 return dyn_cast<AtomicWriteOp>(getSecondOp());
2345 AtomicUpdateOp AtomicCaptureOp::getAtomicUpdateOp() {
2346 if (
auto op = dyn_cast<AtomicUpdateOp>(getFirstOp()))
2348 return dyn_cast<AtomicUpdateOp>(getSecondOp());
2355 LogicalResult AtomicCaptureOp::verifyRegions() {
2356 if (verifyRegionsCommon().failed())
2357 return mlir::failure();
2359 if (getFirstOp()->getAttr(
"hint_val") || getSecondOp()->getAttr(
"hint_val"))
2361 "operations inside capture region must not have hint clause");
2363 if (getFirstOp()->getAttr(
"memory_order_val") ||
2364 getSecondOp()->getAttr(
"memory_order_val"))
2366 "operations inside capture region must not have memory_order clause");
2376 CancelOp::build(builder, state, clauses.cancelDirectiveNameAttr,
2381 ClauseCancellationConstructType cct = getCancellationConstructTypeVal();
2385 return emitOpError() <<
"must be used within a region supporting "
2389 if ((cct == ClauseCancellationConstructType::Parallel) &&
2390 !isa<ParallelOp>(parentOp)) {
2391 return emitOpError() <<
"cancel parallel must appear "
2392 <<
"inside a parallel region";
2394 if (cct == ClauseCancellationConstructType::Loop) {
2395 auto loopOp = dyn_cast<LoopNestOp>(parentOp);
2396 auto wsloopOp = llvm::dyn_cast_if_present<WsloopOp>(
2397 loopOp ? loopOp->getParentOp() :
nullptr);
2400 return emitOpError()
2401 <<
"cancel loop must appear inside a worksharing-loop region";
2403 if (wsloopOp.getNowaitAttr()) {
2404 return emitError() <<
"A worksharing construct that is canceled "
2405 <<
"must not have a nowait clause";
2407 if (wsloopOp.getOrderedValAttr()) {
2408 return emitError() <<
"A worksharing construct that is canceled "
2409 <<
"must not have an ordered clause";
2412 }
else if (cct == ClauseCancellationConstructType::Sections) {
2413 if (!(isa<SectionsOp>(parentOp) || isa<SectionOp>(parentOp))) {
2414 return emitOpError() <<
"cancel sections must appear "
2415 <<
"inside a sections region";
2417 if (isa_and_nonnull<SectionsOp>(parentOp->
getParentOp()) &&
2418 cast<SectionsOp>(parentOp->
getParentOp()).getNowaitAttr()) {
2419 return emitError() <<
"A sections construct that is canceled "
2420 <<
"must not have a nowait clause";
2433 CancellationPointOp::build(builder, state, clauses.cancelDirectiveNameAttr);
2437 ClauseCancellationConstructType cct = getCancellationConstructTypeVal();
2441 return emitOpError() <<
"must be used within a region supporting "
2442 "cancellation point directive";
2445 if ((cct == ClauseCancellationConstructType::Parallel) &&
2446 !(isa<ParallelOp>(parentOp))) {
2447 return emitOpError() <<
"cancellation point parallel must appear "
2448 <<
"inside a parallel region";
2450 if ((cct == ClauseCancellationConstructType::Loop) &&
2451 (!isa<LoopNestOp>(parentOp) || !isa<WsloopOp>(parentOp->
getParentOp()))) {
2452 return emitOpError() <<
"cancellation point loop must appear "
2453 <<
"inside a worksharing-loop region";
2455 if ((cct == ClauseCancellationConstructType::Sections) &&
2456 !(isa<SectionsOp>(parentOp) || isa<SectionOp>(parentOp))) {
2457 return emitOpError() <<
"cancellation point sections must appear "
2458 <<
"inside a sections region";
2469 auto extent = getExtent();
2471 if (!extent && !upperbound)
2472 return emitError(
"expected extent or upperbound.");
2479 PrivateClauseOp::build(
2480 odsBuilder, odsState, symName, type,
2482 DataSharingClauseType::Private));
2488 auto verifyTerminator = [&](
Operation *terminator,
2489 bool yieldsValue) -> LogicalResult {
2493 if (!llvm::isa<YieldOp>(terminator))
2495 <<
"expected exit block terminator to be an `omp.yield` op.";
2497 YieldOp yieldOp = llvm::cast<YieldOp>(terminator);
2498 TypeRange yieldedTypes = yieldOp.getResults().getTypes();
2501 if (yieldedTypes.empty())
2505 <<
"Did not expect any values to be yielded.";
2508 if (yieldedTypes.size() == 1 && yieldedTypes.front() == symType)
2512 <<
"Invalid yielded value. Expected type: " << symType
2515 if (yieldedTypes.empty())
2518 error << yieldedTypes;
2523 auto verifyRegion = [&](
Region ®ion,
unsigned expectedNumArgs,
2524 StringRef regionName,
2525 bool yieldsValue) -> LogicalResult {
2526 assert(!region.
empty());
2530 <<
"`" << regionName <<
"`: "
2531 <<
"expected " << expectedNumArgs
2534 for (
Block &block : region) {
2536 if (!block.mightHaveTerminator())
2539 if (failed(verifyTerminator(block.getTerminator(), yieldsValue)))
2546 if (failed(verifyRegion(getAllocRegion(), 1,
"alloc",
2550 DataSharingClauseType dsType = getDataSharingType();
2552 if (dsType == DataSharingClauseType::Private && !getCopyRegion().empty())
2553 return emitError(
"`private` clauses require only an `alloc` region.");
2555 if (dsType == DataSharingClauseType::FirstPrivate && getCopyRegion().empty())
2557 "`firstprivate` clauses require both `alloc` and `copy` regions.");
2559 if (dsType == DataSharingClauseType::FirstPrivate &&
2560 failed(verifyRegion(getCopyRegion(), 2,
"copy",
2564 if (!getDeallocRegion().empty() &&
2565 failed(verifyRegion(getDeallocRegion(), 1,
"dealloc",
2578 MaskedOp::build(builder, state, clauses.filteredThreadIdVar);
2581 #define GET_ATTRDEF_CLASSES
2582 #include "mlir/Dialect/OpenMP/OpenMPOpsAttributes.cpp.inc"
2584 #define GET_OP_CLASSES
2585 #include "mlir/Dialect/OpenMP/OpenMPOps.cpp.inc"
2587 #define GET_TYPEDEF_CLASSES
2588 #include "mlir/Dialect/OpenMP/OpenMPOpsTypes.cpp.inc"
static std::optional< int64_t > getUpperBound(Value iv)
Gets the constant upper bound on an affine.for iv.
static std::optional< int64_t > getLowerBound(Value iv)
Gets the constant lower bound on an iv.
static void visit(Operation *op, DenseSet< Operation * > &visited)
Visits all the pdl.operand(s), pdl.result(s), and pdl.operation(s) connected to the given operation.
static MLIRContext * getContext(OpFoldResult val)
void printClauseAttr(OpAsmPrinter &p, Operation *op, ClauseAttr attr)
static void printPrivateList(OpAsmPrinter &p, Operation *op, ValueRange privateVarOperands, TypeRange privateVarTypes, ArrayAttr privatizerSymbols)
static ParseResult parseAlignedClause(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &alignedItems, SmallVectorImpl< Type > &types, ArrayAttr &alignmentValues)
aligned ::= aligned ( aligned-list ) aligned-list := aligned-val | aligned-val aligned-list aligned-v...
static ArrayAttr makeArrayAttr(MLIRContext *context, llvm::ArrayRef< Attribute > attrs)
static ParseResult parseClauseAttr(AsmParser &parser, ClauseAttr &attr)
static void printDependVarList(OpAsmPrinter &p, Operation *op, OperandRange dependVars, TypeRange dependTypes, std::optional< ArrayAttr > depends)
Print Depend clause.
static ParseResult parseCaptureType(OpAsmParser &parser, VariableCaptureKindAttr &mapCapture)
static DenseBoolArrayAttr makeDenseBoolArrayAttr(MLIRContext *ctx, const ArrayRef< bool > boolArray)
static ParseResult parseLinearClause(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &vars, SmallVectorImpl< Type > &types, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &stepVars)
linear ::= linear ( linear-list ) linear-list := linear-val | linear-val linear-list linear-val := ss...
static void printAllocateAndAllocator(OpAsmPrinter &p, Operation *op, OperandRange varsAllocate, TypeRange typesAllocate, OperandRange varsAllocator, TypeRange typesAllocator)
Print allocate clause.
static LogicalResult verifyAlignedClause(Operation *op, std::optional< ArrayAttr > alignmentValues, OperandRange alignedVariables)
static ParseResult parseParallelRegion(OpAsmParser &parser, Region ®ion, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &reductionVarOperands, SmallVectorImpl< Type > &reductionVarTypes, DenseBoolArrayAttr &reductionByRef, ArrayAttr &reductionSymbols, llvm::SmallVectorImpl< OpAsmParser::UnresolvedOperand > &privateVarOperands, llvm::SmallVectorImpl< Type > &privateVarsTypes, ArrayAttr &privatizerSymbols)
static ParseResult parseAllocateAndAllocator(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &operandsAllocate, SmallVectorImpl< Type > &typesAllocate, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &operandsAllocator, SmallVectorImpl< Type > &typesAllocator)
Parse an allocate clause with allocators and a list of operands with types.
static ParseResult parseSynchronizationHint(OpAsmParser &parser, IntegerAttr &hintAttr)
Parses a Synchronization Hint clause.
uint64_t mapTypeToBitFlag(uint64_t value, llvm::omp::OpenMPOffloadMappingFlags flag)
static void printAlignedClause(OpAsmPrinter &p, Operation *op, ValueRange alignedVars, TypeRange alignedVarTypes, std::optional< ArrayAttr > alignmentValues)
Print Aligned Clause.
static LogicalResult verifyMapClause(Operation *op, OperandRange mapOperands)
void printWsloop(OpAsmPrinter &p, Operation *op, Region ®ion, ValueRange reductionOperands, TypeRange reductionTypes, DenseBoolArrayAttr isByRef, ArrayAttr reductionSymbols)
static LogicalResult verifyDependVarList(Operation *op, std::optional< ArrayAttr > depends, OperandRange dependVars)
Verifies Depend clause.
static LogicalResult verifySynchronizationHint(Operation *op, uint64_t hint)
Verifies a synchronization hint clause.
static LogicalResult verifyReductionVarList(Operation *op, std::optional< ArrayAttr > reductions, OperandRange reductionVars, std::optional< ArrayRef< bool >> byRef)
Verifies Reduction Clause.
static void printOrderClause(OpAsmPrinter &p, Operation *op, ClauseOrderKindAttr kindAttr, OrderModifierAttr modifierAttr)
static ParseResult parseReductionVarList(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &operands, SmallVectorImpl< Type > &types, DenseBoolArrayAttr &isByRef, ArrayAttr &reductionSymbols)
reduction-entry-list ::= reduction-entry | reduction-entry-list , reduction-entry reduction-entry ::=...
static void printSynchronizationHint(OpAsmPrinter &p, Operation *op, IntegerAttr hintAttr)
Prints a Synchronization Hint clause.
static ParseResult parseClauseWithRegionArgs(OpAsmParser &parser, Region ®ion, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &operands, SmallVectorImpl< Type > &types, DenseBoolArrayAttr &isByRef, ArrayAttr &symbols, SmallVectorImpl< OpAsmParser::Argument > ®ionPrivateArgs)
static void printParallelRegion(OpAsmPrinter &p, Operation *op, Region ®ion, ValueRange reductionVarOperands, TypeRange reductionVarTypes, DenseBoolArrayAttr reductionVarIsByRef, ArrayAttr reductionSymbols, ValueRange privateVarOperands, TypeRange privateVarTypes, ArrayAttr privatizerSymbols)
static void printAtomicReductionRegion(OpAsmPrinter &printer, DeclareReductionOp op, Region ®ion)
ParseResult parseWsloop(OpAsmParser &parser, Region ®ion, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &reductionOperands, SmallVectorImpl< Type > &reductionTypes, DenseBoolArrayAttr &reductionByRef, ArrayAttr &reductionSymbols)
static ParseResult parseCopyPrivateVarList(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &operands, SmallVectorImpl< Type > &types, ArrayAttr ©PrivateSymbols)
copyprivate-entry-list ::= copyprivate-entry | copyprivate-entry-list , copyprivate-entry copyprivate...
static void printCopyPrivateVarList(OpAsmPrinter &p, Operation *op, OperandRange copyPrivateVars, TypeRange copyPrivateTypes, std::optional< ArrayAttr > copyPrivateFuncs)
Print CopyPrivate clause.
static void printLinearClause(OpAsmPrinter &p, Operation *op, ValueRange linearVars, TypeRange linearVarTypes, ValueRange linearStepVars)
Print Linear Clause.
static void printClauseWithRegionArgs(OpAsmPrinter &p, Operation *op, ValueRange argsSubrange, StringRef clauseName, ValueRange operands, TypeRange types, DenseBoolArrayAttr byRef, ArrayAttr symbols)
static ParseResult parseMapClause(OpAsmParser &parser, IntegerAttr &mapType)
Parses a map_entries map type from a string format back into its numeric value.
static LogicalResult verifyOrderedParent(Operation &op)
static void printMembersIndex(OpAsmPrinter &p, MapInfoOp op, DenseIntElementsAttr membersIdx)
static ParseResult parseScheduleClause(OpAsmParser &parser, ClauseScheduleKindAttr &scheduleAttr, ScheduleModifierAttr &scheduleModifier, UnitAttr &simdModifier, std::optional< OpAsmParser::UnresolvedOperand > &chunkSize, Type &chunkType)
schedule ::= schedule ( sched-list ) sched-list ::= sched-val | sched-val sched-list | sched-val ,...
static LogicalResult verifyNontemporalClause(Operation *op, OperandRange nontemporalVariables)
static ParseResult verifyScheduleModifiers(OpAsmParser &parser, SmallVectorImpl< SmallString< 12 >> &modifiers)
static void printReductionVarList(OpAsmPrinter &p, Operation *op, OperandRange reductionVars, TypeRange reductionTypes, std::optional< DenseBoolArrayAttr > isByRef, std::optional< ArrayAttr > reductions)
Print Reduction clause.
static void printScheduleClause(OpAsmPrinter &p, Operation *op, ClauseScheduleKindAttr schedAttr, ScheduleModifierAttr modifier, UnitAttr simd, Value scheduleChunkVar, Type scheduleChunkType)
Print schedule clause.
static ParseResult parseCleanupReductionRegion(OpAsmParser &parser, Region ®ion)
static ParseResult parseAtomicReductionRegion(OpAsmParser &parser, Region ®ion)
static void printCaptureType(OpAsmPrinter &p, Operation *op, VariableCaptureKindAttr mapCaptureType)
static ParseResult parsePrivateList(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &privateOperands, SmallVectorImpl< Type > &privateOperandTypes, ArrayAttr &privatizerSymbols)
static bool opInGlobalImplicitParallelRegion(Operation *op)
static LogicalResult verifyCopyPrivateVarList(Operation *op, OperandRange copyPrivateVars, std::optional< ArrayAttr > copyPrivateFuncs)
Verifies CopyPrivate Clause.
static LogicalResult verifyPrivateVarList(OpType &op)
static ParseResult parseDependVarList(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &operands, SmallVectorImpl< Type > &types, ArrayAttr &dependsArray)
depend-entry-list ::= depend-entry | depend-entry-list , depend-entry depend-entry ::= depend-kind ->...
static ParseResult parseOrderClause(OpAsmParser &parser, ClauseOrderKindAttr &kindAttr, OrderModifierAttr &modifierAttr)
static void printMapClause(OpAsmPrinter &p, Operation *op, IntegerAttr mapType)
Prints a map_entries map type from its numeric value out into its string format.
static ParseResult parseMapEntries(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &mapOperands, SmallVectorImpl< Type > &mapOperandTypes)
static void printMapEntries(OpAsmPrinter &p, Operation *op, OperandRange mapOperands, TypeRange mapOperandTypes)
static void printCleanupReductionRegion(OpAsmPrinter &printer, DeclareReductionOp op, Region ®ion)
static ParseResult parseMembersIndex(OpAsmParser &parser, DenseIntElementsAttr &membersIdx)
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity,...
This base class exposes generic asm parser hooks, usable across the various derived parsers.
@ Paren
Parens surrounding zero or more operands.
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalEqual()=0
Parse a = token if present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseOptionalColon()=0
Parse a : token if present.
virtual ParseResult parseLSquare()=0
Parse a [ token.
virtual ParseResult parseRSquare()=0
Parse a ] token.
ParseResult parseInteger(IntT &result)
Parse an integer value from the stream.
virtual ParseResult parseOptionalArrow()=0
Parse a '->' token if present.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual ParseResult parseColon()=0
Parse a : token.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseArrow()=0
Parse a '->' token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
Block represents an ordered list of Operations.
ValueTypeRange< BlockArgListType > getArgumentTypes()
Return a range containing the types of the arguments for this block.
BlockArgument getArgument(unsigned i)
unsigned getNumArguments()
SuccessorRange getSuccessors()
BlockArgListType getArguments()
IntegerAttr getIntegerAttr(Type type, int64_t value)
IntegerType getIntegerType(unsigned width)
MLIRContext * getContext() const
An attribute that represents a reference to a dense integer vector or tensor object.
static DenseIntElementsAttr get(const ShapedType &type, Arg &&arg)
Get an instance of a DenseIntElementsAttr with the given arguments.
MLIRContext is the top-level object for a collection of MLIR operations.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region ®ion, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult parseArgumentList(SmallVectorImpl< Argument > &result, Delimiter delimiter=Delimiter::None, bool allowType=false, bool allowAttrs=false)=0
Parse zero or more arguments with a specified surrounding delimiter.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
This class helps build Operations.
This class implements the operand iterators for the Operation class.
type_range getType() const
Operation is the basic unit of execution within MLIR.
Dialect * getDialect()
Return the dialect this operation is associated with, or nullptr if the associated dialect is not loa...
MLIRContext * getContext()
Return the context this operation is associated with.
unsigned getNumRegions()
Returns the number of regions held by this operation.
Location getLoc()
The source location the operation was defined or derived from.
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Block * getBlock()
Returns the operation block that contains this operation.
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class contains a list of basic blocks and a link to the parent operation it is attached to.
BlockArgListType getArguments()
unsigned getNumArguments()
Location getLoc()
Return a location for this region.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class represents a collection of SymbolTables.
Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
type_range getType() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Base class for DenseArrayAttr that is instantiated and specialized for each supported element type be...
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< T > content)
Builder from ArrayRef<T>.
ArrayRef< T > asArrayRef() const
Runtime
Potential runtimes for AMD GPU kernels.
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
This is the representation of an operand reference.
This represents an operation in an abstracted form, suitable for use with the builder APIs.
SmallVector< Value, 4 > operands
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
Region * addRegion()
Create a region that should be attached to the operation.
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.