23 #include "llvm/ADT/BitVector.h"
24 #include "llvm/ADT/STLExtras.h"
25 #include "llvm/ADT/STLForwardCompat.h"
26 #include "llvm/ADT/SmallString.h"
27 #include "llvm/ADT/StringExtras.h"
28 #include "llvm/ADT/StringRef.h"
29 #include "llvm/ADT/TypeSwitch.h"
30 #include "llvm/Frontend/OpenMP/OMPConstants.h"
35 #include "mlir/Dialect/OpenMP/OpenMPOpsDialect.cpp.inc"
36 #include "mlir/Dialect/OpenMP/OpenMPOpsEnums.cpp.inc"
37 #include "mlir/Dialect/OpenMP/OpenMPOpsInterfaces.cpp.inc"
38 #include "mlir/Dialect/OpenMP/OpenMPTypeInterfaces.cpp.inc"
45 struct MemRefPointerLikeModel
46 :
public PointerLikeType::ExternalModel<MemRefPointerLikeModel,
49 return llvm::cast<MemRefType>(pointer).getElementType();
53 struct LLVMPointerPointerLikeModel
54 :
public PointerLikeType::ExternalModel<LLVMPointerPointerLikeModel,
55 LLVM::LLVMPointerType> {
62 bool shouldMaterializeInto(
Region *region)
const final {
64 return isa<TargetOp>(region->getParentOp());
69 void OpenMPDialect::initialize() {
72 #include "mlir/Dialect/OpenMP/OpenMPOps.cpp.inc"
75 #define GET_ATTRDEF_LIST
76 #include "mlir/Dialect/OpenMP/OpenMPOpsAttributes.cpp.inc"
79 #define GET_TYPEDEF_LIST
80 #include "mlir/Dialect/OpenMP/OpenMPOpsTypes.cpp.inc"
83 addInterface<OpenMPDialectFoldInterface>();
84 MemRefType::attachInterface<MemRefPointerLikeModel>(*
getContext());
85 LLVM::LLVMPointerType::attachInterface<LLVMPointerPointerLikeModel>(
90 mlir::ModuleOp::attachInterface<mlir::omp::OffloadModuleDefaultModel>(
96 mlir::LLVM::GlobalOp::attachInterface<
99 mlir::LLVM::LLVMFuncOp::attachInterface<
102 mlir::func::FuncOp::attachInterface<
128 operandsAllocator.push_back(operand);
129 typesAllocator.push_back(type);
135 operandsAllocate.push_back(operand);
136 typesAllocate.push_back(type);
147 for (
unsigned i = 0; i < varsAllocate.size(); ++i) {
148 std::string separator = i == varsAllocate.size() - 1 ?
"" :
", ";
149 p << varsAllocator[i] <<
" : " << typesAllocator[i] <<
" -> ";
150 p << varsAllocate[i] <<
" : " << typesAllocate[i] << separator;
158 template <
typename ClauseAttr>
160 using ClauseT = decltype(std::declval<ClauseAttr>().getValue());
165 if (std::optional<ClauseT> enumValue = symbolizeEnum<ClauseT>(enumStr)) {
169 return parser.
emitError(loc,
"invalid clause value: '") << enumStr <<
"'";
172 template <
typename ClauseAttr>
174 p << stringifyEnum(attr.getValue());
198 types.push_back(type);
199 stepVars.push_back(stepVar);
208 size_t linearVarsSize = linearVars.size();
209 for (
unsigned i = 0; i < linearVarsSize; ++i) {
210 std::string separator = i == linearVarsSize - 1 ?
"" :
", ";
212 if (linearStepVars.size() > i)
213 p <<
" = " << linearStepVars[i];
214 p <<
" : " << linearVars[i].getType() << separator;
227 for (
const auto &it : nontemporalVariables)
228 if (!nontemporalItems.insert(it).second)
229 return op->
emitOpError() <<
"nontemporal variable used more than once";
241 if (!alignedVariables.empty()) {
242 if (!alignmentValues || alignmentValues->size() != alignedVariables.size())
244 <<
"expected as many alignment values as aligned variables";
247 return op->
emitOpError() <<
"unexpected alignment values attribute";
253 for (
auto it : alignedVariables)
254 if (!alignedItems.insert(it).second)
255 return op->
emitOpError() <<
"aligned variable used more than once";
257 if (!alignmentValues)
261 for (
unsigned i = 0; i < (*alignmentValues).size(); ++i) {
262 if (
auto intAttr = llvm::dyn_cast<IntegerAttr>((*alignmentValues)[i])) {
263 if (intAttr.getValue().sle(0))
264 return op->
emitOpError() <<
"alignment should be greater than 0";
266 return op->
emitOpError() <<
"expected integer alignment";
282 if (parser.parseOperand(alignedItems.emplace_back()) ||
283 parser.parseColonType(types.emplace_back()) ||
284 parser.parseArrow() ||
285 parser.parseAttribute(alignmentVec.emplace_back())) {
292 alignmentValues =
ArrayAttr::get(parser.getContext(), alignments);
300 std::optional<ArrayAttr> alignmentValues) {
301 for (
unsigned i = 0; i < alignedVars.size(); ++i) {
304 p << alignedVars[i] <<
" : " << alignedVars[i].
getType();
305 p <<
" -> " << (*alignmentValues)[i];
316 if (modifiers.size() > 2)
318 for (
const auto &
mod : modifiers) {
321 auto symbol = symbolizeScheduleModifier(
mod);
324 <<
" unknown modifier type: " <<
mod;
329 if (modifiers.size() == 1) {
330 if (symbolizeScheduleModifier(modifiers[0]) == ScheduleModifier::simd) {
331 modifiers.push_back(modifiers[0]);
332 modifiers[0] = stringifyScheduleModifier(ScheduleModifier::none);
334 }
else if (modifiers.size() == 2) {
337 if (symbolizeScheduleModifier(modifiers[0]) == ScheduleModifier::simd ||
338 symbolizeScheduleModifier(modifiers[1]) != ScheduleModifier::simd)
340 <<
" incorrect modifier order";
355 OpAsmParser &parser, ClauseScheduleKindAttr &scheduleAttr,
356 ScheduleModifierAttr &scheduleModifier, UnitAttr &simdModifier,
357 std::optional<OpAsmParser::UnresolvedOperand> &chunkSize,
Type &chunkType) {
361 std::optional<mlir::omp::ClauseScheduleKind> schedule =
362 symbolizeClauseScheduleKind(keyword);
368 case ClauseScheduleKind::Static:
369 case ClauseScheduleKind::Dynamic:
370 case ClauseScheduleKind::Guided:
376 chunkSize = std::nullopt;
379 case ClauseScheduleKind::Auto:
381 chunkSize = std::nullopt;
390 modifiers.push_back(
mod);
396 if (!modifiers.empty()) {
398 if (std::optional<ScheduleModifier>
mod =
399 symbolizeScheduleModifier(modifiers[0])) {
402 return parser.
emitError(loc,
"invalid schedule modifier");
405 if (modifiers.size() > 1) {
406 assert(symbolizeScheduleModifier(modifiers[1]) == ScheduleModifier::simd);
416 ClauseScheduleKindAttr schedAttr,
417 ScheduleModifierAttr modifier, UnitAttr simd,
418 Value scheduleChunkVar,
419 Type scheduleChunkType) {
420 p << stringifyClauseScheduleKind(schedAttr.getValue());
421 if (scheduleChunkVar)
422 p <<
" = " << scheduleChunkVar <<
" : " << scheduleChunkVar.
getType();
424 p <<
", " << stringifyScheduleModifier(modifier.getValue());
439 unsigned regionArgOffset = regionPrivateArgs.size();
443 if (parser.parseAttribute(reductionVec.emplace_back()) ||
444 parser.parseOperand(operands.emplace_back()) ||
445 parser.parseArrow() ||
446 parser.parseArgument(regionPrivateArgs.emplace_back()) ||
447 parser.parseColonType(types.emplace_back()))
453 auto *argsBegin = regionPrivateArgs.begin();
455 argsBegin + regionArgOffset + types.size());
456 for (
auto [prv, type] : llvm::zip_equal(argsSubrange, types)) {
468 p << clauseName <<
"(";
469 llvm::interleaveComma(
470 llvm::zip_equal(symbols, operands, argsSubrange, types), p, [&p](
auto t) {
471 auto [sym, op, arg, type] = t;
472 p << sym <<
" " << op <<
" -> " << arg <<
" : " << type;
483 ArrayAttr &privatizerSymbols) {
488 reductionVarTypes, reductionSymbols,
495 privateVarsTypes, privatizerSymbols,
500 return parser.
parseRegion(region, regionPrivateArgs);
506 ArrayAttr reductionSymbols,
509 ArrayAttr privatizerSymbols) {
510 if (reductionSymbols) {
513 argsBegin + reductionVarTypes.size());
515 reductionVarOperands, reductionVarTypes,
519 if (privatizerSymbols) {
522 argsBegin + reductionVarOperands.size() +
523 privateVarTypes.size());
525 privateVarOperands, privateVarTypes,
539 ArrayAttr &redcuctionSymbols) {
542 if (parser.parseAttribute(reductionVec.emplace_back()) ||
543 parser.parseArrow() ||
544 parser.parseOperand(operands.emplace_back()) ||
545 parser.parseColonType(types.emplace_back()))
559 std::optional<ArrayAttr> reductions) {
560 for (
unsigned i = 0, e = reductions->size(); i < e; ++i) {
563 p << (*reductions)[i] <<
" -> " << reductionVars[i] <<
" : "
570 std::optional<ArrayAttr> reductions,
572 if (!reductionVars.empty()) {
573 if (!reductions || reductions->size() != reductionVars.size())
575 <<
"expected as many reduction symbol references "
576 "as reduction variables";
579 return op->
emitOpError() <<
"unexpected reduction symbol references";
586 for (
auto args : llvm::zip(reductionVars, *reductions)) {
587 Value accum = std::get<0>(args);
589 if (!accumulators.insert(accum).second)
590 return op->
emitOpError() <<
"accumulator variable used more than once";
593 auto symbolRef = llvm::cast<SymbolRefAttr>(std::get<1>(args));
595 SymbolTable::lookupNearestSymbolFrom<DeclareReductionOp>(op, symbolRef);
597 return op->
emitOpError() <<
"expected symbol reference " << symbolRef
598 <<
" to point to a reduction declaration";
600 if (decl.getAccumulatorType() && decl.getAccumulatorType() != varType)
602 <<
"expected accumulator (" << varType
603 <<
") to be the same type as reduction declaration ("
604 << decl.getAccumulatorType() <<
")";
623 if (parser.parseOperand(operands.emplace_back()) ||
624 parser.parseArrow() ||
625 parser.parseAttribute(copyPrivateFuncsVec.emplace_back()) ||
626 parser.parseColonType(types.emplace_back()))
632 copyPrivateFuncsVec.end());
641 std::optional<ArrayAttr> copyPrivateFuncs) {
642 if (!copyPrivateFuncs.has_value())
644 llvm::interleaveComma(
645 llvm::zip(copyPrivateVars, *copyPrivateFuncs, copyPrivateTypes), p,
646 [&](
const auto &args) {
647 p << std::get<0>(args) <<
" -> " << std::get<1>(args) <<
" : "
648 << std::get<2>(args);
655 std::optional<ArrayAttr> copyPrivateFuncs) {
656 size_t copyPrivateFuncsSize =
657 copyPrivateFuncs.has_value() ? copyPrivateFuncs->size() : 0;
658 if (copyPrivateFuncsSize != copyPrivateVars.size())
659 return op->
emitOpError() <<
"inconsistent number of copyPrivate vars (= "
660 << copyPrivateVars.size()
661 <<
") and functions (= " << copyPrivateFuncsSize
662 <<
"), both must be equal";
663 if (!copyPrivateFuncs.has_value())
666 for (
auto copyPrivateVarAndFunc :
667 llvm::zip(copyPrivateVars, *copyPrivateFuncs)) {
669 llvm::cast<SymbolRefAttr>(std::get<1>(copyPrivateVarAndFunc));
670 std::optional<std::variant<mlir::func::FuncOp, mlir::LLVM::LLVMFuncOp>>
672 if (mlir::func::FuncOp mlirFuncOp =
673 SymbolTable::lookupNearestSymbolFrom<mlir::func::FuncOp>(op,
676 else if (mlir::LLVM::LLVMFuncOp llvmFuncOp =
677 SymbolTable::lookupNearestSymbolFrom<mlir::LLVM::LLVMFuncOp>(
681 auto getNumArguments = [&] {
682 return std::visit([](
auto &f) {
return f.getNumArguments(); }, *funcOp);
685 auto getArgumentType = [&](
unsigned i) {
686 return std::visit([i](
auto &f) {
return f.getArgumentTypes()[i]; },
691 return op->
emitOpError() <<
"expected symbol reference " << symbolRef
692 <<
" to point to a copy function";
694 if (getNumArguments() != 2)
696 <<
"expected copy function " << symbolRef <<
" to have 2 operands";
698 Type argTy = getArgumentType(0);
699 if (argTy != getArgumentType(1))
700 return op->
emitOpError() <<
"expected copy function " << symbolRef
701 <<
" arguments to have the same type";
703 Type varType = std::get<0>(copyPrivateVarAndFunc).getType();
704 if (argTy != varType)
706 <<
"expected copy function arguments' type (" << argTy
707 <<
") to be the same as copyprivate variable's type (" << varType
728 if (parser.parseKeyword(&keyword) || parser.parseArrow() ||
729 parser.parseOperand(operands.emplace_back()) ||
730 parser.parseColonType(types.emplace_back()))
732 if (std::optional<ClauseTaskDepend> keywordDepend =
733 (symbolizeClauseTaskDepend(keyword)))
734 dependVec.emplace_back(
735 ClauseTaskDependAttr::get(parser.getContext(), *keywordDepend));
749 std::optional<ArrayAttr> depends) {
751 for (
unsigned i = 0, e = depends->size(); i < e; ++i) {
754 p << stringifyClauseTaskDepend(
755 llvm::cast<mlir::omp::ClauseTaskDependAttr>((*depends)[i])
757 <<
" -> " << dependVars[i] <<
" : " << dependTypes[i];
763 std::optional<ArrayAttr> depends,
765 if (!dependVars.empty()) {
766 if (!depends || depends->size() != dependVars.size())
767 return op->
emitOpError() <<
"expected as many depend values"
768 " as depend variables";
770 if (depends && !depends->empty())
771 return op->
emitOpError() <<
"unexpected depend values";
787 IntegerAttr &hintAttr) {
788 StringRef hintKeyword;
797 if (hintKeyword ==
"uncontended")
799 else if (hintKeyword ==
"contended")
801 else if (hintKeyword ==
"nonspeculative")
803 else if (hintKeyword ==
"speculative")
807 << hintKeyword <<
" is not a valid hint";
818 IntegerAttr hintAttr) {
819 int64_t hint = hintAttr.getInt();
827 auto bitn = [](
int value,
int n) ->
bool {
return value & (1 << n); };
829 bool uncontended = bitn(hint, 0);
830 bool contended = bitn(hint, 1);
831 bool nonspeculative = bitn(hint, 2);
832 bool speculative = bitn(hint, 3);
836 hints.push_back(
"uncontended");
838 hints.push_back(
"contended");
840 hints.push_back(
"nonspeculative");
842 hints.push_back(
"speculative");
844 llvm::interleaveComma(hints, p);
851 auto bitn = [](
int value,
int n) ->
bool {
return value & (1 << n); };
853 bool uncontended = bitn(hint, 0);
854 bool contended = bitn(hint, 1);
855 bool nonspeculative = bitn(hint, 2);
856 bool speculative = bitn(hint, 3);
858 if (uncontended && contended)
859 return op->
emitOpError() <<
"the hints omp_sync_hint_uncontended and "
860 "omp_sync_hint_contended cannot be combined";
861 if (nonspeculative && speculative)
862 return op->
emitOpError() <<
"the hints omp_sync_hint_nonspeculative and "
863 "omp_sync_hint_speculative cannot be combined.";
873 llvm::omp::OpenMPOffloadMappingFlags flag) {
874 return value & llvm::to_underlying(flag);
883 llvm::omp::OpenMPOffloadMappingFlags mapTypeBits =
884 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_NONE;
889 StringRef mapTypeMod;
893 if (mapTypeMod ==
"always")
894 mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ALWAYS;
896 if (mapTypeMod ==
"implicit")
897 mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_IMPLICIT;
899 if (mapTypeMod ==
"close")
900 mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_CLOSE;
902 if (mapTypeMod ==
"present")
903 mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_PRESENT;
905 if (mapTypeMod ==
"to")
906 mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO;
908 if (mapTypeMod ==
"from")
909 mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM;
911 if (mapTypeMod ==
"tofrom")
912 mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO |
913 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM;
915 if (mapTypeMod ==
"delete")
916 mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_DELETE;
926 llvm::to_underlying(mapTypeBits));
934 IntegerAttr mapType) {
935 uint64_t mapTypeBits = mapType.getUInt();
937 bool emitAllocRelease =
true;
943 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ALWAYS))
944 mapTypeStrs.push_back(
"always");
946 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_IMPLICIT))
947 mapTypeStrs.push_back(
"implicit");
949 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_CLOSE))
950 mapTypeStrs.push_back(
"close");
952 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_PRESENT))
953 mapTypeStrs.push_back(
"present");
959 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO);
961 mapTypeBits, llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM);
963 emitAllocRelease =
false;
964 mapTypeStrs.push_back(
"tofrom");
966 emitAllocRelease =
false;
967 mapTypeStrs.push_back(
"from");
969 emitAllocRelease =
false;
970 mapTypeStrs.push_back(
"to");
973 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_DELETE)) {
974 emitAllocRelease =
false;
975 mapTypeStrs.push_back(
"delete");
977 if (emitAllocRelease)
978 mapTypeStrs.push_back(
"exit_release_or_enter_alloc");
980 for (
unsigned int i = 0; i < mapTypeStrs.size(); ++i) {
982 if (i + 1 < mapTypeStrs.size()) {
999 mapOperands.push_back(arg);
1006 mapOperandTypes.push_back(argType);
1026 unsigned argIndex = 0;
1028 for (
const auto &mapOp : mapOperands) {
1030 p << mapOp <<
" -> " << blockArg;
1032 if (argIndex < mapOperands.size())
1038 for (
const auto &mapType : mapOperandTypes) {
1041 if (argIndex < mapOperands.size())
1047 VariableCaptureKindAttr mapCaptureType) {
1048 std::string typeCapStr;
1049 llvm::raw_string_ostream typeCap(typeCapStr);
1050 if (mapCaptureType.getValue() == mlir::omp::VariableCaptureKind::ByRef)
1052 if (mapCaptureType.getValue() == mlir::omp::VariableCaptureKind::ByCopy)
1053 typeCap <<
"ByCopy";
1054 if (mapCaptureType.getValue() == mlir::omp::VariableCaptureKind::VLAType)
1055 typeCap <<
"VLAType";
1056 if (mapCaptureType.getValue() == mlir::omp::VariableCaptureKind::This)
1062 VariableCaptureKindAttr &mapCapture) {
1063 StringRef mapCaptureKey;
1067 if (mapCaptureKey ==
"This")
1069 parser.
getContext(), mlir::omp::VariableCaptureKind::This);
1070 if (mapCaptureKey ==
"ByRef")
1072 parser.
getContext(), mlir::omp::VariableCaptureKind::ByRef);
1073 if (mapCaptureKey ==
"ByCopy")
1075 parser.
getContext(), mlir::omp::VariableCaptureKind::ByCopy);
1076 if (mapCaptureKey ==
"VLAType")
1078 parser.
getContext(), mlir::omp::VariableCaptureKind::VLAType);
1087 for (
auto mapOp : mapOperands) {
1088 if (!mapOp.getDefiningOp())
1091 if (
auto mapInfoOp =
1092 mlir::dyn_cast<mlir::omp::MapInfoOp>(mapOp.getDefiningOp())) {
1093 if (!mapInfoOp.getMapType().has_value())
1096 if (!mapInfoOp.getMapCaptureType().has_value())
1099 uint64_t mapTypeBits = mapInfoOp.getMapType().value();
1102 mapTypeBits, llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO);
1104 mapTypeBits, llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM);
1106 mapTypeBits, llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_DELETE);
1109 mapTypeBits, llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ALWAYS);
1111 mapTypeBits, llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_CLOSE);
1113 mapTypeBits, llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_IMPLICIT);
1115 if ((isa<TargetDataOp>(op) || isa<TargetOp>(op)) && del)
1117 "to, from, tofrom and alloc map types are permitted");
1119 if (isa<TargetEnterDataOp>(op) && (from || del))
1120 return emitError(op->
getLoc(),
"to and alloc map types are permitted");
1122 if (isa<TargetExitDataOp>(op) && to)
1124 "from, release and delete map types are permitted");
1126 if (isa<TargetUpdateOp>(op)) {
1129 "at least one of to or from map types must be "
1130 "specified, other map types are not permitted");
1135 "at least one of to or from map types must be "
1136 "specified, other map types are not permitted");
1139 auto updateVar = mapInfoOp.getVarPtr();
1141 if ((to && from) || (to && updateFromVars.contains(updateVar)) ||
1142 (from && updateToVars.contains(updateVar))) {
1145 "either to or from map types can be specified, not both");
1148 if (always || close || implicit) {
1151 "present, mapper and iterator map type modifiers are permitted");
1154 to ? updateToVars.insert(updateVar) : updateFromVars.insert(updateVar);
1165 if (getMapOperands().empty() && getUseDevicePtr().empty() &&
1166 getUseDeviceAddr().empty()) {
1168 "useDeviceAddr operand must be present");
1176 return failed(verifyDependVars) ? verifyDependVars
1183 return failed(verifyDependVars) ? verifyDependVars
1190 return failed(verifyDependVars) ? verifyDependVars
1197 return failed(verifyDependVars) ? verifyDependVars
1208 builder, state,
nullptr,
nullptr,
1213 state.addAttributes(attributes);
1216 template <
typename OpType>
1218 auto privateVars = op.getPrivateVars();
1219 auto privatizers = op.getPrivatizersAttr();
1221 if (privateVars.empty() && (privatizers ==
nullptr || privatizers.empty()))
1224 auto numPrivateVars = privateVars.size();
1225 auto numPrivatizers = (privatizers ==
nullptr) ? 0 : privatizers.size();
1227 if (numPrivateVars != numPrivatizers)
1228 return op.
emitError() <<
"inconsistent number of private variables and "
1229 "privatizer op symbols, private vars: "
1231 <<
" vs. privatizer op symbols: " << numPrivatizers;
1233 for (
auto privateVarInfo : llvm::zip_equal(privateVars, privatizers)) {
1234 Type varType = std::get<0>(privateVarInfo).getType();
1235 SymbolRefAttr privatizerSym =
1236 std::get<1>(privateVarInfo).template cast<SymbolRefAttr>();
1237 PrivateClauseOp privatizerOp =
1238 SymbolTable::lookupNearestSymbolFrom<PrivateClauseOp>(op,
1241 if (privatizerOp ==
nullptr)
1242 return op.
emitError() <<
"failed to lookup privatizer op with symbol: '"
1243 << privatizerSym <<
"'";
1245 Type privatizerType = privatizerOp.getType();
1247 if (varType != privatizerType)
1249 <<
"type mismatch between a "
1250 << (privatizerOp.getDataSharingType() ==
1251 DataSharingClauseType::Private
1254 <<
" variable and its privatizer op, var type: " << varType
1255 <<
" vs. privatizer op type: " << privatizerType;
1262 if (getAllocateVars().size() != getAllocatorsVars().size())
1264 "expected equal sizes for allocate and allocator variables");
1292 return emitError(
"expected to be nested inside of omp.target or not nested "
1293 "in any OpenMP dialect operations");
1296 if (
auto numTeamsLowerBound = getNumTeamsLower()) {
1297 auto numTeamsUpperBound = getNumTeamsUpper();
1298 if (!numTeamsUpperBound)
1299 return emitError(
"expected num_teams upper bound to be defined if the "
1300 "lower bound is defined");
1301 if (numTeamsLowerBound.getType() != numTeamsUpperBound.getType())
1303 "expected num_teams upper bound and lower bound to be the same type");
1307 if (getAllocateVars().size() != getAllocatorsVars().size())
1309 "expected equal sizes for allocate and allocator variables");
1319 if (getAllocateVars().size() != getAllocatorsVars().size())
1321 "expected equal sizes for allocate and allocator variables");
1327 for (
auto &inst : *getRegion().begin()) {
1328 if (!(isa<SectionOp>(inst) || isa<TerminatorOp>(inst))) {
1329 return emitOpError()
1330 <<
"expected omp.section op or terminator op inside region";
1339 if (getAllocateVars().size() != getAllocatorsVars().size())
1341 "expected equal sizes for allocate and allocator variables");
1344 getCopyprivateFuncs());
1362 UnitAttr &inclusive) {
1368 parser, region, reductionOperands, reductionTypes,
1369 reductionSymbols, privates));
1398 for (
auto &iv : ivs)
1399 iv.type = loopVarType;
1403 llvm::copy(privates, std::back_inserter(regionArgs));
1411 TypeRange reductionTypes, ArrayAttr reductionSymbols,
1412 UnitAttr inclusive) {
1413 if (reductionSymbols) {
1414 auto reductionArgs =
1417 reductionOperands, reductionTypes,
1423 p <<
" (" << args <<
") : " << args[0].getType() <<
" = (" << lowerBound
1424 <<
") to (" << upperBound <<
") ";
1427 p <<
"step (" << steps <<
") ";
1464 for (
auto &iv : ivs)
1465 iv.type = loopVarType;
1473 UnitAttr inclusive) {
1475 p <<
" (" << args <<
") : " << args[0].getType() <<
" = (" << lowerBound
1476 <<
") to (" << upperBound <<
") ";
1479 p <<
"step (" << steps <<
") ";
1489 return emitOpError() <<
"empty lowerbound for simd loop operation";
1491 if (this->getSimdlen().has_value() && this->getSafelen().has_value() &&
1492 this->getSimdlen().value() > this->getSafelen().value()) {
1493 return emitOpError()
1494 <<
"simdlen clause and safelen clause are both present, but the "
1495 "simdlen value is not less than or equal to safelen value";
1498 this->getAlignedVars())
1511 if (this->getChunkSize() && !this->getDistScheduleStatic())
1512 return emitOpError() <<
"chunk size set without "
1513 "dist_schedule_static being present";
1515 if (getAllocateVars().size() != getAllocatorsVars().size())
1517 "expected equal sizes for allocate and allocator variables");
1534 DeclareReductionOp op,
Region ®ion) {
1537 printer <<
"atomic ";
1542 if (getInitializerRegion().empty())
1543 return emitOpError() <<
"expects non-empty initializer region";
1544 Block &initializerEntryBlock = getInitializerRegion().
front();
1547 return emitOpError() <<
"expects initializer region with one argument "
1548 "of the reduction type";
1551 for (YieldOp yieldOp : getInitializerRegion().getOps<YieldOp>()) {
1552 if (yieldOp.getResults().size() != 1 ||
1553 yieldOp.getResults().getTypes()[0] != getType())
1554 return emitOpError() <<
"expects initializer region to yield a value "
1555 "of the reduction type";
1558 if (getReductionRegion().empty())
1559 return emitOpError() <<
"expects non-empty reduction region";
1560 Block &reductionEntryBlock = getReductionRegion().
front();
1565 return emitOpError() <<
"expects reduction region with two arguments of "
1566 "the reduction type";
1567 for (YieldOp yieldOp : getReductionRegion().getOps<YieldOp>()) {
1568 if (yieldOp.getResults().size() != 1 ||
1569 yieldOp.getResults().getTypes()[0] != getType())
1570 return emitOpError() <<
"expects reduction region to yield a value "
1571 "of the reduction type";
1574 if (getAtomicReductionRegion().empty())
1577 Block &atomicReductionEntryBlock = getAtomicReductionRegion().
front();
1581 return emitOpError() <<
"expects atomic reduction region with two "
1582 "arguments of the same type";
1583 auto ptrType = llvm::dyn_cast<PointerLikeType>(
1586 (ptrType.getElementType() && ptrType.getElementType() != getType()))
1587 return emitOpError() <<
"expects atomic reduction region arguments to "
1588 "be accumulators containing the reduction type";
1595 return emitOpError() <<
"must be used within an operation supporting "
1596 "reduction clause interface";
1598 for (
const auto &var :
1599 cast<ReductionClauseInterface>(op).getAllReductionVars())
1600 if (var == getAccumulator())
1604 return emitOpError() <<
"the accumulator is not used by the parent";
1613 return failed(verifyDependVars)
1616 getInReductionVars());
1624 getTaskReductionVars());
1632 getInReductionVars().end());
1633 allReductionNvars.insert(allReductionNvars.end(), getReductionVars().begin(),
1634 getReductionVars().end());
1635 return allReductionNvars;
1639 if (getAllocateVars().size() != getAllocatorsVars().size())
1641 "expected equal sizes for allocate and allocator variables");
1645 getInReductionVars())))
1648 if (!getReductionVars().empty() && getNogroup())
1649 return emitError(
"if a reduction clause is present on the taskloop "
1650 "directive, the nogroup clause must not be specified");
1651 for (
auto var : getReductionVars()) {
1652 if (llvm::is_contained(getInReductionVars(), var))
1653 return emitError(
"the same list item cannot appear in both a reduction "
1654 "and an in_reduction clause");
1657 if (getGrainSize() && getNumTasks()) {
1659 "the grainsize clause and num_tasks clause are mutually exclusive and "
1660 "may not appear on the same taskloop directive");
1672 build(builder, state, lowerBound, upperBound, step,
1677 false,
false,
false,
1680 state.addAttributes(attributes);
1696 if (getNameAttr()) {
1697 SymbolRefAttr symbolRef = getNameAttr();
1701 return emitOpError() <<
"expected symbol reference " << symbolRef
1702 <<
" to point to a critical declaration";
1714 auto container = (*this)->getParentOfType<WsloopOp>();
1715 if (!container || !container.getOrderedValAttr() ||
1716 container.getOrderedValAttr().getInt() == 0)
1717 return emitOpError() <<
"ordered depend directive must be closely "
1718 <<
"nested inside a worksharing-loop with ordered "
1719 <<
"clause with parameter present";
1721 if (container.getOrderedValAttr().getInt() != (int64_t)*getNumLoopsVal())
1722 return emitOpError() <<
"number of variables in depend clause does not "
1723 <<
"match number of iteration variables in the "
1734 if (
auto container = (*this)->getParentOfType<WsloopOp>()) {
1735 if (!container.getOrderedValAttr() ||
1736 container.getOrderedValAttr().getInt() != 0)
1737 return emitOpError() <<
"ordered region must be closely nested inside "
1738 <<
"a worksharing-loop region with an ordered "
1739 <<
"clause without parameter present";
1750 if (verifyCommon().
failed())
1753 if (
auto mo = getMemoryOrderVal()) {
1754 if (*mo == ClauseMemoryOrderKind::Acq_rel ||
1755 *mo == ClauseMemoryOrderKind::Release) {
1757 "memory-order must not be acq_rel or release for atomic reads");
1768 if (verifyCommon().
failed())
1771 if (
auto mo = getMemoryOrderVal()) {
1772 if (*mo == ClauseMemoryOrderKind::Acq_rel ||
1773 *mo == ClauseMemoryOrderKind::Acquire) {
1775 "memory-order must not be acq_rel or acquire for atomic writes");
1785 LogicalResult AtomicUpdateOp::canonicalize(AtomicUpdateOp op,
1791 if (
Value writeVal = op.getWriteOpVal()) {
1793 op.getHintValAttr(),
1794 op.getMemoryOrderValAttr());
1801 if (verifyCommon().
failed())
1804 if (
auto mo = getMemoryOrderVal()) {
1805 if (*mo == ClauseMemoryOrderKind::Acq_rel ||
1806 *mo == ClauseMemoryOrderKind::Acquire) {
1808 "memory-order must not be acq_rel or acquire for atomic updates");
1815 LogicalResult AtomicUpdateOp::verifyRegions() {
return verifyRegionsCommon(); }
1821 AtomicReadOp AtomicCaptureOp::getAtomicReadOp() {
1822 if (
auto op = dyn_cast<AtomicReadOp>(getFirstOp()))
1824 return dyn_cast<AtomicReadOp>(getSecondOp());
1827 AtomicWriteOp AtomicCaptureOp::getAtomicWriteOp() {
1828 if (
auto op = dyn_cast<AtomicWriteOp>(getFirstOp()))
1830 return dyn_cast<AtomicWriteOp>(getSecondOp());
1833 AtomicUpdateOp AtomicCaptureOp::getAtomicUpdateOp() {
1834 if (
auto op = dyn_cast<AtomicUpdateOp>(getFirstOp()))
1836 return dyn_cast<AtomicUpdateOp>(getSecondOp());
1844 if (verifyRegionsCommon().
failed())
1847 if (getFirstOp()->getAttr(
"hint_val") || getSecondOp()->getAttr(
"hint_val"))
1849 "operations inside capture region must not have hint clause");
1851 if (getFirstOp()->getAttr(
"memory_order_val") ||
1852 getSecondOp()->getAttr(
"memory_order_val"))
1854 "operations inside capture region must not have memory_order clause");
1863 ClauseCancellationConstructType cct = getCancellationConstructTypeVal();
1867 return emitOpError() <<
"must be used within a region supporting "
1871 if ((cct == ClauseCancellationConstructType::Parallel) &&
1872 !isa<ParallelOp>(parentOp)) {
1873 return emitOpError() <<
"cancel parallel must appear "
1874 <<
"inside a parallel region";
1876 if (cct == ClauseCancellationConstructType::Loop) {
1877 if (!isa<WsloopOp>(parentOp)) {
1878 return emitOpError() <<
"cancel loop must appear "
1879 <<
"inside a worksharing-loop region";
1881 if (cast<WsloopOp>(parentOp).getNowaitAttr()) {
1882 return emitError() <<
"A worksharing construct that is canceled "
1883 <<
"must not have a nowait clause";
1885 if (cast<WsloopOp>(parentOp).getOrderedValAttr()) {
1886 return emitError() <<
"A worksharing construct that is canceled "
1887 <<
"must not have an ordered clause";
1890 }
else if (cct == ClauseCancellationConstructType::Sections) {
1891 if (!(isa<SectionsOp>(parentOp) || isa<SectionOp>(parentOp))) {
1892 return emitOpError() <<
"cancel sections must appear "
1893 <<
"inside a sections region";
1895 if (isa_and_nonnull<SectionsOp>(parentOp->
getParentOp()) &&
1896 cast<SectionsOp>(parentOp->
getParentOp()).getNowaitAttr()) {
1897 return emitError() <<
"A sections construct that is canceled "
1898 <<
"must not have a nowait clause";
1909 ClauseCancellationConstructType cct = getCancellationConstructTypeVal();
1913 return emitOpError() <<
"must be used within a region supporting "
1914 "cancellation point directive";
1917 if ((cct == ClauseCancellationConstructType::Parallel) &&
1918 !(isa<ParallelOp>(parentOp))) {
1919 return emitOpError() <<
"cancellation point parallel must appear "
1920 <<
"inside a parallel region";
1922 if ((cct == ClauseCancellationConstructType::Loop) &&
1923 !isa<WsloopOp>(parentOp)) {
1924 return emitOpError() <<
"cancellation point loop must appear "
1925 <<
"inside a worksharing-loop region";
1927 if ((cct == ClauseCancellationConstructType::Sections) &&
1928 !(isa<SectionsOp>(parentOp) || isa<SectionOp>(parentOp))) {
1929 return emitOpError() <<
"cancellation point sections must appear "
1930 <<
"inside a sections region";
1941 auto extent = getExtent();
1943 if (!extent && !upperbound)
1944 return emitError(
"expected extent or upperbound.");
1951 PrivateClauseOp::build(
1952 odsBuilder, odsState, symName, type,
1954 DataSharingClauseType::Private));
1958 Type symType = getType();
1961 if (!terminator->getBlock()->getSuccessors().empty())
1964 if (!llvm::isa<YieldOp>(terminator))
1966 <<
"expected exit block terminator to be an `omp.yield` op.";
1968 YieldOp yieldOp = llvm::cast<YieldOp>(terminator);
1969 TypeRange yieldedTypes = yieldOp.getResults().getTypes();
1971 if (yieldedTypes.size() == 1 && yieldedTypes.front() == symType)
1975 <<
"Invalid yielded value. Expected type: " << symType
1978 if (yieldedTypes.empty())
1981 error << yieldedTypes;
1986 auto verifyRegion = [&](
Region ®ion,
unsigned expectedNumArgs,
1988 assert(!region.
empty());
1992 <<
"`" << regionName <<
"`: "
1993 <<
"expected " << expectedNumArgs
1996 for (
Block &block : region) {
1998 if (!block.mightHaveTerminator())
2001 if (
failed(verifyTerminator(block.getTerminator())))
2008 if (
failed(verifyRegion(getAllocRegion(), 1,
"alloc")))
2011 DataSharingClauseType dsType = getDataSharingType();
2013 if (dsType == DataSharingClauseType::Private && !getCopyRegion().empty())
2014 return emitError(
"`private` clauses require only an `alloc` region.");
2016 if (dsType == DataSharingClauseType::FirstPrivate && getCopyRegion().empty())
2018 "`firstprivate` clauses require both `alloc` and `copy` regions.");
2020 if (dsType == DataSharingClauseType::FirstPrivate &&
2021 failed(verifyRegion(getCopyRegion(), 2,
"copy")))
2027 #define GET_ATTRDEF_CLASSES
2028 #include "mlir/Dialect/OpenMP/OpenMPOpsAttributes.cpp.inc"
2030 #define GET_OP_CLASSES
2031 #include "mlir/Dialect/OpenMP/OpenMPOps.cpp.inc"
2033 #define GET_TYPEDEF_CLASSES
2034 #include "mlir/Dialect/OpenMP/OpenMPOpsTypes.cpp.inc"
static std::optional< int64_t > getUpperBound(Value iv)
Gets the constant upper bound on an affine.for iv.
static std::optional< int64_t > getLowerBound(Value iv)
Gets the constant lower bound on an iv.
static void copy(Location loc, Value dst, Value src, Value size, OpBuilder &builder)
Copies the given number of bytes from src to dst pointers.
static void visit(Operation *op, DenseSet< Operation * > &visited)
Visits all the pdl.operand(s), pdl.result(s), and pdl.operation(s) connected to the given operation.
static MLIRContext * getContext(OpFoldResult val)
void printClauseAttr(OpAsmPrinter &p, Operation *op, ClauseAttr attr)
static ParseResult parseAlignedClause(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &alignedItems, SmallVectorImpl< Type > &types, ArrayAttr &alignmentValues)
aligned ::= aligned ( aligned-list ) aligned-list := aligned-val | aligned-val aligned-list aligned-v...
ParseResult parseLoopControl(OpAsmParser &parser, Region ®ion, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &lowerBound, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &upperBound, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &steps, SmallVectorImpl< Type > &loopVarTypes, UnitAttr &inclusive)
loop-control ::= ( ssa-id-list ) : type = loop-bounds loop-bounds := ( ssa-id-list ) to ( ssa-id-list...
static ParseResult parseClauseAttr(AsmParser &parser, ClauseAttr &attr)
static void printDependVarList(OpAsmPrinter &p, Operation *op, OperandRange dependVars, TypeRange dependTypes, std::optional< ArrayAttr > depends)
Print Depend clause.
static ParseResult parseCaptureType(OpAsmParser &parser, VariableCaptureKindAttr &mapCapture)
ParseResult parseWsloop(OpAsmParser &parser, Region ®ion, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &lowerBound, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &upperBound, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &steps, SmallVectorImpl< Type > &loopVarTypes, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &reductionOperands, SmallVectorImpl< Type > &reductionTypes, ArrayAttr &reductionSymbols, UnitAttr &inclusive)
loop-control ::= ( ssa-id-list ) : type = loop-bounds loop-bounds := ( ssa-id-list ) to ( ssa-id-list...
static ParseResult parseLinearClause(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &vars, SmallVectorImpl< Type > &types, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &stepVars)
linear ::= linear ( linear-list ) linear-list := linear-val | linear-val linear-list linear-val := ss...
static void printAllocateAndAllocator(OpAsmPrinter &p, Operation *op, OperandRange varsAllocate, TypeRange typesAllocate, OperandRange varsAllocator, TypeRange typesAllocator)
Print allocate clause.
static LogicalResult verifyAlignedClause(Operation *op, std::optional< ArrayAttr > alignmentValues, OperandRange alignedVariables)
static void printReductionVarList(OpAsmPrinter &p, Operation *op, OperandRange reductionVars, TypeRange reductionTypes, std::optional< ArrayAttr > reductions)
Print Reduction clause.
static LogicalResult verifyReductionVarList(Operation *op, std::optional< ArrayAttr > reductions, OperandRange reductionVars)
Verifies Reduction Clause.
static ParseResult parseAllocateAndAllocator(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &operandsAllocate, SmallVectorImpl< Type > &typesAllocate, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &operandsAllocator, SmallVectorImpl< Type > &typesAllocator)
Parse an allocate clause with allocators and a list of operands with types.
static ParseResult parseSynchronizationHint(OpAsmParser &parser, IntegerAttr &hintAttr)
Parses a Synchronization Hint clause.
uint64_t mapTypeToBitFlag(uint64_t value, llvm::omp::OpenMPOffloadMappingFlags flag)
static void printAlignedClause(OpAsmPrinter &p, Operation *op, ValueRange alignedVars, TypeRange alignedVarTypes, std::optional< ArrayAttr > alignmentValues)
Print Aligned Clause.
void printWsloop(OpAsmPrinter &p, Operation *op, Region ®ion, ValueRange lowerBound, ValueRange upperBound, ValueRange steps, TypeRange loopVarTypes, ValueRange reductionOperands, TypeRange reductionTypes, ArrayAttr reductionSymbols, UnitAttr inclusive)
static LogicalResult verifyMapClause(Operation *op, OperandRange mapOperands)
static LogicalResult verifyDependVarList(Operation *op, std::optional< ArrayAttr > depends, OperandRange dependVars)
Verifies Depend clause.
static LogicalResult verifySynchronizationHint(Operation *op, uint64_t hint)
Verifies a synchronization hint clause.
static void printSynchronizationHint(OpAsmPrinter &p, Operation *op, IntegerAttr hintAttr)
Prints a Synchronization Hint clause.
static void printAtomicReductionRegion(OpAsmPrinter &printer, DeclareReductionOp op, Region ®ion)
static void printClauseWithRegionArgs(OpAsmPrinter &p, Operation *op, ValueRange argsSubrange, StringRef clauseName, ValueRange operands, TypeRange types, ArrayAttr symbols)
static ParseResult parseCopyPrivateVarList(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &operands, SmallVectorImpl< Type > &types, ArrayAttr ©PrivateSymbols)
copyprivate-entry-list ::= copyprivate-entry | copyprivate-entry-list , copyprivate-entry copyprivate...
static void printCopyPrivateVarList(OpAsmPrinter &p, Operation *op, OperandRange copyPrivateVars, TypeRange copyPrivateTypes, std::optional< ArrayAttr > copyPrivateFuncs)
Print CopyPrivate clause.
static void printLinearClause(OpAsmPrinter &p, Operation *op, ValueRange linearVars, TypeRange linearVarTypes, ValueRange linearStepVars)
Print Linear Clause.
static ParseResult parseReductionVarList(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &operands, SmallVectorImpl< Type > &types, ArrayAttr &redcuctionSymbols)
reduction-entry-list ::= reduction-entry | reduction-entry-list , reduction-entry reduction-entry ::=...
static ParseResult parseMapClause(OpAsmParser &parser, IntegerAttr &mapType)
Parses a map_entries map type from a string format back into its numeric value.
static ParseResult parseScheduleClause(OpAsmParser &parser, ClauseScheduleKindAttr &scheduleAttr, ScheduleModifierAttr &scheduleModifier, UnitAttr &simdModifier, std::optional< OpAsmParser::UnresolvedOperand > &chunkSize, Type &chunkType)
schedule ::= schedule ( sched-list ) sched-list ::= sched-val | sched-val sched-list | sched-val ,...
static LogicalResult verifyNontemporalClause(Operation *op, OperandRange nontemporalVariables)
static ParseResult verifyScheduleModifiers(OpAsmParser &parser, SmallVectorImpl< SmallString< 12 >> &modifiers)
void printLoopControl(OpAsmPrinter &p, Operation *op, Region ®ion, ValueRange lowerBound, ValueRange upperBound, ValueRange steps, TypeRange loopVarTypes, UnitAttr inclusive)
static void printScheduleClause(OpAsmPrinter &p, Operation *op, ClauseScheduleKindAttr schedAttr, ScheduleModifierAttr modifier, UnitAttr simd, Value scheduleChunkVar, Type scheduleChunkType)
Print schedule clause.
static ParseResult parseAtomicReductionRegion(OpAsmParser &parser, Region ®ion)
static void printCaptureType(OpAsmPrinter &p, Operation *op, VariableCaptureKindAttr mapCaptureType)
static void printParallelRegion(OpAsmPrinter &p, Operation *op, Region ®ion, ValueRange reductionVarOperands, TypeRange reductionVarTypes, ArrayAttr reductionSymbols, ValueRange privateVarOperands, TypeRange privateVarTypes, ArrayAttr privatizerSymbols)
static bool opInGlobalImplicitParallelRegion(Operation *op)
ParseResult parseClauseWithRegionArgs(OpAsmParser &parser, Region ®ion, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &operands, SmallVectorImpl< Type > &types, ArrayAttr &symbols, SmallVectorImpl< OpAsmParser::Argument > ®ionPrivateArgs)
static LogicalResult verifyCopyPrivateVarList(Operation *op, OperandRange copyPrivateVars, std::optional< ArrayAttr > copyPrivateFuncs)
Verifies CopyPrivate Clause.
static LogicalResult verifyPrivateVarList(OpType &op)
static ParseResult parseDependVarList(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &operands, SmallVectorImpl< Type > &types, ArrayAttr &dependsArray)
depend-entry-list ::= depend-entry | depend-entry-list , depend-entry depend-entry ::= depend-kind ->...
static ParseResult parseParallelRegion(OpAsmParser &parser, Region ®ion, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &reductionVarOperands, SmallVectorImpl< Type > &reductionVarTypes, ArrayAttr &reductionSymbols, llvm::SmallVectorImpl< OpAsmParser::UnresolvedOperand > &privateVarOperands, llvm::SmallVectorImpl< Type > &privateVarsTypes, ArrayAttr &privatizerSymbols)
static void printMapClause(OpAsmPrinter &p, Operation *op, IntegerAttr mapType)
Prints a map_entries map type from its numeric value out into its string format.
static ParseResult parseMapEntries(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &mapOperands, SmallVectorImpl< Type > &mapOperandTypes)
static void printMapEntries(OpAsmPrinter &p, Operation *op, OperandRange mapOperands, TypeRange mapOperandTypes)
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity,...
This base class exposes generic asm parser hooks, usable across the various derived parsers.
@ Paren
Parens surrounding zero or more operands.
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalEqual()=0
Parse a = token if present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual ParseResult parseColon()=0
Parse a : token.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseArrow()=0
Parse a '->' token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
Block represents an ordered list of Operations.
ValueTypeRange< BlockArgListType > getArgumentTypes()
Return a range containing the types of the arguments for this block.
BlockArgument getArgument(unsigned i)
unsigned getNumArguments()
BlockArgListType getArguments()
IntegerAttr getIntegerAttr(Type type, int64_t value)
IntegerType getIntegerType(unsigned width)
MLIRContext * getContext() const
Define a fold interface to allow for dialects to control specific aspects of the folding behavior for...
DialectFoldInterface(Dialect *dialect)
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region ®ion, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult parseArgumentList(SmallVectorImpl< Argument > &result, Delimiter delimiter=Delimiter::None, bool allowType=false, bool allowAttrs=false)=0
Parse zero or more arguments with a specified surrounding delimiter.
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
This class helps build Operations.
This class implements the operand iterators for the Operation class.
type_range getType() const
Operation is the basic unit of execution within MLIR.
Dialect * getDialect()
Return the dialect this operation is associated with, or nullptr if the associated dialect is not loa...
Location getLoc()
The source location the operation was defined or derived from.
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Operation * getParentWithTrait()
Returns the closest surrounding parent operation with trait Trait.
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
This class represents success/failure for parsing-like operations that find it important to chain tog...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class contains a list of basic blocks and a link to the parent operation it is attached to.
unsigned getNumArguments()
Location getLoc()
Return a location for this region.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class represents a collection of SymbolTables.
Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
type_range getType() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Runtime
Potential runtimes for AMD GPU kernels.
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
int64_t mod(int64_t lhs, int64_t rhs)
Returns MLIR's mod operation on constants.
This class represents an efficient way to signal success or failure.
This is the representation of an operand reference.
This represents an operation in an abstracted form, suitable for use with the builder APIs.