28#include "llvm/ADT/ArrayRef.h"
29#include "llvm/ADT/PostOrderIterator.h"
30#include "llvm/ADT/STLExtras.h"
31#include "llvm/ADT/STLForwardCompat.h"
32#include "llvm/ADT/SmallString.h"
33#include "llvm/ADT/StringExtras.h"
34#include "llvm/ADT/StringRef.h"
35#include "llvm/ADT/TypeSwitch.h"
36#include "llvm/ADT/bit.h"
37#include "llvm/Support/InterleavedRange.h"
43#include "mlir/Dialect/OpenMP/OpenMPOpsDialect.cpp.inc"
44#include "mlir/Dialect/OpenMP/OpenMPOpsEnums.cpp.inc"
45#include "mlir/Dialect/OpenMP/OpenMPOpsInterfaces.cpp.inc"
46#include "mlir/Dialect/OpenMP/OpenMPTypeInterfaces.cpp.inc"
53 return attrs.empty() ?
nullptr : ArrayAttr::get(context, attrs);
67struct MemRefPointerLikeModel
68 :
public PointerLikeType::ExternalModel<MemRefPointerLikeModel,
71 return llvm::cast<MemRefType>(pointer).getElementType();
75struct LLVMPointerPointerLikeModel
76 :
public PointerLikeType::ExternalModel<LLVMPointerPointerLikeModel,
77 LLVM::LLVMPointerType> {
102 bool isRegionArgOfOp;
112 assert(isRegionArgOfOp &&
"Must describe a region operand");
115 size_t &getArgIdx() {
116 assert(isRegionArgOfOp &&
"Must describe a region operand");
121 assert(!isRegionArgOfOp &&
"Must describe a operation of a region");
125 assert(!isRegionArgOfOp &&
"Must describe a operation of a region");
128 bool isLoopOp()
const {
129 assert(!isRegionArgOfOp &&
"Must describe a operation of a region");
130 return isa<CanonicalLoopOp>(op);
132 Region *&getParentRegion() {
133 assert(!isRegionArgOfOp &&
"Must describe a operation of a region");
136 size_t &getLoopDepth() {
137 assert(!isRegionArgOfOp &&
"Must describe a operation of a region");
141 void skipIf(
bool v =
true) { skip = skip || v; }
159 llvm::ReversePostOrderTraversal<Block *> traversal(&r->
getBlocks().front());
162 size_t sequentialIdx = -1;
163 bool isOnlyContainerOp =
true;
164 for (
Block *
b : traversal) {
166 if (&op == o && !found) {
170 if (op.getNumRegions()) {
173 isOnlyContainerOp =
false;
175 if (found && !isOnlyContainerOp)
180 Component &containerOpInRegion = components.emplace_back();
181 containerOpInRegion.isRegionArgOfOp =
false;
182 containerOpInRegion.isUnique = isOnlyContainerOp;
183 containerOpInRegion.getContainerOp() = o;
184 containerOpInRegion.getOpPos() = sequentialIdx;
185 containerOpInRegion.getParentRegion() = r;
190 Component ®ionArgOfOperation = components.emplace_back();
191 regionArgOfOperation.isRegionArgOfOp =
true;
192 regionArgOfOperation.isUnique =
true;
193 regionArgOfOperation.getArgIdx() = 0;
194 regionArgOfOperation.getOwnerOp() = parent;
206 for (
auto [idx, region] : llvm::enumerate(o->
getRegions())) {
210 llvm_unreachable(
"Region not child of its parent operation");
212 regionArgOfOperation.isUnique =
false;
213 regionArgOfOperation.getArgIdx() = getRegionIndex(parent, r);
221 for (Component &c : components)
222 c.skipIf(c.isRegionArgOfOp && c.isUnique);
225 size_t numSurroundingLoops = 0;
226 for (Component &c : llvm::reverse(components)) {
231 if (c.isRegionArgOfOp) {
232 numSurroundingLoops = 0;
239 numSurroundingLoops = 0;
241 c.getLoopDepth() = numSurroundingLoops;
244 if (isa<CanonicalLoopOp>(c.getContainerOp()))
245 numSurroundingLoops += 1;
250 bool isLoopNest =
false;
251 for (Component &c : components) {
252 if (c.skip || c.isRegionArgOfOp)
255 if (!isLoopNest && c.getLoopDepth() >= 1) {
258 }
else if (isLoopNest) {
260 c.skipIf(c.isUnique);
264 if (c.getLoopDepth() == 0)
271 for (Component &c : components)
272 c.skipIf(!c.isRegionArgOfOp && c.isUnique &&
273 !isa<CanonicalLoopOp>(c.getContainerOp()));
277 bool newRegion =
true;
278 for (Component &c : llvm::reverse(components)) {
279 c.skipIf(newRegion && c.isUnique);
286 if (!c.isRegionArgOfOp && c.getContainerOp())
292 llvm::raw_svector_ostream NameOS(Name);
293 for (
auto &c : llvm::reverse(components)) {
297 if (c.isRegionArgOfOp)
298 NameOS <<
"_r" << c.getArgIdx();
299 else if (c.getLoopDepth() >= 1)
300 NameOS <<
"_d" << c.getLoopDepth();
302 NameOS <<
"_s" << c.getOpPos();
305 return NameOS.str().str();
308void OpenMPDialect::initialize() {
311#include "mlir/Dialect/OpenMP/OpenMPOps.cpp.inc"
314#define GET_ATTRDEF_LIST
315#include "mlir/Dialect/OpenMP/OpenMPOpsAttributes.cpp.inc"
318#define GET_TYPEDEF_LIST
319#include "mlir/Dialect/OpenMP/OpenMPOpsTypes.cpp.inc"
322 declarePromisedInterface<ConvertToLLVMPatternInterface, OpenMPDialect>();
324 MemRefType::attachInterface<MemRefPointerLikeModel>(*
getContext());
325 LLVM::LLVMPointerType::attachInterface<LLVMPointerPointerLikeModel>(
330 mlir::ModuleOp::attachInterface<mlir::omp::OffloadModuleDefaultModel>(
336 mlir::LLVM::GlobalOp::attachInterface<
339 mlir::LLVM::LLVMFuncOp::attachInterface<
342 mlir::func::FuncOp::attachInterface<
368 allocatorVars.push_back(operand);
369 allocatorTypes.push_back(type);
375 allocateVars.push_back(operand);
376 allocateTypes.push_back(type);
387 for (
unsigned i = 0; i < allocateVars.size(); ++i) {
388 std::string separator = i == allocateVars.size() - 1 ?
"" :
", ";
389 p << allocatorVars[i] <<
" : " << allocatorTypes[i] <<
" -> ";
390 p << allocateVars[i] <<
" : " << allocateTypes[i] << separator;
398template <
typename ClauseAttr>
400 using ClauseT =
decltype(std::declval<ClauseAttr>().getValue());
405 if (std::optional<ClauseT> enumValue = symbolizeEnum<ClauseT>(enumStr)) {
406 attr = ClauseAttr::get(parser.
getContext(), *enumValue);
409 return parser.
emitError(loc,
"invalid clause value: '") << enumStr <<
"'";
412template <
typename ClauseAttr>
414 p << stringifyEnum(attr.getValue());
439 std::optional<omp::LinearModifier> linearModifier;
441 linearModifier = omp::LinearModifier::val;
443 linearModifier = omp::LinearModifier::ref;
445 linearModifier = omp::LinearModifier::uval;
448 bool hasLinearModifierParens = linearModifier.has_value();
449 if (hasLinearModifierParens && parser.
parseLParen())
457 if (hasLinearModifierParens && parser.
parseRParen())
460 linearVars.push_back(var);
461 linearTypes.push_back(type);
462 linearStepVars.push_back(stepVar);
463 linearStepTypes.push_back(stepType);
464 if (linearModifier) {
466 omp::LinearModifierAttr::get(parser.
getContext(), *linearModifier));
468 modifiers.push_back(UnitAttr::get(parser.
getContext()));
474 linearModifiers = ArrayAttr::get(parser.
getContext(), modifiers);
483 size_t linearVarsSize = linearVars.size();
484 for (
unsigned i = 0; i < linearVarsSize; ++i) {
488 Attribute modAttr = linearModifiers ? linearModifiers[i] :
nullptr;
489 auto mod = modAttr ? dyn_cast<omp::LinearModifierAttr>(modAttr) :
nullptr;
491 p << omp::stringifyLinearModifier(mod.getValue()) <<
"(";
493 p << linearVars[i] <<
" : " << linearTypes[i];
494 p <<
" = " << linearStepVars[i] <<
" : " << stepVarTypes[i];
510 if (!linearModifiers)
512 if (linearModifiers->size() != linearVars.size())
514 <<
"expected as many linear modifiers as linear variables";
515 if (!isDeclareSimd) {
516 for (
Attribute attr : *linearModifiers) {
519 auto modAttr = dyn_cast<omp::LinearModifierAttr>(attr);
522 omp::LinearModifier mod = modAttr.getValue();
523 if (mod == omp::LinearModifier::ref || mod == omp::LinearModifier::uval)
525 <<
"linear modifier '" << omp::stringifyLinearModifier(mod)
526 <<
"' may only be specified on a declare simd directive";
541 for (
const auto &it : nontemporalVars)
542 if (!nontemporalItems.insert(it).second)
543 return op->
emitOpError() <<
"nontemporal variable used more than once";
552 std::optional<ArrayAttr> alignments,
555 if (!alignedVars.empty()) {
556 if (!alignments || alignments->size() != alignedVars.size())
558 <<
"expected as many alignment values as aligned variables";
561 return op->
emitOpError() <<
"unexpected alignment values attribute";
567 for (
auto it : alignedVars)
568 if (!alignedItems.insert(it).second)
569 return op->
emitOpError() <<
"aligned variable used more than once";
575 for (
unsigned i = 0; i < (*alignments).size(); ++i) {
576 if (
auto intAttr = llvm::dyn_cast<IntegerAttr>((*alignments)[i])) {
577 if (intAttr.getValue().sle(0))
578 return op->
emitOpError() <<
"alignment should be greater than 0";
580 return op->
emitOpError() <<
"expected integer alignment";
597 if (parser.parseOperand(alignedVars.emplace_back()) ||
598 parser.parseColonType(alignedTypes.emplace_back()) ||
599 parser.parseArrow() ||
600 parser.parseAttribute(alignmentVec.emplace_back())) {
607 alignmentsAttr = ArrayAttr::get(parser.getContext(), alignments);
614 std::optional<ArrayAttr> alignments) {
615 for (
unsigned i = 0; i < alignedVars.size(); ++i) {
618 p << alignedVars[i] <<
" : " << alignedVars[i].
getType();
619 p <<
" -> " << (*alignments)[i];
630 if (modifiers.size() > 2)
632 for (
const auto &mod : modifiers) {
635 auto symbol = symbolizeScheduleModifier(mod);
638 <<
" unknown modifier type: " << mod;
643 if (modifiers.size() == 1) {
644 if (symbolizeScheduleModifier(modifiers[0]) == ScheduleModifier::simd) {
645 modifiers.push_back(modifiers[0]);
646 modifiers[0] = stringifyScheduleModifier(ScheduleModifier::none);
648 }
else if (modifiers.size() == 2) {
651 if (symbolizeScheduleModifier(modifiers[0]) == ScheduleModifier::simd ||
652 symbolizeScheduleModifier(modifiers[1]) != ScheduleModifier::simd)
654 <<
" incorrect modifier order";
670 ScheduleModifierAttr &scheduleMod, UnitAttr &scheduleSimd,
671 std::optional<OpAsmParser::UnresolvedOperand> &chunkSize,
676 std::optional<mlir::omp::ClauseScheduleKind> schedule =
677 symbolizeClauseScheduleKind(keyword);
681 scheduleAttr = ClauseScheduleKindAttr::get(parser.
getContext(), *schedule);
683 case ClauseScheduleKind::Static:
684 case ClauseScheduleKind::Dynamic:
685 case ClauseScheduleKind::Guided:
691 chunkSize = std::nullopt;
694 case ClauseScheduleKind::Auto:
695 case ClauseScheduleKind::Runtime:
696 case ClauseScheduleKind::Distribute:
697 chunkSize = std::nullopt;
706 modifiers.push_back(mod);
712 if (!modifiers.empty()) {
714 if (std::optional<ScheduleModifier> mod =
715 symbolizeScheduleModifier(modifiers[0])) {
716 scheduleMod = ScheduleModifierAttr::get(parser.
getContext(), *mod);
718 return parser.
emitError(loc,
"invalid schedule modifier");
721 if (modifiers.size() > 1) {
722 assert(symbolizeScheduleModifier(modifiers[1]) == ScheduleModifier::simd);
732 ClauseScheduleKindAttr scheduleKind,
733 ScheduleModifierAttr scheduleMod,
734 UnitAttr scheduleSimd,
Value scheduleChunk,
735 Type scheduleChunkType) {
736 p << stringifyClauseScheduleKind(scheduleKind.getValue());
738 p <<
" = " << scheduleChunk <<
" : " << scheduleChunk.
getType();
740 p <<
", " << stringifyScheduleModifier(scheduleMod.getValue());
752 ClauseOrderKindAttr &order,
753 OrderModifierAttr &orderMod) {
758 if (std::optional<OrderModifier> enumValue =
759 symbolizeOrderModifier(enumStr)) {
760 orderMod = OrderModifierAttr::get(parser.
getContext(), *enumValue);
767 if (std::optional<ClauseOrderKind> enumValue =
768 symbolizeClauseOrderKind(enumStr)) {
769 order = ClauseOrderKindAttr::get(parser.
getContext(), *enumValue);
772 return parser.
emitError(loc,
"invalid clause value: '") << enumStr <<
"'";
776 ClauseOrderKindAttr order,
777 OrderModifierAttr orderMod) {
779 p << stringifyOrderModifier(orderMod.getValue()) <<
":";
781 p << stringifyClauseOrderKind(order.getValue());
784template <
typename ClauseTypeAttr,
typename ClauseType>
787 std::optional<OpAsmParser::UnresolvedOperand> &operand,
789 std::optional<ClauseType> (*symbolizeClause)(StringRef),
790 StringRef clauseName) {
793 if (std::optional<ClauseType> enumValue = symbolizeClause(enumStr)) {
794 prescriptiveness = ClauseTypeAttr::get(parser.
getContext(), *enumValue);
799 <<
"invalid " << clauseName <<
" modifier : '" << enumStr <<
"'";
809 <<
"expected " << clauseName <<
" operand";
812 if (operand.has_value()) {
820template <
typename ClauseTypeAttr,
typename ClauseType>
823 ClauseTypeAttr prescriptiveness,
Value operand,
825 StringRef (*stringifyClauseType)(ClauseType)) {
827 if (prescriptiveness)
828 p << stringifyClauseType(prescriptiveness.getValue()) <<
", ";
831 p << operand <<
": " << operandType;
841 std::optional<OpAsmParser::UnresolvedOperand> &grainsize,
842 Type &grainsizeType) {
844 parser, grainsizeMod, grainsize, grainsizeType,
845 &symbolizeClauseGrainsizeType,
"grainsize");
849 ClauseGrainsizeTypeAttr grainsizeMod,
852 p, op, grainsizeMod, grainsize, grainsizeType,
853 &stringifyClauseGrainsizeType);
863 std::optional<OpAsmParser::UnresolvedOperand> &numTasks,
864 Type &numTasksType) {
866 parser, numTasksMod, numTasks, numTasksType, &symbolizeClauseNumTasksType,
871 ClauseNumTasksTypeAttr numTasksMod,
874 p, op, numTasksMod, numTasks, numTasksType, &stringifyClauseNumTasksType);
890 return mlir::failure();
891 inTypeAttr = TypeAttr::get(inType);
920 if (!typeparams.empty()) {
921 p <<
'(' << typeparams <<
" : " << typeparamsTypes <<
')';
923 for (
auto sh :
shape) {
935 FallbackModifierAttr fallback,
936 Value dynGroupprivateSize) {
937 if (!dynGroupprivateSize && (accessGroup || fallback))
938 return op->
emitOpError(
"dyn_groupprivate modifiers require a size operand");
944 OpAsmParser &parser, AccessGroupModifierAttr &accessGroupAttr,
945 FallbackModifierAttr &fallbackAttr,
946 std::optional<OpAsmParser::UnresolvedOperand> &dynGroupprivateSize,
949 bool parsedAccessGroup =
false;
950 bool parsedFallback =
false;
951 bool parsedSize =
false;
956 if (parsedAccessGroup)
958 "duplicate access group modifier");
959 accessGroupAttr = AccessGroupModifierAttr::get(
960 parser.
getContext(), AccessGroupModifier::cgroup);
961 parsedAccessGroup =
true;
968 "duplicate fallback modifier");
971 "expected '(' after 'fallback'");
972 llvm::StringRef fbKind;
976 "expected fallback modifier (abort/null/default_mem)");
977 std::optional<FallbackModifier> fbEnum;
978 if (fbKind ==
"abort")
979 fbEnum = FallbackModifier::abort;
980 else if (fbKind ==
"null")
981 fbEnum = FallbackModifier::null;
982 else if (fbKind ==
"default_mem")
983 fbEnum = FallbackModifier::default_mem;
986 "invalid fallback modifier '" + fbKind +
"'");
987 fallbackAttr = FallbackModifierAttr::get(parser.
getContext(), *fbEnum);
990 "expected ')' after fallback modifier");
991 parsedFallback =
true;
999 "duplicate size operand");
1000 dynGroupprivateSize = operand;
1004 "expected ':' and type after size operand");
1008 "expected dyn_groupprivate_size operand");
1013 AccessGroupModifierAttr modifierFirst,
1014 FallbackModifierAttr modifierSecond,
1015 Value dynGroupprivateSize,
1018 bool needsComma =
false;
1020 if (modifierFirst) {
1021 printer << modifierFirst.getValue();
1025 if (modifierSecond) {
1028 printer <<
"fallback(";
1029 printer << modifierSecond.getValue();
1034 if (dynGroupprivateSize) {
1037 printer << dynGroupprivateSize <<
" : " << sizeType;
1046struct MapParseArgs {
1047 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &vars;
1048 SmallVectorImpl<Type> &types;
1049 MapParseArgs(SmallVectorImpl<OpAsmParser::UnresolvedOperand> &vars,
1050 SmallVectorImpl<Type> &types)
1051 : vars(vars), types(types) {}
1053struct PrivateParseArgs {
1054 llvm::SmallVectorImpl<OpAsmParser::UnresolvedOperand> &vars;
1055 llvm::SmallVectorImpl<Type> &types;
1057 UnitAttr &needsBarrier;
1059 PrivateParseArgs(SmallVectorImpl<OpAsmParser::UnresolvedOperand> &vars,
1060 SmallVectorImpl<Type> &types,
ArrayAttr &syms,
1061 UnitAttr &needsBarrier,
1063 : vars(vars), types(types), syms(syms), needsBarrier(needsBarrier),
1064 mapIndices(mapIndices) {}
1067struct ReductionParseArgs {
1068 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &vars;
1069 SmallVectorImpl<Type> &types;
1072 ReductionModifierAttr *modifier;
1073 ReductionParseArgs(SmallVectorImpl<OpAsmParser::UnresolvedOperand> &vars,
1075 ArrayAttr &syms, ReductionModifierAttr *mod =
nullptr)
1076 : vars(vars), types(types), byref(byref), syms(syms), modifier(mod) {}
1079struct AllRegionParseArgs {
1080 std::optional<MapParseArgs> hasDeviceAddrArgs;
1081 std::optional<MapParseArgs> hostEvalArgs;
1082 std::optional<ReductionParseArgs> inReductionArgs;
1083 std::optional<MapParseArgs> mapArgs;
1084 std::optional<PrivateParseArgs> privateArgs;
1085 std::optional<ReductionParseArgs> reductionArgs;
1086 std::optional<ReductionParseArgs> taskReductionArgs;
1087 std::optional<MapParseArgs> useDeviceAddrArgs;
1088 std::optional<MapParseArgs> useDevicePtrArgs;
1093 return "private_barrier";
1103 ReductionModifierAttr *modifier =
nullptr,
1104 UnitAttr *needsBarrier =
nullptr) {
1108 unsigned regionArgOffset = regionPrivateArgs.size();
1118 std::optional<ReductionModifier> enumValue =
1119 symbolizeReductionModifier(enumStr);
1120 if (!enumValue.has_value())
1122 *modifier = ReductionModifierAttr::get(parser.
getContext(), *enumValue);
1129 isByRefVec.push_back(
1130 parser.parseOptionalKeyword(
"byref").succeeded());
1132 if (symbols && parser.parseAttribute(symbolVec.emplace_back()))
1135 if (parser.parseOperand(operands.emplace_back()) ||
1136 parser.parseArrow() ||
1137 parser.parseArgument(regionPrivateArgs.emplace_back()))
1141 if (parser.parseOptionalLSquare().succeeded()) {
1142 if (parser.parseKeyword(
"map_idx") || parser.parseEqual() ||
1143 parser.parseInteger(mapIndicesVec.emplace_back()) ||
1144 parser.parseRSquare())
1147 mapIndicesVec.push_back(-1);
1159 if (parser.parseType(types.emplace_back()))
1166 if (operands.size() != types.size())
1175 *needsBarrier = mlir::UnitAttr::get(parser.
getContext());
1178 auto *argsBegin = regionPrivateArgs.begin();
1180 argsBegin + regionArgOffset + types.size());
1181 for (
auto [prv, type] : llvm::zip_equal(argsSubrange, types)) {
1187 *symbols = ArrayAttr::get(parser.
getContext(), symbolAttrs);
1190 if (!mapIndicesVec.empty())
1203 StringRef keyword, std::optional<MapParseArgs> mapArgs) {
1218 StringRef keyword, std::optional<PrivateParseArgs> privateArgs) {
1224 parser, privateArgs->vars, privateArgs->types, entryBlockArgs,
1225 &privateArgs->syms, privateArgs->mapIndices,
nullptr,
1226 nullptr, &privateArgs->needsBarrier)))
1235 StringRef keyword, std::optional<ReductionParseArgs> reductionArgs) {
1240 parser, reductionArgs->vars, reductionArgs->types, entryBlockArgs,
1241 &reductionArgs->syms,
nullptr, &reductionArgs->byref,
1242 reductionArgs->modifier)))
1249 AllRegionParseArgs args) {
1253 args.hasDeviceAddrArgs)))
1255 <<
"invalid `has_device_addr` format";
1258 args.hostEvalArgs)))
1260 <<
"invalid `host_eval` format";
1263 args.inReductionArgs)))
1265 <<
"invalid `in_reduction` format";
1270 <<
"invalid `map_entries` format";
1275 <<
"invalid `private` format";
1278 args.reductionArgs)))
1280 <<
"invalid `reduction` format";
1283 args.taskReductionArgs)))
1285 <<
"invalid `task_reduction` format";
1288 args.useDeviceAddrArgs)))
1290 <<
"invalid `use_device_addr` format";
1293 args.useDevicePtrArgs)))
1295 <<
"invalid `use_device_addr` format";
1297 return parser.
parseRegion(region, entryBlockArgs);
1316 AllRegionParseArgs args;
1317 args.hasDeviceAddrArgs.emplace(hasDeviceAddrVars, hasDeviceAddrTypes);
1318 args.hostEvalArgs.emplace(hostEvalVars, hostEvalTypes);
1319 args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
1320 inReductionByref, inReductionSyms);
1321 args.mapArgs.emplace(mapVars, mapTypes);
1322 args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
1323 privateNeedsBarrier, &privateMaps);
1334 UnitAttr &privateNeedsBarrier) {
1335 AllRegionParseArgs args;
1336 args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
1337 inReductionByref, inReductionSyms);
1338 args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
1339 privateNeedsBarrier);
1350 UnitAttr &privateNeedsBarrier, ReductionModifierAttr &reductionMod,
1354 AllRegionParseArgs args;
1355 args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
1356 inReductionByref, inReductionSyms);
1357 args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
1358 privateNeedsBarrier);
1359 args.reductionArgs.emplace(reductionVars, reductionTypes, reductionByref,
1360 reductionSyms, &reductionMod);
1368 UnitAttr &privateNeedsBarrier) {
1369 AllRegionParseArgs args;
1370 args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
1371 privateNeedsBarrier);
1379 UnitAttr &privateNeedsBarrier, ReductionModifierAttr &reductionMod,
1383 AllRegionParseArgs args;
1384 args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
1385 privateNeedsBarrier);
1386 args.reductionArgs.emplace(reductionVars, reductionTypes, reductionByref,
1387 reductionSyms, &reductionMod);
1396 AllRegionParseArgs args;
1397 args.taskReductionArgs.emplace(taskReductionVars, taskReductionTypes,
1398 taskReductionByref, taskReductionSyms);
1408 AllRegionParseArgs args;
1409 args.useDeviceAddrArgs.emplace(useDeviceAddrVars, useDeviceAddrTypes);
1410 args.useDevicePtrArgs.emplace(useDevicePtrVars, useDevicePtrTypes);
1419struct MapPrintArgs {
1424struct PrivatePrintArgs {
1428 UnitAttr needsBarrier;
1432 : vars(vars), types(types), syms(syms), needsBarrier(needsBarrier),
1433 mapIndices(mapIndices) {}
1435struct ReductionPrintArgs {
1440 ReductionModifierAttr modifier;
1442 ArrayAttr syms, ReductionModifierAttr mod =
nullptr)
1443 : vars(vars), types(types), byref(byref), syms(syms), modifier(mod) {}
1445struct AllRegionPrintArgs {
1446 std::optional<MapPrintArgs> hasDeviceAddrArgs;
1447 std::optional<MapPrintArgs> hostEvalArgs;
1448 std::optional<ReductionPrintArgs> inReductionArgs;
1449 std::optional<MapPrintArgs> mapArgs;
1450 std::optional<PrivatePrintArgs> privateArgs;
1451 std::optional<ReductionPrintArgs> reductionArgs;
1452 std::optional<ReductionPrintArgs> taskReductionArgs;
1453 std::optional<MapPrintArgs> useDeviceAddrArgs;
1454 std::optional<MapPrintArgs> useDevicePtrArgs;
1463 ReductionModifierAttr modifier =
nullptr, UnitAttr needsBarrier =
nullptr) {
1464 if (argsSubrange.empty())
1467 p << clauseName <<
"(";
1470 p <<
"mod: " << stringifyReductionModifier(modifier.getValue()) <<
", ";
1474 symbols = ArrayAttr::get(ctx, values);
1487 llvm::interleaveComma(llvm::zip_equal(operands, argsSubrange, symbols,
1488 mapIndices.asArrayRef(),
1489 byref.asArrayRef()),
1491 auto [op, arg, sym, map, isByRef] = t;
1497 p << op <<
" -> " << arg;
1500 p <<
" [map_idx=" << map <<
"]";
1503 llvm::interleaveComma(types, p);
1511 StringRef clauseName,
ValueRange argsSubrange,
1512 std::optional<MapPrintArgs> mapArgs) {
1519 StringRef clauseName,
ValueRange argsSubrange,
1520 std::optional<PrivatePrintArgs> privateArgs) {
1523 p, ctx, clauseName, argsSubrange, privateArgs->vars, privateArgs->types,
1524 privateArgs->syms, privateArgs->mapIndices,
nullptr,
1525 nullptr, privateArgs->needsBarrier);
1531 std::optional<ReductionPrintArgs> reductionArgs) {
1534 reductionArgs->vars, reductionArgs->types,
1535 reductionArgs->syms,
nullptr,
1536 reductionArgs->byref, reductionArgs->modifier);
1540 const AllRegionPrintArgs &args) {
1541 auto iface = llvm::cast<mlir::omp::BlockArgOpenMPOpInterface>(op);
1545 iface.getHasDeviceAddrBlockArgs(),
1546 args.hasDeviceAddrArgs);
1550 args.inReductionArgs);
1556 args.reductionArgs);
1558 iface.getTaskReductionBlockArgs(),
1559 args.taskReductionArgs);
1561 iface.getUseDeviceAddrBlockArgs(),
1562 args.useDeviceAddrArgs);
1564 iface.getUseDevicePtrBlockArgs(), args.useDevicePtrArgs);
1580 AllRegionPrintArgs args;
1581 args.hasDeviceAddrArgs.emplace(hasDeviceAddrVars, hasDeviceAddrTypes);
1582 args.hostEvalArgs.emplace(hostEvalVars, hostEvalTypes);
1583 args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
1584 inReductionByref, inReductionSyms);
1585 args.mapArgs.emplace(mapVars, mapTypes);
1586 args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
1587 privateNeedsBarrier, privateMaps);
1595 ArrayAttr privateSyms, UnitAttr privateNeedsBarrier) {
1596 AllRegionPrintArgs args;
1597 args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
1598 inReductionByref, inReductionSyms);
1599 args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
1600 privateNeedsBarrier,
1609 ArrayAttr privateSyms, UnitAttr privateNeedsBarrier,
1610 ReductionModifierAttr reductionMod,
ValueRange reductionVars,
1613 AllRegionPrintArgs args;
1614 args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
1615 inReductionByref, inReductionSyms);
1616 args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
1617 privateNeedsBarrier,
1619 args.reductionArgs.emplace(reductionVars, reductionTypes, reductionByref,
1620 reductionSyms, reductionMod);
1627 UnitAttr privateNeedsBarrier) {
1628 AllRegionPrintArgs args;
1629 args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
1630 privateNeedsBarrier,
1638 ReductionModifierAttr reductionMod,
ValueRange reductionVars,
1641 AllRegionPrintArgs args;
1642 args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
1643 privateNeedsBarrier,
1645 args.reductionArgs.emplace(reductionVars, reductionTypes, reductionByref,
1646 reductionSyms, reductionMod);
1656 AllRegionPrintArgs args;
1657 args.taskReductionArgs.emplace(taskReductionVars, taskReductionTypes,
1658 taskReductionByref, taskReductionSyms);
1668 AllRegionPrintArgs args;
1669 args.useDeviceAddrArgs.emplace(useDeviceAddrVars, useDeviceAddrTypes);
1670 args.useDevicePtrArgs.emplace(useDevicePtrVars, useDevicePtrTypes);
1674template <
typename ParsePrefixFn>
1683 if (failed(parsePrefix()))
1691 if (llvm::isa<mlir::omp::IteratedType>(ty)) {
1692 iteratedVars.push_back(v);
1693 iteratedTypes.push_back(ty);
1695 plainVars.push_back(v);
1696 plainTypes.push_back(ty);
1702template <
typename Pr
intPrefixFn>
1706 PrintPrefixFn &&printPrefixForPlain,
1707 PrintPrefixFn &&printPrefixForIterated) {
1714 p << v <<
" : " << t;
1718 for (
unsigned i = 0; i < iteratedVars.size(); ++i)
1719 emit(iteratedVars[i], iteratedTypes[i], printPrefixForIterated);
1720 for (
unsigned i = 0; i < plainVars.size(); ++i)
1721 emit(plainVars[i], plainTypes[i], printPrefixForPlain);
1729 if (!reductionVars.empty()) {
1730 if (!reductionSyms || reductionSyms->size() != reductionVars.size())
1732 <<
"expected as many reduction symbol references "
1733 "as reduction variables";
1734 if (reductionByref && reductionByref->size() != reductionVars.size())
1735 return op->
emitError() <<
"expected as many reduction variable by "
1736 "reference attributes as reduction variables";
1739 return op->
emitOpError() <<
"unexpected reduction symbol references";
1746 for (
auto args : llvm::zip(reductionVars, *reductionSyms)) {
1747 Value accum = std::get<0>(args);
1749 if (!accumulators.insert(accum).second)
1750 return op->
emitOpError() <<
"accumulator variable used more than once";
1753 auto symbolRef = llvm::cast<SymbolRefAttr>(std::get<1>(args));
1757 return op->
emitOpError() <<
"expected symbol reference " << symbolRef
1758 <<
" to point to a reduction declaration";
1760 if (decl.getAccumulatorType() && decl.getAccumulatorType() != varType)
1762 <<
"expected accumulator (" << varType
1763 <<
") to be the same type as reduction declaration ("
1764 << decl.getAccumulatorType() <<
")";
1783 if (parser.parseOperand(copyprivateVars.emplace_back()) ||
1784 parser.parseArrow() ||
1785 parser.parseAttribute(symsVec.emplace_back()) ||
1786 parser.parseColonType(copyprivateTypes.emplace_back()))
1792 copyprivateSyms = ArrayAttr::get(parser.
getContext(), syms);
1800 std::optional<ArrayAttr> copyprivateSyms) {
1801 if (!copyprivateSyms.has_value())
1803 llvm::interleaveComma(
1804 llvm::zip(copyprivateVars, *copyprivateSyms, copyprivateTypes), p,
1805 [&](
const auto &args) {
1806 p << std::get<0>(args) <<
" -> " << std::get<1>(args) <<
" : "
1807 << std::get<2>(args);
1814 std::optional<ArrayAttr> copyprivateSyms) {
1815 size_t copyprivateSymsSize =
1816 copyprivateSyms.has_value() ? copyprivateSyms->size() : 0;
1817 if (copyprivateSymsSize != copyprivateVars.size())
1818 return op->
emitOpError() <<
"inconsistent number of copyprivate vars (= "
1819 << copyprivateVars.size()
1820 <<
") and functions (= " << copyprivateSymsSize
1821 <<
"), both must be equal";
1822 if (!copyprivateSyms.has_value())
1825 for (
auto copyprivateVarAndSym :
1826 llvm::zip(copyprivateVars, *copyprivateSyms)) {
1828 llvm::cast<SymbolRefAttr>(std::get<1>(copyprivateVarAndSym));
1829 std::optional<std::variant<mlir::func::FuncOp, mlir::LLVM::LLVMFuncOp>>
1831 if (mlir::func::FuncOp mlirFuncOp =
1834 funcOp = mlirFuncOp;
1835 else if (mlir::LLVM::LLVMFuncOp llvmFuncOp =
1838 funcOp = llvmFuncOp;
1840 auto getNumArguments = [&] {
1841 return std::visit([](
auto &f) {
return f.getNumArguments(); }, *funcOp);
1844 auto getArgumentType = [&](
unsigned i) {
1845 return std::visit([i](
auto &f) {
return f.getArgumentTypes()[i]; },
1850 return op->
emitOpError() <<
"expected symbol reference " << symbolRef
1851 <<
" to point to a copy function";
1853 if (getNumArguments() != 2)
1855 <<
"expected copy function " << symbolRef <<
" to have 2 operands";
1857 Type argTy = getArgumentType(0);
1858 if (argTy != getArgumentType(1))
1859 return op->
emitOpError() <<
"expected copy function " << symbolRef
1860 <<
" arguments to have the same type";
1862 Type varType = std::get<0>(copyprivateVarAndSym).getType();
1863 if (argTy != varType)
1865 <<
"expected copy function arguments' type (" << argTy
1866 <<
") to be the same as copyprivate variable's type (" << varType
1891 OpAsmParser::UnresolvedOperand operand;
1893 if (parser.parseKeyword(&keyword) || parser.parseArrow() ||
1894 parser.parseOperand(operand) || parser.parseColonType(ty))
1896 std::optional<ClauseTaskDepend> keywordDepend =
1897 symbolizeClauseTaskDepend(keyword);
1901 ClauseTaskDependAttr::get(parser.getContext(), *keywordDepend);
1902 if (llvm::isa<mlir::omp::IteratedType>(ty)) {
1903 iteratedVars.push_back(operand);
1904 iteratedTypes.push_back(ty);
1905 iterKindsVec.push_back(kindAttr);
1907 dependVars.push_back(operand);
1908 dependTypes.push_back(ty);
1909 kindsVec.push_back(kindAttr);
1915 dependKinds = ArrayAttr::get(parser.
getContext(), kinds);
1917 iteratedKinds = ArrayAttr::get(parser.
getContext(), iterKinds);
1924 std::optional<ArrayAttr> dependKinds,
1927 std::optional<ArrayAttr> iteratedKinds) {
1930 std::optional<ArrayAttr> kinds) {
1931 for (
unsigned i = 0, e = vars.size(); i < e; ++i) {
1934 p << stringifyClauseTaskDepend(
1935 llvm::cast<mlir::omp::ClauseTaskDependAttr>((*kinds)[i])
1937 <<
" -> " << vars[i] <<
" : " << types[i];
1941 printEntries(dependVars, dependTypes, dependKinds);
1942 printEntries(iteratedVars, iteratedTypes, iteratedKinds);
1947 std::optional<ArrayAttr> dependKinds,
1949 std::optional<ArrayAttr> iteratedKinds,
1951 if (!dependVars.empty()) {
1952 if (!dependKinds || dependKinds->size() != dependVars.size())
1953 return op->
emitOpError() <<
"expected as many depend values"
1954 " as depend variables";
1956 if (dependKinds && !dependKinds->empty())
1957 return op->
emitOpError() <<
"unexpected depend values";
1960 if (!iteratedVars.empty()) {
1961 if (!iteratedKinds || iteratedKinds->size() != iteratedVars.size())
1962 return op->
emitOpError() <<
"expected as many depend iterated values"
1963 " as depend iterated variables";
1965 if (iteratedKinds && !iteratedKinds->empty())
1966 return op->
emitOpError() <<
"unexpected depend iterated values";
1981 IntegerAttr &hintAttr) {
1982 StringRef hintKeyword;
1988 auto parseKeyword = [&]() -> ParseResult {
1991 if (hintKeyword ==
"uncontended")
1993 else if (hintKeyword ==
"contended")
1995 else if (hintKeyword ==
"nonspeculative")
1997 else if (hintKeyword ==
"speculative")
2001 << hintKeyword <<
" is not a valid hint";
2012 IntegerAttr hintAttr) {
2013 int64_t hint = hintAttr.getInt();
2021 auto bitn = [](
int value,
int n) ->
bool {
return value & (1 << n); };
2023 bool uncontended = bitn(hint, 0);
2024 bool contended = bitn(hint, 1);
2025 bool nonspeculative = bitn(hint, 2);
2026 bool speculative = bitn(hint, 3);
2030 hints.push_back(
"uncontended");
2032 hints.push_back(
"contended");
2034 hints.push_back(
"nonspeculative");
2036 hints.push_back(
"speculative");
2038 llvm::interleaveComma(hints, p);
2045 auto bitn = [](
int value,
int n) ->
bool {
return value & (1 << n); };
2047 bool uncontended = bitn(hint, 0);
2048 bool contended = bitn(hint, 1);
2049 bool nonspeculative = bitn(hint, 2);
2050 bool speculative = bitn(hint, 3);
2052 if (uncontended && contended)
2053 return op->
emitOpError() <<
"the hints omp_sync_hint_uncontended and "
2054 "omp_sync_hint_contended cannot be combined";
2055 if (nonspeculative && speculative)
2056 return op->
emitOpError() <<
"the hints omp_sync_hint_nonspeculative and "
2057 "omp_sync_hint_speculative cannot be combined.";
2068 return (value & flag) == flag;
2076static ParseResult parseMapClause(
OpAsmParser &parser,
2077 ClauseMapFlagsAttr &mapType) {
2078 ClauseMapFlags mapTypeBits = ClauseMapFlags::none;
2081 auto parseTypeAndMod = [&]() -> ParseResult {
2082 StringRef mapTypeMod;
2086 if (mapTypeMod ==
"always")
2087 mapTypeBits |= ClauseMapFlags::always;
2089 if (mapTypeMod ==
"implicit")
2090 mapTypeBits |= ClauseMapFlags::implicit;
2092 if (mapTypeMod ==
"ompx_hold")
2093 mapTypeBits |= ClauseMapFlags::ompx_hold;
2095 if (mapTypeMod ==
"close")
2096 mapTypeBits |= ClauseMapFlags::close;
2098 if (mapTypeMod ==
"present")
2099 mapTypeBits |= ClauseMapFlags::present;
2101 if (mapTypeMod ==
"to")
2102 mapTypeBits |= ClauseMapFlags::to;
2104 if (mapTypeMod ==
"from")
2105 mapTypeBits |= ClauseMapFlags::from;
2107 if (mapTypeMod ==
"tofrom")
2108 mapTypeBits |= ClauseMapFlags::to | ClauseMapFlags::from;
2110 if (mapTypeMod ==
"delete")
2111 mapTypeBits |= ClauseMapFlags::del;
2113 if (mapTypeMod ==
"storage")
2114 mapTypeBits |= ClauseMapFlags::storage;
2116 if (mapTypeMod ==
"return_param")
2117 mapTypeBits |= ClauseMapFlags::return_param;
2119 if (mapTypeMod ==
"private")
2120 mapTypeBits |= ClauseMapFlags::priv;
2122 if (mapTypeMod ==
"literal")
2123 mapTypeBits |= ClauseMapFlags::literal;
2125 if (mapTypeMod ==
"attach")
2126 mapTypeBits |= ClauseMapFlags::attach;
2128 if (mapTypeMod ==
"attach_always")
2129 mapTypeBits |= ClauseMapFlags::attach_always;
2131 if (mapTypeMod ==
"attach_never")
2132 mapTypeBits |= ClauseMapFlags::attach_never;
2134 if (mapTypeMod ==
"attach_auto")
2135 mapTypeBits |= ClauseMapFlags::attach_auto;
2137 if (mapTypeMod ==
"ref_ptr")
2138 mapTypeBits |= ClauseMapFlags::ref_ptr;
2140 if (mapTypeMod ==
"ref_ptee")
2141 mapTypeBits |= ClauseMapFlags::ref_ptee;
2143 if (mapTypeMod ==
"is_device_ptr")
2144 mapTypeBits |= ClauseMapFlags::is_device_ptr;
2161 ClauseMapFlagsAttr mapType) {
2163 ClauseMapFlags mapFlags = mapType.getValue();
2168 mapTypeStrs.push_back(
"always");
2170 mapTypeStrs.push_back(
"implicit");
2172 mapTypeStrs.push_back(
"ompx_hold");
2174 mapTypeStrs.push_back(
"close");
2176 mapTypeStrs.push_back(
"present");
2185 mapTypeStrs.push_back(
"tofrom");
2187 mapTypeStrs.push_back(
"from");
2189 mapTypeStrs.push_back(
"to");
2192 mapTypeStrs.push_back(
"delete");
2194 mapTypeStrs.push_back(
"return_param");
2196 mapTypeStrs.push_back(
"storage");
2198 mapTypeStrs.push_back(
"private");
2200 mapTypeStrs.push_back(
"literal");
2202 mapTypeStrs.push_back(
"attach");
2204 mapTypeStrs.push_back(
"attach_always");
2206 mapTypeStrs.push_back(
"attach_never");
2208 mapTypeStrs.push_back(
"attach_auto");
2210 mapTypeStrs.push_back(
"ref_ptr");
2212 mapTypeStrs.push_back(
"ref_ptee");
2214 mapTypeStrs.push_back(
"is_device_ptr");
2215 if (mapFlags == ClauseMapFlags::none)
2216 mapTypeStrs.push_back(
"none");
2218 for (
unsigned int i = 0; i < mapTypeStrs.size(); ++i) {
2219 p << mapTypeStrs[i];
2220 if (i + 1 < mapTypeStrs.size()) {
2226static ParseResult parseMembersIndex(
OpAsmParser &parser,
2230 auto parseIndices = [&]() -> ParseResult {
2235 APInt(64, value,
false)));
2249 memberIdxs.push_back(ArrayAttr::get(parser.
getContext(), values));
2253 if (!memberIdxs.empty())
2254 membersIdx = ArrayAttr::get(parser.
getContext(), memberIdxs);
2264 llvm::interleaveComma(membersIdx, p, [&p](
Attribute v) {
2266 auto memberIdx = cast<ArrayAttr>(v);
2267 llvm::interleaveComma(memberIdx.getValue(), p, [&p](
Attribute v2) {
2268 p << cast<IntegerAttr>(v2).getInt();
2275 VariableCaptureKindAttr mapCaptureType) {
2276 std::string typeCapStr;
2277 llvm::raw_string_ostream typeCap(typeCapStr);
2278 if (mapCaptureType.getValue() == mlir::omp::VariableCaptureKind::ByRef)
2280 if (mapCaptureType.getValue() == mlir::omp::VariableCaptureKind::ByCopy)
2281 typeCap <<
"ByCopy";
2282 if (mapCaptureType.getValue() == mlir::omp::VariableCaptureKind::VLAType)
2283 typeCap <<
"VLAType";
2284 if (mapCaptureType.getValue() == mlir::omp::VariableCaptureKind::This)
2290 VariableCaptureKindAttr &mapCaptureType) {
2291 StringRef mapCaptureKey;
2295 if (mapCaptureKey ==
"This")
2296 mapCaptureType = mlir::omp::VariableCaptureKindAttr::get(
2297 parser.
getContext(), mlir::omp::VariableCaptureKind::This);
2298 if (mapCaptureKey ==
"ByRef")
2299 mapCaptureType = mlir::omp::VariableCaptureKindAttr::get(
2300 parser.
getContext(), mlir::omp::VariableCaptureKind::ByRef);
2301 if (mapCaptureKey ==
"ByCopy")
2302 mapCaptureType = mlir::omp::VariableCaptureKindAttr::get(
2303 parser.
getContext(), mlir::omp::VariableCaptureKind::ByCopy);
2304 if (mapCaptureKey ==
"VLAType")
2305 mapCaptureType = mlir::omp::VariableCaptureKindAttr::get(
2306 parser.
getContext(), mlir::omp::VariableCaptureKind::VLAType);
2312 Operation *op, mlir::omp::MapInfoOp mapInfoOp,
2316 mlir::omp::ClauseMapFlags mapTypeBits = mapInfoOp.getMapType();
2319 bool from =
mapTypeToBool(mapTypeBits, ClauseMapFlags::from);
2322 bool always =
mapTypeToBool(mapTypeBits, ClauseMapFlags::always);
2323 bool close =
mapTypeToBool(mapTypeBits, ClauseMapFlags::close);
2324 bool implicit =
mapTypeToBool(mapTypeBits, ClauseMapFlags::implicit);
2325 bool attach =
mapTypeToBool(mapTypeBits, ClauseMapFlags::attach);
2327 if ((isa<TargetDataOp>(op) || isa<TargetOp>(op)) && del)
2329 "to, from, tofrom and alloc map types are permitted");
2331 if (isa<TargetEnterDataOp>(op) && (from || del))
2332 return emitError(op->
getLoc(),
"to and alloc map types are permitted");
2334 if (isa<TargetExitDataOp>(op) && to)
2336 "from, release and delete map types are permitted");
2338 if (isa<TargetUpdateOp>(op)) {
2341 "at least one of to or from map types must be "
2342 "specified, other map types are not permitted");
2345 if (!to && !from && !attach) {
2347 "at least one of to or from or attach map types must be "
2348 "specified, other map types are not permitted");
2351 auto updateVar = mapInfoOp.getVarPtr();
2353 if ((to && from) || (to && updateFromVars.contains(updateVar)) ||
2354 (from && updateToVars.contains(updateVar))) {
2357 "either to or from map types can be specified, not both");
2360 if (always || close || implicit) {
2363 "present, mapper and iterator map type modifiers are permitted");
2369 to ? updateToVars.insert(updateVar) : updateFromVars.insert(updateVar);
2373 if ((mapInfoOp.getVarPtrPtr() && !mapInfoOp.getVarPtrPtrType()) ||
2374 (!mapInfoOp.getVarPtrPtr() && mapInfoOp.getVarPtrPtrType())) {
2376 "if varPtrPtr or varPtrPtrType is specified, then both "
2388 for (
auto mapOp : mapVars) {
2389 if (!mapOp.getDefiningOp())
2392 if (
auto mapInfoOp = mapOp.getDefiningOp<mlir::omp::MapInfoOp>()) {
2396 }
else if (!isa<DeclareMapperInfoOp>(op)) {
2398 "map argument is not a map entry operation");
2403 for (
auto iterVal : mapIterated) {
2404 auto iterOp = iterVal.getDefiningOp<mlir::omp::IteratorOp>();
2406 return op->
emitOpError() <<
"'map_iterated' arguments must be defined by "
2407 "'omp.iterator' ops";
2411 cast<mlir::omp::YieldOp>(iterOp.getRegion().front().getTerminator());
2412 auto yieldedMapInfo =
2413 yieldOp.getResults()[0].getDefiningOp<mlir::omp::MapInfoOp>();
2414 if (!yieldedMapInfo)
2415 return op->
emitOpError() <<
"'map_iterated' iterator body must yield "
2416 "a value defined by 'omp.map.info'";
2426template <
typename OpType>
2430 std::optional<DenseI64ArrayAttr> privateMapIndices =
2431 targetOp.getPrivateMapsAttr();
2434 if (!privateMapIndices.has_value() || !privateMapIndices.value())
2439 if (privateMapIndices.value().size() !=
2440 static_cast<int64_t>(privateVars.size()))
2441 return emitError(targetOp.getLoc(),
"sizes of `private` operand range and "
2442 "`private_maps` attribute mismatch");
2452 StringRef clauseName,
2454 for (
Value var : vars)
2455 if (!llvm::isa_and_present<MapInfoOp>(var.getDefiningOp()))
2457 <<
"'" << clauseName
2458 <<
"' arguments must be defined by 'omp.map.info' ops";
2462LogicalResult MapInfoOp::verify() {
2463 if (getMapperId() &&
2465 *
this, getMapperIdAttr())) {
2480 const TargetDataOperands &clauses) {
2481 TargetDataOp::build(builder, state, clauses.device, clauses.ifExpr,
2482 clauses.mapVars, clauses.mapIterated,
2483 clauses.useDeviceAddrVars, clauses.useDevicePtrVars);
2486LogicalResult TargetDataOp::verify() {
2487 if (getMapVars().empty() && getMapIterated().empty() &&
2488 getUseDevicePtrVars().empty() && getUseDeviceAddrVars().empty()) {
2489 return ::emitError(this->getLoc(),
2490 "At least one of map, use_device_ptr_vars, or "
2491 "use_device_addr_vars operand must be present");
2495 getUseDevicePtrVars())))
2499 getUseDeviceAddrVars())))
2509void TargetEnterDataOp::build(
2513 TargetEnterDataOp::build(
2515 clauses.dependVars,
makeArrayAttr(ctx, clauses.dependIteratedKinds),
2516 clauses.dependIterated, clauses.device, clauses.ifExpr, clauses.mapVars,
2517 clauses.mapIterated, clauses.nowait);
2520LogicalResult TargetEnterDataOp::verify() {
2521 LogicalResult verifyDependVars =
2523 getDependIteratedKinds(), getDependIterated());
2524 return failed(verifyDependVars)
2536 TargetExitDataOp::build(
2538 clauses.dependVars,
makeArrayAttr(ctx, clauses.dependIteratedKinds),
2539 clauses.dependIterated, clauses.device, clauses.ifExpr, clauses.mapVars,
2540 clauses.mapIterated, clauses.nowait);
2543LogicalResult TargetExitDataOp::verify() {
2544 LogicalResult verifyDependVars =
2546 getDependIteratedKinds(), getDependIterated());
2547 return failed(verifyDependVars)
2559 TargetUpdateOp::build(builder, state,
makeArrayAttr(ctx, clauses.dependKinds),
2562 clauses.dependIterated, clauses.device, clauses.ifExpr,
2563 clauses.mapVars, clauses.mapIterated, clauses.nowait);
2566LogicalResult TargetUpdateOp::verify() {
2567 LogicalResult verifyDependVars =
2569 getDependIteratedKinds(), getDependIterated());
2570 return failed(verifyDependVars)
2580 const TargetOperands &clauses) {
2585 builder, state, {}, {}, clauses.bare,
2586 makeArrayAttr(ctx, clauses.dependKinds), clauses.dependVars,
2587 makeArrayAttr(ctx, clauses.dependIteratedKinds), clauses.dependIterated,
2588 clauses.device, clauses.dynGroupprivateAccessGroup,
2589 clauses.dynGroupprivateFallback, clauses.dynGroupprivateSize,
2590 clauses.hasDeviceAddrVars, clauses.hostEvalVars, clauses.ifExpr,
2592 nullptr, clauses.isDevicePtrVars, clauses.mapVars,
2593 clauses.mapIterated, clauses.nowait, clauses.privateVars,
2594 makeArrayAttr(ctx, clauses.privateSyms), clauses.privateNeedsBarrier,
2595 clauses.threadLimitVars,
2599LogicalResult TargetOp::verify() {
2601 getDependIteratedKinds(),
2602 getDependIterated())))
2606 getHasDeviceAddrVars())))
2613 *
this, getDynGroupprivateAccessGroupAttr(),
2614 getDynGroupprivateFallbackAttr(), getDynGroupprivateSize())))
2623LogicalResult TargetOp::verifyRegions() {
2624 auto teamsOps = getOps<TeamsOp>();
2625 if (std::distance(teamsOps.begin(), teamsOps.end()) > 1)
2626 return emitError(
"target containing multiple 'omp.teams' nested ops");
2629 bool hostEvalTripCount;
2630 Operation *capturedOp = getInnermostCapturedOmpOp();
2631 TargetExecMode execMode = getKernelExecFlags(capturedOp, &hostEvalTripCount);
2632 for (
Value hostEvalArg :
2633 cast<BlockArgOpenMPOpInterface>(getOperation()).getHostEvalBlockArgs()) {
2635 if (
auto teamsOp = dyn_cast<TeamsOp>(user)) {
2637 if (hostEvalArg == teamsOp.getNumTeamsLower() ||
2638 llvm::is_contained(teamsOp.getNumTeamsUpperVars(), hostEvalArg) ||
2639 llvm::is_contained(teamsOp.getThreadLimitVars(), hostEvalArg))
2642 return emitOpError() <<
"host_eval argument only legal as 'num_teams' "
2643 "and 'thread_limit' in 'omp.teams'";
2645 if (
auto parallelOp = dyn_cast<ParallelOp>(user)) {
2646 if (execMode == TargetExecMode::spmd &&
2647 parallelOp->isAncestor(capturedOp) &&
2648 llvm::is_contained(parallelOp.getNumThreadsVars(), hostEvalArg))
2652 <<
"host_eval argument only legal as 'num_threads' in "
2653 "'omp.parallel' when representing target SPMD";
2655 if (
auto loopNestOp = dyn_cast<LoopNestOp>(user)) {
2656 if (hostEvalTripCount && loopNestOp.getOperation() == capturedOp &&
2657 (llvm::is_contained(loopNestOp.getLoopLowerBounds(), hostEvalArg) ||
2658 llvm::is_contained(loopNestOp.getLoopUpperBounds(), hostEvalArg) ||
2659 llvm::is_contained(loopNestOp.getLoopSteps(), hostEvalArg)))
2662 return emitOpError() <<
"host_eval argument only legal as loop bounds "
2663 "and steps in 'omp.loop_nest' when trip count "
2664 "must be evaluated in the host";
2667 return emitOpError() <<
"host_eval argument illegal use in '"
2668 << user->getName() <<
"' operation";
2677 assert(rootOp &&
"expected valid operation");
2694 bool isOmpDialect = op->
getDialect() == ompDialect;
2696 if (!isOmpDialect || !hasRegions)
2703 if (checkSingleMandatoryExec) {
2708 if (successor->isReachable(parentBlock))
2711 for (
Block &block : *parentRegion)
2713 !domInfo.
dominates(parentBlock, &block))
2720 if (&sibling != op && !siblingAllowedFn(&sibling))
2733Operation *TargetOp::getInnermostCapturedOmpOp() {
2734 auto *ompDialect =
getContext()->getLoadedDialect<omp::OpenMPDialect>();
2746 if (
auto memOp = dyn_cast<MemoryEffectOpInterface>(sibling)) {
2749 memOp.getEffects(effects);
2750 return !llvm::any_of(
2752 return isa<MemoryEffects::Write>(effect.
getEffect()) &&
2753 isa<SideEffects::AutomaticAllocationScopeResource>(
2763 WsloopOp *wsLoopOp) {
2765 if (!teamsOp.getNumTeamsUpperVars().empty())
2769 if (teamsOp.getNumReductionVars())
2771 if (wsLoopOp->getNumReductionVars())
2775 OffloadModuleInterface offloadMod =
2779 auto ompFlags = offloadMod.getFlags();
2782 return ompFlags.getAssumeTeamsOversubscription() &&
2783 ompFlags.getAssumeThreadsOversubscription();
2786TargetExecMode TargetOp::getKernelExecFlags(
Operation *capturedOp,
2787 bool *hostEvalTripCount) {
2793 assert((!capturedOp ||
2794 (targetOp && targetOp.getInnermostCapturedOmpOp() == capturedOp)) &&
2795 "unexpected captured op");
2797 if (hostEvalTripCount)
2798 *hostEvalTripCount =
false;
2801 if (!isa_and_present<LoopNestOp>(capturedOp))
2802 return TargetExecMode::generic;
2806 cast<LoopNestOp>(capturedOp).gatherWrappers(loopWrappers);
2807 assert(!loopWrappers.empty());
2809 LoopWrapperInterface *innermostWrapper = loopWrappers.begin();
2810 if (isa<SimdOp>(innermostWrapper))
2811 innermostWrapper = std::next(innermostWrapper);
2813 auto numWrappers = std::distance(innermostWrapper, loopWrappers.end());
2814 if (numWrappers != 1 && numWrappers != 2)
2815 return TargetExecMode::generic;
2818 if (numWrappers == 2) {
2819 WsloopOp *wsloopOp = dyn_cast<WsloopOp>(innermostWrapper);
2821 return TargetExecMode::generic;
2823 innermostWrapper = std::next(innermostWrapper);
2824 if (!isa<DistributeOp>(innermostWrapper))
2825 return TargetExecMode::generic;
2828 if (!isa_and_present<ParallelOp>(parallelOp))
2829 return TargetExecMode::generic;
2831 TeamsOp teamsOp = dyn_cast<TeamsOp>(parallelOp->
getParentOp());
2833 return TargetExecMode::generic;
2835 if (teamsOp->getParentOp() == targetOp.getOperation()) {
2836 TargetExecMode
result = TargetExecMode::spmd;
2838 result = TargetExecMode::no_loop;
2839 if (hostEvalTripCount)
2840 *hostEvalTripCount =
true;
2845 else if (isa<DistributeOp, LoopOp>(innermostWrapper)) {
2847 if (!isa_and_present<TeamsOp>(teamsOp))
2848 return TargetExecMode::generic;
2850 if (teamsOp->
getParentOp() != targetOp.getOperation())
2851 return TargetExecMode::generic;
2853 if (hostEvalTripCount)
2854 *hostEvalTripCount =
true;
2856 if (isa<LoopOp>(innermostWrapper))
2857 return TargetExecMode::spmd;
2859 return TargetExecMode::generic;
2862 else if (isa<WsloopOp>(innermostWrapper)) {
2864 if (!isa_and_present<ParallelOp>(parallelOp))
2865 return TargetExecMode::generic;
2867 if (parallelOp->
getParentOp() == targetOp.getOperation())
2868 return TargetExecMode::spmd;
2871 return TargetExecMode::generic;
2880 ParallelOp::build(builder, state,
ValueRange(),
2892 const ParallelOperands &clauses) {
2894 ParallelOp::build(builder, state, clauses.allocateVars, clauses.allocatorVars,
2895 clauses.ifExpr, clauses.numThreadsVars, clauses.privateVars,
2897 clauses.privateNeedsBarrier, clauses.procBindKind,
2898 clauses.reductionMod, clauses.reductionVars,
2903template <
typename OpType>
2905 auto privateVars = op.getPrivateVars();
2906 auto privateSyms = op.getPrivateSymsAttr();
2908 if (privateVars.empty() && (privateSyms ==
nullptr || privateSyms.empty()))
2911 auto numPrivateVars = privateVars.size();
2912 auto numPrivateSyms = (privateSyms ==
nullptr) ? 0 : privateSyms.size();
2914 if (numPrivateVars != numPrivateSyms)
2915 return op.emitError() <<
"inconsistent number of private variables and "
2916 "privatizer op symbols, private vars: "
2918 <<
" vs. privatizer op symbols: " << numPrivateSyms;
2920 for (
auto privateVarInfo : llvm::zip_equal(privateVars, privateSyms)) {
2921 Type varType = std::get<0>(privateVarInfo).getType();
2922 SymbolRefAttr privateSym = cast<SymbolRefAttr>(std::get<1>(privateVarInfo));
2923 PrivateClauseOp privatizerOp =
2926 if (privatizerOp ==
nullptr)
2927 return op.emitError() <<
"failed to lookup privatizer op with symbol: '"
2928 << privateSym <<
"'";
2930 Type privatizerType = privatizerOp.getArgType();
2932 if (privatizerType && (varType != privatizerType))
2933 return op.emitError()
2934 <<
"type mismatch between a "
2935 << (privatizerOp.getDataSharingType() ==
2936 DataSharingClauseType::Private
2939 <<
" variable and its privatizer op, var type: " << varType
2940 <<
" vs. privatizer op type: " << privatizerType;
2946LogicalResult ParallelOp::verify() {
2947 if (getAllocateVars().size() != getAllocatorVars().size())
2949 "expected equal sizes for allocate and allocator variables");
2955 getReductionByref());
2958LogicalResult ParallelOp::verifyRegions() {
2959 auto distChildOps = getOps<DistributeOp>();
2960 int numDistChildOps = std::distance(distChildOps.begin(), distChildOps.end());
2961 if (numDistChildOps > 1)
2963 <<
"multiple 'omp.distribute' nested inside of 'omp.parallel'";
2965 if (numDistChildOps == 1) {
2968 <<
"'omp.composite' attribute missing from composite operation";
2970 auto *ompDialect =
getContext()->getLoadedDialect<OpenMPDialect>();
2971 Operation &distributeOp = **distChildOps.begin();
2973 if (&childOp == &distributeOp || ompDialect != childOp.getDialect())
2977 return emitError() <<
"unexpected OpenMP operation inside of composite "
2979 << childOp.getName();
2981 }
else if (isComposite()) {
2983 <<
"'omp.composite' attribute present in non-composite operation";
3000 const TeamsOperands &clauses) {
3004 builder, state, clauses.allocateVars, clauses.allocatorVars,
3005 clauses.dynGroupprivateAccessGroup, clauses.dynGroupprivateFallback,
3006 clauses.dynGroupprivateSize, clauses.ifExpr, clauses.numTeamsLower,
3007 clauses.numTeamsUpperVars, {},
nullptr,
3008 nullptr, clauses.reductionMod,
3009 clauses.reductionVars,
3011 makeArrayAttr(ctx, clauses.reductionSyms), clauses.threadLimitVars);
3018 if (numTeamsLower) {
3019 if (numTeamsUpperVars.size() != 1)
3021 "expected exactly one num_teams upper bound when lower bound is "
3025 "expected num_teams upper bound and lower bound to be "
3032LogicalResult TeamsOp::verify() {
3041 return emitError(
"expected to be nested inside of omp.target or not nested "
3042 "in any OpenMP dialect operations");
3046 this->getNumTeamsUpperVars())))
3050 if (getAllocateVars().size() != getAllocatorVars().size())
3052 "expected equal sizes for allocate and allocator variables");
3055 op, getDynGroupprivateAccessGroupAttr(),
3056 getDynGroupprivateFallbackAttr(), getDynGroupprivateSize())))
3063 getReductionByref());
3071 return getParentOp().getPrivateVars();
3075 return getParentOp().getReductionVars();
3083 const SectionsOperands &clauses) {
3086 SectionsOp::build(builder, state, clauses.allocateVars, clauses.allocatorVars,
3089 clauses.reductionMod, clauses.reductionVars,
3094LogicalResult SectionsOp::verify() {
3095 if (getAllocateVars().size() != getAllocatorVars().size())
3097 "expected equal sizes for allocate and allocator variables");
3100 getReductionByref());
3103LogicalResult SectionsOp::verifyRegions() {
3104 for (
auto &inst : *getRegion().begin()) {
3105 if (!(isa<SectionOp>(inst) || isa<TerminatorOp>(inst))) {
3107 <<
"expected omp.section op or terminator op inside region";
3119 const ScopeOperands &clauses) {
3121 ScopeOp::build(builder, state, clauses.allocateVars, clauses.allocatorVars,
3122 clauses.nowait, clauses.privateVars,
3124 clauses.privateNeedsBarrier, clauses.reductionMod,
3125 clauses.reductionVars,
3130LogicalResult ScopeOp::verify() {
3131 if (getAllocateVars().size() != getAllocatorVars().size())
3133 "expected equal sizes for allocate and allocator variables");
3139 getReductionByref());
3147 const SingleOperands &clauses) {
3150 SingleOp::build(builder, state, clauses.allocateVars, clauses.allocatorVars,
3151 clauses.copyprivateVars,
3152 makeArrayAttr(ctx, clauses.copyprivateSyms), clauses.nowait,
3157LogicalResult SingleOp::verify() {
3159 if (getAllocateVars().size() != getAllocatorVars().size())
3161 "expected equal sizes for allocate and allocator variables");
3164 getCopyprivateSyms());
3172 const WorkshareOperands &clauses) {
3173 WorkshareOp::build(builder, state, clauses.nowait);
3180LogicalResult WorkshareLoopWrapperOp::verifyRegions() {
3181 if (isa_and_nonnull<LoopWrapperInterface>((*this)->getParentOp()) ||
3183 return emitOpError() <<
"expected to be a standalone loop wrapper";
3192LogicalResult LoopWrapperInterface::verifyImpl() {
3196 return emitOpError() <<
"loop wrapper must also have the `NoTerminator` "
3197 "and `SingleBlock` traits";
3200 return emitOpError() <<
"loop wrapper does not contain exactly one region";
3203 if (range_size(region.
getOps()) != 1)
3205 <<
"loop wrapper does not contain exactly one nested op";
3208 if (!isa<LoopNestOp, LoopWrapperInterface>(firstOp))
3209 return emitOpError() <<
"nested in loop wrapper is not another loop "
3210 "wrapper or `omp.loop_nest`";
3220 const LoopOperands &clauses) {
3223 LoopOp::build(builder, state, clauses.bindKind, clauses.privateVars,
3225 clauses.privateNeedsBarrier, clauses.order, clauses.orderMod,
3226 clauses.reductionMod, clauses.reductionVars,
3231LogicalResult LoopOp::verify() {
3236 getReductionByref());
3239LogicalResult LoopOp::verifyRegions() {
3240 if (llvm::isa_and_nonnull<LoopWrapperInterface>((*this)->getParentOp()) ||
3242 return emitOpError() <<
"expected to be a standalone loop wrapper";
3253 build(builder, state, {}, {},
3256 false,
nullptr,
nullptr,
3257 nullptr, {},
nullptr,
3268 const WsloopOperands &clauses) {
3273 {}, {}, clauses.linearVars,
3274 clauses.linearStepVars, clauses.linearVarTypes, clauses.linearModifiers,
3275 clauses.nowait, clauses.order, clauses.orderMod, clauses.ordered,
3276 clauses.privateVars,
makeArrayAttr(ctx, clauses.privateSyms),
3277 clauses.privateNeedsBarrier, clauses.reductionMod, clauses.reductionVars,
3279 makeArrayAttr(ctx, clauses.reductionSyms), clauses.scheduleKind,
3280 clauses.scheduleChunk, clauses.scheduleMod, clauses.scheduleSimd);
3283LogicalResult WsloopOp::verify() {
3287 if (getLinearVars().size() &&
3288 getLinearVarTypes().value().size() != getLinearVars().size())
3289 return emitError() <<
"Ill-formed type attributes for linear variables";
3295 getReductionByref());
3298LogicalResult WsloopOp::verifyRegions() {
3299 bool isCompositeChildLeaf =
3300 llvm::dyn_cast_if_present<LoopWrapperInterface>((*this)->getParentOp());
3302 if (LoopWrapperInterface nested = getNestedWrapper()) {
3305 <<
"'omp.composite' attribute missing from composite wrapper";
3309 if (!isa<SimdOp>(nested))
3310 return emitError() <<
"only supported nested wrapper is 'omp.simd'";
3312 }
else if (isComposite() && !isCompositeChildLeaf) {
3314 <<
"'omp.composite' attribute present in non-composite wrapper";
3315 }
else if (!isComposite() && isCompositeChildLeaf) {
3317 <<
"'omp.composite' attribute missing from composite wrapper";
3328 const SimdOperands &clauses) {
3330 SimdOp::build(builder, state, clauses.alignedVars,
3332 clauses.linearVars, clauses.linearStepVars,
3333 clauses.linearVarTypes, clauses.linearModifiers,
3334 clauses.nontemporalVars, clauses.order, clauses.orderMod,
3335 clauses.privateVars,
makeArrayAttr(ctx, clauses.privateSyms),
3336 clauses.privateNeedsBarrier, clauses.reductionMod,
3337 clauses.reductionVars,
3343LogicalResult SimdOp::verify() {
3344 if (getSimdlen().has_value() && getSafelen().has_value() &&
3345 getSimdlen().value() > getSafelen().value())
3347 <<
"simdlen clause and safelen clause are both present, but the "
3348 "simdlen value is not less than or equal to safelen value";
3360 bool isCompositeChildLeaf =
3361 llvm::dyn_cast_if_present<LoopWrapperInterface>((*this)->getParentOp());
3363 if (!isComposite() && isCompositeChildLeaf)
3365 <<
"'omp.composite' attribute missing from composite wrapper";
3367 if (isComposite() && !isCompositeChildLeaf)
3369 <<
"'omp.composite' attribute present in non-composite wrapper";
3373 std::optional<ArrayAttr> privateSyms = getPrivateSyms();
3375 for (
const Attribute &sym : *privateSyms) {
3376 auto symRef = cast<SymbolRefAttr>(sym);
3377 omp::PrivateClauseOp privatizer =
3379 getOperation(), symRef);
3381 return emitError() <<
"Cannot find privatizer '" << symRef <<
"'";
3382 if (privatizer.getDataSharingType() ==
3383 DataSharingClauseType::FirstPrivate)
3384 return emitError() <<
"FIRSTPRIVATE cannot be used with SIMD";
3391 if (getLinearVars().size() &&
3392 getLinearVarTypes().value().size() != getLinearVars().size())
3393 return emitError() <<
"Ill-formed type attributes for linear variables";
3397LogicalResult SimdOp::verifyRegions() {
3398 if (getNestedWrapper())
3399 return emitOpError() <<
"must wrap an 'omp.loop_nest' directly";
3409 const DistributeOperands &clauses) {
3410 DistributeOp::build(builder, state, clauses.allocateVars,
3411 clauses.allocatorVars, clauses.distScheduleStatic,
3412 clauses.distScheduleChunkSize, clauses.order,
3413 clauses.orderMod, clauses.privateVars,
3415 clauses.privateNeedsBarrier);
3418LogicalResult DistributeOp::verify() {
3419 if (this->getDistScheduleChunkSize() && !this->getDistScheduleStatic())
3421 "dist_schedule_static being present";
3423 if (getAllocateVars().size() != getAllocatorVars().size())
3425 "expected equal sizes for allocate and allocator variables");
3433LogicalResult DistributeOp::verifyRegions() {
3434 if (LoopWrapperInterface nested = getNestedWrapper()) {
3437 <<
"'omp.composite' attribute missing from composite wrapper";
3440 if (isa<WsloopOp>(nested)) {
3442 if (!llvm::dyn_cast_if_present<ParallelOp>(parentOp) ||
3443 !cast<ComposableOpInterface>(parentOp).isComposite()) {
3444 return emitError() <<
"an 'omp.wsloop' nested wrapper is only allowed "
3445 "when a composite 'omp.parallel' is the direct "
3448 }
else if (!isa<SimdOp>(nested))
3449 return emitError() <<
"only supported nested wrappers are 'omp.simd' and "
3451 }
else if (isComposite()) {
3453 <<
"'omp.composite' attribute present in non-composite wrapper";
3464 const DeclareMapperInfoOperands &clauses) {
3465 DeclareMapperInfoOp::build(builder, state, clauses.mapVars,
3466 clauses.mapIterated);
3469LogicalResult DeclareMapperInfoOp::verify() {
3473LogicalResult DeclareMapperOp::verifyRegions() {
3474 if (!llvm::isa_and_present<DeclareMapperInfoOp>(
3475 getRegion().getBlocks().front().getTerminator()))
3476 return emitOpError() <<
"expected terminator to be a DeclareMapperInfoOp";
3485LogicalResult DeclareReductionOp::verifyRegions() {
3486 if (!getAllocRegion().empty()) {
3487 for (YieldOp yieldOp : getAllocRegion().getOps<YieldOp>()) {
3488 if (yieldOp.getResults().size() != 1 ||
3489 yieldOp.getResults().getTypes()[0] !=
getType())
3490 return emitOpError() <<
"expects alloc region to yield a value "
3491 "of the reduction type";
3495 if (getInitializerRegion().empty())
3496 return emitOpError() <<
"expects non-empty initializer region";
3497 Block &initializerEntryBlock = getInitializerRegion().
front();
3500 if (!getAllocRegion().empty())
3501 return emitOpError() <<
"expects two arguments to the initializer region "
3502 "when an allocation region is used";
3504 if (getAllocRegion().empty())
3505 return emitOpError() <<
"expects one argument to the initializer region "
3506 "when no allocation region is used";
3509 <<
"expects one or two arguments to the initializer region";
3513 if (arg.getType() !=
getType())
3514 return emitOpError() <<
"expects initializer region argument to match "
3515 "the reduction type";
3517 for (YieldOp yieldOp : getInitializerRegion().getOps<YieldOp>()) {
3518 if (yieldOp.getResults().size() != 1 ||
3519 yieldOp.getResults().getTypes()[0] !=
getType())
3520 return emitOpError() <<
"expects initializer region to yield a value "
3521 "of the reduction type";
3524 if (getReductionRegion().empty())
3525 return emitOpError() <<
"expects non-empty reduction region";
3526 Block &reductionEntryBlock = getReductionRegion().
front();
3531 return emitOpError() <<
"expects reduction region with two arguments of "
3532 "the reduction type";
3533 for (YieldOp yieldOp : getReductionRegion().getOps<YieldOp>()) {
3534 if (yieldOp.getResults().size() != 1 ||
3535 yieldOp.getResults().getTypes()[0] !=
getType())
3536 return emitOpError() <<
"expects reduction region to yield a value "
3537 "of the reduction type";
3540 if (!getAtomicReductionRegion().empty()) {
3541 Block &atomicReductionEntryBlock = getAtomicReductionRegion().
front();
3545 return emitOpError() <<
"expects atomic reduction region with two "
3546 "arguments of the same type";
3547 auto ptrType = llvm::dyn_cast<PointerLikeType>(
3550 (ptrType.getElementType() && ptrType.getElementType() !=
getType()))
3551 return emitOpError() <<
"expects atomic reduction region arguments to "
3552 "be accumulators containing the reduction type";
3555 if (getCleanupRegion().empty())
3557 Block &cleanupEntryBlock = getCleanupRegion().
front();
3560 return emitOpError() <<
"expects cleanup region with one argument "
3561 "of the reduction type";
3571 const TaskOperands &clauses) {
3574 builder, state, clauses.iterated, clauses.affinityVars,
3575 clauses.allocateVars, clauses.allocatorVars,
3576 makeArrayAttr(ctx, clauses.dependKinds), clauses.dependVars,
3577 makeArrayAttr(ctx, clauses.dependIteratedKinds), clauses.dependIterated,
3578 clauses.final, clauses.ifExpr, clauses.inReductionVars,
3580 makeArrayAttr(ctx, clauses.inReductionSyms), clauses.mergeable,
3581 clauses.priority, clauses.privateVars,
3583 clauses.privateNeedsBarrier, clauses.untied, clauses.eventHandle);
3586LogicalResult TaskOp::verify() {
3587 LogicalResult verifyDependVars =
3589 getDependIteratedKinds(), getDependIterated());
3590 if (
failed(verifyDependVars))
3591 return verifyDependVars;
3597 getInReductionVars(), getInReductionByref());
3605 const TaskgroupOperands &clauses) {
3607 TaskgroupOp::build(builder, state, clauses.allocateVars,
3608 clauses.allocatorVars, clauses.taskReductionVars,
3613LogicalResult TaskgroupOp::verify() {
3615 getTaskReductionVars(),
3616 getTaskReductionByref());
3624 const TaskloopContextOperands &clauses) {
3626 TaskloopContextOp::build(
3627 builder, state, clauses.allocateVars, clauses.allocatorVars,
3628 clauses.final, clauses.grainsizeMod, clauses.grainsize, clauses.ifExpr,
3629 clauses.inReductionVars,
3631 makeArrayAttr(ctx, clauses.inReductionSyms), clauses.mergeable,
3632 clauses.nogroup, clauses.numTasksMod, clauses.numTasks, clauses.priority,
3633 clauses.privateVars,
3635 clauses.privateNeedsBarrier, clauses.reductionMod, clauses.reductionVars,
3640TaskloopWrapperOp TaskloopContextOp::getLoopOp() {
3641 return cast<TaskloopWrapperOp>(
3643 return isa<TaskloopWrapperOp>(op);
3647LogicalResult TaskloopContextOp::verify() {
3648 if (getAllocateVars().size() != getAllocatorVars().size())
3650 "expected equal sizes for allocate and allocator variables");
3656 getReductionVars(), getReductionByref())) ||
3658 getInReductionVars(),
3659 getInReductionByref())))
3662 if (!getReductionVars().empty() && getNogroup())
3663 return emitError(
"if a reduction clause is present on the taskloop "
3664 "directive, the nogroup clause must not be specified");
3665 for (
auto var : getReductionVars()) {
3666 if (llvm::is_contained(getInReductionVars(), var))
3667 return emitError(
"the same list item cannot appear in both a reduction "
3668 "and an in_reduction clause");
3671 if (getGrainsize() && getNumTasks()) {
3673 "the grainsize clause and num_tasks clause are mutually exclusive and "
3674 "may not appear on the same taskloop directive");
3680LogicalResult TaskloopContextOp::verifyRegions() {
3681 Region ®ion = getRegion();
3683 return emitOpError() <<
"expected non-empty region";
3686 return isa<TaskloopWrapperOp>(op);
3690 <<
"expected exactly 1 TaskloopWrapperOp directly nested in "
3692 << count <<
" were found";
3693 TaskloopWrapperOp loopWrapperOp = getLoopOp();
3695 auto loopNestOp = dyn_cast<LoopNestOp>(loopWrapperOp.getWrappedLoop());
3701 std::function<
bool(
Value)> isValidBoundValue = [&](
Value value) ->
bool {
3702 Region *valueRegion = value.getParentRegion();
3708 Operation *defOp = value.getDefiningOp();
3712 return llvm::all_of(defOp->
getOperands(), isValidBoundValue);
3714 auto hasUnsupportedTaskloopLocalBound = [&](
OperandRange range) ->
bool {
3715 return llvm::any_of(range,
3716 [&](
Value value) {
return !isValidBoundValue(value); });
3719 if (hasUnsupportedTaskloopLocalBound(loopNestOp.getLoopLowerBounds()) ||
3720 hasUnsupportedTaskloopLocalBound(loopNestOp.getLoopUpperBounds()) ||
3721 hasUnsupportedTaskloopLocalBound(loopNestOp.getLoopSteps())) {
3723 <<
"expects loop bounds and steps to be defined outside of the "
3724 "taskloop.context region or by pure, regionless operations "
3725 "that do not depend on block arguments";
3736 const TaskloopWrapperOperands &clauses) {
3737 TaskloopWrapperOp::build(builder, state);
3740TaskloopContextOp TaskloopWrapperOp::getTaskloopContext() {
3741 return dyn_cast<TaskloopContextOp>(getOperation()->getParentOp());
3744LogicalResult TaskloopWrapperOp::verify() {
3745 TaskloopContextOp context = getTaskloopContext();
3747 return emitOpError() <<
"expected to be nested in a taskloop context op";
3751LogicalResult TaskloopWrapperOp::verifyRegions() {
3752 if (LoopWrapperInterface nested = getNestedWrapper()) {
3755 <<
"'omp.composite' attribute missing from composite wrapper";
3759 if (!isa<SimdOp>(nested))
3760 return emitError() <<
"only supported nested wrapper is 'omp.simd'";
3761 }
else if (isComposite()) {
3763 <<
"'omp.composite' attribute present in non-composite wrapper";
3787 for (
auto &iv : ivs)
3788 iv.type = loopVarType;
3793 result.addAttribute(
"loop_inclusive", UnitAttr::get(ctx));
3809 "collapse_num_loops",
3814 auto parseTiles = [&]() -> ParseResult {
3818 tiles.push_back(
tile);
3827 if (tiles.size() > 0)
3846 Region ®ion = getRegion();
3848 p <<
" (" << args <<
") : " << args[0].getType() <<
" = ("
3849 << getLoopLowerBounds() <<
") to (" << getLoopUpperBounds() <<
") ";
3850 if (getLoopInclusive())
3852 p <<
"step (" << getLoopSteps() <<
") ";
3853 if (
int64_t numCollapse = getCollapseNumLoops())
3854 if (numCollapse > 1)
3855 p <<
"collapse(" << numCollapse <<
") ";
3858 p <<
"tiles(" << tiles.value() <<
") ";
3864 const LoopNestOperands &clauses) {
3866 LoopNestOp::build(builder, state, clauses.collapseNumLoops,
3867 clauses.loopLowerBounds, clauses.loopUpperBounds,
3868 clauses.loopSteps, clauses.loopInclusive,
3872LogicalResult LoopNestOp::verify() {
3873 if (getLoopLowerBounds().empty())
3874 return emitOpError() <<
"must represent at least one loop";
3876 if (getLoopLowerBounds().size() != getIVs().size())
3877 return emitOpError() <<
"number of range arguments and IVs do not match";
3879 for (
auto [lb, iv] : llvm::zip_equal(getLoopLowerBounds(), getIVs())) {
3880 if (lb.getType() != iv.getType())
3882 <<
"range argument type does not match corresponding IV type";
3885 uint64_t numIVs = getIVs().size();
3887 if (
const auto &numCollapse = getCollapseNumLoops())
3888 if (numCollapse > numIVs)
3890 <<
"collapse value is larger than the number of loops";
3893 if (tiles.value().size() > numIVs)
3894 return emitOpError() <<
"too few canonical loops for tile dimensions";
3896 if (!llvm::dyn_cast_if_present<LoopWrapperInterface>((*this)->getParentOp()))
3897 return emitOpError() <<
"expects parent op to be a loop wrapper";
3902void LoopNestOp::gatherWrappers(
3905 while (
auto wrapper =
3906 llvm::dyn_cast_if_present<LoopWrapperInterface>(parent)) {
3907 wrappers.push_back(wrapper);
3916std::tuple<NewCliOp, OpOperand *, OpOperand *>
3922 return {{},
nullptr,
nullptr};
3925 "Unexpected type of cli");
3931 auto op = cast<LoopTransformationInterface>(use.getOwner());
3933 unsigned opnum = use.getOperandNumber();
3934 if (op.isGeneratee(opnum)) {
3935 assert(!gen &&
"Each CLI may have at most one def");
3937 }
else if (op.isApplyee(opnum)) {
3938 assert(!cons &&
"Each CLI may have at most one consumer");
3941 llvm_unreachable(
"Unexpected operand for a CLI");
3945 return {create, gen, cons};
3968 std::string cliName{
"cli"};
3972 .Case([&](CanonicalLoopOp op) {
3975 .Case([&](UnrollHeuristicOp op) -> std::string {
3976 llvm_unreachable(
"heuristic unrolling does not generate a loop");
3978 .Case([&](FuseOp op) -> std::string {
3979 unsigned opnum =
generator->getOperandNumber();
3982 if (op.getFirst().has_value() && opnum != op.getFirst().value())
3983 return "canonloop_fuse";
3987 .Case([&](TileOp op) -> std::string {
3988 auto [generateesFirst, generateesCount] =
3989 op.getGenerateesODSOperandIndexAndLength();
3990 unsigned firstGrid = generateesFirst;
3991 unsigned firstIntratile = generateesFirst + generateesCount / 2;
3992 unsigned end = generateesFirst + generateesCount;
3993 unsigned opnum =
generator->getOperandNumber();
3995 if (firstGrid <= opnum && opnum < firstIntratile) {
3996 unsigned gridnum = opnum - firstGrid + 1;
3997 return (
"grid" + Twine(gridnum)).str();
3999 if (firstIntratile <= opnum && opnum < end) {
4000 unsigned intratilenum = opnum - firstIntratile + 1;
4001 return (
"intratile" + Twine(intratilenum)).str();
4003 llvm_unreachable(
"Unexpected generatee argument");
4005 .DefaultUnreachable(
"TODO: Custom name for this operation");
4008 setNameFn(
result, cliName);
4011LogicalResult NewCliOp::verify() {
4012 Value cli = getResult();
4015 "Unexpected type of cli");
4021 auto op = cast<mlir::omp::LoopTransformationInterface>(use.getOwner());
4023 unsigned opnum = use.getOperandNumber();
4024 if (op.isGeneratee(opnum)) {
4027 emitOpError(
"CLI must have at most one generator");
4029 .
append(
"first generator here:");
4031 .
append(
"second generator here:");
4036 }
else if (op.isApplyee(opnum)) {
4039 emitOpError(
"CLI must have at most one consumer");
4041 .
append(
"first consumer here:")
4045 .
append(
"second consumer here:")
4052 llvm_unreachable(
"Unexpected operand for a CLI");
4060 .
append(
"see consumer here: ")
4083 setNameFn(&getRegion().front(),
"body_entry");
4086void CanonicalLoopOp::getAsmBlockArgumentNames(
Region ®ion,
4094 p <<
'(' << getCli() <<
')';
4095 p <<
' ' << getInductionVar() <<
" : " << getInductionVar().getType()
4096 <<
" in range(" << getTripCount() <<
") ";
4106 CanonicalLoopInfoType cliType =
4107 CanonicalLoopInfoType::get(parser.
getContext());
4132 if (parser.
parseRegion(*region, {inductionVariable}))
4137 result.operands.append(cliOperand);
4143 return mlir::success();
4146LogicalResult CanonicalLoopOp::verify() {
4149 if (!getRegion().empty()) {
4150 Region ®ion = getRegion();
4153 "Canonical loop region must have exactly one argument");
4157 "Region argument must be the same type as the trip count");
4163Value CanonicalLoopOp::getInductionVar() {
return getRegion().getArgument(0); }
4165std::pair<unsigned, unsigned>
4166CanonicalLoopOp::getApplyeesODSOperandIndexAndLength() {
4171std::pair<unsigned, unsigned>
4172CanonicalLoopOp::getGenerateesODSOperandIndexAndLength() {
4173 return getODSOperandIndexAndLength(odsIndex_cli);
4187 p <<
'(' << getApplyee() <<
')';
4194 auto cliType = CanonicalLoopInfoType::get(parser.
getContext());
4217 return mlir::success();
4220std::pair<unsigned, unsigned>
4221UnrollHeuristicOp ::getApplyeesODSOperandIndexAndLength() {
4222 return getODSOperandIndexAndLength(odsIndex_applyee);
4225std::pair<unsigned, unsigned>
4226UnrollHeuristicOp::getGenerateesODSOperandIndexAndLength() {
4237 if (!generatees.empty())
4238 p <<
'(' << llvm::interleaved(generatees) <<
')';
4240 if (!applyees.empty())
4241 p <<
" <- (" << llvm::interleaved(applyees) <<
')';
4283 bool isOnlyCanonLoops =
true;
4285 for (
Value applyee : op.getApplyees()) {
4286 auto [create, gen, cons] =
decodeCli(applyee);
4289 return op.emitOpError() <<
"applyee CLI has no generator";
4291 auto loop = dyn_cast_or_null<CanonicalLoopOp>(gen->getOwner());
4292 canonLoops.push_back(loop);
4294 isOnlyCanonLoops =
false;
4299 if (!isOnlyCanonLoops)
4303 for (
auto i : llvm::seq<int>(1, canonLoops.size())) {
4304 auto parentLoop = canonLoops[i - 1];
4305 auto loop = canonLoops[i];
4307 if (parentLoop.getOperation() != loop.getOperation()->getParentOp())
4308 return op.emitOpError()
4309 <<
"tiled loop nest must be nested within each other";
4311 parentIVs.insert(parentLoop.getInductionVar());
4316 bool isPerfectlyNested = [&]() {
4317 auto &parentBody = parentLoop.getRegion();
4318 if (!parentBody.hasOneBlock())
4320 auto &parentBlock = parentBody.getBlocks().front();
4322 auto nestedLoopIt = parentBlock.begin();
4323 if (nestedLoopIt == parentBlock.end() ||
4324 (&*nestedLoopIt != loop.getOperation()))
4327 auto termIt = std::next(nestedLoopIt);
4328 if (termIt == parentBlock.end() || !isa<TerminatorOp>(termIt))
4331 if (std::next(termIt) != parentBlock.end())
4336 if (!isPerfectlyNested)
4337 return op.emitOpError() <<
"tiled loop nest must be perfectly nested";
4339 if (parentIVs.contains(loop.getTripCount()))
4340 return op.emitOpError() <<
"tiled loop nest must be rectangular";
4357LogicalResult TileOp::verify() {
4358 if (getApplyees().empty())
4359 return emitOpError() <<
"must apply to at least one loop";
4361 if (getSizes().size() != getApplyees().size())
4362 return emitOpError() <<
"there must be one tile size for each applyee";
4364 if (!getGeneratees().empty() &&
4365 2 * getSizes().size() != getGeneratees().size())
4367 <<
"expecting two times the number of generatees than applyees";
4372std::pair<unsigned, unsigned> TileOp ::getApplyeesODSOperandIndexAndLength() {
4373 return getODSOperandIndexAndLength(odsIndex_applyees);
4376std::pair<unsigned, unsigned> TileOp::getGenerateesODSOperandIndexAndLength() {
4377 return getODSOperandIndexAndLength(odsIndex_generatees);
4387 if (!generatees.empty())
4388 p <<
'(' << llvm::interleaved(generatees) <<
')';
4390 if (!applyees.empty())
4391 p <<
" <- (" << llvm::interleaved(applyees) <<
')';
4394LogicalResult FuseOp::verify() {
4395 if (getApplyees().size() < 2)
4396 return emitOpError() <<
"must apply to at least two loops";
4398 if (getFirst().has_value() && getCount().has_value()) {
4399 int64_t first = getFirst().value();
4400 int64_t count = getCount().value();
4401 if ((
unsigned)(first + count - 1) > getApplyees().size())
4402 return emitOpError() <<
"the numbers of applyees must be at least first "
4403 "minus one plus count attributes";
4404 if (!getGeneratees().empty() &&
4405 getGeneratees().size() != getApplyees().size() + 1 - count)
4406 return emitOpError() <<
"the number of generatees must be the number of "
4407 "aplyees plus one minus count";
4410 if (!getGeneratees().empty() && getGeneratees().size() != 1)
4412 <<
"in a complete fuse the number of generatees must be exactly 1";
4414 for (
auto &&applyee : getApplyees()) {
4415 auto [create, gen, cons] =
decodeCli(applyee);
4418 return emitOpError() <<
"applyee CLI has no generator";
4419 auto loop = dyn_cast_or_null<CanonicalLoopOp>(gen->getOwner());
4422 <<
"currently only supports omp.canonical_loop as applyee";
4426std::pair<unsigned, unsigned> FuseOp::getApplyeesODSOperandIndexAndLength() {
4427 return getODSOperandIndexAndLength(odsIndex_applyees);
4430std::pair<unsigned, unsigned> FuseOp::getGenerateesODSOperandIndexAndLength() {
4431 return getODSOperandIndexAndLength(odsIndex_generatees);
4439 const CriticalDeclareOperands &clauses) {
4440 CriticalDeclareOp::build(builder, state, clauses.symName, clauses.hint);
4443LogicalResult CriticalDeclareOp::verify() {
4448 if (getNameAttr()) {
4449 SymbolRefAttr symbolRef = getNameAttr();
4453 return emitOpError() <<
"expected symbol reference " << symbolRef
4454 <<
" to point to a critical declaration";
4474 return op.
emitOpError() <<
"must be nested inside of a loop";
4478 if (
auto wsloopOp = dyn_cast<WsloopOp>(wrapper)) {
4479 IntegerAttr orderedAttr = wsloopOp.getOrderedAttr();
4481 return op.
emitOpError() <<
"the enclosing worksharing-loop region must "
4482 "have an ordered clause";
4484 if (hasRegion && orderedAttr.getInt() != 0)
4485 return op.
emitOpError() <<
"the enclosing loop's ordered clause must not "
4486 "have a parameter present";
4488 if (!hasRegion && orderedAttr.getInt() == 0)
4489 return op.
emitOpError() <<
"the enclosing loop's ordered clause must "
4490 "have a parameter present";
4491 }
else if (!isa<SimdOp>(wrapper)) {
4492 return op.
emitOpError() <<
"must be nested inside of a worksharing, simd "
4493 "or worksharing simd loop";
4499 const OrderedOperands &clauses) {
4500 OrderedOp::build(builder, state, clauses.doacrossDependType,
4501 clauses.doacrossNumLoops, clauses.doacrossDependVars);
4504LogicalResult OrderedOp::verify() {
4508 auto wrapper = (*this)->getParentOfType<WsloopOp>();
4509 if (!wrapper || *wrapper.getOrdered() != *getDoacrossNumLoops())
4510 return emitOpError() <<
"number of variables in depend clause does not "
4511 <<
"match number of iteration variables in the "
4518 const OrderedRegionOperands &clauses) {
4519 OrderedRegionOp::build(builder, state, clauses.parLevelSimd);
4529 const TaskwaitOperands &clauses) {
4531 TaskwaitOp::build(builder, state,
nullptr,
4540LogicalResult AtomicReadOp::verify() {
4541 if (verifyCommon().
failed())
4542 return mlir::failure();
4545 if (
auto moduleOp = getOperation()->getParentOfType<ModuleOp>())
4546 if (
Attribute verAttr = moduleOp->getAttr(
"omp.version"))
4547 version = llvm::cast<VersionAttr>(verAttr).getVersion();
4549 if (
auto mo = getMemoryOrder()) {
4550 if (*mo == ClauseMemoryOrderKind::Release) {
4551 return emitError(
"memory-order must not be release for atomic reads");
4553 if (*mo == ClauseMemoryOrderKind::Acq_rel) {
4556 return emitError(
"memory-order must not be acq_rel for atomic reads");
4566LogicalResult AtomicWriteOp::verify() {
4567 if (verifyCommon().
failed())
4568 return mlir::failure();
4571 if (
auto moduleOp = getOperation()->getParentOfType<ModuleOp>())
4572 if (
Attribute verAttr = moduleOp->getAttr(
"omp.version"))
4573 version = llvm::cast<VersionAttr>(verAttr).getVersion();
4575 if (
auto mo = getMemoryOrder()) {
4576 if (*mo == ClauseMemoryOrderKind::Acquire) {
4577 return emitError(
"memory-order must not be acquire for atomic writes");
4579 if (*mo == ClauseMemoryOrderKind::Acq_rel) {
4582 return emitError(
"memory-order must not be acq_rel for atomic writes");
4592LogicalResult AtomicUpdateOp::canonicalize(AtomicUpdateOp op,
4598 if (
Value writeVal = op.getWriteOpVal()) {
4600 op, op.getX(), writeVal, op.getHintAttr(), op.getMemoryOrderAttr());
4606LogicalResult AtomicUpdateOp::verify() {
4607 if (verifyCommon().
failed())
4608 return mlir::failure();
4611 if (
auto moduleOp = getOperation()->getParentOfType<ModuleOp>())
4612 if (
Attribute verAttr = moduleOp->getAttr(
"omp.version"))
4613 version = llvm::cast<VersionAttr>(verAttr).getVersion();
4615 if (
auto mo = getMemoryOrder()) {
4616 if (*mo == ClauseMemoryOrderKind::Acq_rel ||
4617 *mo == ClauseMemoryOrderKind::Acquire) {
4621 "memory-order must not be acq_rel or acquire for atomic updates");
4628LogicalResult AtomicUpdateOp::verifyRegions() {
return verifyRegionsCommon(); }
4634AtomicReadOp AtomicCaptureOp::getAtomicReadOp() {
4635 if (
auto op = dyn_cast<AtomicReadOp>(getFirstOp()))
4637 return dyn_cast<AtomicReadOp>(getSecondOp());
4640AtomicWriteOp AtomicCaptureOp::getAtomicWriteOp() {
4641 if (
auto op = dyn_cast<AtomicWriteOp>(getFirstOp()))
4643 return dyn_cast<AtomicWriteOp>(getSecondOp());
4646AtomicUpdateOp AtomicCaptureOp::getAtomicUpdateOp() {
4647 if (
auto op = dyn_cast<AtomicUpdateOp>(getFirstOp()))
4649 return dyn_cast<AtomicUpdateOp>(getSecondOp());
4652LogicalResult AtomicCaptureOp::verify() {
4656LogicalResult AtomicCaptureOp::verifyRegions() {
4657 if (verifyRegionsCommon().
failed())
4658 return mlir::failure();
4660 if (getFirstOp()->getAttr(
"hint") || getSecondOp()->getAttr(
"hint"))
4662 "operations inside capture region must not have hint clause");
4664 if (getFirstOp()->getAttr(
"memory_order") ||
4665 getSecondOp()->getAttr(
"memory_order"))
4667 "operations inside capture region must not have memory_order clause");
4675LogicalResult AtomicCompareOp::verify() {
4676 if (verifyCommon().
failed())
4677 return mlir::failure();
4681LogicalResult AtomicCompareOp::verifyRegions() {
4682 if (verifyRegionsCommon().
failed())
4683 return mlir::failure();
4685 if (verifyOperator().
failed())
4686 return mlir::failure();
4691 if (!terminator || !isa<YieldOp>(terminator))
4692 return emitOpError(
"region must be terminated with omp.yield");
4702 const CancelOperands &clauses) {
4703 CancelOp::build(builder, state, clauses.cancelDirective, clauses.ifExpr);
4716LogicalResult CancelOp::verify() {
4717 ClauseCancellationConstructType cct = getCancelDirective();
4720 if (!structuralParent)
4721 return emitOpError() <<
"Orphaned cancel construct";
4723 if ((cct == ClauseCancellationConstructType::Parallel) &&
4724 !mlir::isa<ParallelOp>(structuralParent)) {
4725 return emitOpError() <<
"cancel parallel must appear "
4726 <<
"inside a parallel region";
4728 if (cct == ClauseCancellationConstructType::Loop) {
4731 auto wsloopOp = mlir::dyn_cast<WsloopOp>(structuralParent->
getParentOp());
4735 <<
"cancel loop must appear inside a worksharing-loop region";
4737 if (wsloopOp.getNowaitAttr()) {
4738 return emitError() <<
"A worksharing construct that is canceled "
4739 <<
"must not have a nowait clause";
4741 if (wsloopOp.getOrderedAttr()) {
4742 return emitError() <<
"A worksharing construct that is canceled "
4743 <<
"must not have an ordered clause";
4746 }
else if (cct == ClauseCancellationConstructType::Sections) {
4750 mlir::dyn_cast<SectionsOp>(structuralParent->
getParentOp());
4752 return emitOpError() <<
"cancel sections must appear "
4753 <<
"inside a sections region";
4755 if (sectionsOp.getNowait()) {
4756 return emitError() <<
"A sections construct that is canceled "
4757 <<
"must not have a nowait clause";
4760 if ((cct == ClauseCancellationConstructType::Taskgroup) &&
4761 (!mlir::isa<omp::TaskOp>(structuralParent) &&
4762 !mlir::isa<omp::TaskloopWrapperOp>(structuralParent->
getParentOp()))) {
4763 return emitOpError() <<
"cancel taskgroup must appear "
4764 <<
"inside a task region";
4774 const CancellationPointOperands &clauses) {
4775 CancellationPointOp::build(builder, state, clauses.cancelDirective);
4778LogicalResult CancellationPointOp::verify() {
4779 ClauseCancellationConstructType cct = getCancelDirective();
4782 if (!structuralParent)
4783 return emitOpError() <<
"Orphaned cancellation point";
4785 if ((cct == ClauseCancellationConstructType::Parallel) &&
4786 !mlir::isa<ParallelOp>(structuralParent)) {
4787 return emitOpError() <<
"cancellation point parallel must appear "
4788 <<
"inside a parallel region";
4792 if ((cct == ClauseCancellationConstructType::Loop) &&
4793 !mlir::isa<WsloopOp>(structuralParent->
getParentOp())) {
4794 return emitOpError() <<
"cancellation point loop must appear "
4795 <<
"inside a worksharing-loop region";
4797 if ((cct == ClauseCancellationConstructType::Sections) &&
4798 !mlir::isa<omp::SectionOp>(structuralParent)) {
4799 return emitOpError() <<
"cancellation point sections must appear "
4800 <<
"inside a sections region";
4802 if ((cct == ClauseCancellationConstructType::Taskgroup) &&
4803 (!mlir::isa<omp::TaskOp>(structuralParent) &&
4804 !mlir::isa<omp::TaskloopWrapperOp>(structuralParent->
getParentOp()))) {
4805 return emitOpError() <<
"cancellation point taskgroup must appear "
4806 <<
"inside a task region";
4815LogicalResult MapBoundsOp::verify() {
4816 auto extent = getExtent();
4818 if (!extent && !upperbound)
4819 return emitError(
"expected extent or upperbound.");
4826 PrivateClauseOp::build(
4827 odsBuilder, odsState, symName, type,
4828 DataSharingClauseTypeAttr::get(odsBuilder.
getContext(),
4829 DataSharingClauseType::Private));
4832LogicalResult PrivateClauseOp::verifyRegions() {
4833 Type argType = getArgType();
4834 auto verifyTerminator = [&](
Operation *terminator,
4835 bool yieldsValue) -> LogicalResult {
4839 if (!llvm::isa<YieldOp>(terminator))
4841 <<
"expected exit block terminator to be an `omp.yield` op.";
4843 YieldOp yieldOp = llvm::cast<YieldOp>(terminator);
4844 TypeRange yieldedTypes = yieldOp.getResults().getTypes();
4847 if (yieldedTypes.empty())
4851 <<
"Did not expect any values to be yielded.";
4854 if (yieldedTypes.size() == 1 && yieldedTypes.front() == argType)
4858 <<
"Invalid yielded value. Expected type: " << argType
4861 if (yieldedTypes.empty())
4864 error << yieldedTypes;
4870 StringRef regionName,
4871 bool yieldsValue) -> LogicalResult {
4872 assert(!region.
empty());
4876 <<
"`" << regionName <<
"`: "
4877 <<
"expected " << expectedNumArgs
4880 for (
Block &block : region) {
4893 for (
Region *region : getRegions())
4894 for (
Type ty : region->getArgumentTypes())
4896 return emitError() <<
"Region argument type mismatch: got " << ty
4897 <<
" expected " << argType <<
".";
4900 if (!initRegion.
empty() &&
4905 DataSharingClauseType dsType = getDataSharingType();
4907 if (dsType == DataSharingClauseType::Private && !getCopyRegion().empty())
4908 return emitError(
"`private` clauses do not require a `copy` region.");
4910 if (dsType == DataSharingClauseType::FirstPrivate && getCopyRegion().empty())
4912 "`firstprivate` clauses require at least a `copy` region.");
4914 if (dsType == DataSharingClauseType::FirstPrivate &&
4919 if (!getDeallocRegion().empty() &&
4932 const MaskedOperands &clauses) {
4933 MaskedOp::build(builder, state, clauses.filteredThreadId);
4941 const ScanOperands &clauses) {
4942 ScanOp::build(builder, state, clauses.inclusiveVars, clauses.exclusiveVars);
4945LogicalResult ScanOp::verify() {
4946 if (hasExclusiveVars() == hasInclusiveVars())
4948 "Exactly one of EXCLUSIVE or INCLUSIVE clause is expected");
4949 if (WsloopOp parentWsLoopOp = (*this)->getParentOfType<WsloopOp>()) {
4950 if (parentWsLoopOp.getReductionModAttr() &&
4951 parentWsLoopOp.getReductionModAttr().getValue() ==
4952 ReductionModifier::inscan)
4955 if (SimdOp parentSimdOp = (*this)->getParentOfType<SimdOp>()) {
4956 if (parentSimdOp.getReductionModAttr() &&
4957 parentSimdOp.getReductionModAttr().getValue() ==
4958 ReductionModifier::inscan)
4961 return emitError(
"SCAN directive needs to be enclosed within a parent "
4962 "worksharing loop construct or SIMD construct with INSCAN "
4963 "reduction modifier");
4968 std::optional<uint64_t> alignment) {
4969 if (alignment.has_value()) {
4970 if ((alignment.value() != 0) && !llvm::has_single_bit(alignment.value()))
4972 <<
"ALIGN value : " << alignment.value() <<
" must be power of 2";
4977LogicalResult AllocateDirOp::verify() {
4985LogicalResult AllocSharedMemOp::verify() {
4993LogicalResult FreeSharedMemOp::verify() {
5001LogicalResult WorkdistributeOp::verify() {
5003 Region ®ion = getRegion();
5008 if (entryBlock.
empty())
5009 return emitOpError(
"region must contain a structured block");
5011 bool hasTerminator =
false;
5012 for (
Block &block : region) {
5013 if (isa<TerminatorOp>(block.
back())) {
5014 if (hasTerminator) {
5015 return emitOpError(
"region must have exactly one terminator");
5017 hasTerminator =
true;
5020 if (!hasTerminator) {
5021 return emitOpError(
"region must be terminated with omp.terminator");
5025 if (isa<BarrierOp>(op)) {
5027 "explicit barriers are not allowed in workdistribute region");
5030 if (isa<ParallelOp>(op)) {
5032 "nested parallel constructs not allowed in workdistribute");
5034 if (isa<TeamsOp>(op)) {
5036 "nested teams constructs not allowed in workdistribute");
5040 if (walkResult.wasInterrupted())
5044 if (!llvm::dyn_cast<TeamsOp>(parentOp))
5045 return emitOpError(
"workdistribute must be nested under teams");
5053LogicalResult DeclareSimdOp::verify() {
5056 dyn_cast_if_present<mlir::FunctionOpInterface>((*this)->getParentOp());
5058 return emitOpError() <<
"must be nested inside a function";
5060 if (getInbranch() && getNotinbranch())
5061 return emitOpError(
"cannot have both 'inbranch' and 'notinbranch'");
5071 const DeclareSimdOperands &clauses) {
5073 DeclareSimdOp::build(odsBuilder, odsState, clauses.alignedVars,
5075 clauses.linearVars, clauses.linearStepVars,
5076 clauses.linearVarTypes, clauses.linearModifiers,
5077 clauses.notinbranch, clauses.simdlen,
5078 clauses.uniformVars);
5095 return mlir::failure();
5096 return mlir::success();
5103 for (
unsigned i = 0; i < uniformVars.size(); ++i) {
5106 p << uniformVars[i] <<
" : " << uniformTypes[i];
5121 parser, iterated, iteratedTypes, affinityVars, affinityVarTypes,
5122 [&]() -> ParseResult {
return success(); })))
5156 OpAsmParser::Argument &arg = ivArgs.emplace_back();
5157 if (parser.parseArgument(arg))
5161 if (succeeded(parser.parseOptionalColon())) {
5162 if (parser.parseType(arg.type))
5165 arg.type = parser.getBuilder().getIndexType();
5177 OpAsmParser::UnresolvedOperand lb, ub, st;
5178 if (parser.parseOperand(lb) || parser.parseKeyword(
"to") ||
5179 parser.parseOperand(ub) || parser.parseKeyword(
"step") ||
5180 parser.parseOperand(st))
5185 steps.push_back(st);
5193 if (ivArgs.size() != lbs.size())
5195 <<
"mismatch: " << ivArgs.size() <<
" variables but " << lbs.size()
5198 for (
auto &arg : ivArgs) {
5199 lbTypes.push_back(arg.type);
5200 ubTypes.push_back(arg.type);
5201 stepTypes.push_back(arg.type);
5221 for (
unsigned i = 0, e = lbs.size(); i < e; ++i) {
5224 p << lbs[i] <<
" to " << ubs[i] <<
" step " << steps[i];
5232LogicalResult IteratorOp::verify() {
5233 auto iteratedTy = llvm::dyn_cast<omp::IteratedType>(getIterated().
getType());
5235 return emitOpError() <<
"result must be omp.iterated<entry_ty>";
5237 for (
auto [lb,
ub, step] : llvm::zip_equal(
5238 getLoopLowerBounds(), getLoopUpperBounds(), getLoopSteps())) {
5240 return emitOpError() <<
"loop step must not be zero";
5244 IntegerAttr stepAttr;
5250 const APInt &lbVal = lbAttr.getValue();
5251 const APInt &ubVal = ubAttr.getValue();
5252 const APInt &stepVal = stepAttr.getValue();
5253 if (stepVal.isStrictlyPositive() && lbVal.sgt(ubVal))
5254 return emitOpError() <<
"positive loop step requires lower bound to be "
5255 "less than or equal to upper bound";
5256 if (stepVal.isNegative() && lbVal.slt(ubVal))
5257 return emitOpError() <<
"negative loop step requires lower bound to be "
5258 "greater than or equal to upper bound";
5261 Block &
b = getRegion().front();
5262 auto yield = llvm::dyn_cast<omp::YieldOp>(
b.getTerminator());
5265 return emitOpError() <<
"region must be terminated by omp.yield";
5267 if (yield.getNumOperands() != 1)
5269 <<
"omp.yield in omp.iterator region must yield exactly one value";
5271 mlir::Type yieldedTy = yield.getOperand(0).getType();
5272 mlir::Type elemTy = iteratedTy.getElementType();
5274 if (yieldedTy != elemTy)
5275 return emitOpError() <<
"omp.iterated element type (" << elemTy
5276 <<
") does not match omp.yield operand type ("
5277 << yieldedTy <<
")";
5290 return emitOpError() <<
"expected symbol reference '" << getSymName()
5291 <<
"' to point to a global variable";
5293 if (isa<FunctionOpInterface>(symbol))
5294 return emitOpError() <<
"expected symbol reference '" << getSymName()
5295 <<
"' to point to a global variable, not a function";
5300#define GET_ATTRDEF_CLASSES
5301#include "mlir/Dialect/OpenMP/OpenMPOpsAttributes.cpp.inc"
5303#define GET_OP_CLASSES
5304#include "mlir/Dialect/OpenMP/OpenMPOps.cpp.inc"
5306#define GET_TYPEDEF_CLASSES
5307#include "mlir/Dialect/OpenMP/OpenMPOpsTypes.cpp.inc"
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
static std::optional< int64_t > getUpperBound(Value iv)
Gets the constant upper bound on an affine.for iv.
static LogicalResult verifyRegion(emitc::SwitchOp op, Region ®ion, const Twine &name)
static Type getElementType(Type type)
Determine the element type of type.
static const mlir::GenInfo * generator
static LogicalResult verifyNontemporalClause(Operation *op, OperandRange nontemporalVars)
static DenseI64ArrayAttr makeDenseI64ArrayAttr(MLIRContext *ctx, const ArrayRef< int64_t > intArray)
static void printDependVarList(OpAsmPrinter &p, Operation *op, OperandRange dependVars, TypeRange dependTypes, std::optional< ArrayAttr > dependKinds, OperandRange iteratedVars, TypeRange iteratedTypes, std::optional< ArrayAttr > iteratedKinds)
Print Depend clause.
static constexpr StringRef getPrivateNeedsBarrierSpelling()
static void printHeapAllocClause(OpAsmPrinter &p, Operation *op, TypeAttr inType, ValueRange typeparams, TypeRange typeparamsTypes, ValueRange shape, TypeRange shapeTypes)
static LogicalResult verifyReductionVarList(Operation *op, std::optional< ArrayAttr > reductionSyms, OperandRange reductionVars, std::optional< ArrayRef< bool > > reductionByref)
Verifies Reduction Clause.
static ParseResult parseLinearClause(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &linearVars, SmallVectorImpl< Type > &linearTypes, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &linearStepVars, SmallVectorImpl< Type > &linearStepTypes, ArrayAttr &linearModifiers)
linear ::= linear ( linear-list ) linear-list := linear-val | linear-val linear-list linear-val := ss...
static ParseResult parseInReductionPrivateRegion(OpAsmParser &parser, Region ®ion, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &inReductionVars, SmallVectorImpl< Type > &inReductionTypes, DenseBoolArrayAttr &inReductionByref, ArrayAttr &inReductionSyms, llvm::SmallVectorImpl< OpAsmParser::UnresolvedOperand > &privateVars, llvm::SmallVectorImpl< Type > &privateTypes, ArrayAttr &privateSyms, UnitAttr &privateNeedsBarrier)
static ArrayAttr makeArrayAttr(MLIRContext *context, llvm::ArrayRef< Attribute > attrs)
static ParseResult parseClauseAttr(AsmParser &parser, ClauseAttr &attr)
static void printDynGroupprivateClause(OpAsmPrinter &printer, Operation *op, AccessGroupModifierAttr modifierFirst, FallbackModifierAttr modifierSecond, Value dynGroupprivateSize, Type sizeType)
static void printAllocateAndAllocator(OpAsmPrinter &p, Operation *op, OperandRange allocateVars, TypeRange allocateTypes, OperandRange allocatorVars, TypeRange allocatorTypes)
Print allocate clause.
static DenseBoolArrayAttr makeDenseBoolArrayAttr(MLIRContext *ctx, const ArrayRef< bool > boolArray)
static std::string generateLoopNestingName(StringRef prefix, CanonicalLoopOp op)
Generate a name of a canonical loop nest of the format <prefix>(_r<idx>_s<idx>)*.
static ParseResult parseAffinityClause(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &iterated, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &affinityVars, SmallVectorImpl< Type > &iteratedTypes, SmallVectorImpl< Type > &affinityVarTypes)
static void printClauseWithRegionArgs(OpAsmPrinter &p, MLIRContext *ctx, StringRef clauseName, ValueRange argsSubrange, ValueRange operands, TypeRange types, ArrayAttr symbols=nullptr, DenseI64ArrayAttr mapIndices=nullptr, DenseBoolArrayAttr byref=nullptr, ReductionModifierAttr modifier=nullptr, UnitAttr needsBarrier=nullptr)
static void printSplitIteratedList(OpAsmPrinter &p, ValueRange iteratedVars, TypeRange iteratedTypes, ValueRange plainVars, TypeRange plainTypes, PrintPrefixFn &&printPrefixForPlain, PrintPrefixFn &&printPrefixForIterated)
static LogicalResult verifyDependVarList(Operation *op, std::optional< ArrayAttr > dependKinds, OperandRange dependVars, std::optional< ArrayAttr > iteratedKinds, OperandRange iteratedVars)
Verifies Depend clause.
static void printBlockArgClause(OpAsmPrinter &p, MLIRContext *ctx, StringRef clauseName, ValueRange argsSubrange, std::optional< MapPrintArgs > mapArgs)
static void printAffinityClause(OpAsmPrinter &p, Operation *op, ValueRange iterated, ValueRange affinityVars, TypeRange iteratedTypes, TypeRange affinityVarTypes)
static void printBlockArgRegion(OpAsmPrinter &p, Operation *op, Region ®ion, const AllRegionPrintArgs &args)
static ParseResult parseGranularityClause(OpAsmParser &parser, ClauseTypeAttr &prescriptiveness, std::optional< OpAsmParser::UnresolvedOperand > &operand, Type &operandType, std::optional< ClauseType >(*symbolizeClause)(StringRef), StringRef clauseName)
static void printIteratorHeader(OpAsmPrinter &p, Operation *op, Region ®ion, ValueRange lbs, ValueRange ubs, ValueRange steps, TypeRange, TypeRange, TypeRange)
static ParseResult parseHeapAllocClause(OpAsmParser &parser, TypeAttr &inTypeAttr, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &typeparams, SmallVectorImpl< Type > &typeparamsTypes, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &shape, SmallVectorImpl< Type > &shapeTypes)
operation ::= $in_type ( ( $typeparams ) )? ( , $shape )?
static ParseResult parseIteratorHeader(OpAsmParser &parser, Region ®ion, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &lbs, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &ubs, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &steps, SmallVectorImpl< Type > &lbTypes, SmallVectorImpl< Type > &ubTypes, SmallVectorImpl< Type > &stepTypes)
static ParseResult parseBlockArgRegion(OpAsmParser &parser, Region ®ion, AllRegionParseArgs args)
static ParseResult parseLoopTransformClis(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &generateesOperands, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &applyeesOperands)
static ParseResult parseSynchronizationHint(OpAsmParser &parser, IntegerAttr &hintAttr)
Parses a Synchronization Hint clause.
static void printScheduleClause(OpAsmPrinter &p, Operation *op, ClauseScheduleKindAttr scheduleKind, ScheduleModifierAttr scheduleMod, UnitAttr scheduleSimd, Value scheduleChunk, Type scheduleChunkType)
Print schedule clause.
static void printCopyprivate(OpAsmPrinter &p, Operation *op, OperandRange copyprivateVars, TypeRange copyprivateTypes, std::optional< ArrayAttr > copyprivateSyms)
Print Copyprivate clause.
static ParseResult parseOrderClause(OpAsmParser &parser, ClauseOrderKindAttr &order, OrderModifierAttr &orderMod)
static bool mapTypeToBool(ClauseMapFlags value, ClauseMapFlags flag)
static void printAlignedClause(OpAsmPrinter &p, Operation *op, ValueRange alignedVars, TypeRange alignedTypes, std::optional< ArrayAttr > alignments)
Print Aligned Clause.
static LogicalResult verifySynchronizationHint(Operation *op, uint64_t hint)
Verifies a synchronization hint clause.
static ParseResult parseUseDeviceAddrUseDevicePtrRegion(OpAsmParser &parser, Region ®ion, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &useDeviceAddrVars, SmallVectorImpl< Type > &useDeviceAddrTypes, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &useDevicePtrVars, SmallVectorImpl< Type > &useDevicePtrTypes)
static Operation * findCapturedOmpOp(Operation *rootOp, bool checkSingleMandatoryExec, llvm::function_ref< bool(Operation *)> siblingAllowedFn)
static ParseResult parseUniformClause(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &uniformVars, SmallVectorImpl< Type > &uniformTypes)
uniform ::= uniform ( uniform-list ) uniform-list := uniform-val (, uniform-val)* uniform-val := ssa-...
static void printInReductionPrivateReductionRegion(OpAsmPrinter &p, Operation *op, Region ®ion, ValueRange inReductionVars, TypeRange inReductionTypes, DenseBoolArrayAttr inReductionByref, ArrayAttr inReductionSyms, ValueRange privateVars, TypeRange privateTypes, ArrayAttr privateSyms, UnitAttr privateNeedsBarrier, ReductionModifierAttr reductionMod, ValueRange reductionVars, TypeRange reductionTypes, DenseBoolArrayAttr reductionByref, ArrayAttr reductionSyms)
static void printInReductionPrivateRegion(OpAsmPrinter &p, Operation *op, Region ®ion, ValueRange inReductionVars, TypeRange inReductionTypes, DenseBoolArrayAttr inReductionByref, ArrayAttr inReductionSyms, ValueRange privateVars, TypeRange privateTypes, ArrayAttr privateSyms, UnitAttr privateNeedsBarrier)
static void printSynchronizationHint(OpAsmPrinter &p, Operation *op, IntegerAttr hintAttr)
Prints a Synchronization Hint clause.
static void printGranularityClause(OpAsmPrinter &p, Operation *op, ClauseTypeAttr prescriptiveness, Value operand, mlir::Type operandType, StringRef(*stringifyClauseType)(ClauseType))
static ParseResult parseDependVarList(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &dependVars, SmallVectorImpl< Type > &dependTypes, ArrayAttr &dependKinds, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &iteratedVars, SmallVectorImpl< Type > &iteratedTypes, ArrayAttr &iteratedKinds)
depend-entry-list ::= depend-entry | depend-entry-list , depend-entry depend-entry ::= depend-kind ->...
static void printTargetOpRegion(OpAsmPrinter &p, Operation *op, Region ®ion, ValueRange hasDeviceAddrVars, TypeRange hasDeviceAddrTypes, ValueRange hostEvalVars, TypeRange hostEvalTypes, ValueRange inReductionVars, TypeRange inReductionTypes, DenseBoolArrayAttr inReductionByref, ArrayAttr inReductionSyms, ValueRange mapVars, TypeRange mapTypes, ValueRange privateVars, TypeRange privateTypes, ArrayAttr privateSyms, UnitAttr privateNeedsBarrier, DenseI64ArrayAttr privateMaps)
static Operation * getParentInSameDialect(Operation *thisOp)
static void printUniformClause(OpAsmPrinter &p, Operation *op, ValueRange uniformVars, TypeRange uniformTypes)
Print Uniform Clauses.
static LogicalResult verifyCopyprivateVarList(Operation *op, OperandRange copyprivateVars, std::optional< ArrayAttr > copyprivateSyms)
Verifies CopyPrivate Clause.
static LogicalResult verifyAlignedClause(Operation *op, std::optional< ArrayAttr > alignments, OperandRange alignedVars)
static ParseResult parseTargetOpRegion(OpAsmParser &parser, Region ®ion, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &hasDeviceAddrVars, SmallVectorImpl< Type > &hasDeviceAddrTypes, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &hostEvalVars, SmallVectorImpl< Type > &hostEvalTypes, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &inReductionVars, SmallVectorImpl< Type > &inReductionTypes, DenseBoolArrayAttr &inReductionByref, ArrayAttr &inReductionSyms, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &mapVars, SmallVectorImpl< Type > &mapTypes, llvm::SmallVectorImpl< OpAsmParser::UnresolvedOperand > &privateVars, llvm::SmallVectorImpl< Type > &privateTypes, ArrayAttr &privateSyms, UnitAttr &privateNeedsBarrier, DenseI64ArrayAttr &privateMaps)
static ParseResult parsePrivateRegion(OpAsmParser &parser, Region ®ion, llvm::SmallVectorImpl< OpAsmParser::UnresolvedOperand > &privateVars, llvm::SmallVectorImpl< Type > &privateTypes, ArrayAttr &privateSyms, UnitAttr &privateNeedsBarrier)
static void printNumTasksClause(OpAsmPrinter &p, Operation *op, ClauseNumTasksTypeAttr numTasksMod, Value numTasks, mlir::Type numTasksType)
static void printLoopTransformClis(OpAsmPrinter &p, TileOp op, OperandRange generatees, OperandRange applyees)
static ParseResult parseDynGroupprivateClause(OpAsmParser &parser, AccessGroupModifierAttr &accessGroupAttr, FallbackModifierAttr &fallbackAttr, std::optional< OpAsmParser::UnresolvedOperand > &dynGroupprivateSize, Type &sizeType)
static void printPrivateRegion(OpAsmPrinter &p, Operation *op, Region ®ion, ValueRange privateVars, TypeRange privateTypes, ArrayAttr privateSyms, UnitAttr privateNeedsBarrier)
static void printPrivateReductionRegion(OpAsmPrinter &p, Operation *op, Region ®ion, ValueRange privateVars, TypeRange privateTypes, ArrayAttr privateSyms, UnitAttr privateNeedsBarrier, ReductionModifierAttr reductionMod, ValueRange reductionVars, TypeRange reductionTypes, DenseBoolArrayAttr reductionByref, ArrayAttr reductionSyms)
static ParseResult parseSplitIteratedList(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &iteratedVars, SmallVectorImpl< Type > &iteratedTypes, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &plainVars, SmallVectorImpl< Type > &plainTypes, ParsePrefixFn &&parsePrefix)
static void printTaskReductionRegion(OpAsmPrinter &p, Operation *op, Region ®ion, ValueRange taskReductionVars, TypeRange taskReductionTypes, DenseBoolArrayAttr taskReductionByref, ArrayAttr taskReductionSyms)
static LogicalResult verifyMapInfoForMapClause(Operation *op, mlir::omp::MapInfoOp mapInfoOp, llvm::DenseSet< mlir::TypedValue< mlir::omp::PointerLikeType > > &updateToVars, llvm::DenseSet< mlir::TypedValue< mlir::omp::PointerLikeType > > &updateFromVars)
static LogicalResult verifyOrderedParent(Operation &op)
static void printOrderClause(OpAsmPrinter &p, Operation *op, ClauseOrderKindAttr order, OrderModifierAttr orderMod)
static ParseResult parseBlockArgClause(OpAsmParser &parser, llvm::SmallVectorImpl< OpAsmParser::Argument > &entryBlockArgs, StringRef keyword, std::optional< MapParseArgs > mapArgs)
static ParseResult parseClauseWithRegionArgs(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &operands, SmallVectorImpl< Type > &types, SmallVectorImpl< OpAsmParser::Argument > ®ionPrivateArgs, ArrayAttr *symbols=nullptr, DenseI64ArrayAttr *mapIndices=nullptr, DenseBoolArrayAttr *byref=nullptr, ReductionModifierAttr *modifier=nullptr, UnitAttr *needsBarrier=nullptr)
static LogicalResult verifyPrivateVarsMapping(TargetOp targetOp)
static ParseResult parseScheduleClause(OpAsmParser &parser, ClauseScheduleKindAttr &scheduleAttr, ScheduleModifierAttr &scheduleMod, UnitAttr &scheduleSimd, std::optional< OpAsmParser::UnresolvedOperand > &chunkSize, Type &chunkType)
schedule ::= schedule ( sched-list ) sched-list ::= sched-val | sched-val sched-list | sched-val ,...
static LogicalResult verifyDynGroupprivateClause(Operation *op, AccessGroupModifierAttr accessGroup, FallbackModifierAttr fallback, Value dynGroupprivateSize)
static LogicalResult verifyLinearModifiers(Operation *op, std::optional< ArrayAttr > linearModifiers, OperandRange linearVars, bool isDeclareSimd=false)
OpenMP 5.2, Section 5.4.6: "A linear-modifier may be specified as ref or uval only on a declare simd ...
static void printClauseAttr(OpAsmPrinter &p, Operation *op, ClauseAttr attr)
static ParseResult parseAllocateAndAllocator(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &allocateVars, SmallVectorImpl< Type > &allocateTypes, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &allocatorVars, SmallVectorImpl< Type > &allocatorTypes)
Parse an allocate clause with allocators and a list of operands with types.
static void printMembersIndex(OpAsmPrinter &p, MapInfoOp op, ArrayAttr membersIdx)
static bool canPromoteToNoLoop(Operation *capturedOp, TeamsOp teamsOp, WsloopOp *wsLoopOp)
Check if we can promote SPMD kernel to No-Loop kernel.
static void printCaptureType(OpAsmPrinter &p, Operation *op, VariableCaptureKindAttr mapCaptureType)
static LogicalResult verifyNumTeamsClause(Operation *op, Value numTeamsLower, OperandRange numTeamsUpperVars)
static bool opInGlobalImplicitParallelRegion(Operation *op)
static void printUseDeviceAddrUseDevicePtrRegion(OpAsmPrinter &p, Operation *op, Region ®ion, ValueRange useDeviceAddrVars, TypeRange useDeviceAddrTypes, ValueRange useDevicePtrVars, TypeRange useDevicePtrTypes)
static LogicalResult verifyMapClause(Operation *op, OperandRange mapVars, OperandRange mapIterated)
static LogicalResult verifyPrivateVarList(OpType &op)
static ParseResult parseNumTasksClause(OpAsmParser &parser, ClauseNumTasksTypeAttr &numTasksMod, std::optional< OpAsmParser::UnresolvedOperand > &numTasks, Type &numTasksType)
LogicalResult verifyAlignment(Operation &op, std::optional< uint64_t > alignment)
Verifies align clause in allocate directive.
static ParseResult parseAlignedClause(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &alignedVars, SmallVectorImpl< Type > &alignedTypes, ArrayAttr &alignmentsAttr)
aligned ::= aligned ( aligned-list ) aligned-list := aligned-val | aligned-val aligned-list aligned-v...
static ParseResult parsePrivateReductionRegion(OpAsmParser &parser, Region ®ion, llvm::SmallVectorImpl< OpAsmParser::UnresolvedOperand > &privateVars, llvm::SmallVectorImpl< Type > &privateTypes, ArrayAttr &privateSyms, UnitAttr &privateNeedsBarrier, ReductionModifierAttr &reductionMod, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &reductionVars, SmallVectorImpl< Type > &reductionTypes, DenseBoolArrayAttr &reductionByref, ArrayAttr &reductionSyms)
static void printLinearClause(OpAsmPrinter &p, Operation *op, ValueRange linearVars, TypeRange linearTypes, ValueRange linearStepVars, TypeRange stepVarTypes, ArrayAttr linearModifiers)
Print Linear Clause.
static ParseResult parseInReductionPrivateReductionRegion(OpAsmParser &parser, Region ®ion, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &inReductionVars, SmallVectorImpl< Type > &inReductionTypes, DenseBoolArrayAttr &inReductionByref, ArrayAttr &inReductionSyms, llvm::SmallVectorImpl< OpAsmParser::UnresolvedOperand > &privateVars, llvm::SmallVectorImpl< Type > &privateTypes, ArrayAttr &privateSyms, UnitAttr &privateNeedsBarrier, ReductionModifierAttr &reductionMod, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &reductionVars, SmallVectorImpl< Type > &reductionTypes, DenseBoolArrayAttr &reductionByref, ArrayAttr &reductionSyms)
static LogicalResult checkApplyeesNesting(TileOp op)
Check properties of the loop nest consisting of the transformation's applyees:
static ParseResult parseCaptureType(OpAsmParser &parser, VariableCaptureKindAttr &mapCaptureType)
static ParseResult parseTaskReductionRegion(OpAsmParser &parser, Region ®ion, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &taskReductionVars, SmallVectorImpl< Type > &taskReductionTypes, DenseBoolArrayAttr &taskReductionByref, ArrayAttr &taskReductionSyms)
static ParseResult parseGrainsizeClause(OpAsmParser &parser, ClauseGrainsizeTypeAttr &grainsizeMod, std::optional< OpAsmParser::UnresolvedOperand > &grainsize, Type &grainsizeType)
static ParseResult parseCopyprivate(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > ©privateVars, SmallVectorImpl< Type > ©privateTypes, ArrayAttr ©privateSyms)
copyprivate-entry-list ::= copyprivate-entry | copyprivate-entry-list , copyprivate-entry copyprivate...
static LogicalResult verifyMapInfoDefinedArgs(Operation *op, StringRef clauseName, OperandRange vars)
static void printGrainsizeClause(OpAsmPrinter &p, Operation *op, ClauseGrainsizeTypeAttr grainsizeMod, Value grainsize, mlir::Type grainsizeType)
static ParseResult verifyScheduleModifiers(OpAsmParser &parser, SmallVectorImpl< SmallString< 12 > > &modifiers)
static bool isUnique(It begin, It end)
static LogicalResult emit(SolverOp solver, const SMTEmissionOptions &options, mlir::raw_indented_ostream &stream)
Emit the SMT operations in the given 'solver' to the 'stream'.
static SmallVector< Value > getTileSizes(Location loc, x86::amx::TileType tType, RewriterBase &rewriter)
Maps the 2-dim vector shape to the two 16-bit tile sizes.
This base class exposes generic asm parser hooks, usable across the various derived parsers.
virtual ParseResult parseMinus()=0
Parse a '-' token.
@ Paren
Parens surrounding zero or more operands.
@ None
Zero or more operands with no delimiters.
virtual ParseResult parseColonTypeList(SmallVectorImpl< Type > &result)=0
Parse a colon followed by a type list, which must have at least one type.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalEqual()=0
Parse a = token if present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseOptionalColon()=0
Parse a : token if present.
virtual ParseResult parseLSquare()=0
Parse a [ token.
virtual ParseResult parseRSquare()=0
Parse a ] token.
ParseResult parseInteger(IntT &result)
Parse an integer value from the stream.
virtual ParseResult parseOptionalArrow()=0
Parse a '->' token if present.
virtual ParseResult parseLess()=0
Parse a '<' token.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual ParseResult parseColon()=0
Parse a : token.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseOptionalLess()=0
Parse a '<' token if present.
virtual ParseResult parseArrow()=0
Parse a '->' token.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseComma()=0
Parse a , token.
virtual ParseResult parseOptionalLParen()=0
Parse a ( token if present.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
Attributes are known-constant values of operations.
Block represents an ordered list of Operations.
ValueTypeRange< BlockArgListType > getArgumentTypes()
Return a range containing the types of the arguments for this block.
BlockArgument getArgument(unsigned i)
unsigned getNumArguments()
SuccessorRange getSuccessors()
Operation * getTerminator()
Get the terminator operation of this block.
bool mightHaveTerminator()
Return "true" if this block might have a terminator.
BlockArgListType getArguments()
IntegerType getIntegerType(unsigned width)
MLIRContext * getContext() const
Attr getAttr(Args &&...args)
Get or construct an instance of the attribute Attr with provided arguments.
Diagnostic & append(Arg1 &&arg1, Arg2 &&arg2, Args &&...args)
Append arguments to the diagnostic.
Diagnostic & appendOp(Operation &op, const OpPrintingFlags &flags)
Append an operation with the given printing flags.
Dialects are groups of MLIR operations, types and attributes, as well as behavior associated with the...
A class for computing basic dominance information.
bool dominates(Operation *a, Operation *b) const
Return true if operation A dominates operation B, i.e.
This class represents a diagnostic that is inflight and set to be reported.
Diagnostic & attachNote(std::optional< Location > noteLoc=std::nullopt)
Attaches a note to this diagnostic.
MLIRContext is the top-level object for a collection of MLIR operations.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region ®ion, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult parseArgument(Argument &result, bool allowType=false, bool allowAttrs=false)=0
Parse a single argument with the following syntax:
virtual ParseResult parseArgumentList(SmallVectorImpl< Argument > &result, Delimiter delimiter=Delimiter::None, bool allowType=false, bool allowAttrs=false)=0
Parse zero or more arguments with a specified surrounding delimiter.
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
virtual void printRegionArgument(BlockArgument arg, ArrayRef< NamedAttribute > argAttrs={}, bool omitType=false)=0
Print a block argument in the usual format of: ssaName : type {attr1=42} loc("here") where location p...
virtual void printOperand(Value value)=0
Print implementations for various things an operation contains.
This class helps build Operations.
This class represents an operand of an operation.
Set of flags used to control the behavior of the various IR print methods (e.g.
This class provides the API for ops that are known to be isolated from above.
This class provides the API for ops that are known to be terminators.
This class indicates that the regions associated with this op don't have terminators.
This class implements the operand iterators for the Operation class.
type_range getType() const
Operation is the basic unit of execution within MLIR.
Dialect * getDialect()
Return the dialect this operation is associated with, or nullptr if the associated dialect is not loa...
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Block * getBlock()
Returns the operation block that contains this operation.
unsigned getNumRegions()
Returns the number of regions held by this operation.
Location getLoc()
The source location the operation was defined or derived from.
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
operand_range getOperands()
Returns an iterator on the underlying Value's.
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
user_range getUsers()
Returns a range of all users.
Region * getParentRegion()
Returns the region to which the instruction belongs.
MLIRContext * getContext()
Return the context this operation is associated with.
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class contains a list of basic blocks and a link to the parent operation it is attached to.
BlockArgListType getArguments()
OpIterator op_begin()
Return iterators that walk the operations nested directly within this region.
bool isAncestor(Region *other)
Return true if this region is ancestor of the other region.
iterator_range< OpIterator > getOps()
unsigned getNumArguments()
Location getLoc()
Return a location for this region.
BlockArgument getArgument(unsigned i)
Operation * getParentOp()
Return the parent operation this region is attached to.
BlockListType & getBlocks()
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Resource * getResource() const
Return the resource that the effect applies to.
EffectT * getEffect() const
Return the effect being applied.
This class represents a collection of SymbolTables.
virtual Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
static Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
type_range getType() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
MLIRContext * getContext() const
Utility to get the associated MLIRContext that this value is defined in.
Type getType() const
Return the type of this value.
use_range getUses() const
Returns a range of all uses, which is useful for iterating over all uses.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
A utility result that is used to signal how to proceed with an ongoing walk:
static WalkResult advance()
static WalkResult interrupt()
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< bool > content)
bool isReachableFromEntry(Block *a) const
Return true if the specified block is reachable from the entry block of its region.
Operation * getOwner() const
Return the owner of this operand.
SideEffects::EffectInstance< Effect > EffectInstance
TargetEnterDataOperands TargetEnterExitUpdateDataOperands
omp.target_enter_data, omp.target_exit_data and omp.target_update take the same clauses,...
std::tuple< NewCliOp, OpOperand *, OpOperand * > decodeCli(mlir::Value cli)
Find the omp.new_cli, generator, and consumer of a canonical loop info.
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
detail::DenseArrayAttrImpl< int64_t > DenseI64ArrayAttr
function_ref< void(Value, StringRef)> OpAsmSetValueNameFn
A functor used to set the name of the start of a result group of an operation.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
llvm::DenseSet< ValueT, ValueInfoT > DenseSet
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
bool isPure(Operation *op)
Returns true if the given operation is pure, i.e., is speculatable that does not touch memory.
detail::constant_int_predicate_matcher m_Zero()
Matches a constant scalar / vector splat / tensor splat integer zero.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
llvm::TypeSwitch< T, ResultT > TypeSwitch
SmallVector< Loops, 8 > tile(ArrayRef< scf::ForOp > forOps, ArrayRef< Value > sizes, ArrayRef< scf::ForOp > targets)
Performs tiling fo imperfectly nested loops (with interchange) by strip-mining the forOps by sizes an...
detail::DenseArrayAttrImpl< bool > DenseBoolArrayAttr
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
function_ref< void(Block *, StringRef)> OpAsmSetBlockNameFn
A functor used to set the name of blocks in regions directly nested under an operation.
This is the representation of an operand reference.
This class provides APIs and verifiers for ops with regions having a single block.
This represents an operation in an abstracted form, suitable for use with the builder APIs.
void addOperands(ValueRange newOperands)
void addAttributes(ArrayRef< NamedAttribute > newAttributes)
Add an array of named attributes.
void addTypes(ArrayRef< Type > newTypes)
Region * addRegion()
Create a region that should be attached to the operation.