28#include "llvm/ADT/ArrayRef.h"
29#include "llvm/ADT/PostOrderIterator.h"
30#include "llvm/ADT/STLExtras.h"
31#include "llvm/ADT/STLForwardCompat.h"
32#include "llvm/ADT/SmallString.h"
33#include "llvm/ADT/StringExtras.h"
34#include "llvm/ADT/StringRef.h"
35#include "llvm/ADT/TypeSwitch.h"
36#include "llvm/ADT/bit.h"
37#include "llvm/Support/InterleavedRange.h"
43#include "mlir/Dialect/OpenMP/OpenMPOpsDialect.cpp.inc"
44#include "mlir/Dialect/OpenMP/OpenMPOpsEnums.cpp.inc"
45#include "mlir/Dialect/OpenMP/OpenMPOpsInterfaces.cpp.inc"
46#include "mlir/Dialect/OpenMP/OpenMPTypeInterfaces.cpp.inc"
53 return attrs.empty() ?
nullptr : ArrayAttr::get(context, attrs);
67struct MemRefPointerLikeModel
68 :
public PointerLikeType::ExternalModel<MemRefPointerLikeModel,
71 return llvm::cast<MemRefType>(pointer).getElementType();
75struct LLVMPointerPointerLikeModel
76 :
public PointerLikeType::ExternalModel<LLVMPointerPointerLikeModel,
77 LLVM::LLVMPointerType> {
102 bool isRegionArgOfOp;
112 assert(isRegionArgOfOp &&
"Must describe a region operand");
115 size_t &getArgIdx() {
116 assert(isRegionArgOfOp &&
"Must describe a region operand");
121 assert(!isRegionArgOfOp &&
"Must describe a operation of a region");
125 assert(!isRegionArgOfOp &&
"Must describe a operation of a region");
128 bool isLoopOp()
const {
129 assert(!isRegionArgOfOp &&
"Must describe a operation of a region");
130 return isa<CanonicalLoopOp>(op);
132 Region *&getParentRegion() {
133 assert(!isRegionArgOfOp &&
"Must describe a operation of a region");
136 size_t &getLoopDepth() {
137 assert(!isRegionArgOfOp &&
"Must describe a operation of a region");
141 void skipIf(
bool v =
true) { skip = skip || v; }
159 llvm::ReversePostOrderTraversal<Block *> traversal(&r->
getBlocks().front());
162 size_t sequentialIdx = -1;
163 bool isOnlyContainerOp =
true;
164 for (
Block *
b : traversal) {
166 if (&op == o && !found) {
170 if (op.getNumRegions()) {
173 isOnlyContainerOp =
false;
175 if (found && !isOnlyContainerOp)
180 Component &containerOpInRegion = components.emplace_back();
181 containerOpInRegion.isRegionArgOfOp =
false;
182 containerOpInRegion.isUnique = isOnlyContainerOp;
183 containerOpInRegion.getContainerOp() = o;
184 containerOpInRegion.getOpPos() = sequentialIdx;
185 containerOpInRegion.getParentRegion() = r;
190 Component ®ionArgOfOperation = components.emplace_back();
191 regionArgOfOperation.isRegionArgOfOp =
true;
192 regionArgOfOperation.isUnique =
true;
193 regionArgOfOperation.getArgIdx() = 0;
194 regionArgOfOperation.getOwnerOp() = parent;
206 for (
auto [idx, region] : llvm::enumerate(o->
getRegions())) {
210 llvm_unreachable(
"Region not child of its parent operation");
212 regionArgOfOperation.isUnique =
false;
213 regionArgOfOperation.getArgIdx() = getRegionIndex(parent, r);
221 for (Component &c : components)
222 c.skipIf(c.isRegionArgOfOp && c.isUnique);
225 size_t numSurroundingLoops = 0;
226 for (Component &c : llvm::reverse(components)) {
231 if (c.isRegionArgOfOp) {
232 numSurroundingLoops = 0;
239 numSurroundingLoops = 0;
241 c.getLoopDepth() = numSurroundingLoops;
244 if (isa<CanonicalLoopOp>(c.getContainerOp()))
245 numSurroundingLoops += 1;
250 bool isLoopNest =
false;
251 for (Component &c : components) {
252 if (c.skip || c.isRegionArgOfOp)
255 if (!isLoopNest && c.getLoopDepth() >= 1) {
258 }
else if (isLoopNest) {
260 c.skipIf(c.isUnique);
264 if (c.getLoopDepth() == 0)
271 for (Component &c : components)
272 c.skipIf(!c.isRegionArgOfOp && c.isUnique &&
273 !isa<CanonicalLoopOp>(c.getContainerOp()));
277 bool newRegion =
true;
278 for (Component &c : llvm::reverse(components)) {
279 c.skipIf(newRegion && c.isUnique);
286 if (!c.isRegionArgOfOp && c.getContainerOp())
292 llvm::raw_svector_ostream NameOS(Name);
293 for (
auto &c : llvm::reverse(components)) {
297 if (c.isRegionArgOfOp)
298 NameOS <<
"_r" << c.getArgIdx();
299 else if (c.getLoopDepth() >= 1)
300 NameOS <<
"_d" << c.getLoopDepth();
302 NameOS <<
"_s" << c.getOpPos();
305 return NameOS.str().str();
308void OpenMPDialect::initialize() {
311#include "mlir/Dialect/OpenMP/OpenMPOps.cpp.inc"
314#define GET_ATTRDEF_LIST
315#include "mlir/Dialect/OpenMP/OpenMPOpsAttributes.cpp.inc"
318#define GET_TYPEDEF_LIST
319#include "mlir/Dialect/OpenMP/OpenMPOpsTypes.cpp.inc"
322 declarePromisedInterface<ConvertToLLVMPatternInterface, OpenMPDialect>();
324 MemRefType::attachInterface<MemRefPointerLikeModel>(*
getContext());
325 LLVM::LLVMPointerType::attachInterface<LLVMPointerPointerLikeModel>(
330 mlir::ModuleOp::attachInterface<mlir::omp::OffloadModuleDefaultModel>(
336 mlir::LLVM::GlobalOp::attachInterface<
339 mlir::LLVM::LLVMFuncOp::attachInterface<
342 mlir::func::FuncOp::attachInterface<
368 allocatorVars.push_back(operand);
369 allocatorTypes.push_back(type);
375 allocateVars.push_back(operand);
376 allocateTypes.push_back(type);
387 for (
unsigned i = 0; i < allocateVars.size(); ++i) {
388 std::string separator = i == allocateVars.size() - 1 ?
"" :
", ";
389 p << allocatorVars[i] <<
" : " << allocatorTypes[i] <<
" -> ";
390 p << allocateVars[i] <<
" : " << allocateTypes[i] << separator;
398template <
typename ClauseAttr>
400 using ClauseT =
decltype(std::declval<ClauseAttr>().getValue());
405 if (std::optional<ClauseT> enumValue = symbolizeEnum<ClauseT>(enumStr)) {
406 attr = ClauseAttr::get(parser.
getContext(), *enumValue);
409 return parser.
emitError(loc,
"invalid clause value: '") << enumStr <<
"'";
412template <
typename ClauseAttr>
414 p << stringifyEnum(attr.getValue());
439 std::optional<omp::LinearModifier> linearModifier;
441 linearModifier = omp::LinearModifier::val;
443 linearModifier = omp::LinearModifier::ref;
445 linearModifier = omp::LinearModifier::uval;
448 bool hasLinearModifierParens = linearModifier.has_value();
449 if (hasLinearModifierParens && parser.
parseLParen())
457 if (hasLinearModifierParens && parser.
parseRParen())
460 linearVars.push_back(var);
461 linearTypes.push_back(type);
462 linearStepVars.push_back(stepVar);
463 linearStepTypes.push_back(stepType);
464 if (linearModifier) {
466 omp::LinearModifierAttr::get(parser.
getContext(), *linearModifier));
468 modifiers.push_back(UnitAttr::get(parser.
getContext()));
474 linearModifiers = ArrayAttr::get(parser.
getContext(), modifiers);
483 size_t linearVarsSize = linearVars.size();
484 for (
unsigned i = 0; i < linearVarsSize; ++i) {
488 Attribute modAttr = linearModifiers ? linearModifiers[i] :
nullptr;
489 auto mod = modAttr ? dyn_cast<omp::LinearModifierAttr>(modAttr) :
nullptr;
491 p << omp::stringifyLinearModifier(mod.getValue()) <<
"(";
493 p << linearVars[i] <<
" : " << linearTypes[i];
494 p <<
" = " << linearStepVars[i] <<
" : " << stepVarTypes[i];
510 if (!linearModifiers)
512 if (linearModifiers->size() != linearVars.size())
514 <<
"expected as many linear modifiers as linear variables";
515 if (!isDeclareSimd) {
516 for (
Attribute attr : *linearModifiers) {
519 auto modAttr = dyn_cast<omp::LinearModifierAttr>(attr);
522 omp::LinearModifier mod = modAttr.getValue();
523 if (mod == omp::LinearModifier::ref || mod == omp::LinearModifier::uval)
525 <<
"linear modifier '" << omp::stringifyLinearModifier(mod)
526 <<
"' may only be specified on a declare simd directive";
541 for (
const auto &it : nontemporalVars)
542 if (!nontemporalItems.insert(it).second)
543 return op->
emitOpError() <<
"nontemporal variable used more than once";
552 std::optional<ArrayAttr> alignments,
555 if (!alignedVars.empty()) {
556 if (!alignments || alignments->size() != alignedVars.size())
558 <<
"expected as many alignment values as aligned variables";
561 return op->
emitOpError() <<
"unexpected alignment values attribute";
567 for (
auto it : alignedVars)
568 if (!alignedItems.insert(it).second)
569 return op->
emitOpError() <<
"aligned variable used more than once";
575 for (
unsigned i = 0; i < (*alignments).size(); ++i) {
576 if (
auto intAttr = llvm::dyn_cast<IntegerAttr>((*alignments)[i])) {
577 if (intAttr.getValue().sle(0))
578 return op->
emitOpError() <<
"alignment should be greater than 0";
580 return op->
emitOpError() <<
"expected integer alignment";
597 if (parser.parseOperand(alignedVars.emplace_back()) ||
598 parser.parseColonType(alignedTypes.emplace_back()) ||
599 parser.parseArrow() ||
600 parser.parseAttribute(alignmentVec.emplace_back())) {
607 alignmentsAttr = ArrayAttr::get(parser.getContext(), alignments);
614 std::optional<ArrayAttr> alignments) {
615 for (
unsigned i = 0; i < alignedVars.size(); ++i) {
618 p << alignedVars[i] <<
" : " << alignedVars[i].
getType();
619 p <<
" -> " << (*alignments)[i];
630 if (modifiers.size() > 2)
632 for (
const auto &mod : modifiers) {
635 auto symbol = symbolizeScheduleModifier(mod);
638 <<
" unknown modifier type: " << mod;
643 if (modifiers.size() == 1) {
644 if (symbolizeScheduleModifier(modifiers[0]) == ScheduleModifier::simd) {
645 modifiers.push_back(modifiers[0]);
646 modifiers[0] = stringifyScheduleModifier(ScheduleModifier::none);
648 }
else if (modifiers.size() == 2) {
651 if (symbolizeScheduleModifier(modifiers[0]) == ScheduleModifier::simd ||
652 symbolizeScheduleModifier(modifiers[1]) != ScheduleModifier::simd)
654 <<
" incorrect modifier order";
670 ScheduleModifierAttr &scheduleMod, UnitAttr &scheduleSimd,
671 std::optional<OpAsmParser::UnresolvedOperand> &chunkSize,
676 std::optional<mlir::omp::ClauseScheduleKind> schedule =
677 symbolizeClauseScheduleKind(keyword);
681 scheduleAttr = ClauseScheduleKindAttr::get(parser.
getContext(), *schedule);
683 case ClauseScheduleKind::Static:
684 case ClauseScheduleKind::Dynamic:
685 case ClauseScheduleKind::Guided:
691 chunkSize = std::nullopt;
694 case ClauseScheduleKind::Auto:
695 case ClauseScheduleKind::Runtime:
696 case ClauseScheduleKind::Distribute:
697 chunkSize = std::nullopt;
706 modifiers.push_back(mod);
712 if (!modifiers.empty()) {
714 if (std::optional<ScheduleModifier> mod =
715 symbolizeScheduleModifier(modifiers[0])) {
716 scheduleMod = ScheduleModifierAttr::get(parser.
getContext(), *mod);
718 return parser.
emitError(loc,
"invalid schedule modifier");
721 if (modifiers.size() > 1) {
722 assert(symbolizeScheduleModifier(modifiers[1]) == ScheduleModifier::simd);
732 ClauseScheduleKindAttr scheduleKind,
733 ScheduleModifierAttr scheduleMod,
734 UnitAttr scheduleSimd,
Value scheduleChunk,
735 Type scheduleChunkType) {
736 p << stringifyClauseScheduleKind(scheduleKind.getValue());
738 p <<
" = " << scheduleChunk <<
" : " << scheduleChunk.
getType();
740 p <<
", " << stringifyScheduleModifier(scheduleMod.getValue());
752 ClauseOrderKindAttr &order,
753 OrderModifierAttr &orderMod) {
758 if (std::optional<OrderModifier> enumValue =
759 symbolizeOrderModifier(enumStr)) {
760 orderMod = OrderModifierAttr::get(parser.
getContext(), *enumValue);
767 if (std::optional<ClauseOrderKind> enumValue =
768 symbolizeClauseOrderKind(enumStr)) {
769 order = ClauseOrderKindAttr::get(parser.
getContext(), *enumValue);
772 return parser.
emitError(loc,
"invalid clause value: '") << enumStr <<
"'";
776 ClauseOrderKindAttr order,
777 OrderModifierAttr orderMod) {
779 p << stringifyOrderModifier(orderMod.getValue()) <<
":";
781 p << stringifyClauseOrderKind(order.getValue());
784template <
typename ClauseTypeAttr,
typename ClauseType>
787 std::optional<OpAsmParser::UnresolvedOperand> &operand,
789 std::optional<ClauseType> (*symbolizeClause)(StringRef),
790 StringRef clauseName) {
793 if (std::optional<ClauseType> enumValue = symbolizeClause(enumStr)) {
794 prescriptiveness = ClauseTypeAttr::get(parser.
getContext(), *enumValue);
799 <<
"invalid " << clauseName <<
" modifier : '" << enumStr <<
"'";
809 <<
"expected " << clauseName <<
" operand";
812 if (operand.has_value()) {
820template <
typename ClauseTypeAttr,
typename ClauseType>
823 ClauseTypeAttr prescriptiveness,
Value operand,
825 StringRef (*stringifyClauseType)(ClauseType)) {
827 if (prescriptiveness)
828 p << stringifyClauseType(prescriptiveness.getValue()) <<
", ";
831 p << operand <<
": " << operandType;
841 std::optional<OpAsmParser::UnresolvedOperand> &grainsize,
842 Type &grainsizeType) {
844 parser, grainsizeMod, grainsize, grainsizeType,
845 &symbolizeClauseGrainsizeType,
"grainsize");
849 ClauseGrainsizeTypeAttr grainsizeMod,
852 p, op, grainsizeMod, grainsize, grainsizeType,
853 &stringifyClauseGrainsizeType);
863 std::optional<OpAsmParser::UnresolvedOperand> &numTasks,
864 Type &numTasksType) {
866 parser, numTasksMod, numTasks, numTasksType, &symbolizeClauseNumTasksType,
871 ClauseNumTasksTypeAttr numTasksMod,
874 p, op, numTasksMod, numTasks, numTasksType, &stringifyClauseNumTasksType);
890 return mlir::failure();
891 inTypeAttr = TypeAttr::get(inType);
920 if (!typeparams.empty()) {
921 p <<
'(' << typeparams <<
" : " << typeparamsTypes <<
')';
923 for (
auto sh :
shape) {
935 FallbackModifierAttr fallback,
936 Value dynGroupprivateSize) {
937 if (!dynGroupprivateSize && (accessGroup || fallback))
938 return op->
emitOpError(
"dyn_groupprivate modifiers require a size operand");
944 OpAsmParser &parser, AccessGroupModifierAttr &accessGroupAttr,
945 FallbackModifierAttr &fallbackAttr,
946 std::optional<OpAsmParser::UnresolvedOperand> &dynGroupprivateSize,
949 bool parsedAccessGroup =
false;
950 bool parsedFallback =
false;
951 bool parsedSize =
false;
956 if (parsedAccessGroup)
958 "duplicate access group modifier");
959 accessGroupAttr = AccessGroupModifierAttr::get(
960 parser.
getContext(), AccessGroupModifier::cgroup);
961 parsedAccessGroup =
true;
968 "duplicate fallback modifier");
971 "expected '(' after 'fallback'");
972 llvm::StringRef fbKind;
976 "expected fallback modifier (abort/null/default_mem)");
977 std::optional<FallbackModifier> fbEnum;
978 if (fbKind ==
"abort")
979 fbEnum = FallbackModifier::abort;
980 else if (fbKind ==
"null")
981 fbEnum = FallbackModifier::null;
982 else if (fbKind ==
"default_mem")
983 fbEnum = FallbackModifier::default_mem;
986 "invalid fallback modifier '" + fbKind +
"'");
987 fallbackAttr = FallbackModifierAttr::get(parser.
getContext(), *fbEnum);
990 "expected ')' after fallback modifier");
991 parsedFallback =
true;
999 "duplicate size operand");
1000 dynGroupprivateSize = operand;
1004 "expected ':' and type after size operand");
1008 "expected dyn_groupprivate_size operand");
1013 AccessGroupModifierAttr modifierFirst,
1014 FallbackModifierAttr modifierSecond,
1015 Value dynGroupprivateSize,
1018 bool needsComma =
false;
1020 if (modifierFirst) {
1021 printer << modifierFirst.getValue();
1025 if (modifierSecond) {
1028 printer <<
"fallback(";
1029 printer << modifierSecond.getValue();
1034 if (dynGroupprivateSize) {
1037 printer << dynGroupprivateSize <<
" : " << sizeType;
1046struct MapParseArgs {
1047 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &vars;
1048 SmallVectorImpl<Type> &types;
1049 MapParseArgs(SmallVectorImpl<OpAsmParser::UnresolvedOperand> &vars,
1050 SmallVectorImpl<Type> &types)
1051 : vars(vars), types(types) {}
1053struct PrivateParseArgs {
1054 llvm::SmallVectorImpl<OpAsmParser::UnresolvedOperand> &vars;
1055 llvm::SmallVectorImpl<Type> &types;
1057 UnitAttr &needsBarrier;
1059 PrivateParseArgs(SmallVectorImpl<OpAsmParser::UnresolvedOperand> &vars,
1060 SmallVectorImpl<Type> &types,
ArrayAttr &syms,
1061 UnitAttr &needsBarrier,
1063 : vars(vars), types(types), syms(syms), needsBarrier(needsBarrier),
1064 mapIndices(mapIndices) {}
1067struct ReductionParseArgs {
1068 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &vars;
1069 SmallVectorImpl<Type> &types;
1072 ReductionModifierAttr *modifier;
1073 ReductionParseArgs(SmallVectorImpl<OpAsmParser::UnresolvedOperand> &vars,
1075 ArrayAttr &syms, ReductionModifierAttr *mod =
nullptr)
1076 : vars(vars), types(types), byref(byref), syms(syms), modifier(mod) {}
1079struct AllRegionParseArgs {
1080 std::optional<MapParseArgs> hasDeviceAddrArgs;
1081 std::optional<MapParseArgs> hostEvalArgs;
1082 std::optional<ReductionParseArgs> inReductionArgs;
1083 std::optional<MapParseArgs> mapArgs;
1084 std::optional<PrivateParseArgs> privateArgs;
1085 std::optional<ReductionParseArgs> reductionArgs;
1086 std::optional<ReductionParseArgs> taskReductionArgs;
1087 std::optional<MapParseArgs> useDeviceAddrArgs;
1088 std::optional<MapParseArgs> useDevicePtrArgs;
1093 return "private_barrier";
1103 ReductionModifierAttr *modifier =
nullptr,
1104 UnitAttr *needsBarrier =
nullptr) {
1108 unsigned regionArgOffset = regionPrivateArgs.size();
1118 std::optional<ReductionModifier> enumValue =
1119 symbolizeReductionModifier(enumStr);
1120 if (!enumValue.has_value())
1122 *modifier = ReductionModifierAttr::get(parser.
getContext(), *enumValue);
1129 isByRefVec.push_back(
1130 parser.parseOptionalKeyword(
"byref").succeeded());
1132 if (symbols && parser.parseAttribute(symbolVec.emplace_back()))
1135 if (parser.parseOperand(operands.emplace_back()) ||
1136 parser.parseArrow() ||
1137 parser.parseArgument(regionPrivateArgs.emplace_back()))
1141 if (parser.parseOptionalLSquare().succeeded()) {
1142 if (parser.parseKeyword(
"map_idx") || parser.parseEqual() ||
1143 parser.parseInteger(mapIndicesVec.emplace_back()) ||
1144 parser.parseRSquare())
1147 mapIndicesVec.push_back(-1);
1159 if (parser.parseType(types.emplace_back()))
1166 if (operands.size() != types.size())
1175 *needsBarrier = mlir::UnitAttr::get(parser.
getContext());
1178 auto *argsBegin = regionPrivateArgs.begin();
1180 argsBegin + regionArgOffset + types.size());
1181 for (
auto [prv, type] : llvm::zip_equal(argsSubrange, types)) {
1187 *symbols = ArrayAttr::get(parser.
getContext(), symbolAttrs);
1190 if (!mapIndicesVec.empty())
1203 StringRef keyword, std::optional<MapParseArgs> mapArgs) {
1218 StringRef keyword, std::optional<PrivateParseArgs> privateArgs) {
1224 parser, privateArgs->vars, privateArgs->types, entryBlockArgs,
1225 &privateArgs->syms, privateArgs->mapIndices,
nullptr,
1226 nullptr, &privateArgs->needsBarrier)))
1235 StringRef keyword, std::optional<ReductionParseArgs> reductionArgs) {
1240 parser, reductionArgs->vars, reductionArgs->types, entryBlockArgs,
1241 &reductionArgs->syms,
nullptr, &reductionArgs->byref,
1242 reductionArgs->modifier)))
1249 AllRegionParseArgs args) {
1253 args.hasDeviceAddrArgs)))
1255 <<
"invalid `has_device_addr` format";
1258 args.hostEvalArgs)))
1260 <<
"invalid `host_eval` format";
1263 args.inReductionArgs)))
1265 <<
"invalid `in_reduction` format";
1270 <<
"invalid `map_entries` format";
1275 <<
"invalid `private` format";
1278 args.reductionArgs)))
1280 <<
"invalid `reduction` format";
1283 args.taskReductionArgs)))
1285 <<
"invalid `task_reduction` format";
1288 args.useDeviceAddrArgs)))
1290 <<
"invalid `use_device_addr` format";
1293 args.useDevicePtrArgs)))
1295 <<
"invalid `use_device_addr` format";
1297 return parser.
parseRegion(region, entryBlockArgs);
1316 AllRegionParseArgs args;
1317 args.hasDeviceAddrArgs.emplace(hasDeviceAddrVars, hasDeviceAddrTypes);
1318 args.hostEvalArgs.emplace(hostEvalVars, hostEvalTypes);
1319 args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
1320 inReductionByref, inReductionSyms);
1321 args.mapArgs.emplace(mapVars, mapTypes);
1322 args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
1323 privateNeedsBarrier, &privateMaps);
1334 UnitAttr &privateNeedsBarrier) {
1335 AllRegionParseArgs args;
1336 args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
1337 inReductionByref, inReductionSyms);
1338 args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
1339 privateNeedsBarrier);
1350 UnitAttr &privateNeedsBarrier, ReductionModifierAttr &reductionMod,
1354 AllRegionParseArgs args;
1355 args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
1356 inReductionByref, inReductionSyms);
1357 args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
1358 privateNeedsBarrier);
1359 args.reductionArgs.emplace(reductionVars, reductionTypes, reductionByref,
1360 reductionSyms, &reductionMod);
1368 UnitAttr &privateNeedsBarrier) {
1369 AllRegionParseArgs args;
1370 args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
1371 privateNeedsBarrier);
1379 UnitAttr &privateNeedsBarrier, ReductionModifierAttr &reductionMod,
1383 AllRegionParseArgs args;
1384 args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
1385 privateNeedsBarrier);
1386 args.reductionArgs.emplace(reductionVars, reductionTypes, reductionByref,
1387 reductionSyms, &reductionMod);
1396 AllRegionParseArgs args;
1397 args.taskReductionArgs.emplace(taskReductionVars, taskReductionTypes,
1398 taskReductionByref, taskReductionSyms);
1408 AllRegionParseArgs args;
1409 args.useDeviceAddrArgs.emplace(useDeviceAddrVars, useDeviceAddrTypes);
1410 args.useDevicePtrArgs.emplace(useDevicePtrVars, useDevicePtrTypes);
1419struct MapPrintArgs {
1424struct PrivatePrintArgs {
1428 UnitAttr needsBarrier;
1432 : vars(vars), types(types), syms(syms), needsBarrier(needsBarrier),
1433 mapIndices(mapIndices) {}
1435struct ReductionPrintArgs {
1440 ReductionModifierAttr modifier;
1442 ArrayAttr syms, ReductionModifierAttr mod =
nullptr)
1443 : vars(vars), types(types), byref(byref), syms(syms), modifier(mod) {}
1445struct AllRegionPrintArgs {
1446 std::optional<MapPrintArgs> hasDeviceAddrArgs;
1447 std::optional<MapPrintArgs> hostEvalArgs;
1448 std::optional<ReductionPrintArgs> inReductionArgs;
1449 std::optional<MapPrintArgs> mapArgs;
1450 std::optional<PrivatePrintArgs> privateArgs;
1451 std::optional<ReductionPrintArgs> reductionArgs;
1452 std::optional<ReductionPrintArgs> taskReductionArgs;
1453 std::optional<MapPrintArgs> useDeviceAddrArgs;
1454 std::optional<MapPrintArgs> useDevicePtrArgs;
1463 ReductionModifierAttr modifier =
nullptr, UnitAttr needsBarrier =
nullptr) {
1464 if (argsSubrange.empty())
1467 p << clauseName <<
"(";
1470 p <<
"mod: " << stringifyReductionModifier(modifier.getValue()) <<
", ";
1474 symbols = ArrayAttr::get(ctx, values);
1487 llvm::interleaveComma(llvm::zip_equal(operands, argsSubrange, symbols,
1488 mapIndices.asArrayRef(),
1489 byref.asArrayRef()),
1491 auto [op, arg, sym, map, isByRef] = t;
1497 p << op <<
" -> " << arg;
1500 p <<
" [map_idx=" << map <<
"]";
1503 llvm::interleaveComma(types, p);
1511 StringRef clauseName,
ValueRange argsSubrange,
1512 std::optional<MapPrintArgs> mapArgs) {
1519 StringRef clauseName,
ValueRange argsSubrange,
1520 std::optional<PrivatePrintArgs> privateArgs) {
1523 p, ctx, clauseName, argsSubrange, privateArgs->vars, privateArgs->types,
1524 privateArgs->syms, privateArgs->mapIndices,
nullptr,
1525 nullptr, privateArgs->needsBarrier);
1531 std::optional<ReductionPrintArgs> reductionArgs) {
1534 reductionArgs->vars, reductionArgs->types,
1535 reductionArgs->syms,
nullptr,
1536 reductionArgs->byref, reductionArgs->modifier);
1540 const AllRegionPrintArgs &args) {
1541 auto iface = llvm::cast<mlir::omp::BlockArgOpenMPOpInterface>(op);
1545 iface.getHasDeviceAddrBlockArgs(),
1546 args.hasDeviceAddrArgs);
1550 args.inReductionArgs);
1556 args.reductionArgs);
1558 iface.getTaskReductionBlockArgs(),
1559 args.taskReductionArgs);
1561 iface.getUseDeviceAddrBlockArgs(),
1562 args.useDeviceAddrArgs);
1564 iface.getUseDevicePtrBlockArgs(), args.useDevicePtrArgs);
1580 AllRegionPrintArgs args;
1581 args.hasDeviceAddrArgs.emplace(hasDeviceAddrVars, hasDeviceAddrTypes);
1582 args.hostEvalArgs.emplace(hostEvalVars, hostEvalTypes);
1583 args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
1584 inReductionByref, inReductionSyms);
1585 args.mapArgs.emplace(mapVars, mapTypes);
1586 args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
1587 privateNeedsBarrier, privateMaps);
1595 ArrayAttr privateSyms, UnitAttr privateNeedsBarrier) {
1596 AllRegionPrintArgs args;
1597 args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
1598 inReductionByref, inReductionSyms);
1599 args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
1600 privateNeedsBarrier,
1609 ArrayAttr privateSyms, UnitAttr privateNeedsBarrier,
1610 ReductionModifierAttr reductionMod,
ValueRange reductionVars,
1613 AllRegionPrintArgs args;
1614 args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
1615 inReductionByref, inReductionSyms);
1616 args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
1617 privateNeedsBarrier,
1619 args.reductionArgs.emplace(reductionVars, reductionTypes, reductionByref,
1620 reductionSyms, reductionMod);
1627 UnitAttr privateNeedsBarrier) {
1628 AllRegionPrintArgs args;
1629 args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
1630 privateNeedsBarrier,
1638 ReductionModifierAttr reductionMod,
ValueRange reductionVars,
1641 AllRegionPrintArgs args;
1642 args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
1643 privateNeedsBarrier,
1645 args.reductionArgs.emplace(reductionVars, reductionTypes, reductionByref,
1646 reductionSyms, reductionMod);
1656 AllRegionPrintArgs args;
1657 args.taskReductionArgs.emplace(taskReductionVars, taskReductionTypes,
1658 taskReductionByref, taskReductionSyms);
1668 AllRegionPrintArgs args;
1669 args.useDeviceAddrArgs.emplace(useDeviceAddrVars, useDeviceAddrTypes);
1670 args.useDevicePtrArgs.emplace(useDevicePtrVars, useDevicePtrTypes);
1674template <
typename ParsePrefixFn>
1683 if (failed(parsePrefix()))
1691 if (llvm::isa<mlir::omp::IteratedType>(ty)) {
1692 iteratedVars.push_back(v);
1693 iteratedTypes.push_back(ty);
1695 plainVars.push_back(v);
1696 plainTypes.push_back(ty);
1702template <
typename Pr
intPrefixFn>
1706 PrintPrefixFn &&printPrefixForPlain,
1707 PrintPrefixFn &&printPrefixForIterated) {
1714 p << v <<
" : " << t;
1718 for (
unsigned i = 0; i < iteratedVars.size(); ++i)
1719 emit(iteratedVars[i], iteratedTypes[i], printPrefixForIterated);
1720 for (
unsigned i = 0; i < plainVars.size(); ++i)
1721 emit(plainVars[i], plainTypes[i], printPrefixForPlain);
1729 if (!reductionVars.empty()) {
1730 if (!reductionSyms || reductionSyms->size() != reductionVars.size())
1732 <<
"expected as many reduction symbol references "
1733 "as reduction variables";
1734 if (reductionByref && reductionByref->size() != reductionVars.size())
1735 return op->
emitError() <<
"expected as many reduction variable by "
1736 "reference attributes as reduction variables";
1739 return op->
emitOpError() <<
"unexpected reduction symbol references";
1746 for (
auto args : llvm::zip(reductionVars, *reductionSyms)) {
1747 Value accum = std::get<0>(args);
1749 if (!accumulators.insert(accum).second)
1750 return op->
emitOpError() <<
"accumulator variable used more than once";
1753 auto symbolRef = llvm::cast<SymbolRefAttr>(std::get<1>(args));
1757 return op->
emitOpError() <<
"expected symbol reference " << symbolRef
1758 <<
" to point to a reduction declaration";
1760 if (decl.getAccumulatorType() && decl.getAccumulatorType() != varType)
1762 <<
"expected accumulator (" << varType
1763 <<
") to be the same type as reduction declaration ("
1764 << decl.getAccumulatorType() <<
")";
1783 if (parser.parseOperand(copyprivateVars.emplace_back()) ||
1784 parser.parseArrow() ||
1785 parser.parseAttribute(symsVec.emplace_back()) ||
1786 parser.parseColonType(copyprivateTypes.emplace_back()))
1792 copyprivateSyms = ArrayAttr::get(parser.
getContext(), syms);
1800 std::optional<ArrayAttr> copyprivateSyms) {
1801 if (!copyprivateSyms.has_value())
1803 llvm::interleaveComma(
1804 llvm::zip(copyprivateVars, *copyprivateSyms, copyprivateTypes), p,
1805 [&](
const auto &args) {
1806 p << std::get<0>(args) <<
" -> " << std::get<1>(args) <<
" : "
1807 << std::get<2>(args);
1814 std::optional<ArrayAttr> copyprivateSyms) {
1815 size_t copyprivateSymsSize =
1816 copyprivateSyms.has_value() ? copyprivateSyms->size() : 0;
1817 if (copyprivateSymsSize != copyprivateVars.size())
1818 return op->
emitOpError() <<
"inconsistent number of copyprivate vars (= "
1819 << copyprivateVars.size()
1820 <<
") and functions (= " << copyprivateSymsSize
1821 <<
"), both must be equal";
1822 if (!copyprivateSyms.has_value())
1825 for (
auto copyprivateVarAndSym :
1826 llvm::zip(copyprivateVars, *copyprivateSyms)) {
1828 llvm::cast<SymbolRefAttr>(std::get<1>(copyprivateVarAndSym));
1829 std::optional<std::variant<mlir::func::FuncOp, mlir::LLVM::LLVMFuncOp>>
1831 if (mlir::func::FuncOp mlirFuncOp =
1834 funcOp = mlirFuncOp;
1835 else if (mlir::LLVM::LLVMFuncOp llvmFuncOp =
1838 funcOp = llvmFuncOp;
1840 auto getNumArguments = [&] {
1841 return std::visit([](
auto &f) {
return f.getNumArguments(); }, *funcOp);
1844 auto getArgumentType = [&](
unsigned i) {
1845 return std::visit([i](
auto &f) {
return f.getArgumentTypes()[i]; },
1850 return op->
emitOpError() <<
"expected symbol reference " << symbolRef
1851 <<
" to point to a copy function";
1853 if (getNumArguments() != 2)
1855 <<
"expected copy function " << symbolRef <<
" to have 2 operands";
1857 Type argTy = getArgumentType(0);
1858 if (argTy != getArgumentType(1))
1859 return op->
emitOpError() <<
"expected copy function " << symbolRef
1860 <<
" arguments to have the same type";
1862 Type varType = std::get<0>(copyprivateVarAndSym).getType();
1863 if (argTy != varType)
1865 <<
"expected copy function arguments' type (" << argTy
1866 <<
") to be the same as copyprivate variable's type (" << varType
1891 OpAsmParser::UnresolvedOperand operand;
1893 if (parser.parseKeyword(&keyword) || parser.parseArrow() ||
1894 parser.parseOperand(operand) || parser.parseColonType(ty))
1896 std::optional<ClauseTaskDepend> keywordDepend =
1897 symbolizeClauseTaskDepend(keyword);
1901 ClauseTaskDependAttr::get(parser.getContext(), *keywordDepend);
1902 if (llvm::isa<mlir::omp::IteratedType>(ty)) {
1903 iteratedVars.push_back(operand);
1904 iteratedTypes.push_back(ty);
1905 iterKindsVec.push_back(kindAttr);
1907 dependVars.push_back(operand);
1908 dependTypes.push_back(ty);
1909 kindsVec.push_back(kindAttr);
1915 dependKinds = ArrayAttr::get(parser.
getContext(), kinds);
1917 iteratedKinds = ArrayAttr::get(parser.
getContext(), iterKinds);
1924 std::optional<ArrayAttr> dependKinds,
1927 std::optional<ArrayAttr> iteratedKinds) {
1930 std::optional<ArrayAttr> kinds) {
1931 for (
unsigned i = 0, e = vars.size(); i < e; ++i) {
1934 p << stringifyClauseTaskDepend(
1935 llvm::cast<mlir::omp::ClauseTaskDependAttr>((*kinds)[i])
1937 <<
" -> " << vars[i] <<
" : " << types[i];
1941 printEntries(dependVars, dependTypes, dependKinds);
1942 printEntries(iteratedVars, iteratedTypes, iteratedKinds);
1947 std::optional<ArrayAttr> dependKinds,
1949 std::optional<ArrayAttr> iteratedKinds,
1951 if (!dependVars.empty()) {
1952 if (!dependKinds || dependKinds->size() != dependVars.size())
1953 return op->
emitOpError() <<
"expected as many depend values"
1954 " as depend variables";
1956 if (dependKinds && !dependKinds->empty())
1957 return op->
emitOpError() <<
"unexpected depend values";
1960 if (!iteratedVars.empty()) {
1961 if (!iteratedKinds || iteratedKinds->size() != iteratedVars.size())
1962 return op->
emitOpError() <<
"expected as many depend iterated values"
1963 " as depend iterated variables";
1965 if (iteratedKinds && !iteratedKinds->empty())
1966 return op->
emitOpError() <<
"unexpected depend iterated values";
1981 IntegerAttr &hintAttr) {
1982 StringRef hintKeyword;
1988 auto parseKeyword = [&]() -> ParseResult {
1991 if (hintKeyword ==
"uncontended")
1993 else if (hintKeyword ==
"contended")
1995 else if (hintKeyword ==
"nonspeculative")
1997 else if (hintKeyword ==
"speculative")
2001 << hintKeyword <<
" is not a valid hint";
2012 IntegerAttr hintAttr) {
2013 int64_t hint = hintAttr.getInt();
2021 auto bitn = [](
int value,
int n) ->
bool {
return value & (1 << n); };
2023 bool uncontended = bitn(hint, 0);
2024 bool contended = bitn(hint, 1);
2025 bool nonspeculative = bitn(hint, 2);
2026 bool speculative = bitn(hint, 3);
2030 hints.push_back(
"uncontended");
2032 hints.push_back(
"contended");
2034 hints.push_back(
"nonspeculative");
2036 hints.push_back(
"speculative");
2038 llvm::interleaveComma(hints, p);
2045 auto bitn = [](
int value,
int n) ->
bool {
return value & (1 << n); };
2047 bool uncontended = bitn(hint, 0);
2048 bool contended = bitn(hint, 1);
2049 bool nonspeculative = bitn(hint, 2);
2050 bool speculative = bitn(hint, 3);
2052 if (uncontended && contended)
2053 return op->
emitOpError() <<
"the hints omp_sync_hint_uncontended and "
2054 "omp_sync_hint_contended cannot be combined";
2055 if (nonspeculative && speculative)
2056 return op->
emitOpError() <<
"the hints omp_sync_hint_nonspeculative and "
2057 "omp_sync_hint_speculative cannot be combined.";
2068 return (value & flag) == flag;
2076static ParseResult parseMapClause(
OpAsmParser &parser,
2077 ClauseMapFlagsAttr &mapType) {
2078 ClauseMapFlags mapTypeBits = ClauseMapFlags::none;
2081 auto parseTypeAndMod = [&]() -> ParseResult {
2082 StringRef mapTypeMod;
2086 if (mapTypeMod ==
"always")
2087 mapTypeBits |= ClauseMapFlags::always;
2089 if (mapTypeMod ==
"implicit")
2090 mapTypeBits |= ClauseMapFlags::implicit;
2092 if (mapTypeMod ==
"ompx_hold")
2093 mapTypeBits |= ClauseMapFlags::ompx_hold;
2095 if (mapTypeMod ==
"close")
2096 mapTypeBits |= ClauseMapFlags::close;
2098 if (mapTypeMod ==
"present")
2099 mapTypeBits |= ClauseMapFlags::present;
2101 if (mapTypeMod ==
"to")
2102 mapTypeBits |= ClauseMapFlags::to;
2104 if (mapTypeMod ==
"from")
2105 mapTypeBits |= ClauseMapFlags::from;
2107 if (mapTypeMod ==
"tofrom")
2108 mapTypeBits |= ClauseMapFlags::to | ClauseMapFlags::from;
2110 if (mapTypeMod ==
"delete")
2111 mapTypeBits |= ClauseMapFlags::del;
2113 if (mapTypeMod ==
"storage")
2114 mapTypeBits |= ClauseMapFlags::storage;
2116 if (mapTypeMod ==
"return_param")
2117 mapTypeBits |= ClauseMapFlags::return_param;
2119 if (mapTypeMod ==
"private")
2120 mapTypeBits |= ClauseMapFlags::priv;
2122 if (mapTypeMod ==
"literal")
2123 mapTypeBits |= ClauseMapFlags::literal;
2125 if (mapTypeMod ==
"attach")
2126 mapTypeBits |= ClauseMapFlags::attach;
2128 if (mapTypeMod ==
"attach_always")
2129 mapTypeBits |= ClauseMapFlags::attach_always;
2131 if (mapTypeMod ==
"attach_never")
2132 mapTypeBits |= ClauseMapFlags::attach_never;
2134 if (mapTypeMod ==
"attach_auto")
2135 mapTypeBits |= ClauseMapFlags::attach_auto;
2137 if (mapTypeMod ==
"ref_ptr")
2138 mapTypeBits |= ClauseMapFlags::ref_ptr;
2140 if (mapTypeMod ==
"ref_ptee")
2141 mapTypeBits |= ClauseMapFlags::ref_ptee;
2143 if (mapTypeMod ==
"is_device_ptr")
2144 mapTypeBits |= ClauseMapFlags::is_device_ptr;
2161 ClauseMapFlagsAttr mapType) {
2163 ClauseMapFlags mapFlags = mapType.getValue();
2168 mapTypeStrs.push_back(
"always");
2170 mapTypeStrs.push_back(
"implicit");
2172 mapTypeStrs.push_back(
"ompx_hold");
2174 mapTypeStrs.push_back(
"close");
2176 mapTypeStrs.push_back(
"present");
2185 mapTypeStrs.push_back(
"tofrom");
2187 mapTypeStrs.push_back(
"from");
2189 mapTypeStrs.push_back(
"to");
2192 mapTypeStrs.push_back(
"delete");
2194 mapTypeStrs.push_back(
"return_param");
2196 mapTypeStrs.push_back(
"storage");
2198 mapTypeStrs.push_back(
"private");
2200 mapTypeStrs.push_back(
"literal");
2202 mapTypeStrs.push_back(
"attach");
2204 mapTypeStrs.push_back(
"attach_always");
2206 mapTypeStrs.push_back(
"attach_never");
2208 mapTypeStrs.push_back(
"attach_auto");
2210 mapTypeStrs.push_back(
"ref_ptr");
2212 mapTypeStrs.push_back(
"ref_ptee");
2214 mapTypeStrs.push_back(
"is_device_ptr");
2215 if (mapFlags == ClauseMapFlags::none)
2216 mapTypeStrs.push_back(
"none");
2218 for (
unsigned int i = 0; i < mapTypeStrs.size(); ++i) {
2219 p << mapTypeStrs[i];
2220 if (i + 1 < mapTypeStrs.size()) {
2226static ParseResult parseMembersIndex(
OpAsmParser &parser,
2230 auto parseIndices = [&]() -> ParseResult {
2235 APInt(64, value,
false)));
2249 memberIdxs.push_back(ArrayAttr::get(parser.
getContext(), values));
2253 if (!memberIdxs.empty())
2254 membersIdx = ArrayAttr::get(parser.
getContext(), memberIdxs);
2264 llvm::interleaveComma(membersIdx, p, [&p](
Attribute v) {
2266 auto memberIdx = cast<ArrayAttr>(v);
2267 llvm::interleaveComma(memberIdx.getValue(), p, [&p](
Attribute v2) {
2268 p << cast<IntegerAttr>(v2).getInt();
2275 VariableCaptureKindAttr mapCaptureType) {
2276 std::string typeCapStr;
2277 llvm::raw_string_ostream typeCap(typeCapStr);
2278 if (mapCaptureType.getValue() == mlir::omp::VariableCaptureKind::ByRef)
2280 if (mapCaptureType.getValue() == mlir::omp::VariableCaptureKind::ByCopy)
2281 typeCap <<
"ByCopy";
2282 if (mapCaptureType.getValue() == mlir::omp::VariableCaptureKind::VLAType)
2283 typeCap <<
"VLAType";
2284 if (mapCaptureType.getValue() == mlir::omp::VariableCaptureKind::This)
2290 VariableCaptureKindAttr &mapCaptureType) {
2291 StringRef mapCaptureKey;
2295 if (mapCaptureKey ==
"This")
2296 mapCaptureType = mlir::omp::VariableCaptureKindAttr::get(
2297 parser.
getContext(), mlir::omp::VariableCaptureKind::This);
2298 if (mapCaptureKey ==
"ByRef")
2299 mapCaptureType = mlir::omp::VariableCaptureKindAttr::get(
2300 parser.
getContext(), mlir::omp::VariableCaptureKind::ByRef);
2301 if (mapCaptureKey ==
"ByCopy")
2302 mapCaptureType = mlir::omp::VariableCaptureKindAttr::get(
2303 parser.
getContext(), mlir::omp::VariableCaptureKind::ByCopy);
2304 if (mapCaptureKey ==
"VLAType")
2305 mapCaptureType = mlir::omp::VariableCaptureKindAttr::get(
2306 parser.
getContext(), mlir::omp::VariableCaptureKind::VLAType);
2315 for (
auto mapOp : mapVars) {
2316 if (!mapOp.getDefiningOp())
2319 if (
auto mapInfoOp = mapOp.getDefiningOp<mlir::omp::MapInfoOp>()) {
2320 mlir::omp::ClauseMapFlags mapTypeBits = mapInfoOp.getMapType();
2323 bool from =
mapTypeToBool(mapTypeBits, ClauseMapFlags::from);
2326 bool always =
mapTypeToBool(mapTypeBits, ClauseMapFlags::always);
2327 bool close =
mapTypeToBool(mapTypeBits, ClauseMapFlags::close);
2328 bool implicit =
mapTypeToBool(mapTypeBits, ClauseMapFlags::implicit);
2329 bool attach =
mapTypeToBool(mapTypeBits, ClauseMapFlags::attach);
2331 if ((isa<TargetDataOp>(op) || isa<TargetOp>(op)) && del)
2333 "to, from, tofrom and alloc map types are permitted");
2335 if (isa<TargetEnterDataOp>(op) && (from || del))
2336 return emitError(op->
getLoc(),
"to and alloc map types are permitted");
2338 if (isa<TargetExitDataOp>(op) && to)
2340 "from, release and delete map types are permitted");
2342 if (isa<TargetUpdateOp>(op)) {
2345 "at least one of to or from map types must be "
2346 "specified, other map types are not permitted");
2349 if (!to && !from && !attach) {
2352 "at least one of to or from or attach map types must be "
2353 "specified, other map types are not permitted");
2356 auto updateVar = mapInfoOp.getVarPtr();
2358 if ((to && from) || (to && updateFromVars.contains(updateVar)) ||
2359 (from && updateToVars.contains(updateVar))) {
2362 "either to or from map types can be specified, not both");
2365 if (always || close || implicit) {
2368 "present, mapper and iterator map type modifiers are permitted");
2374 to ? updateToVars.insert(updateVar)
2375 : updateFromVars.insert(updateVar);
2379 if ((mapInfoOp.getVarPtrPtr() && !mapInfoOp.getVarPtrPtrType()) ||
2380 (!mapInfoOp.getVarPtrPtr() && mapInfoOp.getVarPtrPtrType())) {
2383 "if varPtrPtr or varPtrPtrType is specified, then both "
2386 }
else if (!isa<DeclareMapperInfoOp>(op)) {
2388 "map argument is not a map entry operation");
2395template <
typename OpType>
2399 std::optional<DenseI64ArrayAttr> privateMapIndices =
2400 targetOp.getPrivateMapsAttr();
2403 if (!privateMapIndices.has_value() || !privateMapIndices.value())
2408 if (privateMapIndices.value().size() !=
2409 static_cast<int64_t>(privateVars.size()))
2410 return emitError(targetOp.getLoc(),
"sizes of `private` operand range and "
2411 "`private_maps` attribute mismatch");
2421 StringRef clauseName,
2423 for (
Value var : vars)
2424 if (!llvm::isa_and_present<MapInfoOp>(var.getDefiningOp()))
2426 <<
"'" << clauseName
2427 <<
"' arguments must be defined by 'omp.map.info' ops";
2431LogicalResult MapInfoOp::verify() {
2432 if (getMapperId() &&
2434 *
this, getMapperIdAttr())) {
2449 const TargetDataOperands &clauses) {
2450 TargetDataOp::build(builder, state, clauses.device, clauses.ifExpr,
2451 clauses.mapVars, clauses.useDeviceAddrVars,
2452 clauses.useDevicePtrVars);
2455LogicalResult TargetDataOp::verify() {
2456 if (getMapVars().empty() && getUseDevicePtrVars().empty() &&
2457 getUseDeviceAddrVars().empty()) {
2458 return ::emitError(this->getLoc(),
2459 "At least one of map, use_device_ptr_vars, or "
2460 "use_device_addr_vars operand must be present");
2464 getUseDevicePtrVars())))
2468 getUseDeviceAddrVars())))
2478void TargetEnterDataOp::build(
2482 TargetEnterDataOp::build(
2484 clauses.dependVars,
makeArrayAttr(ctx, clauses.dependIteratedKinds),
2485 clauses.dependIterated, clauses.device, clauses.ifExpr, clauses.mapVars,
2489LogicalResult TargetEnterDataOp::verify() {
2490 LogicalResult verifyDependVars =
2492 getDependIteratedKinds(), getDependIterated());
2493 return failed(verifyDependVars) ? verifyDependVars
2504 TargetExitDataOp::build(
2506 clauses.dependVars,
makeArrayAttr(ctx, clauses.dependIteratedKinds),
2507 clauses.dependIterated, clauses.device, clauses.ifExpr, clauses.mapVars,
2511LogicalResult TargetExitDataOp::verify() {
2512 LogicalResult verifyDependVars =
2514 getDependIteratedKinds(), getDependIterated());
2515 return failed(verifyDependVars) ? verifyDependVars
2526 TargetUpdateOp::build(builder, state,
makeArrayAttr(ctx, clauses.dependKinds),
2529 clauses.dependIterated, clauses.device, clauses.ifExpr,
2530 clauses.mapVars, clauses.nowait);
2533LogicalResult TargetUpdateOp::verify() {
2534 LogicalResult verifyDependVars =
2536 getDependIteratedKinds(), getDependIterated());
2537 return failed(verifyDependVars) ? verifyDependVars
2546 const TargetOperands &clauses) {
2551 builder, state, {}, {}, clauses.bare,
2552 makeArrayAttr(ctx, clauses.dependKinds), clauses.dependVars,
2553 makeArrayAttr(ctx, clauses.dependIteratedKinds), clauses.dependIterated,
2554 clauses.device, clauses.dynGroupprivateAccessGroup,
2555 clauses.dynGroupprivateFallback, clauses.dynGroupprivateSize,
2556 clauses.hasDeviceAddrVars, clauses.hostEvalVars, clauses.ifExpr,
2558 nullptr, clauses.isDevicePtrVars, clauses.mapVars,
2559 clauses.nowait, clauses.privateVars,
2560 makeArrayAttr(ctx, clauses.privateSyms), clauses.privateNeedsBarrier,
2561 clauses.threadLimitVars,
2565LogicalResult TargetOp::verify() {
2567 getDependIteratedKinds(),
2568 getDependIterated())))
2572 getHasDeviceAddrVars())))
2579 *
this, getDynGroupprivateAccessGroupAttr(),
2580 getDynGroupprivateFallbackAttr(), getDynGroupprivateSize())))
2589LogicalResult TargetOp::verifyRegions() {
2590 auto teamsOps = getOps<TeamsOp>();
2591 if (std::distance(teamsOps.begin(), teamsOps.end()) > 1)
2592 return emitError(
"target containing multiple 'omp.teams' nested ops");
2595 bool hostEvalTripCount;
2596 Operation *capturedOp = getInnermostCapturedOmpOp();
2597 TargetExecMode execMode = getKernelExecFlags(capturedOp, &hostEvalTripCount);
2598 for (
Value hostEvalArg :
2599 cast<BlockArgOpenMPOpInterface>(getOperation()).getHostEvalBlockArgs()) {
2601 if (
auto teamsOp = dyn_cast<TeamsOp>(user)) {
2603 if (hostEvalArg == teamsOp.getNumTeamsLower() ||
2604 llvm::is_contained(teamsOp.getNumTeamsUpperVars(), hostEvalArg) ||
2605 llvm::is_contained(teamsOp.getThreadLimitVars(), hostEvalArg))
2608 return emitOpError() <<
"host_eval argument only legal as 'num_teams' "
2609 "and 'thread_limit' in 'omp.teams'";
2611 if (
auto parallelOp = dyn_cast<ParallelOp>(user)) {
2612 if (execMode == TargetExecMode::spmd &&
2613 parallelOp->isAncestor(capturedOp) &&
2614 llvm::is_contained(parallelOp.getNumThreadsVars(), hostEvalArg))
2618 <<
"host_eval argument only legal as 'num_threads' in "
2619 "'omp.parallel' when representing target SPMD";
2621 if (
auto loopNestOp = dyn_cast<LoopNestOp>(user)) {
2622 if (hostEvalTripCount && loopNestOp.getOperation() == capturedOp &&
2623 (llvm::is_contained(loopNestOp.getLoopLowerBounds(), hostEvalArg) ||
2624 llvm::is_contained(loopNestOp.getLoopUpperBounds(), hostEvalArg) ||
2625 llvm::is_contained(loopNestOp.getLoopSteps(), hostEvalArg)))
2628 return emitOpError() <<
"host_eval argument only legal as loop bounds "
2629 "and steps in 'omp.loop_nest' when trip count "
2630 "must be evaluated in the host";
2633 return emitOpError() <<
"host_eval argument illegal use in '"
2634 << user->getName() <<
"' operation";
2643 assert(rootOp &&
"expected valid operation");
2660 bool isOmpDialect = op->
getDialect() == ompDialect;
2662 if (!isOmpDialect || !hasRegions)
2669 if (checkSingleMandatoryExec) {
2674 if (successor->isReachable(parentBlock))
2677 for (
Block &block : *parentRegion)
2679 !domInfo.
dominates(parentBlock, &block))
2686 if (&sibling != op && !siblingAllowedFn(&sibling))
2699Operation *TargetOp::getInnermostCapturedOmpOp() {
2700 auto *ompDialect =
getContext()->getLoadedDialect<omp::OpenMPDialect>();
2712 if (
auto memOp = dyn_cast<MemoryEffectOpInterface>(sibling)) {
2715 memOp.getEffects(effects);
2716 return !llvm::any_of(
2718 return isa<MemoryEffects::Write>(effect.
getEffect()) &&
2719 isa<SideEffects::AutomaticAllocationScopeResource>(
2729 WsloopOp *wsLoopOp) {
2731 if (!teamsOp.getNumTeamsUpperVars().empty())
2735 if (teamsOp.getNumReductionVars())
2737 if (wsLoopOp->getNumReductionVars())
2741 OffloadModuleInterface offloadMod =
2745 auto ompFlags = offloadMod.getFlags();
2748 return ompFlags.getAssumeTeamsOversubscription() &&
2749 ompFlags.getAssumeThreadsOversubscription();
2752TargetExecMode TargetOp::getKernelExecFlags(
Operation *capturedOp,
2753 bool *hostEvalTripCount) {
2759 assert((!capturedOp ||
2760 (targetOp && targetOp.getInnermostCapturedOmpOp() == capturedOp)) &&
2761 "unexpected captured op");
2763 if (hostEvalTripCount)
2764 *hostEvalTripCount =
false;
2767 if (!isa_and_present<LoopNestOp>(capturedOp))
2768 return TargetExecMode::generic;
2772 cast<LoopNestOp>(capturedOp).gatherWrappers(loopWrappers);
2773 assert(!loopWrappers.empty());
2775 LoopWrapperInterface *innermostWrapper = loopWrappers.begin();
2776 if (isa<SimdOp>(innermostWrapper))
2777 innermostWrapper = std::next(innermostWrapper);
2779 auto numWrappers = std::distance(innermostWrapper, loopWrappers.end());
2780 if (numWrappers != 1 && numWrappers != 2)
2781 return TargetExecMode::generic;
2784 if (numWrappers == 2) {
2785 WsloopOp *wsloopOp = dyn_cast<WsloopOp>(innermostWrapper);
2787 return TargetExecMode::generic;
2789 innermostWrapper = std::next(innermostWrapper);
2790 if (!isa<DistributeOp>(innermostWrapper))
2791 return TargetExecMode::generic;
2794 if (!isa_and_present<ParallelOp>(parallelOp))
2795 return TargetExecMode::generic;
2797 TeamsOp teamsOp = dyn_cast<TeamsOp>(parallelOp->
getParentOp());
2799 return TargetExecMode::generic;
2801 if (teamsOp->getParentOp() == targetOp.getOperation()) {
2802 TargetExecMode
result = TargetExecMode::spmd;
2804 result = TargetExecMode::no_loop;
2805 if (hostEvalTripCount)
2806 *hostEvalTripCount =
true;
2811 else if (isa<DistributeOp, LoopOp>(innermostWrapper)) {
2813 if (!isa_and_present<TeamsOp>(teamsOp))
2814 return TargetExecMode::generic;
2816 if (teamsOp->
getParentOp() != targetOp.getOperation())
2817 return TargetExecMode::generic;
2819 if (hostEvalTripCount)
2820 *hostEvalTripCount =
true;
2822 if (isa<LoopOp>(innermostWrapper))
2823 return TargetExecMode::spmd;
2825 return TargetExecMode::generic;
2828 else if (isa<WsloopOp>(innermostWrapper)) {
2830 if (!isa_and_present<ParallelOp>(parallelOp))
2831 return TargetExecMode::generic;
2833 if (parallelOp->
getParentOp() == targetOp.getOperation())
2834 return TargetExecMode::spmd;
2837 return TargetExecMode::generic;
2846 ParallelOp::build(builder, state,
ValueRange(),
2858 const ParallelOperands &clauses) {
2860 ParallelOp::build(builder, state, clauses.allocateVars, clauses.allocatorVars,
2861 clauses.ifExpr, clauses.numThreadsVars, clauses.privateVars,
2863 clauses.privateNeedsBarrier, clauses.procBindKind,
2864 clauses.reductionMod, clauses.reductionVars,
2869template <
typename OpType>
2871 auto privateVars = op.getPrivateVars();
2872 auto privateSyms = op.getPrivateSymsAttr();
2874 if (privateVars.empty() && (privateSyms ==
nullptr || privateSyms.empty()))
2877 auto numPrivateVars = privateVars.size();
2878 auto numPrivateSyms = (privateSyms ==
nullptr) ? 0 : privateSyms.size();
2880 if (numPrivateVars != numPrivateSyms)
2881 return op.emitError() <<
"inconsistent number of private variables and "
2882 "privatizer op symbols, private vars: "
2884 <<
" vs. privatizer op symbols: " << numPrivateSyms;
2886 for (
auto privateVarInfo : llvm::zip_equal(privateVars, privateSyms)) {
2887 Type varType = std::get<0>(privateVarInfo).getType();
2888 SymbolRefAttr privateSym = cast<SymbolRefAttr>(std::get<1>(privateVarInfo));
2889 PrivateClauseOp privatizerOp =
2892 if (privatizerOp ==
nullptr)
2893 return op.emitError() <<
"failed to lookup privatizer op with symbol: '"
2894 << privateSym <<
"'";
2896 Type privatizerType = privatizerOp.getArgType();
2898 if (privatizerType && (varType != privatizerType))
2899 return op.emitError()
2900 <<
"type mismatch between a "
2901 << (privatizerOp.getDataSharingType() ==
2902 DataSharingClauseType::Private
2905 <<
" variable and its privatizer op, var type: " << varType
2906 <<
" vs. privatizer op type: " << privatizerType;
2912LogicalResult ParallelOp::verify() {
2913 if (getAllocateVars().size() != getAllocatorVars().size())
2915 "expected equal sizes for allocate and allocator variables");
2921 getReductionByref());
2924LogicalResult ParallelOp::verifyRegions() {
2925 auto distChildOps = getOps<DistributeOp>();
2926 int numDistChildOps = std::distance(distChildOps.begin(), distChildOps.end());
2927 if (numDistChildOps > 1)
2929 <<
"multiple 'omp.distribute' nested inside of 'omp.parallel'";
2931 if (numDistChildOps == 1) {
2934 <<
"'omp.composite' attribute missing from composite operation";
2936 auto *ompDialect =
getContext()->getLoadedDialect<OpenMPDialect>();
2937 Operation &distributeOp = **distChildOps.begin();
2939 if (&childOp == &distributeOp || ompDialect != childOp.getDialect())
2943 return emitError() <<
"unexpected OpenMP operation inside of composite "
2945 << childOp.getName();
2947 }
else if (isComposite()) {
2949 <<
"'omp.composite' attribute present in non-composite operation";
2966 const TeamsOperands &clauses) {
2970 builder, state, clauses.allocateVars, clauses.allocatorVars,
2971 clauses.dynGroupprivateAccessGroup, clauses.dynGroupprivateFallback,
2972 clauses.dynGroupprivateSize, clauses.ifExpr, clauses.numTeamsLower,
2973 clauses.numTeamsUpperVars, {},
nullptr,
2974 nullptr, clauses.reductionMod,
2975 clauses.reductionVars,
2977 makeArrayAttr(ctx, clauses.reductionSyms), clauses.threadLimitVars);
2984 if (numTeamsLower) {
2985 if (numTeamsUpperVars.size() != 1)
2987 "expected exactly one num_teams upper bound when lower bound is "
2991 "expected num_teams upper bound and lower bound to be "
2998LogicalResult TeamsOp::verify() {
3007 return emitError(
"expected to be nested inside of omp.target or not nested "
3008 "in any OpenMP dialect operations");
3012 this->getNumTeamsUpperVars())))
3016 if (getAllocateVars().size() != getAllocatorVars().size())
3018 "expected equal sizes for allocate and allocator variables");
3021 op, getDynGroupprivateAccessGroupAttr(),
3022 getDynGroupprivateFallbackAttr(), getDynGroupprivateSize())))
3029 getReductionByref());
3037 return getParentOp().getPrivateVars();
3041 return getParentOp().getReductionVars();
3049 const SectionsOperands &clauses) {
3052 SectionsOp::build(builder, state, clauses.allocateVars, clauses.allocatorVars,
3055 clauses.reductionMod, clauses.reductionVars,
3060LogicalResult SectionsOp::verify() {
3061 if (getAllocateVars().size() != getAllocatorVars().size())
3063 "expected equal sizes for allocate and allocator variables");
3066 getReductionByref());
3069LogicalResult SectionsOp::verifyRegions() {
3070 for (
auto &inst : *getRegion().begin()) {
3071 if (!(isa<SectionOp>(inst) || isa<TerminatorOp>(inst))) {
3073 <<
"expected omp.section op or terminator op inside region";
3085 const ScopeOperands &clauses) {
3087 ScopeOp::build(builder, state, clauses.allocateVars, clauses.allocatorVars,
3088 clauses.nowait, clauses.privateVars,
3090 clauses.privateNeedsBarrier, clauses.reductionMod,
3091 clauses.reductionVars,
3096LogicalResult ScopeOp::verify() {
3097 if (getAllocateVars().size() != getAllocatorVars().size())
3099 "expected equal sizes for allocate and allocator variables");
3105 getReductionByref());
3113 const SingleOperands &clauses) {
3116 SingleOp::build(builder, state, clauses.allocateVars, clauses.allocatorVars,
3117 clauses.copyprivateVars,
3118 makeArrayAttr(ctx, clauses.copyprivateSyms), clauses.nowait,
3123LogicalResult SingleOp::verify() {
3125 if (getAllocateVars().size() != getAllocatorVars().size())
3127 "expected equal sizes for allocate and allocator variables");
3130 getCopyprivateSyms());
3138 const WorkshareOperands &clauses) {
3139 WorkshareOp::build(builder, state, clauses.nowait);
3146LogicalResult WorkshareLoopWrapperOp::verify() {
3147 if (!(*this)->getParentOfType<WorkshareOp>())
3148 return emitOpError() <<
"must be nested in an omp.workshare";
3152LogicalResult WorkshareLoopWrapperOp::verifyRegions() {
3153 if (isa_and_nonnull<LoopWrapperInterface>((*this)->getParentOp()) ||
3155 return emitOpError() <<
"expected to be a standalone loop wrapper";
3164LogicalResult LoopWrapperInterface::verifyImpl() {
3168 return emitOpError() <<
"loop wrapper must also have the `NoTerminator` "
3169 "and `SingleBlock` traits";
3172 return emitOpError() <<
"loop wrapper does not contain exactly one region";
3175 if (range_size(region.
getOps()) != 1)
3177 <<
"loop wrapper does not contain exactly one nested op";
3180 if (!isa<LoopNestOp, LoopWrapperInterface>(firstOp))
3181 return emitOpError() <<
"nested in loop wrapper is not another loop "
3182 "wrapper or `omp.loop_nest`";
3192 const LoopOperands &clauses) {
3195 LoopOp::build(builder, state, clauses.bindKind, clauses.privateVars,
3197 clauses.privateNeedsBarrier, clauses.order, clauses.orderMod,
3198 clauses.reductionMod, clauses.reductionVars,
3203LogicalResult LoopOp::verify() {
3208 getReductionByref());
3211LogicalResult LoopOp::verifyRegions() {
3212 if (llvm::isa_and_nonnull<LoopWrapperInterface>((*this)->getParentOp()) ||
3214 return emitOpError() <<
"expected to be a standalone loop wrapper";
3225 build(builder, state, {}, {},
3228 false,
nullptr,
nullptr,
3229 nullptr, {},
nullptr,
3240 const WsloopOperands &clauses) {
3245 {}, {}, clauses.linearVars,
3246 clauses.linearStepVars, clauses.linearVarTypes, clauses.linearModifiers,
3247 clauses.nowait, clauses.order, clauses.orderMod, clauses.ordered,
3248 clauses.privateVars,
makeArrayAttr(ctx, clauses.privateSyms),
3249 clauses.privateNeedsBarrier, clauses.reductionMod, clauses.reductionVars,
3251 makeArrayAttr(ctx, clauses.reductionSyms), clauses.scheduleKind,
3252 clauses.scheduleChunk, clauses.scheduleMod, clauses.scheduleSimd);
3255LogicalResult WsloopOp::verify() {
3259 if (getLinearVars().size() &&
3260 getLinearVarTypes().value().size() != getLinearVars().size())
3261 return emitError() <<
"Ill-formed type attributes for linear variables";
3267 getReductionByref());
3270LogicalResult WsloopOp::verifyRegions() {
3271 bool isCompositeChildLeaf =
3272 llvm::dyn_cast_if_present<LoopWrapperInterface>((*this)->getParentOp());
3274 if (LoopWrapperInterface nested = getNestedWrapper()) {
3277 <<
"'omp.composite' attribute missing from composite wrapper";
3281 if (!isa<SimdOp>(nested))
3282 return emitError() <<
"only supported nested wrapper is 'omp.simd'";
3284 }
else if (isComposite() && !isCompositeChildLeaf) {
3286 <<
"'omp.composite' attribute present in non-composite wrapper";
3287 }
else if (!isComposite() && isCompositeChildLeaf) {
3289 <<
"'omp.composite' attribute missing from composite wrapper";
3300 const SimdOperands &clauses) {
3302 SimdOp::build(builder, state, clauses.alignedVars,
3304 clauses.linearVars, clauses.linearStepVars,
3305 clauses.linearVarTypes, clauses.linearModifiers,
3306 clauses.nontemporalVars, clauses.order, clauses.orderMod,
3307 clauses.privateVars,
makeArrayAttr(ctx, clauses.privateSyms),
3308 clauses.privateNeedsBarrier, clauses.reductionMod,
3309 clauses.reductionVars,
3315LogicalResult SimdOp::verify() {
3316 if (getSimdlen().has_value() && getSafelen().has_value() &&
3317 getSimdlen().value() > getSafelen().value())
3319 <<
"simdlen clause and safelen clause are both present, but the "
3320 "simdlen value is not less than or equal to safelen value";
3332 bool isCompositeChildLeaf =
3333 llvm::dyn_cast_if_present<LoopWrapperInterface>((*this)->getParentOp());
3335 if (!isComposite() && isCompositeChildLeaf)
3337 <<
"'omp.composite' attribute missing from composite wrapper";
3339 if (isComposite() && !isCompositeChildLeaf)
3341 <<
"'omp.composite' attribute present in non-composite wrapper";
3345 std::optional<ArrayAttr> privateSyms = getPrivateSyms();
3347 for (
const Attribute &sym : *privateSyms) {
3348 auto symRef = cast<SymbolRefAttr>(sym);
3349 omp::PrivateClauseOp privatizer =
3351 getOperation(), symRef);
3353 return emitError() <<
"Cannot find privatizer '" << symRef <<
"'";
3354 if (privatizer.getDataSharingType() ==
3355 DataSharingClauseType::FirstPrivate)
3356 return emitError() <<
"FIRSTPRIVATE cannot be used with SIMD";
3363 if (getLinearVars().size() &&
3364 getLinearVarTypes().value().size() != getLinearVars().size())
3365 return emitError() <<
"Ill-formed type attributes for linear variables";
3369LogicalResult SimdOp::verifyRegions() {
3370 if (getNestedWrapper())
3371 return emitOpError() <<
"must wrap an 'omp.loop_nest' directly";
3381 const DistributeOperands &clauses) {
3382 DistributeOp::build(builder, state, clauses.allocateVars,
3383 clauses.allocatorVars, clauses.distScheduleStatic,
3384 clauses.distScheduleChunkSize, clauses.order,
3385 clauses.orderMod, clauses.privateVars,
3387 clauses.privateNeedsBarrier);
3390LogicalResult DistributeOp::verify() {
3391 if (this->getDistScheduleChunkSize() && !this->getDistScheduleStatic())
3393 "dist_schedule_static being present";
3395 if (getAllocateVars().size() != getAllocatorVars().size())
3397 "expected equal sizes for allocate and allocator variables");
3405LogicalResult DistributeOp::verifyRegions() {
3406 if (LoopWrapperInterface nested = getNestedWrapper()) {
3409 <<
"'omp.composite' attribute missing from composite wrapper";
3412 if (isa<WsloopOp>(nested)) {
3414 if (!llvm::dyn_cast_if_present<ParallelOp>(parentOp) ||
3415 !cast<ComposableOpInterface>(parentOp).isComposite()) {
3416 return emitError() <<
"an 'omp.wsloop' nested wrapper is only allowed "
3417 "when a composite 'omp.parallel' is the direct "
3420 }
else if (!isa<SimdOp>(nested))
3421 return emitError() <<
"only supported nested wrappers are 'omp.simd' and "
3423 }
else if (isComposite()) {
3425 <<
"'omp.composite' attribute present in non-composite wrapper";
3435LogicalResult DeclareMapperInfoOp::verify() {
3439LogicalResult DeclareMapperOp::verifyRegions() {
3440 if (!llvm::isa_and_present<DeclareMapperInfoOp>(
3441 getRegion().getBlocks().front().getTerminator()))
3442 return emitOpError() <<
"expected terminator to be a DeclareMapperInfoOp";
3451LogicalResult DeclareReductionOp::verifyRegions() {
3452 if (!getAllocRegion().empty()) {
3453 for (YieldOp yieldOp : getAllocRegion().getOps<YieldOp>()) {
3454 if (yieldOp.getResults().size() != 1 ||
3455 yieldOp.getResults().getTypes()[0] !=
getType())
3456 return emitOpError() <<
"expects alloc region to yield a value "
3457 "of the reduction type";
3461 if (getInitializerRegion().empty())
3462 return emitOpError() <<
"expects non-empty initializer region";
3463 Block &initializerEntryBlock = getInitializerRegion().
front();
3466 if (!getAllocRegion().empty())
3467 return emitOpError() <<
"expects two arguments to the initializer region "
3468 "when an allocation region is used";
3470 if (getAllocRegion().empty())
3471 return emitOpError() <<
"expects one argument to the initializer region "
3472 "when no allocation region is used";
3475 <<
"expects one or two arguments to the initializer region";
3479 if (arg.getType() !=
getType())
3480 return emitOpError() <<
"expects initializer region argument to match "
3481 "the reduction type";
3483 for (YieldOp yieldOp : getInitializerRegion().getOps<YieldOp>()) {
3484 if (yieldOp.getResults().size() != 1 ||
3485 yieldOp.getResults().getTypes()[0] !=
getType())
3486 return emitOpError() <<
"expects initializer region to yield a value "
3487 "of the reduction type";
3490 if (getReductionRegion().empty())
3491 return emitOpError() <<
"expects non-empty reduction region";
3492 Block &reductionEntryBlock = getReductionRegion().
front();
3497 return emitOpError() <<
"expects reduction region with two arguments of "
3498 "the reduction type";
3499 for (YieldOp yieldOp : getReductionRegion().getOps<YieldOp>()) {
3500 if (yieldOp.getResults().size() != 1 ||
3501 yieldOp.getResults().getTypes()[0] !=
getType())
3502 return emitOpError() <<
"expects reduction region to yield a value "
3503 "of the reduction type";
3506 if (!getAtomicReductionRegion().empty()) {
3507 Block &atomicReductionEntryBlock = getAtomicReductionRegion().
front();
3511 return emitOpError() <<
"expects atomic reduction region with two "
3512 "arguments of the same type";
3513 auto ptrType = llvm::dyn_cast<PointerLikeType>(
3516 (ptrType.getElementType() && ptrType.getElementType() !=
getType()))
3517 return emitOpError() <<
"expects atomic reduction region arguments to "
3518 "be accumulators containing the reduction type";
3521 if (getCleanupRegion().empty())
3523 Block &cleanupEntryBlock = getCleanupRegion().
front();
3526 return emitOpError() <<
"expects cleanup region with one argument "
3527 "of the reduction type";
3537 const TaskOperands &clauses) {
3540 builder, state, clauses.iterated, clauses.affinityVars,
3541 clauses.allocateVars, clauses.allocatorVars,
3542 makeArrayAttr(ctx, clauses.dependKinds), clauses.dependVars,
3543 makeArrayAttr(ctx, clauses.dependIteratedKinds), clauses.dependIterated,
3544 clauses.final, clauses.ifExpr, clauses.inReductionVars,
3546 makeArrayAttr(ctx, clauses.inReductionSyms), clauses.mergeable,
3547 clauses.priority, clauses.privateVars,
3549 clauses.privateNeedsBarrier, clauses.untied, clauses.eventHandle);
3552LogicalResult TaskOp::verify() {
3553 LogicalResult verifyDependVars =
3555 getDependIteratedKinds(), getDependIterated());
3556 if (
failed(verifyDependVars))
3557 return verifyDependVars;
3563 getInReductionVars(), getInReductionByref());
3571 const TaskgroupOperands &clauses) {
3573 TaskgroupOp::build(builder, state, clauses.allocateVars,
3574 clauses.allocatorVars, clauses.taskReductionVars,
3579LogicalResult TaskgroupOp::verify() {
3581 getTaskReductionVars(),
3582 getTaskReductionByref());
3590 const TaskloopContextOperands &clauses) {
3592 TaskloopContextOp::build(
3593 builder, state, clauses.allocateVars, clauses.allocatorVars,
3594 clauses.final, clauses.grainsizeMod, clauses.grainsize, clauses.ifExpr,
3595 clauses.inReductionVars,
3597 makeArrayAttr(ctx, clauses.inReductionSyms), clauses.mergeable,
3598 clauses.nogroup, clauses.numTasksMod, clauses.numTasks, clauses.priority,
3599 clauses.privateVars,
3601 clauses.privateNeedsBarrier, clauses.reductionMod, clauses.reductionVars,
3606TaskloopWrapperOp TaskloopContextOp::getLoopOp() {
3607 return cast<TaskloopWrapperOp>(
3609 return isa<TaskloopWrapperOp>(op);
3613LogicalResult TaskloopContextOp::verify() {
3614 if (getAllocateVars().size() != getAllocatorVars().size())
3616 "expected equal sizes for allocate and allocator variables");
3622 getReductionVars(), getReductionByref())) ||
3624 getInReductionVars(),
3625 getInReductionByref())))
3628 if (!getReductionVars().empty() && getNogroup())
3629 return emitError(
"if a reduction clause is present on the taskloop "
3630 "directive, the nogroup clause must not be specified");
3631 for (
auto var : getReductionVars()) {
3632 if (llvm::is_contained(getInReductionVars(), var))
3633 return emitError(
"the same list item cannot appear in both a reduction "
3634 "and an in_reduction clause");
3637 if (getGrainsize() && getNumTasks()) {
3639 "the grainsize clause and num_tasks clause are mutually exclusive and "
3640 "may not appear on the same taskloop directive");
3646LogicalResult TaskloopContextOp::verifyRegions() {
3647 Region ®ion = getRegion();
3649 return emitOpError() <<
"expected non-empty region";
3652 return isa<TaskloopWrapperOp>(op);
3656 <<
"expected exactly 1 TaskloopWrapperOp directly nested in "
3658 << count <<
" were found";
3659 TaskloopWrapperOp loopWrapperOp = getLoopOp();
3661 auto loopNestOp = dyn_cast<LoopNestOp>(loopWrapperOp.getWrappedLoop());
3667 std::function<
bool(
Value)> isValidBoundValue = [&](
Value value) ->
bool {
3668 Region *valueRegion = value.getParentRegion();
3674 Operation *defOp = value.getDefiningOp();
3678 return llvm::all_of(defOp->
getOperands(), isValidBoundValue);
3680 auto hasUnsupportedTaskloopLocalBound = [&](
OperandRange range) ->
bool {
3681 return llvm::any_of(range,
3682 [&](
Value value) {
return !isValidBoundValue(value); });
3685 if (hasUnsupportedTaskloopLocalBound(loopNestOp.getLoopLowerBounds()) ||
3686 hasUnsupportedTaskloopLocalBound(loopNestOp.getLoopUpperBounds()) ||
3687 hasUnsupportedTaskloopLocalBound(loopNestOp.getLoopSteps())) {
3689 <<
"expects loop bounds and steps to be defined outside of the "
3690 "taskloop.context region or by pure, regionless operations "
3691 "that do not depend on block arguments";
3702 const TaskloopWrapperOperands &clauses) {
3703 TaskloopWrapperOp::build(builder, state);
3706TaskloopContextOp TaskloopWrapperOp::getTaskloopContext() {
3707 return dyn_cast<TaskloopContextOp>(getOperation()->getParentOp());
3710LogicalResult TaskloopWrapperOp::verify() {
3711 TaskloopContextOp context = getTaskloopContext();
3713 return emitOpError() <<
"expected to be nested in a taskloop context op";
3717LogicalResult TaskloopWrapperOp::verifyRegions() {
3718 if (LoopWrapperInterface nested = getNestedWrapper()) {
3721 <<
"'omp.composite' attribute missing from composite wrapper";
3725 if (!isa<SimdOp>(nested))
3726 return emitError() <<
"only supported nested wrapper is 'omp.simd'";
3727 }
else if (isComposite()) {
3729 <<
"'omp.composite' attribute present in non-composite wrapper";
3753 for (
auto &iv : ivs)
3754 iv.type = loopVarType;
3759 result.addAttribute(
"loop_inclusive", UnitAttr::get(ctx));
3775 "collapse_num_loops",
3780 auto parseTiles = [&]() -> ParseResult {
3784 tiles.push_back(
tile);
3793 if (tiles.size() > 0)
3812 Region ®ion = getRegion();
3814 p <<
" (" << args <<
") : " << args[0].getType() <<
" = ("
3815 << getLoopLowerBounds() <<
") to (" << getLoopUpperBounds() <<
") ";
3816 if (getLoopInclusive())
3818 p <<
"step (" << getLoopSteps() <<
") ";
3819 if (
int64_t numCollapse = getCollapseNumLoops())
3820 if (numCollapse > 1)
3821 p <<
"collapse(" << numCollapse <<
") ";
3824 p <<
"tiles(" << tiles.value() <<
") ";
3830 const LoopNestOperands &clauses) {
3832 LoopNestOp::build(builder, state, clauses.collapseNumLoops,
3833 clauses.loopLowerBounds, clauses.loopUpperBounds,
3834 clauses.loopSteps, clauses.loopInclusive,
3838LogicalResult LoopNestOp::verify() {
3839 if (getLoopLowerBounds().empty())
3840 return emitOpError() <<
"must represent at least one loop";
3842 if (getLoopLowerBounds().size() != getIVs().size())
3843 return emitOpError() <<
"number of range arguments and IVs do not match";
3845 for (
auto [lb, iv] : llvm::zip_equal(getLoopLowerBounds(), getIVs())) {
3846 if (lb.getType() != iv.getType())
3848 <<
"range argument type does not match corresponding IV type";
3851 uint64_t numIVs = getIVs().size();
3853 if (
const auto &numCollapse = getCollapseNumLoops())
3854 if (numCollapse > numIVs)
3856 <<
"collapse value is larger than the number of loops";
3859 if (tiles.value().size() > numIVs)
3860 return emitOpError() <<
"too few canonical loops for tile dimensions";
3862 if (!llvm::dyn_cast_if_present<LoopWrapperInterface>((*this)->getParentOp()))
3863 return emitOpError() <<
"expects parent op to be a loop wrapper";
3868void LoopNestOp::gatherWrappers(
3871 while (
auto wrapper =
3872 llvm::dyn_cast_if_present<LoopWrapperInterface>(parent)) {
3873 wrappers.push_back(wrapper);
3882std::tuple<NewCliOp, OpOperand *, OpOperand *>
3888 return {{},
nullptr,
nullptr};
3891 "Unexpected type of cli");
3897 auto op = cast<LoopTransformationInterface>(use.getOwner());
3899 unsigned opnum = use.getOperandNumber();
3900 if (op.isGeneratee(opnum)) {
3901 assert(!gen &&
"Each CLI may have at most one def");
3903 }
else if (op.isApplyee(opnum)) {
3904 assert(!cons &&
"Each CLI may have at most one consumer");
3907 llvm_unreachable(
"Unexpected operand for a CLI");
3911 return {create, gen, cons};
3934 std::string cliName{
"cli"};
3938 .Case([&](CanonicalLoopOp op) {
3941 .Case([&](UnrollHeuristicOp op) -> std::string {
3942 llvm_unreachable(
"heuristic unrolling does not generate a loop");
3944 .Case([&](FuseOp op) -> std::string {
3945 unsigned opnum =
generator->getOperandNumber();
3948 if (op.getFirst().has_value() && opnum != op.getFirst().value())
3949 return "canonloop_fuse";
3953 .Case([&](TileOp op) -> std::string {
3954 auto [generateesFirst, generateesCount] =
3955 op.getGenerateesODSOperandIndexAndLength();
3956 unsigned firstGrid = generateesFirst;
3957 unsigned firstIntratile = generateesFirst + generateesCount / 2;
3958 unsigned end = generateesFirst + generateesCount;
3959 unsigned opnum =
generator->getOperandNumber();
3961 if (firstGrid <= opnum && opnum < firstIntratile) {
3962 unsigned gridnum = opnum - firstGrid + 1;
3963 return (
"grid" + Twine(gridnum)).str();
3965 if (firstIntratile <= opnum && opnum < end) {
3966 unsigned intratilenum = opnum - firstIntratile + 1;
3967 return (
"intratile" + Twine(intratilenum)).str();
3969 llvm_unreachable(
"Unexpected generatee argument");
3971 .DefaultUnreachable(
"TODO: Custom name for this operation");
3974 setNameFn(
result, cliName);
3977LogicalResult NewCliOp::verify() {
3978 Value cli = getResult();
3981 "Unexpected type of cli");
3987 auto op = cast<mlir::omp::LoopTransformationInterface>(use.getOwner());
3989 unsigned opnum = use.getOperandNumber();
3990 if (op.isGeneratee(opnum)) {
3993 emitOpError(
"CLI must have at most one generator");
3995 .
append(
"first generator here:");
3997 .
append(
"second generator here:");
4002 }
else if (op.isApplyee(opnum)) {
4005 emitOpError(
"CLI must have at most one consumer");
4007 .
append(
"first consumer here:")
4011 .
append(
"second consumer here:")
4018 llvm_unreachable(
"Unexpected operand for a CLI");
4026 .
append(
"see consumer here: ")
4049 setNameFn(&getRegion().front(),
"body_entry");
4052void CanonicalLoopOp::getAsmBlockArgumentNames(
Region ®ion,
4060 p <<
'(' << getCli() <<
')';
4061 p <<
' ' << getInductionVar() <<
" : " << getInductionVar().getType()
4062 <<
" in range(" << getTripCount() <<
") ";
4072 CanonicalLoopInfoType cliType =
4073 CanonicalLoopInfoType::get(parser.
getContext());
4098 if (parser.
parseRegion(*region, {inductionVariable}))
4103 result.operands.append(cliOperand);
4109 return mlir::success();
4112LogicalResult CanonicalLoopOp::verify() {
4115 if (!getRegion().empty()) {
4116 Region ®ion = getRegion();
4119 "Canonical loop region must have exactly one argument");
4123 "Region argument must be the same type as the trip count");
4129Value CanonicalLoopOp::getInductionVar() {
return getRegion().getArgument(0); }
4131std::pair<unsigned, unsigned>
4132CanonicalLoopOp::getApplyeesODSOperandIndexAndLength() {
4137std::pair<unsigned, unsigned>
4138CanonicalLoopOp::getGenerateesODSOperandIndexAndLength() {
4139 return getODSOperandIndexAndLength(odsIndex_cli);
4153 p <<
'(' << getApplyee() <<
')';
4160 auto cliType = CanonicalLoopInfoType::get(parser.
getContext());
4183 return mlir::success();
4186std::pair<unsigned, unsigned>
4187UnrollHeuristicOp ::getApplyeesODSOperandIndexAndLength() {
4188 return getODSOperandIndexAndLength(odsIndex_applyee);
4191std::pair<unsigned, unsigned>
4192UnrollHeuristicOp::getGenerateesODSOperandIndexAndLength() {
4203 if (!generatees.empty())
4204 p <<
'(' << llvm::interleaved(generatees) <<
')';
4206 if (!applyees.empty())
4207 p <<
" <- (" << llvm::interleaved(applyees) <<
')';
4249 bool isOnlyCanonLoops =
true;
4251 for (
Value applyee : op.getApplyees()) {
4252 auto [create, gen, cons] =
decodeCli(applyee);
4255 return op.emitOpError() <<
"applyee CLI has no generator";
4257 auto loop = dyn_cast_or_null<CanonicalLoopOp>(gen->getOwner());
4258 canonLoops.push_back(loop);
4260 isOnlyCanonLoops =
false;
4265 if (!isOnlyCanonLoops)
4269 for (
auto i : llvm::seq<int>(1, canonLoops.size())) {
4270 auto parentLoop = canonLoops[i - 1];
4271 auto loop = canonLoops[i];
4273 if (parentLoop.getOperation() != loop.getOperation()->getParentOp())
4274 return op.emitOpError()
4275 <<
"tiled loop nest must be nested within each other";
4277 parentIVs.insert(parentLoop.getInductionVar());
4282 bool isPerfectlyNested = [&]() {
4283 auto &parentBody = parentLoop.getRegion();
4284 if (!parentBody.hasOneBlock())
4286 auto &parentBlock = parentBody.getBlocks().front();
4288 auto nestedLoopIt = parentBlock.begin();
4289 if (nestedLoopIt == parentBlock.end() ||
4290 (&*nestedLoopIt != loop.getOperation()))
4293 auto termIt = std::next(nestedLoopIt);
4294 if (termIt == parentBlock.end() || !isa<TerminatorOp>(termIt))
4297 if (std::next(termIt) != parentBlock.end())
4302 if (!isPerfectlyNested)
4303 return op.emitOpError() <<
"tiled loop nest must be perfectly nested";
4305 if (parentIVs.contains(loop.getTripCount()))
4306 return op.emitOpError() <<
"tiled loop nest must be rectangular";
4323LogicalResult TileOp::verify() {
4324 if (getApplyees().empty())
4325 return emitOpError() <<
"must apply to at least one loop";
4327 if (getSizes().size() != getApplyees().size())
4328 return emitOpError() <<
"there must be one tile size for each applyee";
4330 if (!getGeneratees().empty() &&
4331 2 * getSizes().size() != getGeneratees().size())
4333 <<
"expecting two times the number of generatees than applyees";
4338std::pair<unsigned, unsigned> TileOp ::getApplyeesODSOperandIndexAndLength() {
4339 return getODSOperandIndexAndLength(odsIndex_applyees);
4342std::pair<unsigned, unsigned> TileOp::getGenerateesODSOperandIndexAndLength() {
4343 return getODSOperandIndexAndLength(odsIndex_generatees);
4353 if (!generatees.empty())
4354 p <<
'(' << llvm::interleaved(generatees) <<
')';
4356 if (!applyees.empty())
4357 p <<
" <- (" << llvm::interleaved(applyees) <<
')';
4360LogicalResult FuseOp::verify() {
4361 if (getApplyees().size() < 2)
4362 return emitOpError() <<
"must apply to at least two loops";
4364 if (getFirst().has_value() && getCount().has_value()) {
4365 int64_t first = getFirst().value();
4366 int64_t count = getCount().value();
4367 if ((
unsigned)(first + count - 1) > getApplyees().size())
4368 return emitOpError() <<
"the numbers of applyees must be at least first "
4369 "minus one plus count attributes";
4370 if (!getGeneratees().empty() &&
4371 getGeneratees().size() != getApplyees().size() + 1 - count)
4372 return emitOpError() <<
"the number of generatees must be the number of "
4373 "aplyees plus one minus count";
4376 if (!getGeneratees().empty() && getGeneratees().size() != 1)
4378 <<
"in a complete fuse the number of generatees must be exactly 1";
4380 for (
auto &&applyee : getApplyees()) {
4381 auto [create, gen, cons] =
decodeCli(applyee);
4384 return emitOpError() <<
"applyee CLI has no generator";
4385 auto loop = dyn_cast_or_null<CanonicalLoopOp>(gen->getOwner());
4388 <<
"currently only supports omp.canonical_loop as applyee";
4392std::pair<unsigned, unsigned> FuseOp::getApplyeesODSOperandIndexAndLength() {
4393 return getODSOperandIndexAndLength(odsIndex_applyees);
4396std::pair<unsigned, unsigned> FuseOp::getGenerateesODSOperandIndexAndLength() {
4397 return getODSOperandIndexAndLength(odsIndex_generatees);
4405 const CriticalDeclareOperands &clauses) {
4406 CriticalDeclareOp::build(builder, state, clauses.symName, clauses.hint);
4409LogicalResult CriticalDeclareOp::verify() {
4414 if (getNameAttr()) {
4415 SymbolRefAttr symbolRef = getNameAttr();
4419 return emitOpError() <<
"expected symbol reference " << symbolRef
4420 <<
" to point to a critical declaration";
4440 return op.
emitOpError() <<
"must be nested inside of a loop";
4444 if (
auto wsloopOp = dyn_cast<WsloopOp>(wrapper)) {
4445 IntegerAttr orderedAttr = wsloopOp.getOrderedAttr();
4447 return op.
emitOpError() <<
"the enclosing worksharing-loop region must "
4448 "have an ordered clause";
4450 if (hasRegion && orderedAttr.getInt() != 0)
4451 return op.
emitOpError() <<
"the enclosing loop's ordered clause must not "
4452 "have a parameter present";
4454 if (!hasRegion && orderedAttr.getInt() == 0)
4455 return op.
emitOpError() <<
"the enclosing loop's ordered clause must "
4456 "have a parameter present";
4457 }
else if (!isa<SimdOp>(wrapper)) {
4458 return op.
emitOpError() <<
"must be nested inside of a worksharing, simd "
4459 "or worksharing simd loop";
4465 const OrderedOperands &clauses) {
4466 OrderedOp::build(builder, state, clauses.doacrossDependType,
4467 clauses.doacrossNumLoops, clauses.doacrossDependVars);
4470LogicalResult OrderedOp::verify() {
4474 auto wrapper = (*this)->getParentOfType<WsloopOp>();
4475 if (!wrapper || *wrapper.getOrdered() != *getDoacrossNumLoops())
4476 return emitOpError() <<
"number of variables in depend clause does not "
4477 <<
"match number of iteration variables in the "
4484 const OrderedRegionOperands &clauses) {
4485 OrderedRegionOp::build(builder, state, clauses.parLevelSimd);
4495 const TaskwaitOperands &clauses) {
4497 TaskwaitOp::build(builder, state,
nullptr,
4506LogicalResult AtomicReadOp::verify() {
4507 if (verifyCommon().
failed())
4508 return mlir::failure();
4511 if (
auto moduleOp = getOperation()->getParentOfType<ModuleOp>())
4512 if (
Attribute verAttr = moduleOp->getAttr(
"omp.version"))
4513 version = llvm::cast<VersionAttr>(verAttr).getVersion();
4515 if (
auto mo = getMemoryOrder()) {
4516 if (*mo == ClauseMemoryOrderKind::Release) {
4517 return emitError(
"memory-order must not be release for atomic reads");
4519 if (*mo == ClauseMemoryOrderKind::Acq_rel) {
4522 return emitError(
"memory-order must not be acq_rel for atomic reads");
4532LogicalResult AtomicWriteOp::verify() {
4533 if (verifyCommon().
failed())
4534 return mlir::failure();
4537 if (
auto moduleOp = getOperation()->getParentOfType<ModuleOp>())
4538 if (
Attribute verAttr = moduleOp->getAttr(
"omp.version"))
4539 version = llvm::cast<VersionAttr>(verAttr).getVersion();
4541 if (
auto mo = getMemoryOrder()) {
4542 if (*mo == ClauseMemoryOrderKind::Acquire) {
4543 return emitError(
"memory-order must not be acquire for atomic writes");
4545 if (*mo == ClauseMemoryOrderKind::Acq_rel) {
4548 return emitError(
"memory-order must not be acq_rel for atomic writes");
4558LogicalResult AtomicUpdateOp::canonicalize(AtomicUpdateOp op,
4564 if (
Value writeVal = op.getWriteOpVal()) {
4566 op, op.getX(), writeVal, op.getHintAttr(), op.getMemoryOrderAttr());
4572LogicalResult AtomicUpdateOp::verify() {
4573 if (verifyCommon().
failed())
4574 return mlir::failure();
4577 if (
auto moduleOp = getOperation()->getParentOfType<ModuleOp>())
4578 if (
Attribute verAttr = moduleOp->getAttr(
"omp.version"))
4579 version = llvm::cast<VersionAttr>(verAttr).getVersion();
4581 if (
auto mo = getMemoryOrder()) {
4582 if (*mo == ClauseMemoryOrderKind::Acq_rel ||
4583 *mo == ClauseMemoryOrderKind::Acquire) {
4587 "memory-order must not be acq_rel or acquire for atomic updates");
4594LogicalResult AtomicUpdateOp::verifyRegions() {
return verifyRegionsCommon(); }
4600AtomicReadOp AtomicCaptureOp::getAtomicReadOp() {
4601 if (
auto op = dyn_cast<AtomicReadOp>(getFirstOp()))
4603 return dyn_cast<AtomicReadOp>(getSecondOp());
4606AtomicWriteOp AtomicCaptureOp::getAtomicWriteOp() {
4607 if (
auto op = dyn_cast<AtomicWriteOp>(getFirstOp()))
4609 return dyn_cast<AtomicWriteOp>(getSecondOp());
4612AtomicUpdateOp AtomicCaptureOp::getAtomicUpdateOp() {
4613 if (
auto op = dyn_cast<AtomicUpdateOp>(getFirstOp()))
4615 return dyn_cast<AtomicUpdateOp>(getSecondOp());
4618LogicalResult AtomicCaptureOp::verify() {
4622LogicalResult AtomicCaptureOp::verifyRegions() {
4623 if (verifyRegionsCommon().
failed())
4624 return mlir::failure();
4626 if (getFirstOp()->getAttr(
"hint") || getSecondOp()->getAttr(
"hint"))
4628 "operations inside capture region must not have hint clause");
4630 if (getFirstOp()->getAttr(
"memory_order") ||
4631 getSecondOp()->getAttr(
"memory_order"))
4633 "operations inside capture region must not have memory_order clause");
4641LogicalResult AtomicCompareOp::verify() {
4642 if (verifyCommon().
failed())
4643 return mlir::failure();
4647LogicalResult AtomicCompareOp::verifyRegions() {
4648 if (verifyRegionsCommon().
failed())
4649 return mlir::failure();
4651 if (verifyOperator().
failed())
4652 return mlir::failure();
4657 if (!terminator || !isa<YieldOp>(terminator))
4658 return emitOpError(
"region must be terminated with omp.yield");
4668 const CancelOperands &clauses) {
4669 CancelOp::build(builder, state, clauses.cancelDirective, clauses.ifExpr);
4682LogicalResult CancelOp::verify() {
4683 ClauseCancellationConstructType cct = getCancelDirective();
4686 if (!structuralParent)
4687 return emitOpError() <<
"Orphaned cancel construct";
4689 if ((cct == ClauseCancellationConstructType::Parallel) &&
4690 !mlir::isa<ParallelOp>(structuralParent)) {
4691 return emitOpError() <<
"cancel parallel must appear "
4692 <<
"inside a parallel region";
4694 if (cct == ClauseCancellationConstructType::Loop) {
4697 auto wsloopOp = mlir::dyn_cast<WsloopOp>(structuralParent->
getParentOp());
4701 <<
"cancel loop must appear inside a worksharing-loop region";
4703 if (wsloopOp.getNowaitAttr()) {
4704 return emitError() <<
"A worksharing construct that is canceled "
4705 <<
"must not have a nowait clause";
4707 if (wsloopOp.getOrderedAttr()) {
4708 return emitError() <<
"A worksharing construct that is canceled "
4709 <<
"must not have an ordered clause";
4712 }
else if (cct == ClauseCancellationConstructType::Sections) {
4716 mlir::dyn_cast<SectionsOp>(structuralParent->
getParentOp());
4718 return emitOpError() <<
"cancel sections must appear "
4719 <<
"inside a sections region";
4721 if (sectionsOp.getNowait()) {
4722 return emitError() <<
"A sections construct that is canceled "
4723 <<
"must not have a nowait clause";
4726 if ((cct == ClauseCancellationConstructType::Taskgroup) &&
4727 (!mlir::isa<omp::TaskOp>(structuralParent) &&
4728 !mlir::isa<omp::TaskloopWrapperOp>(structuralParent->
getParentOp()))) {
4729 return emitOpError() <<
"cancel taskgroup must appear "
4730 <<
"inside a task region";
4740 const CancellationPointOperands &clauses) {
4741 CancellationPointOp::build(builder, state, clauses.cancelDirective);
4744LogicalResult CancellationPointOp::verify() {
4745 ClauseCancellationConstructType cct = getCancelDirective();
4748 if (!structuralParent)
4749 return emitOpError() <<
"Orphaned cancellation point";
4751 if ((cct == ClauseCancellationConstructType::Parallel) &&
4752 !mlir::isa<ParallelOp>(structuralParent)) {
4753 return emitOpError() <<
"cancellation point parallel must appear "
4754 <<
"inside a parallel region";
4758 if ((cct == ClauseCancellationConstructType::Loop) &&
4759 !mlir::isa<WsloopOp>(structuralParent->
getParentOp())) {
4760 return emitOpError() <<
"cancellation point loop must appear "
4761 <<
"inside a worksharing-loop region";
4763 if ((cct == ClauseCancellationConstructType::Sections) &&
4764 !mlir::isa<omp::SectionOp>(structuralParent)) {
4765 return emitOpError() <<
"cancellation point sections must appear "
4766 <<
"inside a sections region";
4768 if ((cct == ClauseCancellationConstructType::Taskgroup) &&
4769 (!mlir::isa<omp::TaskOp>(structuralParent) &&
4770 !mlir::isa<omp::TaskloopWrapperOp>(structuralParent->
getParentOp()))) {
4771 return emitOpError() <<
"cancellation point taskgroup must appear "
4772 <<
"inside a task region";
4781LogicalResult MapBoundsOp::verify() {
4782 auto extent = getExtent();
4784 if (!extent && !upperbound)
4785 return emitError(
"expected extent or upperbound.");
4792 PrivateClauseOp::build(
4793 odsBuilder, odsState, symName, type,
4794 DataSharingClauseTypeAttr::get(odsBuilder.
getContext(),
4795 DataSharingClauseType::Private));
4798LogicalResult PrivateClauseOp::verifyRegions() {
4799 Type argType = getArgType();
4800 auto verifyTerminator = [&](
Operation *terminator,
4801 bool yieldsValue) -> LogicalResult {
4805 if (!llvm::isa<YieldOp>(terminator))
4807 <<
"expected exit block terminator to be an `omp.yield` op.";
4809 YieldOp yieldOp = llvm::cast<YieldOp>(terminator);
4810 TypeRange yieldedTypes = yieldOp.getResults().getTypes();
4813 if (yieldedTypes.empty())
4817 <<
"Did not expect any values to be yielded.";
4820 if (yieldedTypes.size() == 1 && yieldedTypes.front() == argType)
4824 <<
"Invalid yielded value. Expected type: " << argType
4827 if (yieldedTypes.empty())
4830 error << yieldedTypes;
4836 StringRef regionName,
4837 bool yieldsValue) -> LogicalResult {
4838 assert(!region.
empty());
4842 <<
"`" << regionName <<
"`: "
4843 <<
"expected " << expectedNumArgs
4846 for (
Block &block : region) {
4859 for (
Region *region : getRegions())
4860 for (
Type ty : region->getArgumentTypes())
4862 return emitError() <<
"Region argument type mismatch: got " << ty
4863 <<
" expected " << argType <<
".";
4866 if (!initRegion.
empty() &&
4871 DataSharingClauseType dsType = getDataSharingType();
4873 if (dsType == DataSharingClauseType::Private && !getCopyRegion().empty())
4874 return emitError(
"`private` clauses do not require a `copy` region.");
4876 if (dsType == DataSharingClauseType::FirstPrivate && getCopyRegion().empty())
4878 "`firstprivate` clauses require at least a `copy` region.");
4880 if (dsType == DataSharingClauseType::FirstPrivate &&
4885 if (!getDeallocRegion().empty() &&
4898 const MaskedOperands &clauses) {
4899 MaskedOp::build(builder, state, clauses.filteredThreadId);
4907 const ScanOperands &clauses) {
4908 ScanOp::build(builder, state, clauses.inclusiveVars, clauses.exclusiveVars);
4911LogicalResult ScanOp::verify() {
4912 if (hasExclusiveVars() == hasInclusiveVars())
4914 "Exactly one of EXCLUSIVE or INCLUSIVE clause is expected");
4915 if (WsloopOp parentWsLoopOp = (*this)->getParentOfType<WsloopOp>()) {
4916 if (parentWsLoopOp.getReductionModAttr() &&
4917 parentWsLoopOp.getReductionModAttr().getValue() ==
4918 ReductionModifier::inscan)
4921 if (SimdOp parentSimdOp = (*this)->getParentOfType<SimdOp>()) {
4922 if (parentSimdOp.getReductionModAttr() &&
4923 parentSimdOp.getReductionModAttr().getValue() ==
4924 ReductionModifier::inscan)
4927 return emitError(
"SCAN directive needs to be enclosed within a parent "
4928 "worksharing loop construct or SIMD construct with INSCAN "
4929 "reduction modifier");
4934 std::optional<uint64_t> alignment) {
4935 if (alignment.has_value()) {
4936 if ((alignment.value() != 0) && !llvm::has_single_bit(alignment.value()))
4938 <<
"ALIGN value : " << alignment.value() <<
" must be power of 2";
4943LogicalResult AllocateDirOp::verify() {
4951LogicalResult AllocSharedMemOp::verify() {
4959LogicalResult FreeSharedMemOp::verify() {
4967LogicalResult WorkdistributeOp::verify() {
4969 Region ®ion = getRegion();
4974 if (entryBlock.
empty())
4975 return emitOpError(
"region must contain a structured block");
4977 bool hasTerminator =
false;
4978 for (
Block &block : region) {
4979 if (isa<TerminatorOp>(block.
back())) {
4980 if (hasTerminator) {
4981 return emitOpError(
"region must have exactly one terminator");
4983 hasTerminator =
true;
4986 if (!hasTerminator) {
4987 return emitOpError(
"region must be terminated with omp.terminator");
4991 if (isa<BarrierOp>(op)) {
4993 "explicit barriers are not allowed in workdistribute region");
4996 if (isa<ParallelOp>(op)) {
4998 "nested parallel constructs not allowed in workdistribute");
5000 if (isa<TeamsOp>(op)) {
5002 "nested teams constructs not allowed in workdistribute");
5006 if (walkResult.wasInterrupted())
5010 if (!llvm::dyn_cast<TeamsOp>(parentOp))
5011 return emitOpError(
"workdistribute must be nested under teams");
5019LogicalResult DeclareSimdOp::verify() {
5022 dyn_cast_if_present<mlir::FunctionOpInterface>((*this)->getParentOp());
5024 return emitOpError() <<
"must be nested inside a function";
5026 if (getInbranch() && getNotinbranch())
5027 return emitOpError(
"cannot have both 'inbranch' and 'notinbranch'");
5037 const DeclareSimdOperands &clauses) {
5039 DeclareSimdOp::build(odsBuilder, odsState, clauses.alignedVars,
5041 clauses.linearVars, clauses.linearStepVars,
5042 clauses.linearVarTypes, clauses.linearModifiers,
5043 clauses.notinbranch, clauses.simdlen,
5044 clauses.uniformVars);
5061 return mlir::failure();
5062 return mlir::success();
5069 for (
unsigned i = 0; i < uniformVars.size(); ++i) {
5072 p << uniformVars[i] <<
" : " << uniformTypes[i];
5087 parser, iterated, iteratedTypes, affinityVars, affinityVarTypes,
5088 [&]() -> ParseResult {
return success(); })))
5122 OpAsmParser::Argument &arg = ivArgs.emplace_back();
5123 if (parser.parseArgument(arg))
5127 if (succeeded(parser.parseOptionalColon())) {
5128 if (parser.parseType(arg.type))
5131 arg.type = parser.getBuilder().getIndexType();
5143 OpAsmParser::UnresolvedOperand lb, ub, st;
5144 if (parser.parseOperand(lb) || parser.parseKeyword(
"to") ||
5145 parser.parseOperand(ub) || parser.parseKeyword(
"step") ||
5146 parser.parseOperand(st))
5151 steps.push_back(st);
5159 if (ivArgs.size() != lbs.size())
5161 <<
"mismatch: " << ivArgs.size() <<
" variables but " << lbs.size()
5164 for (
auto &arg : ivArgs) {
5165 lbTypes.push_back(arg.type);
5166 ubTypes.push_back(arg.type);
5167 stepTypes.push_back(arg.type);
5187 for (
unsigned i = 0, e = lbs.size(); i < e; ++i) {
5190 p << lbs[i] <<
" to " << ubs[i] <<
" step " << steps[i];
5198LogicalResult IteratorOp::verify() {
5199 auto iteratedTy = llvm::dyn_cast<omp::IteratedType>(getIterated().
getType());
5201 return emitOpError() <<
"result must be omp.iterated<entry_ty>";
5203 for (
auto [lb,
ub, step] : llvm::zip_equal(
5204 getLoopLowerBounds(), getLoopUpperBounds(), getLoopSteps())) {
5206 return emitOpError() <<
"loop step must not be zero";
5210 IntegerAttr stepAttr;
5216 const APInt &lbVal = lbAttr.getValue();
5217 const APInt &ubVal = ubAttr.getValue();
5218 const APInt &stepVal = stepAttr.getValue();
5219 if (stepVal.isStrictlyPositive() && lbVal.sgt(ubVal))
5220 return emitOpError() <<
"positive loop step requires lower bound to be "
5221 "less than or equal to upper bound";
5222 if (stepVal.isNegative() && lbVal.slt(ubVal))
5223 return emitOpError() <<
"negative loop step requires lower bound to be "
5224 "greater than or equal to upper bound";
5227 Block &
b = getRegion().front();
5228 auto yield = llvm::dyn_cast<omp::YieldOp>(
b.getTerminator());
5231 return emitOpError() <<
"region must be terminated by omp.yield";
5233 if (yield.getNumOperands() != 1)
5235 <<
"omp.yield in omp.iterator region must yield exactly one value";
5237 mlir::Type yieldedTy = yield.getOperand(0).getType();
5238 mlir::Type elemTy = iteratedTy.getElementType();
5240 if (yieldedTy != elemTy)
5241 return emitOpError() <<
"omp.iterated element type (" << elemTy
5242 <<
") does not match omp.yield operand type ("
5243 << yieldedTy <<
")";
5256 return emitOpError() <<
"expected symbol reference '" << getSymName()
5257 <<
"' to point to a global variable";
5259 if (isa<FunctionOpInterface>(symbol))
5260 return emitOpError() <<
"expected symbol reference '" << getSymName()
5261 <<
"' to point to a global variable, not a function";
5266#define GET_ATTRDEF_CLASSES
5267#include "mlir/Dialect/OpenMP/OpenMPOpsAttributes.cpp.inc"
5269#define GET_OP_CLASSES
5270#include "mlir/Dialect/OpenMP/OpenMPOps.cpp.inc"
5272#define GET_TYPEDEF_CLASSES
5273#include "mlir/Dialect/OpenMP/OpenMPOpsTypes.cpp.inc"
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
static std::optional< int64_t > getUpperBound(Value iv)
Gets the constant upper bound on an affine.for iv.
static LogicalResult verifyRegion(emitc::SwitchOp op, Region ®ion, const Twine &name)
static Type getElementType(Type type)
Determine the element type of type.
static const mlir::GenInfo * generator
static LogicalResult verifyNontemporalClause(Operation *op, OperandRange nontemporalVars)
static DenseI64ArrayAttr makeDenseI64ArrayAttr(MLIRContext *ctx, const ArrayRef< int64_t > intArray)
static void printDependVarList(OpAsmPrinter &p, Operation *op, OperandRange dependVars, TypeRange dependTypes, std::optional< ArrayAttr > dependKinds, OperandRange iteratedVars, TypeRange iteratedTypes, std::optional< ArrayAttr > iteratedKinds)
Print Depend clause.
static constexpr StringRef getPrivateNeedsBarrierSpelling()
static void printHeapAllocClause(OpAsmPrinter &p, Operation *op, TypeAttr inType, ValueRange typeparams, TypeRange typeparamsTypes, ValueRange shape, TypeRange shapeTypes)
static LogicalResult verifyMapClause(Operation *op, OperandRange mapVars)
static LogicalResult verifyReductionVarList(Operation *op, std::optional< ArrayAttr > reductionSyms, OperandRange reductionVars, std::optional< ArrayRef< bool > > reductionByref)
Verifies Reduction Clause.
static ParseResult parseLinearClause(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &linearVars, SmallVectorImpl< Type > &linearTypes, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &linearStepVars, SmallVectorImpl< Type > &linearStepTypes, ArrayAttr &linearModifiers)
linear ::= linear ( linear-list ) linear-list := linear-val | linear-val linear-list linear-val := ss...
static ParseResult parseInReductionPrivateRegion(OpAsmParser &parser, Region ®ion, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &inReductionVars, SmallVectorImpl< Type > &inReductionTypes, DenseBoolArrayAttr &inReductionByref, ArrayAttr &inReductionSyms, llvm::SmallVectorImpl< OpAsmParser::UnresolvedOperand > &privateVars, llvm::SmallVectorImpl< Type > &privateTypes, ArrayAttr &privateSyms, UnitAttr &privateNeedsBarrier)
static ArrayAttr makeArrayAttr(MLIRContext *context, llvm::ArrayRef< Attribute > attrs)
static ParseResult parseClauseAttr(AsmParser &parser, ClauseAttr &attr)
static void printDynGroupprivateClause(OpAsmPrinter &printer, Operation *op, AccessGroupModifierAttr modifierFirst, FallbackModifierAttr modifierSecond, Value dynGroupprivateSize, Type sizeType)
static void printAllocateAndAllocator(OpAsmPrinter &p, Operation *op, OperandRange allocateVars, TypeRange allocateTypes, OperandRange allocatorVars, TypeRange allocatorTypes)
Print allocate clause.
static DenseBoolArrayAttr makeDenseBoolArrayAttr(MLIRContext *ctx, const ArrayRef< bool > boolArray)
static std::string generateLoopNestingName(StringRef prefix, CanonicalLoopOp op)
Generate a name of a canonical loop nest of the format <prefix>(_r<idx>_s<idx>)*.
static ParseResult parseAffinityClause(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &iterated, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &affinityVars, SmallVectorImpl< Type > &iteratedTypes, SmallVectorImpl< Type > &affinityVarTypes)
static void printClauseWithRegionArgs(OpAsmPrinter &p, MLIRContext *ctx, StringRef clauseName, ValueRange argsSubrange, ValueRange operands, TypeRange types, ArrayAttr symbols=nullptr, DenseI64ArrayAttr mapIndices=nullptr, DenseBoolArrayAttr byref=nullptr, ReductionModifierAttr modifier=nullptr, UnitAttr needsBarrier=nullptr)
static void printSplitIteratedList(OpAsmPrinter &p, ValueRange iteratedVars, TypeRange iteratedTypes, ValueRange plainVars, TypeRange plainTypes, PrintPrefixFn &&printPrefixForPlain, PrintPrefixFn &&printPrefixForIterated)
static LogicalResult verifyDependVarList(Operation *op, std::optional< ArrayAttr > dependKinds, OperandRange dependVars, std::optional< ArrayAttr > iteratedKinds, OperandRange iteratedVars)
Verifies Depend clause.
static void printBlockArgClause(OpAsmPrinter &p, MLIRContext *ctx, StringRef clauseName, ValueRange argsSubrange, std::optional< MapPrintArgs > mapArgs)
static void printAffinityClause(OpAsmPrinter &p, Operation *op, ValueRange iterated, ValueRange affinityVars, TypeRange iteratedTypes, TypeRange affinityVarTypes)
static void printBlockArgRegion(OpAsmPrinter &p, Operation *op, Region ®ion, const AllRegionPrintArgs &args)
static ParseResult parseGranularityClause(OpAsmParser &parser, ClauseTypeAttr &prescriptiveness, std::optional< OpAsmParser::UnresolvedOperand > &operand, Type &operandType, std::optional< ClauseType >(*symbolizeClause)(StringRef), StringRef clauseName)
static void printIteratorHeader(OpAsmPrinter &p, Operation *op, Region ®ion, ValueRange lbs, ValueRange ubs, ValueRange steps, TypeRange, TypeRange, TypeRange)
static ParseResult parseHeapAllocClause(OpAsmParser &parser, TypeAttr &inTypeAttr, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &typeparams, SmallVectorImpl< Type > &typeparamsTypes, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &shape, SmallVectorImpl< Type > &shapeTypes)
operation ::= $in_type ( ( $typeparams ) )? ( , $shape )?
static ParseResult parseIteratorHeader(OpAsmParser &parser, Region ®ion, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &lbs, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &ubs, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &steps, SmallVectorImpl< Type > &lbTypes, SmallVectorImpl< Type > &ubTypes, SmallVectorImpl< Type > &stepTypes)
static ParseResult parseBlockArgRegion(OpAsmParser &parser, Region ®ion, AllRegionParseArgs args)
static ParseResult parseLoopTransformClis(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &generateesOperands, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &applyeesOperands)
static ParseResult parseSynchronizationHint(OpAsmParser &parser, IntegerAttr &hintAttr)
Parses a Synchronization Hint clause.
static void printScheduleClause(OpAsmPrinter &p, Operation *op, ClauseScheduleKindAttr scheduleKind, ScheduleModifierAttr scheduleMod, UnitAttr scheduleSimd, Value scheduleChunk, Type scheduleChunkType)
Print schedule clause.
static void printCopyprivate(OpAsmPrinter &p, Operation *op, OperandRange copyprivateVars, TypeRange copyprivateTypes, std::optional< ArrayAttr > copyprivateSyms)
Print Copyprivate clause.
static ParseResult parseOrderClause(OpAsmParser &parser, ClauseOrderKindAttr &order, OrderModifierAttr &orderMod)
static bool mapTypeToBool(ClauseMapFlags value, ClauseMapFlags flag)
static void printAlignedClause(OpAsmPrinter &p, Operation *op, ValueRange alignedVars, TypeRange alignedTypes, std::optional< ArrayAttr > alignments)
Print Aligned Clause.
static LogicalResult verifySynchronizationHint(Operation *op, uint64_t hint)
Verifies a synchronization hint clause.
static ParseResult parseUseDeviceAddrUseDevicePtrRegion(OpAsmParser &parser, Region ®ion, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &useDeviceAddrVars, SmallVectorImpl< Type > &useDeviceAddrTypes, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &useDevicePtrVars, SmallVectorImpl< Type > &useDevicePtrTypes)
static Operation * findCapturedOmpOp(Operation *rootOp, bool checkSingleMandatoryExec, llvm::function_ref< bool(Operation *)> siblingAllowedFn)
static ParseResult parseUniformClause(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &uniformVars, SmallVectorImpl< Type > &uniformTypes)
uniform ::= uniform ( uniform-list ) uniform-list := uniform-val (, uniform-val)* uniform-val := ssa-...
static void printInReductionPrivateReductionRegion(OpAsmPrinter &p, Operation *op, Region ®ion, ValueRange inReductionVars, TypeRange inReductionTypes, DenseBoolArrayAttr inReductionByref, ArrayAttr inReductionSyms, ValueRange privateVars, TypeRange privateTypes, ArrayAttr privateSyms, UnitAttr privateNeedsBarrier, ReductionModifierAttr reductionMod, ValueRange reductionVars, TypeRange reductionTypes, DenseBoolArrayAttr reductionByref, ArrayAttr reductionSyms)
static void printInReductionPrivateRegion(OpAsmPrinter &p, Operation *op, Region ®ion, ValueRange inReductionVars, TypeRange inReductionTypes, DenseBoolArrayAttr inReductionByref, ArrayAttr inReductionSyms, ValueRange privateVars, TypeRange privateTypes, ArrayAttr privateSyms, UnitAttr privateNeedsBarrier)
static void printSynchronizationHint(OpAsmPrinter &p, Operation *op, IntegerAttr hintAttr)
Prints a Synchronization Hint clause.
static void printGranularityClause(OpAsmPrinter &p, Operation *op, ClauseTypeAttr prescriptiveness, Value operand, mlir::Type operandType, StringRef(*stringifyClauseType)(ClauseType))
static ParseResult parseDependVarList(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &dependVars, SmallVectorImpl< Type > &dependTypes, ArrayAttr &dependKinds, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &iteratedVars, SmallVectorImpl< Type > &iteratedTypes, ArrayAttr &iteratedKinds)
depend-entry-list ::= depend-entry | depend-entry-list , depend-entry depend-entry ::= depend-kind ->...
static void printTargetOpRegion(OpAsmPrinter &p, Operation *op, Region ®ion, ValueRange hasDeviceAddrVars, TypeRange hasDeviceAddrTypes, ValueRange hostEvalVars, TypeRange hostEvalTypes, ValueRange inReductionVars, TypeRange inReductionTypes, DenseBoolArrayAttr inReductionByref, ArrayAttr inReductionSyms, ValueRange mapVars, TypeRange mapTypes, ValueRange privateVars, TypeRange privateTypes, ArrayAttr privateSyms, UnitAttr privateNeedsBarrier, DenseI64ArrayAttr privateMaps)
static Operation * getParentInSameDialect(Operation *thisOp)
static void printUniformClause(OpAsmPrinter &p, Operation *op, ValueRange uniformVars, TypeRange uniformTypes)
Print Uniform Clauses.
static LogicalResult verifyCopyprivateVarList(Operation *op, OperandRange copyprivateVars, std::optional< ArrayAttr > copyprivateSyms)
Verifies CopyPrivate Clause.
static LogicalResult verifyAlignedClause(Operation *op, std::optional< ArrayAttr > alignments, OperandRange alignedVars)
static ParseResult parseTargetOpRegion(OpAsmParser &parser, Region ®ion, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &hasDeviceAddrVars, SmallVectorImpl< Type > &hasDeviceAddrTypes, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &hostEvalVars, SmallVectorImpl< Type > &hostEvalTypes, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &inReductionVars, SmallVectorImpl< Type > &inReductionTypes, DenseBoolArrayAttr &inReductionByref, ArrayAttr &inReductionSyms, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &mapVars, SmallVectorImpl< Type > &mapTypes, llvm::SmallVectorImpl< OpAsmParser::UnresolvedOperand > &privateVars, llvm::SmallVectorImpl< Type > &privateTypes, ArrayAttr &privateSyms, UnitAttr &privateNeedsBarrier, DenseI64ArrayAttr &privateMaps)
static ParseResult parsePrivateRegion(OpAsmParser &parser, Region ®ion, llvm::SmallVectorImpl< OpAsmParser::UnresolvedOperand > &privateVars, llvm::SmallVectorImpl< Type > &privateTypes, ArrayAttr &privateSyms, UnitAttr &privateNeedsBarrier)
static void printNumTasksClause(OpAsmPrinter &p, Operation *op, ClauseNumTasksTypeAttr numTasksMod, Value numTasks, mlir::Type numTasksType)
static void printLoopTransformClis(OpAsmPrinter &p, TileOp op, OperandRange generatees, OperandRange applyees)
static ParseResult parseDynGroupprivateClause(OpAsmParser &parser, AccessGroupModifierAttr &accessGroupAttr, FallbackModifierAttr &fallbackAttr, std::optional< OpAsmParser::UnresolvedOperand > &dynGroupprivateSize, Type &sizeType)
static void printPrivateRegion(OpAsmPrinter &p, Operation *op, Region ®ion, ValueRange privateVars, TypeRange privateTypes, ArrayAttr privateSyms, UnitAttr privateNeedsBarrier)
static void printPrivateReductionRegion(OpAsmPrinter &p, Operation *op, Region ®ion, ValueRange privateVars, TypeRange privateTypes, ArrayAttr privateSyms, UnitAttr privateNeedsBarrier, ReductionModifierAttr reductionMod, ValueRange reductionVars, TypeRange reductionTypes, DenseBoolArrayAttr reductionByref, ArrayAttr reductionSyms)
static ParseResult parseSplitIteratedList(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &iteratedVars, SmallVectorImpl< Type > &iteratedTypes, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &plainVars, SmallVectorImpl< Type > &plainTypes, ParsePrefixFn &&parsePrefix)
static void printTaskReductionRegion(OpAsmPrinter &p, Operation *op, Region ®ion, ValueRange taskReductionVars, TypeRange taskReductionTypes, DenseBoolArrayAttr taskReductionByref, ArrayAttr taskReductionSyms)
static LogicalResult verifyOrderedParent(Operation &op)
static void printOrderClause(OpAsmPrinter &p, Operation *op, ClauseOrderKindAttr order, OrderModifierAttr orderMod)
static ParseResult parseBlockArgClause(OpAsmParser &parser, llvm::SmallVectorImpl< OpAsmParser::Argument > &entryBlockArgs, StringRef keyword, std::optional< MapParseArgs > mapArgs)
static ParseResult parseClauseWithRegionArgs(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &operands, SmallVectorImpl< Type > &types, SmallVectorImpl< OpAsmParser::Argument > ®ionPrivateArgs, ArrayAttr *symbols=nullptr, DenseI64ArrayAttr *mapIndices=nullptr, DenseBoolArrayAttr *byref=nullptr, ReductionModifierAttr *modifier=nullptr, UnitAttr *needsBarrier=nullptr)
static LogicalResult verifyPrivateVarsMapping(TargetOp targetOp)
static ParseResult parseScheduleClause(OpAsmParser &parser, ClauseScheduleKindAttr &scheduleAttr, ScheduleModifierAttr &scheduleMod, UnitAttr &scheduleSimd, std::optional< OpAsmParser::UnresolvedOperand > &chunkSize, Type &chunkType)
schedule ::= schedule ( sched-list ) sched-list ::= sched-val | sched-val sched-list | sched-val ,...
static LogicalResult verifyDynGroupprivateClause(Operation *op, AccessGroupModifierAttr accessGroup, FallbackModifierAttr fallback, Value dynGroupprivateSize)
static LogicalResult verifyLinearModifiers(Operation *op, std::optional< ArrayAttr > linearModifiers, OperandRange linearVars, bool isDeclareSimd=false)
OpenMP 5.2, Section 5.4.6: "A linear-modifier may be specified as ref or uval only on a declare simd ...
static void printClauseAttr(OpAsmPrinter &p, Operation *op, ClauseAttr attr)
static ParseResult parseAllocateAndAllocator(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &allocateVars, SmallVectorImpl< Type > &allocateTypes, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &allocatorVars, SmallVectorImpl< Type > &allocatorTypes)
Parse an allocate clause with allocators and a list of operands with types.
static void printMembersIndex(OpAsmPrinter &p, MapInfoOp op, ArrayAttr membersIdx)
static bool canPromoteToNoLoop(Operation *capturedOp, TeamsOp teamsOp, WsloopOp *wsLoopOp)
Check if we can promote SPMD kernel to No-Loop kernel.
static void printCaptureType(OpAsmPrinter &p, Operation *op, VariableCaptureKindAttr mapCaptureType)
static LogicalResult verifyNumTeamsClause(Operation *op, Value numTeamsLower, OperandRange numTeamsUpperVars)
static bool opInGlobalImplicitParallelRegion(Operation *op)
static void printUseDeviceAddrUseDevicePtrRegion(OpAsmPrinter &p, Operation *op, Region ®ion, ValueRange useDeviceAddrVars, TypeRange useDeviceAddrTypes, ValueRange useDevicePtrVars, TypeRange useDevicePtrTypes)
static LogicalResult verifyPrivateVarList(OpType &op)
static ParseResult parseNumTasksClause(OpAsmParser &parser, ClauseNumTasksTypeAttr &numTasksMod, std::optional< OpAsmParser::UnresolvedOperand > &numTasks, Type &numTasksType)
LogicalResult verifyAlignment(Operation &op, std::optional< uint64_t > alignment)
Verifies align clause in allocate directive.
static ParseResult parseAlignedClause(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &alignedVars, SmallVectorImpl< Type > &alignedTypes, ArrayAttr &alignmentsAttr)
aligned ::= aligned ( aligned-list ) aligned-list := aligned-val | aligned-val aligned-list aligned-v...
static ParseResult parsePrivateReductionRegion(OpAsmParser &parser, Region ®ion, llvm::SmallVectorImpl< OpAsmParser::UnresolvedOperand > &privateVars, llvm::SmallVectorImpl< Type > &privateTypes, ArrayAttr &privateSyms, UnitAttr &privateNeedsBarrier, ReductionModifierAttr &reductionMod, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &reductionVars, SmallVectorImpl< Type > &reductionTypes, DenseBoolArrayAttr &reductionByref, ArrayAttr &reductionSyms)
static void printLinearClause(OpAsmPrinter &p, Operation *op, ValueRange linearVars, TypeRange linearTypes, ValueRange linearStepVars, TypeRange stepVarTypes, ArrayAttr linearModifiers)
Print Linear Clause.
static ParseResult parseInReductionPrivateReductionRegion(OpAsmParser &parser, Region ®ion, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &inReductionVars, SmallVectorImpl< Type > &inReductionTypes, DenseBoolArrayAttr &inReductionByref, ArrayAttr &inReductionSyms, llvm::SmallVectorImpl< OpAsmParser::UnresolvedOperand > &privateVars, llvm::SmallVectorImpl< Type > &privateTypes, ArrayAttr &privateSyms, UnitAttr &privateNeedsBarrier, ReductionModifierAttr &reductionMod, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &reductionVars, SmallVectorImpl< Type > &reductionTypes, DenseBoolArrayAttr &reductionByref, ArrayAttr &reductionSyms)
static LogicalResult checkApplyeesNesting(TileOp op)
Check properties of the loop nest consisting of the transformation's applyees:
static ParseResult parseCaptureType(OpAsmParser &parser, VariableCaptureKindAttr &mapCaptureType)
static ParseResult parseTaskReductionRegion(OpAsmParser &parser, Region ®ion, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &taskReductionVars, SmallVectorImpl< Type > &taskReductionTypes, DenseBoolArrayAttr &taskReductionByref, ArrayAttr &taskReductionSyms)
static ParseResult parseGrainsizeClause(OpAsmParser &parser, ClauseGrainsizeTypeAttr &grainsizeMod, std::optional< OpAsmParser::UnresolvedOperand > &grainsize, Type &grainsizeType)
static ParseResult parseCopyprivate(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > ©privateVars, SmallVectorImpl< Type > ©privateTypes, ArrayAttr ©privateSyms)
copyprivate-entry-list ::= copyprivate-entry | copyprivate-entry-list , copyprivate-entry copyprivate...
static LogicalResult verifyMapInfoDefinedArgs(Operation *op, StringRef clauseName, OperandRange vars)
static void printGrainsizeClause(OpAsmPrinter &p, Operation *op, ClauseGrainsizeTypeAttr grainsizeMod, Value grainsize, mlir::Type grainsizeType)
static ParseResult verifyScheduleModifiers(OpAsmParser &parser, SmallVectorImpl< SmallString< 12 > > &modifiers)
static bool isUnique(It begin, It end)
static LogicalResult emit(SolverOp solver, const SMTEmissionOptions &options, mlir::raw_indented_ostream &stream)
Emit the SMT operations in the given 'solver' to the 'stream'.
static SmallVector< Value > getTileSizes(Location loc, x86::amx::TileType tType, RewriterBase &rewriter)
Maps the 2-dim vector shape to the two 16-bit tile sizes.
This base class exposes generic asm parser hooks, usable across the various derived parsers.
virtual ParseResult parseMinus()=0
Parse a '-' token.
@ Paren
Parens surrounding zero or more operands.
@ None
Zero or more operands with no delimiters.
virtual ParseResult parseColonTypeList(SmallVectorImpl< Type > &result)=0
Parse a colon followed by a type list, which must have at least one type.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalEqual()=0
Parse a = token if present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseOptionalColon()=0
Parse a : token if present.
virtual ParseResult parseLSquare()=0
Parse a [ token.
virtual ParseResult parseRSquare()=0
Parse a ] token.
ParseResult parseInteger(IntT &result)
Parse an integer value from the stream.
virtual ParseResult parseOptionalArrow()=0
Parse a '->' token if present.
virtual ParseResult parseLess()=0
Parse a '<' token.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual ParseResult parseColon()=0
Parse a : token.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseOptionalLess()=0
Parse a '<' token if present.
virtual ParseResult parseArrow()=0
Parse a '->' token.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseComma()=0
Parse a , token.
virtual ParseResult parseOptionalLParen()=0
Parse a ( token if present.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
Attributes are known-constant values of operations.
Block represents an ordered list of Operations.
ValueTypeRange< BlockArgListType > getArgumentTypes()
Return a range containing the types of the arguments for this block.
BlockArgument getArgument(unsigned i)
unsigned getNumArguments()
SuccessorRange getSuccessors()
Operation * getTerminator()
Get the terminator operation of this block.
bool mightHaveTerminator()
Return "true" if this block might have a terminator.
BlockArgListType getArguments()
IntegerType getIntegerType(unsigned width)
MLIRContext * getContext() const
Attr getAttr(Args &&...args)
Get or construct an instance of the attribute Attr with provided arguments.
Diagnostic & append(Arg1 &&arg1, Arg2 &&arg2, Args &&...args)
Append arguments to the diagnostic.
Diagnostic & appendOp(Operation &op, const OpPrintingFlags &flags)
Append an operation with the given printing flags.
Dialects are groups of MLIR operations, types and attributes, as well as behavior associated with the...
A class for computing basic dominance information.
bool dominates(Operation *a, Operation *b) const
Return true if operation A dominates operation B, i.e.
This class represents a diagnostic that is inflight and set to be reported.
Diagnostic & attachNote(std::optional< Location > noteLoc=std::nullopt)
Attaches a note to this diagnostic.
MLIRContext is the top-level object for a collection of MLIR operations.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region ®ion, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult parseArgument(Argument &result, bool allowType=false, bool allowAttrs=false)=0
Parse a single argument with the following syntax:
virtual ParseResult parseArgumentList(SmallVectorImpl< Argument > &result, Delimiter delimiter=Delimiter::None, bool allowType=false, bool allowAttrs=false)=0
Parse zero or more arguments with a specified surrounding delimiter.
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
virtual void printRegionArgument(BlockArgument arg, ArrayRef< NamedAttribute > argAttrs={}, bool omitType=false)=0
Print a block argument in the usual format of: ssaName : type {attr1=42} loc("here") where location p...
virtual void printOperand(Value value)=0
Print implementations for various things an operation contains.
This class helps build Operations.
This class represents an operand of an operation.
Set of flags used to control the behavior of the various IR print methods (e.g.
This class provides the API for ops that are known to be isolated from above.
This class provides the API for ops that are known to be terminators.
This class indicates that the regions associated with this op don't have terminators.
This class implements the operand iterators for the Operation class.
type_range getType() const
Operation is the basic unit of execution within MLIR.
Dialect * getDialect()
Return the dialect this operation is associated with, or nullptr if the associated dialect is not loa...
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Block * getBlock()
Returns the operation block that contains this operation.
unsigned getNumRegions()
Returns the number of regions held by this operation.
Location getLoc()
The source location the operation was defined or derived from.
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
operand_range getOperands()
Returns an iterator on the underlying Value's.
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
user_range getUsers()
Returns a range of all users.
Region * getParentRegion()
Returns the region to which the instruction belongs.
MLIRContext * getContext()
Return the context this operation is associated with.
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class contains a list of basic blocks and a link to the parent operation it is attached to.
BlockArgListType getArguments()
OpIterator op_begin()
Return iterators that walk the operations nested directly within this region.
bool isAncestor(Region *other)
Return true if this region is ancestor of the other region.
iterator_range< OpIterator > getOps()
unsigned getNumArguments()
Location getLoc()
Return a location for this region.
BlockArgument getArgument(unsigned i)
Operation * getParentOp()
Return the parent operation this region is attached to.
BlockListType & getBlocks()
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Resource * getResource() const
Return the resource that the effect applies to.
EffectT * getEffect() const
Return the effect being applied.
This class represents a collection of SymbolTables.
virtual Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
static Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
type_range getType() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
MLIRContext * getContext() const
Utility to get the associated MLIRContext that this value is defined in.
Type getType() const
Return the type of this value.
use_range getUses() const
Returns a range of all uses, which is useful for iterating over all uses.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
A utility result that is used to signal how to proceed with an ongoing walk:
static WalkResult advance()
static WalkResult interrupt()
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< bool > content)
bool isReachableFromEntry(Block *a) const
Return true if the specified block is reachable from the entry block of its region.
Operation * getOwner() const
Return the owner of this operand.
SideEffects::EffectInstance< Effect > EffectInstance
TargetEnterDataOperands TargetEnterExitUpdateDataOperands
omp.target_enter_data, omp.target_exit_data and omp.target_update take the same clauses,...
std::tuple< NewCliOp, OpOperand *, OpOperand * > decodeCli(mlir::Value cli)
Find the omp.new_cli, generator, and consumer of a canonical loop info.
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
detail::DenseArrayAttrImpl< int64_t > DenseI64ArrayAttr
function_ref< void(Value, StringRef)> OpAsmSetValueNameFn
A functor used to set the name of the start of a result group of an operation.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
llvm::DenseSet< ValueT, ValueInfoT > DenseSet
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
bool isPure(Operation *op)
Returns true if the given operation is pure, i.e., is speculatable that does not touch memory.
detail::constant_int_predicate_matcher m_Zero()
Matches a constant scalar / vector splat / tensor splat integer zero.
llvm::TypeSwitch< T, ResultT > TypeSwitch
SmallVector< Loops, 8 > tile(ArrayRef< scf::ForOp > forOps, ArrayRef< Value > sizes, ArrayRef< scf::ForOp > targets)
Performs tiling fo imperfectly nested loops (with interchange) by strip-mining the forOps by sizes an...
detail::DenseArrayAttrImpl< bool > DenseBoolArrayAttr
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
function_ref< void(Block *, StringRef)> OpAsmSetBlockNameFn
A functor used to set the name of blocks in regions directly nested under an operation.
This is the representation of an operand reference.
This class provides APIs and verifiers for ops with regions having a single block.
This represents an operation in an abstracted form, suitable for use with the builder APIs.
void addOperands(ValueRange newOperands)
void addAttributes(ArrayRef< NamedAttribute > newAttributes)
Add an array of named attributes.
void addTypes(ArrayRef< Type > newTypes)
Region * addRegion()
Create a region that should be attached to the operation.