23 #include "llvm/ADT/BitVector.h"
24 #include "llvm/ADT/STLForwardCompat.h"
25 #include "llvm/ADT/SmallString.h"
26 #include "llvm/ADT/StringExtras.h"
27 #include "llvm/ADT/StringRef.h"
28 #include "llvm/ADT/TypeSwitch.h"
29 #include "llvm/Frontend/OpenMP/OMPConstants.h"
33 #include "mlir/Dialect/OpenMP/OpenMPOpsDialect.cpp.inc"
34 #include "mlir/Dialect/OpenMP/OpenMPOpsEnums.cpp.inc"
35 #include "mlir/Dialect/OpenMP/OpenMPOpsInterfaces.cpp.inc"
36 #include "mlir/Dialect/OpenMP/OpenMPTypeInterfaces.cpp.inc"
42 struct MemRefPointerLikeModel
43 :
public PointerLikeType::ExternalModel<MemRefPointerLikeModel,
46 return llvm::cast<MemRefType>(pointer).getElementType();
50 struct LLVMPointerPointerLikeModel
51 :
public PointerLikeType::ExternalModel<LLVMPointerPointerLikeModel,
52 LLVM::LLVMPointerType> {
59 bool shouldMaterializeInto(
Region *region)
const final {
61 return isa<TargetOp>(region->getParentOp());
66 void OpenMPDialect::initialize() {
69 #include "mlir/Dialect/OpenMP/OpenMPOps.cpp.inc"
72 #define GET_ATTRDEF_LIST
73 #include "mlir/Dialect/OpenMP/OpenMPOpsAttributes.cpp.inc"
76 #define GET_TYPEDEF_LIST
77 #include "mlir/Dialect/OpenMP/OpenMPOpsTypes.cpp.inc"
80 addInterface<OpenMPDialectFoldInterface>();
81 MemRefType::attachInterface<MemRefPointerLikeModel>(*
getContext());
82 LLVM::LLVMPointerType::attachInterface<LLVMPointerPointerLikeModel>(
87 mlir::ModuleOp::attachInterface<mlir::omp::OffloadModuleDefaultModel>(
93 mlir::LLVM::GlobalOp::attachInterface<
96 mlir::LLVM::LLVMFuncOp::attachInterface<
99 mlir::func::FuncOp::attachInterface<
103 mlir::func::FuncOp::attachInterface<
105 mlir::LLVM::LLVMFuncOp::attachInterface<
132 operandsAllocator.push_back(operand);
133 typesAllocator.push_back(type);
139 operandsAllocate.push_back(operand);
140 typesAllocate.push_back(type);
151 for (
unsigned i = 0; i < varsAllocate.size(); ++i) {
152 std::string separator = i == varsAllocate.size() - 1 ?
"" :
", ";
153 p << varsAllocator[i] <<
" : " << typesAllocator[i] <<
" -> ";
154 p << varsAllocate[i] <<
" : " << typesAllocate[i] << separator;
162 template <
typename ClauseAttr>
164 using ClauseT = decltype(std::declval<ClauseAttr>().getValue());
169 if (std::optional<ClauseT> enumValue = symbolizeEnum<ClauseT>(enumStr)) {
173 return parser.
emitError(loc,
"invalid clause value: '") << enumStr <<
"'";
176 template <
typename ClauseAttr>
178 p << stringifyEnum(attr.getValue());
202 types.push_back(type);
203 stepVars.push_back(stepVar);
212 size_t linearVarsSize = linearVars.size();
213 for (
unsigned i = 0; i < linearVarsSize; ++i) {
214 std::string separator = i == linearVarsSize - 1 ?
"" :
", ";
216 if (linearStepVars.size() > i)
217 p <<
" = " << linearStepVars[i];
218 p <<
" : " << linearVars[i].getType() << separator;
231 for (
const auto &it : nontemporalVariables)
232 if (!nontemporalItems.insert(it).second)
233 return op->
emitOpError() <<
"nontemporal variable used more than once";
245 if (!alignedVariables.empty()) {
246 if (!alignmentValues || alignmentValues->size() != alignedVariables.size())
248 <<
"expected as many alignment values as aligned variables";
251 return op->
emitOpError() <<
"unexpected alignment values attribute";
257 for (
auto it : alignedVariables)
258 if (!alignedItems.insert(it).second)
259 return op->
emitOpError() <<
"aligned variable used more than once";
261 if (!alignmentValues)
265 for (
unsigned i = 0; i < (*alignmentValues).size(); ++i) {
266 if (
auto intAttr = llvm::dyn_cast<IntegerAttr>((*alignmentValues)[i])) {
267 if (intAttr.getValue().sle(0))
268 return op->
emitOpError() <<
"alignment should be greater than 0";
270 return op->
emitOpError() <<
"expected integer alignment";
286 if (parser.parseOperand(alignedItems.emplace_back()) ||
287 parser.parseColonType(types.emplace_back()) ||
288 parser.parseArrow() ||
289 parser.parseAttribute(alignmentVec.emplace_back())) {
296 alignmentValues =
ArrayAttr::get(parser.getContext(), alignments);
304 std::optional<ArrayAttr> alignmentValues) {
305 for (
unsigned i = 0; i < alignedVars.size(); ++i) {
308 p << alignedVars[i] <<
" : " << alignedVars[i].
getType();
309 p <<
" -> " << (*alignmentValues)[i];
320 if (modifiers.size() > 2)
322 for (
const auto &
mod : modifiers) {
325 auto symbol = symbolizeScheduleModifier(
mod);
328 <<
" unknown modifier type: " <<
mod;
333 if (modifiers.size() == 1) {
334 if (symbolizeScheduleModifier(modifiers[0]) == ScheduleModifier::simd) {
335 modifiers.push_back(modifiers[0]);
336 modifiers[0] = stringifyScheduleModifier(ScheduleModifier::none);
338 }
else if (modifiers.size() == 2) {
341 if (symbolizeScheduleModifier(modifiers[0]) == ScheduleModifier::simd ||
342 symbolizeScheduleModifier(modifiers[1]) != ScheduleModifier::simd)
344 <<
" incorrect modifier order";
359 OpAsmParser &parser, ClauseScheduleKindAttr &scheduleAttr,
360 ScheduleModifierAttr &scheduleModifier, UnitAttr &simdModifier,
361 std::optional<OpAsmParser::UnresolvedOperand> &chunkSize,
Type &chunkType) {
365 std::optional<mlir::omp::ClauseScheduleKind> schedule =
366 symbolizeClauseScheduleKind(keyword);
372 case ClauseScheduleKind::Static:
373 case ClauseScheduleKind::Dynamic:
374 case ClauseScheduleKind::Guided:
380 chunkSize = std::nullopt;
383 case ClauseScheduleKind::Auto:
385 chunkSize = std::nullopt;
394 modifiers.push_back(
mod);
400 if (!modifiers.empty()) {
402 if (std::optional<ScheduleModifier>
mod =
403 symbolizeScheduleModifier(modifiers[0])) {
406 return parser.
emitError(loc,
"invalid schedule modifier");
409 if (modifiers.size() > 1) {
410 assert(symbolizeScheduleModifier(modifiers[1]) == ScheduleModifier::simd);
420 ClauseScheduleKindAttr schedAttr,
421 ScheduleModifierAttr modifier, UnitAttr simd,
422 Value scheduleChunkVar,
423 Type scheduleChunkType) {
424 p << stringifyClauseScheduleKind(schedAttr.getValue());
425 if (scheduleChunkVar)
426 p <<
" = " << scheduleChunkVar <<
" : " << scheduleChunkVar.
getType();
428 p <<
", " << stringifyScheduleModifier(modifier.getValue());
444 ArrayAttr &redcuctionSymbols) {
447 if (parser.parseAttribute(reductionVec.emplace_back()) ||
448 parser.parseArrow() ||
449 parser.parseOperand(operands.emplace_back()) ||
450 parser.parseColonType(types.emplace_back()))
464 std::optional<ArrayAttr> reductions) {
465 for (
unsigned i = 0, e = reductions->size(); i < e; ++i) {
468 p << (*reductions)[i] <<
" -> " << reductionVars[i] <<
" : "
475 std::optional<ArrayAttr> reductions,
477 if (!reductionVars.empty()) {
478 if (!reductions || reductions->size() != reductionVars.size())
480 <<
"expected as many reduction symbol references "
481 "as reduction variables";
484 return op->
emitOpError() <<
"unexpected reduction symbol references";
491 for (
auto args : llvm::zip(reductionVars, *reductions)) {
492 Value accum = std::get<0>(args);
494 if (!accumulators.insert(accum).second)
495 return op->
emitOpError() <<
"accumulator variable used more than once";
498 auto symbolRef = llvm::cast<SymbolRefAttr>(std::get<1>(args));
500 SymbolTable::lookupNearestSymbolFrom<ReductionDeclareOp>(op, symbolRef);
502 return op->
emitOpError() <<
"expected symbol reference " << symbolRef
503 <<
" to point to a reduction declaration";
505 if (decl.getAccumulatorType() && decl.getAccumulatorType() != varType)
507 <<
"expected accumulator (" << varType
508 <<
") to be the same type as reduction declaration ("
509 << decl.getAccumulatorType() <<
")";
529 if (parser.parseKeyword(&keyword) || parser.parseArrow() ||
530 parser.parseOperand(operands.emplace_back()) ||
531 parser.parseColonType(types.emplace_back()))
533 if (std::optional<ClauseTaskDepend> keywordDepend =
534 (symbolizeClauseTaskDepend(keyword)))
535 dependVec.emplace_back(
536 ClauseTaskDependAttr::get(parser.getContext(), *keywordDepend));
550 std::optional<ArrayAttr> depends) {
552 for (
unsigned i = 0, e = depends->size(); i < e; ++i) {
555 p << stringifyClauseTaskDepend(
556 llvm::cast<mlir::omp::ClauseTaskDependAttr>((*depends)[i])
558 <<
" -> " << dependVars[i] <<
" : " << dependTypes[i];
564 std::optional<ArrayAttr> depends,
566 if (!dependVars.empty()) {
567 if (!depends || depends->size() != dependVars.size())
568 return op->
emitOpError() <<
"expected as many depend values"
569 " as depend variables";
572 return op->
emitOpError() <<
"unexpected depend values";
588 IntegerAttr &hintAttr) {
589 StringRef hintKeyword;
598 if (hintKeyword ==
"uncontended")
600 else if (hintKeyword ==
"contended")
602 else if (hintKeyword ==
"nonspeculative")
604 else if (hintKeyword ==
"speculative")
608 << hintKeyword <<
" is not a valid hint";
619 IntegerAttr hintAttr) {
620 int64_t hint = hintAttr.getInt();
628 auto bitn = [](
int value,
int n) ->
bool {
return value & (1 << n); };
630 bool uncontended = bitn(hint, 0);
631 bool contended = bitn(hint, 1);
632 bool nonspeculative = bitn(hint, 2);
633 bool speculative = bitn(hint, 3);
637 hints.push_back(
"uncontended");
639 hints.push_back(
"contended");
641 hints.push_back(
"nonspeculative");
643 hints.push_back(
"speculative");
645 llvm::interleaveComma(hints, p);
652 auto bitn = [](
int value,
int n) ->
bool {
return value & (1 << n); };
654 bool uncontended = bitn(hint, 0);
655 bool contended = bitn(hint, 1);
656 bool nonspeculative = bitn(hint, 2);
657 bool speculative = bitn(hint, 3);
659 if (uncontended && contended)
660 return op->
emitOpError() <<
"the hints omp_sync_hint_uncontended and "
661 "omp_sync_hint_contended cannot be combined";
662 if (nonspeculative && speculative)
663 return op->
emitOpError() <<
"the hints omp_sync_hint_nonspeculative and "
664 "omp_sync_hint_speculative cannot be combined.";
674 llvm::omp::OpenMPOffloadMappingFlags flag) {
675 return value & llvm::to_underlying(flag);
684 llvm::omp::OpenMPOffloadMappingFlags mapTypeBits =
685 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_NONE;
690 StringRef mapTypeMod;
694 if (mapTypeMod ==
"always")
695 mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ALWAYS;
697 if (mapTypeMod ==
"implicit")
698 mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_IMPLICIT;
700 if (mapTypeMod ==
"close")
701 mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_CLOSE;
703 if (mapTypeMod ==
"present")
704 mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_PRESENT;
706 if (mapTypeMod ==
"to")
707 mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO;
709 if (mapTypeMod ==
"from")
710 mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM;
712 if (mapTypeMod ==
"tofrom")
713 mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO |
714 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM;
716 if (mapTypeMod ==
"delete")
717 mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_DELETE;
727 llvm::to_underlying(mapTypeBits));
735 IntegerAttr mapType) {
736 uint64_t mapTypeBits = mapType.getUInt();
738 bool emitAllocRelease =
true;
744 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ALWAYS))
745 mapTypeStrs.push_back(
"always");
747 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_IMPLICIT))
748 mapTypeStrs.push_back(
"implicit");
750 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_CLOSE))
751 mapTypeStrs.push_back(
"close");
753 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_PRESENT))
754 mapTypeStrs.push_back(
"present");
760 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO);
762 mapTypeBits, llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM);
764 emitAllocRelease =
false;
765 mapTypeStrs.push_back(
"tofrom");
767 emitAllocRelease =
false;
768 mapTypeStrs.push_back(
"from");
770 emitAllocRelease =
false;
771 mapTypeStrs.push_back(
"to");
774 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_DELETE)) {
775 emitAllocRelease =
false;
776 mapTypeStrs.push_back(
"delete");
778 if (emitAllocRelease)
779 mapTypeStrs.push_back(
"exit_release_or_enter_alloc");
781 for (
unsigned int i = 0; i < mapTypeStrs.size(); ++i) {
783 if (i + 1 < mapTypeStrs.size()) {
800 mapOperands.push_back(arg);
807 mapOperandTypes.push_back(argType);
827 unsigned argIndex = 0;
829 for (
const auto &mapOp : mapOperands) {
831 p << mapOp <<
" -> " << blockArg;
833 if (argIndex < mapOperands.size())
839 for (
const auto &mapType : mapOperandTypes) {
842 if (argIndex < mapOperands.size())
848 VariableCaptureKindAttr mapCaptureType) {
849 std::string typeCapStr;
850 llvm::raw_string_ostream typeCap(typeCapStr);
851 if (mapCaptureType.getValue() == mlir::omp::VariableCaptureKind::ByRef)
853 if (mapCaptureType.getValue() == mlir::omp::VariableCaptureKind::ByCopy)
855 if (mapCaptureType.getValue() == mlir::omp::VariableCaptureKind::VLAType)
856 typeCap <<
"VLAType";
857 if (mapCaptureType.getValue() == mlir::omp::VariableCaptureKind::This)
863 VariableCaptureKindAttr &mapCapture) {
864 StringRef mapCaptureKey;
868 if (mapCaptureKey ==
"This")
870 parser.
getContext(), mlir::omp::VariableCaptureKind::This);
871 if (mapCaptureKey ==
"ByRef")
873 parser.
getContext(), mlir::omp::VariableCaptureKind::ByRef);
874 if (mapCaptureKey ==
"ByCopy")
876 parser.
getContext(), mlir::omp::VariableCaptureKind::ByCopy);
877 if (mapCaptureKey ==
"VLAType")
879 parser.
getContext(), mlir::omp::VariableCaptureKind::VLAType);
886 for (
auto mapOp : mapOperands) {
887 if (!mapOp.getDefiningOp())
891 mlir::dyn_cast<mlir::omp::MapInfoOp>(mapOp.getDefiningOp())) {
893 if (!MapInfoOp.getMapType().has_value())
896 if (!MapInfoOp.getMapCaptureType().has_value())
899 uint64_t mapTypeBits = MapInfoOp.getMapType().value();
902 mapTypeBits, llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO);
904 mapTypeBits, llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM);
906 mapTypeBits, llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_DELETE);
908 if ((isa<DataOp>(op) || isa<TargetOp>(op)) && del)
910 "to, from, tofrom and alloc map types are permitted");
912 if (isa<EnterDataOp>(op) && (from || del))
913 return emitError(op->
getLoc(),
"to and alloc map types are permitted");
915 if (isa<ExitDataOp>(op) && to)
917 "from, release and delete map types are permitted");
927 if (getMapOperands().empty() && getUseDevicePtr().empty() &&
928 getUseDeviceAddr().empty()) {
930 "useDeviceAddr operand must be present");
954 builder, state,
nullptr,
nullptr,
958 state.addAttributes(attributes);
962 if (getAllocateVars().size() != getAllocatorsVars().size())
964 "expected equal sizes for allocate and allocator variables");
988 return emitError(
"expected to be nested inside of omp.target or not nested "
989 "in any OpenMP dialect operations");
992 if (
auto numTeamsLowerBound = getNumTeamsLower()) {
993 auto numTeamsUpperBound = getNumTeamsUpper();
994 if (!numTeamsUpperBound)
995 return emitError(
"expected num_teams upper bound to be defined if the "
996 "lower bound is defined");
997 if (numTeamsLowerBound.getType() != numTeamsUpperBound.getType())
999 "expected num_teams upper bound and lower bound to be the same type");
1003 if (getAllocateVars().size() != getAllocatorsVars().size())
1005 "expected equal sizes for allocate and allocator variables");
1015 if (getAllocateVars().size() != getAllocatorsVars().size())
1017 "expected equal sizes for allocate and allocator variables");
1023 for (
auto &inst : *getRegion().begin()) {
1024 if (!(isa<SectionOp>(inst) || isa<TerminatorOp>(inst))) {
1025 return emitOpError()
1026 <<
"expected omp.section op or terminator op inside region";
1035 if (getAllocateVars().size() != getAllocatorsVars().size())
1037 "expected equal sizes for allocate and allocator variables");
1079 for (
auto &iv : ivs)
1080 iv.type = loopVarType;
1087 UnitAttr inclusive) {
1089 p <<
" (" << args <<
") : " << args[0].getType() <<
" = (" << lowerBound
1090 <<
") to (" << upperBound <<
") ";
1093 p <<
"step (" << steps <<
") ";
1103 return emitOpError() <<
"empty lowerbound for simd loop operation";
1105 if (this->getSimdlen().has_value() && this->getSafelen().has_value() &&
1106 this->getSimdlen().value() > this->getSafelen().value()) {
1107 return emitOpError()
1108 <<
"simdlen clause and safelen clause are both present, but the "
1109 "simdlen value is not less than or equal to safelen value";
1112 this->getAlignedVars())
1132 ReductionDeclareOp op,
Region ®ion) {
1135 printer <<
"atomic ";
1140 if (getInitializerRegion().empty())
1141 return emitOpError() <<
"expects non-empty initializer region";
1142 Block &initializerEntryBlock = getInitializerRegion().
front();
1145 return emitOpError() <<
"expects initializer region with one argument "
1146 "of the reduction type";
1149 for (YieldOp yieldOp : getInitializerRegion().getOps<YieldOp>()) {
1150 if (yieldOp.getResults().size() != 1 ||
1151 yieldOp.getResults().getTypes()[0] != getType())
1152 return emitOpError() <<
"expects initializer region to yield a value "
1153 "of the reduction type";
1156 if (getReductionRegion().empty())
1157 return emitOpError() <<
"expects non-empty reduction region";
1158 Block &reductionEntryBlock = getReductionRegion().
front();
1163 return emitOpError() <<
"expects reduction region with two arguments of "
1164 "the reduction type";
1165 for (YieldOp yieldOp : getReductionRegion().getOps<YieldOp>()) {
1166 if (yieldOp.getResults().size() != 1 ||
1167 yieldOp.getResults().getTypes()[0] != getType())
1168 return emitOpError() <<
"expects reduction region to yield a value "
1169 "of the reduction type";
1172 if (getAtomicReductionRegion().empty())
1175 Block &atomicReductionEntryBlock = getAtomicReductionRegion().
front();
1179 return emitOpError() <<
"expects atomic reduction region with two "
1180 "arguments of the same type";
1181 auto ptrType = llvm::dyn_cast<PointerLikeType>(
1184 (ptrType.getElementType() && ptrType.getElementType() != getType()))
1185 return emitOpError() <<
"expects atomic reduction region arguments to "
1186 "be accumulators containing the reduction type";
1193 return emitOpError() <<
"must be used within an operation supporting "
1194 "reduction clause interface";
1196 for (
const auto &var :
1197 cast<ReductionClauseInterface>(op).getAllReductionVars())
1198 if (var == getAccumulator())
1202 return emitOpError() <<
"the accumulator is not used by the parent";
1211 return failed(verifyDependVars)
1214 getInReductionVars());
1222 getTaskReductionVars());
1230 getInReductionVars().end());
1231 allReductionNvars.insert(allReductionNvars.end(), getReductionVars().begin(),
1232 getReductionVars().end());
1233 return allReductionNvars;
1237 if (getAllocateVars().size() != getAllocatorsVars().size())
1239 "expected equal sizes for allocate and allocator variables");
1243 getInReductionVars())))
1246 if (!getReductionVars().empty() && getNogroup())
1247 return emitError(
"if a reduction clause is present on the taskloop "
1248 "directive, the nogroup clause must not be specified");
1249 for (
auto var : getReductionVars()) {
1250 if (llvm::is_contained(getInReductionVars(), var))
1251 return emitError(
"the same list item cannot appear in both a reduction "
1252 "and an in_reduction clause");
1255 if (getGrainSize() && getNumTasks()) {
1257 "the grainsize clause and num_tasks clause are mutually exclusive and "
1258 "may not appear on the same taskloop directive");
1270 build(builder, state, lowerBound, upperBound, step,
1275 false,
false,
nullptr,
1277 state.addAttributes(attributes);
1293 if (getNameAttr()) {
1294 SymbolRefAttr symbolRef = getNameAttr();
1298 return emitOpError() <<
"expected symbol reference " << symbolRef
1299 <<
" to point to a critical declaration";
1311 auto container = (*this)->getParentOfType<WsLoopOp>();
1312 if (!container || !container.getOrderedValAttr() ||
1313 container.getOrderedValAttr().getInt() == 0)
1314 return emitOpError() <<
"ordered depend directive must be closely "
1315 <<
"nested inside a worksharing-loop with ordered "
1316 <<
"clause with parameter present";
1318 if (container.getOrderedValAttr().getInt() != (int64_t)*getNumLoopsVal())
1319 return emitOpError() <<
"number of variables in depend clause does not "
1320 <<
"match number of iteration variables in the "
1331 if (
auto container = (*this)->getParentOfType<WsLoopOp>()) {
1332 if (!container.getOrderedValAttr() ||
1333 container.getOrderedValAttr().getInt() != 0)
1334 return emitOpError() <<
"ordered region must be closely nested inside "
1335 <<
"a worksharing-loop region with an ordered "
1336 <<
"clause without parameter present";
1347 if (verifyCommon().
failed())
1350 if (
auto mo = getMemoryOrderVal()) {
1351 if (*mo == ClauseMemoryOrderKind::Acq_rel ||
1352 *mo == ClauseMemoryOrderKind::Release) {
1354 "memory-order must not be acq_rel or release for atomic reads");
1365 if (verifyCommon().
failed())
1368 if (
auto mo = getMemoryOrderVal()) {
1369 if (*mo == ClauseMemoryOrderKind::Acq_rel ||
1370 *mo == ClauseMemoryOrderKind::Acquire) {
1372 "memory-order must not be acq_rel or acquire for atomic writes");
1382 LogicalResult AtomicUpdateOp::canonicalize(AtomicUpdateOp op,
1388 if (
Value writeVal = op.getWriteOpVal()) {
1390 op.getHintValAttr(),
1391 op.getMemoryOrderValAttr());
1398 if (verifyCommon().
failed())
1401 if (
auto mo = getMemoryOrderVal()) {
1402 if (*mo == ClauseMemoryOrderKind::Acq_rel ||
1403 *mo == ClauseMemoryOrderKind::Acquire) {
1405 "memory-order must not be acq_rel or acquire for atomic updates");
1412 LogicalResult AtomicUpdateOp::verifyRegions() {
return verifyRegionsCommon(); }
1418 AtomicReadOp AtomicCaptureOp::getAtomicReadOp() {
1419 if (
auto op = dyn_cast<AtomicReadOp>(getFirstOp()))
1421 return dyn_cast<AtomicReadOp>(getSecondOp());
1424 AtomicWriteOp AtomicCaptureOp::getAtomicWriteOp() {
1425 if (
auto op = dyn_cast<AtomicWriteOp>(getFirstOp()))
1427 return dyn_cast<AtomicWriteOp>(getSecondOp());
1430 AtomicUpdateOp AtomicCaptureOp::getAtomicUpdateOp() {
1431 if (
auto op = dyn_cast<AtomicUpdateOp>(getFirstOp()))
1433 return dyn_cast<AtomicUpdateOp>(getSecondOp());
1441 if (verifyRegionsCommon().
failed())
1444 if (getFirstOp()->getAttr(
"hint_val") || getSecondOp()->getAttr(
"hint_val"))
1446 "operations inside capture region must not have hint clause");
1448 if (getFirstOp()->getAttr(
"memory_order_val") ||
1449 getSecondOp()->getAttr(
"memory_order_val"))
1451 "operations inside capture region must not have memory_order clause");
1460 ClauseCancellationConstructType cct = getCancellationConstructTypeVal();
1464 return emitOpError() <<
"must be used within a region supporting "
1468 if ((cct == ClauseCancellationConstructType::Parallel) &&
1469 !isa<ParallelOp>(parentOp)) {
1470 return emitOpError() <<
"cancel parallel must appear "
1471 <<
"inside a parallel region";
1473 if (cct == ClauseCancellationConstructType::Loop) {
1474 if (!isa<WsLoopOp>(parentOp)) {
1475 return emitOpError() <<
"cancel loop must appear "
1476 <<
"inside a worksharing-loop region";
1478 if (cast<WsLoopOp>(parentOp).getNowaitAttr()) {
1479 return emitError() <<
"A worksharing construct that is canceled "
1480 <<
"must not have a nowait clause";
1482 if (cast<WsLoopOp>(parentOp).getOrderedValAttr()) {
1483 return emitError() <<
"A worksharing construct that is canceled "
1484 <<
"must not have an ordered clause";
1487 }
else if (cct == ClauseCancellationConstructType::Sections) {
1488 if (!(isa<SectionsOp>(parentOp) || isa<SectionOp>(parentOp))) {
1489 return emitOpError() <<
"cancel sections must appear "
1490 <<
"inside a sections region";
1492 if (isa_and_nonnull<SectionsOp>(parentOp->
getParentOp()) &&
1493 cast<SectionsOp>(parentOp->
getParentOp()).getNowaitAttr()) {
1494 return emitError() <<
"A sections construct that is canceled "
1495 <<
"must not have a nowait clause";
1506 ClauseCancellationConstructType cct = getCancellationConstructTypeVal();
1510 return emitOpError() <<
"must be used within a region supporting "
1511 "cancellation point directive";
1514 if ((cct == ClauseCancellationConstructType::Parallel) &&
1515 !(isa<ParallelOp>(parentOp))) {
1516 return emitOpError() <<
"cancellation point parallel must appear "
1517 <<
"inside a parallel region";
1519 if ((cct == ClauseCancellationConstructType::Loop) &&
1520 !isa<WsLoopOp>(parentOp)) {
1521 return emitOpError() <<
"cancellation point loop must appear "
1522 <<
"inside a worksharing-loop region";
1524 if ((cct == ClauseCancellationConstructType::Sections) &&
1525 !(isa<SectionsOp>(parentOp) || isa<SectionOp>(parentOp))) {
1526 return emitOpError() <<
"cancellation point sections must appear "
1527 <<
"inside a sections region";
1538 auto extent = getExtent();
1540 if (!extent && !upperbound)
1541 return emitError(
"expected extent or upperbound.");
1545 #define GET_ATTRDEF_CLASSES
1546 #include "mlir/Dialect/OpenMP/OpenMPOpsAttributes.cpp.inc"
1548 #define GET_OP_CLASSES
1549 #include "mlir/Dialect/OpenMP/OpenMPOps.cpp.inc"
1551 #define GET_TYPEDEF_CLASSES
1552 #include "mlir/Dialect/OpenMP/OpenMPOpsTypes.cpp.inc"
static std::optional< int64_t > getUpperBound(Value iv)
Gets the constant upper bound on an affine.for iv.
static std::optional< int64_t > getLowerBound(Value iv)
Gets the constant lower bound on an iv.
static MLIRContext * getContext(OpFoldResult val)
void printClauseAttr(OpAsmPrinter &p, Operation *op, ClauseAttr attr)
static ParseResult parseAlignedClause(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &alignedItems, SmallVectorImpl< Type > &types, ArrayAttr &alignmentValues)
aligned ::= aligned ( aligned-list ) aligned-list := aligned-val | aligned-val aligned-list aligned-v...
ParseResult parseLoopControl(OpAsmParser &parser, Region ®ion, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &lowerBound, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &upperBound, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &steps, SmallVectorImpl< Type > &loopVarTypes, UnitAttr &inclusive)
loop-control ::= ( ssa-id-list ) : type = loop-bounds loop-bounds := ( ssa-id-list ) to ( ssa-id-list...
static ParseResult parseClauseAttr(AsmParser &parser, ClauseAttr &attr)
static void printDependVarList(OpAsmPrinter &p, Operation *op, OperandRange dependVars, TypeRange dependTypes, std::optional< ArrayAttr > depends)
Print Depend clause.
static ParseResult parseCaptureType(OpAsmParser &parser, VariableCaptureKindAttr &mapCapture)
static ParseResult parseLinearClause(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &vars, SmallVectorImpl< Type > &types, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &stepVars)
linear ::= linear ( linear-list ) linear-list := linear-val | linear-val linear-list linear-val := ss...
static void printAllocateAndAllocator(OpAsmPrinter &p, Operation *op, OperandRange varsAllocate, TypeRange typesAllocate, OperandRange varsAllocator, TypeRange typesAllocator)
Print allocate clause.
static LogicalResult verifyAlignedClause(Operation *op, std::optional< ArrayAttr > alignmentValues, OperandRange alignedVariables)
static void printReductionVarList(OpAsmPrinter &p, Operation *op, OperandRange reductionVars, TypeRange reductionTypes, std::optional< ArrayAttr > reductions)
Print Reduction clause.
static LogicalResult verifyReductionVarList(Operation *op, std::optional< ArrayAttr > reductions, OperandRange reductionVars)
Verifies Reduction Clause.
static ParseResult parseAllocateAndAllocator(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &operandsAllocate, SmallVectorImpl< Type > &typesAllocate, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &operandsAllocator, SmallVectorImpl< Type > &typesAllocator)
Parse an allocate clause with allocators and a list of operands with types.
static ParseResult parseSynchronizationHint(OpAsmParser &parser, IntegerAttr &hintAttr)
Parses a Synchronization Hint clause.
uint64_t mapTypeToBitFlag(uint64_t value, llvm::omp::OpenMPOffloadMappingFlags flag)
static void printAlignedClause(OpAsmPrinter &p, Operation *op, ValueRange alignedVars, TypeRange alignedVarTypes, std::optional< ArrayAttr > alignmentValues)
Print Aligned Clause.
static LogicalResult verifyMapClause(Operation *op, OperandRange mapOperands)
static LogicalResult verifyDependVarList(Operation *op, std::optional< ArrayAttr > depends, OperandRange dependVars)
Verifies Depend clause.
static LogicalResult verifySynchronizationHint(Operation *op, uint64_t hint)
Verifies a synchronization hint clause.
static void printSynchronizationHint(OpAsmPrinter &p, Operation *op, IntegerAttr hintAttr)
Prints a Synchronization Hint clause.
static void printAtomicReductionRegion(OpAsmPrinter &printer, ReductionDeclareOp op, Region ®ion)
static void printLinearClause(OpAsmPrinter &p, Operation *op, ValueRange linearVars, TypeRange linearVarTypes, ValueRange linearStepVars)
Print Linear Clause.
static ParseResult parseReductionVarList(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &operands, SmallVectorImpl< Type > &types, ArrayAttr &redcuctionSymbols)
reduction-entry-list ::= reduction-entry | reduction-entry-list , reduction-entry reduction-entry ::=...
static ParseResult parseMapClause(OpAsmParser &parser, IntegerAttr &mapType)
Parses a map_entries map type from a string format back into its numeric value.
static ParseResult parseScheduleClause(OpAsmParser &parser, ClauseScheduleKindAttr &scheduleAttr, ScheduleModifierAttr &scheduleModifier, UnitAttr &simdModifier, std::optional< OpAsmParser::UnresolvedOperand > &chunkSize, Type &chunkType)
schedule ::= schedule ( sched-list ) sched-list ::= sched-val | sched-val sched-list | sched-val ,...
static LogicalResult verifyNontemporalClause(Operation *op, OperandRange nontemporalVariables)
static ParseResult verifyScheduleModifiers(OpAsmParser &parser, SmallVectorImpl< SmallString< 12 >> &modifiers)
void printLoopControl(OpAsmPrinter &p, Operation *op, Region ®ion, ValueRange lowerBound, ValueRange upperBound, ValueRange steps, TypeRange loopVarTypes, UnitAttr inclusive)
static void printScheduleClause(OpAsmPrinter &p, Operation *op, ClauseScheduleKindAttr schedAttr, ScheduleModifierAttr modifier, UnitAttr simd, Value scheduleChunkVar, Type scheduleChunkType)
Print schedule clause.
static ParseResult parseAtomicReductionRegion(OpAsmParser &parser, Region ®ion)
static void printCaptureType(OpAsmPrinter &p, Operation *op, VariableCaptureKindAttr mapCaptureType)
static bool opInGlobalImplicitParallelRegion(Operation *op)
static ParseResult parseDependVarList(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &operands, SmallVectorImpl< Type > &types, ArrayAttr &dependsArray)
depend-entry-list ::= depend-entry | depend-entry-list , depend-entry depend-entry ::= depend-kind ->...
static void printMapClause(OpAsmPrinter &p, Operation *op, IntegerAttr mapType)
Prints a map_entries map type from its numeric value out into its string format.
static ParseResult parseMapEntries(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &mapOperands, SmallVectorImpl< Type > &mapOperandTypes)
static void printMapEntries(OpAsmPrinter &p, Operation *op, OperandRange mapOperands, TypeRange mapOperandTypes)
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity,...
This base class exposes generic asm parser hooks, usable across the various derived parsers.
@ Paren
Parens surrounding zero or more operands.
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalEqual()=0
Parse a = token if present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual ParseResult parseColon()=0
Parse a : token.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseArrow()=0
Parse a '->' token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
Block represents an ordered list of Operations.
ValueTypeRange< BlockArgListType > getArgumentTypes()
Return a range containing the types of the arguments for this block.
BlockArgument getArgument(unsigned i)
unsigned getNumArguments()
BlockArgListType getArguments()
IntegerAttr getIntegerAttr(Type type, int64_t value)
IntegerType getIntegerType(unsigned width)
MLIRContext * getContext() const
Define a fold interface to allow for dialects to control specific aspects of the folding behavior for...
DialectFoldInterface(Dialect *dialect)
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region ®ion, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult parseArgumentList(SmallVectorImpl< Argument > &result, Delimiter delimiter=Delimiter::None, bool allowType=false, bool allowAttrs=false)=0
Parse zero or more arguments with a specified surrounding delimiter.
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
This class helps build Operations.
This class implements the operand iterators for the Operation class.
type_range getType() const
Operation is the basic unit of execution within MLIR.
Dialect * getDialect()
Return the dialect this operation is associated with, or nullptr if the associated dialect is not loa...
Location getLoc()
The source location the operation was defined or derived from.
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Operation * getParentWithTrait()
Returns the closest surrounding parent operation with trait Trait.
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
This class represents success/failure for parsing-like operations that find it important to chain tog...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class contains a list of basic blocks and a link to the parent operation it is attached to.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replaces the result op with a new op that is created without verification.
This class represents a collection of SymbolTables.
Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
type_range getType() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Runtime
Potential runtimes for AMD GPU kernels.
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
int64_t mod(int64_t lhs, int64_t rhs)
Returns MLIR's mod operation on constants.
This class represents an efficient way to signal success or failure.
This is the representation of an operand reference.
This represents an operation in an abstracted form, suitable for use with the builder APIs.